Strategies for Handling Argmax in PyTorch

Now, let’s discuss the strategies we can use to address the challenges posed by the non-differentiable nature of argmax in PyTorch.

1. Straight-Through Estimator (STE)

This technique is used to handle non-differentiable operations during backpropagations in neural networks. It use a differentiable approximation during the backward pass to enable gradient flow through the argmax operation.

How it works :

In this method,

  • During the forward pass, we apply the argmax operation to determine the discrete category
  • During backward pass we let the gradient pass through as it is i.e. it acts as a identity function during backward pass which replicates the gradient of above layer and passes it as it is to the below layer without any modification. We pass the gradient through the argmax operation as if it were a continuous function.

2. Gumbel Softmax

Gumbel Softmax is another technique used for handling non-differentiable operations, particularly in the context of discrete variables such as the argmax operation. It introduce stochasticity to the argmax operation using the Gumbel Softmax distribution.

How it works :

This method is used for sampling from a continuous distribution .Lets understand with the help of an simple example:

  • Assume that there is discrete variables with two outcomes X1 and X2
  • Lets say we have a model with a hidden layer that produces a score for each of the outcomes of the discrete variable .
  • Let this score be S(X1) and S(X2). We take the highest score (argmax). The probability distribution will be given by
    • P(Xi) = softmax(S(X1),S(X2))
    • S(X1) and S(X2) are logits obtained from th mdoel
    • P(X1) and P(X2) is probability distribution for each of the categorical outcome
  • Now problem with above approach is that the sampling process is deterministic. We will always get the output with max index.
  • What if we want to actual sample in such a way that we get [1,0] for P(X1) times and [0,1] for P(X2) times. This is where Gumbel Max comes in.
  • The Gumbel max equation is:
    • Where Gi are i.i.d Gumbel(0,1) distribution
    • NOTE : Mathemetically it can be shown that the Y will be distributed with probability distribution same as P(Xi)= softmax(S(X1),S(X2)) i.e. the equation will give output [1,0] for P(X1) times and [0,1] for P(X2) times.
  • Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.
  • Here τ is a temperature hyperparameter. It controls the output variability.
    • When τ-> 0 the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution.
    • During training, in order let to allow gradients past the sample we start with large value of τ , then gradually anneal the temperature (but not completely to 0, as the gradients would blow up).
  • It’s important to observe that the output of the Gumbel Softmax function produces a vector that sums to 1, somewhat resembling a one-hot vector (although it technically isn’t one). This approximation does not completely replace the argmax operation.
  • To genuinely obtain a pure one-hot vector, the Straight-Through (ST) Gumbel Trick is applied.
    • During the forward pass, the code utilizes an argmax operation to obtain an actual one-hot vector.
    • However, during backpropagation, the softmax approximation of the argmax is used to maintain differentiability.

Thus, this method involves introducing stochasticity(using Gumbel distribution) into the discrete decision-making process by using a differentiable approximation(softmax) to the argmax operation.

Gumbel-Softmax is often used in scenarios where you want to create a stochastic decision-making process involving discrete variables like variational encoders. It enables to backpropagate through random samples of discrete variables. Gumbel-Max Trick is very similar to the Reparameterization track whereby we are combining the deterministic part (the model score) with the stochastic part (Gumbel noise ).

3. Customized Operations

We can create custom PyTorch operations that mimic the behavior of argmax but are differentiable.

How it works :

Here we design a custom operation that approximates the argmax operation, and ensure that it’s differentiable. Example softmax

The softmax function converts the input value to a value between 0 and 1, where the sum is 1.

Given an input vector z = [z1, z2, …, zn], where n is the number of classes, the softmax function computes the probability pi for class i as follows:

In this equation:

  • ezi represents the exponential of the ith element of the input vector.
  • The denominator is the sum of the exponentials of all elements in the input vector, which ensures that the probabilities sum to 1.

Once we convert the output to probabilities we use negative log likelihood to calculate the loss:

L(y,p) = -\sum (y_i * log(p_i))

In this equation:

  • L(y, p) is the negative log-likelihood loss for the example.
  • y is a vector of true labels, where yi is 1 for the true class and 0 for other classes (one-hot encoded).
  • p is a vector of predicted probabilities, where pi is the predicted probability for class i.

Above loss equation will give maximum loss to lowest value and minimum loss to highest value

This loss function based on log is continuous and differentiable. Hence we can backpropagate the loss through the neural networks.

How Does PyTorch Backprop Through Argmax?

Backpropagation is a fundamental algorithm in training neural networks, allowing them to learn from data. Backpropagation involves iteratively updating the weights of a neural network to minimize the difference between predicted and actual outputs. It relies on the chain rule to calculate gradients, determining how much each parameter should be adjusted. This allows the network to learn and improve its predictions over time.

However, certain operations, such as argmax, present challenges during backpropagation due to their non-differentiable nature. In this article, we delve into how PyTorch handles backpropagation through the argmax operation and explore techniques like the Straight-Through Estimator (STE) that make this possible.

Similar Reads

Argmax Operation

In mathematical terms, the argmax function returns the input value at which a given function attains its maximum value. The argmax function is not continuous at points where the maximum value changes. It jumps from one value to another when you move across such points. Because of these discontinuities, the argmax function is not differentiable....

Strategies for Handling Argmax in PyTorch

...

Implementation using Gumbel Softmax

...

Contact Us