티스토리 뷰

알고리즘

KD Tree 정리

4567은 소수 2024. 5. 1. 20:25

KD Tree : 일반적인 바이너리 트리를 k 차원으로 확장한 트리

 

데이터 타입이 n 차원 벡터일 때 (x1, ..., xn)

x1을 기준으로 트리 구성, x2를 기준으로 트리 구성, ..., xn을 기준으로 트리 구성, x1을 기준으로 트리 구성, ... 을 반복하여 트리를 만드는 것

 

ex) 2차원에 대한 KD Tree 만들기

input : (5, 3), (5, 4), (3, 9), (4, 6), (6, 2), (1, 1), (4, 1), (6, 7)

root : (5, 3)

left child : xi 차원의 값을 기준으로 작은 것, right child : xi 차원의 값을 기준으로 이상인 것

 

 

위와 같이 차원을 분리하여 각 차원에 대한 대소 비교를 통해 트리를 구성하는 것이 KD Tree 

 

KD Tree 응용 분야 : 

클러스터링, 그래픽스, ....

ex) DBSCAN 알고리즘 

 

KD Tree 적용 분야 : 

kNN 알고리즘, nearest neighbor 찾기, radius neighbor 찾기, .....

 

즉, 현재 주어진 모든 점과 거리를 비교하여 조건을 만족하는 특정 점을 찾는 것이 아닌, 차원을 분리하여 일부 점에 대해서만 조건을 만족하는 지 여부를 판단하는 곳에 사용

 

하지만 2,3차원 벡터로 이루어진 데이터 타입이 아닌, 차원이 매우 큰 경우, 각 차원마다 모두 탐색하여야 하기 때문에 매우 비효율적일 수 있음 

(이에 대한 대안 : ball tree)

 

KD Tree와 Ball Tree에 N개의 점이 있고, K 차원으로 각 점이 이루어져 있을 때,

가장 가까운 점을 찾는 방법의 복잡도는 다음과 같다. 

 

KD Tree : 평균 : O(log N), 최악 : O(N^(1-1/K)) 

최악의 경우 : K가 매우 클 때, 데이터가 특정 차원에 몰려 있을 때

(크다의 기준은 애매하지만 보통 10차원 이상이면 크다고 보는 듯)

 

Ball Tree : 평균 : O(log N), 최악 : O(N)

최악의 경우 : 점이 매우 불균일하게 있을 때, 즉 클러스터링이 전혀 되지 않을 때 (점 1개당 클러스터 1개)

Ball Tree는 차원이 크더라도 데이터가 특정 차원에 몰려 있지 않다면, 즉, 차원이 크더라도 비교적 일반적으로 클러스터링이 이루어진다면 탐색에 평균적으로 O(log N) 


 

KD Tree C++ 구현

 

1. 노드 생성 및 KD Tree 구성

KD Tree 생성 시 다음과 같이 선언 후 이후 insert로 노드를 삽입한다.

- Node *root = nullptr;

struct Node {
	int points[K];
	Node *left;
	Node *right;
};

Node* createNode(int points[])
{
	if((sizeof(points) / sizeof(int)) != K) {
		cout << "Error : input size is not " << K << "\n";
		return nullptr;	
	}
	
	Node *node = new Node;
	for(int i = 0; i < K; i++)
		node->points[i] = points[i];
	
	node->left = nullptr;
	node->right = nullptr;
	return node;
}

 

2. 노드 삽입

노드 삽입은 다음과 같다.

- x 차원 기준인 경우, x 차원 값 대소 비교하여 탐색

- y 차원 기준인 경우, y 차원 값 대소 비교하여 탐색

- null일 때까지 recursive하게 탐색

- null인 경우 해당 위치에 새 노드 삽입

Node* insertNode(Node *root, int points[], int depth)
{
	if(root == nullptr)
		return createNode(points);
	
	int cd = depth % K;
	
	// 축 기준 작은 쪽 (왼쪽 서브 트리) 
	if(points[cd] < root->points[cd])
		root->left = insertNode(root->left, points, depth+1);
	// 축 기준 큰 쪽 (오른쪽 서브 트리) 
	else
		root->right = insertNode(root->right, points, depth+1);
	
	return root;
}

 

 

3. 노드 삭제

노드 삭제는 다음과 같다.

- 삭제할 노드 찾기

- 삭제할 게 리프인 경우, 단순 삭제 후 종료

- 삭제할 노드의 기준이 x축이라면, x축 기준으로 큰 것 중 가장 작은 값

- y축이라면, y축 기준으로 큰 것 중 가장 작은 값 

- 즉, 오른쪽 서브 트리 중 해당 축 값 가장 작은 값

 

- 오른쪽 노드가 없다면 왼쪽 노드 중 해당 축의 값 중 가장 작은 값

- 해당 노드를 삭제한 노드 위치로 이동

- 탐색했던 나머지를 오른쪽 서브 트리로 구성

- 해당 노드를 삭제로 보고 위 내용을 반복

bool eqPoints(int points1[], int points2[])
{
	bool res = true;
	for(int i = 0; i < K; i++) {
		if(points1[i] != points2[i]) {
			res = false;
			break;
		}
	}
	return res;
}

void copyPoints(int points1[], int points2[])
{
	for(int i = 0; i < K; i++)
		points1[i] = points2[i];
}

Node* compareMinNode(Node *x, Node *y, Node *z, int d)
{
	// x,y,z 노드 중 d 축 기준 가장 작은값 갖는 노드 리턴 
	Node *res = x;
	if(y != nullptr && y->points[d] < res->points[d])
		res = y;
	if(z != nullptr && z->points[d] < res->points[d])
		res = z;
		
	return res; 
}

Node* findMin(Node *root, int d, int depth)
{
	// d : 현재 탐색 중인 root의 기준 축
	if(root == nullptr)
		return nullptr;
	
	int cd = depth % K;
	
	if(cd == d) {
		// 현재 탐색 중인 root가 가장 작은 값 가짐 
		if(root->left == nullptr)
			return root;
		return findMin(root->left, d, depth+1);
	}
	
	Node *leftMinNode = findMin(root->left, d, depth+1);
	Node *rightMinNode = findMin(root->right, d, depth+1);
	
	Node *minNode = compareMinNode(root, leftMinNode, rightMinNode, d);
	return minNode;
}

Node* deleteNode(Node *root, int points[], int depth)
{
	if(root == nullptr)
		return nullptr;
		
	int cd = depth % K;
	
	// 삭제할 노드 찾음
	if(eqPoints(root->points, points)) {
		// 해당 차원의 값이 큰 것 중 가장 작은 것
		// 즉, 오른쪽 서브 트리 중 가장 작은 것
		// 없으면 왼쪽 서브 트리를 기준으로 탐색 후, 오른쪽 서브 트리로 변경 
		if(root->right != nullptr) {
			Node *minNode = findMin(root->right, cd, 0);
			// 해당 값 복사  
			copyPoints(root->points, minNode->points);
			// 옮겨진 노드도 삭제로 보고 이를 반복 
			root->right = deleteNode(root->right, minNode->points, depth+1);  
		}  
		else if(root->left != nullptr) {
			Node *minNode = findMin(root->left, cd, 0);
			copyPoints(root->points, minNode->points);
			// 삭제된 노드의 왼쪽 서브트리를 오른쪽 서브트리로 구성 
			root->right = deleteNode(root->left, minNode->points, depth+1);
			// 왼쪽 서브트리는 null로 지정 
			root->left = nullptr;
		}
		else {
			// 리프인 경우 
			delete root;
			return nullptr; 
		}
		
		return root;
	}
	
	// 아직 못 찾음 
	if(points[cd] < root->points[cd]) 
		root->left = deleteNode(root->left, points, depth+1);
	else 
		root->right = deleteNode(root->right, points, depth+1);
	return root;
}

 

4. 가장 가까운 노드 찾기

- 거리 기준에 따라 가장 가까운 노드 찾기 

- 현재 탐색 차원의 값과 비교하며 거리 가장 가까운 노드로 업데이트

 

- 아래는 manhattan distance 기준으로 가장 가까운 노드 찾는 케이스

int distance(int points1[], int points2[])
{
	// manhattan distance
	int res = 0;
	for(int i = 0; i < K; i++)
		res += abs(points1[i] - points2[i]);
	return res;
}

void nearest(Node *root, int target[], Node **best, int &bestDist, int depth)
{
	// 최근접 노드 계산 
	if(root == nullptr)
		return;
	
	int dist = distance(root->points, target);
	
	if(dist < bestDist) {
		bestDist = dist;
		*best = root;
	}
	
	int cd = depth % K;
	
	Node *good = nullptr;
	if(target[cd] < root->points[cd])
		good = root->left;
	else 
		good = root->right;
	
	Node *bad = nullptr;
	if(target[cd] < root->points[cd])
		bad = root->right;
	else 
		bad = root->left;
		
	nearest(good, target, best, bestDist, depth+1);
	
	// 현재 차원에 대한 확인 
	if(abs(target[cd] - root->points[cd]) < bestDist)
		nearest(bad, target, best, bestDist, depth+1);
}

pair<Node*, int> findNearest(Node *root, int target[])
{
	Node *best = nullptr;
	int bestDist = 987654321;
	int depth = 0;
	nearest(root, target, &best, bestDist, depth);
	
	return {best, bestDist};
}

 

 

5. 특정 거리 이내의 모든 노드 찾기

- 거리 차이가 radius 이내인 경우, 왼쪽 서브트리, 오른쪽 서브트리 순으로 탐색

- 아닌 경우, 오른쪽 서브트리, 왼쪽 서브트리 순으로 탐색

 

- 아래는 manhattan distance 기준으로 가장 가까운 노드 찾는 케이스

void findWithinDistance(Node *root, int target[], int radius, int depth, vector<Node*> &res)
{
	if(root == nullptr)
		return;
		
	int dist = distance(root->points, target);
	
	if(dist <= radius)
		res.push_back(root);
	
	int cd = depth % K;
	
	if(abs(target[cd] - root->points[cd]) <= radius) {
		findWithinDistance(root->left, target, radius, depth+1, res);
		findWithinDistance(root->right, target, radius, depth+1, res);
	}
	else {
		findWithinDistance(root->right, target, radius, depth+1, res);
		findWithinDistance(root->left, target, radius, depth+1, res);
	}
}

 

6. 메모리 해제

KD Tree의 각 노드의 오른, 왼쪽 하위 노드에 대해 null이 아닌 경우 recursive하게 탐색하며 메모리를 해제한다.

void deleteKDTree(Node *root)
{
	if(root == nullptr)
		return;
	if(root->left != nullptr)
		deleteKDTree(root->left);
	if(root->right != nullptr)
		deleteKDTree(root->right);
	
	delete root;
	return;
}

전체 코드

#include <iostream>
#include <vector>
#include <algorithm> 

using namespace std;

const int K = 2;

struct Node {
	int points[K];
	Node *left;
	Node *right;
};

Node* createNode(int points[])
{
	if((sizeof(points) / sizeof(int)) != K) {
		cout << "Error : input size is not " << K << "\n";
		return nullptr;	
	}
	
	Node *node = new Node;
	for(int i = 0; i < K; i++)
		node->points[i] = points[i];
	
	node->left = nullptr;
	node->right = nullptr;
	return node;
}

Node* insertNode(Node *root, int points[], int depth)
{
	if(root == nullptr)
		return createNode(points);
	
	int cd = depth % K;
	
	// 축 기준 작은 쪽 (왼쪽 서브 트리) 
	if(points[cd] < root->points[cd])
		root->left = insertNode(root->left, points, depth+1);
	// 축 기준 큰 쪽 (오른쪽 서브 트리) 
	else
		root->right = insertNode(root->right, points, depth+1);
	
	return root;
}

bool eqPoints(int points1[], int points2[])
{
	bool res = true;
	for(int i = 0; i < K; i++) {
		if(points1[i] != points2[i]) {
			res = false;
			break;
		}
	}
	return res;
}

void copyPoints(int points1[], int points2[])
{
	for(int i = 0; i < K; i++)
		points1[i] = points2[i];
}

Node* compareMinNode(Node *x, Node *y, Node *z, int d)
{
	// x,y,z 노드 중 d 축 기준 가장 작은값 갖는 노드 리턴 
	Node *res = x;
	if(y != nullptr && y->points[d] < res->points[d])
		res = y;
	if(z != nullptr && z->points[d] < res->points[d])
		res = z;
		
	return res; 
}

Node* findMin(Node *root, int d, int depth)
{
	// d : 현재 탐색 중인 root의 기준 축
	if(root == nullptr)
		return nullptr;
	
	int cd = depth % K;
	
	if(cd == d) {
		// 현재 탐색 중인 root가 가장 작은 값 가짐 
		if(root->left == nullptr)
			return root;
		return findMin(root->left, d, depth+1);
	}
	
	Node *leftMinNode = findMin(root->left, d, depth+1);
	Node *rightMinNode = findMin(root->right, d, depth+1);
	
	Node *minNode = compareMinNode(root, leftMinNode, rightMinNode, d);
	return minNode;
}

Node* deleteNode(Node *root, int points[], int depth)
{
	if(root == nullptr)
		return nullptr;
		
	int cd = depth % K;
	
	// 삭제할 노드 찾음
	if(eqPoints(root->points, points)) {
		// 해당 차원의 값이 큰 것 중 가장 작은 것
		// 즉, 오른쪽 서브 트리 중 가장 작은 것
		// 없으면 왼쪽 서브 트리를 기준으로 탐색 후, 오른쪽 서브 트리로 변경 
		if(root->right != nullptr) {
			Node *minNode = findMin(root->right, cd, 0);
			// 해당 값 복사  
			copyPoints(root->points, minNode->points);
			// 옮겨진 노드도 삭제로 보고 이를 반복 
			root->right = deleteNode(root->right, minNode->points, depth+1);  
		}  
		else if(root->left != nullptr) {
			Node *minNode = findMin(root->left, cd, 0);
			copyPoints(root->points, minNode->points);
			// 삭제된 노드의 왼쪽 서브트리를 오른쪽 서브트리로 구성 
			root->right = deleteNode(root->left, minNode->points, depth+1);
			// 왼쪽 서브트리는 null로 지정 
			root->left = nullptr;
		}
		else {
			// 리프인 경우 
			delete root;
			return nullptr; 
		}
		
		return root;
	}
	
	// 아직 못 찾음 
	if(points[cd] < root->points[cd]) 
		root->left = deleteNode(root->left, points, depth+1);
	else 
		root->right = deleteNode(root->right, points, depth+1);
	return root;
}

void printKDTree(Node *root, int space, int depth)
{
	if(root == nullptr)
		return;
	
	space += 10;
	
	printKDTree(root->right, space, depth+1);
	
	cout << "\n";
	for(int i = 10; i < space; i++) 
		cout << " ";
	
	for(int i = 0; i < K; i++)
		cout << root->points[i] << ",";
	
	printKDTree(root->left, space, depth+1);
}

int distance(int points1[], int points2[])
{
	// manhattan distance
	int res = 0;
	for(int i = 0; i < K; i++)
		res += abs(points1[i] - points2[i]);
	return res;
}

void nearest(Node *root, int target[], Node **best, int &bestDist, int depth)
{
	// 최근접 노드 계산 
	if(root == nullptr)
		return;
	
	int dist = distance(root->points, target);
	
	if(dist < bestDist) {
		bestDist = dist;
		*best = root;
	}
	
	int cd = depth % K;
	
	Node *good = nullptr;
	if(target[cd] < root->points[cd])
		good = root->left;
	else 
		good = root->right;
	
	Node *bad = nullptr;
	if(target[cd] < root->points[cd])
		bad = root->right;
	else 
		bad = root->left;
		
	nearest(good, target, best, bestDist, depth+1);
	
	// 현재 차원에 대한 확인 
	if(abs(target[cd] - root->points[cd]) < bestDist)
		nearest(bad, target, best, bestDist, depth+1);
}

pair<Node*, int> findNearest(Node *root, int target[])
{
	Node *best = nullptr;
	int bestDist = 987654321;
	int depth = 0;
	nearest(root, target, &best, bestDist, depth);
	
	return {best, bestDist};
}

void findWithinDistance(Node *root, int target[], int radius, int depth, vector<Node*> &res)
{
	if(root == nullptr)
		return;
		
	int dist = distance(root->points, target);
	
	if(dist <= radius)
		res.push_back(root);
	
	int cd = depth % K;
	
	if(abs(target[cd] - root->points[cd]) <= radius) {
		findWithinDistance(root->left, target, radius, depth+1, res);
		findWithinDistance(root->right, target, radius, depth+1, res);
	}
	else {
		findWithinDistance(root->right, target, radius, depth+1, res);
		findWithinDistance(root->left, target, radius, depth+1, res);
	}
}

void deleteKDTree(Node *root)
{
	if(root == nullptr)
		return;
	if(root->left != nullptr)
		deleteKDTree(root->left);
	if(root->right != nullptr)
		deleteKDTree(root->right);
	
	delete root;
	return;
}

void printFormat()
{
	cout << "   ^\n";
	cout << "   |\n";
	cout << "right\n";
	cout << "   |\n";
	cout << "root\n";
	cout << "   |\n";
	cout << "left\n";
	cout << "   |\n";
	cout << "   v\n";
	return;
}

int main()
{
	printFormat();
	
	Node *root = nullptr;
	
	int depth = 0;
	int points[8][2] = {
	{5,3},{5,4},{3,9},{4,6},{6,2},{1,1},{4,1},{6,7}
	};
	
	for(int i = 0; i < 8; i++) {
		root = insertNode(root, points[i], depth);
	}
	
	int space = 0;
	printKDTree(root, space, depth);
	
	cout << "\n=====================\n";
	cout << "delete : (3,9)\n";
	
	root = deleteNode(root, points[2], depth);
	
	printKDTree(root, space, depth);
	
	cout << "\n=====================\n";
	cout << "delete : (5,4)\n";
	
	root = deleteNode(root, points[1], depth);
	
	printKDTree(root, space, depth);
	
	cout << "\n=====================\n";
	cout << "nearest target : (5,8)\n";
	
	int target[2] = {5,8};
	pair<Node*, int> best = findNearest(root, target);
	
	for(int i = 0; i < K; i++) 
		cout << (best.first)->points[i] << ",";
	cout << "\n";
	cout << "dist : " << best.second << "\n";
	
	cout << "\n=====================\n";
	cout << "within radius : 3, target : (4,5)\n";
	
	int withinTarget[2] = {4,5};
	int radius = 3;
	vector<Node*> res;
	
	findWithinDistance(root, withinTarget, radius, depth, res);
	
	for(auto node : res) {
		for(int i = 0; i < K; i++) {
			cout << node->points[i] << ",";
		}
		cout << "\n";
	}
	
	deleteKDTree(root);
}

 

결과

'알고리즘' 카테고리의 다른 글

KMP 알고리즘 정리  (1) 2024.05.01
Trie 정리  (1) 2024.05.01
C++ vector, priority_queue 정렬 정리  (1) 2024.04.21
MST 구하기 정리 C++  (0) 2024.04.21
유니온 파인드 정리 C++  (0) 2024.04.21
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
TAG
more
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함