在Colaboratory上运行PyTorch:从数据加载到模型训练的完整流程
在这篇文章中,我们将介绍如何在Google Colaboratory上使用PyTorch来解决一个具体的问题:手写数字识别。我们将从数据加载开始,一直到模型训练和评估的整个过程。通过本文,你将学会如何在Colaboratory环境中运行PyTorch,并且掌握训练深度学习模型的基本流程。
步骤一:准备数据集
首先,我们需要准备一个用于手写数字识别的数据集。这里我们将使用MNIST数据集,它包含了大量的手写数字图片和对应的标签。我们可以使用PyTorch内置的torchvision
库来加载MNIST数据集。
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
步骤二:构建模型
接下来,我们需要构建一个简单的卷积神经网络模型来进行手写数字识别。这里我们定义一个包含两个卷积层和两个全连接层的LeNet模型。
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet()
步骤三:模型训练
现在我们已经准备好数据集和模型,接下来就是训练我们的模型。在Colaboratory上,我们可以使用GPU来加速训练过程。
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 5
for epoch in range(num_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
步骤四:模型评估
最后,我们需要对训练好的模型进行评估,看看它在测试集上的表现如何。
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f'Accuracy on the test set: {100 * accuracy}%')
通过以上代码示例,我们展示了在Colaboratory上运行PyTorch的完整流程:从数据加载到模型训练和评估。希望通过本文的指导,你能更加熟练地使用PyTorch进行深度学习任务的处理。
时间安排
gantt
title PyTorch模型训练时间安排
dateFormat YYYY-MM-DD
section 训练过程
数据加载 :done, 2022-01-01, 1