题意
求长度为n+m的回文串个数,要求这个回文串包含某个给定的长度为m的串S为子序列。
m
≤
200
,
n
≤
1
0
9
m\leq200,n\leq10^9
m≤200,n≤109
思路
先考虑一个
O
(
n
m
2
)
O(nm^2)
O(nm2)的DP:设
h
[
i
]
[
j
]
[
k
]
h[i][j][k]
h[i][j][k]表示已经决定了回文串中的第1~i个字符,并且串S从左开始匹配了j个,右开始匹配了k个。这个dp可以用一张带权DAG描述,要求的就是走(n+m+1)/2步后到达终止状态的方案数(_的地方意思是已经匹配) 红点的意思是左右两边下一个待匹配字符不一致,这样有24种选择回到自己。绿点则相反,有25种选择回到自己并且只有一条出边。将所有链的到达方案数求和即为答案。一条链上,绿点x与红点y的个数满足:
2
x
+
y
=
n
或
n
+
1
2x+y=n或n+1
2x+y=n或n+1。之所以会有n+1是因为有可能最后一个点是绿点,并且此时没匹配的字符仅剩一个字符。为了简单起见不特殊处理,直接按照n+1安排。注意有一种情况是非法的:当n+m是奇数时,最后一步不能使得左右两边匹配同时+1.(因为实际上只有一个字符)绿点与红点的顺序是无所谓的,因此先求出每条本质不同的路径有多少条。接下来虽然有
O
(
n
)
O(n)
O(n)种本质不同的路径,但这些路径都长得很像: 做一次矩阵乘法求出上述图中两两到达的方案。每一条路径与某个红点到达某个蓝点的路径相同。至于那种非法的情况,类似在图上求一下方案减去即可。
O
(
n
3
(
log
n
+
∑
)
)
O(n^3(\log n+\sum))
O(n3(logn+∑))
#include <bits/stdc++.h>
using namespace std
;
const int N
= 210, mo
= 1e4 + 7;
char s
[N
];
int n
, m
;
int f
[N
][N
][N
];
int odd
, ans
;
#define add(x, y) ((x) = ((x) + (y)) % mo)
void get_count() {
f
[0][0][0] = 1;
for(int x
= 0; x
< n
; x
++) {
for(int y
= 0; x
+ y
< n
; y
++) {
for(int k
= 0; 2 * k
< n
; k
++) if(f
[x
][y
][k
]) {
for(int i
= 'a'; i
<= 'z'; i
++)
if ((s
[x
+ 1] == i
) || (s
[n
- y
] == i
)) {
int _x
= x
+ (s
[x
+ 1] == i
);
int _y
= y
+ (s
[n
- y
] == i
);
add(f
[_x
][_y
][k
+ (s
[x
+ 1] == s
[n
- y
])],
f
[x
][y
][k
]);
}
}
}
}
}
int sz
, cnta
, cntb
;
typedef int mat
[2 * N
][2 * N
];
mat a
;
void mult(mat a
, mat b
, mat c
) {
static mat ret
;
memset(ret
, 0, sizeof ret
);
for(int k
= 1; k
<= sz
; k
++) {
for(int i
= 1; i
<= k
; i
++) if(a
[i
][k
]) {
for(int j
= k
; j
<= sz
; j
++) {
add(ret
[i
][j
], a
[i
][k
] * b
[k
][j
]);
}
}
}
memcpy(c
, ret
, sizeof ret
);
}
void ksm(mat x
, int y
) {
static mat ret
;
memset(ret
, 0, sizeof ret
);
for(int i
= 1; i
<= sz
; i
++) ret
[i
][i
] = 1;
for(; y
; y
>>=1) {
if (y
& 1) mult(ret
, x
, ret
);
mult(x
, x
, x
);
}
memcpy(x
, ret
, sizeof ret
);
}
void build_graph() {
cnta
= n
, cntb
= (n
+ 1) / 2;
memset(a
, 0, sizeof a
);
for(int i
= 1; i
< cnta
+ cntb
; i
++)
a
[i
][i
+ 1] = 1;
for(int i
= 1; i
<= cnta
; i
++) a
[i
][i
] = 24;
for(int i
= cnta
+ 1; i
<= cnta
+ cntb
; i
++) {
a
[i
][i
] = 25;
a
[i
][i
+ cntb
] = 1;
a
[i
+ cntb
][i
+ cntb
] = 26;
}
sz
= cnta
+ cntb
+ cntb
;
}
void calc() {
build_graph();
ksm(a
, (n
+ m
+ 1) / 2);
for(int k
= 0; k
* 2 <= n
+ 1; k
++) {
int sum
[2]; sum
[0] = sum
[1] = 0;
for(int x
= 0; x
<= n
+ 1; x
++) {
if (n
- x
>= 0) add(sum
[0], f
[x
][n
- x
][k
]);
add(sum
[1], f
[x
][n
+ 1 - x
][k
]);
}
for(int s
= n
; s
<= n
+ 1; s
++) {
int z
= s
- 2 * k
;
if (z
>= 0)
add(ans
, a
[cnta
- z
+ 1][cnta
+ k
+ cntb
] * sum
[s
- n
]);
}
}
}
void calc2() {
build_graph();
ksm(a
, (n
+ m
) / 2);
for(int k
= 1; k
* 2 <= n
; k
++) {
int z
= n
- 2 * k
, sum
= 0;
for(int x
= 0; x
< n
- 1; x
++)
if (s
[x
+ 1] == s
[x
+ 2]) {
add(sum
, f
[x
][n
- 2 - x
][k
- 1]);
}
if (z
>= 0)
add(ans
, - a
[cnta
- z
+ 1][cnta
+ k
] * sum
);
}
}
int main() {
freopen("e.in", "r", stdin);
cin
>>s
+1>>m
;
n
= strlen(s
+ 1); odd
= (n
+ m
) % 2;
get_count();
calc();
if (odd
) calc2();
cout
<< (ans
+ mo
)%mo
<< endl
;
}