Pytorch的torch.nn.embedding()如何实现词嵌入层

 更新时间:2024年02月27日 15:49:11   作者:#苦行僧  
这篇文章主要介绍了Pytorch的torch.nn.embedding()如何实现词嵌入层问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.nn.embedding()实现词嵌入层

nn.embedding()其实是NLP中常用的词嵌入层,在实现词嵌入的过程中embedding层的权重用于随机初始化词的向量,该embedding层的权重参数在后续训练时会不断更新调整,并被优化。 

nn.embedding:这是一个矩阵类,该开始时里面初始化了一个随机矩阵,矩阵的长是字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。

类实例化之后可以根据字典中元素的下标来查找元素对应的向量。 

因为输入的句子长度不一,有的长有的短。

长了截断,不够长补齐(我文中用’'填充,然后在nn.embedding层将其补0,也就是用它来表示无意义的词,这样在后面的max-pooling层也就自然而然会把其过滤掉,这样就不用担心他会影响识别。)

这里说一下它的用法:

nn.embedding()主要3个参数

  • 第一个参数num_embeddings是指词表大小 
  • 第二个参数embedding_dim是指你需要用多少维来表示一个符号
  • 第三个参数pading_idx即需要用0填充的符号在词表中的位置,如下,输出中后面两个’'都有被填充为了0.
import torch
import torch.nn as nn


#词表
word_to_id = {'hello':0, '<PAD>':1,'world':2}
embeds = nn.Embedding(len(word_to_id), 4,padding_idx=word_to_id['<PAD>'])

text = 'hello world <PAD> <PAD>'
hello_idx = torch.LongTensor([word_to_id[i] for i in text.split()])
#词嵌入得到词向量
hello_embed = embeds(hello_idx)
print(hello_embed)

从以下输出可以看到,每行代表句子中一个单词的词嵌入向量,句子中的每个单词都有4维度,最后两个0向量是时用来填充补齐的没意义。

所以embedding层其实相当于将前面用索引编码的句子表示乘上embedding层的可训练权重得到的就是词嵌入的结果

输出:

tensor([[-1.1436, 1.4588, -1.2755, 0.0077],
[-0.9600, -1.9986, -1.1087, -0.1520],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=)

你也可以使用nn.Embedding.from_pretrained()加载预训练好的模型,如word2vec,glove等,在训练的过程中也可以边训练,边更新词向量,加快模型的收敛。

本文用的只是简单的nn.embedding()

然后具体使用 nn.embedding() 时,写在初始化搭建网络里

如下:

class Network(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__(nvocab,embed)
        self.filter_sizes = (2, 3, 4)
        self.embed = embed
        self.num_filters = 256
        self.dropout = 0.5
        self.num_classes = num_classes
        self.n_vocab = nvocab
        #通过padding_idx将<PAD>字符填充为0,因为他没意义哦,后面max-pooling自然而然会把他过滤掉哦
        self.embedding = nn.Embedding(self.n_vocab, self.embed, padding_idx=word2idx['<PAD>'])
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, self.num_filters, (k, self.embed)) for k in self.filter_sizes])
        
        self.dropout = nn.Dropout(self.dropout)
        self.fc = nn.Linear(self.num_filters * len(self.filter_sizes), self.num_classes)
        
    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x
        
    def forward(self, x):
        out = self.embedding(x)
        out = out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python中的迭代和可迭代对象代码示例

    python中的迭代和可迭代对象代码示例

    这篇文章主要介绍了python中的迭代和可迭代对象代码示例,具有一定借鉴价值,需要的朋友可以参考下
    2017-12-12
  • Python 开发工具PyCharm安装教程图文详解(新手必看)

    Python 开发工具PyCharm安装教程图文详解(新手必看)

    PyCharm是一种Python IDE,带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具,比如调试、语法高亮、Project管理、代码跳转、智能提示、自动完成、单元测试、版本控制。今天通过本文给大家分享PyCharm安装教程,一起看看吧
    2020-02-02
  • 在python中路径含有空格的问题及解决

    在python中路径含有空格的问题及解决

    这篇文章主要介绍了在python中路径含有空格的问题及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • Python动态导入模块的方法实例分析

    Python动态导入模块的方法实例分析

    这篇文章主要介绍了Python动态导入模块的方法,结合实例形式较为详细的分析了Python动态导入系统模块、自定义模块以及模块列表的相关操作技巧,需要的朋友可以参考下
    2018-06-06
  • pandas数据聚合与分组运算的实现

    pandas数据聚合与分组运算的实现

    本文主要介绍了pandas数据聚合与分组运算的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • python执行系统命令后获取返回值的几种方式集合

    python执行系统命令后获取返回值的几种方式集合

    今天小编就为大家分享一篇python执行系统命令后获取返回值的几种方式集合,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • Python从视频中提取音频的操作

    Python从视频中提取音频的操作

    这篇文章主要介绍了Python从视频中提取音频的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04
  • 使用keras和tensorflow保存为可部署的pb格式

    使用keras和tensorflow保存为可部署的pb格式

    这篇文章主要介绍了使用keras和tensorflow保存为可部署的pb格式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • PIL包中Image模块的convert()函数的具体使用

    PIL包中Image模块的convert()函数的具体使用

    这篇文章主要介绍了PIL包中Image模块的convert()函数的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Django 聚合函数的具体使用

    Django 聚合函数的具体使用

    orm模型中的聚合函数跟MySQL中的聚合函数作用是一致的,也有像Sum、Avg、Count、Max、Min,接下来我们逐个介绍,下面就一起来了解一下
    2021-05-05

最新评论