pytorch GAN生成对抗网络实例

 更新时间:2020年01月10日 09:11:18   作者:全栈的方向  
今天小编就为大家分享一篇pytorch GAN生成对抗网络实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])

def artist_works():
	a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
	paintings = a*np.power(PAINT_POINTS,2) + (a-1)
	paintings = torch.from_numpy(paintings).float()
	return Variable(paintings)

G = nn.Sequential(
	nn.Linear(N_IDEAS,128),
	nn.ReLU(),
	nn.Linear(128,ART_COMPONENTS),
)

D = nn.Sequential(
	nn.Linear(ART_COMPONENTS,128),
	nn.ReLU(),
	nn.Linear(128,1),
	nn.Sigmoid(),
)

opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
	artist_paintings = artist_works()
	G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
	G_paintings = G(G_ideas)

	prob_artist0 = D(artist_paintings)
	prob_artist1 = D(G_paintings)

	D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
	G_loss = torch.mean(torch.log(1 - prob_artist1))

	opt_D.zero_grad()
	D_loss.backward(retain_variables=True)
	opt_D.step()

	opt_G.zero_grad()
	G_loss.backward()
	opt_G.step()

	if step % 50 == 0:
		plt.cla()
		plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
		plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
		plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
		plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
		plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
		plt.ylim((0,3))
		plt.legend(loc='upper right', fontsize=12)
		plt.draw()
		plt.pause(0.01)

plt.ioff()
plt.show()

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python提取具有某种特定字符串的行数据方法

    python提取具有某种特定字符串的行数据方法

    今天小编就为大家分享一篇python提取具有某种特定字符串的行数据方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 超详细讲解python正则表达式

    超详细讲解python正则表达式

    这篇文章主要介绍了python正则表达式,利用正则表达式实现文本的查找和替換功能会相对于比较简单,效率也会更高。感兴趣的小伙伴一起来学习学习吧
    2021-08-08
  • 在Gnumeric下使用Python脚本操作表格的教程

    在Gnumeric下使用Python脚本操作表格的教程

    这篇文章主要介绍了在Gnumeric下使用Python脚本操作表格的教程,本文来自于IBM官方网站,需要的朋友可以参考下
    2015-04-04
  • 分析Python读取文件时的路径问题

    分析Python读取文件时的路径问题

    本篇文章通过图文实例的方式给大家详细分析了Python读取文件时的路径问题,对此有需求的朋友可以参考学习下。
    2018-02-02
  • 使用Python进行同期群分析(Cohort Analysis)

    使用Python进行同期群分析(Cohort Analysis)

    同期群(Cohort)的字面意思(有共同特点或举止类同的)一群人,比如不同性别,不同年龄。这篇文章主要介绍了用Python语言来进行同期群分析,感兴趣的同学可以阅读参考一下本文
    2023-03-03
  • Python 实现将某一列设置为str类型

    Python 实现将某一列设置为str类型

    这篇文章主要介绍了Python 实现将某一列设置为str类型,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Python获取邮件地址的方法

    Python获取邮件地址的方法

    这篇文章主要介绍了Python获取邮件地址的方法,通过自定义函数分析提取字符串中邮件地址的相关技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-07-07
  • python实时监控logstash日志代码

    python实时监控logstash日志代码

    这篇文章主要介绍了python实时监控logstash日志代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • python提取word文件中的图片并上传阿里云OSS

    python提取word文件中的图片并上传阿里云OSS

    这篇文章主要介绍了通过Python提取Word文件中的所有图片,并将其上传至阿里云OSS。文中的示例代码对学习Python有一定的帮助,快跟随小编一起学习一下吧
    2021-12-12
  • 分享PyCharm的几个使用技巧

    分享PyCharm的几个使用技巧

    这篇文章主要介绍了分享PyCharm的几个使用技巧,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-11-11

最新评论