题目大意是问在$S$串中找区间$[i,j]$,在$T$串中找位置$k$,使得$S[i,j]$和$T[1,k]$可以组成回文串,并且$j-i+1>k$,求这样的三元组$(i,j,k)$的个数。
一开始有点懵,但是仔细一想,因为$j-i+1>k$,所以$S[i,j]$中一定包含了回文串后半段的一部分,即$S[i,j]$中一定有后缀是回文串。
如果回文串是$S[x,j]$,则剩余的$S[i,x-1]$与$T[1,k]$应该也能组成回文串。如果将串$S$倒置,则串$S^{'}$上的原$S[i,x-1]$位置与$T[1,k]$应该相同。
所以解题方式应该比较明了,将串$S$倒置,然后求扩展$kmp$,得到串$S^{'}$每个后缀与串$T$的最长公共前缀。然后对串$S^{'}$构建回文自动机。
可以得到串$S^{'}$每个位置作为回文子串的结尾时的回文串个数。然后枚举串$S^{'}$每个位置$i$,以当前位置作为上文中的$x$,然后计算当前位置对答案的贡献。
1 #include<bits/stdc++.h>
2 using namespace std;
3 typedef long long ll;
4 const int maxn = 1e6 + 100;
5 int Next[maxn];
6 int Ex[maxn];
7 void getN(char* s1) {//求子串与自身匹配
8 int i = 0, j, p, len = strlen(s1);
9 Next[0] = len;
10 while (i + 1 < len && s1[i] == s1[i + 1])
11 i++;
12 Next[1] = i;
13 p = 1;
14 for (i = 2; i < len; i++) {
15 if (Next[i - p] + i < Next[p] + p)
16 Next[i] = Next[i - p];
17 else {
18 j = Next[p] + p - i;
19 if (j < 0)
20 j = 0;
21 while (i + j < len && s1[j] == s1[i + j])
22 j++;
23 Next[i] = j;
24 p = i;
25 }
26 }
27 }
28 void getE(char* s1, char* s2) {//求子串与主串匹配
29 int i = 0, j, p, len1 = strlen(s1), len2 = strlen(s2);
30 while (i < len1 && i < len2 && s1[i] == s2[i])
31 i++;
32 Ex[0] = i;
33 p = 0;
34 for (i = 1; i < len1; i++) {
35 if (Next[i - p] + i < Ex[p] + p)
36 Ex[i] = Next[i - p];
37 else {
38 j = Ex[p] + p - i;
39 if (j < 0)
40 j = 0;
41 while (i + j < len1 && j < len2 && s1[i + j] == s2[j])
42 j++;
43 Ex[i] = j;
44 p = i;
45 }
46 }
47 }
48 struct Palindromic_Tree {
49 int next[maxn][26];//指向的串为当前串两端加上同一个字符构成
50 int fail[maxn];//fail指针,失配后跳转到fail指针指向的节点
51 int cnt[maxn]; //表示节点i表示的本质不同的串的个数,最后用count统计
52 int num[maxn]; //表示节点i表示的最长回文串的最右端点为回文串结尾的回文串个数
53 int len[maxn];//len[i]表示节点i表示的回文串的长度
54 int id[maxn];//表示数组下标i在自动机的哪个位置
55 int S[maxn];
56 int last;//指向上一个字符所在的节点,方便下一次add
57 int n; int p;
58 int newnode(int x) {
59 for (int i = 0; i < 26; ++i) next[p][i] = 0;
60 cnt[p] = 0; num[p] = 0; len[p] = x;
61 return p++;
62 }
63 void init() {//初始化
64 p = 0;
65 newnode(0); newnode(-1);
66 last = 0; n = 0;
67 S[n] = -1;
68 fail[0] = 1;
69 }
70 int get_fail(int x) {//失配后找一个最长的
71 while (S[n - len[x] - 1] != S[n]) x = fail[x];
72 return x;
73 }
74 void add(int x) {
75 S[++n] = x;
76 int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
77 if (!next[cur][x]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
78 int now = newnode(len[cur] + 2);//新建节点
79 id[n - 1] = now;
80 fail[now] = next[get_fail(fail[cur])][x];//建立fail指针,以便失配后跳转
81 next[cur][x] = now;
82 num[now] = num[fail[now]] + 1;
83 }
84 else
85 id[n - 1] = next[cur][x];
86 last = next[cur][x];
87 cnt[last]++;
88 }
89 void count() {
90 for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
91 }
92
93 }a;
94 char s[maxn], s1[maxn], t[maxn];
95 int main() {
96 scanf("%s%s", s, t);
97 int n = strlen(s), m = strlen(t);
98 for (int i = 0; i < n; i++)
99 s1[i] = s[n - i - 1];
100 getN(t);
101 getE(s1, t);
102 a.init();
103 for (int i = 0; i < n; i++)
104 a.add(s1[i] - 'a');
105 a.count();
106 ll ans = 0;
107 for (int i = n - 1; i >= 0; i--) {
108 int w = Ex[i];
109 ans += 1LL * w * a.num[a.id[i - 1]];
110 }
111 printf("%lld\n", ans);
112 }
来源:oschina
链接:https://my.oschina.net/u/4275872/blog/3379868