copy() detach() clone()

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的
如果需要开辟新的存储地址而不是引用,可以用clone()进行深拷贝

区别

clone()

解释说明: 返回一个原张量的副本,同时不破坏计算图,它能够维持反向传播计算梯度,
并且两个张量不共享内存.一个张量上值的改变不影响另一个张量.

copy_()

解释说明: 比如x4.copy_(x2), 将x2的数据复制到x4,并且会
修改计算图,使得反向传播自动计算梯度时,计算出x4的梯度后
再继续前向计算x2的梯度. 注意,复制完成之后,两者的值的改变互不影响,
因为他们并不共享内存.

detach()

解释说明: 比如x4 = x2.detach(),返回一个和原张量x2共享内存的新张量x4,
两者的改动可以相互可见, 一个张量上的值的改动会影响到另一个张量.
返回的新张量和计算图相互独立,即新张量和计算图不再关联,
因此也无法进行反向传播计算梯度.即从计算图上把这个张量x2拆
卸detach下来,非常形象.

detach_()

解释说明: detach()的原地操作版本,功能和detach()类似.
比如x4 = x2.detach_(),其实x2和x4是同一个对象,返回的是self,
x2和x4具有相同的id()值.

copy.copy


例子

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], grad_fn=<CloneBackward>)
"""

grad_fn=<CloneBackward>表示clone后的返回值是个中间变量,因此支持梯度的回溯。

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone().requires_grad_(True)
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
"""

clone()操作后的tensor requires_grad=True
detach()操作后的tensor requires_grad=False

import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053,  0.3793])
True
None
False
False
'''