浅谈动态DP

§ 1 前言

本文大量参考 txc 巨爷的《基于变换合并的树上动态DP的链分治算法和全局平衡二叉树学习笔记》一文。

在某些问题中,我们需要实现对某种 DP 的权值修改,以及快速询问全局或子结构的 DP 值。

如果我们能找到一种满足结合律的运算来描述转移过程的话,就可以用数据结构维护合并,降低复杂度。


§ 2 例题一

=> luogu 4719 【模板】动态dp

§ 2.1 题意

给定一棵有 $i$ 个点的树,第 $i$ 个点的点权为 $a_i$。

有 $m$ 次操作,每次操作给定 $x,y$,表示将点 $x$ 的权值修改为 $y$。

求每次操作后这棵树的最大权独立集的权值,其中独立集指任意两个顶点不相邻的点集。

$n,m \leq 10^5$。

§ 2.2 分析

本题的模型为树上带修最大权独立集,是动态 DP 的模板题。

§ 2.2.1 弱化版

首先考虑如果没有修改操作,我们可以直接进行 $\Theta(n)$ 的树形 DP。

令 $f_{u,0/1}$ 表示以 $u$ 为根的子树中,不选 / 选 $u$ 时的最大权独立集的权值,则有

§ 2.2.2 重写转移

考虑树链剖分,求出 $u$ 的重儿子 $hson_u​$。单独取出重儿子贡献的一项,则 DP 转化为

则有

§ 2.2.3 矩阵优化

我们重新定义矩阵乘法 $C=A*B$ 来描述这个转移,令

$\begin{align} c_{i,j}=\max \limits_k(a_{i,k}+b_{k,j}) \end{align} $

则转移可写为

可以证明重定义的矩阵乘法是具有结合律的,且存在单位矩阵

求一个点的 DP 值时,只需将该点沿重链走到底的矩阵相乘即可。

由于重链在 dfs 序上连续,可以用线段树维护区间矩阵乘积。修改类似树链剖分,重复以下步骤:

  1. 更新当前点的矩阵;
  2. 跳到重链顶端,同时计算 DP 值,更新父节点的 $g$ 值,并跳到父节点。

这样单次修改复杂度为 $O(\log^2n)$,查询复杂度为 $O(\log n)$。

§ 2.3 代码

注:leafdfn[u] 表示点 $u​$ 所在重链底部节点的 dfs 序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;

inline void getint(int &num){
register int ch, neg = 0;
while(!isdigit(ch = getchar())) if(ch == '-') neg = 1;
num = ch & 15;
while(isdigit(ch = getchar())) num = num * 10 + (ch & 15);
if(neg) num = -num;
}

int n, m, a[100005], f[100005][2], g[100005][2];
int fa[100005], siz[100005], hson[100005];
int dfn[100005], dfstime = 0, id[100005], top[100005], leafdfn[100005];
struct Edge {int np; Edge *nxt;} E[200005], *V[100005];

inline void addedge(const int &u, const int &v){
static int tope = 0;
E[++tope].np = v, E[tope].nxt = V[u], V[u] = E + tope;
}

struct Matrix{
int v[2][2];

inline Matrix() {v[0][0] = v[1][1] = 0, v[0][1] = v[1][0] = -inf;}
inline Matrix(int g0, int g1) {v[0][0] = v[0][1] = g0, v[1][0] = g1, v[1][1] = -inf;}

inline Matrix operator * (const Matrix &mat) const{
static Matrix res;
res.v[0][0] = max(v[0][0] + mat.v[0][0], v[0][1] + mat.v[1][0]);
res.v[0][1] = max(v[0][0] + mat.v[0][1], v[0][1] + mat.v[1][1]);
res.v[1][0] = max(v[1][0] + mat.v[0][0], v[1][1] + mat.v[1][0]);
res.v[1][1] = max(v[1][0] + mat.v[0][1], v[1][1] + mat.v[1][1]);
return res;
}
} s[400005];

void dfs1(int u){
siz[u] = 1, hson[u] = 0, f[u][1] = a[u];
for(register Edge *ne = V[u]; ne; ne = ne->nxt) if(ne->np != fa[u]){
fa[ne->np] = u, dfs1(ne->np), siz[u] += siz[ne->np];
if(siz[ne->np] > siz[hson[u]]) hson[u] = ne->np;
f[u][0] += max(f[ne->np][0], f[ne->np][1]), f[u][1] += f[ne->np][0];
}
}

void dfs2(int u){
id[dfn[u] = ++dfstime] = u, g[u][1] = a[u];
if(hson[u]) top[hson[u]] = top[u], dfs2(hson[u]), leafdfn[u] = leafdfn[hson[u]];
else leafdfn[u] = dfstime;
for(register Edge *ne = V[u]; ne; ne = ne->nxt) if(ne->np != fa[u] && ne->np != hson[u]){
top[ne->np] = ne->np, dfs2(ne->np);
g[u][0] += max(f[ne->np][0], f[ne->np][1]), g[u][1] += f[ne->np][0];
}
}

#define lch (u << 1)
#define rch (u << 1 | 1)

void build(int u, int l, int r){
if(l == r) {s[u] = Matrix(g[id[l]][0], g[id[l]][1]); return;}
const int mid = l + r >> 1;
build(lch, l, mid), build(rch, mid + 1, r);
s[u] = s[lch] * s[rch];
}

void modify(int u, int l, int r, int pos){
if(l == r) {s[u] = Matrix(g[id[l]][0], g[id[l]][1]); return;}
const int mid = l + r >> 1;
if(pos <= mid) modify(lch, l, mid, pos); else modify(rch, mid + 1, r, pos);
s[u] = s[lch] * s[rch];
}

Matrix query(int u, int l, int r, int ql, int qr){
if(l == ql && r == qr) return s[u];
const int mid = l + r >> 1;
if(qr <= mid) return query(lch, l, mid, ql, qr);
if(ql > mid) return query(rch, mid + 1, r, ql, qr);
return query(lch, l, mid, ql, mid) * query(rch, mid + 1, r, mid + 1, qr);
}

int main(){
getint(n), getint(m);
for(register int i = 1; i <= n; i++) getint(a[i]);
for(register int i = 1; i < n; i++){
int u, v; getint(u), getint(v);
addedge(u, v), addedge(v, u);
}
dfs1(1), top[1] = 1, dfs2(1);
build(1, 1, n);
while(m--){
int x, y; getint(x), getint(y);
g[x][1] += y - a[x], a[x] = y;
while(x){
modify(1, 1, n, dfn[x]), x = top[x];
Matrix res = query(1, 1, n, dfn[x], leafdfn[x]);
g[fa[x]][0] -= max(f[x][0], f[x][1]), g[fa[x]][1] -= f[x][0];
f[x][0] = res.v[0][0], f[x][1] = res.v[1][0];
g[fa[x]][0] += max(f[x][0], f[x][1]), g[fa[x]][1] += f[x][0];
x = fa[x];
}
printf("%d\n", max(f[1][0], f[1][1]));
}
return 0;
}

§ 3 例题二

=> luogu 4751 动态dp【加强版】

§ 3.1 题意

同例题一。强制在线,每次操作给出 $x$,实际修改的点为 $x \oplus lastans$。

$n,m \leq 10^6$。

§ 3.2 分析

本题的数据范围要求了更优秀的复杂度。

考虑 LCT,复杂度 $O((n\log n+q\log n)$ 符合题目要求,但因常数过大而表现不理想。

由于本题不需要动态的 link,cut 以及换根操作,我们可以构造一种类似 LCT 的静态数据结构。

§ 3.2.1 全局平衡二叉树

类似 LCT,我们将树的每条重链用一棵辅助二叉树维护,辅助树之间用虚边连接。

每个节点维护所在重链的辅助树的子树矩阵积。

事实上前面的线段树也是一种类似的结构,但每棵都保证了局部的绝对平衡,导致单次复杂度为 $O(\log^2n)$。

所以我们需要找到一种合适的构造方法,做到所有辅助树全体的总深度平衡。

辅助树的构造方法:

  • 定义点 $u$ 的权重 $w_u=size_u-size_{hson_u}$,即所有轻儿子的 $size$ 和 $+1$。
  • 对于每条重链,以带权重心为辅助树根,左右递归构造即可。

容易证明,这样构造出的全局平衡二叉树的深度为 $O(\log n)$。

这样修改时只需在辅助树上单步向上跳并更新答案,跳虚边时更新父亲的 $g$ 值即可。

总时间复杂度为 $O(n\log n+q\log n)$。

§ 3.3 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;

namespace fastio{
const int BUFSIZE = (1 << 22) + 1;
char ibuf[BUFSIZE], *iS, *iT, obuf[BUFSIZE], *oS = obuf, *oT = obuf + BUFSIZE, ch;
int stk[20], tops = 0, neg;

inline char getch(){
if(iS != iT) return *iS++;
iT = (iS = ibuf) + fread(ibuf, 1, BUFSIZE, stdin);
return iS == iT ? EOF : *iS++;
}

inline void flush() {fwrite(obuf, 1, oS - obuf, stdout), oS = obuf;}

inline void putch(const char &ch) {*oS++ = ch; if(oS == oT) flush();}

#define isdgt(ch) ((ch) >= '0' && (ch) <= '9')

inline void getint(int &num){
neg = 0; while(!isdigit(ch = getch())) if(ch == '-') neg = 1;
num = ch & 15; while(isdigit(ch = getch())) num = num * 10 + (ch & 15);
if(neg) num = -num;
}

inline void putint(int num){
if(!num) return putch('0');
if(num < 0) putch('-'), num = -num;
while(num) stk[++tops] = num % 10 | 48, num /= 10;
while(tops) putch(stk[tops--]);
}
}

using fastio::getint;
using fastio::putint;
using fastio::putch;

int n, m, a[1000005], f[1000005][2], g[1000005][2], ch[1000005][2], root;
int fa[1000005], siz[1000005], hson[1000005], top[1000005];
int w[1000005], s[1000005], p[1000005];
struct Edge {int np; Edge *nxt;} E[2000005], *V[1000005];

inline void addedge(const int &u, const int &v){
static int tope = 0;
E[++tope].np = v, E[tope].nxt = V[u], V[u] = E + tope;
}

struct Matrix{
int v[2][2];

inline Matrix() {v[0][0] = v[1][1] = 0, v[0][1] = v[1][0] = -inf;}
inline Matrix(int g0, int g1) {v[0][0] = v[0][1] = g0, v[1][0] = g1, v[1][1] = -inf;}

inline Matrix operator * (const Matrix &mat) const {
static Matrix res;
res.v[0][0] = max(v[0][0] + mat.v[0][0], v[0][1] + mat.v[1][0]);
res.v[0][1] = max(v[0][0] + mat.v[0][1], v[0][1] + mat.v[1][1]);
res.v[1][0] = max(v[1][0] + mat.v[0][0], v[1][1] + mat.v[1][0]);
res.v[1][1] = max(v[1][0] + mat.v[0][1], v[1][1] + mat.v[1][1]);
return res;
}
} F[1000005], G[1000005];

inline void update(int u) {F[u] = F[ch[u][0]] * G[u] * F[ch[u][1]];}

int build(int l, int r){
if(l > r) return 0;
int smid = (s[l - 1] + s[r] + 1) >> 1, L = l, R = r;
while(L < R){
const int mid = L + R >> 1;
if(s[mid] >= smid) R = mid; else L = mid + 1;
}
const int u = p[L];
fa[ch[u][0] = build(l, L - 1)] = u, fa[ch[u][1] = build(L + 1, r)] = u;
return update(u), u;
}

void dfs1(int u){
siz[u] = 1, hson[u] = 0, f[u][1] = a[u];
for(register Edge *ne = V[u]; ne; ne = ne->nxt) if(ne->np != fa[u]){
fa[ne->np] = u, dfs1(ne->np), siz[u] += siz[ne->np];
if(siz[ne->np] > siz[hson[u]]) hson[u] = ne->np;
f[u][0] += max(f[ne->np][0], f[ne->np][1]), f[u][1] += f[ne->np][0];
}
w[u] = siz[u] - siz[hson[u]];
}

void dfs2(int u){
g[u][1] = a[u];
if(hson[u]) top[hson[u]] = top[u], dfs2(hson[u]);
for(register Edge *ne = V[u]; ne; ne = ne->nxt) if(ne->np != fa[u] && ne->np != hson[u]){
top[ne->np] = ne->np, dfs2(ne->np);
g[u][0] += max(f[ne->np][0], f[ne->np][1]), g[u][1] += f[ne->np][0];
}
G[u] = Matrix(g[u][0], g[u][1]);
if(top[u] == u){
int cnt = 0, fu = fa[u];
for(register int v = u; v; v = hson[v])
p[++cnt] = v, s[cnt] = s[cnt - 1] + w[v];
fa[root = build(1, cnt)] = fu;
}
}

int main(){
getint(n), getint(m);
for(register int i = 1; i <= n; i++) getint(a[i]);
for(register int i = 1; i < n; i++){
int u, v; getint(u), getint(v);
addedge(u, v), addedge(v, u);
}
dfs1(1), top[1] = 1, dfs2(1);
int lastans = 0;
while(m--){
int x, y; getint(x), getint(y), x ^= lastans;
g[x][1] += y - a[x], a[x] = y, G[x] = Matrix(g[x][0], g[x][1]);
while(x){
int z = fa[x];
if(ch[z][0] != x && ch[z][1] != x)
g[z][0] -= max(F[x].v[0][0], F[x].v[1][0]), g[z][1] -= F[x].v[0][0];
update(x);
if(ch[z][0] != x && ch[z][1] != x){
g[z][0] += max(F[x].v[0][0], F[x].v[1][0]), g[z][1] += F[x].v[0][0];
G[z] = Matrix(g[z][0], g[z][1]);
}
x = z;
}
putint(lastans = max(F[root].v[0][0], F[root].v[1][0])), putch('\n');
}
return fastio::flush(), 0;
}