[New Book] Click to get The Beginner's Guide to Data Science!
Use the offer code 20offearlybird to get 20% off. Hurry, sale ends soon!

How to Implement Wasserstein Loss for Generative Adversarial Networks

The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the stability when training the model and provides a loss function that correlates with the quality of generated images.

It is an important extension to the GAN model and requires a conceptual shift away from a discriminator that predicts the probability of a generated image being “real” and toward the idea of a critic model that scores the “realness” of a given image.

This conceptual shift is motivated mathematically using the earth mover distance, or Wasserstein distance, to train the GAN that measures the distance between the data distribution observed in the training dataset and the distribution observed in the generated examples.

In this post, you will discover how to implement Wasserstein loss for Generative Adversarial Networks.

After reading this post, you will know:

  • The conceptual shift in the WGAN from discriminator predicting a probability to a critic predicting a score.
  • The implementation details for the WGAN as minor changes to the standard deep convolutional GAN.
  • The intuition behind the Wasserstein loss function and how implement it from scratch.

Kick-start your project with my new book Generative Adversarial Networks with Python, including step-by-step tutorials and the Python source code files for all examples.

Let’s get started.

How to Implement Wasserstein Loss for Generative Adversarial Networks

How to Implement Wasserstein Loss for Generative Adversarial Networks
Photo by Brandon Levinger, some rights reserved.

Overview

This tutorial is divided into five parts; they are:

  1. GAN Stability and the Discriminator
  2. What Is a Wasserstein GAN?
  3. Implementation Details of the Wasserstein GAN
  4. How to Implement Wasserstein Loss
  5. Common Point of Confusion With Expected Labels

GAN Stability and the Discriminator

Generative Adversarial Networks, or GANs, are challenging to train.

The discriminator model must classify a given input image as real (from the dataset) or fake (generated), and the generator model must generate new and plausible images.

The reason GANs are difficult to train is that the architecture involves the simultaneous training of a generator and a discriminator model in a zero-sum game. Stable training requires finding and maintaining an equilibrium between the capabilities of the two models.

The discriminator model is a neural network that learns a binary classification problem, using a sigmoid activation function in the output layer, and is fit using a binary cross entropy loss function. As such, the model predicts a probability that a given input is real (or fake as 1 minus the predicted) as a value between 0 and 1.

The loss function has the effect of penalizing the model proportionally to how far the predicted probability distribution differs from the expected probability distribution for a given image. This provides the basis for the error that is back propagated through the discriminator and the generator in order to perform better on the next batch.

The WGAN relaxes the role of the discriminator when training a GAN and proposes the alternative of a critic.

What Is a Wasserstein GAN?

The Wasserstein GAN, or WGAN for short, was introduced by Martin Arjovsky, et al. in their 2017 paper titled “Wasserstein GAN.”

It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.

Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.

This change is motivated by a mathematical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples. The argument contrasts different distribution distance measures, such as Kullback-Leibler (KL) divergence, Jensen-Shannon (JS) divergence, and the Earth-Mover (EM) distance, referred to as Wasserstein distance.

The most fundamental difference between such distances is their impact on the convergence of sequences of probability distributions.

Wasserstein GAN, 2017.

They demonstrate that a critic neural network can be trained to approximate the Wasserstein distance, and, in turn, used to effectively train a generator model.

… we define a form of GAN called Wasserstein-GAN that minimizes a reasonable and efficient approximation of the EM distance, and we theoretically show that the corresponding optimization problem is sound.

Wasserstein GAN, 2017.

Importantly, the Wasserstein distance has the properties that it is continuous and differentiable and continues to provide a linear gradient, even after the critic is well trained.

The fact that the EM distance is continuous and differentiable a.e. means that we can (and should) train the critic till optimality. […] the more we train the critic, the more reliable gradient of the Wasserstein we get, which is actually useful by the fact that Wasserstein is differentiable almost everywhere.

Wasserstein GAN, 2017.

This is unlike the discriminator model that, once trained, may fail to provide useful gradient information for updating the generator model.

The discriminator learns very quickly to distinguish between fake and real, and as expected provides no reliable gradient information. The critic, however, can’t saturate, and converges to a linear function that gives remarkably clean gradients everywhere.

Wasserstein GAN, 2017.

The benefit of the WGAN is that the training process is more stable and less sensitive to model architecture and choice of hyperparameter configurations.

… training WGANs does not require maintaining a careful balance in training of the discriminator and the generator, and does not require a careful design of the network architecture either. The mode dropping phenomenon that is typical in GANs is also drastically reduced.

Wasserstein GAN, 2017.

Perhaps most importantly, the loss of the discriminator appears to relate to the quality of images created by the generator.

Specifically, the lower the loss of the critic when evaluating generated images, the higher the expected quality of the generated images. This is important as unlike other GANs that seek stability in terms of finding an equilibrium between two models, the WGAN seeks convergence, lowering generator loss.

To our knowledge, this is the first time in GAN literature that such a property is shown, where the loss of the GAN shows properties of convergence. This property is extremely useful when doing research in adversarial networks as one does not need to stare at the generated samples to figure out failure modes and to gain information on which models are doing better over others.

Wasserstein GAN, 2017.

Want to Develop GANs from Scratch?

Take my free 7-day email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Implementation Details of the Wasserstein GAN

Although the theoretical grounding for the WGAN is dense, the implementation of a WGAN requires a few minor changes to the standard deep convolutional GAN, or DCGAN.

Those changes are as follows:

  • Use a linear activation function in the output layer of the critic model (instead of sigmoid).
  • Use Wasserstein loss to train the critic and generator models that promote larger difference between scores for real and generated images.
  • Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).

In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [−0.01, 0.01]l ) after each gradient update.

Wasserstein GAN, 2017.

  • Update the critic model more times than the generator each iteration (e.g. 5).
  • Use the RMSProp version of gradient descent with small learning rate and no momentum (e.g. 0.00005).

… we report that WGAN training becomes unstable at times when one uses a momentum based optimizer such as Adam […] We therefore switched to RMSProp …

Wasserstein GAN, 2017.

The image below provides a summary of the main training loop for training a WGAN, taken from the paper. Note the listing of recommended hyperparameters used in the model.

Algorithm for the Wasserstein Generative Adversarial Networks.
Taken from: Wasserstein GAN.

How to Implement Wasserstein Loss

The Wasserstein loss function seeks to increase the gap between the scores for real and generated images.

We can summarize the function as it is described in the paper as follows:

  • Critic Loss = [average critic score on real images] – [average critic score on fake images]
  • Generator Loss = -[average critic score on fake images]

Where the average scores are calculated across a mini-batch of samples.

This is precisely how the loss is implemented for graph-based deep learning frameworks such as PyTorch and TensorFlow.

The calculations are straightforward to interpret once we recall that stochastic gradient descent seeks to minimize loss.

In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the critic to output larger scores for fake images. For example, an average score of 10 becomes -10, an average score of 50 becomes -50, which is smaller, and so on.

In the case of the critic, a larger score for real images results in a larger resulting loss for the critic, penalizing the model. This encourages the critic to output smaller scores for real images. For example, an average score of 20 for real images and 50 for fake images results in a loss of -30; an average score of 10 for real images and 50 for fake images results in a loss of -40, which is better, and so on.

The sign of the loss does not matter in this case, as long as loss for real images is a small number and the loss for fake images is a large number. The Wasserstein loss encourages the critic to separate these numbers.

We can also reverse the situation and encourage the critic to output a large score for real images and a small score for fake images and achieve the same result. Some implementations make this change.

In the Keras deep learning library (and some others), we cannot implement the Wasserstein loss function directly as described in the paper and as implemented in PyTorch and TensorFlow. Instead, we can achieve the same effect without having the calculation of the loss for the critic dependent upon the loss calculated for real and fake images.

A good way to think about this is a negative score for real images and a positive score for fake images, although this negative/positive split of scores learned during training is not required; just larger and smaller is sufficient.

  • Small Critic Score (e.g.< 0): Real – Large Critic Score (e.g. >0): Fake

We can multiply the average predicted score by -1 in the case of fake images so that larger averages become smaller averages and the gradient is in the correct direction, i.e. minimizing loss. For example, average scores on fake images of [0.5, 0.8, and 1.0] across three batches of fake images would become [-0.5, -0.8, and -1.0] when calculating weight updates.

  • Loss For Fake Images = -1 * Average Critic Score

No change is needed for the case of real scores, as we want to encourage smaller average scores for real images.

  • Loss For Real Images = Average Critic Score

This can be implemented consistently by assigning an expected outcome target of -1 for fake images and 1 for real images and implementing the loss function as the expected label multiplied by the average score. The -1 label will be multiplied by the average score for fake images and encourage a larger predicted average, and the +1 label will be multiplied by the average score for real images and have no effect, encouraging a smaller predicted average.

  • Wasserstein Loss = Label * Average Critic Score

Or

  • Wasserstein Loss(Real Images) = 1 * Average Predicted Score
  • Wasserstein Loss(Fake Images) = -1 * Average Predicted Score

We can implement this in Keras by assigning the expected labels of -1 and 1 for fake and real images respectively. The inverse labels could be used to the same effect, e.g. -1 for real and +1 for fake to encourage small scores for fake images and large scores for real images. Some developers do implement the WGAN in this alternate way, which is just as correct.

The loss function can be implemented by multiplying the expected label for each sample by the predicted score (element wise), then calculating the mean.

The above function is the elegant way to implement the loss function; an alternative, less-elegant implementation that might be more intuitive is as follows:

In Keras, the mean function can be implemented using the Keras backend API to ensure the mean is calculated across samples in the provided tensors; for example:

Now that we know how to implement the Wasserstein loss function in Keras, let’s clarify one common point of misunderstanding.

Common Point of Confusion With Expected Labels

Recall we are using the expected labels of -1 for fake images and +1 for real images.

A common point of confusion is that a perfect critic model will output -1 for every fake image and +1 for every real image.

This is incorrect.

Again, recall we are using stochastic gradient descent to find the set of weights in the critic (and generator) models that minimize the loss function.

We have established that we want the critic model to output larger scores on average for fake images and smaller scores on average for real images. We then designed a loss function to encourage this outcome.

This is the key point about loss functions used to train neural network models. They encourage a desired model behavior, and they do not have to achieve this by providing the expected outcomes. In this case, we defined our Wasserstein loss function to interpret the average score predicted by the critic model and used labels for the real and fake cases to help with this interpretation.

So what is a good loss for real and fake images under Wasserstein loss?

Wasserstein is not an absolute and comparable loss for comparing across GAN models. Instead, it is relative and depends on your model configuration and dataset. What is important is that it is consistent for a given critic model and convergence of the generator (better loss) does correlate with better generated image quality.

It could be negative scores for real images and positive scores for fake images, but this is not required. All scores could be positive or all scores could be negative.

The loss function only encourages a separation between scores for fake and real images as larger and smaller, not necessarily positive and negative.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Papers

Articles

Summary

In this post, you discovered how to implement Wasserstein loss for Generative Adversarial Networks.

Specifically, you learned:

  • The conceptual shift in the WGAN from discriminator predicting a probability to a critic predicting a score.
  • The implementation details for the WGAN as minor changes to the standard deep convolutional GAN.
  • The intuition behind the Wasserstein loss function and how implement it from scratch.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Develop Generative Adversarial Networks Today!

Generative Adversarial Networks with Python

Develop Your GAN Models in Minutes

...with just a few lines of python code

Discover how in my new Ebook:
Generative Adversarial Networks with Python

It provides self-study tutorials and end-to-end projects on:
DCGAN, conditional GANs, image translation, Pix2Pix, CycleGAN
and much more...

Finally Bring GAN Models to your Vision Projects

Skip the Academics. Just Results.

See What's Inside

36 Responses to How to Implement Wasserstein Loss for Generative Adversarial Networks

  1. Avatar
    Oleg July 17, 2019 at 3:40 am #

    Are you have full example of implimentation WGAN based on, for example, https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py

  2. Avatar
    Joseph July 19, 2019 at 5:30 am #

    I’m trying to understand… How we using -1 for true, and -1 for false calculations. Why this can’t be more better defined. Math wise.

  3. Avatar
    Vincent January 31, 2020 at 6:32 am #

    Correct me if I’m wrong but it looks like the paper does gradient acsent instead of descent in line 6 of Algrithm 1. They’re seeking to maximize the critic loss and minimize generator loss.

    • Avatar
      Jason Brownlee January 31, 2020 at 7:59 am #

      Yes, gradient descent.

    • Avatar
      François October 1, 2020 at 1:35 am #

      You’re right, Vincent: it’s “+alpha” (ascent) for the discriminator and “-alpha” (descent) for the generator.

      This is how the adversarial training is implemented: we have ” -fw(fake)” (with the same sign) in both losses, but the directions of the gradient update are inverted.

      So this part only concerns the generator:
      > The calculations are straightforward to interpret once we recall that stochastic gradient descent seeks to minimize loss.

      By looking at the equations:
      – the discriminator wants to maximize -fw(fake) minimize fw(fake)
      – the generator wants to minimize -fw(fake)= maximize fw(fake)
      so it seems fw is a score for the realness of an image: a bigger score is equivalent to “realer”.

      So the explanation for the generator is correct:
      > In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the critic to output larger scores for fake images. For example, an average score of 10 becomes -10, an average score of 50 becomes -50, which is smaller, and so on.

      Though I would rephrase it this way to make it more clear:
      > In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the generator to synthesize images with a high score (meaning realistic images).

      And the explanation for the discriminator is inverted:
      > In the case of the critic, a larger score for real images results in a larger resulting loss for the critic, ***penalizing the model. This encourages the critic to output smaller scores for real images.*** For example, an average score of 20 for real images and 50 for fake images results in a loss of -30; an average score of 10 for real images and 50 for fake images results in a loss of -40, ***which is better***, and so on. The sign of the loss does not matter in this case, as long as ***loss for real images is a small number and the loss for fake images is a large number***. The Wasserstein loss encourages the critic to separate these numbers.

      I hope I make sense…

      • Avatar
        Jason Brownlee October 1, 2020 at 6:30 am #

        Thanks for sharing!

      • Avatar
        Vincent Roca November 19, 2020 at 8:52 pm #

        Thanks for this remark. Without it I would have thought about the absence of adversity between generator and critic for many hours.

      • Avatar
        Andreas September 14, 2021 at 4:13 pm #

        If more realistic images get a higher score, then why would real image have score 20 and fake have score 50.

  4. Avatar
    Hashem Hashemi June 17, 2020 at 12:12 pm #

    So the whole of WGAN can be summed up in (a) set targets to -1/+1 instead of 0/1 and (b) clip the discriminator weights? Why do a lot of these ML ideas feel like the authors did a ton of experiments, found some tweak that happens to work a bit better, then set about justifying it with obscure math, often named after German scientists? I mean the whole earth-mover explanation seems kid of hand-wavy. How does any of this help avoid mode collapse — which appears to be due to the generator getting desensitized to the latent space and relying on backprop to switch high-level features on/off instead. I’m not seeing how WGAN helps avoid that in any way.

    • Avatar
      Hashem Hashemi June 17, 2020 at 12:13 pm #

      Thanks for the pretty descent walkthrough, by the way. 🙂 Most readable I’ve seen.

    • Avatar
      Jason Brownlee June 17, 2020 at 1:42 pm #

      That is the majority of science: small tweaks! 🙂

      Why does it work!? We cannot answer that well for many things at all.

      I don’t even know why my car engine works. It may have a computer in it.

  5. Avatar
    Parnian October 18, 2020 at 12:56 pm #

    Hi Jason. Thanks for the tutorial. I just do not understand why the original loss version cannot be implemented in Keras.

  6. Avatar
    ali November 1, 2020 at 8:57 am #

    Hi Jason, thank you for your useful post.
    I have found a different implementation of wgan loss function here:
    https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py#L134
    do you have any idea about this?

    • Avatar
      Jason Brownlee November 1, 2020 at 1:14 pm #

      Sorry I do not, I recommend talking to the authors directly.

  7. Avatar
    Ailsor March 30, 2021 at 10:40 pm #

    How can we implement this in Semi-Supervised GANs?
    I’m following your tutorial in https://machinelearningmastery.com/semi-supervised-generative-adversarial-network/

    • Avatar
      Jason Brownlee March 31, 2021 at 6:04 am #

      Sorry, I cannot write the code for you.

      Perhaps try adapting the tutorial to use this alternate loss.

  8. Avatar
    Ori June 8, 2021 at 9:56 pm #

    Hi, thank you for the article. The formula for the critic loss function is wrong – it should be the opposite.

  9. Avatar
    Ori June 8, 2021 at 9:58 pm #

    And also the description after the formula is wrong. The score needs to be larger for the real images and smaller for the fake ones.

  10. Avatar
    farnaz July 13, 2021 at 5:41 pm #

    Hi, Thank you for the tutorial, Do you have any implementation for WGAN_GP loss?

  11. Avatar
    Taylor August 24, 2023 at 7:44 am #

    Hi Jason. Thanks for the great article as always. As others have pointed out, one of the loss functions signs is wrong IF we assume we will do gradient descent to find a local minimum. Since the original paper (the image you show) has gradient ASCENT for the critic (plus sign in line 6) and gradient DESCENT for the generator (minus sign in line 11), if we were to optimize based on gradient descent, which is what you describe, the sign of your critic loss function needs to be reversed.

    • Avatar
      James Carmichael August 24, 2023 at 8:59 am #

      Thank you for your feedback Taylor!

  12. Avatar
    Francesco October 5, 2023 at 9:09 pm #

    Hello Jason,

    Thank you for the amazing tutorial!

    I tried re adapting it and found that the critic loss and the generator loss reach values of around -2000. Is this normal or did I do something wrong in the implementation?

    • Avatar
      James Carmichael October 6, 2023 at 9:10 am #

      Hi Francesco…Did you copy and paste code of did you type it in? Also, what modifications did you make?

  13. Avatar
    Michio January 19, 2024 at 2:57 am #

    Hi Jason,
    Thanks for this tutorial.
    I just wonder why you call GANs zero-sum game?
    I understand that GANs plays mini-max game.
    While the discriminator (the original GANs) attempts to maximize the probability that the it correctly classifies the real samples (Class 1) and the synthetic samples (Class 0), the generator attempts to minimize that the discriminator classifies the synthetic data as synthetic (Class 0).
    As a result, the generator learns the probability distribution of the real data.
    Thus, it is generating a value (positive value), a realistic synthetic data (positive value) to the level were the discriminator simply get confused (zero value).
    Zero sum game would not generate any positive value as a whole (positive plus zero).

Leave a Reply