Post

15. 트라이 (Trie)

15. 트라이 (Trie)

개념

트라이(Trie)는 문자열을 효율적으로 저장하고 검색하는 트리 자료구조이다. 접두사 트리(Prefix Tree)라고도 한다.

구조

1
2
3
4
5
6
7
8
9
10
        (root)
       /  |  \
      a   b   c
     /|       |
    p n       a
   /  |       |
  p   t       r
 /
l (apple)     (car)
   (ant)

특징

특징설명
접두사 공유공통 접두사를 가진 문자열은 경로 공유
검색 시간문자열 길이 L에 대해 O(L)
공간문자 집합 크기(Σ)에 비례

해시 테이블과 비교

항목트라이해시 테이블
검색O(L)O(L) 평균
접두사 검색O(L)O(N * L)
공간많음적음
정렬 순서유지없음

핵심 연산 & 시간복잡도

연산시간복잡도설명
삽입O(L)L = 문자열 길이
검색O(L)정확히 일치하는 문자열
접두사 검색O(L)접두사로 시작하는 문자열 존재 여부
삭제O(L)문자열 제거

구현 (No-STL)

배열 기반 트라이 (소문자만)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
const int MAX_NODE = 1000001;  // 최대 노드 수
const int ALPHA = 26;  // 알파벳 크기

int trie[MAX_NODE][ALPHA];
bool isEnd[MAX_NODE];  // 단어의 끝인지
int nodeCnt;

void init() {
    nodeCnt = 1;  // 0은 루트
    for (int i = 0; i < ALPHA; i++) {
        trie[0][i] = 0;
    }
    isEnd[0] = false;
}

int newNode() {
    for (int i = 0; i < ALPHA; i++) {
        trie[nodeCnt][i] = 0;
    }
    isEnd[nodeCnt] = false;
    return nodeCnt++;
}

void insert(const char* s) {
    int cur = 0;

    for (int i = 0; s[i]; i++) {
        int c = s[i] - 'a';

        if (trie[cur][c] == 0) {
            trie[cur][c] = newNode();
        }
        cur = trie[cur][c];
    }

    isEnd[cur] = true;
}

bool search(const char* s) {
    int cur = 0;

    for (int i = 0; s[i]; i++) {
        int c = s[i] - 'a';

        if (trie[cur][c] == 0) {
            return false;
        }
        cur = trie[cur][c];
    }

    return isEnd[cur];
}

bool startsWith(const char* prefix) {
    int cur = 0;

    for (int i = 0; prefix[i]; i++) {
        int c = prefix[i] - 'a';

        if (trie[cur][c] == 0) {
            return false;
        }
        cur = trie[cur][c];
    }

    return true;
}

단어 개수 저장

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
const int MAX_NODE = 1000001;
const int ALPHA = 26;

int trie[MAX_NODE][ALPHA];
int wordCount[MAX_NODE];  // 이 노드에서 끝나는 단어 수
int prefixCount[MAX_NODE];  // 이 노드를 지나는 단어 수
int nodeCnt;

void init() {
    nodeCnt = 1;
    for (int i = 0; i < ALPHA; i++) {
        trie[0][i] = 0;
    }
    wordCount[0] = 0;
    prefixCount[0] = 0;
}

int newNode() {
    for (int i = 0; i < ALPHA; i++) {
        trie[nodeCnt][i] = 0;
    }
    wordCount[nodeCnt] = 0;
    prefixCount[nodeCnt] = 0;
    return nodeCnt++;
}

void insert(const char* s) {
    int cur = 0;

    for (int i = 0; s[i]; i++) {
        int c = s[i] - 'a';
        if (trie[cur][c] == 0) {
            trie[cur][c] = newNode();
        }
        cur = trie[cur][c];
        prefixCount[cur]++;
    }

    wordCount[cur]++;
}

int countWord(const char* s) {
    int cur = 0;
    for (int i = 0; s[i]; i++) {
        int c = s[i] - 'a';
        if (trie[cur][c] == 0) return 0;
        cur = trie[cur][c];
    }
    return wordCount[cur];
}

int countPrefix(const char* prefix) {
    int cur = 0;
    for (int i = 0; prefix[i]; i++) {
        int c = prefix[i] - 'a';
        if (trie[cur][c] == 0) return 0;
        cur = trie[cur][c];
    }
    return prefixCount[cur];
}

숫자/비트 트라이 (XOR 최댓값)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
const int MAX_NODE = 3200001;  // 32비트 * 100000개

int trie[MAX_NODE][2];
int nodeCnt;

void init() {
    nodeCnt = 1;
    trie[0][0] = trie[0][1] = 0;
}

int newNode() {
    trie[nodeCnt][0] = trie[nodeCnt][1] = 0;
    return nodeCnt++;
}

void insert(int num) {
    int cur = 0;

    for (int i = 30; i >= 0; i--) {
        int bit = (num >> i) & 1;

        if (trie[cur][bit] == 0) {
            trie[cur][bit] = newNode();
        }
        cur = trie[cur][bit];
    }
}

int maxXor(int num) {
    int cur = 0;
    int result = 0;

    for (int i = 30; i >= 0; i--) {
        int bit = (num >> i) & 1;
        int want = 1 - bit;  // XOR 최대화를 위해 반대 비트

        if (trie[cur][want] != 0) {
            result |= (1 << i);
            cur = trie[cur][want];
        } else {
            cur = trie[cur][bit];
        }
    }

    return result;
}

구조체로 캡슐화

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
const int MAX_NODE = 1000001;
const int ALPHA = 26;

struct Trie {
    int next[MAX_NODE][ALPHA];
    bool isEnd[MAX_NODE];
    int nodeCnt;

    void init() {
        nodeCnt = 1;
        for (int i = 0; i < ALPHA; i++) next[0][i] = 0;
        isEnd[0] = false;
    }

    int newNode() {
        for (int i = 0; i < ALPHA; i++) next[nodeCnt][i] = 0;
        isEnd[nodeCnt] = false;
        return nodeCnt++;
    }

    void insert(const char* s) {
        int cur = 0;
        for (int i = 0; s[i]; i++) {
            int c = s[i] - 'a';
            if (next[cur][c] == 0) next[cur][c] = newNode();
            cur = next[cur][c];
        }
        isEnd[cur] = true;
    }

    bool search(const char* s) {
        int cur = 0;
        for (int i = 0; s[i]; i++) {
            int c = s[i] - 'a';
            if (next[cur][c] == 0) return false;
            cur = next[cur][c];
        }
        return isEnd[cur];
    }
};

STL

vector + map 기반

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <map>
#include <string>

struct TrieNode {
    std::map<char, TrieNode*> children;
    bool isEnd = false;
};

class Trie {
    TrieNode* root;

public:
    Trie() { root = new TrieNode(); }

    void insert(const std::string& word) {
        TrieNode* cur = root;
        for (char c : word) {
            if (cur->children.find(c) == cur->children.end()) {
                cur->children[c] = new TrieNode();
            }
            cur = cur->children[c];
        }
        cur->isEnd = true;
    }

    bool search(const std::string& word) {
        TrieNode* cur = root;
        for (char c : word) {
            if (cur->children.find(c) == cur->children.end()) {
                return false;
            }
            cur = cur->children[c];
        }
        return cur->isEnd;
    }

    bool startsWith(const std::string& prefix) {
        TrieNode* cur = root;
        for (char c : prefix) {
            if (cur->children.find(c) == cur->children.end()) {
                return false;
            }
            cur = cur->children[c];
        }
        return true;
    }
};

배열 기반 (빠름)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <cstring>

class Trie {
    static const int MAX_NODE = 1000001;
    static const int ALPHA = 26;

    int next[MAX_NODE][ALPHA];
    bool isEnd[MAX_NODE];
    int nodeCnt;

public:
    Trie() {
        nodeCnt = 1;
        memset(next[0], 0, sizeof(next[0]));
        isEnd[0] = false;
    }

    void insert(const std::string& s) {
        int cur = 0;
        for (char ch : s) {
            int c = ch - 'a';
            if (next[cur][c] == 0) {
                next[cur][c] = nodeCnt;
                memset(next[nodeCnt], 0, sizeof(next[nodeCnt]));
                isEnd[nodeCnt] = false;
                nodeCnt++;
            }
            cur = next[cur][c];
        }
        isEnd[cur] = true;
    }

    bool search(const std::string& s) {
        int cur = 0;
        for (char ch : s) {
            int c = ch - 'a';
            if (next[cur][c] == 0) return false;
            cur = next[cur][c];
        }
        return isEnd[cur];
    }
};

사용 예시

자동 완성

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 접두사로 시작하는 모든 단어 찾기
void autocomplete(int cur, char* buffer, int len, char results[][100], int& resultCnt) {
    if (isEnd[cur]) {
        buffer[len] = '\0';
        int i = 0;
        while (buffer[i]) {
            results[resultCnt][i] = buffer[i];
            i++;
        }
        results[resultCnt][i] = '\0';
        resultCnt++;
    }

    for (int c = 0; c < ALPHA; c++) {
        if (trie[cur][c] != 0) {
            buffer[len] = 'a' + c;
            autocomplete(trie[cur][c], buffer, len + 1, results, resultCnt);
        }
    }
}

void findByPrefix(const char* prefix, char results[][100], int& resultCnt) {
    int cur = 0;
    int len = 0;
    char buffer[100];

    for (int i = 0; prefix[i]; i++) {
        int c = prefix[i] - 'a';
        if (trie[cur][c] == 0) {
            resultCnt = 0;
            return;
        }
        buffer[len++] = prefix[i];
        cur = trie[cur][c];
    }

    resultCnt = 0;
    autocomplete(cur, buffer, len, results, resultCnt);
}

최장 공통 접두사

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int longestCommonPrefix() {
    int cur = 0;
    int length = 0;

    while (true) {
        int childCount = 0;
        int nextNode = -1;

        for (int c = 0; c < ALPHA; c++) {
            if (trie[cur][c] != 0) {
                childCount++;
                nextNode = trie[cur][c];
            }
        }

        // 자식이 하나이고, 현재 노드가 단어의 끝이 아닐 때만 계속
        if (childCount == 1 && !isEnd[cur]) {
            length++;
            cur = nextNode;
        } else {
            break;
        }
    }

    return length;
}

XOR 최댓값 쿼리

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 배열에서 arr[i] XOR arr[j]의 최댓값

int maxXorPair(int arr[], int n) {
    init();

    for (int i = 0; i < n; i++) {
        insert(arr[i]);
    }

    int maxVal = 0;
    for (int i = 0; i < n; i++) {
        int val = maxXor(arr[i]);
        if (val > maxVal) maxVal = val;
    }

    return maxVal;
}

사전식 순서 K번째 문자열

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int countSubtree[MAX_NODE];  // 서브트리에 있는 단어 수

void calcCount(int cur) {
    countSubtree[cur] = wordCount[cur];
    for (int c = 0; c < ALPHA; c++) {
        if (trie[cur][c] != 0) {
            calcCount(trie[cur][c]);
            countSubtree[cur] += countSubtree[trie[cur][c]];
        }
    }
}

void kthWord(int cur, int k, char* result, int len) {
    if (k <= wordCount[cur]) {
        result[len] = '\0';
        return;
    }
    k -= wordCount[cur];

    for (int c = 0; c < ALPHA; c++) {
        if (trie[cur][c] != 0) {
            if (countSubtree[trie[cur][c]] >= k) {
                result[len] = 'a' + c;
                kthWord(trie[cur][c], k, result, len + 1);
                return;
            }
            k -= countSubtree[trie[cur][c]];
        }
    }
}

주의사항 / Edge Cases

노드 수 계산

1
2
3
4
5
// 최대 노드 수 = 문자열 총 길이
// N개 문자열, 평균 길이 L → 최대 N * L 노드

// 예: 10만 개 문자열, 최대 길이 10
// MAX_NODE = 100000 * 10 = 1000000

메모리 최적화

1
2
3
4
5
6
7
// 배열 대신 map 사용 (노드당 메모리 절약)
// 단, 접근 시간 O(log Σ) → O(1)보다 느림

struct TrieNode {
    std::map<int, int> children;  // 동적 할당
    bool isEnd;
};

빈 문자열

1
2
3
4
5
6
7
8
9
// 빈 문자열 삽입 시 루트가 isEnd가 됨
void insert(const char* s) {
    int cur = 0;
    // s가 빈 문자열이면 루프 실행 안 됨
    for (int i = 0; s[i]; i++) {
        // ...
    }
    isEnd[cur] = true;  // cur = 0 (루트)
}

대소문자 혼합

1
2
3
4
5
6
7
// 대소문자 모두 처리
const int ALPHA = 52;  // a-z, A-Z

int charToIndex(char c) {
    if (c >= 'a' && c <= 'z') return c - 'a';
    return c - 'A' + 26;
}

삭제 연산

1
2
3
4
5
6
7
8
9
10
11
12
13
// 단순 삭제: isEnd만 false로
void erase(const char* s) {
    int cur = 0;
    for (int i = 0; s[i]; i++) {
        int c = s[i] - 'a';
        if (trie[cur][c] == 0) return;
        cur = trie[cur][c];
    }
    isEnd[cur] = false;
}

// 노드까지 삭제하려면 더 복잡한 구현 필요
// (prefixCount 사용하여 리프 노드 정리)

추천 문제

난이도문제링크
Silver문자열 집합BOJ 14425
Gold전화번호 목록BOJ 5052
Gold개미굴BOJ 14725
Platinum휴대폰 자판BOJ 5670
PlatinumXOR 합BOJ 13505
This post is licensed under CC BY 4.0 by the author.