Natasha, Sasha and the Prefix Sums
链接
视频
试题
试题大意
定义一个数列 $a$ 的最大前缀和为 $f(a)=\max(0,\max\limits_{1\le i\le |a|} \sum\limits_{1\le j\le i} a_j )$, 对于包含 $n$ 个+1与 $m$ 个 −1 的所有 $a$ ,求对应的 $f(a)$ 的和, 答案对998244853取模, $0\le n, m\le 2000$,时间限制2秒
一种思路
试题的作者使用了动态规划的做法来解决此题,但使用纯组合数学方法亦可解决此题.
考虑一个退化版问题,即对于包含 $n$ 个1与 $m$ 个−1的所有子串,前缀和最大值恰为 $x$ 的序列种数
对于这个问题,我们不妨继续简化情况,询问有 $n$ 个1与 $m$ 个-1的所有序列中,最大前缀和恰为 n-m 的情况种数
可以证明简化后问题的答案为 $C^{n+m}_m−C^{n+m}_{m−1}$ 即,序列的任意后缀中,1的个数不少于-1的个数,这个答案的证明可以参考卡塔兰数的证明方法。
那么,对于包含 $n$ 个1与 $m$ 个-1的所有子串,前缀和最大值恰为 $x$ 的序列种数,可以通过枚举最大值 $x$ 第一次出现的位置得到,如果假设位置是 $y$ 的话
可以知道这个前缀包含 $(x+y)/2$ 个+1与 $(y-x)/2$ 个-1 那么这个前缀的种数即使之前简化问题的解 $C^{y-1}_{(y-x)/2} - C^{y-1}_{(y-x)/2-1}$
而对于答案的后半部分,就是从这个位置开始数,遇到的-1个数不会少于遇到的1的个数,如果后半部分有 $p=n−(x+y)/2$ 个+1与 $q=m−(y−x)/2$ 个-1,答案就是 $C^{p+q}_{p} - C^{p+q}_{p-1}$
至此,可以得知,前缀和最大值为 $x$ 的序列方案数为
\[\sum_{1\le y\le n+m} \left(C^{y-1}_{(y-x)/2} - C^{y-1}_{(y-x)/2 - 1}\right) \times \left(C^{n+m-y}_{n-(x+y)/2} - C^{n+m-y}_{n-(x+y)/2-1}\right)\]进一步枚举𝑥,可以得到问题的答案为
\[\sum_{1\le x \le n-m}\sum_{1\le y\le n+m} \left(C^{y-1}_{(y-x)/2} - C^{y-1}_{(y-x)/2 - 1}\right)\left(C^{n+m-y}_{n-(x+y)/2} - C^{n+m-y}_{n-(x+y)/2-1}\right)\left[(x\mod 2)=(y\mod 2)\right]\]复杂度分析
我们枚举了 $x$ 与 $y$,其中$x=\Theta(n) ,y=\Theta(n)$,计算组合数在预处理后计常数复杂度,复杂度约为$O(n^2)$ 在Codeforces上,本题使用了249ms
代码
#include <bits/stdc++.h>
using namespace std;
const long long mod = 998244853;
const int maxn= 4000 + 5;
long long prod[maxn], iprod[maxn];
long long pmod(long long a, long long b) {
long long ret = 1;
while (b) {
if (b & 1) ret = ret * a % mod;
b >>= 1;
a = a * a % mod;
}
return ret;
}
long long C(long long n, long long m) {
if (n < m) return 0;
if (n < 0 || m < 0) return 0;
return prod[n] * iprod[m] % mod * iprod[n - m] % mod;
}
long long f(long long m, long long n) {
if (m < n) return 0;
return (C(m + n, m) - C(m + n, m + 1) + mod) % mod;
}
long long sol(long long n, long long m) {
long long cnt = 0;
for (long long i = 1; i <= n; i++) {
for (long long j = 0; j <= m && j <= i; j++) {
long long ans1, ans2;
ans1 = f(i - 1, j) * (i - j) % mod;
ans2 = f(m - j, n - i);
(cnt += ans1 * ans2 % mod) %= mod;
}
}
return cnt;
}
int main() {
int n, m;
prod[0] = iprod[0] = 1;
for (long long i = 1; i < maxn; i++) {
prod[i] = prod[i - 1] * i % mod;
iprod[i] = pmod(prod[i], mod - 2);
}
cin >> n >> m;
cout << sol(n, m) << endl;
return 0;
}