Pytorch+PyG实现GIN过程示例详解

 更新时间:2023年04月21日 09:57:03   作者:实力  
这篇文章主要为大家介绍了Pytorch+PyG实现GIN过程示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

GIN简介

GIN(Graph Isomorphism Network)是一类基于图同构的神经网络。在传统的神经网络中,每个节点的特征只依赖于其自身特征,但在图数据中,节点的特征还与其邻居节点有关系。GIN网络通过定义可重复均值池化运算来学习节点及其邻居的特征表示,并使用多层感知器(MLP)作为逐层转换函数进行特征提取。

实现步骤

数据准备

这里我们仍然选用Cora数据集作为示例数据。由于GIN采用基于点、简单且无参数的邻域聚合方式,因此不需要额外对数据做处理,直接使用即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加载Cora数据集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 将nx.Graph形式的图转换成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 获取节点数量和特征向量维度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要训练的节点分割数据集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

实现模型

接下来,我们需要定义GIN模型。

from torch_geometric.nn import global_mean_pool
class GIN(torch.nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(GIN, self).__init__()
        self.conv1 = GINConv(mlp=nn.Sequential(nn.Linear(num_features, hidden_dim),
                                                nn.ReLU(),
                                                nn.Linear(hidden_dim, hidden_dim)))
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GINConv(mlp=nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                        nn.ReLU(),
                                                        nn.Linear(hidden_dim, hidden_dim))))
        self.classify = nn.Sequential(nn.Linear(hidden_dim, num_classes))
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        out = global_mean_pool(x, batch)
        return self.classify(out)

在上述代码中,我们实现了多层GIN的“可重复均值池化”结构,并使用MLP作为转换函数进行多层特征提取。

模型训练

定义好模型后,可以开始针对Cora数据集进行模型训练了。训练模型前先设置好优化器和损失函数,并指定训练周期及其过程中需要记录输出信息的参数。

from torch_geometric.nn import GINConv, global_add_pool
# 初始化GIN并指定参数
num_layers = 5
hidden_dim = 1024
model = GIN(hidden_dim=hidden_dim, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-06)
loss_func = nn.CrossEntropyLoss()
# 开始训练
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    pred = model(train_data)
    loss = loss_func(pred[train_mask], train_labels)
    loss.backward()
    optimizer.step()
    # 在各个测试阶段检测一下准确率
    with torch.no_grad():
        model.eval()
        pred = model(test_data)
        test_loss = loss_func(pred[test_mask], test_labels).item()
        pred = pred.argmax(dim=-1, keepdim=True)
        correct = float(pred[test_mask].eq(test_labels.view(-1, 1)[test_mask]).sum().item())
        acc = correct / test_mask.sum().item()
        if epoch % 10 == 0:
            print("Epoch {:03d}, Train Loss {:.4f}, Test Loss {:.4f}, Test Acc {:.4f}".format(
                epoch, loss.item(), test_loss, acc))

以上就是Pytorch+PyG实现GIN过程示例详解的详细内容,更多关于Pytorch PyG实现GIN的资料请关注脚本之家其它相关文章!

相关文章

  • django中url映射规则和服务端响应顺序的实现

    django中url映射规则和服务端响应顺序的实现

    这篇文章主要介绍了django中url映射规则和服务端响应顺序的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • Ubuntu权限不足无法创建文件夹解决方案

    Ubuntu权限不足无法创建文件夹解决方案

    这篇文章主要介绍了Ubuntu权限不足无法创建文件夹解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • TensorFlow的环境配置与安装教程详解(win10+GeForce GTX1060+CUDA 9.0+cuDNN7.3+tensorflow-gpu 1.12.0+python3.5.5)

    TensorFlow的环境配置与安装教程详解(win10+GeForce GTX1060+CUDA 9.0+cuDNN7

    这篇文章主要介绍了TensorFlow的环境配置与安装(win10+GeForce GTX1060+CUDA 9.0+cuDNN7.3+tensorflow-gpu 1.12.0+python3.5.5),本文通过图文并茂的形式给大家介绍的非常详细,需要的朋友可以参考下
    2020-06-06
  • python pandas实现excel转为html格式的方法

    python pandas实现excel转为html格式的方法

    今天小编就为大家分享一篇python pandas实现excel转为html格式的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python中的logging模块实现日志打印

    Python中的logging模块实现日志打印

    这篇文章主要介绍了Python中的logging模块实现日志打印,其实不止print打印日志方便排查问题,Python自带的logging模块,也可以很简单就能实现日志的配置和打印,下面来看看具体的实现过程吧,需要的朋友可以参考一下
    2022-03-03
  • python处理写入数据代码讲解

    python处理写入数据代码讲解

    在本篇文章里小编给大家整理的是一篇关于python处理写入数据代码讲解内容,有兴趣的朋友们可以学习下。
    2020-10-10
  • Python关于excel和shp的使用在matplotlib

    Python关于excel和shp的使用在matplotlib

    今天小编就为大家分享一篇关于Python关于excel和shp的使用在matplotlib,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-01-01
  • Python中的xlrd模块使用整理

    Python中的xlrd模块使用整理

    今天给大家带来的文章是关于Python的相关知识,文章围绕着xlrd模块的使用展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • 基于pandas数据样本行列选取的方法

    基于pandas数据样本行列选取的方法

    下面小编就为大家分享一篇基于pandas数据样本行列选取的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Selenium元素的常用操作方法分析

    Selenium元素的常用操作方法分析

    这篇文章主要介绍了Selenium元素的常用操作方法,结合实例形式分析Selenium在获取元素之后针对点击、输入、提交、属性获取等常见操作相关实现技巧,需要的朋友可以参考下
    2018-08-08

最新评论