반응형
# pre-trained & target state dict
file_name = './saved_models/nuscenes_CVT_model1100/saved_chk_point_27.pt'
pre_state_dict = torch.load(file_name, map_location=torch.device('cpu'))['model_state_dict']
target_state_dict = model.state_dict()
# copy from pre-trained to target
for name, param in pre_state_dict.items():
if 'shallow' in name:
if name.replace("module.", "") in target_state_dict:
target_state_dict[name.replace("module.", "")].copy_(param)
else:
sys.exit(f"Warning: {name} not found in target model")
model.load_state_dict(target_state_dict)
# stop specific layers from learning
for name, param in model.named_parameters():
if 'shallow' in name:
param.requires_grad = False