Fine-Tuning a BERT Model

BERT is a foundational NLP model trained to understand language, but it may not perform well on any specific task out of the box. However, you can build upon BERT by adding appropriate model heads and training it for a specific task. This process is called fine-tuning. In this article, you will learn how to fine-tune a BERT model for several NLP tasks.

Let’s get started.

Fine-Tuning a BERT Model
Photo by Martin Krchnacek. Some rights reserved.

Overview

This article is divided into two parts; they are:

  • Fine-tuning a BERT Model for GLUE Tasks
  • Fine-tuning a BERT Model for SQuAD Tasks

Fine-tuning a BERT Model for GLUE Tasks

GLUE is a benchmark for evaluating natural language understanding (NLU) tasks. It contains nine tasks, such as sentiment analysis, paraphrase identification, and text classification. The model learns task-specific behavior from examples. GLUE has a held-out test set to evaluate model performance on each task, with results reported on a public leaderboard.

Let’s take the “sst2” task (sentiment classification) in GLUE as an example.

You can load the dataset using the Hugging Face datasets library:

Running this code, the output is:

The dataset loaded has three splits: train, validation, and test. Each sample in the dataset is a dictionary. The keys that we are interested in are "sentence" and "label". The label is either 0 or 1, representing negative or positive sentiment, respectively.

This dataset cannot be used directly because the model requires you to convert text sentences into token sequences. Moreover, the training loop requires data in batches, so you need to create batches of shuffled and padded sequences. Let’s create a PyTorch DataLoader with a custom collate function:

To prepare the data, use the tokenizer you trained in the previous post. The BERT model should also be trained with the same tokenizer.

The collate() function takes a batch of samples as a list of dictionaries. It converts text sentences into token sequences and pads them to the same length. Unlike BERT pre-training, you do not have a pair of sentences, but you still need to use the [CLS] and [SEP] tokens as delimiters in the output sequence. The output of the collate function is a tuple of two tensors: the input IDs as a 2D tensor and the labels as a 1D tensor.

You set up two DataLoader objects: one for the training set and one for the validation set. The training set is shuffled, but the validation set is not.

Next, you need to set up a model for GLUE tasks. Since this is a sentence classification task, you need to add a linear layer on top of the BERT model to project the hidden state of the [CLS] token to the number of labels. Below is the implementation:

You reuse the BertModel and BertConfig classes defined in the previous post. In the BertForSequenceClassification class, you use the foundation BERT model to process the input sequence. The pooled output, corresponding to the [CLS] token, is then passed through a linear layer to project it to the number of labels. The sequence output, however, is unused. The model’s output is the logits for the classification task. In sentiment classification, this is a vector of two values per sample.

Note: Some variants of the BERT model, such as RoBERTa, did not pretrain the [CLS] token. The fine-tuning is the only way to make the [CLS] token learn the representation of the entire input sequence.

All fine-tuning of the BERT model follows a similar architecture. In fact, you can see the figure below from the BERT paper that you are using the (b) architecture:

Different fine-tuning architectures of BERT. Figure from the BERT paper.

Since you have already trained the foundation BERT model, you can instantiate the model for sequence classification and then load the pretrained weights for the foundation model:

You can now run the training loop. Compared with pre-training, fine-tuning requires only a few epochs. Otherwise, the training loop is quite typical:

Running this code, you may see:

Since you have both a training and a validation set, train on the training set and evaluate on the validation set. Be sure to use model.train() and model.eval() to set the model to training or evaluation mode, respectively, since your model uses dropout layers.

That’s all you need to do to fine-tune a BERT model for GLUE tasks. Below is the complete code for sequence classification:

Fine-tuning a BERT Model for SQuAD

SQuAD is a question answering dataset. Each sample contains a question and a context paragraph. The answer to the question is a span of words within the context paragraph. This is not a general question-answering task, since the answer is always a substring within the context. If no such substring exists, the question has no answer.

Let’s take a look at one sample in the dataset:

Running this code, the output is:

The SQuAD dataset has only training and validation splits. Each sample is a dictionary with the keys "id", "title", "context", "question", and "answers". The "answers" key is a dictionary containing the answer text and its offset in the context.

To train a model, you need to batch and process the data samples into tensors as you did for the GLUE tasks. Let’s create a custom collate function for the SQuAD dataset:

This collate function is more complex than the one for GLUE tasks because you need to pass a pair of sentences as input in the format [CLS] question [SEP] context [SEP]. The context may be clipped to fit the maximum length. The question, context, and answer are all converted into token sequences. In the inner for-loop, you find the position of the answer span in the context. If the answer is not found within the provided context, you mark the question as having no answer.

One of the most important roles of the collate function is to create tensors for the whole batch. Here, you produce four tensors: the input (including both the question and the context), the token type IDs (indicating which tokens belong to the question and which belong to the context), and the start and end positions of the answer span.

Next, you need to set up a model for SQuAD tasks. The strategy is to process the sequence output from the BERT model such that each token is transformed into a probability of being the start or end of the answer span. You can then identify the highest-probability start and end tokens to form the answer span. The implementation is straightforward:

The foundation BERT model produces both sequence output and pooled output. You use only the sequence output for SQuAD tasks. You pass the output through a linear layer to produce the start and end position logits, then return them separately. To convert them to probabilities, apply softmax to the logits, which can be done outside the model. This model for fine-tuning follows the architecture (c) shown in the figure in the previous section.

As in the example for GLUE tasks above, you can instantiate the model and load the pretrained weights:

And finally, you can run the training loop for fine-tuning:

The training loop is similar to the one for GLUE tasks. Instead of using the pooled output, you now use the output corresponding to each token in the sequence. The token with the highest logit value is the predicted start or end position. The loss function is the sum of the cross-entropy losses for the predicted start and end positions.

This is a simplified way to use the model’s output. You can refine the logic by finding the start-end pair with the highest combined score, subject to the constraint that the end position is greater than or equal to the start position. This may improve the model’s performance.

Running this code, you may see:

You may notice that the model’s performance is not very good. This is likely because the foundation BERT model was trained on the smaller WikiText-2 dataset, which does not generalize well to more complex tasks. For better performance in practical applications, use the official pretrained weights.

Below is the complete code for fine-tuning a BERT model for SQuAD tasks:

Further Readings

Below are some resources that you may find useful:

Summary

In this article, you learned how to fine-tune a BERT model for GLUE and SQuAD tasks. Specifically, you learned:

  • How to build a new model on top of BERT for fine-tuning
  • How to run the training loop for fine-tuning
  • The GLUE and SQuAD datasets and the tasks they are designed for

No comments yet.

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.