浅谈FWT

§ 1 前言

对于离散卷积式

我们可以直接使用 $\rm FFT$ 求解。

考虑该问题的扩展,如何求解将 $i+j=n$ 限制中的 $+$ 换成其他运算符的卷积式。

对于 $c_n=\sum\limits_{i-j=n}a_ib_j$,可以将多项式 $B$ 系数反转后进行 $\rm FFT$。

对于 $c_n=\sum\limits_{i\times j=n} a_ib_j$,在模 $p$ 意义下,且 $p$ 有原根 $G$ 时可以将 $i$ 位置的值放到 $\log_G i\pmod{p}$ 位置进行 $\rm NTT$,然后再放回来。

而对于位运算卷积,即

可以使用快速沃尔什变换 $(\rm Fast\ Walsh\ Transform,\ FWT)$ 求解。


§ 2 快速沃尔什变换

首先声明下文中用到的表示法。

对于一个 $n-1$ 次多项式

将其表示为

定义多项式加法 $(\rm operator\ +)$ 为

定义多项式减法 $(\rm operator\ -)$ 为

定义多项式对应系数乘法 $(\rm operator\ \times)$ 为

对于一个位运算符 $\boxtimes\in\lbrace\ |,\ \&,\oplus\rbrace$,定义位运算卷积

易证得,$\boxtimes$ 运算具有分配律。

定义多项式拼接运算 $(A,B)$ 为

对于一个 $2^k-1$ 次的多项式 $A$,令 $A_0$ 为其前 $2^{k-1}$ 项,$A_1$ 为其后 $2^{k-1}$ 项。

代码中出现的常量、变量及函数定义如下:

1
2
3
4
5
typedef long long ll;
const int inv2 = 499122177, MOD = 998244353;

inline int add(const int &x, const int &y) {return x + y >= MOD ? x + y - MOD : x + y;}
inline int mul(const int &x, const int &y) {return x * (ll)y % MOD;}

§ 2.1 按位或卷积

§ 2.1.1 构造

按位或卷积形式如下

定义 ${\rm FWT}(A)$ 为

容易发现 ${\rm FWT}(A\ |\ B)={\rm FWT}(A)\times{\rm FWT}(B)$,证明如下

所以我们只需寻找 $A\Leftrightarrow{\rm FWT}(A)$ 之间快速变换的方法即可计算出卷积。

§ 2.1.2 快速变换

容易发现,对于一个 $2^n-1$ 次多项式 $A$,有如下结论

  • 对于 ${\rm FWT}(A)_0$,只有 ${\rm FWT}(A_0)$ 的对应项有贡献。
  • 对于 ${\rm FWT}(A)_1$,${\rm FWT}(A_0)$ 及 ${\rm FWT}(A_1)$ 的对应项均有贡献。

所以有

1
2
3
4
5
6
7
8
9
10
11
12
13
inline void FWT_or(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k + i] = add(A[j + k + i] , A[j + k]);
}

inline void IFWT_or(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k + i] = add(A[j + k + i] , MOD - A[j + k]);
}

§ 2.2 按位与卷积

按位与卷积形式如下

处理方式和按位或卷积基本相同,下面直接给出结论

1
2
3
4
5
6
7
8
9
10
11
12
13
inline void FWT_and(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k] = add(A[j + k] , A[j + k + i]);
}

inline void IFWT_and(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k] = add(A[j + k] , MOD - A[j + k + i]);
}

§ 2.3 按位异或卷积

按位异或卷积形式如下

处理方式和前两个位运算卷积基本相同,证明见 zyh 巨爷的博文,下面直接给出结论

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
inline void FWT_xor(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++){
const int L = A[j + k], R = A[j + k + i];
A[j + k] = add(L, R), A[j + k + i] = add(L, MOD - R);
}
}

inline void IFWT_xor(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++){
const int L = A[j + k], R = A[j + k + i];
A[j + k] = add(L, R), A[j + k + i] = add(L, MOD - R);
A[j + k] = mul(A[j + k], inv2), A[j + k + i] = mul(A[j + k + i], inv2);
}
}

§ 3 模板题

=> luogu 4717 【模板】快速沃尔什变换

§ 3.1 题意

给定长度为 $2^n$ 的两个多项式 $A,\ B$,设 $C_i=\sum\limits_{j\boxtimes k=i}A_jB_k$。

求出当 $\boxtimes$ 分别为 ${\rm or},\ {\rm and},\ {\rm xor}$ 时的 $C$。

$0\leq n\leq 17$。

§ 3.2 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int inv2 = 499122177, MOD = 998244353;

inline void getint(int &num){
register int ch, neg = 0;
while(!isdigit(ch = getchar())) if(ch == '-') neg = 1;
num = ch & 15;
while(isdigit(ch = getchar())) num = num * 10 + (ch & 15);
if(neg) num = -num;
}

inline int add(const int &x, const int &y) {return x + y >= MOD ? x + y - MOD : x + y;}
inline int mul(const int &x, const int &y) {return x * (ll)y % MOD;}

int n, lim, A[131080], B[131080], C[131080];

inline void FWT_or(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k + i] = add(A[j + k + i] , A[j + k]);
}

inline void IFWT_or(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k + i] = add(A[j + k + i] , MOD - A[j + k]);
}

inline void FWT_and(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k] = add(A[j + k] , A[j + k + i]);
}

inline void IFWT_and(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++)
A[j + k] = add(A[j + k] , MOD - A[j + k + i]);
}

inline void FWT_xor(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++){
const int L = A[j + k], R = A[j + k + i];
A[j + k] = add(L, R), A[j + k + i] = add(L, MOD - R);
}
}

inline void IFWT_xor(int *A){
for(register int i = 1; i < lim; i <<= 1)
for(register int j = 0; j < lim; j += i << 1)
for(register int k = 0; k < i; k++){
const int L = A[j + k], R = A[j + k + i];
A[j + k] = add(L, R), A[j + k + i] = add(L, MOD - R);
A[j + k] = mul(A[j + k], inv2), A[j + k + i] = mul(A[j + k + i], inv2);
}
}

int main(){
getint(n), lim = 1 << n;
for(register int i = 0; i < lim; i++) getint(A[i]);
for(register int i = 0; i < lim; i++) getint(B[i]);

FWT_or(A), FWT_or(B);
for(register int i = 0; i < lim; i++) C[i] = mul(A[i], B[i]);
IFWT_or(A), IFWT_or(B), IFWT_or(C);
for(register int i = 0; i < lim; i++) printf("%d ", C[i]); puts("");

FWT_and(A), FWT_and(B);
for(register int i = 0; i < lim; i++) C[i] = mul(A[i], B[i]);
IFWT_and(A), IFWT_and(B), IFWT_and(C);
for(register int i = 0; i < lim; i++) printf("%d ", C[i]); puts("");

FWT_xor(A), FWT_xor(B);
for(register int i = 0; i < lim; i++) C[i] = mul(A[i], B[i]);
IFWT_xor(C);
for(register int i = 0; i < lim; i++) printf("%d ", C[i]); puts("");
return 0;
}