好题。也是一个会一半的题目,最后一半是一个没怎么见过的套路。
考虑字典序比较,常见的方法是枚举 长度,然后固定下一位小于,然后剩下的随便放。
在这个题中,我们一样可以枚举 的长度,比如长度为 ,也就意味着第 小在 中是相等的,而 的第 小比 的小。
我们考虑对两个集合分别排序后处理。对于每个 ,我们考虑 处 的取值对答案的影响,设这个位置到左边第一个已经确定的位置中间有 个空位,这个位置到右边第一个有 个空位,这个位置到两边的值域分别由 个数,那么贡献是是
而这个位置一旦确定,也就是枚举到了 ,那么相当于是删除前后位置间这一段的贡献,加入新的两段的贡献。明显这个贡献是个组合数。
那么就有了一个不怎么行的算法,用 维护这些段,然后每次暴力枚举这个上面说的贡献,复杂度 。
继续观察,会发现这个操作本质上是在值域上进行一次分裂,把值域 分裂成了 。我们可以进行启发式分裂,也就是每次枚举小的那一边的话,复杂度就会被降到 。可以发现这个式子是一个经典的组合恒等式
所以可以直接减掉一部分得到另一部分。
最后总复杂度变成了 。
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
#include "cassert"
using namespace std;
#define MAXN 500006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 998244353;
int n , m;
int A[MAXN] , B[MAXN] , p[MAXN];
int Pow( int x , int a ) {
int ret = 1;
while( a ) {
if( a & 1 ) ret = ret * 1ll * x % P;
x = x * 1ll * x % P , a >>= 1;
}
return ret;
}
int J[MAXN] , iJ[MAXN];
int C( int a , int b ) {
if( a < 0 || b < 0 || a - b < 0 ) return 0;
return J[a] * 1ll * iJ[b] % P * iJ[a - b] % P;
}
struct tcurts {
int l , r , L , R;
tcurts( int l = 0 , int r = 0 , int L = 0 , int R = 0 ) : l(l) , r(r) , L(L) , R(R) {}
bool operator < ( const tcurts& s ) const {
return l == s.l ? r < s.r : l < s.l;
}
};
set<tcurts> S;
void solve() {
J[0] = iJ[0] = 1;
rep( i , 1 , MAXN - 1 ) J[i] = J[i - 1] * 1ll * i % P , iJ[i] = Pow( J[i] , P - 2 ) % P;
cin >> n >> m;
rep( i , 1 , n ) scanf("%d",B + i);
sort( B + 1 , B + 1 + n );
rep( i , 1 , n ) scanf("%d",p + i);
S.insert( tcurts( 1 , n , 1 , m ) );
int as = C( m , n ) , ans = 0;
rep( i , 1 , n ) {
int t = p[i] , c = B[p[i]];
auto s = S.upper_bound( tcurts( t , 0x3f3f3f3f ) ); -- s;
int l = t - s -> l , r = s -> r - t , L = s -> L , R = s -> R;
int tot = C( R - L + 1 , l + r + 1 );
as = as * 1ll * Pow( tot , P - 2 ) % P;
if( c - L <= R - c + 1 ) {
int s = 0;
per( j , c - 1 , L )
s = ( s + C( j - L , l ) * 1ll * C( R - j , r ) ) % P;
ans = ( ans + as * 1ll * s ) % P;
} else {
int s = 0;
rep( j , c , R )
s = ( s + C( j - L , l ) * 1ll * C( R - j , r ) ) % P;
s = ( tot + P - s ) % P;
ans = ( ans + as * 1ll * s ) % P;
}
S.erase( s );
if( l ) S.insert( tcurts( t - l , t - 1 , L , c - 1 ) ) , as = as * 1ll * C( c - L , l ) % P;
if( r ) S.insert( tcurts( t + 1 , t + r , c + 1 , R ) ) , as = as * 1ll * C( R - c , r ) % P;
if( !as ) break;
}
cout << ans << endl;
}
signed main() {
// int T;cin >> T;while( T-- ) solve();
solve();
}