[SDOI2016] 生成魔咒
Description
初态串为空,每次在末尾追加一个字符,动态维护本质不同的子串数。
Solution
考虑时间倒流,并将串反转,则变为每次从开头删掉一个字符,即每次从后缀集合中删掉一个后缀。
预处理出后缀数组和高度数组后,用平衡树维护所有后缀集合(按照后缀排序),要删除一个后缀 \(S[sa[p],n]\) 时,找到它在平衡树上的前驱 \(u\) 和后继 \(v\) ,如果都存在,那么这一步的贡献就是
\[(n-sa[p]+1) - Max(h[p],h[v])\]
约定 \(h[p]\) 表示 \(S[sa[p],n]\) 与 \(S[sa[p-1],n]\) 的 LCP 长度。
如果 \(u\) 或 \(v\) 不存在,则当作 \(LCP\) 为零处理,仍然成立。
求 \(LCP\) 可以暴力用 ST 表维护。但考虑到这里每次删除操作最多只会再影响一个元素,我们可以顺便记录一下,即当我们删除 \(p\) 的时候令 \(h[v] = Min(h[p],h[v])\) 即可。
#include <bits/stdc++.h> using namespace std; #define int long long int n,m,sa[1000005],y[1000005],u[1000005],v[1000005],o[1000005],r[1000005],h[1000005],T; int str[1000005]; long long ans; map <int,int> mp; set <int> s; vector <int> an; signed main() { scanf("%lld",&n); m=n; for(int i=1; i<=n; i++) scanf("%lld",&str[i]); reverse(str+1,str+n+1); for(int i=1; i<=n; i++) mp[str[i]]++; int ind=0; for(map<int,int>::iterator it=mp.begin(); it!=mp.end(); ++it) it->second=++ind; for(int i=1; i<=n; i++) str[i]=mp[str[i]]; for(int i=1; i<=n; i++) u[str[i]]++; for(int i=1; i<=m; i++) u[i]+=u[i-1]; for(int i=n; i>=1; i--) sa[u[str[i]]--]=i; r[sa[1]]=1; for(int i=2; i<=n; i++) r[sa[i]]=r[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]); for(int l=1; r[sa[n]]<n; l<<=1) { memset(u,0,sizeof u); memset(v,0,sizeof v); memcpy(o,r,sizeof r); for(int i=1; i<=n; i++) u[r[i]]++, v[r[i+l]]++; for(int i=1; i<=n; i++) u[i]+=u[i-1], v[i]+=v[i-1]; for(int i=n; i>=1; i--) y[v[r[i+l]]--]=i; for(int i=n; i>=1; i--) sa[u[r[y[i]]]--]=y[i]; r[sa[1]]=1; for(int i=2; i<=n; i++) r[sa[i]]=r[sa[i-1]]+((o[sa[i]]!=o[sa[i-1]])||(o[sa[i]+l]!=o[sa[i-1]+l])); } { int i,j,k=0; for(int i=1; i<=n; h[r[i++]]=k) for(k?k--:0,j=sa[r[i]-1]; str[i+k]==str[j+k]; k++); } ans=(long long)n*(long long)(n+1)/(long long)2; for(int i=1; i<=n; i++) ans-=(long long)h[i]; an.push_back(ans); for(int i=1; i<=n; i++) s.insert(i); for(int i=1; i<=n; i++) { int p=r[i],u=0,v=0; set<int>::iterator it,it1,it2; it=s.find(p); it1=it; it2=it; if(it1!=s.begin()) { --it1; u=*it1; } if(it2!=s.end()) { ++it2; if(it2!=s.end()) v=*it2; } int tmp=max(h[p],h[v]); ans -= n-i+1 - tmp; h[v]=min(h[p],h[v]); s.erase(it); an.push_back(ans); } for(int i=n-1; i>=0; --i) printf("%lld\n",an[i]); }