Argmax Operation

In mathematical terms, the argmax function returns the input value at which a given function attains its maximum value. The argmax function is not continuous at points where the maximum value changes. It jumps from one value to another when you move across such points. Because of these discontinuities, the argmax function is not differentiable.

Consider the below example, in which we are finding the maximum index value of the tensor along the row.

Python3

import torch
 
torch.argmax(torch.tensor([[32,11,12,14],[1,123,12,212]]), dim = 1)

                    

Output:

tensor([0, 3])

where

  • the first value of 0 indicates index 0 of [32,11,12,14], implying 32 is the largest value.
  • the second value of 3 indicates index 3 of [1,123,12,212]-> implying 212 is the largest value

Non-Differentiable

The argmax function is discontinuous as it changes value suddenly. Also, the discontinuity point will vary depending on the values present in the tensor. Thus, the argmax function is non-differentiable. It does not have a well-defined derivative.

For training a neural network we need a differentiable function to backpropagate through loss for parameter updating. Therefore, the loss cannot be backpropagated through the network to update the weights if we have an the argmax function.

Application

In the context of neural networks, argmax is primarily a tool for interpreting or extracting the most likely results from the probabilistic outputs generated by the models. Some use cases where one can use argmax

  • Classification tasks: After training a neural network, argmax is often used in the inference phase to make predictions. For example, in a classification task, the argmax function is used to determine which class the network’s output corresponds to.
  • NLP: In sequence-to-sequence models for text generation, argmax is often used to choose the most likely next word or token at each step during decoding.
  • Conditional Sampling: In many generative neural networks like GAN and VAE, argmax is used for sampling to produce different results.

Since using argmax is paramount in various neural network architectures and can be updated, it becomes imperative to find ways to make it differentiable so that the loss can be backpropagated through the neural network and the model can update its weight.

Let us understand why argmax is not differentiable using code, and then we will discuss ways to counter it.
learn from

Error: Training a Neural Network with Argmax

Installations

!pip install torchinfo

Let us try to train a neural network with argmax as output and see what happens.

Import the necessary libraries

Python3

# Installing required libraries
import torchinfo
import torch
from torch import nn,optim
import torch.nn.functional as F

                    
  • Next, we create a model for classification of our image. We use two convolution layers each with ReLU activation and maxpooling followed by 3 fully connected layer . The final output of linear layer is of dimension 10. We pass this through an argmax .

Python3

# Creating our own LeNet5
 
 
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5# in channel , out channe, kernel
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d((2, 2))
 
        self.conv2 = nn.Conv2d(6, 16, 5# in channel , out channe, kernel
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d((2, 2))
 
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        # x = F.max_pool2d(F.relu(self.conv1(x)),(2, 2))
        # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
 
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
 
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
 
        x = torch.argmax(x, dim=1)
 
        return x
 
 
model = LeNet5()

                    
  • Let us load CIFAR10 dataset. The CIFAR-10 dataset consists of 60000 32×32 color images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We use transforms to convert the CIFAR 10 dataset to tensor and normalize it.

Python3

# Loading the dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10
train_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])
train_data = CIFAR10(root="./train/", train=True, download=True, transform=train_transforms)
trainloader = torch.utils.data.DataLoader( train_data,batch_size=16, shuffle=True)

                    
  • Define our loss function and optimizer

Python3

# Our loss function
def my_loss(output, target):
    output = torch.tensor(output, dtype=torch.float)
    loss = ((output - target)**2).mean()
    return loss
# Our optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)

                    
  • Train our model

Python3

# Training the model
 
N_EPOCHS = 2
for epoch in range(N_EPOCHS):
  epoch_loss = 0.0
  for inputs, labels in trainloader:
 
    optimizer.zero_grad()
 
    outputs = model(inputs)
 
    loss = my_loss(outputs, labels.float())
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  print("Epoch: {} Loss: {}".format(epoch,epoch_loss/len(trainloader)))

                    

Output:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If we train our model it gives us RuntimeError. This is because we have used argmax in our model. The argmax does not have a gradient function.

How Does PyTorch Backprop Through Argmax?

Backpropagation is a fundamental algorithm in training neural networks, allowing them to learn from data. Backpropagation involves iteratively updating the weights of a neural network to minimize the difference between predicted and actual outputs. It relies on the chain rule to calculate gradients, determining how much each parameter should be adjusted. This allows the network to learn and improve its predictions over time.

However, certain operations, such as argmax, present challenges during backpropagation due to their non-differentiable nature. In this article, we delve into how PyTorch handles backpropagation through the argmax operation and explore techniques like the Straight-Through Estimator (STE) that make this possible.

Similar Reads

Argmax Operation

In mathematical terms, the argmax function returns the input value at which a given function attains its maximum value. The argmax function is not continuous at points where the maximum value changes. It jumps from one value to another when you move across such points. Because of these discontinuities, the argmax function is not differentiable....

Strategies for Handling Argmax in PyTorch

...

Implementation using Gumbel Softmax

...

Contact Us