

裸的求公共祖先的题目。
思路:用被增法求最近公共祖先,用fa[i][j]表示从i开始,向上走2j 步能走到的所有节点,其中1 <= j <= logn(下取整)

例如上图,f[6][0] = 4, f[6][1] = 2, f[6][2] = -1表示不存在。
做法:1.首先我们需要预处理一个fa数组,采用递推的方式。fa[i][j]表示从i开始向上走2j 步,那么我们可以拆成两部分,先走2j -1 步再走2j - 1 步,也就是fa[fa[i][j - 1][j - 1]。于是我们便得到递推公式fa[i][i] = fa[fa[i][j - 1]][j - 1]。只需将j从小到大枚举一遍就可以了。
ps:本质是用二进制拼凑路径长度,和多重背包的二进制优化思想大致。
2.同时我们还需要预处理一个depth数组,来表示当前点的深度,例如depth[1] = 1.depth[6] = 4。就是到根节点的距离+1。同时我们再设置两个“哨兵“,如果从i跳过了根节点,那么fa[i][j] = 0, depth[0] = 0。
3.预处理完两个数组后,进行操作,先将两个点跳到同一层,我们就统一将a看做较低的节点。b看做较高的节点。
4.两个节点所在层数相同后,如果还不是公共祖先,就一起往上跳到最近公共祖先的下一层。(比较好判断,因为当f[a][k] = f[b][k]时,表示跳到了公共节点,但不一定是最近的,而跳到第一个公共节点前一层,再往上就一定是最近的公共节点,即fa[a][0]就是答案。
预处理的复杂度是O(logn),查询也是O(logn)。
1 #include <iostream>
2 #include <algorithm>
3 #include <cstring>
4 #include <queue>
5
6 using namespace std;
7
8 const int N = 40010, M = 2 * N;
9
10 int e[M], ne[M], h[M], idx;
11 int depth[N], fa[N][16];
12
13 void add(int a, int b)
14 {
15 e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
16 }
17
18 void bfs(int root)
19 {
20 memset(depth, 0x3f, sizeof depth);
21 queue<int> q;
22
23 q.push(root);
24 depth[0] = 0, depth[root] = 1;//初始化
25
26 while(q.size())
27 {
28 int t = q.front();
29 q.pop();
30
31 for(int i = h[t] ; ~i ; i = ne[i])
32 {
33 int j = e[i];
34 if(depth[j] > depth[t] + 1)
35 {
36 depth[j] = depth[t] + 1;
37 fa[j][0] = t;
38 q.push(j);
39 for(int k = 1 ; k <= 15 ; k ++)
40 fa[j][k] = fa[fa[j][k - 1]][k - 1];
41 }
42 }
43 }
44 }
45
46 int lca(int a, int b)
47 {
48 if(depth[a] < depth[b])swap(a, b);//默认a都是深度较深的节点
49 for(int k = 15 ; k >= 0 ; k --)
50 if(depth[fa[a][k]] >= depth[b])//将a跳到和b同一层,因为默认depth都是>=1的,所以depth[] = 0时不会满足条件。“哨兵”的用处。
51 a = fa[a][k];
52 if(a == b)return a;//相同了就返回其中一个,就是最近祖先节点
53 for(int k = 15 ; k >= 0 ; k --)
54 if(fa[a][k] != fa[b][k])
55 {
56 a = fa[a][k];
57 b = fa[b][k];
58 }
59 return fa[a][0];
60 }
61
62
63 int main(){
64 int n, m, root = 0;
65 cin >> n;
66 memset(h, -1, sizeof h);
67
68 while(n --)
69 {
70 int a, b;
71 cin >> a >> b;
72 if(b == -1)root = a;
73 else add(a, b), add(b, a);
74 }
75
76 bfs(root);
77
78 cin >> m;
79 while(m --)
80 {
81 int a, b;
82 cin >> a >> b;
83 int p = lca(a, b);
84 if(p == a)puts("1");
85 else if(p == b)puts("2");
86 else puts("0");
87 }
88 return 0;
89 }

求树上两点间的距离,可以转换成求两点间的公共祖先的方法来解决。
这里用到tarjan算法(对求公共祖先的向上标记法的优化)。
j

将树上的节点分为三类:1.已经遍历过的节点(绿色)。2.正在遍历的节点(红色)。3.还未遍历的节点(紫色)。
我们可以发现,求红色部分和绿色部分的最近公共祖先,就是红色部分上的那些节点,也就是橙色部分圈起来的子树,可以合并到他的祖宗节点去。(集合合并)所以我们可以用并查集做。
除此之外,我们还需预处理一个dis数组,dis[i]表示根节点到i的距离。

之后求x和y之间的距离,只需要dis[x] + dis[y] - 2 * dis[lca],lca是x和y的最近公共祖先。
并查集合并和查询的操作都是O(1)所以算法复杂度是线性的O(n + m)。
代码:
1 #include <iostream>
2 #include <algorithm>
3 #include <cstring>
4 #include <vector>
5
6 using namespace std;
7
8 typedef pair<int, int> PII;
9
10 const int N = 10010, M = 2 * N;
11
12 int e[M], ne[M], w[M], h[M], idx;
13 int ans[M], dis[N], st[N], f[N];
14 vector<PII> query[N];
15 int n, m;
16
17 void add(int a, int b, int c)
18 {
19 e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
20 }
21
22 int find(int x)
23 {
24 return f[x] == x ? x : f[x] = find(f[x]);
25 }
26
27
28 void dfs(int u, int fa)
29 {
30 for(int i = h[u] ; ~i ; i = ne[i])
31 {
32 int j = e[i];
33 if(j == fa)continue;
34 dis[j] = dis[u] + w[i];
35 dfs(j, u);
36 }
37 }
38
39 void tarjan(int u)
40 {
41 st[u] = true;
42 for(int i = h[u] ; ~i ; i = ne[i])
43 {
44 int j = e[i];
45 if(!st[j])
46 {
47 tarjan(j);
48 f[j] = u;//将子树合并到祖先节点
49 }
50 }
51
52 for(auto item : query[u])
53 {
54 int y = item.first, id = item.second;
55 if(st[y] == 2)//标记为2说明是完成搜索的
56 {
57 int lca = find(y);
58 ans[id] = dis[u] + dis[y] - 2 * dis[lca];
59 }
60 }
61
62 st[u] = 2;//完成搜索后标记为2
63 }
64
65 int main(){
66 cin >> n >> m;
67
68 memset(h, -1, sizeof h);
69
70 for(int i = 0 ; i < n - 1; i ++)
71 {
72 int a, b, c;
73 cin >> a >> b >> c;
74 add(a, b, c), add(b, a, c);
75 }
76
77 for(int i = 0 ; i < m ; i ++)
78 {
79 int a, b;
80 cin >> a >> b;
81 query[a].push_back({b, i});//存入和这个点相关的点和查询的下标
82 query[b].push_back({a, i});
83 }
84
85 for(int i = 1 ; i <= n ; i ++)f[i] = i;
86
87 dfs(1, -1);//预处理dis数组
88 tarjan(1);
89
90 for(int i = 0 ; i < m ; i ++)cout << ans[i] << endl;
91
92 return 0;
93 }