ZigZagK的博客
[原根+NTT+快速幂]BZOJ3992(SDOI2015)【序列统计】题解
2019年1月7日 19:03
BZOJ
查看标签

题目概述

有 $S$ 个数,取 $n$ 次(可重复取),将得到的数乘起来模 $m$ 为 $x$ 的概率。

解题报告

做过这道题的弱化版……这道题只需要把循环矩乘换成NTT就行了。

我原来以为这道题是循环卷积,被法老D了之后发现这只是NTT的应用而已……

$c_i=\sum_{j+k\ mod\ m-1=i}a_jb_k$ ,所以只要先做一遍NTT,然后把 $>m-1$ 的加回前面就行了。

示例程序

#include<cmath>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxs=8000,maxt=1<<14,MOD=1004535809;

int n,m,X,S,a[maxs+5],Lg[maxs+5],t,rev[maxt+5],A[maxt+5],B[maxt+5],C[maxt+5],pw[2][maxt+5];
struct Matrix {int s[maxs+5];inline void Unit() {for (int i=0;i<m;i++) s[i]=0;s[0]=1;}} T;

inline int Pow(int w,int b,int MOD) {int s=1;for (;b;b>>=1,w=(LL)w*w%MOD) b&1?s=(LL)s*w%MOD:0;return s;}
inline int getG(int p,int ans=1){
    for (int s=p-1;;ans++,s=p-1){
        for (int i=2,S=sqrt(s);s>1&&i<=S;i++)
            if (!(s%i)) {if (Pow(ans,(p-1)/i,p)==1) goto END;while (!(s%i)) s/=i;}
        if (s>1&&Pow(ans,(p-1)/s,p)==1) goto END;break;END:;
    }return ans;
}
inline void Pre(int n){
    for (int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    for (int i=2;i<=n;i<<=1) pw[0][i]=Pow(3,(MOD-1)/i,MOD),pw[1][i]=Pow(pw[0][i],MOD-2,MOD);
}
inline void AMOD(int &x,int y) {if ((x+=y)>=MOD) x-=MOD;}
inline void NTT(int *a,int n,int f){
    for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int k=1;k<n;k<<=1){
        int gn=pw[f<0][k<<1],g=1,x,y;
        for (int i=0;i<n;i+=k<<1,g=1)
            for (int j=0;j<k;j++,g=(LL)g*gn%MOD)
                x=a[i+j],y=(LL)a[i+j+k]*g%MOD,AMOD(a[i+j]=x,y),AMOD(a[i+j+k]=x,MOD-y);
    }if (f<0) for (int inv=Pow(n,MOD-2,MOD),i=0;i<n;i++) a[i]=(LL)a[i]*inv%MOD;
}
Matrix operator * (const Matrix &a,const Matrix &b){
    static Matrix s;for (int i=0;i<m;i++) A[i]=a.s[i],B[i]=b.s[i];
    for (int i=m;i<t;i++) A[i]=B[i]=0;NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) C[i]=(LL)A[i]*B[i]%MOD;NTT(C,t,-1);
    for (int i=0;i<m;i++) s.s[i]=C[i];for (int i=m;i<t;i++) AMOD(s.s[i%m],C[i]);return s;
}
inline Matrix Pow(Matrix w,int b) {static Matrix s;s.Unit();for (;b;b>>=1,w=w*w) if (b&1) s=s*w;return s;}
int main(){
    freopen("program.in","r",stdin);freopen("program.out","w",stdout);
    scanf("%d%d%d%d",&n,&m,&X,&S);for (int i=1;i<=S;i++) scanf("%d",&a[i]);
    int g=getG(m);Lg[1]=0;for (int i=1,pw=g;pw>1;i++,pw=(LL)pw*g%m) Lg[pw]=i;m--;
    for (t=1;t<=(m<<1);t<<=1);Pre(t);for (int i=1;i<=S;i++) if (a[i]) T.s[Lg[a[i]]]++;
    T=Pow(T,n);printf("%d\n",T.s[Lg[X]]);return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!