PyTorch VAE实现

介绍

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,可用于从高维数据中学习潜在表示,并生成具有相似特征的新样本。在本文中,我们将使用PyTorch实现一个简单的VAE模型。

变分自编码器

VAE是一种概率生成模型,由一个编码器和一个解码器组成。编码器将输入数据映射到潜在空间中的概率分布,而解码器则从潜在空间中的样本生成输出数据。

VAE的核心思想是使用变分推断来近似学习潜在空间的分布。它引入了一个随机变量z,也称为潜在变量,来表示输入数据的潜在特征。VAE假设输入数据x是由潜在变量z生成的,即p(x|z)。然后,VAE使用一个编码器q(z|x)来近似学习潜在变量的后验分布p(z|x)。最后,VAE使用一个解码器p(x|z)来生成输出数据。

PyTorch实现

导入依赖

首先,我们需要导入PyTorch和其他必要的依赖库。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

定义VAE模型

接下来,我们将定义一个VAE模型。VAE模型由一个编码器和一个解码器组成。编码器将输入数据映射到潜在空间中的概率分布,而解码器从潜在空间中的样本生成输出数据。

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, latent_size * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def forward(self, x):
        x = self.encoder(x)
        mu, log_var = torch.chunk(x, 2, dim=1)
        z = self.reparameterize(mu, log_var)
        x = self.decoder(z)
        return x, mu, log_var

定义损失函数和优化器

我们将使用负对数似然损失函数来训练VAE模型,并使用Adam优化器进行参数优化。

def loss_function(x_hat, x, mu, log_var):
    BCE = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

model = VAE(input_size=784, hidden_size=256, latent_size=20)
optimizer = optim.Adam(model.parameters(), lr=0.001)

数据加载和训练

在训练VAE模型之前,我们需要加载数据集(如MNIST)并定义训练循环。

# 加载数据集
train_loader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.MNIST(root='.', train=True, transform=torchvision.transforms.ToTensor(), download=True),
    batch_size=128, shuffle=True
)

# 训练循环
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784)
        data = Variable(data)

        optimizer.zero_grad()
        recon_batch, mu, log_var = model(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: