Saving and Loading Model

Method 1: Using torch.save() and torch.load()

The following code shows method to save and load the model using the built-in function provided by the torch module. The torch.save() method directly saves model object into the file and the torch.load() loads the model back into the memory.

Python




# Save the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
 
# Load the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
 
# Set the model to evaluation mode
loaded_model.eval()


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Method 2: Using model.state_dict()

Now, let us see another way to save and load the model using the state_dict() method. This method stores the parameters of the created model. When the model is loaded, a new model with the same architecture is created. Then, the parameters of the new model are replaced with the stored parameters. Since only parameters are stored, this method is memory efficient. The following code snippet illustrates this method.

Python




# Saving the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
 
# Loading the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
print(loaded_model)


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Method 3: Saving and Loading using the Checkpoints

The checkpoints method saves the model by creating a dictionary that contains all the necessary information like model state_dict, optimizer state_dict, current epoch, loss, etc. And, to load the model, the checkpoint file is loaded to retrieve the information. This method is demonstrated as shown below:

Python




# Saving the model
checkpoint = {
    'epoch': epoch,
    'model_state_dict': cnn_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # you may add other information to add
}
torch.save(checkpoint, 'checkpoint.pth')
 
# Loading the model
checkpoint = torch.load('checkpoint.pth')
cnn_model = SimpleCNN()
cnn_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(cnn_model)


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Save and Load Models in PyTorch

It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second option. So in this article, we will see how to implement the concept of saving and loading the models using PyTorch.

Table of Content

  • What is PyTorch?
  • Stepwise Guide to Save and Load Models in PyTorch
  • Saving and Loading Model
  • Frequently Asked Questions

Similar Reads

What is PyTorch?

PyTorch is an open-source Machine Learning Library that works on the dynamic computation graph. In the static computation approach, the models are predefined before the execution. But in dynamic computation which PyTorch follows, the structure of the graph in the Neural Network can change during the execution based on the input data. Hence, It allows to creation and training the Neural Networks to extract hidden patterns from the data....

Stepwise Guide to Save and Load Models in PyTorch

Now, we will see how to create a Model using the PyTorch....

Saving and Loading Model

...

Conclusion

...

Frequently Asked Questions

...

Contact Us