洛谷P3384 【模板】树链剖分

匿名 (未验证) 提交于 2019-12-02 23:55:01

题目传送门

#include <iostream> #include <cstdio> #include <vector> #define maxn 100005 using namespace std; typedef long long ll; struct T{     int data, next; }e[maxn << 1]; int top[maxn], son[maxn], size[maxn], depth[maxn], data[maxn], fa[maxn]; int head[maxn], cnt; vector<int> vec; int p; struct node{     int l, r;     ll sum;     ll lazy; }tree[maxn << 2]; void add(int x, int y) {     ++ cnt;     e[cnt].data = y;     e[cnt].next = head[x];     head[x] = cnt; } void dfs1(int x) {     size[x] = 1;     for(int i = head[x]; i != 0; i = e[i].next){         int r = e[i].data;         if(r != fa[x]){             depth[r] = depth[x] + 1;             fa[r] = x;             dfs1(r);             size[x] += size[r];             if(!son[x] || size[r] > size[son[x]])                 son[x] = r;         }     } } int mp[maxn]; void dfs2(int x, int k) {     if(x == 0)         return;     top[x] = k;     vec.push_back(x);     mp[x] = vec.size() - 1;     dfs2(son[x], k);     for(int i = head[x]; i != 0; i = e[i].next){         int r = e[i].data;         if(r != fa[x] && r != son[x]){             dfs2(r, r);         }     } } void build(int l, int r, int k) {     tree[k].l = l;     tree[k].r = r;     if(l == r){         tree[k].sum = data[vec[l]];         return;     }     int mid = (l + r) / 2;     build(l, mid, 2*k);     build(mid + 1, r, 2*k+1);     tree[k].sum = tree[2*k].sum + tree[2*k+1].sum;     tree[k].sum %= p; } void down(int k) {     if(tree[k].lazy == 0)         return;     tree[2*k].sum += (tree[2*k].r - tree[2*k].l + 1) * tree[k].lazy;     tree[2*k+1].sum += (tree[2*k+1].r - tree[2*k+1].l + 1) * tree[k].lazy;     tree[2*k].lazy +=  tree[k].lazy;     tree[2*k + 1].lazy += tree[k].lazy;     tree[k].lazy = 0; } void add(int l, int r, int z, int k) {     if(tree[k].l >= l && tree[k].r <= r){         tree[k].sum += ((tree[k].r - tree[k].l + 1) * z) % p;         tree[k].sum %= p;         tree[k].lazy += z;         return;     }     down(k);     int mid = (tree[k].l + tree[k].r) / 2;     if(l <= mid)         add(l, r, z, 2*k);     if(r > mid)         add(l, r, z, 2*k+1);     tree[k].sum = tree[2*k].sum + tree[2*k + 1].sum;     tree[k].sum %= p; } void add1(int x, int y, int z) {     while(top[x] != top[y]){         if(depth[top[x]] > depth[top[y]]){             add(mp[top[x]], mp[x], z, 1);             x = fa[top[x]];         }         else {             add(mp[top[y]], mp[y], z, 1);             y = fa[top[y]];         }     }     if(depth[x] > depth[y])         add(mp[y], mp[x], z, 1);     else         add(mp[x], mp[y], z, 1); } ll query(int l, int r, int k) {     if(tree[k].l >= l && tree[k].r <= r){         return tree[k].sum;     }     down(k);     int mid = (tree[k].l + tree[k].r) / 2;     ll ans = 0;     if(l <= mid)         ans += query(l, r, 2*k), ans %= p;     if(r > mid)         ans += query(l, r, 2*k+1), ans %= p;;     return ans; } ll get1(int x, int y) {     ll ans = 0;     while(top[x] != top[y]){         if(depth[top[x]] > depth[top[y]]){             ans += query(mp[top[x]], mp[x], 1);             ans %= p;             x = fa[top[x]];         }         else {             ans += query(mp[top[y]], mp[y], 1);             ans %= p;             y = fa[top[y]];         }     }     if(depth[x] > depth[y])         ans += query(mp[y], mp[x], 1), ans %= p;     else         ans += query(mp[x], mp[y], 1), ans %= p;     return ans; } inline void add2(int x, int y) {     add(mp[x], mp[x] + size[x] - 1, y, 1); } inline ll get2(int x) {     return query(mp[x], mp[x] + size[x] - 1, 1) % p; } int main() {     int n, m, r;     scanf("%d%d%d%d", &n, &m, &r, &p);     for(int i = 1; i <= n; i ++)         scanf("%d", &data[i]);     for(int i = 1; i < n; i ++){         int x, y;         scanf("%d%d", &x, &y);         add(x, y);         add(y, x);     }     vec.push_back(0);     dfs1(r);     dfs2(r, r);     build(1, n, 1);     for(int i = 1; i <= m; i ++){         int opt;         scanf("%d", &opt);         if(opt == 1){             int x, y, z;             scanf("%d%d%d", &x, &y, &z);             add1(x, y, z);         }         else if(opt == 2){             int x, y;             scanf("%d%d", &x, &y);             printf("%lld\n", get1(x, y) % p);         }         else if(opt == 3){             int x, y;             scanf("%d%d", &x, &y);             add2(x, y);         }         else {             int x;             scanf("%d", &x);             printf("%lld\n", get2(x) % p);         }     }     return 0; }
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!