Building a Softmax Classifier for Images in PyTorch

Last Updated on April 8, 2023

Softmax classifier is a type of classifier in supervised learning. It is an important building block in deep learning networks and the most popular choice among deep learning practitioners.

Softmax classifier is suitable for multiclass classification, which outputs the probability for each of the classes.

This tutorial will teach you how to build a softmax classifier for images data. You will learn how to prepare the dataset, and then learn how to implement softmax classifier using PyTorch. Particularly, you’ll learn:

  • About the Fashion-MNIST dataset.
  • How you can use a Softmax classifier for images in PyTorch.
  • How to build and train a multi-class image classifier in PyTorch.
  • How to plot the results after model training.

Kick-start your project with my book Deep Learning with PyTorch. It provides self-study tutorials with working code.

Let’s get started.

Building a Softmax Classifier for Images in PyTorch.
Picture by Joshua J. Cotten. Some rights reserved.


This tutorial is in three parts; they are

    • Preparing the Dataset
    • Build the Model
    • Train the Model

Preparing the Dataset

The dataset you will use here is Fashion-MNIST. It is a pre-processed and well-organized dataset consisting of 70,000 images, with 60,000 images for training data and 10,000 images for testing data.

Each example in the dataset is a $28\times 28$ pixels grayscale image with a total pixel count of 784. The dataset has 10 classes, and each image is labelled as a fashion item, which is associated with an integer label from 0 through 9.

This dataset can be loaded from torchvision. To make the training faster, we limit the dataset to 4000 samples:

At the first time you fetch the fashion-MNIST dataset, you will see PyTorch downloading it from Internet and saving to a local directory named data:

The dataset train_data above is a list of tuples, which each tuple is an image (in the form of a Python Imaging Library object) and an integer label.

Let’s plot the first 10 images in the dataset with matplotlib.

You should see an image like the following:

PyTorch needs the dataset in PyTorch tensors. Hence you will convert this data by applying the transforms, using the ToTensor() method from PyTorch transforms. This transform can be done transparently in torchvision’s dataset API:

Before proceeding to the model, let’s also split our data into train and validation sets in such a way that the first 3500 images is the training set and the rest is for validation. Normally we want to shuffle the data before the split but we can skip this step to make our code concise.

Want to Get Started With Deep Learning with PyTorch?

Take my free email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Build the Model

In order to build a custom softmax module for image classification, we’ll use nn.Module from the PyTorch library. To keep things simple, we build a model of just one layer.

Now, let’s instantiate our model object. It takes a one-dimensional vector as input and predicts for 10 different classes. Let’s also check how parameters are initialized.

You should see the model’s weight are randomly initialized but it should be in the shape like the following:

Train the Model

You will use stochastic gradient descent for model training along with cross-entropy loss. Let’s fix the learning rate at 0.01. To help training, let’s also load the data into a dataloader for both training and validation sets, and set the batch size at 16.

Now, let’s put everything together and train our model for 200 epochs.

You should see the progress printed once every 10 epochs:

As you can see, the accuracy of the model increases after every epoch and its loss decreases. Here, the accuracy you achieved for the softmax images classifier is around 85 percent. If you use more data and increase the number of epochs, the accuracy may get a lot better. Now let’s see how the plots for loss and accuracy look like.

First the loss plot:

which should look like the following:

Here is the model accuracy plot:

which is like the one below:

Putting everything together, the following is the complete code:


In this tutorial, you learned how to build a softmax classifier for images data. Particularly, you learned:

  • About the Fashion-MNIST dataset.
  • How you can use a softmax classifier for images in PyTorch.
  • How to build and train a multiclass image classifier in PyTorch.
  • How to plot the results after model training.

Get Started on Deep Learning with PyTorch!

Deep Learning with PyTorch

Learn how to build deep learning models

...using the newly released PyTorch 2.0 library

Discover how in my new Ebook:
Deep Learning with PyTorch

It provides self-study tutorials with hundreds of working code to turn you from a novice to expert. It equips you with
tensor operation, training, evaluation, hyperparameter optimization, and much more...

Kick-start your deep learning journey with hands-on exercises

See What's Inside

2 Responses to Building a Softmax Classifier for Images in PyTorch

  1. Dhavan Rathore January 9, 2023 at 8:41 pm #

    Nicely written and explained

  2. Dhavan Rathore January 9, 2023 at 8:42 pm #

    Yes, well said

Leave a Reply