ZigZagK的博客
[指数型生成函数+分治NTT+广义容斥]LOJ6503(雅礼集训 2018 Day4)【Magic】题解
2019年3月3日 18:25
LOJ
查看标签

题目概述

有 $n​$ 种颜色的膜法卡,每种颜色有 $a_i​$ 种,总共有 $m​$ 张。现在要把所有卡片排成一排,如果相邻两个卡片颜色相同则产生一个膜法对,求膜法对个数为 $k​$ 的排列个数。

解题报告

学弟CJJ秒了,首先根据套路我们可以先讨论膜法对个数 $\ge k​$ 的方案数,然后用广义容斥来求出 $=k​$ 的方案数。

定义 $f_{i,j}$ 表示前 $i$ 个颜色,分成了 $j$ 块的方案数,那么我们枚举 $i$ 这个颜色分成 $k$ 块,进行转移:

$$ f_{i,j}=\sum_{k}f_{i-1,j-k}{a_i-1\choose k-1}{j\choose k} $$

其中 ${a_i-1\choose k-1}$ 表示 $a_i$ 个分成 $k$ 块的方案数,${j\choose k}$ 表示 $k$ 块和 $j-k$ 块互相穿插(乱排)的方案数。

这样定义状态有一个好处就是这是个指数型生成函数的卷积形式,可以分治NTT来快速转移!而且对于 $f_{n,j}$ 只要根据 $m-j$ 就可以得知至少有 $m-j$ 个膜法对的方案数( $f_{n,j}$ 已经保证了有 $m-j$ 个膜法对,在乱排的时候还会出现新的膜法对)。所以答案就是:

$$ \sum_{i=k}^{m-1}(-1)^{i-k}{i\choose k}f_{n,m-i} $$

示例程序

#include<cstdio>
#include<vector>
#include<algorithm>
#define pb push_back
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=20000,maxm=100000,maxt=1<<18,MOD=998244353;

int n,m,K,a[maxn+5],fac[maxm+5],INV[maxm+5],f[maxm+5],ans;PN p[maxn+5],F;

inline int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=(LL)w*w%MOD) if (b&1) s=(LL)s*w%MOD;return s;}
inline void AMOD(int &x,int tem) {if ((x+=tem)>=MOD) x-=MOD;}
namespace Poly{
    int rev[maxt+5],pw[2][maxt+5],tem[maxt+5];
    inline void Pre(int n){
        for (int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
        for (int i=2;i<=n;i<<=1) pw[0][i]=Pow(3,(MOD-1)/i),pw[1][i]=Pow(pw[0][i],MOD-2);
    }
    inline void NTT(int *a,int n,int f){
        for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
        for (int k=1;k<n;k<<=1){
            int gn=pw[f<0][k<<1],g=1,x,y;
            for (int i=0;i<n;i+=k<<1,g=1)
                for (int j=0;j<k;j++,g=(LL)g*gn%MOD)
                    x=a[i+j],y=(LL)a[i+j+k]*g%MOD,AMOD(a[i+j]=x,y),AMOD(a[i+j+k]=x,MOD-y);
        }if (f<0) for (int i=0,INV=Pow(n,MOD-2);i<n;i++) a[i]=(LL)a[i]*INV%MOD;
    }
    inline PN operator * (const PN &A,const PN &B){
        static int n,a[maxt+5],m,b[maxt+5],c[maxt+5];static PN C;n=A.size();m=B.size();
        for (int i=0;i<n;i++) a[i]=A[i];for (int i=0;i<m;i++) b[i]=B[i];
        if (n<=250&&m<=250){
            for (int i=0;i<n+m-1;i++) c[i]=0;C.clear();
            for (int i=0;i<n;i++) for (int j=0;j<m;j++) AMOD(c[i+j],(LL)a[i]*b[j]%MOD);
            for (int i=0;i<n+m-1;i++) C.pb(c[i]);return C;
        }
        int t;for (t=1;t<n+m-1;t<<=1);Pre(t);
        for (int i=n;i<t;i++) a[i]=0;for (int i=m;i<t;i++) b[i]=0;
        NTT(a,t,1);NTT(b,t,1);for (int i=0;i<t;i++) c[i]=(LL)a[i]*b[i]%MOD;
        NTT(c,t,-1);C.clear();for (int i=0;i<n+m-1;i++) C.pb(c[i]);return C;
    }
}
using namespace Poly;
inline void Make(int n){
    INV[0]=INV[1]=1;for (int i=2;i<=n;i++) INV[i]=(LL)(MOD-MOD/i)*INV[MOD%i]%MOD;
    fac[0]=1;for (int i=1;i<=n;i++) fac[i]=(LL)fac[i-1]*i%MOD,INV[i]=(LL)INV[i-1]*INV[i]%MOD;
}
#define C(x,y) ((x)<(y)?0:(LL)fac[x]*INV[y]%MOD*INV[(x)-(y)]%MOD)
PN Solve(int L,int R) {if (L==R) return p[L];int mid=L+(R-L>>1);return Solve(L,mid)*Solve(mid+1,R);}
int main(){
    freopen("program.in","r",stdin);freopen("program.out","w",stdout);
    scanf("%d%d%d",&n,&m,&K);Make(m);for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    for (int i=1;i<=n;i++) {p[i].pb(0);for (int j=1;j<=a[i];j++) p[i].pb((LL)C(a[i]-1,j-1)*INV[j]%MOD);}
    F=Solve(1,n);for (int i=0;i<m;i++) f[i]=(LL)F[m-i]*fac[m-i]%MOD;
    for (int i=K;i<=m;i++) AMOD(ans,(LL)(i-K&1?MOD-1:1)*C(i,K)%MOD*f[i]%MOD);printf("%d\n",ans);return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!