题目链接
https://atcoder.jp/contests/agc035/tasks/agc035_d
题解
想了两小时憋出来一个状压DP,发现人家怎么空间才十几MB,原来暴力就行了。。。
考虑原序列那个操作,我们可以建一个图,一开始有\(n\)个点没有边,每次选一个点向其左右未被选的点加两条边,一个点的贡献次数就是它到左右两个终点的路径数。
那么我们可以枚举最后选的点,把原序列分裂成两个区间,因此使用区间DP. 现在的问题是我们需要方便地统计一个点对总和的贡献。我们观察到,从\([l,r]\)这一层到\([1,n]\)最底层,每次要么拓展左边(左端点\(l\)连向新的\(l'\),\(r\)连向\(l'\)),要么拓展右边,这样总共的过程可以表达为一个长度不超过\((n-1)\)的01
串,且从\(l\)或\(r\)到达\(1\)或\(n\)的方案数可由这个串得到(具体地,维护两个变量\(x=1,y=1\), 拓展左边时\((x,y)\rightarrow (x+y,y)\), 拓展右边时\((x,y)\rightarrow (x,x+y)\))。那么预处理每个01
串对应的系数,就可以快速计算了。最终的DP状态是,\(f[l][r][k][S]\)表示区间\([l,r]\),01
串长度为\(k\),串本身为\(S\).
总共状态数是\(\sum^n_{i=1}2^i(n-i+1)=O(2^nn)\)的,转移需要\(O(n)\)的复杂度(其实远远不满),总时间复杂度\(O(2^nn^2)\).
但是如果我们设\(f[l][r][k][S]\)记忆化的话空间复杂度就变成\(O(2^nn^3)\)了,怎么办?智障的我就把后两维压到一起了,卡着空间限制过了……
然而事实是重复遍历一个状态的情况很少,不记忆化不仅能过,而且速度还快一倍。(如果不记忆化的话不需要预处理数组,把\(S\)串直接改成左右端点分别被算几次就行了)
实测效果如下:
(上:不记忆化 下:记忆化)枯了
代码
记忆化:
#include<bits/stdc++.h> #define llong long long #define mkpr make_pair #define riterator reverse_iterator #define U ((1<<n)-1) using namespace std; inline int read() { int x = 0,f = 1; char ch = getchar(); for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;} for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;} return x*f; } const int N = 18; const llong INF = 1e17; llong f[N+1][N+1][(1<<N)+3]; llong val[(1<<N)+3]; llong a[N+3]; int n; void updmin(llong &x,llong y) {x = min(x,y);} llong dfs(int l,int r,int sta) { if(f[l][r][sta]<INF) {return f[l][r][sta];} if(r-l<=1) return f[l][r][sta]=0ll; for(int i=l+1; i<r; i++) { updmin(f[l][r][sta],dfs(l,i,((sta<<1)|1)&U)+dfs(i,r,(sta<<1)&U)+a[i]*val[sta]); } return f[l][r][sta]; } int main() { scanf("%d",&n); for(int i=0; i<(1<<n)-1; i++) { int len = n; while(i&(1<<len-1)) {len--;} len--; llong x = 1ll,y = 1ll; for(int j=0; j<len; j++) i&(1<<j)?y+=x:x+=y; val[i] = x+y; } for(int i=0; i<n; i++) scanf("%lld",&a[i]); memset(f,10,sizeof(f)); printf("%lld\n",dfs(0,n-1,(1<<n)-2)+a[0]+a[n-1]); return 0; }
不记忆化
#include<bits/stdc++.h> #define llong long long #define mkpr make_pair #define riterator reverse_iterator #define U ((1<<n)-1) using namespace std; inline int read() { int x = 0,f = 1; char ch = getchar(); for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;} for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;} return x*f; } const int N = 18; const llong INF = 1e17; llong a[N+3]; int n; void updmin(llong &x,llong y) {x = min(x,y);} llong dfs(int l,int r,llong x,llong y) { if(r-l<=1) return 0ll; llong ret = INF; for(int i=l+1; i<r; i++) { updmin(ret,dfs(l,i,x,x+y)+dfs(i,r,x+y,y)+a[i]*(x+y)); } return ret; } int main() { scanf("%d",&n); for(int i=0; i<n; i++) scanf("%lld",&a[i]); printf("%lld\n",dfs(0,n-1,1ll,1ll)+a[0]+a[n-1]); return 0; }
来源:https://www.cnblogs.com/suncongbo/p/12297613.html