HDU4261 / Estimation

Posted on By 二价氢

题目:HDU 4261

给你一个数组\(A\),求一个数组\(B\),\(B\)被划分为k段,每段内的数字相等,求\(\Sigma \vert B_i - A_i\vert\)的最小值

我们知道,假设只有一段,那么数组\(B\)内元素的值肯定等于\(A_i\)中的元素的中位数,现在的问题在于,如果暴力求得中位数(\(O(n^3 \log n)\)),那么肯定是超时的,我们需要更快的中位数求法

设当前中位数为\(m\),现在维护两个集合,\(S_1\)与\(S_2\),其中\(S_1\)中的元素均大于等于\(m\)而\(S_2\)中的元素均小于等于\(m\)

当我们新得到一个数\(x\),假设\(x \le m\)则将它加入\(S_2\)中,否则加入到\(S_1\)中

如果两个集合里元素较多的那个集合元素个数比元素较少的那个多了2,则说明中位数发生了偏移,此时,我们将现在的中位数加入到元素较少的那个集合中,从元素较多的那个集合中取出并删去最大(小)的元素,并设其为新的中位数。

维护这个东西可以考虑用堆(优先队列)等数据结构来完成,速度很快,为\(O(n^2 \log n)\)

之后即是朴素的区间DP,此处不再赘述。

AC-Code

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2000 + 5 , maxk = 25 + 5;
int a[maxn];
long long seg[maxn][maxn];
long long dp[maxk][maxn];
int n , k;
void work()
{
	memset(seg , 0 , sizeof(seg));
	for (int i = 1 ; i <= n ; i++)
		scanf("%d" , &a[i]);
	for (int i = 1 ; i <= n ; i++)
	{
		seg[i][i] = 0;
		long long lg = 0 , sm = 0 , m = a[i] , slg = 0 , ssm = 0;
		priority_queue<long long , vector<long long> , greater<long long> > A;
		priority_queue<long long> B;
		long long it;
		for (int j = i + 1 ; j <= n ; j++)
		{
			if (a[j] < m)
			{
				B.push(a[j]);
				sm++;
				ssm += a[j];
			}else
			{
				A.push(a[j]);
				lg++;
				slg += a[j];
			}
			while (lg - sm >= 2)
			{
				it = A.top();A.pop();
				B.push(m);
				slg -= it;lg--;
				ssm += m;sm++;
				m = it;
			}
			while (sm - lg >= 2)
			{
				it = B.top();B.pop();
				A.push(m);
				ssm -= it;sm--;
				slg += m;lg++;
				m = it;
			}
//			cerr << i << " " << j << " " << m << endl;
			seg[i][j] = m * sm - ssm + slg - m * lg;
//			printf("%d %d %lld %lld\n" , i , j , m , seg[i][j]);
		}
	}
	memset(dp , 0x7f , sizeof(dp));
	dp[0][0] = 0;
	for (int i = 1 ; i <= k ; i++)
		for (int j = 0 ; j <= n ; j++)
			for (int l = j + 1 ; l <= n ; l++)
				if (dp[i - 1][j] != 0x7f7f7f7f7f7f7f7fll)
				{
					if (dp[i][l] == 0x7f7f7f7f7f7f7f7fll)
						dp[i][l] = dp[i - 1][j] + seg[j + 1][l];
					else
						dp[i][l] = min(dp[i][l] , dp[i - 1][j] + seg[j + 1][l]);
				}
	printf("%lld\n" , dp[k][n]);
}
int main()
{
	while ((~scanf("%d%d" , &n , &k)) && n && k)
		work();
	return 0;
}