P4221 [WC2018]州区划分

P4221 [WC2018]州区划分

我们考虑一个 dpdp ,设 dp[s]dp[s] 表示 ss 这个集合城市得到的答案,f[s]f[s] 表示 ss 集合城市的权值和(如果联通且存在欧拉路径我们视之为 0 )。

于是有:

dp[S]=1f[S]sSdp[s]f[Ss]dp[S] = \frac{1}{f[S]}\sum_{s\subset S} dp[s]f[S\setminus s]

这东西形式看起来很像分治 FFT 啊(只是换成了 FWT 而已。。)。

我们考虑子集卷积最后推出的式子是:

Cx=IFMT(0ixFMT(Ai)FMT(Bxi))C_x' = IFMT(\sum_{0\le i\le x} FMT(A'_i)\cdot FMT(B'_{x-i}))

在这里,我们把 f,dpf,dp 分别扩展到二维,因为 ss 不能等于 SS ,所以就是:

dpx,s=ij=s,p(i)+p(j)=x,p(i)xdpp(i),ifp(j),jdp_{x,s} = \sum_{i|j=s,p(i)+p(j)=x,p(i)\neq x} dp_{p(i),i}f_{p(j),j}

写回那 FMT 形式:

dpx=IFMT(0i<xdpi×orfxi)dpx=IFMT(0i<xFMT(dpi)FMT(fxi))dp_{x} = IFMT( \sum_{0\le i< x} dp_i \times_{or} f_{x-i} )\\dp_x = IFMT( \sum_{0\le i<x} FMT(dp_i)\cdot FMT(f_{x-i}) )

所以为了算 dpxdp_x 拿前 0i<x0\le i<xdpdpff 来分别卷一下就好了。

预处理 FMT(fi)FMT(f_i) 复杂度 O(n22n)O(n^22^n)

于是算每个 FMT(dpx)FMT(dp_x) 需要的仅仅是算个点积,复杂度是 O(n22n)O(n^22^n)

然后对于每个 dpxdp_x 都需要 IFMT 回去 乘上前面那个系数 1f[S]\frac 1 {f[S]} 再做回来,复杂度还是 O(n22n)O(n^22^n)

于是总复杂度仍然是 O(n22n)O(n^22^n)

不开 O2 稳 T()

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN ( 1 << 21 ) + 6
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
#define min( a , b ) ( (a) < (b) ? (a) : (b) )
#define max( a , b ) ( (a) > (b) ? (a) : (b) )
#define P 998244353
typedef long long ll;
int n , m , p;

int Pow( int x , int a ) {
    int cur = x % P , ans = 1;
    while( a ) {
        if( a & 1 ) ans = 1ll * ans * cur % P;
        cur = 1ll * cur * cur % P , a >>= 1;
    }
    return ans;
}

int dp[22][MAXN] , f[MAXN] , F[22][MAXN] , w[MAXN];
int G[23][23];

int fa[23] , deg[23];
int find( int x ) { return x == fa[x] ? x : fa[x] = find( fa[x] ); }
int chk( int s ) {
    rep( i , 1 , n ) fa[i] = i , deg[i] = 0;
    rep( i , 1 , n ) if( s & ( 1 << i - 1 ) )
            rep( j , i + 1 , n ) if( ( s & ( 1 << j - 1 ) ) && G[i][j] ) {
                    ++ deg[i] , ++ deg[j] , fa[find( j )] = find( i );
                }
    int t = __builtin_ctz( s ) + 1;
    rep( i , 1 , n ) if( s & ( 1 << i - 1 ) ) {
            if( ( deg[i] & 1 ) || find( i ) != find( t ) ) return 1;
        }
    return 0;
}

void FWTor( int* A , int len ) {
    for( int mid = 2 ; mid <= len ; mid <<= 1 )
        for( int i = 0 ; i < len ; i += mid )
            for( int j = i ; j < i + ( mid >> 1 ) ; ++ j )
                ( A[j + ( mid >> 1 )] += A[j] ) %= P;
}
void IFWTor( int* A , int len ) {
    for( int mid = 2 ; mid <= len ; mid <<= 1 )
        for( int i = 0 ; i < len ; i += mid )
            for( int j = i ; j < i + ( mid >> 1 ) ; ++ j )
                ( A[j + ( mid >> 1 )] += P - A[j] ) %= P;
}

int calc( int x ) {
    return !p ? 1 : ( p == 1 ? x : 1ll * x * x % P );
}

int inv[MAXN];

void solve() {
    cin >> n >> m >> p;
    int u , v;
    rep( i , 1 , m ) {
        scanf("%d%d",&u,&v);
        G[u][v] = G[v][u] = 1;
    }
    rep( i , 1 , n ) scanf("%d",w + i);
    rep( i , 1 , ( 1 << n ) - 1 ) {
        f[i] = (f[i ^ (i & -i)] + w[__builtin_ctz(i & -i) + 1]), F[__builtin_popcount(i)][i] = chk(i) * calc(f[i]);
        inv[i] = Pow( calc( f[i] ) , P - 2 );
    }
    int len = ( 1 << n );
    rep( i , 0 , n ) FWTor( F[i] , ( 1 << n ) );
    dp[0][0] = 1; FWTor( dp[0] , len );
    rep( i , 1 , n ) {
        rep( j , 0 , i - 1 ) {
            rep( k , 0 , len - 1 )
                ( dp[i][k] += 1ll * dp[j][k] * F[i - j][k] % P ) %= P;
        }
        IFWTor( dp[i] , len );
        rep( k , 0 , len - 1 ) dp[i][k] = 1ll * dp[i][k] * inv[k] % P;
        if( i != n ) FWTor( dp[i] , len );
    }
    printf("%d\n",dp[n][( 1 << n ) - 1]);
}

signed main() {
//    freopen("walk2.in","r",stdin);
//    freopen("fuckout","w",stdout);
//    int T;cin >> T;while( T-- ) solve();
    solve();
}

\