树链剖分

2025 7 7 学习报告

树链剖分

是指一种对树进行划分的算法,它先通过轻重边剖分(Heavy-Light Decomposition)将树分为多条链,保证每个点属于且只属于其中一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。

定义

size[i] 以结点i为根的子树中结点的个数;
son[i] 结点i的重儿子;
dep[i] 结点i的深度,根的深度为1;
top[i] 结点i所在重链的链首结点;
fa[i] 结点i的父结点;
pos[i] 在DFS找重链的过程中为结点i重新编的号码,每条重链上的结点编号是连续的。

那么,令 的儿子节点中size值最大的节点,那么称 的重儿子重边,除此之外,所有的边称为轻边。全部由重边构成的路径称为重链

树链剖分的过程

  1. 第一次DFS(预处理)
    从根节点出发,递归遍历整棵树,计算每个节点的sizedepfa,并找出每个节点的重儿子son(即子树最大的儿子)。

  2. 第二次DFS(分配编号)
    再次从根节点出发,优先遍历重儿子,将每个节点分配一个新的编号pos,并确定每个节点所在重链的链首top

  3. 建树维护
    利用线段树、树状数组等数据结构,按照pos数组的顺序建立数据结构,实现对树上路径或子树的区间操作。

  4. 树上操作转化为区间操作
    查询或修改树上两点间路径时,将路径拆分为若干条重链上的区间,依次在数据结构上进行操作。

伪代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void dfs1(int u, int father) {
size[u] = 1;
fa[u] = father;
dep[u] = dep[father] + 1;
int max_size = -1;
for (int v : G[u]) {
if (v == father) continue;
dfs1(v, u);
size[u] += size[v];
if (size[v] > max_size) {
max_size = size[v];
son[u] = v;
}
}
}

void dfs2(int u, int topf) {
pos[u] = ++cnt;
top[u] = topf;
if (!son[u]) return;
dfs2(son[u], topf);
for (int v : G[u]) {
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}

线段树在树链剖分中的维护方法

树链剖分后,通常将每条边的权值映射到其子节点(即边 的权值存储在 上),这样便于用线段树维护。

1. 修改某条边的权值

假设要修改边 的权值,设 ,则只需在线段树上修改 位置的值。

2. 修改某条路径所有边的权值

将路径拆分为若干条重链区间,对每段区间(注意每段区间的左端点要+1,因为边权存储在子节点)进行区间修改。

3. 查询原树中某条路径所有边权的最大值

同样将路径拆分为若干条重链区间,分别查询最大值,取最大即可。

4. 查询/修改某个子树中的边

子树内所有边都可以映射为 区间( 的子树内所有节点,除了 本身)。

例题 P2590 [ZJOI2008] 树的统计

纯板子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 5e5 + 5;
int head[N], nxt[2 * N], to[2 * N], num = 0;
void add(int x, int y) {
num++;
to[num] = y;
nxt[num] = head[x];
head[x] = num;
}
int size[N], son[N], dep[N], top[N], fa[N], pos[N], rev[N];
void dfs1(int x, int fat, int d) {
size[x] = 1;
dep[x] = d;
fa[x] = fat;
son[x] = 0;
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == fat) continue;
dfs1(v, x, d + 1);
size[x] += size[v];
if (size[v] > size[son[x]]) {
son[x] = v;
}
}
}
int label = 0;
void dfs2(int x, int ance) {
pos[x] = ++label;
rev[label] = x;
top[x] = ance;
if (son[x]) dfs2(son[x], ance);
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == fa[x] || v == son[x]) continue;
dfs2(v, v);
}
}

int a[N], tree1[4 * N], tree2[4 * N]; // tree1 max tree2 sum
void build(int p, int l, int r) {
if (l == r) {
tree1[p] = tree2[p] = a[rev[l]];
return;
}
int mid = (l + r) / 2;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
tree1[p] = max(tree1[p * 2], tree1[p * 2 + 1]);
tree2[p] = tree2[p * 2] + tree2[p * 2 + 1];
}
void change(int p, int l, int r, int x, int y) {
if (l == r) {
tree1[p] = y;
tree2[p] = y;
return;
}
int mid = (l + r) / 2;
if (x <= mid) change(p * 2, l, mid, x, y);
else change(p * 2 + 1, mid + 1, r, x, y);
tree1[p] = max(tree1[p * 2], tree1[p * 2 + 1]);
tree2[p] = tree2[p * 2] + tree2[p * 2 + 1];
}
int askmax(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return tree1[p];
}
int mid = (l + r) / 2;
int ans = -1e9;
if (x <= mid) ans = max(ans, askmax(p * 2, l, mid, x, y));
if (mid + 1 <= y) ans = max(ans, askmax(p * 2 + 1, mid + 1, r, x, y));
return ans;
}
int asksum(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return tree2[p];
}
int mid = (l + r) / 2;
int ans = 0;
if (x <= mid) ans += asksum(p * 2, l, mid, x, y);
if (mid + 1 <= y) ans += asksum(p * 2 + 1, mid + 1, r, x, y);
return ans;
}
int qmax(int u, int v, int n) {
int ans = -1e9;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans = max(ans, askmax(1, 1, n, pos[top[u]], pos[u]));
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
ans = max(ans, askmax(1, 1, n, pos[u], pos[v]));
return ans;
}
int qsum(int u, int v, int n) {
int ans = 0;
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans += asksum(1, 1, n, pos[top[u]], pos[u]);
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
ans += asksum(1, 1, n, pos[u], pos[v]);
return ans;
}
int main() {
int n;
cin >> n;
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
dfs1(1, 0, 1);
dfs2(1, 1);
build(1, 1, n);
int q;
cin >> q;
while (q--) {
string op;
int u, v;
cin >> op >> u >> v;
if (op == "CHANGE") {
change(1, 1, n, pos[u], v);
} else if (op == "QMAX") {
cout << qmax(u, v, n) << endl;
} else {
cout << qsum(u, v, n) << endl;
}
}
return 0;
}
上一篇
下一篇