Submission #1317217

#TimeUsernameProblemLanguageResultExecution timeMemory
1317217starplatinum휴가 (IOI14_holiday)C++20
100 / 100
298 ms42972 KiB
#include <algorithm> #include <climits> #include <cstdint> #include <vector> #ifdef LOCAL #include <iostream> #include <random> #endif using namespace std; namespace { struct PersistentSegTree { // Segment tree over compressed values [0..m-1] int m = 0; // number of distinct values int node_cnt = 0; // allocated nodes vector<int> lch, rch; vector<int> cnt; vector<long long> sum; vector<long long> values; // original values by compressed index PersistentSegTree() = default; void init(const vector<long long>& sorted_unique_vals, int n_elements) { values = sorted_unique_vals; m = (int)values.size(); int lg = 0; while ((1 << lg) < max(1, m)) ++lg; int cap = (n_elements + 5) * (lg + 3); cap = max(cap, 32); lch.assign(cap, 0); rch.assign(cap, 0); cnt.assign(cap, 0); sum.assign(cap, 0); node_cnt = 0; // node 0 is the null node } int new_node(int from) { int idx = ++node_cnt; if (idx >= (int)cnt.size()) { int new_cap = (int)cnt.size() * 2; lch.resize(new_cap); rch.resize(new_cap); cnt.resize(new_cap); sum.resize(new_cap); } lch[idx] = lch[from]; rch[idx] = rch[from]; cnt[idx] = cnt[from]; sum[idx] = sum[from]; return idx; } int update(int prev, int segL, int segR, int pos, long long val) { int cur = new_node(prev); cnt[cur] += 1; sum[cur] += val; if (segL == segR) return cur; int mid = (segL + segR) >> 1; if (pos <= mid) { lch[cur] = update(lch[prev], segL, mid, pos, val); } else { rch[cur] = update(rch[prev], mid + 1, segR, pos, val); } return cur; } long long query_top_k(int leftRoot, int rightRoot, int segL, int segR, int k) const { if (k <= 0) return 0LL; int cntDiff = cnt[rightRoot] - cnt[leftRoot]; if (cntDiff <= 0) return 0LL; if (segL == segR) { int take = min(k, cntDiff); return (long long)take * values[segL]; } int mid = (segL + segR) >> 1; int lL = lch[leftRoot], lR = lch[rightRoot]; int rL = rch[leftRoot], rR = rch[rightRoot]; int cntRight = cnt[rR] - cnt[rL]; long long sumRight = sum[rR] - sum[rL]; if (k <= cntRight) { return query_top_k(rL, rR, mid + 1, segR, k); } return sumRight + query_top_k(lL, lR, segL, mid, k - cntRight); } }; struct OneDirectionSolver { int n = 0; int s = 0; int d = 0; vector<int> a; vector<long long> comp_vals; // sorted unique values vector<int> roots; // prefix roots, size n+1 PersistentSegTree pst; long long best = 0; int jMax(int i) const { long long jm = (long long)d - (long long)s + 2LL * i; if (jm < s) jm = s; if (jm > n - 1) jm = n - 1; return (int)jm; } long long eval(int i, int j) const { long long moves = (long long)s + (long long)j - 2LL * i; if (moves > d) return LLONG_MIN / 4; int k = (int)((long long)d - moves); int len = j - i + 1; if (k > len) k = len; if (k <= 0) return 0LL; int leftRoot = roots[i]; int rightRoot = roots[j + 1]; return pst.query_top_k(leftRoot, rightRoot, 0, pst.m - 1, k); } void solve(int iL, int iR, int jL, int jR) { if (iL > iR) return; int mid = (iL + iR) >> 1; int upper = min(jR, jMax(mid)); int bestJ = jL; long long bestVal = LLONG_MIN / 4; for (int j = jL; j <= upper; ++j) { long long val = eval(mid, j); if (val > bestVal) { bestVal = val; bestJ = j; } } if (bestVal > best) best = bestVal; solve(iL, mid - 1, jL, bestJ); solve(mid + 1, iR, bestJ, jR); } long long run(const vector<int>& arr, int start, int days) { a = arr; n = (int)a.size(); s = start; d = days; best = 0; // Coordinate compression. comp_vals.assign(n, 0); for (int i = 0; i < n; ++i) comp_vals[i] = (long long)a[i]; sort(comp_vals.begin(), comp_vals.end()); comp_vals.erase(unique(comp_vals.begin(), comp_vals.end()), comp_vals.end()); pst.init(comp_vals, n); // Build persistent prefix roots: roots[t] is multiset of a[0..t-1]. roots.assign(n + 1, 0); for (int i = 0; i < n; ++i) { int pos = (int)(lower_bound(comp_vals.begin(), comp_vals.end(), (long long)a[i]) - comp_vals.begin()); roots[i + 1] = pst.update(roots[i], 0, pst.m - 1, pos, (long long)a[i]); } int iMin = max(0, s - d / 2); int iMax = s; if (iMin > iMax) return 0LL; int jL = s; int jR = jMax(s); if (jL > jR) return 0LL; solve(iMin, iMax, jL, jR); return best; } }; } // namespace long long int findMaxAttraction(int n, int start, int d, int attraction[]) { vector<int> a(n); for (int i = 0; i < n; ++i) a[i] = attraction[i]; OneDirectionSolver solver; long long ans1 = solver.run(a, start, d); // Reverse for the opposite turning direction. vector<int> rev(a.rbegin(), a.rend()); int start_rev = n - 1 - start; long long ans2 = solver.run(rev, start_rev, d); return max(ans1, ans2); } #ifdef LOCAL static long long brute_solve(const vector<int>& a, int s, int d) { int n = (int)a.size(); long long best = 0; for (int L = 0; L <= s; ++L) { for (int R = s; R < n; ++R) { int move = (R - L) + min(s - L, R - s); if (move > d) continue; int k = d - move; int len = R - L + 1; if (k > len) k = len; vector<int> v(a.begin() + L, a.begin() + R + 1); sort(v.begin(), v.end(), greater<int>()); long long sum = 0; for (int i = 0; i < k; ++i) sum += v[i]; best = max(best, sum); } } return best; } int main() { std::mt19937 rng(1); for (int t = 0; t < 2000; ++t) { int n = std::uniform_int_distribution<int>(2, 10)(rng); vector<int> a(n); for (int i = 0; i < n; ++i) a[i] = std::uniform_int_distribution<int>(0, 20)(rng); int s = std::uniform_int_distribution<int>(0, n - 1)(rng); int d = std::uniform_int_distribution<int>(0, 2 * n + n / 2)(rng); long long fast = findMaxAttraction(n, s, d, a.data()); long long brute = brute_solve(a, s, d); if (fast != brute) { cerr << "Mismatch!\n"; cerr << "n=" << n << " s=" << s << " d=" << d << "\n"; cerr << "a: "; for (int x : a) cerr << x << ' '; cerr << "\nfast=" << fast << " brute=" << brute << "\n"; return 0; } } cerr << "All tests passed.\n"; } #endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...