网络模型库

对于深度学习,torchvision.models库提供了众多经典的网络结构与预训练模型,例如VGG、ResNet和Inception等,利用这些模型可以快速搭建物体检测网络,不需要逐层手动实现。
例: Resnet_50

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torchvision import models
resnet = models.resnet50()
# 查看网络框架
print(len(resnet.layer1))
# 可以通过出现的顺序直接索引每一层
print(resnet.layer1[0])

'''
result:

3
Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)

加载预训练模型

  1. torchvision.models中自带的预训练模型,只需要在使用时赋予pretrained参数为True即可。
    例:
    1
    2
    from torchvision import models
    resnet = models.resnet50(pretrained = True)
  2. 如果想要使用自己的本地预训练模型,或者之前训练过的模型,则可以通过model.load_state_dict()函数操作
    例:
    1
    2
    3
    4
    5
    import torch
    from torchvision import models
    resnet = models.resnet50(pretrained = False)
    state_dict = torch.load('Resnet_four_classification/model1.pth')
    resnet.load_state_dict(state_dict,False)
  • 通常来讲,对于不同的检测任务,卷积网络的前两三层的作用是非常类似的,都是提取图像的边缘信息等,因此为了保证模型训练中能够更加稳定,一般会固定预训练网络的前两三个卷积层而不进行参数的学习
    例:
    1
    2
    3
    4
    frozen_layers = [resnet.conv1,resnet.layer2,resnet.layer3]
    for layer in frozen_layers:
    for name, value in layer.named_parameters():
    value.requires_grad = False

模型保存

  1. 直接保存整个模型
    例:torch.save(model_ft, 'model.pth')
  2. 选择性保存
    例:
    1
    2
    3
    4
    5
    6
    7
    torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    ...
    }, PATH)

数据处理,加载,变换与增强

详细信息请见这两篇文章:

数据集图片变换与增强[transform][augmentation]
训练数据准备与导入[自定义数据集][分类问题]

Reference

深度学习之PyTorch物体检测实战 - 董洪义