显然是个期望DP,倒着考虑正常一点(因为在 $n-1$ 处结束,步数为 $0$ ),所以我们倒着DP。
定义 $f(i)$ 表示 $i$ 走到 $n-1$ 状态的期望步数,$p(i)$ 表示选到 $i$ 的概率,那么:
$$ f(i)=\sum_{j\le i}[f(i)+1]\cdot p(i\ xor\ j)+\sum_{j>i}[f(j)+1]\cdot p(i\ xor\ j) $$
第一部分是 $i$ 原地踏步的期望,第二部分是上一次为 $j$ 的期望。
令 $s(i)=\sum_{j>i}p(i\ xor\ j)$ ,整理一下:
$$ f(i)s(i)=1+\sum_{j>i}f(j)\cdot p(i\ xor\ j)\\ f(i)s(i)=1+\sum_{j\ xor\ k=i(j>i)}f(j)\cdot p(k) $$
不难发现这是个异或卷积,但是有 $j>i$ 的限制,因此可以用分治+FWT解决。
同时 $s(i)=\sum_{j\ xor\ k=i(j>i)}p(k)$ ,也是个分治+FWT。
分治+FWT 和 分治+FFT的实现类似,当分治到 $[L,R]$ ,对应二进制位为 $d$ 的时候,不难发现 $[L,mid]$ 第 $d$ 位是 $0$ ,$[mid+1,R]$ 第 $d$ 位是 $1$ ,且 $d$ 更高位相同。同时,由于 $d$ 更高位相同,所以 $p(k)$ 的 $k$ 中 $d$ 的更高位一定是 $0$ ,而 $d$ 位一定是 $1$ 。因此只需要考虑 $d$ 低位的 $2^d$ 个数即可。
#include<cstdio>
using namespace std;
typedef long long LL;
const int maxn=1<<16,MOD=1e9+7;
int te,n,m,sum,p[maxn+5],s[maxn+5],f[maxn+5],A[maxn+5],B[maxn+5];
#define ADD(x,y) (((x)+(y))%MOD)
#define MUL(x,y) ((LL)(x)*(y)%MOD)
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void FWT(int *a,int n,int f){
for (int k=1;k<n;k<<=1)
for (int i=0;i<n;i+=k<<1)
for (int j=0,x,y;j<k;j++)
x=a[i+j],y=a[i+j+k],a[i+j]=ADD(x,y),a[i+j+k]=ADD(x,MOD-y);
if (f<0) for (int i=0,INV=Pow(n,MOD-2);i<n;i++) a[i]=MUL(a[i],INV);
}
void SolveS(int L,int R,int d){
if (L==R) return;
int mid=L+(R-L>>1);
SolveS(mid+1,R,d-1);
for (int i=0;i<(1<<d);i++) A[i]=1;
for (int i=0;i<(1<<d);i++) B[i]=p[i+(1<<d)];
FWT(A,1<<d,1);FWT(B,1<<d,1);
for (int i=0;i<(1<<d);i++) A[i]=MUL(A[i],B[i]);
FWT(A,1<<d,-1);
for (int i=L;i<=mid;i++) s[i]=ADD(s[i],A[i-L]);
SolveS(L,mid,d-1);
}
void SolveF(int L,int R,int d){
if (L==R) {f[L]=ADD(f[L],s[L]);return;}
int mid=L+(R-L>>1);
SolveF(mid+1,R,d-1);
for (int i=mid+1;i<=R;i++) A[i-mid-1]=f[i];
for (int i=0;i<(1<<d);i++) B[i]=p[i+(1<<d)];
FWT(A,1<<d,1);FWT(B,1<<d,1);
for (int i=0;i<(1<<d);i++) A[i]=MUL(A[i],B[i]);
FWT(A,1<<d,-1);
for (int i=L;i<=mid;i++) f[i]=ADD(f[i],MUL(A[i-L],s[i]));
SolveF(L,mid,d-1);
}
int main(){
for (scanf("%d",&te);te;te--){
scanf("%d",&n);sum=0;
for (int i=0;i<n;i++) scanf("%d",&p[i]),sum+=p[i];
sum=Pow(sum,MOD-2);for (int i=0;i<n;i++) p[i]=MUL(p[i],sum);
for (m=1;(1<<m)<n;m++);m--;
for (int i=0;i<n;i++) s[i]=f[i]=0;
SolveS(0,n-1,m);
for (int i=0;i<n;i++) s[i]=Pow(s[i],MOD-2);
SolveF(0,n-1,m);
printf("%d\n",f[0]);
}
return 0;
}