线性模型基础,引自视频:刘二大人——《PyTorch深度学习实践》完结合集(p5)
要求
- 绘制各种优化算法的曲线进行比较
- 以训练次数为x,loss为y进行绘制
代码如下:
点击查看代码import torch
import matplotlib.pyplot as plt
# 原始数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # 第一个1为输入维度,第二个1为输出维度
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# 使用不同的优化方法进行训练,str为训练方法
def train(model, optimizer, str):
loss_list = []
criterion = torch.nn.MSELoss(size_average=False)
for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
# print(epoch, loss.item())
loss_list.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 绘制曲线
epoch = range(len(loss_list))
plt.plot(epoch, loss_list, label=str)
# SGD曲线
model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train(model, optimizer, 'SGD')
# Adam曲线
model = LinearModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, 'Adam')
# Adagrad曲线
model = LinearModel()
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)
train(model, optimizer, 'Adagrad')
# Adamax曲线
model = LinearModel()
optimizer = torch.optim.Adamax(model.parameters(), lr=0.01)
train(model, optimizer, 'Adamax')
# ASGD曲线
model = LinearModel()
optimizer = torch.optim.ASGD(model.parameters(), lr=0.01)
train(model, optimizer, 'ASGD')
# LBFGS曲线
# model = LinearModel()
# optimizer = torch.optim.LBFGS(model.parameters(), lr=0.01)
# train(model, optimizer, 'LBFGS')
# RMSprop
model = LinearModel()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01)
train(model, optimizer, 'RMSprop')
# Rprop曲线
model = LinearModel()
optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
train(model, optimizer, 'Rprop')
plt.legend()
plt.show()
# print('w=', model.linear.weight.item())
# print('b=', model.linear.bias.item())
#
x_test = torch.Tensor([4.0])
y_test = model(x_test)
print('y_test=', y_test.data)
运行结果如下:
y_test= tensor([7.9975])
总结&反思:
- SGD速度较快,更改训练次数后,优化函数均可收敛,猜测为适用范围不同,待日后验证
- LBFGS方法会报错(暂未解决)
- 代码冗余,方法无法直接引用,需多次复制粘贴(暂未解决)