K-Nearest Neighbor(KNN) 알고리즘은 Supervised Learning (지도학습) 머신러닝에서 가장 유명한 알고리즘 중 하나입니다. KNN 모델이 어떻게 동작하는지, 핵심적인 개념과 가장 적절한 K값을 찾는 방법, 그리고 사이킷런(Scikit-learn)으로 KNN 알고리즘을 적용해보는 예제를 다뤄보겠습니다.
KNN(K-nearest Neighbor) 알고리즘이란?
K-nearest neighbor (KNN) 알고리즘은 대표적인 Supervised Learning(지도학습) 알고리즘 중 하나입니다. KNN 알고리즘은 Classification이나 Regression에 사용되는 정말 간단하면서도 널리 활용되는 머신러닝 알고리즘입니다. KNN알고리즘은 test data에 있는 각 데이터 포인트 마다 모든 training 데이터에 있는 데이터 포인트까지의 거리를 계산한 후 가장 가까이에 위치한 K개의 데이터 포인터를 선택하고, 그 중 가장 많은 class로 최종 예측하는 과정을 거칩니다.
K-nearest neighbor(KNN)은 "lazy learner" 라고도 불립니다. 왜냐하면 KNN 알고리즘은 model training 시간에는 하는게 거의 없고, 예측(Prediction)할 때 모든 계산이 이루어지기 때문입니다.
예를들면, 학습(training) 데이터가 세가지 클래스로 나뉘어 라벨링 되어있다고 생각해 봅시다. 이 모델은 training할 때는 Initialization 하는 것 말고는 하는게 없습니다. Prediction을 위해 새로운 데이터가 들어오면 그 데이터로 부터 기존 학습 데이터들까지의 모든 거리를 측정합니다. 그리고 그 거리를 순서대로 나열해봅니다.
그 후 설정한 K개의 가장 가까운 데이터들을 선택합니다. 그 중에서 가장 많이 있는 class로 최종 예측을 하는 방식을 사용합니다. 이런 이유로 KNN(K-nearest neighbor) 모델은 머신러닝 모델 중 가장 이해하기 쉽고 간단하고 직관정인 모델이라고도 여겨집니다.
Small K vs Large K Trade-off
KNN(K-nearest neighbor) 알고리즘에서는 K값을 어떻게 설정하느냐에 따라 결과 값이 크게 달라질 수 있습니다. 너무 작은 K는 모델이 Overfitting할 가능성과, 높은 Variance를 가질 가능성을 높게 만듭니다. 너무 큰 K값은 모델이 Underfitting 하고, 높은 Bias를 가지게 만듭니다. 너무 큰 K값으로 예측을 하게되면, 지나치게 일반화된 예측값을 내놓기 때문입니다.
아래는 Iris 데이터를 KNN알고리즘으로 Classification을 수행한 결과입니다. K가 15일 때와 K가 3일 때 확연히 다른 결과를 가진다는 것을 볼 수 있습니다. 위에서 설명드린 것 처럼 너무 작은 K값은 모델이 Overfitting할 가능성을 높이게 된다는 것도 확인할 수 있습니다.
KNN(K-nearest neighbor) 알고리즘에서 적절한 K값을 정하는 방법?
어떻게 적절한 K를 고를 수 있는지 알려진 정답은 없습니다. 따라서 데이터에 따라 상황에 따라 조건을 바꿔서 알고리즘을 실행시켜보면서 적절한 K를 찾아나가야 합니다. 다음은 적절한 K를 찾을 수 있는 과정을 나열해본 것입니다.
- 먼저 랜덤 숫자의 K값을 정한다.
- K 값을 증가시켜가면서 에러율을 측정한다.
- 에러율이 가장 작은 K값으로 최종 선택한다.
KNN(K-nearest neighbor) 알고리즘이 거리를 계산하는 방법?
1. 유클리디안 거리(Euclidean Distance)
가장 많이 사용되는 방법으로, 새로운 데이터 포인트가 x고 기존 데이터가 y일 때 둘 사이의 거리의 차를 제곱한 것들의 합의 제곱근으로 거리를 구하는 방식입니다. 아래는 (1,2) 와 (4,6)의 데이터 포인트들 사이의 거리를 유클리디안 거리 계산 방식으로 구해본 예시입니다.
2. 맨해튼 거리(Manhattan Distance)
두번째로 많이 사용되는 방법으로, 새로운 데이터 포인트가 x고 기존 데이터가 y일 때 차원마다의 거리의 차를 절대값한 것을 모두 더해 구하는 방식입니다. 아래는 (1,2) 와 (4,6)의 데이터 포인트들 사이의 거리를 맨하탄 거리 계산 방식으로 구해본 예시입니다.
3. 해밍 거리(Hamming Distance)
해밍 거리는 카테고리 변수일 때 사용됩니다. 만약 x와 y의 값이 같으면 0, 다르면 1의 거리를 가집니다.
KNN(K-Nearest Neighbor) 사이킷런으로 코딩해보기
아래 예시 코드는 먼저 사이킷런(Scikit-learn)패키지에서 제공하는 아이리스(Iris)데이터를 로드를 한후, train과 test set = 7:3의 비율로 나눈다음, Best K를 찾기위해 (3,20)사이의 K를 가지고 cross-validation을 반복하며 실험을 해봅니다. 그리고 Best K를 찾은 후 KNN Classifier를 선언하고, test set에 대해 예측해본 예시입니다.
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# 아이리스 데이터 로드
iris = load_iris()
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=42)
# best K를 찾기 위해 3,20까지의 K값을 반복하여 수행해보고 5-fold cross validation 거침
k_values = list(range(3, 20))
accuracies = []
for k in k_values:
knn = KNeighborsClassifier(n_neighbors=k)
# Perform 5-fold cross-validation
cv_scores = cross_val_score(knn, X_train, y_train, cv=10)
accuracies.append(cv_scores.mean())
# 가장 높은 accuracy를 가지는 K값으로 정하기
best_k = k_values[accuracies.index(max(accuracies))]
print("Best K value: {}".format(best_k))
# 위에서 구한 K로 KNN 모델 선언
knn = KNeighborsClassifier(n_neighbors=best_k)
# 모델 학습
knn.fit(X_train, y_train)
# test 데이터에 Prediction하기
y_pred = knn.predict(X_test)
# Accuracy 계산하기
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: {:.2f}%".format(accuracy * 100))
# Generate the confusion matrix and classification report
cm = confusion_matrix(y_test, y_pred)
cr = classification_report(y_test, y_pred, target_names=iris.target_names)
print("Confusion Matrix:")
print(cm)
print("\nClassification Report:")
print(cr)
Summary
이번 포스팅에서는 머신러닝 Supervised Learning 알고리즘 중 가장 유명하고 널리쓰이는 알고리즘인 KNN알고리즘에 대해 알아보았습니다. 여기서 핵심적으로 이해해야할 내용은 KNN이 어떤 방식으로 데이터를 예측하는지, 왜 KNN이 Lazy learner라고 불리는지 그리고 가장 적합한 K를 찾기 위해서는 어떤 과정을 거쳐야하는지 입니다. K-nearest neighbor알고리즘을 이해하는 데 도움이 되셨길 바랍니다.
<참고하면 좋을 포스팅>
'머신러닝' 카테고리의 다른 글
머신러닝 Feature Selection 개념, 중요성, 하는 방법 (0) | 2023.02.22 |
---|---|
머신러닝 ROC curve(커브), AUC 개념 제대로 이해하기 (0) | 2023.02.22 |
머신러닝 서포트벡터머신 모델(Support Vector Machine) (0) | 2023.02.19 |
[파이썬 머신러닝] 사이킷런(Scikit-learn) 소개 (0) | 2023.02.18 |
노코드 머신러닝 툴 WEKA 사용방법(코딩 없이 Machine Learning하기) (0) | 2023.02.15 |
댓글