si+sj中间有一个切割点,我们在t上枚举这个切割点i,即以t[i]作为最后一个字符时求有多少si可以匹配,以t[i+1]作为第一个字符时有多少sj可以匹配
那么对s串正着建一个ac自动机,反着建一个自动机,然后t正反各匹配一次,用sum[]数组记录t[i]作为最后一个字符可以匹配的串数量
注意:求sum数组时,暴力跳fail显然会t,考虑到跳fail是为了统计匹配串的后缀,那么我们在build时,就可以在处理fail指针时就可以把那个fail的end加到now的end上去,这样就避免了暴力跳fail
#include<bits/stdc++.h>
using namespace std;
#define N 200005
struct Trie{
int nxt[N][26],fail[N],end[N];
int root,L;
int newnode(){
memset(nxt[L],-1,sizeof nxt[L]);
end[L]=0;
return L++;
}
void init(){
L++;
root=newnode();
}
void insert(char buf[]){
int len=strlen(buf+1);
int now=root;
for(int i=1;i<=len;i++){
if(nxt[now][buf[i]-'a']==-1)
nxt[now][buf[i]-'a']=newnode();
now=nxt[now][buf[i]-'a'];
}
end[now]++;
}
void build(){
queue<int>q;
fail[root]=root;
for(int i=0;i<26;i++)
if(nxt[root][i]==-1)
nxt[root][i]=root;
else {
fail[nxt[root][i]]=root;
q.push(nxt[root][i]);
}
while(q.size()){
int now=q.front();
q.pop();
for(int i=0;i<26;i++)
if(nxt[now][i]==-1)
nxt[now][i]=nxt[fail[now]][i];
else {
fail[nxt[now][i]]=nxt[fail[now]][i];
end[nxt[now][i]]+=end[nxt[fail[now]][i]];
q.push(nxt[now][i]);
}
}
}
int sum[N];
int query(char buf[]){
int len=strlen(buf+1);
int now=root;
for(int i=1;i<=len;i++){
now=nxt[now][buf[i]-'a'];
sum[i]+=end[now];
}
}
};
char buf[N],t[N];
Trie t1,t2;
int n;
void reserve(char s[]){
int i=1,j=strlen(s+1);
while(i<j){
swap(s[i],s[j]);
++i,--j;
}
}
int main(){
t1.init();
t2.init();
scanf("%s%d",t+1,&n);
for(int i=1;i<=n;i++){
scanf("%s",buf+1);
t1.insert(buf);
reserve(buf);
t2.insert(buf);
}
t1.build();
t2.build();
t1.query(t);
reserve(t);
t2.query(t);
int len=strlen(t+1);
long long ans=0;
for(int i=0;i<len;i++)
ans+=(long long)t1.sum[i]*t2.sum[len-i];
cout<<ans<<'\n';
}