How to Code the GAN Training Algorithm and Loss Functions

The Generative Adversarial Network, or GAN for short, is an architecture for training a generative model.

The architecture is comprised of two models. The generator that we are interested in, and a discriminator model that is used to assist in the training of the generator. Initially, both of the generator and discriminator models were implemented as Multilayer Perceptrons (MLP), although more recently, the models are implemented as deep convolutional neural networks.

It can be challenging to understand how a GAN is trained and exactly how to understand and implement the loss function for the generator and discriminator models.

In this tutorial, you will discover how to implement the generative adversarial network training algorithm and loss functions.

After completing this tutorial, you will know:

  • How to implement the training algorithm for a generative adversarial network.
  • How the loss function for the discriminator and generator work.
  • How to implement weight updates for the discriminator and generator models in practice.

Kick-start your project with my new book Generative Adversarial Networks with Python, including step-by-step tutorials and the Python source code files for all examples.

Let’s get started.

  • Update Jan/2020: Fixed small typo in description of training algorithm.
How to Code the Generative Adversarial Network Training Algorithm and Loss Functions

How to Code the Generative Adversarial Network Training Algorithm and Loss Functions
Photo by Hilary Charlotte, some rights reserved.

Tutorial Overview

This tutorial is divided into three parts; they are:

  1. How to Implement the GAN Training Algorithm
  2. Understanding the GAN Loss Function
  3. How to Train GAN Models in Practice

Note: The code examples in this tutorial are snippets only, not standalone runnable examples. They are designed to help you develop an intuition for the algorithm and they can be used as the starting point for implementing the GAN training algorithm on your own project.

How to Implement the GAN Training Algorithm

The GAN training algorithm involves training both the discriminator and the generator model in parallel.

The algorithm is summarized in the figure below, taken from the original 2014 paper by Goodfellow, et al. titled “Generative Adversarial Networks.”

Summary of the Generative Adversarial Network Training Algorithm

Summary of the Generative Adversarial Network Training Algorithm.Taken from: Generative Adversarial Networks.

Let’s take some time to unpack and get comfortable with this algorithm.

The outer loop of the algorithm involves iterating over steps to train the models in the architecture. One cycle through this loop is not an epoch: it is a single update comprised of specific batch updates to the discriminator and generator models.

An epoch is defined as one cycle through a training dataset, where the samples in a training dataset are used to update the model weights in mini-batches. For example, a training dataset of 100 samples used to train a model with a mini-batch size of 10 samples would involve 10 mini batch updates per epoch. The model would be fit for a given number of epochs, such as 500.

This is often hidden from you via the automated training of a model via a call to the fit() function and specifying the number of epochs and the size of each mini-batch.

In the case of the GAN, the number of training iterations must be defined based on the size of your training dataset and batch size. In the case of a dataset with 100 samples, a batch size of 10, and 500 training epochs, we would first calculate the number of batches per epoch and use this to calculate the total number of training iterations using the number of epochs.

For example:

In the case of a dataset of 100 samples, a batch size of 10, and 500 epochs, the GAN would be trained for floor(100 / 10) * 500 or 5,000 total iterations.

Next, we can see that one iteration of training results in possibly multiple updates to the discriminator and one update to the generator, where the number of updates to the discriminator is a hyperparameter that is set to 1.

The training process consists of simultaneous SGD. On each step, two minibatches are sampled: a minibatch of x values from the dataset and a minibatch of z values drawn from the model’s prior over latent variables. Then two gradient steps are made simultaneously …

NIPS 2016 Tutorial: Generative Adversarial Networks, 2016.

We can therefore summarize the training algorithm with Python pseudocode as follows:

An alternative approach may involve enumerating the number of training epochs and splitting the training dataset into batches for each epoch.

Updating the discriminator model involves a few steps.

First, a batch of random points from the latent space must be selected for use as input to the generator model to provide the basis for the generated or ‘fake‘ samples. Then a batch of samples from the training dataset must be selected for input to the discriminator as the ‘real‘ samples.

Next, the discriminator model must make predictions for the real and fake samples and the weights of the discriminator must be updated proportional to how correct or incorrect those predictions were. The predictions are probabilities and we will get into the nature of the predictions and the loss function that is minimized in the next section. For now, we can outline what these steps actually look like in practice.

We need a generator and a discriminator model, e.g. such as a Keras model. These can be provided as arguments to the training function.

Next, we must generate points from the latent space and then use the generator model in its current form to generate some fake images. For example:

Note that the size of the latent dimension is also provided as a hyperparameter to the training algorithm.

We then must select a batch of real samples, and this too will be wrapped into a function.

The discriminator model must then make a prediction for each of the generated and real images and the weights must be updated.

Next, the generator model must be updated.

Again, a batch of random points from the latent space must be selected and passed to the generator to generate fake images, and then passed to the discriminator to classify.

The response can then be used to update the weights of the generator model.

It is interesting that the discriminator is updated with two batches of samples each training iteration whereas the generator is only updated with a single batch of samples per training iteration.

Now that we have defined the training algorithm for the GAN, we need to understand how the model weights are updated. This requires understanding the loss function used to train the GAN.

Want to Develop GANs from Scratch?

Take my free 7-day email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Understanding the GAN Loss Function

The discriminator is trained to correctly classify real and fake images.

This is achieved by maximizing the log of predicted probability of real images and the log of the inverted probability of fake images, averaged over each mini-batch of examples.

Recall that we add log probabilities, which is the same as multiplying probabilities, although without vanishing into small numbers. Therefore, we can understand this loss function as seeking probabilities close to 1.0 for real images and probabilities close to 0.0 for fake images, inverted to become larger numbers. The addition of these values means that lower average values of this loss function result in better performance of the discriminator.

Inverting this to a minimization problem, it should not be surprising if you are familiar with developing neural networks for binary classification, as this is exactly the approach used.

This is just the standard cross-entropy cost that is minimized when training a standard binary classifier with a sigmoid output. The only difference is that the classifier is trained on two minibatches of data; one coming from the dataset, where the label is 1 for all examples, and one coming from the generator, where the label is 0 for all examples.

NIPS 2016 Tutorial: Generative Adversarial Networks, 2016.

The generator is more tricky.

The GAN algorithm defines the generator model’s loss as minimizing the log of the inverted probability of the discriminator’s prediction of fake images, averaged over a mini-batch.

This is straightforward, but according to the authors, it is not effective in practice when the generator is poor and the discriminator is good at rejecting fake images with high confidence. The loss function no longer gives good gradient information that the generator can use to adjust weights and instead saturates.

In this case, log(1 − D(G(z))) saturates. Rather than training G to minimize log(1 − D(G(z))) we can train G to maximize log D(G(z)). This objective function results in the same fixed point of the dynamics of G and D but provides much stronger gradients early in learning.

Generative Adversarial Networks, 2014.

Instead, the authors recommend maximizing the log of the discriminator’s predicted probability for fake images.

The change is subtle.

In the first case, the generator is trained to minimize the probability of the discriminator being correct. With this change to the loss function, the generator is trained to maximize the probability of the discriminator being incorrect.

In the minimax game, the generator minimizes the log-probability of the discriminator being correct. In this game, the generator maximizes the log probability of the discriminator being mistaken.

NIPS 2016 Tutorial: Generative Adversarial Networks, 2016.

The sign of this loss function can then be inverted to give a familiar minimizing loss function for training the generator. As such, this is sometimes referred to as the -log D trick for training GANs.

Our baseline comparison is DCGAN, a GAN with a convolutional architecture trained with the standard GAN procedure using the −log D trick.

Wasserstein GAN, 2017.

Now that we understand the GAN loss function, we can look at how the discriminator and the generator model can be updated in practice.

How to Train GAN Models in Practice

The practical implementation of the GAN loss function and model updates is straightforward.

We will look at examples using the Keras library.

We can implement the discriminator directly by configuring the discriminator model to predict a probability of 1 for real images and 0 for fake images and minimizing the cross-entropy loss, specifically the binary cross-entropy loss.

For example, a snippet of our model definition with Keras for the discriminator might look as follows for the output layer and the compilation of the model with the appropriate loss function.

The defined model can be trained for each batch of real and fake samples providing arrays of 1s and 0s for the expected outcome.

The ones() and zeros() NumPy functions can be used to create these target labels, and the Keras function train_on_batch() can be used to update the model for each batch of samples.

The discriminator model will be trained to predict the probability of “realness” of a given input image that can be interpreted as a class label of class=0 for fake and class=1 for real.

The generator is trained to maximize the discriminator predicting a high probability of “realness” for generated images.

This is achieved by updating the generator via the discriminator with the class label of 1 for the generated images. The discriminator is not updated in this operation but provides the gradient information required to update the weights of the generator model.

For example, if the discriminator predicts a low average probability for the batch of generated images, then this will result in a large error signal propagated backward into the generator given the “expected probability” for the samples was 1.0 for real. This large error signal, in turn, results in relatively large changes to the generator to hopefully improve its ability at generating fake samples on the next batch.

This can be implemented in Keras by creating a composite model that combines the generator and discriminator models, allowing the output images from the generator to flow into discriminator directly, and in turn, allow the error signals from the predicted probabilities of the discriminator to flow back through the weights of the generator model.

For example:

The composite model can then be updated using fake images and real class labels.

That completes out tour of the GAN training algorithm, loss function and weight update details for the discriminator and generator models.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Papers

Articles

Summary

In this tutorial, you discovered how to implement the generative adversarial network training algorithm and loss functions.

Specifically, you learned:

  • How to implement the training algorithm for a generative adversarial network.
  • How the loss function for the discriminator and generator work.
  • How to implement weight updates for the discriminator and generator models in practice.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Develop Generative Adversarial Networks Today!

Generative Adversarial Networks with Python

Develop Your GAN Models in Minutes

...with just a few lines of python code

Discover how in my new Ebook:
Generative Adversarial Networks with Python

It provides self-study tutorials and end-to-end projects on:
DCGAN, conditional GANs, image translation, Pix2Pix, CycleGAN
and much more...

Finally Bring GAN Models to your Vision Projects

Skip the Academics. Just Results.

See What's Inside

22 Responses to How to Code the GAN Training Algorithm and Loss Functions

  1. Avatar
    Kate July 12, 2019 at 8:18 am #

    Hey Jason,

    If I’ve worked through the time series book for python as well as Machine Learning Algorithms From Scratch. Do you think I would be able to understand the Develop Generative Adversarial Networks Today book or should I start with one the Deep Learning books first.

    Thanks!

    Kate

  2. Avatar
    Olufikayo Bolu August 10, 2019 at 3:52 pm #

    Hi Jason, thanks for your contributions and explanations. I am considering using GAN for non-image data. How best would you advise I start with the implementation, seeing this example was implemented with Image data. Also could you provide a complete sample code of your implemented example. Thanks in advance

  3. Avatar
    patti February 5, 2020 at 2:53 am #

    Hello Jason, thanks a lot for great tutorials.

    I’m a little confused about generators loss function, so it actually uses binary cross entropy in the train_gan section of code, right?

  4. Avatar
    The Coder heist May 19, 2020 at 5:19 am #

    Hi Jason, thanks for the contribution and explanation. I am trying to develop gan for a CNN LSTM spatio temporal frame prediction algorithm. Please tell me about its feasibility and some references

  5. Avatar
    kevin October 4, 2020 at 2:14 pm #

    Why The when implement gan training we need to make discriminator.trainable = False?

    • Avatar
      Jason Brownlee October 4, 2020 at 2:59 pm #

      So that we don’t update the discriminator when updating the generator.

      • Avatar
        Zeinab May 10, 2021 at 10:32 pm #

        Hello
        Do you have tutorials about text to image using GAN?

  6. Avatar
    Cara Evangeline March 9, 2021 at 12:23 am #

    Hi Jason, Thanks for this great blog!
    I have one question for GAN time-series preditcion. For discriminator we can declare the loss function as binary_crossentropy, what will be the case for generator network?
    Generator will be predicting the next time-step, so will MAE be the loss function or binary_entropy?

    • Avatar
      Jason Brownlee March 9, 2021 at 5:21 am #

      Sorry, I don’t know about GANs for time series, only image data.

  7. Avatar
    MIMI May 19, 2021 at 1:15 am #

    HELLO
    why some use HOG in GAN?
    if there is anyway to communicate i have lots to ask you

  8. Avatar
    Irshad May 19, 2021 at 10:34 pm #

    Thank you, Jason. Great post.

    Consider an example, where 3-dimensional data needs to be generated. As I understand from this article, the Discriminator model’s loss is considered to make the changes in the weights of the Generator model. But how does the Generator knows in which of the dimensions, it has to do the changes in weights. Does it do it for all the dimensions same changes in weights?

    • Avatar
      Jason Brownlee May 20, 2021 at 5:47 am #

      Generally GANs are used for generating image data, I don’t know about generating other data types sorry.

  9. Avatar
    prachi November 27, 2021 at 1:50 am #

    sir i am trying to create a synthetic data set for text detection from scenic images .using gans
    can you provide which technique will be used

    • Adrian Tam
      Adrian Tam November 29, 2021 at 8:39 am #

      What did you tried? There are many different ways to do synthetic data and therefore, various techniques are related. If you first come up with an idea of how to synthesis, it would be easier to talk about the techniques.

Leave a Reply