PyTorch 图像分类 单张图片推断

介绍

在本文中,我将向你介绍如何使用 PyTorch 进行图像分类,以及如何在单张图片上进行推断。无论你是一个刚入行的小白,还是一个经验丰富的开发者,这篇文章都能帮助你理解和应用该技术。

整体流程

下面是使用 PyTorch 进行图像分类和单张图片推断的整体流程的表格展示:

步骤 描述
步骤 1 导入必要的库和模块
步骤 2 加载预训练模型
步骤 3 加载并预处理单张图片
步骤 4 进行推断
步骤 5 显示推断结果

接下来,我会依次讲解每个步骤所需的代码和注释其意义。

步骤 1: 导入必要的库和模块

在开始之前,我们需要导入一些必要的库和模块,包括 PyTorch,以及其他用于图像处理的库。

import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
  • torch 是 PyTorch 库的主要模块。
  • torchvision.transforms 包含了一些用于图像处理的函数和类。
  • torchvision.models 包含了一些常用的预训练模型。
  • PIL.Image 是一个用于图像处理的库。

步骤 2: 加载预训练模型

在步骤 2 中,我们将加载一个预训练模型,例如 ResNet。首先,我们需要指定模型的名称和预训练参数。

model_name = 'resnet18'
pretrained = True

接下来,我们可以使用 torchvision.models 中的 resnet18 函数来加载预训练模型。我们还需要将模型设置为评估模式,以便在推断过程中使用。

model = models.resnet18(pretrained=pretrained)
model.eval()

步骤 3: 加载并预处理单张图片

在步骤 3 中,我们需要加载并预处理单张图片。首先,我们需要指定图片的路径。

image_path = 'path/to/image.jpg'

接下来,我们需要使用 PIL.Image 中的 open 函数打开图片,并使用 torchvision.transforms 中的 Compose 函数将图片转换为 PyTorch 张量。

image = Image.open(image_path)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_image = transform(image).unsqueeze(0)

在上述代码中,我们首先将图片调整为大小为 256x256 像素,并将其中心裁剪为 224x224 像素。然后,我们将图片转换为 PyTorch 张量,并进行归一化处理。

步骤 4: 进行推断

在步骤 4 中,我们将使用加载的模型对图片进行推断。我们只需要将预处理后的图片输入到模型中,然后获取模型的输出。

output = model(input_image)

步骤 5: 显示推断结果

在步骤 5 中,我们可以根据模型的输出来显示推断结果。通常,模型的输出是一个概率向量,表示输入图片属于每个类别的概率。我们可以从概率向量中选择最高的概率,并将其对应的类别作为推断结果。

_, predicted_index = torch.max(output, 1)
predicted_class = predicted_index.item()

接下来,我们可以使用预训练模型的类别列表来获取预测类别的标签。

class_labels = ['cat', 'dog', 'bird', 'horse', 'elephant']
predicted_label = class