项目背景
现在要对XXX
疾病进行二分类,通常医学上称之为阴性(无XXX
病),阳性(有XXX
病)。对于分类任务来说,二分类是最简单的分类任务。经典的分类网络(VGG,ResNet,DenseNet等)都是在ImageNet进行1000类分类任务。因此,本项目拟采用经典网络ResNet系列网络结构进行二分类实验。
基本内容
数据采集:特定设备采集人体3D数据,渲染生成训练需要的各种类型的2D图片。那么应该生成哪种类型的数据进行分类训练?要根据实际的分类任务和目标而定。主要考虑如下几个方面:
- 图片特征是否明显?
- 特征是否具有差异化,也即是能否将不同的类别区分开?
- 图片尺寸和深度,单通道图片无法利用现有的预训练模型;
网络结构:针对新的任务而言,首先选择主流的网络结构进行实验,该项目选择ResNet-50
系列变体。
程序结构:当数据和网络结构确定之后,开始实现整个项目,主要分为如下几个模块构建完整可训练的项目,
-
train.py
:主程序入口,构建整个项目的框架,每个功能模块的调用接口。 -
DataLoader.py
:数据读取模块,在一定硬件资源下,高效读取数据。 -
Model.py
:定义网络结构,以及各个模块的实现(比如,网络结构,损失函数,网络评估,日志保存等)。 -
config.py
:设定训练参数,测试参数和网络参数。 -
predict.py
:预测代码,加载测试数据,得到预测结果。
代码展示
下面的代码是train.py
, 主要包括数据读取,模型构建,开启训练与日志保存。该代码是基于TF1.X
最基本的训练流程,能够满足大部分项目需求。在实现一个新的项目时,首先构建整个框架结构,然后实现每个模块。这样有助于理清思路,快速实现项目。
from Model import *
from DataLoader import *
from config import cfg
from collections import Counter
def train():
tf.set_random_seed(-1)
# ******************************************************************
# 1. Python多线程数据读取与数据预处理
# ******************************************************************
train_queue = train_set_queue()
valid_queue = valid_set_queue()
# ******************************************************************
# 2. 构建静态图
# ******************************************************************
model_test = Model()
print('variables: ', model_test.first_stage_trainable_var_list)
# ******************************************************************
# 3. 开启会话, 执行静态图
# ******************************************************************
init = tf.global_variables_initializer()
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
with tf.Session(config=config_proto) as sess:
sess.run(init)
try:
print('=> Restoring weights from: %s ... ' % model_test.pre_trained_model)
model_test.loader.restore(sess, model_test.pre_trained_model)
print('成功加载模型')
except:
print('Load Pretrained Weight Failed !!!\n')
print('=> %s does not exist !!!' % model_test.pre_trained_model)
print('=> Now it starts to train from scratch ...')
cfg.Train.First_Stage_Epochs = 0 # 重头开始训练网络, 不需要预训练权重
else:
print('\nLoad Model Success !!!\n')
start_epoch = 0
# 训练日志保存目录
save_path = cfg.Train.Log
train_writer = tf.summary.FileWriter(save_path + 'train', sess.graph)
valid_writer = tf.summary.FileWriter(save_path + 'valid')
for epoch in range(start_epoch+1, 1 + cfg.Train.First_Stage_Epochs + cfg.Train.Second_Stage_Epochs):
if epoch < cfg.Train.First_Stage_Epochs:
train_op = model_test.train_op_with_frozen_variables
else:
train_op = model_test.train_op_with_all_variables
# Train Process
for step in range(0, cfg.Train.Train_Num//cfg.Train.Batch_Size):
# 每次从队列中取出一个batch的数据
_, _, image, label = train_queue.get()
# 统计每个batch中标签的二分类分布情况
# result = Counter(label)
# 输入参数
train_params = [train_op,
model_test.loss,
model_test.accuracy,
model_test.merged,
model_test.learn_rate,
model_test.global_step]
# 输出Tensor
train_dict = {model_test.inputs: image,
model_test.label_c: label,
model_test.trainable: True}
_, train_loss, train_accuracy, merge_train, lr_, global_step_ = sess.run(train_params,
feed_dict=train_dict)
train_writer.add_summary(merge_train, global_step_)
print('iter:%2d/%d || train loss:%.4f || train accuracy:%.4f || lr:%g ' % (step,
epoch,
train_loss,
train_accuracy,
lr_))
# print('阴性样本:%d || 阳性样本:%d ' % (result[0], result[1]))
# Valid Process
valid = []
label_ = []
predict_ = []
for valid_step in range(cfg.Train.Valid_Num//cfg.Train.Batch_Size):
_, _, valid_image, valid_label = valid_queue.get()
valid_params = [model_test.accuracy,
model_test.loss,
model_test.merged,
tf.argmax(model_test.end_point_squeeze, 1),
model_test.label_c]
valid_dict = {model_test.inputs: valid_image,
model_test.label_c: valid_label,
model_test.trainable: False}
valid_accuracy, valid_loss, merge_valid, predict, label_c = sess.run(valid_params, feed_dict=valid_dict)
valid_writer.add_summary(merge_valid, global_step_)
print('==================================================================================')
print('step=%2d/%d || valid accuracy=%g || loss=%.4f ' % (valid_step,
epoch,
valid_accuracy,
valid_loss))
print('==================================================================================')
print('标签值:', label_c)
print('预测值:', predict)
label_.append(label_c)
predict_.append(predict)
valid.append(valid_accuracy)
# 根据具体的项目需求,打印相应的输出,有助于动态观察网络的训练过程
# 计算所有验证集的混淆矩阵
print('mean valid accuracy: ', np.mean(valid))
confusion_matrix = tf.confusion_matrix(np.hstack(label_), np.hstack(predict_), num_classes=2)
confusion_matrix_ = sess.run(confusion_matrix)
TN = confusion_matrix_[0][0]
FP = confusion_matrix_[0][1]
FN = confusion_matrix_[1][0]
TP = confusion_matrix_[1][1]
acc = (TP + TN) / (TP + TN + FP + FN)
sensitive = TP / (TP + FN)
specify = TN / (TN + FP)
print('混淆矩阵\n', confusion_matrix_)
print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)
# 保存每个epoch的模型
if epoch > 50 and epoch % 5 == 0:
model_test.saver.save(sess, cfg.Train.Save_Model, global_step=epoch)
if __name__ == '__main__':
train()