题意:给你n个字符串,求出在超过一半的字符串中出现的所有子串中最长的子串,按字典序输出。
这道题算是我的一个黑历史了吧,以前我的做法是对这n个字符串建广义后缀自动机,然后在自动机上dfs,交上去AC了,然而事后发现算法假了,出了个数据把自己给hack了...
之前写的太烂了,决定重写一遍。
正确的操作是对n个串倒序建广义后缀自动机,建好以后把每个串放到自动机上跑一遍,把所有覆盖到的状态结点打上标记(每个串只标记一次,用vis判重),记录每个状态在多少个串中出现过,然后在后缀树(fail树)上按字典序dfs一遍就好了。
注意每添加一个字符串,需要把last指向根节点,而且在每次往后添加结点的时候判断当前结点是否存在过,如果存在则需要特殊处理(源自洛谷zcysky大神的思路)
复杂度$O(n\sqrt n)$,但上界很松,跑起来速度还是很快滴~
1 #include<bits/stdc++.h>
2 using namespace std;
3 typedef long long ll;
4 const int N=2e5+10,M=26;
5 int n,ka;
6 char s[105][1010];
7 struct SAM {
8 int fa[N],go[N][M],mxl[N],last,tot,ch[N][M],pos[N],cc[N],nc,vis[N],cnt[N],mx;
9 int newnode(int l,int p) {
10 int u=++tot;
11 mxl[u]=l,pos[u]=p,cnt[u]=0;
12 memset(go[u],0,sizeof go[u]);
13 memset(ch[u],0,sizeof ch[u]);
14 return u;
15 }
16 void init() {tot=nc=0,last=newnode(0,-1);}
17 void add(int ch) {
18 cc[++nc]=ch;
19 int p=last;
20 if(go[p][ch]) {
21 int q=go[p][ch];
22 if(mxl[q]==mxl[p]+1)last=q;
23 else {
24 int nq=newnode(mxl[p]+1,pos[q]);
25 memcpy(go[nq],go[q],sizeof go[q]);
26 fa[nq]=fa[q],fa[q]=nq;
27 for(; p&&go[p][ch]==q; p=fa[p])go[p][ch]=nq;
28 last=nq;
29 }
30 } else {
31 int np=last=newnode(mxl[p]+1,nc);
32 for(; p&&!go[p][ch]; p=fa[p])go[p][ch]=np;
33 if(!p)fa[np]=1;
34 else {
35 int q=go[p][ch];
36 if(mxl[q]==mxl[p]+1)fa[np]=q;
37 else {
38 int nq=newnode(mxl[p]+1,pos[q]);
39 memcpy(go[nq],go[q],sizeof go[q]);
40 fa[nq]=fa[q],fa[q]=fa[np]=nq;
41 for(; p&&go[p][ch]==q; p=fa[p])go[p][ch]=nq;
42 }
43 }
44 }
45 }
46 void dfs(int u) {
47 if(mxl[u]==mx&&cnt[u]>n/2) {
48 for(int i=pos[u]; i>pos[u]-mxl[u]; --i)printf("%c",cc[i]+'a');
49 puts("");
50 }
51 for(int i=0; i<M; ++i)if(ch[u][i])dfs(ch[u][i]);
52 }
53 void run() {
54 for(int i=0; i<n; ++i) {
55 last=1;
56 int l=strlen(s[i]);
57 reverse(s[i],s[i]+l);
58 for(int j=0; j<l; ++j)add(s[i][j]-'a');
59 }
60 memset(vis,-1,sizeof vis);
61 for(int i=0; i<n; ++i)
62 for(int j=0,u=1; s[i][j]; u=go[u][s[i][j]-'a'],++j)
63 for(int v=go[u][s[i][j]-'a']; v!=1&&vis[v]!=i; v=fa[v])vis[v]=i,++cnt[v];
64 mx=-1;
65 for(int i=2; i<=tot; ++i)if(cnt[i]>n/2)mx=max(mx,mxl[i]);
66 for(int i=2; i<=tot; ++i)ch[fa[i]][cc[pos[i]-mxl[fa[i]]]]=i;
67 if(!~mx)puts("?");
68 else {
69 memset(vis,0,sizeof vis);
70 dfs(1);
71 }
72 }
73 } sam;
74 int main() {
75 while(scanf("%d",&n),n) {
76 ka?puts(""):++ka;
77 sam.init();
78 for(int i=0; i<n; ++i)scanf("%s",s[i]);
79 sam.run();
80 }
81 return 0;
82 }
还有一种做法是利用后缀数组。
把这n个串用不同的字符连接在一起求后缀数组,并给每个后缀i赋一个值a[i]表示它是哪个字符串里的。然后对排好序的后缀进行尺取并维护区间不同值的个数,一旦区间不同值的个数>n/2,就输出长度为左右端点lcp的字符串。(需要尺取两次,第一次求出最大长度,第二次输出)
但是这样做可能会有重复的串被输出,怎么去重呢?用哈希固然可以,可有没有优雅一点的做法呢?当然。只要每次输出的时候记录一下当前子串的左端点la,下次准备输出的时候和la求一次lcp,如果lcp=最大长度的话,就跳过。
复杂度$O(nlogn+n)$
1 #include<bits/stdc++.h>
2 using namespace std;
3 typedef long long ll;
4 const int N=1e5+1000,mod=998244353;
5 char buf[N];
6 int s[N],sa[N],buf1[N],buf2[N],c[N],n,m,k,rnk[N],ht[N],ST[N][20],Log[N],a[N],cnt,ka;
7 void Sort(int* x,int* y,int m) {
8 for(int i=0; i<m; ++i)c[i]=0;
9 for(int i=0; i<n; ++i)++c[x[i]];
10 for(int i=1; i<m; ++i)c[i]+=c[i-1];
11 for(int i=n-1; i>=0; --i)sa[--c[x[y[i]]]]=y[i];
12 }
13 void da(int* s,int n,int m=1000) {
14 int *x=buf1,*y=buf2;
15 x[n]=y[n]=-1;
16 for(int i=0; i<n; ++i)x[i]=s[i],y[i]=i;
17 Sort(x,y,m);
18 for(int k=1; k<n; k<<=1) {
19 int p=0;
20 for(int i=n-k; i<n; ++i)y[p++]=i;
21 for(int i=0; i<n; ++i)if(sa[i]>=k)y[p++]=sa[i]-k;
22 Sort(x,y,m),p=1,y[sa[0]]=0;
23 for(int i=1; i<n; ++i)y[sa[i]]=x[sa[i-1]]==x[sa[i]]&&x[sa[i-1]+k]==x[sa[i]+k]?p-1:p++;
24 if(p==n)break;
25 swap(x,y),m=p;
26 }
27 }
28 void getht() {
29 for(int i=0; i<n; ++i)rnk[sa[i]]=i;
30 ht[0]=0;
31 for(int i=0,k=0; i<n; ++i) {
32 if(k)--k;
33 if(!rnk[i])continue;
34 for(; s[i+k]==s[sa[rnk[i]-1]+k]; ++k);
35 ht[rnk[i]]=k;
36 }
37 }
38 void initST() {
39 for(int i=1; i<n; ++i)ST[i][0]=ht[i];
40 for(int j=1; (1<<j)<=n; ++j)
41 for(int i=1; i+(1<<j)-1<n; ++i)
42 ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1]);
43 }
44 int lcp(int l,int r) {
45 if(l==r)return n-sa[l];
46 if(l>r)swap(l,r);
47 l++;
48 int k=Log[r-l+1];
49 return min(ST[l][k],ST[r-(1<<k)+1][k]);
50 }
51 void add(int x,int f) {
52 if(!x)return;
53 if(!c[x])++cnt;
54 if(!(c[x]-=f))--cnt;
55 }
56 int main() {
57 Log[0]=-1;
58 for(int i=1; i<N; ++i)Log[i]=Log[i>>1]+1;
59 while(scanf("%d",&m),m) {
60 if(ka++)puts("");
61 memset(a,0,sizeof a);
62 n=0;
63 for(int i=0; i<m; ++i) {
64 if(i)s[n++]='z'+i;
65 scanf("%s",buf),k=strlen(buf);
66 for(int j=0; j<k; ++j)a[n]=i+1,s[n++]=buf[j];
67 }
68 s[n]=0;
69 da(s,n),getht(),initST();
70 memset(c,0,sizeof c);
71 cnt=0;
72 int mx=0;
73 for(int i=0,j=0; i<n; ++i) {
74 if(!a[sa[i]])break;
75 for(; j<n&&cnt<=m/2; ++j)add(a[sa[j]],1);
76 add(a[sa[i]],-1);
77 mx=max(mx,lcp(i,j-1));
78 }
79 if(!mx)puts("?");
80 else {
81 for(int i=0,j=0,k,la=-1; i<n; ++i) {
82 if(!a[sa[i]])break;
83 for(; j<n&&cnt<=m/2; ++j)add(a[sa[j]],1);
84 if(lcp(i,j-1)==mx) {
85 if(!~la||lcp(la,j-1)!=mx) {
86 for(k=0; k<lcp(i,j-1); ++k)printf("%c",s[sa[i]+k]);
87 puts("");
88 }
89 la=i;
90 }
91 add(a[sa[i]],-1);
92 }
93 }
94 }
95 return 0;
96 }