CF1307F Cow and Vacation

首先在边上建点,于是距离就从 kk 变成了 2k2k

距离变成偶数后,对于两个位置,它们可达的条件就是从其中一个走 kk 另一个走 kk 能到的点有交。

首先考虑以每个点作为中转点可以到达的点。也就是只要某一步到达了这个点就一定可以通过转驿站来走到的点。做法是把所有驿站加入队列距离设置为 00 进行一次 bfs 。考虑当我们从一个点到达另一个已经访问过的点,那么这两个点所对应的联通块可以直接合并,因为总是可以通过这两个点的初始驿站互相到达。如果当前点与最近的驿站的距离已经超过了 kk 就弹出。直接用并查集维护即可。

然后考虑两个点之间是否可达。相当于两个点作为了一个初始的驿站。根据一个很显然的贪心,我们可以通过向对方走 kk 步然后判断这两个点作为中转点是否可达即可。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
#include "stack"
#include "cassert"
#include "string"
using namespace std;
#define MAXN 500306
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 924844033;
int n , k , r;
int N;
vi G[MAXN];

int g[MAXN][20] , dep[MAXN];
void dfs( int u , int f ) {
	dep[u] = dep[f] + 1;
	for( int v : G[u] ) if( v != f ) {
		g[v][0] = u;
		rep( k , 1 , 19 ) if( g[g[v][k - 1]][k - 1] ) g[v][k] = g[g[v][k - 1]][k - 1]; else break;
		dfs( v , u );
	}
}
int lca( int u , int v ) {
	if( dep[u] < dep[v] ) swap( u , v );
	for( int k = 19 ; k >= 0 ; -- k ) if( dep[g[u][k]] >= dep[v] ) u = g[u][k];
	if( u == v ) return u;
	for( int k = 19 ; k >= 0 ; -- k ) if( g[u][k] != g[v][k] ) u = g[u][k] , v = g[v][k];
	return g[u][0];
}

int d[MAXN];
int fa[MAXN];
int find( int x ) {
	return x == fa[x] ? x : fa[x] = find( fa[x] );
}

void solve() {
	cin >> n >> k >> r;
	N = n;
	rep( i , 2 , n ) {
		int u , v;
		scanf("%d%d",&u,&v);
		++ N;
		G[u].pb( N ) , G[N].pb( v );
		G[v].pb( N ) , G[N].pb( u );
	}
	dfs( 1 , 1 );
	queue<int> Q;
	memset( d , 0x3f , sizeof d );
	rep( i , 1 , r ) {
		int u;
		scanf("%d",&u);
		Q.push( u );
		d[u] = 0;
	}
	rep( i , 1 , N ) fa[i] = i;
	while( !Q.empty() ) {
		int u = Q.front(); Q.pop();
		if( d[u] >= k ) break;
		for( int v : G[u] ) {
			fa[find( v )] = find( u );
			if( d[v] > 1e9 ) d[v] = d[u] + 1 , Q.push( v );
		}
	}
	int L;
	auto dis = [&]( int u , int v ) {
		return dep[u] + dep[v] - 2 * dep[L];
	};
	auto walk = [&] ( int u , int v , int k ) {
		if( k > dep[u] - dep[L] ) {
			k -= ( dep[u] - dep[L] );
			k = ( dep[v] - dep[L] - k );
			for( int t = 19 ; t >= 0 ; -- t ) if( k >= ( 1 << t ) ) v = g[v][t] , k -= ( 1 << t );
			return v;
		} else {
			for( int t = 19 ; t >= 0 ; -- t ) if( k >= ( 1 << t ) )
				u = g[u][t] , k -= ( 1 << t );
			return u;
		}
	};
	int q; cin >> q;
	while( q-- ) {
		int u , v;
		scanf("%d%d",&u,&v);
		L = lca( u , v );
		if( dis( u , v ) <= 2 * k ) puts("YES");
		else if( find( walk( u , v , k ) ) == find( walk( v , u , k ) ) )
			puts("YES");
		else puts("NO");
	}
}

signed main() {
//	freopen("sample_wave10.in","r",stdin);
//	freopen("wave.out","w",stdout);
//	freopen("input","r",stdin);
//	freopen("sot","w",stdout);
//	int T;cin >> T;while( T-- ) solve();
	solve();
}
\