CF506 E Mr. Kitayuta's Gift (计数)

mac2024-08-19  62

题意

求长度为n+m的回文串个数,要求这个回文串包含某个给定的长度为m的串S为子序列。 m ≤ 200 , n ≤ 1 0 9 m\leq200,n\leq10^9 m200,n109

思路

先考虑一个 O ( n m 2 ) O(nm^2) O(nm2)的DP:设 h [ i ] [ j ] [ k ] h[i][j][k] h[i][j][k]表示已经决定了回文串中的第1~i个字符,并且串S从左开始匹配了j个,右开始匹配了k个。这个dp可以用一张带权DAG描述,要求的就是走(n+m+1)/2步后到达终止状态的方案数(_的地方意思是已经匹配) 红点的意思是左右两边下一个待匹配字符不一致,这样有24种选择回到自己。绿点则相反,有25种选择回到自己并且只有一条出边。将所有链的到达方案数求和即为答案。一条链上,绿点x与红点y的个数满足: 2 x + y = n 或 n + 1 2x+y=n或n+1 2x+y=nn+1。之所以会有n+1是因为有可能最后一个点是绿点,并且此时没匹配的字符仅剩一个字符。为了简单起见不特殊处理,直接按照n+1安排。注意有一种情况是非法的:当n+m是奇数时,最后一步不能使得左右两边匹配同时+1.(因为实际上只有一个字符)绿点与红点的顺序是无所谓的,因此先求出每条本质不同的路径有多少条。接下来虽然有 O ( n ) O(n) O(n)种本质不同的路径,但这些路径都长得很像: 做一次矩阵乘法求出上述图中两两到达的方案。每一条路径与某个红点到达某个蓝点的路径相同。至于那种非法的情况,类似在图上求一下方案减去即可。 O ( n 3 ( log ⁡ n + ∑ ) ) O(n^3(\log n+\sum)) O(n3(logn+)) #include <bits/stdc++.h> using namespace std; const int N = 210, mo = 1e4 + 7; char s[N]; int n, m; int f[N][N][N]; int odd, ans; #define add(x, y) ((x) = ((x) + (y)) % mo) void get_count() { f[0][0][0] = 1; for(int x = 0; x < n; x++) { for(int y = 0; x + y < n; y++) { for(int k = 0; 2 * k < n; k++) if(f[x][y][k]) { for(int i = 'a'; i <= 'z'; i++) if ((s[x + 1] == i) || (s[n - y] == i)) { int _x = x + (s[x + 1] == i); int _y = y + (s[n - y] == i); add(f[_x][_y][k + (s[x + 1] == s[n - y])], f[x][y][k]); } } } } } int sz, cnta, cntb; typedef int mat[2 * N][2 * N]; mat a; void mult(mat a, mat b, mat c) { static mat ret; memset(ret, 0, sizeof ret); for(int k = 1; k <= sz; k++) { for(int i = 1; i <= k; i++) if(a[i][k]) { for(int j = k; j <= sz; j++) { add(ret[i][j], a[i][k] * b[k][j]); } } } memcpy(c, ret, sizeof ret); } void ksm(mat x, int y) { static mat ret; memset(ret, 0, sizeof ret); for(int i = 1; i <= sz; i++) ret[i][i] = 1; for(; y; y>>=1) { if (y & 1) mult(ret, x, ret); mult(x, x, x); } memcpy(x, ret, sizeof ret); } void build_graph() { cnta = n, cntb = (n + 1) / 2; memset(a, 0, sizeof a); for(int i = 1; i < cnta + cntb; i++) a[i][i + 1] = 1; for(int i = 1; i <= cnta; i++) a[i][i] = 24; for(int i = cnta + 1; i <= cnta + cntb; i++) { a[i][i] = 25; a[i][i + cntb] = 1; a[i + cntb][i + cntb] = 26; } sz = cnta + cntb + cntb; } void calc() { build_graph(); ksm(a, (n + m + 1) / 2); for(int k = 0; k * 2 <= n + 1; k++) { int sum[2]; sum[0] = sum[1] = 0; for(int x = 0; x <= n + 1; x++) { if (n - x >= 0) add(sum[0], f[x][n - x][k]); add(sum[1], f[x][n + 1 - x][k]); } for(int s = n; s <= n + 1; s++) { int z = s - 2 * k; if (z >= 0) add(ans, a[cnta - z + 1][cnta + k + cntb] * sum[s - n]); } } } void calc2() { build_graph(); ksm(a, (n + m) / 2); for(int k = 1; k * 2 <= n; k++) { int z = n - 2 * k, sum = 0; for(int x = 0; x < n - 1; x++) if (s[x + 1] == s[x + 2]) { add(sum, f[x][n - 2 - x][k - 1]); } if (z >= 0) add(ans, - a[cnta - z + 1][cnta + k] * sum); } } int main() { freopen("e.in", "r", stdin); cin>>s+1>>m; n = strlen(s + 1); odd = (n + m) % 2; get_count(); calc(); if (odd) calc2(); cout << (ans + mo)%mo << endl; }
最新回复(0)