How to handle overfitting in PyTorch models using Early Stopping

Overfitting is a challenge in machine learning, where a model performs well on training data but poorly on unseen data, due to learning excessive noise or details from the training dataset.

In the context of deep learning with PyTorch, one effective method to combat overfitting is implementing early stopping. This article explains how early stopping works, demonstrates how to implement it in PyTorch, and explores its benefits and considerations.

Table of Content

  • What is Early Stopping?
  • Benefits of Early Stopping
  • Steps needed to Implement Early Stopping in PyTorch
    • Step 1: Import Libraries
    • Step 2: Define the Neural Network Architecture
    • Step 3: Implement Early Stopping
    • Step 4: Load the Data
    • Step 5: Initialize the Model, Loss Function, and Optimizer
    • Step 6: Train the Model with Early Stopping
    • Step 7: Evaluate the Model
    • Building and Training a Simple Neural Network with Early Stopping in PyTorch
  • Conclusion

What is Early Stopping?

Early stopping is a regularization technique used to avoid overfitting during the training process. It involves stopping the training phase if the model’s performance on a validation set does not improve for a specified number of consecutive epochs, called the “patience” period. This ensures the model does not learn the noise and specific details of the training data, thereby enhancing its generalization capabilities.

Benefits of Early Stopping

  1. Prevents Overfitting: By halting training at the right time, early stopping ensures the model does not overfit.
  2. Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
  3. Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.

Steps needed to Implement Early Stopping in PyTorch

In this section, we are going to walk through the process of creating, training, and evaluating a simple neural network using PyTorch, focusing on the implementation of early stopping to prevent overfitting.

Step 1: Import Libraries

First, we import the necessary libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

Step 2: Define the Neural Network Architecture

Next, we define a simple neural network class using PyTorch’s nn.Module. The neural network has:

  • fc1, fc2, fc3: Fully connected layers with ReLU activations.
  • forward method: Defines the forward pass of the network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

Step 3: Implement Early Stopping

We implement an EarlyStopping class to halt training if the validation loss stops improving. Here the parameters are:

  • patience: Number of epochs to wait before stopping if no improvement.
  • delta: Minimum change in the monitored quantity to qualify as an improvement.
  • best_score, best_model_state: Track the best validation score and model state.
  • call method: Updates the early stopping logic.
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None

def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0

def load_best_model(self, model):
model.load_state_dict(self.best_model_state)

Step 4: Load the Data

We load and transform the MNIST dataset.

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Step 5: Initialize the Model, Loss Function, and Optimizer

We set up the model, criterion, and optimizer.

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Step 6: Train the Model with Early Stopping

We train the model, incorporating early stopping.

Here,

  • Train loop: Train the model, update weights, and calculate training loss.
  • Validation loop: Evaluate the model on validation data and calculate validation loss.
  • Early stopping check: Apply early stopping logic after each epoch.
early_stopping = EarlyStopping(patience=5, delta=0.01)
num_epochs = 100

for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)

train_loss /= len(train_loader.dataset)

model.eval()
val_loss = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)

val_loss /= len(test_loader.dataset)

print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break

early_stopping.load_best_model(model)

Step 7: Evaluate the Model

Finally, we evaluate the model’s accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.

model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Building and Training a Simple Neural Network with Early Stopping in PyTorch

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

        
# Data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.01)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)

    val_loss /= len(test_loader.dataset)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

# Load the best model
early_stopping.load_best_model(model)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Output:

Epoch 1, Train Loss: 0.3897, Val Loss: 0.2323
Epoch 2, Train Loss: 0.1778, Val Loss: 0.1427
Epoch 3, Train Loss: 0.1339, Val Loss: 0.1281
Epoch 4, Train Loss: 0.1068, Val Loss: 0.1271
Epoch 5, Train Loss: 0.0911, Val Loss: 0.1028
Epoch 6, Train Loss: 0.0807, Val Loss: 0.0852
Epoch 7, Train Loss: 0.0699, Val Loss: 0.0973
Epoch 8, Train Loss: 0.0659, Val Loss: 0.0941
Epoch 9, Train Loss: 0.0593, Val Loss: 0.0834
Epoch 10, Train Loss: 0.0542, Val Loss: 0.0954
Epoch 11, Train Loss: 0.0496, Val Loss: 0.0879
Early stopping
Accuracy of the model on the test images: 97.24%

Conclusion

In this tutorial, we demonstrated how to build, train, and evaluate a simple neural network using PyTorch, with a focus on implementing early stopping to prevent overfitting. This approach helps achieve better generalization by halting training when the validation performance stops improving.




Contact Us