20210813 -

0. 引言

最近在实现对抗自编码器的代码,想法是从最简单的模板开始。同时为了能够先找到点感觉,先看看怎么处理MNIST数据。

1. 代码示例

针对对抗自编码器的代码,找到了两份代码,分别是tensorflow实现和keras实现。其实最开始是弄的keras版本,但是判别器的判别准确率基本上一直稳定在100%,就挺奇怪的。所以,就有弄了个tensorflow来看看,不过这个问题还是没有解答。先把整理代码的过程来记录下,因为代码并不能直接跑。代码地址分别位于[1]和[2]。

1.1 Keras版本代码

这个版本的代码有一个错误,也不算错误把,属于API的版本问题。

def build_encoder(self):
        # Encoder

        img = Input(shape=self.img_shape)

        h = Flatten()(img)
        h = Dense(512)(h)
        h = LeakyReLU(alpha=0.2)(h)
        h = Dense(512)(h)
        h = LeakyReLU(alpha=0.2)(h)
        latent_repr = Dense(self.latent_dim, activation='tanh')(h)
        #mu = Dense(self.latent_dim)(h)
        #log_var = Dense(self.latent_dim)(h)
        #latent_repr = merge([mu, log_var], mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), output_shape=lambda p: p[0])
        return Model(img, latent_repr)

他的代码部分,变量latent_repr是由注释部分的代码来形成的,但是函数merge在新版中已经不可用了,这段代码可以使用Lambda层来实现。不过,在看了另外一篇文章[3]中,其指出,对于编码器的内容,可以通过3中方式来实现,最后一种不太明白,前面两种分别是决定性的(翻译是否正确有待商榷),或者类似变分自编码器的形式,将输出再链接到两个层,正是前面代码的注释部分。决定性,就是我代码中正使用的部分。

1.2 Tensorflow版本代码

首先要说明的是,tensorflow版本的代码中编码器是决定性的。但是这个代码是使用低版本tensorlofw写的(1.7.0好像是),在我2.3的环境上跑不起来。所以要代码进行一些修改。修改的部分有两个。

  • example.tutorials
  • 1.0api兼容性

这两个部分都可以在文章[4]中找到答案。因为使用了MNIST数据集,所以第一次运行的时候需要下载,这个下载过程,如果出现错误,可以多运行几次。

1.3 运行

通过上面的修改之后,两个版本的代码都能正常运行。不过其中他们的损失函数,我有些看不懂。其实还是训练过程中,对GAN的内在原理不是很清晰,还是需要在看看。

(未完待续,后续将记录实际的损失函数变化过程分析。。。)