A. Accelerator

题意

给定 $n$ 个正整数 $a_1,a_2,…,a_n$,它们总共有 $n!$ 种 permutation。

对于每一种 permutation $a_{k_1},a_{k_2},…,a_{k_n}$,我们有一个初始值 $v=0$。

从左向右遍历这个 permutation,每遍到一个数 $a_{k_i}$,就令 $v = (v+1) * a_{k_i}$。

求所有 permutation 中,最终的 $v$ 的期望值?

其中,$n \leq 10^5, a_i \in [1, 10^9]$。

题解

期望的分母我们知道是 $n!$,我们只要求分子即可。

我们先看一种简单的例子:

$n=3$ 的情况下,假设一种 permutation 是 $a_1,a_3,a_2$,那么最终的答案 $v$ 就等于:

$$v=(((v+1)*a_1+1)*a_3+1)*a_2=a_1a_2a_3+a_2a_3+a_2$$

很容易发现规律,$v$ 由 $n$ 个数的和组成,每个数都是 $k = \{1,2,…n\}$ 个 $a_i$ 的乘积。

我们设 $b_i$ 为 $i$ 个 $a$ 的乘积组成的项的和。

例如,$n=3$ 时,

$$b_3=a_1a_2a_3$$ $$b_2=a_1a_2+a_1a_3+a_2a_3$$ $$b_1=a_1+a_2+a_3$$

又因为全排列总共有 $n!$ 个,而对于每一个全排列,其中的每一项,都恰好会出现 $1$ 次(比如上面例子中,对于全排列 $a_1,a_3,a_2$,出现的项就是 $a_1a_2a_3,a_2a_3,a_2$)。

所以对于 $b_i$,它总共出现了 $\frac{n!}{C_n^i}$ 次。

所以分子的值就是:

$$\sum\limits_{i=1}^n \frac{n!}{C_n^i}b_i$$

现在我们只要求出所有的 $b_i$ 即可。


我们注意到 $b_i$ 是 乘起来,然后加起来 的形式,于是我们想到 生成函数

定义

$$f(x)=(1+a_1x)(1+a_2x)…(1+a_nx)$$

我们会发现 $x^i$ 的系数就是所求的 $b_i$。


最后注意到这是 $n$ 个 deg = 1 的多项式相乘。

然而,两个 deg 分别为 $n,m$ 的多项式相乘,令 $d=\max(n,m)$,则复杂度是 $O(d\log d)$ 的。

所以我们不能直接把这些多项式乘一块,我们应该分治着相乘,即 solve(L, R) = solve(L, mid) * solve(mid+1, R)

代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int maxn = 1e5+5;

Z fac[maxn], inv_fac[maxn];
Z C(int n, int m) {
    return fac[n] * inv_fac[n-m] * inv_fac[m];
}
Z invC(int n, int m) {
    return inv_fac[n] * fac[n-m] * fac[m];
}
 
int n;
Poly solve(int l, int r) {
    if (l == r) {
        int x; cin >> x;
        Poly f(2);
        f[0] = 1;
        f[1] = x;
        return f;
    }
    int mid = (l+r) >> 1;
    Poly f = solve(l, mid);
    return f * solve(mid+1, r);
}
 
int main() {
    int T; cin >> T;
    fac[0] = inv_fac[0] = 1;
    for (int i = 1; i <= 1e5; i++) fac[i] = fac[i-1] * i;
    inv_fac[100000] = fac[100000].inv();
    for (int i = 99999; i >= 1; i--) inv_fac[i] = inv_fac[i+1] * (i+1);
    while (T--) {
        cin >> n;
        Poly f = solve(1, n);
        Z res = 0;
        for (int i = 1; i <= n; i++) {
            res = res + (Z(f[i].v) * invC(n, i));
        }
        cout << res.val() << "\n";
    }
}