Pytorch+PyG实现EdgeCNN过程示例详解

 更新时间:2023年04月21日 09:46:01   作者:实力  
这篇文章主要为大家介绍了Pytorch+PyG实现EdgeCNN过程示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

1.EdgeCNN简介

EdgeCNN是一种用于图像点云处理的卷积神经网络(Convolutional Neural Network,CNN)模型。与传统的CNN仅能处理图片二维数据不同,EdgeCNN可以对三维点云中每个点周围的局部邻域进行操作,并适用于物体识别、深度估计、自动驾驶等多项任务。

2. 实现步骤

2.1 数据准备

在本实验中,我们使用了一个包含4万个点云的数据集ModelNet10,作为示例。与其它标准图像数据集不同的是,这个数据集中图形的构成量非常大,而且各图之间结构差异很大,因此需要进行大量的预处理工作。

# 导入模型数据集
from torch_geometric.datasets import ModelNet
# 加载ModelNet数据集
dataset = ModelNet(root='./modelnet', name='10')
data = dataset[0]
# 定义超级参数
num_points = 1024
batch_size = 32
train_dataset_size = 8000
# 将数据集分割成训练、验证及测试三个数据集
train_dataset = data[0:train_dataset_size]
val_dataset = data[train_dataset_size: 9000]
test_dataset = data[9000:]
# 定义数据加载批处理器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

通过上述代码,我们先是导入ModelNet数据集并将其分割成训练、验证及测试三个数据集,并创建了数据加载批处理器,以便于在训练过程中对这些数据进行有效的处理。

2.2 实现模型

在定义EdgeCNN模型时,我们需要根据图像点云经常使用的架构定义网络结构。同时,在实现卷积操作时应引入相应的邻域信息,来使得网络能够学习到系统中附近点之间的关系。

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import EdgeConv, global_max_pool
class EdgeCNN(torch.nn.Module):
    def __init__(self, dataset):
        super(EdgeCNN, self).__init__()
        # 定义基础参数
        self.input_dim = dataset.num_features
        self.output_dim = dataset.num_classes
        self.num_points = num_points
        # 定义模型结构
        self.conv1 = EdgeConv(Seq(Lin(self.input_dim, 32), ReLU()))
        self.conv2 = EdgeConv(Seq(Lin(32, 64), ReLU()))
        self.conv3 = EdgeConv(Seq(Lin(64, 128), ReLU()))
        self.conv4 = EdgeConv(Seq(Lin(128, 256), ReLU()))
        self.fc1 = torch.nn.Linear(256, 1024)
        self.fc2 = torch.nn.Linear(1024, self.output_dim)
    def forward(self, pos, batch):
        # 构造图
        edge_index = radius_graph(pos, r=0.6, batch=batch, loop=False)
        # 第一层CNN模型的卷积 + 池化处理
        x = F.relu(self.conv1(x=pos, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第二层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=0.9, batch=batch, loop=False)
        x = F.relu(self.conv2(x=x, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第三层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=1.2, batch=batch, loop=False)
        x = F.relu(self.conv3(x=x, edge_index=edge_index))
        x = global_max_pool(x, batch)
        # 第四层CNN模型的卷积 + 池化处理
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
        x = F.relu(self.conv4(x=x, edge_index=edge_index))
        # 定义全连接网络
        x = global_max_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

在上述代码中,实现了基于EdgeCNN的模型的各个卷积层和全连接层,并使用radius_graph等函数将局部区域问题归约到定义的卷积核检测范围之内,以便更好地对点进行分析和特征提取。最后结合全连接层输出一个维度为类别数的向量,并通过softmax函数来计算损失。

2.3 模型训练

在定义好EdgeCNN网络结构之后,我们还需要指定合适的优化器、损失函数,并控制训练轮数、批大小与学习率等超参数。同时也需要记录大量日志信息,方便后期跟踪及管理。

# 定义训练计划,包括损失函数、优化器及迭代次数等
train_epochs = 50
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(edge_cnn.parameters(), lr=learning_rate)
losses_per_epoch = []
accuracies_per_epoch = []
for epoch in range(train_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    count = 0.0
    for samples in train_loader:
        optimizer.zero_grad()
        pos, batch, label = samples.pos, samples.batch, samples.y.to(torch.long)
        out = edge_cnn(pos, batch)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() / len(train_dataset)
        running_corrects += torch.sum(torch.argmax(out, dim=1) == label).item() / len(train_dataset)
        count += 1
    losses_per_epoch.append(running_loss)
    accuracies_per_epoch.append(running_corrects)
    if (epoch + 1) % 5 == 0:
        print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format(
            epoch + 1, train_epochs, running_loss, running_corrects))

在训练过程中,我们遍历每个batch,通过反向传播算法进行优化,并更新loss及accuracy输出。同时,为了方便可视化与记录,需要将训练过程中的loss和accuracy输出到相应的容器中,以便后期进行分析和处理。

以上就是Pytorch+PyG实现EdgeCNN过程示例详解的详细内容,更多关于Pytorch PyG实现EdgeCNN的资料请关注脚本之家其它相关文章!

相关文章

  • python循环嵌套的多种使用方法解析

    python循环嵌套的多种使用方法解析

    这篇文章主要介绍了python循环嵌套的多种使用方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11
  • Python 高效编程技巧分享

    Python 高效编程技巧分享

    工作中经常要处理各种各样的数据,遇到项目赶进度的时候自己写函数容易浪费时间。Python 中有很多内置函数帮你提高工作效率。
    2020-09-09
  • 有关pycharm登录github时有的时候会报错connection reset的问题

    有关pycharm登录github时有的时候会报错connection reset的问题

    这篇文章主要介绍了有关pycharm登录github时有的时候会报错connection reset的问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-09-09
  • Python的Flask框架开发验证码登录的实现

    Python的Flask框架开发验证码登录的实现

    在本文我们介绍了如何使用Python的Flask框架开发一个简单的验证码登录功能,将涵盖生成验证码、处理用户输入、验证验证码以及实现安全的用户认证等方面,感兴趣的可以了解一下
    2023-11-11
  • Python计算机视觉SIFT尺度不变的图像特征变换

    Python计算机视觉SIFT尺度不变的图像特征变换

    这篇文章主要为大家介绍了Python计算机视觉SIFT尺度不变的图像特征变换,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • pandas多级分组实现排序的方法

    pandas多级分组实现排序的方法

    下面小编就为大家分享一篇pandas多级分组实现排序的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python绘图实现显示中文

    Python绘图实现显示中文

    今天小编就为大家分享一篇Python绘图实现显示中文,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python实现轻松防止屏幕截图的技巧分享

    Python实现轻松防止屏幕截图的技巧分享

    屏幕截图是一种常见的用于记录信息或者监控用户活动的方法,为了保护隐私和数据安全,可以通过使用Python编写一些防护措施来防止他人截取我们的屏幕,下面我们就来学习一下有哪些具体操作吧
    2023-12-12
  • Matlab中如何实现将长字符串换行写

    Matlab中如何实现将长字符串换行写

    这篇文章主要介绍了Matlab中如何实现将长字符串换行写问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-01-01
  • Python常用断言函数实例汇总

    Python常用断言函数实例汇总

    这篇文章主要介绍了Python常用断言函数实例汇总,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11

最新评论