题意: 给出一颗树,求树上的两点间最近距离不超过k的点对数量。
思路: 对于一条树路径 只有经过或不经过一个点的情况,对于不经过的情况,把一棵树按这个点拆成好几棵分治就行了,考虑经过这个点的情况,对于这题 可以对这个点延伸出的几棵子树各做一次dfs,记录子树中出现的距离值,对于一棵树的距离值数组,把它排序求一次ans1,再对每棵子树分别求一个自己对自己的ans2, ( a n s 1 − ∑ a n s 2 ) (ans1-\sum ans2) (ans1−∑ans2)即为最后的ans。
#include<bits/stdc++.h> #define ll long long using namespace std; const int N = 2e4 + 10; int n, k, h[N], cnt, rt, sz[N], mx; int dep[N], bel[N], vis[N], d[N], tot, ans; struct node { int v, w, nt; } no[N]; void add(int u, int v, int w) { no[cnt] = node{v, w, h[u]}; h[u] = cnt++; } void getroot(int u, int fa) { sz[u] = 1; int ma = 0; for(int i = h[u]; ~i; i = no[i].nt) { int v = no[i].v; if(v != fa && !vis[v]) { getroot(v, u); sz[u] += sz[v]; ma = max(ma, sz[v]); } } ma = max(ma, n - sz[u]); if(ma < mx) { mx = ma; rt = u; } } void getdep(int u, int fa) { d[++tot] = dep[u]; for(int i = h[u]; ~i; i = no[i].nt) { int v = no[i].v; if(v != fa && !vis[v]) dep[v] = dep[u] + no[i].w, getdep(v, u); } } int calc(int u) { tot = 0, getdep(u, 0); sort(d + 1, d + tot + 1); int l = 1, r = tot, sum = 0; while(l < r) { if(d[l] + d[r] <= k) sum += r - l, l++; else r--; } return sum; } void dfs(int u) { dep[u] = 0, vis[u] = 1, ans += calc(u); for(int i = h[u]; ~i; i = no[i].nt) { int v = no[i].v; if(!vis[v]) mx = 1e9, dep[v] = no[i].w, ans -= calc(v), n = sz[v], rt = 0, getroot(v, 0), dfs(rt); } } int main() { while(~scanf("%d%d", &n, &k) && (n + k)) { memset(h, -1, sizeof h); cnt = 0, rt = 0, ans = 0; memset(vis, 0, sizeof vis); for(int u, v, w, i = 1; i < n; i++) { scanf("%d%d%d", &u, &v, &w); add(u, v, w), add(v, u, w); } mx = 1e9; getroot(1, 0), dfs(rt); printf("%d\n", ans); } return 0; }