How to Implement Multi-Head Attention from Scratch in TensorFlow and Keras

We have already familiarized ourselves with the theory behind the Transformer model and its attention mechanism. We have already started our journey of implementing a complete model by seeing how to implement the scaled-dot product attention. We shall now progress one step further into our journey by encapsulating the scaled-dot product attention into a multi-head attention mechanism, which is a core component. Our end goal remains to apply the complete model to Natural Language Processing (NLP).

In this tutorial, you will discover how to implement multi-head attention from scratch in TensorFlow and Keras. 

After completing this tutorial, you will know:

  • The layers that form part of the multi-head attention mechanism.
  • How to implement the multi-head attention mechanism from scratch.   

Kick-start your project with my book Building Transformer Models with Attention. It provides self-study tutorials with working code to guide you into building a fully-working transformer model that can
translate sentences from one language to another...

Let’s get started. 

How to implement multi-head attention from scratch in TensorFlow and Keras
Photo by Everaldo Coelho, some rights reserved.

Tutorial Overview

This tutorial is divided into three parts; they are:

  • Recap of the Transformer Architecture
    • The Transformer Multi-Head Attention
  • Implementing Multi-Head Attention From Scratch
  • Testing Out the Code

Prerequisites

For this tutorial, we assume that you are already familiar with:

Recap of the Transformer Architecture

Recall having seen that the Transformer architecture follows an encoder-decoder structure. The encoder, on the left-hand side, is tasked with mapping an input sequence to a sequence of continuous representations; the decoder, on the right-hand side, receives the output of the encoder together with the decoder output at the previous time step to generate an output sequence.

The encoder-decoder structure of the Transformer architecture
Taken from “Attention Is All You Need

In generating an output sequence, the Transformer does not rely on recurrence and convolutions.

You have seen that the decoder part of the Transformer shares many similarities in its architecture with the encoder. One of the core mechanisms that both the encoder and decoder share is the multi-head attention mechanism. 

The Transformer Multi-Head Attention

Each multi-head attention block is made up of four consecutive levels:

  • On the first level, three linear (dense) layers that each receive the queries, keys, or values 
  • On the second level, a scaled dot-product attention function. The operations performed on both the first and second levels are repeated h times and performed in parallel, according to the number of heads composing the multi-head attention block. 
  • On the third level, a concatenation operation that joins the outputs of the different heads
  • On the fourth level, a final linear (dense) layer that produces the output

Multi-head attention
Taken from “Attention Is All You Need

Recall as well the important components that will serve as building blocks for your implementation of the multi-head attention:

  • The queries, keys, and values: These are the inputs to each multi-head attention block. In the encoder stage, they each carry the same input sequence after this has been embedded and augmented by positional information. Similarly, on the decoder side, the queries, keys, and values fed into the first attention block represent the same target sequence after this would have also been embedded and augmented by positional information. The second attention block of the decoder receives the encoder output in the form of keys and values, and the normalized output of the first decoder attention block as the queries. The dimensionality of the queries and keys is denoted by $d_k$, whereas the dimensionality of the values is denoted by $d_v$.
  • The projection matrices: When applied to the queries, keys, and values, these projection matrices generate different subspace representations of each. Each attention head then works on one of these projected versions of the queries, keys, and values. An additional projection matrix is also applied to the output of the multi-head attention block after the outputs of each individual head would have been concatenated together. The projection matrices are learned during training.

Let’s now see how to implement the multi-head attention from scratch in TensorFlow and Keras.

Implementing Multi-Head Attention from Scratch

Let’s start by creating the class, MultiHeadAttention, which inherits from the Layer base class in Keras and initialize several instance attributes that you shall be working with (attribute descriptions may be found in the comments):

Here note that an instance of the DotProductAttention class that was implemented earlier has been created, and its output was assigned to the variable attention. Recall that you implemented the DotProductAttention class as follows:

Next, you will be reshaping the linearly projected queries, keys, and values in such a manner as to allow the attention heads to be computed in parallel. 

The queries, keys, and values will be fed as input into the multi-head attention block having a shape of (batch size, sequence length, model dimensionality), where the batch size is a hyperparameter of the training process, the sequence length defines the maximum length of the input/output phrases, and the model dimensionality is the dimensionality of the outputs produced by all sub-layers of the model. They are then passed through the respective dense layer to be linearly projected to a shape of (batch size, sequence length, queries/keys/values dimensionality).

The linearly projected queries, keys, and values will be rearranged into (batch size, number of heads, sequence length, depth), by first reshaping them into (batch size, sequence length, number of heads, depth) and then transposing the second and third dimensions. For this purpose, you will create the class method, reshape_tensor, as follows:

The reshape_tensor method receives the linearly projected queries, keys, or values as input (while setting the flag to True) to be rearranged as previously explained. Once the multi-head attention output has been generated, this is also fed into the same function (this time setting the flag to False) to perform a reverse operation, effectively concatenating the results of all heads together. 

Hence, the next step is to feed the linearly projected queries, keys, and values into the reshape_tensor method to be rearranged, then feed them into the scaled dot-product attention function. In order to do so, let’s create another class method, call, as follows:

Note that the reshape_tensor method can also receive a mask (whose value defaults to None) as input, in addition to the queries, keys, and values. 

Recall that the Transformer model introduces a look-ahead mask to prevent the decoder from attending to succeeding words, such that the prediction for a particular word can only depend on known outputs for the words that come before it. Furthermore, since the word embeddings are zero-padded to a specific sequence length, a padding mask also needs to be introduced to prevent the zero values from being processed along with the input. These look-ahead and padding masks can be passed on to the scaled-dot product attention through the mask argument.  

Once you have generated the multi-head attention output from all the attention heads, the final steps are to concatenate back all outputs together into a tensor of shape (batch size, sequence length, values dimensionality) and passing the result through one final dense layer. For this purpose, you will add the next two lines of code to the call method. 

Putting everything together, you have the following implementation of the multi-head attention:

Want to Get Started With Building Transformer Models with Attention?

Take my free 12-day email crash course now (with sample code).

Click to sign-up and also get a free PDF Ebook version of the course.

Testing Out the Code

You will be working with the parameter values specified in the paper, Attention Is All You Need, by Vaswani et al. (2017):

As for the sequence length and the queries, keys, and values, you will be working with dummy data for the time being until you arrive at the stage of training the complete Transformer model in a separate tutorial, at which point you will be using actual sentences:

In the complete Transformer model, values for the sequence length and the queries, keys, and values will be obtained through a process of word tokenization and embedding. We will be covering this in a separate tutorial. 

Returning to the testing procedure, the next step is to create a new instance of the MultiHeadAttention class, assigning its output to the multihead_attention variable:

Since the MultiHeadAttention class inherits from the Layer base class, the call() method of the former will be automatically invoked by the magic __call()__ method of the latter. The final step is to pass in the input arguments and print the result:

Tying everything together produces the following code listing:

Running this code produces an output of shape (batch size, sequence length, model dimensionality). Note that you will likely see a different output due to the random initialization of the queries, keys, and values and the parameter values of the dense layers.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Books

Papers

Summary

In this tutorial, you discovered how to implement multi-head attention from scratch in TensorFlow and Keras. 

Specifically, you learned:

  • The layers that form part of the multi-head attention mechanism
  • How to implement the multi-head attention mechanism from scratch 

Do you have any questions?
Ask your questions in the comments below, and I will do my best to answer.

Learn Transformers and Attention!

Building Transformer Models with Attention

Teach your deep learning model to read a sentence

...using transformer models with attention

Discover how in my new Ebook:
Building Transformer Models with Attention

It provides self-study tutorials with working code to guide you into building a fully-working transformer models that can
translate sentences from one language to another...

Give magical power of understanding human language for
Your Projects


See What's Inside

, , ,

28 Responses to How to Implement Multi-Head Attention from Scratch in TensorFlow and Keras

  1. Moisés October 9, 2022 at 5:27 am #

    Hi. Very good explanation! I have a little doubt. Why do you need to reshape the input and then calculate the transpose instead of just reshaping it to (batch size, heads, sequence lenght, -1)? Thanks.

    • James Carmichael October 10, 2022 at 11:09 am #

      Hi Moises…You may certainly proceed with your suggestion. Let us know what you find.

  2. Brett October 26, 2022 at 4:05 pm #

    Awesome tutorial. Thanks for pulling all this together!

    • James Carmichael October 27, 2022 at 7:39 am #

      You are very welcome Brett!

  3. Diego November 21, 2022 at 4:39 am #

    if we reshape the data, does that not mean that we still perform regular dot attention (1 head), but now reshape it as if it were 8 heads? I’m not seeing extra dimensions being created when I print out the reshaped query, key, value matrices’ shapes

  4. Diego November 21, 2022 at 11:16 am #

    I think
    self.W_q (d_k)
    self.W_k (d_k)
    self.W_v (d_v)

    should not have a static d_k value, but instead should be:

    self.head_dim = self.d_model // self.heads

    self.W_q (self.head_dim)
    self.W_k (self.head_dim)
    self.W_v (self.head_dim)

    otherwise your q,k,v matrices are not related to the embed space and the number of heads, but just to a preset value of the d_k dimensionality

    • Farid T. August 28, 2023 at 9:44 am #

      I agree. This needs to be explicitly pointed out.

  5. Cybernetic1 December 18, 2022 at 4:49 pm #

    In the last block of code:

    queries = random.random((batch_size, input_seq_length, d_k))

    Should it be this instead:

    queries = random.random((batch_size, input_seq_length, d_model)) ?

    In the article you wrote:

    “The queries, keys, and values will be fed as input into the multi-head attention block having a shape of (batch size, sequence length, model dimensionality), where the batch size is a hyperparameter of the training process, the sequence length defines the maximum length of the input/output phrases, and the model dimensionality is the dimensionality of the outputs produced by all sub-layers of the model. They are then passed through the respective dense layer to be linearly projected to a shape of (batch size, sequence length, queries/keys/values dimensionality).”

    In particular you said “model dimensionality” in the above quote, which is d_model, not d_k.

    In other words, I think W_q is a matrix of dimension 512 x 64. Where input dim = d_model = 512, output dim = d_k = 64.

    Am I mistaken? This is very confusing to me…. thanks if someone can clarify this for me.

    • Stefania Cristina
      Stefania Cristina December 19, 2022 at 12:58 am #

      Hi Cybernetic1, thank you for the interest.

      What I meant by, “The queries, keys, and values will be fed as input into the multi-head attention block having a shape of (batch size, sequence length, model dimensionality)…”, is that the output of the multi-head attention block is of shape (batch size, sequence length, model dimensionality) rather than the queries, keys or values. Otherwise, it may be confirmed from Vaswani’s paper that the queries and keys are of dimensionality d_k, and the values are of dimensionality d_v, where Vaswani et al. set d_k and d_v to a value that differs from that for d_model.

    • Vicki Huang January 6, 2023 at 2:18 pm #

      I agree with Cybernetic1.
      d_k is the dimensionality of the linearly projected queries and keys. But in this code ‘queries = random.random((batch_size, input_seq_length, d_k))’, queries is the input of the multihead_attention which have not been projected yet. They will be projected in the call function of MultiHeadAttention which is done by this code ‘q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)’. So in the first line of code, the last dimensionality should be d_model or d_k * h.

    • Lavero August 16, 2023 at 2:58 pm #

      There is a mistake in the code itself. The multihead class should be:

      class MultiHeadAttention(Layer):
      def __init__(self, h, d_model, **kwargs):
      super(MultiHeadAttention, self).__init__(**kwargs)
      self.attention = DotProductAttention() # Scaled dot product attention
      self.heads = h # Number of attention heads to use

      assert d_model % h = 0
      d_k= d_model//h # d_k should be calculated from d_model and h

      self.d_k = d_k # Dimensionality of the linearly projected queries and keys
      self.d_v = d_k # Dimensionality of the linearly projected values
      self.d_model = d_model # Dimensionality of the model

      #units of projection matrices should be d_model instead of d_k/d_v
      self.W_q = Dense(d_model) # Learned projection matrix for the queries
      self.W_k = Dense(d_model) # Learned projection matrix for the keys
      self.W_v = Dense(d_model) # Learned projection matrix for the values
      self.W_o = Dense(d_model) # Learned projection matrix for the multi-head output

      The other portion of the code seems to be ok.

      • yleo April 24, 2024 at 2:18 am #

        this is the correct answer

  6. John January 14, 2023 at 1:43 pm #

    Agree with Cybernetic1, but I really appreciate author’s effort putting these together!

  7. Ivan February 22, 2023 at 3:12 pm #

    Queries and Keys must have same dimensions, to be able successfully compute MatMul in DotProductAttention.

    Queries: (batch, time, d_q)
    Keys: (batch, time, d_k)

    To calculate DotProduct(Queries, Keys) = Queries @ Keys.T (d_q must be EQUAL to d_k)

    • James Carmichael February 23, 2023 at 8:24 am #

      Thank you for your feedback Ivan!

  8. Ivan February 22, 2023 at 3:15 pm #

    Hello,

    In PDF-version.

    def reshape_tensor…..
    …..
    —- False branch
    x = reshape(x, shape=(shape(x)[0], shape(x)[1], self.d_k))
    …..

    IMO d_k is incorrect , should be d_v. Due to reverse reshaping is done on Values, not on keys/queries.

    Book version throws exception is d_k not the same as d_v

    • James Carmichael February 23, 2023 at 8:24 am #

      Thank you for your feedback Ivan!

  9. Ivan February 26, 2023 at 2:05 am #

    Over-regularization at a scaled dot-product attention.

    In chapter 3.2.1 of the paper “Attention is all you need” there is a footer explaining the effect of regularization term (dk**-1/2)


    To illustrate why the dot products get large, assume that the components of q and k are independent random variables with mean 0 and variance 1. Then their dot product, q · k, has mean 0 and variance dk.

    We can observe this behavior:
    q = tf.random.normal([2,3,16,8])
    k = tf.random.normal([2,3,16,8])
    qk = q@tf.transpose(k, perm=[0,1,3,2])
    tf.math.reduce_variance(qk)

    # ——— almost equal to last dimension
    qk2 = q@tf.transpose(k, perm=[0,1,3,2])/tf.math.sqrt(8.)
    # ——— add regularization term
    tf.math.reduce_variance(qk2)

    # ——— dot-prod variance is close to variance of q and k

    Back to the Book.
    In the code MultiHeadAttention.call()

    # Compute the multi-head attention output using the reshaped queries, keys and values
    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
    # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)

    self.d_k could be two values:
    1-st option) before reshape: d_k = number of heads * query size
    2-nd option) after reshape: d_k = query size

    In the book “1-st option” is used. Which makes dot-product over-regularized and slows down the training, due to gradients are equal across all tokens.

    To implement 2-nd option, in MultiHeadAttention.call replace
    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
    # to the next line
    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k / self.heads, mask)

    To investigate scores-values you can set breakpoint immediately after score calculations in the DotProductAttention.call and add per-head-variance of keys, queries and scores to a watch window:
    math.reduce_variance(keys, axis=[0,2,3])
    math.reduce_variance(queries, axis=[0,2,3])
    math.reduce_variance(scores, axis=[0,2,3])

    Observations with current code (1-st option):
    keys and queries variances are close to each other, but scores-variance is order of magnitude smaller. Scores variance is around 0.1.

    With 2-nd option implemented all three variances are in the same ball-park.

    Over-regularization influences softmax distribution, with each term is almost equal to each other and gradients flow will be equally distributed across all terms.

    • James Carmichael February 26, 2023 at 10:35 am #

      Thank you for your feedback Ivan! Let us know if we can address any specific questions regarding our content.

  10. Ivan February 27, 2023 at 1:12 pm #

    I think this is the correct code:

    o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k / self.heads, mask)

  11. Florian May 17, 2023 at 3:44 pm #

    Hi! And thank you for your detailed and very clear introduction to Transformer models.

    I just wanted to point out a confusing typo. In the “Implementing Multi-Head Attention from Scratch” sub-section, you write “Note that the reshape_tensor method can also receive a mask”, but I think you are talking about the attention method, instead.

    Cheers,
    Florian

    • James Carmichael May 18, 2023 at 6:07 am #

      You are very welcome Florian! We appreciate the support, feedback and suggestions!

      • Farid T. August 28, 2023 at 6:34 am #

        Florian’s correction is an important one. I have seen this in many places on this website, where a consequential typo is pointed out or an important correction is made and you just thank them. Also instead of giving detailed responses you link to generic (sometimes useful) pages and “encourage” commenters to “try it out” and “let us know what you find”.
        I purchased the eBook and I am grateful for it. But I use the comments on this website as a much-needed “errata” section. It is a bit disappointing that you are not actively engaging with your readers or making corrections when errors are pointed out.

  12. Gabriel Nascimento December 10, 2023 at 4:01 pm #

    How to use this with after Keras Bi-LSTM layer?

  13. Kishan March 16, 2024 at 1:32 am #

    Why would I pass query, key and value as input to the multihead attention? The input to the multihead attention should be the training data right if dimension [batch_dim, max_token_length, embedding_dim]?

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.