ZR1353 20三月省选 Day4 C dict

好题。也是一个会一半的题目,最后一半是一个没怎么见过的套路。

考虑字典序比较,常见的方法是枚举 LCP\text{LCP} 长度,然后固定下一位小于,然后剩下的随便放。

在这个题中,我们一样可以枚举 LCP\text{LCP} 的长度,比如长度为 kk ,也就意味着第 p1pkp_1 \dots p_k 小在 A,BA,B 中是相等的,而 AA 的第 pk+1p_{k+1} 小比 BB 的小。

我们考虑对两个集合分别排序后处理。对于每个 kk ,我们考虑 pk+1p_{k+1}AA 的取值对答案的影响,设这个位置到左边第一个已经确定的位置中间有 ll 个空位,这个位置到右边第一个有 rr 个空位,这个位置到两边的值域分别由 L,RL,R 个数,那么贡献是是

cBk+1(cLl)(Rcr)\sum_{c \ge B_{k+1}} \binom{c-L}{l} \binom{R-c}{r}

而这个位置一旦确定,也就是枚举到了 k+1k+1 ,那么相当于是删除前后位置间这一段的贡献,加入新的两段的贡献。明显这个贡献是个组合数。

那么就有了一个不怎么行的算法,用 set\text{set} 维护这些段,然后每次暴力枚举这个上面说的贡献,复杂度 O(nm)O(nm)

继续观察,会发现这个操作本质上是在值域上进行一次分裂,把值域 [L,R][L,R] 分裂成了 [L,c1],[c+1,R][L,c-1] , [c+1,R] 。我们可以进行启发式分裂,也就是每次枚举小的那一边的话,复杂度就会被降到 O(nlogm)O(n\log m) 。可以发现这个式子是一个经典的组合恒等式

RcL(cLl)(Rcr)=(RL+1r+l+1)\sum_{R \ge c \ge L} \binom{c-L} l \binom{R - c} r = \binom{R-L + 1}{r + l + 1}

所以可以直接减掉一部分得到另一部分。

最后总复杂度变成了 O(n(logm+logn))O(n(\log m + \log n))

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
#include "cassert"
using namespace std;
#define MAXN 500006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 998244353;
int n , m;
int A[MAXN] , B[MAXN] , p[MAXN];

int Pow( int x , int a ) {
	int ret = 1;
	while( a ) {
		if( a & 1 ) ret = ret * 1ll * x % P;
		x = x * 1ll * x % P , a >>= 1;
	}
	return ret;
}
int J[MAXN] , iJ[MAXN];
int C( int a , int b ) {
	if( a < 0 || b < 0 || a - b < 0 ) return 0;
	return J[a] * 1ll * iJ[b] % P * iJ[a - b] % P;
}

struct tcurts {
	int l , r , L , R;
	tcurts( int l = 0 , int r = 0 , int L = 0 , int R = 0 ) : l(l) , r(r) , L(L) , R(R) {}
	bool operator < ( const tcurts& s ) const {
		return l == s.l ? r < s.r : l < s.l;
	}
};

set<tcurts> S;


void solve() {
	J[0] = iJ[0] = 1;
	rep( i , 1 , MAXN - 1 ) J[i] = J[i - 1] * 1ll * i % P , iJ[i] = Pow( J[i] , P - 2 ) % P;
	cin >> n >> m;
	rep( i , 1 , n ) scanf("%d",B + i);
	sort( B + 1 , B + 1 + n );
	rep( i , 1 , n ) scanf("%d",p + i);
	S.insert( tcurts( 1 , n , 1 , m ) );
	int as = C( m , n ) , ans = 0;
	rep( i , 1 , n ) {
		int t = p[i] , c = B[p[i]];
		auto s = S.upper_bound( tcurts( t , 0x3f3f3f3f ) ); -- s;
		int l = t - s -> l , r = s -> r - t , L = s -> L , R = s -> R;
		int tot = C( R - L + 1 , l + r + 1 );
		as = as * 1ll * Pow( tot , P - 2 ) % P;
		if( c - L <= R - c + 1 ) {
			int s = 0;
			per( j , c - 1 , L ) 
				s = ( s + C( j - L , l ) * 1ll * C( R - j , r ) ) % P;
			ans = ( ans + as * 1ll * s ) % P;
		} else {
			int s = 0;
			rep( j , c , R ) 
				s = ( s + C( j - L , l ) * 1ll * C( R - j , r ) ) % P;
			s = ( tot + P - s ) % P;
			ans = ( ans + as * 1ll * s ) % P;
		}
		S.erase( s );
		if( l ) S.insert( tcurts( t - l , t - 1 , L , c - 1 ) ) , as = as * 1ll * C( c - L , l ) % P;
		if( r ) S.insert( tcurts( t + 1 , t + r , c + 1 , R ) ) , as = as * 1ll * C( R - c , r ) % P;
		if( !as ) break;
	}
	cout << ans << endl;
}

signed main() {
//    int T;cin >> T;while( T-- ) solve();
    solve();
}

\