How to Implement CycleGAN Models From Scratch With Keras

The Cycle Generative adversarial Network, or CycleGAN for short, is a generator model for converting images from one domain to another domain.

For example, the model can be used to translate images of horses to images of zebras, or photographs of city landscapes at night to city landscapes during the day.

The benefit of the CycleGAN model is that it can be trained without paired examples. That is, it does not require examples of photographs before and after the translation in order to train the model, e.g. photos of the same city landscape during the day and at night. Instead, it is able to use a collection of photographs from each domain and extract and harness the underlying style of images in the collection in order to perform the translation.

The model is very impressive but has an architecture that appears quite complicated to implement for beginners.

In this tutorial, you will discover how to implement the CycleGAN architecture from scratch using the Keras deep learning framework.

After completing this tutorial, you will know:

  • How to implement the discriminator and generator models.
  • How to define composite models to train the generator models via adversarial and cycle loss.
  • How to implement the training process to update model weights each training iteration.

Discover how to develop DCGANs, conditional GANs, Pix2Pix, CycleGANs, and more with Keras in my new GANs book, with 29 step-by-step tutorials and full source code.

Let’s get started.

How to Develop CycleGAN Models From Scratch With Keras

How to Develop CycleGAN Models From Scratch With Keras
Photo by anokarina, some rights reserved.

Tutorial Overview

This tutorial is divided into five parts; they are:

  1. What Is the CycleGAN Architecture?
  2. How to Implement the CycleGAN Discriminator Model
  3. How to Implement the CycleGAN Generator Model
  4. How to Implement Composite Models for Least Squares and Cycle Loss
  5. How to Update Discriminator and Generator Models

What Is the CycleGAN Architecture?

The CycleGAN model was described by Jun-Yan Zhu, et al. in their 2017 paper titled “Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.”

The model architecture is comprised of two generator models: one generator (Generator-A) for generating images for the first domain (Domain-A) and the second generator (Generator-B) for generating images for the second domain (Domain-B).

  • Generator-A -> Domain-A
  • Generator-B -> Domain-B

The generator models perform image translation, meaning that the image generation process is conditional on an input image, specifically an image from the other domain. Generator-A takes an image from Domain-B as input and Generator-B takes an image from Domain-A as input.

  • Domain-B -> Generator-A -> Domain-A
  • Domain-A -> Generator-B -> Domain-B

Each generator has a corresponding discriminator model.

The first discriminator model (Discriminator-A) takes real images from Domain-A and generated images from Generator-A and predicts whether they are real or fake. The second discriminator model (Discriminator-B) takes real images from Domain-B and generated images from Generator-B and predicts whether they are real or fake.

  • Domain-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Generator-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Discriminator-B -> [Real/Fake]
  • Domain-A -> Generator-B -> Discriminator-B -> [Real/Fake]

The discriminator and generator models are trained in an adversarial zero-sum process, like normal GAN models.

The generators learn to better fool the discriminators and the discriminators learn to better detect fake images. Together, the models find an equilibrium during the training process.

Additionally, the generator models are regularized not just to create new images in the target domain, but instead create translated versions of the input images from the source domain. This is achieved by using generated images as input to the corresponding generator model and comparing the output image to the original images.

Passing an image through both generators is called a cycle. Together, each pair of generator models are trained to better reproduce the original source image, referred to as cycle consistency.

  • Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A

There is one further element to the architecture referred to as the identity mapping.

This is where a generator is provided with images as input from the target domain and is expected to generate the same image without change. This addition to the architecture is optional, although it results in a better matching of the color profile of the input image.

  • Domain-A -> Generator-A -> Domain-A
  • Domain-B -> Generator-B -> Domain-B

Now that we are familiar with the model architecture, we can take a closer look at each model in turn and how they can be implemented.

The paper provides a good description of the models and training process, although the official Torch implementation was used as the definitive description for each model and training process and provides the basis for the the model implementations described below.

Want to Develop GANs from Scratch?

Take my free 7-day email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Download Your FREE Mini-Course

How to Implement the CycleGAN Discriminator Model

The discriminator model is responsible for taking a real or generated image as input and predicting whether it is real or fake.

The discriminator model is implemented as a PatchGAN model.

For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

The PatchGAN was described in the 2016 paper titled “Precomputed Real-time Texture Synthesis With Markovian Generative Adversarial Networks” and was used in the pix2pix model for image translation described in the 2016 paper titled “Image-to-Image Translation with Conditional Adversarial Networks.”

The architecture is described as discriminating an input image as real or fake by averaging the prediction for nxn squares or patches of the source image.

… we design a discriminator architecture – which we term a PatchGAN – that only penalizes structure at the scale of patches. This discriminator tries to classify if each NxN patch in an image is real or fake. We run this discriminator convolutionally across the image, averaging all responses to provide the ultimate output of D.

Image-to-Image Translation with Conditional Adversarial Networks, 2016.

This can be implemented directly by using a somewhat standard deep convolutional discriminator model.

Instead of outputting a single value like a traditional discriminator model, the PatchGAN discriminator model can output a square or one-channel feature map of predictions. The 70×70 refers to the effective receptive field of the model on the input, not the actual shape of the output feature map.

The receptive field of a convolutional layer refers to the number of pixels that one output of the layer maps to in the input to the layer. The effective receptive field refers to the mapping of one pixel in the output of a deep convolutional model (multiple layers) to the input image. Here, the PatchGAN is an approach to designing a deep convolutional network based on the effective receptive field, where one output activation of the model maps to a 70×70 patch of the input image, regardless of the size of the input image.

The PatchGAN has the effect of predicting whether each 70×70 patch in the input image is real or fake. These predictions can then be averaged to give the output of the model (if needed) or compared directly to a matrix (or a vector if flattened) of expected values (e.g. 0 or 1 values).

The discriminator model described in the paper takes 256×256 color images as input and defines an explicit architecture that is used on all of the test problems. The architecture uses blocks of Conv2D-InstanceNorm-LeakyReLU layers, with 4×4 filters and a 2×2 stride.

Let Ck denote a 4×4 Convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2. After the last layer, we apply a convolution to produce a 1-dimensional output. We do not use InstanceNorm for the first C64 layer. We use leaky ReLUs with a slope of 0.2.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

The architecture for the discriminator is as follows:

  • C64-C128-C256-C512

This is referred to as a 3-layer PatchGAN in the CycleGAN and Pix2Pix nomenclature, as excluding the first hidden layer, the model has three hidden layers that could be scaled up or down to give different sized PatchGAN models.

Not listed in the paper, the model also has a final hidden layer C512 with a 1×1 stride, and an output layer C1, also with a 1×1 stride with a linear activation function. Given the model is mostly used with 256×256 sized images as input, the size of the output feature map of activations is 16×16. If 128×128 images were used as input, then the size of the output feature map of activations would be 8×8.

The model does not use batch normalization; instead, instance normalization is used.

Instance normalization was described in the 2016 paper titled “Instance Normalization: The Missing Ingredient for Fast Stylization.” It is a very simple type of normalization and involves standardizing (e.g. scaling to a standard Gaussian) the values on each feature map.

The intent is to remove image-specific contrast information from the image during image generation, resulting in better generated images.

The key idea is to replace batch normalization layers in the generator architecture with instance normalization layers, and to keep them at test time (as opposed to freeze and simplify them out as done for batch normalization). Intuitively, the normalization process allows to remove instance-specific contrast information from the content image, which simplifies generation. In practice, this results in vastly improved images.

Instance Normalization: The Missing Ingredient for Fast Stylization, 2016.

Although designed for generator models, it can also prove effective in discriminator models.

An implementation of instance normalization is provided in the keras-contrib project that provides early access to community-supplied Keras features.

The keras-contrib library can be installed via pip as follows:

Or, if you are using an Anaconda virtual environment, such as on EC2:

The new InstanceNormalization layer can then be used as follows:

The “axis” argument is set to -1 to ensure that features are normalized per feature map.

The network weights are initialized to Gaussian random numbers with a standard deviation of 0.02, as is described for DCGANs more generally.

Weights are initialized from a Gaussian distribution N (0, 0.02).

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

The discriminator model is updated using a least squares loss (L2), a so-called Least-Squared Generative Adversarial Network, or LSGAN.

… we replace the negative log likelihood objective by a least-squares loss. This loss is more stable during training and generates higher quality results.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

This can be implemented using “mean squared error” between the target values of class=1 for real images and class=0 for fake images.

Additionally, the paper suggests dividing the loss for the discriminator by half during training, in an effort to slow down updates to the discriminator relative to the generator.

In practice, we divide the objective by 2 while optimizing D, which slows down the rate at which D learns, relative to the rate of G.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

This can be achieved by setting the “loss_weights” argument to 0.5 when compiling the model. Note that this weighting does not appear to be implemented in the official Torch implementation when updating discriminator models are defined in the fDx_basic() function.

We can tie all of this together in the example below with a define_discriminator() function that defines the PatchGAN discriminator. The model configuration matches the description in the appendix of the paper with additional details from the official Torch implementation defined in the defineD_n_layers() function.

Note: the plot_model() function requires that both the pydot and pygraphviz libraries are installed. If this is a problem, you can comment out both the import and call to this function.

Running the example summarizes the model showing the size inputs and outputs for each layer.

A plot of the model architecture is also created to help get an idea of the inputs, outputs, and transitions of the image data through the model.

Plot of the PatchGAN Discriminator Model for the CycleGAN

Plot of the PatchGAN Discriminator Model for the CycleGAN

How to Implement the CycleGAN Generator Model

The CycleGAN Generator model takes an image as input and generates a translated image as output.

The model uses a sequence of downsampling convolutional blocks to encode the input image, a number of residual network (ResNet) convolutional blocks to transform the image, and a number of upsampling convolutional blocks to generate the output image.

Let c7s1-k denote a 7×7 Convolution-InstanceNormReLU layer with k filters and stride 1. dk denotes a 3×3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2. Reflection padding was used to reduce artifacts. Rk denotes a residual block that contains two 3 × 3 convolutional layers with the same number of filters on both layer. uk denotes a 3 × 3 fractional-strided-ConvolutionInstanceNorm-ReLU layer with k filters and stride 1/2.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

The architecture for the 6-resnet block generator for 128×128 images is as follows:

  • c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3

First, we need a function to define the ResNet blocks. These are blocks comprised of two 3×3 CNN layers where the input to the block is concatenated to the output of the block, channel-wise.

This is implemented in the resnet_block() function that creates two Conv-InstanceNorm blocks with 3×3 filters and 1×1 stride and without a ReLU activation after the second block, matching the official Torch implementation in the build_conv_block() function. Same padding is used instead of reflection padded recommended in the paper for simplicity.

Next, we can define a function that will create the 9-resnet block version for 256×256 input images. This can easily be changed to the 6-resnet block version by setting image_shape to (128x128x3) and n_resnet function argument to 6.

Importantly, the model outputs pixel values with the shape as the input and pixel values are in the range [-1, 1], typical for GAN generator models.

The generator model is not compiled as it is trained via a composite model, seen in the next section.

Tying this together, the complete example is listed below.

Running the example first summarizes the model.

A Plot of the generator model is also created, showing the skip connections in the ResNet blocks.

Plot of the Generator Model for the CycleGAN

Plot of the Generator Model for the CycleGAN

How to Implement Composite Models for Least Squares and Cycle Loss

The generator models are not updated directly. Instead, the generator models are updated via composite models.

An update to each generator model involves changes to the model weights based on four concerns:

  • Adversarial loss (L2 or mean squared error).
  • Identity loss (L1 or mean absolute error).
  • Forward cycle loss (L1 or mean absolute error).
  • Backward cycle loss (L1 or mean absolute error).

The adversarial loss is the standard approach for updating the generator via the discriminator, although in this case, the least squares loss function is used instead of the negative log likelihood (e.g. binary cross entropy).

First, we can use our function to define the two generators and two discriminators used in the CycleGAN.

A composite model is required for each generator model that is responsible for only updating the weights of that generator model, although it is required to share the weights with the related discriminator model and the other generator model.

This can be achieved by marking the weights of the other models as not trainable in the context of the composite model to ensure we are only updating the intended generator.

The model can be constructed piecewise using the Keras functional API.

The first step is to define the input of the real image from the source domain, pass it through our generator model, then connect the output of the generator to the discriminator and classify it as real or fake.

Next, we can connect the identity mapping element with a new input for the real image from the target domain, pass it through our generator model, and output the (hopefully) untranslated image directly.

So far, we have a composite model with two real image inputs and a discriminator classification and identity image output. Next, we need to add the forward and backward cycles.

The forward cycle can be achieved by connecting the output of our generator to the other generator, the output of which can be compared to the input to our generator and should be identical.

The backward cycle is more complex and involves the input for the real image from the target domain passing through the other generator, then passing through our generator, which should match the real image from the target domain.

That’s it.

We can then define this composite model with two inputs: one real image for the source and the target domain, and four outputs, one for the discriminator, one for the generator for the identity mapping, one for the other generator for the forward cycle, and one from our generator for the backward cycle.

The adversarial loss for the discriminator output uses least squares loss which is implemented as L2 or mean squared error. The outputs from the generators are compared to images and are optimized using L1 loss implemented as mean absolute error.

The generator is updated as a weighted average of the four loss values. The adversarial loss is weighted normally, whereas the forward and backward cycle loss is weighted using a parameter called lambda and is set to 10, e.g. 10 times more important than adversarial loss. The identity loss is also weighted as a fraction of the lambda parameter and is set to 0.5 * 10 or 5 in the official Torch implementation.

We can tie all of this together and define the function define_composite_model() for creating a composite model for training a given generator model.

This function can then be called to prepare a composite model for training both the g_model_AtoB generator model and the g_model_BtoA model; for example:

Summarizing and plotting the composite model is a bit of a mess as it does not help to see the inputs and outputs of the model clearly.

We can summarize the inputs and outputs for each of the composite models below. Recall that we are sharing or reusing the same set of weights if a given model is used more than once in the composite model.

Generator-A Composite Model

Only Generator-A weights are trainable and weights for other models and not trainable.

  • Adversarial Loss: Domain-B -> Generator-A -> Domain-A -> Discriminator-A -> [real/fake]
  • Identity Loss: Domain-A -> Generator-A -> Domain-A
  • Forward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Backward Cycle Loss: Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A

Generator-B Composite Model

Only Generator-B weights are trainable and weights for other models are not trainable.

  • Adversarial Loss: Domain-A -> Generator-B -> Domain-B -> Discriminator-B -> [real/fake]
  • Identity Loss: Domain-B -> Generator-B -> Domain-B
  • Forward Cycle Loss: Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A
  • Backward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B

A complete example of creating all of the models is listed below for completeness.

How to Update Discriminator and Generator Models

Training the defined models is relatively straightforward.

First, we must define a helper function that will select a batch of real images and the associated target (1.0).

Similarly, we need a function to generate a batch of fake images and the associated target (0.0).

Now, we can define the steps of a single training iteration. We will model the order of updates based on the implementation in the official Torch implementation in the OptimizeParameters() function (Note: the official code uses a more confusing inverted naming convention).

  1. Update Generator-B (A->B)
  2. Update Discriminator-B
  3. Update Generator-A (B->A)
  4. Update Discriminator-A

First, we must select a batch of real images by calling generate_real_samples() for both Domain-A and Domain-B.

Typically, the batch size (n_batch) is set to 1. In this case, we will assume 256×256 input images, which means the n_patch for the PatchGAN discriminator will be 16.

Next, we can use the batches of selected real images to generate corresponding batches of generated or fake images.

The paper describes using a pool of previously generated images from which examples are randomly selected and used to update the discriminator model, where the pool size was set to 50 images.

… [we] update the discriminators using a history of generated images rather than the ones produced by the latest generators. We keep an image buffer that stores the 50 previously created images.

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, 2017.

This can be implemented using a list for each domain and a using a function to populate the pool, then randomly replace elements from the pool once it is at capacity.

The update_image_pool() function below implements this based on the official Torch implementation in image_pool.lua.

We can then update our image pool with generated fake images, the results of which can be used to train the discriminator models.

Next, we can update Generator-A.

The train_on_batch() function will return a value for each of the four loss functions, one for each output, as well as the weighted sum (first value) used to update the model weights which we are interested in.

We can then update the discriminator model using the fake images that may or may not have come from the image pool.

We can then do the same for the other generator and discriminator models.

At the end of the training run, we can then report the current loss for the discriminator models on real and fake images and of each generator model.

Tying this all together, we can define a function named train() that takes an instance of each of the defined models and a loaded dataset (list of two NumPy arrays, one for each domain) and trains the model.

A batch size of 1 is used as is described in the paper and the models are fit for 100 training epochs.

The train function can then be called directly with our defined models and loaded dataset.

As an improvement, it may be desirable to combine the update to each discriminator model into a single operation as is performed in the fDx_basic() function of the official implementation.

Additionally, the paper describes updating the models for another 100 epochs (200 in total), where the learning rate is decayed to 0.0. This too can be added as a minor extension to the training process.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Papers

API

Projects

Articles

Summary

In this tutorial, you discovered how to implement the CycleGAN architecture from scratch using the Keras deep learning framework.

Specifically, you learned:

  • How to implement the discriminator and generator models.
  • How to define composite models to train the generator models via adversarial and cycle loss.
  • How to implement the training process to update model weights each training iteration.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Develop Generative Adversarial Networks Today!

Generative Adversarial Networks with Python

Develop Your GAN Models in Minutes

...with just a few lines of python code

Discover how in my new Ebook:
Generative Adversarial Networks with Python

It provides self-study tutorials and end-to-end projects on:
DCGAN, conditional GANs, image translation, Pix2Pix, CycleGAN
and much more...

Finally Bring GAN Models to your Vision Projects

Skip the Academics. Just Results.

Click to learn more

12 Responses to How to Implement CycleGAN Models From Scratch With Keras

  1. Dilip Rajkumar August 8, 2019 at 8:59 pm #

    Hi Jason, thanks for this great tutorial about CycleGAN. I have a physics-based regression problem (~8 input features and 1 response variable) with only 35 data points from the real world lab test results. We have a 1D simulation tool with which we can generate any number of low fidelity artificial data points.
    Unfortunately, the low fidelity artificial synthesised data points are not having the same distribution as the real world lab test results and there are differences due to domain shift between real world lab test and simulated data.

    Can you please give some tips on applying CycleGAN for a numeric dataset (regression problem) to make the low fidelity (synthesised data) to appear more like the real world lab test data but still being physically consistent?

    • Jason Brownlee August 9, 2019 at 8:10 am #

      Good question.

      Perhaps you can try adapting the above example for your data?

      Perhaps try a gaussian process or kde approach to modeling the distribution of points and sample it randomly?

  2. Prisilla August 9, 2019 at 10:21 am #

    Hi Jason,

    Can you elaborate about batch normalization and instance normalization? Why in instance normalization, axis is set to -1.

    Thanks,
    Prisilla

  3. Jonathan August 20, 2019 at 10:52 pm #

    Hi Jason,

    Thanks for the nice post. Just one question regarding update_image_pool function, wouldn’t be a possible issue to reach an index out of range of pool?

    ++ max_size is set to 50 (if len(pool) < max_size) , hence pool index are from 0 to 49.
    ++ ix = randint(0, len(pool)), will return an integer between 0 to 50

    So there may be a possibility to access pool[50] which is out of range?

    • Jason Brownlee August 21, 2019 at 6:44 am #

      Nice catch!

      I should do len(pool)-1.

      Nevertheless, python does not do out of bounds, an index of 50 will become an index of 0.

  4. Radu September 1, 2019 at 5:58 am #

    Hi Jason,

    First of all, I want o thank you. I really appreciate your post, it really helped me understand the whole “CycleGan” thing, you really did a great thing explaining everything.

    Secondly, I want to ask you something. Your training loop is quite different from the other implementations that I have seen(PyTorch official, Keras). Mainly, I reffer to the way you are updating the models G1, D1, G2, D2 instead of G1,G2, D1, D2. Is there any specific reason you have decided to do it this way, or it was just modified for simplicity?

    Thank you,
    Radu

    • Jason Brownlee September 2, 2019 at 5:22 am #

      Thanks Radu.

      Yes, the update order of the models is based on the official implementation provided with the paper.

  5. Jon September 18, 2019 at 12:34 am #

    Hey Jason! Thanks for the wonderful and in-depth explanation of this complex topic.
    I feel a little stupid asking this but can you please tell me how we are passing the images to the code? We are supposed to feed them as lists but are those lists supposed to be made of 3-dim image data?
    Thank you.

  6. Raghu September 18, 2019 at 5:17 am #

    Thanks for the well explained article.

    I was trying to save and load the model using Keras’ save & load_model but on loading it fails to recognize the InstanceNormalization layer. Can you suggest a workaround?
    Also, any way to visualize the training images after some iterations?

    • Jason Brownlee September 18, 2019 at 6:32 am #

      Good question, use:

Leave a Reply