ZigZagK的博客
[矩阵维护DP+树链剖分+不重叠ST表]2021牛客暑期多校训练营6 K【Starch Cat】题解
2021年8月5日 18:51
牛客
查看标签

题目概述

Starch Cat

解题报告

正解是猫树,但是我不会,所以我选择黑科技。

首先我们能想到一个比较暴力的做法,每次询问 $x,y$ 路径上的最大独立集,就是求 $x\to LCA(x,y),y\to LCA(x,y)$ 的两个最大独立集,然后在 $LCA$ 处讨论一下。

定义 $f_{i,0/1}$ 表示 $i$ 点不选/选的最大独立集,然后考虑 $x\to fa_x$ 的DP:

$$ f_{fa_x,0}=\max\{f_{x,0},f_{x,1}\}\\ f_{fa_x,1}=f_{x,0}+a_{fa_x} $$

这可以用矩阵维护(当然,矩乘定义是 $c_{i,j}=\max_{k}\{a_{i,k}+b_{k,j}\}$ ):

$$ \begin{bmatrix}f_{x,0}&f_{x,1}\end{bmatrix}\begin{bmatrix}0&a_{fa_x}\\0&-\infty\end{bmatrix}=\begin{bmatrix}f_{fa_x,0}&f_{fa_x,1}\end{bmatrix} $$

令 $x$ 的矩阵为 $\begin{bmatrix}0&a_{fa_x}\\0&-\infty\end{bmatrix}$ ,那么求出 $x\to LCA(x,y),y\to LCA(x,y)$​ 路径上矩阵的乘积就可以知道DP的值。

所以暴力的做法就很显然了:用倍增求出矩阵乘积即可。

但是这道题数据范围太大,倍增求LCA的做法复杂度稳定太大了,所以过不了。

倍增复杂度太大,我们就考虑常数小很多的树链剖分,但是这样就无法维护ST表了(ST表有重复部分),

接下来利用黑科技不重叠ST表来维护区间矩阵乘积就行了。

示例程序

#include<cstdio>
#include<cctype>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=500000,LOG=19,maxs=1<<LOG,maxm=maxn<<1,MOD=998244353;

int n,m,seed,a[maxn+5],dep[maxn+5],si[maxn+5],fa[maxn+5];
int E,lnk[maxn+5],nxt[maxm+5],son[maxm+5];
int dfn,Lt[maxn+5],Rt[maxn+5],SH[maxn+5],top[maxn+5],lg[maxs+5];
struct Matrix{
    LL s[2][2];
    void build(LL a,LL b,LL c,LL d) {s[0][0]=a;s[0][1]=b;s[1][0]=c;s[1][1]=d;}
    void clear() {for (int i=0;i<2;i++) for (int j=0;j<2;j++) s[i][j]=-(1e18);}
    void unit() {clear();for (int i=0;i<2;i++) s[i][i]=0;}
};
Matrix M[maxn+5],A[LOG+1][maxn+5],B[LOG+1][maxn+5];
struct Rand{
    unsigned int n,seed;
    Rand(unsigned int n,unsigned int seed)
    :n(n),seed(seed){}
    int get(long long lastans){
        seed ^= seed << 13;
        seed ^= seed >> 17;
        seed ^= seed << 5;
        return (seed^lastans)%n+1;
    }
};

#define EOLN(x) ((x)==10 || (x)==13 || (x)==EOF)
inline char readc(){
    static char buf[100000],*l=buf,*r=buf;
    return l==r && (r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<typename T> int readi(T &x){
    T tot=0;char ch=readc(),lst='+';
    while (!isdigit(ch)) {if (ch==EOF) return EOF;lst=ch;ch=readc();}
    while (isdigit(ch)) tot=(tot<<3)+(tot<<1)+(ch^48),ch=readc();
    lst=='-'?x=-tot:x=tot;return EOLN(ch);
}
inline void Add(int x,int y) {son[++E]=y;nxt[E]=lnk[x];lnk[x]=E;}
inline Matrix operator * (const Matrix &a,const Matrix &b){
    static Matrix c;c.clear();
    for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
            c.s[i][j]=max(a.s[i][0]+b.s[0][j],a.s[i][1]+b.s[1][j]);
    return c;
}
void DFS(int x,int pre=0){
    dep[x]=dep[pre]+1;si[x]=1;fa[x]=pre;
    for (int j=lnk[x],u;j;j=nxt[j])
        if ((u=son[j])!=pre){
            DFS(u,x);si[x]+=si[u];
            if (si[u]>si[SH[x]]) SH[x]=u;
        }
}
void HLD(int x,int lst,int pre=0){
    Lt[x]=++dfn;top[x]=lst;
    if (pre) M[dfn].build(0,a[pre],0,-(1e18));
    if (SH[x]) HLD(SH[x],lst,x);
    for (int j=lnk[x],u;j;j=nxt[j])
        if ((u=son[j])!=pre && u!=SH[x]) HLD(u,u,x);
    Rt[x]=dfn;
}
inline Matrix Ask(int L,int R) {return L==R?M[L]:A[lg[L^R]][R]*B[lg[L^R]][L];}
LL Query(int x,int y){
    static Matrix X,Y;X.unit();Y.unit();
    int s=x,t=y;
    while (top[x]!=top[y])
        if (dep[top[x]]>dep[top[y]]) X=X*Ask(Lt[top[x]],Lt[x]),x=fa[top[x]];
        else Y=Y*Ask(Lt[top[y]],Lt[y]),y=fa[top[y]];
    int lca=x;
    if (Lt[x]<Lt[y]) lca=x,Y=Y*Ask(Lt[x]+1,Lt[y]);
    if (Lt[x]>Lt[y]) lca=y,X=X*Ask(Lt[y]+1,Lt[x]);
    LL f[2]={max(X.s[0][0],X.s[1][0]+a[s]),max(X.s[0][1],X.s[1][1]+a[s])};
    LL g[2]={max(Y.s[0][0],Y.s[1][0]+a[t]),max(Y.s[0][1],Y.s[1][1]+a[t])};
    return max(f[0]+g[0],f[1]+g[1]-a[lca]);
}
#define L(i,k) ((i)>>(k)<<(k))
#define R(i,k) (((i)+(1<<(k))>>(k)<<(k))-1)
int main(){
    readi(n);readi(m);readi(seed);
    for (int i=1;i<=n;i++) readi(a[i]);
    for (int i=2,x;i<=n;i++) readi(x),Add(x,i),Add(i,x);
    DFS(1);dfn=-1;HLD(1,1);
    for (int j=1;j<LOG;j++) lg[1<<j]=j;
    for (int i=1;i<maxs;i++) if (!lg[i]) lg[i]=lg[i-1];
    for (int i=0;i<n;i++) A[0][i]=B[0][i]=M[i];
    for (int j=1;(1<<j)<=n;j++){
        A[j][0].unit();for (int i=0;i<=n-1;i++) A[j][i]=(L(i,j)==i?M[i]:M[i]*A[j][i-1]);
        B[j][n].unit();for (int i=n-1;i>=0;i--) B[j][i]=(R(i,j)==i?M[i]:B[j][i+1]*M[i]);
    }
    LL lstans=0,ans=0;
    Rand rand(n,seed);
    for (int i=1;i<=m;i++){
        int u=rand.get(lstans);
        int v=rand.get(lstans);
        int x=rand.get(lstans);
        lstans=Query(u,v);
        ans=(ans+lstans%MOD*x)%MOD;
    }
    printf("%lld\n",ans);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!