只需 130 行代码,用 GAN 生成二维样本的小例子

50行GAN代码的问题

Dev Nag 写的 50 行代码的 GAN,大概是网上流传最广的,关于GAN最简单的小例子。这是一份用一维均匀样本作为特征空间(latent space)样本,经过生成网络变换后,生成高斯分布样本的代码。结构非常清晰,却有一个奇怪的问题,就是判别器(Discriminator)的输入不是2维样本,而是把整个mini-batch整体作为一个维度是batch size(代码中batch size等于cardinality)那么大的样本。也就是说判别网络要判别的不是一个一维的目标分布,而是batch size那么大维度的分布:

...
d_input_size = 100   # Minibatch size - cardinality of distributions
...
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))
...
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
...
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward()  # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()  # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

...

不知作者是疏忽了还是有意为之,总之这么做的结果就是如此简单的例子收敛都好。可能作者自己也察觉了收敛问题,就想把方差信息也放进来,于是又写了个预处理函数(decorate_with_diffs)计算出每个样本距离一批样本中心的距离平方,作为给判别网络的额外输入,其实这样还增加了输入维度。结果当然是加不加这个方差信息都能勉强收敛,但是都不稳定。甚至作者自己贴出来的生成样本分布(下图)都不令人满意:

如果直接把这份代码改成二维的,就会发现除了简单的对称分布以外,其他分布基本都无法生成。

理论上讲神经网络作为一种通用的近似函数,只要capacity够,学习多少维分布都不成问题,但是这样写法显然极大增加了收敛难度。更自然的做法应该是:判别网络只接受单个二维样本,通过batch size或是多步迭代学习分布信息。

另:这份代码其实有130行。

从自定义的二维分布采样

不管怎样Dev Nag的代码还是提供了一个用于理解和试验GAN的很好的框架,做一些修改就可以得到一份更适合直观演示,且更容易收敛的代码,也就是本文的例子。

从可视化的角度二维显然比一维更直观,所以我们采用二维样本。第一步,当然是要设定一个目标分布,作为二维的例子,分布的定义方式应该尽量自由,这个例子中我们的思路是通过灰度图像定义的概率密度,进而来产生样本,比如下面这样:

二维情况下,这种采样的一个实现方法是:求一个维度上的边缘(marginal)概率+另一维度上近似的条件概率。比如把图像中白色像素的值作为概率密度的相对大小,然后沿着x求和,然后在y轴上求出marginal probability density,接着再根据y的位置,近似得到对应x关于y的条件概率。采样的时候先采y的值,再采x的值就能近似得到符合图像描述的分布的样本。具体细节就不展开讲解了,看代码:

from functools import partial
import numpy
from skimage import transform

EPS = 1e-6
RESOLUTION = 0.001
num_grids = int(1/RESOLUTION+0.5)

def generate_lut(img):
    """
    linear approximation of CDF & marginal
    :param density_img:
    :return: lut_y, lut_x
    """
    density_img = transform.resize(img, (num_grids, num_grids))
    x_accumlation = numpy.sum(density_img, axis=1)
    sum_xy = numpy.sum(x_accumlation)
    y_cdf_of_accumulated_x = [[0., 0.]]
    accumulated = 0
    for ir, i in enumerate(range(num_grids-1, -1, -1)):
        accumulated += x_accumlation[i]
        if accumulated == 0:
            y_cdf_of_accumulated_x[0][0] = float(ir+1)/float(num_grids)
        elif EPS < accumulated < sum_xy - EPS:
            y_cdf_of_accumulated_x.append([float(ir+1)/float(num_grids), accumulated/sum_xy])
        else:
            break
    y_cdf_of_accumulated_x.append([float(ir+1)/float(num_grids), 1.])
    y_cdf_of_accumulated_x = numpy.array(y_cdf_of_accumulated_x)

    x_cdfs = []
    for j in range(num_grids):
        x_freq = density_img[num_grids-j-1]
        sum_x = numpy.sum(x_freq)
        x_cdf = [[0., 0.]]
        accumulated = 0
        for i in range(num_grids):
            accumulated += x_freq[i]
            if accumulated == 0:
                x_cdf[0][0] = float(i+1) / float(num_grids)
            elif EPS < accumulated < sum_xy - EPS:
                x_cdf.append([float(i+1)/float(num_grids), accumulated/sum_x])
            else:
                break
        x_cdf.append([float(i+1)/float(num_grids), 1.])
        if accumulated > EPS:
            x_cdf = numpy.array(x_cdf)
            x_cdfs.append(x_cdf)
        else:
            x_cdfs.append(None)

    y_lut = partial(numpy.interp, xp=y_cdf_of_accumulated_x[:, 1], fp=y_cdf_of_accumulated_x[:, 0])
    x_luts = [partial(numpy.interp, xp=x_cdfs[i][:, 1], fp=x_cdfs[i][:, 0]) if x_cdfs[i] is not None else None for i in range(num_grids)]

    return y_lut, x_luts

def sample_2d(lut, N):
    y_lut, x_luts = lut
    u_rv = numpy.random.random((N, 2))
    samples = numpy.zeros(u_rv.shape)
    for i, (x, y) in enumerate(u_rv):
        ys = y_lut(y)
        x_bin = int(ys/RESOLUTION)
        xs = x_luts[x_bin](x)
        samples[i][0] = xs
        samples[i][1] = ys

    return samples

if __name__ == '__main__':
    from skimage import io
    density_img = io.imread('batman.jpg', True)
    lut_2d = generate_lut(density_img)
    samples = sample_2d(lut_2d, 10000)

    from matplotlib import pyplot
    fig, (ax0, ax1) = pyplot.subplots(ncols=2, figsize=(9, 4))
    fig.canvas.set_window_title('Test 2D Sampling')
    ax0.imshow(density_img, cmap='gray')
    ax0.xaxis.set_major_locator(pyplot.NullLocator())
    ax0.yaxis.set_major_locator(pyplot.NullLocator())

    ax1.axis('equal')
    ax1.axis([0, 1, 0, 1])
    ax1.plot(samples[:, 0], samples[:, 1], 'k,')
    pyplot.show()

二维GAN的小例子

虽然网上到处都有,这里还是贴一下GAN的公式:

就是一个你追我赶的零和博弈,这在Dev Nag的代码里体现得很清晰:判别网络训一拨,然后生成网络训一拨,不断往复。按照上节所述,本文例子在Dev Nag代码的基础上,把判别网络每次接受一个batch作为输入的方式变成了:每次接受一个二维样本,通过每个batch的多个样本计算loss。GAN部分的训练代码如下:

DIMENSION = 2

...

generator = SimpleMLP(input_size=z_dim, hidden_size=args.g_hidden_size, output_size=DIMENSION)
discriminator = SimpleMLP(input_size=DIMENSION, hidden_size=args.d_hidden_size, output_size=1)

...

for train_iter in range(args.iterations):
    for d_index in range(args.d_steps):
        # 1. Train D on real+fake
        discriminator.zero_grad()

        #  1A: Train D on real
        real_samples = sample_2d(lut_2d, bs)
        d_real_data = Variable(torch.Tensor(real_samples))
        d_real_decision = discriminator(d_real_data)
        labels = Variable(torch.ones(bs))
        d_real_loss = criterion(d_real_decision, labels)  # ones = true

        #  1B: Train D on fake
        latent_samples = torch.randn(bs, z_dim)
        d_gen_input = Variable(latent_samples)
        d_fake_data = generator(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = discriminator(d_fake_data)
        labels = Variable(torch.zeros(bs))
        d_fake_loss = criterion(d_fake_decision, labels)  # zeros = fake

        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()

        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(args.g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        generator.zero_grad()

        latent_samples = torch.randn(bs, z_dim)
        g_gen_input = Variable(latent_samples)
        g_fake_data = generator(g_gen_input)
        g_fake_decision = discriminator(g_fake_data)
        labels = Variable(torch.ones(bs))
        g_loss = criterion(g_fake_decision, labels)  # we want to fool, so pretend it's all genuine

        g_loss.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    ...

...

和Dev Nag的版本比起来除了上面提到的判别网络,和样本维度的修改,还加了可视化方便直观演示和理解,比如用一个二维高斯分布产生一个折线形状的分布,执行:

python gan_demo.py inputs/zig.jpg

训练过程的可视化如下:

更多可视化例子可以参考如下链接:

http://t.cn/Ro8aNJz

Conditional GAN

对于一些复杂的分布,原始的GAN就会很吃力,比如用一个二维高斯分布产生两坨圆形的分布:

因为latent space的分布就是一坨二维的样本,所以即使模型有很强的非线性,也难以把这个分布“切开”并变换成两个很好的圆形分布。因此在上面的动图里能看到生成的两坨样本中间总是有一些残存的样本,像是两个天体在交换物质。要改进这种情况,比较直接的想法是增加模型复杂度,或是提高latent space维度。也许模型可以学习到用其中部分维度产生一个圆形,用另一部分维度产生另一个圆形。不过我自己试了下,效果都不好。

其实这个例子人眼一看就知道是两个分布在一个图里,假设我们已经知道这个信息,那么生成依据的就是个条件概率。把这个条件加到GAN里,就是Conditional GAN,公式如下:

示意图如下:

条件信息变相降低了生成样本的难度,所以生成的样本效果好很多。

在网络中加入条件的方式没有固定的原则,这里我们采用的是可能最常见的方法:用one-hot方式将条件编码成一个向量,然后和原始的输入拼一下。注意对于判别网络和生成网络都要这么做,所以上面公式和C-GAN原文简化过度的公式比起来多了两个y,避免造成迷惑。

C-GAN的代码实现就是GAN的版本基础上,利用pytorch的torch.cat()对条件和输入进行拼接。其中条件的输入就是多张图片,每张定义一部分分布的PDF。比如对于上面两坨分布的例子,就拆成两张图像来定义PDF:

具体实现就不贴这里了,参考本文的Github页面:

http://t.cn/Ro8Svq4

加入条件信息后,两坨分布的生成就轻松搞定了,执行:

python cgan_demo.py inputs/binary

得到下面的训练过程可视化:

对于一些更复杂的分布也不在话下,比如:

这两个图案对应的原始GAN和C-GAN的训练可视化对比可以在这里看到。

应用样例

其实现在能见到的基于 GAN 的有意思应用基本都是 Conditional GAN,下篇打算介绍基于 C-GAN 的一个实(dan)用(teng)例子:

提高驾驶技术:用GAN去除(爱情)动作片中的马赛克和衣服

本文完整代码

http://t.cn/Ro8Svq4



开发者专场 | 英伟达深度学习学院现场授课

英伟达 DLI 高级工程师现场指导,理论结合实践,一举入门深度学习!

课程链接:http://www.mooc.ai/course/90

====================================分割线================================

本文作者:AI研习社

本文转自雷锋网禁止二次转载,原文链接

时间: 2024-09-17 04:16:12

只需 130 行代码,用 GAN 生成二维样本的小例子的相关文章

php简简单单搞定中英文混排字符串截取,只需2行代码!

提到中英文混排计数.截取,大家首先想到的是ascii.16进制.正则匹配.循环计数.   今天我给大家分享的是php的mb扩展,教你如何轻松处理字符串.       先给大家介绍用到的函数:   mb_strwidth($str, $encoding) 返回字符串的宽度   $str 要计算的字符串   $encoding 要使用的编码,如 utf8.gbk   mb_strimwidth($str, $start, $width, $tail, $encoding) 按宽度截取字符串   $s

只需20行代码就可以写出CSS覆盖率测试脚本_基础知识

document.styleSheets里保存了当前页面上所有CSS规则的集合.通过它可以遍历出页面<style>里定义的所有selector,访问selectorText属性可得选择器的匹配规则.然后将规则规则传递给 document.querySelectorAll 即可获取页面内匹配此规则的元素列表. 这里我们只求CSS规则的覆盖率,所以访问 querySelectorAll().length 即可.通过排序就可看出各个CSS使用情况. 代码很简单. 复制代码 代码如下: var usa

利用google api生成二维码名片例子

二维条码/二维码可以分为堆叠式/行排式二维条码和矩阵式二维条码.堆叠式/行排式二维条码形态上是由多行短截的一维条码堆叠而成:矩阵式二维条码以矩阵的形式组成,在矩阵相应元素位置上用"点"表示二进制"1",用"空"表示二进制"0","点"和"空"的排列组成代码. 堆叠式/行排式二维条码,如,Code 16K.Code 49.PDF417等. 矩阵式二维码,最流行莫过于QR CODE. 矩阵式

Symfony生成二维码的方法_php实例

本文实例讲述了Symfony生成二维码的方法.分享给大家供大家参考,具体如下: 现在网上能搜到很多关于使用PHP生成二维码的例子,主要是两种方法: 第一种:google开放api,如下: $urlToEncode="http://blog.it985.com"; generateQRfromGoogle($urlToEncode); function generateQRfromGoogle($chl, $widhtHeight = '150', $EC_level = 'L', $m

使用PHP生成二维码的方法汇总_php技巧

随着科技的进步,二维码应用领域越来越广泛,本站之前已有文章介绍通过使用jQuery插件来生成二维码,今天我给大家分享下如何使用PHP生成二维码,以及如何生成中间带LOGO图像的二维码. 利用Google API生成二维码 Google提供了较为完善的二维码生成接口,调用API接口很简单,以下是调用代码: $urlToEncode="http://www.jb51.net"; generateQRfromGoogle($urlToEncode); /** * google api 二维码

用JAVA 设计生成二维码详细教程_java

教你一步一步用 java 设计生成二维码 在物联网的时代,二维码是个很重要的东西了,现在无论什么东西都要搞个二维码标志,唯恐落伍,就差人没有用二维码识别了.也许有一天生分证或者户口本都会用二维码识别了.今天心血来潮,看见别人都为自己的博客添加了二维码,我也想搞一个测试一下. 主要用来实现两点: 1. 生成任意文字的二维码. 2. 在二维码的中间加入图像. 一.准备工作. 准备QR二维码3.0 版本的core包和一张jpg图片. 下载QR二维码包. 首先得下载 zxing.jar 包, 我这里用的

【圣诞特辑】Keras+树莓派,130行代码找到圣诞老人

今天这篇文章是使用Keras在Raspberry Pi上运行深度神经网络的一个完整指南. 我把这个项目当做一个"不是圣诞老人"(Not Santa)检测器,教你如何实际地实现它(并且过程中乐趣无穷). 第一部分,我们说一下什么是"圣诞老人检测器"(可能你不熟悉热播美剧<硅谷>里的"不是热狗"识别App,现在已经有人把它实现了). 然后,我们将通过安装TensorFlow.Keras和其他一些条件来配置树莓派进行深度学习. 树莓派为深度

.Net中生成二维表格的代码

找了很久才找到的在.NET中生成二维表格的代码,不敢独享,现在就贴出来给大家看看,相信对大家有所帮助. 代码如下: void Page_Load(object o, EventArgs e) ...{ DataTable dt = GetData(); //assume GetData returns the DataTable //probably better to use Hashtable for depts and months too, but to keep the order,

.Net中生成二维的表格的代码

找了很久才找到的在.NET中生成二维表格的代码,不敢独享,现在就贴出来给大家看看,相信对大家有所帮助.   代码如下: 复制代码 代码如下: void Page_Load(object o, EventArgs e) ...{ DataTable dt = GetData(); //assume GetData returns the DataTable //probably better to use Hashtable for depts and months too, but to kee