original link - https://codeforces.com/problemset/problem/1218/E
题意:
10 10 10次查询,每次构造出长度 n n n的 a a a数组, C n k C_n^k Cnk个 k k k元组,求 ∑ i = 1 C n k ∏ j = 1 k a p j \sum_{i=1}^{C_n^k}\prod_{j=1}^k a_{p_j} ∑i=1Cnk∏j=1kapj,也就是任意选择 k k k个数,求乘积之和。
解析:
怎么做选择 K K K个的乘积之和?生成函数。
a 1 , a 2 . . . a n → ( 1 + a 1 x ) ( 1 + a 2 x ) . . . ( 1 + a n x ) a_1,a_2...a_n\to (1+a_1x)(1+a_2x)...(1+a_nx) a1,a2...an→(1+a1x)(1+a2x)...(1+anx),然后 x k x^k xk的系数就是答案。(很简单~)
这个东西分治一下,因为取模,所以用 N T T NTT NTT,时间复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。
注意区间长度 l l l时, l e n ≥ 2 ∗ l len\geq 2*l len≥2∗l(因为长度为 1 1 1时有两项)。
代码:
/* * Author : Jk_Chen * Date : 2019-10-02-14.52.51 */ #include<bits/stdc++.h> using namespace std; #define LL long long #define rep(i,a,b) for(int i=(int)(a);i<=(int)(b);i++) #define per(i,a,b) for(int i=(int)(a);i>=(int)(b);i--) #define mmm(a,b) memset(a,b,sizeof(a)) #define pb push_back #define pill pair<int, int> #define fi first #define se second #define debug(x) cerr<<#x<<" = "<<x<<'\n'; const LL mod=998244353 ; const int maxn=4e4+9; const int inf=0x3f3f3f3f; LL rd(){ LL ans=0; char last=' ',ch=getchar(); while(!(ch>='0' && ch<='9'))last=ch,ch=getchar(); while(ch>='0' && ch<='9')ans=ans*10+ch-'0',ch=getchar(); if(last=='-')ans=-ans; return ans; } #define rd rd() /*_________________________________________________________begin*/ LL Pow(LL a,LL b) { LL res=1; while(b>0) { if(b&1) res=res*a%mod; a=a*a%mod; b>>=1; } return res; } #define Swap(x,y) x ^= y, y ^= x, x ^= y void rader(vector<LL>&a, int n) { for(int i = 1, j = n >> 1; i < n - 1; i++) { if(i < j) Swap(a[i], a[j]); int k = n >> 1; for(; j >= k; k >>= 1) j -= k; if(j < k) j += k; } } const int G = 3; // root of mod void NTT(vector<LL>&a, int n, int on) { // 0~n-1 rader(a, n); for(int h = 2; h <= n; h <<= 1) { int hh = h >> 1, wn = Pow(G, on==-1 ? mod - 1 - (mod - 1) / h : (mod - 1) / h); for(int i = 0; i < n; i += h) { LL w = 1; for(int j = i; j < i + hh; j++) { int x = a[j], y = w * a[j + hh] % mod; a[j] = (x + y) % mod; a[j + hh] = (x - y + mod) % mod; w = w * wn % mod; } } } if(on==-1) { int inv = Pow(n, mod - 2); rep(i,0,n-1) a[i] = a[i] * inv % mod; } } int n,k; LL a[maxn],b[maxn]; vector<LL> divide(int l,int r){ if(l==r){ vector<LL> res(2); res[0]=1; res[1]=b[l]; return res; } int mid=l+r>>1; auto L=divide(l,mid),R=divide(mid+1,r); int len=1; while(len<2*(r-l+1))len<<=1; while(L.size()<len)L.push_back(0); while(R.size()<len)R.push_back(0); NTT(L,len,1); NTT(R,len,1); rep(i,0,len-1){ L[i]=L[i]*R[i]%mod; } NTT(L,len,-1); return L; } int main(){ n=rd,k=rd; rep(i,1,n)a[i]=rd; int _=rd; while(_--){ int op=rd; LL q=rd; if(op==1){ int I=rd;LL d=rd; rep(i,1,n){ if(i==I){ b[i]=(q-d+mod)%mod; } else{ b[i]=(q-a[i]+mod)%mod; } } } else{ int l=rd,r=rd;LL d=rd; rep(i,1,n){ if(i>=l&&i<=r){ b[i]=(q-(a[i]+d)+mod)%mod; } else{ b[i]=(q-a[i]+mod)%mod; } } } auto ans=divide(1,n); printf("%lld\n",(ans[k]+mod)%mod); } return 0; } /*_________________________________________________________end*/