#include <bits/stdc++.h>
using namespace std;
int n, LOG;
vector <vector <pair <int, int>>> adj;
vector <int> depth;
vector <long long> d;
vector <vector <int>> par;
vector <int> time_in, time_out;
int __lca(int u, int v) {
if(depth[u] < depth[v]) swap(u, v);
for (int i = LOG; --i >= 0;) {
if(depth[par[i][u]] >= depth[v]) {
u = par[i][u];
}
}
if(u == v) return u;
for (int i = LOG; --i >= 0;) {
if(par[i][u] != par[i][v]) {
u = par[i][u];
v = par[i][v];
}
}
return par[0][u];
}
long long dist(int u, int v) {
return d[u] + d[v] - 2 * d[__lca(u, v)];
}
bool is_par(int u, int v) { // check u is parent of v
return time_in[u] <= time_in[v] and time_in[v] <= time_out[u];
}
void Init(int N, int A[], int B[], int D[]) {
n = N;
LOG = 32 - __builtin_clz(n);
adj = vector <vector <pair <int, int>>> (n + 1);
for (int i = 0; i < n - 1; ++i) {
adj[A[i] + 1].emplace_back(B[i] + 1, D[i]);
adj[B[i] + 1].emplace_back(A[i] + 1, D[i]);
}
par = vector <vector <int>> (LOG + 1, vector <int> (n + 1));
depth = time_in = time_out = vector <int> (n + 1);
d = vector <long long > (n + 1, 0);
function <void(int, int)> dfs = [&] (int u, int p) {
static int run = 0;
time_in[u] = ++run;
depth[u] = depth[p] + 1;
par[0][u] = p;
for (int i = 1; (1 << i) <= depth[u]; ++i) {
par[i][u] = par[i - 1][par[i - 1][u]];
}
for (auto [v, w] : adj[u]) if(v != p) {
d[v] = d[u] + w;
dfs(v, u);
}
time_out[u] = run;
};
dfs(1, 0);
}
long long Query(int S, int X[], int T, int Y[]) {
static vector <int> dd(n + 1, -1);
vector <int> p;
for (int i = 0; i < S; ++i) {
p.emplace_back(X[i] + 1);
dd[X[i] + 1] = 0;
}
for (int i = 0; i < T; ++i) {
p.emplace_back(Y[i] + 1);
dd[Y[i] + 1] = 1;
}
sort(p.begin(), p.end(), [&] (const int &x, const int &y) {
return time_in[x] < time_in[y];
});
int len = p.size();
for (int i = 0; i < len - 1; ++i) {
p.emplace_back(__lca(p[i], p[i + 1]));
}
sort(p.begin(), p.end(), [&] (const int &x, const int &y) {
return time_in[x] < time_in[y];
});
p.erase(unique(p.begin(), p.end()), p.end());
static vector <vector <pair <int, long long>>> adj(n + 1);
static vector <vector <long long>> dp(n + 1, vector <long long> (2));
stack <int> st;
if(p[0] != 1) st.push(1);
for (auto u : p) {
while(not st.empty() and !is_par(st.top(), u)) st.pop();
if(not st.empty()) adj[st.top()].emplace_back(u, dist(u, st.top()));
// cout << u << " " << st.top() << endl;
st.push(u);
}
// return 1;
long long ans = 1e18;
function <void(int)> dfs = [&] (int u) {
dp[u][0] = dp[u][1] = 1e18;
if(dd[u] != -1) {
dp[u][dd[u]] = 0;
}
for (auto [v, w] : adj[u]) {
dfs(v);
// cout << u - 1 << " " << v - 1 << " " << w << endl;
ans = min(ans, dp[u][0] + w + dp[v][1]);
ans = min(ans, dp[u][1] + w + dp[v][0]);
dp[u][0] = min(dp[u][0], dp[v][0] + w);
dp[u][1] = min(dp[u][1], dp[v][1] + w);
}
};
// for (int i = 1; i <= n; ++i) cout << dp[i][0] << " " << dp[i][1] << endl;
dfs(1);
for (int u : p) {
dd[u] = -1;
adj[u].clear();
}
return ans;
}
#ifdef LOCAL
#include <stdio.h>
#include <stdlib.h>
#define MAX_N 500000
#define MAX_Q 100000
#define MAX_SUM_ST 1000000
#define MAX_VALUE 1000000000
static int N, Q;
static int A[MAX_N], B[MAX_N], D[MAX_N];
static int S[MAX_N];
static int T[MAX_N];
static int X[MAX_SUM_ST];
static int Y[MAX_SUM_ST];
static int Qx[MAX_N];
static int Qy[MAX_N];
int main() {
// freopen("TASK.inp","r",stdin);
int i, j, k;
int STop, TTop;
if (2 != scanf("%d%d", &N, &Q)) {
fprintf(stderr, "error: cannot read N and Q.\n");
exit(1);
}
if (!(2 <= N && N <= MAX_N)) {
fprintf(stderr, "error: N is out of bounds.\n");
exit(1);
}
if (!(1 <= Q && Q <= MAX_Q)) {
fprintf(stderr, "error: Q is out of bounds.\n");
exit(1);
}
for (i = 0; i < N - 1; ++i) {
if (1 != scanf("%d", &A[i])) {
fprintf(stderr, "error: cannot read A[%d].\n", i);
exit(1);
}
if (!(0 <= A[i] && A[i] <= N - 1)) {
fprintf(stderr, "error: A[%d] is out of bounds.\n", i);
exit(1);
}
if (1 != scanf("%d", &B[i])) {
fprintf(stderr, "error: cannot read B[%d].\n", i);
exit(1);
}
if (!(0 <= B[i] && B[i] <= N - 1)) {
fprintf(stderr, "error: B[%d] is out of bounds.\n", i);
exit(1);
}
if (A[i] == B[i]) {
fprintf(stderr, "error: B[%d] is equal to A[%d].\n", i, i);
exit(1);
}
if (1 != scanf("%d", &D[i])) {
fprintf(stderr, "error: cannot read D[%d].\n", i);
exit(1);
}
if (!(1 <= D[i] && D[i] <= MAX_VALUE)) {
fprintf(stderr, "error: D[%d] is out of bounds.\n", i);
exit(1);
}
}
STop = 0;
TTop = 0;
for (j = 0; j < Q; ++j) {
if (2 != scanf("%d%d", &S[j], &T[j])) {
fprintf(stderr, "error: cannot read L[%d] and R[%d].\n", j, j);
exit(1);
}
if(STop + S[j] > MAX_SUM_ST) {
fprintf(stderr, "error: S[0] + S[1] + ... + S[%d] is out of bounds.\n", j);
exit(1);
}
if(TTop + T[j] > MAX_SUM_ST) {
fprintf(stderr, "error: T[0] + T[1] + ... + T[%d] is out of bounds.\n", j);
exit(1);
}
for (k = 0; k < S[j]; ++k, ++STop) {
if (1 != scanf("%d", &X[STop])) {
fprintf(stderr, "error: cannot read X[%d][%d].\n", j, k);
exit(1);
}
if (!(0 <= X[STop] && X[STop] <= N - 1)) {
fprintf(stderr, "error: cannot read X[%d][%d].\n", j, k);
exit(1);
}
}
for (k = 0; k < T[j]; ++k, ++TTop) {
if (1 != scanf("%d", &Y[TTop])) {
fprintf(stderr, "error: cannot read Y[%d][%d].\n", j, k);
exit(1);
}
if (!(0 <= Y[TTop] && Y[TTop] <= N - 1)) {
fprintf(stderr, "error: cannot read Y[%d][%d].\n", j, k);
exit(1);
}
}
}
STop = 0;
TTop = 0;
Init(N, A, B, D);
for (j = 0; j < Q; ++j) {
for (k = 0; k < S[j]; k++) {
Qx[k] = X[STop++];
}
for (k = 0; k < T[j]; k++) {
Qy[k] = Y[TTop++];
}
printf("%lld\n", Query(S[j], Qx, T[j], Qy));
}
return 0;
}
#endif
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |