D. Bear and Rectangle Strips
给定一个$2\times n$的矩形阵,你要选择出尽量多的不相交的子矩阵,每个矩阵内的和为0
题解读了好长时间都没读懂……然后顺着题解的意思拍脑袋拍出来了……
首先,我们知道一个naïve的做法就是$dp_{i,j}$表示上面到第i行,下面到第j行得到的结果,这个做法很好想,但是是$O(n^2)$的,我们需要把它降到$O(n)$
考虑原来$n^2$的转移,我们可以发现,用$(C_i, L_i)$表示只取上下两行前$i$个格子,同时上面那行的格子以$i$结尾,下面那行的格子以$L_i$结束,最多可以选$C_i$个子矩阵,用$(D_i, U_i)$表示只去上下两行前$i$个格子,同时下面那行的格子以$i$结尾,上面那行的格子以$U_i$结尾,最多可以选$D_i$个子矩形。
当然,如果只是这样做,那么依然是$O(n^2)$的——我们要枚举$i$跟$U_i, L_i$,更新$C_i, D_i, L_i, U_i$。那么有没有什么办法可以优化到$O(n)$呢?
这里,我们发现,可以使用单调性DP来做——
- 如果$i>j, C_i \le C_j$,那么$i$状态不会成为最优解。
- 如果$L_i’ > L_i$,那么$i$状态下面不会以$L_i’$结束。
在这里,我们引入原题解中的结论:只需要维护对于$i$,$C_{i+k} = C_i + 1$的$k$即可——这就说明我们的状态可以下降到$O(1)$的程度——对于给定的$i$和$L_i$,我们总可以在预处理之后快速求得这个$k$——lower_bound
即可(当然是预处理之后的)。
对于$D_i$和$U_i$,我们一样的处理即可,需要注意的是$C_i, D_i$和与其对应的$L_i, U_i$要交叉的更新,以及要注意两个0平行的情况——这时候要加2而不是1,这是为了避免更新时的一些问题。
我在最后做这题的时候还多维护了一个变量,来表示上下两排都在$i$处结尾的情况。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 300000 + 5;
int c[maxn][3];
int nxt[maxn][3];
int mat[maxn][2];
long long sum[maxn][3];
map<long long, int> gnxt[3];
int li[maxn][3], lj[maxn][3];
int n;
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &mat[i][0]);
for (int i = 1; i <= n; i++) scanf("%d", &mat[i][1]);
for (int i = 1; i <= n; i++)
{
sum[i][0] = sum[i - 1][0] + mat[i][0];
sum[i][1] = sum[i - 1][1] + mat[i][1];
sum[i][2] = sum[i][0] + sum[i][1];
}
nxt[n][0] = nxt[n][1] = nxt[n][2] = n + 1;
for (int i = n; i >= 0; i--)
{
if (i != n)
{
nxt[i][0] = gnxt[0][sum[i][0]];
nxt[i][1] = gnxt[1][sum[i][1]];
nxt[i][2] = gnxt[2][sum[i][2]];
}
gnxt[0][sum[i][0]] = i;
gnxt[1][sum[i][1]] = i;
gnxt[2][sum[i][2]] = i;
}
for (int i = n - 1; i >= 0; i--)
{
if (nxt[i][0]) nxt[i][0] = min(nxt[i][0], nxt[i + 1][0]);
else nxt[i][0] = nxt[i + 1][0];
if (nxt[i][1]) nxt[i][1] = min(nxt[i][1], nxt[i + 1][1]);
else nxt[i][1] = nxt[i + 1][1];
if (nxt[i][2]) nxt[i][2] = min(nxt[i][2], nxt[i + 1][2]);
else nxt[i][2] = nxt[i + 1][2];
}
// for (int i = 0; i <= n; i++)
// printf("%d %d %d\n", nxt[i][0], nxt[i][1], nxt[i][2]);
//--
for (int i = 0; i <= n; i++)
{
for (int k = 0; k < 3; k++)
{
int j1 = nxt[li[i][k]][0];
int j2 = nxt[lj[i][k]][1];
if (j1 == j2)
{
if (c[j1][2] < c[i][k] + 2)
{
c[j1][2] = c[i][k] + 2;
li[j1][2] = lj[j1][2] = j1;
}
} else
{
if (c[j1][0] < c[i][k] + 1 || (c[j1][0] == c[i][k] + 1 && lj[i][k] < lj[j1][0]))
{
c[j1][0] = c[i][k] + 1;
li[j1][0] = j1;
lj[j1][0] = lj[i][k];
}
if (c[j2][1] < c[i][k] + 1 || (c[j2][1] == c[i][k] + 1 && li[i][k] < li[j2][1]))
{
c[j2][1] = c[i][k] + 1;
li[j2][1] = li[i][k];
lj[j2][1] = j2;
}
}
int h2 = nxt[i][2];
if (c[h2][2] < c[i][k] + 1)
{
c[h2][2] = c[i][k] + 1;
li[h2][2] = lj[h2][2] = h2;
}
}
}
int ans = 0;
for (int i = 0; i <= n; i++) for (int k = 0; k < 3; k++) ans = max(ans, c[i][k]);
cout << ans << endl;
return 0;
}