网络模型库
对于深度学习,torchvision.models库提供了众多经典的网络结构与预训练模型,例如VGG、ResNet和Inception等,利用这些模型可以快速搭建物体检测网络,不需要逐层手动实现。
例: Resnet_50
1 | from torchvision import models |
加载预训练模型
- torchvision.models中自带的预训练模型,只需要在使用时赋予pretrained参数为True即可。
例:1
2from torchvision import models
resnet = models.resnet50(pretrained = True) - 如果想要使用自己的本地预训练模型,或者之前训练过的模型,则可以通过model.load_state_dict()函数操作
例:1
2
3
4
5import torch
from torchvision import models
resnet = models.resnet50(pretrained = False)
state_dict = torch.load('Resnet_four_classification/model1.pth')
resnet.load_state_dict(state_dict,False)
- 通常来讲,对于不同的检测任务,卷积网络的前两三层的作用是非常类似的,都是提取图像的边缘信息等,因此为了保证模型训练中能够更加稳定,一般会固定预训练网络的前两三个卷积层而不进行参数的学习
例:1
2
3
4frozen_layers = [resnet.conv1,resnet.layer2,resnet.layer3]
for layer in frozen_layers:
for name, value in layer.named_parameters():
value.requires_grad = False
模型保存
- 直接保存整个模型
例:torch.save(model_ft, 'model.pth')
- 选择性保存
例:1
2
3
4
5
6
7torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
数据处理,加载,变换与增强
详细信息请见这两篇文章:
Reference
深度学习之PyTorch物体检测实战 - 董洪义
money money money~ money money~
- 本文链接:http://yoursite.com/2020/08/19/DL_P9/
- 版权声明:本博客所有文章除特别声明外,均默认采用 许可协议。
若没有本文 Issue,您可以使用 Comment 模版新建。
GitHub Issues