[EZOI][1217NOI模拟赛]math(生成函数+分治FFT+高精度)

§ 1 题意

对于集合 $S$,定义 $P(S)=\prod\limits_{x\in S}x$,即 $S$ 中元素之积。特别地,定义 $P(\emptyset)=1$。

记 $[n]=\lbrace 1,2,3,\cdots,n \rbrace$,对于 $n \in \mathbb{Z},\ 0\leq k\leq n$,定义 $F(n,k)=\sum\limits_{S\subseteq[n],\ |S|=k}P(S)$。

给定 $n​$ 以及质数 $p​$,求存在多少 $[0,n]​$ 上的整数 $k​$ 满足 $p\nmid F(n,k)​$,答案对 $998244353​$ 取模。

$1\leq n\leq 10^{1000},\ \ 2\leq p\leq 10^5$。


§ 2 分析

由题意可得, $F(n,k)$ 表示从 $1\sim n$ 中选出 $k$ 个相乘的值。考虑其生成函数,显然有

题目要求满足 $p\nmid F(n,k)$ 的整数 $k$ 的个数,就等价于求生成函数有多少项系数不为 $0$。

为简化计算,我们考虑将上式的高次项和低次项交换,得到

该多项式系数不为 $0$ 的项数显然与原多项式相同。

令 $n=a·p+b$,则有

考虑 $\prod\limits_{i=1}^p(x+i)$,可以发现该式在 ${\rm mod}\ p$ 意义下有且仅有 $p$ 个零点:$0,1,\cdots,p-1$。

由费马小定理可得 $x^{p-1}\equiv 1\quad({\rm mod}\ p)\ (0\lt x\lt p)$。

那么我们考虑 $x·(x^{p-1}-1)$,可以发现该式在 ${\rm mod\ p}$ 意义下有且仅有 $p$ 个零点:$0,1,\cdots,p-1$。

又有 $\prod\limits_{i=1}^p(x+i)$ 与 $x·(x^{p-1}-1)$ 都是 $p$ 次多项式,最高次均为 $1$,故能被 $p$ 个不同点确定。

所以 $\prod\limits_{i=1}^p(x+i)\equiv x·(x^{p-1}-1)\quad({\rm mod}\ p)$,于是有


首先考虑 $x^a$,对该项系数是否为 $0$ 无影响。


然后考虑 $(x^{p-1}-1)^a$,用二项式定理展开得

发现只有 $\binom{a}{i}\not\equiv 0\ ({\rm mod}\ p)$ 时才有可能系数不为 $0$。

由 ${\rm Lucas}$ 定理可得 $\binom{a}{i}\equiv\binom{a\,\%\,p}{i\,\%\,p}·\binom{a\,/\,p}{i\,/\,p}$,即以 $p$ 进制分解组合数。

所以 $\binom{a}{i}\not\equiv 0\ ({\rm mod}\ p)$ 等价于 $p$ 进制下 $a$ 的每一位都不小于 $i$ 的同一位。

若 $a=\overline{a_ka_{k-1}\cdots a_0}_{(p)}$,则贡献为 $\prod\limits_{i=0}^k(a_i+1)$。


最后考虑 $\prod\limits_{i=1}^b(x+i)$。

若 $b\lt p-1$,由于 $(x^{p-1}-1)^a$ 中非 $0$ 项的次数一定为 $p-1$ 的倍数,而 $\prod\limits_{i=1}^b(x+i)$ 中非 $0$ 项的次数小于 $p-1$,一定不会出现重叠项。

那么原多项式中非 $0$ 项个数即为 $(x^{p-1}-1)^a$ 与 $\prod\limits_{i=1}^b(x+i)$ 非 $0$ 项个数之积。

后者直接分治 ${\rm FFT}$ 求解即可。

若 $b =p-1$,则 $\prod\limits_{i=1}^b(x+i)\equiv x^{p-1}-1\ ({\rm mod}\ p)$,证明同上。

计算 $(x^{p-1}-1)^a$ 时将 $a$ 加一即可。


使用 long double 版 ${\rm FFT}$ 即可通过,T老师的高精除法转 $p$ 进制代码精妙。

总时间复杂度为 $O(p\log^2p)$。


§ 3 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const long double PI = acos(-1.0L);
const int MOD = 998244353;

struct cpl{
long double r, i;

inline cpl operator + (const cpl &rhs) const {return (cpl){r + rhs.r, i + rhs.i};}
inline cpl operator - (const cpl &rhs) const {return (cpl){r - rhs.r, i - rhs.i};}
inline cpl operator * (const cpl &rhs) const {return (cpl){r * rhs.r - i * rhs.i, r * rhs.i + i * rhs.r};}
} W[262150];

char n[1005];
int p, len, a[3333], cnt = 0, ans = 1, rev[262150];

inline void carryBit(){
for(register int i = 1; i <= cnt + 5; i++) a[i + 1] += a[i] / p, a[i] %= p;
for(cnt += 5; !a[cnt]; cnt--);
}

inline void FFT(vector<cpl> &a, int lim, int f){
a.resize(lim);
for(register int i = 1; i < lim; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(register int i = 1; i < lim; i <<= 1){
W[0] = (cpl){1, 0}, W[1] = (cpl){cos(PI / i), f * sin(PI / i)};
for(register int j = 2; j < i; j++) W[j] = W[j - 1] * W[1];
for(register int j = 0; j < lim; j += i << 1)
for(register int k = j; k < j + i; k++){
const cpl t = W[k - j] * a[k + i];
a[k + i] = a[k] - t, a[k] = a[k] + t;
}
}
if(f == -1) for(register int i = 0; i < lim; i++) a[i].r /= lim, a[i].i = 0;
}

inline vector<cpl> mul(vector<cpl> f, vector<cpl> g){
register int siz = f.size() + g.size() - 1, lim = 1, s = 0;
while(lim < siz) lim <<= 1, s++;
for(register int i = 1; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (s - 1);
FFT(f, lim, 1), FFT(g, lim, 1);
for(register int i = 0; i < lim; i++) f[i] = f[i] * g[i];
FFT(f, lim, -1), f.resize(siz);
for(auto &i: f) i.r = (long long)(i.r + 0.5) % p;
return f;
}

vector<cpl> solve(int l, int r){
if(l == r){
static vector<cpl> p(2);
return p[0] = (cpl){(long double)l, 0}, p[1] = (cpl){1, 0}, p;
}
const int mid = l + r >> 1;
return mul(solve(l, mid), solve(mid + 1, r));
}

int main(){
scanf("%s%d", n, &p), len = strlen(n);
for(register int i = 0; i < len; i++){
for(register int j = 1; j <= cnt; j++) a[j] *= 10;
a[1] += n[i] & 15, carryBit();
} // High-precise division (n / p)
if(a[1] == p - 1) a[2]++, a[1] = 0, carryBit();
for(register int i = 2; i <= cnt; i++) ans = ans * (a[i] + 1LL) % MOD;
if(a[1]){
auto res = solve(1, a[1]); int cnt = 0;
for(auto i: res) cnt += ((int)(i.r + 0.5) != 0);
ans = ans * (long long)cnt % MOD;
}
return printf("%d\n", ans), 0;
}