CF932G

PAM,回文自动机

题意

给定一个串 SS2S1062\le|S|\le 10^6),把串分为偶数段,假设分为了s1,s2,s3,,sks_1,s_2,s_3,\dots,s_k 求,满足s1=sk,s2=sk1s_1=s_k,s_2=s_{k-1}\dots 的方案数。

题解

n=Sn=|S|,设 dpidp_{i} 表示搞定了 S[1:i]S[1:i] 的方案数,显然有:

dpij<idpj[S[j+1:i]=S[ni+1:nj]]dp_{i}\gets\sum_{j<i} dp_{j}\big[S[j+1:i]=S[n-i+1:n-j]\big]

然而这看着就不可做。

我们重新构造字符串 T=S1SnS2Sn1T=S_1S_nS_{2}S_{n-1}\dots

考虑我们选了 S[i:j]S[i:j] 段就必须满足 S[i,j]=S[nj+1,ni+1]S[i,j]=S[n-j+1,n-i+1] 。等价于在 T[2i1,2j]T[2i-1,2j] 是个回文串。

那么原问题等价转化为求 TT 偶长回文划分方案数。dpidp_i 表示 T[1:i]T[1:i] 偶长回文划分方案数,O(n2)O(n^2) 转移显然:

dpij<idpj[T[j+1:i] 是一个偶回文串]dp_i \gets \sum_{j<i} dp_{j}\big[\texttt{$T[j+1:i]$ 是一个偶回文串}\big]

其实不用判断选的那串是否是偶回文,只要在偶数位置更新 dpdp 值即可。

这类似于「最小回文划分」(回文树 oi-wiki)。利用性质「 ss 的回文后缀长度可以划分成 logs\log|s| 段等差数列」在 PAM 上跳 slink 的次数是 log\log 级别的。

CODE

还是看代码吧。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1000010;
const int mod = 1e9+7;
template <typename Tp> void Add(Tp &x,Tp y) {x=(x+y)%mod;}

namespace PAM {
const int MAX = 1000010;
int tr[MAX][26],fail[MAX],len[MAX],diff[MAX],slink[MAX];
int sz,tot,lst;
char s[MAX];
int newnode(int l) {
++sz;len[sz]=l;fail[sz]=diff[sz]=slink[sz]=0;
memset(tr[sz],0,sizeof(tr[sz]));
return sz;
}
void init() {
sz=-1;lst=0;s[tot=0]='$';
newnode(0);newnode(-1);fail[0]=1;
}
int getfail(int x) {
while(s[tot]!=s[tot-len[x]-1]) x=fail[x];
return x;
}
void insert(char c) {
s[++tot]=c;
int cur=getfail(lst);
if(!tr[cur][c-'a']) {
int now=newnode(len[cur]+2);
fail[now]=tr[getfail(fail[cur])][c-'a'];
tr[cur][c-'a']=now;
diff[now]=len[now]-len[fail[now]];
if(diff[now]==diff[fail[now]]) slink[now]=slink[fail[now]];
else slink[now]=fail[now];
}
lst=tr[cur][c-'a'];
}
ll solve(char *s,int n) {
static ll g[N],dp[N];
init();
memset(g,0,sizeof(g));
memset(dp,0,sizeof(dp));
dp[0]=1;
for(int i=1;i<=n;++i) {
insert(s[i]);
for(int x=lst;x>1;x=slink[x]) {
g[x]=dp[i-len[slink[x]]-diff[x]];
if(diff[x]==diff[fail[x]]) Add(g[x],g[fail[x]]);
if(i%2==0) Add(dp[i],g[x]);
}
}
return dp[n];
}
}
char s[N],t[N];
int n;

int main() {
cin>>(s+1);
n=strlen(s+1);
if(n&1) {puts("0");return 0;}
for(int i=1;i<=n/2;++i) {t[i*2-1]=s[i];t[i*2]=s[n-i+1];}
cout<<PAM::solve(t,n)<<endl;
return 0;
}