题意是,有$n$个石头,每个石头有初始能量$E_i$,每秒能量增长$L_i$,以及能量上限$C_i$,有$m$个收能量的时间点,每次把区间$\left[S_i, T_i\right]$石头的能量都给收掉,石头的能量都置零重新开始增长。问最后收了多少能量。
看完题解觉得好有道理...我好菜...
考虑每个石头在多少个时间点收能量,然后每次收的能量就和这些时间点的时间间隔有关。
若时间间隔大于等于$\dfrac {C_i}{L_i}$,那么这一段对答案的贡献就是$C_i$了,统计有多少这样的段即可。
若时间间隔小于$\dfrac {C_i}{L_i}$那么对答案的贡献就是时间长度$t \times L_i$。
用两个权值树状数组可以维护对应时间长度的和及个数。
时间点可以用set维护。从前到后遍历,遇到一个$S_i$就把对应是时间加入,遇到一个$T_i + 1$就把时间删去,同时维护树状数组即可。感觉看代码就很好懂?

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5 + 7;
int n, m;
ll E[N], C[N], L[N];
set<int> st;
vector<int> G[N];
struct BIT {
ll tree1[N], tree2[N];
inline void clear() {
memset(tree1, 0, sizeof tree1);
memset(tree2, 0, sizeof tree2);
}
inline int lowbit(int x) {
return x & -x;
}
inline void add(int x, int val) {
if (!x) return;
for (int i = x; i < N; i += lowbit(i)) {
if (val > 0) tree1[i]++;
else tree1[i]--;
tree2[i] += val;
}
}
inline int query1(int x) {
int ans = 0;
for (int i = x; i; i -= lowbit(i))
ans += tree1[i];
return ans;
}
inline int query2(int x) {
int ans = 0;
for (int i = x; i; i -= lowbit(i))
ans += tree2[i];
return ans;
}
} bit;
inline void init() {
st.clear();
bit.clear();
for (int i = 0; i <= n; i++) G[i].clear();
}
void add(int x) {
if (st.empty()) {
st.insert(x);
return;
}
auto p = st.lower_bound(x);
if (p == st.begin()) {
bit.add((*p - x), (*p - x));
st.insert(x);
return;
}
if (p == st.end()) {
bit.add(x - (*prev(p)), x - (*prev(p)));
st.insert(x);
return;
}
int x1 = (*p) - x, x2 = x - (*prev(p));
bit.add(x1, x1);
bit.add(x2, x2);
bit.add(x1 + x2, -x1 - x2);
st.insert(x);
}
void del(int x) {
auto p = st.find(x);
if (st.size() == 1) {
st.erase(p);
return;
}
if (p == st.begin()) {
bit.add((*next(p)) - x, x -(*next(p)));
st.erase(p);
return;
}
if (p == prev(st.end())) {
bit.add(x - (*prev(p)), (*prev(p)) - x);
st.erase(p);
return;
}
int x1 = (*next(p)) - x, x2 = x - (*prev(p));
bit.add(x1, -x1);
bit.add(x2, -x2);
bit.add(x1 + x2, x1 + x2);
st.erase(p);
}
int main() {
int T, kase = 0;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
init();
for (int i = 1; i <= n; i++)
scanf("%lld%lld%lld", &E[i], &L[i], &C[i]);
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
int l, r, t;
scanf("%d%d%d", &t, &l, &r);
G[l].push_back(t); G[r + 1].push_back(-t);
}
ll ans = 0;
for (int i = 1; i <= n; i++) {
for (auto x: G[i]) {
if (x > 0) add(x);
else del(-x);
}
if (st.empty()) continue;
ans += min(C[i], 1LL * (*st.begin()) * L[i] + E[i]);
if (!L[i]) continue;
ans += (st.size() - 1 - bit.query1(C[i] / L[i])) * C[i] + bit.query2(C[i] / L[i]) * L[i];
}
printf("Case #%d: %lld\n", ++kase, ans);
}
return 0;
}
