SALE! Use code blackfriday for 40% off everything!
Hurry, sale ends soon! Click to see the full catalog.

Last Updated on November 28, 2022

Structuring the data pipeline in a way that it can be effortlessly linked to your deep learning model is an important aspect of any deep learning-based system. PyTorch packs everything to do just that.

While in the previous tutorial, we used simple datasets, we’ll need to work with larger datasets in real world scenarios in order to fully exploit the potential of deep learning and neural networks.

In this tutorial, you’ll learn how to build custom datasets in PyTorch. While the focus here remains only on the image data, concepts learned in this session can be applied to any form of dataset such as text or tabular datasets. So, here you’ll learn:

• How to work with pre-loaded image datasets in PyTorch.
• How to apply torchvision transforms on preloaded datasets.
• How to build custom image dataset class in PyTorch and apply various transforms on it.

Let’s get started.

Picture by Uriel SC. Some rights reserved.

## Overview

This tutorial is in three parts; they are

• Applying Torchvision Transforms on Image Datasets
• Building Custom Image Datasets

A variety of preloaded datasets such as CIFAR-10, MNIST, Fashion-MNIST, etc. are available in the PyTorch domain library. You can import them from torchvision and perform your experiments. Additionally, you can benchmark your model using these datasets.

We’ll move on by importing Fashion-MNIST dataset from torchvision. The Fashion-MNIST dataset includes 70,000 grayscale images in 28×28 pixels, divided into ten classes, and each class contains 7,000 images. There are 60,000 images for training and 10,000 for testing.

Let’s start by importing a few libraries we’ll use in this tutorial.

Let’s also define a helper function to display the sample elements in the dataset using matplotlib.

Now, we’ll load the Fashion-MNIST dataset, using the function FashionMNIST() from torchvision.datasets. This function takes some arguments:

• root: specifies the path where we are going to store our data.
• train: indicates whether it’s train or test data. We’ll set it to False as we don’t yet need it for training.
• download: set to True, meaning it will download the data from the internet.
• transform: allows us to use any of the available transforms that we need to apply on our dataset.

Let’s check the class names along with their corresponding labels we have in the Fashion-MNIST dataset.

It prints

Similarly, for class labels:

It prints

Here is how we can visualize the first element of the dataset with its corresponding label using the helper function defined above.

First element of the Fashion MNIST dataset

## Applying Torchvision Transforms on Image Datasets

In many cases, we’ll have to apply several transforms before feeding the images to neural networks. For instance, a lot of times we’ll need to RandomCrop the images for data augmentation.

As you can see below, PyTorch enables us to choose from a variety of transforms.

This shows all available transform functions:

As an example, let’s apply the RandomCrop transform to the Fashion-MNIST images and convert them to a tensor. We can use transform.Compose to combine multiple transforms as we learned from the previous tutorial.

This prints

As you can see image has now been cropped to $16\times 16$ pixels. Now, let’s plot the first element of the dataset to see how they have been randomly cropped.

This shows the following image

Cropped image from Fashion MNIST dataset

Putting everything together, the complete code is as follows:

## Building Custom Image Datasets

Until now we have been discussing prebuilt datasets in PyTorch, but what if we have to build a custom dataset class for our image dataset? While in the previous tutorial we only had a simple overview about the components of the Dataset class, here we’ll build a custom image dataset class from scratch.

Firstly, in the constructor we define the parameters of the class. The __init__ function in the class instantiates the Dataset object. The directory where images and annotations are stored is initialized along with the transforms if we want to apply them on our dataset later. Here we assume we have some images in a directory structure like the following:

and the annotation is a CSV file like the following, located under the root directory of the images (i.e., “attface” above):

where the first column of the CSV data is the path to the image and the second column is the label.

Similarly, we define the __len__ function in the class that returns the total number of samples in our image dataset while the __getitem__ method reads and returns a data element from the dataset at a given index.

Now, we can create our dataset object and apply the transforms on it. We assume the image data are located under the directory named “attface” and the annotation CSV file is at “attface/imagedata.csv”. Then the dataset is created as follows:

Optionally, you can add the transform function to the dataset as well:

You can use this custom image dataset class to any of your datasets stored in your directory and apply the transforms for your requirements.

## Summary

In this tutorial, you learned how to work with image datasets and transforms in PyTorch. Particularly, you learned:

• How to work with pre-loaded image datasets in PyTorch.
• How to apply torchvision transforms on pre-loaded datasets.
• How to build custom image dataset class in PyTorch and apply various transforms on it.