换根dp

[学习笔记]换根dp - 洛谷专栏 (luogu.com)

引入

我们来看这个题

[POI2008] STA-Station

题目描述

给定一个 \(n\) 个点的树,请求出一个结点,使得以这个结点为根时,所有结点的深度之和最大。

一个结点的深度之定义为该节点到根的简单路径上边的数量。

输入格式

第一行有一个整数,表示树的结点个数 \(n\)
接下来 \((n - 1)\) 行,每行两个整数 \(u, v\),表示存在一条连接 \(u, v\) 的边。

输出格式

本题存在 Special Judge

输出一行一个整数表示你选择的结点编号。如果有多个结点符合要求,输出任意一个即可。

样例 #1

样例输入 #1

1
2
3
4
5
6
7
8
8
1 4
5 6
4 5
6 7
6 8
2 4
3 4

样例输出 #1

1
7

提示

样例 1 解释

输出 \(7\)\(8\) 都是正确答案。

数据规模与约定

对于全部的测试点,保证 \(1 \leq n \leq 10^6\)\(1 \leq u, v \leq n\),给出的是一棵树。

给出一个 N 个点的树,找出一个点来,以这个点为根的树时,所有点的深度之和最大

我们可以枚举每一个点,将其作为根,然后求出每一个点的深度然后统计深度和就行了

但通过不了。

换根dp

我们可以利用一些技巧来把这一类问题优化到 Θ(n) 的时间内解决

换根dp一般分为三个步骤

1、先指定一个根节点 2、一次dfs统计子树内的节点对当前节点的贡献 3、一次dfs统计父亲节点对当前节点的贡献并合并统计最终答案

我们先来看一个节点 \(u\) 的子树里面的节点对它的贡献:

很明显,这个贡献就是子树里所有节点到 \(u\) 的深度和,我们把它记为 \(g_u\)

接着,我们记 \(sz_u\) 为以 \(u\) 为根的子树大小,\(dep_u\) 为点 \(u\) 到 1 号节点(我们指定的根节点)的深度(之后有用)

接下来,我们来考虑第2次dfs,也就是计算父亲对它的贡献

我们令 \(f_u\) 为以 $u $ 为根节点的深度和

所以,我们令 \(v\)\(u\) 的一个儿子节点,可得 \(f_v=g_v+(f_u-(g_v+sz_v))+(sz_1-sz_v)\)

为啥打上了括号呢?是因为这样更好理解了

如果看不懂,我来解释一下

\(g_v\) 是以 \(v\) 为根的子树深度和,显然要加上
\(f_u\) 是父亲 \(u\) 节点的答案,显然要减去我们 \(v\) 子树里的信息(不然就多算了)
那么,\(g_v+sz_v\) 就是我们 \(v\) 子树的信息了

有人就要问了,子树里的的信息不就是 \(g_v\) 吗?

但是我们是从父亲的答案减去,显然在以 \(u\) 为根时,\(v\) 的子树中的所有节点的深度都有加一,于是就增加 \(sz_v\)

同理,后面的 \(sz_1-sz_v\) 就是非 \(v\) 节点子树中的点啦,由于移动后深度加一。

解释完了,我们化简一下这个柿子 \(f_v=f_u+sz_1-2\times sz_v\)

发现可以不用记录 \(g\)

于是,我们就可以在 \(\Theta(n)\) 的时间内搞定啦

换根的部分需要注意子树深度减少,其他子树深度加加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define frep(i,a,n) for(int i=a;i>=n;i--)
#define int long long
#define PII pair<int,int>
#define lowbit(x) (x&(-x))
const int mod=1e9+7;
const double pai=acos(-1.0);
#define ios ios::sync_with_stdio(0); cin.tie(0),cout.tie(0);
#define LF(x) fixed<<setprecision(x)
const int N=1e6+100;
struct
{
int to, nxt;
}e[N<<1];
int n,cnt;
int head[N];
void add(int u,int v)
{
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
int dep[N];
int sz[N];
int f[N];
void dfs1(int u,int fa)
{
sz[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
continue;
dep[v]=dep[u]+1;
dfs1(v,u);
sz[u]+=sz[v];
}
}
void dfs2(int u,int fa)
{
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa)
continue;
f[v]=f[u]-2*sz[v]+sz[1];
dfs2(v,u);
}
}
signed main()
{
ios;
//换根dp
cin>>n;
rep(i,1,n-1)
{
int u,v;
cin>>u>>v;
add(u,v);
add(v,u);
}
dfs1(1,0);
rep(i,1,n)
{
f[1]+=dep[i];
//递推要用
}
dfs2(1,0);
int ans=-1e18;
int id=0;
rep(i,1,n)
{
if(ans<f[i])
{
ans=f[i];
id=i;
}
}
cout<<id<<'\n';
return 0;
}