Pytorch基本变量类型FloatTensor与Variable用法

 更新时间:2020年01月08日 10:42:06   投稿:jingxian  
今天小编就为大家分享一篇Pytorch基本变量类型FloatTensor与Variable用法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息

pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现

然而正确的打开方式并不是这样

韩国一位大神写了一个pytorch的turorial,其中包含style transfer的一个代码实现

for step in range(config.total_step):

    
    # Extract multiple(5) conv feature vectors
    target_features = vgg(target)  # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
    content_features = vgg(Variable(content))
    style_features = vgg(Variable(style))

    style_loss = 0
    content_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
      # Compute content loss (target and content image)
      content_loss += torch.mean((f1 - f2)**2) # square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作

      # Reshape conv features
      _, c, h, w = f1.size() # channel height width
      f1 = f1.view(c, h * w) # reshape a vector
      f3 = f3.view(c, h * w) # reshape a vector

      # Compute gram matrix 
      f1 = torch.mm(f1, f1.t())
      f3 = torch.mm(f3, f3.t())

      # Compute style loss (target and style image)
      style_loss += torch.mean((f1 - f3)**2) / (c * h * w)  # 总共元素的数目?

其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后

# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable

# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype) # 两个权重矩阵
w2 = torch.randn(D_in, H).type(dtype)
# operate with +-*/ and **
w3 = w1-2*w2
w4 = w3**2
w5 = w4/w1


# operate the Variable with +-*/ and **
w6 = Variable(torch.randn(N, D_in).type(dtype))
w7 = Variable(torch.randn(N, D_in).type(dtype))
w8 = w6 + w7
w9 = w6*w7
w10 = w9**2
print(1)

基本上调试的结果与预期相符

所以,对于floattensor以及variable进行普通的+-×/以及**没毛病

以上这篇Pytorch基本变量类型FloatTensor与Variable用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python实现特殊字符判断并去掉非字母和数字的特殊字符

    Python实现特殊字符判断并去掉非字母和数字的特殊字符

    在 Python 中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家有所帮助
    2025-04-04
  • Windows下安装Scrapy

    Windows下安装Scrapy

    今天小编就为大家分享一篇关于Windows下安装Scrapy,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-10-10
  • Python实现视频裁剪的示例代码

    Python实现视频裁剪的示例代码

    这篇文章主要介绍了如何通过Python实现视频裁剪,可以将视频按照自定义尺寸进行裁剪,文中的示例代码简洁易懂,感兴趣的可以了解一下
    2022-01-01
  • 浅谈如何重构冗长的Python代码

    浅谈如何重构冗长的Python代码

    这篇文章主要介绍了浅谈如何重构冗长的Python代码,编写干净的 Pythonic 代码就是尽可能使其易于理解,但又简洁,过长的代码如何做到简洁高效,需要的朋友可以参考下
    2023-04-04
  • Python3非对称加密算法RSA实例详解

    Python3非对称加密算法RSA实例详解

    这篇文章主要介绍了Python3非对称加密算法RSA,结合实例形式分析了Python3 RSA加密相关模块安装及使用操作技巧,需要的朋友可以参考下
    2018-12-12
  • python并发场景锁的使用方法

    python并发场景锁的使用方法

    这篇文章主要介绍了python并发场景锁的使用方法,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-07-07
  • Python中最神秘missing()函数介绍

    Python中最神秘missing()函数介绍

    大家好,本篇文章主要讲的是Python中最神秘missing()函数介绍,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2021-12-12
  • python SQLAlchemy 数据库连接池的实现

    python SQLAlchemy 数据库连接池的实现

    SSQLAlchemy提供了强大的连接池和连接管理功能,可以有效地管理数据库连接,本文主要介绍了python SQLAlchemy 数据库连接池的实现,具有一定的参考价值,感兴趣的可以了解一下
    2025-03-03
  • Python单例模式实例分析

    Python单例模式实例分析

    这篇文章主要介绍了Python单例模式,以实例形式分析了Python单例模式的具体使用技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-01-01
  • python实现井字棋游戏

    python实现井字棋游戏

    这篇文章主要为大家详细介绍了python实现井字棋游戏的相关资料,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2016-02-02

最新评论