摆明了要你点分治……每次统计 $x$ 子树的时候记录两个值:到 $x$ 路径最大值和到 $x$ 路径点权和。按照最大值排序之后,枚举 $i$ ,只要看前面有多少 $dis_j\equiv max_i+a_{x}-dis_i(mod\ P)$ 就行了。
#include<cstdio>
#include<algorithm>
#define fr first
#define sc second
#define mp make_pair
using namespace std;
typedef long long LL;
const int maxn=100000,Log=17,maxp=10000000;
int n,MOD,a[maxn+5],dep[maxn+5],ST[maxn+5][Log+5],val[maxn+5][Log+5];LL dis[maxn+5];
int S,si[maxn+5],MAX[maxn+5],ro,m;bool vis[maxn+5];pair<int,LL> p[maxn+5];LL ans;
int E,lnk[maxn+5],son[(maxn<<1)+5],nxt[(maxn<<1)+5],num[maxp+5];
#define Add(x,y) (son[++E]=(y),nxt[E]=lnk[x],lnk[x]=E)
void getST(int x,int pre=0){
dep[x]=dep[pre]+1;dis[x]=dis[pre]+a[x];ST[x][0]=pre;val[x][0]=a[x];
for (int j=1;j<=Log;j++) ST[x][j]=ST[ST[x][j-1]][j-1],val[x][j]=max(val[x][j-1],val[ST[x][j-1]][j-1]);
for (int j=lnk[x];j;j=nxt[j]) if (son[j]!=pre) getST(son[j],x);
}
inline int LCA(int x,int y){
if (dep[x]<dep[y]) swap(x,y);for (int j=Log;~j&&dep[x]>dep[y];j--) if (dep[ST[x][j]]>=dep[y]) x=ST[x][j];
if (x==y) return x;for (int j=Log;~j;j--) if (ST[x][j]!=ST[y][j]) x=ST[x][j],y=ST[y][j];return ST[x][0];
}
inline LL Dis(int x,int y) {int lca=LCA(x,y);return dis[x]+dis[y]-(dis[lca]<<1)+a[lca];}
inline int Max(int x,int y){
int lca=LCA(x,y),MAX=a[lca];
for (int j=Log;~j&&dep[x]>dep[lca];j--) if (dep[ST[x][j]]>=dep[lca]) MAX=max(MAX,val[x][j]),x=ST[x][j];
for (int j=Log;~j&&dep[y]>dep[lca];j--) if (dep[ST[y][j]]>=dep[lca]) MAX=max(MAX,val[y][j]),y=ST[y][j];
return MAX;
}
void getro(int x,int pre=0){
for (int j=(si[x]=1,MAX[x]=0,lnk[x]),u;j;j=nxt[j])
if (!vis[u=son[j]]&&u!=pre) getro(u,x),si[x]+=si[u],MAX[x]=max(MAX[x],si[u]);
MAX[x]=max(MAX[x],S-si[x]);if (!ro||MAX[ro]>MAX[x]) ro=x;
}
void Join(int x,int y,int pre=0){
p[++m]=mp(Max(x,y),Dis(x,y));
for (int j=lnk[x],u;j;j=nxt[j])
if (!vis[u=son[j]]&&u!=pre) Join(u,y,x);
}
inline void Count(int x,int y,int P,int f){
m=0;Join(x,y);sort(p+1,p+1+m);
for (int i=1;i<=m;i++){
int now=(P+p[i].fr+MOD-p[i].sc%MOD)%MOD;
ans+=f*num[now];num[p[i].sc%MOD]++;
}for (int i=1;i<=m;i++) num[p[i].sc%MOD]=0;
}
void Dfs(int x,int pre=0){
vis[x]=true;Count(x,x,a[x]%MOD,1);int sum=S;
for (int j=lnk[x],u;j;j=nxt[j])
if (!vis[u=son[j]]&&u!=pre){
Count(u,x,a[x]%MOD,-1);S=si[x]>si[u]?si[u]:sum-si[x];
ro=0;getro(u,x);Dfs(ro,x);
}
}
int main(){
freopen("program.in","r",stdin);freopen("program.out","w",stdout);
scanf("%d%d",&n,&MOD);for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),Add(x,y),Add(y,x);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);getST(1);
S=n;ro=0;getro(1);Dfs(ro);printf("%lld\n",ans+n);return 0;
}