CF566C A Logistical Questions

考虑当前在 uu ,如何确定带权重心在哪个子树中。考虑往某个子树移动一个微小距离,那么所有点到这个点距离的变化是

tSv(xt1.5(xtΔx)1.5)tSv((xt+Δx)1.5xt1.5)\sum_{t\notin S_v} (x_t^{1.5} - (x_t - \Delta x)^{1.5} ) - \sum _{t \in S_v} ((x_t + \Delta x)^{1.5} - x_t^{1.5})

如果这个东西大于 00 ,就意味着往这个点移是优的,也就是

tSv(xt1.5(xtΔx)1.5)tSv((xt+Δx)1.5xt1.5)tSv(xt1.5(xtΔx)1.5)tSv((xt+Δx)1.5xt1.5)Δx>0\sum_{t\notin S_v} (x_t^{1.5} - (x_t - \Delta x)^{1.5} ) - \sum _{t \in S_v} ((x_t + \Delta x)^{1.5} - x_t^{1.5})\\ \frac{\sum_{t\notin S_v} (x_t^{1.5} - (x_t - \Delta x)^{1.5} ) - \sum _{t \in S_v} ((x_t + \Delta x)^{1.5} - x_t^{1.5})}{\Delta x} > 0\\

如果我们设 f(x)=x1.5f(x) = x^{1.5} ,那么就是

tSvf(xt)tSvf(xt)>0S2tSvf(xt)>0\sum_{t \notin S_v} f'(x_t) - \sum_{t\in S_v} f'(x_t) > 0\\ S - 2\sum_{t \in S_v} f'(x_t) > 0

我们点分治一下,这样就只需要跳 logn\log n 次了。

可能还需要证明一下能取到最小值的位置只有 11 个位置,因为 f(x)f(x) 是凸函数,所以只有一个点能取到最优的位置。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "assert.h"
#include "cmath"
using namespace std;
#define MAXN 200006
int n;
int A[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , wto[MAXN << 1] , ecn;
void ade( int u , int v , int w ) {
    wto[++ ecn] = w , to[ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
int vis[MAXN] , siz[MAXN] , sz;
int p , mx;
void dfs( int u , int fa ) {
    siz[u] = 1;
    int s = 0;
    for( int i = head[u] ; i ; i = nex[i] ) {
        int v = to[i];
        if( v == fa || vis[v] ) continue;
        dfs( v , u );
        siz[u] += siz[v];
        s = max( s , siz[v] );
    }
    s = max( s , sz - siz[u] );
    if( s < mx ) p = u , mx = s;
}
int dep[MAXN]; double all , pia;
int via = 0 , pre;
double getit( int u , int fa ) {
    if( u == pre ) via = 1;
    double ret = A[u] * 1.5 * sqrt( 1.0 * dep[u] );
    all += A[u] * sqrt( 1.0 * dep[u] ) * dep[u];
    pia += ret;
    for( int i = head[u] ; i ; i = nex[i] ) {
        int v = to[i];
        if( v == fa ) continue;
        dep[v] = dep[u] + wto[i];
        ret += getit( v , u );
    }
    return ret;
}
double re[MAXN];
double ans = 8e18; int ps;
void solve( int u ) {
//    cout << u << endl;
    vis[u] = 1;
    all = pia = 0.0;
    int tr = u;
    for( int i = head[u] ; i ; i = nex[i] ) {
        int v = to[i];
        dep[v] = wto[i];
        re[v] = getit( v , u );
        if( via ) tr = v;
    }
    if( all < ans ) ans = all , ps = u;
    for( int i = head[u] ; i ; i = nex[i] ) {
        int v = to[i];
        if( vis[v] ) continue;
        if( pia - 2 * re[v] <= 1e-7 ) {
            p = 0 , mx = 0x3f3f3f3f;
            sz = siz[u];
            dfs( v , u );
            pre = u;
            solve( p );
            dep[tr] = 0.0;
            all = 0;
            getit( tr , tr );
            if( all < ans ) ans = all , ps = tr;
        }
    }
}
int main() {
    cin >> n;
    for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&A[i]);
    for( int i = 1 , u , v , w ; i < n ; ++ i ) {
        scanf("%d%d%d",&u,&v,&w) , ade( u , v , w ) , ade( v , u , w );
    }
    mx = 0x3f3f3f3f;
    sz = n;
    dfs( 1 , 1 );
    solve( p );
    printf("%d %.7lf",ps,ans);
}
\