不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)

生成对抗网络(Generative Adversarial Networks,GAN)最早由 Ian Goodfellow 在 2014 年提出,是目前深度学习领域最具潜力的研究成果之一。它的核心思想是:同时训练两个相互协作、同时又相互竞争的深度神经网络(一个称为生成器 Generator,另一个称为判别器 Discriminator)来处理无监督学习的相关问题。在训练过程中,两个网络最终都要学习如何处理任务。

通常,我们会用下面这个例子来说明 GAN 的原理:将警察视为判别器,制造假币的犯罪分子视为生成器。一开始,犯罪分子会首先向警察展示一张假币。警察识别出该假币,并向犯罪分子反馈哪些地方是假的。接着,根据警察的反馈,犯罪分子改进工艺,制作一张更逼真的假币给警方检查。这时警方再反馈,犯罪分子再改进工艺。不断重复这一过程,直到警察识别不出真假,那么模型就训练成功了。

虽然 GAN 的核心思想看起来非常简单,但要搭建一个真正可用的 GAN 网络却并不容易。因为毕竟在 GAN 中有两个相互耦合的深度神经网络,同时对这两个网络进行梯度的反向传播,也就比一般场景困难两倍。

为此,本文将以深度卷积生成对抗网络(Deep Convolutional GAN,DCGAN)为例,介绍如何基于 Keras 2.0 框架,以 Tensorflow 为后端,在 200 行代码内搭建一个真实可用的 GAN 模型,并以该模型为基础自动生成 MNIST 手写体数字。

  判别器

判别器的作用是判断一个模型生成的图像和真实图像比,有多逼真。它的基本结构就是如下图所示的卷积神经网络(Convolutional Neural Network,CNN)。对于 MNIST 数据集来说,模型输入是一个 28x28 像素的单通道图像。Sigmoid 函数的输出值在 0-1 之间,表示图像真实度的概率,其中 0 表示肯定是假的,1 表示肯定是真的。与典型的 CNN 结构相比,这里去掉了层之间的 max-pooling,而是采用了步进卷积来进行下采样。这里每个 CNN 层都以 LeakyReLU 为激活函数。而且为了防止过拟合和记忆效应,层之间的 dropout 值均被设置在 0.4-0.7 之间。具体在 Keras 中的实现代码如下。

self.D = Sequential()
depth = 64
dropout = 0.4
# In: 28 x 28 x 1, depth = 1
# Out: 10 x 10 x 1, depth=64
input_shape = (self.img_rows, self.img_cols, self.channel)
self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,\
padding='same', activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same',\
activation=LeakyReLU(alpha=0.2)))
self.D.add(Dropout(dropout))
# Out: 1-dim probability
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()

  生成器

生成器的作用是合成假的图像,其基本机构如下图所示。图中,我们使用了卷积的倒数,即转置卷积(transposed convolution),从 100 维的噪声(满足 -1 至 1 之间的均匀分布)中生成了假图像。如在 DCGAN 模型中提到的那样,去掉微步进卷积,这里我们采用了模型前三层之间的上采样来合成更逼真的手写图像。在层与层之间,我们采用了批量归一化的方法来平稳化训练过程。以 ReLU 函数为每一层结构之后的激活函数。最后一层 Sigmoid 函数输出最后的假图像。第一层设置了 0.3-0.5 之间的 dropout 值来防止过拟合。具体代码如下。

self.G = Sequential()
dropout = 0.4
depth = 64+64+64+64
dim = 7
# In: 100
# Out: dim x dim x depth
self.G.add(Dense(dim*dim*depth, input_dim=100))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))
self.G.add(Dropout(dropout))
# In: dim x dim x depth
# Out: 2*dim x 2*dim x depth/2
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
self.G.add(Conv2DTranspose(1, 5, padding='same'))
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G

  生成 GAN 模型

下面我们生成真正的 GAN 模型。如上所述,这里我们需要搭建两个模型:一个是判别器模型,代表警察;另一个是对抗模型,代表制造假币的犯罪分子。

判别器模型

下面代码展示了如何在 Keras 框架下生成判别器模型。上文定义的判别器是为模型训练定义的损失函数。这里由于判别器的输出为 Sigmoid 函数,因此采用了二进制交叉熵为损失函数。在这种情况下,以 RMSProp 作为优化算法可以生成比 Adam 更逼真的假图像。这里我们将学习率设置在 0.0008,同时还设置了权值衰减和clipvalue等参数来稳定后期的训练过程。如果你需要调节学习率,那么也必须同步调节其他相关参数。

optimizer = RMSprop(lr=0.0008, clipvalue=1.0, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

对抗模型

如图所示,对抗模型的基本结构是判别器和生成器的叠加。生成器试图骗过判别器,同时从其反馈中提升自己。如下代码中演示了如何基于 Keras 框架实现这一部分功能。其中,除了学习速率的降低和相对权值衰减之外,训练参数与判别器模型中的训练参数完全相同。

optimizer = RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
metrics=['accuracy'])

训练

搭好模型之后,训练是最难实现的部分。这里我们首先用真实图像和假图像对判别器模型单独进行训练,以判断其正确性。接着,对判别器模型和对抗模型轮流展开训练。如下图展示了判别器模型训练的基本流程。在 Keras 框架下的实现代码如下所示。

images_train = self.x_train[np.random.randint(0,
self.x_train.shape[0], size=batch_size), :, :, :]
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.generator.predict(noise)
x = np.concatenate((images_train, images_fake))
y = np.ones([2*batch_size, 1])
y[batch_size:, :] = 0
d_loss = self.discriminator.train_on_batch(x, y)
y = np.ones([batch_size, 1])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
a_loss = self.adversarial.train_on_batch(noise, y)

训练过程中需要非常耐心,这里列出一些常见问题和解决方案:

问题1:最终生成的图像噪点太多。

解决:尝试在判别器和生成器模型上引入 dropout,一般更小的 dropout 值(0.3-0.6)可以产生更逼真的图像。

问题2:判别器的损失函数迅速收敛为零,导致发生器无法训练。

解决:不要对判别器进行预训练。而是调整学习率,使判别器的学习率大于对抗模型的学习率。也可以尝试对生成器换一个不同的训练噪声样本。

问题3:生成器输出的图像仍然看起来像噪声。

解决:检查激活函数、批量归一化和 dropout 的应用流程是否正确。

问题4:如何确定正确的模型/训练参数。

解决:尝试从一些已经发表的论文或代码中找到参考,调试时每次只调整一个参数。在进行 2000 步以上的训练时,注意观察在 500 或 1000 步左右参数值调整的效果。

  输出情况

下图展示了在训练过程中,整个模型的输出变化情况。可以看到,GAN 在自己学习如何生成手写体数字。

完整代码地址:

https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/GAN/dcgan_mnist.py 

本文作者:恒亮

本文转自雷锋网禁止二次转载,原文链接

时间: 2024-08-01 19:51:25

不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)的相关文章

200 行代码实现一个简单的区块链应用

区块链的基础概念很简单:一个分布式数据库,存储一个不断加长的 list,list 中包含着许多有序的记录.然而,在通常情况下,当我们谈到区块链的时候也会谈起使用区块链来解决的问题,这两者很容易混淆.像流行的比特币和以太坊这样基于区块链的项目就是这样."区块链"这个术语通常和像交易.智能合约.加密货币这样的概念紧紧联系在一起. 这就令理解区块链变得不必要得复杂起来,特别是当你想理解源码的时候.下面我将通过 200 行 JS 实现的超级简单的区块链来帮助大家理解它,我给这段代码起名为 Na

最简单易懂的GAN(生成对抗网络)教程:从理论到实践(附代码)

  之前 GAN网络是近两年深度学习领域的新秀,火的不行,本文旨在浅显理解传统GAN,分享学习心得.现有GAN网络大多数代码实现使用Python.torch等语言,这里,后面用matlab搭建一个简单的GAN网络,便于理解GAN原理. GAN的鼻祖之作是2014年NIPS一篇文章:Generative Adversarial Net,可以细细品味. ● 分享一个目前各类GAN的一个论文整理集合 ● 再分享一个目前各类GAN的一个代码整理集合   开始 我们知道GAN的思想是是一种二人零和博弈思想

GAN 很复杂?如何用不到 50 行代码训练 GAN(基于 PyTorch)

本文作者为前谷歌高级工程师.AI 初创公司 Wavefront 创始人兼 CTO Dev Nag,介绍了他是如何用不到五十行代码,在 PyTorch 平台上完成对 GAN 的训练.雷锋网编译整理. Dev Nag 什么是 GAN? 在进入技术层面之前,为照顾新入门的开发者,雷锋网先来介绍下什么是 GAN. 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative Adversarial Nets>,这标志着生成对抗网络

实践指南!16位资深行业者教你如何学习使用TensorFlow

首发地址:https://yq.aliyun.com/articles/71257 更多深度文章,请关注:https://yq.aliyun.com/cloud 如何开始学习使用TensorFlow? 相关回答: Harrison Kinsley --PythonProgramming.net的创始人 TensorFlow官方网站有相当多的文档和教程,但这些往往认为读者掌握了一些机器学习和人工智能知识.除了知道ML和AI,你也应该对Python编程语言非常熟练.因此,在开始学习如何使用Tenso

从把三千行代码重构成15行代码谈起

从把三千行代码重构成15行代码谈起 如果你认为这是一个标题党,那么我真诚的恳请你耐心的把文章的第一部分读完,然后再下结论.如果你认为能够戳中您的G点,那么请随手点个赞. 把三千行代码重构为15行 那年我刚毕业,进了现在这个公司.公司是搞数据中心环境监控的,里面充斥着嵌入式.精密空调.总线.RFID的概念,我一个都不懂.还好,公司之前用Delphi写的老客户端因为太慢,然后就搞了个Webform的替代,恰好我对Asp.Net还算了解,我对业务的不了解并不妨碍我称成为这个公司的一个程序员.小公司也有

你的每行代码值多少钱?

我知道,"line of code"(LoC)是一种非常愚蠢的计量方式.不要急着喷我,请大家先听我讲讲我最近参与的两个项目,看一下一些非常有意思的数字. 项目#1:传统的同地协作 第一个项目是由一组程序员通过传统的同地协作来执行的.人数为20(不包括项目经理.分析人员.产品负责人.SCRUM大师等等).该项目是一个大流量的网络拍卖网站(每天有超过200万的页面访问量). 代码库的大小约为20万行,其中15万是PHP,3万5是JavaScript,其余则是CSS.XML以及Ruby等.这

【圣诞特辑】Keras+树莓派,130行代码找到圣诞老人

今天这篇文章是使用Keras在Raspberry Pi上运行深度神经网络的一个完整指南. 我把这个项目当做一个"不是圣诞老人"(Not Santa)检测器,教你如何实际地实现它(并且过程中乐趣无穷). 第一部分,我们说一下什么是"圣诞老人检测器"(可能你不熟悉热播美剧<硅谷>里的"不是热狗"识别App,现在已经有人把它实现了). 然后,我们将通过安装TensorFlow.Keras和其他一些条件来配置树莓派进行深度学习. 树莓派为深度

多少行代码才能完成下列项目?1亿行代码是神马概念?

class="post_content" itemprop="articleBody"> 多少行代码才能完成下列项目?据统计,平均一个 iPhone APP 是4万行代码,PS CS6 是5百万行,facebook 的总体项目则高达6000多万行代码-- 我们也可以从这张信息图中看到,早期的 Windows 3.1 只有200多万行代码,到了 Vista 年代则是将近5000万行,也可以看出 Windows 7 比 viata 精简了许多.

几行代码搞定一棵漂亮的树

程序名:JTree(树状控件)结合了XML的长处,使您只需几行代码就可以拥有像Windows的资源管理器一样的Treeview了. 之前,本人曾写过一个Treeview,但是,不够美观,这一版本,在外观上做了很大的改进,很漂亮.运行速度很快. 详细功能请见示例示例打包下载 JTree在onclick时,有两个值可以用: var myTree=new JTree("showTree","vogueType.xml");myTree.setPicPath("i