文章目录
- 论文地址
- 希望路过的学霸们可以看看最后一句话
- 一、基本概念
- 二、GAN的原理
- 1、Generator(生成器)
- 2、Discriminator(判别器)
- 3、强强联合
- 4、算法
- 三、公式推导
- 1、KL散度、JS散度
- 2、求解Generator
- 3、求解Discriminator
- 4、整个训练的具体步骤
- 四、存在的缺点与问题
- 五、GAN的改进
- 1、Conditional GAN
- 2、Deep Convolutional GAN
- 3、DRAGAN(On Convergence and Stability of GANs)
- 4、Cycle Gan
论文地址
https://arxiv.org/abs/1406.2661
希望路过的学霸们可以看看最后一句话
一、基本概念
生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。
二、GAN的原理
1、Generator(生成器)
随机输入高斯噪声(注意这个噪声一定要是来自简单的分布,容易采样的,例如normal distribution),生成图片(概率分布 distribution)。
我们将生成的图片的数据分布叫做Pg,真实图片的数据分布叫做Pdata。
我们的目的:通过Generator生成高质量的图片。也就是说,通过Generator生成的数据分布,与真实图片的数据分布,越接近越好
2、Discriminator(判别器)
输入图片,给图片打分,输出为一个Scalar(数字)。
如果图片来自真实图片,给高分;来自生成的图片,给低分。
3、强强联合
在GAN中,我们就是在不断的训练G和D,让G生成的图片的分布尽可能逼近真实图片的分布。
在训练过程中:
D学着去分辨真实的图片和G生成的图片,企图给真实的图片打高分,给G生成的图片打低分。
G学着去骗过D,企图让D分辨不出来图片是来自真实的数据集还是G生成的。达到“以假乱真”的效果。
总结:G其实就是在学会做高仿品。D也在不断变强,增强自己辨真假的能力。这就是“对抗”,毕竟,只有足够强大的对手才能逼出更强大的自己。
小例子:
这个例子来自李宏毅老师的机器学习课程,链接在这(b站):
李宏毅2021/2022春机器学习课程—生成式对抗网络(GAN) 20分20秒处
4、算法
在每轮的迭代过程中:
Step1:固定Generator,更新Discriminator;
Step2:固定Discriminator,更新Generator;
整个迭代步骤:
终极目标:通过G生成D难以分辨的图片,D打分为0.5,因为它不知道是真是假。
三、公式推导
刚才说到,我们希望通过G生成的分布与真实图片的分布,越接近越好。
那么,如何衡量两个分布之间的相似程度(或者说差异性)呢?
巧了,在数学中有一些东西,可以做到这件事情,它叫做KL散度(JS散度等等也可以,不止这一种)。
1、KL散度、JS散度
KL散度(KL Divergence):也称相对熵、KL距离。对于两个概率分布P和Q之间的差异性(也可以简单理解成相似性),二者越相似,KL散度越小。
关于KL散度具体的内容,推导,可以看看b站王木头讲的,个人觉得讲的很清楚了,链接如下(其实我觉得弄懂了好像也没啥用,可能是我太菜了):
“交叉熵”如何做损失函数?打包理解“信息量”、“比特”、“熵”、“KL散度”、“交叉熵”
这里,只列出KL散度的计算公式(离散随机变量):
JS散度(JS Divergence):JS散度是KL散度的一种变体,与KL散度相似,P和Q越相似,JS散度越小。
与KL散度相比,JS散度具有对称性,也就是说,P和Q可以交换位置,而KL散度不具有这个特点。
2、求解Generator
其实现在,想必大家也猜到了G的优化公式。那就是,最小化Pg和Pdata的KL散度。
其实这个divergence就是Generator中的loss function。
Generator也是一个神经网络,我们就是在找出一组权值(weight)和阈值(bias),使得Pg和Pdata之间的KL散度,越小越好。
但是,通过KL散度的公式,我们会发现,要求解KL散度,必须要知道Pg和Pdata的分布,你要事先知道他们俩长什么样,所以,我们很难直接算出Pg和Pdata之间的KL散度。那么,该如何得到KL散度呢?
这就是GAN神奇的地方,它告诉你说,你不需要知道Pg和Pdata的公式长什么样子,就可以计算KL散度,这就要依靠Discriminator的力量。
3、求解Discriminator
GAN提出的对于D的优化是这样的:
我们希望给来自Pdata的样本高分,给来自Pg的样本低分,所以就是要最大化这个目标方程。
GAN神奇的对方就在于,这个最大化的目标方程,是和JS散度(KL散度也可以)有关的。
接下来就来求解一下最优的D,并求解这个V(G,D):
注意:上文说过,这里是固定G,求解D
D* 就是使得目标方程 V(G,D) 最大的D
我们把 D* 带入到 V(G,D) 中
可以看到,求解得到的 max V(G,D) 与JS散度有关,那么,求解 max V(G,D) ,其实就是在求Pg和Pdata的JS散度,G的优化函数就可以写成如下形式:
也就得到了,论文中提到的那个方程:
4、整个训练的具体步骤
注:
(1)这里用采样后求均值的方法代替分布的期望
(2)更新G的时候,不能更新太多。我们更新G的时候,是固定了D,那么如果G更新过多的话, V(G,D) 可能会发生很大变化,以至于当前的D可能已经不是使 V(G,D) 达到最大值的D,那么这时候我们通过梯度下降减小的也就不是JS散度了。所以,这里实际上是假设了,梯度下降的每一步更新后,使得 V(G,D) 达到最大的D是基本没变的。如下图所示:
四、存在的缺点与问题
(1)没有显示地表达Pg
(2)D必须与G同步训练,且G不能更新太多。最终需要达到纳什平衡,但是有时候是做不到的,所以训练起来有时候是不稳定的,且生成器和图像质量之间缺乏相关性。
(3)模式崩溃问题。就是说G会偷懒,它发现有一种方法可以无限次骗过D,那它就会得寸进尺,每次都用这种方法来骗D,换到我们这个场景里面来说就是,生成的图片比较单一,缺乏多样性。就像下图所示:
生成的样本全部聚集在左边的峰下,这时虽然生成样本的质量比较高,但是生成器完全没有捕捉到右边的峰的模式。(如果使用多种猫的图像训练GAN,最终GAN只能产生逼真的英短,而无法产生其他品种)。
五、GAN的改进
1、Conditional GAN
论文地址:https://arxiv.org/abs/1411.1784 在原始的GAN基础上加了条件y,分别加到G和D中。
2、Deep Convolutional GAN
论文地址:https://arxiv.org/abs/1511.06434 将CNN加入到GAN中。
3、DRAGAN(On Convergence and Stability of GANs)
论文地址:https://arxiv.org/abs/1705.07215 用梯度惩罚方案解决了模式崩溃问题。
4、Cycle Gan
论文地址:https://arxiv.org/abs/1703.10593 图像转换。
还有很多GAN,我也没看几个,就把最原始的GAN仔细的从头到尾看了下,希望这篇可以对大家有所帮助。
接下来准备看Diffusion Model,希望同学们可以多多帮助,分享下学习经验,paper,视频等等,谢谢了!研一菜鸟一枚。