完整代码下载链接
🍞正在为您运送作品详情
因为之前用生成对抗网络及众多变体生成诸如心电信号,肌电信号,脑电信号,微震信号,机械振动信号,雷达信号等,但生成的信号在频谱或者时频谱上表现很差,所以暂时先不涉及到这些复杂信号,仅仅以手写数字图像为例进行说明,因为Python相关的资源太多了,我就不凑热闹了,使用的编程环境为MALAB R2021B。
首先看一下对抗自编码器AAE(Adversarial AutoEncoder),关于AAE的大致理解,可以查看如下文章
AAE(Adversarial Autoencoders)浅解 - 嘎嘎小鱼仔的文章 - 知乎 AAE(Adversarial Autoencoders)浅解 - 知乎
AAE根据变分自编码器VAE发展而来,其发展之处就在于加入了对抗的思想。
上半部分就是一个简单典型的自编码器AE结构,包含输入层input layer,编码层encoder layer, 隐层hidden layer, 解码层decoder layer , 输出层output layer。encoder把真实分布x映射为隐层z, decoder 再将z解码还原成x。AAE的特点就在于在隐层hidden layer中引入了对抗的思想来优化隐层的z,判别器discriminator 需要在隐层判断采样后的真实数据和生成器encoder所产生的假数据。因此discriminator的目的就是使得q(z | x) 不断向p(z)靠近。
Adversarial Autoencoders论文链接:https://arxiv.org/abs/1511.0564
下面直接上代码
首先,导入相关的mnist手写数字图
load('mnistAll.mat')
然后对训练、测试图像进行预处理
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;%测试标签
preprocess为归一化函数,如下
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等
settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;
下面进行编码器初始化,代码还是很容易看懂的
paramsEn.FCW1 = dlarray(initializeGaussian([512,...
prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));
解码器初始化
paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));
判别器初始化
paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));
%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];
开始训练
dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
tic;
shuffleid = randperm(size(trainX,2));
trainXshuffle = trainX(:,shuffleid);
fprintf('Epoch %d\n',epoch)
for i=1:numIterations
global_iter = global_iter+1;
idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');
[GradEn,GradDe,GradDis] = ...
dlfeval(@modelGradients,XBatch,...
paramsEn,paramsDe,paramsDis,settings);
% 更新判别器网络参数
[paramsDis,avgG.Dis,avgGS.Dis] = ...
adamupdate(paramsDis, GradDis, ...
avgG.Dis, avgGS.Dis, global_iter, ...
settings.lrD, settings.beta1, settings.beta2);
% 更新编码器网络参数
[paramsEn,avgG.En,avgGS.En] = ...
adamupdate(paramsEn, GradEn, ...
avgG.En, avgGS.En, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
% 更新解码器网络参数
[paramsDe,avgG.De,avgGS.De] = ...
adamupdate(paramsDe, GradDe, ...
avgG.De, avgGS.De, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
if i==1 || rem(i,20)==0
progressplot(paramsDe,settings);
if i==1
h = gcf;
% 捕获图像
frame = getframe(h);
im = frame2im(frame);
[imind,cm] = rgb2ind(im,256);
% 写入 GIF 文件
if epoch == 0
imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf);
else
imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append');
end
end
end
end
elapsedTime = toc;
disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
epoch = epoch+1;
if epoch == settings.maxepochs
out = true;
end
end
下面是完整的辅助函数
模型的梯度计算函数
function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
dly(settings.latent_dim+1:2*settings.latent_dim)*...
randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');
%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));
%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));
%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
提取数据函数
function x = gatext(x)
x = gather(extractdata(x));
end
GPU深度学习数组wrapper函数
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
权重初始化函数
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
dropout函数
function dly = dropout(dlx,p)
if nargin < 2
p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end
编码器函数
function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end
解码器函数
function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end
判别器函数
function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end
动态进度图
function progressplot(paramsDe,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Decoder(noise,paramsDe);
gen_imgs = reshape(gen_imgs,28,28,[]);
fig = gcf;
if ~isempty(fig.Children)
delete(fig.Children)
end
I = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap gray
drawnow;
end
最后,看一下生成的GIF动态图
以后会讲
(1)辅助分类器生成对抗网络Auxiliary Classifier Generative Adversarial Network
(2)条件生成对抗网络Conditional Generative Adversarial Network
(3)深层卷积生成对抗网络Deep Convolutional Generative Adversarial Network
(4)最基础的生成对抗网络Basic Generative Adversarial Network
(5)Info Generative Adversarial Network
(6)最小二乘生成对抗网络Least Squares Generative Adversarial Network
(7)著名的Pixels-to-Pixels
(8)半监督生成对抗网络Semi-Supervised Generative Adversarial Network
(9)著名的Wasserstein Generative Adversarial Network
相应的参考文献如下
- Y. LeCun and C. Cortes, “MNIST handwritten digitdatabase,” 2010. [MNIST]
- J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, andL. Fei-Fei, “ImageNet: A Large-Scale Hierarchical Image Database,” inCVPR09, 2009. [Apple2Orange (ImageNet)]
- R. Tyleček and R. Šára, “Spatial pattern templates forrecognition of objects with regular structure,” inProc.GCPR, (Saarbrucken, Germany), 2013. [Facade]
- Z. Liu, P. Luo, X. Wang, and X. Tang, “Deep learn-ing face attributes in the wild,” inProceedings of In-ternational Conference on Computer Vision (ICCV),December 2015. [CelebA]
- Goodfellow, Ian J. et al. “Generative Adversarial Networks.” ArXiv abs/1406.2661 (2014): n. pag. (GAN)
- Radford, Alec et al. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” CoRR abs/1511.06434 (2015): n. pag. (DCGAN)
- Denton, Emily L. et al. “Semi-Supervised Learning with Context-Conditional Generative Adversarial Networks.” ArXiv abs/1611.06430 (2017): n. pag. (CGAN)
- Odena, Augustus et al. “Conditional Image Synthesis with Auxiliary Classifier GANs.” ICML (2016). (ACGAN)
- Chen, Xi et al. “InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.” NIPS (2016). (InfoGAN)
- Makhzani, Alireza et al. “Adversarial Autoencoders.” ArXiv abs/1511.05644 (2015): n. pag. (AAE)
- Isola, Phillip et al. “Image-to-Image Translation with Conditional Adversarial Networks.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016): 5967-5976. (Pix2Pix)
- J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros, “Unpairedimage-to-image translation using cycle-consistent ad-versarial networks,” 2017. (CycleGAN)
- Arjovsky, Martín et al. “Wasserstein GAN.” ArXiv abs/1701.07875 (2017): n. pag. (WGAN)
- Odena, Augustus. “Semi-Supervised Learning with Generative Adversarial Networks.” ArXiv abs/1606.01583 (2016): n. pag. (SGAN)
详细可见知乎文章
MATLAB生成对抗网络系列-持续更新 - 哥廷根数学学派的文章 - 知乎 https://zhuanlan.zhihu.com/p/565101258