#include <bits/stdc++.h>
using namespace std;
#ifndef lisie_bimbi
#define endl '\n'
#pragma GCC optimize("O3")
#pragma GCC target("avx,avx2,bmi2,fma")
#endif
using ll = long long;
const ll inf = 1'000'000'000;
const int N = 100'000;
struct node{
int mn;
int cnt;
int mod;
};
node merge(node a, node b){
node ans;
ans.mod = 0;
if(a.mn < b.mn){
ans.mn = a.mn;
ans.cnt = a.cnt;
}else if(b.mn < a.mn){
ans.mn = b.mn;
ans.cnt = b.cnt;
}else{
ans.mn = a.mn;
ans.cnt = a.cnt + b.cnt;
}
return ans;
}
struct segtree{
int n;
node d[262144];
void push(int u, int l, int r){
if(l + 1 != r){
d[u * 2 + 1].mod += d[u].mod;
d[u * 2 + 2].mod += d[u].mod;
}
d[u].mn += d[u].mod;
d[u].mod = 0;
}
void build(int u, int l, int r){
d[u].mn = 0;
d[u].cnt = r - l;
d[u].mod = 0;
if(l + 1 == r){
}else{
int m = (l + r) / 2;
build(u * 2 + 1, l ,m);
build(u * 2 + 2, m, r);
}
}
void update(int u, int l, int r, int ql, int qr, int dd){
if((ql >= r) || (qr <= l)){
return;
}
if((ql <= l) && (r <= qr)){
d[u].mod += dd;
return;
}
push(u, l, r);
int m = (l + r) / 2;
update(u * 2 + 1, l, m, ql, qr, dd);
update(u * 2 + 2, m, r, ql, qr, dd);
push(u * 2 + 1, l, m);
push(u * 2 + 2, m, r);
d[u] = merge(d[u * 2 + 1], d[u * 2 + 2]);
}
node get(int u, int l, int r, int ql, int qr){
if((ql >= r) || (qr <= l)){
return {inf, 0};
}
push(u, l, r);
if((ql <= l) && (r <= qr)){
return d[u];
}
int m = (l + r) / 2;
return merge(get(u * 2 + 1, l, m, ql, qr), get(u * 2 + 2, m, r, ql, qr));
}
int gg(int l, int r){
auto [mn, cnt, mod] = get(0, 0, n, l, r);
if(mn == 0){
return cnt;
}else{
return 0;
}
}
void upd(int l, int r, int dd){
update(0, 0, n, l, r, dd);
}
};
vector<vector<int>> v;
vector<int> c;
pair<int, int> furthest(int u, int par){
int mx = 0, z = u;
for(auto i : v[u]){
if(i != par){
auto [mx2, z2] = furthest(i, u);
if(mx2 + 1 > mx){
mx = mx2 + 1;
z = z2;
}
}
}
return {mx, z};
}
void calcdist(int u, int par, vector<int> &d){
if(par == -1){
d[u] = 0;
}else{
d[u] = d[par] + 1;
}
for(auto i : v[u]){
if(i != par){
calcdist(i, u, d);
}
}
}
void calcleaf(int u, int par, vector<int> &d){
for(auto i : v[u]){
if(i != par){
calcleaf(i, u, d);
d[u] = max(d[u], d[i] + 1);
}
}
}
int n;
segtree st;
void calcans(int u, int par, int level, vector<int> &d, vector<int> &ans){
st.upd(max(level - d[u], 0), level + 1, 1);
ans[u] = st.gg(0, level);
st.upd(max(level - d[u], 0), level + 1, -1);
int mx1 = 0, mx2 = 0;
for(auto i : v[u]){
if(i != par){
if(d[i] + 1 > mx1){
mx2 = mx1;
mx1 = d[i] + 1;
}else if(d[i] + 1 > mx2){
mx2 = d[i] + 1;
}
}
}
// cout << u << ' ' << NOW << ' ' << d[u] << endl;
// cout << mx1 << ' ' << mx2 << endl;
for(auto i : v[u]){
if(i != par){
if(d[i] + 1 != mx1){
st.upd(max(0, level - mx1), level, 1);
calcans(i, u, level + 1, d, ans);
st.upd(max(0, level - mx1), level, -1);
}else{
st.upd(max(0, level - mx2), level, 1);
calcans(i, u, level + 1, d, ans);
st.upd(max(0, level - mx2), level, -1);
}
}
}
}
vector<int> calc(int s){
vector<int> a(n);
calcleaf(s, -1, a);
vector<int> ans(n);
// cout << "aaaaaaaaaaaaaaaaaaaaa\n";
st.n = n;
st.build(0, 0, n);
calcans(s, -1, 0, a, ans);
return ans;
}
void calcans2(int u, int par, int level, int NOW, vector<int> &d, vector<int> &ans){
if(NOW + d[u] < level){
ans[u] = 1;
}
int mx1 = 0, mx2 = 0;
for(auto i : v[u]){
if(i != par){
if(d[i] + 1 > mx1){
mx2 = mx1;
mx1 = d[i] + 1;
}else if(d[i] + 1 > mx2){
mx2 = d[i] + 1;
}
}
}
// cout << u << ' ' << NOW << ' ' << d[u] << endl;
// cout << mx1 << ' ' << mx2 << endl;
for(auto i : v[u]){
if(i != par){
if(d[i] + 1 != mx1){
if(NOW + mx1 < level){
calcans2(i, u, level + 1, NOW, d, ans);
}else{
calcans2(i, u, level + 1, level, d, ans);
}
}else{
if(NOW + mx2 < level){
calcans2(i, u, level + 1, NOW, d, ans);
}else{
calcans2(i, u, level + 1, level, d, ans);
}
}
}
}
}
vector<int> calc2(int s){
vector<int> a(n);
calcleaf(s, -1, a);
vector<int> ans(n);
// cout << "aaaaaaaaaaaaaaaaaaaaa\n";
calcans2(s, -1, 0, 0, a, ans);
return ans;
}
void solve(){
int m;
cin >> n >> m;
v.resize(n);
for(int i = 0; i < n - 1; i++){
int x, y;
cin >> x >> y;
x--;y--;
v[x].push_back(y);
v[y].push_back(x);
}
c.resize(n);
for(int i = 0; i < n; i++){
cin >> c[i];
}
if(n <= 2000){
for(int i = 0; i < n; i++){
vector<int> d(n);
calcdist(i, -1, d);
vector<vector<int>> z(n);
for(int j = 0; j < n; j++){
z[d[j]].push_back(c[j]);
}
set<int> cc;
for(int zz = 1; zz < n; zz++){
auto j = z[zz];
if(j.size() == 1){
cc.insert(j[0]);
}
}
cout << cc.size() << endl;
}
return;
}
auto [_, s1] = furthest(0, -1);
auto [__, s2] = furthest(s1, -1);
vector<int> d1(n), d2(n);
calcdist(s1, -1, d1);
calcdist(s2, -1, d2);
vector<int> ans1, ans2;
if(m == 1){
ans1 = calc2(s1);
ans2 = calc2(s2);
}else{
ans1 = calc(s1);
ans2 = calc(s2);
}
for(int i = 0; i < n; i++){
if(d1[i] > d2[i]){
cout << ans1[i] << endl;
}else{
cout << ans2[i] << endl;
}
}
}
signed main(){
#ifdef lisie_bimbi
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#else
#endif
cin.tie(nullptr)->sync_with_stdio(false);
int t = 1;
//cin >> t;
while(t--){
solve();
}
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |