图像分类通用测试代码

设备选择

运用生成器的格式,选择GPU其中哪片"cuda:0" 或者cpu "cpu",然后,输出使用的设备。

device = torch.device("cuda:0" if torch.cuda.is_avalible() else "cpu")
print("using {} device.".formate(device))

图片转换操作

定义字典形式的data_transform,运用transforms进行图片转换起到图片增强效果。其中,有"train""val"两个键,注意这里数据增强只是在训练集中,在测试集中,直接缩减大小,所以分成俩组。键值transforms.Compose中参数是列表格式[],

data_transform = {
	"train":transforms.Compose([transforms.RandomResizeCrop(224,224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Narmolize((0.5,0.5,0.5)(0.5,0.5,0.5))])
	"val":transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Narmolize((0.5,0.5,0.5)(0.5,0.5,0.5))])
}

其中,transforms的相关用法进行详细展开。


transforms的二十二个方法

一、 裁剪——Crop

1.随机裁剪:transforms.RandomCrop

class torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’)
功能:依据给定的size随机裁剪
参数:
size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)
padding-(sequence or int, optional),此参数是设置填充多少个pixel。
当为int时,图像上下左右均填充int个,例如padding=4,则上下左右均填充4个pixel,若为3232,则会变成4040。
当为sequence时,若有2个数,则第一个数表示左右扩充多少,第二个数表示上下的。当有4个数时,则为左,上,右,下。
fill- (int or tuple) 填充的值是什么(仅当填充模式为constant时有用)。int时,各通道均填充该值,当长度为3的tuple时,表示RGB通道需要填充的值。
padding_mode- 填充模式,这里提供了4种填充模式,1.constant,常量。2.edge 按照图片边缘的像素值来填充。3.reflect,暂不了解。 4. symmetric,暂不了解。

2.中心裁剪:transforms.CenterCrop

class torchvision.transforms.CenterCrop(size)
功能:依据给定的size从中心裁剪
参数:
size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)

3.随机长宽比裁剪 transforms.RandomResizedCrop

class torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
功能:随机大小,随机长宽比裁剪原始图片,最后将图片resize到设定好的size
参数:
size- 输出的分辨率
scale- 随机crop的大小区间,如scale=(0.08, 1.0),表示随机crop出来的图片会在的0.08倍至1倍之间。
ratio- 随机长宽比设置
interpolation- 插值的方法,默认为双线性插值(PIL.Image.BILINEAR)

4.上下左右中心裁剪:transforms.FiveCrop

class torchvision.transforms.FiveCrop(size)
功能:对图片进行上下左右以及中心裁剪,获得5张图片,返回一个4D-tensor
参数:
size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)

5.上下左右中心裁剪后翻转: transforms.TenCrop

class torchvision.transforms.TenCrop(size, vertical_flip=False)
功能:对图片进行上下左右以及中心裁剪,然后全部翻转(水平或者垂直),获得10张图片,返回一个4D-tensor。
参数:
size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)
vertical_flip (bool) - 是否垂直翻转,默认为flase,即默认为水平翻转

二、翻转和旋转——Flip and Rotation

6.依概率p水平翻转transforms.RandomHorizontalFlip

class torchvision.transforms.RandomHorizontalFlip(p=0.5)
功能:依据概率p对PIL图片进行水平翻转
参数:
p- 概率,默认值为0.5

7.依概率p垂直翻转transforms.RandomVerticalFlip

class torchvision.transforms.RandomVerticalFlip(p=0.5)
功能:依据概率p对PIL图片进行垂直翻转
参数:
p- 概率,默认值为0.5

8.随机旋转:transforms.RandomRotation

class torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None)
功能:依degrees随机旋转一定角度
参数:
degress- (sequence or float or int) ,若为单个数,如 30,则表示在(-30,+30)之间随机旋转
若为sequence,如(30,60),则表示在30-60度之间随机旋转
resample- 重采样方法选择,可选 PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC,默认为最近邻
expand- ?
center- 可选为中心旋转还是左上角旋转

三、图像变换

9.resize:transforms.Resize

class torchvision.transforms.Resize(size, interpolation=2)
功能:重置图像分辨率
参数:
size- If size is an int, if height > width, then image will be rescaled to (size * height / width, size),所以建议size设定为h*w
interpolation- 插值方法选择,默认为PIL.Image.BILINEAR

10.标准化:transforms.Normalize

class torchvision.transforms.Normalize(mean, std)
功能:对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc

11.转为tensor:transforms.ToTensor

class torchvision.transforms.ToTensor
功能:将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
注意事项:归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。

12.填充:transforms.Pad

class torchvision.transforms.Pad(padding, fill=0, padding_mode=‘constant’)
功能:对图像进行填充
参数:
padding-(sequence or int, optional),此参数是设置填充多少个pixel。
当为int时,图像上下左右均填充int个,例如padding=4,则上下左右均填充4个pixel,若为3232,则会变成4040。
当为sequence时,若有2个数,则第一个数表示左右扩充多少,第二个数表示上下的。当有4个数时,则为左,上,右,下。
fill- (int or tuple) 填充的值是什么(仅当填充模式为constant时有用)。int时,各通道均填充该值,当长度为3的tuple时,表示RGB通道需要填充的值。
padding_mode- 填充模式,这里提供了4种填充模式,1.constant,常量。2.edge 按照图片边缘的像素值来填充。3.reflect,? 4. symmetric,?

13.修改亮度、对比度和饱和度:transforms.ColorJitter

class torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
功能:修改修改亮度、对比度和饱和度

14.转灰度图:transforms.Grayscale

class torchvision.transforms.Grayscale(num_output_channels=1)
功能:将图片转换为灰度图
参数:
num_output_channels- (int) ,当为1时,正常的灰度图,当为3时, 3 channel with r == g == b

15.线性变换:transforms.LinearTransformation()

class torchvision.transforms.LinearTransformation(transformation_matrix)
功能:对矩阵做线性变化,可用于白化处理! whitening: zero-center the data, compute the data covariance matrix
参数:
transformation_matrix (Tensor) – tensor [D x D], D = C x H x W

16.仿射变换:transforms.RandomAffine

class torchvision.transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)
功能:仿射变换

17.依概率p转为灰度图:transforms.RandomGrayscale

class torchvision.transforms.RandomGrayscale(p=0.1)
功能:依概率p将图片转换为灰度图,若通道数为3,则3 channel with r == g == b

18.将数据转换为PILImage:transforms.ToPILImage

class torchvision.transforms.ToPILImage(mode=None)
功能:将tensor 或者 ndarray的数据转换为 PIL Image 类型数据
参数:
mode- 为None时,为1通道, mode=3通道默认转换为RGB,4通道默认转换为RGBA

19.transforms.Lambda

Apply a user-defined lambda as a transform.
暂不了解,待补充。

四、对transforms操作,使数据增强更灵活

PyTorch不仅可设置对图片的操作,还可以对这些操作进行随机选择、组合

20.transforms.RandomChoice(transforms)

功能:从给定的一系列transforms中选一个进行操作,randomly picked from a list

21.transforms.RandomApply(transforms, p=0.5)

功能:给一个transform加上概率,以一定的概率执行该操作

22.transforms.RandomOrder

功能:将transforms中的操作顺序随机打乱

路径拼接

运用os库指令,

  • os.path.join() 拼接路径,将目录与文件名整合在一起。
  • os.path.abspath() 获取绝对路径(完整路径),os.path.abspath无法获取指定文件的绝对路径,而是需要加文件路径os.path.abspath(path)。
  • os.path.getcwd() 返回当前工作目录。os.getcwd()返回的是当前目录并不是指脚本所在的目录,而是所运行脚本的目录。
  • os.path.listdir()函数获得指定目录中的内容。
  • os.path.basename()去掉目录路径,返回文件名。
  • os.path.dirname()去掉文件名,返回目录路径。
  • os.path.split() 返回目录路径和文件名的元组
  • os.path.getatime() 返回文件最近的访问时间
  • os.path.getctime() #返回文件的创建时间
  • os.path.getmtime() #返回文件的修改时间
  • os.path.getsize() #返回文件的大小单位为字节
  • os.path.exists() #指定路径(文件或目录)是否存在
  • os.path.isfile() #指定的路径是否为一个文件
  • os.path.samefile() #两个路径名是否指向同一个文件

assert()用法: 先计算expression表达式的值,如果计算结果为真,继续运行下面的程序;如果计算结果为假,则程序终止运行。在这里如果路径没有,即为假,输出后面句子终止。

data_root = os.path.abspath(os.path.join(os.path.getcwd(),"../.."))
image_root = os.path.join(data_root,"data_set","flower_data")
assert os.path.exist(image_path), "{} path does not exist.".formate(image_path)

数据导入

ImageFolder是一个通用的数据加载器,它要求以特殊格式来组织数据集的训练、验证或者测试图片。假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名。
它主要有四个参数:

  1. root:在root指定的路径下寻找图片
  2. transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
  3. target_transform:对label的转换
  4. loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象。

成员变量:

  • self.classes --用一个list保存类名
  • self.class_to_idx -- 类名对应的索引
  • self.imgs -- 保存(img-path, class) tuple的list。

DataLoader数据读取,主要用来将自定义的数据读取的输出或者PyTorch已有的数据读取的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此有点承上启下的作用,比较重要。又称为数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
输入参数

  • 1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
  • 2、batch_size,根据具体情况设置即可。
  • 3、shuffle,一般在训练数据中会采用。
  • 4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
  • 5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。
  • 6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。
  • 7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
  • 8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。
  • 9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
train_dataset = Datasets.ImageFolder(root=os.path.join(image_root,"train"), transform=data_transform["train"])
train_num = len(train_dataset)
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0 ,8 ])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size ,shuffle=True, num_workers=nw)
val_dataset = Datasets.ImageFolder(root=os.path.join(image_root,"val"), transform=data_transform["val"])
val_num = len(train_dataset)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
print("using {} image for training, {} image for validation.".format(train_num,val_num))

生成JSON格式文件

JSON (JavaScript Object Notation)指的是 JavaScript 对象表示法。JSON 是轻量级的数据存储格式,与开发语言无关。首先一个花括号{},整个代表一个对象,同时里面是一种Key-Value的存储形式,它还有不同的数据类型来区分。
优点:易于人的阅读和编写,易于程序解析与生产
json.dumps() 是把python对象转换成json对象的一个过程,生成的是字符串。
参数:

  • obj:转化成json的对象。
  • sort_keys =True:是告诉编码器按照字典排序(a到z)输出。如果是字典类型的python对象,就把关键字按照字典排序。
  • indent:参数根据数据格式缩进显示,读起来更加清晰。
  • separators:是分隔符的意思,参数意思分别为不同dict项之间的分隔符和dict项内key和value之间的分隔符,把:和,后面的空格都除去了。
flower_list = train_datasets.class_to_idx
cla_list =dict((val,key) for key,val in flower_list.items())
json_str = json.dumps(cla_list,indent=4)
with open ('class_indices.json','w') as json_file:
	json_file.write(json_str)

总代码

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    data_transform = { "train":transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
"val":transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(),"../.."))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    assert os.path.exists(image_path),"{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transform["train"])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val,key) for key,val in flower_list.items())
    json_str = json.dumps(cla_dict,indent=4)
    with open('class_indices.json','w') as json_file:
        json_file.write(json_str)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size >1 else 0, 8])
    print('Using {} dataloadet workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=True,num_workers=nw)
    print("using {} images for training, {} images fot validation.".format(train_num,val_num))
    net = AlexNet(num_classes=5,init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    save_path = './AlexNet.pth'
    best_acc = 0.0
    for epoch in range(10):
        net.train()
        running_loss = 0.0
        t1 = time.perf_counter()
        for step, data in enumerate(train_loader, start=0):
            images,labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            rate = (step + 1) / len(train_loader)
            a ="*" * int(rate * 50)
            b ="." * int((1-rate) * 50)
            print("\rtrain loss: {:^3.0f}%[{}->{}]{:.f}".format(int(rate * 100), a, b, loss), end="")
        print()
        print(time.perf_counter()-t1)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            for val_data in validate_loader:
                val_images,val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += (predict_y == val_labels.to(device)).sum().item()
            val_accurate = acc / val_num
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
            print('[epoch %d] train_loss: %.3f test_accuracy;%.3f' % (epoch + 1, running_loss / step, val_accurate))
    print('Finished Training')