반응형
This is from the answers in https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/16
After model_dict.update(pretrained_dict), the model_dict may still have keys that pretrained_model doesn’t have, which will cause a error.
Assum following situation:
pretrained_dict: ['A', 'B', 'C', 'D']
model_dict: ['A', 'B', 'C', 'E']
After pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} and model_dict.update(pretrained_dict), they are:
pretrained_dict: ['A', 'B', 'C']
model_dict: ['A', 'B', 'C', 'E']
So when performing model.load_state_dict(pretrained_dict), model_dict still has key E that pretrained_dict doen’t have.
So how about using model.load_state_dict(model_dict) instead of model.load_state_dict(pretrained_dict)?
The complete snippet is therefore as follow:
pretrained_dict = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
'Deep Learning' 카테고리의 다른 글
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0. (0) | 2023.08.22 |
---|---|
Pytorch how to use nn.ModuleDict with zip for iteration (0) | 2023.07.19 |
Image Frustum to Global 3D (0) | 2023.06.30 |
Deformable DETR attention operation cuda build (0) | 2023.06.20 |
Image Augmentation (Photometric) 방법 (0) | 2023.05.31 |