1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| import torch import torch.nn as nn import torch.nn.functional as F
class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost): super(VectorQuantizer, self).__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings) def forward(self, z_e): """ z_e: shape [B, D, H, W] (假设是特征图) """ z_e_flat = z_e.permute(0, 2, 3, 1).contiguous() z_e_flat = z_e_flat.view(-1, self.embedding_dim) distances = (torch.sum(z_e_flat**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_e_flat, self.embedding.weight.t())) encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) z_q = self.embedding(encoding_indices).squeeze(1) z_q = z_q.view_as(z_e_flat) z_q = z_q.permute(0, 2, 1).contiguous() z_q = z_q.view_as(z_e) loss_codebook = F.mse_loss(z_q.detach(), z_e) loss_commitment = self.commitment_cost * F.mse_loss(z_e, z_q.detach()) z_q = z_e + (z_q - z_e).detach() return z_q, loss_codebook, loss_commitment, encoding_indices
class VQVAE(nn.Module): def __init__(self, in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64, commitment_cost=0.25): super(VQVAE, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, hidden_dim, 4, 2, 1), nn.ReLU(), nn.Conv2d(hidden_dim, hidden_dim, 4, 2, 1), nn.ReLU(), nn.Conv2d(hidden_dim, embedding_dim, 1) ) self.vector_quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) self.decoder = nn.Sequential( nn.Conv2d(embedding_dim, hidden_dim, 1), nn.ReLU(), nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(hidden_dim, in_channels, 4, 2, 1), nn.Sigmoid() ) def forward(self, x): z_e = self.encoder(x) z_q, loss_codebook, loss_commitment, indices = self.vector_quantizer(z_e) x_recon = self.decoder(z_q) return x_recon, loss_codebook, loss_commitment, indices
model = VQVAE(in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64)
|