pytorch中with torch.no_grad():的用法实例

 更新时间:2022年03月11日 10:18:36   作者:这是一只小菜鸡  
最近在看别人写的代码,遇到经常使用with torch.no_grad(),所以下面这篇文章主要给大家介绍了关于pytorch中with torch.no_grad():用法的相关资料,需要的朋友可以参考下

1.关于with

with是python中上下文管理器,简单理解,当要进行固定的进入,返回操作时,可以将对应需要的操作,放在with所需要的语句中。比如文件的写入(需要打开关闭文件)等。

以下为一个文件写入使用with的例子。

        with open (filename,'w') as sh:    
            sh.write("#!/bin/bash\n")
            sh.write("#$ -N "+'IC'+altas+str(patientNumber)+altas+'\n')
            sh.write("#$ -o "+pathSh+altas+'log.log\n') 
            sh.write("#$ -e "+pathSh+altas+'err.log\n') 
            sh.write('source ~/.bashrc\n')          
            sh.write('. "/home/kjsun/anaconda3/etc/profile.d/conda.sh"\n')
            sh.write('conda activate python27\n')
            sh.write('echo "to python"\n')
            sh.write('echo "finish"\n')
            sh.close()

with后部分,可以将with后的语句运行,将其返回结果给到as后的变量(sh),之后的代码块对close进行操作。

2.关于with torch.no_grad():

在使用pytorch时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。而对于tensor的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。

以下分别为使用和不使用的情况:

(1)使用with torch.no_grad():

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))        
print(outputs)

运行结果:

Accuracy of the network on the 10000 test images: 55 %
tensor([[-2.9141, -3.8210,  2.1426,  3.0883,  2.6363,  2.6878,  2.8766,  0.3396,
         -4.7505, -3.8502],
        [-1.4012, -4.5747,  1.8557,  3.8178,  1.1430,  3.9522, -0.4563,  1.2740,
         -3.7763, -3.3633],
        [ 1.3090,  0.1812,  0.4852,  0.1315,  0.5297, -0.3215, -2.0045,  1.0426,
         -3.2699, -0.5084],
        [-0.5357, -1.9851, -0.2835, -0.3110,  2.6453,  0.7452, -1.4148,  5.6919,
         -6.3235, -1.6220]])

此时的outputs没有 属性。

(2)不使用with torch.no_grad():

而对应的不使用的情况

for data in testloader:
    images, labels = data
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
print(outputs)

结果如下:

Accuracy of the network on the 10000 test images: 55 %
tensor([[-2.9141, -3.8210,  2.1426,  3.0883,  2.6363,  2.6878,  2.8766,  0.3396,
         -4.7505, -3.8502],
        [-1.4012, -4.5747,  1.8557,  3.8178,  1.1430,  3.9522, -0.4563,  1.2740,
         -3.7763, -3.3633],
        [ 1.3090,  0.1812,  0.4852,  0.1315,  0.5297, -0.3215, -2.0045,  1.0426,
         -3.2699, -0.5084],
        [-0.5357, -1.9851, -0.2835, -0.3110,  2.6453,  0.7452, -1.4148,  5.6919,
         -6.3235, -1.6220]], grad_fn=<AddmmBackward>)

可以看到,此时有grad_fn=<AddmmBackward>属性,表示,计算的结果在一计算图当中,可以进行梯度反传等操作。但是,两者计算的结果实际上是没有区别的。

附:pytorch使用模型测试使用with torch.no_grad():

使用pytorch时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。而对于tensor的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))        
print(outputs)

运行结果:

Accuracy of the network on the 10000 test images: 55 %
tensor([[-2.9141, -3.8210,  2.1426,  3.0883,  2.6363,  2.6878,  2.8766,  0.3396,
         -4.7505, -3.8502],
        [-1.4012, -4.5747,  1.8557,  3.8178,  1.1430,  3.9522, -0.4563,  1.2740,
         -3.7763, -3.3633],
        [ 1.3090,  0.1812,  0.4852,  0.1315,  0.5297, -0.3215, -2.0045,  1.0426,
         -3.2699, -0.5084],
        [-0.5357, -1.9851, -0.2835, -0.3110,  2.6453,  0.7452, -1.4148,  5.6919,
         -6.3235, -1.6220]])

总结

到此这篇关于pytorch中with torch.no_grad():用法的文章就介绍到这了,更多相关pytorch中with torch.no_grad():内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python实现局域网远程控制电脑

    Python实现局域网远程控制电脑

    这篇文章主要为大家详细介绍了如何利用Python编写一个工具,可以实现远程控制局域网电脑关机,重启,注销等功能,感兴趣的小伙伴可以参考一下
    2024-12-12
  • Python函数参数操作详解

    Python函数参数操作详解

    这篇文章主要介绍了Python函数参数操作,结合实例形式详细分析了Python形参、实参、默认参数、关键字参数、可变参数、对参数解包以及获取参数个数等相关操作技巧,需要的朋友可以参考下
    2018-08-08
  • python经典练习百题之猴子吃桃三种解法

    python经典练习百题之猴子吃桃三种解法

    这篇文章主要给大家介绍了关于python经典练习百题之猴子吃桃三种解法的相关资料, Python猴子吃桃子编程是一个趣味性十足的编程练习,在这个练习中,我们将要使用Python语言来模拟一只猴子吃桃子的过程,需要的朋友可以参考下
    2023-10-10
  • python 实现IP子网计算

    python 实现IP子网计算

    这篇文章主要介绍了python 实现IP子网计算的方法,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-02-02
  • Python获取时间戳代码实例

    Python获取时间戳代码实例

    这篇文章主要介绍了Python获取时间戳代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python实现批量把SVG格式转成png、pdf格式的代码分享

    Python实现批量把SVG格式转成png、pdf格式的代码分享

    这篇文章主要介绍了Python实现批量把SVG格式转成png、pdf格式的代码分享,本文代码需要引用一个第三方模块cairosvg,需要的朋友可以参考下
    2014-08-08
  • python 将list转成字符串,中间用符号分隔的方法

    python 将list转成字符串,中间用符号分隔的方法

    今天小编就为大家分享一篇python 将list转成字符串,中间用符号分隔的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python使用WebSocket和SSE实现HTTP服务器消息推送方式

    Python使用WebSocket和SSE实现HTTP服务器消息推送方式

    本文介绍了两种实时数据获取的技术:WebSocket和SSE,WebSocket是全双工通信协议,支持双向通信,但需要专门定义数据协议,SSE是一种单工通信技术,基于HTTP的流式数据传输,客户端开发简单,但只能单工通信
    2024-11-11
  • Python实现Word文档转换为图片(JPG、PNG、SVG等常见格式)

    Python实现Word文档转换为图片(JPG、PNG、SVG等常见格式)

    将Word文档以图片形式导出,既能方便信息的分享,也能保护数据安全,避免被二次编辑,文本将介绍如何使用 Spire.Doc for Python 库在Python程序中实现Word到图片的批量转换,需要的朋友可以参考下
    2024-06-06
  • Python如何遍历JSON所有数据

    Python如何遍历JSON所有数据

    这篇文章主要介绍了Python如何遍历JSON所有数据问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-08-08

最新评论