基本信息

项目 内容
论文标题 Neural Discrete Representation Learning
作者 Aaron van den Oord, Oriol Vinyals, 和 Koray Kavukcuoglu
作者单位
发表会议/期刊 2017
论文链接
别名 Vector Quantized-Variational Autoencoder

方法概览

特点 文章性质
输入 单张 RGB 图像
输出 分类、分割
所属领域 视觉 Transformer

背景

  • 标准 VAE 的局限:
    • 潜在变量 z 是连续的(通常是高斯分布)。
    • 这导致生成的样本(尤其是图像)往往比较模糊
    • 连续潜在空间可能难以捕捉数据中固有的离散结构(如物体类别、音素、单词)。
  • VQ-VAE 的解决方案:
    • 放弃连续潜在变量:VQ-VAE 的编码器输出的不是分布参数,而是一个连续的潜在向量 z_e
    • 引入离散潜在空间:这个连续向量 z_e 会通过一个向量量化 (Vector Quantization) 过程,被映射到一个有限的、离散的码本 (Codebook) 中最接近的嵌入向量 (Embedding Vector) e 上。这个离散的 e 就是真正的潜在表示。
    • 解码器使用离散的 e:解码器接收这个离散的 e 来重构原始数据 x

创新点

  1. 解决了标准 VAE 潜在空间连续性带来的问题,并允许学习离散的、高维的潜在表示,特别适用于生成高质量图像、音频和视频等复杂数据。
  2. 一个 VQ-VAE 模型包含三个核心组件:
    1. 编码器 (Encoder - Enc):
      • 输入:原始数据 x (如图像)。
      • 输出:一个连续的潜在表示 z_e = Enc(x)。注意,这里没有输出方差或进行重参数化。
    2. 码本 (Codebook - E):
      • 一个可学习的嵌入矩阵 E ∈ R^(K×D),其中 K 是码本大小(嵌入向量的数量),D 是每个嵌入向量的维度。
      • 码本中的每一行 e_k (k=1,2,...,K) 是一个 D 维的嵌入向量。
    3. 解码器 (Decoder - Dec):
      • 输入:量化的潜在向量 z_q
      • 输出:重构的数据 x' = Dec(z_q)

怎么使用VAE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost

# 初始化码本
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

def forward(self, z_e):
"""
z_e: shape [B, D, H, W] (假设是特征图)
"""
# 1. 将 z_e 展平以便计算距离
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous() # [B, H, W, D]
z_e_flat = z_e_flat.view(-1, self.embedding_dim) # [B*H*W, D]

# 2. 计算 z_e_flat 与码本中所有嵌入的距离
distances = (torch.sum(z_e_flat**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(z_e_flat, self.embedding.weight.t()))
# distances.shape = [B*H*W, K]

# 3. 找到最近的嵌入索引
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # [B*H*W, 1]

# 4. 获取量化的向量 z_q
z_q = self.embedding(encoding_indices).squeeze(1) # [B*H*W, D]
z_q = z_q.view_as(z_e_flat) # [B*H*W, D]
z_q = z_q.permute(0, 2, 1).contiguous() # [B, D, H, W] - 恢复形状
z_q = z_q.view_as(z_e) # 确保形状完全匹配

# 5. 计算损失 (在 forward 中计算,便于返回)
# 码本损失: ||sg[z_e] - E||^2
loss_codebook = F.mse_loss(z_q.detach(), z_e)
# 承诺损失: γ ||z_e - sg[E]||^2
loss_commitment = self.commitment_cost * F.mse_loss(z_e, z_q.detach())

# 6. 直通估计器 (Straight-Through Estimator)
# 让梯度从 z_q "直通" 到 z_e
z_q = z_e + (z_q - z_e).detach()

return z_q, loss_codebook, loss_commitment, encoding_indices

class VQVAE(nn.Module):
def __init__(self, in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64, commitment_cost=0.25):
super(VQVAE, self).__init__()
# 简化的卷积编码器/解码器
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(hidden_dim, embedding_dim, 1) # 输出 embedding_dim 维的 z_e
)
self.vector_quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
self.decoder = nn.Sequential(
nn.Conv2d(embedding_dim, hidden_dim, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim, in_channels, 4, 2, 1),
nn.Sigmoid() # 假设输出在 [0,1]
)

def forward(self, x):
z_e = self.encoder(x)
z_q, loss_codebook, loss_commitment, indices = self.vector_quantizer(z_e)
x_recon = self.decoder(z_q)
return x_recon, loss_codebook, loss_commitment, indices

# 实例化模型
model = VQVAE(in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

model.train()
for epoch in range(100):
for batch_idx, (data, _) in enumerate(train_loader):
optimizer.zero_grad()

recon_data, loss_codebook, loss_commitment, _ = model(data)

# 总损失
loss_recon = F.mse_loss(recon_data, data)
loss = loss_recon + loss_codebook + loss_commitment

loss.backward()
optimizer.step()

if batch_idx % 100 == 0:
print(f'Loss: {loss.item():.4f} (Recon: {loss_recon.item():.4f}, '
f'Codebook: {loss_codebook.item():.4f}, Commitment: {loss_commitment.item():.4f})')

使用训练好的 VQ-VAE

  • 生成新样本:
    1. 从码本 E 中随机选择索引 k,得到 z_q = e_k
    2. 或者,训练一个自回归模型(如PixelCNN、Transformer)在离散的潜在代码序列上。这个自回归模型学习 p(z_q) 或 p(indices)。生成时,先用自回归模型采样一个 indices 序列,得到 z_q,再用 VQ-VAE 的解码器生成最终图像。这是 VQ-VAE 生成高质量样本的标准方式。
  • 重构数据: 直接将输入数据通过 model(data) 得到 recon_data
  • 获取离散表示: indices 就是输入数据的离散潜在代码,可用于聚类、检索等任务。

VAE

这里有一个非常好的解释为什么要做变分的自编码器:【VAE学习笔记】全面通透地理解VAE(Variational Auto Encoder)_vae架构-CSDN博客