听说这题卡倍增SA,表示强烈谴责😡(不过我没想到SA怎么做)。
由于需要求任意两个字符串之间的最长公共子串,因此考虑广义SAM,在构造时,记录一下 $i$ 号串出现在哪些节点。
对于广义SAM的一个节点 $x$ ,如果以这个节点作为LCS(长度为 $MAX_x$ ),那么Parent树中 $x$ 的子树里,所有串之间都存在 $MAX_x$ 的边。由于要求最大生成树,因此只要优先让 $MAX_x$ 大的节点先被考虑,就可以保证权值最大(因为 $x$ 的祖先虽然覆盖了 $x$ 的所有串,但是权值是小于 $MAX_x$ 的)。
接下来考虑如何进行并查集合并,肯定是不能暴力合并的,优秀一些的想法是利用启发式合并维护 $x$ 子树中所有串的编号。不过我们不难意识到其实根本没必要维护 $x$ 子树中的编号,只需要随便记录一个编号就行了,因为按照 $MAX_x$ 从大到小考虑时,正好也是按照Parent树从深到浅考虑,因此 $x$ 的儿子的子树中肯定已经全部属于同一个连通块了,所以只需要随便记录一个编号,然后合并的时候按照记录的编号进行加边。不过需要注意的是如果 $x$ 是叶子,那么需要把 $x$ 中出现的串合并一下。
#include<set>
#include<cstdio>
#include<cctype>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=2000000,maxt=maxn<<1;
int n;char s[maxn+5];LL ans;
int pl=1,ro=1,son[maxt+5][26],fai[maxt+5],MAX[maxt+5];
int ID[maxt+5],who[maxt+5],fat[maxn+5];vector<int> S[maxt+5],e[maxt+5];
#define EOLN(x) ((x)==10 || (x)==13 || (x)==EOF)
inline char readc(){
static char buf[100000],*l=buf,*r=buf;
return l==r && (r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
int reads(char *s){
int len=0;char ch=readc();
while (!islower(ch)) {if (ch==EOF) return EOF;ch=readc();}
while (islower(ch)) s[++len]=ch,ch=readc();
s[len+1]=0;return len;
}
int Extend(int p,int c,int ID){
if (son[p][c]){
int q=son[p][c];if (MAX[p]+1==MAX[q]) {S[q].push_back(ID);return q;}
int nq=++pl;MAX[nq]=MAX[p]+1;for (int i=0;i<26;i++) son[nq][i]=son[q][i];
fai[nq]=fai[q];fai[q]=nq;while (p && son[p][c]==q) son[p][c]=nq,p=fai[p];
S[nq].push_back(ID);return nq;
} else {
int np=++pl;MAX[np]=MAX[p]+1;
while (p && !son[p][c]) son[p][c]=np,p=fai[p];
if (!p) {fai[np]=ro;S[np].push_back(ID);return np;}
int q=son[p][c];if (MAX[p]+1==MAX[q]) {fai[np]=q;S[np].push_back(ID);return np;}
int nq=++pl;MAX[nq]=MAX[p]+1;for (int i=0;i<26;i++) son[nq][i]=son[q][i];
fai[nq]=fai[q];fai[q]=fai[np]=nq;while (p && son[p][c]==q) son[p][c]=nq,p=fai[p];
S[np].push_back(ID);return np;
}
}
int getfa(int x) {return x==fat[x]?x:fat[x]=getfa(fat[x]);}
void Merge(int x,int y,int z){
x=getfa(x);y=getfa(y);
if (x==y) return;
fat[x]=y;ans+=z;
}
inline bool cmp(const int &i,const int &j) {return MAX[i]>MAX[j];}
int main(){
scanf("%d",&n);
for (int i=1;i<=n;i++){
reads(s);fat[i]=i;
for (int j=1,p=ro;s[j];j++) p=Extend(p,s[j]-'a',i);
}
for (int i=1;i<=pl;i++){
if (fai[i]) e[fai[i]].push_back(i);
ID[i]=i;
}
sort(ID+1,ID+1+pl,cmp);
for (int i=1;i<=pl;i++){
int x=ID[i];
if (!S[x].empty()){
int fr=*S[x].begin();
for (auto y:S[x]) Merge(y,fr,MAX[x]);
who[x]=fr;
}
for (auto u:e[x]){
if (!who[x]) who[x]=who[u];
Merge(who[x],who[u],MAX[x]);
}
}
printf("%lld\n",ans);
return 0;
}