akaSonny

[Pytorch] checkpoint 불러와서 모델 이어서 학습하기 본문

Study (Programming)/Pytorch

[Pytorch] checkpoint 불러와서 모델 이어서 학습하기

Jihyeoning 2024. 2. 19. 11:12

모델 돌리고 퇴근하고 와서 봤는데 네트워크 문제 등등으로 모델 훈련이 중단되어 있는 경우 ,,

저장된 checkpoint를 이용해서 다시 훈련시켜보자 

 

- checkpoint 파일: 'checkpoint.pth'

 

model = Model()

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint)

for epoch in range(n_epoch):
	model.train()
	...

 

그런데 나 같은 경우에는 이렇게 하니까 RuntimeError 가 발생했다.

RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: ~~
Unexpected key(s) in state_dict: ~~

 

state_dict - key 가 다를 경우 발생하는 에러라고 해서, key를 바꿔주는 코드를 추가하였다.

나 같은 경우에는 module. 글자가 unexpected keys에 다 붙어있어서, 이거를 제거해주는 코드를 작성했다.

에러에 맞게 수정하면 될 것 같다.

checkpoint = torch.load('checkpoint.pth')

for key in list(checkpoint.keys()):
    if 'module.' in key:
    checkpoint[key.replace('module.', '')] = checkpoint[key]
    del checkpoint[key]
    
model.load_state_dict(checkpoint)

 


참고

https://cocoa-t.tistory.com/entry/RuntimeError-Errors-in-loading-statedict

'Study (Programming) > Pytorch' 카테고리의 다른 글

완전 간단한 Linear model 만들기  (0) 2022.05.20
Pytorch 기초부터 시작!  (0) 2022.05.20