1. 简介


2. 概述

2.1. 什么是GAN(生成对抗网络)

GANs是一个深度学习模型框架,用于获取训练数据的分布,这样我们就可以从同一分布中生成新的数据。GANs是Ian Goodfellow在2014年提出的,并在论文Generative Adversarial Nets中进行了首次描述。



现在,让我们从判别器开始定义一些在整个教程中都会使用的符号。令pytorch dp和ddp的区别 pytorch dcgan_数据 为图像数据,pytorch dp和ddp的区别 pytorch dcgan_卷积_02是判别器网络输出pytorch dp和ddp的区别 pytorch dcgan_数据来自训练数据而不是生成器的概率。由于我们要处理图像,因此pytorch dp和ddp的区别 pytorch dcgan_卷积_02的输入是CHW大小为3x64x64的图像。直观地说,当pytorch dp和ddp的区别 pytorch dcgan_数据来自训练数据时,pytorch dp和ddp的区别 pytorch dcgan_卷积_02的值应该高;当pytorch dp和ddp的区别 pytorch dcgan_数据来自生成生成器时,pytorch dp和ddp的区别 pytorch dcgan_卷积_02的值应该低。pytorch dp和ddp的区别 pytorch dcgan_卷积_02其实也可以看作是传统的二分类器。

对于生成器的表示法,令pytorch dp和ddp的区别 pytorch dcgan_生成器_10为从标准正态分布采样的潜在空间向量。pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_11表示将潜在空间向量pytorch dp和ddp的区别 pytorch dcgan_生成器_10映射到数据空间的生成器函数。pytorch dp和ddp的区别 pytorch dcgan_数据_13的目标是估计训练数据分布(pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_14),以便它可以从估计的数据分布(pytorch dp和ddp的区别 pytorch dcgan_卷积_15)中生成假样本。

因此,pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_16是生成器pytorch dp和ddp的区别 pytorch dcgan_数据_13的输出为真实图像的概率值(标量)。正如Goodfellow的论文所描述的,pytorch dp和ddp的区别 pytorch dcgan_生成器_18pytorch dp和ddp的区别 pytorch dcgan_数据_13玩一个minimax的游戏,其中pytorch dp和ddp的区别 pytorch dcgan_生成器_18尝试使它能正确分类真图和伪图的概率最大化(pytorch dp和ddp的区别 pytorch dcgan_卷积_21),而pytorch dp和ddp的区别 pytorch dcgan_数据_13却尝试使pytorch dp和ddp的区别 pytorch dcgan_生成器_18预测其输出是伪图的概率最小化(pytorch dp和ddp的区别 pytorch dcgan_数据_24)。论文中,GAN的损失函数是:pytorch dp和ddp的区别 pytorch dcgan_生成器_25 从理论上讲,此minimax游戏的最终解决方案是pytorch dp和ddp的区别 pytorch dcgan_数据_26,并且判别器会随机猜测输入的图像是真还是假。但是GANs的收敛理论仍在积极地研究中,实际上模型也并不总是能够达到这一点。

2.2. 什么是DCGAN(深度卷积生成对抗网络)

DCGAN是上述讲的GAN的一个分支,不同的是DCGAN分别在判别器和生成器中使用卷积和反卷积层。它最初是由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中提出的。


生成器由反卷积层、批标准化层、以及ReLU激活函数组成。输入是一个来自标准正分布的潜在空间向量pytorch dp和ddp的区别 pytorch dcgan_生成器_10,输出是一个3x64x64的RGB彩色图片。反置卷积层将潜在空间向量转换为具有与真实图像相同的维度。论文中,作者还提供了有关如何设置优化器,如何计算损失函数,以及如何初始化模型权重的一些技巧,所有这些将在接下来的部分中进行讲解。

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)


Random Seed:  999

3. 输入

  • dataroot:数据集文件夹所在的路径
  • workers :数据加载器加载数据的线程数
  • batch_size:训练的批次大小。DCGAN论文中用的是128
  • image_size:训练图像的维度。默认是64x64。如果需要其它尺寸,必须更改pytorch dp和ddp的区别 pytorch dcgan_卷积_28pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_29的结构,点击这里查看详情
  • nc:输入图像的通道数。对于彩色图像是3
  • nz:潜在空间的长度
  • ngf:与通过生成器进行的特征映射的深度有关
  • ndf:设置通过鉴别器传播的特征映射的深度
  • num_epochs:训练的总轮数。训练的轮数越多,可能会导致更好的结果,但也会花费更长的时间
  • lr:学习率。DCGAN论文中用的是0.0002
  • beta1:Adam优化器的参数beta1。论文中,值为0.5
  • ngpus:可用的GPU数量。如果为0,代码将在CPU模式下运行;如果大于0,它将在该数量的GPU下运行
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

4. 数据

本教程中,我们将使用Celeb-A Faces数据集,该数据集可以在链接的网站或谷歌云盘中下载。数据集下载下来是一个名为img_align_celeba.zip的压缩文件。下载后,创建一个名为celeba的目录,并将zip文件解压到该目录中。然后,将dataroot设置为刚创建的目录。结果目录结构应该为:

    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg


# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

5. 实现


5.1. 权重初始化


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

5.2. 生成器

生成器pytorch dp和ddp的区别 pytorch dcgan_数据_13用于将潜在空间向量pytorch dp和ddp的区别 pytorch dcgan_生成器_10映射到数据空间。由于我们的数据是图像,因此将pytorch dp和ddp的区别 pytorch dcgan_生成器_10转换到数据空间意味着最终创建与训练图像大小相同的RGB图像(即3x64x64)。

实际上,这是通过一系列的二维反卷积层来完成的,每层都配带有批标准化层和relu激活。生成器的输出最终经过tanh函数处理,以使其返回到[-1, 1]的输入数据范围。


pytorch dp和ddp的区别 pytorch dcgan_数据_33

注意,我们在输入部分中设置的输入(nz,ngf和nc)如何影响代码中的生成器体系结构。 nz是输入向量pytorch dp和ddp的区别 pytorch dcgan_生成器_10的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像的通道数(对于RGB图像来说是3)。 下面是生成器的代码。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            # state size. (nc) x 64 x 64

    def forward(self, input):
        return self.main(input)


# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.

# Print the model


  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()

5.3. 判别器

如前所述,判别器pytorch dp和ddp的区别 pytorch dcgan_生成器_18是一个二分类网络,该网络将图像作为输入,并输出该图是真(与假相对)的标量概率。

这里,pytorch dp和ddp的区别 pytorch dcgan_生成器_18以3x64x64的图像作为输入,通过一系列的Conv2d,BatchNorm2d和LeakyReLU层的处理,然后通过Sigmoid激活函数输出最终概率。对于这个问题,如果需要的话,这个体系结构可以扩展更多的层,但是使用strided convolution,BatchNorm和LeakyReLUs具有重要意义。DCGAN论文提到,使用strided convolution而不是通过池化来进行下采样是个好方法,因为它可以让网络学习自己的池化函数。 batch norm和leaky relu函数还可以促进健康的梯度流动,这对于pytorch dp和ddp的区别 pytorch dcgan_数据_13pytorch dp和ddp的区别 pytorch dcgan_生成器_18的学习过程都至关重要。


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

    def forward(self, input):
        return self.main(input)


# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.

# Print the model


  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()

5.4. 损失函数和优化器

pytorch dp和ddp的区别 pytorch dcgan_生成器_18pytorch dp和ddp的区别 pytorch dcgan_数据_13设置之后,我们可以指定它们如何通过损失函数和优化器学习。我们将使用在PyTorch中定义的二元交叉熵损失(BCELoss)函数:pytorch dp和ddp的区别 pytorch dcgan_生成器_41 注意此函数如何提供目标函数中两个对数成分的计算(即pytorch dp和ddp的区别 pytorch dcgan_卷积_42pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_43)。 我们可以指定BCE方程的哪一部分用于pytorch dp和ddp的区别 pytorch dcgan_数据_44输入。 这是在即将到来的训练循环中完成的,但重要的是要了解如何仅通过更改pytorch dp和ddp的区别 pytorch dcgan_数据_44(即GT标签)就可以选择想要计算的组件。

接下来,我们将实际标签定义为1,将假标签定义为0。这些标签将在计算pytorch dp和ddp的区别 pytorch dcgan_生成器_18pytorch dp和ddp的区别 pytorch dcgan_数据_13的损失时使用,这是在原始GAN论文中使用的惯例。

最后,我们设置了两个单独的优化器,一个针对pytorch dp和ddp的区别 pytorch dcgan_生成器_18,一个针对pytorch dp和ddp的区别 pytorch dcgan_数据_13。正如DCGAN论文中所规定的,这两个都是lr为0.0002且Beta1为0.5的Adam优化器。为了跟踪生成器的学习过程,我们将生成一批来自高斯分布的固定潜在空间向量(即fixed_noise)。在训练循环中,我们将定期地把fixed_noise输入到pytorch dp和ddp的区别 pytorch dcgan_数据_13中,经过多次迭代,我们将看到图像从噪声中形成。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

5.5. 训练


在这里,我们将严格遵守Goodfellow论文中的算法1,同时遵守ganhacks中展示的一些最佳做法。也即是说,我们将为真图和假图构造不同的mini-batches,并调整pytorch dp和ddp的区别 pytorch dcgan_数据_13的目标函数,使pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_52最大化。训练分为两个主要部分,第一部分是判别器的更新,第二部分是生成器的更新。

5.5.1. 第一部分 - 训练判别器


实际上,我们想最大化pytorch dp和ddp的区别 pytorch dcgan_数据_53。由于ganhacks提出了单独的mini-batch建议,因此我们将分两步进行计算。首先,我们将从训练集中构造一批真实样本,向前传播给pytorch dp和ddp的区别 pytorch dcgan_生成器_18,计算损失(pytorch dp和ddp的区别 pytorch dcgan_卷积_42),然后向后传播计算梯度。接着,我们将用当前的生成器构造一批假样本,将该批样本向前传播给pytorch dp和ddp的区别 pytorch dcgan_生成器_18,计算损失(pytorch dp和ddp的区别 pytorch dcgan_生成器_57),并向后传播累加梯度。现在,随着从所有真实批次和所有假批次累积的梯度,我们称之为判别器的优化器的一个步骤。

5.5.2. 第二部分 - 训练生成器

如原论文所述,我们希望通过最小化pytorch dp和ddp的区别 pytorch dcgan_生成器_57来训练生成器,以产生更好的伪造品。但又如前所述,Goodfellow表明,这不能提供足够的梯度,特别是在学习过程的早期。而解决方案是改为最大化pytorch dp和ddp的区别 pytorch dcgan_卷积_59

在代码中,我们的具体实现方法是:用判别器对第一部分生成器的输出进行分类,使用真图的标签作为GT计算pytorch dp和ddp的区别 pytorch dcgan_数据_13的损失,计算pytorch dp和ddp的区别 pytorch dcgan_数据_13在反向传播中的梯度,最后通过优化器step更新pytorch dp和ddp的区别 pytorch dcgan_数据_13的参数。使用真图的标签作为GT来计算损失似乎是违反常识的,但这允许我们使用BCELoss的pytorch dp和ddp的区别 pytorch dcgan_生成器_63部分(而不是pytorch dp和ddp的区别 pytorch dcgan_卷积_64部分),这正是我们想要的。

最后,我们将做一些统计报告,在每个epoch结束时,我们将通过生成器推动我们的fixed_noise batch,以直观地跟踪pytorch dp和ddp的区别 pytorch dcgan_数据_13的训练过程。 上报的训练统计数据为:

  • Loss_D - 判别器损失,计算为所有真实批次和所有假批次的损失之和 (pytorch dp和ddp的区别 pytorch dcgan_生成器_66)。
  • Loss_G - 生成器损失,计算为log(D(G(z)))。
  • D(x) - 判别器对于真实批次的平均输出(整个批次)。刚开始训练的时候这个值应该接近1,当pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_29变得更好时,理论上收敛到0.5。想想这是为什么。
  • D(G(z)) - 判别器对于假批次的平均输出。第一个数字在pytorch dp和ddp的区别 pytorch dcgan_卷积_28更新之前,第二个数字在pytorch dp和ddp的区别 pytorch dcgan_卷积_28更新之后。这些数字在开始的时候应该是接近0的,并随着pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_29的提高向0.5收敛。想想这是为什么。


# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ## Train with all-real batch
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D

        # (2) Update G network: maximize log(D(G(z)))
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        D_G_z2 = output.mean().item()
        # Update G

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.9847  Loss_G: 5.5914  D(x): 0.6004    D(G(z)): 0.6680 / 0.0062
[0/5][50/1583]  Loss_D: 0.4017  Loss_G: 17.8778 D(x): 0.8368    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 2.8508  Loss_G: 22.8236 D(x): 0.9634    D(G(z)): 0.8460 / 0.0000
[0/5][150/1583] Loss_D: 0.2360  Loss_G: 5.4596  D(x): 0.8440    D(G(z)): 0.0308 / 0.0090
[0/5][200/1583] Loss_D: 1.6425  Loss_G: 4.7064  D(x): 0.3414    D(G(z)): 0.0079 / 0.0176
[0/5][250/1583] Loss_D: 0.2731  Loss_G: 4.4791  D(x): 0.9431    D(G(z)): 0.1680 / 0.0225
[0/5][300/1583] Loss_D: 0.6051  Loss_G: 4.6251  D(x): 0.8278    D(G(z)): 0.2424 / 0.0230
[0/5][350/1583] Loss_D: 0.7070  Loss_G: 1.6842  D(x): 0.6204    D(G(z)): 0.0824 / 0.2560
[0/5][400/1583] Loss_D: 0.6758  Loss_G: 4.0679  D(x): 0.9354    D(G(z)): 0.3946 / 0.0288
[0/5][450/1583] Loss_D: 0.5348  Loss_G: 5.7453  D(x): 0.9625    D(G(z)): 0.3514 / 0.0083
[0/5][500/1583] Loss_D: 0.6896  Loss_G: 7.8784  D(x): 0.9364    D(G(z)): 0.4080 / 0.0012
[0/5][550/1583] Loss_D: 0.4377  Loss_G: 8.1336  D(x): 0.9425    D(G(z)): 0.2840 / 0.0007
[0/5][600/1583] Loss_D: 1.8797  Loss_G: 2.5577  D(x): 0.3201    D(G(z)): 0.0123 / 0.1258
[0/5][650/1583] Loss_D: 1.3832  Loss_G: 10.6947 D(x): 0.9770    D(G(z)): 0.7006 / 0.0001
[0/5][700/1583] Loss_D: 0.3195  Loss_G: 3.7833  D(x): 0.8474    D(G(z)): 0.0844 / 0.0789
[0/5][750/1583] Loss_D: 0.2142  Loss_G: 4.1755  D(x): 0.8942    D(G(z)): 0.0813 / 0.0232
[0/5][800/1583] Loss_D: 1.4535  Loss_G: 2.3077  D(x): 0.4024    D(G(z)): 0.0111 / 0.1806
[0/5][850/1583] Loss_D: 0.4109  Loss_G: 6.3312  D(x): 0.9002    D(G(z)): 0.2153 / 0.0048
[0/5][900/1583] Loss_D: 2.7930  Loss_G: 4.5548  D(x): 0.1428    D(G(z)): 0.0022 / 0.0240
[0/5][950/1583] Loss_D: 0.3493  Loss_G: 5.5976  D(x): 0.8767    D(G(z)): 0.1498 / 0.0080
[0/5][1000/1583]        Loss_D: 0.6749  Loss_G: 5.0457  D(x): 0.6349    D(G(z)): 0.0215 / 0.0194
[0/5][1050/1583]        Loss_D: 0.4009  Loss_G: 4.5791  D(x): 0.7669    D(G(z)): 0.0484 / 0.0260
[0/5][1100/1583]        Loss_D: 0.3453  Loss_G: 2.7277  D(x): 0.8885    D(G(z)): 0.1408 / 0.1219
[0/5][1150/1583]        Loss_D: 0.2484  Loss_G: 5.0396  D(x): 0.8727    D(G(z)): 0.0595 / 0.0174
[0/5][1200/1583]        Loss_D: 0.6760  Loss_G: 3.2315  D(x): 0.7052    D(G(z)): 0.1756 / 0.0688
[0/5][1250/1583]        Loss_D: 0.5845  Loss_G: 3.1392  D(x): 0.7576    D(G(z)): 0.2018 / 0.0673
[0/5][1300/1583]        Loss_D: 0.2762  Loss_G: 4.9311  D(x): 0.8666    D(G(z)): 0.0933 / 0.0136
[0/5][1350/1583]        Loss_D: 0.4753  Loss_G: 4.7346  D(x): 0.8595    D(G(z)): 0.2228 / 0.0170
[0/5][1400/1583]        Loss_D: 0.3764  Loss_G: 5.9964  D(x): 0.7758    D(G(z)): 0.0109 / 0.0098
[0/5][1450/1583]        Loss_D: 0.4025  Loss_G: 3.8804  D(x): 0.8158    D(G(z)): 0.1413 / 0.0320
[0/5][1500/1583]        Loss_D: 0.6678  Loss_G: 2.7302  D(x): 0.6980    D(G(z)): 0.1486 / 0.1040
[0/5][1550/1583]        Loss_D: 0.6062  Loss_G: 3.1664  D(x): 0.7235    D(G(z)): 0.1305 / 0.0783
[1/5][0/1583]   Loss_D: 0.6615  Loss_G: 8.0512  D(x): 0.9412    D(G(z)): 0.3797 / 0.0007
[1/5][50/1583]  Loss_D: 0.8057  Loss_G: 2.1089  D(x): 0.5929    D(G(z)): 0.0869 / 0.1893
[1/5][100/1583] Loss_D: 0.4206  Loss_G: 3.3245  D(x): 0.7409    D(G(z)): 0.0554 / 0.0640
[1/5][150/1583] Loss_D: 0.6361  Loss_G: 4.0774  D(x): 0.7830    D(G(z)): 0.2605 / 0.0256
[1/5][200/1583] Loss_D: 1.7394  Loss_G: 7.5861  D(x): 0.9685    D(G(z)): 0.7499 / 0.0014
[1/5][250/1583] Loss_D: 0.4597  Loss_G: 3.1064  D(x): 0.7053    D(G(z)): 0.0265 / 0.0844
[1/5][300/1583] Loss_D: 0.4190  Loss_G: 2.2869  D(x): 0.7942    D(G(z)): 0.1163 / 0.1660
[1/5][350/1583] Loss_D: 0.4724  Loss_G: 4.3673  D(x): 0.8292    D(G(z)): 0.2106 / 0.0213
[1/5][400/1583] Loss_D: 0.2877  Loss_G: 4.3217  D(x): 0.8823    D(G(z)): 0.1125 / 0.0225
[1/5][450/1583] Loss_D: 0.8508  Loss_G: 0.8635  D(x): 0.5397    D(G(z)): 0.0390 / 0.5324
[1/5][500/1583] Loss_D: 0.4317  Loss_G: 3.1585  D(x): 0.7646    D(G(z)): 0.0931 / 0.0767
[1/5][550/1583] Loss_D: 0.8256  Loss_G: 6.1484  D(x): 0.9395    D(G(z)): 0.4563 / 0.0051
[1/5][600/1583] Loss_D: 0.9765  Loss_G: 1.5017  D(x): 0.4807    D(G(z)): 0.0076 / 0.2843
[1/5][650/1583] Loss_D: 1.8020  Loss_G: 8.8270  D(x): 0.9480    D(G(z)): 0.7248 / 0.0003
[1/5][700/1583] Loss_D: 0.3680  Loss_G: 3.7401  D(x): 0.7991    D(G(z)): 0.0949 / 0.0404
[1/5][750/1583] Loss_D: 0.5763  Loss_G: 2.0559  D(x): 0.6739    D(G(z)): 0.0851 / 0.1882
[1/5][800/1583] Loss_D: 0.7773  Loss_G: 5.0999  D(x): 0.9399    D(G(z)): 0.4335 / 0.0142
[1/5][850/1583] Loss_D: 0.3901  Loss_G: 3.4356  D(x): 0.8537    D(G(z)): 0.1744 / 0.0491
[1/5][900/1583] Loss_D: 0.7268  Loss_G: 6.5356  D(x): 0.9635    D(G(z)): 0.4428 / 0.0027
[1/5][950/1583] Loss_D: 0.4570  Loss_G: 3.8893  D(x): 0.8707    D(G(z)): 0.2376 / 0.0304
[1/5][1000/1583]        Loss_D: 1.3551  Loss_G: 7.2447  D(x): 0.9333    D(G(z)): 0.6422 / 0.0030
[1/5][1050/1583]        Loss_D: 0.3905  Loss_G: 3.3360  D(x): 0.8183    D(G(z)): 0.1462 / 0.0537
[1/5][1100/1583]        Loss_D: 1.3858  Loss_G: 0.9796  D(x): 0.3336    D(G(z)): 0.0259 / 0.4584
[1/5][1150/1583]        Loss_D: 0.5776  Loss_G: 2.6197  D(x): 0.6443    D(G(z)): 0.0532 / 0.1051
[1/5][1200/1583]        Loss_D: 0.5647  Loss_G: 3.5713  D(x): 0.8026    D(G(z)): 0.2450 / 0.0428
[1/5][1250/1583]        Loss_D: 0.4568  Loss_G: 3.6666  D(x): 0.8934    D(G(z)): 0.2581 / 0.0403
[1/5][1300/1583]        Loss_D: 0.7197  Loss_G: 1.8175  D(x): 0.6211    D(G(z)): 0.1035 / 0.2184
[1/5][1350/1583]        Loss_D: 0.5255  Loss_G: 3.2736  D(x): 0.8141    D(G(z)): 0.2233 / 0.0574
[1/5][1400/1583]        Loss_D: 0.8241  Loss_G: 3.0776  D(x): 0.7807    D(G(z)): 0.3659 / 0.0743
[1/5][1450/1583]        Loss_D: 0.4302  Loss_G: 3.3777  D(x): 0.9058    D(G(z)): 0.2518 / 0.0519
[1/5][1500/1583]        Loss_D: 0.4173  Loss_G: 2.5610  D(x): 0.7916    D(G(z)): 0.1358 / 0.1058
[1/5][1550/1583]        Loss_D: 0.7993  Loss_G: 5.1228  D(x): 0.8527    D(G(z)): 0.4162 / 0.0104
[2/5][0/1583]   Loss_D: 0.4844  Loss_G: 2.2263  D(x): 0.7645    D(G(z)): 0.1510 / 0.1426
[2/5][50/1583]  Loss_D: 0.6756  Loss_G: 2.4608  D(x): 0.5915    D(G(z)): 0.0657 / 0.1248
[2/5][100/1583] Loss_D: 0.4391  Loss_G: 3.0181  D(x): 0.7901    D(G(z)): 0.1486 / 0.0744
[2/5][150/1583] Loss_D: 0.5683  Loss_G: 1.8918  D(x): 0.7083    D(G(z)): 0.1411 / 0.1858
[2/5][200/1583] Loss_D: 0.5932  Loss_G: 3.3342  D(x): 0.9111    D(G(z)): 0.3576 / 0.0522
[2/5][250/1583] Loss_D: 0.7331  Loss_G: 2.3817  D(x): 0.6635    D(G(z)): 0.1665 / 0.1397
[2/5][300/1583] Loss_D: 0.5493  Loss_G: 2.3824  D(x): 0.7491    D(G(z)): 0.1742 / 0.1196
[2/5][350/1583] Loss_D: 0.6197  Loss_G: 1.8560  D(x): 0.6443    D(G(z)): 0.1018 / 0.1972
[2/5][400/1583] Loss_D: 0.6172  Loss_G: 3.0777  D(x): 0.8482    D(G(z)): 0.3251 / 0.0621
[2/5][450/1583] Loss_D: 0.5047  Loss_G: 3.2941  D(x): 0.9174    D(G(z)): 0.3116 / 0.0566
[2/5][500/1583] Loss_D: 0.7335  Loss_G: 1.2796  D(x): 0.5676    D(G(z)): 0.0575 / 0.3470
[2/5][550/1583] Loss_D: 0.7716  Loss_G: 1.9450  D(x): 0.5513    D(G(z)): 0.0580 / 0.1922
[2/5][600/1583] Loss_D: 0.4425  Loss_G: 2.0531  D(x): 0.8015    D(G(z)): 0.1640 / 0.1686
[2/5][650/1583] Loss_D: 1.0964  Loss_G: 4.4602  D(x): 0.9096    D(G(z)): 0.5833 / 0.0163
[2/5][700/1583] Loss_D: 0.4745  Loss_G: 2.8636  D(x): 0.8492    D(G(z)): 0.2403 / 0.0770
[2/5][750/1583] Loss_D: 0.4947  Loss_G: 3.6931  D(x): 0.8803    D(G(z)): 0.2732 / 0.0364
[2/5][800/1583] Loss_D: 0.9355  Loss_G: 4.3906  D(x): 0.9120    D(G(z)): 0.5168 / 0.0195
[2/5][850/1583] Loss_D: 0.9213  Loss_G: 1.6006  D(x): 0.4645    D(G(z)): 0.0339 / 0.2467
[2/5][900/1583] Loss_D: 0.5337  Loss_G: 3.7601  D(x): 0.9101    D(G(z)): 0.3310 / 0.0314
[2/5][950/1583] Loss_D: 1.2562  Loss_G: 4.9530  D(x): 0.9432    D(G(z)): 0.6244 / 0.0144
[2/5][1000/1583]        Loss_D: 0.4187  Loss_G: 2.4701  D(x): 0.8454    D(G(z)): 0.1945 / 0.1129
[2/5][1050/1583]        Loss_D: 0.5796  Loss_G: 2.3732  D(x): 0.7714    D(G(z)): 0.2253 / 0.1216
[2/5][1100/1583]        Loss_D: 0.6325  Loss_G: 2.5824  D(x): 0.8307    D(G(z)): 0.3235 / 0.0939
[2/5][1150/1583]        Loss_D: 0.7639  Loss_G: 3.9487  D(x): 0.9031    D(G(z)): 0.4398 / 0.0291
[2/5][1200/1583]        Loss_D: 0.7040  Loss_G: 3.3561  D(x): 0.8073    D(G(z)): 0.3403 / 0.0500
[2/5][1250/1583]        Loss_D: 1.0567  Loss_G: 4.7122  D(x): 0.9292    D(G(z)): 0.5656 / 0.0155
[2/5][1300/1583]        Loss_D: 0.5431  Loss_G: 2.4260  D(x): 0.7628    D(G(z)): 0.2028 / 0.1116
[2/5][1350/1583]        Loss_D: 0.7633  Loss_G: 4.1670  D(x): 0.9257    D(G(z)): 0.4404 / 0.0237
[2/5][1400/1583]        Loss_D: 2.1958  Loss_G: 0.5288  D(x): 0.1539    D(G(z)): 0.0147 / 0.6404
[2/5][1450/1583]        Loss_D: 0.6991  Loss_G: 1.8573  D(x): 0.5818    D(G(z)): 0.0621 / 0.1980
[2/5][1500/1583]        Loss_D: 0.8286  Loss_G: 3.6899  D(x): 0.8805    D(G(z)): 0.4440 / 0.0364
[2/5][1550/1583]        Loss_D: 0.5100  Loss_G: 2.5931  D(x): 0.7721    D(G(z)): 0.1862 / 0.0989
[3/5][0/1583]   Loss_D: 0.7136  Loss_G: 2.6315  D(x): 0.8178    D(G(z)): 0.3462 / 0.1034
[3/5][50/1583]  Loss_D: 0.6472  Loss_G: 2.6359  D(x): 0.7572    D(G(z)): 0.2460 / 0.0962
[3/5][100/1583] Loss_D: 0.5211  Loss_G: 1.7793  D(x): 0.7275    D(G(z)): 0.1402 / 0.2050
[3/5][150/1583] Loss_D: 0.9620  Loss_G: 4.0717  D(x): 0.9423    D(G(z)): 0.5500 / 0.0243
[3/5][200/1583] Loss_D: 0.5469  Loss_G: 2.1994  D(x): 0.7581    D(G(z)): 0.1972 / 0.1359
[3/5][250/1583] Loss_D: 0.3941  Loss_G: 2.7071  D(x): 0.7281    D(G(z)): 0.0401 / 0.0902
[3/5][300/1583] Loss_D: 0.6482  Loss_G: 1.4858  D(x): 0.6275    D(G(z)): 0.1085 / 0.2802
[3/5][350/1583] Loss_D: 1.2781  Loss_G: 4.7393  D(x): 0.9594    D(G(z)): 0.6587 / 0.0120
[3/5][400/1583] Loss_D: 0.5942  Loss_G: 2.8406  D(x): 0.7861    D(G(z)): 0.2579 / 0.0784
[3/5][450/1583] Loss_D: 0.5395  Loss_G: 1.9849  D(x): 0.6755    D(G(z)): 0.0854 / 0.1764
[3/5][500/1583] Loss_D: 0.7941  Loss_G: 2.5871  D(x): 0.7891    D(G(z)): 0.3784 / 0.1006
[3/5][550/1583] Loss_D: 0.6556  Loss_G: 3.9228  D(x): 0.9328    D(G(z)): 0.4053 / 0.0254
[3/5][600/1583] Loss_D: 0.6489  Loss_G: 3.2773  D(x): 0.8385    D(G(z)): 0.3419 / 0.0490
[3/5][650/1583] Loss_D: 0.9217  Loss_G: 1.3858  D(x): 0.4992    D(G(z)): 0.0854 / 0.3095
[3/5][700/1583] Loss_D: 0.4947  Loss_G: 2.2791  D(x): 0.7948    D(G(z)): 0.2035 / 0.1332
[3/5][750/1583] Loss_D: 0.9676  Loss_G: 1.6087  D(x): 0.4641    D(G(z)): 0.0363 / 0.2599
[3/5][800/1583] Loss_D: 0.5918  Loss_G: 1.8852  D(x): 0.7019    D(G(z)): 0.1637 / 0.1948
[3/5][850/1583] Loss_D: 0.7856  Loss_G: 3.4243  D(x): 0.8672    D(G(z)): 0.4219 / 0.0512
[3/5][900/1583] Loss_D: 0.5023  Loss_G: 2.7348  D(x): 0.8372    D(G(z)): 0.2416 / 0.0851
[3/5][950/1583] Loss_D: 0.9028  Loss_G: 1.8348  D(x): 0.5362    D(G(z)): 0.1219 / 0.2110
[3/5][1000/1583]        Loss_D: 0.8118  Loss_G: 3.9327  D(x): 0.9092    D(G(z)): 0.4586 / 0.0306
[3/5][1050/1583]        Loss_D: 0.8709  Loss_G: 3.1103  D(x): 0.8752    D(G(z)): 0.4686 / 0.0639
[3/5][1100/1583]        Loss_D: 0.4286  Loss_G: 2.9141  D(x): 0.8379    D(G(z)): 0.1912 / 0.0741
[3/5][1150/1583]        Loss_D: 0.6005  Loss_G: 1.8091  D(x): 0.7044    D(G(z)): 0.1727 / 0.2042
[3/5][1200/1583]        Loss_D: 0.7432  Loss_G: 3.8108  D(x): 0.9088    D(G(z)): 0.4344 / 0.0297
[3/5][1250/1583]        Loss_D: 0.6872  Loss_G: 1.8717  D(x): 0.7355    D(G(z)): 0.2731 / 0.1789
[3/5][1300/1583]        Loss_D: 0.5740  Loss_G: 3.4426  D(x): 0.8874    D(G(z)): 0.3380 / 0.0422
[3/5][1350/1583]        Loss_D: 0.5689  Loss_G: 2.0738  D(x): 0.6823    D(G(z)): 0.0966 / 0.1621
[3/5][1400/1583]        Loss_D: 0.5023  Loss_G: 3.1107  D(x): 0.9225    D(G(z)): 0.3231 / 0.0565
[3/5][1450/1583]        Loss_D: 0.7466  Loss_G: 3.1208  D(x): 0.8441    D(G(z)): 0.3891 / 0.0634
[3/5][1500/1583]        Loss_D: 0.7135  Loss_G: 2.8145  D(x): 0.8924    D(G(z)): 0.4117 / 0.0765
[3/5][1550/1583]        Loss_D: 0.7881  Loss_G: 4.0945  D(x): 0.9332    D(G(z)): 0.4717 / 0.0258
[4/5][0/1583]   Loss_D: 0.6309  Loss_G: 2.2672  D(x): 0.7764    D(G(z)): 0.2761 / 0.1311
[4/5][50/1583]  Loss_D: 0.8068  Loss_G: 1.4844  D(x): 0.5595    D(G(z)): 0.1015 / 0.2795
[4/5][100/1583] Loss_D: 0.4912  Loss_G: 2.0030  D(x): 0.7526    D(G(z)): 0.1516 / 0.1674
[4/5][150/1583] Loss_D: 3.0392  Loss_G: 0.6172  D(x): 0.0896    D(G(z)): 0.0134 / 0.6503
[4/5][200/1583] Loss_D: 0.6768  Loss_G: 2.5170  D(x): 0.7543    D(G(z)): 0.2852 / 0.0986
[4/5][250/1583] Loss_D: 1.2451  Loss_G: 0.9252  D(x): 0.3817    D(G(z)): 0.0554 / 0.4569
[4/5][300/1583] Loss_D: 0.5916  Loss_G: 1.7704  D(x): 0.6588    D(G(z)): 0.1113 / 0.2144
[4/5][350/1583] Loss_D: 1.3058  Loss_G: 0.6935  D(x): 0.3416    D(G(z)): 0.0394 / 0.5486
[4/5][400/1583] Loss_D: 0.6206  Loss_G: 3.0787  D(x): 0.8405    D(G(z)): 0.3261 / 0.0609
[4/5][450/1583] Loss_D: 0.5866  Loss_G: 1.4752  D(x): 0.6981    D(G(z)): 0.1565 / 0.2718
[4/5][500/1583] Loss_D: 0.5616  Loss_G: 3.0459  D(x): 0.8869    D(G(z)): 0.3223 / 0.0650
[4/5][550/1583] Loss_D: 0.6073  Loss_G: 3.2580  D(x): 0.7503    D(G(z)): 0.2344 / 0.0500
[4/5][600/1583] Loss_D: 0.6905  Loss_G: 3.0939  D(x): 0.8591    D(G(z)): 0.3762 / 0.0589
[4/5][650/1583] Loss_D: 0.5836  Loss_G: 1.7048  D(x): 0.6781    D(G(z)): 0.1227 / 0.2282
[4/5][700/1583] Loss_D: 0.8543  Loss_G: 3.7586  D(x): 0.8876    D(G(z)): 0.4712 / 0.0337
[4/5][750/1583] Loss_D: 0.8484  Loss_G: 2.3787  D(x): 0.6606    D(G(z)): 0.2724 / 0.1192
[4/5][800/1583] Loss_D: 0.5562  Loss_G: 2.1677  D(x): 0.7446    D(G(z)): 0.1887 / 0.1533
[4/5][850/1583] Loss_D: 0.7600  Loss_G: 1.4960  D(x): 0.5447    D(G(z)): 0.0559 / 0.2722
[4/5][900/1583] Loss_D: 0.5677  Loss_G: 3.0179  D(x): 0.8308    D(G(z)): 0.2804 / 0.0664
[4/5][950/1583] Loss_D: 0.5381  Loss_G: 2.9582  D(x): 0.7989    D(G(z)): 0.2345 / 0.0711
[4/5][1000/1583]        Loss_D: 0.8333  Loss_G: 2.8499  D(x): 0.7720    D(G(z)): 0.3700 / 0.0786
[4/5][1050/1583]        Loss_D: 0.5125  Loss_G: 1.8930  D(x): 0.7287    D(G(z)): 0.1387 / 0.1848
[4/5][1100/1583]        Loss_D: 0.4527  Loss_G: 3.0039  D(x): 0.8639    D(G(z)): 0.2413 / 0.0614
[4/5][1150/1583]        Loss_D: 0.7072  Loss_G: 0.8361  D(x): 0.5589    D(G(z)): 0.0563 / 0.4846
[4/5][1200/1583]        Loss_D: 0.8619  Loss_G: 4.9323  D(x): 0.9385    D(G(z)): 0.4880 / 0.0112
[4/5][1250/1583]        Loss_D: 0.6864  Loss_G: 2.4925  D(x): 0.7232    D(G(z)): 0.2431 / 0.1152
[4/5][1300/1583]        Loss_D: 0.5835  Loss_G: 3.1599  D(x): 0.8430    D(G(z)): 0.3018 / 0.0644
[4/5][1350/1583]        Loss_D: 0.9119  Loss_G: 4.7225  D(x): 0.9409    D(G(z)): 0.5082 / 0.0154
[4/5][1400/1583]        Loss_D: 0.3856  Loss_G: 3.1007  D(x): 0.8980    D(G(z)): 0.2238 / 0.0584
[4/5][1450/1583]        Loss_D: 1.3314  Loss_G: 5.1061  D(x): 0.9395    D(G(z)): 0.6621 / 0.0094
[4/5][1500/1583]        Loss_D: 0.5882  Loss_G: 1.7242  D(x): 0.6443    D(G(z)): 0.0785 / 0.2306
[4/5][1550/1583]        Loss_D: 0.5792  Loss_G: 2.0347  D(x): 0.7582    D(G(z)): 0.2143 / 0.1594

6. 结果

最后,让我们看看我们是如何做到的。在这里,我们将看到三个不同的结果。首先,我们将看到pytorch dp和ddp的区别 pytorch dcgan_生成器_18pytorch dp和ddp的区别 pytorch dcgan_数据_13的损失在训练过程中是如何变化的。然后,我们将可视化pytorch dp和ddp的区别 pytorch dcgan_数据_13在每个epoch的fixed_noise batch上的输出。最后,我们将对比一批真实数据和一批来自pytorch dp和ddp的区别 pytorch dcgan_数据_13的假数据。

6.1. 损失随迭代次数的变化趋势图

以下是pytorch dp和ddp的区别 pytorch dcgan_生成器_18pytorch dp和ddp的区别 pytorch dcgan_数据_13的损失与迭代次数的关系图。

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")

pytorch dp和ddp的区别 pytorch dcgan_生成器_77

6.2. 可视化G的训练过程

还记得我们是如何在每个训练的epoch后保存生成器的输出吗?现在,我们可以用动画来可视化pytorch dp和ddp的区别 pytorch dcgan_数据_13的训练过程。

fig = plt.figure(figsize=(8,8))
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)


pytorch dp和ddp的区别 pytorch dcgan_生成器_79

6.3. 真图 vs 假图


# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.title("Fake Images")

pytorch dp和ddp的区别 pytorch dcgan_pytorch dp和ddp的区别_80

7. 展望


  • 训练更长的时间,看看效果如何
  • 修改此模型以采用其他数据集,如果可能的话也可以更改图像的大小和模型架构
  • 在这里查看其他一些很酷的GAN项目
  • 创建可产生音乐的GAN

