[New Book] Click to get Mastering Digital Art with Stable Diffusion!
Use the offer code 20offearlybird to get 20% off. Hurry, sale ends soon!

How to Use Test-Time Augmentation to Make Better Predictions

Data augmentation is a technique often used to improve performance and reduce generalization error when training neural network models for computer vision problems.

The image data augmentation technique can also be applied when making predictions with a fit model in order to allow the model to make predictions for multiple different versions of each image in the test dataset. The predictions on the augmented images can be averaged, which can result in better predictive performance.

In this tutorial, you will discover test-time augmentation for improving the performance of models for image classification tasks.

After completing this tutorial, you will know:

  • Test-time augmentation is the application of data augmentation techniques normally used during training when making predictions.
  • How to implement test-time augmentation from scratch in Keras.
  • How to use test-time augmentation to improve the performance of a convolutional neural network model on a standard image classification task.

Kick-start your project with my new book Deep Learning for Computer Vision, including step-by-step tutorials and the Python source code files for all examples.

Let’s get started.

How to Use Test-Time Augmentation to Improve Model Performance for Image Classification

How to Use Test-Time Augmentation to Improve Model Performance for Image Classification
Photo by daveynin, some rights reserved.

Tutorial Overview

This tutorial is divided into five parts; they are:

  1. Test-Time Augmentation
  2. Test-Time Augmentation in Keras
  3. Dataset and Baseline Model
  4. Example of Test-Time Augmentation
  5. How to Tune Test-Time Augmentation Configuration

Test-Time Augmentation

Data augmentation is an approach typically used during the training of the model that expands the training set with modified copies of samples from the training dataset.

Data augmentation is often performed with image data, where copies of images in the training dataset are created with some image manipulation techniques performed, such as zooms, flips, shifts, and more.

The artificially expanded training dataset can result in a more skillful model, as often the performance of deep learning models continues to scale in concert with the size of the training dataset. In addition, the modified or augmented versions of the images in the training dataset assist the model in extracting and learning features in a way that is invariant to their position, lighting, and more.

Test-time augmentation, or TTA for short, is an application of data augmentation to the test dataset.

Specifically, it involves creating multiple augmented copies of each image in the test set, having the model make a prediction for each, then returning an ensemble of those predictions.

Augmentations are chosen to give the model the best opportunity for correctly classifying a given image, and the number of copies of an image for which a model must make a prediction is often small, such as less than 10 or 20.

Often, a single simple test-time augmentation is performed, such as a shift, crop, or image flip.

In their 2015 paper that achieved then state-of-the-art results on the ILSVRC dataset titled “Very Deep Convolutional Networks for Large-Scale Image Recognition,” the authors use horizontal flip test-time augmentation:

We also augment the test set by horizontal flipping of the images; the soft-max class posteriors of the original and flipped images are averaged to obtain the final scores for the image.

Similarly, in their 2015 paper on the inception architecture titled “Rethinking the Inception Architecture for Computer Vision,” the authors at Google use cropping test-time augmentation, which they refer to as multi-crop evaluation.

Want Results with Deep Learning for Computer Vision?

Take my free 7-day email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Test-Time Augmentation in Keras

Test-time augmentation is not provided natively in the Keras deep learning library but can be implemented easily.

The ImageDataGenerator class can be used to configure the choice of test-time augmentation. For example, the data generator below is configured for horizontal flip image data augmentation.

The augmentation can then be applied to each sample in the test dataset separately.

First, the dimensions of the single image can be expanded from [rows][cols][channels] to [samples][rows][cols][channels], where the number of samples is one, for the single image. This transforms the array for the image into an array of samples with one image.

Next, an iterator can be created for the sample, and the batch size can be used to specify the number of augmented images to generate, such as 10.

The iterator can then be passed to the predict_generator() function of the model in order to make a prediction. Specifically, a batch of 10 augmented images will be generated and the model will make a prediction for each.

Finally, an ensemble prediction can be made. A prediction was made for each image, and each prediction contains a probability of the image belonging to each class, in the case of image multiclass classification.

An ensemble prediction can be made using soft voting where the probabilities of each class are summed across the predictions and a class prediction is made by calculating the argmax() of the summed predictions, returning the index or class number of the largest summed probability.

We can tie these elements together into a function that will take a configured data generator, fit model, and single image, and will return a class prediction (integer) using test-time augmentation.

Now that we know how to make predictions in Keras using test-time augmentation, let’s work through an example to demonstrate the approach.

Dataset and Baseline Model

We can demonstrate test-time augmentation using a standard computer vision dataset and a convolutional neural network.

Before we can do that, we must select a dataset and a baseline model.

We will use the CIFAR-10 dataset, comprised of 60,000 32×32 pixel color photographs of objects from 10 classes, such as frogs, birds, cats, ships, etc. CIFAR-10 is a well-understood dataset and widely used for benchmarking computer vision algorithms in the field of machine learning. The problem is “solved.” Top performance on the problem is achieved by deep learning convolutional neural networks with a classification accuracy above 96% or 97% on the test dataset.

We will also use a convolutional neural network, or CNN, model that is capable of achieving good (better than random) results, but not state-of-the-art results, on the problem. This will be sufficient to demonstrate the lift in performance that test-time augmentation can provide.

The CIFAR-10 dataset can be loaded easily via the Keras API by calling the cifar10.load_data() function, that returns a tuple with the training and test datasets split into input (images) and output (class labels) components.

It is good practice to normalize the pixel values from the range 0-255 down to the range 0-1 prior to modeling. This ensures that the inputs are small and close to zero, and will, in turn, mean that the weights of the model will be kept small, leading to faster and better learning.

The class labels are integers and must be converted to a one hot encoding prior to modeling.

This can be achieved using the to_categorical() Keras utility function.

We are now ready to define a model for this multi-class classification problem.

The model has a convolutional layer with 32 filter maps with a 3×3 kernel using the rectifier linear activation, “same” padding so the output is the same size as the input and the He weight initialization. This is followed by a batch normalization layer and a max pooling layer.

This pattern is repeated with a convolutional, batch norm, and max pooling layer, although the number of filters is increased to 64. The output is then flattened before being interpreted by a dense layer and finally provided to the output layer to make a prediction.

The Adam variation of stochastic gradient descent is used to find the model weights.

The categorical cross entropy loss function is used, required for multi-class classification, and classification accuracy is monitored during training.

The model is fit for three training epochs and a large batch size of 128 images is used.

Once fit, the model is evaluated on the test dataset.

The complete example is listed below and will easily run on the CPU in a few minutes.

Running the example shows that the model is capable of learning the problem well and quickly.

A test set accuracy of about 66% is achieved, which is okay, but not terrific. The chosen model configuration has already started to overfit and could benefit from the use of regularization and further tuning. Nevertheless, this provides a good starting point for demonstrating test-time augmentation.

Neural networks are stochastic algorithms and the same model fit on the same data multiple times may find a different set of weights and, in turn, have different performance each time.

In order to even out the estimate of model performance, we can change the example to re-run the fit and evaluation of the model multiple times and report the mean and standard deviation of the distribution of scores on the test dataset.

First, we can define a function named load_dataset() that will load the CIFAR-10 dataset and prepare it for modeling.

Next, we can define a function named define_model() that will define a model for the CIFAR-10 dataset, ready to be fit and then evaluated.

Next, an evaluate_model() function is defined that will fit the defined model on the training dataset and then evaluate it on the test dataset, returning the estimated classification accuracy for the run.

Next, we can define a function with new behavior to repeatedly define, fit, and evaluate a new model and return the distribution of accuracy scores.

The repeated_evaluation() function below implements this, taking the dataset and using a default of 10 repeated evaluations.

Finally, we can call the load_dataset() function to prepare the dataset, then repeated_evaluation() to get a distribution of accuracy scores that can be summarized by reporting the mean and standard deviation.

Tying all of this together, the complete code example of repeatedly evaluating a CNN model on the MNIST dataset is listed below.

Running the example may take a while on modern CPU hardware and is much faster on GPU hardware.

The accuracy of the model is reported for each repeated evaluation and the final mean model performance is reported.

In this case, we can see that the mean accuracy of the chosen model configuration is about 68%, which is close to the estimate from a single model run.

Now that we have developed a baseline model for a standard dataset, let’s look at updating the example to use test-time augmentation.

Example of Test-Time Augmentation

We can now update our repeated evaluation of the CNN model on CIFAR-10 to use test-time augmentation.

The tta_prediction() function developed in the section above on how to implement test-time augmentation in Keras can be used directly.

We can develop a function that will drive the test-time augmentation by defining the ImageDataGenerator configuration and call tta_prediction() for each image in the test dataset.

It is important to consider the types of image augmentations that may benefit a model fit on the CIFAR-10 dataset. Augmentations that cause minor modifications to the photographs might be useful. This might include augmentations such as zooms, shifts, and horizontal flips.

In this example, we will only use horizontal flips.

We will configure the image generator to create seven photos, from which the mean prediction for each example in the test set will be made.

The tta_evaluate_model() function below configures the ImageDataGenerator then enumerates the test dataset, making a class label prediction for each image in the test dataset. The accuracy is then calculated by comparing the predicted class labels to the class labels in the test dataset. This requires that we reverse the one hot encoding performed in load_dataset() by using argmax().

The evaluate_model() function can then be updated to call tta_evaluate_model() in order to get model accuracy scores.

Tying all of this together, the complete example of the repeated evaluation of a CNN for CIFAR-10 with test-time augmentation is listed below.

Running the example may take some time given the repeated evaluation and the slower manual test-time augmentation used to evaluate each model.

In this case, we can see a modest lift in performance from about 68.6% on the test set without test-time augmentation to about 69.8% accuracy on the test set with test-time augmentation.

How to Tune Test-Time Augmentation Configuration

Choosing the augmentation configurations that give the biggest lift in model performance can be challenging.

Not only are there many augmentation methods to choose from and configuration options for each, but the time to fit and evaluate a model on a single set of configuration options can take a long time, even if fit on a fast GPU.

Instead, I recommend fitting the model once and saving it to file. For example:

Then load the model from a separate file and evaluate different test-time augmentation schemes on a small validation dataset or small subset of the test set.

For example:

Once you find a set of augmentation options that give the biggest lift, you can then evaluate the model on the whole test set or trial a repeated evaluation experiment as above.

Test-time augmentation configuration not only includes the options for the ImageDataGenerator, but also the number of images generated from which the average prediction will be made for each example in the test set.

I used this approach to choose the test-time augmentation in the previous section, discovering that seven examples worked better than three or five, and that random zooming and random shifts appeared to decrease model accuracy.

Remember, if you also use image data augmentation for the training dataset and that augmentation uses a type of pixel scaling that involves calculating statistics on the dataset (e.g. you call datagen.fit()), then those same statistics and pixel scaling techniques must also be used during test-time augmentation.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.




In this tutorial, you discovered test-time augmentation for improving the performance of models for image classification tasks.

Specifically, you learned:

  • Test-time augmentation is the application of data augmentation techniques normally used during training when making predictions.
  • How to implement test-time augmentation from scratch in Keras.
  • How to use test-time augmentation to improve the performance of a convolutional neural network model on a standard image classification task.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Develop Deep Learning Models for Vision Today!

Deep Learning for Computer Vision

Develop Your Own Vision Models in Minutes

...with just a few lines of python code

Discover how in my new Ebook:
Deep Learning for Computer Vision

It provides self-study tutorials on topics like:
classification, object detection (yolo and rcnn), face recognition (vggface and facenet), data preparation and much more...

Finally Bring Deep Learning to your Vision Projects

Skip the Academics. Just Results.

See What's Inside

12 Responses to How to Use Test-Time Augmentation to Make Better Predictions

  1. Avatar
    sang-min Park September 30, 2019 at 11:49 am #

    Hello. My name is sang-min.

    I saw your writing well.

    I want try TTA(Test time Augmentation) with Object Detection, it’s not well

    In my opinion, the key to TTA is to average model result score.

    But, Object Detection requires extract to bounding boxes, so I’m not sure how to average to boxes.

    If you have good solution, please advice me.

    Thank you!!!

    • Avatar
      Jason Brownlee September 30, 2019 at 2:28 pm #

      Great question.

      Perhaps check the literature for good ways to use TTA for object detection?

  2. Avatar
    San April 10, 2020 at 4:59 pm #

    I have a multi class imbalance problem & there are 3 classes with a single instance per each class, another 3 classes with 2 instances per each class and so on. The majority class 84 instances. After preprocessing the instances of the dataset reduce from 339 to 309.

    Whatever method I try, it doesn’t give me a good performance. I tried out RandomOverSampling but it doesn’t give me a good performance. I cannot use SMOTE or any other techniques realated with SMOTE because I have lot of classes with very few samples & they all give an error. I tried out with hierarchical classification and many other algorithms but they all seem to perform quite better on the majority class but doesn’t work for the minority classes at all. The overall performance I get is 0.40 f1 on test set & the model seems to overfit a lot.

    I found a dataset created with Bayesian Network Generator, that creates artificial data similar to an original datasets. The attribute distributions of this artificial dataset looks very similar to the original one.

    By the way the problem is related to primary tumor classification & in that case is it okay if I take a sample from this artificial dataset & combine it with my original dataset for either training, testing or development?

    If this is possible, what is correct way to combine the artificial instances with my original dataset?


    • Avatar
      Jason Brownlee April 11, 2020 at 6:10 am #

      If it results in better results on your test harness and your test harness is robust, then go for it.

  3. Avatar
    Nikhil May 23, 2020 at 7:04 pm #

    Hi, how does the TTA code you wrote lead to better training of the model itself ?. The model is only trained using the model.fit() step while others just evaluate the model on different augmentations of data. Could you please explain a bit.

    • Avatar
      Jason Brownlee May 24, 2020 at 6:04 am #

      In some cases (some models/some datasets), TTA will provide multiple different attempts at making a prediction on the same dataset, which can result in better performance.

      The model gets to see the input from different perspectives.

  4. Avatar
    SUKH May 26, 2020 at 7:39 pm #


    • Avatar
      Jason Brownlee May 27, 2020 at 7:45 am #

      It can be.

      I have a tutorial on the topic coming. Stay tuned!

  5. Avatar
    Abhinay Kumar December 22, 2020 at 12:02 am #

    Thanks for the tutorial.
    However I was just looking for the generator segment of the code but the code provided above doesnot produce a augmented sample of 10 out of 1 image. the image data generator function loops through only once on the sample and creates a batch size of 1 (whatever we may set the batch_size and steps variable) . Hence I had to put the generator in a loop and take the sum over that to obtain the ‘summed’ variable .

    • Avatar
      Jason Brownlee December 22, 2020 at 6:48 am #

      You can change the code to loop as many times as you like.

  6. Avatar
    Harsha February 13, 2024 at 12:12 am #

    Hey How can i apply this to a image enhancement task, i have reconstructed images and groundtruth images. Using MSE loss and Adam optimizer with UNET architecture. I want to perform TTA when dealing with new datasets.

    • Avatar
      James Carmichael February 13, 2024 at 8:47 am #

      Hi Harsha…Test-time augmentation (TTA) is a technique commonly used in machine learning to improve model performance during inference by applying various transformations to the input data and averaging the predictions. This helps to reduce overfitting and increase the robustness of the model. Here’s how you can apply test-time augmentation in Python:

      1. **Prepare your model**: First, make sure you have a trained model ready for inference. This could be a neural network model trained using libraries like TensorFlow, PyTorch, or scikit-learn.

      2. **Define augmentation transformations**: Define a set of augmentation transformations you want to apply to your test data. These transformations could include techniques like rotation, flipping, scaling, cropping, or color jittering.

      3. **Perform TTA during inference**: During inference, apply each augmentation transformation to the input data and make predictions using your model. Then, average the predictions to get the final output.

      Here’s a Python code example using the imgaug library for image augmentation and applying TTA to image classification:

      import numpy as np
      import imgaug.augmenters as iaa

      # Define your model
      # Example:
      # from keras.models import load_model
      # model = load_model('path_to_your_model.h5')

      # Define augmentation transformations
      augmentation = iaa.Sequential([
      iaa.Fliplr(0.5), # horizontally flip 50% of images
      iaa.Affine(rotate=(-20, 20)), # rotate images by -20 to +20 degrees
      iaa.GaussianBlur(sigma=(0, 3.0)) # apply gaussian blur with sigma between 0 and 3.0

      def predict_with_tta(model, images, n_augmentations=5):
      all_predictions = []

      for _ in range(n_augmentations):
      augmented_images = augmentation(images=images)
      predictions = model.predict(augmented_images)

      # Average predictions
      avg_predictions = np.mean(all_predictions, axis=0)
      return avg_predictions

      # Example usage:
      # images = load_test_images()
      # predictions = predict_with_tta(model, images)
      # print(predictions)

      In this example:

      – We use the imgaug library to define a set of augmentation transformations.
      – The predict_with_tta function takes a trained model, input images, and the number of augmentations to apply. It applies augmentation transformations to the input images, makes predictions using the model for each augmented version, and averages the predictions.

      You can adjust the augmentation techniques and parameters based on your specific use case and data characteristics. Additionally, make sure to adapt the code according to the specific requirements and APIs of your model and data.

Leave a Reply