ZigZagK的博客
[KMP+后缀数组+主席树]2022牛客暑期多校训练营6 L【Striking String Problem】题解
2022年8月12日 16:42
牛客
查看标签

题目概述

Striking String Problem

解题报告

神仙题,根本想不到。

记 $n$ 为 $S$ 长度,$m$ 为 $T$ 长度,$U_i$ 表示 $S[l_1,r_1]+\cdots+S[l_i,r_i]$ ,$len_i=r_i-l_i+1$ 。处理出 $T$ 的 $fail$ 数组,对 $S$ 先做一遍KMP,记录 $pre_i$ 表示 $S[1,i]$ 与 $T$ 的最大匹配长度,$cnt_i$ 表示 $S[1,i]$ 与 $T$ 的匹配次数。


考虑在这个接起来的串上做KMP,定义 $M_i$ 表示 $U_i$ 与 $T$ 的最大匹配长度(即最大的 $len$ 使得 $U_i$ 末尾 $len$ 个与 $T$ 前 $len$ 个相同)。然后考虑从 $M_{i-1}$ 推到 $M_i$ 。

第一种情况:$M_i$ 和 $M_{i-1}$ 无关,则我们需要求出 $S[l_i,r_i]$ 与 $T$ 的最大匹配长度 $len(len\le len_i)$ ,可以通过 $pre_{r_i}$ 跳 $fail$ 树得到。

第二种情况:$M_i$ 从 $M_{i-1}$ 接一部分过来,假设保留的长度为 $D$ ,需要满足:

  1. $D$ 是 $M_{i-1}$ 的 $fail$ 树祖先,这样才能将 $T$ 平移过来。
  2. $T[D+1,D+len_i]=S[l_i,r_i]$ ,即 $LCP(T[D+1,m],S[l_i,n])\ge len_i$ ,可以通过求 $S+T$ 串的后缀数组求出LCP。

我们对于每一个 $S[l_i,n]$ ,找出 $[rkL_i,rkR_i]$ 表示和 $S[l_i,n]$ 的 $LCP\ge len_i$ 的区间,那么只要 $T[D+1,m]$ 的 $rank$ 在这个区间内就说明 $T[D+1,D+len_i]$ 可以和 $S[l_i,r_i]$ 匹配。而我们要从 $M_{i-1}$ 的祖先中找出最大的满足的 $D$ ,则 $M_i=D+len_i$ 。

考虑用权值主席树,每棵树存 $rank$ 对应的信息,就可以求出 $x$ 到根中 $[rkL_i,rkR_i]$ 中的最大值 $D$ 。

在这两种情况中挑个大的就可以得到 $M_i$ ,为了后面处理方便,如果 $M_i=m$ ,则令 $M_i=fail[M_i]$ 。


求出 $M_i$ 之后我们考虑如何求出答案。首先为了统一询问,我们把 $[L,R]$ 拆成 $[1,R]$ 减去 $[1,L+m-2]$ 。

然后考虑求 $U[1,P]$ 的答案,我们可以先求出整段的 $U_i$ ,然后最后补上不完整的一段。

继续考虑KMP的过程,之前已经匹配了 $M_{i-1}$ ,定义 $Solve(M,l,r)$ 表示在 $M$ 的基础上添上了 $S[l,r]$ ,产生的匹配数。类似上面,我们讨论两种情况:

第一种情况:没有用到 $M$ ,则匹配数为 $cnt_{r}-cnt_{l+m-2}$ 。

第二种情况:用到了 $M$ ,假设用到长度为 $D$ ,需要满足:

  1. $D$ 是 $M$ 的祖先。
  2. $T[D+1,m]=S[l,l+(m-D)-1]$ ,即 $LCP(T[D+1,m],S[l,n])\ge m-D$ 。
  3. $l+(m-D)-1\le r$ ,即 $D\ge m-(r-l+1)$ 。

这次条件只和 $D$ 有关,因此我们需要另一个权值主席树,用于数出 $M$ 的祖先中包含 $rank(S[l,n])$ 的个数,可以通过标记永久化和单点查询求出。$M$ 祖先中满足条件的范围也可以通过 $fail$ 树上倍增求出。

这两种情况累加起来就是 $Solve(M,l,r)$ ,令 $ans_i$ 表示 $U_i$ 的答案,则 $ans_i=ans_{i-1}+Solve(M_{i-1},l_i,r_i)$ 。

最后查询 $U[1,P]$ 就是找到对应的 $ans_i$ ,并加上不完整的一段。

示例程序

#include<cstdio>
#include<cctype>
using namespace std;
typedef long long LL;
const int maxn=1e6,maxm=5e5,maxk=5e5,maxl=maxn+maxm,maxt=6e7,LOG=20;

int n,m,K,Q,pre[maxn+5],cnt[maxn+5];
int l[maxk+5],r[maxk+5],len[maxk+5],L[maxk+5],R[maxk+5],M[maxk+5];
char s[maxn+5],t[maxm+5],a[maxl+5];
int SA[maxl+5],rk[maxl+5],ha[maxl+5],sc[(maxl<<1)+5];
int lg[maxl+5],RMQ[LOG+1][maxl+5];
int E,lnk[maxm+5],nxt[maxm+5],to[maxm+5];
int fai[maxm+5],ST[LOG+1][maxm+5];
int pl,ro[2][maxm+5],ls[maxt+5],rs[maxt+5],val[maxt+5];
LL pos[maxk+5],ans[maxk+5];

#define EOLN(x) ((x)==10 || (x)==13 || (x)==EOF)
inline char readc(){
    static char buf[1<<16],*l=buf,*r=buf;
    return l==r && (r=(l=buf)+fread(buf,1,1<<16,stdin),l==r)?EOF:*l++;
}
template<typename T> int readi(T &x){
    T tot=0;char ch=readc(),lst='+';
    while (!isdigit(ch)) {if (ch==EOF) return EOF;lst=ch;ch=readc();}
    while (isdigit(ch)) tot=(tot<<3)+(tot<<1)+(ch^48),ch=readc();
    lst=='-'?x=-tot:x=tot;return EOLN(ch);
}
int reads(char *s){
    int len=0;char ch=readc();
    while (!islower(ch)) {if (ch==EOF) return EOF;ch=readc();}
    while (islower(ch)) s[++len]=ch,ch=readc();
    s[len+1]=0;return len;
}
struct fastO{
    int si;char buf[1<<16];
    fastO() {si=0;}
    void putc(char ch){
        if (si==(1<<16)) fwrite(buf,1,si,stdout),si=0;
        buf[si++]=ch;
    }
    ~fastO() {fwrite(buf,1,si,stdout);}
}fo;
template<typename T> void writei(T x,char ch='\n'){
    int len=0,buf[100];
    if (x<0) fo.putc('-'),x=-x;
    do buf[len++]=x%10,x/=10; while (x);
    while (len) fo.putc(buf[--len]+48);
    fo.putc(ch);
}
inline int min(int x,int y) {return x<y?x:y;}
inline int max(int x,int y) {return x>y?x:y;}
void Sort(int n,int m){
    for (int i=0;i<=m;i++) ha[i]=0;
    for (int i=1;i<=n;i++) ha[rk[i]]++;
    for (int i=1;i<=m;i++) ha[i]+=ha[i-1];
    for (int i=n;i;i--) SA[ha[rk[sc[i]]]--]=sc[i];
}
void MakeST(char *s,int n,int m=255){
    for (int i=1;i<=n;i++) rk[i]=s[i],sc[i]=i,sc[i+n]=0;Sort(n,m);
    for (int k=1,p=0;p<n;m=p,k<<=1){
        p=0;for (int i=n-k+1;i<=n;i++) sc[++p]=i;
        for (int i=1;i<=n;i++) if (SA[i]>k) sc[++p]=SA[i]-k;
        Sort(n,m);for (int i=1;i<=n;i++) sc[i]=rk[i];rk[SA[1]]=p=1;
        for (int i=2;i<=n;i++)
            rk[SA[i]]=(p+=sc[SA[i-1]]!=sc[SA[i]] || sc[SA[i-1]+k]!=sc[SA[i]+k]);
    }
    for (int i=1,k=0;i<=n;i++){
        if (k) k--;
        while (s[i+k]==s[SA[rk[i]-1]+k]) k++;
        RMQ[0][rk[i]]=k;
    }
    for (int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
    for (int j=1;(1<<j)<n;j++)
        for (int i=2;i+(1<<j)-1<=n;i++)
            RMQ[j][i]=min(RMQ[j-1][i],RMQ[j-1][i+(1<<j-1)]);
}
int LCP(int x,int y){
    x++;int k=lg[y-x+1];
    return min(RMQ[k][x],RMQ[k][y-(1<<k)+1]);
}
int FindL(int x,int len){
    int L=1,R=x-1;
    for (int mid=L+(R-L>>1);L<=R;mid=L+(R-L>>1))
        LCP(mid,x)>=len?R=mid-1:L=mid+1;
    return L;
}
int FindR(int x,int len){
    int L=x+1,R=n+m;
    for (int mid=L+(R-L>>1);L<=R;mid=L+(R-L>>1))
        LCP(x,mid)>=len?L=mid+1:R=mid-1;
    return R;
}
inline void Add(int x,int y) {to[++E]=y;nxt[E]=lnk[x];lnk[x]=E;}
int Build(int L,int R,int k){
    int p=++pl;val[p]=k;
    if (L==R) return p;
    int mid=L+(R-L>>1);
    ls[p]=Build(L,mid,k);rs[p]=Build(mid+1,R,k);
    return p;
}
int Insert(int p,int pos,int k,int l=1,int r=n+m){
    int now=++pl;ls[now]=ls[p];rs[now]=rs[p];val[now]=val[p];
    if (l==r) {val[now]=k;return now;}
    int mid=l+(r-l>>1);
    pos<=mid?ls[now]=Insert(ls[p],pos,k,l,mid):rs[now]=Insert(rs[p],pos,k,mid+1,r);
    val[now]=max(val[ls[now]],val[rs[now]]);
    return now;
}
int Askmax(int p,int L,int R,int l=1,int r=n+m){
    if (L==l && r==R) return val[p];
    int mid=l+(r-l>>1);
    if (R<=mid) return Askmax(ls[p],L,R,l,mid); else if (L>mid) return Askmax(rs[p],L,R,mid+1,r);
    else return max(Askmax(ls[p],L,mid,l,mid),Askmax(rs[p],mid+1,R,mid+1,r));
}
int Update(int p,int L,int R,int l=1,int r=n+m){
    int now=++pl;ls[now]=ls[p];rs[now]=rs[p];val[now]=val[p];
    if (L==l && r==R) {val[now]++;return now;}
    int mid=l+(r-l>>1);
    if (R<=mid) ls[now]=Update(ls[p],L,R,l,mid); else if (L>mid) rs[now]=Update(rs[p],L,R,mid+1,r);
    else ls[now]=Update(ls[p],L,mid,l,mid),rs[now]=Update(rs[p],mid+1,R,mid+1,r);
    return now;
}
int Asksum(int A,int B,int pos,int sum=0,int l=1,int r=n+m){
    sum+=val[B]-val[A];if (l==r) return sum;
    int mid=l+(r-l>>1);
    return pos<=mid?Asksum(ls[A],ls[B],pos,sum,l,mid):Asksum(rs[A],rs[B],pos,sum,mid+1,r);
}
void DFS(int x,int pre){
    ro[0][x]=Insert(ro[0][pre],rk[x+1+n],x);
    ro[1][x]=ro[1][pre];
    if (x>0) ro[1][x]=Update(ro[1][x],FindL(rk[x+1+n],m-x),FindR(rk[x+1+n],m-x));
    for (int j=lnk[x];j;j=nxt[j]) if (to[j]<m) DFS(to[j],x);
}
int Solve(int M,int l,int r){
    if (M+r-l+1<m) return 0;
    int ans=0;
    if (r-l+1>=m) ans+=cnt[r]-cnt[l+m-2];
    int x=M;for (int j=LOG;~j;j--) if (ST[j][x]>=m-(r-l+1)) x=ST[j][x];
    return ans+Asksum(ro[1][fai[x]],ro[1][M],rk[l]);
}
LL Sum(LL P){
    if (!P) return 0;
    int L=1,R=K;
    for (int mid=L+(R-L>>1);L<=R;mid=L+(R-L>>1))
        P<=pos[mid]?R=mid-1:L=mid+1;
    if (P==pos[L]) return ans[L];
    return ans[L-1]+Solve(M[L-1],l[L],l[L]+P-pos[L-1]-1);
}
int main(){
    n=reads(s);m=reads(t);
    for (int i=1;i<=n;i++) a[i]=s[i];
    for (int i=1;i<=m;i++) a[i+n]=t[i];
    MakeST(a,n+m);
    readi(K);readi(Q);
    for (int i=1;i<=K;i++){
        readi(l[i]);readi(r[i]);len[i]=r[i]-l[i]+1;
        L[i]=FindL(rk[l[i]],len[i]);
        R[i]=FindR(rk[l[i]],len[i]);
    }
    Add(fai[1],1);
    for (int i=2,j=0;i<=m;i++){
        while (j && t[j+1]!=t[i]) j=fai[j];
        j+=(t[j+1]==t[i]);
        fai[i]=j;Add(fai[i],i);
    }
    for (int i=1;i<=m;i++) ST[0][i]=fai[i];
    for (int j=1;j<=LOG;j++)
        for (int i=1;i<=m;i++)
            ST[j][i]=ST[j-1][ST[j-1][i]];
    for (int i=1,j=0;i<=n;i++){
        while (j && t[j+1]!=s[i]) j=fai[j];
        j+=(t[j+1]==s[i]);
        pre[i]=j;
        if (j==m) cnt[i]++,j=fai[j];
    }
    for (int i=1;i<=n;i++) cnt[i]+=cnt[i-1];
    fai[0]=m+1;ro[0][m+1]=Build(1,n+m,-1);ro[1][m+1]=Build(1,n+m,0);
    DFS(0,m+1);
    for (int i=1;i<=K;i++){
        int MAX=Askmax(ro[0][M[i-1]],L[i],R[i]);
        int x=pre[r[i]];
        for (int j=LOG;~j;j--)
            if (ST[j][x]>len[i]) x=ST[j][x];
        if (x>len[i]) x=fai[x];
        M[i]=x;if (~MAX) M[i]=max(M[i],MAX+len[i]);
        if (M[i]==m) M[i]=fai[M[i]];
    }
    for (int i=1;i<=K;i++){
        pos[i]=pos[i-1]+r[i]-l[i]+1;
        ans[i]=ans[i-1]+Solve(M[i-1],l[i],r[i]);
//        printf("M[%d]=%d ans[%d]=%lld\n",i,M[i],i,ans[i]);
    }
    for (int t=1;t<=Q;t++){
        LL x,y;readi(x);readi(y);
        if (y-x+1<m) {writei(0);continue;}
        writei(Sum(y)-Sum(x+m-2));
    }
    return 0;
}
版权声明:本博客所有文章除特别声明外,均采用 CC BY 4.0 CN协议 许可协议。转载请注明出处!
请不要发毫无意义或内容不文明的评论。与本文无关评论请发留言板!
2022-08-23 11:03:48Owen_codeisking
2022-08-23 11:03:48

感谢大佬!查了好多篇题解都不说人话,您这篇终于看懂了!

访客
2022-08-23 11:40:14ZigZagK
2022-08-23 11:40:14
@Owen_codeisking 

其实官方发的讲的挺清楚的,我这篇写的和官方的也差不多
无奈.jpg

博主
2022-08-23 12:33:52Owen_codeisking
2022-08-23 12:33:52
@ZigZagK 

好像确实,应该是结合代码好理解一些 QAQ

访客