PyTorch VAE实现
介绍
变分自编码器(Variational Autoencoder, VAE)是一种生成模型,可用于从高维数据中学习潜在表示,并生成具有相似特征的新样本。在本文中,我们将使用PyTorch实现一个简单的VAE模型。
变分自编码器
VAE是一种概率生成模型,由一个编码器和一个解码器组成。编码器将输入数据映射到潜在空间中的概率分布,而解码器则从潜在空间中的样本生成输出数据。
VAE的核心思想是使用变分推断来近似学习潜在空间的分布。它引入了一个随机变量z,也称为潜在变量,来表示输入数据的潜在特征。VAE假设输入数据x是由潜在变量z生成的,即p(x|z)。然后,VAE使用一个编码器q(z|x)来近似学习潜在变量的后验分布p(z|x)。最后,VAE使用一个解码器p(x|z)来生成输出数据。
PyTorch实现
导入依赖
首先,我们需要导入PyTorch和其他必要的依赖库。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
定义VAE模型
接下来,我们将定义一个VAE模型。VAE模型由一个编码器和一个解码器组成。编码器将输入数据映射到潜在空间中的概率分布,而解码器从潜在空间中的样本生成输出数据。
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, latent_size * 2)
)
self.decoder = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
nn.Sigmoid()
)
def reparameterize(self, mu, log_var):
std = torch.exp(log_var / 2)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def forward(self, x):
x = self.encoder(x)
mu, log_var = torch.chunk(x, 2, dim=1)
z = self.reparameterize(mu, log_var)
x = self.decoder(z)
return x, mu, log_var
定义损失函数和优化器
我们将使用负对数似然损失函数来训练VAE模型,并使用Adam优化器进行参数优化。
def loss_function(x_hat, x, mu, log_var):
BCE = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
model = VAE(input_size=784, hidden_size=256, latent_size=20)
optimizer = optim.Adam(model.parameters(), lr=0.001)
数据加载和训练
在训练VAE模型之前,我们需要加载数据集(如MNIST)并定义训练循环。
# 加载数据集
train_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(root='.', train=True, transform=torchvision.transforms.ToTensor(), download=True),
batch_size=128, shuffle=True
)
# 训练循环
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, 784)
data = Variable(data)
optimizer.zero_grad()
recon_batch, mu, log_var = model(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: