pytorch自定义二值化网络层方式

 更新时间:2020年01月07日 13:44:24   作者:ChLee98  
今天小编就为大家分享一篇pytorch自定义二值化网络层方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python安装及建立虚拟环境的完整步骤

    Python安装及建立虚拟环境的完整步骤

    在使用 Python 开发时,建议在开发环境和生产环境下都使用虚拟环境来管理项目的依赖,下面这篇文章主要给大家介绍了关于Python安装及建立虚拟环境的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • Python文件夹与文件的相关操作(推荐)

    Python文件夹与文件的相关操作(推荐)

    下面小编就为大家带来一篇Python文件夹与文件的相关操作(推荐)。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-07-07
  • Python实现深度遍历和广度遍历的方法

    Python实现深度遍历和广度遍历的方法

    今天小编就为大家分享一篇Python实现深度遍历和广度遍历的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • 解决Python print 输出文本显示 gbk 编码错误问题

    解决Python print 输出文本显示 gbk 编码错误问题

    这篇文章主要介绍了解决Python print 输出文本显示 gbk 编码错误问题,本文给出了三种解决方法,需要的朋友可以参考下
    2018-07-07
  • Django 1.10以上版本 url 配置注意事项详解

    Django 1.10以上版本 url 配置注意事项详解

    这篇文章主要介绍了Django 1.10以上版本 url 配置注意事项详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • LeetCode百钱买百鸡python递归解法示例

    LeetCode百钱买百鸡python递归解法示例

    这篇文章主要为大家介绍了LeetCode百钱买百鸡题目python递归解法示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-11-11
  • pycharm远程连接服务器运行pytorch的过程详解

    pycharm远程连接服务器运行pytorch的过程详解

    这篇文章主要介绍了在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorch的过程,包括安装PyTorch、CUDA以及配置PyCharm远程开发环境的详细步骤,需要的朋友可以参考下
    2025-02-02
  • Python中__init__.py文件的作用

    Python中__init__.py文件的作用

    这篇文章主要介绍了Python中__init__.py文件的作用,在PyCharm中,带有__init__.py这个文件的目录被认为是Python的包目录,与普通目录的图标有不一样的显示
    2022-09-09
  • YOLOv5中SPP/SPPF结构源码详析(内含注释分析)

    YOLOv5中SPP/SPPF结构源码详析(内含注释分析)

    其实关于YOLOv5的网络结构其实网上相关的讲解已经有很多了,但是觉着还是有必要再给大家介绍下,下面这篇文章主要给大家介绍了关于YOLOv5中SPP/SPPF结构源码的相关资料,需要的朋友可以参考下
    2022-05-05
  • python turtle工具绘制四叶草的实例分享

    python turtle工具绘制四叶草的实例分享

    在本篇文章里小编给各位整理的是关于python turtle工具绘制四叶草的实例分享,有兴趣的朋友们可以跟着学习下。
    2020-02-02

最新评论