A Gentle Introduction to the tensorflow.data API

Last Updated on August 6, 2022

When you build and train a Keras deep learning model, you can provide the training data in several different ways. Presenting the data as a NumPy array or a TensorFlow tensor is common. Another way is to make a Python generator function and let the training loop read data from it. Yet another way of providing data is to use tf.data dataset.

In this tutorial, you will see how you can use the tf.data dataset for a Keras model. After finishing this tutorial, you will learn:

  • How to create and use the tf.data dataset
  • The benefit of doing so compared to a generator function

Let’s get started.

A gentle introduction to the tensorflow.data API
Photo by Monika MG. Some rights reserved.


This article is divided into four sections; they are:

  • Training a Keras Model with NumPy Array and Generator Function
  • Creating a Dataset using tf.data
  • Creating a Dataset from Generator Function
  • Data with Prefetch

Training a Keras Model with NumPy Array and Generator Function

Before you see how the tf.data API works, let’s review how you might usually train a Keras model.

First, you need a dataset. An example is the fashion MNIST dataset that comes with the Keras API. This dataset has 60,000 training samples and 10,000 test samples of 28×28 pixels in grayscale, and the corresponding classification label is encoded with integers 0 to 9.

The dataset is a NumPy array. Then you can build a Keras model for classification, and with the model’s fit() function, you provide the NumPy array as data.

The complete code is as follows:

Running this code will print out the following:

And also, create the following plot of validation accuracy over the 50 epochs you trained your model:

The other way of training the same network is to provide the data from a Python generator function instead of a NumPy array. A generator function is the one with a yield statement to emit data while the function runs parallel to the data consumer. A generator of the fashion MNIST dataset can be created as follows:

This function is supposed to be called with the syntax batch_generator(train_image, train_label, 32). It will scan the input arrays in batches indefinitely. Once it reaches the end of the array, it will restart from the beginning.

Training a Keras model with a generator is similar to using the fit() function:

Instead of providing the data and label, you just need to provide the generator as it will give out both. When data are presented as a NumPy array, you can tell how many samples there are by looking at the length of the array. Keras can complete one epoch when the entire dataset is used once. However, your generator function will emit batches indefinitely, so you need to tell it when an epoch is ended, using the steps_per_epoch argument to the fit() function.

In the above code, the validation data was provided as a NumPy array, but you can use a generator instead and specify the validation_steps argument.

The following is the complete code using a generator function, in which the output is the same as the previous example:

Creating a Dataset Using tf.data

Given that you have the fashion MNIST data loaded, you can convert it into a tf.data dataset, like the following:

This prints the dataset’s spec as follows:

You can see the data is a tuple (as a tuple was passed as an argument to the from_tensor_slices() function), whereas the first element is in the shape (28,28) while the second element is a scalar. Both elements are stored as 8-bit unsigned integers.

If you do not present the data as a tuple of two NumPy arrays when you create the dataset, you can also do it later. The following creates the same dataset but first creates the dataset for the image data and the label separately before combining them:

This will print the same spec:

The zip() function in the dataset is like the zip() function in Python because it matches data one by one from multiple datasets into a tuple.

One benefit of using the tf.data dataset is the flexibility in handling the data. Below is the complete code on how you can train a Keras model using a dataset in which the batch size is set to the dataset:

This is the simplest use case of using a dataset. If you dive deeper, you can see that a dataset is just an iterator. Therefore, you can print out each sample in a dataset using the following:

The dataset has many functions built in. The batch() used before is one of them. If you create batches from a dataset and print them, you have the following:

Here, each item from a batch is not a sample but a batch of samples. You also have functions such as map(), filter(), and reduce() for sequence transformation, or concatendate() and interleave() for combining with another dataset. There are also repeat(), take(), take_while(), and skip() like our familiar counterpart from Python’s itertools module. A full list of the functions can be found in the API documentation.

Creating a Dataset from Generator Function

So far, you saw how a dataset could be used in place of a NumPy array in training a Keras model. Indeed, a dataset can also be created out of a generator function. But instead of a generator function that generates a batch, as you saw in one of the examples above, you now make a generator function that generates one sample at a time. The following is the function:

This function randomizes the input array by shuffling the index vector. Then it generates one sample at a time. Unlike the previous example, this generator will end when the samples from the array are exhausted.

You can create a dataset from the function using from_generator(). You need to provide the name of the generator function (instead of an instantiated generator) and also the output signature of the dataset. This is required because the tf.data.Dataset API cannot infer the dataset spec before the generator is consumed.

Running the above code will print the same spec as before:

Such a dataset is functionally equivalent to the dataset that you created previously. Hence you can use it for training as before. The following is the complete code:

Dataset with Prefetch

The real benefit of using a dataset is to use prefetch().

Using a NumPy array for training is probably the best in performance. However, this means you need to load all data into memory. Using a generator function for training allows you to prepare one batch at a time, in which the data can be loaded from disk on demand, for example. However, using a generator function to train a Keras model means either the training loop or the generator function is running at any time. It is not easy to make the generator function and Keras’s training loop run in parallel.

Dataset is the API that allows the generator and the training loop to run in parallel. If you have a generator that is computationally expensive (e.g., doing image augmentation in realtime), you can create a dataset from such a generator function and then use it with prefetch(), as follows:

The number argument to prefetch() is the size of the buffer. Here, the dataset is asked to keep three batches in memory ready for the training loop to consume. Whenever a batch is consumed, the dataset API will resume the generator function to refill the buffer asynchronously in the background. Therefore, you can allow the training loop and the data preparation algorithm inside the generator function to run in parallel.

It’s worth mentioning that, in the previous section, you created a shuffling generator for the dataset API. Indeed the dataset API also has a shuffle() function to do the same, but you may not want to use it unless the dataset is small enough to fit in memory.

The shuffle() function, same as prefetch(), takes a buffer-size argument. The shuffle algorithm will fill the buffer with the dataset and draw one element randomly from it. The consumed element will be replaced with the next element from the dataset. Hence you need the buffer as large as the dataset itself to make a truly random shuffle. This limitation is demonstrated with the following snippet:

The output from the above looks like the following:

Here you can see the numbers are shuffled around its neighborhood, and you never see large numbers from its output.

Further Reading

More about the tf.data dataset can be found from its API documentation:


In this post, you have seen how you can use the tf.data dataset and how it can be used in training a Keras model.

Specifically, you learned:

  • How to train a model using data from a NumPy array, a generator, and a dataset
  • How to create a dataset using a NumPy array or a generator function
  • How to use prefetch with a dataset to make the generator and training loop run in parallel

4 Responses to A Gentle Introduction to the tensorflow.data API

  1. Lukas August 18, 2022 at 4:33 am #

    Sorry, didnt see that Adrian wrote the tutorial.
    Thanks to Adrian my man!

    • James Carmichael August 18, 2022 at 10:54 am #

      Great feedback Lukas! It is appreciated!

  2. david August 21, 2022 at 1:27 am #

    Hi Dr.Jason
    Thank you for all your great posts. i truly enjoy them.

    • James Carmichael August 21, 2022 at 7:49 am #

      Hi David…Thank you for your feedback and support! We greatly appreciate it!

Leave a Reply