系列文章目录




Pix2Pix GAN的例子

  • 系列文章目录
  • 1. Pix2Pix介绍
  • 2. 下载卫星地图数据集
  • 3. 数据预处理(Data Reprocessing)
  • 4. 定义判别器
  • 5. 定义生成器
  • 6. 定义GAN模型
  • 7. 加载真实图片以及生成假的图片
  • 8. 用生成器每个几个Epoch生成一些假的图片。看看效果
  • 10. 训练过程
  • 11. 训练后效果
  • 12.完整的代码


1. Pix2Pix介绍

Pix2Pix是一个对抗神经网络(GAN)模型,设计一般用于图像到图像转换。

该方法由Phillip Isola等提出。在其2016年题为“使用条件对抗网络的图像到图像翻译”的论文中,该论文于2017年在CVPR上发表。

GAN架构由用于输出新的合理合成图像的生成器模型和将图像分类为真实(来自数据集)或伪图像(生成)的鉴别器模型组成。鉴别器模型直接更新,而生成器模型通过鉴别器模型更新。这样,在对抗过程中同时训练两个模型,其中生成器试图更好地欺骗鉴别器,而鉴别器试图更好地识别伪造图像。

Pix2Pix模型是一种条件GAN或cGAN,其中输出图像的生成取决于输入(在这种情况下为源图像)。鉴别器既提供源图像又提供目标图像,并且必须确定目标是否是源图像的合理变换。

通过对抗损失训练生成器,这鼓励了生成器在目标域中生成合理的图像。还通过在生成的图像和预期的输出图像之间测量的L1损耗来更新生成器。这种额外的损失鼓励生成器模型创建源图像的合理翻译。

Pix2Pix GAN已在一系列图像到图像转换任务中得到了证明,例如将地图转换为卫星照片,将黑白照片转换为颜色,将产品草图转换为产品照片。

现在我们已经熟悉了Pix2Pix GAN,下面我们准备一个可用于图像到图像转换的数据集。

2. 下载卫星地图数据集

这个数据集由纽约的卫星图像及其相应的Google地图组成。 图像的转换问题涉及将卫星照片转换为Google地图格式,或者将Google地图图像转换为卫星照片。

数据集在pix2pix网站上提供,可以作为255 MB的zip文件下载。
Download Maps Dataset (maps.tar.gz)

下载后解压后目录结构如下:

生成对抗网络去雾 生成对抗网络实战_生成对抗网络去雾


进入任意一个目录,打开其中一个图片,

生成对抗网络去雾 生成对抗网络实战_数据集_02

3. 数据预处理(Data Reprocessing)

为了让图片在训练的时候加载的快一点,我们把下载的所有的图片都用Numpy保存在maps_256.npz.

from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed

# load all images in a directory into memory
def load_images(path, size=(256,512)):
	src_list, tar_list = list(), list()
	# enumerate filenames in directory, assume all are images
	for filename in listdir(path):
		# load and resize the image
		pixels = load_img(path + filename, target_size=size)
		# convert to numpy array
		pixels = img_to_array(pixels)
		# split into satellite and map
		sat_img, map_img = pixels[:, :256], pixels[:, 256:]
		src_list.append(sat_img)
		tar_list.append(map_img)
	return [asarray(src_list), asarray(tar_list)]

# dataset path
path = 'D:/ML/datasets/maps/train/'
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'maps_256.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)

结果是

Loaded:  (1096, 256, 256, 3) (1096, 256, 256, 3)
Saved dataset:  maps_256.npz

然后运行下面代码验证一下是否正确的可以显示图片。

# load the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('maps_256.npz')
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + i)
	pyplot.axis('off')
	pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + n_samples + i)
	pyplot.axis('off')
	pyplot.imshow(tar_images[i].astype('uint8'))
pyplot.show()

生成对抗网络去雾 生成对抗网络实战_生成器_03

4. 定义判别器

这个判别器是基于PatchGAN discriminator model实现的。

注意这里的输入是两个图片,in_src_image是卫星图像, in_target_image是谷歌地图。
同过Concatenate方法,合并为6个通道,每天图片是的3个通道(RGB).
激活函数用LeakyReLU, 除了第一层与最后一层,其它都用BatchNormalization.
输出层输出是(16,16,1)

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model
if __name__ == '__main__':
	d_model = define_discriminator((256,256,3))
	print(d_model.summary())

它的结构是

生成对抗网络去雾 生成对抗网络实战_数据集_04

5. 定义生成器

生成器是使用U-Net架构的encoder-decoder模型。 该模型获取源图像(例如卫星照片)并生成目标图像(例如Google地图图像)。 它首先通过对输入图像进行下采样或编码到瓶颈层(bottleneck layer),然后对瓶颈(bottleneck layer)表示进行上采样或解码到输出图像的大小来做到这一点。 U-Net体系结构意味着在编码层和相应的解码层之间添加跳过连接(skip-connections),从而形成U形。

下图清楚地显示了跳过连接(skip-connections),显示了编码器的第一层如何连接到解码器的最后一层,依此类推。

生成对抗网络去雾 生成对抗网络实战_生成器_05

生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers组成。 这种标准化意味着我们可以开发辅助函数来创建每个图层块,并反复调用它以建立模型的encoder和decoder部分。

下面的define_generator()函数实现了U-Net编码器-解码器生成器模型。 它使用define_encoder_block()帮助函数创建用于编码器的层块,并使用coder_block()函数创建用于解码器的层块。 tanh激活函数在输出层中使用,这意味着生成的图像中的像素值将在[-1,1]范围内。

输入是一个文星图片,经过Encoder-Decoder这个网络结构,最后生成一个谷歌地图
(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model
if __name__ == '__main__':
	g_model = define_generator((256,256,3))
	print(g_model.summary())

它的结构是

生成对抗网络去雾 生成对抗网络实战_深度学习_06

6. 定义GAN模型

GAN的模型主要是训练生成器,所以判别器不训练(d_model.trainable = False)。
输入层是卫星图片(256,256,3),
输出层是 dis_out=(16,16,1)
gen_out = (256,256,3)

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	d_model.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model
if __name__ == '__main__':
	d_model = define_discriminator((256,256,3))
	g_model = define_generator((256,256,3))
	gan_model = define_gan(g_model, d_model, (256,256,3))
	print(g_model.summary())
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 64) 3136        input_3[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  524544      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 32, 32, 256)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  2097664     leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 512)  2048        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 16, 16, 512)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 8, 8, 512)    4194816     leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 512)    2048        conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 8, 8, 512)    0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 4, 4, 512)    4194816     leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4, 4, 512)    2048        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 4, 4, 512)    0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 2, 2, 512)    4194816     leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2, 2, 512)    2048        conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 2, 2, 512)    0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 1, 1, 512)    4194816     leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 1, 1, 512)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 2, 2, 512)    4194816     activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 2, 2, 512)    2048        conv2d_transpose[0][0]           
__________________________________________________________________________________________________
dropout (Dropout)               (None, 2, 2, 512)    0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 2, 2, 1024)   0           dropout[0][0]                    
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 2, 2, 1024)   0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512)    8389120     activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 4, 4, 512)    2048        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 4, 4, 1024)   0           dropout_1[0][0]                  
                                                                 leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 4, 4, 1024)   0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512)    8389120     activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 8, 8, 512)    2048        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 8, 8, 512)    0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 8, 8, 1024)   0           dropout_2[0][0]                  
                                                                 leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 8, 8, 1024)   0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512)  8389120     activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 16, 16, 1024) 0           batch_normalization_13[0][0]     
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 16, 16, 1024) 0           concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256)  4194560     activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 256)  1024        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 32, 32, 512)  0           batch_normalization_14[0][0]     
                                                                 leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 512)  0           concatenate_5[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128)  1048704     activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 64, 64, 128)  512         conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 64, 64, 256)  0           batch_normalization_15[0][0]     
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 256)  0           concatenate_6[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208      activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256         conv2d_transpose_6[0][0]         
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 128, 128, 128 0           batch_normalization_16[0][0]     
                                                                 leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 128, 128, 128 0           concatenate_7[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3)  6147        activation_8[0][0]               
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 256, 256, 3)  0           conv2d_transpose_7[0][0]         
==================================================================================================
Total params: 54,429,315
Trainable params: 54,419,459
Non-trainable params: 9,856

7. 加载真实图片以及生成假的图片

load_real_samples方法是加载真实图片。
generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)
generate_fake_samples方法是生成假的图片。每个数组标签都是0,shape是(16,16,1)
标签这里不一样,一般是数字,但是这里是shape为(16,16,1)三维数组。

# load and prepare training images
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

8. 用生成器每个几个Epoch生成一些假的图片。看看效果

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
	# select a sample of input images
	[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
	# generate a batch of fake samples
	X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
	# scale all pixels from [-1,1] to [0,1]
	X_realA = (X_realA + 1) / 2.0
	X_realB = (X_realB + 1) / 2.0
	X_fakeB = (X_fakeB + 1) / 2.0
	# plot real source images
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realA[i])
	# plot generated target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_fakeB[i])
	# plot real target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realB[i])
	# save plot to file
	filename1 = 'pix2pix_plot_%06d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'pix2pix_model_%06d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

10. 训练过程

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	# determine the output square shape of the discriminator
	n_patch = d_model.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		# update discriminator for real samples
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		# update discriminator for generated samples
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		# update the generator
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
		# summarize performance
		print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
		# summarize model performance
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, dataset)

11. 训练后效果

在前10个时间段之后,尽管街道的线条并不完全笔直且图像中有些模糊,但仍会生成看起来合理的地图图像。 但是,大型结构在正确的位置带有大多数正确的颜色。

生成对抗网络去雾 生成对抗网络实战_数据集_07


经过约50个训练时期后生成的图像开始看起来非常逼真,至少意味着,并且在其余训练过程中质量似乎仍然保持良好。请注意下面第一个生成的图像示例(右列,中间行),该示例包含比真实Google地图图像更有用的细节。

生成对抗网络去雾 生成对抗网络实战_生成对抗网络去雾_08

12.完整的代码

# example of pix2pix gan for satellite to map image-to-image translation
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose

from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	d_model.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	print(dis_out)
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model

# load and prepare training images
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
	# select a sample of input images
	[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
	# generate a batch of fake samples
	X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
	# scale all pixels from [-1,1] to [0,1]
	X_realA = (X_realA + 1) / 2.0
	X_realB = (X_realB + 1) / 2.0
	X_fakeB = (X_fakeB + 1) / 2.0
	# plot real source images
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realA[i])
	# plot generated target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_fakeB[i])
	# plot real target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realB[i])
	# save plot to file
	filename1 = 'pix2pix_plot_%06d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'pix2pix_model_%06d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	# determine the output square shape of the discriminator
	n_patch = d_model.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		# update discriminator for real samples
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		# update discriminator for generated samples
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		# update the generator
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
		# summarize performance
		print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
		# summarize model performance
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, dataset)

def start_train():
	# load image data
	dataset = load_real_samples('maps_256.npz')
	print('Loaded', dataset[0].shape, dataset[1].shape)
	# define input shape based on the loaded dataset
	image_shape = dataset[0].shape[1:]
	# define the models
	d_model = define_discriminator(image_shape)
	print(image_shape)
	print(d_model.summary())
	g_model = define_generator(image_shape)
	# define the composite model
	gan_model = define_gan(g_model, d_model, image_shape)
	# train model
	train(d_model, g_model, gan_model, dataset)

if __name__ == '__main__':
	#d_model = define_discriminator((256,256,3))
	#print(d_model.summary())
	#g_model = define_generator((256,256,3))
	#print(g_model.summary())
	#gan_model = define_gan(g_model, d_model, (256,256,3))
	#print(g_model.summary())
	start_train()