Image Augmentation for Deep Learning With Keras

Data preparation is required when working with neural network and deep learning models. Increasingly data augmentation is also required on more complex object recognition tasks.

In this post you will discover how to use data preparation and data augmentation with your image datasets when developing and evaluating deep learning models in Python with Keras.

After reading this post, you will know:

  • About the image augmentation API provide by Keras and how to use it with your models.
  • How to perform feature standardization.
  • How to perform ZCA whitening of your images.
  • How to augment data with random rotations, shifts and flips.
  • How to save augmented image data to disk.

Let’s get started.

  • Update: The examples in this post were updated for the latest Keras API. The function was removed.
  • Update Oct/2016: Updated examples for Keras 1.1.0, TensorFlow 0.10.0 and scikit-learn v0.18.
  • Update Jan/2017: Updated examples for Keras 1.2.0 and TensorFlow 0.12.1.

Keras Image Augmentation API

Like the rest of Keras, the image augmentation API is simple and powerful.

Keras provides the ImageDataGenerator class that defines the configuration for image data preparation and augmentation. This includes capabilities such as:

  • Sample-wise standardization.
  • Feature-wise standardization.
  • ZCA whitening.
  • Random rotation, shifts, shear and flips.
  • Dimension reordering.
  • Save augmented images to disk.

An augmented image generator can be created as follows:

Rather than performing the operations on your entire image dataset in memory, the API is designed to be iterated by the deep learning model fitting process, creating augmented image data for you just-in-time. This reduces your memory overhead, but adds some additional time cost during model training.

After you have created and configured your ImageDataGenerator, you must fit it on your data. This will calculate any statistics required to actually perform the transforms to your image data. You can do this by calling the fit() function on the data generator and pass it your training dataset.

The data generator itself is in fact an iterator, returning batches of image samples when requested. We can configure the batch size and prepare the data generator and get batches of images by calling the flow() function.

Finally we can make use of the data generator. Instead of calling the fit() function on our model, we must call the fit_generator() function and pass in the data generator and the desired length of an epoch as well as the total number of epochs on which to train.

You can learn more about the Keras image data generator API in the Keras documentation.

Beat the Math/Theory Doldrums and Start using Deep Learning in your own projects Today, without getting lost in “documentation hell”

Deep Learning With Python Mini-CourseGet my free Deep Learning With Python mini course and develop your own deep nets by the time you’ve finished the first PDF with just a few lines of Python.

Daily lessons in your inbox for 14 days, and a DL-With-Python “Cheat Sheet” you can download right now.   

Download Your FREE Mini-Course  


Point of Comparison for Image Augmentation

Now that you know how the image augmentation API in Keras works, let’s look at some examples.

We will use the MNIST handwritten digit recognition task in these examples. To begin with, let’s take a look at the first 9 images in the training dataset.

Running this example provides the following image that we can use as a point of comparison with the image preparation and augmentation in the examples below.

Example MNIST images

Example MNIST images

Feature Standardization

It is also possible to standardize pixel values across the entire dataset. This is called feature standardization and mirrors the type of standardization often performed for each column in a tabular dataset.

You can perform feature standardization by setting the featurewise_center and featurewise_std_normalization arguments on the ImageDataGenerator class. These are in fact set to True by default and creating an instance of ImageDataGenerator with no arguments will have the same effect.

Running this example you can see that the effect is different, seemingly darkening and lightening different digits.

Standardized Feature MNIST Images

Standardized Feature MNIST Images

ZCA Whitening

A whitening transform of an image is a linear algebra operation that reduces the redundancy in the matrix of pixel images.

Less redundancy in the image is intended to better highlight the structures and features in the image to the learning algorithm.

Typically, image whitening is performed using the Principal Component Analysis (PCA) technique. More recently, an alternative called ZCA (learn more in Appendix A of this tech report) shows better results and results in transformed images that keeps all of the original dimensions and unlike PCA, resulting transformed images still look like their originals.

You can perform a ZCA whitening transform by setting the zca_whitening argument to True.

Running the example, you can see the same general structure in the images and how the outline of each digit has been highlighted.

ZCA Whitening MNIST Images

ZCA Whitening MNIST Images

Random Rotations

Sometimes images in your sample data may have varying and different rotations in the scene.

You can train your model to better handle rotations of images by artificially and randomly rotating images from your dataset during training.

The example below creates random rotations of the MNIST digits up to 90 degrees by setting the rotation_range argument.

Running the example, you can see that images have been rotated left and right up to a limit of 90 degrees. This is not helpful on this problem because the MNIST digits have a normalized orientation, but this transform might be of help when learning from photographs where the objects may have different orientations.

Random Rotations of MNIST Images

Random Rotations of MNIST Images

Random Shifts

Objects in your images may not be centered in the frame. They may be off-center in a variety of different ways.

You can train your deep learning network to expect and currently handle off-center objects by artificially creating shifted versions of your training data. Keras supports separate horizontal and vertical random shifting of training data by the width_shift_range and height_shift_range arguments.

Running this example creates shifted versions of the digits. Again, this is not required for MNIST as the handwritten digits are already centered, but you can see how this might be useful on more complex problem domains.

Random Shifted MNIST Images

Random Shifted MNIST Images

Random Flips

Another augmentation to your image data that can improve performance on large and complex problems is to create random flips of images in your training data.

Keras supports random flipping along both the vertical and horizontal axes using the vertical_flip and horizontal_flip arguments.

Running this example you can see flipped digits. Flipping digits is not useful as they will always have the correct left and right orientation, but this may be useful for problems with photographs of objects in a scene that can have a varied orientation.

Randomly Flipped MNIST Images

Randomly Flipped MNIST Images

Saving Augmented Images to File

The data preparation and augmentation is performed just in time by Keras.

This is efficient in terms of memory, but you may require the exact images used during training. For example, perhaps you would like to use them with a different software package later or only generate them once and use them on multiple different deep learning models or configurations.

Keras allows you to save the images generated during training. The directory, filename prefix and image file type can be specified to the flow() function before training. Then, during training, the generated images will be written to file.

The example below demonstrates this and writes 9 images to a “images” subdirectory with the prefix “aug” and the file type of PNG.

Running the example you can see that images are only written when they are generated.

Augmented MNIST Images Saved To File

Augmented MNIST Images Saved To File

Tips For Augmenting Image Data with Keras

Image data is unique in that you can review the data and transformed copies of the data and quickly get an idea of how the model may be perceive it by your model.

Below are some times for getting the most from image data preparation and augmentation for deep learning.

  • Review Dataset. Take some time to review your dataset in great detail. Look at the images. Take note of image preparation and augmentations that might benefit the training process of your model, such as the need to handle different shifts, rotations or flips of objects in the scene.
  • Review Augmentations. Review sample images after the augmentation has been performed. It is one thing to intellectually know what image transforms you are using, it is a very different thing to look at examples. Review images both with individual augmentations you are using as well as the full set of augmentations you plan to use. You may see ways to simplify or further enhance your model training process.
  • Evaluate a Suite of Transforms. Try more than one image data preparation and augmentation scheme. Often you can be surprised by results of a data preparation scheme you did not think would be beneficial.


In this post you discovered image data preparation and augmentation.

You discovered a range of techniques that you can use easily in Python with Keras for deep learning models. You learned about:

  • The ImageDataGenerator API in Keras for generating transformed images just in time.
  • Sample-wise and Feature wise pixel standardization.
  • The ZCA whitening transform.
  • Random rotations, shifts and flips of images.
  • How to save transformed images to file for later reuse.

Do you have any questions about image data augmentation or this post? Ask your questions in the comments and I will do my best to answer.

Frustrated With Your Progress In Deep Learning?

 What If You Could Develop Your Own Deep Nets in Minutes

...with just a few lines of Python

Discover how in my new Ebook: Deep Learning With Python

It covers self-study tutorials and end-to-end projects on topics like:
Multilayer PerceptronsConvolutional Nets and Recurrent Neural Nets, and more...

Finally Bring Deep Learning To
Your Own Projects

Skip the Academics. Just Results.

Click to learn more.

26 Responses to Image Augmentation for Deep Learning With Keras

  1. Andy August 2, 2016 at 7:34 am #

    Interesting tutorial.

    I’m working through the step to standardize images across the dataset and run into the following error:

    AttributeError Traceback (most recent call last)
    in ()
    18 datagen.flow(X_train, y_train, batch_size=9)
    19 # retrieve one batch of images
    —> 20 X_batch, y_batch =
    21 # create a grid of 3×3 images
    22 for i in range(0, 9):

    AttributeError: ‘ImageDataGenerator’ object has no attribute ‘next’

    I have checked the Keras documentation and see no mention of a next attribute.

    Perhaps I’m missing something.

    Thanks for the great tutorials!

  2. narayan August 9, 2016 at 6:38 pm #

    for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
    File “/usr/local/lib/python2.7/dist-packages/keras/preprocessing/”, line 475, in next
    x = self.image_data_generator.random_transform(x.astype(‘float32’))
    File “/usr/local/lib/python2.7/dist-packages/keras/preprocessing/”, line 346, in random_transform
    fill_mode=self.fill_mode, cval=self.cval)
    File “/usr/local/lib/python2.7/dist-packages/keras/preprocessing/”, line 109, in apply_transform
    x = np.stack(channel_images, axis=0)
    AttributeError: ‘module’ object has no attribute ‘stack’

    how to solve this error …?

    • Jason Brownlee August 15, 2016 at 11:13 am #

      I have not seen an error like that before. Perhaps there is a problem with your environment?

      Consider re-installing Theano and/or Keras.

      • narayan August 26, 2016 at 9:02 pm #

        i solved this error by updating numpy version ….previously it means it should be more than 1.9.0

  3. narayan August 26, 2016 at 9:05 pm #

    Now i have question that how to decide value for this parameter So that i can get good testing accuracy ..i have training dataset with 110 category with 32000 images ..


    Waiting for your positive reply…

    • Jason Brownlee August 27, 2016 at 11:34 am #

      My advice is to try a suite of different configurations and see what works best on your problem.

  4. Walid Ahmed November 9, 2016 at 2:08 am #

    Thanks a lot.
    all worked fine except the last code to save images to file, I got the following exception

    Walids-MacBook-Pro:DataAugmentation walidahmed$ python
    Using TensorFlow backend.
    Traceback (most recent call last):
    File “”, line 20, in
    for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9, save_to_dir=’images’, save_prefix=’aug’, save_format=’png’):
    File “/usr/local/lib/python2.7/site-packages/keras/preprocessing/”, line 490, in next
    img = array_to_img(batch_x[i], self.dim_ordering, scale=True)
    File “/usr/local/lib/python2.7/site-packages/keras/preprocessing/”, line 140, in array_to_img
    raise Exception(‘Unsupported channel number: ‘, x.shape[2])
    Exception: (‘Unsupported channel number: ‘, 28)

    Any advice?
    thanks again

    • Jason Brownlee November 9, 2016 at 9:52 am #

      Double check your version of Keras is 1.1.0 and TensorFlow is 0.10.

  5. Sudesh November 11, 2016 at 9:37 pm #

    Hello Jason,

    Thanks a lot for your tutorial. It is helping me in many ways.

    I had question on mask image or target Y for training image X
    Can i also transform Y along with X. Helps in the case of training for segmentation

    • Sudesh November 15, 2016 at 5:25 am #

      I managed to do it.

      datagen = ImageDataGenerator(shear_range=0.02,dim_ordering=K._image_dim_ordering,rotation_range=5,width_shift_range=0.05, height_shift_range=0.05,zoom_range=0.3,fill_mode=’constant’, cval=0)

      for samples in range(0,100):
      seed = rd.randint(low=10,high=100000)
      for imags_batch in datagen.flow(imgs_train,batch_size=batch_size,save_to_dir=’augmented’,save_prefix=’aug’,seed=seed,save_format=’tif’):
      for imgs_mask_batch in datagen.flow(imgs_mask_train, batch_size=batch_size, save_to_dir=’augmented’,seed=seed, save_prefix=’mask_aug’,save_format=’tif’):

  6. Addie November 29, 2016 at 6:01 am #

    This is great stuff but I wonder if you could provide an example like this with an RGB image with three channels? I am getting some really buggy results personally with this ImageGenerator.

  7. Lucas December 24, 2016 at 9:02 am #

    I wonder what channel_shift_range is about. The doc says “shift range for each channels”, but what does this actually mean? Is it adding a random value to each channel or doing something else?

    • Jason Brownlee December 26, 2016 at 7:37 am #

      I have not used this one yet, sorry Lucas.

      You could try experimenting with it or dive into the source to see what it’s all about.

  8. Indra December 26, 2016 at 5:30 pm #


    Thanks for the post. I’ve one question i.e., we do feature standardization in the training set, so while testing, we need those standardized values to apply on testing images ?

    • Jason Brownlee December 27, 2016 at 5:22 am #

      Yes Indra, any transforms like standardization performed on the data prior to modeling will also need to be performed on new data when testing or making predictions.

      In the case of standardization, we need to keep track of means and standard deviations.

  9. Dan March 11, 2017 at 11:01 pm #

    Thanks again Jason. Why do we subplot 330+1+i? Thanks

    • Jason Brownlee March 12, 2017 at 8:24 am #

      This is matplotlab syntax.

      The 33 creates a grid of 3×3 images. The number after that (1-9) indicates the position in that grid to place the next image (left to right, top to bottom ordering).

      I hope that helps.

  10. Vineeth March 13, 2017 at 7:52 pm #

    How do I save the augmented images into a directory with a class label prefix or even better into a subdirectory of class name?

    • Jason Brownlee March 14, 2017 at 8:15 am #

      Great question Vineeth,

      You can specify any directory and filename prefix you like in the call to flow()

  11. Richa March 21, 2017 at 10:45 pm #

    can we augment data of a particular class. I mean images of a class which are less, to deal with the class imbalance problem.

    • Jason Brownlee March 22, 2017 at 8:06 am #

      Great idea.

      Yes, but you may need to prepare the data for each class separately.

Leave a Reply