项目背景

  现在要对XXX疾病进行二分类,通常医学上称之为阴性(无XXX病),阳性(有XXX病)。对于分类任务来说,二分类是最简单的分类任务。经典的分类网络(VGG,ResNet,DenseNet等)都是在ImageNet进行1000类分类任务。因此,本项目拟采用经典网络ResNet系列网络结构进行二分类实验。

基本内容

数据采集:特定设备采集人体3D数据,渲染生成训练需要的各种类型的2D图片。那么应该生成哪种类型的数据进行分类训练?要根据实际的分类任务和目标而定。主要考虑如下几个方面:

  1. 图片特征是否明显?
  2. 特征是否具有差异化,也即是能否将不同的类别区分开?
  3. 图片尺寸和深度,单通道图片无法利用现有的预训练模型;

网络结构:针对新的任务而言,首先选择主流的网络结构进行实验,该项目选择ResNet-50系列变体。
程序结构:当数据和网络结构确定之后,开始实现整个项目,主要分为如下几个模块构建完整可训练的项目,

  1. train.py:主程序入口,构建整个项目的框架,每个功能模块的调用接口。
  2. DataLoader.py:数据读取模块,在一定硬件资源下,高效读取数据。
  3. Model.py:定义网络结构,以及各个模块的实现(比如,网络结构,损失函数,网络评估,日志保存等)。
  4. config.py:设定训练参数,测试参数和网络参数。
  5. predict.py:预测代码,加载测试数据,得到预测结果。

代码展示

下面的代码是train.py, 主要包括数据读取模型构建开启训练日志保存。该代码是基于TF1.X最基本的训练流程,能够满足大部分项目需求。在实现一个新的项目时,首先构建整个框架结构,然后实现每个模块。这样有助于理清思路,快速实现项目。

from Model import *
from DataLoader import *
from config import cfg
from collections import Counter

def train():
    tf.set_random_seed(-1)

    # ****************************************************************** 
    #                   1. Python多线程数据读取与数据预处理                  
    # ****************************************************************** 
    train_queue = train_set_queue()
    valid_queue = valid_set_queue()

    # ****************************************************************** 
    #                   2. 构建静态图                                     
    # ****************************************************************** 
    model_test = Model()
    print('variables: ', model_test.first_stage_trainable_var_list)

    # ****************************************************************** 
    #                   3. 开启会话, 执行静态图                            
    # ****************************************************************** 
    init = tf.global_variables_initializer()
    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    with tf.Session(config=config_proto) as sess:
        sess.run(init)
        try:
            print('=> Restoring weights from: %s ... ' % model_test.pre_trained_model)
            model_test.loader.restore(sess, model_test.pre_trained_model)
            print('成功加载模型')
        except:
            print('Load Pretrained Weight Failed !!!\n')
            print('=> %s does not exist !!!' % model_test.pre_trained_model)
            print('=> Now it starts to train from scratch ...')
            cfg.Train.First_Stage_Epochs = 0  # 重头开始训练网络, 不需要预训练权重
        else:
            print('\nLoad Model Success !!!\n')

        start_epoch = 0

        # 训练日志保存目录
        save_path = cfg.Train.Log
        train_writer = tf.summary.FileWriter(save_path + 'train', sess.graph)
        valid_writer = tf.summary.FileWriter(save_path + 'valid')

        for epoch in range(start_epoch+1, 1 + cfg.Train.First_Stage_Epochs + cfg.Train.Second_Stage_Epochs):
            if epoch < cfg.Train.First_Stage_Epochs:
                train_op = model_test.train_op_with_frozen_variables
            else:
                train_op = model_test.train_op_with_all_variables

            # Train Process
            for step in range(0, cfg.Train.Train_Num//cfg.Train.Batch_Size):
                # 每次从队列中取出一个batch的数据
                _, _, image, label = train_queue.get()

                # 统计每个batch中标签的二分类分布情况
                # result = Counter(label)
				# 输入参数
                train_params = [train_op,
                                model_test.loss,
                                model_test.accuracy,
                                model_test.merged,
                                model_test.learn_rate,
                                model_test.global_step]
				# 输出Tensor
                train_dict = {model_test.inputs: image,
                              model_test.label_c: label,
                              model_test.trainable: True}

                _, train_loss, train_accuracy, merge_train, lr_, global_step_ = sess.run(train_params,
                                                                                         feed_dict=train_dict)

                train_writer.add_summary(merge_train, global_step_)
                print('iter:%2d/%d || train loss:%.4f || train accuracy:%.4f || lr:%g ' % (step,
                                                                                           epoch,
                                                                                           train_loss,
                                                                                           train_accuracy,
                                                                                           lr_))
                # print('阴性样本:%d || 阳性样本:%d ' % (result[0], result[1]))

            # Valid Process
            valid = []
            label_ = []
            predict_ = []
            for valid_step in range(cfg.Train.Valid_Num//cfg.Train.Batch_Size):
                _, _, valid_image, valid_label = valid_queue.get()

                valid_params = [model_test.accuracy,
                                model_test.loss,
                                model_test.merged,
                                tf.argmax(model_test.end_point_squeeze, 1),
                                model_test.label_c]

                valid_dict = {model_test.inputs: valid_image,
                              model_test.label_c: valid_label,
                              model_test.trainable: False}

                valid_accuracy, valid_loss, merge_valid, predict, label_c = sess.run(valid_params, feed_dict=valid_dict)

                valid_writer.add_summary(merge_valid, global_step_)
                print('==================================================================================')
                print('step=%2d/%d || valid accuracy=%g || loss=%.4f ' % (valid_step,
                                                                          epoch,
                                                                          valid_accuracy,
                                                                          valid_loss))
                print('==================================================================================')
                print('标签值:', label_c)
                print('预测值:', predict)

                label_.append(label_c)
                predict_.append(predict)
                valid.append(valid_accuracy)
			
			# 根据具体的项目需求,打印相应的输出,有助于动态观察网络的训练过程
            # 计算所有验证集的混淆矩阵
            print('mean valid accuracy: ', np.mean(valid))
            confusion_matrix = tf.confusion_matrix(np.hstack(label_), np.hstack(predict_), num_classes=2)
            confusion_matrix_ = sess.run(confusion_matrix)

            TN = confusion_matrix_[0][0]
            FP = confusion_matrix_[0][1]
            FN = confusion_matrix_[1][0]
            TP = confusion_matrix_[1][1]

            acc = (TP + TN) / (TP + TN + FP + FN)
            sensitive = TP / (TP + FN)
            specify = TN / (TN + FP)

            print('混淆矩阵\n', confusion_matrix_)
            print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)

            # 保存每个epoch的模型
            if epoch > 50 and epoch % 5 == 0:
                model_test.saver.save(sess, cfg.Train.Save_Model, global_step=epoch)


if __name__ == '__main__':
    train()