akaSonny

[Tensorflow/keras] callbacks 사용하기 (훈련 중 모델 저장 / EarlyStop) 본문

Study (Programming)/Tensorflow-Keras

[Tensorflow/keras] callbacks 사용하기 (훈련 중 모델 저장 / EarlyStop)

Jihyeoning 2024. 2. 1. 13:42

나는 항상 모델 훈련을 하면 전체 에폭이 돌고, 모델 저장을 해서 테스트를 했는데,

생각해보니 "마지막에 저장된 모델이 과적합돼서 오히려 성능이 떨어진 모델을 사용하는게 아닐까..?" 하는 의문이 들었다.

 

그래서 훈련 중 두 가지 처리를 시도해보았다.

  1. 훈련 중 모델 저장하기
  2. EarlyStop 사용하기

 

1. 훈련 중 모델 저장하기 - tensorflow.keras.callbacks.ModelCheckpoint()

tf.keras.callbacks.ModelCheckpoint()
    filepath,
    monitor: str = 'val_loss',
    verbose: int = 0,
    save_best_only: bool = False,
    save_weights_only: bool = False,
    mode: str = 'auto',
    save_freq='epoch',
    options=None,
    initial_value_threshold=None,
    **kwargs
)

 

    • 필수 입력 인자는 filepath로, checkpoint가 어디에 저장될지 지정하는 인자이다. 확장자 (.ckpt) 까지 작성
    • save_freq 는 'epoch'이 default이고 이는 한 에폭이 끝날 때 체크포인트가 저장된다. int로도 입력할 수 있는데 그럴 경우 입력된 batch 가 훈련이 완료되면 체크포인트가 저장된다.

아래는 사용법 

from tensorflow.keras import callbacks

checkpoint_path = path + 'cp.ckpt'
cp_callback = callbakcs.ModelCheckpoint(filepath=ckeckpoint_path, 
                                        save_weiths_only=True,
                                        save_freq='epoch')

hist = model.fit(x, y, epochs=100, batch_size=1, callbacks=[cp_callback])

 

 

2. EarlyStop - tensorflow.keras.callbacks.EarlyStopping

 

loss가 증가하면 훈련을 종료해서 과적합을 막아주는 함수 (tf 2.15 ver)

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=0,
    verbose=0,
    mode='auto',
    baseline=None,
    restore_best_weights=False,
    start_from_epoch=0
)

 

  • monitor : 어떤 metric을 볼 것인지 정하는 인자. 그냥 loss를 봐도 되고,  loss가 아닌 사용자가 지정한 metric이 있다면 그걸로 설정해도 됨.
  • patience : loss가 증가할 때 어느 epoch 까지 지켜볼 것인지 정하는 인자. 예를 들어, 5라고 설정했으면 5번의 에폭 동안 loss 가 증가한다면 훈련 종료.
  • start_from_epoch : warm-up이 필요하다고 생각하면, 특정 epoch 이후로 early stop을 수행하도록 설정  (나는 tf 2.4를 쓰는데 이 버전에는 없다 ㅜㅜ .. 버전 체크해보기)

사용법

앞에서 만들었던 cp_callback 을 함께 사용하자!

fit 함수에 callbacks 의 인자에 리스트로 넣어주면 된다. 

from tensorflow.keras import callbacks

earlystop = callbakcs.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

hist = model.fit(x, y, epochs=100, batch_size=1, callbacks=[earlystop, cp_callback])