应该算是远古时期的一道题了吧,不过感觉挺经典的。
题意是给出三一个字符串s,a,b,求以a开头b结尾的本质不同的字符串数。
由于n不算大,用hash就可以搞,不过这道题是存在复杂度$O(nlogn)$的做法的。
由于要求本质不同,所以可以考虑使用后缀数组来不重复地枚举字符串。
首先用两个不同的其他字符将s,a,b拼起来求后缀数组,这样就可以知道任意两个后缀的lcp了。然后将s中所有b出现的末尾位置置1,求个后缀和suf。将s中所有后缀按名次从小到大存到一个vector里。对于s中的每个后缀,设其名次为x,a的长度为la,b的长度为lb,若$lcp(x,rnk[ia])=la$,则其对答案的贡献为$suf[sa[i]+max(lcp(x,vec[i-1]),la-1,lb-1)]$。其中la-1和lb-1是为了保证字符串长度比a和b都大,lcp(x,vec[i-1])是为了保证不重复枚举,相当于没有将s与a,b拼起来时的height[x]。特别地,当x=0时为0。
1 #include<bits/stdc++.h>
2 using namespace std;
3 typedef long long ll;
4 const int N=1e4+10;
5 char buf[N];
6 int s[N],sa[N],buf1[N],buf2[N],c[N],n,rnk[N],ht[N],ST[N][20],Log[N],ia,ib,m,la,lb,suf[N];
7 void Sort(int* x,int* y,int m) {
8 for(int i=0; i<m; ++i)c[i]=0;
9 for(int i=0; i<n; ++i)++c[x[i]];
10 for(int i=1; i<m; ++i)c[i]+=c[i-1];
11 for(int i=n-1; i>=0; --i)sa[--c[x[y[i]]]]=y[i];
12 }
13 void da(int* s,int n,int m=1000) {
14 int *x=buf1,*y=buf2;
15 s[n]=x[n]=y[n]=-1;
16 for(int i=0; i<n; ++i)x[i]=s[i],y[i]=i;
17 Sort(x,y,m);
18 for(int k=1; k<n; k<<=1) {
19 int p=0;
20 for(int i=n-k; i<n; ++i)y[p++]=i;
21 for(int i=0; i<n; ++i)if(sa[i]>=k)y[p++]=sa[i]-k;
22 Sort(x,y,m),p=1,y[sa[0]]=0;
23 for(int i=1; i<n; ++i)y[sa[i]]=x[sa[i-1]]==x[sa[i]]&&x[sa[i-1]+k]==x[sa[i]+k]?p-1:p++;
24 if(p==n)break;
25 swap(x,y),m=p;
26 }
27 }
28 void getht() {
29 for(int i=0; i<n; ++i)rnk[sa[i]]=i;
30 ht[0]=0;
31 for(int i=0,k=0; i<n; ++i) {
32 if(k)--k;
33 if(!rnk[i])continue;
34 for(; s[i+k]==s[sa[rnk[i]-1]+k]; ++k);
35 ht[rnk[i]]=k;
36 }
37 }
38 void initST() {
39 for(int i=1; i<n; ++i)ST[i][0]=ht[i];
40 for(int j=1; (1<<j)<=n; ++j)
41 for(int i=1; i+(1<<j)-1<n; ++i)
42 ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1]);
43 }
44 int lcp(int l,int r) {
45 if(l==r)return n-sa[l];
46 if(l>r)swap(l,r);
47 l++;
48 int k=Log[r-l+1];
49 return min(ST[l][k],ST[r-(1<<k)+1][k]);
50 }
51 vector<int> vec;
52 int main() {
53 Log[0]=-1;
54 for(int i=1; i<N; ++i)Log[i]=Log[i>>1]+1;
55 scanf("%s",buf),m=strlen(buf);
56 for(int i=0; i<m; ++i)s[n++]=buf[i];
57 s[n++]='z'+1,ia=n;
58 scanf("%s",buf),m=strlen(buf),la=m;
59 for(int i=0; i<m; ++i)s[n++]=buf[i];
60 s[n++]='z'+2,ib=n;
61 scanf("%s",buf),m=strlen(buf),lb=m;
62 for(int i=0; i<m; ++i)s[n++]=buf[i];
63 s[n]=0;
64 da(s,n),getht(),initST();
65 for(int i=0; i<ia-1; ++i)if(lcp(rnk[i],rnk[ib])==lb)suf[i+lb-1]=1;
66 for(int i=ia-3; i>=0; --i)suf[i]+=suf[i+1];
67 for(int i=0; i<ia-1; ++i)vec.push_back(rnk[i]);
68 sort(vec.begin(),vec.end());
69 int ans=0;
70 for(int i=0; i<vec.size(); ++i) {
71 int x=vec[i];
72 if(lcp(x,rnk[ia])==la)ans+=suf[sa[x]+max(i?lcp(x,vec[i-1]):0,max(la-1,lb-1))];
73 }
74 printf("%d\n",ans);
75 return 0;
76 }