ZigZagK的博客
[分治+矩乘NTT]2022牛客暑期多校训练营(加赛)L【Lndjy and the mex】题解
2022年10月13日 20:43
牛客
查看标签

题目概述

Lndjy and the mex

解题报告

首先不难发现长度为 $len$ 的区间的方案数是相同的,因此只需要考虑区间长度而不需要考虑区间位置(乘 $n-len+1$ 即可)。

然后考虑长度为 $len$ 的区间中,如果 $\text{mex}$ 为 $M$ ,则 $[0,M-1]$ 一定出现过了,$M$ 一定没有出现,$[M+1,n]$ 随意出不出现。

设 $c_i$ 表示 $i$ 数字在区间中出现了多少个,则贡献为:

$$ (n-len+1)M\sum_{\sum c_i=len}{len!\over\prod c_i!}{(n-len)!\over \prod(a_i-c_i)!} $$

不难发现可以转成指数型生成函数的形式:

$$ A_i(x)=\sum_{j=1}^{a_i}{1\over j!(a_i-j)!}x^j,B_i(x)=\sum_{j=0}^{a_i}{1\over j!(a_i-j)!}x^j\\ (n-len+1)(n-len)!{M\over a_M!}[{x^{len}\over len!}]\prod_{i=0}^{M-1}A_i(x)\prod_{i=M+1}^{n}B_i(x) $$

如果枚举 $M$ ,则从多项式中除掉 $A_M(x)$ 以及求前后缀多项式都是不现实的,而回退背包在这题也由于多项式次数较大而无法进行。因此只能考虑递推求出所有多项式的和,定义 $F_{i,0}(x),F_{i,1}(x)$ 表示前 $i$ 个还没定 $M$ 的多项式和,前 $i$ 个已经确定 $M$ 的多项式和,则:

$$ \begin{bmatrix}F_{i-1,0}(x)&F_{i-1,1}(x)\end{bmatrix}\begin{bmatrix}A_i(x)&{i\over a_i!}\\0&B_i(x)\end{bmatrix}=\begin{bmatrix}F_{i,0}(x)&F_{i,1}(x)\end{bmatrix} $$

因此可以用分治矩乘NTT来快速求出 $F_{n,0},F_{n,1}$,则答案为(因为这题一定有一个没出现,即 $M\le n$ 所以不用考虑 $F_{n,0}$ ):

$$ \sum_{len=1}^{n}(n-len+1)(n-len)![{x^{len}\over len!}]F_{n,1}(x) $$

示例程序

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=100000,maxt=1<<17,MOD=998244353;

int n,a[maxn+5],fac[maxn+5],INV[maxn+5],ans;
int wn[maxt+5],temA[maxt+5],temB[maxt+5];
struct Matrix{
    PN s[2][2];
    void zero() {s[0][0]=s[0][1]=s[1][0]=s[1][1]={0};}
}M[maxn+5];
PN F;

inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void NTTPre(){
    int x=Pow(3,(MOD-1)/maxt);
    wn[maxt>>1]=1;
    for (int i=(maxt>>1)+1;i<maxt;i++) wn[i]=MUL(wn[i-1],x);
    for (int i=(maxt>>1)-1;i;i--) wn[i]=wn[i<<1];
}
void NTT(int *a,int n,int f){
    if (f>0){
        for (int k=n>>1;k;k>>=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=a[i+j+k];
                    a[i+j+k]=MUL(x+MOD-y,wn[k+j]);
                    a[i+j]=ADD(x,y);
                }
    } else {
        for (int k=1;k<n;k<<=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=MUL(a[i+j+k],wn[k+j]);
                    a[i+j+k]=ADD(x,MOD-y);
                    a[i+j]=ADD(x,y);
                }
        for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
        reverse(a+1,a+n);
    }
}
inline PN operator + (const PN &a,const PN &b){
    static PN c;c.resize(max(a.size(),b.size()));
    for (int i=0;i<c.size();i++) c[i]=ADD(i<a.size()?a[i]:0,i<b.size()?b[i]:0);
    return c;
}
PN operator * (const PN &a,const PN &b){
    static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) temA[i]=a[i];for (int i=n;i<t;i++) temA[i]=0;
    for (int i=0;i<m;i++) temB[i]=b[i];for (int i=m;i<t;i++) temB[i]=0;
    NTT(temA,t,1);NTT(temB,t,1);
    for (int i=0;i<t;i++) temA[i]=MUL(temA[i],temB[i]);
    NTT(temA,t,-1);
    c.resize(n+m-1);for (int i=0;i<n+m-1;i++) c[i]=temA[i];
    return c;
}
Matrix operator * (const Matrix &a,const Matrix &b){
    static Matrix c;c.zero();
    for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
            for (int k=0;k<2;k++)
                c.s[i][j]=c.s[i][j]+a.s[i][k]*b.s[k][j];
    return c;
}
void Make(int n){
    INV[0]=INV[1]=1;for (int i=2;i<=n;i++) INV[i]=MUL(MOD-MOD/i,INV[MOD%i]);
    fac[0]=1;for (int i=1;i<=n;i++) fac[i]=MUL(fac[i-1],i),INV[i]=MUL(INV[i-1],INV[i]);
}
Matrix Solve(int L,int R){
    if (L==R) return M[L];
    int mid=L+(R-L>>1);
    return Solve(L,mid)*Solve(mid+1,R);
}
int main(){
    scanf("%d",&n);
    NTTPre();Make(n);
    for (int i=0;i<=n;i++){
        scanf("%d",&a[i]);
        M[i].s[0][0].resize(a[i]+1);M[i].s[0][0][0]=0;
        for (int j=1;j<=a[i];j++) M[i].s[0][0][j]=MUL(INV[a[i]-j],INV[j]);
        M[i].s[0][1]={MUL(i,INV[a[i]])};M[i].s[1][0]={0};
        M[i].s[1][1]=M[i].s[0][0];M[i].s[1][1][0]=MUL(INV[a[i]],INV[0]);
    }
    F=Solve(0,n).s[0][1];
    for (int i=1,si=F.size();i<=n && i<si;i++)
        ans=ADD(ans,MUL(MUL(F[i],fac[i]),MUL(fac[n-i],n-i+1)));
    printf("%d\n",ans);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!