本文是ICLR2019的一篇文章,对MMD-GAN进行了优化,下面是对这篇文章的阅读笔记。

生成式对抗性网络(GANs)被广泛用于学习数据采样过程,在有限的计算预算下,其性能在很大程度上取决于损失函数。本研究修正了以最大均值误差(MMD)作为GAN损失函数的MMD-GAN,主要有两大贡献:首先,作者认为,现有的MMD损失函数可能阻碍学习数据中的细节,因为它试图缩小真实数据的鉴别输出。为了解决这一问题,提出了一个排斥损失函数,通过简单地重新排列MMD中的项来主动学习实际数据之间的差异。其次,受Hinge loss的启发,提出了一种有界高斯核来稳定具有排斥损失函数的MMD-GAN的训练。

GAN:GAN的理念是联合训练一个试图生成人工样本的生成器网络,以及一个区分生成的样本和真实样本的鉴别器网络。与基于极大似然的方法相比,GAN生成的样本具有更清晰、更生动的细节,但同时也需要更多的训练。

本文主要研究一种称为最大平均偏差(MMD)的损失函数,即两个概率分布之间的距离度量。MMD的基本假设是:如果对于所有以分布生成的样本空间为输入的函数f,如果两个分布生成的足够多的样本在f上的对应的像的均值都相等,那么可以认为这两个分布是同一个分布。现在一般用于度量两个分布之间的相似性。理论上,当且仅当两种分布相等时,MMD达到全局最小值为零。在这篇文章中,我们把MMD损失函数的优化过程解释为吸引和排斥过程的结合,类似于线性判别分析(LDA)。我们认为,现有的MMD损失可能阻碍学习数据中的细节,因为甄别者试图将其输出的真实数据的组内方差最小化,换言之,MMD方法的鉴别器更注重于学习源域数据上的共性,而忽视了其细节上的差异。我们提出了一个斥性损失的鉴别器,显式地探索不同的真实数据分布。在此基础上,提出了一种有界高斯核函数来稳定鉴别器的训练。在训练过程中使用单一核函数就足够了,正因如此大大降低了计算过程中的时间损耗。

MMD-GAN:MMD使用特征核 来实现核嵌入。在源域P与目标域Q之间的MMD平方损失函数为 ,核k(a,b)衡量的是两个样例a与b之间的相似程度。MMD-GAN里,鉴别器D被视为一个新内核k_D(a,b)=k(D(a),D(b)),如果D是单射的,k_D就是每一个值一一对应的,而M_(k_D )(P1,P2)=0,只有在P1=P2时才会取到。这样G与D的目标函数就是

条件生成对抗网络 噪音_损失函数


MMD损失函数相较于原先的JS散度这一约束函数而言是一个更弱的一个约束,因为JS散度实际上就隐含了MMD损失函数,正因如此,MMD损失提供了更多关于调整模型以适应数据分布的信息。在本节中,我们将MMD-GAN的训练解释为吸引和排斥过程的结合,并通过重新排列中的成分,为鉴别器提出了一个新的排斥损失函数。排斥损失函数:在这一部分把MMD-GAN的训练过程归结为吸引与排斥两个阶段。并为鉴别器引入了一个全新的损失函数L_drep来为L_datt的组成部分重新组合。首先,使用一个LDA模型作为辨别器,鉴别器的目标是找到一个变量w来最大化类间差异

条件生成对抗网络 噪音_机器学习_02

并最小化类内个体之间的差异

条件生成对抗网络 噪音_损失函数_03

条件生成对抗网络 噪音_GAN_04


条件生成对抗网络 噪音_损失函数_05

分别是类平均值与协方差。注意这里的类指的只有两类,源域类与目标域类。

在MMD_GAN里,神经网络鉴别器以与LDA类似的方法来实现域鉴别。通过最小化

条件生成对抗网络 噪音_损失函数_06

,域鉴别器D实现了两个任务:1.D减小了

条件生成对抗网络 噪音_机器学习_07

,使源域样例与目标域样例相互背弃,换言之,使两类的协方差变大。2.D增加了

条件生成对抗网络 噪音_机器学习_08


条件生成对抗网络 噪音_损失函数_09

,使源域与目标域的样例类内间距减小,换言之,最小化域内协方差。

但这种损失函数会拖慢GAN的训练过程,因为:1.鉴别器D更倾向于找到真实样例里的相似性而不是寻找能分离它们的细节。最初,生成器产生的生成低仿真度的图像就能够使D鉴别出它们之间的差异。而只有源域分布与目标域分布足够接近时才可能学到这些细节,就会导致最后生成器学不到真实样例的细节。上述两个损失的梯度在训练时可能相反,导致总和接近于0,使训练过程中,可能出现两个域间差异很大,但是由于梯度接近0,而很难在继续训练的情况。为了解决上述问题,为鉴别器提出了一个排斥损失函数,来鼓励减低真实数据上得分。

条件生成对抗网络 噪音_机器学习_10


生成器还是用之前MMD-GAN的损失函数进行约束。这样通过最小化

条件生成对抗网络 噪音_机器学习_11

来制约D(y)得分而生成器不断拟合分布来使D(y)得分不断提高。这样就可以积极拟合真实样本中的细节,并能使G梯度更利于训练。最后,全新的鉴别器损失函数就是

条件生成对抗网络 噪音_条件生成对抗网络 噪音_12


λ在这里是一个超参数,当<0时,使整个式子是吸引属性的,>0时是排斥属性的,当>1时相当于一个正则项,防止D(x)的成对距离过于远离感兴趣区域。

MMD方法与鉴别器的正则化方法

这一部分就是提出两个方法是MMD-GAN的训练更加稳定。1.一个有界核,以避免由过于置信的鉴频器引起的饱和问题,2. 本文提出了一种广义幂次迭代法来估计卷积核的谱范数。

MMD-GAN用了两个核函数:1.高斯径向基核(RBF)

条件生成对抗网络 噪音_MMD_13

,2.

条件生成对抗网络 噪音_损失函数_14

,之前的研究都使用了五种不同核尺度的核的线性组合。目的是为了减少单一核带来的过拟合问题。但是之前均只是在参数训练时采用正则化防止过拟合,这里提出一个有界RBF核函数来约束鉴别器,来防止鉴别器本身带来的过拟合。

条件生成对抗网络 噪音_机器学习_15


条件生成对抗网络 噪音_机器学习_16

分别是吸引与排斥两种情况的有界RBF核的表达式。RBF-B内核仅用于鉴别器,以防止它过于置信。谱归一化是利用每一层的谱范数来表示每一层的权矩阵,该权矩阵对每一层的输出量和梯度有一个上界。本文提出了一种直接估计卷积核谱范数的广义幂迭代方法,并将谱归一化方法应用到滤波器的鉴频器中。来保训练过程稳定。之后就是实验,来证明这一方法的有效性。