【Algorithm Notes】Tree DP Algorithm Summary & Detailed Explanation

Translation Notice
This article was machine-translated using DeepSeek-R1.

  • Original Version: Authored in Chinese by myself
  • Accuracy Advisory: Potential discrepancies may exist between translations
  • Precedence: The Chinese text shall prevail in case of ambiguity
  • Feedback: Technical suggestions regarding translation quality are welcomed

Definition

Tree DP, also known as Tree-shaped DP, refers to DP (Dynamic Programming) performed on trees, and is one of the more complex types in DP algorithms.

Basics

Let $f[u]=~$data related to tree node $u$, and perform $\text{DP}$ in topological order (from leaf nodes up to the root) to ensure a node’s children have updated DP values before updating the current node. Typically implemented via DFS:

1
2
3
4
5
6
7
void dfs(int v) { // Traverse node v
    dp[v] = ...; // Initialize
    for(int u: G[v]) { // Traverse all children of v
        dfs(u);
        update(u, v); // Update current node's DP value using child's DP value
    }
}

【Example 1.1】Subtree Size

Given a tree with $N$ nodes rooted at node 1. For each $i=1,2,\dots,N$, compute the size of the subtree rooted at node $i$.

$$ f[v]=1+\sum_{i=1}^{\text{deg}_v} G[v][i] $$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <cstdio>
#include <vector>
#define maxn 100
using namespace std;

vector<int> G[maxn]; // Adjacency list
int sz[maxn]; // DP array, sz[v] = size of subtree v

void dfs(int v)
{
    sz[v] = 1; // Initial size is 1
    for(int u: G[v]) // Traverse children
    {
        dfs(u); 
        sz[v] += sz[u]; // Update subtree size
    }
}

int main()
{
    int n;
    scanf("%d", &n); 
    for(int i=1; i<n; i++) 
    {
        int u, v;
        scanf("%d%d", &u, &v); 
        G[u].push_back(v); 
    }
    dfs(1);
    for(int i=1; i<=n; i++)
        printf("%d\n", sz[i]);
    return 0;
}

【Example 1.2】Luogu P1352 No Boss’s Party

This is the Maximum Independent Set on a Tree problem.

Let $f(v)$ be the optimal solution when selecting $v$, and $g(v)$ when not selecting $v$. Transitions:

  • $g(v)=\sum\max\{f(u),g(u)\}$;
  • $f(v)=r_i+\sum g(u)$.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <cstdio>
#include <vector>
#define maxn 6005
using namespace std;

inline int max(int x, int y) { return x > y? x: y; }

vector<int> G[maxn]; 
bool bad[maxn]; 
int f[maxn], g[maxn]; 

void dfs(int v) 
{
    for(int u: G[v]) 
    {
        dfs(u); 
        f[v] += g[u]; 
        g[v] += max(f[u], g[u]); 
    }
}

int main()
{
    int n;
    scanf("%d", &n); 
    for(int i=0; i<n; i++)
        scanf("%d", f + i); 
    for(int i=1; i<n; i++) 
    {
        int u, v;
        scanf("%d%d", &u, &v); 
        G[--v].push_back(--u); 
        bad[u] = true; 
    }
    int root = -1; 
    for(int i=0; i<n; i++)
        if(!bad[i]) 
        {
            root = i; 
            break;
        }
    dfs(root); 
    printf("%d\n", max(f[root], g[root])); 
    return 0;
}

Exercises

Tree Knapsack

【Example 2.1】Luogu P2014 / AcWing 286 Course Selection

Construct a tree with virtual root 0. Let $f[i][j]$ be the maximum credit when selecting $j$ courses in subtree $i$. Add 1 to $M$ since root 0 must be selected.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <cstdio>
#include <vector>
#include <algorithm>
#define maxn 305
using namespace std;

inline void setmax(int& x, int y)
{
    if(x < y) x = y;
}

vector<int> G[maxn]; 
int n, m, f[maxn][maxn];

int dfs(int u) 
{
    int tot = 1; 
    for(int v: G[u]) 
    {
        int sz = dfs(v); 
        for(int i=min(tot, m); i>0; i--) 
            for(int j=1, lim=min(sz, m-i); j<=lim; j++)
                setmax(f[u][i + j], f[u][i] + f[v][j]); 
        tot += sz; 
    }
    return tot; 
}

int main()
{
    scanf("%d%d", &n, &m);
    for(int i=1; i<=n; i++)
    {
        int a;
        scanf("%d%d", &a, f[i] + 1); 
        G[a].push_back(i);
    }
    m ++; 
    dfs(0);
    printf("%d\n", f[0][m]);
    return 0;
}

Exercises

Rerooting DP

【Example 3.1】Luogu P3478 [POI2008] STA-Station

Compute depth sum for all nodes when treating each node as root. Use two-pass DFS for $\mathcal O(N)$ solution.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <cstdio>
#include <vector>
#define maxn 1000005
using namespace std;

using LL = long long;

vector<int> G[maxn];
LL sz[maxn], f[maxn];
int n, ans;

LL dfs1(int v, int d, int par)
{
    sz[v] = 1;
    LL s = d;
    for(int u: G[v])
        if(u != par)
            s += dfs1(u, d + 1, v), sz[v] += sz[u];
    return s;
}

void dfs2(int v, int par)
{
    if(f[v] > f[ans]) ans = v;
    for(int u: G[v])
        if(u != par)
        {
            f[u] = f[v] + n - (sz[u] << 1LL);
            dfs2(u, v);
        }
}

int main()
{
    scanf("%d", &n);
    for(int t=n; --t; )
    {
        int u, v;
        scanf("%d%d", &u, &v);
        G[--u].push_back(--v);
        G[v].push_back(u);
    }
    f[0] = dfs1(0, 0, -1);
    dfs2(0, -1);
    printf("%d\n", ++ans);
    return 0;
}

Exercises

Postscript

Turns out it’s not as hard as it initially seemed… Remember to like and share!

References:

Built with Hugo
Theme Stack designed by Jimmy