Implementation: From Pytorch to Pytorch Lightning
Letâs illustrate the difference in code between a basic PyTorch script and its equivalent using PyTorch Lightning. Consider a simple training script for a neural network in both PyTorch and PyTorch Lightning.
Letâs compare the training and validation loops for a simple 3-layer neural network on the MNIST dataset using both PyTorch and PyTorch Lightning. The key ingredients include the model, dataset (MNIST), optimizer, and loss function.
PyTorch
Importing necessary libraries and modules
The code starts by importing the necessary libraries and modules for building and training the neural network. These include torch, torch.nn, torch.optim, and torchvision.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
2. Defining the neural network model
Next, the code defines a simple neural network model using PyTorchâs nn.Module class. The model consists of three fully connected layers (fc1, fc2, and fc3) with 256, 128, and 10 neurons, respectively. The output layer has 10 neurons, corresponding to the 10 classes of digits in the MNIST dataset.
The forward method defines the forward pass of the neural network, where the input is passed through each layer and transformed using the ReLU activation function.
# Define the model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
3. Loading the dataset
The code then loads the MNIST dataset using torchvision.datasets.MNIST. The dataset is preprocessed using the transforms.Compose method, which applies a series of transformations to the data. In this case, the data is converted to a tensor and normalized to have a mean of 0.1307 and a standard deviation of 0.3081.
The train_dataset and test_dataset objects are created and loaded into train_loader and test_loader, which are PyTorch DataLoader objects that handle batching and shuffling of the data.
# Load the dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
4. Initializing the model, optimizer, and loss function
The neural network model is initialized, and the stochastic gradient descent (SGD) optimizer and cross-entropy loss function are defined.
# Initialize the model
model = NeuralNetwork()
# Define the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
5. Training and validating the model
The code defines two functions, train and validate, which handle the training and validation of the neural network.
- The train function takes in the model, training data loader, optimizer, and loss function, and trains the model on the data in batches. The gradients are accumulated and the model weights are updated using the SGD optimizer.
- The validate function takes in the model, test data loader, and loss function, and evaluates the model on the test data. The test loss is computed and the accuracy of the model is calculated as the percentage of correct predictions.
Finally, the code trains and validates the neural network for 10 epochs using the train and validate functions.
# Training loop
def train(model, train_loader, optimizer, criterion):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Validation loop
def validate(model, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
# Train and validate the model
for epoch in range(10):
train(model, train_loader, optimizer, criterion)
validate(model, test_loader, criterion)
Output:
Validation set: Average loss: 0.0052, Accuracy: 9050/10000 (90.50%)
Validation set: Average loss: 0.0041, Accuracy: 9267/10000 (92.67%)
Validation set: Average loss: 0.0034, Accuracy: 9369/10000 (93.69%)
Validation set: Average loss: 0.0030, Accuracy: 9445/10000 (94.45%)
Validation set: Average loss: 0.0026, Accuracy: 9506/10000 (95.06%)
Validation set: Average loss: 0.0024, Accuracy: 9555/10000 (95.55%)
Validation set: Average loss: 0.0021, Accuracy: 9597/10000 (95.97%)
Validation set: Average loss: 0.0020, Accuracy: 9641/10000 (96.41%)
Validation set: Average loss: 0.0018, Accuracy: 9663/10000 (96.63%)
Validation set: Average loss: 0.0017, Accuracy: 9680/10000 (96.80%)
PyTorch Lightning
1. First, the necessary imports are made, including PyTorch, PyTorch Lightning, the MNIST dataset, and the Adam optimizer.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
import pytorch_lightning as pl
2. Defining the neural network model
Next, the MyModel class is defined, which inherits from pl.LightningModule. This class defines the neural network architecture, the forward pass, the training step, the validation step, and the configuration of the optimizer.
- In the __init__ method, the neural network architecture is defined using PyTorchâs nn.Sequential module. It consists of three fully connected layers with ReLU activation functions, and a final softmax layer for outputting probabilities.
- The forward method takes in an input tensor x, reshapes it to have the correct number of dimensions, and passes it through the neural network using the self.model module.
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(dim=1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Reshape the input
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log('val_loss', loss)
# Calculate accuracy
correct = (y_hat.argmax(1) == y).sum().item()
total = y.size(0)
self.log('accuracy', correct / total, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.001)
3. Loading the dataset
The MNIST dataset is loaded using MNIST class from torchvision.datasets. The training set and validation set are split into separate DataLoader objects for training and validation.
# Load the MNIST dataset
train_dataset = MNIST(root='.', train=True, transform=ToTensor(), download=True)
val_dataset = MNIST(root='.', train=False, transform=ToTensor())
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
4. Initializing the model
# Initialize the model
model = MyModel()
5. Train the model
The trainer object is created using pl.Trainer. The max_epochs argument is set to 10, which means that the model will be trained for 10 epochs. The accelerator argument is set to âcpuâ if a GPU is not available, otherwise it is set to âcpuâ.
# Initialize the trainer
trainer = pl.Trainer(max_epochs=10, accelerator = "cpu" if torch.cuda.is_available() else "cpu")
6. Model fitting
Finally, the model is trained using the trainer.fit method. It takes in the model object, and the train_loader and val_loader objects as arguments.
# Train the model
trainer.fit(model, train_loader, val_loader)
7. Validating the model
trainer.validate(model, val_loader)
Output:
[{'val_loss': 1.4881926774978638, 'accuracy': 0.9729999899864197}]
PyTorch vs PyTorch Lightning
The PyTorch research team at Facebook AI Research (FAIR) introduced PyTorch Lightning to address these challenges and provide a more organized and standardized approach. In this article, we will see the major differences between PyTorch Lightning and Pytorch.
Table of Content
- Pytorch
- Pytorch Lightning: Advanced Framework of Pytorch
- Pytorch vs Pytorch Lightning
- Implementation: From Pytorch to Pytorch Lightning
- Code Difference between PyTorch and PyTorch Lightning
Contact Us