牛客-139 I. Substring(后缀数组 or 后缀自动机)

mac2025-10-28  18

牛客-139 I. Substring(后缀数组 or 后缀自动机)

题目链接

题意

一个由 { a , b , c } \{a, b, c\} {a,b,c}组成的字符串 S S S,求S子串的最大的集合,使得集合里的字符串互不同构

思路一:后缀数组

一共有3种字符,所以一共有6中映射关系,首先将字符串转化为6个同构的字符串.

然后求这个字符串一共有多少个不同的字符串

非连续的子串" a a b aab aab",与他同构的有6种并且6种字符串都不相同

连续的子串"aaa",与他同构的有3种

所以连续的子串少算3个,求出不同字符串之后加上少算的字符串刚好是答案的6倍

#include <bits/stdc++.h> const int maxn = 3e5 + 10; const int inf = 0x3f3f3f3f; const int mod = 1e9 + 7; using namespace std; char s[maxn], ss[maxn], t[4] = "ABC"; struct SuffixArray{ // 下标1 int cntA[maxn], cntB[maxn], A[maxn], B[maxn]; int n, Sa[maxn], tSa[maxn], height[maxn], Rank[maxn]; // Sa[i] 排名第i的下标, Rank[i] 下标i的排名 void init(char *buf, int len) { // 下标1,sa,rank,height n = len; for (int i = 0; i < 500; ++i) cntA[i] = 0; for (int i = 1; i <= n; ++i) cntA[(int)buf[i]]++; for (int i = 1; i < 500; ++i) cntA[i] += cntA[i-1]; for (int i = n; i >= 1; --i) Sa[ cntA[(int)buf[i]]-- ] = i; Rank[ Sa[1] ] = 1; for (int i = 2; i <= n; ++i) { Rank[Sa[i]] = Rank[Sa[i-1]]; if (buf[Sa[i]] != buf[Sa[i-1]]) Rank[Sa[i]]++; } for (int l = 1; Rank[Sa[n]] < n; l <<= 1) { for (int i = 0; i <= n; ++i) cntA[i] = 0; for (int i = 0; i <= n; ++i) cntB[i] = 0; for (int i = 1; i <= n; ++i) { cntA[ A[i] = Rank[i] ]++; cntB[ B[i] = (i + l <= n) ? Rank[i+l] : 0]++; } for (int i = 1; i <= n; ++i) cntB[i] += cntB[i-1]; for (int i = n; i >= 1; --i) tSa[ cntB[B[i]]-- ] = i; for (int i = 1; i <= n; ++i) cntA[i] += cntA[i-1]; for (int i = n; i >= 1; --i) Sa[ cntA[A[tSa[i]]]-- ] = tSa[i]; Rank[ Sa[1] ] = 1; for (int i = 2; i <= n; ++i) { Rank[Sa[i]] = Rank[Sa[i-1]]; if (A[Sa[i]] != A[Sa[i-1]] || B[Sa[i]] != B[Sa[i-1]]) Rank[Sa[i]]++; } } for (int i = 1, j = 0; i <= n; ++i) { if (j) --j; int tmp = Sa[Rank[i] - 1]; while (i + j <= n && tmp + j <= n && buf[i+j] == buf[tmp+j]) ++j; height[Rank[i]] = j; } } }S; int main() { int n; while (scanf("%d%s", &n, s) != EOF) { int now = 0; char c = 'Z'; map<char, char> mp; do{ mp['a'] = t[0]; mp['b'] = t[1]; mp['c'] = t[2]; for (int i = 0; i < n; ++i) { ss[++now] = mp[s[i]]; } ss[++now] = ++c; }while (next_permutation(t, t+3)); S.init(ss, now); long long ans = 1ll * (n+1) * n * 3; for (int i = 1; i <= n*6; ++i) { ans -= S.height[i]; } int tmp = 1, cnt = 1; for (int i = 1; i < n; ++i) { if (s[i] == s[i-1]) tmp++; else tmp = 1; cnt = max(cnt, tmp); } ans += cnt * 3; printf("%lld\n", ans/6); } return 0; }

思路二:广义后缀自动机

对6个串建广义自动机,然后类似思路一求不同子串的个数,加上少算的除以6即可

#include <bits/stdc++.h> const int maxn = 3e5 + 5; const int inf = 0x3f3f3f3f; const int mod = 1e9 + 7; using namespace std; struct SAM{ int trans[maxn<<1][3], slink[maxn<<1], maxlen[maxn<<1]; // 用来求endpos int indegree[maxn<<1], endpos[maxn<<1], rank[maxn<<1], ans[maxn<<1]; // 计算所有子串的和(0-9表示) long sum[maxn<<1]; int last, now, root, len; inline void newnode (int v) { maxlen[++now] = v; } inline void init() { root = now = 1; memset(trans, 0, sizeof(trans)); // memset(slink, 0, sizeof(slink)); // memset(maxlen, 0, sizeof(maxlen)); } inline void extend(int c) { if (trans[last][c]) { // 广义自动机 节点合并 节省空间 int p = last, np = trans[last][c]; if (maxlen[last] + 1 == maxlen[np]) last = np; else { int q = trans[p][c]; newnode(maxlen[p]+1); int nq = now; memcpy(trans[nq], trans[q], sizeof(trans[q])); slink[nq] = slink[q]; slink[q] = last = nq; while (p && trans[p][c] == q) { trans[p][c] = nq; p = slink[p]; } } return; } newnode(maxlen[last] + 1); // 新建节点 int p = last, np = now; // 更新trans while (p && !trans[p][c]) { // last的slink只想新节点 trans[p][c] = np; p = slink[p]; } if (!p) slink[np] = root; else { // slink中存在节点有到c的边 int q = trans[p][c]; if (maxlen[p] + 1 != maxlen[q]) { // 判断是否为同一个endpos newnode(maxlen[p] + 1); // 将q点拆出nq,使得maxlen[p] + 1 == maxlen[q] int nq = now; memcpy(trans[nq], trans[q], sizeof(trans[q])); slink[nq] = slink[q]; slink[q] = slink[np] = nq; while (p && trans[p][c] == q) { trans[p][c] = nq; p = slink[p]; } }else slink[np] = q; } last = np; // 初始状态为可接受状态 endpos[np] = 1; } inline void insert(char *buf) { len = strlen(buf); last = 1; for (int i = 0; i < len; ++i) extend(buf[i] - '0'); // extend(s[i] - '1'); } // 求不同的子串种类 inline long long all () { long long ans = 0; for (int i = root+1; i <= now; ++i) { ans += maxlen[i] - maxlen[ slink[i] ]; } return ans; } }sam; char s[maxn], ss[maxn], t[4] = "012"; int main() { int n; while (scanf("%d%s", &n, s) != EOF) { sam.init(); do{ for (int i = 0; i < n; ++i) { ss[i] = t[s[i]-'a']; } sam.insert(ss); }while(next_permutation(t, t+3)); long long ans = sam.all(); int tmp = 1, sum = 1; for (int i = 1; i < n; ++i) { if (s[i] == s[i-1]) ++tmp; else tmp = 1; sum = max(sum, tmp); } ans += sum * 3; printf("%lld\n", ans/6); } return 0; }
最新回复(0)