给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
先套路一波:把两个串接在一起(加个分隔符),求后缀数组。然后答案就是 $\sum_{i=1}^{n}\sum_{j}LCP(i,j)$ ,其中 $i$ 和 $j$ 是在分隔符前后的后缀。
这个可以用两棵线段树搞啊,一个后缀在后缀数组中往后统计贡献,$LCP$ 会越来越小,所以当 $Height_i$ 比上一次小的时候,$Height_i+1$ 后面的贡献就没用了,清空。然后每次在对应线段树中给 $[1,Height_{i+1}]$ 均加上 $1$ 的权值就行了。注意这是双tag线段树,定义清空tag优先级大于加tag优先级就行了。
好像还可以SAM搞,没有想过Orz。
#include<cstdio>
#include<cctype>
using namespace std;
typedef long long LL;
const int maxn=400000,maxt=maxn<<2;
int pre,n;char s[maxn+5];LL ans;
int SA[maxn+5],rk[maxn+5],t[maxn+5],ha[maxn+5],H[maxn+5];
#define LS (p<<1)
#define RS (p<<1|1)
#define Addtag(p) (sum[p]=0,tag[p]=true,add[p]=0)
#define Addadd(p,len,v) (sum[p]+=(len)*(v),add[p]+=(v))
struct SegmentTree{
int sum[maxt+5],add[maxt+5];bool tag[maxt+5];
inline void Pushdown(int p,int l,int r){
if (tag[p]) Addtag(LS),Addtag(RS),tag[p]=false;
if (add[p]) {int mid=l+(r-l>>1);Addadd(LS,mid-l+1,add[p]);Addadd(RS,r-mid,add[p]);add[p]=0;}
}
void Clear(int L,int R,int l=1,int r=n,int p=1){
if (R<l||r<L) return;if (L<=l&&r<=R) {Addtag(p);return;}Pushdown(p,l,r);
int mid=l+(r-l>>1);Clear(L,R,l,mid,LS);Clear(L,R,mid+1,r,RS);sum[p]=sum[LS]+sum[RS];
}
void Insert(int L,int R,int k,int l=1,int r=n,int p=1){
if (R<l||r<L) return;if (L<=l&&r<=R) {Addadd(p,r-l+1,k);return;}Pushdown(p,l,r);
int mid=l+(r-l>>1);Insert(L,R,k,l,mid,LS);Insert(L,R,k,mid+1,r,RS);sum[p]=sum[LS]+sum[RS];
}
int Ask(int L,int R,int l=1,int r=n,int p=1){
if (R<l||r<L) return 0;if (L<=l&&r<=R) return sum[p];Pushdown(p,l,r);
int mid=l+(r-l>>1);return Ask(L,R,l,mid,LS)+Ask(L,R,mid+1,r,RS);
}
};
SegmentTree A,B;
inline void Sort(int n,int m){
for (int i=0;i<=m;i++) ha[i]=0;for (int i=1;i<=n;i++) ha[rk[i]]++;
for (int i=1;i<=m;i++) ha[i]+=ha[i-1];for (int i=n;i;i--) SA[ha[rk[t[i]]]--]=t[i];
}
#define Diff(i,j) (t[SA[i]]!=t[SA[j]]||SA[i]+k>n||SA[j]+k>n||t[SA[i]+k]!=t[SA[j]+k])
inline void Make(char *s){
int m=255;for (int i=1;i<=n;i++) rk[i]=s[i],t[i]=i;Sort(n,m);
for (int k=1,p=0;p<n;k<<=1,m=p){
p=0;for (int i=n-k+1;i<=n;i++) t[++p]=i;
for (int i=1;i<=n;i++) if (SA[i]>k) t[++p]=SA[i]-k;
Sort(n,m);for (int i=1;i<=n;i++) t[i]=rk[i];rk[SA[1]]=p=1;
for (int i=2;i<=n;rk[SA[i++]]=p) p+=Diff(i-1,i);
}
for (int i=1,k=0;i<=n;i++) {if (k) k--;while (s[i+k]==s[SA[rk[i]-1]+k]) k++;H[rk[i]]=k;}
}
int main(){
freopen("program.in","r",stdin);
freopen("program.out","w",stdout);
char ch=getchar();while (!islower(ch)) ch=getchar();
while (islower(ch)) s[++n]=ch,pre++,ch=getchar();
s[++n]='%';while (!islower(ch)) ch=getchar();
while (islower(ch)) s[++n]=ch,ch=getchar();Make(s);H[1]=0;
//分隔符'%'比'a'小,所以第一个串一定是'%XXX',就不用管了XD
for (int i=2;i<=n;i++){
if (SA[i]<=pre){
ans+=B.Ask(1,H[i]);if (H[i]<H[i-1]) A.Clear(H[i]+1,n),B.Clear(H[i]+1,n);
A.Insert(1,H[i+1],1);
} else{
ans+=A.Ask(1,H[i]);if (H[i]<H[i-1]) A.Clear(H[i]+1,n),B.Clear(H[i]+1,n);
B.Insert(1,H[i+1],1);
}
}
return printf("%lld\n",ans),0;
}