15. 트라이 (Trie)
15. 트라이 (Trie)
개념
트라이(Trie)는 문자열을 효율적으로 저장하고 검색하는 트리 자료구조이다. 접두사 트리(Prefix Tree)라고도 한다.
구조
1
2
3
4
5
6
7
8
9
10
(root)
/ | \
a b c
/| |
p n a
/ | |
p t r
/
l (apple) (car)
(ant)
특징
| 특징 | 설명 |
|---|---|
| 접두사 공유 | 공통 접두사를 가진 문자열은 경로 공유 |
| 검색 시간 | 문자열 길이 L에 대해 O(L) |
| 공간 | 문자 집합 크기(Σ)에 비례 |
해시 테이블과 비교
| 항목 | 트라이 | 해시 테이블 |
|---|---|---|
| 검색 | O(L) | O(L) 평균 |
| 접두사 검색 | O(L) | O(N * L) |
| 공간 | 많음 | 적음 |
| 정렬 순서 | 유지 | 없음 |
핵심 연산 & 시간복잡도
| 연산 | 시간복잡도 | 설명 |
|---|---|---|
| 삽입 | O(L) | L = 문자열 길이 |
| 검색 | O(L) | 정확히 일치하는 문자열 |
| 접두사 검색 | O(L) | 접두사로 시작하는 문자열 존재 여부 |
| 삭제 | O(L) | 문자열 제거 |
구현 (No-STL)
배열 기반 트라이 (소문자만)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
const int MAX_NODE = 1000001; // 최대 노드 수
const int ALPHA = 26; // 알파벳 크기
int trie[MAX_NODE][ALPHA];
bool isEnd[MAX_NODE]; // 단어의 끝인지
int nodeCnt;
void init() {
nodeCnt = 1; // 0은 루트
for (int i = 0; i < ALPHA; i++) {
trie[0][i] = 0;
}
isEnd[0] = false;
}
int newNode() {
for (int i = 0; i < ALPHA; i++) {
trie[nodeCnt][i] = 0;
}
isEnd[nodeCnt] = false;
return nodeCnt++;
}
void insert(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (trie[cur][c] == 0) {
trie[cur][c] = newNode();
}
cur = trie[cur][c];
}
isEnd[cur] = true;
}
bool search(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (trie[cur][c] == 0) {
return false;
}
cur = trie[cur][c];
}
return isEnd[cur];
}
bool startsWith(const char* prefix) {
int cur = 0;
for (int i = 0; prefix[i]; i++) {
int c = prefix[i] - 'a';
if (trie[cur][c] == 0) {
return false;
}
cur = trie[cur][c];
}
return true;
}
단어 개수 저장
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
const int MAX_NODE = 1000001;
const int ALPHA = 26;
int trie[MAX_NODE][ALPHA];
int wordCount[MAX_NODE]; // 이 노드에서 끝나는 단어 수
int prefixCount[MAX_NODE]; // 이 노드를 지나는 단어 수
int nodeCnt;
void init() {
nodeCnt = 1;
for (int i = 0; i < ALPHA; i++) {
trie[0][i] = 0;
}
wordCount[0] = 0;
prefixCount[0] = 0;
}
int newNode() {
for (int i = 0; i < ALPHA; i++) {
trie[nodeCnt][i] = 0;
}
wordCount[nodeCnt] = 0;
prefixCount[nodeCnt] = 0;
return nodeCnt++;
}
void insert(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (trie[cur][c] == 0) {
trie[cur][c] = newNode();
}
cur = trie[cur][c];
prefixCount[cur]++;
}
wordCount[cur]++;
}
int countWord(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (trie[cur][c] == 0) return 0;
cur = trie[cur][c];
}
return wordCount[cur];
}
int countPrefix(const char* prefix) {
int cur = 0;
for (int i = 0; prefix[i]; i++) {
int c = prefix[i] - 'a';
if (trie[cur][c] == 0) return 0;
cur = trie[cur][c];
}
return prefixCount[cur];
}
숫자/비트 트라이 (XOR 최댓값)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
const int MAX_NODE = 3200001; // 32비트 * 100000개
int trie[MAX_NODE][2];
int nodeCnt;
void init() {
nodeCnt = 1;
trie[0][0] = trie[0][1] = 0;
}
int newNode() {
trie[nodeCnt][0] = trie[nodeCnt][1] = 0;
return nodeCnt++;
}
void insert(int num) {
int cur = 0;
for (int i = 30; i >= 0; i--) {
int bit = (num >> i) & 1;
if (trie[cur][bit] == 0) {
trie[cur][bit] = newNode();
}
cur = trie[cur][bit];
}
}
int maxXor(int num) {
int cur = 0;
int result = 0;
for (int i = 30; i >= 0; i--) {
int bit = (num >> i) & 1;
int want = 1 - bit; // XOR 최대화를 위해 반대 비트
if (trie[cur][want] != 0) {
result |= (1 << i);
cur = trie[cur][want];
} else {
cur = trie[cur][bit];
}
}
return result;
}
구조체로 캡슐화
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
const int MAX_NODE = 1000001;
const int ALPHA = 26;
struct Trie {
int next[MAX_NODE][ALPHA];
bool isEnd[MAX_NODE];
int nodeCnt;
void init() {
nodeCnt = 1;
for (int i = 0; i < ALPHA; i++) next[0][i] = 0;
isEnd[0] = false;
}
int newNode() {
for (int i = 0; i < ALPHA; i++) next[nodeCnt][i] = 0;
isEnd[nodeCnt] = false;
return nodeCnt++;
}
void insert(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (next[cur][c] == 0) next[cur][c] = newNode();
cur = next[cur][c];
}
isEnd[cur] = true;
}
bool search(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (next[cur][c] == 0) return false;
cur = next[cur][c];
}
return isEnd[cur];
}
};
STL
vector + map 기반
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <map>
#include <string>
struct TrieNode {
std::map<char, TrieNode*> children;
bool isEnd = false;
};
class Trie {
TrieNode* root;
public:
Trie() { root = new TrieNode(); }
void insert(const std::string& word) {
TrieNode* cur = root;
for (char c : word) {
if (cur->children.find(c) == cur->children.end()) {
cur->children[c] = new TrieNode();
}
cur = cur->children[c];
}
cur->isEnd = true;
}
bool search(const std::string& word) {
TrieNode* cur = root;
for (char c : word) {
if (cur->children.find(c) == cur->children.end()) {
return false;
}
cur = cur->children[c];
}
return cur->isEnd;
}
bool startsWith(const std::string& prefix) {
TrieNode* cur = root;
for (char c : prefix) {
if (cur->children.find(c) == cur->children.end()) {
return false;
}
cur = cur->children[c];
}
return true;
}
};
배열 기반 (빠름)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <cstring>
class Trie {
static const int MAX_NODE = 1000001;
static const int ALPHA = 26;
int next[MAX_NODE][ALPHA];
bool isEnd[MAX_NODE];
int nodeCnt;
public:
Trie() {
nodeCnt = 1;
memset(next[0], 0, sizeof(next[0]));
isEnd[0] = false;
}
void insert(const std::string& s) {
int cur = 0;
for (char ch : s) {
int c = ch - 'a';
if (next[cur][c] == 0) {
next[cur][c] = nodeCnt;
memset(next[nodeCnt], 0, sizeof(next[nodeCnt]));
isEnd[nodeCnt] = false;
nodeCnt++;
}
cur = next[cur][c];
}
isEnd[cur] = true;
}
bool search(const std::string& s) {
int cur = 0;
for (char ch : s) {
int c = ch - 'a';
if (next[cur][c] == 0) return false;
cur = next[cur][c];
}
return isEnd[cur];
}
};
사용 예시
자동 완성
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 접두사로 시작하는 모든 단어 찾기
void autocomplete(int cur, char* buffer, int len, char results[][100], int& resultCnt) {
if (isEnd[cur]) {
buffer[len] = '\0';
int i = 0;
while (buffer[i]) {
results[resultCnt][i] = buffer[i];
i++;
}
results[resultCnt][i] = '\0';
resultCnt++;
}
for (int c = 0; c < ALPHA; c++) {
if (trie[cur][c] != 0) {
buffer[len] = 'a' + c;
autocomplete(trie[cur][c], buffer, len + 1, results, resultCnt);
}
}
}
void findByPrefix(const char* prefix, char results[][100], int& resultCnt) {
int cur = 0;
int len = 0;
char buffer[100];
for (int i = 0; prefix[i]; i++) {
int c = prefix[i] - 'a';
if (trie[cur][c] == 0) {
resultCnt = 0;
return;
}
buffer[len++] = prefix[i];
cur = trie[cur][c];
}
resultCnt = 0;
autocomplete(cur, buffer, len, results, resultCnt);
}
최장 공통 접두사
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int longestCommonPrefix() {
int cur = 0;
int length = 0;
while (true) {
int childCount = 0;
int nextNode = -1;
for (int c = 0; c < ALPHA; c++) {
if (trie[cur][c] != 0) {
childCount++;
nextNode = trie[cur][c];
}
}
// 자식이 하나이고, 현재 노드가 단어의 끝이 아닐 때만 계속
if (childCount == 1 && !isEnd[cur]) {
length++;
cur = nextNode;
} else {
break;
}
}
return length;
}
XOR 최댓값 쿼리
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 배열에서 arr[i] XOR arr[j]의 최댓값
int maxXorPair(int arr[], int n) {
init();
for (int i = 0; i < n; i++) {
insert(arr[i]);
}
int maxVal = 0;
for (int i = 0; i < n; i++) {
int val = maxXor(arr[i]);
if (val > maxVal) maxVal = val;
}
return maxVal;
}
사전식 순서 K번째 문자열
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int countSubtree[MAX_NODE]; // 서브트리에 있는 단어 수
void calcCount(int cur) {
countSubtree[cur] = wordCount[cur];
for (int c = 0; c < ALPHA; c++) {
if (trie[cur][c] != 0) {
calcCount(trie[cur][c]);
countSubtree[cur] += countSubtree[trie[cur][c]];
}
}
}
void kthWord(int cur, int k, char* result, int len) {
if (k <= wordCount[cur]) {
result[len] = '\0';
return;
}
k -= wordCount[cur];
for (int c = 0; c < ALPHA; c++) {
if (trie[cur][c] != 0) {
if (countSubtree[trie[cur][c]] >= k) {
result[len] = 'a' + c;
kthWord(trie[cur][c], k, result, len + 1);
return;
}
k -= countSubtree[trie[cur][c]];
}
}
}
주의사항 / Edge Cases
노드 수 계산
1
2
3
4
5
// 최대 노드 수 = 문자열 총 길이
// N개 문자열, 평균 길이 L → 최대 N * L 노드
// 예: 10만 개 문자열, 최대 길이 10
// MAX_NODE = 100000 * 10 = 1000000
메모리 최적화
1
2
3
4
5
6
7
// 배열 대신 map 사용 (노드당 메모리 절약)
// 단, 접근 시간 O(log Σ) → O(1)보다 느림
struct TrieNode {
std::map<int, int> children; // 동적 할당
bool isEnd;
};
빈 문자열
1
2
3
4
5
6
7
8
9
// 빈 문자열 삽입 시 루트가 isEnd가 됨
void insert(const char* s) {
int cur = 0;
// s가 빈 문자열이면 루프 실행 안 됨
for (int i = 0; s[i]; i++) {
// ...
}
isEnd[cur] = true; // cur = 0 (루트)
}
대소문자 혼합
1
2
3
4
5
6
7
// 대소문자 모두 처리
const int ALPHA = 52; // a-z, A-Z
int charToIndex(char c) {
if (c >= 'a' && c <= 'z') return c - 'a';
return c - 'A' + 26;
}
삭제 연산
1
2
3
4
5
6
7
8
9
10
11
12
13
// 단순 삭제: isEnd만 false로
void erase(const char* s) {
int cur = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (trie[cur][c] == 0) return;
cur = trie[cur][c];
}
isEnd[cur] = false;
}
// 노드까지 삭제하려면 더 복잡한 구현 필요
// (prefixCount 사용하여 리프 노드 정리)
추천 문제
This post is licensed under CC BY 4.0 by the author.