ZigZagK的博客
[线性代数+多项式取模]洛谷4723【常系数齐次线性递推】题解
2021年9月29日 18:02
洛谷
查看标签

题目概述

给出递推 $f$ 的 $[0,K-1]$ 项,给出系数 $\{a_K\}$ ,$f_n=\sum_{i=1}^{K}a_if_{n-i}(n\ge K)$ ,求第 $n$ 项。

$n\le 10^9,K\le32000$

解题报告

这种题学过一遍就不会想学第二遍了= =。

如果 $K$ 比较小的话,我们可以直接用矩阵快速幂来转移DP,但是这样复杂度是 $O(K^3\log_2n)$ 的。

我们先写出转移的矩阵形式:

$$ A=\begin{bmatrix}a_1&a_2&\cdots&a_{K-1}&a_K\\1&0&\cdots&0&0\\\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&\cdots&1&0\end{bmatrix},A\cdot\begin{bmatrix}f_{n-1}\\f_{n-2}\\\vdots\\f_{n-K}\end{bmatrix}=\begin{bmatrix}f_{n}\\f_{n-1}\\\vdots\\f_{n-K+1}\end{bmatrix} $$

我们需要求出 $A^n$ 。

设 $A$ 的特征多项式为 $F(x)=\prod_{i=1}^{K}(\lambda_i-x)$ 。

根据多项式除法可设:$x^n=D(x)F(x)+R(x)$ ,该式子可用 $A$ 代入得到 $A^n=D(A)F(A)+R(A)$ 。

根据哈密尔顿-凯莱定理,将 $A$ 代入 $A$ 的特征多项式可得 $F(A)=O$ 。

因此,$A^n=R(A)$ ,而 $R(x)=x^n\bmod F(x)$ ,可以用多项式取模求得。

不过直接将 $A$ 代进 $R(x)$ 还是需要用到矩乘,复杂度又爆炸了,我们考虑答案式子。

$$ X=\begin{bmatrix}f_{K-1}&f_{K-2}&\cdots&f_0\end{bmatrix}^{T}\\ f_n=(A^nX)_{K,1}=[R(A)X]_{K,1}=(\sum_{i=0}^{K-1}r_iA^{i}X)_{K,1}=\sum_{i=0}^{K-1}r_i(A^iX)_{K,1}=\sum_{i=0}^{K-1}r_if_{i} $$

我们愉快的发现根本不需要把 $A$ 代入到 $R(x)$ 里,只需要求出 $R(x)$ 的系数就行了。


要求 $R(x)$ 的系数,首先要求出 $F(x)$ 的系数。

$$ \lambda E-A=\begin{bmatrix}\lambda-a_1&-a_2&\cdots&-a_{K-1}&-a_K\\-1&\lambda&\cdots&0&0\\\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&\cdots&-1&\lambda\end{bmatrix} $$

用第一行展开,手玩一下可以发现 $F(x)=(\lambda-a_1)\lambda^{K-1}+\sum_{i=2}^{K}(-a_i)\lambda^{K-i}=\sum_{i=0}^{K-1}(-a_{K-i})\lambda^i+\lambda^K$ 。

然后用快速幂就可以求出 $x^n\bmod F(x)$ 了。

复杂度:暴力取模 $O(K^2\log_2n)$ ,多项式取模 $O(K\log_2K\log_2n)$(常数起飞)。

示例程序

#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxk=32000,maxt=65536,MOD=998244353;

int n,K,f[maxk+5],a[maxk+5],p[maxt+5],s[maxt+5],w[maxt+5],ans;
int rev[maxt+5],pw[2][maxt+5],tem[maxt+5],P[maxt+5];

#define ADD(x,y) (((x)+(y))%MOD)
#define MUL(x,y) ((LL)(x)*(y)%MOD)
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void Pre(int n){
    for (int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    for (int i=2;i<=n;i<<=1) pw[0][i]=Pow(3,(MOD-1)/i),pw[1][i]=Pow(pw[0][i],MOD-2);
}
void NTT(int *a,int n,int f){
    for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int k=1;k<n;k<<=1){
        int gn=pw[f<0][k<<1],g=1,x,y;
        for (int i=0;i<n;i+=k<<1,g=1)
            for (int j=0;j<k;j++,g=MUL(g,gn))
                x=a[i+j],y=MUL(a[i+j+k],g),a[i+j]=ADD(x,y),a[i+j+k]=ADD(x,MOD-y);
    }
    if (f<0) for (int i=0,INV=Pow(n,MOD-2);i<n;i++) a[i]=MUL(a[i],INV);
}
void Inv(int *a,int *b,int n){
    if (n==1) {b[0]=Pow(a[0],MOD-2);return;}
    Inv(a,b,n>>1);Pre(n<<1);
    for (int i=0;i<n;i++) tem[i]=a[i],tem[i+n]=b[i+n]=0;
    NTT(tem,n<<1,1);NTT(b,n<<1,1);
    for (int i=0;i<(n<<1);i++) tem[i]=MUL(b[i],MOD+2-MUL(b[i],tem[i]));
    NTT(tem,n<<1,-1);
    for (int i=0;i<n;i++) b[i]=tem[i],b[i+n]=0;
}
void Div(int *a,int *b,int *c,int n,int m){
    static int A[maxt+5],B[maxt+5],C[maxt+5],t;
    for (t=1;t<n-m+1;t<<=1);t<<=1;
    for (int i=0;i<t;i++) A[i]=B[i]=0;
    for (int i=0;i<n-m+1;i++) A[i]=a[n-i],B[i]=m-i<0?0:b[m-i];
    Inv(B,C,t>>1);Pre(t);NTT(A,t,1);NTT(C,t,1);
    for (int i=0;i<t;i++) C[i]=MUL(C[i],A[i]);
    NTT(C,t,-1);for (int i=0;i<n-m+1;i++) c[i]=C[n-m-i];
}
void Mul(int *res,int *a,int *b){
    static int A[maxt+5],B[maxt+5],C[maxt+5],t;
    for (t=1;t<K+K-1;t<<=1);
    for (int i=0;i<K;i++) A[i]=a[i],B[i]=b[i];
    for (int i=K;i<t;i++) A[i]=B[i]=0;
    Pre(t);NTT(A,t,1);NTT(B,t,1);
    for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);
    NTT(A,t,-1);Div(A,p,C,K+K-2,K);
    for (int i=K-1;i<t;i++) C[i]=0;
    Pre(t);NTT(C,t,1);
    for (int i=0;i<t;i++) C[i]=MUL(C[i],P[i]);
    NTT(C,t,-1);
    for (int i=0;i<K;i++) res[i]=ADD(A[i],MOD-C[i]);
}
int main(){
    scanf("%d%d",&n,&K);
    for (int i=1;i<=K;i++) scanf("%d",&a[i]),a[i]=(a[i]%MOD+MOD)%MOD;
    for (int i=0;i<K;i++) scanf("%d",&f[i]),f[i]=(f[i]%MOD+MOD)%MOD;
    if (n<K) {printf("%d\n",f[n]);return 0;}
    for (int i=0;i<K;i++) p[i]=(MOD-a[K-i])%MOD;p[K]=1;

    int t=1;for (t=1;t<K+K-1;t<<=1);
    for (int i=0;i<=K;i++) P[i]=p[i];
    Pre(t);NTT(P,t,1);

    K==1?w[0]=(MOD-p[0])%MOD:w[1]=1;
    for (s[0]=1;n;n>>=1,Mul(w,w,w)) if (n&1) Mul(s,s,w);
    for (int i=0;i<K;i++) ans=ADD(ans,MUL(s[i],f[i]));
    printf("%d\n",ans);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!