在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)

项目链接:https://github.com/kvmanohar22/ Generative-Models
变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法。
本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能。你可能会问:我们已经有了数百万张图像,为什么还要从给定数据分布中生成图像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那样,实际上有很多应用。我觉得比较有趣的一种是使用 GAN 模拟可能的未来,就像强化学习中使用策略梯度的智能体那样。
本文组织架构:

  • 变分自编码器(VAE)
  • 生成对抗网络(GAN)
  • 训练普通 GAN 的难点
  • 训练细节
  • 在 MNIST 上进行 VAE 和 GAN 对比实验
    • 在无标签的情况下训练 GAN 判别器
    • 在有标签的情况下训练 GAN 判别器
  • 在 CIFAR 上进行 VAE 和 GAN 实验
  • 延伸阅读

VAE

变分自编码器可用于对先验数据分布进行建模。从名字上就可以看出,它包括两部分:编码器和解码器。编码器将数据分布的高级表征映射到数据的低级表征,低级表征叫作本征向量(latent vector)。解码器吸收数据的低级表征,然后输出同样数据的高级表征。
从数学上来讲,让 X 作为编码器的输入,z 作为本征向量,X′作为解码器的输出。
图 1 是 VAE 的可视化图。

这与标准自编码器有何不同?关键区别在于我们对本征向量的约束。如果是标准自编码器,那么我们主要关注重建损失(reconstruction loss),即:

而在变分自编码器的情况中,我们希望本征向量遵循特定的分布,通常是单位高斯分布(unit Gaussian distribution),使下列损失得到优化:

p(z′)∼N(0,I) 中 I 指单位矩阵(identity matrx),q(z∣X) 是本征向量的分布,其中。和由神经网络来计算。KL(A,B) 是分布 B 到 A 的 KL 散度。
由于损失函数中还有其他项,因此存在模型生成图像的精度,同本征向量的分布与单位高斯分布的接近程度之间存在权衡(trade-off)。这两部分由两个超参数λ_1 和λ_2 来控制。

GAN

GAN 是根据给定的先验分布生成数据的另一种方式,包括同时进行的两部分:判别器和生成器。
判别器用于对「真」图像和「伪」图像进行分类,生成器从随机噪声中生成图像(随机噪声通常叫作本征向量或代码,该噪声通常从均匀分布(uniform distribution)或高斯分布中获取)。生成器的任务是生成可以以假乱真的图像,令判别器也无法区分出来。也就是说,生成器和判别器是互相对抗的。判别器非常努力地尝试区分真伪图像,同时生成器尽力生成更加逼真的图像,目的是使判别器将这些图像也分类为「真」图像。
图 2 是 GAN 的典型结构。

生成器包括利用代码输出图像的解卷积层。图 3 是生成器的架构图。

训练 GAN 的难点

训练 GAN 时我们会遇到一些挑战,我认为其中最大的挑战在于本征向量/代码的采样。代码只是从先验分布中对本征变量的噪声采样。有很多种方法可以克服该挑战,包括:使用 VAE 对本征变量进行编码,学习数据的先验分布。这听起来要好一些,因为编码器能够学习数据分布,现在我们可以从分布中进行采样,而不是生成随机噪声。

训练细节

我们知道两个分布 p(真实分布)和 q(估计分布)之间的交叉熵通过以下公式计算:

  • 对于二元分类:

  • 对于 GAN,我们假设分布的一半来自真实数据分布,一半来自估计分布,因此:

    训练 GAN 需要同时优化两个损失函数。
    按照极小极大值算法:

    这里,判别器需要区分图像的真伪,不管图像是否包含真实物体,都没有注意力。当我们在 CIFAR 上检查 GAN 生成的图像时会明显看到这一点。
    我们可以重新定义判别器损失目标,使之包含标签。这被证明可以提高主观样本的质量。如:在 MNIST 或 CIFAR-10(两个数据集都有 10 个类别)。
    上述 Python 损失函数在 TensorFlow 中的实现:

    def VAE_loss(true_images, logits, mean, std):
      """
        Args:
          true_images : batch of input images
          logits      : linear output of the decoder network (the constructed images)
          mean        : mean of the latent code
          std         : standard deviation of the latent code
      """
      imgs_flat    = tf.reshape(true_images, [-1, img_himg_wimg_d])
      encoder_loss = 0.5 * tf.reduce_sum(tf.square(mean)+tf.square(std)
                     -tf.log(tf.square(std))-1, 1)
      decoder_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                     logits=logits, labels=img_flat), 1)
      return tf.reduce_mean(encoder_loss + decoder_loss)
    def GAN_loss_without_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a column vector)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a column vector)
      """
    
      true_prob = tf.nn.sigmoid(true_logit)
      fake_prob = tf.nn.sigmoid(fake_logit)
      d_loss = tf.reduce_mean(-tf.log(true_prob)-tf.log(1-fake_prob))
      g_loss = tf.reduce_mean(-tf.log(fake_prob))
      return d_loss, g_loss
    def GAN_loss_with_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a matrix now)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a matrix now)
      """
      d_true_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.true_logit, dim=1)
      d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=1-self.labels, logits=self.fake_logit, dim=1)
      g_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.fake_logit, dim=1)
    
      d_loss = d_true_loss + d_fake_loss      return tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)
    

在 MNIST 上进行 VAE 与 GAN 对比实验

1. 不使用标签训练判别器
我在 MNIST 上训练了一个 VAE。代码地址:https://github.com/kvmanohar22/Generative-Models
实验使用了 MNIST 的 28×28 图像,下图中:

  • 左侧:数据分布的 64 张原始图像
  • 中间:VAE 生成的 64 张图像
  • 右侧:GAN 生成的 64 张图像

第 1 次迭代:

第 2 次迭代:

第 3 次迭代:

第 4 次迭代:

第 100 次迭代:

VAE(125)和 GAN(368)训练的最终结果:

根据GAN迭代次数生成的gif图:

显然,VAE 生成的图像与 GAN 生成的图像相比,前者更加模糊。这个结果在预料之中,因为 VAE 模型生成的所有输出都是分布平均。为了减少图像的模糊度,我们可以使用 L1 损失来代替 L2 损失。
在第一个实验后,作者还将在近期研究使用标签训练判别器,并在 CIFAR 数据集上测试 VAE 与 GAN 的性能。
使用
下载 MNIST 和 CIFAR 数据集
使用 MNIST 训练 VAE 请运行:

python main.py --train --model vae --dataset mnist

使用 MNIST 训练 GAN 请运行:

python main.py --train --model gan --dataset mnist

想要获取完整的命令行选项,请运行:

python main.py --help

该模型由 generate_frq 决定生成图片的频率,默认值为 1。

GAN 在 MNIST 上的训练结果

MNIST 数据集中的样本图像:

上方是 VAE 生成的图像,下方的图展示了 GAN 生成图像的过程:

原文发布时间为:2017-10-29
本文来自合作伙伴“数据派THU”,了解相关信息可以关注“数据派THU”微信公众号

时间: 2024-09-10 23:35:03

在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)的相关文章

directx-x文件中有两个独立模型,在d3d中用什么代码能单独控制其中一个,会的@下,模我有

问题描述 x文件中有两个独立模型,在d3d中用什么代码能单独控制其中一个,会的@下,模我有 x文件中有两个独立模型,在d3d中用什么代码能单独控制其中一个,会的@下,模我有, 有没写d3d游戏的高手,菜鸟求指导,只看过和改过浅墨的例子,有很多问题都不懂,我是自学的,有没有乐于指导的朋友,求引导

王建宙批露中移动两大基地详情:占地均1千多亩

12月24日消息,在今天举行的中国移动南方基地IT支撑中心开通仪式上,中国移动总裁王建宙透露了目前建设的两大基地的情况,其中北方基地更鲜为人知,目前刚刚奠基. 王建宙表示,2005年中国移动决策在广州建南方基地,当时决定不采用在城市中心建大厦的方式,而是在郊区建设园区的方式. 他表示,南方基地已施工3年,体现集中化管理的原则,这是第一次对省公司实行远程管理.异地服务,在园区将留很大地方给合作伙伴. 中国移动计划建设部总经理董昕透露,中国移动目前正在建设南方基地和北方基地,南方基地建设较早,外界比

中兴欲成全球品牌前三 中美两大市场对抗苹果

腾讯科技讯(明轩)北京时间12月29日消息,据国外媒体报道,中兴通讯计划通过高端设备和与运营商建立更加紧密的关系,在竞争激烈的美国智能手机市场加强竞争力.中兴通讯此举旨在成为全球智能手机三大品牌之一. 中兴通讯以生产廉价手机闻名,目前占据着美国智能手机市场大约5%的份额,并面临着苹果和三星电子等公司的激烈对抗.中兴通讯执行副总裁何士友在接受采访时表示,在中美两国市场获得成功,是实现公司抱负的关键,并补充称中兴通讯希望美国市场超越中国市场,成为公司智能手机业务营收的第一大市场.负责中兴通讯手机业务

XP安装过程中的两大潜在危险_WindowsXP

Windows XP是微软推出的视窗操作系统中,迄今以来体积最大.安装所需时间最长,功能也号称最强大的产品.安装XP的时间基本需要50-80分钟左右,那么在这么长的时间里,XP到底干了些什么呢?为什么有的人声称安装XP破坏了他们原来的系统或是文件呢,我们就来仔细地看看XP安装时候的关键步骤,让大家明白安装XP操作的安全要点: 一.解压数据包.拷贝临时文件 安装程序主要是在C盘先建立一个临时目录,把安装程序中某些压缩包内的文件释放到该目录里,为安装做好准备.XP的压缩安装文件已经达到了数百兆,拷贝

大数据量分页存储过程效率测试附测试代码与结果

测试环境 硬件:CPU 酷睿双核T5750 内存:2G 软件:Windows server 2003 + sql server 2005 OK,我们首先创建一数据库:data_Test,并在此数据库中创建一表:tb_TestTable 复制代码 代码如下: create database data_Test --创建数据库 data_Test GO use data_Test GO create table tb_TestTable --创建表 (id int identity(1,1) pri

大数据量分页存储过程效率测试附测试代码与结果_MsSql

测试环境 硬件:CPU 酷睿双核T5750 内存:2G 软件:Windows server 2003 + sql server 2005 OK,我们首先创建一数据库:data_Test,并在此数据库中创建一表:tb_TestTable 复制代码 代码如下: create database data_Test --创建数据库 data_Test  GO use data_Test GO create table tb_TestTable --创建表 (id int identity(1,1) pr

在IIS中改变ASP.NET程序版本的实现方法附批处理代码_服务器

在windows2003的iis6.0当中,在装过.NETFRAMEWORK1.1,和.NETFRAMEWORK2.0之后,在新建的ASP.NET应用程序中,查看属性,会出现ASP.NET的选项卡,在此可以更改该WEB应用程序是基于哪个框架运行的. 但是最近在装了64位的WINDOWSXP PROFESSIONAL后,找不到此选项卡了,只能通过如下方式修改: 在IIS管理器中,选择指定的WEB应用程序,右键-->属性-->configuration-->mapping-->根据需要

国内域名种类繁多 拼音和数字成两大支柱

中介交易 SEO诊断 淘宝客 云主机 技术大厅 近日,杨子和黄圣依所属的巨力影视公司成功收购了数字域名5888.com,交易价达6位数以上,目前域名已经转发至官网,易宝支付花高价购买yibao.com双拼域名,可见,目前国内成功交易的域名种类当中,当属数字域名和拼音域名较为常见,同时,这两种类域名也是国内域名市场中的两大支柱. 数字域名越来越受国内投资者热衷,越来越多不同行业的人涉及域名行业投资,甚至娱乐业也已和域名行业投资相关联.杨子涉猎域名行业,其在收购5888.com域名之前,就已经收购了

中美两大会计远程教育巨头联姻

中国和美国两大会计远程教育巨头--正保远程和Becker的联姻,或许能为面临国际化挑战的国内财务高管们提供了一条新的职场进化路线. 文/马丽 随着越来越多的国际会计认证机构在国内落地开花,美国注册会计师(AICPA)也正在全力拓展在中国的布局. 2009年12月8日,中国最大的CPA资格考试培训教育机构--正保远程教育机构(以下简称"正保远程")宣布与美国迪弗莱公司 (DevryInc)旗下子公司Becker ProfessionalEducation(以下简称"Becker