개발/머신러닝

파이썬 머신러닝 - softmax 로 처리한 결과값을 레이블인코딩, 오버피팅을 처리하는 콜백클래스 이용

웅'jk 2022. 12. 29. 11:39

이전 포스트에서 작성하였던 이미지 관련 학습데이터를 이용합니다.

https://mokokodevelop.tistory.com/110

 

파이썬 머신러닝 - 이미지 학습을 위한 flatten, 3개 이상 분류 액티베이션함수 softmax 와 loss 셋팅방

파이썬에서 이미지를 학습하기 위해서 먼저 이미지의 구조를 알아야 합니다. 파이썬에서 이미지는 가로 x 세로의 좌표값에 색상코드를 가지고 있게 됩니다. (0~ 255) 즉 이미지는 기본적으로 2차

mokokodevelop.tistory.com

 

만들어 놓은 model 로 X_test 값을 넣어 예측을 해봅니다.

 

예측한 결과물은 아웃레이어의 노드수가 10개였으므로 1개의 데이터당 10개씩 나오게 됩니다.

( 1번 데이터를 넣으면 결과물로 나온게 10가지라는 뜻입니다.)

 

이걸 이제 y_test와 비교하여 정확도를 알고 싶습니다만 10개의 데이터로 나오기 때문에 

레이블 인코딩을 해야합니다. 

(결과물은 2차원 데이터이고, 우리가 비교할 y_test는 1차원 데이터이기 때문입니다.)

 

굉장히 간단하게 numpy에 argmax 라는 함수를 이용하여 처리합니다.

y_pred = model.predict(X_test)
y_pred = y_pred.argmax(axis=1)

 

지난번 이용한 ANN모델에 epochs를 5개로 설정하여 학습을 하였다면 이번에는 30개로 늘려서 알아봅시다.

# model

def build_model() :
  model = Sequential()
  model.add(Flatten() )
  model.add(Dense(units=128 , activation = 'relu'))
  model.add(Dense(units=64 , activation = 'relu'))
  model.add(Dense(units=10 , activation = 'softmax'))
  model.compile(optimizer='adam',loss = 'sparse_categorical_crossentropy',metrics=['accuracy'] )
  return model
  
  # 빈 ann 모델 생성
  model = build_model()
  # 학습 결과를 epoch_history 에 저장
  epoch_history = model.fit(X_train,y_train,epochs=30,validation_split=0.2)

학습한 데이터의 결과를 loss , accuracy 로 나누어 차트로 확인하면 다음과 같습니다.

좌. loss , 우. accuracy

학습된 데이터의 loss 는 줄고 테스트한 loss는 내려갔다가 오히려 올라가고 있습니다.

accuracy 또한 학습된 데이터는 올라가나 테스트한 데이터는 올라갔다가 내려가기도 하고 변동이 심합니다.

 

이러한 사실을 통해 오버피팅이 발생하고 있음을 알 수 있습니다.

 

따라서 30번의 epochs은 의미가 없는 학습이게 됩니다.

이러한 문제를 해결하기 위해서 정확도가 일정 수치 이상 안올라가면 바로 종료하도록 콜백클래스를 이용하겠습니다.

class myCallback(tf.keras.callbacks.Callback) :
  def on_epoch_end(self,epoch,logs={}) :
    if logs['val_accuracy'] > 0.88 :
      print('\n 내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
      self.model.stop_training = True

my_cb = myCallback()

epoch_history = model.fit(X_train,y_train,epochs=30,validation_split=0.2,callbacks=[my_cb] )

클래스를 아직 배우지 않았기 때문에 이렇게 사용한다고만 알고 넘어갑니다.

 

우리가 볼 부분은 on_epoch_end 함수의 정의부분입니다.

if 문을 통해 조건을 걸었습니다. 

 

logs['val_accuracy'] 의 값이 0.88 보다 크다면 학습을 멈춰달라고 코드를 작성하였습니다.

 

위 코드르 실행시키면 에포크 30횟수가 아닌 accuracy가 0.88이 넘는 순간 멈추게됩니다.