When you work on a machine learning problem related to images, not only do you need to collect some images as training data, but you also need to employ augmentation to create variations in the image. It is especially true for more complex object recognition problems.
There are many ways for image augmentation. You may use some external libraries or write your own functions for that. There are some modules in TensorFlow and Keras for augmentation too.
In this post, you will discover how you can use the Keras preprocessing layer as well as the tf.image
module in TensorFlow for image augmentation.
After reading this post, you will know:
- What are the Keras preprocessing layers, and how to use them
- What are the functions provided by the
tf.image
module for image augmentation - How to use augmentation together with the
tf.data
dataset
Let’s get started.

Image augmentation with Keras preprocessing layers and tf.image.
Photo by Steven Kamenar. Some rights reserved.
Overview
This article is divided into five sections; they are:
- Getting Images
- Visualizing the Images
- Keras Preprocessing Layers
- Using tf.image API for Augmentation
- Using Preprocessing Layers in Neural Networks
Getting Images
Before you see how you can do augmentation, you need to get the images. Ultimately, you need the images to be represented as arrays, for example, in HxWx3 in 8-bit integers for the RGB pixel value. There are many ways to get the images. Some can be downloaded as a ZIP file. If you’re using TensorFlow, you may get some image datasets from the tensorflow_datasets
library.
In this tutorial, you will use the citrus leaves images, which is a small dataset of less than 100MB. It can be downloaded from tensorflow_datasets
as follows:
1 2 |
import tensorflow_datasets as tfds ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True) |
Running this code the first time will download the image dataset into your computer with the following output:
1 2 3 4 5 |
Downloading and preparing dataset 63.87 MiB (download: 63.87 MiB, generated: 37.89 MiB, total: 101.76 MiB) to ~/tensorflow_datasets/citrus_leaves/0.1.2... Extraction completed...: 100%|██████████████████████████████| 1/1 [00:06<00:00, 6.54s/ file] Dl Size...: 100%|██████████████████████████████████████████| 63/63 [00:06<00:00, 9.63 MiB/s] Dl Completed...: 100%|███████████████████████████████████████| 1/1 [00:06<00:00, 6.54s/ url] Dataset citrus_leaves downloaded and prepared to ~/tensorflow_datasets/citrus_leaves/0.1.2. Subsequent calls will reuse this data. |
The function above returns the images as a tf.data
dataset object and the metadata. This is a classification dataset. You can print the training labels with the following:
1 2 3 |
... for i in range(meta.features['label'].num_classes): print(meta.features['label'].int2str(i)) |
This prints:
1 2 3 4 |
Black spot canker greening healthy |
If you run this code again at a later time, you will reuse the downloaded image. But the other way to load the downloaded images into a tf.data
dataset is to use the image_dataset_from_directory()
function.
As you can see from the screen output above, the dataset is downloaded into the directory ~/tensorflow_datasets
. If you look at the directory, you see the directory structure as follows:
1 2 3 4 5 6 |
.../Citrus/Leaves ├── Black spot ├── Melanose ├── canker ├── greening └── healthy |
The directories are the labels, and the images are files stored under their corresponding directory. You can let the function to read the directory recursively into a dataset:
1 2 3 4 5 6 7 8 9 10 |
import tensorflow as tf from tensorflow.keras.utils import image_dataset_from_directory # set to fixed image size 256x256 PATH = ".../Citrus/Leaves" ds = image_dataset_from_directory(PATH, validation_split=0.2, subset="training", image_size=(256,256), interpolation="bilinear", crop_to_aspect_ratio=True, seed=42, shuffle=True, batch_size=32) |
You may want to set batch_size=None
if you do not want the dataset to be batched. Usually, you want the dataset to be batched for training a neural network model.
Visualizing the Images
It is important to visualize the augmentation result, so you can verify the augmentation result is what we want it to be. You can use matplotlib for this.
In matplotlib, you have the imshow()
function to display an image. However, for the image to be displayed correctly, the image should be presented as an array of 8-bit unsigned integers (uint8).
Given that you have a dataset created using image_dataset_from_directory()
You can get the first batch (of 32 images) and display a few of them using imshow()
, as follows:
1 2 3 4 5 6 7 8 9 10 11 |
... import matplotlib.pyplot as plt fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5)) for images, labels in ds.take(1): for i in range(3): for j in range(3): ax[i][j].imshow(images[i*3+j].numpy().astype("uint8")) ax[i][j].set_title(ds.class_names[labels[i*3+j]]) plt.show() |
Here, you see a display of nine images in a grid, labeled with their corresponding classification label, using ds.class_names
. The images should be converted to NumPy array in uint8 for display. This code displays an image like the following:
The complete code from loading the image to display is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from tensorflow.keras.utils import image_dataset_from_directory import matplotlib.pyplot as plt # use image_dataset_from_directory() to load images, with image size scaled to 256x256 PATH='.../Citrus/Leaves' # modify to your path ds = image_dataset_from_directory(PATH, validation_split=0.2, subset="training", image_size=(256,256), interpolation="mitchellcubic", crop_to_aspect_ratio=True, seed=42, shuffle=True, batch_size=32) # Take one batch from dataset and display the images fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5)) for images, labels in ds.take(1): for i in range(3): for j in range(3): ax[i][j].imshow(images[i*3+j].numpy().astype("uint8")) ax[i][j].set_title(ds.class_names[labels[i*3+j]]) plt.show() |
Note that if you’re using tensorflow_datasets
to get the image, the samples are presented as a dictionary instead of a tuple of (image,label). You should change your code slightly to the following:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import tensorflow_datasets as tfds import matplotlib.pyplot as plt # use tfds.load() or image_dataset_from_directory() to load images ds, meta = tfds.load('citrus_leaves', with_info=True, split='train', shuffle_files=True) ds = ds.batch(32) # Take one batch from dataset and display the images fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(5,5)) for sample in ds.take(1): images, labels = sample["image"], sample["label"] for i in range(3): for j in range(3): ax[i][j].imshow(images[i*3+j].numpy().astype("uint8")) ax[i][j].set_title(meta.features['label'].int2str(labels[i*3+j])) plt.show() |
For the rest of this post, assume the dataset is created using image_dataset_from_directory()
. You may need to tweak the code slightly if your dataset is created differently.
Keras Preprocessing Layers
Keras comes with many neural network layers, such as convolution layers, that you need to train. There are also layers with no parameters to train, such as flatten layers to convert an array like an image into a vector.
The preprocessing layers in Keras are specifically designed to use in the early stages of a neural network. You can use them for image preprocessing, such as to resize or rotate the image or adjust the brightness and contrast. While the preprocessing layers are supposed to be part of a larger neural network, you can also use them as functions. Below is how you can use the resizing layer as a function to transform some images and display them side-by-side with the original:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
... # create a resizing layer out_height, out_width = 128,256 resize = tf.keras.layers.Resizing(out_height, out_width) # show original vs resized fig, ax = plt.subplots(2, 3, figsize=(6,4)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # resize ax[1][i].imshow(resize(images[i]).numpy().astype("uint8")) ax[1][i].set_title("resize") plt.show() |
The images are in 256×256 pixels, and the resizing layer will make them into 256×128 pixels. The output of the above code is as follows:
Since the resizing layer is a function, you can chain them to the dataset itself. For example,
1 2 3 4 5 6 7 8 |
... def augment(image, label): return resize(image), label resized_ds = ds.map(augment) for image, label in resized_ds: ... |
The dataset ds
has samples in the form of (image, label)
. Hence you created a function that takes in such tuple and preprocesses the image with the resizing layer. You then assigned this function as an argument for the map()
in the dataset. When you draw a sample from the new dataset created with the map()
function, the image will be a transformed one.
There are more preprocessing layers available. Some are demonstrated below.
As you saw above, you can resize the image. You can also randomly enlarge or shrink the height or width of an image. Similarly, you can zoom in or zoom out on an image. Below is an example of manipulating the image size in various ways for a maximum of 30% increase or decrease:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
... # Create preprocessing layers out_height, out_width = 128,256 resize = tf.keras.layers.Resizing(out_height, out_width) height = tf.keras.layers.RandomHeight(0.3) width = tf.keras.layers.RandomWidth(0.3) zoom = tf.keras.layers.RandomZoom(0.3) # Visualize images and augmentations fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # resize ax[1][i].imshow(resize(images[i]).numpy().astype("uint8")) ax[1][i].set_title("resize") # height ax[2][i].imshow(height(images[i]).numpy().astype("uint8")) ax[2][i].set_title("height") # width ax[3][i].imshow(width(images[i]).numpy().astype("uint8")) ax[3][i].set_title("width") # zoom ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8")) ax[4][i].set_title("zoom") plt.show() |
This code shows images as follows:
While you specified a fixed dimension in resize, you have a random amount of manipulation in other augmentations.
You can also do flipping, rotation, cropping, and geometric translation using preprocessing layers:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
... # Create preprocessing layers flip = tf.keras.layers.RandomFlip("horizontal_and_vertical") # or "horizontal", "vertical" rotate = tf.keras.layers.RandomRotation(0.2) crop = tf.keras.layers.RandomCrop(out_height, out_width) translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2) # Visualize augmentations fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # flip ax[1][i].imshow(flip(images[i]).numpy().astype("uint8")) ax[1][i].set_title("flip") # crop ax[2][i].imshow(crop(images[i]).numpy().astype("uint8")) ax[2][i].set_title("crop") # translation ax[3][i].imshow(translation(images[i]).numpy().astype("uint8")) ax[3][i].set_title("translation") # rotate ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8")) ax[4][i].set_title("rotate") plt.show() |
This code shows the following images:
And finally, you can do augmentations on color adjustments as well:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
... brightness = tf.keras.layers.RandomBrightness([-0.8,0.8]) contrast = tf.keras.layers.RandomContrast(0.2) # Visualize augmentation fig, ax = plt.subplots(3, 3, figsize=(6,7)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # brightness ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8")) ax[1][i].set_title("brightness") # contrast ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8")) ax[2][i].set_title("contrast") plt.show() |
This shows the images as follows:
For completeness, below is the code to display the result of various augmentations:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from tensorflow.keras.utils import image_dataset_from_directory import tensorflow as tf import matplotlib.pyplot as plt # use image_dataset_from_directory() to load images, with image size scaled to 256x256 PATH='.../Citrus/Leaves' # modify to your path ds = image_dataset_from_directory(PATH, validation_split=0.2, subset="training", image_size=(256,256), interpolation="mitchellcubic", crop_to_aspect_ratio=True, seed=42, shuffle=True, batch_size=32) # Create preprocessing layers out_height, out_width = 128,256 resize = tf.keras.layers.Resizing(out_height, out_width) height = tf.keras.layers.RandomHeight(0.3) width = tf.keras.layers.RandomWidth(0.3) zoom = tf.keras.layers.RandomZoom(0.3) flip = tf.keras.layers.RandomFlip("horizontal_and_vertical") rotate = tf.keras.layers.RandomRotation(0.2) crop = tf.keras.layers.RandomCrop(out_height, out_width) translation = tf.keras.layers.RandomTranslation(height_factor=0.2, width_factor=0.2) brightness = tf.keras.layers.RandomBrightness([-0.8,0.8]) contrast = tf.keras.layers.RandomContrast(0.2) # Visualize images and augmentations fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # resize ax[1][i].imshow(resize(images[i]).numpy().astype("uint8")) ax[1][i].set_title("resize") # height ax[2][i].imshow(height(images[i]).numpy().astype("uint8")) ax[2][i].set_title("height") # width ax[3][i].imshow(width(images[i]).numpy().astype("uint8")) ax[3][i].set_title("width") # zoom ax[4][i].imshow(zoom(images[i]).numpy().astype("uint8")) ax[4][i].set_title("zoom") plt.show() fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # flip ax[1][i].imshow(flip(images[i]).numpy().astype("uint8")) ax[1][i].set_title("flip") # crop ax[2][i].imshow(crop(images[i]).numpy().astype("uint8")) ax[2][i].set_title("crop") # translation ax[3][i].imshow(translation(images[i]).numpy().astype("uint8")) ax[3][i].set_title("translation") # rotate ax[4][i].imshow(rotate(images[i]).numpy().astype("uint8")) ax[4][i].set_title("rotate") plt.show() fig, ax = plt.subplots(3, 3, figsize=(6,7)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # brightness ax[1][i].imshow(brightness(images[i]).numpy().astype("uint8")) ax[1][i].set_title("brightness") # contrast ax[2][i].imshow(contrast(images[i]).numpy().astype("uint8")) ax[2][i].set_title("contrast") plt.show() |
Finally, it is important to point out that most neural network models can work better if the input images are scaled. While we usually use an 8-bit unsigned integer for the pixel values in an image (e.g., for display using imshow()
as above), a neural network prefers the pixel values to be between 0 and 1 or between -1 and +1. This can be done with preprocessing layers too. Below is how you can update one of the examples above to add the scaling layer into the augmentation:
1 2 3 4 5 6 7 8 9 10 11 12 |
... out_height, out_width = 128,256 resize = tf.keras.layers.Resizing(out_height, out_width) rescale = tf.keras.layers.Rescaling(1/127.5, offset=-1) # rescale pixel values to [-1,1] def augment(image, label): return rescale(resize(image)), label rescaled_resized_ds = ds.map(augment) for image, label in rescaled_resized_ds: ... |
Using tf.image API for Augmentation
Besides the preprocessing layer, the tf.image
module also provides some functions for augmentation. Unlike the preprocessing layer, these functions are intended to be used in a user-defined function and assigned to a dataset using map()
as we saw above.
The functions provided by the tf.image
are not duplicates of the preprocessing layers, although there is some overlap. Below is an example of using the tf.image
functions to resize and crop images:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
... fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): # original ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # resize h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2)) w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2)) ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8")) ax[1][i].set_title("resize") # crop y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8") ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8")) ax[2][i].set_title("crop") # central crop x = tf.random.uniform([], minval=0.4, maxval=1.0) ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8")) ax[3][i].set_title("central crop") # crop to (h,w) at random offset h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8") seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8")) ax[4][i].set_title("random crop") plt.show() |
Below is the output of the above code:
While the display of images matches what you might expect from the code, the use of tf.image
functions is quite different from that of the preprocessing layers. Every tf.image
function is different. Therefore, you can see the crop_to_bounding_box()
function takes pixel coordinates, but the central_crop()
function assumes a fraction ratio as the argument.
These functions are also different in the way randomness is handled. Some of these functions do not assume random behavior. Therefore, the random resize should have the exact output size generated using a random number generator separately before calling the resize function. Some other functions, such as stateless_random_crop()
, can do augmentation randomly, but a pair of random seeds in the int32
needs to be specified explicitly.
To continue the example, there are the functions for flipping an image and extracting the Sobel edges:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
... fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # flip seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8")) ax[1][i].set_title("flip left-right") # flip seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8")) ax[2][i].set_title("flip up-down") # sobel edge sobel = tf.image.sobel_edges(images[i:i+1]) ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8")) ax[3][i].set_title("sobel y") # sobel edge ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8")) ax[4][i].set_title("sobel x") plt.show() |
This shows the following:
And the following are the functions to manipulate the brightness, contrast, and colors:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
... fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # brightness seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8")) ax[1][i].set_title("brightness") # contrast ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8")) ax[2][i].set_title("contrast") # saturation ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8")) ax[3][i].set_title("saturation") # hue ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8")) ax[4][i].set_title("hue") plt.show() |
This code shows the following:
Below is the complete code to display all of the above:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from tensorflow.keras.utils import image_dataset_from_directory import tensorflow as tf import matplotlib.pyplot as plt # use image_dataset_from_directory() to load images, with image size scaled to 256x256 PATH='.../Citrus/Leaves' # modify to your path ds = image_dataset_from_directory(PATH, validation_split=0.2, subset="training", image_size=(256,256), interpolation="mitchellcubic", crop_to_aspect_ratio=True, seed=42, shuffle=True, batch_size=32) # Visualize tf.image augmentations fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): # original ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # resize h = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2)) w = int(256 * tf.random.uniform([], minval=0.8, maxval=1.2)) ax[1][i].imshow(tf.image.resize(images[i], [h,w]).numpy().astype("uint8")) ax[1][i].set_title("resize") # crop y, x, h, w = (128 * tf.random.uniform((4,))).numpy().astype("uint8") ax[2][i].imshow(tf.image.crop_to_bounding_box(images[i], y, x, h, w).numpy().astype("uint8")) ax[2][i].set_title("crop") # central crop x = tf.random.uniform([], minval=0.4, maxval=1.0) ax[3][i].imshow(tf.image.central_crop(images[i], x).numpy().astype("uint8")) ax[3][i].set_title("central crop") # crop to (h,w) at random offset h, w = (256 * tf.random.uniform((2,))).numpy().astype("uint8") seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[4][i].imshow(tf.image.stateless_random_crop(images[i], [h,w,3], seed).numpy().astype("uint8")) ax[4][i].set_title("random crop") plt.show() fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # flip seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[1][i].imshow(tf.image.stateless_random_flip_left_right(images[i], seed).numpy().astype("uint8")) ax[1][i].set_title("flip left-right") # flip seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[2][i].imshow(tf.image.stateless_random_flip_up_down(images[i], seed).numpy().astype("uint8")) ax[2][i].set_title("flip up-down") # sobel edge sobel = tf.image.sobel_edges(images[i:i+1]) ax[3][i].imshow(sobel[0, ..., 0].numpy().astype("uint8")) ax[3][i].set_title("sobel y") # sobel edge ax[4][i].imshow(sobel[0, ..., 1].numpy().astype("uint8")) ax[4][i].set_title("sobel x") plt.show() fig, ax = plt.subplots(5, 3, figsize=(6,14)) for images, labels in ds.take(1): for i in range(3): ax[0][i].imshow(images[i].numpy().astype("uint8")) ax[0][i].set_title("original") # brightness seed = tf.random.uniform((2,), minval=0, maxval=65536).numpy().astype("int32") ax[1][i].imshow(tf.image.stateless_random_brightness(images[i], 0.3, seed).numpy().astype("uint8")) ax[1][i].set_title("brightness") # contrast ax[2][i].imshow(tf.image.stateless_random_contrast(images[i], 0.7, 1.3, seed).numpy().astype("uint8")) ax[2][i].set_title("contrast") # saturation ax[3][i].imshow(tf.image.stateless_random_saturation(images[i], 0.7, 1.3, seed).numpy().astype("uint8")) ax[3][i].set_title("saturation") # hue ax[4][i].imshow(tf.image.stateless_random_hue(images[i], 0.3, seed).numpy().astype("uint8")) ax[4][i].set_title("hue") plt.show() |
These augmentation functions should be enough for most uses. But if you have some specific ideas on augmentation, you would probably need a better image processing library. OpenCV and Pillow are common but powerful libraries that allow you to transform images better.
Using Preprocessing Layers in Neural Networks
You used the Keras preprocessing layers as functions in the examples above. But they can also be used as layers in a neural network. It is trivial to use. Below is an example of how you can incorporate a preprocessing layer into a classification network and train it using a dataset:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from tensorflow.keras.utils import image_dataset_from_directory import tensorflow as tf import matplotlib.pyplot as plt # use image_dataset_from_directory() to load images, with image size scaled to 256x256 PATH='.../Citrus/Leaves' # modify to your path ds = image_dataset_from_directory(PATH, validation_split=0.2, subset="training", image_size=(256,256), interpolation="mitchellcubic", crop_to_aspect_ratio=True, seed=42, shuffle=True, batch_size=32) AUTOTUNE = tf.data.AUTOTUNE ds = ds.cache().prefetch(buffer_size=AUTOTUNE) num_classes = 5 model = tf.keras.Sequential([ tf.keras.layers.RandomFlip("horizontal_and_vertical"), tf.keras.layers.RandomRotation(0.2), tf.keras.layers.Rescaling(1/127.0, offset=-1), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(num_classes) ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) model.fit(ds, epochs=3) |
Running this code gives the following output:
1 2 3 4 5 6 7 8 |
Found 609 files belonging to 5 classes. Using 488 files for training. Epoch 1/3 16/16 [==============================] - 5s 253ms/step - loss: 1.4114 - accuracy: 0.4283 Epoch 2/3 16/16 [==============================] - 4s 259ms/step - loss: 0.8101 - accuracy: 0.6475 Epoch 3/3 16/16 [==============================] - 4s 267ms/step - loss: 0.7015 - accuracy: 0.7111 |
In the code above, you created the dataset with cache()
and prefetch()
. This is a performance technique to allow the dataset to prepare data asynchronously while the neural network is trained. This would be significant if the dataset has some other augmentation assigned using the map()
function.
You will see some improvement in accuracy if you remove the RandomFlip
and RandomRotation
layers because you make the problem easier. However, as you want the network to predict well on a wide variation of image quality and properties, using augmentation can help your resulting network become more powerful.
Further Reading
Below is some documentation from TensorFlow that is related to the examples above:
tf.data.Dataset
API- Citrus leaves dataset
- Load and preprocess images
- Data augmentation
tf.image
APItf.data
performance
Summary
In this post, you have seen how you can use the tf.data
dataset with image augmentation functions from Keras and TensorFlow.
Specifically, you learned:
- How to use the preprocessing layers from Keras, both as a function and as part of a neural network
- How to create your own image augmentation function and apply it to the dataset using the
map()
function - How to use the functions provided by the
tf.image
module for image augmentation
A very helpful blog but I have 3 problems with the code.
Problem 1: I had to upgrade to TensorFlow 2.9.0 for the RandomBrightness function to work.
Problem 2: When I run the Sequential model you create for the last example, it complains about not having the ‘shape’ defined for the data input to the Model. In looking at the Model, I don’t see where you’ve defined the input shape or the input data. Where has this been done?
The error indicates, “You must provide an
input_shape
argument.”Problem 3: I cannot make the following code work:
# use image_dataset_from_directory() to load images, with image size scaled to 256×256
PATH=’…/Citrus/Leaves’ # modify to your path
ds = image_dataset_from_directory(PATH,
validation_split=0.2, subset=”training”,
image_size=(256,256), interpolation=”mitchellcubic”,
crop_to_aspect_ratio=True,
seed=42, shuffle=True, batch_size=32)
As you suggest, I have to replace the previous code with the following code to make the other programs work.
ds, meta = tfds.load(‘citrus_leaves’, with_info=True, split=’train’, shuffle_files=True)
ds = ds.batch(3*3)
for sample in ds.take(1):
images, labels = sample[“image”], sample[“label”]…
Hi Terry…Thank you for your feedback! Did you type the code samples in or copy and paste them? Also, have you tried the code samples in Google Colab to rule out any issues you may have with your local Python environment?
Hi Dear James,
while implementing the “layers.experimental.preprocessing.RandomContrast(factor=0.8)”, I get the following error:
A deterministic GPU implementation of AdjustContrastv2 is not currently available.
[[{{node model/sequential_1/random_contrast/adjust_contrast}}]] [Op:__inference_train_function_3974]
Any suggestion to overcome it? Due to this error, I am not able to run
os.environ[‘TF_DETERMINISTIC_OPS’] = ‘1’
thanks and regards