akaSonny
[Tensorflow/keras] callbacks 사용하기 (훈련 중 모델 저장 / EarlyStop) 본문
Study (Programming)/Tensorflow-Keras
[Tensorflow/keras] callbacks 사용하기 (훈련 중 모델 저장 / EarlyStop)
Jihyeoning 2024. 2. 1. 13:42나는 항상 모델 훈련을 하면 전체 에폭이 돌고, 모델 저장을 해서 테스트를 했는데,
생각해보니 "마지막에 저장된 모델이 과적합돼서 오히려 성능이 떨어진 모델을 사용하는게 아닐까..?" 하는 의문이 들었다.
그래서 훈련 중 두 가지 처리를 시도해보았다.
- 훈련 중 모델 저장하기
- EarlyStop 사용하기
1. 훈련 중 모델 저장하기 - tensorflow.keras.callbacks.ModelCheckpoint()
tf.keras.callbacks.ModelCheckpoint()
filepath,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq='epoch',
options=None,
initial_value_threshold=None,
**kwargs
)
- 필수 입력 인자는
filepath
로, checkpoint가 어디에 저장될지 지정하는 인자이다. 확장자 (.ckpt) 까지 작성 save_freq
는 'epoch'이 default이고 이는 한 에폭이 끝날 때 체크포인트가 저장된다. int로도 입력할 수 있는데 그럴 경우 입력된 batch 가 훈련이 완료되면 체크포인트가 저장된다.
아래는 사용법
from tensorflow.keras import callbacks
checkpoint_path = path + 'cp.ckpt'
cp_callback = callbakcs.ModelCheckpoint(filepath=ckeckpoint_path,
save_weiths_only=True,
save_freq='epoch')
hist = model.fit(x, y, epochs=100, batch_size=1, callbacks=[cp_callback])
2. EarlyStop - tensorflow.keras.callbacks.EarlyStopping
loss가 증가하면 훈련을 종료해서 과적합을 막아주는 함수 (tf 2.15 ver)
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False,
start_from_epoch=0
)
monitor
: 어떤 metric을 볼 것인지 정하는 인자. 그냥 loss를 봐도 되고, loss가 아닌 사용자가 지정한 metric이 있다면 그걸로 설정해도 됨.patience
: loss가 증가할 때 어느 epoch 까지 지켜볼 것인지 정하는 인자. 예를 들어, 5라고 설정했으면 5번의 에폭 동안 loss 가 증가한다면 훈련 종료.start_from_epoch
: warm-up이 필요하다고 생각하면, 특정 epoch 이후로 early stop을 수행하도록 설정 (나는 tf 2.4를 쓰는데 이 버전에는 없다 ㅜㅜ .. 버전 체크해보기)
사용법
앞에서 만들었던 cp_callback 을 함께 사용하자!
fit 함수에 callbacks 의 인자에 리스트로 넣어주면 된다.
from tensorflow.keras import callbacks
earlystop = callbakcs.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
hist = model.fit(x, y, epochs=100, batch_size=1, callbacks=[earlystop, cp_callback])
'Study (Programming) > Tensorflow-Keras' 카테고리의 다른 글
[TF/Keras] Custom Datagenerator로 Multiple Input 넣기 (0) | 2024.04.11 |
---|---|
[Tensorflow/Keras] MNIST 학습시키기 (1) | 2022.10.18 |
[Tensorflow/Keras] Custom loss function 만들기 (0) | 2022.09.30 |