본문 바로가기

Deep Learning

Initialize specific parameters of a NN with pre-trained ones and stop them from learning

반응형
# pre-trained & target state dict
file_name = './saved_models/nuscenes_CVT_model1100/saved_chk_point_27.pt'
pre_state_dict = torch.load(file_name, map_location=torch.device('cpu'))['model_state_dict']
target_state_dict = model.state_dict()

# copy from pre-trained to target
for name, param in pre_state_dict.items():
    if 'shallow' in name:
        if name.replace("module.", "") in target_state_dict:
            target_state_dict[name.replace("module.", "")].copy_(param)
        else:
            sys.exit(f"Warning: {name} not found in target model")
model.load_state_dict(target_state_dict)

# stop specific layers from learning
for name, param in model.named_parameters():
    if 'shallow' in name:
        param.requires_grad = False