这题感觉有利于理解FWT对序列本质上的影响。
链接:BZOJ 4589
1<=n<=10^9, 2<=m<=50000。 80组数据
我看到是没想到怎么用FWT做,看了TT的题解懂得。
感觉还是和FFT之类一样,要有往这个方向想的思维才可以。
首先,定义函数 f ( x ) = a 0 + a 1 x + a 2 x 2 + . . . + a n x n f(x) = a_0 + a_1x + a_2x^2+...+a_nx^n f(x)=a0+a1x+a2x2+...+anxn 初始化 a i a_i ai为 [ i i s p r i m e ] [i is prime] [iisprime]
那么f(x)和自己做一次异或FWT得到的结果: F ( 2 k ) = ∑ i ⨁ j = k a i ∗ b j F\binom{2}{k} = \sum_{i\bigoplus j = k}a_i*b_j F(k2)=∑i⨁j=kai∗bj 这个结果的意义是什么?
我们先假设一个这是fft的结果,那么这个意义就是,两个素数的和为k的组合的个数。
那么换到FWT也是一样的,两个素数的异或值为K的组合个数!
那么我们继续,用$F\binom{2}{k} 和 和 和f$做一次FWT,得到的结果意义就是,三个素数的异或值为K的组合个数
通过读题我们知道,我们要求的就是n个素数的异或值为0的组合个数,那么经过FWT的转化就变成了求a序列n次FWT的结果中的a[0]的值。
由于n较大,我们发现在中间计算点值的时候因为自己和自己做卷积,所以实际上n次卷积就是a[i]^n,那么我们用快速幂来优化一下就可以了。
fwt(a,len); for(int i=0;i<len;i++) { a[i] = ksm(a[i],n); } ifwt(a,len);完整代码
#include <bits/stdc++.h> #define mem(a,b) memset((a),(b),sizeof(a)) using namespace std; typedef long long ll; const int N = (1<<20) + 10; const ll mod = 1e9+7; void fwt(ll a[],int n){ for(int d=1;d<n;d<<=1){ for(int m=d<<1,i=0;i<n;i+=m){ for(int j=0;j<d;j++){ ll x=a[i+j],y=a[i+j+d]; a[i+j] = (x + y)%mod, a[i+j+d] = (x - y + mod) % mod; } } } //xor : a[i+j] = x + y, a[i+j+d] = (x - y + mod) % mod; //and : a[i+j] = x + y; //or : a[i+j+d] = x + y; } ll inv2; void ifwt(ll a[],int n){ for(int d=1;d<n;d<<=1){ for(int m=d<<1,i=0;i<n;i+=m){ for(int j=0;j<d;j++){ ll x=a[i+j],y=a[i+j+d]; a[i+j] = (x + y) * inv2 % mod, a[i+j+d] = ((x - y + mod)%mod) * inv2 % mod; } } } //inv2 = 2^(-1) //xor : a[i+j] = (x + y) / 2, a[i+j+d] = (x - y) / 2; //and : a[i+j] = x - y; //or : a[i+j+d] = y - x; } const int MAXN = 66666; ll vis[MAXN + 10]; void init() { for(int i = 2;i < MAXN;i ++) { if(!vis[i]) { for(int j = i+i;j < MAXN;j += i) vis[j] = 1; } } } ll ksm(ll a,ll b) { ll ret = 1; while(b) { if(b & 1) ret = ret * a % mod; a = a * a % mod; b >>= 1; } return ret; } ll a[N],b[N],c[N],d[N]; int main(){ inv2 = 1ll*(mod+1ll)/2ll;//2的逆元 int n,m; init(); while(scanf("%d%d",&n,&m) == 2) { mem(a,0); m ++; int len = 1; while(len < m) len <<= 1; for(int i = 2;i < m;i ++) { a[i] = !vis[i];//因为有效位只有m个,不要一直到len // printf("i:%d %lld\n",i,a[i]); } fwt(a,len); for(int i=0;i<len;i++) { a[i] = ksm(a[i],n); } ifwt(a,len); printf("%lld\n",a[0]); } } /* 3 7 4 13 */