Tensor和Autograd

  • Tensor的数据结构
  • 自动求导Autograd
  • Variable
  • 计算图
  • 扩展Autograd
  • pytorch实现线性回归


Tensor可简单的认为是支持高效计算的数组,可以是标量、向量、矩阵或更高维的数组。Tensor与Numpy数组具有很高的相似性,彼此共享内存,所以遇到Tensor不支持的操作时,可以先将其转换为Numpy数组,处理后再转回Tensor,其转换开销很小,与Numpy不同的是,Pytorch的Tensor支持GPU加速。CPU tensor和GPU tensor之间的相互转换通过tensor.cuda和tensor.cpu的方法实现。

Tensor的数据结构

Tensor分为头信息区(Tensor)和存储区(Storage),信息区主要保存Tensor的形状、步长、数据类型等信息,真正的数据则保存为连续数组。不同的Tensor头信息一般不同,但却可能使用相同的Storage。

绝大部分操作并不修改Tensor的数据,只是修改了Tensor的头信息,这种做法更节省内存,同时提高了处理速度。此外,有些操作会导致Tensor不连续,这时需要调用tensor.contiguous方法将其变为连续数据。

pytorch的tensor经过linear层之后变成nan pytorch tensor grad_数组

自动求导Autograd

torch.autograd能够根据输入和前向传播过程,自动构建计算图,执行反向传播。

Variable

autograd中的核心数据结构是Variable,Variable封装了Tensor并记录对Tensor的操作用来构建计算图。Variable主要包含三个属性:

在pytorch 0.4.0版本更新后,Variable已经合并到了Tensor中,Tensor具有下文提到的Variable所有属性。创建Tensor时,只需设置属性requires_grad即可指定该Tensor完成autograd。

  • data:保存Variable所包含的Tensor
  • grad:保存data对应的梯度,形状与data一致
  • grad_fn: 指向一个Function,记录Variable的操作历史,用于构建计算图

计算图

autograd的底层采用了计算图,计算图作为现代深度学习框架的核心,为自动求导算法——反向传播,提供了理论支持。它是一种特殊的有向无环图,用于记录算子与变量之间的关系。一般用矩形表示算子,椭圆形表示变量。如表达式 z = wx + b, 其计算图如下所示:

pytorch的tensor经过linear层之后变成nan pytorch tensor grad_反向传播_02

wxb是叶子结点,这些节点通常由用户自己创建,不依赖于其他变量,z为根节点,是计算图的最终目标。有了计算图之后,即可利用反向传播自动完成链式求导,获得各个叶子节点的梯度。

在Pytorch实现中,Autograd会随着用户的操作,记录生成当前Variable的所有操作,并由此建立计算图,用户每进行一个操作,相应的计算图就会发生改变。更底层的实现中,图中记录了操作Function,每一个前向传播操作函数都有与之对应的反向传播函数用来计算输入的各个Variable梯度,这些函数名通常以Backward结尾。

扩展Autograd

目前,绝大多数函数都能够使用autograd实现反向传播,当需要自定义函数时,可借助pytorch中用于自定义张量操作函数的类torch.autograd.Function,通过定义前向传播和反向传播过程,完成函数的自动求导。

在实际应用中,自定义Function可被理解为自定义网络操作,算子扩充

# -*- coding: UTF-8 -*- 
""" 
create on 2021-06-15
@author: yang
"""
import torch
from torch.autograd import Function

'''
    自定义Tensor操作函数,通过实现forward和backward方法分别定义前向计算
    和反向求导的过程,从而实现自定义网络层整个动态计算图中的计算传递
'''
class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        print('========>Forward')
        output = w * x + b
        ctx.save_for_backward(w, x)     # 保存需要给backward的变量
        return output

    @staticmethod
    def backward(ctx, grad_output):
        print('==========>Backward')
        w, x = ctx.saved_tensors        # 获取保存的变量
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        print("garad_output: ", grad_output)
        return grad_w, grad_x, grad_b

if __name__ == '__main__':
    x = torch.rand(1)
    w = torch.rand(1, requires_grad=True)
    b = torch.rand(1, requires_grad=True)
    print('Start Forward!!!')
    z = MultiplyAdd.apply(x, w, b)
    print('Start Backward!!!')
    z.backward()

    print(x.data, w.data, b.data)
    print(x.grad, w.grad, b.grad)

pytorch实现线性回归

# -*- coding: utf-8 -*-
# create on 2021-06-23
# author: yang

import torch
from matplotlib import pyplot as plt

torch.manual_seed(1000)

def get_fake_data(bath_size = 8):
    x = torch.rand(bath_size, 1) * 20
    y = x * 2 + (1 + torch.randn(bath_size, 1)) * 3
    return x, y

if __name__ == '__main__':

    w = torch.rand(1, 1, requires_grad=True)
    b = torch.zeros(1, 1, requires_grad=True)

    lr = 0.001

    for ii in range(20000):
        x, y = get_fake_data()

        # forward
        y_pred = x.mm(w) + b.expand_as(y)
        loss = 0.5 * (y_pred - y) ** 2
        loss = loss.sum()

        # Manual backward
        # dloss = 1
        # dy_pred = dloss * (y_pred - y)
        # dw = x.t().mm(dy_pred)
        # db = dy_pred.sum()
        # w.data.sub_(lr * dw)
        # b.data.sub_(lr * db)

        # Auto backward()
        loss.backward()
        w.data.sub_(lr * w.grad.data)
        b.data.sub_(lr * b.grad.data)
        w.grad.data.zero_()
        b.grad.data.zero_()

        if ii % 1000 == 0:
            plt.clf()
            x = torch.arange(0, 20).view(-1, 1)
            x = x.float()
            y = x.mm(w) + b.expand_as(x)
            plt.plot(x.detach().numpy(), y.detach().numpy())

            x2, y2 = get_fake_data(bath_size=20)
            plt.scatter(x2.numpy(), y2.numpy())

            plt.xlim(0, 20)
            plt.ylim(0, 42)
            plt.show()
            plt.pause(1)

    print(w.data, b.data)