这题显然是一个树形背包 $f_{x,j,0},f_{x,j,1}$ 表示 $x$ 没选 / 选了,子树里选了 $j$ 个的方案数。
这个背包可以写成多项式形式,$B_x$ 为 $x$ 的权值:
$$ F_{x,0}=\prod_{u\in son(x)}(F_{u,0}+F_{u,1})\\ F_{x,1}=B_xx\prod_{u\in son(x)}F_{u,0} $$
但是不能对于每一个点都求出 $F$ ,由于我们只需要求出 $F_1$ ,因此需要想办法进行优化。考虑链分治:
对于重链,每个节点记录 $A_x=\prod_{u\in son(x)}(F_{u,0}+F_{u,1}),B_x=B_xx\prod_{u\in son(x)}F_{u,0}$ ,其中 $u$ 是轻儿子。即对于每个点都记录下一层的信息。求出下一层信息之后,考虑一条重链上的转移,$sh_x$ 为 $x$ 的重儿子:
$$ F_{x,0}=A_x(F_{sh_x,0}+F_{sh_x,1})\\ F_{x,1}=B_xF_{sh_x,0}\\ \begin{bmatrix}F_{sh_x,0}&F_{sh_x,1}\end{bmatrix}\begin{bmatrix}A_x&B_x\\A_x&0\end{bmatrix}=\begin{bmatrix}F_{x,0}&F_{x,1}\end{bmatrix} $$
因此,对于 $A_x,B_x$ ,我们可以通过分治NTT求出。而重链上我们只需要顶点的 $F$ 即可,因此可以通过分治NTT矩乘求出。
链分治总共有 $O(\log_2n)$ 层,每层 $O(n\log^2_2n)$ 处理 $A_x,B_x,F$ ,总复杂度 $O(n\log_2^3n)$ 。实际上由于链分治常数小跑得飞快。
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=80000,maxt=1<<17,MOD=998244353;
int n,m,a[maxn+5],si[maxn+5],SH[maxn+5],que[maxn+5];
int E,lnk[maxn+5],nxt[(maxn<<1)+5],to[(maxn<<1)+5];
int wn[maxt+5],temA[maxt+5],temB[maxt+5];
PN f[maxn+5][2],A[maxn+5],B[maxn+5];
inline void Add(int x,int y) {to[++E]=y;nxt[E]=lnk[x];lnk[x]=E;}
inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void NTTPre(){
int x=Pow(3,(MOD-1)/maxt);
wn[maxt>>1]=1;
for (int i=(maxt>>1)+1;i<maxt;i++) wn[i]=MUL(wn[i-1],x);
for (int i=(maxt>>1)-1;i;i--) wn[i]=wn[i<<1];
}
void NTT(int *a,int n,int f){
if (f>0){
for (int k=n>>1;k;k>>=1)
for (int i=0;i<n;i+=k<<1)
for (int j=0;j<k;j++){
int x=a[i+j],y=a[i+j+k];
a[i+j+k]=MUL(x+MOD-y,wn[k+j]);
a[i+j]=ADD(x,y);
}
} else {
for (int k=1;k<n;k<<=1)
for (int i=0;i<n;i+=k<<1)
for (int j=0;j<k;j++){
int x=a[i+j],y=MUL(a[i+j+k],wn[k+j]);
a[i+j+k]=ADD(x,MOD-y);
a[i+j]=ADD(x,y);
}
for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
reverse(a+1,a+n);
}
}
inline PN operator + (const PN &a,const PN &b){
static PN c;c.resize(max(a.size(),b.size()));
for (int i=0;i<c.size();i++) c[i]=ADD(i<a.size()?a[i]:0,i<b.size()?b[i]:0);
return c;
}
PN operator * (const PN &a,const PN &b){
static PN c;
int n=a.size(),m=b.size(),t;
for (t=1;t<n+m-1;t<<=1);
for (int i=0;i<n;i++) temA[i]=a[i];for (int i=n;i<t;i++) temA[i]=0;
for (int i=0;i<m;i++) temB[i]=b[i];for (int i=m;i<t;i++) temB[i]=0;
NTT(temA,t,1);NTT(temB,t,1);
for (int i=0;i<t;i++) temA[i]=MUL(temA[i],temB[i]);
NTT(temA,t,-1);
c.resize(n+m-1);for (int i=0;i<n+m-1;i++) c[i]=temA[i];
return c;
}
struct Matrix{
PN s[2][2];
void zero() {s[0][0]=s[0][1]=s[1][0]=s[1][1]={0};}
}tem,res;
Matrix operator * (const Matrix &a,const Matrix &b){
static Matrix c;c.zero();
for (int i=0;i<2;i++)
for (int j=0;j<2;j++)
for (int k=0;k<2;k++)
c.s[i][j]=c.s[i][j]+a.s[i][k]*b.s[k][j];
return c;
}
void DFS(int x,int pre=0){
si[x]=1;
for (int j=lnk[x];j;j=nxt[j])
if (to[j]!=pre){
DFS(to[j],x);si[x]+=si[to[j]];
if (si[to[j]]>si[SH[x]]) SH[x]=to[j];
}
}
PN CalcA(int L,int R){
if (L==R) return f[que[L]][0]+f[que[L]][1];
int mid=L+(R-L>>1);
return CalcA(L,mid)*CalcA(mid+1,R);
}
PN CalcB(int L,int R){
if (L==R) return f[que[L]][0];
int mid=L+(R-L>>1);
return CalcB(L,mid)*CalcB(mid+1,R);
}
Matrix Calc(int L,int R){
if (L==R){
tem.s[0][0]=tem.s[1][0]=A[que[L]];
tem.s[0][1]=B[que[L]];tem.s[1][1]={0};
return tem;
}
int mid=L+(R-L>>1);
return Calc(L,mid)*Calc(mid+1,R);
}
void Solve(int x,int pre=0,bool fl=true){
if (SH[x]) Solve(SH[x],x,false);
for (int j=lnk[x];j;j=nxt[j])
if (to[j]!=pre && to[j]!=SH[x])
Solve(to[j],x,true);
int cnt=0;
for (int j=lnk[x];j;j=nxt[j])
if (to[j]!=pre && to[j]!=SH[x])
que[++cnt]=to[j];
if (cnt) A[x]=CalcA(1,cnt),B[x]=PN({0,a[x]})*CalcB(1,cnt);
else A[x]={1},B[x]={0,a[x]};
if (fl){
cnt=0;for (int i=x;i;i=SH[i]) que[++cnt]=i;
reverse(que+1,que+1+cnt);res=Calc(1,cnt);
f[x][0]=res.s[0][0];f[x][1]=res.s[0][1];
}
}
int main(){
NTTPre();
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),Add(x,y),Add(y,x);
DFS(1);Solve(1);
f[1][0]=f[1][0]+f[1][1];
printf("%d\n",m<f[1][0].size()?f[1][0][m]:0);
return 0;
}