#include <algorithm>
#include <climits>
#include <cstdint>
#include <vector>
#ifdef LOCAL
#include <iostream>
#include <random>
#endif
using namespace std;
namespace {
struct PersistentSegTree {
// Segment tree over compressed values [0..m-1]
int m = 0; // number of distinct values
int node_cnt = 0; // allocated nodes
vector<int> lch, rch;
vector<int> cnt;
vector<long long> sum;
vector<long long> values; // original values by compressed index
PersistentSegTree() = default;
void init(const vector<long long>& sorted_unique_vals, int n_elements) {
values = sorted_unique_vals;
m = (int)values.size();
int lg = 0;
while ((1 << lg) < max(1, m)) ++lg;
int cap = (n_elements + 5) * (lg + 3);
cap = max(cap, 32);
lch.assign(cap, 0);
rch.assign(cap, 0);
cnt.assign(cap, 0);
sum.assign(cap, 0);
node_cnt = 0; // node 0 is the null node
}
int new_node(int from) {
int idx = ++node_cnt;
if (idx >= (int)cnt.size()) {
int new_cap = (int)cnt.size() * 2;
lch.resize(new_cap);
rch.resize(new_cap);
cnt.resize(new_cap);
sum.resize(new_cap);
}
lch[idx] = lch[from];
rch[idx] = rch[from];
cnt[idx] = cnt[from];
sum[idx] = sum[from];
return idx;
}
int update(int prev, int segL, int segR, int pos, long long val) {
int cur = new_node(prev);
cnt[cur] += 1;
sum[cur] += val;
if (segL == segR) return cur;
int mid = (segL + segR) >> 1;
if (pos <= mid) {
lch[cur] = update(lch[prev], segL, mid, pos, val);
} else {
rch[cur] = update(rch[prev], mid + 1, segR, pos, val);
}
return cur;
}
long long query_top_k(int leftRoot, int rightRoot, int segL, int segR, int k) const {
if (k <= 0) return 0LL;
int cntDiff = cnt[rightRoot] - cnt[leftRoot];
if (cntDiff <= 0) return 0LL;
if (segL == segR) {
int take = min(k, cntDiff);
return (long long)take * values[segL];
}
int mid = (segL + segR) >> 1;
int lL = lch[leftRoot], lR = lch[rightRoot];
int rL = rch[leftRoot], rR = rch[rightRoot];
int cntRight = cnt[rR] - cnt[rL];
long long sumRight = sum[rR] - sum[rL];
if (k <= cntRight) {
return query_top_k(rL, rR, mid + 1, segR, k);
}
return sumRight + query_top_k(lL, lR, segL, mid, k - cntRight);
}
};
struct OneDirectionSolver {
int n = 0;
int s = 0;
int d = 0;
vector<int> a;
vector<long long> comp_vals; // sorted unique values
vector<int> roots; // prefix roots, size n+1
PersistentSegTree pst;
long long best = 0;
int jMax(int i) const {
long long jm = (long long)d - (long long)s + 2LL * i;
if (jm < s) jm = s;
if (jm > n - 1) jm = n - 1;
return (int)jm;
}
long long eval(int i, int j) const {
long long moves = (long long)s + (long long)j - 2LL * i;
if (moves > d) return LLONG_MIN / 4;
int k = (int)((long long)d - moves);
int len = j - i + 1;
if (k > len) k = len;
if (k <= 0) return 0LL;
int leftRoot = roots[i];
int rightRoot = roots[j + 1];
return pst.query_top_k(leftRoot, rightRoot, 0, pst.m - 1, k);
}
void solve(int iL, int iR, int jL, int jR) {
if (iL > iR) return;
int mid = (iL + iR) >> 1;
int upper = min(jR, jMax(mid));
int bestJ = jL;
long long bestVal = LLONG_MIN / 4;
for (int j = jL; j <= upper; ++j) {
long long val = eval(mid, j);
if (val > bestVal) {
bestVal = val;
bestJ = j;
}
}
if (bestVal > best) best = bestVal;
solve(iL, mid - 1, jL, bestJ);
solve(mid + 1, iR, bestJ, jR);
}
long long run(const vector<int>& arr, int start, int days) {
a = arr;
n = (int)a.size();
s = start;
d = days;
best = 0;
// Coordinate compression.
comp_vals.assign(n, 0);
for (int i = 0; i < n; ++i) comp_vals[i] = (long long)a[i];
sort(comp_vals.begin(), comp_vals.end());
comp_vals.erase(unique(comp_vals.begin(), comp_vals.end()), comp_vals.end());
pst.init(comp_vals, n);
// Build persistent prefix roots: roots[t] is multiset of a[0..t-1].
roots.assign(n + 1, 0);
for (int i = 0; i < n; ++i) {
int pos = (int)(lower_bound(comp_vals.begin(), comp_vals.end(), (long long)a[i]) - comp_vals.begin());
roots[i + 1] = pst.update(roots[i], 0, pst.m - 1, pos, (long long)a[i]);
}
int iMin = max(0, s - d / 2);
int iMax = s;
if (iMin > iMax) return 0LL;
int jL = s;
int jR = jMax(s);
if (jL > jR) return 0LL;
solve(iMin, iMax, jL, jR);
return best;
}
};
} // namespace
long long int findMaxAttraction(int n, int start, int d, int attraction[]) {
vector<int> a(n);
for (int i = 0; i < n; ++i) a[i] = attraction[i];
OneDirectionSolver solver;
long long ans1 = solver.run(a, start, d);
// Reverse for the opposite turning direction.
vector<int> rev(a.rbegin(), a.rend());
int start_rev = n - 1 - start;
long long ans2 = solver.run(rev, start_rev, d);
return max(ans1, ans2);
}
#ifdef LOCAL
static long long brute_solve(const vector<int>& a, int s, int d) {
int n = (int)a.size();
long long best = 0;
for (int L = 0; L <= s; ++L) {
for (int R = s; R < n; ++R) {
int move = (R - L) + min(s - L, R - s);
if (move > d) continue;
int k = d - move;
int len = R - L + 1;
if (k > len) k = len;
vector<int> v(a.begin() + L, a.begin() + R + 1);
sort(v.begin(), v.end(), greater<int>());
long long sum = 0;
for (int i = 0; i < k; ++i) sum += v[i];
best = max(best, sum);
}
}
return best;
}
int main() {
std::mt19937 rng(1);
for (int t = 0; t < 2000; ++t) {
int n = std::uniform_int_distribution<int>(2, 10)(rng);
vector<int> a(n);
for (int i = 0; i < n; ++i) a[i] = std::uniform_int_distribution<int>(0, 20)(rng);
int s = std::uniform_int_distribution<int>(0, n - 1)(rng);
int d = std::uniform_int_distribution<int>(0, 2 * n + n / 2)(rng);
long long fast = findMaxAttraction(n, s, d, a.data());
long long brute = brute_solve(a, s, d);
if (fast != brute) {
cerr << "Mismatch!\n";
cerr << "n=" << n << " s=" << s << " d=" << d << "\n";
cerr << "a: ";
for (int x : a) cerr << x << ' ';
cerr << "\nfast=" << fast << " brute=" << brute << "\n";
return 0;
}
}
cerr << "All tests passed.\n";
}
#endif
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |