1. 概述

生成对抗网络GAN(Generative adversarial nets)[1]是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator)生成对抗网络GAN_生成模型和判别网络(Discriminator)生成对抗网络GAN_计算机视觉_02,为描述简单,以图像生成为例:

  • 生成网络(Generator)生成对抗网络GAN_计算机视觉_03用于生成图片,其输入是一个随机的噪声生成对抗网络GAN_计算机视觉_04,通过这个噪声生成图片,记作生成对抗网络GAN_计算机视觉_05
  • 判别网络(Discriminator)生成对抗网络GAN_生成对抗网络_06用于判别一张图片是否是真实的,对应的,其输入是一整图片生成对抗网络GAN_生成模型_07,输出生成对抗网络GAN_深度学习_08表示的是图片生成对抗网络GAN_生成模型_07为真实图片的概率

在GAN框架的训练过程中,希望生成网络生成对抗网络GAN_生成模型生成的图片尽量真实,能够欺骗过判别网络生成对抗网络GAN_计算机视觉_02;而希望判别网络生成对抗网络GAN_计算机视觉_02能够把生成对抗网络GAN_生成模型生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络生成对抗网络GAN_生成模型生成的图片能够以假乱真,即对于判别网络生成对抗网络GAN_计算机视觉_02来说,无法判断生成对抗网络GAN_生成模型生成的网络是不是真实的。

综上,训练好的生成网络生成对抗网络GAN_生成模型便可以用于生成“以假乱真”的图片。

2. 算法原理

2.1. GAN的框架结构

GAN的框架是由生成网络生成对抗网络GAN_生成模型和判别网络生成对抗网络GAN_计算机视觉_02这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练,GAN框架由下图所示:

生成对抗网络GAN_计算机视觉_20

由生成网络生成对抗网络GAN_生成模型生成一张“Fake image”,判别网络生成对抗网络GAN_计算机视觉_02判断这张图片是否来自真实图片。

2.2. GAN框架的训练过程

在GAN的训练过程中,其最终的目标是使得训练出来的生成模型生成对抗网络GAN_生成模型生成的图片与真实图片具有相同的分布,其过程可通过下图描述[2]:

生成对抗网络GAN_深度学习_24

假设有一个先验分布生成对抗网络GAN_生成模型_25,如上图中的unit gaussian,通过采样得到其中的一个样本点生成对抗网络GAN_深度学习_26。对于真实的图片,事先对于其分布是未知的,即上图中的生成对抗网络GAN_生成对抗网络_27未知。为了使得能与真实图片具有相同的分布,通过一个生成模型将先验分布映射到另一个分布,生成模型记为生成对抗网络GAN_计算机视觉_28,其中生成对抗网络GAN_生成模型_29为生成模型的参数,这里的生成模型可以是一个前馈神经网络MLP,生成对抗网络GAN_生成模型_29便为该神经网络的参数。通过多次的采样,便可以刻画出生成的分布生成对抗网络GAN_生成对抗网络_31,此时需要计算其与真实的分布生成对抗网络GAN_生成对抗网络_27之间的相关性,即需要一个判别模型来定量表示两个分布之间的相关性,这里可以通过另一个前馈神经网络MLP,判别模型记为生成对抗网络GAN_深度学习_33,其中生成对抗网络GAN_深度学习_33的输出是一个标量,表示的是生成对抗网络GAN_计算机视觉_35来自真实的分布,而不是来自于生成模型构造出的分布的概率。

对于这样的一个过程中,有两个模型,分别为生成模型生成对抗网络GAN_计算机视觉_28和判别模型生成对抗网络GAN_深度学习_33,在GAN中,生成模型和判断模型分别对应了一个神经网络,以下都称为生成网络和判别模型。GAN希望的是对于判别网络,其能够正确判定数据是否来自真实的分布,对于生成网络,其能够尽可能使得生成的数据能够“以假乱真”,使得判别网络分辨不了。这样的训练过程是一个动态的“博弈”过程,通过交替训练,最终使得生成网络生成对抗网络GAN_生成模型生成的图片能够“以假乱真”,其具体过程如下图所示:

生成对抗网络GAN_计算机视觉_39

如上图(a)中,黑色的虚线表示的是从真实的分布生成对抗网络GAN_数据_40,绿色的实线表示的是需要训练的生成网络的生成的分布生成对抗网络GAN_深度学习_41,蓝色的虚线表示的是判别网络,最下面的横线生成对抗网络GAN_深度学习_26表示的是从一个先验分布(如图中是一个均匀分布)采样得到的数据点,中间的横线生成对抗网络GAN_计算机视觉_35表示的真实分布,两条横线之间的对应关系表示的是生成网络将先验分布映射成一个生成分布生成对抗网络GAN_深度学习_41。从图(a)到图(d)表示了一个完整的交替训练过程,首先,如图(a)所示,当通过先验分布采样后的数据经过生成网络生成对抗网络GAN_生成模型映射后得到了图上绿色的实线代表的分布,此时判别网络生成对抗网络GAN_计算机视觉_02并不能区分数据是否来自真实数据,通过对判别网络的训练,其能够正确地判断生成的数据是否来自真实数据,如图(b)所示;此时更新生成网络生成对抗网络GAN_生成模型,通过对先验分布重新映射到新的生成分布上,如图(c)中的绿色实线所示。依次交替完成上述步骤,当达到一定迭代的代数后,达到一个平衡状态,此时生成对抗网络GAN_数据_48,判别网络生成对抗网络GAN_计算机视觉_02将不能区分图片是否来自真实分布,且生成对抗网络GAN_数据_50

2.3. 价值函数

对于GAN框架,其价值函数生成对抗网络GAN_生成模型_51为:

生成对抗网络GAN_计算机视觉_52

其中,生成对抗网络GAN_数据_53表示的是生成对抗网络GAN_数据_54的期望,同理,生成对抗网络GAN_计算机视觉_55表示的是生成对抗网络GAN_生成对抗网络_56的期望。

假设从真实数据中采样生成对抗网络GAN_计算机视觉_57个样本生成对抗网络GAN_生成对抗网络_58,从噪音分布生成对抗网络GAN_生成对抗网络_59中同样采样生成对抗网络GAN_计算机视觉_57个样本,记为生成对抗网络GAN_深度学习_61,此时,上述价值函数可以近似表示为:

生成对抗网络GAN_计算机视觉_62

简化后为:

生成对抗网络GAN_数据_63

上述的交替训练过程如下流程所示:

生成对抗网络GAN_生成模型_64

3. GAN背后的数学原理

为了能够从数学的角度对上述过程做分析,首先对问题进行数学的描述:假设真实的数据分布为生成对抗网络GAN_计算机视觉_65,生成网络得到的分布为生成对抗网络GAN_深度学习_66,其中生成对抗网络GAN_数据_67为生成网络的参数,现在需要找到一个生成对抗网络GAN_数据_68使得生成对抗网络GAN_深度学习_69

3.1. 为什么会有这样的价值函数

由上可知,当生成网络生成对抗网络GAN_生成模型确定后,GAN的价值函数可以近似为:

生成对抗网络GAN_生成对抗网络_71

其来源可以追溯到二分类的损失函数,对于一个二分类来说,通常选择交叉墒作为其损失函数,交叉墒的一般形式为:

生成对抗网络GAN_计算机视觉_72

其中,生成对抗网络GAN_数据_73表示的是真实的样本标签,生成对抗网络GAN_计算机视觉_74表示的是模型的预测值。对于GAN来说,样本分为两个部分,一个是来自真实的样本生成对抗网络GAN_生成模型_75,将其带入到交叉墒的公式中(去除交叉墒的负号)为:

生成对抗网络GAN_深度学习_76

另一个则是来时生成模型生成对抗网络GAN_计算机视觉_77,将其带入到交叉墒的公式中(去除交叉墒的负号)为:

生成对抗网络GAN_生成模型_78

将两部分合并在一起,便是上述的价值函数。

3.2. KL散度

需要刻画两个分布是否相似,需要用到KL散度(KL divergence)。KL散度是统计学中的一个基本概念,用于衡量两个分布的相似程度,数值越小,表示两种概率分布越接近。对于离散的概率分布,定义如下:

生成对抗网络GAN_数据_79

对于连续的概率分布,定义如下:

生成对抗网络GAN_生成模型_80

3.3. 极大似然估计

极大似然估计(Maximum Likelihood Estimation),是一种概率论在统计学的应用,它是参数估计的方法之一。上述需要求解生成分布生成对抗网络GAN_深度学习_66中的参数生成对抗网络GAN_数据_67,需要用到极大似然估计。根据极大似然估计的方式,由于最终是希望生成的分布生成对抗网络GAN_深度学习_66与原始的真实分布生成对抗网络GAN_计算机视觉_65,首先从真实分布生成对抗网络GAN_计算机视觉_65采样生成对抗网络GAN_计算机视觉_57个数据点,记为生成对抗网络GAN_生成对抗网络_58,根据生成的分布,得到似然函数为:

生成对抗网络GAN_数据_88

取log后,得到等价的log似然:

生成对抗网络GAN_计算机视觉_89

此时,生成对抗网络GAN_数据_68为:

生成对抗网络GAN_生成模型_91

对上述的公式做一些修改,增加一个与生成对抗网络GAN_数据_67无关的项生成对抗网络GAN_计算机视觉_93,这样并不改变对生成对抗网络GAN_数据_68的求解,此时,公式变为:

生成对抗网络GAN_数据_95

将最大值求解变成最小值为:

生成对抗网络GAN_计算机视觉_96

通过积分公式的合并,得到:

生成对抗网络GAN_数据_97

由KL散度可知,上述可以表示为:

生成对抗网络GAN_深度学习_98

由此可以看出最小化KL散度等价于最大化似然函数。

3.4. 收敛性分析

当生成网络生成对抗网络GAN_生成模型确定后,价值函数可以表示为:

生成对抗网络GAN_计算机视觉_100

由于上述的积分与生成对抗网络GAN_计算机视觉_02无关,上述可以简化成求解:

生成对抗网络GAN_深度学习_102

求导数并令其为生成对抗网络GAN_深度学习_103,便可以得到最大的生成对抗网络GAN_计算机视觉_02:

生成对抗网络GAN_生成对抗网络_105

生成对抗网络GAN_数据_106,将其带入到价值函数中,可得

生成对抗网络GAN_数据_107

对上式简化,可得:

生成对抗网络GAN_生成对抗网络_108

通过对分子和分母分别除生成对抗网络GAN_数据_109,可得:

生成对抗网络GAN_计算机视觉_110

这里引入另一个符号:JS散度(Jensen-Shannon Divergence)

生成对抗网络GAN_生成对抗网络_111

其中,生成对抗网络GAN_计算机视觉_112。因此生成对抗网络GAN_计算机视觉_113可以表示为:

生成对抗网络GAN_生成对抗网络_114

已知JS散度是一个非负值,且值域为生成对抗网络GAN_生成对抗网络_115,当两个分布相同时取生成对抗网络GAN_深度学习_103,不同时取生成对抗网络GAN_生成对抗网络_117。对于生成对抗网络GAN_计算机视觉_113的最小值为当生成对抗网络GAN_数据_119时,即最小值是生成对抗网络GAN_数据_120。此时生成对抗网络GAN_深度学习_121,此时求得的生成网络生成对抗网络GAN_生成模型生成的数据分布与真实的数据分布差异性最小,即GAN所要求的目标:生成对抗网络GAN_数据_48

3. 总结

生成对抗网络GAN中通过生成网络生成对抗网络GAN_生成模型和判别网络生成对抗网络GAN_计算机视觉_02之间的“生成”和“对抗”过程,通过多次的迭代,最终达到平衡,使得训练出来的生成网络生成对抗网络GAN_生成模型能够生成“以假乱真”的数据,判别网络生成对抗网络GAN_计算机视觉_02不能将其从真实数据中区分开。

参考文献

[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[J]. Advances in neural information processing systems, 2014, 27.

[2] Generative Models

[2] PyTorch 学习笔记(十):初识生成对抗网络(GANs)

[3] 通俗理解生成对抗网络GAN