树链剖分-重链剖分

重链剖分

是指一种对树进行划分的算法, 它先通过轻重边剖分(Heavy-Light Decomposition)将树分为多条链, 保证每个点属于且只属于其中一条链, 然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链.

定义

size[i] 以结点i为根的子树中结点的个数.
son[i] 结点i的重儿子.
dep[i] 结点i的深度, 根的深度为1.
top[i] 结点i所在重链的链首结点.
fa[i] 结点i的父结.
dfn[i] 在DFS找重链的过程中为结点i重新编的号码, 每条重链上的结点编号是连续的.
rnk[i] 链上节点对应的树上节点编号, 用于访问初值.

那么, 令 $V$ 是 $U$ 的儿子节点中size值最大的节点, 称 $V$ 是 $U$ 的重儿子重边, 除此之外, 所有的边称为轻边. 全部由重边构成的路径称为重链.

预处理

  1. 第一次DFS(预处理)
    从根节点出发, 递归遍历整棵树, 计算每个节点的sizedepfa, 并找出每个节点的重儿子son(即子树最大的儿子).

  2. 第二次DFS(分配编号)
    再次从根节点出发, 优先遍历重儿子, 将每个节点分配一个新的编号dfn, 并确定每个节点所在重链的链首top.

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void dfs1(int u, int f) {
size[u] = 1;
for (int i = head[u]; i; i = nxt[i]) {
int v = to[i];
if (v == f) continue;
dep[v] = dep[u] + 1;
fa[v] = u;
dfs1(v, u);
size[u] += size[v];
if (size[v] > size[son[u]]) {
son[u] = v;
}
}
}
void dfs2(int u, int t) {
top[u] = t;
dfn[u] = ++tot;
rnk[tot] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = head[u]; i; i = nxt[i]) {
int v = to[i];
if (v != fa[u] && v != son[u]) {
dfs2(v, v);
}
}
}

// in int main()
dfs1(rt,0);
dfs2(rt,rt);

常见操作

单点处理

要处理 x 节点上的值, 直接访问 dfn[x] 即可.

路径上维护

由于链上的 dfn 都是连续的, 可以用线段树/树状数组直接维护这些重链.

但是大多数时间任意两个节点形成的路径并不是一条链. 我们的目标是让这两个节点到同一个重链上以便我们维护. 让 dep[top] 更大的节点往上跳到top, 同时修改这一段重链, 直到两个点到同一个重链上. 到同一重链后也要操作一下.

注意, 上述操作应跳至fa[top[x]],修改dfn[top[x]] ~ dfn[x], 具体原因请自行思考.

子树上维护

dfs2的编号过程可知, 一个子树的dfn值使连续的. 由此可知, 我们可以直接对一个子树进行更改. 更改范围: $dfn_x \sim dfn_x+size[x]-1$

例题 P2590 [ZJOI2008] 树的统计

按照题意维护即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 5e5 + 5;
int head[N], nxt[2 * N], to[2 * N], num = 0;
void add(int x, int y) {
num++;
to[num] = y;
nxt[num] = head[x];
head[x] = num;
}
int size[N], son[N], dep[N], top[N], fa[N], pos[N], rev[N];
void dfs1(int x, int f) {
size[x] = 1;
dep[x] = dep[f] + 1;
fa[x] = f;
son[x] = 0;
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == f) continue;
dfs1(v, x);
size[x] += size[v];
if (size[v] > size[son[x]]) {
son[x] = v;
}
}
}
int label = 0;
void dfs2(int x, int ance) {
pos[x] = ++label;
rev[label] = x;
top[x] = ance;
if (son[x]) dfs2(son[x], ance);
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == fa[x] || v == son[x]) continue;
dfs2(v, v);
}
}

int a[N], tree1[4 * N], tree2[4 * N]; // tree1 max tree2 sum
void build(int p, int l, int r) {
if (l == r) {
tree1[p] = tree2[p] = a[rev[l]];
return;
}
int mid = (l + r) / 2;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
tree1[p] = max(tree1[p * 2], tree1[p * 2 + 1]);
tree2[p] = tree2[p * 2] + tree2[p * 2 + 1];
}
void change(int p, int l, int r, int x, int y) {
if (l == r) {
tree1[p] = y;
tree2[p] = y;
return;
}
int mid = (l + r) / 2;
if (x <= mid)
change(p * 2, l, mid, x, y);
else
change(p * 2 + 1, mid + 1, r, x, y);
tree1[p] = max(tree1[p * 2], tree1[p * 2 + 1]);
tree2[p] = tree2[p * 2] + tree2[p * 2 + 1];
}
int askmax(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return tree1[p];
}
int mid = (l + r) / 2;
int ans = -1e9;
if (x <= mid) ans = max(ans, askmax(p * 2, l, mid, x, y));
if (mid + 1 <= y) ans = max(ans, askmax(p * 2 + 1, mid + 1, r, x, y));
return ans;
}
int asksum(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return tree2[p];
}
int mid = (l + r) / 2;
int ans = 0;
if (x <= mid) ans += asksum(p * 2, l, mid, x, y);
if (mid + 1 <= y) ans += asksum(p * 2 + 1, mid + 1, r, x, y);
return ans;
}
int qmax(int u, int v, int n) {
int ans = -1e9;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans = max(ans, askmax(1, 1, n, pos[top[u]], pos[u]));
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
ans = max(ans, askmax(1, 1, n, pos[u], pos[v]));
return ans;
}
int qsum(int u, int v, int n) {
int ans = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans += asksum(1, 1, n, pos[top[u]], pos[u]);
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
ans += asksum(1, 1, n, pos[u], pos[v]);
return ans;
}
int main() {
int n;
cin >> n;
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
int q;
cin >> q;
while (q--) {
string op;
int u, v;
cin >> op >> u >> v;
if (op == "CHANGE") {
change(1, 1, n, pos[u], v);
} else if (op == "QMAX") {
cout << qmax(u, v, n) << endl;
} else {
cout << qsum(u, v, n) << endl;
}
}
return 0;
}