
PyTorch Lightning Hyperparameter Optimization with Optuna
Image by Author | Ideogram
PyTorch Lightning was released in recent years as a high-level alternative to the classical PyTorch library for deep learning modeling. It simplifies the process of training, validating, and deploying models. When it comes to hyperparameter optimization, that is, the process of finding the optimal set of model parameters that maximize performance on a given task, Optuna can be a great tool to be used in combination with PyTorch Lightning due to its seamless integration process and the efficient search algorithms it provides to find the best setting for your model among a ton of possible configurations.
This article shows how to jointly use PyTorch Lightning and Optuna to guide the hyperparameter optimization process for a deep learning model. It is recommended to have a basic knowledge of practical construction and training of neural networks, ideally with PyTorch.
Step-by-Step Process
The process starts by installing and importing a series of necessary libraries and modules, including PyTorch Lightning and Optuna. The initial installation process may take some time to complete.
|
1 2 3 |
pip install pytorch_lightning pip install optuna pip install optuna-integration[pytorch_lightning] |
Now, the many imports:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import os import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.loggers import TensorBoardLogger import optuna from optuna.integration import PyTorchLightningPruningCallback |
When building neural network models with PyTorch Lightning, it is a common practice to set a random seed for reproducibility. You can do this by adding pl.seed_everything(42) at the start of your code, right after the imports.
Next, we define our neural network model architecture by creating a class that inherits pl.LightningModule: Lightning’s counterpart to PyTorch’s Module class.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
class MNISTClassifier(pl.LightningModule): def __init__(self, layer_1_size=128, layer_2_size=256, learning_rate=1e-3, dropout_rate=0.5): super().__init__() self.save_hyperparameters() # Neural network architecture self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_size) self.layer_2 = nn.Linear(self.hparams.layer_1_size, self.hparams.layer_2_size) self.layer_3 = nn.Linear(self.hparams.layer_2_size, 10) self.dropout = nn.Dropout(self.hparams.dropout_rate) def forward(self, x): # Flatten layer batch_size, _, _, _ = x.size() x = x.view(batch_size, -1) # Forward pass x = F.relu(self.layer_1(x)) x = self.dropout(x) x = F.relu(self.layer_2(x)) x = self.dropout(x) x = self.layer_3(x) return F.log_softmax(x, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) self.log('train_loss', loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) self.log('val_loss', loss, prog_bar=True) self.log('val_acc', acc, prog_bar=True) return {'val_loss': loss, 'val_acc': acc} def test_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) self.log('test_loss', loss, prog_bar=True) self.log('test_acc', acc, prog_bar=True) return {'test_loss': loss, 'test_acc': acc} def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return optimizer |
We named our class MNISTClassifier because it is a simple feed-forward neural network classifier we will train on the MNIST dataset for low-resolution image classification. It consists of a small number of linear layers preceded by an input layer that flattens the two-dimensional image data, and ReLU activation function in between. The class also defines both the forward() method for forward passes, and other methods that emulate training, test, and validation steps.
Outside the newly defined class, the following function will also be handy to calculate the mean accuracy across a set of many predictions:
|
1 2 |
def accuracy(preds, y): return (preds == y).float().mean() |
The following function establishes a data preparation pipeline that we will later apply to the MNIST dataset. It converts the dataset to tensors, normalizes it based on a priori known statistics of this dataset, downloads and splits the training data into training and validation, and creates one DataLoader object for each of the three data subsets.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
def prepare_data(): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # Mean and Stdev of MNIST dataset ]) # Download the dataset mnist_train = datasets.MNIST('data', train=True, download=True, transform=transform) mnist_test = datasets.MNIST('data', train=False, download=True, transform=transform) # Split original training data into training and validation sets mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) # Create DataLoaders for PyTorch data management train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True) val_loader = DataLoader(mnist_val, batch_size=64) test_loader = DataLoader(mnist_test, batch_size=64) |
The objective() function is the core element Optuna provides for defining a hyperparameter optimization framework. We define it by first defining a search space: hyperparameters and possible values for them to try. There can be architectural hyperparameters related to the layers of the neural network, or hyperparameters related to the training algorithm, like the learning rate and dropout rate, for fighting issues like overfitting.
Inside this function, we try to initialize and train a model for each possible hyperparameter setting, with a callback included for early stopping if the model stabilizes early. Just like standard PyTorch, a Trainer is used to model the training process.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Define the objective function for Optuna def objective(trial): # Set the hyperparameters to optimize layer_1_size = trial.suggest_int('layer_1_size', 64, 256) layer_2_size = trial.suggest_int('layer_2_size', 128, 512) learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True) dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.7) # Create the model with trial hyperparameters model = MNISTClassifier( layer_1_size=layer_1_size, layer_2_size=layer_2_size, learning_rate=learning_rate, dropout_rate=dropout_rate ) # Early stopping callback early_stop_callback = EarlyStopping( monitor='val_loss', patience=5, verbose=False, mode='min' ) # Optuna pruning callback pruning_callback = PyTorchLightningPruningCallback(trial, monitor='val_loss') # Logger logger = TensorBoardLogger(save_dir=os.getcwd(), name=f"optuna_logs/trial_{trial.number}") # Create trainer trainer = pl.Trainer( max_epochs=10, callbacks=[early_stop_callback, pruning_callback], logger=logger, enable_progress_bar=False, enable_model_summary=False ) # Preparing the data train_loader, val_loader, test_loader = prepare_data() # Training the model trainer.fit(model, train_loader, val_loader) # Final validation loss return trainer.callback_metrics['val_loss'].item() |
Another Optuna key function to govern the optimization process is run_optimization, where we indicate the number of random trials to run, based on the specifications defined in the previous optimization function.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 |
def run_optimization(n_trials=20): pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10) study = optuna.create_study(direction='minimize', pruner=pruner) study.optimize(objective, n_trials=n_trials) print("Best trial:") trial = study.best_trial print(f" Value: {trial.value}") print(" Params: ") for key, value in trial.params.items(): print(f" {key}: {value}") return study |
Once the hyperparameter optimization process has been completed and the best model configuration is identified, another function is needed to take those results and evaluate that model’s performance on a test set for final validation.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
def test_best_model(study): # Getting the best hyperparameters best_params = study.best_trial.params # Creating the model with the best hyperparameters model = MNISTClassifier( layer_1_size=best_params['layer_1_size'], layer_2_size=best_params['layer_2_size'], learning_rate=best_params['learning_rate'], dropout_rate=best_params['dropout_rate'] ) # Creating trainer instance trainer = pl.Trainer(max_epochs=10) # Preparing the data train_loader, val_loader, test_loader = prepare_data() # Training the model with the best hyperparameters trainer.fit(model, train_loader, val_loader) # Testing the model with the test data results = trainer.test(model, test_loader) return results |
Now that we have all the classes and functions we need, we finalize with a demo that puts it all together.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
study = run_optimization(n_trials=5) # Visualize the results try: # Plot optimization history optuna.visualization.plot_optimization_history(study) # Plot parameter importances optuna.visualization.plot_param_importances(study) # Plot parallel coordinate plot optuna.visualization.plot_parallel_coordinate(study) except ImportError: print("Visualization requires plotly. Install with: pip install plotly") # Test the best model results = test_best_model(study) print(f"Test results with best hyperparameters: {results}") |
Here is a TL;DR breakdown of the workflow:
- Run a study (set of Optuna experiments) specifying the number of trials
- Perform a series of visualizations of the optimization process along training procedures
- Once the best model is found, expose it to the test data to further evaluate it
Wrapping Up
This article illustrates how to use PyTorch Lightning and Optuna together to perform efficient and effective hyperparameter optimization for neural network models. Optuna provides enhanced algorithms for model tuning, and PyTorch Lightning is built on top of PyTorch to further simplify the process of neural network modeling at a higher level.






No comments yet.