pytorch:实现简单的GAN示例(MNIST数据集)
更新时间:2020年01月10日 09:17:37 作者:xckkcxxck
今天小编就为大家分享一篇pytorch:实现简单的GAN示例(MNIST数据集),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
我就废话不多说了,直接上代码吧!
# -*- coding: utf-8 -*- """ Created on Sat Oct 13 10:22:45 2018 @author: www """ import torch from torch import nn from torch.autograd import Variable import torchvision.transforms as tfs from torch.utils.data import DataLoader, sampler from torchvision.datasets import MNIST import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸 plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' def show_images(images): # 定义画图工具 images = np.reshape(images, [images.shape[0], -1]) sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg = int(np.ceil(np.sqrt(images.shape[1]))) fig = plt.figure(figsize=(sqrtn, sqrtn)) gs = gridspec.GridSpec(sqrtn, sqrtn) gs.update(wspace=0.05, hspace=0.05) for i, img in enumerate(images): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg,sqrtimg])) return def preprocess_img(x): x = tfs.ToTensor()(x) return (x - 0.5) / 0.5 def deprocess_img(x): return (x + 1.0) / 2.0 class ChunkSampler(sampler.Sampler): # 定义一个取样的函数 """Samples elements sequentially from some offset. Arguments: num_samples: # of desired datapoints start: offset where we should start selecting from """ def __init__(self, num_samples, start=0): self.num_samples = num_samples self.start = start def __iter__(self): return iter(range(self.start, self.start + self.num_samples)) def __len__(self): return self.num_samples NUM_TRAIN = 50000 NUM_VAL = 5000 NOISE_DIM = 96 batch_size = 128 train_set = MNIST('E:/data', train=True, transform=preprocess_img) train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0)) val_set = MNIST('E:/data', train=True, transform=preprocess_img) val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN)) imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果 show_images(imgs) #判别网络 def discriminator(): net = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1) ) return net #生成网络 def generator(noise_dim=NOISE_DIM): net = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.Linear(1024, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh() ) return net #判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1 bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数 def discriminator_loss(logits_real, logits_fake): # 判别器的 loss size = logits_real.shape[0] true_labels = Variable(torch.ones(size, 1)).float() false_labels = Variable(torch.zeros(size, 1)).float() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return loss def generator_loss(logits_fake): # 生成器的 loss size = logits_fake.shape[0] true_labels = Variable(torch.ones(size, 1)).float() loss = bce_loss(logits_fake, true_labels) return loss # 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999 def get_optimizer(net): optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999)) return optimizer def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, noise_size=96, num_epochs=10): iter_count = 0 for epoch in range(num_epochs): for x, _ in train_data: bs = x.shape[0] # 判别网络 real_data = Variable(x).view(bs, -1) # 真实数据 logits_real = D_net(real_data) # 判别网络得分 sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed) # 生成的假的数据 logits_fake = D_net(fake_images) # 判别网络得分 d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step() # 优化判别网络 # 生成网络 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed) # 生成的假的数据 gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake) # 生成网络的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step() # 优化生成网络 if (iter_count % show_every == 0): print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item())) imgs_numpy = deprocess_img(fake_images.data.cpu().numpy()) show_images(imgs_numpy[0:16]) plt.show() print() iter_count += 1 D = discriminator() G = generator() D_optim = get_optimizer(D) G_optim = get_optimizer(G) train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
python之线程通过信号pyqtSignal刷新ui的方法
今天小编就为大家分享一篇python之线程通过信号pyqtSignal刷新ui的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2019-01-01浅谈对pytroch中torch.autograd.backward的思考
这篇文章主要介绍了对pytroch中torch.autograd.backward的思考,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2019-12-12python中import和from-import的区别解析
这篇文章主要介绍了python中import和from-import的区别解析,本文通过实例代码给大家讲解的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2022-12-12
最新评论