pytorch自定义不可导激活函数的操作

 更新时间:2021年06月05日 14:46:53   作者:Luna_Lovegood_001  
这篇文章主要介绍了pytorch自定义不可导激活函数的操作,具有很好的参考价值,希望大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch自定义不可导激活函数

今天自定义不可导函数的时候遇到了一个大坑。

首先我需要自定义一个函数:sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定义函数
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

结果报错

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我试了半天,发现自定义函数后面要加 apply ,详细见下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### <-----注意此处
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

问题解决了!

PyTorch自定义带学习参数的激活函数(如sigmoid)

有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节

在这里插入图片描述

在这里插入图片描述

函数如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

验证和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

输出结果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)

验证权重是不是会更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

输出结果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

会更新~

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中的pathlib.Path为什么不继承str详解

    Python中的pathlib.Path为什么不继承str详解

    这篇文章主要给大家介绍了关于Python中pathlib.Path为什么不继承str的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-06-06
  • Python脚本导出为exe程序的方法

    Python脚本导出为exe程序的方法

    这篇文章主要介绍了如何把Python脚本导出为exe程序的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-03-03
  • python中使用正则表达式的连接符示例代码

    python中使用正则表达式的连接符示例代码

    在正则表达式中,匹配数字或者英文字母的书写非常不方便。因此,正则表达式引入了连接符“-”来定义字符的范围,下面这篇文章主要给大家介绍了关于python中如何使用正则表达式的连接符的相关资料,需要的朋友可以参考下。
    2017-10-10
  • Python爬虫框架Scrapy实战之批量抓取招聘信息

    Python爬虫框架Scrapy实战之批量抓取招聘信息

    网络爬虫又被称为网页蜘蛛,网络机器人,在FOAF社区中间,更经常的称为网页追逐者,是按照一定的规则,自动抓取万维网信息的程序或者脚本。这篇文章主要介绍Python爬虫框架Scrapy实战之批量抓取招聘信息,有需要的朋友可以参考下
    2015-08-08
  • 详解Django中间件执行顺序

    详解Django中间件执行顺序

    这篇文章主要介绍了详解Django中间件执行顺序,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-07-07
  • python中字符串的操作方法大全

    python中字符串的操作方法大全

    这篇文章主要给大家介绍了关于python中字符串操作方法的相关资料,文中通过示例代码详细介绍了关于python中字符串的大小写转换、isXXX判断、填充、子串搜索、替换、分割、join以及修剪:strip、lstrip和rstrip的相关内容,需要的朋友可以参考下
    2018-06-06
  • python高阶爬虫实战分析

    python高阶爬虫实战分析

    这篇文章给大家分享了python高阶爬虫实战的相关实例内容以及技巧分析,有兴趣的朋友参考下。
    2018-07-07
  • python使用cPickle模块序列化实例

    python使用cPickle模块序列化实例

    这篇文章主要介绍了python使用cPickle模块序列化的方法,是一个非常实用的技巧,需要的朋友可以参考下
    2014-09-09
  • python3的数据类型及数据类型转换实例详解

    python3的数据类型及数据类型转换实例详解

    在本文里小编给大家分享的是关于python3的数据类型及数据类型转换以及相关实例内容,有兴趣的朋友们可以学习下。
    2019-08-08
  • Python实现计算最小编辑距离

    Python实现计算最小编辑距离

    这篇文章主要介绍了Python实现计算最小编辑距离的相关代码,有需要的小伙伴可以参考下
    2016-03-03

最新评论