Part 6 - Saving and Loading Models


This is from Udacity's Deep Learning Repository which supports their Deep Learning Nanodegree.

In this notebook we're going to look at how to save and load models with PyTorch.

Set Up



from pathlib import Path


from dotenv import load_dotenv
import matplotlib.pyplot as pyplot
import seaborn
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms

Nano Program

from nano.pytorch import helper

This Project

from neurotic.tangles.data_paths import DataPathTwo
from fashion import (


The Data

Once again we're going to use the fashion-MNIST data.

The Path

path = DataPathTwo(folder_key="FASHION_MNIST")

Define a transform to normalize the data

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Download and Load the Training Data

trainset = datasets.FashionMNIST(path.folder, download=True, train=True,
training =, batch_size=64,

Download and Load the Test Data

testset = datasets.FashionMNIST(path.folder, download=True, train=False,
testing =, batch_size=64, shuffle=True)

Here's one of the images.

image, label = next(iter(trainloader))



Training the Network

I'm re-using the DropoutModel from the previous lesson about avoiding over-fitting using dropout. I'm also re-using the (somewhat updated) train function.

model = DropoutModel()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model=model, optimizer=optimizer, criterion=criterion,
      train_batches=training, test_batches=testing, epochs=2)
Epoch: 1/30 Training loss: 2.41 Test Loss: 2.40 Test Accuracy: 0.09
Epoch: 2/30 Training loss: 2.41 Test Loss: 2.40 Test Accuracy: 0.09

Saving and loading networks

Rather than re-training your model every time you want to use it you can instead save it an re-load the pre-trained model when you need it.

The parameters for PyTorch networks are stored in a model's state_dict.

print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())
Our model: 

  (input_to_hidden): Linear(in_features=784, out_features=256, bias=True)
  (hidden_1_to_hidden_2): Linear(in_features=256, out_features=128, bias=True)
  (hidden_2_to_hidden_3): Linear(in_features=128, out_features=64, bias=True)
  (hidden_3_to_output): Linear(in_features=64, out_features=10, bias=True)
  (dropout): Dropout(p=0.2)

The state dict keys: 

 odict_keys(['input_to_hidden.weight', 'input_to_hidden.bias', 'hidden_1_to_hidden_2.weight', 'hidden_1_to_hidden_2.bias', 'hidden_2_to_hidden_3.weight', 'hidden_2_to_hidden_3.bias', 'hidden_3_to_output.weight', 'hidden_3_to_output.bias'])

The simplest thing to do is simply save the state dict with, which uses python's pickle to serialze the settings. PyTorch has an explanation for why you would prefer saving the settings instead of the entire model.

As an example, we can save our trained model's settings to a file checkpoint.pth.

file_name = "checkpoint.pth", file_name)
check_path = Path(file_name)
print("File Size: {} K".format(check_path.stat().st_size/10**3))
File Size: 972.392 K

I couldn't find an explanation for the file-extension, but the pytorch documentation mentions that it's a convention to use .pt and .pth as extensions. I'm assuming pt is for PyTorch and the h is for hyper-parameters, but I'm not really sure that it's the case.

To load the model you can use torch.load.

state_dict = torch.load('checkpoint.pth')
odict_keys(['input_to_hidden.weight', 'input_to_hidden.bias', 'hidden_1_to_hidden_2.weight', 'hidden_1_to_hidden_2.bias', 'hidden_2_to_hidden_3.weight', 'hidden_2_to_hidden_3.bias', 'hidden_3_to_output.weight', 'hidden_3_to_output.bias'])

To load the state-dict you take your instantiated but untrained model and call its load_state_dict method.


Seems pretty straightforward, but as usual it's a bit more complicated. Loading the state dict works only if the model architecture is exactly the same as the checkpoint architecture. Using a model with a different architecture, this fails.

parameters = HyperParameters()
parameters.hidden_layer_1 = 400
bad_model = DropoutModel(parameters)
# This will throw an error because the tensor sizes are wrong!
RuntimeError: Error(s) in loading state_dict for DropoutModel:
        size mismatch for input_to_hidden.weight: copying a param of torch.Size([400, 784]) from checkpoint, where the shape is torch.Size([256, 784]) in current model.
        size mismatch for input_to_hidden.bias: copying a param of torch.Size([400]) from checkpoint, where the shape is torch.Size([256]) in current model.
        size mismatch for hidden_1_to_hidden_2.weight: copying a param of torch.Size([128, 400]) from checkpoint, where the shape is torch.Size([128, 256]) in current model.

This means we need to rebuild the model exactly as it was when trained. Information about the model architecture needs to be saved in the checkpoint, along with the state dict. To do this, you build a dictionary with all the information you need to compeletely rebuild the model.

Originally the bad-model was just called 'model' and that seems to have messed up the state-dict so I'm going to re-use the one we made before.

checkpoint = {'hyperparameters': HyperParameters,
              'state_dict': state_dict}, file_name)

Remember that this is using pickle under the hood so whatever you save has to be pickleable. It probably would be safer to use parameters instead of a settings object like I did, but I didn't know we were going to be doing this.

Here's a function to load checkpoint-files.

def load_checkpoint(filepath: str) -> nn.Module:
    """Load the model checkpoint from disk

     filepath: path to the saved checkpoint
    checkpoint = torch.load(filepath)
    model = DropoutModel(checkpoint["hyperparameters"])
    return model

You can see from the function that the checkpoint is really just pickling a dictionary, and we can add any arbitrary things we want to it. I'm not really sure what it gives that using pickle directly doesn't have.

model = load_checkpoint(file_name)
  (input_to_hidden): Linear(in_features=784, out_features=256, bias=True)
  (hidden_1_to_hidden_2): Linear(in_features=256, out_features=128, bias=True)
  (hidden_2_to_hidden_3): Linear(in_features=128, out_features=64, bias=True)
  (hidden_3_to_output): Linear(in_features=64, out_features=10, bias=True)
  (dropout): Dropout(p=0.2)

PyTorch has more about saving and loading models in their documentation, including saving your model to continue training later (you need to save more than the model's settings).