原论文地址:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
GitHub:https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
一、GAN 有什么用?
GAN 即 Generative Adversarial Nets,生成对抗网络,从名字上我们可以得到两个信息:
- 首先,它是一个生成模型
- 其次,它的训练是通过“对抗”完成的
何为生成模型?即,给个服从某种分布(比如正态分布)随机数,模型就可以给你生成一张人脸、一段文字 etc。它要做的,是找到一种映射关系,将随机数的分布映射成数据的分布。
何为对抗?GAN 除了包含一个生成模型 G 外还包含一个 判别模型 D ,G 输入随机数生成数据,D 输入数据输出置信度,1 表示是真实数据,0 表示为 G 伪造的数据;二者通过反复地对抗,最终理想情况下, G 生成的数据与真实数据非常接近,分布也相同,而 D 无论输出真实数据还是 G 伪造的数据都输出0.5。
二、GAN 的目标函数及流程
- max 部分的含义是,D 要尽可能正确地识别出真实数据和 G 伪造的数据。
- min 部分的含义是,G 要尽可能缩小自己生成的数据与真实数据的差别,让 D 真假难别。
整个训练流程如图:
在每一步的训练中:
- 取 m 个真实数据,使用 G 和 m 组随机数(一般使用服从正态分布的随机数)生成 m 个假数据
- 根据 max 部分的目标更新 D 的参数,提高 D 的分辨能力
- 根据 min 部分的目标更新 G 的参数,使 G 生成的数据更有迷惑性
三、GAN 的 Pytorch 实现(使用 mnist 数据集)
import
在这个实现中需要注意的一点是,原论文中 G 的训练是希望减小 log(1-D(G(z)),而代码中是使用二值交叉熵BCE(G(z), 1),即希望提高-log(D(G(x))),虽然都是希望让 D(G(x)) 趋近于1 ,但数值上还是有细微的不同。