#include <bits/stdc++.h>
#define ent '\n'
// #define int long long
using namespace std;
const int maxn = 200'020;
const int lk = 19;
vector<int> g[maxn], e[maxn * lk], rev[maxn * lk];
int up[lk][maxn], go[lk][maxn];
int c[maxn], tin[maxn], tout[maxn], lca[maxn];
int n, k, N, timer;
void dfs(int v, int p) {
tin[v] = ++timer;
up[0][v] = p;
go[0][v] = ++N;
e[go[0][v]].push_back(c[v]);
for(int i = 1; i < 18; i++) {
up[i][v] = up[i - 1][up[i - 1][v]];
if(up[i - 1][v] != 0) {
go[i][v] = ++N;
e[go[i][v]].push_back(go[i - 1][v]);
e[go[i][v]].push_back(go[i - 1][up[i - 1][v]]);
}
}
for(int to : g[v]) {
if(to != p) {
dfs(to, v);
}
}
tout[v] = timer;
}
bool check(int u, int v) {
return tin[u] <= tin[v] && tout[v] <= tout[u];
}
int get_lca(int u, int v) {
if(check(u, v)) return u;
if(check(v, u)) return v;
for(int i = 17; i >= 0; i--) {
if(up[i][v] != 0 && !check(up[i][v], u)) {
v = up[i][v];
}
}
return up[0][v];
}
vector<int> ord;
int comp[maxn * 20], used[maxn * 20], fg[maxn * 20], sz[maxn * 20];
void dfs(int v) {
used[v] = true;
for(int to : e[v]) {
if(!used[to]) {
dfs(to);
}
}
ord.push_back(v);
}
void calc(int v, int col) {
comp[v] = col;
for(int to : rev[v]) {
if(!comp[to]) {
calc(to, col);
}
}
}
void solve() {
cin >> n >> k;
N = k;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= n; i++) {
cin >> c[i];
}
dfs(1, 0);
for(int i = 1; i <= n; i++) {
if(!lca[c[i]]) {
lca[c[i]] = i;
}
lca[c[i]] = get_lca(lca[c[i]], i);
}
for(int v = 1; v <= n; v++) {
int cur = v;
for(int i = 17; i >= 0; i--) {
if(up[i][cur] != 0 && !check(up[i][cur], lca[c[v]])) {
e[c[v]].push_back(go[i][cur]);
cur = up[i][cur];
}
}
if(cur != lca[c[v]]) {
e[c[v]].push_back(go[1][cur]);
}
else {
e[c[v]].push_back(go[0][cur]);
}
}
for(int v = 1; v <= N; v++) {
if(!used[v]) {
dfs(v);
}
}
reverse(ord.begin(), ord.end());
for(int v = 1; v <= N; v++) {
for(int to : e[v]) {
rev[to].push_back(v);
}
vector<int> mus;
e[v].swap(mus);
}
int cnt = 0;
for(int v : ord) {
if(!comp[v]) {
calc(v, ++cnt);
}
}
for(int v = 1; v <= k; v++) {
sz[comp[v]]++;
}
int ans = 1e9;
for(int v = 1; v <= N; v++) {
for(int to : rev[v]) {
if(comp[v] != comp[to]) {
fg[comp[to]] = true;
}
}
}
for(int i = 1; i <= cnt; i++) {
if(!fg[i] && sz[i] > 0) {
ans = min(ans, sz[i]);
}
}
cout << ans - 1 << ent;
}
int32_t main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t = 1;
// cin >> t;
while(t--) {
solve();
}
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |