제출 #1316482

#제출 시각아이디문제언어결과실행 시간메모리
1316482pvproJOI tour (JOI24_joitour)C++20
64 / 100
3112 ms375216 KiB
#ifndef LOCAL #pragma GCC Optimize("O3,Ofast,unroll-loops") #pragma GCC Target("bmi,bmi2,avx,avx2") #endif #include <bits/stdc++.h> using namespace std; using ll = long long; using ld = long double; #define f first #define s second #define mp make_pair #define pb push_back #define pii pair<int, int> #define all(x) (x).begin(), (x).end() #define rall(x) (x).rbegin() (x).rend() #ifndef LOCAL #define endl "\n" #endif mt19937 rnd(11); const int LOG_N = 18; struct F { vector<int> t; F() = default; int get(int r) { int ans = 0; for (; r >= 0; r = (r&(r + 1)) - 1) { ans += t[r]; } return ans; } void upd(int i, int x) { for (; i < t.size(); i = (i|(i + 1))) { t[i] += x; } } }; vector<vector<int>> graph; vector<int> CP[LOG_N], lvl, sz, f, tin, tout, binup[LOG_N], tin1[LOG_N], tout1[LOG_N], fst[LOG_N], act; vector<ll> zerotwo; vector<F> T[6], two, zero; ll ans = 0; int Tm = 0; void dfs(int v, int prev = 0) { binup[0][v] = prev; tin[v] = Tm++; for (auto &u : graph[v]) { if (tin[u] == -1) { dfs(u, v); } } tout[v] = Tm; } bool inside(int a, int b) { return tin[a] <= tin[b] && tout[b] <= tout[b]; } int lca(int a, int b) { if (inside(a, b)) { return a; } for (int lg = LOG_N - 1; lg >= 0; --lg) { if (!inside(binup[lg][a], b)) { a = binup[lg][a]; } } return binup[0][a]; } int calcUp(int a, int t) { int ans = 0; while (a != 0) { ans += (f[a] == t); a = binup[0][a]; } return ans + (f[a] == t); } int calcWay(int a, int b, int t) { int lc = lca(a, b); int ans = calcUp(a, t) + calcUp(b, t) - calcUp(lc, t) * 2; if (f[lc] == t) { ++ans; } return ans; } ll sumup(int v, int lg, int t) { return T[t][CP[lg][v]].get(tout1[lg][v]) - T[3 + t][CP[lg][v]].get(tout1[lg][v]); } ll sumsub(int v, int lg, int t) { return T[t][CP[lg][v]].get(tout1[lg][v] - 1) - T[t][CP[lg][v]].get(tin1[lg][v] - 1); } ll calc(int v) { ll ans = 0; for (int lg = lvl[v] - 1; lg >= 0; --lg) { if (f[v] == 0) { ll onenum = sumup(v, lg, 1); ll twonum = sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2); ans += onenum * twonum; ans += two[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1); ans -= two[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - two[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1); } else if (f[v] == 1) { ans += sumsub(v, lg, 0) * (sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2)); ans += sumsub(v, lg, 2) * (sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0)); } else { ll onenum = sumup(v, lg, 1); ll zeronum = sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0); ans += onenum * zeronum; ans += zero[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1); ans -= zero[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - zero[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1); } } if (f[v] == 0) { ans += two[v].get(tout1[lvl[v]][v] - 1); } else if (f[v] == 2) { ans += zero[v].get(tout1[lvl[v]][v] - 1); } else { ans += zerotwo[v]; } return ans; } void del(int v) { ans -= calc(v); for (int lg = lvl[v]; lg >= 0; --lg) { T[f[v]][CP[lg][v]].upd(tin1[lg][v], -1); T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], -1); if (lg != lvl[v]) { if (f[v] == 0) { zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]); zero[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]]))); } else if (f[v] == 2) { zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]); two[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]]))); } else { zero[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 0)); two[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 2)); } } } act[v] = false; } void add(int v) { for (int lg = lvl[v]; lg >= 0; --lg) { T[f[v]][CP[lg][v]].upd(tin1[lg][v], 1); T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], 1); if (lg != lvl[v]) { if (f[v] == 0) { zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]); zero[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]]))); } else if (f[v] == 2) { zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]); two[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]]))); } else { zero[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 0)); two[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 2)); } } } act[v] = true; ans += calc(v); } void init(int n, vector<int> F, vector<int> u, vector<int> v, int q) { f = F; zerotwo.resize(n); act.resize(n); graph.resize(n); lvl.assign(n, -1); binup[0].resize(n); tin.assign(n, -1); tout.resize(n); for (int i = 0; i < 6; ++i) { T[i].resize(n); } two.resize(n); zero.resize(n); for (int i = 0; i < n - 1; ++i) { graph[u[i]].pb(v[i]); graph[v[i]].pb(u[i]); } dfs(0); for (int l = 1; l < LOG_N; ++l) { binup[l].resize(n); for (int i = 0; i < n; ++i) { binup[l][i] = binup[l - 1][binup[l - 1][i]]; } } int lg = 0; auto calcSz = [&](int v, int prev, auto &&self) -> void { sz[v] = 1; for (auto &u : graph[v]) { if (lvl[u] == -1 && u != prev) { self(u, v, self); sz[v] += sz[u]; } } }; auto findCenter = [&](int v, int prev, int Tsz, auto &&self) -> int { for (auto &u : graph[v]) { if (lvl[u] == -1 && u != prev && sz[u] * 2 > Tsz) { return self(u, v, Tsz, self); } } return v; }; int Tm1; auto paint = [&](int v, int prev, int center, auto &&self) -> void { if (prev == center) { fst[lg][v] = v; } else { fst[lg][v] = fst[lg][prev]; } tin1[lg][v] = Tm1++; CP[lg][v] = center; for (auto &u : graph[v]) { if (lvl[u] == -1 && u != prev) { self(u, v, center, self); } } tout1[lg][v] = Tm1++; }; for (; lg < LOG_N; ++lg) { CP[lg].resize(n); fst[lg].resize(n); tin1[lg].resize(n); tout1[lg].resize(n); sz.assign(n, -1); for (int i = 0; i < n; ++i) { if (lvl[i] == -1 && sz[i] == -1) { Tm1 = 0; calcSz(i, i, calcSz); int center = findCenter(i, i, sz[i], findCenter); lvl[center] = lg; calcSz(center, center, calcSz); for (int j = 0; j < 6; ++j) { T[j][center].t.resize(sz[center] * 2); } two[center].t.resize(sz[center] * 2); zero[center].t.resize(sz[center] * 2); paint(center, center, center, paint); } } } for (int i = 0; i < n; ++i) { add(i); } } void change(int v, int x) { del(v); f[v] = x; add(v); } long long num_tours() { return ans; } #ifdef LOCAL #include <cassert> #include <cstdio> int main() { freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); int N; assert(scanf("%d", &N) == 1); std::vector<int> F(N); for (int i = 0; i < N; i++) { assert(scanf("%d", &F[i]) == 1); } std::vector<int> U(N - 1), V(N - 1); for (int j = 0; j < N - 1; j++) { assert(scanf("%d %d", &U[j], &V[j]) == 2); } int Q; assert(scanf("%d", &Q) == 1); init(N, F, U, V, Q); printf("%lld\n", num_tours()); fflush(stdout); for (int k = 0; k < Q; k++) { int X, Y; assert(scanf("%d %d", &X, &Y) == 2); change(X, Y); printf("%lld\n", num_tours()); fflush(stdout); } } #endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...