akaSonny
[Pytorch] checkpoint 불러와서 모델 이어서 학습하기 본문
모델 돌리고 퇴근하고 와서 봤는데 네트워크 문제 등등으로 모델 훈련이 중단되어 있는 경우 ,,
저장된 checkpoint를 이용해서 다시 훈련시켜보자
- checkpoint 파일: 'checkpoint.pth'
model = Model()
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint)
for epoch in range(n_epoch):
model.train()
...
그런데 나 같은 경우에는 이렇게 하니까 RuntimeError
가 발생했다.
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: ~~
Unexpected key(s) in state_dict: ~~
state_dict - key 가 다를 경우 발생하는 에러라고 해서, key를 바꿔주는 코드를 추가하였다.
나 같은 경우에는 module.
글자가 unexpected keys
에 다 붙어있어서, 이거를 제거해주는 코드를 작성했다.
에러에 맞게 수정하면 될 것 같다.
checkpoint = torch.load('checkpoint.pth')
for key in list(checkpoint.keys()):
if 'module.' in key:
checkpoint[key.replace('module.', '')] = checkpoint[key]
del checkpoint[key]
model.load_state_dict(checkpoint)
참고
https://cocoa-t.tistory.com/entry/RuntimeError-Errors-in-loading-statedict
'Study (Programming) > Pytorch' 카테고리의 다른 글
완전 간단한 Linear model 만들기 (0) | 2022.05.20 |
---|---|
Pytorch 기초부터 시작! (0) | 2022.05.20 |