参考网址:https://github.com/yunjey/pytorch-tutorial


gan的缺点:

其实使用过GAN的人应该知道,训练GAN有很多头疼的问题。例如:GAN的训练对超参数特别敏感,需要精心设计。GAN中关于生成模型和判别模型的迭代也很有问题,按照通常理解,如果判别模型训练地很好,应该对生成的提高有很大作用,但实际中恰恰相反,如果将判别模型训练地很充分,生成模型甚至会变差。那么问题出在哪里呢?

在ICLR 2017大会上有一篇口头报告论文提出了这个问题产生的机理和解决办法。问题就出在目标函数的设计上。这篇文章的作者证明,GAN的本质其实是优化真实样本分布和生成样本分布之间的差异,并最小化这个差异。特别需要指出的是,优化的目标函数是两个分布上的Jensen-Shannon距离,但这个距离有这样一个问题,如果两个分布的样本空间并不完全重合,这个距离是无法定义的。

作者接着证明了“真实分布与生成分布的样本空间并不完全重合”是一个极大概率事件,并证明在一些假设条件下,可以从理论层面推导出一些实际中遇到的现象。

既然知道了问题的关键所在,那么应该如何解决问题呢?该文章提出了一种解决方案:使用Wasserstein距离代替Jensen-Shannon距离。并依据Wasserstein距离设计了相应的算法,即WGAN。新的算法与原始GAN相比,参数更加不敏感,训练过程更加平滑。



一、网络上的介绍

生成对抗网络的反馈系统 生成对抗网络优缺点_生成对抗网络的反馈系统

二、DCGAN,有两个神经网络

Generator:生成器。用来生成伪造的图像。生成器的输入是gauss随机噪声,输出是一张图像,随着训练epochs的增加,伪图像会越来越像真正的图像。

Discriminator:鉴别器。是鉴别一张图像的真假。输入是一张真实的图像,输出则是{0,1} 输出为0时 鉴别器认为该图像是假图像,否则反之。

生成对抗网络的反馈系统 生成对抗网络优缺点_生成对抗网络的反馈系统_02

三、训练办法

生成对抗网络的反馈系统 生成对抗网络优缺点_生成对抗网络的反馈系统_03

四、训练网络核心代码

def train(self):
    """Train generator and discriminator."""
    fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim))
    total_step = len(self.data_loader)
    for epoch in range(self.num_epochs):
        for i, images in enumerate(self.data_loader):
            
            #===================== Train D =====================#
            images = self.to_variable(images) # 生成Variable对象
            batch_size = images.size(0)
            noise = self.to_variable(torch.randn(batch_size, self.z_dim))
            
            # Train D to recognize real images as real.
            outputs = self.discriminator(images)
            real_loss = torch.mean((outputs - 1) ** 2)      # L2 loss instead of Binary cross entropy loss (this is optional for stable training)

            # Train D to recognize fake images as fake.
            fake_images = self.generator(noise)
            outputs = self.discriminator(fake_images)
            fake_loss = torch.mean(outputs ** 2)

            # Backprop + optimize
            d_loss = real_loss + fake_loss
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()
            
            #===================== Train G =====================#
            noise = self.to_variable(torch.randn(batch_size, self.z_dim))
            
            # Train G so that D recognizes G(z) as real.
            fake_images = self.generator(noise)
            outputs = self.discriminator(fake_images)
            g_loss = torch.mean((outputs - 1) ** 2)

            # Backprop + optimize
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

            # print the log info
            if (i+1) % self.log_step == 0:
                print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, ' 
                      'd_fake_loss: %.4f, g_loss: %.4f' 
                      %(epoch+1, self.num_epochs, i+1, total_step, 
                        real_loss.data[0], fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (i+1) % self.sample_step == 0:
                fake_images = self.generator(fixed_noise)
                torchvision.utils.save_image(self.denorm(fake_images.data), 
                    os.path.join(self.sample_path,
                                 'fake_samples-%d-%d.png' %(epoch+1, i+1)))
        
        # save the model parameters for each epoch
        g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1))
        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1))
        torch.save(self.generator.state_dict(), g_path)
        torch.save(self.discriminator.state_dict(), d_path)



五、程序整体特性


给定gauss特征向量,就可以生成一张训练集类似的头像人脸。最终生成效果如下




生成对抗网络的反馈系统 生成对抗网络优缺点_生成对抗网络的反馈系统_04