数据范围:
碰到这种给出多个字符串要求构造一个新的字符串求概率或者价值的题刘汝佳在<<算法竞赛入门经典>>里告诉我们:
做法一般都是先建出 A C AC AC自动机然后再来做 d p dp dp。
首先设出 d p dp dp方程:
设 f u , L f_{u,L} fu,L表示从节点 u u u开始走,走 L L L步后能得到的最大价值。把每个字符串的价值记录在 A C AC AC自动机上末尾节点上,注意如果一个字符串包含另一个字符,需要把另一个字符串的贡献也算上。
具体为在建 f a i l fail fail数组的 b f s bfs bfs时, w u + = w f a i l [ u ] w_{u}+=w_{fail[u]} wu+=wfail[u]。
那么 d p dp dp方程的转移为: f u , L = M a x v ∈ c h [ u ] { f v , L − 1 + w v } f_{u,L}=Max_{v \in ch[u]}\{ f_{v,L-1}+w_v\} fu,L=Maxv∈ch[u]{fv,L−1+wv}
最后答案为 f r o o t , n f_{root,n} froot,n。
代码实现上可以采用递归的形式。
bool vis[maxn][maxn]; long long f[maxn][maxn]; long long dp(int now,int L) { if(vis[now][L]) return f[now][L]; vis[now][L]=1; if(L==0) return 0; for(int i=0;i<27;i++) { if(f[now][L]<(dp(Trie::ch[now][i],L-1)+w[Trie::ch[now][i]])) from[now]=i; f[now][L]=max(f[now][L],dp(Trie::ch[now][i],L-1)+w[Trie::ch[now][i]]); } return f[now][L]; }这样可以拿到 70 p t s 70pts 70pts。
那么如果 n n n很大呢?
我们注意到每次 f f f更新都是同一个模式,所以可以采用矩阵优化。
这里是广义的矩阵。
我们观察下面的两个等式: a ( b + c ) = a b + a c a(b+c)=ab+ac a(b+c)=ab+ac
a + m a x ( b , c ) = m a x ( b + a , c + a ) a+max(b,c)=max(b+a,c+a) a+max(b,c)=max(b+a,c+a)
它们都满足结合律。事实上只要满足结合律就可以使用矩阵来优化。我们可以把原来的 ∗ * ∗重定义为 + + +,原来的 + + +重定义为 m a x max max即可。
注意单位矩阵变成了主对角线为 0 0 0,其余点为 − i n f -inf −inf的矩阵。
具体实现为:
Matrix operator*(Matrix p) const { Matrix t; for(int i=1;i<=200;i++) for(int j=1;j<=200;j++) t[i][j]=-inf; for(int k=1;k<=200;k++) { for(int i=1;i<=200;i++) { for(int j=1;j<=200;j++) { t[i][j]=max(t[i][j],a[i][k]+p[k][j]); } } } return t; }那么我们把 d p dp dp方程用行向量保存下来,用一个矩阵表示转移,然后做矩阵快速幂即可。
题外话:考前只知道可以广义矩阵的定义,没打过,结果单位矩阵和 A C AC AC自动机加权值搞错了, D E B U G 2 h DEBUG\ 2h DEBUG 2h!!!。
暴力和矩阵都在里面,显得略长。
/******************************* Author:galaxy yr LANG:C++ Created Time:2019年10月31日 星期四 08时54分15秒 *******************************/ #include<iostream> #include<cstdio> #include<cstring> #include<string> #include<queue> #define int long long //#define _70pts_ using namespace std; const long long inf=1e17; struct Matrix{ enum{size=203}; long long a[size][size]; long long * operator[](const int i) { return a[i]; } const long long * operator[](const int i) const { return a[i]; } Matrix() { for(int i=1;i<=200;i++) for(int j=1;j<=200;j++) a[i][j]=-inf; } Matrix(int p) { for(int i=1;i<=200;i++) for(int j=1;j<=200;j++) a[i][j]=-inf; for(int i=0;i<=200;i++) a[i][i]=0; } Matrix operator*(Matrix p) const { Matrix t; for(int i=1;i<=200;i++) for(int j=1;j<=200;j++) t[i][j]=-inf; for(int k=1;k<=200;k++) { for(int i=1;i<=200;i++) { for(int j=1;j<=200;j++) { t[i][j]=max(t[i][j],a[i][k]+p[k][j]); } } } return t; } }; const int maxn=505; long long n,m,a[maxn],fail[maxn*27],from[maxn]; string s; /*AC自动机*/ namespace Trie{ int ch[maxn][27],w[maxn*27],tot=1; void insert(string s,int val) { int now=1,c; for(int i=0;i<(int)s.size();i++) { c=s[i]-'a'; if(!ch[now][c]) ch[now][c]=++tot; now=ch[now][c]; } w[now]+=val; } }; using namespace Trie; void make_fail() { for(int i=0;i<=26;i++) Trie::ch[0][i]=1; queue<int>que; que.push(1); while(!que.empty()) { int x=que.front(); que.pop(); w[x]+=w[fail[x]]; for(int i=0;i<26;i++) if(ch[x][i]) { fail[ch[x][i]]=ch[fail[x]][i]; que.push(ch[x][i]); } else { ch[x][i]=ch[fail[x]][i]; } } } #ifdef _70pts_ //暴力dp bool vis[maxn][maxn]; long long f[maxn][maxn]; long long dp(int now,int L) { if(vis[now][L]) return f[now][L]; vis[now][L]=1; if(L==0) return 0; for(int i=0;i<27;i++) { if(f[now][L]<(dp(Trie::ch[now][i],L-1)+w[Trie::ch[now][i]])) from[now]=i; f[now][L]=max(f[now][L],dp(Trie::ch[now][i],L-1)+w[Trie::ch[now][i]]); } return f[now][L]; } #else Matrix f,T; bool vis[maxn]; void dfs(int now) { if(vis[now]) return; vis[now]=1; for(int i=0;i<27;i++) { T[ch[now][i]][now]=w[ch[now][i]]; if(!vis[ch[now][i]]) dfs(ch[now][i]); } } Matrix ksm(Matrix a,long long b) { Matrix res(1); while(b) { if(b&1) res=res*a; a=a*a; b>>=1; } return res; } void solve() { for(int i=1;i<=tot;i++) f[1][i]=0; dfs(1); f=f*ksm(T,n); printf("%lld\n",f[1][1]); } #endif signed main() { ios::sync_with_stdio(false); cin>>n>>m; for(int i=1;i<=m;i++) cin>>a[i]; for(int i=1;i<=m;i++) { cin>>s; Trie::insert(s,a[i]); } make_fail(); #ifdef _70pts_ printf("%lld\n",dp(1,n)); #else solve(); #endif return 0; }