lca最近公共祖先(模板)

痴心易碎 提交于 2019-12-21 11:59:34

洛谷上的lca模板题——传送门

1.tarjan求lca

学了求lca的tarjan算法(离线),在洛谷上做模板题,结果后三个点超时。

又把询问改成链式前向星,才ok。

这个博客,tarjan分析的很详细。

附代码——

#include <cstdio>
#include <cstring>

const int maxn = 500001;

int n, m, cnt, s, cns;
int x, y, z[maxn];//z是x和y的lca 
int f[maxn], head[maxn], from[maxn];
bool vis[maxn];
struct node
{
    int to, next;
}e[2 * maxn];
struct Node
{
    int to, next, num;
}q[2 * maxn];

inline int read()//读入优化 
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

inline void ask(int u, int v, int i)//储存待询问的结构体,也是链式前向星优化 
{
    q[cns].num = i;//num表示第几次询问 
    q[cns].to = v;
    q[cns].next = from[u];
    from[u] = cns++;
}

inline void add(int u, int v)//
{
    e[cnt].to = v;
    e[cnt].next = head[u];
    head[u] = cnt++;
}

inline int find(int a)
{
    return a == f[a] ? a : f[a] = find(f[a]);//路径压缩优化 
}

/*inline void Union(int a, int b)
{
    int fx = find(a), fy = find(b);
    if(fx == fy) return;
    f[fy] = fx;
}*/

inline void tarjan(int k)
{
    int i, j;
    vis[k] = 1;
    f[k] = k;
    for(i = head[k]; i != -1; i = e[i].next)
     if(!vis[e[i].to])
     {
          tarjan(e[i].to);
          //Union(k, e[i].to);
          f[e[i].to] = k;
     }
    for(i = from[k]; i != -1; i = q[i].next)
     if(vis[q[i].to] == 1)
      z[q[i].num] = find(q[i].to);
}

int main()
{
    int i, j, u, v;
    n = read();
    m = read();
    s = read();
    memset(head, -1, sizeof(head));
    memset(from, -1, sizeof(from));
    for(i = 1; i <= n - 1; i++)
    {
        u = read();
        v = read();
        add(u, v);//注意添加两遍 
        add(v, u);
    }
    for(i = 1; i <= m; i++)
    {
        x = read();
        y = read();
        ask(x, y, i);//两遍 
        ask(y, x, i);
    }
    tarjan(s);
    for(i = 1; i <= m; i++) printf("%d\n", z[i]);
    return 0;
}
View Code

 

进过培训,修改了代码

 1 # include <iostream>
 2 # include <cstdio>
 3 # include <cstring>
 4 # include <string>
 5 # include <cmath>
 6 # include <vector>
 7 # include <map>
 8 # include <queue>
 9 # include <cstdlib>
10 # define MAXN 500001
11 using namespace std;
12 
13 inline int get_num() {
14     int k = 0, f = 1;
15     char c = getchar();
16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
18     return k * f;
19 }
20 
21 int n, m, s;
22 int fa[MAXN], qx[MAXN], qy[MAXN], ans[MAXN], f[MAXN];
23 vector <int> vec[MAXN], q[MAXN];
24 
25 inline int find(int x)
26 {
27     return x == fa[x] ? x : fa[x] = find(fa[x]);
28 }
29 
30 inline void dfs(int u)
31 {
32     int i, v;
33     fa[u] = u;
34     for(i = 0; i < vec[u].size(); i++)
35     {
36         v = vec[u][i];
37         if(f[u] != v) f[v] = u, dfs(v);
38     }
39     for(i = 0; i < q[u].size(); i++)
40         if(f[v = u ^ qx[q[u][i]] ^ qy[q[u][i]]])
41             ans[q[u][i]] = find(v);
42     fa[u] = f[u];
43 }
44 
45 int main()
46 {
47     int i, x, y;
48     n = get_num();
49     m = get_num();
50     s = get_num();
51     for(i = 1; i < n; i++)
52     {
53         x = get_num();
54         y = get_num();
55         vec[x].push_back(y);
56         vec[y].push_back(x);
57     }
58     for(i = 1; i <= m; i++)
59     {
60         qx[i] = get_num();
61         qy[i] = get_num();
62         q[qx[i]].push_back(i);
63         q[qy[i]].push_back(i);
64     }
65     dfs(s);
66     for(i = 1; i <= m; i++) printf("%d\n", ans[i]);
67     return 0;
68 }
View Code

 

其实上面两个代码有些重复运算,请手动把求lca的过程放到dfs上面(也就是遍历到这个节点就求lca,而不是遍历完再求)

 

2.倍增求lca

下面是求lca的倍增算法(在线)

1. DFS预处理出所有节点的深度和父节点

inline void dfs(int u)
{
    int i;
    for(i=head[u];i!=-1;i=next[i])  
    {  
        if (!deep[to[i]])
        {            
            deep[to[i]] = deep[u]+1;
            p[to[i]][0] = u; //p[x][0]保存x的父节点为u;
            dfs(to[i]);
        }
    }
}
dfs预处理

2. 初始各个点的2^j祖先是谁 ,其中 2^j (j =0...log(该点深度))倍祖先,1倍祖先就是父亲,2倍祖先是父亲的父亲......。

void init()
{
    int i,j;
    //p[i][j]表示i结点的第2^j祖先
    for(j=1;(1<<j)<=n;j++)
        for(i=1;i<=n;i++)
            if(p[i][j-1]!=-1)
                p[i][j]=p[p[i][j-1]][j-1];//i的第2^j祖先就是i的第2^(j-1)祖先的第2^(j-1)祖先
}
初始化

3.从深度大的节点上升至深度小的节点同层,如果此时两节点相同直接返回此节点,即lca。

否则,利用倍增法找到最小深度的 p[a][j]!=p[b][j],此时他们的父亲p[a][0]即lca。

int lca(int a,int b)//最近公共祖先
{
    int i,j;
    if(deep[a]<deep[b])swap(a,b);
    for(i=0;(1<<i)<=deep[a];i++);
    i--;
    //使a,b两点的深度相同
    for(j=i;j>=0;j--)
        if(deep[a]-(1<<j)>=deep[b])
            a=p[a][j];
    if(a==b)return a;
    //倍增法,每次向上进深度2^j,找到最近公共祖先的子结点
    for(j=i;j>=0;j--)
    {
        if(p[a][j]!=-1&&p[a][j]!=p[b][j])
        {
            a=p[a][j];
            b=p[b][j];
        }
    }
    return p[a][0];
}
倍增求lca

 

最后是完整代码,为了节约时间,就没有把p数组初始化为-1.

#include <cstdio>
#include <cstring>
#include <iostream>

const int maxn = 500001;
int n, m, cnt, s;
int next[2 * maxn], to[2 * maxn], head[2 * maxn], deep[maxn], p[maxn][21];

inline void add(int x, int y)
{
    to[cnt] = y;
    next[cnt] = head[x];
    head[x] = cnt++;
}

inline void dfs(int i)
{
    int j;
    for(j = head[i]; j != -1; j = next[j])
     if(!deep[to[j]])
     {
         deep[to[j]] = deep[i] + 1;
         p[to[j]][0] = i;
         dfs(to[j]);
     }
}

inline void init()
{
    int i, j;
    for(j = 1; (1 << j) <= n; j++)
     for(i = 1; i <= n; i++)
      p[i][j] = p[p[i][j - 1]][j - 1];
}

inline int lca(int a, int b)
{
    int i, j;
    if(deep[a] < deep[b]) std::swap(a, b);
    for(i = 0; (1 << i) <= deep[a]; i++);
    i--;
    for(j = i; j >= 0; j--)
     if(deep[a] - (1 << j) >= deep[b])
      a = p[a][j];
    if(a == b) return a;
    for(j = i; j >= 0; j--)
     if(p[a][j] != p[b][j])
     {
         a = p[a][j];
         b = p[b][j];
     }
    return p[a][0];
}

int main()
{
    int i, j, x, y;
    memset(head, -1, sizeof(head));
    scanf("%d %d %d", &n, &m, &s);
    for(i = 1; i <= n - 1; i++)
    {
        scanf("%d %d", &x, &y);
        add(x, y);
        add(y, x);
    }
    deep[s] = 1;
    dfs(s);
    init();
    for(i = 1; i <= m; i++)
    {
        scanf("%d %d", &x, &y);
        printf("%d\n", lca(x, y));
    }
    return 0;
}
View Code

 

经过培训,又改了改代码。

 1 # include <iostream>
 2 # include <cstdio>
 3 # include <cstring>
 4 # include <string>
 5 # include <cmath>
 6 # include <vector>
 7 # include <map>
 8 # include <queue>
 9 # include <cstdlib>
10 # define MAXN 500001
11 using namespace std;
12 
13 inline int get_num() {
14     int k = 0, f = 1;
15     char c = getchar();
16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
18     return k * f;
19 }
20 
21 int n, m, s;
22 int f[MAXN][25], deep[MAXN];
23 vector <int> vec[MAXN];
24 
25 inline void dfs(int u)
26 {
27     int i, v;
28     deep[u] = deep[f[u][0]] + 1;
29     for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i];
30     for(i = 0; i < vec[u].size(); i++)
31     {
32         v = vec[u][i];
33         if(!deep[v]) f[v][0] = u, dfs(v);
34     }
35 }
36 
37 inline int lca(int x, int y)
38 {
39     int i;
40     if(deep[x] < deep[y]) swap(x, y);
41     for(i = 20; i >= 0; i--)
42         if(deep[f[x][i]] >= deep[y])
43             x = f[x][i];
44     if(x == y) return x;
45     for(i = 20; i >= 0; i--)
46         if(f[x][i] != f[y][i])
47             x = f[x][i], y = f[y][i];
48     return f[x][0];
49 }
50 
51 int main()
52 {
53     int i, x, y;
54     n = get_num();
55     m = get_num();
56     s = get_num();
57     for(i = 1; i < n; i++)
58     {
59         x = get_num();
60         y = get_num();
61         vec[x].push_back(y);
62         vec[y].push_back(x);
63     }
64     dfs(s);
65     for(i = 1; i <= m; i++)
66     {
67         scanf("%d %d", &x, &y);
68         printf("%d\n", lca(x, y));
69     }
70     return 0;
71 }
View Code

 

3.树剖法求lca

 1 # include <iostream>
 2 # include <cstdio>
 3 # include <cstring>
 4 # include <string>
 5 # include <cmath>
 6 # include <vector>
 7 # include <map>
 8 # include <queue>
 9 # include <cstdlib>
10 # define MAXN 500001
11 using namespace std;
12 
13 inline int get_num() {
14     int k = 0, f = 1;
15     char c = getchar();
16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
18     return k * f;
19 }
20 
21 int n, m, s;
22 int f[MAXN], size[MAXN], top[MAXN], son[MAXN], deep[MAXN];
23 vector <int> vec[MAXN];
24 
25 inline void dfs1(int u)
26 {
27     int i, v;
28     size[u] = 1;
29     deep[u] = deep[f[u]] + 1;
30     for(i = 0; i < vec[u].size(); i++)
31     {
32         v = vec[u][i];
33         if(!deep[v])
34         {
35             f[v] = u;
36             dfs1(v);
37             size[u] += size[v];
38             if(size[son[u]] < size[v]) son[u] = v;
39         }
40     }
41 }
42 
43 inline void dfs2(int u, int tp)
44 {
45     int i, v;
46     top[u] = tp;
47     if(!son[u]) return;
48     dfs2(son[u], tp);
49     for(i = 0; i < vec[u].size(); i++)
50     {
51         v = vec[u][i];
52         if(v != son[u] && v != f[u]) dfs2(v, v);
53     }
54 }
55 
56 inline int lca(int x, int y)
57 {
58     while(top[x] != top[y])
59     {
60         if(deep[top[x]] < deep[top[y]]) swap(x, y);
61         x = f[top[x]];
62     }
63     if(deep[x] > deep[y]) swap(x, y);
64     return x;
65 }
66 
67 int main()
68 {
69     int i, x, y;
70     n = get_num();
71     m = get_num();
72     s = get_num();
73     for(i = 1; i < n; i++)
74     {
75         x = get_num();
76         y = get_num();
77         vec[x].push_back(y);
78         vec[y].push_back(x);
79     }
80     dfs1(s);
81     dfs2(s, s);
82     for(i = 1; i <= m; i++)
83     {
84         x = get_num();
85         y = get_num();
86         printf("%d\n", lca(x, y));
87     }
88     return 0;
89 }
View Code

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!