CodeForces 626E Simple Skewness

Posted on By 二价氢

题目大意

给一个有重集\(S\),\(S\)的大小不超过\(2\times 10^5\),求\(S\)的一个子集(可以有重复元素),使得子集的均值减子集的中值最大。

做法

若要子集的均值减中值最大,则有子集的形式一定为\(\{S_{i-j},S_{i-j+1},\ldots,S_{i-1},S_{i},S_{n-j+1},S_{n-j+2},\ldots,S_{n}\}\),或\(\{S_{i-j+1},\ldots,S_{i-1},S_{i},S_{n-j+1},S_{n-j+2},\ldots,S_{n}\}\),当然,我们可以证明,如果将第二种情况可能为答案,则若去掉\(S_{n-j+1}\)这一项,答案不会变差。

不妨设\(x = S_{i},y=S_{n-j+1}\),均值为\(z\),则其中位数为\(\frac{x+y}{2}\),若其为答案备选项,则有\(\frac{x+y}{2}\le z\)

去掉\(S_{n-j+1}\)之后,新集合的均值为\(\frac{2zj - y}{2j-1}\),新集合的中位数为\(x\),显然,新的答案与旧的答案之差为

\[\frac{2zj - y}{2j-1}-x-z+\frac{x+y}{2} = \frac{4zj-2y-(4j-2)x - (4j-2)z+(2j-1)x+(2j-1)y}{4j-2}\]

化简得

\[\frac{2z+(2j-3)y-(2j-1)x}{4j-2} \ge \frac{(2j-1)y-(2j-1)x}{4j-2} \ge 0\]

即答案不会变劣。

我们进一步猜想,根据这个数据范围,\(O(n^2)\)是不可能的,如果是\(O(n)\),我们显然不可能在\(O(1)\)的时间内得到最佳的\(j\),那么有很大的可能,答案是\(O(n\log n)\)的,我们现在开始证明其凹凸性或单调性。

假设当前的平均值为\(z\),有\(n\)个数,新增加的数为\(a\)和\(b\),则我们可以得到新的平均值为\(\frac{nz + a + b}{n+2}\),答案的变化量为\(\Delta = \frac{a+b-2z}{n+2}\),设\(p = a-z,q=b-z\),那么\(\Delta = \frac{p + q}{2}\),当我们再加入下一个数的时候,显然,有\(p'\le p, q' \le q\),即\(\Delta' = \frac{p'+q'}{2} \le \frac{p+q}{2}\),即答案的变化量是单调的,又即,答案的导数是单调的,即答案是上凸函数。

因而,我们可以二分,找出答案的极值点,即是我们在中位数为\(S_i\)时的最优的\(j\)的取值,复杂度为\(O(n\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;
const int maxn = 200000 + 1;
int n, mxi = 1, mxj = 0;
double mxa = 0;
int main() {
    ios::sync_with_stdio(0);
    cin >> n;
    vector<int> a(n + 1);
    vector<long long> prefix(n + 1, 0);
    for (int i = 1; i <= n; i++) cin >> a[i];
    sort(a.begin(), a.end());
    for (int i = 1; i <= n; i++)
        prefix[i] = prefix[i - 1] + a[i];
    for (int i = 1; i <= n; i++) {
        int tj = 0, rj = 0;
        double lxa = 0;
        for (int j = 1 << 30; j; j >>= 1) {
            int nj = tj + j;
            double txa;
            if (i - nj < 1 || i + nj > n) continue;
            txa = 1.0 * (prefix[n] - prefix[n - nj] + prefix[i] - prefix[i - nj - 1]) /
                        (nj * 2 + 1) - a[i];
            if (txa > lxa) {
                lxa = txa;
                rj = nj;
            }
            if (i - nj < 2 || i + nj > n - 1) {
                continue;
            } else {
                double txb = 1.0 * 
                         (prefix[n] - prefix[n - nj - 1] + prefix[i] - prefix[i - nj - 2]) /
                         (nj * 2 + 3) - a[i];
                if (txa < txb) tj = nj;
            }
        }
        if (lxa > mxa) {
            mxa = lxa;
            mxi = i;
            mxj = rj;
        }
    }
    cout << mxj * 2 + 1 << endl;
    for (int i = mxi - mxj; i < mxi; i++) cout << a[i] << ' ';
    cout << a[mxi];
    for (int i = n - mxj + 1; i <= n; i++) cout << ' ' << a[i];
    cout << endl;
    cerr << mxa << endl;
    return 0;
}