HDU 6061 RXD and functions

Posted on By 二价氢

题意

在一眼看出$a_i$只有用来求$\sum a_i$意义之后,题意变成了求$\sum_{i=0}^{n} c_i(x + p)^i$展开式中$x^i$项前面的系数。

思路

考虑变形:

\[\begin{align*} \sum_{i=0}^{n} c_i(x + p)^i &=& \sum_{i=0}^{n} c_i \sum_{j=0}^i \binom{i}{j} x^{i-j} p^{j} \\ &=& \sum_{i=0}^{n} c_i \sum_{j=0}^i \frac{i!}{j!(i-j)!} x^{i-j} p^{j} \\ (r\leftarrow i-j) & = & \sum_{r=0}^{n} \frac{x^r}{r!} \sum_{j=0}^{n-r} \frac{(r+j)!}{j!} c_{r+j} p^{j} \\ & = & \sum_{r=0}^{n} \frac{x^r}{r!} \sum_{j=0}^{n-r} c_{r+j} (r+j)! p^j (j!)^{-1} \\ \left(K_j\leftarrow p^{n-j}(n-j)!^{-1}\right) & = & \sum_{r=0}^{n} \frac{x^r}{r!} \sum_{j=0}^{n-r} c_{r+j} (r+j)! K^{n-j} \\ \left(Q_{r,j}\leftarrow (r+j)!c_{r+j}\right) & = & \sum_{r=0}^{n} \frac{x^r}{r!} \sum_{j=0}^{n-r} Q_{r,j} K^{n-j} \\ \end{align*}\]

注意到最后的式子为卷积的形式,结合$\mod 998244353$,考虑使用NTT来进行计算。构造$Q,K$两个数列,进行卷积即可得到答案。

Code

耗时: 592 ms

内存: 12736 K

#include <bits/stdc++.h>
using namespace std;

const int g = 3, mod = 998244353, maxn = 262144;

long long pow_mod(long long x, long long b)
{
	long long ret = 1;
	while (b)
	{
		if (b & 1) ret = ret * x % mod;
		x = x * x % mod;
		b >>= 1;
	}
	return ret;
}

struct NTT {
	long long w[262144];
	int cur_w_dir = 0;
	void pre_processw(int oper, int n) {
		if (cur_w_dir == oper * n) return;
		w[0] = 1;w[1] = pow_mod(g, (mod - 1) / n);
		if (oper == -1) w[1] = pow_mod(w[1], mod - 2);
		for (int i = 2; i < n; i++) w[i] = w[i - 1] * w[1] % mod;
	}

	void fft(long long P[], int n, int oper)
	{
		pre_processw(oper, n);
		for (int i = 1, j = 0; i < n - 1; i++) {
			for (int s = n; j ^= s >>= 1, ~j & s;);
			if (i < j) swap(P[i], P[j]);
		}
		for (int d = 0; (1 << d) < n; d++) {
			int m = 1 << d, m2 = m * 2, rm = n >> (d + 1);
			for (int i = 0; i < n; i += m2) {
				for (int j = 0; j < m; j++) {
					long long &P1 = P[i + j + m], &P2 = P[i + j];
					long long t = w[rm * j] * P1 % mod;
					P1 = (P2 - t + mod) % mod;
					P2 = (P2 + t) % mod;
				}
			}
		}
	}
}ntt;

long long frac[maxn], finv[maxn];
long long A[maxn], B[maxn], C[maxn];
long long sigma = 0;
int n;

void sol()
{
	int k = 1;
	while (k <= (n + 1) * 2) k *= 2;
	memset(B, 0, sizeof(long long) * k);
	memset(A, 0, sizeof(long long) * k);
	sigma = 0;
	for (int i = 0; i <= n; i++) scanf("%lld", &C[i]);
	for (int i = 0; i <= n; i++)
		B[i] = C[i] * frac[i] % mod;
	int m;
	scanf("%d", &m);
	for (int i = 0; i < m; i++)
	{
		int t;
		scanf("%d", &t);
		(sigma += t) %= mod;
	}
	sigma = mod - sigma;
	sigma %= mod;
	for (int i = 0; i <= n; i++)
		A[n - i] = pow_mod(sigma, i) * finv[i] % mod;
	long long kkinv = pow_mod(k, mod - 2);
	ntt.fft(A, k, 1);
	ntt.fft(B, k, 1);
	for (int i = 0; i < k; i++)
		(A[i] *= B[i]) %= mod;
	ntt.fft(A, k, -1);
	for (int i = 0; i < k; i++) A[i] = A[i] * kkinv % mod;
	for (int i = n; i <= n + n; i++) printf("%lld ", A[i] * finv[i - n] % mod);
	printf("\n");
}

int main()
{
	frac[0] = 1;
	for (int i = 1; i < maxn; i++) frac[i] = frac[i - 1] * i % mod;
	for (int i = 0; i < maxn; i++) finv[i] = pow_mod(frac[i], mod - 2);
	while (~scanf("%d", &n)) sol();
	return 0;
}