1.为什么要使用hook()函数
-
Pytorch在进行完一次反向传播后,出于节省内存的考虑,只会存储叶子节点的梯度信息,并不会存储中间变量的梯度信息。然而有些时候我们又不得不使用中间变量的梯度信息完成某些工作(如获取中间层的梯度,获取中间层的特征图),这时候hook()函数就可以派上用场啦 -
hook()函数翻译成中文叫做钩子函数,这非常形象:我们的主任务是反向传播更新梯度,而钩子函数就是挂在主任务上的辅任务 - 主要有四种钩子函数:①
torch.Tensor.register_hook②torch.nn.Module.register_backward_hook③torch.nn.Module.register_forward_hook④torch.nn.Module.register_forward_pre_hook,接下来分别对他们进行介绍
2. torch.Tensor.register_hook
2.1 使用说明
- 为某个需要梯度的中间变量注册一个钩子
(Registers a backward hook),每次计算出中间变量的一个梯度后都会调用钩子函数 - 钩子函数无法改变传入的实参值,但可通过新建变量的方式对传入的实参值进行相关计算从而返回新的梯度以取代原始梯度值
- 值得注意的是,在实际使用的时候,
register_hook将一个函数作为形参,该函数以register_hook获取到的中间函数的梯度作为实参 - 使用完
register_hook()函数后要及时清除,清除时使用hook.remove()
2.2 示例代码
import torch
def doubleGrad(grad):
print('The gradient of y by using hook:', 2*grad)
x = torch.tensor([1.0], dtype=torch.float32, requires_grad=True)
y = x + 1
y.register_hook(doubleGrad)
z = x + y**2
z.backward()
print('The gradient of x:',x.grad)
print('The gradient of y:',y.grad)
y.remove() # 清除hook函数输出如下:
The gradient of y by using hook: tensor([8.])
The gradient of x: tensor([5.])
The gradient of y: None3. torch.nn.Module.register_backward_hook
3.1 使用说明
- 为网络中某个模块注册一个反向传播钩子,用于获得反向传播时该模块的梯度
- 反向传播每次经过该模块,该模块注册的钩子都会被调用
- 该函数注册的钩子函数具有如下的形式
hook(module, grad_input, grad_output) -> Tensor or None
/***
1.grad_input和grad_output都是tuple,形状分别与module的输入和输出对应
2.钩子函数不能直接改变传入的实参,但可返回一个新的梯度用以取代原有梯度- 钩子函数返回一个
handle,该handle可调用remove()函数来清除注册的钩子函数 - 对复杂的模块,使用这个钩子函数可能会出现bug,最好直接在想要获取梯度的变量上使用
torch.Tensor.register_hook
3.2 示例代码
官方文档说这个函数还有一些bug,并不推荐使用
4. torch.nn.Module.register_forward_hook(hook)
4.1 使用说明
- 为某模块注册一个前向传播的hook,每次该模块进行前向传播后,hook获得该模块前向传播的值
- 调用形式如下,注册的hook可以对输出进行修改,但是不会对前向传播产生影响,这是因为在该模块前向传播结束后才会调用hook
hook(module, input, output) -> None or modified output- 返回一个
handle,可以通过handle移除注册的hook,语法为handle.remove() - 前向传播的hook有个重要的应用:获得深度神经网络中间层特征
4.2 示例代码
import torch as t
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.shape)
net = LeNet()
handle = net.conv2.register_forward_hook(hook)
img = t.rand(1,1,32,32)
net(img)
# 用完hook后删除
handle.remove()5. torch.nn.Module.register_forward_pre_hook(hook)
5.1 使用说明
- 调用形式如下:
hook(module, input) -> None or modified input- 与
register_forward_hook的区别是,register_forward_pre_hook是在调用该hook的模块前向传播完成前执行,因此可对传递至该模块的输入进行修改,再将修改后的输入通过该模块向后传递

















