#include <bits/stdc++.h>
using namespace std;
#define all(x) x.begin(), x.end()
const int mxN = 2e5 + 100;
const int mxB = 21;
vector<int> adj[mxN];
int st[mxN], sum[mxN], depth[mxN], dp[mxN][mxB], parent[mxN], total = 0, tin = 0;
void dfs(int u, int par){
parent[u] = dp[u][0] = par, st[u] = ++tin, depth[u] = depth[par] + 1;
for(int i = 1; i < mxB; ++i) dp[u][i] = dp[dp[u][i - 1]][i - 1];
for(auto it : adj[u]){
if(it ^ par) dfs(it, u);
}
}
int kth(int u, int k){
for(int j = 0; j < mxB; ++j){
if(k & (1 << j)) u = dp[u][j];
}
return u;
}
int lca(int u, int v){
if(depth[u] < depth[v]) swap(u, v);
u = kth(u, depth[u] - depth[v]);
if(u == v) return u;
for(int j = mxB - 1; j >= 0; --j){
if(dp[u][j] ^ dp[v][j]){
u = dp[u][j], v = dp[v][j];
}
}
return dp[u][0];
}
void update(int u, int x){
while(u > 0){
total += x;
sum[u] += x;
u = parent[u];
}
}
void update(int u, int v, int x){
update(u, x);
update(v, x);
update(lca(u, v), -x);
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(nullptr); cout.tie(nullptr);
int n, k, ans = 1e9;
cin >> n >> k;
vector<int> c[n + 1];
for(int i = 0, u, v; i < n - 1; ++i){
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for(int i = 1, x; i <= n; ++i) {
cin >> x;
c[x].push_back(i);
}
dfs(1, -1);
for(int i = 1; i <= k; ++i){
int cur = (int) c[i].size();
sort(all(c[i]), [&](int x, int y){
return st[x] < st[y];
});
c[i].push_back(c[i][0]);
for(int j = 0; j < (int) c[i].size() - 1; ++j) update(c[i][j], c[i][j + 1], 1);
ans = min(ans, total / 2 - cur);
for(int j = 0; j < (int) c[i].size() - 1; ++j) update(c[i][j], c[i][j + 1], -1);
}
cout << ans << "\n";
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |