从来没做过这种容斥,长见识了😭。
定义 $f_i$ 表示走了 $i$ 行都合法的权值和,以及 $g_i$ 表示走了 $i$ 行全非法的权值和(特殊的,如果 $i=1$ 也认为非法)。
先考虑前 $n-1$ 行都合法,则权值为 $f_{n-1}g_1$ ,这样有个问题,就是 $n-1$ 行到 $n$ 行时不一定合法,因此我们强制 $n-1$ 到 $n$ 非法,即减去 $f_{n-2}g_2$ ,但是类似的会发现 $f_{n-2}g_2$ 里包含了 $n-2$ 到 $n-1$ 非法的情况是我们多减的,所以加上 $f_{n-3}g_3$ 。以此类推会发现这是个容斥:
$$ f_{n}=\sum_{i=1}^{n}(-1)^{i-1}f_{n-i}g_i=\sum_{i=1}^{n}f_{n-i}(-1)^{i-1}g_{i}\\ h_{i}=(-1)^{i-1}g_i\\ f_n=\sum_{i=1}^{n}f_{n-i}h_i $$
因此 $f=f*h+1$ ,即 $f={1\over 1-h}$ 。现在的问题只剩下求出 $g$ 。
由于 $g$ 是全非法的状态,所以所有行均满足下一个 $P>y+S(S(S(y)))$ ,也就是说列是递增的,我们可以在列上考虑DP。
不难发现 $S(S(S(y)))$ 最多只有 $2$ ,因此选了一个列之后最多不选两次就可以再次选下一个列。
定义 $F_{i,0/1/2}$ 表示 选到第 $i$ 列时还剩下 $0/1/2$ 次才可以再次选 的生成函数。那么考虑转移(假设 $S(S(S(i)))=k$ ,$i$ 权值为 $v_i$ ):
$$ vx\cdot F_{i-1,0}\to F_{i,k}\\ F_{i-1,0}\to F_{i,0}\\ F_{i-1,1}\to F_{i,0}\\ F_{i-1,2}\to F_{i,1} $$
不难发现可以利用矩阵进行转移,进一步的,我们把每个位置的转移矩阵 $M_i$ 存下来之后做分治矩阵乘法就可以得到 $F_{m}$ 。
令 $G=F_{m,0}+F_{m,1}+F_{m,2}$ ,则 $g_i=G_i$ ,求出 $h$ 之后多项式求逆就可以得到 $f$ 。
最后 $f_n$ 就是答案。
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;typedef vector<int> PN;
const int maxn=100000,maxt=1<<18,MOD=998244353;
int n,m,a[maxn+5];
int wn[maxt+5],tem[maxt+5];
struct Matrix{
PN s[3][3];
void zero() {for (int i=0;i<3;i++) for (int j=0;j<3;j++) s[i][j]={0};}
};
Matrix M[maxn+5],res;
PN G;int g[maxt+5],f[maxt+5];
inline int ADD(int x,int y) {return x+y>=MOD?x+y-MOD:x+y;}
inline int MUL(int x,int y) {return (LL)x*y%MOD;}
int Pow(int w,int b) {int s;for (s=1;b;b>>=1,w=MUL(w,w)) if (b&1) s=MUL(s,w);return s;}
void Make(){
int x=Pow(3,(MOD-1)/maxt);
wn[maxt/2]=1;
for (int i=maxt/2+1;i<maxt;i++) wn[i]=MUL(wn[i-1],x);
for (int i=maxt/2-1;i>=0;i--) wn[i]=wn[i<<1];
}
void NTT(int *a,int n,int f){
if (f>0){
for (int k=n>>1;k;k>>=1)
for (int i=0;i<n;i+=k<<1)
for (int j=0;j<k;j++){
int x=a[i+j],y=a[i+j+k];
a[i+j+k]=MUL(x+MOD-y,wn[k+j]);
a[i+j]=ADD(x,y);
}
} else {
for (int k=1;k<n;k<<=1)
for (int i=0;i<n;i+=k<<1)
for (int j=0;j<k;j++){
int x=a[i+j],y=MUL(a[i+j+k],wn[k+j]);
a[i+j+k]=ADD(x,MOD-y);
a[i+j]=ADD(a[i+j],y);
}
for (int i=0,INV=MOD-(MOD-1)/n;i<n;i++) a[i]=MUL(a[i],INV);
reverse(a+1,a+n);
}
}
PN operator * (const PN &a,const PN &b){
static int A[maxt+5],B[maxt+5];static PN c;
int n=a.size(),m=b.size(),t;
for (t=1;t<n+m-1;t<<=1);
for (int i=0;i<n;i++) A[i]=a[i];for (int i=n;i<t;i++) A[i]=0;
for (int i=0;i<m;i++) B[i]=b[i];for (int i=m;i<t;i++) B[i]=0;
NTT(A,t,1);NTT(B,t,1);
for (int i=0;i<t;i++) A[i]=MUL(A[i],B[i]);NTT(A,t,-1);
c.clear();for (int i=0;i<n+m-1;i++) c.push_back(A[i]);
return c;
}
void operator += (PN &a,const PN &b){
if (b.size()>a.size()) a.resize(b.size());
for (int i=0,si=b.size();i<si;i++) a[i]=ADD(a[i],b[i]);
}
Matrix operator * (const Matrix &a,const Matrix &b){
static Matrix c;c.zero();
for (int i=0;i<3;i++)
for (int j=0;j<3;j++)
for (int k=0;k<3;k++)
c.s[i][j]+=a.s[i][k]*b.s[k][j];
return c;
}
Matrix Solve(int L,int R) {if (L==R) return M[L];int mid=L+(R-L>>1);return Solve(L,mid)*Solve(mid+1,R);}
void Inv(int *a,int *b,int n){ // a=1/b
if (n==1) {a[0]=Pow(b[0],MOD-2);return;}
Inv(a,b,n>>1);
for (int i=0;i<n;i++) tem[i]=b[i],tem[i+n]=a[i+n]=0;
NTT(tem,n<<1,1);NTT(a,n<<1,1);
for (int i=0;i<(n<<1);i++) tem[i]=MUL(a[i],2+MOD-MUL(tem[i],a[i]));
NTT(tem,n<<1,-1);for (int i=0;i<n;i++) a[i]=tem[i],a[i+n]=0;
}
int main(){
Make();
scanf("%d%d",&n,&m);
for (int i=1;i<=m;i++) scanf("%d",&a[i]);
for (int i=1;i<=m;i++){
M[i].zero();M[i].s[1][0]=M[i].s[2][1]={1};
if (i<16) M[i].s[0][0]={1,a[i]}; else
if (i<65536) M[i].s[0][0]={1},M[i].s[0][1]={0,a[i]};
else M[i].s[0][0]={1},M[i].s[0][2]={0,a[i]};
}
res=Solve(1,m);G={0};
for (int i=0;i<3;i++) G+=res.s[0][i];
g[0]=1;for (int i=1,si=G.size();i<si;i++) g[i]=(i&1?MOD-G[i]:G[i])%MOD;
int t;for (t=1;t<=n;t<<=1);
Inv(f,g,t);
printf("%d\n",f[n]);
return 0;
}