线段树合并,就是将已有的两棵线段树合并为一棵,相同位置的信息整合到一起,通常是权值线段树
比较裸的,就是将一棵线段树的每一个位置取出来插入另一棵中
但比较高效的线段树合并可以参照可并堆的合并方式
线段树合并的原理十分简单,具体步骤如下:
对于两颗树的节点u和v
①如果u为空,返回v
②如果v为空,返回u
③否则,新建节点t,整合u和v的信息,然后递归合并u和v的左右子树
代码如下:
int merge(int u,int v){ if (!u) return v; if (!v) return u; int t = ++cnt; sum[t] = sum[u] + sum[v]; ls[t] = merge(ls[u],ls[v]); rs[t] = merge(rs[u],rs[v]); return t; }
容易发现,这样合并的复杂度取决于两棵线段树重合的部分的大小
每有一个位置权值同样存在,就要\(O(logn)\)的复杂度
不过,由于权值线段树中被更新的位置通常很均匀分布,所以合并的两棵线段树通常具有很小的相似性
可以用这一道水题入门线段树合并
BZOJ4756
#include<iostream> #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> #define LL long long int #define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt) #define REP(i,n) for (int i = 1; i <= (n); i++) #define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts(""); using namespace std; const int maxn = 100005,maxm = 10000005,INF = 1000000000; inline int read(){ int out = 0,flag = 1; char c = getchar(); while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();} while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();} return out * flag; } int val[maxn],b[maxn],tot = 1; int fa[maxn],lsn[maxn],rbr[maxn]; int sum[maxm],ls[maxm],rs[maxm],rt[maxn],cnt; int n,ans[maxn]; int getn(int x){return lower_bound(b + 1,b + 1 + tot,x) - b;} void modify(int& u,int l,int r,int pos){ if (!u) u = ++cnt; sum[u]++; if (l == r) return; int mid = l + r >> 1; if (mid >= pos) modify(ls[u],l,mid,pos); else modify(rs[u],mid + 1,r,pos); } int query(int u,int l,int r,int L){ if (!u) return 0; if (l >= L) return sum[u]; int mid = l + r >> 1; if (mid >= L) return query(ls[u],l,mid,L) + query(rs[u],mid + 1,r,L); return query(rs[u],mid + 1,r,L); } int merge(int u,int v){ if (!u) return v; if (!v) return u; int t = ++cnt; sum[t] = sum[u] + sum[v]; ls[t] = merge(ls[u],ls[v]); rs[t] = merge(rs[u],rs[v]); return t; } void dfs(int u){ for (int k = lsn[u]; k; k = rbr[k]){ dfs(k); rt[u] = merge(rt[u],rt[k]); } ans[u] = query(rt[u],1,tot,val[u] + 1); modify(rt[u],1,tot,val[u]); } int main(){ n = read(); for (int i = 1; i <= n; i++) b[i] = val[i] = read(); for (int i = 2; i <= n; i++){ fa[i] = read(); rbr[i] = lsn[fa[i]]; lsn[fa[i]] = i; } sort(b + 1,b + 1 + n); for (int i = 2; i <= n; i++) if (b[i] != b[tot]) b[++tot] = b[i]; for (int i = 1; i <= n; i++) val[i] = getn(val[i]); dfs(1); for (int i = 1; i <= n; i++) printf("%d\n",ans[i]); return 0; }
来源:https://www.cnblogs.com/Mychael/p/8665589.html