PyTorch使用Torchdyn实现连续时间神经网络的代码示例

 更新时间:2025年02月05日 09:35:36   作者:deephub  
神经常微分方程(Neural ODEs)是深度学习领域的创新性模型架构,它将神经网络的离散变换扩展为连续时间动力系统,本文将基于Torchdyn(一个专门用于连续深度学习和平衡模型的PyTorch扩展库)介绍Neural ODE的实现与训练方法,需要的朋友可以参考下

Torchdyn概述

Torchdyn是基于PyTorch构建的专业库,专注于连续深度学习和隐式神经网络模型(如Neural ODEs)的开发。该库具有以下核心特性:

  • 支持深度不变性和深度可变性的ODE模型
  • 提供多种数值求解算法(如Runge-Kutta法,Dormand-Prince法)
  • 与PyTorch Lightning框架的无缝集成,便于训练流程管理

本教程将以经典的moons数据集为例,展示Neural ODEs在分类问题中的应用。

数据集构建

首先,我们使用Torchdyn内置的数据集生成工具创建实验数据:

 from torchdyn.datasets import ToyDataset  
 import matplotlib.pyplot as plt  
   
 # 生成示例数据
 d = ToyDataset()  
 X, yn = d.generate(n_samples=512, noise=1e-1, dataset_type='moons')  
 # 可视化数据集
 colors = ['orange', 'blue']  
 fig, ax = plt.subplots(figsize=(3, 3))  
 for i in range(len(X)):  
     ax.scatter(X[i, 0], X[i, 1], s=1, color=colors[yn[i].int()])  
 plt.show()

数据预处理

将生成的数据转换为PyTorch张量格式,并构建训练数据加载器。Torchdyn支持CPU和GPU计算,可根据硬件环境灵活选择:

 import torch  
 import torch.utils.data as data  
   
 device = torch.device("cpu")  # 如果使用GPU则改为'cuda'
 X_train = torch.Tensor(X).to(device)  
 y_train = torch.LongTensor(yn.long()).to(device)  
 train = data.TensorDataset(X_train, y_train)  
 trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

Neural ODE模型构建

Neural ODEs的核心组件是向量场(vector field),它通过神经网络定义了数据在连续深度域中的演化规律。以下代码展示了向量场的基本实现:

 import torch.nn as nn  
   
 # 定义向量场f
 f = nn.Sequential(  
     nn.Linear(2, 16),  
     nn.Tanh(),  
     nn.Linear(16, 2)  
 )

接下来,我们使用Torchdyn的

NeuralODE

类定义Neural ODE模型。这个类接收向量场和求解器设置作为输入。

 from torchdyn.core import NeuralODE  
   
 t_span = torch.linspace(0, 1, 5)  # 时间跨度
 model = NeuralODE(f, sensitivity='adjoint', solver='dopri5').to(device)

类来管理训练过程:

 import pytorch_lightning as pl  
   
 class Learner(pl.LightningModule):  
     def __init__(self, t_span: torch.Tensor, model: nn.Module):  
         super().__init__()  
         self.model, self.t_span = model, t_span  
     def forward(self, x):  
         return self.model(x)  
     def training_step(self, batch, batch_idx):  
         x, y = batch  
         t_eval, y_hat = self.model(x, self.t_span)  
         y_hat = y_hat[-1]  # 选择轨迹的最后一个点
         loss = nn.CrossEntropyLoss()(y_hat, y)  
         return {'loss': loss}  
     def configure_optimizers(self):  
         return torch.optim.Adam(self.model.parameters(), lr=0.01)  
     def train_dataloader(self):  
         return trainloader

最后训练模型:

 learn = Learner(t_span, model)  
 trainer = pl.Trainer(max_epochs=200)  
 trainer.fit(learn)

实验结果可视化

深度域轨迹分析

训练完成后,我们可以观察数据样本在深度域(即ODE的时间维度)中的演化轨迹:

 t_eval, trajectory = model(X_train, t_span)  
 trajectory = trajectory.detach().cpu()  
   
 fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 2))  
 for i in range(500):  
     ax0.plot(t_span, trajectory[:, i, 0], alpha=0.1, color=colors[int(yn[i])])  
     ax1.plot(t_span, trajectory[:, i, 1], alpha=0.1, color=colors[int(yn[i])])  
 ax0.set_title("维度 0")  
 ax1.set_title("维度 1")  
 plt.show()

向量场可视化

通过可视化学习得到的向量场,我们可以直观理解模型的动力学特性:

 x = torch.linspace(trajectory[:, :, 0].min(), trajectory[:, :, 0].max(), 50)  
 y = torch.linspace(trajectory[:, :, 1].min(), trajectory[:, :, 1].max(), 50)  
 X, Y = torch.meshgrid(x, y)  
 z = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1)  
 f_eval = model.vf(0, z.to(device)).cpu().detach()  
   
 fx, fy = f_eval[:, 0], f_eval[:, 1]  
 fx, fy = fx.reshape(50, 50), fy.reshape(50, 50)  
 fig, ax = plt.subplots(figsize=(4, 4))  
 ax.streamplot(X.numpy(), Y.numpy(), fx.numpy(), fy.numpy(), color='black')  
 plt.show()

Torchdyn进阶特性

Torchdyn框架的功能远不限于基础的Neural ODEs实现。它提供了丰富的高级特性,包括:

  • 高精度数值求解器
  • 平衡模型支持
  • 自定义微分方程系统

无论是物理模型的数值模拟,还是连续深度学习模型的开发,Torchdyn都提供了完整的工具链支持。

以上就是PyTorch使用Torchdyn实现连续时间神经网络的代码示例的详细内容,更多关于PyTorch Torchdyn连续时间神经网络的资料请关注脚本之家其它相关文章!

相关文章

  • Python3 sort和sorted用法+cmp_to_key()函数详解

    Python3 sort和sorted用法+cmp_to_key()函数详解

    这篇文章主要介绍了Python3 sort和sorted用法+cmp_to_key()函数详解,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-07-07
  • python 动态获取当前运行的类名和函数名的方法

    python 动态获取当前运行的类名和函数名的方法

    这篇文章主要介绍了python 动态获取当前运行的类名和函数名的方法,分别介绍使用内置方法、sys模块、修饰器、inspect模块等方法,需要的朋友可以参考下
    2014-04-04
  • 使用Python提取PDF表格到Excel文件的操作步骤

    使用Python提取PDF表格到Excel文件的操作步骤

    在对PDF中的表格进行再利用时,除了直接将PDF文档转换为Excel文件,我们还可以提取PDF文档中的表格数据并写入Excel工作表,本文将介绍如何使用Python提取PDF文档中的表格并写入Excel文件中,需要的朋友可以参考下
    2024-09-09
  • 在windows下Python打印彩色字体的方法

    在windows下Python打印彩色字体的方法

    这篇文章主要介绍了Python在windows下打印彩色字体的方法;具有很好的参考价值,希望对大家有所帮助,一起跟随小编过来看看吧
    2018-05-05
  • python使用IP归属地查询API追踪网络活动

    python使用IP归属地查询API追踪网络活动

    这篇文章主要为大家介绍了python使用IP归属地查询API追踪网络活动实现详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-09-09
  • Python入门之布尔值详解

    Python入门之布尔值详解

    Python中布尔值(Booleans)表示以下两个值之一:True或False。本文主要介绍布尔值(Booleans)的使用,和使用时需要注意的地方,需要的可以参考一下
    2023-02-02
  • Win10里python3创建虚拟环境的步骤

    Win10里python3创建虚拟环境的步骤

    在本篇文章里小编给大家整理的是一篇关于Win10里python3创建虚拟环境的步骤内容,需要的朋友们可以学习参考下。
    2020-01-01
  • Python Matplotlib条形图之垂直条形图和水平条形图详解

    Python Matplotlib条形图之垂直条形图和水平条形图详解

    这篇文章主要为大家详细介绍了Python Matplotlib条形图之垂直条形图和水平条形图,使用数据库,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-03-03
  • python中的break、continue、exit()、pass全面解析

    python中的break、continue、exit()、pass全面解析

    下面小编就为大家带来一篇python中的break、continue、exit()、pass全面解析。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-08-08
  • python plt如何保存为emf图像

    python plt如何保存为emf图像

    这篇文章主要介绍了python plt如何保存为emf图像问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09

最新评论