PyTorch中的Variable变量详解

 更新时间:2020年01月07日 14:35:31   作者:Wei Ji  
今天小编就为大家分享一篇PyTorch中的Variable变量详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

一、了解Variable

顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。

具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。

【tensor 是一个多维矩阵】

用一个例子说明,Variable的定义:

import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)
 
print(tensor)
"""
 1 2
 3 4
[torch.FloatTensor of size 2x2]
"""
 
print(variable)
"""
Variable containing:
 1 2
 3 4
[torch.FloatTensor of size 2x2]
"""

注:tensor不能反向传播,variable可以反向传播。

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。

v_out.backward() # 模拟 v_out 的误差反向传递

print(variable.grad) # 初始 Variable 的梯度
'''
 0.5000 1.0000
 1.5000 2.0000
'''

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。

print(variable)  # Variable 形式
"""
Variable containing:
 1 2
 3 4
[torch.FloatTensor of size 2x2]
"""
 
print(variable.data) # 将variable形式转为tensor 形式
"""
 1 2
 3 4
[torch.FloatTensor of size 2x2]
"""
 
print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
 [ 3. 4.]]
"""

扩展

在PyTorch中计算图的特点总结如下:

autograd根据用户对Variable的操作来构建其计算图。

1、requires_grad

variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。

2、volatile

variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。

3、retain_graph

多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。

4、backward()

反向传播,求解Variable的梯度。放在中间缓存中。

以上这篇PyTorch中的Variable变量详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python 拷贝对象(深拷贝deepcopy与浅拷贝copy)

    Python 拷贝对象(深拷贝deepcopy与浅拷贝copy)

    Python中的对象之间赋值时是按引用传递的,如果需要拷贝对象,需要使用标准库中的copy模块。
    2008-09-09
  • 集调试共享及成本控制Prompt工具PromptLayer使用指南

    集调试共享及成本控制Prompt工具PromptLayer使用指南

    这篇文章主要介绍了集调试共享及成本控制Prompt工具PromptLayer使用指南,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • PyQt5 QListView 高亮显示某一条目的案例

    PyQt5 QListView 高亮显示某一条目的案例

    这篇文章主要介绍了PyQt5 QListView 高亮显示某一条目的案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python使用jupyter notebook查看ipynb文件过程解析

    Python使用jupyter notebook查看ipynb文件过程解析

    这篇文章主要介绍了Python使用jupyter notebook查看ipynb文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • python 元组的使用方法

    python 元组的使用方法

    这篇文章主要介绍了python 元组的使用方法,文中讲解非常细致,代码帮助大家更好的参考和学习,感兴趣的朋友可以了解下
    2020-06-06
  • Python基础教程之内置函数locals()和globals()用法分析

    Python基础教程之内置函数locals()和globals()用法分析

    这篇文章主要介绍了Python基础教程之内置函数locals()和globals()用法,结合实例形式分析了locals()和globals()函数的功能、使用方法及相关操作注意事项,需要的朋友可以参考下
    2018-03-03
  • Python提取JSON格式数据实战案例

    Python提取JSON格式数据实战案例

    这篇文章主要给大家介绍了关于Python提取JSON格式数据的相关资料, Python提供了内置的json模块,用于处理JSON数据,文中给出了详细的代码示例,需要的朋友可以参考下
    2023-07-07
  • POC漏洞批量验证程序Python脚本编写

    POC漏洞批量验证程序Python脚本编写

    这篇文章主要为大家介绍了POC漏洞批量验证程序Python脚本编写的完整示例代码,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步
    2022-02-02
  • 利用Python实现原创工具的Logo与Help

    利用Python实现原创工具的Logo与Help

    这篇文章主要给大家介绍了关于如何利用Python实现原创工具的Logo与Help的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考借鉴,下面来一起看看吧
    2018-12-12
  • win10下opencv-python特定版本手动安装与pip自动安装教程

    win10下opencv-python特定版本手动安装与pip自动安装教程

    这篇文章主要介绍了win10下opencv-python特定版本手动安装与pip自动安装教程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-03-03

最新评论