ZigZagK的博客
[背包回退+二进制状压DP+NTT]2022牛客暑期多校训练营1 H【Fly】题解
2022年7月19日 14:10
牛客
查看标签

题目概述

Fly

解题报告

首先想到DP:$f_{i,j,0/1/2}$ 表示前 $i$ 个二进制位(从低位到高位),到 $i$ 这位时背包大小为 $j$ ,并且状态为 $0$ 小于 $1$ 相等 $2$ 大于的方案数。

暴力做法:对于二进制位 $i$ ,求出没有限制的那些物品的 $a_i$ 的背包(分治NTT)$F$ ,然后将 $F$ 与 $f_{i-1,j,0/1/2}$ 做背包(NTT),得到这一位的状态。复杂度 $O(60\cdot2^{16}\cdot16\log_2n+60\cdot2^{17}\cdot17)$ ,看起来很不可过,但是如果使用奇技淫巧:由于 $\sum a_i=40000$ ,因此相同的 $a_i$ 不超过 $283$ 个,把相同的 $a_i$ 先做NTT得到对应的背包,然后再分治NTT求出总背包 $F$ 。这样可以减少分治NTT一半的层数,勉强通过此题。


为了避免每次都做分治NTT,我们采用背包回退。由于限制个数只有 $5000$ 个,因此只会回退 $5000$ 次,每次回退复杂度 $40000$ ,复杂度可以接受。

回退代码非常的简单:

for (int i=a[x];i<=maxn;i++) g[i]=ADD(g[i],MOD-g[i-a[x]]);

进一步优化,我们发现我们最终只需要 $0,1$ 两个状态。定义 $f_{i,j}$ 表示到 $i$ 这位时之前背包容量 $V-m$ 的值为 $j$ 的方案数,但是 $j=0$ 时,表示的不是 $V-m=0$ 而是 $V-m\le 0$ 。

每次将 $f_{i,j}$ 与 $F$ 做背包得到 $G$ ,此时我们需要考虑转移到下一位 $i+1$ 。

令 $m[i],V[i]$ 表示二进制第 $i$ 位。

  • $m[i]=0,V[i]=0$ :$G_j\to f_{i+1,j/2}$
  • $m[i]=0,V[i]=1$ :$G_j\to f_{i+1,(j+1)/2}$
  • $m[i]=1,V[i]=0$ :$G_{j}\to f_{i+1,j/2}$
    此时说明不管 $i$ 这位之前的 $V$ 是怎么样的,$V$ 一定比 $m$ 小,因此不需要考虑 $-1$ ,直接把 $i$ 位删除(即,认为 $i$ 前面都是 $0$ )
  • $m[i]=1,V[i]=1$ :$G_j\to f_{i+1,j/2}$

由于第三种情况时我们直接把后面全当成 $0$ ,因此 $j=0$ 的状态表示 $V-m\le 0$ 而不是 $V-m=0$ 。

这样最后我们只需要询问 $f_{60,0}$ 就可以得到 $V\le m$ 的答案,减少了 $3$ 倍常数。

示例程序

暴力做法+奇技淫巧

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
typedef unsigned long long ULL;
const int maxn=40000,LOG=60,maxs=1<<17,maxt=1<<17,MOD=998244353,BA=23333;

int n,K,a[maxn+5],M[LOG+5];LL m;
vector<int> e[LOG+5];
int ti,vis[maxn+5];
int w[maxt+5],A[maxt+5],B[maxt+5];
int tot,cnt[maxn+5];PN p[maxn+5],P;
int f[LOG+5][maxs+5][3],ans; // 0<,1=,2>

inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void Make(){
    int x=Pow(3,(MOD-1)/maxt);
    w[maxt/2]=1;
    for (int i=maxt/2+1;i<maxt;i++) w[i]=MUL(w[i-1],x);
    for (int i=maxt/2-1;i>=0;i--) w[i]=w[i<<1];
}
void NTT(int *a,int n,int f){
    if (f>0){
        for (int k=n>>1;k;k>>=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=a[i+j+k];
                    a[i+j+k]=MUL(x+MOD-y,w[k+j]);
                    a[i+j]=ADD(x,y);
                }
    } else {
        for (int k=1;k<n;k<<=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=MUL(a[i+j+k],w[k+j]);
                    a[i+j+k]=ADD(x,MOD-y);
                    a[i+j]=ADD(a[i+j],y);
                }
        for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
        reverse(a+1,a+n);
    }
}
PN operator * (const PN &a,const PN &b){
    static int A[maxt+5],B[maxt+5];static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) A[i]=a[i];for (int i=n;i<t;i++) A[i]=0;
    for (int i=0;i<m;i++) B[i]=b[i];for (int i=m;i<t;i++) B[i]=0;
    NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);NTT(A,t,-1);
    c.clear();for (int i=0;i<n+m-1;i++) c.push_back(A[i]);
    return c;
}
inline bool cmp(const int &i,const int &j) {return p[i].size()>p[j].size();}
PN Solve(int L,int R) {if (L==R) return p[L];int mid=L+(R-L>>1);return Solve(L,mid)*Solve(mid+1,R);}
int main(){
    Make();
    scanf("%d%lld%d",&n,&m,&K);
    for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    for (int i=1,x,y;i<=K;i++) scanf("%d%d",&x,&y),e[y].push_back(x);
    for (int i=0;i<LOG;i++) M[i]=m>>i&1;
    for (int t=0;t<LOG;t++){
        ti++;for (auto x:e[t]) vis[x]=ti;
        for (int i=1;i<=maxn;i++) cnt[i]=0;
        for (int i=1;i<=n;i++) if (vis[i]<ti) cnt[a[i]]++;
        tot=0;
        for (int a=1;a<=maxn;a++)
            if (cnt[a]){
                int lim=a*cnt[a];
                for (int i=0;i<=lim;i++) A[i]=0;
                A[0]=A[a]=1;
                if (cnt[a]>1){
                    int len;for (len=1;len<=lim;len<<=1);
                    for (int i=lim+1;i<len;i++) A[i]=0;
                    NTT(A,len,1);
                    for (int i=0;i<len;i++) A[i]=Pow(A[i],cnt[a]);
                    NTT(A,len,-1);
                }
                tot++;p[tot].resize(lim+1);
                for (int i=0;i<=lim;i++) p[tot][i]=A[i];
            }
        if (!tot) P.resize(1),P[0]=1;
        else P=Solve(1,tot);
        int si=P.size();
        if (t==0){
            for (int i=0;i<si;i++)
                if ((i&1)>M[t]) f[t][i][2]=P[i]; else
                if ((i&1)<M[t]) f[t][i][0]=P[i];
                else f[t][i][1]=P[i];
            continue;
        }
        for (int i=0;i<si;i++) A[i]=P[i];
        for (int i=si;i<maxt;i++) A[i]=0;
        NTT(A,maxt,1);
        for (int j=0;j<3;j++){
            for (int i=0;i<maxt;i++) B[i]=0;
            for (int i=0;i<maxs;i++) B[i>>1]=ADD(B[i>>1],f[t-1][i][j]);
            NTT(B,maxt,1);
            for (int i=0;i<maxt;i++) B[i]=MUL(A[i],B[i]);
            NTT(B,maxt,-1);
            for (int i=0;i<maxs;i++)
                if ((i&1)>M[t]) f[t][i][2]=ADD(f[t][i][2],B[i]); else
                if ((i&1)<M[t]) f[t][i][0]=ADD(f[t][i][0],B[i]);
                else f[t][i][j]=ADD(f[t][i][j],B[i]);
        }
    }
    ans=ADD(f[LOG-1][0][0],f[LOG-1][0][1]);
    ans=ADD(ans,ADD(f[LOG-1][1][0],f[LOG-1][1][1]));
    printf("%d\n",ans);
    return 0;
}

背包回退+优化DP

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
typedef unsigned long long ULL;
const int maxn=40000,LOG=60,maxt=1<<17,MOD=998244353;

int n,K,a[maxn+5],M[LOG+5];LL m;
vector<int> e[LOG+5];
int ti,vis[maxn+5];
int w[maxt+5],A[maxt+5],B[maxt+5];
PN p[maxn+5],P;
int F[maxn+5],g[maxn+5];
int f[LOG+5][maxn+5]; // 0<,1=,2>

inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void Make(){
    int x=Pow(3,(MOD-1)/maxt);
    w[maxt/2]=1;
    for (int i=maxt/2+1;i<maxt;i++) w[i]=MUL(w[i-1],x);
    for (int i=maxt/2-1;i>=0;i--) w[i]=w[i<<1];
}
void NTT(int *a,int n,int f){
    if (f>0){
        for (int k=n>>1;k;k>>=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=a[i+j+k];
                    a[i+j+k]=MUL(x+MOD-y,w[k+j]);
                    a[i+j]=ADD(x,y);
                }
    } else {
        for (int k=1;k<n;k<<=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=MUL(a[i+j+k],w[k+j]);
                    a[i+j+k]=ADD(x,MOD-y);
                    a[i+j]=ADD(a[i+j],y);
                }
        for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
        reverse(a+1,a+n);
    }
}
PN operator * (const PN &a,const PN &b){
    static int A[maxt+5],B[maxt+5];static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) A[i]=a[i];for (int i=n;i<t;i++) A[i]=0;
    for (int i=0;i<m;i++) B[i]=b[i];for (int i=m;i<t;i++) B[i]=0;
    NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);NTT(A,t,-1);
    c.clear();for (int i=0;i<n+m-1;i++) c.push_back(A[i]);
    return c;
}
inline bool cmp(const int &i,const int &j) {return p[i].size()>p[j].size();}
PN Solve(int L,int R) {if (L==R) return p[L];int mid=L+(R-L>>1);return Solve(L,mid)*Solve(mid+1,R);}
int main(){
    Make();
    scanf("%d%lld%d",&n,&m,&K);
    for (int i=1;i<=n;i++) scanf("%d",&a[i]),p[i].resize(a[i]+1),p[i][0]=p[i][a[i]]=1;
    for (int i=1,x,y;i<=K;i++) scanf("%d%d",&x,&y),e[y].push_back(x);
    for (int i=0;i<LOG;i++) M[i]=m>>i&1;
    P=Solve(1,n);
    for (int i=0;i<P.size();i++) F[i]=P[i];
    f[0][0]=1;
    for (int t=0;t<LOG;t++){
        for (int i=0;i<=maxn;i++) g[i]=F[i];ti++;
        for (auto x:e[t])
            if (vis[x]<ti){
                vis[x]=ti;
                for (int i=a[x];i<=maxn;i++) g[i]=ADD(g[i],MOD-g[i-a[x]]);
            }
        for (int i=0;i<=maxn;i++) A[i]=g[i];
        for (int i=maxn+1;i<maxt;i++) A[i]=0;
        NTT(A,maxt,1);
        for (int i=0;i<=maxn;i++) B[i]=f[t][i];
        for (int i=maxn+1;i<maxt;i++) B[i]=0;
        NTT(B,maxt,1);
        for (int i=0;i<maxt;i++) B[i]=MUL(A[i],B[i]);
        NTT(B,maxt,-1);
        for (int i=0;i<=(maxn<<1);i++)
            if (M[t]) f[t+1][i>>1]=ADD(f[t+1][i>>1],B[i]);
            else f[t+1][i+1>>1]=ADD(f[t+1][i+1>>1],B[i]);
    }
    printf("%d\n",f[LOG][0]);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!