pytorch构建ANN pytorch搭建gan_生成模型


原论文地址:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

GitHub:https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

一、GAN 有什么用?

GAN 即 Generative Adversarial Nets,生成对抗网络,从名字上我们可以得到两个信息:

  1. 首先,它是一个生成模型
  2. 其次,它的训练是通过“对抗”完成的

何为生成模型?即,给个服从某种分布(比如正态分布)随机数,模型就可以给你生成一张人脸、一段文字 etc。它要做的,是找到一种映射关系,将随机数的分布映射成数据的分布。

何为对抗?GAN 除了包含一个生成模型 G 外还包含一个 判别模型 D ,G 输入随机数生成数据,D 输入数据输出置信度,1 表示是真实数据,0 表示为 G 伪造的数据;二者通过反复地对抗,最终理想情况下, G 生成的数据与真实数据非常接近,分布也相同,而 D 无论输出真实数据还是 G 伪造的数据都输出0.5。

二、GAN 的目标函数及流程


pytorch构建ANN pytorch搭建gan_pytorch构建ANN_02


  • max 部分的含义是,D 要尽可能正确地识别出真实数据和 G 伪造的数据。
  • min 部分的含义是,G 要尽可能缩小自己生成的数据与真实数据的差别,让 D 真假难别。

整个训练流程如图:


pytorch构建ANN pytorch搭建gan_随机数_03


在每一步的训练中:

  1. 取 m 个真实数据,使用 G 和 m 组随机数(一般使用服从正态分布的随机数)生成 m 个假数据
  2. 根据 max 部分的目标更新 D 的参数,提高 D 的分辨能力
  3. 根据 min 部分的目标更新 G 的参数,使 G 生成的数据更有迷惑性

三、GAN 的 Pytorch 实现(使用 mnist 数据集)


import


在这个实现中需要注意的一点是,原论文中 G 的训练是希望减小 log(1-D(G(z)),而代码中是使用二值交叉熵BCE(G(z), 1),即希望提高-log(D(G(x))),虽然都是希望让 D(G(x)) 趋近于1 ,但数值上还是有细微的不同。