线段树合并

杀马特。学长 韩版系。学妹 提交于 2019-12-04 06:39:04

线段树合并

应用范围:将子树的信息合并给父亲节点,并且权值线段树的下标值域和节点数相近。

CF600E Lomsat gelral

题目链接

题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。

\(1 <= n<=1e5\)

解法:线段树合并,这个东西的时空复杂度都很玄学,姑且认为时间为\(O(nlogn)\),空间为(常数\(\times log_n\times n\)),常数一般为\(4-8\)

#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
    nxt[++tot] = fir[u];
    fir[u] = tot;
    vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
    if(sum[lz[a]] < sum[rz[a]])
    {
        sum[a] = sum[rz[a]];
        ans[a] = ans[rz[a]];
    }
    if(sum[lz[a]] > sum[rz[a]])
    {
        sum[a] = sum[lz[a]];
        ans[a] = ans[lz[a]];
    }
    if(sum[lz[a]] == sum[rz[a]])
    {
        sum[a] = sum[lz[a]];
        ans[a] = ans[lz[a]] + ans[rz[a]];
    }
    return;
}
int merge(int a, int b, int l, int r)
{
    if(a == 0) return b;
    if(b == 0) return a;
    if(l == r)
    {
        sum[a] += sum[b];
        ans[a] = l;
        return a;
    }
    int mid = (l + r) >> 1;
    lz[a] = merge(lz[a], lz[b], l, mid);
    rz[a] = merge(rz[a], rz[b], mid + 1, r);
    pushup(a);
    return a;
}
void update(int &a, int l, int r, int v)
{
    if(!a) a = ++cnt;
    int mid = (l + r) >> 1;
    if(l == r)
    {
        sum[a] += 1;
        ans[a] = l;
        return;
    }
    if(mid >= v) update(lz[a], l, mid, v);
    if(mid < v) update(rz[a], mid + 1, r, v); 
    pushup(a);
    return;
}
void dfs(int u, int fa)
{
    for(int i = fir[u]; i; i = nxt[i])
    {
        int v = vv[i];
        if(v == fa) continue;
        dfs(v, u);
        merge(root[u], root[v], 1, 100000);
    }
    update(root[u], 1, 100000, col[u]);
    ans[u] = ans[root[u]];
}
int main()
{
    scanf("%d", &n); cnt = n;
    for(int i = 1; i <= n; i++) 
    {
        scanf("%d", &col[i]); root[i] = i;
    }
    for(int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v); add(v, u);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
    return 0;
}

雨天的尾巴

题目链接

注意\(ans\)要在\(dfs\)时计算,不然当前节点的\(root\)可能被父亲节点继承,然后就炸了。

#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 100100
#define ll long long
int n;
int fir[maxn], nxt[maxn * 2], vv[maxn * 2];
int tot = 0;
void add(int u, int v)
{
    nxt[++tot] = fir[u];
    fir[u] = tot;
    vv[tot] = v;
}
int cnt = 0;
int root[maxn], col[maxn];
int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2];
ll ans[maxn * 17 * 2];
void pushup(int a)
{
    if(sum[lz[a]] < sum[rz[a]])
    {
        sum[a] = sum[rz[a]];
        ans[a] = ans[rz[a]];
    }
    if(sum[lz[a]] > sum[rz[a]])
    {
        sum[a] = sum[lz[a]];
        ans[a] = ans[lz[a]];
    }
    if(sum[lz[a]] == sum[rz[a]])
    {
        sum[a] = sum[lz[a]];
        ans[a] = ans[lz[a]] + ans[rz[a]];
    }
    return;
}
int merge(int a, int b, int l, int r)
{
    if(a == 0) return b;
    if(b == 0) return a;
    if(l == r)
    {
        sum[a] += sum[b];
        ans[a] = l;
        return a;
    }
    int mid = (l + r) >> 1;
    lz[a] = merge(lz[a], lz[b], l, mid);
    rz[a] = merge(rz[a], rz[b], mid + 1, r);
    pushup(a);
    return a;
}
void update(int &a, int l, int r, int v)
{
    if(!a) a = ++cnt;
    int mid = (l + r) >> 1;
    if(l == r)
    {
        sum[a] += 1;
        ans[a] = l;
        return;
    }
    if(mid >= v) update(lz[a], l, mid, v);
    if(mid < v) update(rz[a], mid + 1, r, v); 
    pushup(a);
    return;
}
void dfs(int u, int fa)
{
    for(int i = fir[u]; i; i = nxt[i])
    {
        int v = vv[i];
        if(v == fa) continue;
        dfs(v, u);
        merge(root[u], root[v], 1, 100000);
    }
    update(root[u], 1, 100000, col[u]);
    ans[u] = ans[root[u]];
}
int main()
{
    scanf("%d", &n); cnt = n;
    for(int i = 1; i <= n; i++) 
    {
        scanf("%d", &col[i]); root[i] = i;
    }
    for(int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v); add(v, u);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; i++) printf("%lld ", ans[i]);
    return 0;
}
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!