Python之torch.no_grad()函数使用和示例
torch.no_grad()函数使用和示例
torch.no_grad() 是 PyTorch 中的一个上下文管理器,用于在进入该上下文时禁用梯度计算。
这在你只关心评估模型,而不是训练模型时非常有用,因为它可以显著减少内存使用并加速计算。
当你在 torch.no_grad() 上下文管理器中执行张量操作时,PyTorch 不会为这些操作计算梯度。
这意味着不会在 .grad 属性中累积梯度,并且操作会更快地执行。
使用torch.no_grad()
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0], requires_grad=True)
# 使用 no_grad() 上下文管理器
with torch.no_grad():
y = x * 2
y.backward()
print(x.grad)
输出:
RuntimeError Traceback (most recent call last)
Cell In[52], line 11
7 with torch.no_grad():
8 y = x * 2
---> 11 y.backward()
13 print(x.grad)File E:\anaconda\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
387 if has_torch_function_unary(self):
388 return handle_torch_function(
389 Tensor.backward,
390 (self,),
(...)
394 create_graph=create_graph,
395 inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)File E:\anaconda\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors_, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
输出错误,因为使用了with torch.no_grad():。
不使用torch.no_grad()
import torch # 创建一个需要梯度的张量 x = torch.tensor([1.0], requires_grad=True) # 使用 no_grad() 上下文管理器 y = x * 2 y.backward() print(x.grad)
输出:
tensor([2.])
@torch.no_grad()
with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播
model.eval() with torch.no_grad(): ...
等价于
@torch.no_grad()
def eval():
...总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
在Python的Flask框架下使用sqlalchemy库的简单教程
这篇文章主要介绍了在Python的Flask框架下使用sqlalchemy库的简单教程,用来简洁地连接与操作数据库,需要的朋友可以参考下2015-04-04
Python抓新型冠状病毒肺炎疫情数据并绘制全国疫情分布的代码实例
在本篇文章里小编给大家整理了一篇关于Python抓新型冠状病毒肺炎疫情数据并绘制全国疫情分布的代码实例,有兴趣的朋友们可以学习下。2020-02-02
Django报错TemplateDoesNotExist的问题及解决
这篇文章主要介绍了Django报错TemplateDoesNotExist的问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教2023-08-08
Python使用SQLAlchemy操作Mysql数据库的操作示例
SQLAlchemy是Python的SQL工具包和对象关系映射(ORM)库,它提供了全套的企业级持久性模型,用于高效、灵活且优雅地与关系型数据库进行交互,这篇文章主要介绍了Python使用SQLAlchemy操作Mysql数据库,需要的朋友可以参考下2024-08-08


最新评论