1. 生成对抗网络的概念与公式1

我们每次看生成对抗的公式,都会出现一个疑问,每次看懂之后,过一段时间遇到,还是得看半天,md,这一次记录下来!

1.1 判别网络

生成对抗网络loss图 生成对抗网络loss曲线_生成对抗网络loss图


疑惑1: 13.30式,作者说判别网络的目标函数是最小化交叉熵。

我对于交叉熵的第一印象来源于相对熵,相对熵越小,两个分布的差异越小,交叉熵与相对熵差一个常数,所以交叉熵越小,两个分布的差异越小。

所以这里作者说判别网络的目标函数是最小化交叉熵,我第一反应是让生成的和真实的差异越小。我擦,判别网络的目的不是让生成和真实的差异越大吗?

解疑1:

  1. 首先判别网络求loss时,GT不是真实的图像。ok?
    对于输入生成的数据,GT是假,即GT是一个全0的矩阵
    对于输入真实的数据,GT是真,即GT是一个全1的矩阵
  2. 由上述分析,判别器求loss时,其实是让生成对抗网络loss图 生成对抗网络loss曲线_生成_02生成对抗网络loss图 生成对抗网络loss曲线_对抗_03,接近,让生成对抗网络loss图 生成对抗网络loss曲线_pytorch_04生成对抗网络loss图 生成对抗网络loss曲线_原理_05接近。通过这样的方式达到让生成与真实数据的距离远,而不是直接做loss,也就找了一个中间量。

这样做有什么好处,其实就是软化,直接做loss让生成的与真实的距离大?这么训练直接崩了。加一个网络处理一下,在处理结果上做loss,并且找一个中间量,假的数据->0,真的数据->1。看看后面的代码就明白了。

疑惑2:用两个BCELoss就能实现判别器Loss,熟悉BCELoss公式的话,你好像觉得13.32式少了什么东西。

解疑2

生成对抗网络loss图 生成对抗网络loss曲线_对抗_06


这是BCELoss的公式2,对比上图,发现第一项少了BCE公式的后半部分,第二项少了BCE公式的前半部分。注意第一项的生成对抗网络loss图 生成对抗网络loss曲线_pytorch_07下标表示:x来自真实样本,即BCE公式的生成对抗网络loss图 生成对抗网络loss曲线_原理_08,所以BCE公式的后半部分为0

生成对抗网络loss图 生成对抗网络loss曲线_生成对抗网络loss图_09

1.2 生成网络

生成对抗网络loss图 生成对抗网络loss曲线_pytorch_10

2. 代码3

# discriminator loss
dis_input_real = torch.cat((images, edges), dim=1)
dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
dis_real, dis_real_feat = self.discriminator(dis_input_real)        # in: (grayscale(1) + edge(1))
dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)        # in: (grayscale(1) + edge(1))
dis_real_loss = self.adversarial_loss(dis_real, True, True)
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2


# generator adversarial loss
gen_input_fake = torch.cat((images, outputs), dim=1)
gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)        # in: (grayscale(1) + edge(1))
gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
gen_loss += gen_gan_loss
class AdversarialLoss(nn.Module):
    r"""
    Adversarial loss
    https://arxiv.org/abs/1711.10337
    """
    def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
        r"""
        type = nsgan | lsgan | hinge
        """
        super(AdversarialLoss, self).__init__()

        self.type = type
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))

        if type == 'nsgan':
            self.criterion = nn.BCELoss()

        elif type == 'lsgan':
            self.criterion = nn.MSELoss()

        elif type == 'hinge':
            self.criterion = nn.ReLU()

    def __call__(self, outputs, is_real, is_disc=None):
        if self.type == 'hinge':
            if is_disc:
                if is_real:
                    outputs = -outputs
                return self.criterion(1 + outputs).mean()
            else:
                return (-outputs).mean()

        else:
            labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
            loss = self.criterion(outputs, labels)
            return loss

  1. https://nndl.github.io/nndl-book.pdf ↩︎
  2. https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html?highlight=bceloss#torch.nn.BCELoss ↩︎
  3. https://github.com/knazeri/edge-connect/blob/master/src/models.py ↩︎