pytorch自定义不可导激活函数的操作

 更新时间:2021年06月05日 14:46:53   作者:Luna_Lovegood_001  
这篇文章主要介绍了pytorch自定义不可导激活函数的操作,具有很好的参考价值,希望大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch自定义不可导激活函数

今天自定义不可导函数的时候遇到了一个大坑。

首先我需要自定义一个函数:sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定义函数
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

结果报错

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我试了半天,发现自定义函数后面要加 apply ,详细见下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### <-----注意此处
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

问题解决了!

PyTorch自定义带学习参数的激活函数(如sigmoid)

有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节

在这里插入图片描述

在这里插入图片描述

函数如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

验证和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

输出结果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)

验证权重是不是会更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

输出结果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

会更新~

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 用Python实现童年贪吃蛇小游戏功能的实例代码

    用Python实现童年贪吃蛇小游戏功能的实例代码

    这篇文章主要介绍了用Python实现童年贪吃蛇小游戏功能的实例代码,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-12-12
  • python排序函数sort()与sorted()的区别

    python排序函数sort()与sorted()的区别

    这篇文章主要介绍了python排序函数sort()与sorted()的区别,需要的朋友可以参考下
    2018-09-09
  • 一篇文章从零开始创建conda环境、常用命令的使用及pycharm配置项目环境

    一篇文章从零开始创建conda环境、常用命令的使用及pycharm配置项目环境

    在Conda中创建新环境是一个非常有用的做法,尤其是当你需要为不同的项目安装不同版本的软件包时,这篇文章主要给大家介绍了关于从零开始创建conda环境、常用命令的使用及pycharm配置项目环境的相关资料,需要的朋友可以参考下
    2024-07-07
  • 教你如何在Pytorch中使用TensorBoard

    教你如何在Pytorch中使用TensorBoard

    TensorBoard是TensorFlow中强大的可视化工具,今天通过本文给大家介绍如何在Pytorch中使用TensorBoard,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友一起看看吧
    2021-08-08
  • 详解python 破解网站反爬虫的两种简单方法

    详解python 破解网站反爬虫的两种简单方法

    这篇文章主要介绍了详解python 破解网站反爬虫的两种简单方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • python创建与遍历二叉树的方法实例

    python创建与遍历二叉树的方法实例

    这篇文章主要给大家介绍了关于python创建与遍历二叉树的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • 用Python把csv文件批量修改编码为UTF-8格式并转为Excel格式的方法

    用Python把csv文件批量修改编码为UTF-8格式并转为Excel格式的方法

    有时候用excel打开一个csv文件,中文全部显示乱码,然后手动用notepad++打开,修改编码为utf-8并保存后,再用excel打开显示正常,本文将给大家介绍一下用Python把csv文件批量修改编码为UTF-8格式并转为Excel格式的方法,需要的朋友可以参考下
    2023-09-09
  • 解决python-docx打包之后找不到default.docx的问题

    解决python-docx打包之后找不到default.docx的问题

    今天小编就为大家分享一篇解决python-docx打包之后找不到default.docx的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • Dataframe的行名及列名排序问题

    Dataframe的行名及列名排序问题

    这篇文章主要介绍了Dataframe的行名及列名排序问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • python入门字符串拼接\截取\转数字理解学习

    python入门字符串拼接\截取\转数字理解学习

    本篇内容我们主要讲有关Python字符串的用法,包括字符串的拼接、字符串怎么转数字、字符串的格式化、字符串函数等内容,有需要的朋友可以借鉴参考下
    2021-09-09

最新评论