#实践中,受限于数据集规模的约束,我们很少从头开始端到端的训练一个神经网络。通常情况下,
# 我们会选择在ImageNet数据集上预训练好的网络模型上进行适当的修改,使其适用于目标数据集。

#首先,修改网络模型的最后一个全连接层,使其适应于目标数据集,
# 使用预训练的网络权重来初始化网络模型的权重,用自己的图像数据来微调训练网络。微调网络主要有以下两种做法:

#1.只训练最后一个全连接层,冻结除最后一个全连接层外的所有层的权重。
#2.所有网络层都参与训练,不过最后一个全连接层在训练时使用更大的学习率,通常最后一个全连接层的学习率是前面层学习率的10倍。

#下面基于迁移学习实现一个ResNet18来对蜜蜂和蚂蚁分类,点击这里下载数据集。蚂蚁和蜜蜂大约均有120幅训练图像。每个类别有75幅验证图像。

from __future__ import print_function, division

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import time
import os
import copy

# 是否使用gpu运算
use_gpu = torch.cuda.is_available()
# 数据预处理,Pytorch提供了一个数据预处理的操作对象。定义如下:
data_transforms = {
    'train': transforms.Compose([
        # 随机在图像上裁剪出224*224大小的图像
        transforms.RandomResizedCrop(224),
        # 将图像随机翻转
        transforms.RandomHorizontalFlip(),
        # 将图像数据,转换为网络训练所需的tensor向量
        transforms.ToTensor(),
        # 图像归一化处理
        # 个人理解,前面是3个通道的均值,后面是3个通道的方差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 读取数据
# 这种数据读取方法,需要有train和val两个文件夹,
# 每个文件夹下一类图像存在一个文件夹下
#在对分类的数据进行处理的时候,可以使用Pytorch提供的ImageFolder类来实现数据预处理。
#首先需要定义数据集的根目录:
data_dir = '../data/hymenoptera_data'
#然后,对于train和val这两个分别使用ImageFolder处理.这时,ImageFolder已经完成了照片数据的分类,并将这些图片的分类信息放倒了image_datasets变量中,
#可以看到,ImageFolder类已经将ants,bees做好了分类,并赋值为0和1。并且,训练数据以及测试数据被很好的分开。
#data_transforms对象在ImageFolder进行数据处理的时候作为参数传入,可以将上面数据处理的代码改为如下形式:
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

#有了ImageFolder获取到的image_datasets,这里只是找到了数据的路径以及相对应的类别,
# Pytorch还提供了DataLoader类,用于在训练时,实时获取数据对应的训练数据。代码如下:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
#DataLoader的第一个参数为上面获取到的image_datasets,第二个参数为batch_size,
#表示的是批训练时每批样本的数量。参数shuffle表示的是是否打乱数据的顺序,True表示打乱。参数num_workers表示参与计算的CPU核心数。

# 读取数据集大小 train:244,val:153
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# 数据类别 ['ants','bees']
class_names = image_datasets['train'].classes

# 训练与验证网络(所有层都参加训练)
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time() #返回的是毫秒
    # 保存网络训练最好的权重
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 每训练一个epoch,测试一下网络模型的准确率
        for phase in ['train', 'val']: #phase=='train'
            if phase == 'train':
                # 学习率更新方式
                scheduler.step()
                #  调用模型训练
                model.train(True)
            else:
                # 调用模型测试
                model.train(False)

            running_loss = 0.0
            running_corrects = 0
            # 依次获取所有图像,参与模型训练或测试
            for data in dataloaders[phase]:
                # 获取输入
                inputs, labels = data
                # 判断是否使用gpu
                if use_gpu:
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                # 梯度清零
                optimizer.zero_grad()

                # 网络前向运行
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1) #获取最大值索引
                # 计算Loss值,交叉熵损失函数,其内部会自动加上Sofrmax层
                loss = criterion(outputs, labels)

                # 反传梯度,更新权重
                if phase == 'train':
                    # 反传梯度
                    loss.backward()
                    # 更新权重
                    optimizer.step()

                # 计算一个epoch的loss值和准确率,inputs.size(0)=4,
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            # 计算Loss和准确率的均值
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = float(running_corrects) / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # 保存测试阶段,准确率最高的模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    # 网络导入最好的网络权重
    model.load_state_dict(best_model_wts)
    return model

# 微调网络
if __name__ ==  '__main__':

    # 导入Pytorch中自带的resnet18网络模型
    model_ft = models.resnet18(pretrained=True)
    # 将网络模型的各层的梯度更新置为False
    for param in model_ft.parameters():
        param.requires_grad = False

    # 修改网络模型的最后一个全连接层
    # 获取最后一个全连接层的输入通道数
    num_ftrs = model_ft.fc.in_features
    # 修改最后一个全连接层的的输出数为2
    model_ft.fc = nn.Linear(num_ftrs, 2)
    # 是否使用gpu
    if use_gpu:
        model_ft = model_ft.cuda()

    # 定义网络模型的损失函数
    criterion = nn.CrossEntropyLoss()

    # 只训练最后一个层
    # 采用随机梯度下降的方式,来优化网络模型
    optimizer_ft = torch.optim.SGD(model_ft.fc.parameters(), lr=0.001, momentum=0.9)

    # 定义学习率的更新方式,每5个epoch修改一次学习率
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)
    # 训练网络模型
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=10)
    # 存储网络模型的权重
    torch.save(model_ft.state_dict(),"model_only_fc.pkl")