#include "bits/stdc++.h"
using namespace std;
#ifdef DEBUG
auto&operator<<(auto &o, pair<auto, auto> p) {o << "(" << p.first << ", " << p.second << ")"; return o;}
auto operator<<(auto&o,auto x)->decltype(x.end(),o){o<<"{"; for(auto e : x) o<<e<<", "; return o<<"}";}
#define debug(X) cerr << "["#X"]: " << X << '\n';
#else
#define cerr if(0)cout
#define debug(X) ;
#endif
using ll = long long;
#define all(v) (v).begin(), (v).end()
#define ssize(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define eb emplace_back
const int inf = 1e9;
struct node {
int dp[2][2];
node() {
dp[0][0] = dp[0][1] = dp[1][0] = dp[1][1] = inf;
}
node(int x) {
dp[1-x][0] = dp[1-x][1] = 1;
dp[x][1] = 0;
dp[x][0] = inf;
}
friend ostream& operator << (auto &os, const node &a) {
return os << "{{" << a.dp[0][0] << ", " << a.dp[0][1] << "}, {"
<< a.dp[1][0] << ", " << a.dp[1][1] << "}}";
}
};
node comb(const node &a, const node &b) {
node res;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
for (int x = 0; x < 2; ++x) {
for (int y = 0; y < 2; ++y) {
if (j && y) continue;
int col = i, add = 0;
// if (!j && !y) {
// res.dp[col][1] = min(res.dp[col][1], a.dp[i][j] + b.dp[x][y]);
// res.dp[col^1][1] = min(res.dp[col^1][1], a.dp[i][j] + b.dp[x][y] + 2`);
// }
if (!x) {
col ^= 1;
add += 2;
}
if (y) {
res.dp[col][1] = min(res.dp[col][1], a.dp[i][j] + b.dp[x][y] + add);
}
else {
if (!j)
res.dp[col][1] = min(res.dp[col][1], a.dp[i][j] + b.dp[x][y] + add);
add++;
col ^= 1;
res.dp[col][j] = min(res.dp[col][j], a.dp[i][j] + b.dp[x][y] + add);
}
}
}
}
}
return res;
}
void dfs(int v, int p, const vector<vector<int>> &g, vector<node> &dp) {
for (int u : g[v]) {
if (u == p) continue;
dfs(u, v, g, dp);
dp[v] = comb(dp[v], dp[u]);
}
debug(v);
debug(dp[v]);
}
int solve_dp(int s, const vector<vector<int>> &g, const vector<int> &c) {
int n = ssize(g);
vector<node> dp(n);
for (int i = 0; i < n; ++i) dp[i] = node(c[i]);
debug(s);
dfs(s, -1, g, dp);
cerr << endl;
return min(dp[s].dp[1][0], dp[s].dp[1][1]);
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(nullptr);
int n;
cin >> n;
vector<int> c(n);
vector<vector<int>> g(n);
for (int i = 0; i < n; ++i) {
char x;
cin >> x;
c[i] = int(x - '0');
}
for (int i = 0; i < n-1; ++i) {
int u, v;
cin >> u >> v;
g[u-1].eb(v-1);
g[v-1].eb(u-1);
}
int res = inf;
for (int i = 0; i < n; ++i) res = min(res, solve_dp(i, g, c));
cout << res << '\n';
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |