[머신러닝]실습-KNN : 붓꽃 분류

sk-learn 라이브러리는 붓꽃에 대한 데이터 셋을 제공한다.

목표

붓꽃의 꽃잎 길이, 꽃잎 너비, 꽃받침 길이, 꽃받침 너비 특징을 활용하여 3가지 품종 분류


※KNN : 붓꽃 데이터 학습

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sklearn.datasets import load_iris #붓꽃 데이터셋
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier #KNN
from sklearn import metrics #성능 지표 확인
from sklearn.model_selection import train_test_split #데이터 분배 라이브러리
 
i_data = load_iris()
#총 150개의 데이터 셋으로 구성
#꽃받침의 길이, 꽃받침의 넓이, 꽃잎의 길이, 꽃잎의 넓이로 특성이 구분됨
 
#데이터셋 분리
= i_data.data
= i_data.target
 
#데이터 프레임 생성
i_df = pd.DataFrame(i_data.data, columns = i_data.feature_names)
 
#train, test 데이터 셋 분배 : 7:3
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3, random_state = 1#순서 고정
 
#K는 추후 조정 가능
k_model = KNeighborsClassifier(n_neighbors=10#이웃 하는 데이터 수 10개
k_model.fit(X_train,y_train) #학습 수행
pred = k_model.predict(X_test) #예측 수행
 
#성능지표 확인 : 1에 가까울수록 성능이 좋음
print(metrics.accuracy_score(pred,y_test))
 
cs


KNN 모델에 대해 최적의 K(판단을 위한 인접 이웃 수)를 찾을 필요가 있음
->최적의 K의 경우 모델의 성능이 제일 높음

※최적K 탐색

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from sklearn.datasets import load_iris #붓꽃 데이터셋
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier #KNN
from sklearn import metrics #성능 지표 확인
from sklearn.model_selection import train_test_split #데이터 분배 라이브러리
import matplotlib.pyplot as plt #시각화
 
test_l = []#test에 대한 정확도
train_l = []#train에 대한 정확도
i_data = load_iris()
 
= i_data.data
= i_data.target
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3, random_state = 1#순서 고정
 
#k가 1~80(2씩 증가)일때의 성능 측정
for k in range(1,80,2):
    m = KNeighborsClassifier(n_neighbors = k )
    m.fit(X_train, y_train)
    train_pre = m.predict(X_train)
    test_pre = m.predict(X_test)
    
    s1 = metrics.accuracy_score(train_pre, y_train)
    s2 = metrics.accuracy_score(test_pre, y_test)
    
    train_l.append(s1)
    test_l.append(s2)
 
#시각화 실행
plt.figure(figsize=(10,10))
plt.plot(range(1,80,2),train_l, label='train')
plt.plot(range(1,80,2),test_l, label='test')
plt.legend()
plt.show()
cs


※ 시각화 결과
























- 인접 이웃 수가 많아 질수록 대체적으로 성능 감소(과대 적합)
- train data, test data에 대한 K의 정확도 값이 최대 값인 0~10 사이의 K가 적합한 값으로 보임
- 최적의 K에 대한 정확도는 대략 0.977777....

댓글