详解使用Pytorch Geometric实现GraphSAGE模型

 更新时间:2023年04月24日 10:31:39   作者:实力  
这篇文章主要为大家介绍了详解使用Pytorch Geometric实现GraphSAGE模型示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

正文

GraphSAGE是一种用于图神经网络中的节点嵌入学习方法。它通过聚合节点邻居的信息来生成节点的低维表示,使节点表示能够更好地应用于各种下游任务,如节点分类、链路预测等。

图构建

在使用GraphSAGE对节点进行嵌入学习之前,我们需要先将原始数据转换为图结构,并将其存储为Pytorch Tensor格式。例如,我们可以使用networkx库来构建一个简单的图:

import networkx as nx

G = nx.karate_club_graph()

然后,我们可以使用Pytorch Geometric库将NetworkX图转换为Pytorch Tensor格式。首先,我们需要安装Pytorch Geometric并导入所需的类:

!pip install torch-geometric

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils.convert import from_networkx

接着,我们可以使用from_networkx函数将NetworkX图转换为Pytorch Tensor格式:

data = from_networkx(G)

此时,data对象包含了关于节点、边及其属性的信息,例如:

data.edge_index: 2x(#edges)的长整型张量,表示边的起点和终点

  • data.x: n×dn \times dn×d 的浮点型张量,表示每个节点的特征向量(其中nnn是节点数量,ddd是特征维度)

注意,此时的data对象并未包含邻居信息。接下来,我们将介绍如何使用Sampler方法采样节点邻居。

Sampler方法

GraphSAGE使用Sampler方法来聚合邻居信息。在Pytorch Geometric中,可以使用Various Sampling方法来实现Sampler。例如,使用ClusterData方法将图分成多个子图,然后对每个子图进行采样操作。

以下是ClusterData的使用示例:

from torch_geometric.utils import degree, to_undirected
from torch_geometric.transforms import ClusterData

# Convert the graph to an undirected graph, so we can aggregate neighbors in both directions.
G = to_undirected(G)

# Compute the degree of each node.
deg = degree(data.edge_index[0], num_nodes=data.num_nodes)

# Use METIS algorithm to partition the graph into multiple subgraphs.
cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(),
                           degree=deg)

这里我们将原始图分成两个子图,并对每个子图进行规范化特征转换。注意,在使用ClusterData方法之前,需要将原始图转换为无向图。

另一个常用的Sampler方法是在随机游动时对邻居进行采样,这种方法被称为随机游走采样(Random Walk Sampling)。以下是随机游走采样的示例代码:

from torch_geometric.utils import random_walk

# Perform random walk sampling to obtain node neighbor samples.
walk_length = 20  # The length of random walk trail.
num_steps = 4     # The number of nodes to sample from each step.
data.batch = None
data.edge_index = to_undirected(data.edge_index)  # Use undirected edge for random walk.

rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)

这里我们将使用一个长度为20、每个步骤采样4个邻居的随机游走方法。注意,在使用随机游走方法进行采样之前,需要使用无向边。

GraphSAGE模型定义

GraphSAGE模型包含3个部分:1)图卷积层;2)聚合器(Aggregator);3)输出层。我们将在本节中介绍如何使用Pytorch实现这些组件。

首先,让我们定义一个图卷积层。图卷积层的输入是节点特征矩阵、邻接矩阵和聚合器,输出是新的节点特征矩阵。以下是图卷积层的代码实现:

import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool

class GraphSageConv(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr='mean'):
        super(GraphSageConv, self).__init__(aggr=aggr)
        self.lin = nn.Linear(in_channels, out_channels)
        
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j):
        return x_j
    
    def update(self, aggr_out, x):
        return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))

这里我们继承了MessagePassing类,并在__init__函数中定义了一个全连接层,用于将输入特征矩阵x从 dind_{in}din​ 维映射到 doutd_{out}dout​ 维。在forward函数中,我们使用propagate方法来实现消息传递操作;在message函数中,我们仅向下游节点发送原始特征数据;在update函数中,我们首先对聚合结果进行ReLU非线性变换,然后再通过全连接层进行节点特征的更新。

接下来,让我们定义一个聚合器。聚合器的输入是采样得到的邻居特征矩阵,输出是新的节点嵌入向量。以下是聚合器的代码实现:

class MeanAggregator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MeanAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lin = nn.Linear(input_dim, output_dim)
        
    def forward(self, neigh_mean):
        out = F.relu(self.lin(neigh_mean))
        return out

这里我们定义了一个简单的均值聚合器,其将邻居特征矩阵中每列的均值作为节点嵌入向量,并使用全连接层进行维度变换。

最后,让我们定义整个GraphSage模型。GraphSage模型包含2个图卷积层和1个输出层。以下是模型的代码实现:

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.conv1 = GraphSageConv(in_channels, hidden_channels)
        self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels)
        self.conv2 = GraphSageConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = global_mean_pool(x, edge_index)  # Compute global mean over nodes.
        x = self.aggreg1(x)
        x = self.conv2(x, edge_index)
        return x

这里我们定义了一个包含2层GraphSAGE Conv层的神经网络。在最后一层GraphSAGE Conv层之后,我们使用global_mean_pool函数来计算节点嵌入的全局平均值。注意,在本示例中,我们仅保留了一个输出节点,因此输出矩阵的大小为1。如果需要输出多个节点,则需要设置global_mean_pool函数中的参数。

模型训练与测试

在定义好模型后,我们可以使用Pytorch进行模型训练和测试。首先,让我们定义一个损失函数和优化器:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

这里我们使用交叉熵作为损失函数,并使用Adam优化器来更新模型参数。

接着,我们可以开始训练模型。以下是训练过程的代码实现:

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))

这里我们遍历所有数据样本,计算预测结果和真实标签之间的交叉熵损失,并使用反向传播来更新权重。我们在每个epoch结束后打印出当前损失值。

最后,我们可以对模型进行测试。以下是测试过程的代码实现:

model.eval()

with torch.no_grad():
    pred = model(data.x, data.edge_index)
    pred = pred.argmax(dim=1)

acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
print('Test accuracy: {:.4f}'.format(acc))

这里我们使用测试集来计算模型的准确率。注意,在执行model.eval()后,我们需要使用torch.no_grad()包装代码块,以禁止梯度计算。

总结

介绍了如何使用Pytorch Geometric实现GraphSAGE模型,包括构建图、定义Sampler方法、定义模型、训练和测试模型等步骤。GraphSAGE模型是一种常用的节点嵌入学习方法,可以应用于各种下游任务中。

以上就是详解使用Pytorch Geometric实现GraphSAGE模型的详细内容,更多关于Pytorch Geometric GraphSAGE的资料请关注脚本之家其它相关文章!

相关文章

  • openstack中的rpc远程调用的方法

    openstack中的rpc远程调用的方法

    今天通过本文给大家分享openstack中的rpc远程调用的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2021-07-07
  • Python 递归函数详解及实例

    Python 递归函数详解及实例

    这篇文章主要介绍了Python 递归函数详解及实例的相关资料,需要的朋友可以参考下
    2016-12-12
  • python如何代码集体右移

    python如何代码集体右移

    在本篇文章里小编给各位分享的是一篇关于python如何代码集体右移的相关知识点文章,需要的朋友们可以学习下。
    2020-07-07
  • 详解Pytorch显存动态分配规律探索

    详解Pytorch显存动态分配规律探索

    这篇文章主要介绍了Pytorch显存动态分配规律探索,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11
  • 教你如何使用Python selenium

    教你如何使用Python selenium

    今天教大家如何使用Python selenium,本文会以艺龙旅游网为对象,进行selenium的学习,目的:爬取艺龙网中南阳市唐河县的酒店信息,包括:名字,电话,标间价格,地址,介绍,图片,需要的朋友可以参考下
    2021-06-06
  • 使用Python实现为PDF文件添加图章

    使用Python实现为PDF文件添加图章

    在日常工作中,我们经常需要给PDF文档添加一些标识,比如公司的图章或水印图章,所以本文就来为大家详细介绍一下如何使用Python实现为PDF文件添加图章,需要的可以参考下
    2023-11-11
  • Python 中 and, or, &, |, ^ 的使用小结

    Python 中 and, or, &, |, ^ 

    这篇文章主要介绍了Python 中 and, or, &, |, ^ 的使用小结,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2024-01-01
  • python多进程控制学习小结

    python多进程控制学习小结

    这篇文章主要介绍了python多进程控制学习小结,想要充分利用多核CPU资源,Python中大部分情况下都需要使用多进程,Python中提供了multiprocessing这个包实现多进程。感兴趣的小伙伴们可以参考一下
    2018-10-10
  • Python函数必须先定义,后调用说明(函数调用函数例外)

    Python函数必须先定义,后调用说明(函数调用函数例外)

    这篇文章主要介绍了Python函数必须先定义,后调用说明(函数调用函数例外),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python光学仿真wxpython之DC绘图

    Python光学仿真wxpython之DC绘图

    这篇文章主要为大家介绍了Python光学仿真wxpython之DC绘图的基本概念及用法详解,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-10-10

最新评论