P4248 [AHOI2013]差异(SAMSA)

mac2024-12-20  6

题意:

给出一个长度为 n n n的字符串,求 ∑ 1 < = i < j < = n l e n ( T i ) + l e n ( T j ) − 2 ∗ l c p ( T i , T j ) \sum_{1<=i<j<=n} len(Ti)+len(Tj)−2∗lcp(Ti,Tj) 1<=i<j<=nlen(Ti)+len(Tj)2lcp(Ti,Tj)

题解:
方法一:

考虑把原串翻转过来,那么要求的就是所有前缀公共后缀了,两个前缀的最长公共后缀的答案就是这两个前缀所代表的点在后缀树上 L C A LCA LCA的长度,题目就转化为求树上某一个点是多少点对的 L C A LCA LCA再乘上 l e n [ L C A ] len[LCA] len[LCA]我们可以直接求出前半部分为: n ( n + 1 ) ( n − 1 ) 2 \frac{n(n+1)(n-1)}{2} 2n(n+1)(n1),然后我们对每个点考虑它作为 L C A LCA LCA的贡献值,再减去就好了。

AC代码:

#include<bits/stdc++.h> using namespace std; typedef long long LL; const int MAXN = 1e6+50; char s[MAXN]; struct Suffix{ int nxt[MAXN][26],fa[MAXN],len[MAXN],endpos[MAXN]; int last=1,tot=1; vector<int> g[MAXN]; LL res; inline void Insert(int x){ int p = last, np = ++tot; last = np,len[np] = len[p] + 1; for(;p && !nxt[p][x];p = fa[p]) nxt[p][x] = np; if(!p) fa[np] = 1; else { int q = nxt[p][x]; if(len[p]+1 == len[q]) fa[np] = q; else { int nq = ++tot; len[nq] = len[p] + 1; memcpy(nxt[nq],nxt[q],sizeof(nxt[q])); fa[nq] = fa[q]; fa[q] = fa[np] = nq; for(;nxt[p][x]==q;p = fa[p]) nxt[p][x] = nq; } } endpos[np] = 1; } inline void dfs(int u){ for(int i=0;i<(int)g[u].size();i++){ int v = g[u][i]; dfs(v); res -= 2LL*endpos[u]*endpos[v]*len[u]; endpos[u] += endpos[v]; } } inline void Solve(int n){ for(int i=2;i<=tot;i++) g[fa[i]].push_back(i); res = 1LL*(n+1)*(n-1)*n/2; dfs(1); printf("%lld\n",res); } }SAM; int main(){ scanf("%s",s+1); int len = strlen(s+1); for(int i=len;i>=1;i--) SAM.Insert(s[i]-'a'); SAM.Solve(len); return 0; }
方法二:

也是构造出反串的 p a r e n t parent parent树,然后根据父子关系dp出每个后缀的累加次数即可

AC代码:

#include<bits/stdc++.h> using namespace std; typedef long long LL; const int MAXN = 1e6+50; char s[MAXN]; struct Suffix{ int nxt[MAXN][26],fa[MAXN],len[MAXN],endpos[MAXN]; int c[MAXN],a[MAXN]; LL dp[MAXN]; int last=1,tot=1; vector<int> g[MAXN]; LL res; inline void Insert(int x){ int p = last, np = ++tot; last = np,len[np] = len[p] + 1; for(;p && !nxt[p][x];p = fa[p]) nxt[p][x] = np; if(!p) fa[np] = 1; else { int q = nxt[p][x]; if(len[p]+1 == len[q]) fa[np] = q; else { int nq = ++tot; len[nq] = len[p] + 1; memcpy(nxt[nq],nxt[q],sizeof(nxt[q])); fa[nq] = fa[q]; fa[q] = fa[np] = nq; for(;nxt[p][x]==q;p = fa[p]) nxt[p][x] = nq; } } endpos[np] = 1; } inline void Solve(int n){ for(int i=1;i<=tot;i++) c[len[i]]++; for(int i=1;i<=tot;i++) c[i]+=c[i-1]; for(int i=1;i<=tot;i++) a[c[len[i]]--]=i; for(int i=1;i<=tot;i++) dp[i]=endpos[i]; for(int i=tot;i;i--) dp[fa[a[i]]] += dp[a[i]]; LL res = 1LL*(n-1)*(n+1)*n/2; for(int i=2;i<=tot;i++) res -= 1LL*(dp[i]-1)*dp[i]*(len[i]-len[fa[i]]); printf("%lld\n",res); } }SAM; int main(){ scanf("%s",s+1); int len = strlen(s+1); for(int i=len;i>=1;i--) SAM.Insert(s[i]-'a'); SAM.Solve(len); return 0; }
方法三:

后缀数组+单调栈

和上面一样,我们只需要求出每个 l c p lcp lcp即可,不难发现,对于排名为 i i i j j j的两个后缀,它们的 l c p lcp lcp应该是 m i n ( h [ i ] , h [ i + 1 ] , h [ i + 2 ] . . . h [ j ] ) min(h[i] ,h[i+1],h[i+2]... h[j]) min(h[i],h[i+1],h[i+2]...h[j]),很明显这样求时间复杂度太大,所以我们可以考虑换个角度考虑,因为如果 h [ k ] h[k] h[k]是一段 m i n ( h [ i ] , h [ i + 1 ] , h [ i + 2 ] . . . h [ j ] ) min(h[i] ,h[i+1],h[i+2]... h[j]) min(h[i],h[i+1],h[i+2]...h[j]) 的最小值,那么我们认为h[k]产生了贡献值。这样就可以用单调栈维护啦!

AC代码:

#include<bits/stdc++.h> using namespace std; typedef long long LL; const int MAXN = 1e6+50; const int INF = 0x3f3f3f3f; int a[MAXN],c[MAXN],rk[MAXN],y[MAXN],sa[MAXN],h[MAXN]; int l[MAXN],r[MAXN]; char s[MAXN]; inline void SA(int n,int m){ for(int i=1;i<=n;i++) rk[i]=s[i],++c[rk[i]]; for(int i=1;i<=m;i++) c[i]+=c[i-1]; for(int i=n;i>=1;i--) sa[c[rk[i]]--]=i; for(int k=1;k<=n;k<<=1){ int num = 0; for(int i=n-k+1;i<=n;i++) y[++num]=i; for(int i=1;i<=n;i++) if(sa[i]>k) y[++num]=sa[i]-k; for(int i=1;i<=m;i++) c[i]=0; for(int i=1;i<=n;i++) ++c[rk[i]]; for(int i=1;i<=m;i++) c[i]+=c[i-1]; for(int i=n;i>=1;i--) sa[c[rk[y[i]]]--]=y[i]; swap(rk,y); rk[sa[1]] = num = 1; for(int i=2;i<=n;i++) rk[sa[i]] = (y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k] ? num : ++num); if(num==n) break; m = num; } for(int i=1;i<=n;i++) rk[sa[i]] = i; int k = 0; for(int i=1;i<=n;i++){ if(rk[i]==1) continue; if(k) --k; int j = sa[rk[i]-1]; while(j+k<=n && i+k<=n && s[j+k]==s[i+k]) ++k; h[rk[i]] = k; } } int main(){ scanf("%s",s+1); int n = strlen(s+1),m = 'z'; SA(n,m); h[0] = h[n+1] = -INF; LL res = 0; for(int i=1;i<=n;i++) l[i]=i-1,r[i]=i+1; for(int i=2;i<=n;i++) while(h[l[i]]>h[i]) l[i]=l[l[i]]; for(int i=n;i>=2;i--) while(h[r[i]]>=h[i]) r[i]=r[r[i]]; for(int i=2;i<=n;i++) res += 2LL*h[i]*(r[i]-i)*(i-l[i]); printf("%lld\n",1LL*(n+1)*(n-1)*n/2-res); return 0; }
最新回复(0)