03-1 k-์ต๊ทผ์ ์ด์ ํ๊ท
์ง๊ธ๊น์ง๋ ์ฃผ์ด์ง ๊ฒ์ ๊ตฌ๋ถํ์ฌ ์ ๋ต์ ๊ณ ๋ฅด๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํด์๋ค. ๊ทธ๋ ๋ค๋ฉด ๋ฌด์ธ๊ฐ ๊ณ ๋ฅด๋ ๋ฌธ์ ๋ง๊ณ , ์์ธกํ๋ ๊ฒ๋ ํ ์ ์์๊น?
k-์ต๊ทผ์ ์ด์ ํ๊ท
์ง๋ ํ์ต ์๊ณ ๋ฆฌ์ฆ์ ํฌ๊ฒ ๋ถ๋ฅ์ ํ๊ท๋ก ๋๋๋ค. ๋ถ๋ฅ๋ ๋ง ๊ทธ๋๋ก ์ํ์ ๋ช ๊ฐ์ ํด๋์ค ์ค ํ๋๋ก ๋ถ๋ฅํ๋ ๋ฌธ์ ์ด๋ค. ํ๊ท๋ ํด๋์ค ์ค ํ๋๋ก ๋ถ๋ฅํ๋ ๊ฒ์ด ์๋๋ผ ์์์ ์ด๋ค ์ซ์๋ฅผ ์์ธกํ๋ ๋ฌธ์ ์ด๋ค.
k- ์ต๊ทผ์ ์ด์ ํ๊ท๋ ๋ถ๋ฅ์์์ ๋ง์ฐฌ๊ฐ์ง๋ก ์์ธกํ๋ ค๋ ์ํ์ ๊ฐ์ฅ ๊ฐ๊น์ด ์ํ k๊ฐ๋ฅผ ์ ํํ๋ค. ๊ทธ๋ฆฌ๊ณ ๊ทธ ์ํ๋ค์ ํ๊น๊ฐ์ ํ๊ท ์ ๊ตฌํ๋ ๋ฐฉ์์ผ๋ก ์์ธกํ๋ค.
๋ฐ์ดํฐ ์ค๋น
๋จผ์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ์ค๋นํ๊ณ , ์ฐ์ ๋๋ฅผ ๊ทธ๋ ค ๋ฐ์ดํฐ์ ํํ๋ฅผ ํ์ ํด๋ณด์.
import numpy as np
perch_length = np.array([8.4, 13.7, 15.0, 16.2, 17.4, 18.0, 18.7, 19.0, 19.6, 20.0, 21.0,
21.0, 21.0, 21.3, 22.0, 22.0, 22.0, 22.0, 22.0, 22.5, 22.5, 22.7,
23.0, 23.5, 24.0, 24.0, 24.6, 25.0, 25.6, 26.5, 27.3, 27.5, 27.5,
27.5, 28.0, 28.7, 30.0, 32.8, 34.5, 35.0, 36.5, 36.0, 37.0, 37.0,
39.0, 39.0, 39.0, 40.0, 40.0, 40.0, 40.0, 42.0, 43.0, 43.0, 43.5,
44.0])
perch_weight = np.array([5.9, 32.0, 40.0, 51.5, 70.0, 100.0, 78.0, 80.0, 85.0, 85.0, 110.0,
115.0, 125.0, 130.0, 120.0, 120.0, 130.0, 135.0, 110.0, 130.0,
150.0, 145.0, 150.0, 170.0, 225.0, 145.0, 188.0, 180.0, 197.0,
218.0, 300.0, 260.0, 265.0, 250.0, 250.0, 300.0, 320.0, 514.0,
556.0, 840.0, 685.0, 700.0, 700.0, 690.0, 900.0, 650.0, 820.0,
850.0, 900.0, 1015.0, 820.0, 1100.0, 1000.0, 1100.0, 1000.0,
1000.0])
import matplotlib.pyplot as plt
plt.scatter(perch_length, perch_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
์ฃผ์ด์ง ๋์ด ๋ฐ์ดํฐ๋ฅผ ๋ณด๋ฉด ๊ธธ์ด๊ฐ ์ปค์ง์ ๋ฐ๋ผ ๋ฌด๊ฒ๋ ๋์ด๋๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. ๋ณธ๊ฒฉ์ ์ผ๋ก ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ ์ ์ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ๋ก ๋๋์.
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(perch_length, perch_weight, random_state=42)