基于pytorch的图像分类框架-更新日志

源码地址 github


pytorch-classifier v1.1 更新日志

  • 2022.11.8
  1. 修改processing.py的分配数据集逻辑,之前是先分出test_size的数据作为测试集,然后再从剩下的数据里面分val_size的数据作为验证集,这种分数据的方式,当我们的val_size=0.2和test_size=0.2,最后出来的数据集比例不是严格等于6:2:2,现在修改为等比例的划分,也就是现在的逻辑分割数据集后严格等于6:2:2.
  2. 参考yolov5,训练中的模型保存改为FP16保存.(在精度基本保持不变的情况下,模型相比FP32小一半)
  3. metrice.py和predict.py新增支持FP16推理.(在精度基本保持不变的情况下,速度更加快)
  • 2022.11.9
  1. 支持albumentations库的数据增强.
  2. 训练过程新增R-Drop,具体在main.py中添加–rdrop参数即可.
  • 2022.11.10
  1. 利用Pycm库进行修改metrice.py中的可视化内容.增加指标种类.
  • 2022.11.11
  1. 支持EMA(Exponential Moving Average),具体在main.py中添加–ema参数即可.
  2. 修改早停法中的–patience机制,当–patience参数为0时,停止使用早停法.
  3. 知识蒸馏中增加了一些实验数据.
  4. 修复一些bug.

FP16推理实验:

实验环境:

System

CPU

GPU

RAM

Ubuntu

i9-12900KF

RTX-3090

32G

训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练resnext50:

python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练RepVGG-A0:

python main.py --model_name RepVGG-A0 --config config/config.py --save_path runs/RepVGG-A0 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

训练densenet121:

python main.py --model_name densenet121 --config config/config.py --save_path runs/densenet121 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

计算各个模型的指标:

python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/RepVGG-A0
    python metrice.py --task val --save_path runs/densenet121

    python metrice.py --task val --save_path runs/mobilenetv2 --half
    python metrice.py --task val --save_path runs/resnext50 --half
    python metrice.py --task val --save_path runs/RepVGG-A0 --half
    python metrice.py --task val --save_path runs/densenet121 --half

计算各个模型的fps:

python metrice.py --task fps --save_path runs/mobilenetv2
    python metrice.py --task fps --save_path runs/resnext50
    python metrice.py --task fps --save_path runs/RepVGG-A0
    python metrice.py --task fps --save_path runs/densenet121

    python metrice.py --task fps --save_path runs/mobilenetv2 --half
    python metrice.py --task fps --save_path runs/resnext50 --half
    python metrice.py --task fps --save_path runs/RepVGG-A0 --half
    python metrice.py --task fps --save_path runs/densenet121 --half

model

val accuracy(train stage)

val accuracy(test stage)

val accuracy half(test stage)

FP32 FPS(batch_size=64)

FP16 FPS(batch_size=64)

mobilenetv2

0.74284

0.74340

0.74396

52.43

92.80

resnext50

0.80966

0.80966

0.80966

19.48

30.28

RepVGG-A0

0.73666

0.73666

0.73666

54.74

98.87

densenet121

0.77035

0.77148

0.77035

18.87

32.75

R-Drop实验:

训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --rdrop

训练resnext50:

python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_rdrop --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --rdrop

计算各个模型的指标:

python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_rdrop
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_rdrop
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_rdrop
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_rdrop

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_rdrop
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_rdrop
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_rdrop
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_rdrop

model

val accuracy

val accuracy(r-drop)

test accuracy

test accuracy(r-drop)

mobilenetv2

0.74340

0.75126

0.73784

0.73741

resnext50

0.80966

0.81134

0.82437

0.82092

ghostnet

0.77597

0.76698

0.76625

0.77012

efficientnet_v2_s

0.84166

0.85289

0.84460

0.85837

EMA实验:

训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd --ema

训练resnext50:

python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50 --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name resnext50 --config config/config.py --save_path runs/resnext50_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd

    python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s_ema --lr 1e-4 --Augment AutoAugment --epoch 150 \
    --pretrained --amp --warmup --imagenet_meanstd  --ema

计算各个模型的指标:

python metrice.py --task val --save_path runs/mobilenetv2
    python metrice.py --task val --save_path runs/mobilenetv2_ema
    python metrice.py --task val --save_path runs/resnext50
    python metrice.py --task val --save_path runs/resnext50_ema
    python metrice.py --task val --save_path runs/ghostnet
    python metrice.py --task val --save_path runs/ghostnet_ema
    python metrice.py --task val --save_path runs/efficientnet_v2_s
    python metrice.py --task val --save_path runs/efficientnet_v2_s_ema

    python metrice.py --task test --save_path runs/mobilenetv2
    python metrice.py --task test --save_path runs/mobilenetv2_ema
    python metrice.py --task test --save_path runs/resnext50
    python metrice.py --task test --save_path runs/resnext50_ema
    python metrice.py --task test --save_path runs/ghostnet
    python metrice.py --task test --save_path runs/ghostnet_ema
    python metrice.py --task test --save_path runs/efficientnet_v2_s
    python metrice.py --task test --save_path runs/efficientnet_v2_s_ema

model

val accuracy

val accuracy(ema)

test accuracy

test accuracy(ema)

mobilenetv2

0.74340

0.74958

0.73784

0.73870

resnext50

0.80966

0.81246

0.82437

0.82307

ghostnet

0.77597

0.77765

0.76625

0.77142

efficientnet_v2_s

0.84166

0.83998

0.84460

0.83986


pytorch-classifier v1.2 更新日志

  1. 新增export.py,支持导出(onnx, torchscript, tensorrt)模型.
  2. metrice.py支持onnx,torchscript,tensorrt的推理.
此处在predict.py中暂不支持onnx,torchscript,tensorrt的推理的推理,原因是因为predict.py中的热力图可视化没办法在onnx、torchscript、tensorrt中实现,后续单独推理部分会额外写一部分代码.
 在metrice.py中,onnx和torchscript和tensorrt的推理也不支持tsne的可视化,那么我在metrice.py中添加onnx,torchscript,tensorrt的推理的目的是为了测试fps和精度.
 所以简单来说,使用metrice.py最好还是直接用torch模型,torchscript和onnx和tensorrt的推理的推理模型后续会写一个单独的推理代码.
  1. main.py,metrice.py,predict.py,export.py中增加–device参数,可以指定设备.
  2. 优化程序和修复一些bug.
训练命令:
python main.py --model_name efficientnet_v2_s --config config/config.py --batch_size 128 --Augment AutoAugment --save_path runs/efficientnet_v2_s --device 0 \
--pretrained --amp --warmup --ema --imagenet_meanstd
GPU 推理速度测试 sh脚本:
batch_size=1 # 1 2 4 8 16 32 64
python metrice.py --task fps --save_path runs/efficientnet_v2_s --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --half --model_type torchscript --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --half --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type onnx --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --batch_size $batch_size
python export.py --save_path runs/efficientnet_v2_s --export tensorrt --simplify --half --batch_size $batch_size 
python metrice.py --task fps --save_path runs/efficientnet_v2_s --model_type tensorrt --half --batch_size $batch_size
CPU 推理速度测试 sh脚本:
python export.py --save_path runs/efficientnet_v2_s --export onnx --simplify --dynamic --device cpu
batch_size=1
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=2
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=4
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=8
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size
batch_size=16
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type torchscript --batch_size $batch_size
python metrice.py --task fps --save_path runs/efficientnet_v2_s --device cpu --model_type onnx --batch_size $batch_size

各导出模型在cpu和gpu上的fps实验:

实验环境:

System

CPU

GPU

RAM

Model

Ubuntu20.04

i7-12700KF

RTX-3090

32G DDR5 6400

efficientnet_v2_s

GPU

model

Torch FP32 FPS

Torch FP16 FPS

TorchScript FP32 FPS

TorchScript FP16 FPS

ONNX FP32 FPS

ONNX FP16 FPS

TensorRT FP32 FPS

TensorRT FP16 FPS

batch-size 1

93.77

105.65

233.21

260.07

177.41

308.52

311.60

789.19

batch-size 2

94.32

108.35

208.53

253.83

166.23

258.98

275.93

713.71

batch-size 4

95.98

108.31

171.99

255.05

130.43

190.03

212.75

573.88

batch-size 8

94.03

85.76

118.79

210.58

87.65

122.31

147.36

416.71

batch-size 16

61.93

76.25

75.45

125.05

50.33

69.01

87.25

260.94

batch-size 32

34.56

58.11

41.93

72.29

26.91

34.46

48.54

151.35

batch-size 64

18.64

31.57

23.15

38.90

12.67

15.90

26.19

85.47

CPU

model

Torch FP32 FPS

Torch FP16 FPS

TorchScript FP32 FPS

TorchScript FP16 FPS

ONNX FP32 FPS

ONNX FP16 FPS

TensorRT FP32 FPS

TensorRT FP16 FPS

batch-size 1

27.91

Not Support

46.10

Not Support

79.27

Not Support

Not Support

Not Support

batch-size 2

25.26

Not Support

24.98

Not Support

45.62

Not Support

Not Support

Not Support

batch-size 4

14.02

Not Support

13.84

Not Support

23.90

Not Support

Not Support

Not Support

batch-size 8

7.53

Not Support

7.35

Not Support

12.01

Not Support

Not Support

Not Support

batch-size 16

3.07

Not Support

3.64

Not Support

5.72

Not Support

Not Support

Not Support


pytorch-classifier v1.3 更新日志

  1. 增加repghost模型.
  2. 推理阶段把模型中的conv和bn进行fuse.
  3. 发现mnasnet0_5有点问题,暂停使用.
  4. torch.no_grad()更换成torch.inference_mode().

pytorch-classifier v1.4 更新日志

  1. predict.py支持检测灰度图,其读取后会检测是否为RGB通道,不是的话会进行转换.
  2. 更新readme.md.
  3. 修复一些bug.

Knowledge Distillation Experiment

为了测试知识蒸馏的可用性,基于CUB-200-2011百度网盘链接数据集进行实验.

stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD1 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/resnet50_admaw

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw

计算通过resnet50蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74116

0.74200

0.73483

0.73452

0.77012

0.76979

resnet50

0.78720

0.78744

0.77744

0.77670

0.81231

0.81162

teacher->resnet50

student->mobilenetv2

SoftTarget

0.77092

0.77179

0.75248

0.75191

0.77787

0.77752

teacher->resnet50

student->mobilenetv2

MGD

0.78888

0.78994

0.78390

0.78296

0.79940

0.79890

teacher->resnet50

student->mobilenetv2

AT

0.74789

0.74878

0.73870

0.73795

0.76324

0.76244

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.2 --teacher_path runs/ghostnet_admaw

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST
python metrice.py --task test --save_path runs/mobilenetv2_admaw_ST --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD
python metrice.py --task test --save_path runs/mobilenetv2_admaw_MGD --test_tta
python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT --test_tta

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74116

0.74200

0.73483

0.73452

0.77012

0.76979

ghostnet

0.77709

0.77756

0.76367

0.76277

0.78046

0.77958

teacher->ghostnet

student->mobilenetv2

SoftTarget

0.77878

0.77968

0.76108

0.76022

0.77916

0.77807

teacher->ghostnet

student->mobilenetv2

MGD

0.75632

0.75723

0.74688

0.74638

0.77357

0.77302

teacher->ghostnet

student->mobilenetv2

AT

0.74846

0.74945

0.73827

0.73782

0.76625

0.76534

由于SP蒸馏开启AMP时,kd_loss大概率会出现nan,所在SP蒸馏实验中,我们把所有模型都不开启AMP.

stduent为mobilenetv2,teacher为ghostnet.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练ghostnet:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnet_admaw
python metrice.py --task test --save_path runs/ghostnetadmaw --test_tta

知识蒸馏, ghostnet作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP
python metrice.py --task test --save_path runs/mobilenetv2_admaw_SP --test_tta

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74509

0.74568

0.73827

0.73761

0.76969

0.76903

ghostnet

0.77821

0.77881

0.75807

0.75708

0.77873

0.77805

teacher->ghostnet

student->mobilenetv2

SP

0.74733

0.74836

0.73267

0.73198

0.75893

0.75850

stduent为mobilenetv2,teacher为resnet50.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw
python metrice.py --task test --save_path runs/mobilenetv2_admaw --test_tta

普通训练resnet50:

python main.py --model_name resnet50 --config config/config.py --save_path runs/resnet50_admaw --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd

计算resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw
python metrice.py --task test --save_path runs/resnet50_admaw --test_tta

知识蒸馏, resnet50作为teacher, mobilenetv2作为student, 使用SP进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_SP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --warmup --imagenet_meanstd \
--kd --kd_method SP --kd_ratio 10.0 --teacher_path runs/resnet50_admaw

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74509

0.74568

0.73827

0.73761

0.76969

0.76903

resnet50

0.78720

0.78707

0.77400

0.77321

0.81231

0.81138

teacher->resnet50

student->mobilenetv2

SP

0.74116

0.74200

0.74042

0.73969

0.76840

0.76753

以下实验是通过训练好的自身模型再作为教师模型进行训练.

知识蒸馏, resnet50作为teacher, resnet50作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/resnet50_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/resnet50_admaw

计算通过resnet50蒸馏resnet50指标:

python metrice.py --task val --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self
python metrice.py --task test --save_path runs/resnet50_admaw_AT_self --test_tta

知识蒸馏, mobilenetv2作为teacher, mobilenetv2作为student, 使用AT进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 100 --teacher_path runs/mobilenetv2_admaw

计算通过mobilenetv2蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self
python metrice.py --task test --save_path runs/mobilenetv2_admaw_AT_self --test_tta

知识蒸馏, ghostnet作为teacher, ghostnet作为student, 使用AT进行蒸馏:

python main.py --model_name ghostnet --config config/config.py --save_path runs/ghostnet_admaw_AT_self --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method AT --kd_ratio 1000 --teacher_path runs/ghostnet_admaw

计算通过ghostnet蒸馏ghostnet指标:

python metrice.py --task val --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self
python metrice.py --task test --save_path runs/ghostnet_admaw_AT_self --test_tta

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74116

0.74200

0.73483

0.73452

0.77012

0.76979

teacher->mobilenetv2

student->mobilenetv2

AT

0.74677

0.74758

0.74430

0.74342

0.77012

0.76926

resnet50

0.78720

0.78744

0.77744

0.77670

0.81231

0.81162

teacher->resnet50

student->resnet50

AT

0.79057

0.79091

0.79165

0.79026

0.81102

0.81030

ghostnet

0.77709

0.77756

0.76367

0.76277

0.78046

0.77958

teacher->ghostnet

student->ghostnet

AT

0.78046

0.78080

0.77142

0.77069

0.78820

0.78742

在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2 --test_tta

普通训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算efficientnet_v2_s指标:

python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

计算通过efficientnet_v2_s蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta

model

val accuracy

val mpa

test accuracy

test mpa

test accuracy(TTA)

test mpa(TTA)

mobilenetv2

0.74116

0.74200

0.73483

0.73452

0.77012

0.76979

efficientnet_v2_s

0.84166

0.84191

0.84460

0.84441

0.86483

0.86484

teacher->efficientnet_v2_s

student->mobilenetv2

ST

0.76137

0.76209

0.75161

0.75088

0.77830

0.77715

teacher->efficientnet_v2_s

student->mobilenetv2

MGD

0.77204

0.77288

0.77529

0.77464

0.79337

0.79261

teacher->efficientnet_v2_s

student->mobilenetv2

MGD(EMA)

0.77204

0.77267

0.77744

0.77671

0.80284

0.80201

teacher->efficientnet_v2_s

student->mobilenetv2

MGD(RDrop)

0.77204

0.77288

0.77529

0.77464

0.79337

0.79261

teacher->efficientnet_v2_s

student->mobilenetv2

MGD(EMA,RDrop)

0.77204

0.77267

0.77744

0.77671

0.80284

0.80201

关于Knowledge Distillation的一些解释

实验解释:

  1. 对于AT和SP蒸馏方法,上述实验都是使用block3和block4的特征层进行蒸馏.
  2. MPA是平均类别精度,在类别不平衡的情况下非常有用,当类别基本平衡的情况下,跟accuracy差不多.
  3. 当蒸馏loss出现nan的时候请不要开启AMP,AMP可能会导致浮点溢出导致的nan.

目前支持的类型有:

Name

Method

paper

SoftTarget

logits

https://arxiv.org/pdf/1503.02531.pdf

MGD

features

https://arxiv.org/abs/2205.01529.pdf

SP

features

https://arxiv.org/pdf/1907.09682.pdf

AT

features

https://arxiv.org/pdf/1612.03928.pdf

蒸馏学习跟模型,参数,蒸馏的方法,蒸馏的层都有关系,效果不好需要自行调整,其中SP和AT都可以对模型中的四个block进行组合计算蒸馏损失具体代码在utils/utils_fit.py的fitting_distill函数中可以进行修改.