ZigZagK的博客
[链分治+分治NTT]LOJ6289【花朵】题解
2022年10月9日 19:48
LOJ
查看标签

题目概述

LOJ6289

解题报告

这题显然是一个树形背包 $f_{x,j,0},f_{x,j,1}$ 表示 $x$ 没选 / 选了,子树里选了 $j$ 个的方案数。

这个背包可以写成多项式形式,$B_x$ 为 $x$ 的权值:

$$ F_{x,0}=\prod_{u\in son(x)}(F_{u,0}+F_{u,1})\\ F_{x,1}=B_xx\prod_{u\in son(x)}F_{u,0} $$

但是不能对于每一个点都求出 $F$ ,由于我们只需要求出 $F_1$ ,因此需要想办法进行优化。考虑链分治:

对于重链,每个节点记录 $A_x=\prod_{u\in son(x)}(F_{u,0}+F_{u,1}),B_x=B_xx\prod_{u\in son(x)}F_{u,0}$ ,其中 $u$ 是轻儿子。即对于每个点都记录下一层的信息。求出下一层信息之后,考虑一条重链上的转移,$sh_x$ 为 $x$ 的重儿子:

$$ F_{x,0}=A_x(F_{sh_x,0}+F_{sh_x,1})\\ F_{x,1}=B_xF_{sh_x,0}\\ \begin{bmatrix}F_{sh_x,0}&F_{sh_x,1}\end{bmatrix}\begin{bmatrix}A_x&B_x\\A_x&0\end{bmatrix}=\begin{bmatrix}F_{x,0}&F_{x,1}\end{bmatrix} $$

因此,对于 $A_x,B_x$ ,我们可以通过分治NTT求出。而重链上我们只需要顶点的 $F$ 即可,因此可以通过分治NTT矩乘求出。

链分治总共有 $O(\log_2n)$ 层,每层 $O(n\log^2_2n)$ 处理 $A_x,B_x,F$ ,总复杂度 $O(n\log_2^3n)$ 。实际上由于链分治常数小跑得飞快。

示例程序

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=80000,maxt=1<<17,MOD=998244353;

int n,m,a[maxn+5],si[maxn+5],SH[maxn+5],que[maxn+5];
int E,lnk[maxn+5],nxt[(maxn<<1)+5],to[(maxn<<1)+5];
int wn[maxt+5],temA[maxt+5],temB[maxt+5];
PN f[maxn+5][2],A[maxn+5],B[maxn+5];

inline void Add(int x,int y) {to[++E]=y;nxt[E]=lnk[x];lnk[x]=E;}
inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void NTTPre(){
    int x=Pow(3,(MOD-1)/maxt);
    wn[maxt>>1]=1;
    for (int i=(maxt>>1)+1;i<maxt;i++) wn[i]=MUL(wn[i-1],x);
    for (int i=(maxt>>1)-1;i;i--) wn[i]=wn[i<<1];
}
void NTT(int *a,int n,int f){
    if (f>0){
        for (int k=n>>1;k;k>>=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=a[i+j+k];
                    a[i+j+k]=MUL(x+MOD-y,wn[k+j]);
                    a[i+j]=ADD(x,y);
                }
    } else {
        for (int k=1;k<n;k<<=1)
            for (int i=0;i<n;i+=k<<1)
                for (int j=0;j<k;j++){
                    int x=a[i+j],y=MUL(a[i+j+k],wn[k+j]);
                    a[i+j+k]=ADD(x,MOD-y);
                    a[i+j]=ADD(x,y);
                }
        for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
        reverse(a+1,a+n);
    }
}
inline PN operator + (const PN &a,const PN &b){
    static PN c;c.resize(max(a.size(),b.size()));
    for (int i=0;i<c.size();i++) c[i]=ADD(i<a.size()?a[i]:0,i<b.size()?b[i]:0);
    return c;
}
PN operator * (const PN &a,const PN &b){
    static PN c;
    int n=a.size(),m=b.size(),t;
    for (t=1;t<n+m-1;t<<=1);
    for (int i=0;i<n;i++) temA[i]=a[i];for (int i=n;i<t;i++) temA[i]=0;
    for (int i=0;i<m;i++) temB[i]=b[i];for (int i=m;i<t;i++) temB[i]=0;
    NTT(temA,t,1);NTT(temB,t,1);
    for (int i=0;i<t;i++) temA[i]=MUL(temA[i],temB[i]);
    NTT(temA,t,-1);
    c.resize(n+m-1);for (int i=0;i<n+m-1;i++) c[i]=temA[i];
    return c;
}
struct Matrix{
    PN s[2][2];
    void zero() {s[0][0]=s[0][1]=s[1][0]=s[1][1]={0};}
}tem,res;
Matrix operator * (const Matrix &a,const Matrix &b){
    static Matrix c;c.zero();
    for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
            for (int k=0;k<2;k++)
                c.s[i][j]=c.s[i][j]+a.s[i][k]*b.s[k][j];
    return c;
}
void DFS(int x,int pre=0){
    si[x]=1;
    for (int j=lnk[x];j;j=nxt[j])
        if (to[j]!=pre){
            DFS(to[j],x);si[x]+=si[to[j]];
            if (si[to[j]]>si[SH[x]]) SH[x]=to[j];
        }
}
PN CalcA(int L,int R){
    if (L==R) return f[que[L]][0]+f[que[L]][1];
    int mid=L+(R-L>>1);
    return CalcA(L,mid)*CalcA(mid+1,R);
}
PN CalcB(int L,int R){
    if (L==R) return f[que[L]][0];
    int mid=L+(R-L>>1);
    return CalcB(L,mid)*CalcB(mid+1,R);
}
Matrix Calc(int L,int R){
    if (L==R){
        tem.s[0][0]=tem.s[1][0]=A[que[L]];
        tem.s[0][1]=B[que[L]];tem.s[1][1]={0};
        return tem;
    }
    int mid=L+(R-L>>1);
    return Calc(L,mid)*Calc(mid+1,R);
}
void Solve(int x,int pre=0,bool fl=true){
    if (SH[x]) Solve(SH[x],x,false);
    for (int j=lnk[x];j;j=nxt[j])
        if (to[j]!=pre && to[j]!=SH[x])
            Solve(to[j],x,true);
    int cnt=0;
    for (int j=lnk[x];j;j=nxt[j])
        if (to[j]!=pre && to[j]!=SH[x])
            que[++cnt]=to[j];
    if (cnt) A[x]=CalcA(1,cnt),B[x]=PN({0,a[x]})*CalcB(1,cnt);
    else A[x]={1},B[x]={0,a[x]};
    if (fl){
        cnt=0;for (int i=x;i;i=SH[i]) que[++cnt]=i;
        reverse(que+1,que+1+cnt);res=Calc(1,cnt);
        f[x][0]=res.s[0][0];f[x][1]=res.s[0][1];
    }
}
int main(){
    NTTPre();
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%d",&a[i]);
    for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),Add(x,y),Add(y,x);
    DFS(1);Solve(1);
    f[1][0]=f[1][0]+f[1][1];
    printf("%d\n",m<f[1][0].size()?f[1][0][m]:0);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!