pytorch:预训练权重、冻结训练和断点恢复
知乎—吵鸡凶鸭OvO  侵删
01
If I have seen further, it is by standing on the shoulders of giants. 
02
# 第一步:读取当前模型参数model_dict = model.state_dict()# 第二步:读取预训练模型pretrained_dict = torch.load(model_path, map_location = device)pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}# 第三步:使用预训练的模型更新当前模型参数model_dict.update(pretrained_dict)# 第四步:加载模型参数model.load_state_dict(model_dict)
model_dict = model.state_dict()pretrained_dict = torch.load(model_path, map_location=device)temp = {}for k, v in pretrained_dict.items():try:if np.shape(model_dict[k]) == np.shape(v):temp[k]=vexcept:passmodel_dict.update(temp)
03
# 冻结阶段训练参数,learning_rate和batch_size可以设置大一点Init_Epoch = 0Freeze_Epoch = 50Freeze_batch_size = 8Freeze_lr = 1e-3# 解冻阶段训练参数,learning_rate和batch_size设置小一点UnFreeze_Epoch = 100Unfreeze_batch_size = 4Unfreeze_lr = 1e-4# 可以加一个变量控制是否进行冻结训练Freeze_Train = True# 冻结一部分进行训练batch_size = Freeze_batch_sizelr = Freeze_lrstart_epoch = Init_Epochend_epoch = Freeze_Epochif Freeze_Train:for param in model.backbone.parameters():param.requires_grad = False# 解冻后训练batch_size = Unfreeze_batch_sizelr = Unfreeze_lrstart_epoch = Freeze_Epochend_epoch = UnFreeze_Epochif Freeze_Train:for param in model.backbone.parameters():param.requires_grad = True
04
torch.save(model.state_dict(), "你要保存到的路径")05
猜您喜欢:
附下载 |《TensorFlow 2.0 深度学习算法实战》
评论
