2019牛客多校第四场I string(SAM PAM)

mac2025-11-13  5

题目描述: 给出一个字符串 求该字符串本质不同且不互为反串的子串的数目。 abac abac,b,a,ab,aba,bac,ac,c ab和ba互为反串只算一个 这个题目是SAM+PAM的板子题 SAM有两种使用的方法。 一: 输入原串然后输入一个特殊字符,然后输入原串反串。 这种我们在计算时候 a n s = ( s a m . s u m − 1 L L ∗ ( l e n + 1 ) ∗ ( l e n + 1 ) + p ) / 2 ans=(sam.sum-1LL*(len+1)*(len+1)+p)/2 ans=(sam.sum1LL(len+1)(len+1)+p)/2 二: 输入原串,sam.last置为1,输入原串的反串。 a n s = ( s a m . s u m + p ) / 2 ans=(sam.sum+p)/2 ans=(sam.sum+p)/2

PS: sam.sum为sam中本质不同的子串的数目 (len+1)*(len+1)是特殊字符贡献的子串的数目。 p为本质不同的回文子串的数目

详见代码及注释。

#include<bits/stdc++.h> using namespace std; namespace SAM { const int N_CHAR = 27; const int maxn = 2e5 + 50; struct Node { int nxt[N_CHAR], fail; int len; // Max Length of State int pos; // Appear Position of State, Indexed From 1 int cnt; // Appear Count of State }node[maxn * 4]; int numn, last, root; inline int newNode(int l, int p) { int x = ++numn; for (int i = 0; i < N_CHAR; i++) node[x].nxt[i] = 0; node[x].cnt = node[x].fail = 0; node[x].len = l; node[x].pos = p; return x; } inline void init() { root = last = newNode(numn = 0, 0); } inline void addChar(int c) { int p = last, np = newNode(node[p].len + 1, node[p].len + 1); while (p && node[p].nxt[c] == 0) node[p].nxt[c] = np, p = node[p].fail; if (p == 0) node[np].fail = root; else { int q = node[p].nxt[c]; if (node[p].len + 1 == node[q].len) { node[np].fail = q; } else { int nq = newNode(node[p].len + 1, node[q].pos); for (int i = 0; i < N_CHAR; i++) node[nq].nxt[i] = node[q].nxt[i]; node[nq].fail = node[q].fail; node[q].fail = node[np].fail = nq; while (p && node[p].nxt[c] == q) node[p].nxt[c] = nq, p = node[p].fail; } } last = np; node[np].cnt = 1; } } using namespace SAM; struct Pam { int nxt[maxn][30], fail[maxn], len[maxn], s[maxn], last, n, p; inline int newnode(int l) { memset(nxt[p], 0, sizeof(nxt[p])); len[p] = l; return p++; } void init() { p = 0; newnode(0), newnode(-1); last = n = 0; s[0] = -1; fail[0] = 1; } int getfail(int x) { while (s[n - len[x] - 1] != s[n]) x = fail[x]; return x; } void add(int c) { s[++n] = c; int cur = getfail(last); if (!nxt[cur][c]) { int now = newnode(len[cur] + 2); fail[now] = nxt[getfail(fail[cur])][c]; nxt[cur][c] = now; } last = nxt[cur][c]; } }pam; int main() { char s[maxn]; scanf("%s", s); init(); int len = strlen(s); for (int i = 0; i < len; i++) addChar(s[i] - 'a'); addChar(26);//方法2把这句改为last=1; for (int i = len - 1; i >= 0; i--)addChar(s[i] - 'a'); long long ans = 0; for (int i = numn; i >= 0; i--) { ans = ans + node[i].len - (node[node[i].fail].len); } pam.init(); for (int i = 0; i < len; i++) pam.add(s[i] - 'a'); int p = pam.p - 2; ans = ans - 1LL*(len + 1)*(len + 1) + p;//方法二改为ans=ans+p; printf("%lld\n", ans / 2); return 0; }
最新回复(0)