给你长度为N的一个序列,让你将其分成连续的k段,每段的价值为其中数字种类的个数,求最大价值总和。
首先能想到n^2复杂度的dp
设定dp[i][j]表示到位子i,分成j段的最大价值总和。
dp[i][j]=max( dp[i][j],dp[k][j-1]+val(k+1,i) );k为这个数上一次出现的位置
可以用线段树加速转移。
考虑val(k+1,j).
我们遍历到第j个位子的时候,我们显然树上第k个位子表示的是dp[k][j-1]+val(k+1,i),那么考虑第i个数,它会对区间
(pre[a[i]] ,i-1)区间内的树上的位子有所影响。
那么我们遍历到第i个位子的时候,将树上区间(pre[a[i]],i)的值都+1。
这里pre[a[i]]表示的是a[i]这个数上一次出现的位子。
#include<bits/stdc++.h> using namespace std; const int maxn = 35555; int a[maxn<<2],n,k; int delt[maxn<<2],pre[maxn],pos[maxn]; int dp[maxn][52]; void pushup(int rt) { a[rt]=max(a[rt*2],a[rt*2+1]); } void pushdown(int rt) { delt[rt<<1]+=delt[rt]; delt[rt<<1|1]+=delt[rt]; a[rt<<1]+=delt[rt]; a[rt<<1|1]+=delt[rt]; delt[rt]=0; } void update(int rt,int x,int y,int l, int r, int val) { if (x<=l&&r<=y) { a[rt]+=val; delt[rt]+=val; return; } pushdown(rt); int mid=(l+r)>>1; if (x<=mid) update(rt<<1,x,y,l,mid,val); if(y>mid) update(rt<<1|1,x,y,mid+1,r,val); pushup(rt); } int query(int rt,int x,int y,int l,int r) { if (x<=l&&r<=y) return a[rt]; pushdown(rt); int mid=(l+r)>>1; int ans=0; if(x<=mid) ans=query(rt<<1,x,y,l,mid); if(y>mid) ans=max(ans,query(rt<<1|1,x,y,mid+1,r)); return ans; } int main() { while(~scanf("%d%d",&n,&k)) { memset(pos,0,sizeof(pos)); for(int i=1;i<=n;i++) { int x; scanf("%d",&x); pre[i]=pos[x]; pos[x]=i; } memset(dp,0,sizeof(dp)); for(int i=1;i<=n;i++) { dp[i][1]=dp[i-1][1]; if(!pre[i]) dp[i][1]++; } for(int j=2;j<=k;j++) { memset(a,0,sizeof(a)); memset(delt,0,sizeof(delt)); for(int i=1;i<=n;i++) update(1,i,i,1,n,dp[i][j-1]); for(int i=j;i<=n;i++) { update(1,pre[i],i-1,1,n,1); dp[i][j]=query(1,j-1,i-1,1,n); } } printf("%d\n",dp[n][k]); } return 0; }