生成对抗网络是一种产生模型。它由两部分组成,分别称为“生成器”和“分判器”。生成器以随机值为输入,并将输入转换为可以作为训练数据的输出。分判器将样作为输入并尽量区分真实的训练样本和生成器产生的样本。它们两一起训练。分判器越来越能判别真假,生成器越来越能骗分判器。
条件GAN (CGAN)允许增加输入到生成器和分判器使它们的输出是有条件的。例如,可能是类的标签,GAN试图学习不同类的数据分布的变化。
例如我们将产生包含2D椭圆的数据分布集,位置、形状和方向是随机的。每个类对应不同的椭圆。我们来随机的产生椭圆。每个椭圆我们随机的选择一个中心,X和Y大小,旋转角。然后我们产生变换矩阵将单位圆映射到椭圆。
In [1]:
import deepchem as dc
import numpy as np
import tensorflow as tf
n_classes = 4
class_centers = np.random.uniform(-4, 4, (n_classes, 2))
class_transforms = []
for i in range(n_classes):
xscale = np.random.uniform(0.5, 2)
yscale = np.random.uniform(0.5, 2)
angle = np.random.uniform(0, np.pi)
m = [[xscale*np.cos(angle), -yscale*np.sin(angle)],
[xscale*np.sin(angle), yscale*np.cos(angle)]]
class_transforms.append(m)
class_transforms = np.array(class_transforms)
这个函数从分布中随机产生数字。每个点选择随机的类,然后随机的位置。
In [2]:
def generate_data(n_points):
classes = np.random.randint(n_classes, size=n_points)
r = np.random.random(n_points)
angle = 2*np.pi*np.random.random(n_points)
points = (r*np.array([np.cos(angle), np.sin(angle)])).T
points = np.einsum('ijk,ik->ij', class_transforms[classes], points)
points += class_centers[classes]
return classes, points
我们从这个分布中作一些随机的点来看它是什么样子的。点的着色基于它们的类的标签。
In [3]:
%matplotlib inline
import matplotlib.pyplot as plot
classes, points = generate_data(1000)
plot.scatter(x=points[:,0], y=points[:,1], c=classes)
Out[3]:
<matplotlib.collections.PathCollection at 0x1584692d0>
现在我们来创建CGAN模型。DeepChem的GAN类使这个工作非常容易。我们只要子类化它并实施新的方法。两个要点是:
create_generator()构建模型实施生成器。模型任务的输放为一批随机噪音加上任何的条件变量(我们的情形,每个样本的one-hot encoded类)。它的输出是合成的样本被认为是训练数据。
create_discriminator() 构建一个模型实施分判器。模型的输入为评估样本(真实的数据或是生成器产生的合成样本)和条件变量。它的输出是每个样本的数值,被认为是样本是真实训练样本的概率。
这种情况下,我们使用简单的模型。它们只是将输入连接在一起并传递到一些全链接层。注意分判器的最后一层使用sigmoid激活函数。这确保它产生一个0到1之间的输出可以被解释为概率。
我们也要实施一些方法来确定不同输入的形状。我们指定提供给生成器的随机噪音应包含10个数字;每个数据点包含两个数字(2D空间中点的X和Y坐标),并且每个样本(one-hot encoded索引)条件输入包含n_classes个数值
In [4]:
from tensorflow.keras.layers import Concatenate, Dense, Input
class ExampleGAN(dc.models.GAN):
def get_noise_input_shape(self):
return (10,)
def get_data_input_shapes(self):
return [(2,)]
def get_conditional_input_shapes(self):
return [(n_classes,)]
def create_generator(self):
noise_in = Input(shape=(10,))
conditional_in = Input(shape=(n_classes,))
gen_in = Concatenate()([noise_in, conditional_in])
gen_dense1 = Dense(30, activation=tf.nn.relu)(gen_in)
gen_dense2 = Dense(30, activation=tf.nn.relu)(gen_dense1)
generator_points = Dense(2)(gen_dense2)
return tf.keras.Model(inputs=[noise_in, conditional_in], outputs=[generator_points])
def create_discriminator(self):
data_in = Input(shape=(2,))
conditional_in = Input(shape=(n_classes,))
discrim_in = Concatenate()([data_in, conditional_in])
discrim_dense1 = Dense(30, activation=tf.nn.relu)(discrim_in)
discrim_dense2 = Dense(30, activation=tf.nn.relu)(discrim_dense1)
discrim_prob = Dense(1, activation=tf.sigmoid)(discrim_dense2)
return tf.keras.Model(inputs=[data_in, conditional_in], outputs=[discrim_prob])
gan = ExampleGAN(learning_rate=1e-4)
现在来拟合模型。我们通过调用fit_gan()来完成。参数为迭代器它产生多批的训练数据。更具体的,它需要生产字典映射击所有的输入和条件输入到我们要使用的值。就我们的情况,我们可以很容易的产生尽可能多的随机数,所以我们确定生成器设用前面确定的generate_data()函数。
In [5]:
def iterbatches(batches):
for i in range(batches):
classes, points = generate_data(gan.batch_size)
classes = dc.metrics.to_one_hot(classes, n_classes)
yield {gan.data_inputs[0]: points, gan.conditional_inputs[0]: classes}
gan.fit_gan(iterbatches(5000))
Ending global_step 999: generator average loss 0.87121, discriminator average loss 1.08472
Ending global_step 1999: generator average loss 0.968357, discriminator average loss 1.17393
Ending global_step 2999: generator average loss 0.710444, discriminator average loss 1.37858
Ending global_step 3999: generator average loss 0.699195, discriminator average loss 1.38131
Ending global_step 4999: generator average loss 0.694203, discriminator average loss 1.3871
TIMING: model fitting took 31.352 s
已经有训练的模型产生一些数据,看它如何配匹我们前面画的训练分布。
In [6]:
classes, points = generate_data(1000)
one_hot_classes = dc.metrics.to_one_hot(classes, n_classes)
gen_points = gan.predict_gan_generator(conditional_inputs=[one_hot_classes])
plot.scatter(x=gen_points[:,0], y=gen_points[:,1], c=classes)
Out[6]:
<matplotlib.collections.PathCollection at 0x160dedf50>