Post

12. 최소 공통 조상 (LCA)

12. 최소 공통 조상 (LCA)

개념

최소 공통 조상(Lowest Common Ancestor, LCA)은 트리에서 두 노드의 공통 조상 중 가장 깊은(가장 가까운) 노드를 찾는 문제이다.

LCA 예시

1
2
3
4
5
6
7
8
9
10
11
12
        1
       /|\
      2 3 4
     /|   |
    5 6   7
   /
  8

LCA(5, 6) = 2
LCA(5, 7) = 1
LCA(8, 6) = 2
LCA(6, 4) = 1

알고리즘 비교

알고리즘전처리쿼리공간특징
단순 방법O(N)O(N)O(N)구현 간단
희소 배열 (Sparse Table)O(N log N)O(log N)O(N log N)가장 많이 사용
오일러 투어 + RMQO(N log N)O(1)O(N log N)쿼리 최적

핵심 연산 & 시간복잡도

연산희소 배열단순 방법
전처리O(N log N)O(N)
LCA 쿼리O(log N)O(N)
두 노드 거리O(log N)O(N)

구현 (No-STL)

단순 방법 (깊이 맞추기)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
const int MAX_N = 100001;

int parent[MAX_N];
int depth[MAX_N];
int adj[MAX_N][100];  // 인접 리스트
int adjCnt[MAX_N];
int n;

void dfs(int cur, int par, int d) {
    parent[cur] = par;
    depth[cur] = d;

    for (int i = 0; i < adjCnt[cur]; i++) {
        int next = adj[cur][i];
        if (next != par) {
            dfs(next, cur, d + 1);
        }
    }
}

int lcaNaive(int u, int v) {
    // 깊이 맞추기
    while (depth[u] > depth[v]) u = parent[u];
    while (depth[v] > depth[u]) v = parent[v];

    // 같이 올라가기
    while (u != v) {
        u = parent[u];
        v = parent[v];
    }

    return u;
}

희소 배열 (Sparse Table) - 핵심

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
const int MAX_N = 100001;
const int LOG = 17;  // log2(MAX_N) + 1

int parent[MAX_N][LOG];  // parent[v][k] = v의 2^k번째 조상
int depth[MAX_N];
int adj[MAX_N][100];
int adjCnt[MAX_N];
int n;

void dfs(int cur, int par, int d) {
    parent[cur][0] = par;
    depth[cur] = d;

    for (int i = 0; i < adjCnt[cur]; i++) {
        int next = adj[cur][i];
        if (next != par) {
            dfs(next, cur, d + 1);
        }
    }
}

void preprocess(int root) {
    dfs(root, 0, 0);  // 루트의 부모는 0 (또는 -1)

    // 희소 배열 구축
    for (int k = 1; k < LOG; k++) {
        for (int v = 1; v <= n; v++) {
            if (parent[v][k - 1] != 0) {
                parent[v][k] = parent[parent[v][k - 1]][k - 1];
            }
        }
    }
}

int lca(int u, int v) {
    // u가 더 깊도록 swap
    if (depth[u] < depth[v]) {
        int temp = u; u = v; v = temp;
    }

    // 깊이 맞추기 (2^k씩 점프)
    int diff = depth[u] - depth[v];
    for (int k = 0; k < LOG; k++) {
        if ((diff >> k) & 1) {
            u = parent[u][k];
        }
    }

    if (u == v) return u;

    // 같이 올라가기 (LCA 바로 아래까지)
    for (int k = LOG - 1; k >= 0; k--) {
        if (parent[u][k] != parent[v][k]) {
            u = parent[u][k];
            v = parent[v][k];
        }
    }

    return parent[u][0];
}

구조체로 캡슐화

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
const int MAX_N = 100001;
const int LOG = 17;

struct LCA {
    int parent[MAX_N][LOG];
    int depth[MAX_N];
    int adj[MAX_N][100];
    int adjCnt[MAX_N];
    int n;

    void init(int size) {
        n = size;
        for (int i = 0; i <= n; i++) {
            adjCnt[i] = 0;
            for (int k = 0; k < LOG; k++) parent[i][k] = 0;
        }
    }

    void addEdge(int u, int v) {
        adj[u][adjCnt[u]++] = v;
        adj[v][adjCnt[v]++] = u;
    }

    void dfs(int cur, int par, int d) {
        parent[cur][0] = par;
        depth[cur] = d;
        for (int i = 0; i < adjCnt[cur]; i++) {
            if (adj[cur][i] != par) {
                dfs(adj[cur][i], cur, d + 1);
            }
        }
    }

    void build(int root) {
        dfs(root, 0, 0);
        for (int k = 1; k < LOG; k++) {
            for (int v = 1; v <= n; v++) {
                if (parent[v][k - 1]) {
                    parent[v][k] = parent[parent[v][k - 1]][k - 1];
                }
            }
        }
    }

    int query(int u, int v) {
        if (depth[u] < depth[v]) { int t = u; u = v; v = t; }
        int diff = depth[u] - depth[v];
        for (int k = 0; k < LOG; k++) {
            if ((diff >> k) & 1) u = parent[u][k];
        }
        if (u == v) return u;
        for (int k = LOG - 1; k >= 0; k--) {
            if (parent[u][k] != parent[v][k]) {
                u = parent[u][k];
                v = parent[v][k];
            }
        }
        return parent[u][0];
    }
};

두 노드 사이 거리

1
2
3
4
int distance(int u, int v) {
    int l = lca(u, v);
    return depth[u] + depth[v] - 2 * depth[l];
}

K번째 조상 찾기

1
2
3
4
5
6
7
8
9
10
// v의 K번째 조상 (희소 배열 활용)
int kthAncestor(int v, int k) {
    for (int i = 0; i < LOG; i++) {
        if ((k >> i) & 1) {
            v = parent[v][i];
            if (v == 0) return -1;  // 조상이 없음
        }
    }
    return v;
}

경로 상의 K번째 노드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// u에서 v로 가는 경로에서 K번째 노드 (0-indexed)
int kthNodeOnPath(int u, int v, int k) {
    int l = lca(u, v);
    int distToLca = depth[u] - depth[l];

    if (k <= distToLca) {
        // u에서 LCA 방향으로 k번째
        return kthAncestor(u, k);
    } else {
        // LCA에서 v 방향으로
        int distFromLca = k - distToLca;
        int totalToV = depth[v] - depth[l];
        return kthAncestor(v, totalToV - distFromLca);
    }
}

STL

vector 기반 인접 리스트

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <vector>

const int MAX_N = 100001;
const int LOG = 17;

std::vector<int> adj[MAX_N];
int parent[MAX_N][LOG];
int depth[MAX_N];
int n;

void dfs(int cur, int par, int d) {
    parent[cur][0] = par;
    depth[cur] = d;

    for (int next : adj[cur]) {
        if (next != par) {
            dfs(next, cur, d + 1);
        }
    }
}

void build(int root) {
    dfs(root, 0, 0);

    for (int k = 1; k < LOG; k++) {
        for (int v = 1; v <= n; v++) {
            if (parent[v][k - 1]) {
                parent[v][k] = parent[parent[v][k - 1]][k - 1];
            }
        }
    }
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) std::swap(u, v);

    int diff = depth[u] - depth[v];
    for (int k = 0; k < LOG; k++) {
        if ((diff >> k) & 1) u = parent[u][k];
    }

    if (u == v) return u;

    for (int k = LOG - 1; k >= 0; k--) {
        if (parent[u][k] != parent[v][k]) {
            u = parent[u][k];
            v = parent[v][k];
        }
    }

    return parent[u][0];
}

가중치 있는 트리

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <vector>

const int MAX_N = 100001;
const int LOG = 17;

std::vector<std::pair<int, int>> adj[MAX_N];  // {next, weight}
int parent[MAX_N][LOG];
int depth[MAX_N];
long long dist[MAX_N];  // 루트에서의 거리
int n;

void dfs(int cur, int par, int d, long long distance) {
    parent[cur][0] = par;
    depth[cur] = d;
    dist[cur] = distance;

    for (auto& [next, weight] : adj[cur]) {
        if (next != par) {
            dfs(next, cur, d + 1, distance + weight);
        }
    }
}

long long pathWeight(int u, int v) {
    return dist[u] + dist[v] - 2 * dist[lca(u, v)];
}

사용 예시

트리에서 경로 합

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 각 노드에 값이 있을 때, u에서 v로 가는 경로의 합
// prefixSum[v] = 루트에서 v까지의 합

long long prefixSum[MAX_N];

void dfs(int cur, int par, long long sum) {
    prefixSum[cur] = sum + value[cur];
    // ...
}

long long pathSum(int u, int v) {
    int l = lca(u, v);
    return prefixSum[u] + prefixSum[v] - 2 * prefixSum[l] + value[l];
}

트리에서 쿼리 처리

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Q개의 쿼리: u, v 사이의 거리

void solve() {
    build(1);  // 전처리

    int q;
    scanf("%d", &q);

    while (q--) {
        int u, v;
        scanf("%d %d", &u, &v);
        printf("%d\n", distance(u, v));
    }
}

경로 상 최솟값 (희소 배열 확장)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
const int MAX_N = 100001;
const int LOG = 17;
const int INF = 1e9;

int parent[MAX_N][LOG];
int minEdge[MAX_N][LOG];  // v에서 2^k번째 조상까지의 최소 간선
int depth[MAX_N];

void dfs(int cur, int par, int d, int weight) {
    parent[cur][0] = par;
    minEdge[cur][0] = weight;
    depth[cur] = d;
    // ...
}

void build(int root) {
    dfs(root, 0, 0, INF);

    for (int k = 1; k < LOG; k++) {
        for (int v = 1; v <= n; v++) {
            if (parent[v][k - 1]) {
                parent[v][k] = parent[parent[v][k - 1]][k - 1];
                minEdge[v][k] = (minEdge[v][k - 1] < minEdge[parent[v][k - 1]][k - 1])
                              ? minEdge[v][k - 1] : minEdge[parent[v][k - 1]][k - 1];
            }
        }
    }
}

int pathMin(int u, int v) {
    int result = INF;

    if (depth[u] < depth[v]) { int t = u; u = v; v = t; }

    int diff = depth[u] - depth[v];
    for (int k = 0; k < LOG; k++) {
        if ((diff >> k) & 1) {
            result = (result < minEdge[u][k]) ? result : minEdge[u][k];
            u = parent[u][k];
        }
    }

    if (u == v) return result;

    for (int k = LOG - 1; k >= 0; k--) {
        if (parent[u][k] != parent[v][k]) {
            result = (result < minEdge[u][k]) ? result : minEdge[u][k];
            result = (result < minEdge[v][k]) ? result : minEdge[v][k];
            u = parent[u][k];
            v = parent[v][k];
        }
    }

    result = (result < minEdge[u][0]) ? result : minEdge[u][0];
    result = (result < minEdge[v][0]) ? result : minEdge[v][0];

    return result;
}

주의사항 / Edge Cases

LOG 값 설정

1
2
3
4
5
6
// LOG = ceil(log2(N)) + 1
// N = 100,000이면 LOG = 17 (2^17 = 131,072)
// N = 1,000,000이면 LOG = 20

const int LOG = 17;  // N ≤ 100,000
const int LOG = 20;  // N ≤ 1,000,000

루트 노드의 조상

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 루트의 부모를 0 또는 -1로 설정
// 0으로 설정하면 parent 배열 초기화 불필요

void dfs(int cur, int par, int d) {
    parent[cur][0] = par;  // 루트면 par = 0
    // ...
}

int lca(int u, int v) {
    // ...
    if (parent[u][k] != parent[v][k]) {  // 0과 비교되어도 OK
        // ...
    }
}

깊이 비교 순서

1
2
3
4
5
6
7
8
// u가 더 깊도록 swap하는 것이 일반적
if (depth[u] < depth[v]) {
    int temp = u; u = v; v = temp;
}

// 또는 절댓값 사용
int diff = depth[u] - depth[v];
if (diff < 0) diff = -diff;  // 위험: 둘 다 올려야 함

희소 배열 순서

1
2
3
4
5
6
7
8
// k = 0인 경우를 먼저 채우고, k를 증가시키며 채움
for (int k = 1; k < LOG; k++) {  // k = 1부터 시작
    for (int v = 1; v <= n; v++) {
        if (parent[v][k - 1] != 0) {  // 조상이 존재할 때만
            parent[v][k] = parent[parent[v][k - 1]][k - 1];
        }
    }
}

같은 노드 쿼리

1
2
3
4
5
6
7
8
int lca(int u, int v) {
    if (u == v) return u;  // 바로 반환해도 됨

    // 또는 깊이 맞춘 후 체크
    // ...
    if (u == v) return u;  // 깊이 맞춘 후에도 체크
    // ...
}

재귀 깊이

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 트리가 편향되면 재귀 깊이 초과 가능 (N > 10만)
// BFS로 전처리하거나 스택 크기 늘리기

// 반복적 DFS
void bfs(int root) {
    int queue[MAX_N], front = 0, rear = 0;
    queue[rear++] = root;
    parent[root][0] = 0;
    depth[root] = 0;

    while (front < rear) {
        int cur = queue[front++];
        for (int i = 0; i < adjCnt[cur]; i++) {
            int next = adj[cur][i];
            if (next != parent[cur][0]) {
                parent[next][0] = cur;
                depth[next] = depth[cur] + 1;
                queue[rear++] = next;
            }
        }
    }
}

추천 문제

난이도문제링크
GoldLCABOJ 11437
PlatinumLCA 2BOJ 11438
Platinum정점들의 거리BOJ 1761
Platinum도로 네트워크BOJ 3176
Platinum트리와 쿼리BOJ 15480
This post is licensed under CC BY 4.0 by the author.