树形DP有一个独特的优化,就是通过递归,枚举目前有效的元素个数,求dp[ i ][ j ] (表示 选取以i为根的子树中有选取j个元素的最大取值)
(搭配 siz 数组表示当前该节点的总共子孙数)
1.hdu1561(树形依赖背包裸题)
注意 siz 数组的运用,以及 u 点选择的节点数时要逆向枚举,就像01背包
复杂度看似O(n^3),实际是 O( n^2 ) 左右。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 250;
vector<int> g[maxn];
int dp[maxn][maxn];
int val[maxn];
int siz[maxn];
int n,m;
//dp[i][j] 表示 选取以i为根的子树中有选取j个元素的最大取值
void dfs(int u){
siz[u]=1;
dp[u][1] = val[u];
for(int i=0; i<g[u].size(); i++){
int v = g[u][i];
dfs(v); //这里的siz[u]不包括siz[v] ,并且是把效率很低的2^n举法用01背包来做
for(int i=siz[u]; i>=1; i--){ //这里就像01背包里,避免由这个点的情况递推这个点的更佳情况
for(int j=1; j<=siz[v]&&i+j<=m; j++){ //就比如要避免刚刚还说是从v取3个点推出的最优
dp[u][i+j] = max(dp[u][i+j], dp[u][i]+dp[v][j]); //后面又从前面的dp值而只从j中取1个点得出错误的更优解
}
}
siz[u] += siz[v];
}
}
int main(){
while(scanf("%d%d",&n,&m)!=EOF){
if(n==0&&m==0) break;
for(int i=0; i<=n; i++){
for(int j=0; j<=n; j++){
dp[i][j] = 0;
}
g[i].clear();
}
int t;
for(int i=1; i<=n; i++){
scanf("%d%d",&t,val+i);
g[t].push_back(i);
}
m++;
dfs(0);
printf("%d\n", dp[0][m]);
}
}
2.codeforces 815C (树形dp)
这个选取树上物品可以不需要有父子关系的,但使用优惠券和父子关系有关,所以可以把 dp数组多增加一维,表示是否能够使用优惠券。
只需要设置默认值为 inf ,再这样初始化:
dp[u][0][0]=0;
dp[u][1][0]=c[u];
dp[u][1][1]=c[u]-d[u];
就可以在枚举时考虑到 0 这个元素。

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int maxn = 5005;
int dp[maxn][maxn][2]; //dp[i][j]表示以i为根的子树中取j个元素的最大值
vector<int> g[maxn]; //再来一维表示是否购买根节点i这个元素,也就是用不用优惠券
int val[maxn],d[maxn];
int siz[maxn];
void dfs(int u){
siz[u] = 1;
dp[u][1][1] = val[u]-d[u];
dp[u][1][0] = val[u];
dp[u][0][0] = 0;
for(int i=0; i<g[u].size(); i++){
int v=g[u][i];
dfs(v); //这里的siz[u]不包括siz[v]
for(int i=siz[u]; i>=0; i--){ //这里的0是为了处理可以不取
for(int j=0; j<=siz[v]; j++){
dp[u][i+j][0] = min(dp[u][i+j][0], dp[u][i][0]+dp[v][j][0]);
dp[u][i+j][1] = min(dp[u][i+j][1], dp[u][i][1]+min(dp[v][j][0],dp[v][j][1]));
}
}
siz[u] += siz[v];
}
}
int main(){
int n,b;
scanf("%d%d",&n,&b);
scanf("%d%d",val+1,d+1);
for(int i=2; i<=n; i++){
int t;
scanf("%d%d%d",val+i,d+i,&t);
g[t].push_back(i);
}
memset(dp,0x3f,sizeof(dp));
dfs(1);
int ans=n;
while(dp[1][ans][1]>b&&dp[1][ans][0]>b){
ans--;
//printf("%d %d\n",dp[1][ans][1],dp[1][ans][0] );
}
printf("%d\n", ans);
}
