详解Pytorch+PyG实现GAT过程示例

 更新时间:2023年04月21日 10:01:31   作者:实力  
这篇文章主要为大家介绍了Pytorch+PyG实现GAT过程示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

导入库和数据

GAT(图注意力网络)是常见的图神经网络结构之一,它使用注意力机制来对节点进行特征加权,并考虑其邻居节点的交互。

首先,我们需要导入PyTorch和PyG库,然后准备好我们的数据。例如,我们可以使用以下方式生成一个简单的随机数据集:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
train_loader = DataLoader(dataset[0], batch_size=128, shuffle=True)
test_loader = DataLoader(dataset[0], batch_size=128, shuffle=False)

其中, Planetoid 是PyG提供的图形数据集之一。这里我们选择了 Cora 数据集并存储到 /tmp/Cora 文件夹中。然后我们将该数据集分成训练集和测试集,设置相应的加载器。

定义模型结构

接下来,我们需要定义GAT模型的结构。通过PyTorch和PyG,我们可以自己定义完整的GAT模型或者利用现有的库函数快速构建模型。在这里,我们将使用 torch_geometric.nn.GATConv 函数逐层堆叠多个图注意力层来实现GAT模型。以下是GAT模型定义的示例代码:

import torch.nn.functional as F
from torch_geometric.nn import GATConv
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.num_layers = 2
        self.conv1 = GATConv(in_channels=in_channels, out_channels=16, heads=8, dropout=0.6)
        self.conv2 = GATConv(in_channels=16*8, out_channels=out_channels, heads=1, concat=False, dropout=0.6)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

上述代码中,我们定义了一个 Net 类用于构建GAT网络,接收输入通道数和输出通道数作为参数。例如,我们可以按照以下方式创建一个将 CORD 参量作为输入特征向量大小、64 个隐藏节点(每个注意力头)。并将数字类别作为输出大小的GAT模型:

model = Net(in_channels=dataset.num_features, out_channels=dataset.num_classes)

其中 num_featuresnum_classes 是PyG数据集中包含的属性。

定义训练函数

然后,我们需要定义训练函数来训练我们的GAT神经网络。在这里,我们将使用交叉熵损失和Adam优化器进行训练,并在每一个epoch结束时计算准确率并打印出来。以下是训练函数的示例代码:

import torch.optim as optim
from tqdm import tqdm
def train(model, loader, optimizer, loss_fn):
    model.train()
    correct = 0
    total_loss = 0
    for data in tqdm(loader, desc='Training'):
        optimizer.zero_grad()
        out = model(data)
        pred = out.argmax(dim=1)
        loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        correct += pred[data.train_mask].eq(data.y[data.train_mask]).sum().item()
    return total_loss / len(loader.dataset), correct / len(data.train_mask)

在上述代码中,我们遍历加载器中的每个数据批次,并对模型进行培训。对于每个图数据批次,我们计算网络输出、预测和损失,然后通过反向传播来更新权重。最后,我们将总损失和正确率记录下来并返回。

定义测试函数

接下来,我们还需要定义测试函数来测试我们的GAT神经网络性能表现。我们将利用与训练函数相同的输出参数进行测试,并打印出最终的测试准确率。以下是测试函数的示例代码:

def test(model, loader, loss_fn):
    model.eval()
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(loader, desc='Testing'):
            out = model(data)
            pred = out.argmax(dim=1)
            loss = loss_fn(out[data.test_mask], data.y[data.test_mask])
            total_loss += loss.item() * data.num_graphs
            correct += pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    return total_loss / len(loader.dataset), correct / len(data.test_mask)

在上述代码中,我们对测试数据集中的所有数据进行了循环,并计算网络的输出和预测。我们记录下总损失和正确分类的数据量,并返回损失和准确率之间的比率。

训练模型并评估训练结果

最后,我们可以使用前面定义过的函数来定义主函数,从而完成GAT神经网络的训练和测试。以下是主函数的示例代码:

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net(in_channels=dataset.num_features, out_channels=dataset.num_classes).to(device)
    train_loader = DataLoader(dataset[0], batch_size=128, shuffle=True)
    test_loader = DataLoader(dataset[0], batch_size=128, shuffle=False)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, 201):
        train_loss, train_acc = train(model, train_loader, optimizer, loss_fn)
        test_loss, test_acc = test(model, test_loader, loss_fn)
        print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
              f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

通过上述代码,我们就可以完成GAT神经网络的训练和测试。我们使用 DataLoader 函数进行数据加载,设置学习率、损失函数、训练轮数等超参数。最后,我们可以在屏幕上看到每个时代的准确率和损失值,并通过它们评估模型的训练表现。

以上就是详解Pytorch+PyG实现GAT过程示例的详细内容,更多关于Pytorch PyG实现GAT的资料请关注脚本之家其它相关文章!

相关文章

  • python实现根据主机名字获得所有ip地址的方法

    python实现根据主机名字获得所有ip地址的方法

    这篇文章主要介绍了python实现根据主机名字获得所有ip地址的方法,涉及Python解析IP地址的相关技巧,需要的朋友可以参考下
    2015-06-06
  • python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法

    python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法

    今天小编就为大家分享一篇python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • python实现防截图的6种方法详解

    python实现防截图的6种方法详解

    防截图是指一组技术或方法,用于防止他人在未经允许的情况下在屏幕上截取或记录图像,这是一个重要的安全措施,它可以防止窃取敏感信息或监视个人信息,本文为大家整理了6种python可以防截图的方法,需要的可以参考下
    2023-10-10
  • python跨文件夹调用别的文件夹下py文件或参数方式详解

    python跨文件夹调用别的文件夹下py文件或参数方式详解

    这篇文章主要给大家介绍了关于python跨文件夹调用别的文件夹下py文件或参数方式的相关资料,在python中有时候我们需要调用另一.py文件中的方法或者类,需要的朋友可以参考下
    2023-08-08
  • python 异常捕获详解流程

    python 异常捕获详解流程

    异常即非正常状态,在Python中使用异常对象来表示异常。若程序在编译或运行过程中发生错误,程序的执行过程就会发生改变,抛出异常对象,程序流进入异常处理。如果异常对象没有被处理或捕捉,程序就会执行回溯(Traceback)来终止程序
    2022-03-03
  • 详解duck typing鸭子类型程序设计与Python的实现示例

    详解duck typing鸭子类型程序设计与Python的实现示例

    这篇文章主要介绍了详解duck typing鸭子类型程序设计与Python的实现示例,鸭子类型特指解释型语言中的一种编程风格,需要的朋友可以参考下
    2016-06-06
  • python正则表达式对字符串的查找匹配

    python正则表达式对字符串的查找匹配

    正则表达式是一种文本模式,包括普通字符(例如,a 到 z 之间的字母)和特殊字符(称为“元字符”),下面这篇文章主要给大家介绍了关于python正则表达式对字符串的查找匹配的相关资料,需要的朋友可以参考下
    2022-09-09
  • python中循环语句while用法实例

    python中循环语句while用法实例

    这篇文章主要介绍了python中循环语句while用法,实例分析了while语句的使用方法,需要的朋友可以参考下
    2015-05-05
  • Python实现的最近最少使用算法

    Python实现的最近最少使用算法

    这篇文章主要介绍了Python实现的最近最少使用算法,涉及节点、时间、流程控制等相关技巧,需要的朋友可以参考下
    2015-07-07
  • Python实现爬取逐浪小说的方法

    Python实现爬取逐浪小说的方法

    这篇文章主要介绍了Python实现爬取逐浪小说的方法,基于Python的正则匹配功能实现爬取小说页面标题、链接及正文等功能,需要的朋友可以参考下
    2015-07-07

最新评论