题目链接:HDU - 2896
题意:给你n个模式串,对应每个模式串由编号,给出m个文本串,然后你要输出对应所匹配出模式串序号,以及有多少个文本串中有模式串
思路:比起比较基本统计个数,这里我们可以用set或者map来统计模式串序号
这里总结一下新学习的last优化。 参考博客:(1)
last相当于一个超级fail指针:
因为我们只有到根节点时才会重新匹配一个字母,所以我们此时直接记录一个last ,直接结束当前匹配过程.直接省去原 Fail 指针到可以匹配的节点之间的距离.
同时结合路径压缩,在匹配时可以完全不使用原 Fail.可以看下参考博客里面的图片。
code: 详细注解
#include<bits/stdc++.h>
const int N=100000+5;
using namespace std;
struct AC_automaton{
int trie[N][128];//字典树
int val[N];//字符串结尾标记
int fail[N];//失配指针
int last[N];//last[i]=j表j节点表示的单词是i节点单词的后缀,且j节点是单词节点(last优化)
int tot;//编号
void init(){//初始化0号点
tot=1;
val[0]=fail[0]=last[0]=0;
memset(trie[0],0,sizeof(trie[0]));
}
void insert(char *s,int v){//构造trie与val数组,v需非0,表示一个单词节点
int len=strlen(s);
int root=0;
for(int i=0;i<len;i++){
int id=s[i];
if(trie[root][id]==0){
trie[root][id]=tot;
memset(trie[tot],0,sizeof(trie[tot]));
val[tot++]=0;
}
root=trie[root][id];
}
val[root]=v; //编号
//val[root]++; 个数
}
void build(){//构造fail与last
queue<int> q;
last[0]=fail[0]=0;
//先把第0个部分放进去
for(int i=0;i<128;i++){
int root=trie[0][i];
if(root!=0){
//初始化
fail[root]=0;
last[root]=0;
q.push(root);
}
}
while(!q.empty()){//bfs求fail
int k=q.front();
q.pop();
//ASCII编码范围
for(int i=0;i<128;i++){
int u=trie[k][i];//被取出结点k的子结点
if(u==0)
continue;
q.push(u);
int v=fail[k];//k位置的失配指针
//把子节点改成fail节点的子节点形成一个Trie图
while(v && trie[v][i]==0)
v=fail[v];
fail[u]=trie[v][i];//得到其儿子的失配结点
//last指针表示“在它顶上的fail边所指向的一串节点中,第一个真正的结束节点”
last[u]=val[fail[u]]?fail[u]:last[fail[u]];
}
}
}
void print(int i,set<int> &st){//递归找到存在结点i后缀相同的前缀节点编号
if(val[i]){
if( st.find(i)==st.end() )
st.insert(val[i]);
print(last[i],st);
}
}
void query(char *s,set<int> &st){//匹配
int len=strlen(s);
int j=0;
for(int i=0;i<len;i++){
int id=s[i];
while(j && trie[j][id]==0)
j=fail[j];
j=trie[j][id];
if(val[j])
print(j,st);
else if(last[j])
print(last[j],st);
}
}
}ac;
char P[N];
char T[N];
set<int>::iterator it;
int main(){
int n;
scanf("%d",&n);
ac.init();
for(int i=1;i<=n;i++){
scanf("%s",P);
ac.insert(P,i);
}
ac.build();
int m;
scanf("%d",&m);
int total=0;
for(int i=1;i<=m;i++){
scanf("%s",&T);
set<int> st;//保存文本串已经匹配到的模式串编号
ac.query(T,st);
if(!st.empty()){
total++;
printf("web %d:",i);
for(set<int>::iterator it=st.begin();it!=st.end();it++)
printf(" %d",(*it));
printf("%\n");
}
}
printf("total: %d\n",total);
return 0;
}