ZigZagK的博客
[容斥+DP+多项式求逆]2022“杭电杯”中国大学生算法设计超级联赛(1)1010【Walk】题解
2022年7月22日 15:53
HDU
查看标签

题目概述

HDU7147

解题报告

从来没做过这种容斥,长见识了😭。

定义 $f_i$ 表示走了 $i$ 行都合法的权值和,以及 $g_i$ 表示走了 $i$ 行全非法的权值和(特殊的,如果 $i=1$ 也认为非法)。

先考虑前 $n-1$ 行都合法,则权值为 $f_{n-1}g_1$ ,这样有个问题,就是 $n-1$ 行到 $n$ 行时不一定合法,因此我们强制 $n-1$ 到 $n$ 非法,即减去 $f_{n-2}g_2$ ,但是类似的会发现 $f_{n-2}g_2$ 里包含了 $n-2$ 到 $n-1$ 非法的情况是我们多减的,所以加上 $f_{n-3}g_3$ 。以此类推会发现这是个容斥:

$$ f_{n}=\sum_{i=1}^{n}(-1)^{i-1}f_{n-i}g_i=\sum_{i=1}^{n}f_{n-i}(-1)^{i-1}g_{i}\\ h_{i}=(-1)^{i-1}g_i\\ f_n=\sum_{i=1}^{n}f_{n-i}h_i $$

因此 $f=f*h+1$ ,即 $f={1\over 1-h}$ 。现在的问题只剩下求出 $g$ 。

由于 $g$ 是全非法的状态,所以所有行均满足下一个 $P>y+S(S(S(y)))$ ,也就是说列是递增的,我们可以在列上考虑DP。

不难发现 $S(S(S(y)))$ 最多只有 $2$ ,因此选了一个列之后最多不选两次就可以再次选下一个列。

定义 $F_{i,0/1/2}$ 表示 选到第 $i$ 列时还剩下 $0/1/2$ 次才可以再次选 的生成函数。那么考虑转移(假设 $S(S(S(i)))=k$ ,$i$ 权值为 $v_i$ ):

$$ vx\cdot F_{i-1,0}\to F_{i,k}\\ F_{i-1,0}\to F_{i,0}\\ F_{i-1,1}\to F_{i,0}\\ F_{i-1,2}\to F_{i,1} $$

不难发现可以利用矩阵进行转移,进一步的,我们把每个位置的转移矩阵 $M_i$ 存下来之后做分治矩阵乘法就可以得到 $F_{m}$ 。

令 $G=F_{m,0}+F_{m,1}+F_{m,2}$ ,则 $g_i=G_i$ ,求出 $h$ 之后多项式求逆就可以得到 $f$ 。

最后 $f_n$ 就是答案。

示例程序

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=100000,maxt=1<<18,MOD=998244353;

int n,m,a[maxn+5];
int wn[maxt+5],tem[maxt+5];
struct Matrix{
    PN s[3][3];
    void zero() {for (int i=0;i<3;i++) for (int j=0;j<3;j++) s[i][j]={0};}
};
Matrix M[maxn+5],res;
PN G;int g[maxt+5],f[maxt+5];

inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void Make(){
    int x=Pow(3,(MOD-1)/maxt);
    wn[maxt/2]=1;
    for (int i=maxt/2+1;i<maxt;i++) wn[i]=MUL(wn[i-1],x);
    for (int i=maxt/2-1;i>=0;i--) wn[i]=wn[i<<1];
}
void NTT(int *a,int n,int f){
    if (f>0){
        for (int k=n>>1;k;k>>=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=a[i+j+k];
                    a[i+j+k]=MUL(x+MOD-y,wn[k+j]);
                    a[i+j]=ADD(x,y);
                }
    } else {
        for (int k=1;k<n;k<<=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=MUL(a[i+j+k],wn[k+j]);
                    a[i+j+k]=ADD(x,MOD-y);
                    a[i+j]=ADD(a[i+j],y);
                }
        for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
        reverse(a+1,a+n);
    }
}
PN operator * (const PN &a,const PN &b){
    static int A[maxt+5],B[maxt+5];static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) A[i]=a[i];for (int i=n;i<t;i++) A[i]=0;
    for (int i=0;i<m;i++) B[i]=b[i];for (int i=m;i<t;i++) B[i]=0;
    NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);NTT(A,t,-1);
    c.clear();for (int i=0;i<n+m-1;i++) c.push_back(A[i]);
    return c;
}
void operator += (PN &a,const PN &b){
    if (b.size()>a.size()) a.resize(b.size());
    for (int i=0,si=b.size();i<si;i++) a[i]=ADD(a[i],b[i]);
}
Matrix operator * (const Matrix &a,const Matrix &b){
    static Matrix c;c.zero();
    for (int i=0;i<3;i++)
        for (int j=0;j<3;j++)
            for (int k=0;k<3;k++)
                c.s[i][j]+=a.s[i][k]*b.s[k][j];
    return c;
}
Matrix Solve(int L,int R) {if (L==R) return M[L];int mid=L+(R-L>>1);return Solve(L,mid)*Solve(mid+1,R);}
void Inv(int *a,int *b,int n){ // a=1/b
    if (n==1) {a[0]=Pow(b[0],MOD-2);return;}
    Inv(a,b,n>>1);
    for (int i=0;i<n;i++) tem[i]=b[i],tem[i+n]=a[i+n]=0;
    NTT(tem,n<<1,1);NTT(a,n<<1,1);
    for (int i=0;i<(n<<1);i++) tem[i]=MUL(a[i],2+MOD-MUL(tem[i],a[i]));
    NTT(tem,n<<1,-1);for (int i=0;i<n;i++) a[i]=tem[i],a[i+n]=0;
}
int main(){
    Make();
    scanf("%d%d",&n,&m);
    for (int i=1;i<=m;i++) scanf("%d",&a[i]);
    for (int i=1;i<=m;i++){
        M[i].zero();M[i].s[1][0]=M[i].s[2][1]={1};
        if (i<16) M[i].s[0][0]={1,a[i]}; else
        if (i<65536) M[i].s[0][0]={1},M[i].s[0][1]={0,a[i]};
        else M[i].s[0][0]={1},M[i].s[0][2]={0,a[i]};
    }
    res=Solve(1,m);G={0};
    for (int i=0;i<3;i++) G+=res.s[0][i];
    g[0]=1;for (int i=1,si=G.size();i<si;i++) g[i]=(i&1?MOD-G[i]:G[i])%MOD;
    int t;for (t=1;t<=n;t<<=1);
    Inv(f,g,t);
    printf("%d\n",f[n]);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!