基于PyTorch的鸢尾花分类

鸢尾花(Iris)数据集是机器学习中一个经典的案例,常用于分类和聚类的算法研究。本文将使用PyTorch构建一个简单的神经网络模型来识别鸢尾花的品种,并通过代码示例来演示具体实现。

鸢尾花数据集简介

鸢尾花数据集包含150个样本,每个样本有4个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。这些特征用于分类三种鸢尾花品种:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolor)和维吉尼亚鸢尾(Iris Virginica)。

代码实例

在以下代码中,我们将使用PyTorch构建和训练一个简单的神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder

# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target.reshape(-1, 1)

# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)
encoder = OneHotEncoder(sparse=False)
y = encoder.fit_transform(y)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为张量
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.FloatTensor(y_test)

# 定义神经网络模型
class IrisModel(nn.Module):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 模型、损失函数和优化器
model = IrisModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(1000):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)
    loss.backward()
    optimizer.step()

# 测试模型
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    predicted = torch.argmax(torch.sigmoid(test_outputs), dim=1)
    actual = torch.argmax(y_test_tensor, dim=1)
    accuracy = (predicted == actual).float().mean()

print(f'测试准确率:{accuracy:.2f}')

代码说明

  • 首先,我们导入必要的库,并加载鸢尾花数据集。
  • 接着,我们对输入特征进行标准化,同时将目标变量进行独热编码。
  • 然后,构建一个包含两层全连接层的神经网络模型,并通过 ReLU 激活函数进行非线性变换。
  • 接下来,我们训练模型,并在测试集上评估其性能。

类图

为了更好地理解代码结构,以下是使用 Mermaid 语法描述的类图:

classDiagram
    class IrisModel {
        +__init__()
        +forward(x)
    }

Gantt图

下面是训练和测试过程的甘特图,经过预处理、训练和测试的时间安排如下:

gantt
    title 鸢尾花分类模型训练与测试
    dateFormat  YYYY-MM-DD
    section 数据处理
    数据加载          :a1, 2023-10-01, 1d
    数据预处理        :after a1  , 1d
    section 模型训练
    模型定义          :a2, 2023-10-02, 1d
    模型训练          :after a2  , 3d
    section 模型测试
    测试模型          :2023-10-05, 1d

结论

通过以上的代码示例和项目结构图,我们可以看到如何使用PyTorch进行简单的鸢尾花分类。这种方法不仅可以帮助我们理解机器学习的基本概念,还可以为更复杂的模型构建打下基础。希望本篇文章能为你的机器学习之旅提供一小部分帮助与启发。