#include<bits/allocator.h>
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx2,fma,bmi,bmi2,popcnt,lzcnt,tune=native")
#include<bits/stdc++.h>
using namespace std;
typedef int ll;
const int mn = (1 << 20) + 5;
ll n, a[mn], st, en, t, p[mn][21], d[mn][21], dp[mn][42], in[mn], out[mn], ti, f[100][100], o[100][100], dst[100][100], dep[mn];
vector<pair<ll, ll> > v[mn];
vector<ll> g, sk;
void dfs(ll i, ll pa, ll s){
in[i] = ++ti;
p[i][0] = pa;
d[i][0] = s;
dep[i] = dep[pa] + 1;
for(int j = 1; j <= 20; ++j) p[i][j] = p[p[i][j - 1]][j - 1];
for(int j = 1; j <= 20; ++j) d[i][j] = (d[p[i][j - 1]][j - 1] | d[i][j - 1]);
for(auto j : v[i]){
if(j.first == pa) continue;
if(j.second == -1) dfs(j.first, i, 0);
else dfs(j.first, i, (1 << j.second));
}
out[i] = ti;
}
bool anc(ll a, ll b){
return in[a] <= in[b] && out[b] <= out[a];
}
ll lca(ll a, ll b){
if(anc(a, b)) return a;
if(anc(b, a)) return b;
for(int j = 20; j >= 0; --j){
if(!anc(p[a][j], b)) a = p[a][j];
}
return p[a][0];
}
ll cal(ll a, ll b, ll u){
ll ans = 0;
for(int j = 20; j >= 0; --j){
if(!anc(p[a][j], u)) ans |= d[a][j], a = p[a][j];
}
if(a != u) ans |= d[a][0];
for(int j = 20; j >= 0; --j){
if(!anc(p[b][j], u)) ans |= d[b][j], b = p[b][j];
}
if(b != u) ans |= d[b][0];
return ans;
}
void solve(){
cin >> n >> st >> en;
sk.push_back(st);
for(int i = 1; i < n; ++i){
ll x, y, w; cin >> x >> y >> w;
if(w > 0) sk.push_back(x), sk.push_back(y), sk.push_back(w);
if(w > 0) g.push_back(w), w = (int)g.size() - 1;
else w = -1;
v[x].push_back({y, w});
v[y].push_back({x, w});
}
sk.push_back(en);
ti = 0;
t = g.size();
dfs(1, 1, 0);
memset(dp, -1, sizeof(dp));
dp[0][0] = 0;
ll ans = 1e9;
for(int i = 0; i < 3 * t + 2; ++i) for(int j = 0; j < 3 * t + 2; ++j){
f[i][j] = lca(sk[i], sk[j]);
o[i][j] = cal(sk[i], sk[j], f[i][j]);
dst[i][j] = dep[sk[i]] + dep[sk[j]] - 2 * dep[f[i][j]];
}
for(int i = 0; i < (1 << t); ++i) for(int j = 0; j <= 2 * t; ++j){
if(dp[i][j] == -1) continue;
ll s;
if(j == 0) s = 0;
else if(j % 2) s = 3 * ((j - 1) / 2) + 1;
else s = 3 * ((j - 1) / 2) + 2;
// cout << s << " " << sk[s] << "\n";
for(int k = 0; k < t; ++k){
if((i >> k) & 1) continue;
ll r = (o[s][3 * k + 3] | o[3 * k + 3][3 * k + 1]);
if((i & r) == r){
ll nw = dp[i][j] + dst[s][3 * k + 3] + dst[3 * k + 3][3 * k + 1], mask = (i ^ (1 << k));
if(dp[mask][2 * k + 1] == -1) dp[mask][2 * k + 1] = nw;
else dp[mask][2 * k + 1] = min(dp[mask][2 * k + 1], nw);
}
r = (o[s][3 * k + 3] | o[3 * k + 3][3 * k + 2]);
// cout << r << "\n";
if((i & r) == r){
ll nw = dp[i][j] + dst[s][3 * k + 3] + dst[3 * k + 3][3 * k + 2], mask = (i ^ (1 << k));
if(dp[mask][2 * k + 2] == -1) dp[mask][2 * k + 2] = nw;
else dp[mask][2 * k + 2] = min(dp[mask][2 * k + 2], nw);
}
}
ll r = o[s][3 * t + 1];
if((i & r) == r){
// cout << i << " " << j << "- \n";
ans = min(ans, dp[i][j] + dst[s][3 * t + 1]);
}
}
if(ans == 1e9) cout << -1;
else cout << ans;
return;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
if(fopen(".INP", "r")) {
freopen(".INP", "r", stdin);
freopen(".OUT", "w", stdout);
}
int testCase = 1;
//cin >> testCase;
while(testCase--) solve();
}
Compilation message (stderr)
Main.cpp: In function 'int main()':
Main.cpp:116:24: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
116 | freopen(".INP", "r", stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~
Main.cpp:117:24: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
117 | freopen(".OUT", "w", stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |