基于PyTorch的鸢尾花分类
鸢尾花(Iris)数据集是机器学习中一个经典的案例,常用于分类和聚类的算法研究。本文将使用PyTorch构建一个简单的神经网络模型来识别鸢尾花的品种,并通过代码示例来演示具体实现。
鸢尾花数据集简介
鸢尾花数据集包含150个样本,每个样本有4个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。这些特征用于分类三种鸢尾花品种:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolor)和维吉尼亚鸢尾(Iris Virginica)。
代码实例
在以下代码中,我们将使用PyTorch构建和训练一个简单的神经网络模型。
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target.reshape(-1, 1)
# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)
encoder = OneHotEncoder(sparse=False)
y = encoder.fit_transform(y)
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 转换为张量
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.FloatTensor(y_test)
# 定义神经网络模型
class IrisModel(nn.Module):
def __init__(self):
super(IrisModel, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.fc2 = nn.Linear(10, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 模型、损失函数和优化器
model = IrisModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(1000):
model.train()
optimizer.zero_grad()
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
loss.backward()
optimizer.step()
# 测试模型
model.eval()
with torch.no_grad():
test_outputs = model(X_test_tensor)
predicted = torch.argmax(torch.sigmoid(test_outputs), dim=1)
actual = torch.argmax(y_test_tensor, dim=1)
accuracy = (predicted == actual).float().mean()
print(f'测试准确率:{accuracy:.2f}')
代码说明
- 首先,我们导入必要的库,并加载鸢尾花数据集。
- 接着,我们对输入特征进行标准化,同时将目标变量进行独热编码。
- 然后,构建一个包含两层全连接层的神经网络模型,并通过 ReLU 激活函数进行非线性变换。
- 接下来,我们训练模型,并在测试集上评估其性能。
类图
为了更好地理解代码结构,以下是使用 Mermaid 语法描述的类图:
classDiagram
class IrisModel {
+__init__()
+forward(x)
}
Gantt图
下面是训练和测试过程的甘特图,经过预处理、训练和测试的时间安排如下:
gantt
title 鸢尾花分类模型训练与测试
dateFormat YYYY-MM-DD
section 数据处理
数据加载 :a1, 2023-10-01, 1d
数据预处理 :after a1 , 1d
section 模型训练
模型定义 :a2, 2023-10-02, 1d
模型训练 :after a2 , 3d
section 模型测试
测试模型 :2023-10-05, 1d
结论
通过以上的代码示例和项目结构图,我们可以看到如何使用PyTorch进行简单的鸢尾花分类。这种方法不仅可以帮助我们理解机器学习的基本概念,还可以为更复杂的模型构建打下基础。希望本篇文章能为你的机器学习之旅提供一小部分帮助与启发。