使用Pytorch实现two-head(多输出)模型的操作

 更新时间:2021年05月28日 11:53:20   作者:XJTU-Qidong  
这篇文章主要介绍了使用Pytorch实现two-head(多输出)模型的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

A two-head model

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

    f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python面向对象之类的定义与继承用法示例

    Python面向对象之类的定义与继承用法示例

    这篇文章主要介绍了Python面向对象之类的定义与继承用法,结合实例形式分析了Python类的定义、实例化、继承等基本操作技巧,需要的朋友可以参考下
    2019-01-01
  • pytorch无法使用GPU问题的解决方法

    pytorch无法使用GPU问题的解决方法

    这篇文章主要介绍了如何解决pytorch 无法使用GPU 的问题,文中通过代码和图文给大家讲解的非常详细,对大家的学习或工作有一定的帮助,需要的朋友可以参考下
    2024-02-02
  • python实现图片横向和纵向拼接

    python实现图片横向和纵向拼接

    这篇文章主要为大家详细介绍了python实现图片横向和纵向拼接,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • 童年回忆录之python版4399吃豆豆小游戏

    童年回忆录之python版4399吃豆豆小游戏

    相信80,90后都玩过4399网站的小游戏,虽然游戏很low但是童年的回忆,今天小编带你一起用python自己写一个4399吃豆豆的小游戏,文中给大家介绍的非常详细,对大家的学习或工作具有一定的价值
    2021-09-09
  • python矩阵列的实现示例

    python矩阵列的实现示例

    在Python和NumPy库的帮助下,矩阵列可以很容易地进行各种操作,本文主要介绍了python矩阵列的实现示例,具有一定的参考价值,感兴趣的可以了解一下
    2024-02-02
  • Python有参函数使用代码实例

    Python有参函数使用代码实例

    这篇文章主要介绍了Python有参函数使用代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • 使用keras2.0 将Merge层改为函数式

    使用keras2.0 将Merge层改为函数式

    这篇文章主要介绍了使用keras2.0 将Merge层改为函数式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python深入浅出分析元类

    Python深入浅出分析元类

    在Python里一切都是对象(object),基本数据类型,如数字,字符串,函数都是对象。对象可以由类(class)进行创建。那么既然一切都是对象,那么类是对象吗?是的,类也是对象,那么又是谁创造了类呢?答案也很简单,也是类,一个能创作类的类,称之为(type)元类
    2022-07-07
  • 使用Python编制一个批处理文件管理器

    使用Python编制一个批处理文件管理器

    在软件开发和系统管理中,批处理文件(.bat)是一种常见且有用的工具,它们可以自动化重复性任务,简化复杂的操作流程,今天,我们将探讨如何使用Python和wxPython创建一个图形用户界面(GUI)应用程序来管理和执行批处理文件,需要的朋友可以参考下
    2025-01-01
  • python正则表达式常见的知识点汇总

    python正则表达式常见的知识点汇总

    正则表达式提供了一些可用的匹配模式,比如忽略大小写、多行匹配等,下面这篇文章主要给大家介绍了关于python正则表达式常见的知识点,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05

最新评论