Argmax Operation
In mathematical terms, the argmax function returns the input value at which a given function attains its maximum value. The argmax function is not continuous at points where the maximum value changes. It jumps from one value to another when you move across such points. Because of these discontinuities, the argmax function is not differentiable.
Consider the below example, in which we are finding the maximum index value of the tensor along the row.
Python3
import torch torch.argmax(torch.tensor([[ 32 , 11 , 12 , 14 ],[ 1 , 123 , 12 , 212 ]]), dim = 1 ) |
Output:
tensor([0, 3])
where
- the first value of 0 indicates index 0 of [32,11,12,14], implying 32 is the largest value.
- the second value of 3 indicates index 3 of [1,123,12,212]-> implying 212 is the largest value
Non-Differentiable
The argmax function is discontinuous as it changes value suddenly. Also, the discontinuity point will vary depending on the values present in the tensor. Thus, the argmax function is non-differentiable. It does not have a well-defined derivative.
For training a neural network we need a differentiable function to backpropagate through loss for parameter updating. Therefore, the loss cannot be backpropagated through the network to update the weights if we have an the argmax function.
Application
In the context of neural networks, argmax is primarily a tool for interpreting or extracting the most likely results from the probabilistic outputs generated by the models. Some use cases where one can use argmax
- Classification tasks: After training a neural network, argmax is often used in the inference phase to make predictions. For example, in a classification task, the argmax function is used to determine which class the network’s output corresponds to.
- NLP: In sequence-to-sequence models for text generation, argmax is often used to choose the most likely next word or token at each step during decoding.
- Conditional Sampling: In many generative neural networks like GAN and VAE, argmax is used for sampling to produce different results.
Since using argmax is paramount in various neural network architectures and can be updated, it becomes imperative to find ways to make it differentiable so that the loss can be backpropagated through the neural network and the model can update its weight.
Let us understand why argmax is not differentiable using code, and then we will discuss ways to counter it.
learn from
Error: Training a Neural Network with Argmax
Installations
!pip install torchinfo
Let us try to train a neural network with argmax as output and see what happens.
Import the necessary libraries
Python3
# Installing required libraries import torchinfo import torch from torch import nn,optim import torch.nn.functional as F |
- Next, we create a model for classification of our image. We use two convolution layers each with ReLU activation and maxpooling followed by 3 fully connected layer . The final output of linear layer is of dimension 10. We pass this through an argmax .
Python3
# Creating our own LeNet5 class LeNet5(nn.Module): def __init__( self ): super (LeNet5, self ).__init__() self .conv1 = nn.Conv2d( 3 , 6 , 5 ) # in channel , out channe, kernel self .relu1 = nn.ReLU() self .maxpool1 = nn.MaxPool2d(( 2 , 2 )) self .conv2 = nn.Conv2d( 6 , 16 , 5 ) # in channel , out channe, kernel self .relu2 = nn.ReLU() self .maxpool2 = nn.MaxPool2d(( 2 , 2 )) self .fc1 = nn.Linear( 16 * 5 * 5 , 120 ) self .fc2 = nn.Linear( 120 , 84 ) self .fc3 = nn.Linear( 84 , 10 ) def forward( self , x): # x = F.max_pool2d(F.relu(self.conv1(x)),(2, 2)) # x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = self .conv1(x) x = self .relu1(x) x = self .maxpool1(x) x = self .conv2(x) x = self .relu2(x) x = self .maxpool2(x) x = x.view( - 1 , int (x.nelement() / x.shape[ 0 ])) x = F.relu( self .fc1(x)) x = F.relu( self .fc2(x)) x = self .fc3(x) x = torch.argmax(x, dim = 1 ) return x model = LeNet5() |
- Let us load CIFAR10 dataset. The CIFAR-10 dataset consists of 60000 32×32 color images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We use transforms to convert the CIFAR 10 dataset to tensor and normalize it.
Python3
# Loading the dataset from torchvision import transforms from torchvision.datasets import CIFAR10 train_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean = ( 0.4914 , 0.4822 , 0.4465 ), std = ( 0.2023 , 0.1994 , 0.2010 ))]) train_data = CIFAR10(root = "./train/" , train = True , download = True , transform = train_transforms) trainloader = torch.utils.data.DataLoader( train_data,batch_size = 16 , shuffle = True ) |
- Define our loss function and optimizer
Python3
# Our loss function def my_loss(output, target): output = torch.tensor(output, dtype = torch. float ) loss = ((output - target) * * 2 ).mean() return loss # Our optimizer optimizer = optim.SGD(model.parameters(),lr = 0.001 , momentum = 0.9 ) |
- Train our model
Python3
# Training the model N_EPOCHS = 2 for epoch in range (N_EPOCHS): epoch_loss = 0.0 for inputs, labels in trainloader: optimizer.zero_grad() outputs = model(inputs) loss = my_loss(outputs, labels. float ()) loss.backward() optimizer.step() epoch_loss + = loss.item() print ( "Epoch: {} Loss: {}" . format (epoch,epoch_loss / len (trainloader))) |
Output:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
If we train our model it gives us RuntimeError. This is because we have used argmax in our model. The argmax does not have a gradient function.
How Does PyTorch Backprop Through Argmax?
Backpropagation is a fundamental algorithm in training neural networks, allowing them to learn from data. Backpropagation involves iteratively updating the weights of a neural network to minimize the difference between predicted and actual outputs. It relies on the chain rule to calculate gradients, determining how much each parameter should be adjusted. This allows the network to learn and improve its predictions over time.
However, certain operations, such as argmax, present challenges during backpropagation due to their non-differentiable nature. In this article, we delve into how PyTorch handles backpropagation through the argmax operation and explore techniques like the Straight-Through Estimator (STE) that make this possible.
Contact Us