今天和大家分享一个开源项目,利用ResNet-18进行猫狗分类。整个项目基于Pytorch1.7实现。

项目地址为:https://gitee.com/cv365/cat-dog_classification

1.背景

给定25000张图片,其中一半图片内容为猫,另外一半图片内容为狗,如下图所示:

机器学习猫狗分类 猫狗分类网络训练过程_深度学习

想要训练一个分类器,用于分类图片中的动物是猫还是狗。

我们将这25000张图片分成2部分,第一部分为20000张,2种类别各10000张,用于训练;第二部分为5000张,2种类别各2500张,用于测试网络性能。

基于ResNet-18进行训练,论文中ResNet-18中最后的FC层输出为1000,将其改为2,因为我们要进行2分类任务。

除了最后的FC层外,其余层的权重均使用在ImageNet上预训练的权重,FC层的权重使用Pytorch中默认的初始化方法进行初始化。

2.代码结构

代码结构如下图所示:

机器学习猫狗分类 猫狗分类网络训练过程_卷积神经网络_02

共有4个.py文件:

DogCatDataset.py用于读取数据

prepare_data.py用于将25000张图片分为训练集和测试集,运行该程序生成上图中的data/newtrain文件夹和data/newtest文件夹

train.py用于训练网络

test.py用于测试训练好的网络,在运行时需要resnet18_Cat_Dog.pth文件

3.训练、测试方法

(1)从gitee上下载代码,下载地址:https://gitee.com/cv365/cat-dog_classification

(2)下载数据集,下载地址:https://pan.baidu.com/s/19eG-kbPifVfIRGcgS21gZA,提取码:hjag

(3)将下载完的数据集存放至[工程主目录]/data路径下,并在该路径下解压数据集

(4)运行prepare_data.py文件完成数据集划分,运行完成后,在data路径下会生成newtrainnewtest这2个文件夹,分别存放训练集和测试集

训练命令
python train.py

训练完成后,在工程主目录下会生成名为resnet18_Cat_Dog.pth的权重文件,推理时会读取该文件。

测试命令
python test.py

推理完成后会打印出推理的正确率。

若不想训练,而想直接测试,也可以下载训练好的权重,下载地址:https://pan.baidu.com/s/1DykBh0ht5URLzdludSVPNQ,提取码:s7h2