#include <bits/stdc++.h>
using namespace std;
vector<vector<int>> child;
vector<int> joke;
vector<bitset<101>> lans, rans;
void dfs(int x) {
vector<int> lo, hi;
for (int y : child[x]) {
dfs(y);
if (joke[y] < joke[x]) lo.push_back(y);
if (joke[y] > joke[x]) hi.push_back(y);
}
// Base case: interval containing only joke[x]
lans[x][joke[x]] = 1;
rans[x][joke[x] + 1] = 1;
// Extend to smaller joke types
sort(lo.begin(), lo.end(),
[&](int a, int b) { return joke[a] > joke[b]; });
for (int y : lo) {
if ((lans[x] & rans[y]).any()) {
lans[x] |= lans[y];
}
}
// Extend to larger joke types
sort(hi.begin(), hi.end(),
[&](int a, int b) { return joke[a] < joke[b]; });
for (int y : hi) {
if ((rans[x] & lans[y]).any()) {
rans[x] |= rans[y];
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
joke.resize(N);
for (int i = 0; i < N; i++) {
cin >> joke[i];
--joke[i]; // make joke types 0-based
}
child.assign(N, {});
for (int i = 0; i < N - 1; i++) {
int a, b;
cin >> a >> b;
--a; --b;
child[a].push_back(b);
}
lans.assign(N, bitset<101>());
rans.assign(N, bitset<101>());
dfs(0);
long long ans = 1LL * lans[0].count() * rans[0].count();
cout << ans << "\n";
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |