AI

PyTorch | Model save, checkpoints, freezing

cstory-bo 2024. 1. 12. 14:38

모델 불러오기

학습 시키다가 갑자기 데이터가 모두 날아가는 경험 한 번씩은 있을 것이다...

이렇게 날아가지 않도록 중간중간 저장시켜주기 위해서 사용하는 함수가 save()이다.

Model.save()

  • 학습의 결과를 저장하기 위한 함수
  • 모델 형태와 파라미터를 저장한다.
  • 모델 학습 중간 과정의 저장을 통해 최선의 결과모델을 선택할 수 있다.
  • 만들어진 모델을 외부 연구자와 공유하여 학습 재연경 향상시켜준다.

이 save함수를 사용하여 아래와 같은 코드를 쓸 수 있다.

  • state_dict : 모델의 파라미터를 표시
  • load_state_dict : 같은 모델의 형태에서 파라미터만 load
  • save(model.state_dict(),~) → 모델의 파라미터를 저장
  • save(model,~) → 모델의 architecuture와 함게 저장
  • load(path) : 모델의 architecutre와 함께 load
for param_tensor in model.state_dict():
	print(~)

# 이 방식 선호
torch.save(model.state_dict(),os.path.join(model_path,'model.pt'))

new_model = TheModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))

# 결과와 코드도 공유된다.
torch.save(model, os.path.join(MODEL_PATH,"model.pt"))
model = torch.load(os.path.join(MODEL_PATH, "model.pt"))

Checkpoints

  • 학습의 중간 결과를 저장하여 최선의 결과를 선택할 수 있다.
  • earlystopping 기법 사용시 이전 학습의 결과물을 저장한다.
  • 일반적으로 epoch, loss, metric을 함께 저장하여 확인한다.
torch.save({
	'epoch':e,
	'model_state_dict':model.state_dict(),
	'optimizer_state_dict':optimizer.state_dict(),
	'loss':epoch_loss,
	}
f'saved/checkpoint_model_{e}_{epoch_loss/len(dataloader)}_{epoch_acc/len(dataloader)}
)

Pretrained Learning

  • 다른 데이터셋으로 만든 모델을 현재 데이터에 적용
  • 대용량 데이터셋으로 만들어진 모델의 성능이 일반적으로 높다.
  • backbone architecture가 잘 학습된 모델에서 일부분만 변경하여 학습을 수행한다.

NLP는 보통 HuggingFace를 사용한다.

Freezing

pretrained model을 활용하려면 모델의 일부분을 frozen 시켜야한다.

특정 위치까지 멈춰서 일부분 파라미터 값이 바뀌지 않게 한다.

# 모델을 할당한다.
vgg = models.vgg16(pretrained=True).to(device)

# 마지막에 Linear Layer 추가한다.
class MyNewNet(nn.Module):
	def __init__(self):
		super(MyNewNet,self).__init__()
		self.vgg19 = models.vgg19(pretrained=True)
		self.linear_layers = nn.Linear(1000,1)

	def forward(self, x):
		x = self.vgg19(x)
		return self.linear_layers(x)

# 마지막 레이어를 제외하고 frozen 시켜준다.
for param in my_model.parameters():
	param.requires_grad = False
for param in my_model.linear_layers.parameters():
	param.requires_grad = True