保序回归入门题
题目链接:P7294 [USACO21JAN] Minimum Cost Paths P
题意:
现有 $n\times m$ 的一个矩形,第 $i$ 列有一个代价 $c_i$。如果现在你在 $(x,y)$,可以花费 $x^2$ 的代价到 $(x,y+1)$,或 $c_y$ 的代价到 $(x+1,y)$。
现有 $q$ 次询问,每次给定 $(x_i,y_i)$,求从 $(1,1)$ 到 $(x_i,y_i)$ 的最小代价。
$2\le n\le 10^9,2\le m,q\le 10^5$。
本题解提供两种做法,第一种是 $\text{B}\red{\text{enq}}$ 给出的正解,还有一种是更为无脑的保序回归。由于作者很懒所以只写了第二种的代码
听神仙zzm说本质是一样的 代码还差不多
Solution 1
本部分参照 $\text{B}\red{\text{enq}}$ 在 $\text{USACO}$ 发布的官方题解。
令从 $(1,1)$ 到 $(x,y)$ 的最小代价为 $\text{ans}_y(x)$,那么能够发现,对于一个固定的 $y$,其函数值随着 $x$ 单调增且 $\text{ans}_{y}(x)-\text{ans}_{y}(x-1)\le \text{ans}_{y}(x+1)-\text{ans}_{y}(x)$。可以这样解释:
考虑从 $\text{ans}_y(x)$ 转移到 $\text{ans}_{y+1}(x)$ 的时候:
- 先令 $\text{ans}_{y+1}(x)=\text{ans}_y(x)+x^2$。
- 接着令 $\text{ans}_{y+1}(x)=\min(\text{ans}_{y+1}(x),\text{ans}_{y+1}(x-1)+c_y)$。
对于操作 $2$,相当于拿一条斜率为 $c_y$ 的直线去切函数图像,然后把过高的部分砍下来。同时,看到 $1$ 中 $\text{ans}_{y+1}(x)$ 上加了 $x^2$ 就可以大概发现 $\text{ans}_{y}(x)-\text{ans}_{y}(x-1)\le \text{ans}_{y}(x+1)-\text{ans}_{y}(x)$ 的规律。
那么现在可以使用一个栈来维护这样一个类似凸壳的东西,每次在栈里二分找到被直线给砍掉的位置,然后加入一个新点。至于每次的 $1$ 操作,也就是 $+x^2$,可以记录一个时间戳,然后和当前时间戳作差。
是的代码咕了
Solution 2
本部分来自 $\text{zxyhymzg}$ 考场想法。
萌新刚学保序回归 有锅请轻喷
令 $s_i$ 为我们在第 $i$ 行处走到 $(i,s_i)$,其中 $s_0=1$。那么显然有 $s_{i-1}\le s_i$,因为只能向右或者向下走。
计算一下代价,对于从 $(1,1)$ 到 $(x,y)$,$\forall s_i\le x$,代价为:
$$\sum_{i=1}^{y-1} s_i^2+(s_i-s_{i-1})\times c_i$$
后面的部分交换和式,把一项变成只和 $s_i$ 相关:
$$x\times c_y-c_1+\sum_{i=1}^{y-1} s_i^2+(c_i-c_{i+1})\times s_i$$
和式外面是常数,那么只要最小化和式内部。发现是二次函数,先写成顶点式:
$$(s_i-\dfrac{c_{i+1}-c_i}{2})^2-\dfrac{(c_{i+1}-c_i)^2}{4}$$
后面的也是常数,不管。变成最小化:
$$\sum_{i=1}^{y-1} (s_i-\dfrac{c_{i+1}-c_i}{2})^2$$
$\dfrac{c_{i+1}-c_i}{2}$ 是常数,那么就变成了经典的保序回归中的特殊的 $L_2$ 问题,在 $\text{IOI2018}$ 国家候选队论文集中有写到。
需要注意的是,现在对于一个固定的 $y$,发现其 $s_i$ 是固定的。而我们要求 $1\le s_i\le x$,实际上可能会不满足。可以发现将 $s_i$ 向 $[1,x]$ 取整,同时浮点数四舍五入即可。
那么同样的,在保序回归该特殊情况的贪心求解中,我们需要维护一个单调栈,因此对于向 $[1,x]$ 取整的操作可以在单调栈上二分,取整后一段区间的贡献是可以 $\mathcal{O}(1)$ 计算的。
至此即可解决本题,时间复杂度 $\mathcal{O}(n\log n)$,实际上好像效率海星?
$\text{Code:}$
//Code By CXY07
#include<bits/stdc++.h>
using namespace std;
//#define FILE
#define int long long
#define debug(x) cout << #x << " = " << x << endl
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define LINE() cout << "LINE = " << __LINE__ << endl
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define vec vector
const int MAXN = 2e5 + 10;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
//const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;
struct Query {
int x, y, id;
Query(int _x = 0, int _y = 0, int _id = 0) : x(_x), y(_y), id(_id) {}
bool operator < (const Query &b) const {return y < b.y;}
} q[MAXN];
int n, m, qs, top;
int c[MAXN], sum[MAXN], stk[MAXN];
//为了尽量避免浮点数计算,sum数组中先不将c_{i+1}-c_i除2
int Ans[MAXN], num[MAXN], val[MAXN];
template<typename T> inline bool read(T &a) {
a = 0; char c = getchar(); int f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
a *= f;
return 1;
}
template<typename A, typename ...B>
inline bool read(A &x, B &...y) {return read(x) && read(y...);}
double aver(int l, int r) {
return 1. * (sum[r] - sum[l - 1]) / (2. * (r - l + 1));
}
int calc(int l, int r, int x) {
if(l > r) return 0;
return x * x * (r - l + 1) - x * (sum[r] - sum[l - 1]);
}
void Insert(int x) {
while(top && aver(stk[top - 1] + 1, stk[top]) + eps > aver(stk[top] + 1, x)) top--;
stk[++top] = x;
num[top] = (int)round(aver(stk[top - 1] + 1, stk[top]));
val[top] = val[top - 1] + calc(stk[top - 1] + 1, stk[top], num[top]);
}
signed main () {
#ifdef FILE
freopen(".in","r",stdin);
freopen(".out","w",stdout);
#endif
read(n), read(m);
for(int i = 1; i <= m; ++i) read(c[i]);
for(int i = 1; i <= m; ++i) sum[i] = sum[i - 1] + c[i + 1] - c[i];
read(qs);
for(int i = 1, x, y; i <= qs; ++i) {
read(x), read(y);
q[i] = Query(x, y, i);
}
sort(q + 1, q + qs + 1);
int nowy = 0;
for(int p = 1; p <= qs; ++p) {
if(q[p].y == 1) {
Ans[q[p].id] = (q[p].x - 1) * c[1];
continue;
}
while(nowy + 1 < q[p].y) {
nowy++;
Insert(nowy);
}
Ans[q[p].id] = val[top] + q[p].x * c[q[p].y] - c[1];
int L = 0, R = top, mid;
while(L < R) {
mid = (L + R + 1) >> 1;
if(num[mid] > q[p].x) R = mid - 1;
else L = mid;
}
Ans[q[p].id] += calc(stk[L] + 1, nowy, q[p].x) - (val[top] - val[L]);
L = 1, R = top + 1;
while(L < R) {
mid = (L + R) >> 1;
if(num[mid] <= 0) L = mid + 1;
else R = mid;
}
Ans[q[p].id] += calc(1, stk[L - 1], 1) - val[L - 1];
}
for(int i = 1; i <= qs; ++i)
printf("%lld\n", Ans[i]);
return 0;
}