Pytorch+PyG实现GraphSAGE过程示例详解

 更新时间:2023年04月21日 09:54:31   作者:实力  
这篇文章主要为大家介绍了Pytorch+PyG实现GraphSAGE过程示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

GraphSAGE简介

GraphSAGE(Graph Sampling and Aggregation)是一种常见的图神经网络模型,主要用于结点级别的表征学习。该模型基于采样和聚合策略,将一个结点及其邻居节点信息融合在一起,得到其表征表示,并通过多轮迭代更新来提高表征的精度。

实现步骤

数据准备

在本次实现中,我们仍然使用Cora数据集作为示例进行测试,由于GraphSage主要聚焦于单一节点特征的更新,因此这里不需要对数据集做特别处理,只需要将数据转化成PyG格式即可。

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx, to_networkx
# 加载cora数据集
dataset = Planetoid(root='./cora', name='Cora')
data = dataset[0]
# 将nx.Graph形式的图转换成PyG需要的格式
graph = to_networkx(data)
data = from_networkx(graph)
# 获取节点数量和特征向量维度
num_nodes = data.num_nodes
num_features = dataset.num_features
num_classes = dataset.num_classes
# 建立需要训练的节点分割数据集
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[:num_nodes - 1000] = True
data.test_mask[-1000:] = True
data.val_mask[num_nodes - 2000: num_nodes - 1000] = True

实现模型

接下来,我们需要定义GraphSAGE模型。与传统的GCN中只需要一层卷积操作不同,GraphSAGE包含两层卷积和采样(也称“聚合”)操作。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_channels if i != 0 else num_features
            out_channels = num_classes if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(in_channels, out_channels))
    def forward(self, x, edge_index):
        for _, conv in enumerate(self.convs[:-1]):
            x = F.relu(conv(x, edge_index))
        # 最后一层不用激活函数
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=-1)

在上述代码中,我们实现了多层GraphSAGE卷积和相应的聚合函数,并使用ReLU和softmax函数来进行特征提取和分类分数的输出。

模型训练

定义好模型之后,就可以开始针对Cora数据集进行模型训练。首先还是需要先指定优化器和损失函数,并设定一些参数用于记录训练过程中的信息,如Epochs、Batch size、学习率等。

# 初始化GraphSage并指定参数
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 训练过程
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    out = model(data.x.to(device), data.edge_index.to(device))
    loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
    loss.backward()
    optimizer.step()
    # 在各个测试阶段检测一下准确率
    if epoch % 10 == 0:
        with torch.no_grad():
            _, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
            correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
            acc = correct / data.test_mask.sum().item()
            print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format(
                epoch, loss.item(), acc))

在上述代码中,我们使用有标记的训练数据拟合GraphSAGE模型,在各个验证阶段测试准确率,并通过梯度下降法优化损失函数。

以上就是Pytorch+PyG实现GraphSAGE过程示例详解的详细内容,更多关于Pytorch PyG实现GraphSAGE的资料请关注脚本之家其它相关文章!

相关文章

  • Python中的 pass 占位语句

    Python中的 pass 占位语句

    这篇文章主要介绍了Python中的 pass 占位语句,Python pass是空语句,是为了保持程序结构的完整性,下文具体的相关内容介绍需要的小伙伴可以参考一下
    2022-04-04
  • python里使用正则表达式的组嵌套实例详解

    python里使用正则表达式的组嵌套实例详解

    这篇文章主要介绍了python里使用正则表达式的组嵌套实例详解的相关资料,希望通过本文能帮助到大家,需要的朋友可以参考下
    2017-10-10
  • 使用python将csv数据导入mysql数据库

    使用python将csv数据导入mysql数据库

    这篇文章主要为大家详细介绍了如何使用python将csv数据导入mysql数据库,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2024-05-05
  • python spotlight库简化交互式方法探索数据分析

    python spotlight库简化交互式方法探索数据分析

    这篇文章主要为大家介绍了python spotlight库简化的交互式方法探索数据,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • 基于Python制作一个端午节相关的小游戏

    基于Python制作一个端午节相关的小游戏

    端午节快乐,今天我将为大家带来一篇有关端午节的编程文章,希望能够为大家献上一份小小的惊喜,我们将会使用Python来实现一个与端午粽子相关的小应用程序,在本文中,我将会介绍如何用Python代码制做一个“粽子拆解器”,感兴趣的小伙伴欢迎阅读
    2023-06-06
  • python实现简易通讯录修改版

    python实现简易通讯录修改版

    这篇文章主要为大家详细介绍了python实现简易通讯录的修改版,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • Python实现将16进制字符串转化为ascii字符的方法分析

    Python实现将16进制字符串转化为ascii字符的方法分析

    这篇文章主要介绍了Python实现将16进制字符串转化为ascii字符的方法,结合实例形式分析了Python 16进制字符串转换为ascii字符的实现方法与相关注意事项,需要的朋友可以参考下
    2017-07-07
  • Python实现完全数的示例详解

    Python实现完全数的示例详解

    完全数,又称完美数,定义为:这个数的所有因数(不包括这个数本身)加起来刚好等于这个数。本文就来用Python实现计算完全数,需要的可以参考一下
    2023-01-01
  • Python基于xlutils修改表格内容过程解析

    Python基于xlutils修改表格内容过程解析

    这篇文章主要介绍了Python基于xlutils修改表格内容过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • python操作excel的方法(xlsxwriter包的使用)

    python操作excel的方法(xlsxwriter包的使用)

    这篇文章主要为大家详细介绍了python操作excel的方法,xlsxwriter包的使用方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-06-06

最新评论