基于PyTorch,使用预训练的GoogLeNet实现UC-Merced数据集分类
- 数据集准备
- 定义网络
- 使用GPU
- 训练结果
数据集准备
数据集样本量不大:UC-Merced数据集及介绍 由于torchvision中并没有UC-Merced数据集,因此要自己提前下载,作为自己的数据集使用。
本文首先制作数据集的List文件(索引),然后用Dataset类导入。
import torch
import torch.nn as nn
import torch.utils as utils
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
super(MyDataset, self).__init__()
fh = open(txt, 'r')
imgs = []
for line in fh:
# 移除字符串首尾的换行符, 以tab为分隔符 将字符串分成
line = line.strip('\n')
words = line.split('\t')
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transforms.Compose([transforms.Resize((image_size, image_size)),
transform])
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
# 调用定义的loader方法
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
train_dataset = MyDataset(txt=root + 'Train List.txt', transform=transforms.ToTensor())
test_dataset = MyDataset(txt=root + 'Test List.txt', transform=transforms.ToTensor())
这里一定要重写__getitem__(self, index),len(self)两个方法。
txt文件中记录了每个图片的存储地址(第一列)和标签(第二列)。
定义网络
使用自己写的CNN分类,效果不好,测试正确率仅30%左右!?调参未见提高。因此转而先用torchvision中自带的pre-trained网络来做。
加载预训练的模型,可以先print(net)查看网络的结构。
去除网络最后的那个全连接层,自己增加self.classifier部分,使得输出为21类,且便于部分训练。报错(mat1 与 mat2 不匹配等等),可以多print来改。
用PyTorch推荐的.to(device)将张量复制到GPU上。
class Net(nn.Module):
def __init__(self, model):
super(Net, self).__init__()
self.GolN_layer = nn.Sequential(*list(model.children())[:-1])
self.classifier = nn.Sequential(nn.Linear(in_features=1024, out_features=1024),
nn.ReLU(inplace=True),
nn.Linear(in_features=1024, out_features=21)
)
def forward(self, x):
# print('before:', x.size())
x = self.GolN_layer(x)
# print('googled:', x.size())
x = x.view(-1, 1024)
# print('before claasified:', x.size())
x = self.classifier(x)
return x
gnet_pretrained = models.googlenet(pretrained=True)
mynet = Net(gnet_pretrained)
mynet = mynet.to(device)
使用GPU
在训练时,除上面定义实例,还要把张量迁移到GPU上。
for batch_idx, (data, target) in enumerate(train_loader):
mynet.train()
inp, label = data.to(device), target.to(device)
optimizer.zero_grad()
outp = mynet(inp)
loss = criterion(outp, label)
loss.backward()
optimizer.step()
训练结果
训练周期: 15 [0/1784 (0%)] Loss: 0.273342 训练正确率: 100.00% 校验正确率: 92.17%
训练周期: 15 [400/1784 (22%)] Loss: 0.023752 训练正确率: 95.79% 校验正确率: 92.17%
训练周期: 15 [800/1784 (45%)] Loss: 0.521515 训练正确率: 95.52% 校验正确率: 88.70%
训练周期: 15 [1200/1784 (67%)] Loss: 0.958793 训练正确率: 95.35% 校验正确率: 93.04%
训练周期: 15 [1600/1784 (90%)] Loss: 0.035471 训练正确率: 95.45% 校验正确率: 89.57%
Current Training epoch: 15, Current Testing Acc: 95.50%
-----------------------------------
训练周期: 16 [0/1784 (0%)] Loss: 0.037701 训练正确率: 100.00% 校验正确率: 90.43%
训练周期: 16 [400/1784 (22%)] Loss: 0.143321 训练正确率: 96.04% 校验正确率: 93.04%
训练周期: 16 [800/1784 (45%)] Loss: 0.045256 训练正确率: 96.02% 校验正确率: 93.04%
训练周期: 16 [1200/1784 (67%)] Loss: 0.407427 训练正确率: 96.76% 校验正确率: 89.57%
训练周期: 16 [1600/1784 (90%)] Loss: 0.040848 训练正确率: 96.82% 校验正确率: 95.65%
训练周期: 17 [0/1784 (0%)] Loss: 0.104182 训练正确率: 100.00% 校验正确率: 93.04%
训练周期: 17 [400/1784 (22%)] Loss: 0.457662 训练正确率: 95.79% 校验正确率: 90.43%
训练周期: 17 [800/1784 (45%)] Loss: 0.293050 训练正确率: 96.52% 校验正确率: 89.57%
训练周期: 17 [1200/1784 (67%)] Loss: 0.053018 训练正确率: 96.84% 校验正确率: 94.78%
训练周期: 17 [1600/1784 (90%)] Loss: 0.067211 训练正确率: 96.51% 校验正确率: 88.70%
训练周期: 18 [0/1784 (0%)] Loss: 0.022396 训练正确率: 100.00% 校验正确率: 93.04%
训练周期: 18 [400/1784 (22%)] Loss: 0.780139 训练正确率: 97.77% 校验正确率: 92.17%
训练周期: 18 [800/1784 (45%)] Loss: 1.022786 训练正确率: 97.26% 校验正确率: 89.57%
训练周期: 18 [1200/1784 (67%)] Loss: 0.057382 训练正确率: 97.34% 校验正确率: 92.17%
训练周期: 18 [1600/1784 (90%)] Loss: 0.026848 训练正确率: 97.13% 校验正确率: 90.43%
训练周期: 19 [0/1784 (0%)] Loss: 0.106134 训练正确率: 100.00% 校验正确率: 87.83%
训练周期: 19 [400/1784 (22%)] Loss: 0.646356 训练正确率: 97.77% 校验正确率: 90.43%
训练周期: 19 [800/1784 (45%)] Loss: 0.143579 训练正确率: 98.26% 校验正确率: 94.78%
训练周期: 19 [1200/1784 (67%)] Loss: 0.352177 训练正确率: 97.92% 校验正确率: 89.57%
训练周期: 19 [1600/1784 (90%)] Loss: 0.040774 训练正确率: 98.00% 校验正确率: 89.57%
训练周期: 20 [0/1784 (0%)] Loss: 0.155777 训练正确率: 100.00% 校验正确率: 90.43%
训练周期: 20 [400/1784 (22%)] Loss: 0.032036 训练正确率: 97.28% 校验正确率: 90.43%
训练周期: 20 [800/1784 (45%)] Loss: 0.219169 训练正确率: 96.52% 校验正确率: 93.04%
训练周期: 20 [1200/1784 (67%)] Loss: 1.495021 训练正确率: 97.01% 校验正确率: 90.43%
训练周期: 20 [1600/1784 (90%)] Loss: 0.022227 训练正确率: 96.76% 校验正确率: 93.91%
Current Training epoch: 20, Current Testing Acc: 93.00%
-----------------------------------
End of Training...