浅谈多项式

§ 1 多项式基本操作

  • 多项式乘法
  • 多项式求逆
  • 多项式除法/取模
  • 多项式开根
  • 多项式 $\ln$
  • 多项式 $\exp$
  • 多项式 $k$ 次幂

为简化运算,所有操作均在 ${\rm mod}\ 998244353$ 意义下进行。实现时注意清空高次系数。

代码中出现的常量、变量及函数定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
typedef long long ll;
typedef struct {int r, i;} Pair;
const int G = 3, MOD = 998244353;
int lim, invlim, s, Wn[1 << 18], rev[1 << 18];

inline int add(const int &x, const int &y) {return x + y < MOD ? x + y : x + y - MOD;}
inline int sub(const int &x, const int &y) {return x >= y ? x - y : x - y + MOD;}
inline int mul(const int &x, const int &y) {return x * (ll)y % MOD;}
inline int getrand(const int &mx) {return (((ll)rand() << 15) ^ rand()) % mx + 1;}

inline int fastpow(int bas, int ex = MOD - 2){
register int res = 1; bas %= MOD;
for(; ex; ex >>= 1, bas = mul(bas, bas)) if(ex & 1) res = mul(res, bas);
return res;
}

§ 1.1 多项式乘法

直接 $\rm NTT$ 求卷积。

时间复杂度 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
inline void init(const int &n){
lim = 1, s = 0; while(lim < n) lim <<= 1, s++; invlim = fastpow(lim);
for(register int i = 1; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (s - 1);
}

inline void NTT(vector<int> &a, int f){
a.resize(lim);
for(register int i = 1; i < lim; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(register int i = 1; i < lim; i <<= 1){
Wn[0] = 1, Wn[1] = fastpow(G, (MOD - 1) / (i << 1));
for(register int j = 2; j < i; j++) Wn[j] = mul(Wn[j - 1], Wn[1]);
for(register int j = 0; j < lim; j += i << 1)
for(register int k = j; k < j + i; k++){
const int t = mul(Wn[k - j], a[k + i]);
a[k + i] = sub(a[k], t), a[k] = add(a[k], t);
}
}
if(f == -1){
reverse(a.begin() + 1, a.end());
for(auto &i: a) i = mul(i, invlim);
}
}

inline vector<int> polymul(const vector<int> &a, const vector<int> &b){
vector<int> A(a), B(b); init(a.size() + b.size() - 1);
NTT(A, 1), NTT(B, 1);
for(register int i = 0; i < lim; i++) A[i] = mul(A[i], B[i]);
NTT(A, -1);
return A.resize(a.size() + b.size() - 1), A;
}

§ 1.2 多项式求逆

给定多项式 $A(x)$,求 $A^{-1}(x)$ 满足

其中 ${\rm mod}\ x^n$ 表示保留 $0\sim n-1$ 次项。

当 $n=1$ 时,直接用费马小定理得 $a^{-1}\equiv a^{p-2}\pmod{p}$。

当 $n\gt 1$ 时,假设当前已求出 ${\rm mod}\ x^{\lceil\frac{n}{2}\rceil}$ 意义下的 $A(x)$ 的逆元 $B_0(x)$,即

需要求出 $B(x)$ 满足

两式相减得

考虑到 $A(x)\not\equiv 0\pmod{x^{\lceil\frac{n}{2}\rceil}}$,故有

两边平方,考虑卷积定义,可知同余式右边仍为 $0$,即

两边同乘 $A(x)$ 得

化简得

递归求解即可,时间复杂度为 $T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)$。

1
2
3
4
5
6
7
8
9
inline vector<int> polyinv(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> inv, A(a.begin(), a.begin() + n);
if(n == 1) return inv.push_back(fastpow(a[0])), inv;
inv = polyinv(a, n + 1 >> 1), init((n << 1) - 1);
NTT(inv, 1), NTT(A, 1);
for(register int i = 0; i < lim; i++) inv[i] = mul(inv[i], sub(2, mul(inv[i], A[i])));
NTT(inv, -1);
return inv.resize(n), inv;
}

§ 1.3 多项式除法/取模

给定 $n-1$ 次多项式 $A(x)$ 和 $m-1$ 次多项式 $B(x)$,求多项式 $D(x),\ R(x)$ 满足

其中 $D(x)$ 为 $n-m$ 次,$R(x)$ 至多 $m-2$ 次。

由于该式的余数 $R(x)$ 难以处理,考虑去除其影响。

定义 $A(x)$ 的系数反转多项式 $A^R(x)$ 为

则有

所以可在 ${\rm mod}\ x^{n-m+1}$ 意义下求逆得到 $D^R(x)$,反转即为 $D(x)$,然后代入原式计算 $R(x)$。

共需要一次求逆和两次乘法,时间复杂度为 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
inline void polydiv(const vector<int> &a, const vector<int> &b, vector<int> &d, vector<int> &r){
if(b.size() > a.size()) return r = a, d.clear();
vector<int> A(a), B(b), invB; int n = a.size(), m = b.size();
reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
B.resize(n - m + 1), invB = polyinv(B, n - m + 1);
d = polymul(A, invB), d.resize(n - m + 1), reverse(d.begin(), d.end());
r = polymul(b, d);
for(register int i = 0; i < m - 1; i++) r[i] = sub(a[i], r[i]);
r.resize(m - 1);
}

§ 1.π 多项式牛顿迭代法

该方法可以用来推很多公式,如多项式求逆、开根以及 $\exp$ 等。

给出一个关于多项式 $f(x)$ 的方程 $g(f(x))=0$,假设已求出 $f(x)$ 的前 $n$ 项 $f_0(x)$,即

将函数 $g(f(x))$ 在 $f_0(x)$ 上泰勒展开得

注意到 $f(x)-f_0(x)$ 第 $0\sim n-1$ 项为 $0$,故 $(f(x)-f_0(x))^2$ 第 $0\sim 2n-1$ 项为 $0$,于是有

化简得

运用一次该公式即可将 $f(x)$ 的已知项数翻倍。


§ 1.4 多项式开根

给定多项式 $A(x)$,求 $B(x)=\sqrt{A(x)}$ 满足

令 $B(x)\equiv B_0(x)\pmod{x^n}$,直接运用牛顿迭代法得

若 $A(x)$ 的常数项不够优秀,递归到边界时还需用 ${\rm Cipolla}$ 算法计算二次剩余,见 T 老师的博文

时间复杂度为 $T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
inline Pair pmul(const Pair &a, const Pair &b, int t){
return (Pair){ add(mul(a.r, b.r), mul(mul(a.i, b.i), t)), add(mul(a.r, b.i), mul(a.i, b.r)) };
}

inline int quadres(const int &a){ // 计算二次剩余 (Quadratic residue)
if(a == 1) return 1;
if(fastpow(a, MOD - 1 >> 1) != 1) return -1; int x, t;
do x = getrand(a - 1); while(fastpow(t = sub(mul(x, x), a), MOD - 1 >> 1) == 1);
Pair res = (Pair){1, 0}, bas = (Pair){x, 1};
for(register int ex = MOD + 1 >> 1; ex; ex >>= 1, bas = pmul(bas, bas, t))
if(ex & 1) res = pmul(res, bas, t);
return min(res.r, MOD - res.r);
}

inline vector<int> polysqrt(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> s, A(a.begin(), a.begin() + n);
if(n == 1) return s.push_back(quadres(a[0])), s;
s = polysqrt(a, n + 1 >> 1), s.resize(n), A = polymul(A, polyinv(s));
for(register int i = 0; i < n; i++) s[i] = add(s[i], A[i]);
for(auto &i: s) i = (i & 1 ? i + MOD >> 1 : i >> 1);
return s;
}

§ 1.5 多项式 $\ln$

给定多项式 $A(x)$,求

考虑直接计算

其中求导和积分都可以在 $O(n)$ 的时间内完成。

代码中省略了预处理逆元的步骤,故积分的复杂度为 $O(n\log n)$。

注意需要保证 $A(x)$ 常数项为 $1$,且默认 $\ln A(x)$ 的常数项为 $0$。

共需要一次求逆,时间复杂度为 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline vector<int> polyderiv(const vector<int> &a){
vector<int> deriv(a.size() - 1);
for(register int i = 1; i < a.size(); i++) deriv[i - 1] = mul(a[i], i);
return deriv;
}

inline vector<int> polyintegr(const vector<int> &a){
vector<int> integr(a.size() + 1);
for(register int i = 0; i < a.size(); i++) integr[i + 1] = mul(a[i], fastpow(i + 1));
return integr;
}

inline vector<int> polyln(const vector<int> &a){
vector<int> l = polymul(polyderiv(a), polyinv(a));
return l.resize(a.size() - 1), polyintegr(l);
}

§ 1.6 多项式 $\exp$

给定多项式 $A(x)$,求

令 $B(x)=e^{A(x)}\pmod{x^n}$,两边取对数得

然后令 $B(x)\equiv B_0(x)\pmod{x^n}$,运用牛顿迭代法得

注意需要保证 $A(x)$ 常数项为 $0$,且默认 $e^{A(x)}$ 的常数项为 $1$。

时间复杂度为 $T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n)$。

1
2
3
4
5
6
7
8
inline vector<int> polyexp(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> e, A;
if(n == 1) return e.push_back(1), e;
e = polyexp(a, n + 1 >> 1), e.resize(n), A = polyln(e);
for(register int i = 0; i < n; i++) A[i] = sub(a[i], A[i]);
A[0] = add(A[0], 1), e = polymul(e, A);
return e.resize(n), e;
}

§ 1.7 多项式 $k$ 次幂

给定多项式 $A(x)$ 和正整数 $k$,求

直接快速幂,时间复杂度为 $O(n\log n\log k)$。

当 $f(x)$ 的常数项为 $1$ 时,有

时间复杂度为 $O(n\log n)$。

当 $f(x)$ 的常数项不为 $1$ 时,设 $f(x)$ 的最低次项为 $ax^d$,则将其提出

可以平移幂次处理后用上面的公式计算,时间复杂度为 $O(n\log n)$。

代码只给出 $f(x)$ 常数项为 $1$ 的情况。

1
2
3
4
5
inline vector<int> polypow(const vector<int> &a, const int &ex){
vector<int> p = polyln(a);
for(int &i: p) i = mul(i, ex);
return polyexp(p);
}

§ 2 模板题

=> LibreOJ #150 挑战多项式

§ 2.1 题意

给定一个 $n+1$ 次多项式 $F(x)$ 和一个正整数 $k$,求多项式 $G(x)$ 满足

保证 $F(x)$ 的常数项是 ${\rm mod}\ 998244353$ 的二次剩余。

注意到开根时 $\pm\sqrt{F(x)}$ 均为合法解,只需取常数项较小者计算即可。

所有运算在 ${\rm mod}\ 998244353$ 下进行。

$1\leq n\leq 10^5,\ \ 0\leq k\lt 998244353$。

§ 2.2 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define MAXN (1 << 18)
using namespace std;
typedef long long ll;

template <typename T> inline void getint(T &num){
register int ch, neg = 0;
while(!isdigit(ch = getchar())) if(ch == '-') neg = 1;
num = ch & 15;
while(isdigit(ch = getchar())) num = num * 10 + (ch & 15);
if(neg) num = -num;
}

namespace Poly{
typedef struct {int r, i;} Pair;
const int G = 3, MOD = 998244353;
int lim, invlim, s, Wn[MAXN], rev[MAXN];

inline int add(const int &x, const int &y) {return x + y < MOD ? x + y : x + y - MOD;}
inline int sub(const int &x, const int &y) {return x >= y ? x - y : x - y + MOD;}
inline int mul(const int &x, const int &y) {return x * (ll)y % MOD;}
inline int getrand(const int &mx) {return (((ll)rand() << 15) ^ rand()) % mx + 1;}

inline int fastpow(int bas, int ex = MOD - 2){
register int res = 1; bas %= MOD;
for(; ex; ex >>= 1, bas = mul(bas, bas)) if(ex & 1) res = mul(res, bas);
return res;
}

inline Pair pmul(const Pair &a, const Pair &b, int t){
return (Pair){ add(mul(a.r, b.r), mul(mul(a.i, b.i), t)), add(mul(a.r, b.i), mul(a.i, b.r)) };
}

inline int quadres(const int &a){
if(a == 1) return 1;
if(fastpow(a, MOD - 1 >> 1) != 1) return -1; int x, t;
do x = getrand(a - 1); while(fastpow(t = sub(mul(x, x), a), MOD - 1 >> 1) == 1);
Pair res = (Pair){1, 0}, bas = (Pair){x, 1};
for(register int ex = MOD + 1 >> 1; ex; ex >>= 1, bas = pmul(bas, bas, t))
if(ex & 1) res = pmul(res, bas, t);
return min(res.r, MOD - res.r);
}

inline void init(const int &n){
lim = 1, s = 0; while(lim < n) lim <<= 1, s++; invlim = fastpow(lim);
for(register int i = 1; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << (s - 1);
}

inline void NTT(vector<int> &a, int f){
a.resize(lim);
for(register int i = 1; i < lim; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(register int i = 1; i < lim; i <<= 1){
Wn[0] = 1, Wn[1] = fastpow(G, (MOD - 1) / (i << 1));
for(register int j = 2; j < i; j++) Wn[j] = mul(Wn[j - 1], Wn[1]);
for(register int j = 0; j < lim; j += i << 1)
for(register int k = j; k < j + i; k++){
const int t = mul(Wn[k - j], a[k + i]);
a[k + i] = sub(a[k], t), a[k] = add(a[k], t);
}
}
if(f == -1){
reverse(a.begin() + 1, a.end());
for(auto &i: a) i = mul(i, invlim);
}
}

inline vector<int> polymul(const vector<int> &a, const vector<int> &b){
vector<int> A(a), B(b); init(a.size() + b.size() - 1);
NTT(A, 1), NTT(B, 1);
for(register int i = 0; i < lim; i++) A[i] = mul(A[i], B[i]);
NTT(A, -1);
return A.resize(a.size() + b.size() - 1), A;
}

inline vector<int> polyinv(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> inv, A(a.begin(), a.begin() + n);
if(n == 1) return inv.push_back(fastpow(a[0])), inv;
inv = polyinv(a, n + 1 >> 1), init((n << 1) - 1);
NTT(inv, 1), NTT(A, 1);
for(register int i = 0; i < lim; i++) inv[i] = mul(inv[i], sub(2, mul(inv[i], A[i])));
NTT(inv, -1);
return inv.resize(n), inv;
}

inline void polydiv(const vector<int> &a, const vector<int> &b, vector<int> &d, vector<int> &r){
if(b.size() > a.size()) return r = a, d.clear();
vector<int> A(a), B(b), invB; int n = a.size(), m = b.size();
reverse(A.begin(), A.end()), reverse(B.begin(), B.end());
B.resize(n - m + 1), invB = polyinv(B, n - m + 1);
d = polymul(A, invB), d.resize(n - m + 1), reverse(d.begin(), d.end());
r = polymul(b, d);
for(register int i = 0; i < m - 1; i++) r[i] = sub(a[i], r[i]);
r.resize(m - 1);
}

inline vector<int> polyderiv(const vector<int> &a){
vector<int> deriv(a.size() - 1);
for(register int i = 1; i < a.size(); i++) deriv[i - 1] = mul(a[i], i);
return deriv;
}

inline vector<int> polyintegr(const vector<int> &a){
vector<int> integr(a.size() + 1);
for(register int i = 0; i < a.size(); i++) integr[i + 1] = mul(a[i], fastpow(i + 1));
return integr;
}

inline vector<int> polyln(const vector<int> &a){
vector<int> l = polymul(polyderiv(a), polyinv(a));
return l.resize(a.size() - 1), polyintegr(l);
}

inline vector<int> polyexp(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> e, A;
if(n == 1) return e.push_back(1), e;
e = polyexp(a, n + 1 >> 1), e.resize(n), A = polyln(e);
for(register int i = 0; i < n; i++) A[i] = sub(a[i], A[i]);
A[0] = add(A[0], 1), e = polymul(e, A);
return e.resize(n), e;
}

inline vector<int> polysqrt(const vector<int> &a, int n = -1){
if(n == -1) n = a.size(); vector<int> s, A(a.begin(), a.begin() + n);
if(n == 1) return s.push_back(quadres(a[0])), s;
s = polysqrt(a, n + 1 >> 1), s.resize(n), A = polymul(A, polyinv(s));
for(register int i = 0; i < n; i++) s[i] = add(s[i], A[i]);
for(auto &i: s) i = (i & 1 ? i + MOD >> 1 : i >> 1);
return s;
}

inline vector<int> polypow(const vector<int> &a, const int &ex){
vector<int> p = polyln(a);
for(int &i: p) i = mul(i, ex);
return polyexp(p);
}
}

int n, k;
vector<int> F, G;

inline vector<int> solve(const vector<int> &F, const int &k){
using namespace Poly;
vector<int> G = polyexp(polyintegr(polyinv(polysqrt(F))));
for(register int i = 1; i < F.size(); i++) G[i] = sub(F[i], G[i]);
return G = polyln(G), G[0] = add(G[0], 1), polyderiv(polypow(G, k));
}

int main(){
srand(19260817), getint(n), getint(k), F.resize(n + 1);
for(register int i = 0; i <= n; i++) getint(F[i]);
G = solve(F, k), G.resize(n);
while(*G.rbegin() == 0) G.pop_back();
for(auto i: G) printf("%d ", i);
return puts(""), 0;
}