ZigZagK的博客
[期望DP+分治FWT]2021牛客暑期多校训练营6 D【Gambling Monster】题解
2021年8月6日 14:49
牛客
查看标签

题目概述

Gambling Monster

解题报告

显然是个期望DP,倒着考虑正常一点(因为在 $n-1$ 处结束,步数为 $0$ ),所以我们倒着DP。

定义 $f(i)$​ 表示 $i$​ 走到 $n-1$​ 状态的期望步数,$p(i)$​ 表示选到 $i$​ 的概率,那么:

$$ f(i)=\sum_{j\le i}[f(i)+1]\cdot p(i\ xor\ j)+\sum_{j>i}[f(j)+1]\cdot p(i\ xor\ j) $$

第一部分是 $i$ 原地踏步的期望,第二部分是上一次为 $j$ 的期望。

令 $s(i)=\sum_{j>i}p(i\ xor\ j)$ ,整理一下:

$$ f(i)s(i)=1+\sum_{j>i}f(j)\cdot p(i\ xor\ j)\\ f(i)s(i)=1+\sum_{j\ xor\ k=i(j>i)}f(j)\cdot p(k) $$

不难发现这是个异或卷积,但是有 $j>i$ 的限制,因此可以用分治+FWT解决。

同时 $s(i)=\sum_{j\ xor\ k=i(j>i)}p(k)$​ ,也是个分治+FWT。


分治+FWT 和 分治+FFT的实现类似,当分治到 $[L,R]$ ,对应二进制位为 $d$ 的时候,不难发现 $[L,mid]$ 第 $d$ 位是 $0$ ,$[mid+1,R]$ 第 $d$ 位是 $1$ ,且 $d$ 更高位相同。同时,由于 $d$ 更高位相同,所以 $p(k)$ 的 $k$ 中 $d$ 的更高位一定是 $0$ ,而 $d$ 位一定是 $1$ 。因此只需要考虑 $d$ 低位的 $2^d$ 个数即可。

示例程序

#include<cstdio>
using namespace std;
typedef long long LL;
const int maxn=1<<16,MOD=1e9+7;

int te,n,m,sum,p[maxn+5],s[maxn+5],f[maxn+5],A[maxn+5],B[maxn+5];

#define ADD(x,y) (((x)+(y))%MOD)
#define MUL(x,y) ((LL)(x)*(y)%MOD)
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void FWT(int *a,int n,int f){
    for (int k=1;k<n;k<<=1)
        for (int i=0;i<n;i+=k<<1)
            for (int j=0,x,y;j<k;j++)
                x=a[i+j],y=a[i+j+k],a[i+j]=ADD(x,y),a[i+j+k]=ADD(x,MOD-y);
    if (f<0) for (int i=0,INV=Pow(n,MOD-2);i<n;i++) a[i]=MUL(a[i],INV);
}
void SolveS(int L,int R,int d){
    if (L==R) return;
    int mid=L+(R-L>>1);
    SolveS(mid+1,R,d-1);

    for (int i=0;i<(1<<d);i++) A[i]=1;
    for (int i=0;i<(1<<d);i++) B[i]=p[i+(1<<d)];
    FWT(A,1<<d,1);FWT(B,1<<d,1);
    for (int i=0;i<(1<<d);i++) A[i]=MUL(A[i],B[i]);
    FWT(A,1<<d,-1);
    for (int i=L;i<=mid;i++) s[i]=ADD(s[i],A[i-L]);

    SolveS(L,mid,d-1);
}
void SolveF(int L,int R,int d){
    if (L==R) {f[L]=ADD(f[L],s[L]);return;}
    int mid=L+(R-L>>1);
    SolveF(mid+1,R,d-1);

    for (int i=mid+1;i<=R;i++) A[i-mid-1]=f[i];
    for (int i=0;i<(1<<d);i++) B[i]=p[i+(1<<d)];
    FWT(A,1<<d,1);FWT(B,1<<d,1);
    for (int i=0;i<(1<<d);i++) A[i]=MUL(A[i],B[i]);
    FWT(A,1<<d,-1);
    for (int i=L;i<=mid;i++) f[i]=ADD(f[i],MUL(A[i-L],s[i]));

    SolveF(L,mid,d-1);
}
int main(){
    for (scanf("%d",&te);te;te--){
        scanf("%d",&n);sum=0;
        for (int i=0;i<n;i++) scanf("%d",&p[i]),sum+=p[i];
        sum=Pow(sum,MOD-2);for (int i=0;i<n;i++) p[i]=MUL(p[i],sum);
        for (m=1;(1<<m)<n;m++);m--;
        for (int i=0;i<n;i++) s[i]=f[i]=0;
        SolveS(0,n-1,m);
        for (int i=0;i<n;i++) s[i]=Pow(s[i],MOD-2);
        SolveF(0,n-1,m);
        printf("%d\n",f[0]);
    }
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!