解决办法:对抗神经网路
生成器和判别器,生成器用来生成图像(可以从文字生成图片,也可以图片生成图片),判别器用来判别该图像是否符合该描述
对于这样的哲学,不能出现对方比另外一方层次高太多,会导致对方无法进一步地提升
DCGAN
中间为反卷积操作
A为输入矩阵,W为卷积核,B为卷积输出
B的第一个输出可以表示成
卷积的操作可以看成是下面的矩阵的操作
把输入矩阵转化为一维向量。
B = W*AT
W∈[4,16],AT∈[16,1]
则B的输出可以表示成
下面的矩阵[4,16]
与输入的一维向量转置[16,1]做内积,得到[4,1]的B,就是
|B00 |B01 |
|B10 |B11 |
的展开反卷积
A = WT*B
WT∈[16,4],B∈[4,1],A属于[16,1]。
实际操作中,我们需要对B做padding,无论是卷积还是反卷积,核还是3X3。
卷积的正向传播,就是反卷积的反向传播。反卷积的正向传播就是卷积的反向传播
反卷积的扩展(padding的方式可以多变化)
下图中,strided>0
,0<fractional-strided<1
,在以前的笔记中提到,当卷积中strided>1
时,可以缩小图片,0<strided<1
时,可以放大图片。
实现效果,
如下图,取多张的图片的平均值效果会好得多
图像翻译Pix2Pix
在图像翻译问题中,一张图片的表达形式有很多种。如下图
相应的网络结构:D网络结构需要解决的是判断两张图片是否为一对,用普通的神经网络结构可以实现。而G则需要新的网络结构,如下图:U-Net
为Encoder-decoder
的变种。如图所示,变化的地方在于使用反卷积之后再与反卷积之前的网络层进行拼接
U-Net
网络结构如下
Pix2Pix模型的缺点在于数据必须成对而且是标签好的,但实际上要获取成对的数据是比较难的,大多数都是没有标签的数据或者是单一的数据。
CycleGAN
寻找解决办法CycleGAN的提出。
如下图CycleGAN所示,初始的x与卷积,反卷积之后的x’之间存在有一个loss,同样的,y与y’之间也存在一个loss
GAN
的loss
定义
cycleGAN
的loss
定义
合并得到
思考:CycleGAN为何会有效?
答:在GAN里面,图片是没有配对的。所以任意一张A领域的图片是没有映射到B领域的哪一张图片,如果没有loss的计算的话,映射的范围会变得很大。所以我们加入一致性约束之后,降低了搜索空间。A领域的一张图片映射到B领域的一张图片,在Cycle中,B领域那一张图片就要映射到A领域中相似的图片。
效果如下
为了解决多领域图像翻译的问题。
StartGAN
引用StarGAN的问题
如下图,输入一组图像,一张为真,另一张为假,判断器需要判断这组图像是否为真/假(是否为一对),还需要判断是否为同一类别的图像。
如下图所示,b为我们的生成器,d为我们的判别器,c的作用类似于CycleGAN的效果类似,实现一致性约束。
StartGAN详细结构, 如上图所示,初始时Target domain和Input image进行拼接,就是下图最左边的in。
loss损失函数
对于分类判别的loss损失函数有两个,对于生成器和判别器分别使用不同的loss损失函数
一个是对于生成后的图片
另外一个是对于原始(真实)的图片
另外地还有一个loss损失函数是重建图片后的loss损失函数(类似于cycleGAN中一致性损失函数)
在训练判别器D的时候使用下面这个损失函数
在训练生成器G的时候使用
利用两个不同的分类损失函数,这样分类器可以通过,这样分类器可以通过两种不同的分类损失函数得到不同的数据,从而学习到不同的信息。判别器D的训练可以使用真实的图片的loss,输入不会有误差,可以提高D。对于G生成器可以使用重建后的loss损失函数,由于D已经训练好了,从而也可以提升本身G
回到上述的一个重要的模型结构
从左到右的解释:
最左边的输入为词的embedding
和一个随机向量
加入随机向量是为了多样性,应对文本可以映射到不同领域的图像。
随机向量和词的embedding进行拼接,再进反卷积操作,最后得到一张图片。进入到判别器D,判别器进行的是判断该图片是否为真的图片,还需要判别词的embedding与该图片是否匹配。
训练技巧:可以先提升G生成器,在提高D。除了判断真假描述之外,还需要训练出真图与不匹配的描述,类似于数据增强,有利于神经网络的提升。
伪代码:
3:文字的embedding
4:不匹配的文本进行embedding。
5:生成随机向量。
6:输入到G生成器:,生成图像
7:真实图像和真实的描述组成Sr
8:真实图像和错误的描述组成Sw
9:假的图像和真实的描述组成Sf
10:将Sr,Sw,Sf都输到D中,计算loss损失函数
11:然后更新D
12:在利用Sf计算G的损失函数。
13:更新G