虽然pytorch可以自动求导,但是有时候一些操作是不可导的,这时候你需要自定义求导方式。也就是所谓的 "Extending torch.autograd"。
如果想要通过Function自定义一个操作,需要
①继承torch.autograd.Function这个类
from torch.autograd import Function
class LinearFunction(Function):
②实现forward()和backward()
属性(成员变量)
saved_tensors: 传给forward()的参数,在backward()中会用到。
needs_input_grad:长度为 :attr:num_inputs的bool元组,表示输出是否需要梯度。可以用于优化反向过程的缓存。
num_inputs: 传给函数 :func:forward的参数的数量。
num_outputs: 函数 :func:forward返回的值的数目。
requires_grad: 布尔值,表示函数 :func:backward 是否永远不会被调用。
成员函数
forward()
forward()可以有任意多个输入、任意多个输出,但是输入和输出必须是Variable。(官方给的例子中有只传入tensor作为参数的例子)
backward()
backward()的输入和输出的个数就是forward()函数的输出和输入的个数。其中,backward()输入表示关于forward()输出的梯度(计算图中上一节点的梯度),backward()的输出表示关于forward()的输入的梯度。在输入不需要梯度时(通过查看needs_input_grad参数)或者不可导时,可以返回None。
ctx is a context object that can be used to stash information for backward computation
ctx可以利用
save_for_backward
来保存tensors,在backward阶段可以进行获取例1
z
import torchfrom torch import nn
from torch.autograd import Function
import torch
class Exp(Function):
@staticmethod
def forward(ctx, input):
result = torch.exp(input)
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
x = torch.rand(4,3,5,5)
exp = Exp.apply # Use it by calling the apply method:
output = exp(x)
print(output.shape)
自定义的forward和backward要用静态方法,网上也有别的人写成def forward(self, input_):这种形式,但是这种写法快要被Pytorch淘汰了
例2
import torchfrom torch import nn
from torch.autograd import Function
import torch
class MyReLU(Function):
@staticmethod
def forward(ctx, input_):
# 在forward中,需要定义MyReLU这个运算的forward计算过程
# 同时可以保存任何在后向传播中需要使用的变量值
ctx.save_for_backward(input_) # 将输入保存起来,在backward时使用
output = input_.clamp(min=0) # relu就是截断负数,让所有负数等于0
return output
@staticmethod
def backward(ctx, grad_output):
# 根据BP算法的推导(链式法则),dloss / dx = (dloss / doutput) * (doutput / dx)
# dloss / doutput就是输入的参数grad_output、
# 因此只需求relu的导数,在乘以grad_output
input_, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input_ < 0] = 0 # 上诉计算的结果就是左式。即ReLU在反向传播中可以看做一个通道选择函数,所有未达到阈值(激活值<0)的单元的梯度都为0
return grad_input
x = torch.rand(4,3,5,5)
myrelu = MyReLU.apply # Use it by calling the apply method:
output = myrelu(x)
print(output.shape)
例3
import torchfrom torch.autograd import Function
from torch.autograd import gradcheck
class LinearFunction(Function):
# 创建torch.autograd.Function类的一个子类
# 必须是staticmethod
@staticmethod
# 第一个是ctx,第二个是input,其他是可选参数。
# ctx在这里类似self,ctx的属性可以在backward中调用。
# 自己定义的Function中的forward()方法,所有的Variable参数将会转成tensor!因此这里的input也是tensor.在传入forward前,autograd engine会自动将Variable unpack成Tensor。
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias) # 将Tensor转变为Variable保存到ctx中
output = input @ weight.t()
if bias is not None:
output += bias.unsqueeze(0).expand_as(output) #unsqueeze(0) 扩展处第0维
# expand_as(tensor)等价于expand(tensor.size()), 将原tensor按照新的size进行扩展
return output
@staticmethod
def backward(ctx, grad_output):
# grad_output为反向传播上一级计算得到的梯度值
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# 分别代表输入,权值,偏置三者的梯度
# 判断三者对应的Variable是否需要进行反向求导计算梯度
if ctx.needs_input_grad[0]:
grad_input = grad_output @ weight # 复合函数求导,链式法则
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t() @ input #复合函数求导,链式法则
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
linear = LinearFunction.apply
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = torch.randn(20,20,requires_grad=True).double()
weight = torch.randn(20,20,requires_grad=True).double()
bias = torch.randn(20,requires_grad=True).double()
test = gradcheck(LinearFunction.apply, (input,weight,bias), eps=1e-6, atol=1e-4)
print(test) # 没问题的话输出True
ctx.needs_input_grad作为一个boolean型的表示也可以用来控制每一个input是否需要计算梯度,e.g., ctx.needs_input_grad[0] = False,表示forward里的第一个input不需要梯度,若此时我们return这个位置的梯度值的话,为None即可
import torch
from torch import autograd
class MyFunc(autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.clone()
@staticmethod
def backward(ctx, gO):
# Error during the backward pass
raise RuntimeError("Some error in backward")
return gO.clone()
def run_fn(a):
out = MyFunc.apply(a)
return out.sum()
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()
with autograd.detect_anomaly():
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()
Function与Module的差异与应用场景
Function与Module都可以对pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有十分重要的不同:
- Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络
- Function需要定义三个方法:__init__, forward, backward(需要自己写求导公式);Module:只需定义__init__和forward,而backward的计算由自动求导机制构成