P4437 [HNOI/AHOI2018]排列

考虑这个限制本质上就是限制了 iiaia_i 后被选,也就是如果从 iiaia_i 连边,然后就是必须先选择祖先再选择儿子的一个选择点的问题,使得最后 iAi\sum{iA_i} 尽量大。如果存在环,显然无解。

我们考虑全局最小的点,我们显然可以把它合并到它的父亲上,因为在选择它的父亲后选择这个点本身一定是最优的决策。如果一直合并,我们会得到很多的联通块。但是如何确定两个块谁先被合并?我们考虑合并 W,VW,V 的代价:

vu:vVwvpv+uWwu(ku+mv)uv:uWwupu+vVwv(kv+mu)v \to u : \sum_{v \in V} w_vp_v + \sum_{u \in W} w_u(k_u + m_v)\\ u \to v : \sum_{u \in W} w_up_u + \sum_{v \in V} w_v(k_v + m_u)\\

考虑下面大于上面,也就是说先选择 uu 更优秀,那么

vVwumvvWwvmu\sum_{v\in V} w_um_v \le \sum_{v \in W} w_vm_u

也就是

uWwumuvVwvmv\frac{\sum_{u\in W} w_u}{m_u} \le \frac{\sum_{v\in V}w_v}{m_v}

按照这个东西贪心,拿堆维护即可。

最后所有东西都会被合并到 00 上。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 500006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
#define P 1000000007
int n , k;
int f[MAXN] , w[MAXN] , cnt;
struct node {
	ll w; int m , idx; ll as; int tim = 0;
	node( ll w = 0 , int m = 0 , int idx = 0 , ll as = 0 , int tm = 0 ) : w(w) , m(m) , idx(idx) , as(as) , tim(tm) {}
	bool operator < ( node x ) const {
		return 1ll * w * x.m > 1ll * x.w * m;
	}
	node operator + ( node x ) {
		return node( x.w + w , m + x.m , idx , as + x.as + x.w * 1ll * m , ++ cnt );
	}
} nd[MAXN] ;
priority_queue<node> Q;

int fa[MAXN];
int find( int x ) {
	return x == fa[x] ? x : fa[x] = find( fa[x] );
}

void solve( ) {
	cin >> n;
	rep( i , 1 , n ) fa[i] = i;
	rep( i , 1 , n ) {
		scanf("%d",&f[i]);
		if( find( i ) == find( f[i] ) ) puts("-1") , exit(0);
		else {
			fa[find( i )] = find( f[i] );
		}
	}
	rep( i , 1 , n ) fa[i] = i;
	rep( i , 1 , n ) {
		scanf("%d",w + i);
		Q.push( nd[i] = (node) { w[i] , 1 , i , w[i] , 0 } );
	}
	while( !Q.empty() ) {
		node tp = Q.top(); Q.pop();
		if( nd[tp.idx].tim != tp.tim ) continue;
		int x = find( f[tp.idx] );
		fa[tp.idx] = x;
		nd[x] = nd[x] + tp;
		if( x ) Q.push( nd[x] );
	}
	printf("%lld\n",nd[0].as);
}

signed main() {
//	freopen("perm6.in","r",stdin);
//	freopen("fuckout","w",stdout);
//    int T;cin >> T;while( T-- ) solve();
	solve();
}
\