ZigZagK的博客
[生成函数+多项式除法求系数]2021牛客暑期多校训练营8 H【Scholomance Academy】题解
2021年9月29日 18:34
牛客
查看标签

题目概述

Scholomance Academy

解题报告

伪装成数论题的多项式全家桶题。

首先我们可以把 $\varphi(n)$ 拆成 $\prod\varphi(p_i^{a_i})$ ,因此不难证明 $F(n)$ 是个积性函数。

那么 $G(N)$ 也可以把 $F(p_i^{k_i})$ 全拆开。而 $\sum_{i=1}^{t}k_i=N$ 让我们想到生成函数:

$$ \begin{align} G_j(x)&=(\sum_{i=0}^{\infty}\varphi(p_j^i)x^i)^m=(1+{p_j-1\over p_j}\sum_{i=1}^{\infty}p_j^ix^i)^m\\&=(1+{p_j-1\over p_j}{p_jx\over 1-p_jx})^m=({1-x\over 1-p_jx})^m\\ \end{align} $$

那么我们要求的就是 $G(x)=\prod_{i=1}^{t}G_i(x)$ 的第 $N$ 项系数。

$$ G(x)=\prod_{i=1}^{t}({1-x\over 1-p_ix})^m={(1-x)^{mt}\over(\prod_{i=1}^{t}1-p_ix)^m} $$

令 $U(x)=(1-x)^{mt},D(x)=(\prod_{i=1}^{t}1-p_ix)^m$ 。$U(x)$ 可以通过二项式展开求出,$D(x)$ 可以通过分治+NTT求出。


要求 $G(x)={U(x)\over D(x)}$ 的第 $N$ 项系数,由于 $N$ 太大了,不可能通过多项式求逆来求出,我们先从式子里找一下性质。

$$ U(x)=G(x)D(x)\\ \Rightarrow\forall n>mt,\sum_{i+j=n}g_id_j=u_n=0\\ \Leftrightarrow\forall n>mt,\sum_{i=0}^{n}d_ig_{n-i}=0\\ \Leftrightarrow\forall n>mt,\sum_{i=0}^{mt}d_ig_{n-i}=0(d_i=0,i>mt)\\ \Leftrightarrow\forall n>mt,d_0g_n+\sum_{i=1}^{mt}d_ig_{n-i}=0\\ \Leftrightarrow\forall n>mt,g_n=\sum_{i=1}^{mt}{-d_i\over d_0}g_{n-i}\\ $$

所以我们可以先用多项式求逆求出 $[0,mt]$ 范围内的 $g_i$ ,然后用常系数线性齐次递推来求出 $g_N$ 。

示例程序

这个可以当常系数线性齐次递推和多项式除法求系数的板子了= =。

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=100000,maxm=maxn*5,maxt=1<<20,MOD=998244353;

int N,n,m,U[maxt+5],D[maxt+5];PN p[maxn+5],P;
int fac[maxm+5],INV[maxm+5];
int rev[maxt+5],pw[2][maxt+5],tem[maxt+5];

#define ADD(x,y) (((x)+(y))%MOD)
#define MUL(x,y) ((LL)(x)*(y)%MOD)
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
inline void Pre(int n){
    for (int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    for (int i=2;i<=n;i<<=1) pw[0][i]=Pow(3,(MOD-1)/i),pw[1][i]=Pow(pw[0][i],MOD-2);
}
void NTT(int *a,int n,int f){
    for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int k=1;k<n;k<<=1){
        int gn=pw[f<0][k<<1],g=1,x,y;
        for (int i=0;i<n;i+=k<<1,g=1)
            for (int j=0;j<k;j++,g=MUL(g,gn))
                x=a[i+j],y=MUL(a[i+j+k],g),a[i+j]=ADD(x,y),a[i+j+k]=ADD(x,MOD-y);
    }
    if (f<0) for (int i=0,INV=Pow(n,MOD-2);i<n;i++) a[i]=MUL(a[i],INV);
}
PN operator - (const PN &A,const PN &B){
    static PN c;c.resize(max(A.size(),B.size()));
    for (int i=0;i<c.size();i++)
        c[i]=ADD(i<A.size()?A[i]:0,i<B.size()?MOD-B[i]:0);
    return c;
}
PN operator * (const PN &a,const PN &b){
    static int A[maxt+5],B[maxt+5];static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) A[i]=a[i];for (int i=n;i<t;i++) A[i]=0;
    for (int i=0;i<m;i++) B[i]=b[i];for (int i=m;i<t;i++) B[i]=0;
    Pre(t);NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);NTT(A,t,-1);
    c.clear();for (int i=0;i<n+m-1;i++) c.push_back(A[i]);
    return c;
}
void Inv(int *a,int *b,int n){
    if (n==1) {b[0]=Pow(a[0],MOD-2);return;}
    Inv(a,b,n>>1);
    for (int i=0;i<n;i++) tem[i]=a[i],tem[i+n]=b[i+n]=0;
    Pre(n<<1);NTT(tem,n<<1,1);NTT(b,n<<1,1);
    for (int i=0;i<(n<<1);i++) tem[i]=MUL(b[i],MOD+2-MUL(b[i],tem[i]));NTT(tem,n<<1,-1);
    for (int i=0;i<n;i++) b[i]=tem[i],b[i+n]=0;
}
PN operator / (const PN &a,const PN &b){
    static int A[maxt+5],B[maxt+5],C[maxt+5];static PN c;
    if (a.size()<b.size()) {c.clear();c.push_back(0);return c;}
    int n=a.size()-1,m=b.size()-1,t;
    for (t=1;t<n-m+1;t<<=1);t<<=1;
    for (int i=0;i<t;i++) A[i]=B[i]=0;
    for (int i=0;i<n-m+1;i++) A[i]=a[n-i],B[i]=m-i<0?0:b[m-i];
    Inv(B,C,t>>1);Pre(t);NTT(A,t,1);NTT(C,t,1);
    for (int i=0;i<t;i++) C[i]=MUL(C[i],A[i]);NTT(C,t,-1);
    c.clear();for (int i=0;i<n-m+1;i++) c.push_back(C[n-m-i]);
    return c;
}
PN operator % (const PN &a,const PN &b) {static PN c;c=a-a/b*b;c.resize(b.size()-1);return c;}
void Make(int n){
    INV[0]=INV[1]=1;for (int i=2;i<=n;i++) INV[i]=MUL(MOD-MOD/i,INV[MOD%i]);
    fac[0]=1;for (int i=1;i<=n;i++) fac[i]=MUL(fac[i-1],i),INV[i]=MUL(INV[i-1],INV[i]);
}
#define C(x,y) ((x)<(y)?0:MUL(fac[x],MUL(INV[y],INV[(x)-(y)])))
PN Solve(int L,int R){
    if (L==R) return p[L];
    int mid=L+(R-L>>1);
    return Solve(L,mid)*Solve(mid+1,R);
}
int Linear(int *f,int *a,int K,int n){
    if (n<K) return f[n];
    static PN p,w,s;p.resize(K+1);
    for (int i=0;i<K;i++) p[i]=(MOD-a[K-i])%MOD;p[K]=1;
    w.resize(2);w[0]=0;w[1]=1;w=w%p;
    for (s.resize(1),s[0]=1;n;n>>=1,w=w*w%p) if (n&1) s=s*w%p;
    int ans=0;for (int i=0,si=s.size();i<K && i<si;i++) ans=ADD(ans,MUL(s[i],f[i]));
    return ans;
}
int Asknth(int *U,int *D,int n,int N){
    static int A[maxt+5];
    int t;for (t=1;t<=n;t<<=1);t<<=1;
    Inv(D,A,t>>1);Pre(t);NTT(A,t,1);NTT(U,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],U[i]);NTT(A,t,-1);
    for (int i=0;i<n;i++) A[i]=A[i+1];
    for (int i=1,INV=Pow(D[0],MOD-2);i<=n;i++) D[i]=MUL(MOD-D[i],INV);
    return Linear(A,D,n,N-1);
}
int main(){
    scanf("%d%d%d",&N,&n,&m);Make(n*m);
    for (int i=1;i<=n;i++){
        int x;scanf("%d",&x);
        p[i].resize(2);p[i][0]=1;p[i][1]=MOD-x;
    }
    P=Solve(1,n);
    for (int i=0;i<=n*m;i++) U[i]=(i&1?MOD-C(n*m,i):C(n*m,i));
    for (int i=0,si=P.size();i<si;i++) D[i]=P[i];
    int t;for (t=1;t<=n*m;t<<=1);
    Pre(t);NTT(D,t,1);
    for (int i=0;i<t;i++) D[i]=Pow(D[i],m);
    NTT(D,t,-1);
    printf("%d\n",Asknth(U,D,n*m,N));
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!