如果对于每个询问跑一次$dp$,那么$dp[i]$为断开$i$这棵子树的最小花费。
这样的复杂度为$O(n*m)$,过于臃肿。
所以我们要对于每次询问降低这次询问的复杂度。
我们可以发现$m$个关键点,最多有$m-1$个$lca$。
简单证明一下,如果有两个点,会有$1$个$lca$点,如果有三个点,则第三个点会和上一个$lca$产生一个$lca$。
所以以这$2*m-1$个点构建一棵树,在这个树上跑$dp$
虚树的构建推荐一个巨巨的博客
1 #include <bits/stdc++.h>
2 using namespace std;
3 typedef long long ll;
4 const int maxn = 250010;
5 const ll inf = 2e18 + 10;
6 struct node {
7 int s, e, next;
8 ll w;
9 }edge[maxn * 2];
10 int head[maxn], len;
11 void init() {
12 memset(head, -1, sizeof(head));
13 len = 0;
14 }
15 void add(int s, int e, ll w) {
16 edge[len] = { s,e,head[s],w };
17 head[s] = len++;
18 }
19 int fat[maxn], son[maxn], siz[maxn], top[maxn], tid[maxn], dep[maxn], dfx;
20 ll a[maxn], st[maxn], dp[maxn], Min[maxn], stop;
21 vector<int>mp[maxn];
22 void dfs1(int x, int fa, int d) {
23 siz[x] = 1, fat[x] = fa;
24 dep[x] = d, son[x] = -1;
25 for (int i = head[x]; i != -1; i = edge[i].next) {
26 int y = edge[i].e;
27 if (y == fa)continue;
28 Min[y] = min(Min[x], edge[i].w);
29 dfs1(y, x, d + 1);
30 siz[x] += siz[y];
31 if (son[x] == -1 || siz[son[x]] < siz[y])
32 son[x] = y;
33 }
34 }
35 void dfs2(int x, int c) {
36 top[x] = c;
37 tid[x] = ++dfx;
38 if (son[x] == -1)return;
39 dfs2(son[x], c);
40 for (int i = head[x]; i != -1; i = edge[i].next) {
41 int y = edge[i].e;
42 if (fat[x] == y || y == son[x])continue;
43 dfs2(y, y);
44 }
45 }
46 int LCA(int x, int y) {
47 while (top[x] != top[y]) {
48 if (dep[top[x]] < dep[top[y]])
49 swap(x, y);
50 x = fat[top[x]];
51 }
52 if (dep[x] > dep[y])swap(x, y);
53 return x;
54
55 }
56 bool cmp(int x, int y) {
57 return tid[x] < tid[y];
58 }
59 void dfs3(int x) {
60 if (mp[x].size() == 0) {
61 dp[x] = Min[x];
62 return;
63 }
64 ll sum = 0;
65 for (int i = 0; i < mp[x].size(); i++) {
66 int y = mp[x][i];
67 dfs3(y);
68 sum += dp[y];
69 }
70 mp[x].clear();
71 dp[x] = min(Min[x], sum);
72 }
73 void insert(int x) {
74 if (stop == 1) {
75 st[++stop] = x;
76 return;
77 }
78 int lca = LCA(x, st[stop]);
79 if (lca == st[stop])
80 return;
81 while (stop > 1 && tid[st[stop - 1]] >= tid[lca])
82 mp[st[stop - 1]].push_back(st[stop]), stop--;
83 if (lca != st[stop])
84 mp[lca].push_back(st[stop]), st[stop] = lca;
85 st[++stop] = x;
86 }
87 int main() {
88 init();
89 int n, m, t, x, y, z;
90 scanf("%d", &n);
91 for (int i = 1; i < n; i++) {
92 scanf("%d%d%d", &x, &y, &z);
93 add(x, y, z);
94 add(y, x, z);
95 }
96 Min[1] = inf;
97 dfs1(1, 0, 1);
98 dfs2(1, 1);
99 scanf("%d", &m);
100 while (m--) {
101 scanf("%d", &t);
102 for (int i = 1; i <= t; i++)
103 scanf("%d", &a[i]);
104 sort(a + 1, a + 1 + t, cmp);
105 st[stop = 1] = 1;
106 for (int i = 1; i <= t; i++)
107 insert(a[i]);
108 while (stop > 1)
109 mp[st[stop - 1]].push_back(st[stop]), stop--;
110 dfs3(1);
111 printf("%lld\n", dp[1]);
112 }
113 }