Managing a PyTorch Training Process with Checkpoints and Early Stopping

A large deep learning model can take a long time to train. You lose a lot of work if the training process interrupted in the middle. But sometimes, you actually want to interrupt the training process in the middle because you know going any further would not give you a better model. In this post, you will discover how to control the training loop in PyTorch such that you can resume an interrupted process, or early stop the training loop.

After completing this post, you will know:

  • The importance of checkpointing neural network models when training
  • How to checkpoint a model during training and retore it later
  • How to terminate training loop early with checkpointing

Kick-start your project with my book Deep Learning with PyTorch. It provides self-study tutorials with working code.


Let’s get started.

Managing a PyTorch Training Process with Checkpoints and Early Stopping
Photo by Arron Choi. Some rights reserved.

Overview

This chapter is in two parts; they are:

  • Checkpointing Neural Network Models
  • Checkpointing with Early Stopping

Checkpointing Neural Network Models

A lot of systems have states. If you can save all its state from a system and restore it later, you can always move back in a particular point in time about how a system behaves. If you worked on Microsoft Word and saved multiple versions of a document because you don’t know if you want to revert back your edit, it is the same idea here.

Same applies to long-running processes. Application checkpointing is a fault tolerance technique. In this approach, a snapshot of the state of the system is taken in case of system failure. If there is a problem, you can resume from the snapshot. The checkpoint may be used directly or as the starting point for a new run, picking up where it left off. When training deep learning models, the checkpoint captures the weights of the model. These weights can be used to make predictions as-is or as the basis for ongoing training.

PyTorch does not provide any function for checkpointing but it has functions for retrieving and restoring weights of a model. So you can implement checkpointing logic with them. Let’s make a checkpoint and a resume function, which simply save weights from a model and load them back:

Below is how you would usually do to train a PyTorch model. The dataset used is fetched from OpenML platform. It is a binary classification dataset. PyTorch DataLoader is used in this example to make the training loop more concise.

If you want to add checkpoints to the training loop above, you can do it at the end of the outer for-loop, where the model validation with the test set is done. Let’s say, the following:

You will see a number of files created in your working directory. This code is going to checkpoint the model from epoch 7, for example, into file epoch-7.pth. Each of these file is a ZIP file with the pickled model weight. Nothing forbid you to checkpoint inside the inner for-loop but due to the overhead it incurs, it is not a good idea to checkpoint too frequent.

As a fault tolerance technique, by adding a few lines of code before the training loop, you can resume from a particular epoch:

That is, if the training loop was interrupted in the middle of epoch 8 so the last checkpoint is from epoch 7, setting start_epoch = 8 above will do.

Note that if you do so, the random_split() function that generate the training set and test set may give you different split due to the random nature. If that’s a concern for you, you should have a consistent way of creating the datasets (e.g., save the splitted data so you can reuse them).

Sometimes, there are states outside of the model and you may want to checkpoint it as well. One particular example is the optimizer, which in cases like Adam, there are dynamically adjusted momentum. If you restarted your training loop, you may want to restore the momentum at the optimizer as well. It is not difficult to do. The idea is to make your checkpoint() function more complicated, e.g.

and correspondingly, change your resume() function:

This works because in PyTorch, the torch.save() and torch.load() function are backed by pickle, so you can use it with a list or dict container.

To put everything together, below is the complete code:

Want to Get Started With Deep Learning with PyTorch?

Take my free email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Checkpointing with Early Stopping

Checkpointing is not only for fault tolerance. You can also use it to keep your best model. How to define what is the best is subjective but considering the score from the test set is a sensible method. Let’s say to keep only the best model ever found, you can modify the training loop as follows:

The variable best_accuracy is to keep track on the highest accuracy (acc) obtained so far, which is in a percentage range of 0 to 100. Whenever a higher accuracy is observed, the model is checkpointed to the file best_model.pth. The best model is restored after the entire training loop, via the resume() function you created before. Afterward, you can make predictions with the model on unseen data. Beware that, if you’re using a different metric for checkpointing, e.g., the cross entropy loss, the better model should come with a lower cross entropy. Thus you should keep track on the lowest cross entropy obtained.

You can also checkpoint the model per epoch unconditionally together with the best model checkpointing, as you are free to create multiple checkpoint files. Since the code above is the find the best model and make a copy of it, you may usually see a further optimization to the training loop by stopping it early if the hope to see model improvement is slim. This is the early stopping technique that can save time in training.

The code above validates the model with test set at the end of each epoch and keeps the best model found into a checkpoint file. The simplest strategy for early stopping is to set up a threshold of $k$ epochs. If you didn’t see the model improved over the last $k$ epochs, you terminate the training loop in the middle. This can be implemented as follows:

The threshold early_stop_thresh was set to 5 above. There is a variable best_epoch that remembers the epoch of the best model. If the model has not been improved for long enough, the outer for-loop will be terminated.

This design is a relief on one of the design parameter, n_epochs. You can now make n_epochs the maximum number of epochs to train the model, hence a larger number than needed and assured that usually your training loop will stop earlier. This is also a strategy to avoid overfitting: If the model is indeed perform worse as you trained it further on the test set, this early stopping logic will interrupt the training and restore the best checkpoint.

Tying everything together, the following is the complete code for checkpointing with early stopping:

You may see the above code to produce:

It stopped at end of epoch 17 for the best model obtained from epoch 11. Due to the stochastic nature of algorithm, you may see the result slightly different. But for sure, even when the maximum number of epochs set to 10000 above, the training loop indeed stopped much earlier.

Of course, you can design a more sophisticated early stopping strategy, e.g., run for at least $N$ epochs and then allow to early stop after $k$ epochs. You have all the freedom to tweak the code above to make the best training loop to fit your need.

Summary

In this chapter, you discovered the importance of checkpointing deep learning models for long training runs. You learned:

  • What is checkpointing and why it is useful
  • How to checkpoint your model and how to restore the checkpoint
  • Different strategies to use checkpoints
  • How to implement early stopping with checkpointing

Get Started on Deep Learning with PyTorch!

Deep Learning with PyTorch

Learn how to build deep learning models

...using the newly released PyTorch 2.0 library

Discover how in my new Ebook:
Deep Learning with PyTorch

It provides self-study tutorials with hundreds of working code to turn you from a novice to expert. It equips you with
tensor operation, training, evaluation, hyperparameter optimization, and much more...

Kick-start your deep learning journey with hands-on exercises


See What's Inside

2 Responses to Managing a PyTorch Training Process with Checkpoints and Early Stopping

  1. Avatar
    AM September 4, 2023 at 8:12 am #

    Thank you. Any articles that you know of where checkpoint is used to save some memory and allow for additional gradient accumulation steps?

Leave a Reply