본문 바로가기

Deep Learning

[Pytorch] Loading specific keys for NN initialization

반응형

This is from the answers in https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16

 

How to load part of pre trained model?

After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error. Assum following situation: pretrained_dict: ['A', 'B', 'C', 'D'] model_dict: ['A', 'B', 'C', 'E'] After pretrained_

discuss.pytorch.org

 

After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error.

Assum following situation:

pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']

After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} and model_dict.update(pretrained_dict), they are:

pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']

So when performing model.load_state_dict(pretrained_dict), model_dict still has key E that pretrained_dict doen’t have.

So how about using model.load_state_dict(model_dict) instead of model.load_state_dict(pretrained_dict)?

The complete snippet is therefore as follow:

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(model_dict)