模型调用

首先再models/__init__.py中,有以下代码:

from .alexnet import AlexNet
from .resnet34 import ResNet34
from .squeezenet import SqueezeNet
# from torchvision.models import InceptinV3
# from torchvision.models import alexnet as AlexNet

这样主函数在调用时就可以写成:

from models import AlexNetg
#或
import models
model = models.AlexNet()
#或
import models
model = getattr(models,'AlexNet')().eval()

书中比较推荐的是第三种方法,但是比较绕,个人理解如下:

首先是getattr和eval两个函数的理解:

getattr(x,'foobar')等效于x.foobar

model.eval()是评估模式,用于验证集和测试集,表示我们用验证集或测试集进行评估时,不改变dropout和batch normalization的参数。

(工程代码中getattr的第二个参数是opt.model,这需要读取配置文件中的model参数,小白还没有具体研究,待研究明白再来这里做下记录。)这样如果再使用别的模型的时候,只需要修改一个字符串就好了,便于后期维护。

模型定义

在说具体的模型之前,先来研究一下basic_module.py中的用来加载和保存模型的接口,首先是以下两个函数

state_dict(destination = None,prefix = '' ,keep_vars = False)

返回一个包含整个模块训练参数的字典表,key值和参数name相一致,prefix添加前缀

load_state_dict(state_dict,strcict =True)

从传入的 state_dict 中复制 网络训练参数和缓存 到当前网络及其子类中去,如果strict为True,传入的训练参数的key值必须准确匹配当前网络参数的key值

初始化时将类属性model_name设置为模型默认的名字,在config.py中可配置

保存模型时,保存模型的所有参数,文件保存在checkpoints文件下,文件用‘alexnet_月日_时.分.秒.pth’命名,需要注意的是windows系统中文件名不能用‘:’,所以第21行中的冒号改为了点号

加载模型时,传入模型路径即可加载

优化器的选择也封装在这里,采用Adam优化器,后面还有一个Flat类,reshape输入数据用,还没研究明白,搞懂回来补充,上代码!

#coding:utf8
import torch as t
import time


class BasicModule(t.nn.Module):
    """
    封装了nn.Module,主要是提供了save和load两个方法
    """

    def __init__(self):
        super(BasicModule,self).__init__()
        self.model_name=str(type(self))# 模型的默认名字

    def save(self, name=None):
        """
        保存模型,默认使用“模型名字+时间”作为文件名
        """
        if name is None:
            prefix = 'checkpoints/' + self.model_name + '_'
            name = time.strftime(prefix + '%m%d_%H.%M.%S.pth')
        t.save(self.state_dict(), name)
        return name

    def load(self, path):
        """
        可加载指定路径的模型
        """
        self.load_state_dict(t.load(path))

    def get_optimizer(self, lr, weight_decay):
        return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)


class Flat(t.nn.Module):
    """
    把输入reshape成(batch_size,dim_length)
    """

    def __init__(self):
        super(Flat, self).__init__()
        #self.size = size

    def forward(self, x):
        return x.view(x.size(0), -1)