PyTorch中clone()、detach()及相关扩展详解

 更新时间:2020年12月09日 10:02:05   作者:Chris_34  
这篇文章主要给大家介绍了关于PyTorch中clone()、detach()及相关扩展的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

clone() 与 detach() 对比

Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,

首先我们来打印出来clone()操作后的数据类型定义变化:

(1). 简单打印类型

import torch

a = torch.tensor(1.0, requires_grad=True)
b = a.clone()
c = a.detach()
a.data *= 3
b += 1

print(a) # tensor(3., requires_grad=True)
print(b)
print(c)

'''
输出结果:
tensor(3., requires_grad=True)
tensor(2., grad_fn=<AddBackward0>)
tensor(3.) # detach()后的值随着a的变化出现变化
'''

grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数。

detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。

注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。

import torch as t
a = t.tensor([1.0,2.0], requires_grad=True)
b = a.detach()
#c[:] = a.detach()
print(id(a))
print(id(b))
#140568935450520
140570337203616

显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断。

(2). clone()的梯度回传

detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。

而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
y = a**2
z = a ** 2+a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad)   # None. 中间variable,无grad
print(a.grad) 
'''
输出:
tensor(2.) 
None
tensor(7.) # 2*2+3=7
'''

使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。

通常如果原tensor的requires_grad=True,则:

  • clone()操作后的tensor requires_grad=True
  • detach()操作后的tensor requires_grad=False。
import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''

另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。

如下:

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_() #require_grad=True
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) 
'''
输出:
None
tensor(2.)
'''

了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要。

比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变。

需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致。

x = torch.rand(2, 2)
y = x.view(4)
x += 1
print(x)
print(y) # 也加了1

view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处

x = torch.rand(2, 2)
x_cp = x.clone().view(4)
x += 1
print(id(x))
print(id(x_cp))
print(x)
print(x_cp)
'''
140568935036464
140568935035816
tensor([[0.4963, 0.7682],
 [0.1320, 0.3074]])
tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) 
'''

另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。在上一篇中有总结

总结:

  • torch.detach() — 新的tensor会脱离计算图,不会牵扯梯度计算
  • torch.clone() — 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度。
    原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面两者中执行都会引发错误或者警告。
  • 共享数据内存是底层设计,并不能简单的通过直接打印tensor的id地址进行判断,需要在进行赋值或运算操作后打印比较数据的变化进行判断。
  • 复制操作可以根据实际需要进行结合使用。

引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。

像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同。

到此这篇关于PyTorch中clone()、detach()及相关扩展详解的文章就介绍到这了,更多相关PyTorch中clone()、detach()及相关扩展内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python+turtle绘制对称图形的示例代码

    Python+turtle绘制对称图形的示例代码

    这篇文章主要是带大家写一个利用Turtle库绘制一些有趣的对称图形,文中的示例代码讲解详细,对我们学习Python有一定帮助,感兴趣的可以了解一下
    2022-07-07
  • Numpy 数组索引的实现

    Numpy 数组索引的实现

    本文主要介绍了Numpy 数组索引的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • Python 利用邮件系统完成远程控制电脑的实现(关机、重启等)

    Python 利用邮件系统完成远程控制电脑的实现(关机、重启等)

    这篇文章主要介绍了Python 利用邮件系统完成远程控制电脑(关机、重启等),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-11-11
  • python中对%、~含义的解释

    python中对%、~含义的解释

    这篇文章主要介绍了python中对%、~含义的解释,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • 使用Python给PDF添加目录书签的实现方法

    使用Python给PDF添加目录书签的实现方法

    有时下载到扫描版的 PDF 是不带书签目录的,这样阅读起来很不方便,下面通过 python 实现一个半自动化添加书签目录的脚本,文中通过代码介绍的非常详细,具有一定的参考价值,需要的朋友可以参考下
    2023-10-10
  • 在Python中操作日期和时间之gmtime()方法的使用

    在Python中操作日期和时间之gmtime()方法的使用

    这篇文章主要介绍了在Python中操作日期和时间之gmtime()方法的使用,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-05-05
  • 读取json格式为DataFrame(可转为.csv)的实例讲解

    读取json格式为DataFrame(可转为.csv)的实例讲解

    今天小编就为大家分享一篇读取json格式为DataFrame(可转为.csv)的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 对Tensorflow中的矩阵运算函数详解

    对Tensorflow中的矩阵运算函数详解

    今天小编就为大家分享一篇对Tensorflow中的矩阵运算函数详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python文件特定行插入和替换实例详解

    python文件特定行插入和替换实例详解

    这篇文章主要介绍了python文件特定行插入和替换实例详解的相关资料,需要的朋友可以参考下
    2017-07-07
  • python kmeans聚类简单介绍和实现代码

    python kmeans聚类简单介绍和实现代码

    这篇文章主要为大家详细介绍了python kmeans聚类简单介绍和实现代码,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-02-02

最新评论