GANs

生成对抗网络(Generative Adversarial Networks, GANs) 是一种用于捕获训练数据的**分布(distribution)**的神经网络。通过学习到的分布,可以创造新的数据。GAN由两个部分组成:

  • 生成器(generator):用基于ga的神经网络参数调整 gans神经网络_数据表示,输入是(一般为正态分布采样的)随机噪声基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_02,输出是和训练数据等大的“fake”数据;
  • 判别器(discriminator):用基于ga的神经网络参数调整 gans神经网络_损失函数_03表示,用来判断输入数据基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_04是否为真正的训练数据,输出是一个基于ga的神经网络参数调整 gans神经网络_损失函数_05区间的标量,输出值越大表示基于ga的神经网络参数调整 gans神经网络_数据_06判定基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_04更可能是真实训练数据,越小则更可能是假的数据。

基于ga的神经网络参数调整 gans神经网络_损失函数_08

生成器和判别器之间是一个零和博弈的过程:生成器的效果越好,则判别器的正确率越低,反之亦然。在训练过程中,生成器的目的是“让判别器判断错误”,因此会生成越来越接近真实训练数据的假数据,这个过程也是学习训练数据分布的过程;判别器的目的则是“更好地区分真实数据和假数据”,因此学习过程会提高它的判别能力。

以上是一个感性的认知,在神经网络学习的框架下,需要定义一个具体的损失函数(loss),来对生成器和判别器的参数进行梯度更新。定义损失函数:
基于ga的神经网络参数调整 gans神经网络_生成器_09
式中的基于ga的神经网络参数调整 gans神经网络_数据_10取自训练数据,基于ga的神经网络参数调整 gans神经网络_数据_11为生成器基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_12的随机噪声(一般满足正态分布)输入,基于ga的神经网络参数调整 gans神经网络_损失函数_13表示期望。

对于判别器基于ga的神经网络参数调整 gans神经网络_损失函数_14,我们的目的是:真实数据基于ga的神经网络参数调整 gans神经网络_生成器_15基于ga的神经网络参数调整 gans神经网络_数据_16尽量更大;假数据基于ga的神经网络参数调整 gans神经网络_生成器_17基于ga的神经网络参数调整 gans神经网络_数据_16尽量更小,亦即基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_19尽量更大。因此判别器的训练过程的损失函数为基于ga的神经网络参数调整 gans神经网络_数据_20

对于生成器基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_12,我们的目的是:真实数据基于ga的神经网络参数调整 gans神经网络_生成器_15,与基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_12无关,可以当做一个常数;假数据基于ga的神经网络参数调整 gans神经网络_生成器_17,希望判别器“认为它是真实数据”,也就是希望基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_19尽量更小。因此生成器的训练过程的损失函数为基于ga的神经网络参数调整 gans神经网络_损失函数_26

综上,GAN的训练本质上是一种极大极小博弈(minimax game):
基于ga的神经网络参数调整 gans神经网络_数据_27
理论上,训练最终会收敛于基于ga的神经网络参数调整 gans神经网络_生成器_28,也就是生成器学习的概率分布与训练数据的一致,而判别器的输出等价于随机判定真假。但实际上GAN的训练过程很不稳定。



一个例子

使用CelebA人脸数据集训练一个DCGAN。所谓DCGAN,就是生成器和判别器都是卷积神经网络的GAN。训练后的生成器可以用随机噪声生成和训练集相似的人脸图像。

训练使用二进制交叉熵损失函数(BCELoss):
基于ga的神经网络参数调整 gans神经网络_数据_29
这和之前讲到的基于ga的神经网络参数调整 gans神经网络_损失函数_26非常相似,通过合理地选择标签基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_31就可以等价的表示基于ga的神经网络参数调整 gans神经网络_损失函数_26。训练过程的一个iter的过程如下(关键步骤):

  1. 输入数据基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_33从训练图像数据中得到,为真实数据,取标签基于ga的神经网络参数调整 gans神经网络_生成器_34,前向传播(forward)得到基于ga的神经网络参数调整 gans神经网络_数据_35,得到损失基于ga的神经网络参数调整 gans神经网络_生成器_36
  2. 输入数据基于ga的神经网络参数调整 gans神经网络_生成器_37,为假数据,取标签基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_38,前向传播得到基于ga的神经网络参数调整 gans神经网络_生成器_39,得到损失基于ga的神经网络参数调整 gans神经网络_损失函数_40; (上面的基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_41为单个图像的输出,基于ga的神经网络参数调整 gans神经网络_数据_42为一个batch的输出。
  3. 从1和2我们得到了判别器基于ga的神经网络参数调整 gans神经网络_数据_06的损失基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_44,对判别器进行参数更新;
  4. 更新后的判别器重新对基于ga的神经网络参数调整 gans神经网络_生成器_45进行判定,得到输出基于ga的神经网络参数调整 gans神经网络_损失函数_46,取标签基于ga的神经网络参数调整 gans神经网络_生成器_34,得到损失基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_48,前面我们说了生成器的目的是最小化基于ga的神经网络参数调整 gans神经网络_损失函数_49,而基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_50这一项为常数项,因此生成器的目的等价于最小化基于ga的神经网络参数调整 gans神经网络_生成器_51,又等价于最小化基于ga的神经网络参数调整 gans神经网络_基于ga的神经网络参数调整_52,这正好对应基于ga的神经网络参数调整 gans神经网络_生成器_53
  5. 从4我们得到了生成器基于ga的神经网络参数调整 gans神经网络_损失函数_54的损失基于ga的神经网络参数调整 gans神经网络_生成器_53,对生成器进行参数更新。

训练后的模型的生成器的输出和原始图片数据的对比:

基于ga的神经网络参数调整 gans神经网络_生成器_56

代码这里



参考资料

  1. DCGAN TUTORIAL
  2. 网络资料。