pytorch 禁止/允许计算局部梯度的操作

 更新时间:2021年05月12日 09:05:31   作者:Answerlzd  
这篇文章主要介绍了pytorch 禁止/允许计算局部梯度的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

一、禁止计算局部梯度

torch.autogard.no_grad: 禁用梯度计算的上下文管理器。

当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True

两种禁用方法:

将不用计算梯度的变量放在with torch.no_grad()里

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
Out[12]:False

使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度

>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
Out[13]:False

二、禁止后允许计算局部梯度

torch.autogard.enable_grad :允许计算梯度的上下文管理器

在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.

用法和上面类似:

使用with torch.enable_grad()允许计算梯度

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   with torch.enable_grad():
...     y = x * 2
>>> y.requires_grad
Out[14]:True
 
>>> y.backward()  # 计算梯度
>>> x.grad
Out[15]: tensor([2.])

在禁止计算梯度下调用被允许计算梯度的函数,结果可以计算梯度

>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2
 
>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad
 
Out[16]:True

三、是否计算梯度

torch.autograd.set_grad_enable()

可以作为一个函数使用:

>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
Out[17]:False
 
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
Out[18]:True
 
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
Out[19]:False

总结:

单独使用这三个函数时没有什么,但是若是嵌套,遵循就近原则。

x = torch.tensor([1.], requires_grad=True)
 
with torch.enable_grad():
    torch.set_grad_enabled(False)
    y = x * 2
    print(y.requires_grad)
Out[20]: False
 
torch.set_grad_enabled(True)
with torch.no_grad():
    z = x * 2
    print(z.requires_grad)
Out[21]:False

补充:pytorch局部范围内禁用梯度计算,no_grad、enable_grad、set_grad_enabled使用举例

在这里插入图片描述 在这里插入图片描述

原文及翻译

Locally disabling gradient computation
在局部区域内关闭(禁用)梯度的计算.
The context managers torch.no_grad(), torch.enable_grad(), 
and torch.set_grad_enabled() are helpful for locally disabling 
and enabling gradient computation. See Locally disabling gradient 
computation for more details on their usage. These context 
managers are thread local, so they won't work if you send 
work to another thread using the threading module, etc.
上下文管理器torch.no_grad()、torch.enable_grad()和
torch.set_grad_enabled()可以用来在局部范围内启用或禁用梯度计算.
在Locally disabling gradient computation章节中详细介绍了
局部禁用梯度计算的使用方式.这些上下文管理器具有线程局部性,
因此,如果你使用threading模块来将工作负载发送到另一个线程,
这些上下文管理器将不会起作用.

no_grad   Context-manager that disabled gradient calculation.
no_grad   用于禁用梯度计算的上下文管理器.
enable_grad  Context-manager that enables gradient calculation.
enable_grad  用于启用梯度计算的上下文管理器.
set_grad_enabled  Context-manager that sets gradient calculation to on or off.
set_grad_enabled  用于设置梯度计算打开或关闭状态的上下文管理器.

例子1

Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001A2E55A8870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]], requires_grad=True)
>>> b = a * 2
>>> b
tensor([[ 0.5648, -0.7430,  1.8176, -3.5202],
        [-0.3612,  4.1874,  2.0812, -3.5303],
        [ 2.2433,  1.6879,  0.3567,  1.3718]], grad_fn=<MulBackward0>)
>>> b.requires_grad
True
>>> b.grad
__main__:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
>>> print(b.grad)
None
>>> a.requires_grad
True
>>> a.grad
>>> print(a.grad)
None
>>>
>>> with torch.no_grad():
...     c = a * 2
...
>>> c
tensor([[ 0.5648, -0.7430,  1.8176, -3.5202],
        [-0.3612,  4.1874,  2.0812, -3.5303],
        [ 2.2433,  1.6879,  0.3567,  1.3718]])
>>> c.requires_grad
False
>>> print(c.grad)
None
>>> a.grad
>>>
>>> print(a.grad)
None
>>> c.sum()
tensor(6.1559)
>>>
>>> c.sum().backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>>
>>> b.sum()
tensor(6.1559, grad_fn=<SumBackward0>)
>>> b.sum().backward()
>>>
>>>
>>> a.grad
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])
>>> a.requires_grad
True
>>>
>>>

例子2

Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000002109ABC8870>
>>>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]], requires_grad=True)
>>> a.requires_grad
True
>>>
>>> with torch.set_grad_enabled(False):
...     b = a * 2
...
>>> b
tensor([[ 0.5648, -0.7430,  1.8176, -3.5202],
        [-0.3612,  4.1874,  2.0812, -3.5303],
        [ 2.2433,  1.6879,  0.3567,  1.3718]])
>>> b.requires_grad
False
>>>
>>> with torch.set_grad_enabled(True):
...     c = a * 3
...
>>> c
tensor([[ 0.8472, -1.1145,  2.7263, -5.2804],
        [-0.5418,  6.2810,  3.1219, -5.2954],
        [ 3.3649,  2.5319,  0.5350,  2.0576]], grad_fn=<MulBackward0>)
>>> c.requires_grad
True
>>>
>>> d = a * 4
>>> d.requires_grad
True
>>>
>>> torch.set_grad_enabled(True)  # this can also be used as a function
<torch.autograd.grad_mode.set_grad_enabled object at 0x00000210983982C8>
>>>
>>> # 以函数调用的方式来使用
>>>
>>> e = a * 5
>>> e
tensor([[ 1.4119, -1.8574,  4.5439, -8.8006],
        [-0.9030, 10.4684,  5.2031, -8.8257],
        [ 5.6082,  4.2198,  0.8917,  3.4294]], grad_fn=<MulBackward0>)
>>> e.requires_grad
True
>>>
>>> d
tensor([[ 1.1296, -1.4859,  3.6351, -7.0405],
        [-0.7224,  8.3747,  4.1625, -7.0606],
        [ 4.4866,  3.3759,  0.7133,  2.7435]], grad_fn=<MulBackward0>)
>>>
>>> torch.set_grad_enabled(False) # 以函数调用的方式来使用
<torch.autograd.grad_mode.set_grad_enabled object at 0x0000021098394C48>
>>>
>>> f = a * 6
>>> f
tensor([[  1.6943,  -2.2289,   5.4527, -10.5607],
        [ -1.0836,  12.5621,   6.2437, -10.5908],
        [  6.7298,   5.0638,   1.0700,   4.1153]])
>>> f.requires_grad
False
>>>
>>>
>>>

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • Python中的类与对象之描述符详解

    Python中的类与对象之描述符详解

    这篇文章主要介绍了Python中的描述符详解,属于Python学习过程中类与对象的基本知识,需要的朋友可以参考下
    2015-03-03
  • Python异常处理之常见异常类型绝佳实践详解

    Python异常处理之常见异常类型绝佳实践详解

    这篇文章主要为大家介绍了Python异常处理之常见异常类型绝佳实践详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-09-09
  • Python中的yield全方位解读

    Python中的yield全方位解读

    这篇文章主要介绍了Python中的yield全方位解读,在 Python 中,使用了 yield 的函数被称为生成器,跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器,需要的朋友可以参考下
    2023-08-08
  • 浅谈python标准库--functools.partial

    浅谈python标准库--functools.partial

    这篇文章主要介绍了python标准库--functools.partial,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • Python实现上传Minio和阿里Oss文件

    Python实现上传Minio和阿里Oss文件

    这篇文章主要介绍了如何通过Python上传Minio和阿里OSS文件,文中的示例代码介绍得很详细,对我们的工作和学习都有一定的价值,感兴趣的小伙伴可以了解一下
    2021-12-12
  • Python去除、替换字符串空格的处理方法

    Python去除、替换字符串空格的处理方法

    这篇文章主要介绍了Python去除、替换字符串空格的处理方法,去除字符串空格有两种方法,一种是 .replace(' old ',' new '),第二种方法也很简单,需要的朋友可以参考下
    2018-04-04
  • python生成器用法实例详解

    python生成器用法实例详解

    这篇文章主要介绍了python生成器用法,结合实例形式详细分析了Python生成器相关原理、创建、使用方法及操作注意事项,需要的朋友可以参考下
    2019-11-11
  • python 实现文件的递归拷贝实现代码

    python 实现文件的递归拷贝实现代码

    今天翻电脑时突然发现有个存了很多照片和视频的文件夹,想起来是去年换手机(流行的小5)时拷出来的。看了几张照片,往事又一幕幕的浮现在脑海,好吧,我是个感性的人
    2012-08-08
  • python爬虫之教你如何爬取地理数据

    python爬虫之教你如何爬取地理数据

    这篇文章主要介绍了python爬虫之教你如何爬取地理数据,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好的帮助,需要的朋友可以参考下
    2021-04-04
  • Python实现简单的四则运算计算器

    Python实现简单的四则运算计算器

    相信大家在学习数据结构时,就学习了简单四则运算表达式求解的一个算法,可惜一直没有自己动手实现过这个算法。最近重拾数据结构与算法,恰巧又正在用Python比较频繁,所幸就用它来实现这个算法,虽然网上有很多代码,不过作为一个学习者,还是应当亲自动手实现。
    2016-11-11

最新评论