Pytorch复现扩散模型的示例详解

 更新时间:2023年04月25日 08:23:47   作者:嘟粥yyds  
这篇文章主要为大家详细介绍了如何利用Pytorch复现扩散模型,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的可以跟随小编一起了解一下

开发环境

集成开发工具:jupyter notebook 6.5.2
集成开发环境:Python 3.10.6
第三方库:torch、matplotlib、sklearn、numpy

1 加载相关第三方库

# 使得在 notebook 中显示绘图,而不是在外部窗口中显示
%matplotlib inline  
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch
import torch.nn as nn
import io
from PIL import Image

2 加载数据集

这里选择S 形曲线数据集作为本次复现扩散模型所用数据集。

s_curve, _ = make_s_curve(10 ** 4, noise=0.1)

将数据集中的特征缩放到一个相对较小的范围内,以便于模型的训练和收敛。这样做可以避免数据的特征值之间差异过大,导致某些特征对模型的影响过大,而其他特征的影响被忽略的情况。同时,将数据的特征缩放到一个相对较小的范围内,也有助于提高模型的泛化能力,使其能够更好地适应新的未知数据。

s_curve = s_curve[:, [0, 2]] / 10.
print(F"shape of Moons:{np.shape(s_curve)}")

将数据集从原来的 (10000, 2) 转换为 (2, 10000),即每一列对应一个样本的所有特征值,这样的形状更适合一些深度学习框架的输入格式。同时还可以保持数据的连续性:对数据进行转置操作可以保持数据之间的连续性。在某些机器学习算法或深度学习框架中,连续的数据在内存中存储更加紧凑,可以更快地读取和处理数据,从而提高模型的训练和预测效率。

data = s_curve.T
 
# 绘制 S 形曲线数据集
fig, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolor='white')
ax.axis('off')

因为在深度学习中,通常使用 PyTorch 等深度学习框架来实现模型的训练和预测。而 PyTorch 中的数据处理对象是张量(Tensor),因此我们需要将原始数据集转换为张量对象才能进行后续的深度学习模型的训练和预测。另外,由于深度学习模型通常需要浮点数类型的数据作为输入,因此我们需要使用 float() 将张量的数据类型设置为浮点型。这样做可以保证输入数据类型的一致性,避免数据类型不匹配导致的错误。

dataset = torch.Tensor(s_curve).float()

3 确定超参数的值

首先,指定步数(num_step),这个步数可以根据 beta、分布的均值和标准差来共同确定。num_step 指定了扩散模型的最终状态的计算次数,每一次计算对应一个 beta 值。

接着,使用 torch.linspace() 函数生成一个等间隔的 num_step 个 beta 值。然后,通过对这些 beta 值执行 sigmoid 激活函数以及线性变换,将它们转换为介于 1e-5 到 0.5e-2 之间的浮点数。这些 beta 值将在后续计算中用于计算扩散模型的每一步的参数。

接下来,计算一些中间变量,包括 alphas、alphas_prod、alphas_prod_p、alphas_bar_sqrt、one_minus_alphas_bar_log 和 one_minus_alphas_bar_sqrt。其中,alphas 表示每一步的 alpha 值,alphas_prod 表示前 t 步的 alpha 值的累积乘积,alphas_prod_p 表示前 t-1 步的 alpha 值的累积乘积,alphas_bar_sqrt 表示前 t 步的 alpha 值的累积乘积的平方根,one_minus_alphas_bar_log 表示前 t 步的 alpha 值的累积乘积的对数的负值,one_minus_alphas_bar_sqrt 表示前 t 步的 alpha 值的累积乘积的差值的平方根。

最后,使用 assert 命令检查计算的所有变量的形状是否相同,并打印出 betas 变量的形状。

num_step = 100  # 一开始可以由beta、分布的均值和标准差来共同确定
 
# 指定每一步的beta
betas = torch.linspace(-6, 6, num_step)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
 
# 计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, dim=0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)  # p表示previous
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
 
assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape
print(f"all the same shape:{betas.shape}")

 4 确定扩散过程任意时刻的采样值

首先从正态分布中生成随机噪声。然后,根据参数重整化技巧,使用预先计算好的 alpha_bar_sqrt 和 one_minus_alphas_bar_sqrt,将初始值 x_0 进行变换,得到时刻 t 的采样值。最后,将噪声加入到采样值中,得到最终的采样值。

# 计算任意时刻的x的采样值,基于x_0和参数重整化技巧
def q_x(x_0, t):
    """可以基于x[0]得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)  # noise是从正态分布中生成的随机噪声
    alphas_t = alphas_bar_sqrt[t]
    alphas_l_m_t = one_minus_alphas_bar_sqrt[t]
    return (alphas_t * x_0 + alphas_l_m_t * noise) # 在x[0]的基础上添加噪声

5 演示原始数据分布加噪100步后的效果

生成样本点随时间变化的演化过程图。生成一个大小为2x10的子图网格,每个子图显示了原始S曲线数据集在经过噪声添加和扩散操作后在某个时间点t时的图像。其中,num_shows变量指定了要显示的时间步数,这里为20,因此总共会显示20张子图。在每个子图中,使用q_x函数对原始数据集进行噪声添加和扩散操作,得到对应时间点t时的新数据集,然后在子图中以红色散点图的形式绘制出来。每个子图的标题显示了该子图所对应的时间步t。

num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 7))
plt.rc('text', color='blue')
# 共有10000个点,每个点包含两个坐标
# 生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
    j = i // 10
    k = i % 10
    q_i = q_x(dataset, torch.tensor([i * num_step // num_shows]))  # 生成t时刻的采样数据
    axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')
    
    axs[j, k].set_axis_off()
    axs[j, k].set_title('$q(\mathbf{x}_{'+ str(i * num_step // num_shows) + '})$')

6 编写拟合逆扩散过程高斯分布的模型

在输入的基础上添加一个时间步长 t,并对此进行嵌入。具体来说,它使用了 3 个 nn.Embedding 层,分别对应于嵌入 t 的 3 个维度。

模型的 forward 方法接受一个输入 x 和一个时间步长 t,并返回输出 y。在 forward 方法中,输入 x 会经过一系列的全连接层(使用 nn.Linear 实现),其中每两个全连接层之间都有一个 ReLU 激活函数。在这些全连接层之前和之后,模型都会使用 nn.Embedding 层将 t 嵌入到向量中。最终的输出 y 是一个 2 维向量。

class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
        
        self.linears = nn.ModuleList(
        [
            nn.Linear(2, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, 2),
        ])
        
        self.step_embeddings = nn.ModuleList(
        [
            nn.Embedding(n_steps, num_units),
            nn.Embedding(n_steps, num_units),
            nn.Embedding(n_steps, num_units),
        ])
    
    def forward(self, x, t):
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)
        
        x = self.linears[-1](x)
        return x

7 编写训练的误差函数

def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """
    对任意时刻t进行采样计算loss
    param:
    model:模型 
    x_0:初始状态 
    alphas_bar_sqrt、one_minus_alphas_bar_sqrt: 参数 
    n_steps:时间步数
    return:损失值
    """
    batch_size = x_0.shape[0]
    # 随机采样一个时刻t,为了提高训练效率,这里确保t不重复
    # 对一个batchsize样本生成随机的时刻t,覆盖到更多不同的t
    t = torch.randint(0, n_steps, size=(batch_size // 2,))
    t = torch.cat([t, n_steps - 1 - t], dim=0)  # [batch]
    t = t.unsqueeze(-1)  # [batch, 1]
    # x0的系数
    a = alphas_bar_sqrt[t]
    # eps的系数
    aml = one_minus_alphas_bar_sqrt[t]
    # 生成随机噪声eps
    e = torch.randn_like(x_0)
    # 构造模型的输入
    x = x_0 * a + e * aml
    # 送入模型,得到t时刻的随机噪声预测值
    output = model(x, t.squeeze(-1))
    # 与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()

8 编写逆扩散采样函数(inference过程)

进行扩散模型的采样。具体来说,p_sample_loop函数是从x[T]恢复x[T-1]、x[T-2]、...、x[0]的过程,其中x[T]是输入的初始值。在这个函数里,使用了一个for循环,从最后一个时刻T开始往前推,依次对每个时刻进行采样。在每个时刻,调用p_sample函数进行采样。

p_sample函数的主要作用是从x[T]采样t时刻的重构值,其中x[T]是输入的初始值,t表示当前时刻。具体来说,首先通过模型预测出eps_theta,然后通过一些计算,得到该时刻的重构值sample。其中,mean表示重构值的均值,z是服从标准正态分布的噪声,sigma_t是该时刻的标准差。最后,将sample作为当前时刻的重构值返回。

def p_sample_loop(model, shape, n_step, betas, one_minus_alphas_bar_sqrt):
    """从x[T]恢复x[T - 1]、x[T - 2]、...、x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_step)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq
 
def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    """从x[T]采样t时刻的重构值"""
    t = torch.tensor([t])
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    eps_theta = model(x, t)
    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return (sample)

9 开始训练模型,并打印loss及中间重构效果

这段代码定义了一个EMA(Exponential Moving Average,指数平滑移动平均)类,它用于对模型的参数进行平滑处理。构造函数中的 mu 参数控制平滑程度,shadow 是一个字典,用于存储参数的平滑后的值。

register 方法将参数 val 注册到 shadow 字典中,__call__方法对指定名称的参数 name 进行平滑处理。其中,x 是当前时刻参数的值。计算完成后,将结果存储在 shadow 字典中,并返回平滑后的值。

seed = 1234  # 确保程序在每次运行时生成的随机数序列都是一样的
 
class EMA():
    """构建一个参数平滑器,以便更好地泛化模型并减少过拟合"""
    def __init__(self, mu=0.01):
        self.mu = mu
        self.shadow = {}
        
    def register(self, name, val):
        self.shadow[name] = val.clone()
        
    def __call__(self, name, x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
        return new_average
print('Training model.....')
 
 
batch_size = 512  # 批训练大小
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_epoch = 4000  # 定义迭代4000次
plt.rc('text', color='blue')
 
model = MLPDiffusion(num_step)  # 输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
for t in range(num_epoch):
    for idx, batch_x in enumerate(dataloader):
        loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_step)
        optimizer.zero_grad()  # 对梯度进行清零,防止网络权重更新过于迅速或不稳定,无法得到正确的收敛结果
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.)  # 对梯度进行裁剪,避免出现梯度爆炸
        optimizer.step()
    if (t % 100 == 0):
        print(loss)
        x_seq = p_sample_loop(model, dataset.shape, num_step, betas, one_minus_alphas_bar_sqrt)  # 共有100个元素
        
        fig, axs = plt.subplots(1, 5, figsize=(28, 7))
        for i in range(1, 6):
            cur_x = x_seq[i * 20].detach()
            axs[i - 1].scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white')
            axs[i - 1].set_axis_off()
            axs[i - 1].set_title('$q(\mathbf{x}_{'+str(i * 20)+'})$')

部分效果图:

10 动画演示扩散过程和逆扩散过程

# 生成前向过程,也就是逐步加噪声
imgs = []
for i in range(100):
    plt.clf()
    torch_i = q_x(dataset, torch.tensor([i]))
    plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5)
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    img = Image.open(img_buf)
    imgs.append(img)

# 生成逆过程,也就是逐步复原
reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].detach()  # 拿到训练末尾阶段生成的x_seq
    plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white', s=5)
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    img = Image.open(img_buf)
    reverse.append(img)

imgs = imgs + reverse
 
imgs[0].save("diffusion.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

动画效果图:

以上就是Pytorch复现扩散模型的示例详解的详细内容,更多关于Pytorch扩散模型的资料请关注脚本之家其它相关文章!

相关文章

  • python实现简单的飞机大战游戏

    python实现简单的飞机大战游戏

    这篇文章主要为大家详细介绍了python实现简单的飞机大战游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-05-05
  • tensorflow模型保存、加载之变量重命名实例

    tensorflow模型保存、加载之变量重命名实例

    今天小编就为大家分享一篇tensorflow模型保存、加载之变量重命名实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • pygame加载中文名mp3文件出现error

    pygame加载中文名mp3文件出现error

    本文主要介绍了pygame加载中文名mp3文件出现error的解决方案。具有很好的参考价值,下面跟着小编一起来看下吧
    2017-03-03
  • Python 获取中文字拼音首个字母的方法

    Python 获取中文字拼音首个字母的方法

    今天小编就为大家分享一篇Python 获取中文字拼音首个字母的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • Python minidom模块用法示例【DOM写入和解析XML】

    Python minidom模块用法示例【DOM写入和解析XML】

    这篇文章主要介绍了Python minidom模块用法,结合实例形式分析了Python DOM创建、写入和解析XML文件相关操作技巧,需要的朋友可以参考下
    2019-03-03
  • Python技法-序列拆分详解

    Python技法-序列拆分详解

    Python中的任何序列(可迭代的对象)都可以通过赋值操作进行拆分,包括但不限于元组、列表、字符串、文件、迭代器、生成器等。
    2021-10-10
  • pytorch 如何实现HWC转CHW

    pytorch 如何实现HWC转CHW

    这篇文章主要介绍了pytorch HWC转CHW的实现方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python使用腾讯云API实现短信验证码功能

    Python使用腾讯云API实现短信验证码功能

    使用Python与腾讯云接口对接,实现短信验证码功能变得非常简单,只需要几行代码就能够轻松实现短信的发送,无须关心复杂的短信协议和底层实现,读者可以根据自己的实际需求,灵活使用腾讯云短信SDK提供的API来实现更丰富的短信功能
    2024-01-01
  • 计算pytorch标准化(Normalize)所需要数据集的均值和方差实例

    计算pytorch标准化(Normalize)所需要数据集的均值和方差实例

    今天小编就为大家分享一篇计算pytorch标准化(Normalize)所需要数据集的均值和方差实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python模块psycopg2连接postgresql的实现

    Python模块psycopg2连接postgresql的实现

    本文主要介绍了Python模块psycopg2连接postgresql的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-07-07

最新评论