TopCoder SRM717 Div1 Middle

Posted on By 二价氢

题意

给定两个数$n$, $m$

另$D_i$为集合${1, 2, \ldots, n + i}$的全排列$p$中满足$\forall j \in [1, i], p[j] != j$的全排列的个数。

另$B_i = D_i / n!$,易证$B_i$为整数

求$B_1 \oplus B_2 \oplus B_3 \oplus \ldots B_m$的值。

注意:异或和只是为了减少答案大小,标程计算了所有的$B_i$的值。

思路

考虑$D_i$的表达式

最后发现拿$\frac{n!}{0!n!},\frac{(n+1)!}{1!n!},\frac{(n+2)!}{2!n!},\ldots,\frac{(n+m)!}{m!n!}$和$1/0!, -1/1!, 1/2!. -1/3!, \ldots, (-1)^m/m!$卷一下就好了。

Code

耗时: 12107 ms

内存: 37480 kB

#include <bits/stdc++.h>
using namespace std;

const int mod = 1000000007;
const double PI = acos(-1);
typedef complex<double> Complex;
const int N = 262144, L = 15, MASK = (1 << L) - 1;
Complex w[N];
void FFTInit()
{
	for (int i = 0; i < N; i++)
		w[i] = Complex(cos(2 * i * PI / N), sin(2 * i * PI / N));
}

void FFT(Complex p[], int n)
{
	for (int i = 1, j = 0; i < n - 1; i++)
	{
		for (int s = n; j ^= s >>= 1, ~j & s;);
		if (i < j) swap(p[i], p[j]);
	}
	for (int d = 0; (1 << d) < n; ++d)
	{
		int m = 1 << d, m2 = m * 2, rm = n >> (d + 1);
		for (int i = 0; i < n; i += m2) {
			for (int j = 0; j < m; j++) {
				Complex &p1 = p[i + j + m], &p2 = p[i + j];
				Complex t = w[rm * j] * p1;
				p1 = p2 - t;
				p2 = p2 + t;
			}
		}
	}
}

Complex A[N], B[N], C[N], D[N];
void mul(int a[N], int b[N])
{
	for (int i = 0; i < N; i++) {
		A[i] = Complex(a[i] >> L, a[i] & MASK);
		B[i] = Complex(b[i] >> L, b[i] & MASK);
	}
	FFT(A, N), FFT(B, N);
	for (int i = 0; i < N; i++)
	{
		int j = (N - i) % N;
		Complex da = (A[i] - conj(A[j])) * Complex(0, -0.5),
				db = (A[i] + conj(A[j])) * Complex(0.5, 0),
				dc = (B[i] - conj(B[j])) * Complex(0, -0.5),
				dd = (B[i] + conj(B[j])) * Complex(0.5, 0);
		C[j] = da * dd + da * dc * Complex(0, 1);
		D[j] = db * dd + db * dc * Complex(0, 1);
	}
	FFT(C, N), FFT(D, N);
	for (int i = 0; i < N; i++) {
		long long da = (long long)(C[i].imag() / N + 0.5) % mod,
			 db = (long long)(C[i].real() / N + 0.5) % mod,
			 dc = (long long)(D[i].imag() / N + 0.5) % mod,
			 dd = (long long)(D[i].real() / N + 0.5) % mod;
		a[i] = ((dd << (L * 2)) + ((db + dc) << L) + da) % mod;
	}
}

long long pow_mod(long long a, long long b)
{
	long long ret = 1;
	while (b)
	{
		if (b & 1) ret = ret * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return ret;
}

int ta[N], tb[N], nn[N], frac[N], ifrac[N];

struct DerangementsStrikeBack{
	int count(int n, int m)
	{
		FFTInit();
		memset(ta, 0, sizeof ta);
		memset(tb, 0, sizeof tb);
		memset(nn, 0, sizeof nn);
		memset(frac, 0, sizeof frac);
		memset(ifrac, 0, sizeof ifrac);
		nn[0] = frac[0] = ifrac[0] = 1;
		for (int i = 1; i <= m; i++)
		{
			frac[i] = ((long long)frac[i - 1]) * i % mod;
			ifrac[i] = pow_mod(frac[i], mod - 2);
			nn[i] = ((long long)(nn[i - 1])) * (n + i) % mod;
		}
		for (int i = 0; i <= m; i++)
		{
			ta[i] = ((long long)nn[i]) * ifrac[i] % mod;
			if (i & 1)
				tb[i] = (mod - ifrac[i]) % mod;
			else
				tb[i] = ifrac[i];
		}
		mul(ta, tb);
		int ans = 0;
		for (int i = 1; i <= m; i++)
			ans ^= (((long long)ta[i]) * frac[i] % mod);
		return ans;
	}
};