Translation Notice
This article was machine-translated using DeepSeek-R1.
- Original Version: Authored in Chinese by myself
- Accuracy Advisory: Potential discrepancies may exist between translations
- Precedence: The Chinese text shall prevail in case of ambiguity
- Feedback: Technical suggestions regarding translation quality are welcomed
Definition
Tree DP, also known as Tree-shaped DP, refers to DP (Dynamic Programming) performed on trees, and is one of the more complex types in DP algorithms.
Basics
Let $f[u]=~$data related to tree node $u$, and perform $\text{DP}$ in topological order (from leaf nodes up to the root) to ensure a node’s children have updated DP values before updating the current node. Typically implemented via DFS:
1
2
3
4
5
6
7
|
void dfs(int v) { // Traverse node v
dp[v] = ...; // Initialize
for(int u: G[v]) { // Traverse all children of v
dfs(u);
update(u, v); // Update current node's DP value using child's DP value
}
}
|
【Example 1.1】Subtree Size
Given a tree with $N$ nodes rooted at node 1. For each $i=1,2,\dots,N$, compute the size of the subtree rooted at node $i$.
$$
f[v]=1+\sum_{i=1}^{\text{deg}_v} G[v][i]
$$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
#include <cstdio>
#include <vector>
#define maxn 100
using namespace std;
vector<int> G[maxn]; // Adjacency list
int sz[maxn]; // DP array, sz[v] = size of subtree v
void dfs(int v)
{
sz[v] = 1; // Initial size is 1
for(int u: G[v]) // Traverse children
{
dfs(u);
sz[v] += sz[u]; // Update subtree size
}
}
int main()
{
int n;
scanf("%d", &n);
for(int i=1; i<n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
}
dfs(1);
for(int i=1; i<=n; i++)
printf("%d\n", sz[i]);
return 0;
}
|
【Example 1.2】Luogu P1352 No Boss’s Party
This is the Maximum Independent Set on a Tree problem.
Let $f(v)$ be the optimal solution when selecting $v$, and $g(v)$ when not selecting $v$. Transitions:
- $g(v)=\sum\max\{f(u),g(u)\}$;
- $f(v)=r_i+\sum g(u)$.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
#include <cstdio>
#include <vector>
#define maxn 6005
using namespace std;
inline int max(int x, int y) { return x > y? x: y; }
vector<int> G[maxn];
bool bad[maxn];
int f[maxn], g[maxn];
void dfs(int v)
{
for(int u: G[v])
{
dfs(u);
f[v] += g[u];
g[v] += max(f[u], g[u]);
}
}
int main()
{
int n;
scanf("%d", &n);
for(int i=0; i<n; i++)
scanf("%d", f + i);
for(int i=1; i<n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
G[--v].push_back(--u);
bad[u] = true;
}
int root = -1;
for(int i=0; i<n; i++)
if(!bad[i])
{
root = i;
break;
}
dfs(root);
printf("%d\n", max(f[root], g[root]));
return 0;
}
|
Exercises
Tree Knapsack
Construct a tree with virtual root 0. Let $f[i][j]$ be the maximum credit when selecting $j$ courses in subtree $i$. Add 1 to $M$ since root 0 must be selected.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
#include <cstdio>
#include <vector>
#include <algorithm>
#define maxn 305
using namespace std;
inline void setmax(int& x, int y)
{
if(x < y) x = y;
}
vector<int> G[maxn];
int n, m, f[maxn][maxn];
int dfs(int u)
{
int tot = 1;
for(int v: G[u])
{
int sz = dfs(v);
for(int i=min(tot, m); i>0; i--)
for(int j=1, lim=min(sz, m-i); j<=lim; j++)
setmax(f[u][i + j], f[u][i] + f[v][j]);
tot += sz;
}
return tot;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++)
{
int a;
scanf("%d%d", &a, f[i] + 1);
G[a].push_back(i);
}
m ++;
dfs(0);
printf("%d\n", f[0][m]);
return 0;
}
|
Exercises
Rerooting DP
【Example 3.1】Luogu P3478 [POI2008] STA-Station
Compute depth sum for all nodes when treating each node as root. Use two-pass DFS for $\mathcal O(N)$ solution.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
#include <cstdio>
#include <vector>
#define maxn 1000005
using namespace std;
using LL = long long;
vector<int> G[maxn];
LL sz[maxn], f[maxn];
int n, ans;
LL dfs1(int v, int d, int par)
{
sz[v] = 1;
LL s = d;
for(int u: G[v])
if(u != par)
s += dfs1(u, d + 1, v), sz[v] += sz[u];
return s;
}
void dfs2(int v, int par)
{
if(f[v] > f[ans]) ans = v;
for(int u: G[v])
if(u != par)
{
f[u] = f[v] + n - (sz[u] << 1LL);
dfs2(u, v);
}
}
int main()
{
scanf("%d", &n);
for(int t=n; --t; )
{
int u, v;
scanf("%d%d", &u, &v);
G[--u].push_back(--v);
G[v].push_back(u);
}
f[0] = dfs1(0, 0, -1);
dfs2(0, -1);
printf("%d\n", ++ans);
return 0;
}
|
Exercises
Postscript
Turns out it’s not as hard as it initially seemed… Remember to like and share!
References: