PyTorch官方Tutorials
跟着PyTorch官方Tutorials码的,便于理解自己稍有改动代码并添加注释,IDE用的jupyter notebook
优化模型参数
有了模型和数据以后就可以训练了 通过利用数据优化参数来验证和测试模型
训练模型是一个迭代的过程 在每一次迭代(epoch)中 模型对输出进行一次猜测 计算猜测的错(loss) 收集和参数相关的错误的导数(见上一节) 并使用梯度下降对这些参数进行优化
进一步了解这个过程 观看视频backpropgation from 3Blue1Brown.
优化前的必要代码
从Datasets & Dataloaders和Bulid Model两节中加载代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
超参数
超参数是可以让你控制模型优化过程的可调整参数
不同的超参数值可影响模型训练和收敛速度 更多关于超参数调参
为了训练定义以下超参数
- Epochs 即在数据集上迭代的次数
- Batch Size 神经网络的参数更新前传入神经网络的样本数量
- Learning Rate 在每一个batch/epoch中参数更新的程度 小了学习速度会变慢 大了可能在学习过程中出现不可预知的现象
learning_rate = 1e-3
batch_szie = 64
epochs = 10
优化循环
一旦设定好超参数 我们就可以用一个优化循环训练并优化模型 每一个优化循环叫做一个epoch
每个epoch包含两个主要部分:
- 训练循环 在训练数据集上的循环 并尝试收敛到最佳参数
- 验证/测试循环 在测试数据集上的循环来检验模型表现是否改善
简要地熟悉一下训练循环中使用的一些概念。跳转到查看优化循环的完整实现
损失函数
当面对一些训练数据的时候 未训练的网络有可能无法正确预测 损失函数度量取得的结果和目标值之间的误差程度 训练过程中我们优化的就是损失函数 为了计算loss 先用输入数据做出预测然后和正确的数据标签进行比较
常用的loss函数包括
- nn.MSELoss(Mean Square Error) 用于回归任务
- nn.NLLLoss(Negative Log Likelihood) 用于fenlei
- nn.CrossEntropyLoss 结合了 nn.LogSoftmax和nn.NLLLoss
将模型的输出logits传入nn.CrossEntropyLoss 正则化logits并计算预测误差
# 初始化loss函数
loss_fn = nn.CrossEntropyLoss()
优化器
优化是在训练这一步中调整模型参数来减小模型误差的过程 优化算法决定了这个过程是如何完成的(在这个例子里使用Stohastic Gradient Descent)
所有的优化罗辑都封装在optimizer
中
这里使用SGD优化器 此外 pytorch中还有很多其他的优化器诸如 ADAM RMSProp 根据模型和数据来具体选择
通过传入模型中需要被训练的参数来初始化优化器 然后 传入学习率
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WqHkfxoP-1625217477316)(attachment:image.png)]
在训练循环中 优化以三步进行:
- 调用optimizer.zero_grad()来重置模型参数梯度 梯度默认会累加 为了防止重复计算 特别在每次迭代中清零
- 通过调用loss.backward()后向传播预测误差 票易通日常存储了loss中每个参数的梯度
- 一旦得到了梯度 调用optimizer.step()使用后向传播得到的梯度来调整参数
完整实现
定义train_loop来循环优化代码 test_loop评估模型在测试数据上的表现
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# 计算预测和损失
pred = model(X)
#print(pred)
loss = loss_fn(pred, y)
# 后向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch*len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loop中:
- enumerate用于给可迭代对象标出索引 即batch
- size = len(dataloader.dataset) 参数为dataloader.dataset 而不是dataloader dataloader为num_batches
- 循环中先两步 pred=model(x) loss=loss_fn(pred,y)
- 循环中后三步 optimizer.zero_grad() loss.backward() optimizer.step()
def test_loop(dataloader,model,loss_fn):
size=len(dataloader.dataset)
num_batches=len(dataloader)
test_loss,correct=0,0
with torch.no_grad():
for X,y in dataloader:
pred=model(X)
test_loss+=loss_fn(pred,y).item()
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
test_loss/=num_batches
correct/=size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
test_loop中:
- loss函数计算得到的不是数字 要先调用item()
- pred出来是 经过神经网络的tensor
- 关于限定输出精度
初始化loss函数和优化器 传入 train_loop 和 test_loop
可以增加epochs来让模型更加优化
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader,model,loss_fn,optimizer)
test_loop(test_dataloader,model,loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 6.275609 [ 0/60000]
loss: 6.168862 [ 6400/60000]
loss: 5.964859 [12800/60000]
loss: 5.773824 [19200/60000]
loss: 5.461776 [25600/60000]
loss: 4.683303 [32000/60000]
loss: 4.824953 [38400/60000]
loss: 4.399492 [44800/60000]
loss: 4.707501 [51200/60000]
loss: 3.621667 [57600/60000]
Test Error:
Accuracy: 30.9%, Avg loss: 4.041796
Epoch 2
-------------------------------
loss: 4.483642 [ 0/60000]
loss: 4.351563 [ 6400/60000]
loss: 3.972307 [12800/60000]
loss: 3.827237 [19200/60000]
loss: 4.043344 [25600/60000]
loss: 3.517235 [32000/60000]
loss: 3.886236 [38400/60000]
loss: 3.817020 [44800/60000]
loss: 4.296854 [51200/60000]
loss: 2.978844 [57600/60000]
Test Error:
Accuracy: 41.8%, Avg loss: 3.652236
Epoch 3
-------------------------------
loss: 4.169164 [ 0/60000]
loss: 4.003831 [ 6400/60000]
loss: 3.545149 [12800/60000]
loss: 3.330601 [19200/60000]
loss: 3.725735 [25600/60000]
loss: 2.887691 [32000/60000]
loss: 3.344828 [38400/60000]
loss: 3.088535 [44800/60000]
loss: 3.371561 [51200/60000]
loss: 2.630687 [57600/60000]
Test Error:
Accuracy: 50.5%, Avg loss: 3.087908
Epoch 4
-------------------------------
loss: 3.314708 [ 0/60000]
loss: 3.212036 [ 6400/60000]
loss: 2.757336 [12800/60000]
loss: 2.843891 [19200/60000]
loss: 3.640403 [25600/60000]
loss: 2.574634 [32000/60000]
loss: 3.171757 [38400/60000]
loss: 2.906542 [44800/60000]
loss: 3.238701 [51200/60000]
loss: 2.501581 [57600/60000]
Test Error:
Accuracy: 51.1%, Avg loss: 2.960645
Epoch 5
-------------------------------
loss: 3.149806 [ 0/60000]
loss: 3.076319 [ 6400/60000]
loss: 2.606539 [12800/60000]
loss: 2.740151 [19200/60000]
loss: 3.574028 [25600/60000]
loss: 2.449526 [32000/60000]
loss: 3.079580 [38400/60000]
loss: 2.810920 [44800/60000]
loss: 3.157782 [51200/60000]
loss: 2.422556 [57600/60000]
Test Error:
Accuracy: 51.6%, Avg loss: 2.881615
Epoch 6
-------------------------------
loss: 3.047944 [ 0/60000]
loss: 2.995542 [ 6400/60000]
loss: 2.515550 [12800/60000]
loss: 2.674568 [19200/60000]
loss: 3.520339 [25600/60000]
loss: 2.367092 [32000/60000]
loss: 3.016242 [38400/60000]
loss: 2.746255 [44800/60000]
loss: 3.101627 [51200/60000]
loss: 2.366339 [57600/60000]
Test Error:
Accuracy: 52.0%, Avg loss: 2.825696
Epoch 7
-------------------------------
loss: 2.976136 [ 0/60000]
loss: 2.939208 [ 6400/60000]
loss: 2.453444 [12800/60000]
loss: 2.600638 [19200/60000]
loss: 3.161774 [25600/60000]
loss: 2.004015 [32000/60000]
loss: 2.540197 [38400/60000]
loss: 2.131633 [44800/60000]
loss: 2.471781 [51200/60000]
loss: 2.057839 [57600/60000]
Test Error:
Accuracy: 57.9%, Avg loss: 2.264959
Epoch 8
-------------------------------
loss: 2.224742 [ 0/60000]
loss: 2.217500 [ 6400/60000]
loss: 2.009215 [12800/60000]
loss: 2.406226 [19200/60000]
loss: 2.637452 [25600/60000]
loss: 1.753832 [32000/60000]
loss: 2.384584 [38400/60000]
loss: 2.012477 [44800/60000]
loss: 2.422630 [51200/60000]
loss: 2.003280 [57600/60000]
Test Error:
Accuracy: 58.4%, Avg loss: 2.199475
Epoch 9
-------------------------------
loss: 2.142583 [ 0/60000]
loss: 2.140580 [ 6400/60000]
loss: 1.950389 [12800/60000]
loss: 2.352949 [19200/60000]
loss: 2.579938 [25600/60000]
loss: 1.703772 [32000/60000]
loss: 2.337671 [38400/60000]
loss: 1.968013 [44800/60000]
loss: 2.375862 [51200/60000]
loss: 1.962950 [57600/60000]
Test Error:
Accuracy: 58.9%, Avg loss: 2.152021
Epoch 10
-------------------------------
loss: 2.090175 [ 0/60000]
loss: 2.091544 [ 6400/60000]
loss: 1.891107 [12800/60000]
loss: 2.240966 [19200/60000]
loss: 2.236654 [25600/60000]
loss: 1.534315 [32000/60000]
loss: 2.068044 [38400/60000]
loss: 1.734532 [44800/60000]
loss: 1.799461 [51200/60000]
loss: 1.695580 [57600/60000]
Test Error:
Accuracy: 59.3%, Avg loss: 1.720667
Done!