ZigZagK的博客
[后缀自动机parent树+虚树]2022牛客暑期多校训练营1 B【Spirit Circle Observation】题解
2022年7月21日 12:23
牛客
查看标签

题目概述

Spirit Circle Observation

解题报告

找性质题最烦了.jpg。

首先建SAM,考虑一个直观的做法。直接考虑枚举SAM里的节点,然后在节点上考虑 $a999\cdots$ 这样的形式,并找到 $(a+1)000\cdots$ 对应节点,然后求 $a$ 前面部分和 $a+1$ 前面部分的重叠长度,并统计答案。这样会发现由于我们要从 $a$ 或 $a+1$ 前面的部分开始,这样难以快速找到对应的节点,所以这个做法难以实现。

不过,如果我们要求的是两个节点的最长公共后缀,那么只要求出他们在 $parent$ 树上的LCA,LCA的 $MAX$ 就是最长公共后缀。

所以如果我们考虑在节点后面添加 $a999\cdots$ 或 $(a+1)000\cdots$ ,这样想要求出重叠长度就会非常容易。为了避免计数重复,显然我们只需要在所有前缀节点上考虑添加。

顺着思路我们考虑枚举 $a$ 和前缀节点 $x$ ,然后看 $x$ 后面是不是 $a$ 或 $a+1$ ,如果是的话就找 $9$ 或 $0$ 的个数 $len$ ,表明 $[0,len]$ 都是可以考虑的范围。然后对于相同的 $len$ ,我们DFS一遍 $parent$ 树,对于节点 $p$ ,$p$ 任意两个子树在 $p$ 的重叠部分都是 $MAX_p$ ,方案数为 $MAX_p+1$ ,这样就可以统计出第一个为 $a$ ,长度为 $len$ 的方案数。

这样看起来是个 $O(10n^2)$ 暴力,但是分析一下发现,对于 $9$ 来说,$9$ 在整个串的出现次数 $<n$ ,并且连续一段 $9$ 只能被前面那个数 $a$ 用到,不能被其他数用到。$0$ 也同理。因此所有 $len$ 包含的前缀节点的个数和是 $O(n)$ 的,也就是说我们没必要遍历整个 $parent$ 树,只需要对每个 $len$ 建虚树就可以了。

复杂度 $O(10n+n\log_2n)$ 。

示例程序

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=400000,maxt=maxn<<1,maxe=maxt+(maxn<<1),LOG=18;

int n,suf0[maxn+5],suf9[maxn+5];char s[maxn+5];
int ro,pl,son[maxt+5][10],MAX[maxt+5],fai[maxt+5],pos[maxt+5];
int dep[maxt+5],lt[maxt+5],rt[maxt+5],ST[maxt+5][LOG+1];
vector<int> e0[maxn+5],e9[maxn+5];
int E,lnk[maxt+5],vt[maxt+5],nxt[maxe+5],to[maxe+5];
int tmpE,K,vn[maxt+5],top,stk[maxt+5];
int cnt[maxt+5][2];LL ans;

inline int newnode() {return ++pl;}
int Extend(int p,int c,int id){
    int np=newnode();MAX[np]=MAX[p]+1;pos[np]=id;
    while (p && !son[p][c]) son[p][c]=np,p=fai[p];
    if (!p) {fai[np]=ro;return np;}
    int q=son[p][c];if (MAX[p]+1==MAX[q]) {fai[np]=q;return np;}
    int nq=newnode();MAX[nq]=MAX[p]+1;pos[nq]=-1;
    for (int i=0;i<10;i++) son[nq][i]=son[q][i];
    fai[nq]=fai[q];fai[q]=fai[np]=nq;
    while (p && son[p][c]==q) son[p][c]=nq,p=fai[p];
    return np;
}
inline void Add(int *lnk,int x,int y) {to[++E]=y;nxt[E]=lnk[x];lnk[x]=E;}
void DFS(int x,int pre=0){
    lt[x]=++lt[0];dep[x]=dep[pre]+1;ST[x][0]=pre;
    for (int j=1;j<=LOG;j++) ST[x][j]=ST[ST[x][j-1]][j-1];
    for (int j=lnk[x];j;j=nxt[j]) DFS(to[j],x);
    rt[x]=lt[0];
}
int LCA(int x,int y){
    if (dep[x]<dep[y]) swap(x,y);
    for (int j=LOG;~j && dep[x]>dep[y];j--) if (dep[ST[x][j]]>=dep[y]) x=ST[x][j];
    if (x==y) return x;
    for (int j=LOG;~j;j--) if (ST[x][j]!=ST[y][j]) x=ST[x][j],y=ST[y][j];
    return ST[x][0];
}
#define Son(fa,x) (lt[fa]<=lt[x] && rt[x]<=rt[fa])
inline bool cmp(const int &i,const int &j) {return lt[i]<lt[j];}
void VT(){
    vn[++K]=ro;sort(vn+1,vn+1+K,cmp);
    int m=K;for (int i=2;i<=K;i++) vn[++m]=LCA(vn[i-1],vn[i]);
    sort(vn+1,vn+1+m,cmp);m=unique(vn+1,vn+1+m)-(vn+1);top=0;
    E=tmpE;for (int i=1;i<=m;i++) vt[vn[i]]=cnt[vn[i]][0]=cnt[vn[i]][1]=0;
    for (int i=1;i<=m;i++){
        while (top && !Son(stk[top],vn[i])) top--;
        if (top) Add(vt,stk[top],vn[i]);
        stk[++top]=vn[i];
    }
}
void Solve(int x){
    ans+=(LL)cnt[x][0]*cnt[x][1]*(MAX[x]+1);
    for (int j=vt[x],u;j;j=nxt[j]){
        Solve(u=to[j]);
        ans+=(LL)cnt[x][0]*cnt[u][1]*(MAX[x]+1);
        ans+=(LL)cnt[x][1]*cnt[u][0]*(MAX[x]+1);
        cnt[x][0]+=cnt[u][0];cnt[x][1]+=cnt[u][1];
    }
}
int main(){
    scanf("%d%s",&n,s+1);
    ro=newnode();for (int i=1,p=ro;i<=n;i++) p=Extend(p,s[i]-'0',i);
    for (int i=2;i<=pl;i++) Add(lnk,fai[i],i);tmpE=E;DFS(ro);
    for (int i=n;i;i--){
        suf0[i]=(s[i]=='0'?suf0[i+1]+1:0);
        suf9[i]=(s[i]=='9'?suf9[i+1]+1:0);
    }
    for (int a=0;a<9;a++){
        for (int i=0;i<n;i++) e0[i].clear(),e9[i].clear();
        for (int i=1;i<=pl;i++)
            if (pos[i]>=0 && s[pos[i]+1]-'0'==a)
                for (int j=0;j<=suf9[pos[i]+2];j++)
                    e9[j].push_back(i);
        for (int i=1;i<=pl;i++)
            if (pos[i]>=0 && s[pos[i]+1]-'0'==a+1)
                for (int j=0;j<=suf0[pos[i]+2];j++)
                    e0[j].push_back(i);
        for (int i=0;i<n;i++){
            if (e0[i].empty() || e9[i].empty()) continue;
            K=0;
            for (auto x:e0[i]) vn[++K]=x;
            for (auto x:e9[i]) vn[++K]=x;
//            printf("[%d,%d]\n",a,i);
//            printf("0:");for (auto x:e0[i]) printf("%d ",x);puts("");
//            printf("9:");for (auto x:e9[i]) printf("%d ",x);puts("");
            VT();
            for (auto x:e0[i]) cnt[x][0]++;
            for (auto x:e9[i]) cnt[x][1]++;
            Solve(ro);
        }
    }
    printf("%lld\n",ans);
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!