luogu P4081 感觉不到黑题的难度,做完这道题目,对SAM添加字符操作有了更加深刻的理解。 题目描述: 给你一个整数n,然后给你n个串,要求求出n个串每个的只属于其的本质不同的非空字串的数量。 所有串的总长度不超过1e5,n不超过1e5 不知道每条串的最大长度,1e5*1e5的数组太大 可以用一个1e5的数组存 然后记录下每条串的长度 如果想用SAM来做这道题目我们要清楚SAM的原理 要了解SAM添加字符的时候到底在干嘛 (搞懂之后就知道广义SAM为啥成立了)
inline void addChar(int c){ int p=last; int np=newNode(node[p].len+1, node[p].pos+1); while(p && node[p].nxt[c]==0) node[p].nxt[c]=np, p=node[p].fail; if (p==0) node[np].fail=root; else{ int q=node[p].nxt[c]; if (node[p].len+1 == node[q].len){ node[np].fail=q; } else { int nq=newNode(node[p].len+1, node[q].pos); for (int i=0; i<kind; i++) node[nq].nxt[i]=node[q].nxt[i]; node[nq].fail=node[q].fail; node[q].fail=node[np].fail=nq; while(p && node[p].nxt[c]==q) node[p].nxt[c]=nq, p=node[p].fail; } } last=np; node[np].cnt=1; }首先找到尾结点,然后必定要新建一个节点(np),这个节点(np)也将成为新的尾结点(添加最后将last置为np),它保证了所有节点的最大长度等于加入字符的总数量(可以这样理解,每一个曾被置为last的节点,都是添加这个字符时,当时最长长度最大的节点),这个节点的父亲节点如果没有一条边(边为当前添加字符)那就为其父亲节点添加,并向上遍历父亲节点的父亲节点,直到有一个节点有一条边(边为当前添加字符)或者遍历到0节点。 例如下面: SAM中的已经添加的串为ababa (1)如果要向其中添加b ab已经出现过,所以向上遍历父亲节点,会出现一个节点有代表b的一条边连向其他点。 (2)如果要向其中添加c ac没有出现过,那么会一直向上遍历父亲节点,直到父亲节点=0(根节点为1,父亲节点等于0相当于结束) 如果向上沿着fail树遍历的节点p等于0 那np的fail就是root 且明显不需要新建节点(因为相当于重新加入了一个从未出现过的字符) 当p!=0时 就需要分情况讨论了 令q为p节点沿添加字符的边走向的节点 如果p.len+1等于q.len根据后缀自动机的定义 这两个节点其实也是一样的 不需要新建节点 但如果不等于 我们就需要新建节点了 其实出现这种情况就是因为p表示的长度要大于q这个节点的长度+1 例子:q节点表示的最长串为ababa p节点表示的最长串为cccababac 我们需要新建一个节点nq使其表示的最长长度的串为ababac,且q以及其父亲节点经过新添加字符的边走向nq,并将p节点的父亲置为nq。(当我们操作多个字符串时,如果有这种情况新建的一个点nq,且p属于其他的串时,那这个节点nq所代表的所有子串就没有了贡献)。 理解了SAM添加字符时候的情况 我们将所有的串存入SAM(每次新串输入last置为1),这个是广义SAM和SAM的区别,因为我们再把last置为1后,之后申请的节点就如果和之前的串有公共子串,那么该子串对应的状态节点和我们之前前一条串申请的节点之间,会有fail边连接。 然后我们设置一个访问数组vis(大小为字符串长度的两倍(SAM性质)) 这个数组记录了谁访问过该节点 然后把每个串再跑一遍SAM因为之前输入过,所以串每个字符跑的时候一定有匹配节点,我们把当前状态节点s,以及当先串是第几串a,然后沿着fail树遍历s的父亲和祖先,如果节点未被访问过,将其vis置为a,如果其被其他节点访问过,将其vis置为-1。如果vis已经为-1了,那就不需要向上遍历了,因为当前节点已经被至少两个串访问过,其祖先节点必定也被至少两个串访问过。 之后我们只要找到vis值不是-1的节点,然后对其代表的子串的数量进行统计然后输出即可。
#include<bits/stdc++.h> using namespace std; typedef long long ll; namespace SAM { const int maxn=2e5; const int kind=26; struct Node{ int nxt[kind], fail; int len; // Max Length of State int pos; // Appear Position of State, Indexed From 1 int cnt; // Appear Count of State }node[maxn*2]; int numn, last, root; inline int newNode(int l, int p){ int x=++numn; for (int i=0; i<kind; i++) node[x].nxt[i]=0; node[x].cnt=node[x].fail=0; node[x].len=l; node[x].pos=p; return x; } inline void init(){ root=last=newNode(numn=0, 0); } inline void addChar(int c){ int p=last; int np=newNode(node[p].len+1, node[p].pos+1); while(p && node[p].nxt[c]==0) node[p].nxt[c]=np, p=node[p].fail; if (p==0) node[np].fail=root; else{ int q=node[p].nxt[c]; if (node[p].len+1 == node[q].len){ node[np].fail=q; } else { int nq=newNode(node[p].len+1, node[q].pos); for (int i=0; i<kind; i++) node[nq].nxt[i]=node[q].nxt[i]; node[nq].fail=node[q].fail; node[q].fail=node[np].fail=nq; while(p && node[p].nxt[c]==q) node[p].nxt[c]=nq, p=node[p].fail; } } last=np; node[np].cnt=1; } } using namespace SAM; long long ans[maxn]; char s[maxn],ss[maxn]; int len[maxn],vis[maxn]; int n; inline void cla(int x,int y){ for(;x&&vis[x]!=y&&vis[x]!=-1;x=node[x].fail){ if(vis[x]!=0)vis[x]=-1; else vis[x]=y; } } void solve(){ int tot=0; for(int i=1;i<=n;i++){ for(int j=1,x=1;j<=len[i];j++){ int v=s[tot]-'a'; tot++; cla(x=node[x].nxt[v],i); } } for(int i=1;i<=numn;i++) { if(vis[i]!=-1){ int x=vis[i]; ans[x]+=node[i].len-node[node[i].fail].len; } } for(int i=1;i<=n;i++)printf("%lld\n",ans[i]); } int main(){ scanf("%d",&n); int tot=0; init(); memset(len,0,sizeof(len)); memset(vis,0,sizeof(vis)); memset(ans,0,sizeof(ans)); for(int i=1;i<=n;i++){ scanf("%s",ss); last=1;//最核心的一句 int l=strlen(ss); for(int j=0;j<l;j++){ len[i]++; addChar(ss[j]-'a'); s[tot++]=ss[j]; } } s[tot]='\0'; solve(); return 0; }上一版代码很丑,而且其实因为是只是统计本质不同的子串,所以这个版本的写法是可以的,但是这种广义SAM写法会申请很多多余的节点,宏观感受下{ab,abc} 简易+不申请多余节点版
#include<bits/stdc++.h> using namespace std; const int maxn=5e5+150; const int kind=26; typedef long long ll; int tot1=1,las=1; int ch[maxn*2][kind]; int len[maxn*2],fa[maxn*2]; char s1[maxn],s2[maxn]; ll sum[maxn*2]; int d1[maxn*2],d2[maxn*2]; int n; inline int newn(int x){len[++tot1]=x;return tot1;} inline int newnq(int p,int w){ int nq=newn(len[p]+1); int q=ch[p][w]; for(int i=0;i<kind;i++)ch[nq][i]=ch[q][i]; fa[nq]=fa[q]; fa[q]=nq; while(p&&ch[p][w]==q)ch[p][w]=nq,p=fa[p]; return nq; } void sam_ins(int c){ int p=las; if(ch[p][c]){ int q=ch[p][c]; if (len[q]==len[p]+1)las=q; else las=newnq(p,c); return ; } int np=newn(len[las]+1);las=tot1; while(p&&!ch[p][c])ch[p][c]=np,p=fa[p]; if(!p)fa[np]=1; else{ int q=ch[p][c]; if(len[q]==len[p]+1) fa[np]=q; else{ fa[np]=newnq(p,c); } } } int vis[maxn*2]; inline void cla(int x,int y){ for(;x&&vis[x]!=y&&vis[x]!=-1;x=fa[x]){ if(vis[x]!=0)vis[x]=-1; else vis[x]=y; } } ll ans[maxn]; int le[maxn]; void solve(){ int dd=0; for(int i=1;i<=n;i++){ for(int j=1,x=1;j<=le[i];j++){ int v=s2[dd]-'a'; dd++; cla(x=ch[x][v],i); } } for(int i=1;i<=tot1;i++) { if(vis[i]!=-1){ int x=vis[i]; ans[x]+=len[i]-len[fa[i]]; } } for(int i=1;i<=n;i++)printf("%lld\n",ans[i]); } int main(){ scanf("%d",&n); int tot=0; for(int i=1;i<=n;i++){ scanf("%s",s1); las=1; int l=strlen(s1); for(int j=0;j<l;j++){ le[i]++; sam_ins(s1[j]-'a'); s2[tot++]=s1[j]; } } s2[tot]='\0'; solve(); return 0; }