pytorch的java版怎么用 java pytorch模型_python

pytorch是深度学习训练的常用框架,其代码书写有一些可以学习的套路。这个系列的博客将总结pytorch构建深度学习网络并训练的几种套路。


目录

  • 数据准备
  • 模型构建
  • 方法1 Class
  • 方法2 Sequential
  • 损失函数与训练
  • 查看模型细节
  • 查看参数
  • 保存模型
  • 参考


数据准备

实验数据中的x是特征,y是标签,二者均维1维标量,方便对实验结果进行查看。

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

模型构建

方法1 Class

这里创建一个Net类,此类继承自nn.Module,因而在第一行需要写上super(Net,self).__init__()。模型中一共包括2个隐藏层,2个激活层,1个输出层。

class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden1 = nn.Linear(n_input,n_hidden)
        self.hidden2 = nn.Linear(n_hidden,n_hidden)
        self.predict = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.hidden1(input)
        out = F.relu(out)
        out = self.hidden2(out)
        out = F.sigmoid(out)
        out =self.predict(out)
        return out

方法2 Sequential

这里的模型model通过Sequential方法构建,这种方法相较于上一种方法更加简单。

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

n_input = 1
n_hidden = 10
n_output = 1
model = nn.Sequential(
    nn.Linear(n_input,n_hidden),
    nn.ReLU(),
    nn.Linear(n_hidden,n_hidden),
    nn.ReLU(),
    nn.Linear(n_hidden,n_output),
)
# loss_func = torch.nn.MSELoss()
# optimizer = torch.optim.SGD(model.parameters(),lr = 0.1)

损失函数与训练

损失函数一般都使用pytorch自己定义的一些损失函数,例如cross entropy等等。对于想要自己设计损失函数的情况,需要查看源码中的损失函数的写法,新的损失函数需要按照老损失函数进行定义。2种模型构建方法训练方法一致。

optimizer = torch.optim.SGD(net.parameters(),lr = 0.1)
loss_func = torch.nn.MSELoss()

for t in range(5000):
    prediction = net(x)
    loss = loss_func(prediction,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

训练结束后可以打印实验结果

plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss = %.4f' % loss.data, fontdict={'size': 20, 'color': 'red'})

pytorch的java版怎么用 java pytorch模型_数据_02

查看模型细节

查看参数

采用方法1构建的模型,可以调用list(net.named_parameters())查看参数。该指令可以把模型的参数按照层数存储起来。

pytorch的java版怎么用 java pytorch模型_pytorch的java版怎么用_03


⚠️该指令只能打印命了名的层,没命名的层无法打印,所以在模型构造的时候尽量给每一层命名。采用方法2构建的模型,模型是以list的形式存储的,可以像遍历列表那样遍历模型,并打印。Sequential格式进行构建的模型还可以打印中间结果。

pytorch的java版怎么用 java pytorch模型_python_04

保存模型

torch.save(net,'net.pkl')