参考网址:https://github.com/yunjey/pytorch-tutorial
gan的缺点:
其实使用过GAN的人应该知道,训练GAN有很多头疼的问题。例如:GAN的训练对超参数特别敏感,需要精心设计。GAN中关于生成模型和判别模型的迭代也很有问题,按照通常理解,如果判别模型训练地很好,应该对生成的提高有很大作用,但实际中恰恰相反,如果将判别模型训练地很充分,生成模型甚至会变差。那么问题出在哪里呢?
在ICLR 2017大会上有一篇口头报告论文提出了这个问题产生的机理和解决办法。问题就出在目标函数的设计上。这篇文章的作者证明,GAN的本质其实是优化真实样本分布和生成样本分布之间的差异,并最小化这个差异。特别需要指出的是,优化的目标函数是两个分布上的Jensen-Shannon距离,但这个距离有这样一个问题,如果两个分布的样本空间并不完全重合,这个距离是无法定义的。
作者接着证明了“真实分布与生成分布的样本空间并不完全重合”是一个极大概率事件,并证明在一些假设条件下,可以从理论层面推导出一些实际中遇到的现象。
既然知道了问题的关键所在,那么应该如何解决问题呢?该文章提出了一种解决方案:使用Wasserstein距离代替Jensen-Shannon距离。并依据Wasserstein距离设计了相应的算法,即WGAN。新的算法与原始GAN相比,参数更加不敏感,训练过程更加平滑。
一、网络上的介绍
二、DCGAN,有两个神经网络
Generator:生成器。用来生成伪造的图像。生成器的输入是gauss随机噪声,输出是一张图像,随着训练epochs的增加,伪图像会越来越像真正的图像。
Discriminator:鉴别器。是鉴别一张图像的真假。输入是一张真实的图像,输出则是{0,1} 输出为0时 鉴别器认为该图像是假图像,否则反之。
三、训练办法
四、训练网络核心代码
def train(self):
"""Train generator and discriminator."""
fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim))
total_step = len(self.data_loader)
for epoch in range(self.num_epochs):
for i, images in enumerate(self.data_loader):
#===================== Train D =====================#
images = self.to_variable(images) # 生成Variable对象
batch_size = images.size(0)
noise = self.to_variable(torch.randn(batch_size, self.z_dim))
# Train D to recognize real images as real.
outputs = self.discriminator(images)
real_loss = torch.mean((outputs - 1) ** 2) # L2 loss instead of Binary cross entropy loss (this is optional for stable training)
# Train D to recognize fake images as fake.
fake_images = self.generator(noise)
outputs = self.discriminator(fake_images)
fake_loss = torch.mean(outputs ** 2)
# Backprop + optimize
d_loss = real_loss + fake_loss
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
#===================== Train G =====================#
noise = self.to_variable(torch.randn(batch_size, self.z_dim))
# Train G so that D recognizes G(z) as real.
fake_images = self.generator(noise)
outputs = self.discriminator(fake_images)
g_loss = torch.mean((outputs - 1) ** 2)
# Backprop + optimize
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (i+1) % self.log_step == 0:
print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, '
'd_fake_loss: %.4f, g_loss: %.4f'
%(epoch+1, self.num_epochs, i+1, total_step,
real_loss.data[0], fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (i+1) % self.sample_step == 0:
fake_images = self.generator(fixed_noise)
torchvision.utils.save_image(self.denorm(fake_images.data),
os.path.join(self.sample_path,
'fake_samples-%d-%d.png' %(epoch+1, i+1)))
# save the model parameters for each epoch
g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1))
d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1))
torch.save(self.generator.state_dict(), g_path)
torch.save(self.discriminator.state_dict(), d_path)
五、程序整体特性
给定gauss特征向量,就可以生成一张训练集类似的头像人脸。最终生成效果如下