《深度学习的数学》(涌井良幸、涌井贞美著) 神经网络计算pytorch示例二_涌井良幸

涌井良幸、涌井贞美著的《深度学习的数学》这本书,浅显易懂。书中还用Excel示例(如下图)神经网络的计算,真是不错。但光有Excel示例还是有点欠缺的,如果有pytorch代码演示就更好了。

《深度学习的数学》(涌井良幸、涌井贞美著) 神经网络计算pytorch示例二_深度学习的数学_02

百度了半天在网上没找到别人写的,只好自己撸一个(使用python + pytorch),供同样在学习神经网络的初学者参考。

(注,这是书中5-6节:体验卷积神经网络误差反向传播法,数据是96个6x6的1、2和3,用平方误差的总和作为代价函数, 用 Sigmoid 函数作为激活函数

(书中4-4神经网络计算pytorch示例一请参考:https://blog.51cto.com/oldycat/8133220

(看这本书前建议可以先看立石贤吾著的《白话机器学习的数学》,再看这本书会变得很简单)

demo56.py:

import torch
import torch.nn as nn
import torch.optim as optimal
from torch import cosine_similarity

import demo56data as demo


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.activation = nn.Sigmoid()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Linear(3 * (2 * 2), 3)  # 输出层有3个神经元,对应数字0、1、2

        self.conv1.weight.data = demo.get_param_co().resize_(3, 1, 3, 3)  # 若用正态分布,注释此行
        self.conv1.bias.data = demo.get_param_co_bias()  # 若用正态分布,注释此行
        self.fc.weight.data = demo.get_param_op().resize_(3, 12)  # 若用正态分布,注释此行
        self.fc.bias.data = demo.get_param_o_bias()  # 若用正态分布,注释此行

    def forward(self, x):
        x = self.conv1(x)
        demo.print_x("zF=", x)
        x = self.activation(x)
        x = self.pool(x)
        demo.print_x("aF=", x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        demo.print_x("zO=", x)
        x = self.activation(x)
        demo.print_x("aO=", x)
        return x


# 创建模型实例
model = CNN()
# for param in model.parameters():
#     print(param)

# 定义损失函数和优化器
criterion = nn.MSELoss(reduction='sum')
optimizer = optimal.SGD(model.parameters(), lr=0.2)

# 转换输入数据为张量
train_data = demo.get_data()
train_labels = demo.get_result()

# 开始训练
num_epochs = 1000
for epoch in range(num_epochs):
    print("\nepoch=", epoch + 1)
    optimizer.zero_grad()
    outputs = model(train_data)

    loss = criterion(outputs, train_labels) / 2
    print("Loss: {:.4f}".format(loss.item()))
    loss.backward()
    optimizer.step()

    if (epoch + 1) == num_epochs or loss.item() < 0.5:
        print("Epoch [{}/{}], Loss: {:.4f}".format(epoch + 1, num_epochs, loss.item()))
        break

# 使用训练好的模型进行预测
model.eval()
print()
output = model(demo.get_test()).data
print(output.argmax(dim=1) + 1)

print("\n=======  比对全部结果  ======")
test_data = demo.get_data()
predictions = model(test_data)
result = (predictions.argmax(dim=1) + 1)
print()
print(result.data)
print("差异:")
print((demo.get_result2() - result).long())
print()
print("准确度:", (torch.round(
    cosine_similarity(result.unsqueeze(0), demo.get_result2().unsqueeze(0)).mean() * 10000) / 100).data.numpy(),
      "%")

demo56data.py

import torch


def get_param_co():
    return torch.tensor([[
        -1.277, -0.454, 0.358,
        1.138, -2.398, -1.664,
        -0.794, 0.899, 0.675
    ], [
        -1.274, 2.338, 2.301,
        0.649, -0.339, -2.054,
        -1.022, -1.204, -1.900
    ], [
        -1.869, 2.044, -1.290,
        -1.710, -2.091, -2.946,
        0.201, -1.323, 0.207
    ]])


def get_param_co_bias():
    return torch.tensor([-3.363, -3.176, -1.739])


def get_param_op():
    return torch.tensor([

        [[
            -0.276, 0.124,
            - 0.961, 0.718
        ], [
            -3.680, - 0.594,
            0.280, - 0.782

        ], [
            -1.475, - 2.010,
            - 1.085, - 0.188
        ]],

        [[
            0.010, 0.661,
            - 1.591, 2.189
        ], [
            1.728, 0.003,
            - 0.250, 1.898
        ], [
            0.238, 1.589,
            2.246, - 0.093
        ]],

        [[
            -1.322, - 0.218,
            3.527, 0.061
        ], [
            0.613, 0.218,
            - 2.130, - 1.678
        ], [
            1.236, - 0.486,
            - 0.144, - 1.235
        ]]

    ])


def get_param_o_bias():
    return torch.tensor([2.060, -2.746, -1.818])


def get_test():
    return (torch.tensor([
        [
            1.0, 1, 1, 1, 0, 0,
            1, 1, 0, 0, 1, 0,
            0, 0, 0, 0, 1, 0,
            0, 0, 0, 1, 1, 0,
            1, 1, 0, 0, 1, 0,
            1, 1, 1, 1, 0, 0], [

            0, 0, 1, 1, 1, 0,
            0, 1, 0, 0, 1, 1,
            0, 0, 0, 1, 1, 0,
            0, 0, 0, 0, 1, 0,
            0, 1, 0, 0, 1, 1,
            0, 0, 1, 1, 1, 0]]
    ).resize_(2, 1, 6, 6))


def print_x(name, x):
    if x.dim() > 3:
        print(name, end='')
        # for i in range(x.size()[0]):
        for j in range(x.size()[1]):
            print("\t[", end='')
            for k in range(x.size()[2]):
                print("", x[0, j, k, :].data.numpy(), end='')
            print("]\n  ", end='')
        print()
    elif x.dim() > 1:
        print(name, end='')
        print("", x[0, :].data.numpy(), end='')
        for i in range(x.size()[0]):
            if i > 0:
                print("\t\t", x[i, :].data.numpy(), end='')
        print()


def print_params(params):
    for param in params:
        if param.dim() > 1:
            for i in range(param.size()[0]):
                print('\t[', end='')
                print(param[i, 0].data.numpy(), end='')
                for j in range(param.size()[1]):
                    if j > 0:
                        print('\t', param[i, j].data.numpy(), end='')
                print('] ', end='')
            print()
        else:
            print('\t[', end='')
            print(param[0].data.numpy(), end='')
            for i in range(param.size()[0]):
                if i > 0:
                    print('\t', param[i].data.numpy(), end='')
            print('] ')
    print()


def get_data():
    return torch.tensor([[
        0.0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,  # 10
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 1, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0], [

        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 0, 0, 0,  # 20
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 0, 1, 0, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0], [

        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0], [

        0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0], [

        0, 1, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0], [

        0, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,  # 30
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        1, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 1], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        1, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        1, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        1, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        1, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 1], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 1], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 1, 1,
        0, 0, 0, 1, 1, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 0,
        1, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 1, 1,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 1], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 1, 1, 0, 0,
        1, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 1, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1,
        0, 0, 0, 1, 1, 1,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 1, 1,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1,
        0, 0, 0, 1, 1, 1,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 0, 1,
        0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 1, 1,
        0, 1, 0, 0, 0, 1,
        0, 0, 1, 1, 1, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 0, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        1, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 0, 1,
        0, 0, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 0,
        0, 0, 1, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        1, 0, 0, 0, 1, 0,
        0, 0, 1, 1, 1, 0,
        0, 0, 1, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        1, 1, 1, 1, 0, 0], [

        1, 1, 1, 1, 0, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 1, 1, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 1,
        0, 1, 1, 1, 1, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 1, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        0, 1, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 1, 1, 0, 1, 0,
        0, 0, 1, 1, 0, 0], [

        0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 1, 1,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 1, 0], [

        1, 1, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 0, 0], [

        1, 1, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 0, 0, 1, 1, 0,
        1, 1, 0, 0, 1, 0,
        1, 1, 1, 1, 0, 0], [

        0, 0, 1, 1, 1, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 1, 0,
        0, 1, 0, 0, 1, 1,
        0, 0, 1, 1, 1, 0]]
    ).resize_(96, 1, 6, 6)


def get_result():
    return torch.tensor([[
        1.0, 0.0, 0.0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        1, 0, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 1, 0], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1], [
        0, 0, 1]])


def get_result2():
    return torch.tensor([
        1.0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3
    ])

运行结果:

《深度学习的数学》(涌井良幸、涌井贞美著) 神经网络计算pytorch示例二_深度学习的数学_03

《深度学习的数学》(涌井良幸、涌井贞美著) 神经网络计算pytorch示例二_涌井良幸_04

《深度学习的数学》(涌井良幸、涌井贞美著) 神经网络计算pytorch示例二_涌井贞美_05