Pytorch 使用tensor特定条件判断索引
torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index
想要实现numpy中where()的功能,可以借助nonzero()
对应numpy中的where()操作效果:
补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法
detach
detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来
需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度
import torch as t a = t.ones(10,) b = a.detach() print(b) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
那么这个函数有什么作用?
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法
a = A(input) a = detach() b = B(a) loss = criterion(b, target) loss.backward()
来看一个实际的例子:
import torch as t x = t.ones(1, requires_grad=True) x.requires_grad #True y = t.ones(1, requires_grad=True) y.requires_grad #True x = x.detach() #分离之后 x.requires_grad #False y = x+y #tensor([2.]) y.requires_grad #我还是True y.retain_grad() #y不是叶子张量,要加上这一行 z = t.pow(y, 2) z.backward() #反向传播 y.grad #tensor([4.]) x.grad #None
以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None
既然谈到了修改模型的权重问题,那么还有一种情况是:
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.
for param in B.parameters(): param.requires_grad = False a = A(input) b = B(a) loss = criterion(b, target) loss.backward()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
相关文章
解决Python3.8用pip安装turtle-0.0.2出现错误问题
turtle库是python的基础绘图库,这个库被介绍为一个最常用的用来给孩子们介绍编程知识的方法库,这篇文章主要介绍了解决Python3.8用pip安装turtle-0.0.2出现错误问题,需要的朋友可以参考下2020-02-02python 解决flask uwsgi 获取不到全局变量的问题
今天小编就为大家分享一篇python 解决flask uwsgi 获取不到全局变量的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2019-12-12pybaobabdt库基于python的决策树随机森林可视化工具使用
这篇文章主要为大家介绍了pybaobabdt库基于python的决策树随机森林可视化工具使用探索,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪2024-02-02关于jupyter lab安装及导入tensorflow找不到模块的问题
这篇文章主要介绍了关于jupyter lab安装及导入tensorflow找不到模块的问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2021-03-03
最新评论