티스토리 뷰
KD Tree : 일반적인 바이너리 트리를 k 차원으로 확장한 트리
데이터 타입이 n 차원 벡터일 때 (x1, ..., xn)
x1을 기준으로 트리 구성, x2를 기준으로 트리 구성, ..., xn을 기준으로 트리 구성, x1을 기준으로 트리 구성, ... 을 반복하여 트리를 만드는 것
ex) 2차원에 대한 KD Tree 만들기
input : (5, 3), (5, 4), (3, 9), (4, 6), (6, 2), (1, 1), (4, 1), (6, 7)
root : (5, 3)
left child : xi 차원의 값을 기준으로 작은 것, right child : xi 차원의 값을 기준으로 이상인 것
위와 같이 차원을 분리하여 각 차원에 대한 대소 비교를 통해 트리를 구성하는 것이 KD Tree
KD Tree 응용 분야 :
클러스터링, 그래픽스, ....
ex) DBSCAN 알고리즘
KD Tree 적용 분야 :
kNN 알고리즘, nearest neighbor 찾기, radius neighbor 찾기, .....
즉, 현재 주어진 모든 점과 거리를 비교하여 조건을 만족하는 특정 점을 찾는 것이 아닌, 차원을 분리하여 일부 점에 대해서만 조건을 만족하는 지 여부를 판단하는 곳에 사용
하지만 2,3차원 벡터로 이루어진 데이터 타입이 아닌, 차원이 매우 큰 경우, 각 차원마다 모두 탐색하여야 하기 때문에 매우 비효율적일 수 있음
(이에 대한 대안 : ball tree)
KD Tree와 Ball Tree에 N개의 점이 있고, K 차원으로 각 점이 이루어져 있을 때,
가장 가까운 점을 찾는 방법의 복잡도는 다음과 같다.
KD Tree : 평균 : O(log N), 최악 : O(N^(1-1/K))
최악의 경우 : K가 매우 클 때, 데이터가 특정 차원에 몰려 있을 때
(크다의 기준은 애매하지만 보통 10차원 이상이면 크다고 보는 듯)
Ball Tree : 평균 : O(log N), 최악 : O(N)
최악의 경우 : 점이 매우 불균일하게 있을 때, 즉 클러스터링이 전혀 되지 않을 때 (점 1개당 클러스터 1개)
Ball Tree는 차원이 크더라도 데이터가 특정 차원에 몰려 있지 않다면, 즉, 차원이 크더라도 비교적 일반적으로 클러스터링이 이루어진다면 탐색에 평균적으로 O(log N)
KD Tree C++ 구현
1. 노드 생성 및 KD Tree 구성
KD Tree 생성 시 다음과 같이 선언 후 이후 insert로 노드를 삽입한다.
- Node *root = nullptr;
struct Node {
int points[K];
Node *left;
Node *right;
};
Node* createNode(int points[])
{
if((sizeof(points) / sizeof(int)) != K) {
cout << "Error : input size is not " << K << "\n";
return nullptr;
}
Node *node = new Node;
for(int i = 0; i < K; i++)
node->points[i] = points[i];
node->left = nullptr;
node->right = nullptr;
return node;
}
2. 노드 삽입
노드 삽입은 다음과 같다.
- x 차원 기준인 경우, x 차원 값 대소 비교하여 탐색
- y 차원 기준인 경우, y 차원 값 대소 비교하여 탐색
- null일 때까지 recursive하게 탐색
- null인 경우 해당 위치에 새 노드 삽입
Node* insertNode(Node *root, int points[], int depth)
{
if(root == nullptr)
return createNode(points);
int cd = depth % K;
// 축 기준 작은 쪽 (왼쪽 서브 트리)
if(points[cd] < root->points[cd])
root->left = insertNode(root->left, points, depth+1);
// 축 기준 큰 쪽 (오른쪽 서브 트리)
else
root->right = insertNode(root->right, points, depth+1);
return root;
}
3. 노드 삭제
노드 삭제는 다음과 같다.
- 삭제할 노드 찾기
- 삭제할 게 리프인 경우, 단순 삭제 후 종료
- 삭제할 노드의 기준이 x축이라면, x축 기준으로 큰 것 중 가장 작은 값
- y축이라면, y축 기준으로 큰 것 중 가장 작은 값
- 즉, 오른쪽 서브 트리 중 해당 축 값 가장 작은 값
- 오른쪽 노드가 없다면 왼쪽 노드 중 해당 축의 값 중 가장 작은 값
- 해당 노드를 삭제한 노드 위치로 이동
- 탐색했던 나머지를 오른쪽 서브 트리로 구성
- 해당 노드를 삭제로 보고 위 내용을 반복
bool eqPoints(int points1[], int points2[])
{
bool res = true;
for(int i = 0; i < K; i++) {
if(points1[i] != points2[i]) {
res = false;
break;
}
}
return res;
}
void copyPoints(int points1[], int points2[])
{
for(int i = 0; i < K; i++)
points1[i] = points2[i];
}
Node* compareMinNode(Node *x, Node *y, Node *z, int d)
{
// x,y,z 노드 중 d 축 기준 가장 작은값 갖는 노드 리턴
Node *res = x;
if(y != nullptr && y->points[d] < res->points[d])
res = y;
if(z != nullptr && z->points[d] < res->points[d])
res = z;
return res;
}
Node* findMin(Node *root, int d, int depth)
{
// d : 현재 탐색 중인 root의 기준 축
if(root == nullptr)
return nullptr;
int cd = depth % K;
if(cd == d) {
// 현재 탐색 중인 root가 가장 작은 값 가짐
if(root->left == nullptr)
return root;
return findMin(root->left, d, depth+1);
}
Node *leftMinNode = findMin(root->left, d, depth+1);
Node *rightMinNode = findMin(root->right, d, depth+1);
Node *minNode = compareMinNode(root, leftMinNode, rightMinNode, d);
return minNode;
}
Node* deleteNode(Node *root, int points[], int depth)
{
if(root == nullptr)
return nullptr;
int cd = depth % K;
// 삭제할 노드 찾음
if(eqPoints(root->points, points)) {
// 해당 차원의 값이 큰 것 중 가장 작은 것
// 즉, 오른쪽 서브 트리 중 가장 작은 것
// 없으면 왼쪽 서브 트리를 기준으로 탐색 후, 오른쪽 서브 트리로 변경
if(root->right != nullptr) {
Node *minNode = findMin(root->right, cd, 0);
// 해당 값 복사
copyPoints(root->points, minNode->points);
// 옮겨진 노드도 삭제로 보고 이를 반복
root->right = deleteNode(root->right, minNode->points, depth+1);
}
else if(root->left != nullptr) {
Node *minNode = findMin(root->left, cd, 0);
copyPoints(root->points, minNode->points);
// 삭제된 노드의 왼쪽 서브트리를 오른쪽 서브트리로 구성
root->right = deleteNode(root->left, minNode->points, depth+1);
// 왼쪽 서브트리는 null로 지정
root->left = nullptr;
}
else {
// 리프인 경우
delete root;
return nullptr;
}
return root;
}
// 아직 못 찾음
if(points[cd] < root->points[cd])
root->left = deleteNode(root->left, points, depth+1);
else
root->right = deleteNode(root->right, points, depth+1);
return root;
}
4. 가장 가까운 노드 찾기
- 거리 기준에 따라 가장 가까운 노드 찾기
- 현재 탐색 차원의 값과 비교하며 거리 가장 가까운 노드로 업데이트
- 아래는 manhattan distance 기준으로 가장 가까운 노드 찾는 케이스
int distance(int points1[], int points2[])
{
// manhattan distance
int res = 0;
for(int i = 0; i < K; i++)
res += abs(points1[i] - points2[i]);
return res;
}
void nearest(Node *root, int target[], Node **best, int &bestDist, int depth)
{
// 최근접 노드 계산
if(root == nullptr)
return;
int dist = distance(root->points, target);
if(dist < bestDist) {
bestDist = dist;
*best = root;
}
int cd = depth % K;
Node *good = nullptr;
if(target[cd] < root->points[cd])
good = root->left;
else
good = root->right;
Node *bad = nullptr;
if(target[cd] < root->points[cd])
bad = root->right;
else
bad = root->left;
nearest(good, target, best, bestDist, depth+1);
// 현재 차원에 대한 확인
if(abs(target[cd] - root->points[cd]) < bestDist)
nearest(bad, target, best, bestDist, depth+1);
}
pair<Node*, int> findNearest(Node *root, int target[])
{
Node *best = nullptr;
int bestDist = 987654321;
int depth = 0;
nearest(root, target, &best, bestDist, depth);
return {best, bestDist};
}
5. 특정 거리 이내의 모든 노드 찾기
- 거리 차이가 radius 이내인 경우, 왼쪽 서브트리, 오른쪽 서브트리 순으로 탐색
- 아닌 경우, 오른쪽 서브트리, 왼쪽 서브트리 순으로 탐색
- 아래는 manhattan distance 기준으로 가장 가까운 노드 찾는 케이스
void findWithinDistance(Node *root, int target[], int radius, int depth, vector<Node*> &res)
{
if(root == nullptr)
return;
int dist = distance(root->points, target);
if(dist <= radius)
res.push_back(root);
int cd = depth % K;
if(abs(target[cd] - root->points[cd]) <= radius) {
findWithinDistance(root->left, target, radius, depth+1, res);
findWithinDistance(root->right, target, radius, depth+1, res);
}
else {
findWithinDistance(root->right, target, radius, depth+1, res);
findWithinDistance(root->left, target, radius, depth+1, res);
}
}
6. 메모리 해제
KD Tree의 각 노드의 오른, 왼쪽 하위 노드에 대해 null이 아닌 경우 recursive하게 탐색하며 메모리를 해제한다.
void deleteKDTree(Node *root)
{
if(root == nullptr)
return;
if(root->left != nullptr)
deleteKDTree(root->left);
if(root->right != nullptr)
deleteKDTree(root->right);
delete root;
return;
}
전체 코드
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int K = 2;
struct Node {
int points[K];
Node *left;
Node *right;
};
Node* createNode(int points[])
{
if((sizeof(points) / sizeof(int)) != K) {
cout << "Error : input size is not " << K << "\n";
return nullptr;
}
Node *node = new Node;
for(int i = 0; i < K; i++)
node->points[i] = points[i];
node->left = nullptr;
node->right = nullptr;
return node;
}
Node* insertNode(Node *root, int points[], int depth)
{
if(root == nullptr)
return createNode(points);
int cd = depth % K;
// 축 기준 작은 쪽 (왼쪽 서브 트리)
if(points[cd] < root->points[cd])
root->left = insertNode(root->left, points, depth+1);
// 축 기준 큰 쪽 (오른쪽 서브 트리)
else
root->right = insertNode(root->right, points, depth+1);
return root;
}
bool eqPoints(int points1[], int points2[])
{
bool res = true;
for(int i = 0; i < K; i++) {
if(points1[i] != points2[i]) {
res = false;
break;
}
}
return res;
}
void copyPoints(int points1[], int points2[])
{
for(int i = 0; i < K; i++)
points1[i] = points2[i];
}
Node* compareMinNode(Node *x, Node *y, Node *z, int d)
{
// x,y,z 노드 중 d 축 기준 가장 작은값 갖는 노드 리턴
Node *res = x;
if(y != nullptr && y->points[d] < res->points[d])
res = y;
if(z != nullptr && z->points[d] < res->points[d])
res = z;
return res;
}
Node* findMin(Node *root, int d, int depth)
{
// d : 현재 탐색 중인 root의 기준 축
if(root == nullptr)
return nullptr;
int cd = depth % K;
if(cd == d) {
// 현재 탐색 중인 root가 가장 작은 값 가짐
if(root->left == nullptr)
return root;
return findMin(root->left, d, depth+1);
}
Node *leftMinNode = findMin(root->left, d, depth+1);
Node *rightMinNode = findMin(root->right, d, depth+1);
Node *minNode = compareMinNode(root, leftMinNode, rightMinNode, d);
return minNode;
}
Node* deleteNode(Node *root, int points[], int depth)
{
if(root == nullptr)
return nullptr;
int cd = depth % K;
// 삭제할 노드 찾음
if(eqPoints(root->points, points)) {
// 해당 차원의 값이 큰 것 중 가장 작은 것
// 즉, 오른쪽 서브 트리 중 가장 작은 것
// 없으면 왼쪽 서브 트리를 기준으로 탐색 후, 오른쪽 서브 트리로 변경
if(root->right != nullptr) {
Node *minNode = findMin(root->right, cd, 0);
// 해당 값 복사
copyPoints(root->points, minNode->points);
// 옮겨진 노드도 삭제로 보고 이를 반복
root->right = deleteNode(root->right, minNode->points, depth+1);
}
else if(root->left != nullptr) {
Node *minNode = findMin(root->left, cd, 0);
copyPoints(root->points, minNode->points);
// 삭제된 노드의 왼쪽 서브트리를 오른쪽 서브트리로 구성
root->right = deleteNode(root->left, minNode->points, depth+1);
// 왼쪽 서브트리는 null로 지정
root->left = nullptr;
}
else {
// 리프인 경우
delete root;
return nullptr;
}
return root;
}
// 아직 못 찾음
if(points[cd] < root->points[cd])
root->left = deleteNode(root->left, points, depth+1);
else
root->right = deleteNode(root->right, points, depth+1);
return root;
}
void printKDTree(Node *root, int space, int depth)
{
if(root == nullptr)
return;
space += 10;
printKDTree(root->right, space, depth+1);
cout << "\n";
for(int i = 10; i < space; i++)
cout << " ";
for(int i = 0; i < K; i++)
cout << root->points[i] << ",";
printKDTree(root->left, space, depth+1);
}
int distance(int points1[], int points2[])
{
// manhattan distance
int res = 0;
for(int i = 0; i < K; i++)
res += abs(points1[i] - points2[i]);
return res;
}
void nearest(Node *root, int target[], Node **best, int &bestDist, int depth)
{
// 최근접 노드 계산
if(root == nullptr)
return;
int dist = distance(root->points, target);
if(dist < bestDist) {
bestDist = dist;
*best = root;
}
int cd = depth % K;
Node *good = nullptr;
if(target[cd] < root->points[cd])
good = root->left;
else
good = root->right;
Node *bad = nullptr;
if(target[cd] < root->points[cd])
bad = root->right;
else
bad = root->left;
nearest(good, target, best, bestDist, depth+1);
// 현재 차원에 대한 확인
if(abs(target[cd] - root->points[cd]) < bestDist)
nearest(bad, target, best, bestDist, depth+1);
}
pair<Node*, int> findNearest(Node *root, int target[])
{
Node *best = nullptr;
int bestDist = 987654321;
int depth = 0;
nearest(root, target, &best, bestDist, depth);
return {best, bestDist};
}
void findWithinDistance(Node *root, int target[], int radius, int depth, vector<Node*> &res)
{
if(root == nullptr)
return;
int dist = distance(root->points, target);
if(dist <= radius)
res.push_back(root);
int cd = depth % K;
if(abs(target[cd] - root->points[cd]) <= radius) {
findWithinDistance(root->left, target, radius, depth+1, res);
findWithinDistance(root->right, target, radius, depth+1, res);
}
else {
findWithinDistance(root->right, target, radius, depth+1, res);
findWithinDistance(root->left, target, radius, depth+1, res);
}
}
void deleteKDTree(Node *root)
{
if(root == nullptr)
return;
if(root->left != nullptr)
deleteKDTree(root->left);
if(root->right != nullptr)
deleteKDTree(root->right);
delete root;
return;
}
void printFormat()
{
cout << " ^\n";
cout << " |\n";
cout << "right\n";
cout << " |\n";
cout << "root\n";
cout << " |\n";
cout << "left\n";
cout << " |\n";
cout << " v\n";
return;
}
int main()
{
printFormat();
Node *root = nullptr;
int depth = 0;
int points[8][2] = {
{5,3},{5,4},{3,9},{4,6},{6,2},{1,1},{4,1},{6,7}
};
for(int i = 0; i < 8; i++) {
root = insertNode(root, points[i], depth);
}
int space = 0;
printKDTree(root, space, depth);
cout << "\n=====================\n";
cout << "delete : (3,9)\n";
root = deleteNode(root, points[2], depth);
printKDTree(root, space, depth);
cout << "\n=====================\n";
cout << "delete : (5,4)\n";
root = deleteNode(root, points[1], depth);
printKDTree(root, space, depth);
cout << "\n=====================\n";
cout << "nearest target : (5,8)\n";
int target[2] = {5,8};
pair<Node*, int> best = findNearest(root, target);
for(int i = 0; i < K; i++)
cout << (best.first)->points[i] << ",";
cout << "\n";
cout << "dist : " << best.second << "\n";
cout << "\n=====================\n";
cout << "within radius : 3, target : (4,5)\n";
int withinTarget[2] = {4,5};
int radius = 3;
vector<Node*> res;
findWithinDistance(root, withinTarget, radius, depth, res);
for(auto node : res) {
for(int i = 0; i < K; i++) {
cout << node->points[i] << ",";
}
cout << "\n";
}
deleteKDTree(root);
}
결과
'알고리즘' 카테고리의 다른 글
KMP 알고리즘 정리 (1) | 2024.05.01 |
---|---|
Trie 정리 (1) | 2024.05.01 |
C++ vector, priority_queue 정렬 정리 (1) | 2024.04.21 |
MST 구하기 정리 C++ (0) | 2024.04.21 |
유니온 파인드 정리 C++ (0) | 2024.04.21 |