Implementation using Gumbel Softmax

A variational autoencoder is an architecture composed of both an encoder and a decoder that is trained to minimize the reconstruction error between the encoded-decoded data and the initial data. In a VAE, the Gumbel-Softmax is commonly used to sample from categorical distributions that represent discrete latent variables.

  1. The input is encoded as a distribution over the latent space. Here the encoded distributions are chosen to be normal so that the encoder can be trained to return the mean and the covariance matrix that describe these Gaussians.
  2. A point from the latent space is sampled from that distribution. Here we use Gumbel Softmax for sampling
    1. Gumbel Sampling: Gumbel noise is added to the output of encoder. The Gumbel noise is generated from a Gumbel distribution, and it’s typically denoted as g=−log⁡(−log⁡(u))g=−log(−log(u)), where u is sampled from a uniform distribution.
    2. Gumbel-Softmax: The Gumbel noise generated is added to the output of Encoder , and then the Softmax function is applied to obtain a differentiable sample from the categorical distribution. The Softmax function has a temperature parameter (τ) that controls the smoothness of the approximation. Initial the temp is kept at 1 so that different latent space vector can be be generated and then slowly annealed to zero so that the Gumbel-Softmax distribution approaches a true categorical distribution.
  3. The sampled point is decoded and the reconstruction error can be computed . The loss function consists of two terms:
    1. Generative loss: Compares the model output with the model input
    2. Latent loss: Compares the latent vector with a zero mean, unit variance Gaussian distribution in order to force it to be Normal. distributions returned by the encoder are enforced to be close to a standard normal distribution.

Now, let us train the model using softmax . The notebook is available at Notebook link.

1. Importing Necessary Libraries


from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import numpy as np
import pandas as pd
import math


2. Loading the data :

We will use MNIST handwritten digit data to train our VAE.

MNIST Dataset Loading:

  • mnist_dataset_train: Loads the MNIST training dataset. Here, transforms.ToTensor() converts the images to PyTorch tensors.
  • mnist_dataset_test: Similar to mnist_dataset_train, but for the testing dataset.

Batch Size and DataLoader:

  • batch_size = 128: Defines the batch size for training and testing.
  • train_loader and test_loader: These are instances of, which is used to efficiently load and batch data. It takes the MNIST datasets and uses the specified batch size for training and testing. The shuffle=True argument ensures that the data is randomly shuffled during training, which can improve the learning process.


from import DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the MNIST dataset
mnist_dataset_train = datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor())
# Load the MNIST dataset
mnist_dataset_test = datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor())
batch_size = 128
train_loader =
    mnist_dataset_train, batch_size=batch_size, shuffle=True)
test_loader =
    mnist_dataset_test, batch_size=batch_size, shuffle=True)


3. Define Gumbel Softmax

Here is the breakdown of all the function


  • This function generates samples from a Gumbel distribution.
  • The function returns -torch.log(-torch.log(U + eps) + eps), which is a sample from the Gumbel distribution.


  • This function performs a Gumbel-Softmax sampling.
  • logits: The logits representing the output of the encoder block.
  • temperature: A temperature parameter controlling the level of smoothing in the sampling process.
  • Adds Gumbel noise to the logits and Returns the the softmax function with the specified temperature.


  • This function represents the Gumbel-Softmax distribution and allows for sampling from it.
  • temperature: A temperature parameter controlling the level of smoothing in the sampling process
  • If hard is False, returns the soft sample as a flattened tensor.
  • If hard is True, returns a one-hot vector by finding the index of the maximum value in each row of the soft sample.


def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape).to(device)
    return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    y = gumbel_softmax_sample(logits, temperature)
    if not hard:
        return y.view(-1, latent_dim * categorical_dim)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, latent_dim * categorical_dim)


Check the Gumbel Softmax performance


latent_dim = 2
input = torch.rand(1,latent_dim*categorical_dim)
print('Input:', input)
# With hard = False
print('\nGumbel Softmax with hard=False\n',gumbel_softmax(input, temp,False))
#With hard = True
print('\nGumbel Softmax with hard=True\n',gumbel_softmax(input, temp,  True))



Input: tensor([[0.2976, 0.1200, 0.3637, 0.3025, 0.3605, 0.7416, 0.5763, 0.8461, 0.9991,

Gumbel Softmax with hard=False
tensor([[0.0973, 0.0399, 0.0676, 0.0300, 0.0443, 0.0617, 0.2640, 0.1006, 0.2470,

Gumbel Softmax with hard=True
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

4. Define the VAE

The encode method transforms input data into a latent space, the decode method reconstructs data from the latent space, and the forward method combines these steps, also returning the Gumbel-Softmax samples and softmax probabilities.

  1. Encoder
    • Performs 3 hidden layer transformations with ReLU activation. It reduces the size from 784 to latent_dim*categorical_dim
  2. Sampling
    • z = gumbel_softmax(q_y, temp, hard): Samples from the Gumbel-Softmax distribution. This is the latent variable. This latent variable is then passed to decoder
  3. Decoder
    • Reconstructs the image using 3 hidden layer transformations which is the reverse of encoder architecture


class VAE_gumbel(nn.Module):
    def __init__(self, temp):
        super(VAE_gumbel, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))
    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))
    def forward(self, x, temp=0, hard=False):
        q = self.encode(x.view(-1, 784))
        q_y = q.view(q.size(0), latent_dim, categorical_dim)
        z = gumbel_softmax(q_y, temp, hard)
        return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())


5. Define the loss function

This is the standard VAE loss. The loss function consists of two components: the BCE loss, which measures the difference between the reconstructed data and the original data, and the KL divergence loss, which penalizes the divergence between the distribution of the latent variables and a chosen prior distribution. The final loss is the sum of these two components.


latent_dim = 20
categorical_dim = 10  # one-of-K vector
temp = 1
temp_min = 0.5
ANNEAL_RATE = 0.00003
model = VAE_gumbel(temp).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, qy):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]
    log_ratio = torch.log(qy * categorical_dim + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()
    return BCE + KLD


6. Define the train function

This training function iterates over batches of training data, performs forward and backward passes, updates the model parameters, and prints training progress. It also includes temperature annealing to control the exploration of the Gumbel-Softmax distribution during training

  • model.train(): Sets the model to training mode.
  • Moves the input data to the specified device and zeroes out the gradients accumulated in the optimizer.
  • Forward pass through the model to obtain the reconstructed batch and Gumbel-Softmax samples
  • Computes the loss using the reconstruction loss and KL divergence loss.
  • Backward pass to compute the gradients.
  • perform temperature annealing after every 100 batch
  • Print training progress after every 10 batch


def train(epoch, model, train_loader, optimizer, temp, cuda=True, hard=False):
  train_loss = 0
  for batch_idx, (data, _) in enumerate(train_loader):
    recon_batch, q_y = model(data, temp, hard)
    loss = loss_function(recon_batch, data, q_y)
    train_loss += loss.item() * len(data)
    if batch_idx % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)
    if batch_idx % 100 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader),
  print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))


7. Define the test function

This code evaluates a given model on a test dataset, computing and printing the average reconstruction loss. Additionally, it saves a visual comparison of original and reconstructed images for the first batch.

  • model.eval(): Sets the model to evaluation mode. This is important because it disables gradient calculation like dropout, which are typically used during training but not during evaluation.
  • test_loss = 0: Initializes a variable to accumulate the test loss.
  • The function then iterates over the batches in the test_loader. For each batch:
    • Passes the input data through the model to obtain the reconstructed batch (recon_batch) and some output qy.
    • Computes the loss between the reconstructed batch and the original data using a loss_function. The loss is then added to the test_loss variable, scaled by the batch size.
    • Additionally, there’s an annealing mechanism for the temperature (temp) used in the model. If i % 100 == 1, the temperature is updated. This is likely part of some annealing schedule, where the temperature is reduced during training.
  • Finally, if i == 0, that is for every first batch we visualizs the reconstruction of the first batch by concatenating the original and reconstructed images and saving them using the save_image function.


def test(epoch, model, test_loader, temp, cuda=True, hard=False):
  test_loss = 0
  for i, (data, _) in enumerate(test_loader):
    recon_batch, qy = model(data, temp, hard)
    test_loss += loss_function(recon_batch, data, qy).item() * len(data)
    if i % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)
    if i == 0:
        n = min(data.size(0), 8)
        comparison =[data[:n],recon_batch.view(128, 1, 28, 28)[:n]])
        save_image(,f"./reconstruction_{epoch:03d}.png", nrow=n)
  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))


8. Training

We run the train and test function for 10 epochs


epochs = 10
prec = math.ceil(math.log10(epochs / 100))
model = VAE_gumbel(latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
from torch.autograd import Variable
for epoch in range(1, epochs + 1):
    train(epoch, model, train_loader, optimizer, temp, True)
    test(epoch, model, test_loader, temp, True)



Train Epoch: 1 [0/60000 (0%)]    Loss: 543.420044
Train Epoch: 1 [12800/60000 (21%)] Loss: 203.588531
Train Epoch: 1 [25600/60000 (43%)] Loss: 200.128128
Train Epoch: 1 [38400/60000 (64%)] Loss: 193.303040
Train Epoch: 1 [51200/60000 (85%)] Loss: 186.463058
====> Epoch: 1 Average loss: 202.1153
====> Test set loss: 178.8977
Train Epoch: 2 [0/60000 (0%)] Loss: 178.230606
Train Epoch: 2 [12800/60000 (21%)] Loss: 168.383270
Train Epoch: 2 [25600/60000 (43%)] Loss: 149.937363
Train Epoch: 2 [38400/60000 (64%)] Loss: 144.123688
Train Epoch: 2 [51200/60000 (85%)] Loss: 146.283325
====> Epoch: 2 Average loss: 156.7775
====> Test set loss: 143.5305
Train Epoch: 3 [0/60000 (0%)] Loss: 145.444778
Train Epoch: 3 [12800/60000 (21%)] Loss: 141.762360


In this article we saw how we can use different methods to make argmax differentiable . Making the argmax operation differentiable allows for gradient flow during backpropagation. Three commonly used methods are the Straight-Through Estimator, Gumbel-Softmax, and Custom Operations:

Straight-Through Estimator (STE): The Straight-Through Estimator is a simple and intuitive method for making the argmax operation differentiable which allows the incoming gradient to pass through as it is.

Gumbel-Softmax: The Gumbel-Softmax method uses parameterization trick to introduce Gumbel noise so as to make sampling differentiable and uses softmax to make argmax differentiable . We saw a detailed implementation of this using VAE.

Custom Operations: We can create custom, differentiable operations that approximate the argmax operation in a way that allows gradients to flow through.

How Does PyTorch Backprop Through Argmax?

Backpropagation is a fundamental algorithm in training neural networks, allowing them to learn from data. Backpropagation involves iteratively updating the weights of a neural network to minimize the difference between predicted and actual outputs. It relies on the chain rule to calculate gradients, determining how much each parameter should be adjusted. This allows the network to learn and improve its predictions over time.

However, certain operations, such as argmax, present challenges during backpropagation due to their non-differentiable nature. In this article, we delve into how PyTorch handles backpropagation through the argmax operation and explore techniques like the Straight-Through Estimator (STE) that make this possible.

