개발/머신러닝
파이썬. 머신러닝 - 내가 만든 Linear Regression으로 예측하기, joblib으로 저장
웅'jk
2022. 12. 1. 17:50
먼저 다시 한번 더 Linear Regression 을 학습하는 방법에 대해 알아봅시다.
1. 학습할 데이터를 가져옵니다.
2. 데이터의 NaN 값 여부를 확인
- NaN값이 존재할 경우 처리.
3. 예측할 컬럼을 y, 예측에 이용할 컬럼을 X축으로 둡니다.
4. X축 자료에 카테고리컬 데이터의 문자가 있다면, Encoding 을 통해 숫자로 바꿔줍니다.
없다면 그냥 사용하시면 됩니다.
- 1~2 개 -> label , 3 개이상 , One-Hot
https://mokokodevelop.tistory.com/45
5. X의 데이터를 Traing , Test 로 나눕니다.
https://mokokodevelop.tistory.com/47
6. 인공지능을 이용해 학습을 시작합니다.
-ex) LinearRegression 의 경우 LinearRegression()으로 비어있는 인공지능 생성
https://mokokodevelop.tistory.com/48
7. X_test 값으로 예측을 합니다.
8. 실제값과 예측값을 확인합니다.
9. 오류를 구해 인공지능의 성능을 체크합니다.
먼저 예측을 하기 전 위 순서대로 인공지능을 다시 한번 만들어봅시다.
#1. 라이브러리 import
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#2. read data
df = pd.read_csv('../data/50_Startups.csv')
#3. NaN 값 확인
df.isna().sum()
#4. X , y 지정
X = df.loc[:,'R&D Spend':'State']
y = df['Profit']
# 문자의 갯수확인
X['State'].nunique()
X['State'].unique()
sorted(X['State'].unique())
# 5.문자를 숫자로 바꾸기
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
ct=ColumnTransformer( [ ('encoder' , OneHotEncoder() , [3]) ],
remainder = 'passthrough')
X = ct.fit_transform(X)
#6.training / test 분리
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.2 , random_state = 1 )
#7. 모델링하기
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X_train,y_train)
y_pred = regressor.predict(X_test)
#8. 에러 확인
error = y_test - y_pred
((error) **2).mean()
#9. 차트로 확인
plt.plot(y_test.values)
plt.plot(y_pred)
plt.legend(['real','pred'])
plt.savefig('chart1.jpg')
plt.show()
위 처럼 만든 인공지능을 이제 예측을 해봅니다.
# 새롭게 예측할 데이터
new_data = np.array([130000,150000,400000,'Florida'])
new_data = new_data.reshape(1,4)
# 문자열을 숫자로 인코딩
new_data = ct.transform(new_data).astype(float)
# 데이터 예측결과
regressor.predict(new_data)
위 인공지능과 , 문자열을 숫자로 바꿔준 ct를 저장하기 위해서는
joblib을 이용합니다.
import joblib
joblib.dump(regressor,'regressor.pkl')
joblib.dump(ct,'ct.pkl')