今天才发现自己根本不会树形背包,我太菜了。
一般的树形背包是这样做的:

看上去,它的复杂度是 $O(nk^2)$ 的。
第一种优化:
这里,如果第二维的大小和子树大小有关,同时又不超过一个常数 $k$ 。例如:第二维表示子树内选了多少个点,那么通过一些精妙的分析和上界优化,复杂度就可以变成 $O(nk)$ 了。
以下的 $siz_x$ 表示合并 $son$ 这个子树前 $x$ 子树的大小(注意:不是 $x$ 的真实子树大小,这里很重要)。

这样分析出来的复杂度就是 $O(nk)$ .
证明:摘自这里;
首先,定义 $T(n)$ 为处理 $n$ 这棵子树时所用的时间,$f(n)$ 为处理 $n$ 这个点时所用的时间。
$T(x)=\left(\sum_{f_y=x} T_{y}\right)+f(x)\\f(x)=\min(m,siz(y_1))\times \min(m,siz(y_1))+\min(m,siz(y_1)+siz(y_2))\times \min(m,siz(y_1))\\ ~~~~~~~~~~~+\cdots+\min(m,siz(x))\times \min(m,siz(y_n))$
现在进行一番放缩,把每个乘法的前一项统一变成 $\min(m,siz(x))$ ,这样显然只会使答案变大,所以分析出来的复杂度上界就应该是正确的。
$f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} \min(m,siz(y))\right)$
再次放缩,把后面括号里的 $min$ 直接扔掉,得:
$f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} siz(y)\right)\\~~~~~~~~=\min(m,siz(x))\times siz(x)$
对于 $siz(x)<m$ 的点,首先考虑他的子树都是叶子的情况:
$T(x)=siz(x)^2+\sum 1$
对于任意 $siz(x)<m$ 的点,递归证明,由于 “平方和小于和的平方” ,所以 $T(x)$ 与 $siz(x)^2$ 同阶;
对于 $siz(x)>m$ 的点,首先考虑它的所有子树都小于 $m$ 的情况:
$T(x)=m\times siz(x)+\sum siz(j)^2$
接着放缩可得,$T(x)$ 与 $m\times siz(x)$ 同阶;
继续使用递归证明的技巧,考虑某一层出现了子树大于 $m$ 的情况:
$T(x)=m\times siz(x)+\sum siz(j)^2+\sum siz(j)\times m$
所以,$T(x)$ 还是与 $m\times siz(x)$ 同阶;
综上所述,这种做法的复杂度是 $n\times k$ 。
选课加强版:https://www.luogu.org/problem/U53204

1 # include <cstdio>
2 # include <iostream>
3 # include <cstring>
4 # include <vector>
5 # define R register int
6
7 using namespace std;
8
9 const int N=100005;
10 struct edge
11 {
12 int to,nex;
13 };
14 int si,h=0,n,m;
15 edge g[N<<1];
16 int firs[N],a[N],siz[N];
17 bool vis[N];
18 int dp[100000100];
19
20 void add(int u,int v)
21 {
22 g[++h].to=v;
23 g[h].nex=firs[u];
24 firs[u]=h;
25 }
26
27 void dfs(int x)
28 {
29 dp[x*(m+1)+1]=a[x];
30 siz[x]=1;
31 vis[x]=true;
32 int j;
33 for (R i=firs[x];i;i=g[i].nex)
34 {
35 j=g[i].to;
36 if(vis[j]) continue;
37 dfs(j);
38 for (int k=min(siz[x]+siz[j],m);k>=1;--k)
39 for (int z=max(1,k-siz[x]);z<=min(siz[j],k-1);++z)
40 dp[x*(m+1)+k]=max(dp[x*(m+1)+k],dp[x*(m+1)+k-z]+dp[j*(m+1)+z]);
41 siz[x]+=siz[j];
42 }
43 }
44
45 int read()
46 {
47 int x=0;
48 char c=getchar();
49 while (!isdigit(c)) c=getchar();
50 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
51 return x;
52 }
53
54 int main()
55 {
56 scanf("%d%d",&n,&m); m++;
57 memset(g,0,sizeof(g));
58 for (R i=1;i<=n;i++)
59 {
60 si=read(),a[i]=read();
61 add(i,si);
62 add(si,i);
63 }
64 dfs(0);
65 printf("%d",dp[m]);
66 return 0;
67 }
这种做法比较好写,而且还有一个优点,就是它事实上求出了每棵子树的 $dp$ 值,换句话说,它可以统计到每个连通块的答案。当然,它也有一定的局限性,那就是第二维必须和子树的大小有关,否则复杂度就不对了。下面,再来介绍另一种不要求第二维大小的做法。
第二种优化:
首先对树求出后序遍历序,设 $f[i][j]$ 表示:dfs序编号在i之前的点当前都满足依赖条件时的背包;$j$ 表示什么因题目而异;看上去有点难以理解?解释一下,“当前满足依赖条件”是指,在仅考虑前 $i$ 个点构成的森林的情况下,每个点都满足依赖关系(当前已经出现的祖先都被选了,还没出现的祖先不用考虑)。转移方程十分简单,在往森林里一个点时,如果不选它,那它的子树就都不能选,因为它的子树的dfs序是一段连续的区间,我们直接跳回到还没有考虑过这棵子树时的状态;如果选它,那就从上一个点进行转移即可。复杂度显然为 $n\times m$ 。

这种方法比上一种还好写,但是它也有一个问题,那就是只能算出以指定点为根时的答案,而不能做任意联通块。

1 # include <cstdio>
2 # include <iostream>
3 # include <cstring>
4 # include <vector>
5 # define R register int
6
7 using namespace std;
8
9 const int N=100005;
10 struct edge
11 {
12 int to,nex;
13 };
14 int si,h=0,n,m;
15 edge g[N<<1];
16 int firs[N],a[N],siz[N];
17 bool vis[N];
18 int dp[100000100];
19 int no[N],cnt;
20
21 void add(int u,int v)
22 {
23 g[++h].to=v;
24 g[h].nex=firs[u];
25 firs[u]=h;
26 }
27
28 int read()
29 {
30 int x=0;
31 char c=getchar();
32 while (!isdigit(c)) c=getchar();
33 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
34 return x;
35 }
36
37 void dfs (int x)
38 {
39 int j;
40 siz[x]=1;
41 for (R i=firs[x];i;i=g[i].nex)
42 {
43 j=g[i].to;
44 if(vis[j]) continue;
45 vis[j]=1;
46 dfs(j);
47 siz[x]+=siz[j];
48 }
49 no[++cnt]=x;
50 }
51
52 int main()
53 {
54 scanf("%d%d",&n,&m); m++;
55 memset(g,0,sizeof(g));
56 for (R i=1;i<=n;i++)
57 {
58 si=read(),a[i]=read();
59 add(i,si);
60 add(si,i);
61 }
62 vis[0]=1;
63 dfs(0);
64 int x;
65 for (R i=1;i<=cnt;++i)
66 {
67 x=no[i];
68 for (R j=1;j<=m;++j)
69 dp[i*(m+1)+j]=max(dp[(i-1)*(m+1)+j-1]+a[x],dp[(i-siz[x])*(m+1)+j]);
70 }
71 printf("%d",dp[cnt*(m+1)+m]);
72 return 0;
73 }
学习了以上知识后,我们来做一道题?
Shopping:https://www.lydsy.com/JudgeOnline/problem.php?id=4182
这好像是个权限题?那我来概述一下题意:
给定一棵 $n$ 个点的树,每个点上有一种物品 $(w,c,d)$ 表示它的价值是 $w$ ,价格是 $c$ ,有 $d$ 个。你有 $m$ 元钱,并希望它们能买到价值和最大的物品,还有一个限制是买了物品的点必须是树上的一个连通块,求最大价值。$n<=500,m<=4000,d<=100$
一个显然的思路是直接上树形背包的第一种做法,因为它事实上是在每个连通块最高的点处对这个连通块进行了处理,可以直接求出这道题的答案。
不过,别忘了第一种优化的前提,如果你以为它任何条件下都适用,那就会 $TLE$ 得很惨。在这道题中,即使是很小的子树也可以有满的 $dp$ 数组,所以复杂度就是 $O(nm^2)$
看起来第一种做法已经走进死路,让我们来考虑一下第二种做法吧。
第二种做法可以枚举根,复杂度 $n^2m$ ,感觉已经有了不少改进呢!可以发现,枚举根是一个比较愚蠢的方法,因为在做第一次的时候,就已经把所有与这个根有交的连通块都算过了,接下来只需要对每个子树再做就好了。子树大小有可能不平均?点分治!

1 # include <cstdio>
2 # include <iostream>
3 # include <cstring>
4 # include <vector>
5 # define R register int
6
7 using namespace std;
8
9 const int N=502;
10 int T,n,m,h,x,y,cnt,rt,ans,S,d;
11 int firs[N],siz[N],no[N],vis[N],w[N],c[N],maxs[N];
12 int dp[N][4005];
13 struct edge
14 {
15 int too,nex;
16 }g[N<<1];
17 struct thi
18 {
19 int c,w;
20 thi (int a=0,int b=0) { c=a; w=b; }
21 };
22 vector <thi> v[N];
23
24 void add (int x,int y)
25 {
26 g[++h].nex=firs[x];
27 firs[x]=h;
28 g[h].too=y;
29 }
30
31 int read()
32 {
33 int x=0;
34 char c=getchar();
35 while (!isdigit(c)) c=getchar();
36 while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
37 return x;
38 }
39
40 void get_root (int x,int f)
41 {
42 siz[x]=1,maxs[x]=0;
43 int j;
44 for (R i=firs[x];i;i=g[i].nex)
45 {
46 j=g[i].too;
47 if(vis[j]||f==j) continue;
48 get_root(j,x);
49 siz[x]+=siz[j];
50 maxs[x]=max(maxs[x],siz[j]);
51 }
52 maxs[x]=max(maxs[x],S-siz[x]);
53 if(maxs[x]<maxs[rt]) rt=x;
54 }
55
56 void dfs (int x,int f)
57 {
58 int j; siz[x]=1;
59 for (R i=firs[x];i;i=g[i].nex)
60 {
61 j=g[i].too;
62 if(vis[j]||j==f) continue;
63 dfs(j,x);
64 siz[x]+=siz[j];
65 }
66 no[++cnt]=x;
67 }
68
69 void pdc (int x)
70 {
71 cnt=0;
72 dfs(x,0);
73 for (R i=0;i<=cnt;++i)
74 for (R j=0;j<=m;++j)
75 dp[i][j]=0;
76 for (R i=1;i<=cnt;++i)
77 {
78 int a=no[i],vs=v[a].size();
79 for (R j=0;j<=m;++j)
80 dp[i][j]=max(dp[i][j],dp[ i-siz[a] ][j]);
81 for (R k=0;k<vs;++k)
82 for (R j=m;j>=v[a][k].c;--j)
83 dp[i][j]=max(dp[i][j],max(dp[i-1][ j-v[a][k].c ]+v[a][k].w,dp[i][ j-v[a][k].c ]+v[a][k].w));
84 }
85 for (R i=1;i<=m;++i)
86 ans=max(ans,dp[cnt][i]);
87 }
88
89 void solve (int x)
90 {
91 vis[x]=1;
92 pdc(x);
93 int j;
94 for (R i=firs[x];i;i=g[i].nex)
95 {
96 j=g[i].too;
97 if(vis[j]) continue;
98 rt=0; maxs[rt]=n; S=siz[j];
99 get_root(j,0);
100 solve(rt);
101 }
102 }
103
104 void t4182()
105 {
106 n=read(),m=read();
107 memset(firs,0,sizeof(firs));
108 memset(g,0,sizeof(g));
109 memset(vis,0,sizeof(vis));
110 h=0;
111 for (R i=1;i<=n;++i)
112 v[i].clear();
113 for (R i=1;i<=n;++i)
114 w[i]=read();
115 for (R i=1;i<=n;++i)
116 c[i]=read();
117 for (R i=1;i<=n;++i)
118 {
119 d=read();
120 int x=1;
121 while(x<=d)
122 {
123 d-=x;
124 v[i].push_back(thi(c[i]*x,w[i]*x));
125 x<<=1;
126 }
127 if(d>0) v[i].push_back(thi(c[i]*d,w[i]*d));
128 }
129 for (R i=1;i<n;++i)
130 {
131 x=read(),y=read();
132 add(x,y); add(y,x);
133 }
134 rt=0;
135 S=maxs[rt]=n;
136 get_root(1,0);
137 ans=0;
138 solve(rt);
139 printf("%d\n",ans);
140 }
141
142 int main()
143 {
144 scanf("%d",&T);
145 while(T--)
146 t4182();
147 return 0;
148 }
---shzr
来源:https://www.cnblogs.com/shzr/p/11475359.html
