树链剖分

树链剖分,就是将树剖分为若干条链,用来维护树上信息,常搭配树上值域线段树,如:

  1. 修改树上两点间的最短路径上的节点权值

  2. 查询树上两点间的最短路径上的点权和

  3. 修改以某个点为根的子树的每个节点的点权

  4. 查询以某个点为根的子树的节点权值和

其中,操作3和4可以直接建立树上值域线段树解决,操作1和2需要进行树链剖分。

树链剖分有三种方法:重链剖分(复杂度 \(O(\log n)\))、长链剖分(复杂度 \(O(\sqrt n)\))和实链剖分(常用于LCT维护)。其中,重链剖分最为常见,因此本节主要记录重链剖分的学习笔记

一、基础定义

重儿子:一个节点的所有儿子中,子树大小最大的那一个儿子。如有多种选择,就只选一个儿子

轻儿子:一个节点的所有儿子中,不是重儿子的节点。根节点也是轻儿子。

重链:从一个轻儿子开始,沿着重儿子走,连出的极大子链。

轻链:不是重链的子链。

重链定理\(\quad\) 除了根节点以外的任何一个节点的父亲一定在一条重链上。

二、重链剖分

重链剖分,需要我们维护一下内容:

  1. fa[MAXN],即节点的父节点。
  2. dep[MAXN],即节点深度。
  3. son[MAXN],即该节点的重儿子编号,如果是叶子节点,则 son[p]=0
  4. top[MAXN],即该节点所在重链的链头。
  5. sz[MAXN],即以该节点为根的子树的大小。
  6. dfn[MAXN],该节点进行 \(\text{dfs}\) 的时间戳,即该节点的 \(\text{dfs}\) 序。
  7. w[MAXN],即在 \(\text{dfs}\) 序中,该序号节点的权值。
  8. tick,即 \(\text{dfs}\) 时间戳。

前面几个信息可以打包进一个结构体,然后线段树需要另一个结构体。

重链剖分要求重链上的时间戳一定要连续(方便在线段树上区间修改和查询),所以需要进行两次 \(\text{dfs}\)

2.1\(\quad\) 第一次 \(\text{dfs}\)

第一次 \(\text{dfs}\) 需要处理出重链剖分的前置信息。

从根节点开始遍历整棵树。记录节点父亲、子树大小、深度,还有重儿子。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void dfs1(int u,int fa){
T[u].fa=fa;
T[u].sz=1;
T[u].dep=T[fa].dep+1;
int tmp=-1;
for(auto v:g[u]){
if(v==fa) continue;
dfs1(v,u);
T[u].sz+=T[v].sz;
if(T[v].sz>tmp){
tmp=T[v].sz;
T[u].son=v;
}
}
}

2.2\(\quad\) 第二次 $ $

第二次 \(\text{dfs}\) 就可以剖分这棵树了。

我们进行重链剖分,记录该节点所在的重链的链头和时间戳。

1
2
3
4
5
6
7
8
9
10
11
void dfs2(int u,int t){
T[u].top=t;
dfn[u]=++tick;
w[tick]=a[u];
if(!T[u].son) return;
dfs2(T[u].son,t);
for(auto v:g[u]){
if(v==T[u].fa||v==T[u].son) continue;
dfs2(v,v);
}
}

2.3\(\quad\) 建立树上值域线段树

因为子树的 \(\text{dfs}\) 序一个区间,我们就可以建立值域线段树。

1
2
3
4
5
6
7
8
9
10
11
void build(int l,int r,int p){
tree[p].l=l,tree[p].r=r;
tree[p].tag=0;
if(l==r){
tree[p].sum=w[l];
return;
}
build(l,mid,ls);
build(mid+1,r,rs);
update(p);//整合子树信息
}

三、维护信息

3.1\(\quad\) 进行子树加操作

因为子树的 \(\text{dfs}\) 序是一个区间,可以在线段树上进行区间修改操作(\(\text{modify}\)),修改的区间就是 ,其\([\text{dfn}[p],\text{dfn}[p]+\text{sz}[p]-1]\)\(p\) 为子树根节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
void modify(int l,int r,int k,int p){
if(l<=tree[p].l&&r>=tree[p].r){
tree[p].sum=(tree[p].sum+k*len(p))%mod;
tree[p].tag=(tree[p].tag+k)%mod;
return;
}
pushdown(p);//懒标记下传
if(l<=mid) modify(l,r,k,ls);
if(r>mid) modify(l,r,k,rs);
update(p);
}
//in main:
modify(dfn[x],dfn[x]+T[x].sz-1,z,1);

3.2\(\quad\)进行子树求和

类比子树加操作,在线段树上进行区间求和(\(\text{query}\)),求和区间就是 \([\text{dfn}[p],\text{dfn}[p]+\text{sz}[p]-1]\),其中 \(p\) 为子树根节点。

1
2
3
4
5
6
7
8
9
10
int query(int l,int r,int p){
if(l<=tree[p].l&&r>=tree[p].r) return tree[p].sum;
pushdown(p);
int ans=0;
if(l<=mid) ans=(ans+query(l,r,ls))%mod;
if(r>mid) ans=(ans+query(l,r,rs))%mod;
return ans;
}
//in main:
cout<<query(dfn[x],dfn[x]+T[x].sz-1,1)<<endl;

3.3\(\quad\) 进行路径修改操作

根据重连定理,除了根节点以外的任何一个节点的父亲一定在一条重链上。所以我们就可以进行重链到重链的转换,从而一点一点地在每一条链上进行区间修改。

考虑每次选择链头深度高的那条链,将该节点跳到链头并区间修改,此时就改掉了这条链(也就是路径的一部分)上的值,修改区间为 \([\text{dfn}[\text{top}[p]],\text{dfn}[p]]\),其中,\(p\) 为该节点,而后跳到链头的父亲,此时就在另一条链上了,可以重复操作直到两节点在同一条重链上。

如果两节点在同一跳重链上,则可以直接进行区间修改,修改区间为 \([\text{dfn}[x],\text{dfn}[y]]\)$,其中 \(x,y\) 是两个节点,且防止无效修改操作, \(\text{dep}[x]<\text{dep}[y]\)

1
2
3
4
5
6
7
8
9
10
void addOnTree(int x,int y,int c){
c%=mod;
while(T[x].top!=T[y].top){
if(T[T[x].top].dep<T[T[y].top].dep) swap(x,y);
modify(dfn[T[x].top],dfn[x],c,1);
x=T[T[x].top].fa;
}
if(T[x].dep>T[y].dep) swap(x,y);
modify(dfn[x],dfn[y],c,1);
}

3.4\(\quad\) 进行路径求和操作

思想类似路径修改,只不过把修改操作改成求和。

1
2
3
4
5
6
7
8
9
10
11
int getSumOnTree(int x,int y){
int ans=0;
while(T[x].top!=T[y].top){
if(T[T[x].top].dep<T[T[y].top].dep) swap(x,y);
ans=(ans+query(dfn[T[x].top],dfn[x],1))%mod;
x=T[T[x].top].fa;
}
if(T[x].dep>T[y].dep) swap(x,y);
ans=(ans+query(dfn[x],dfn[y],1))%mod;
return ans;
}

四、参考代码

本代码为树链剖分/重链剖分模板。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include<iostream>
#include<vector>
using namespace std;
#define MAXN 100005
#define mid ((tree[p].l+tree[p].r)>>1)
#define ls (p<<1)
#define rs (p<<1|1)
#define len(x) (tree[x].r-tree[x].l+1)
struct G{
int sz,dep,top,son,fa;
}T[MAXN];
struct F{
int l,r,sum,tag;
}tree[MAXN<<2];
int n,m,r,mod,a[MAXN],dfn[MAXN],tick,w[MAXN];
vector<int> g[MAXN];
void dfs1(int u,int fa){
T[u].fa=fa;
T[u].sz=1;
T[u].dep=T[fa].dep+1;
int tmp=-1;
for(auto v:g[u]){
if(v==fa) continue;
dfs1(v,u);
T[u].sz+=T[v].sz;
if(T[v].sz>tmp){
tmp=T[v].sz;
T[u].son=v;
}
}
}
void dfs2(int u,int t){
T[u].top=t;
dfn[u]=++tick;
w[tick]=a[u];
if(!T[u].son) return;
dfs2(T[u].son,t);
for(auto v:g[u]){
if(v==T[u].fa||v==T[u].son) continue;
dfs2(v,v);
}
}
void update(int p){
tree[p].sum=tree[ls].sum+tree[rs].sum;
}
void build(int l,int r,int p){
tree[p].l=l,tree[p].r=r;
tree[p].tag=0;
if(l==r){
tree[p].sum=w[l];
return;
}
build(l,mid,ls);
build(mid+1,r,rs);
update(p);
}
void pushdown(int p){
if(!tree[p].tag) return;
tree[ls].sum=(tree[ls].sum+tree[p].tag*len(ls))%mod;
tree[rs].sum=(tree[rs].sum+tree[p].tag*len(rs))%mod;
tree[ls].tag=(tree[ls].tag+tree[p].tag)%mod;
tree[rs].tag=(tree[rs].tag+tree[p].tag)%mod;
tree[p].tag=0;
}
void modify(int l,int r,int k,int p){
if(l<=tree[p].l&&r>=tree[p].r){
tree[p].sum=(tree[p].sum+k*len(p))%mod;
tree[p].tag=(tree[p].tag+k)%mod;
return;
}
pushdown(p);
if(l<=mid) modify(l,r,k,ls);
if(r>mid) modify(l,r,k,rs);
update(p);
}
void addOnTree(int x,int y,int c){
c%=mod;
while(T[x].top!=T[y].top){
if(T[T[x].top].dep<T[T[y].top].dep) swap(x,y);
modify(dfn[T[x].top],dfn[x],c,1);
x=T[T[x].top].fa;
}
if(T[x].dep>T[y].dep) swap(x,y);
modify(dfn[x],dfn[y],c,1);
}
int query(int l,int r,int p){
if(l<=tree[p].l&&r>=tree[p].r) return tree[p].sum;
pushdown(p);
int ans=0;
if(l<=mid) ans=(ans+query(l,r,ls))%mod;
if(r>mid) ans=(ans+query(l,r,rs))%mod;
return ans;
}
int getSumOnTree(int x,int y){
int ans=0;
while(T[x].top!=T[y].top){
if(T[T[x].top].dep<T[T[y].top].dep) swap(x,y);
ans=(ans+query(dfn[T[x].top],dfn[x],1))%mod;
x=T[T[x].top].fa;
}
if(T[x].dep>T[y].dep) swap(x,y);
ans=(ans+query(dfn[x],dfn[y],1))%mod;
return ans;
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;i++) cin>>a[i];
for(int x,y,i=1;i<n;i++){
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
dfs1(r,0);
dfs2(r,r);
build(1,n,1);
for(int op,x,y,z,i=1;i<=m;i++){
cin>>op;
switch(op){
case 1:{
cin>>x>>y>>z;
addOnTree(x,y,z);
break;
}
case 2:{
cin>>x>>y;
cout<<getSumOnTree(x,y)<<endl;
break;
}
case 3:{
cin>>x>>z;
modify(dfn[x],dfn[x]+T[x].sz-1,z,1);
break;
}
case 4:{
cin>>x;
cout<<query(dfn[x],dfn[x]+T[x].sz-1,1)<<endl;
break;
}
}
}
return 0;
}