#include <bits/stdc++.h>
using namespace std;
class dsu {
private:
int n;
vector<int> par;
public:
dsu(int n) : n(n), par(n) {
iota(par.begin(), par.end(), 0);
}
int root(int u) { return u == par[u] ? u : par[u] = root(par[u]); }
void merge(int u, int v) { // first retains its identity
u = root(u), v = root(v);
if (u != v) {
par[v] = u;
}
}
};
const int inf = 1e9;
class segment_tree {
private:
int n;
vector<int> seg;
public:
segment_tree(int n) : n(n), seg(2 * n) {}
void add(int i, int x) {
for (seg[i += n] += x, i >>= 1; i > 0; i >>= 1) {
seg[i] = seg[2 * i] + seg[2 * i + 1];
}
}
int query(int l, int r) {
int ans = 0;
for (l += n, r += n + 1; l < r; l >>= 1, r >>= 1) {
if (l & 1)
ans += seg[l++];
if (r & 1)
ans += seg[--r];
}
return ans;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, k;
cin >> n >> k;
vector<vector<int>> adj(n + 1);
vector<int> deg(n + 1);
for (int i = 0, u, v; i < n - 1; ++i) {
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
++deg[u], ++deg[v];
}
vector<int> line, loc(n + 1);
auto dfs = [&](auto &&self, int u, int p) -> void {
loc[u] = line.size();
line.push_back(u);
for (int &i : adj[u]) {
if (i != p) {
self(self, i, u);
}
}
};
for (int i = 1; i <= n; ++i) {
if (deg[i] == 1) {
dfs(dfs, i, 0);
break;
}
}
vector<int> c(n + 1), city_st(k + 1, inf), city_en(k + 1, -inf);
for (int i = 1; i <= n; ++i) {
cin >> c[i];
city_st[c[i]] = min(city_st[c[i]], loc[i]);
city_en[c[i]] = max(city_en[c[i]], loc[i]);
}
dsu dsu(k + 1);
vector<pair<int, int>> ints(k);
vector<int> int_of(n + 1);
for (int i = 1; i <= k; ++i) {
ints[i - 1] = {city_st[i], city_en[i]};
int_of[city_st[i]] = i;
}
// merge
// 1 2 3 4 5
// |---|
// |-----|
sort(ints.begin(), ints.end(), [&](pair<int, int> a, pair<int, int> b) { return a.second > b.second; });
set<int> lefts;
for (int i = 0; i < ints.size(); ++i) {
auto [l, r] = ints[i];
auto it = lefts.upper_bound(l);
if (it != lefts.end() && *it < r) { // have intersection
int u = int_of[l], v = int_of[*it];
dsu.merge(u, v);
lefts.erase(it);
}
lefts.insert(l);
}
// rebuild ints
ints.clear();
int_of.assign(n + 1, 0);
vector<vector<int>> comps(k + 1);
for (int i = 1; i <= k; ++i) {
comps[dsu.root(i)].push_back(i);
}
for (auto &comp : comps) {
if (comp.empty()) {
continue;
}
int l = city_st[comp.front()], r = city_en[comp.front()];
for (int &x : comp) {
l = min(l, city_st[x]);
r = max(r, city_en[x]);
}
ints.push_back({l, r});
int_of[r] = comp.front();
}
// merge
// 1 2 3 4 5
// |---|
// |-----|
sort(ints.begin(), ints.end(), [&](pair<int, int> a, pair<int, int> b) { return a.first < b.first; });
set<int> rights;
for (int i = 0; i < ints.size(); ++i) {
auto [l, r] = ints[i];
auto it = rights.lower_bound(r);
if (it != rights.begin()) {
--it;
if (*it > l) {
int u = dsu.root(r), v = dsu.root(*it);
dsu.merge(u, v);
rights.erase(it);
}
}
rights.insert(r);
}
// now do complete interval intersection thing
struct interval {
int l, r, c;
};
comps = vector<vector<int>>(k + 1);
for (int i = 1; i <= k; ++i) {
comps[dsu.root(i)].push_back(i);
}
vector<interval> fin; // final intervals
for (auto &comp : comps) {
if (comp.empty()) {
continue;
}
int l = city_st[comp.front()], r = city_en[comp.front()];
for (int &x : comp) {
l = min(l, city_st[x]);
r = max(r, city_en[x]);
}
fin.push_back({l, r, int(comp.size())});
}
// cerr << "so far, we have:\n";
// for (auto &[l, r, c] : fin) {
// cerr << "interval [" << l << ", " << r << "] with cost " << c << '\n';
// }
int ans = inf;
segment_tree st(n + 1); // add at r, iterate over reverse l
sort(fin.begin(), fin.end(), [&](interval a, interval b) { return a.l > b.l; });
for (auto &[l, r, c] : fin) {
ans = min(ans, st.query(0, r) + c);
st.add(r, c);
}
cout << ans - 1 << '\n';
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |