Pretrain a BERT Model from Scratch

BERT is a transformer-based model for NLP tasks. As an encoder-only model, it has a highly regular architecture. In this article, you will learn how to create and pretrain a BERT model from scratch using PyTorch.

Let’s get started.

Pretrain a BERT Model from Scratch.
Photo by Matheus Câmara da Silva. Some rights reserved.

Overview

This article is divided into three parts; they are:

  • Creating a BERT Model the Easy Way
  • Creating a BERT Model from Scratch with PyTorch
  • Pre-training the BERT Model

Creating a BERT Model the Easy Way

If your goal is to create a BERT model so that you can train it on your own data, using the Hugging Face transformers library is the easiest way to get started. After installing the library:

You can then create a BERT model by using the BertModel class. For example, you can load a pretrained BERT model from the Hugging Face model hub with the following code:

This will download the BERT model from the Hugging Face model hub and load it into a PyTorch model object. You can also create a new BERT model with a different configuration by using the BertConfig class. For example, to create a BERT model with 12 layers, 768 hidden dimensions, and 12 attention heads, you can use the following code:

This will create a new, untrained BERT model with the specified configuration.

Creating a BERT Model from Scratch with PyTorch

Using the transformers library is convenient, but you lose the flexibility to customize the model architecture. However, building a BERT model from scratch with PyTorch is not very difficult. Let’s revisit the architecture of BERT:

The BERT architecture

As you can see, BERT is a stack of transformer blocks. Each transformer block consists of a self-attention layer and a feed-forward layer with GeLU activation. Post-norm with LayerNorm is used in the blocks. You can implement one transformer block in PyTorch with the following code:

The BERT model requires a pooler that transforms the hidden state of the [CLS] token for classification tasks. The [CLS] token is a special placeholder used to represent the entire sequence, so its representation should be distinguished from other token states. The pooler is simply a linear layer with a tanh activation function. You can implement it with the following code:

Now, you can implement the BERT model using the above building blocks. The BERT model takes a sequence of integer tokens as input, which must be converted into embedding vectors before the transformer blocks process them. Moreover, the model applies a mask to the input tokens to prevent the model from attending to padding tokens.

The BERT model can be implemented as follows:

The BERT model embeds not only the input tokens but also the token type. Moreover, BERT uses learned position embeddings. You need to sum the three embeddings and pass the result to the transformer blocks. The normalization and dropout applied after the embeddings help regularize the model and stabilize training.

The model returns the hidden state of the entire sequence and the pooled output of the [CLS] token, which are useful for the MLM and NSP tasks, respectively.

Notice that the model is instantiated with a config object. This helps avoid listing all hyperparameters in the constructor of the BertModel class. The config object is simply defined as:

The above code defines the BERT model backbone. When you pre-train the model, you need to add pretraining heads to generate predictions for the MLM and NSP tasks. Let’s implement a pretraining model that uses the BERT backbone and adds pretraining heads.

Pre-training the BERT Model

Pre-training the BERT model requires a labeled dataset. See the previous post for instructions on creating the labeled dataset.

The first step is to create a data loader for the pretraining dataset. Like most other models, BERT operates on batches of data rather than individual samples. The data loader helps you shuffle and batch the data for training and allows you to customize the data to fit the training pipeline.

Let’s see how you can create a PyTorch DataLoader object with the labeled dataset.

The dataset object is created using the Hugging Face datasets library, which loads the parquet file created in the previous post. The dataset is preprocessed to create labels for the MLM and NSP tasks. Each sample in the dataset is a Python dictionary with the keys tokens (the integer tokens of the sequence), segment_ids (the segment labels), is_random_next (a Boolean label indicating whether the next sentence is from another document), masked_positions (a list of masked positions), and masked_labels (the original tokens at the masked positions).

PyTorch’s DataLoader can help you shuffle and batch the data. You should set num_workers appropriately to use multiple CPU cores, so data loading is not a bottleneck in your training.

You can set a custom collate function in the DataLoader to transform the data into tensors that can be fed into the model. Note that the segment_ids in the previous post use -1 for padding tokens, but we did not set up any embedding for this value. Since the padding locations are ignored, you can set those values to 1 for convenience.

Masked positions and labels must be handled differently. Each sample may have a different number of masked positions, so you cannot stack them into a single tensor. Instead, you keep the masked positions as a list of tuples, each containing the batch index and the positions. The masked labels are simply a flattened tensor of the original tokens at the masked positions. The reason will become clear when you implement the training loop, as follows:

This is a standard training loop in PyTorch, but much simplified compared with the training procedure described in section A.2 of the original BERT paper. You set up the optimizer, scheduler, and loss function, then iterate over the data loader and update the model parameters. The tqdm library is used to visualize training progress. The pretraining model outputs are a sequence of logits for the MLM task and a logit for the NSP task. Calculating the NSP loss is straightforward since the output is a tensor of shape (B, 2) and the target is a vector of either 0 or 1. The MLM loss is calculated only on 15% of the input tokens. You need to extract the logits of the masked positions as mlm_logits and then compare them with masked_labels. The overall loss is the sum of the MLM and NSP losses.

This training loop runs for 10 epochs (the original paper suggested 40). Depending on your hardware, it may take an hour to complete even for the smaller WikiText-2 dataset. If you’re using the larger WikiText-103 dataset, it may take a day to complete. The trained model is saved to the file bert_pretraining_model.pth. However, you typically do not need the pretrained model since the pretraining heads are not useful for other tasks. You can extract the BERT model’s backbone and save it separately.

For completeness, here is the complete code:

Further Reading

Below are some resources that you may find useful:

Summary

In this article, you learned how to create a BERT model from scratch using PyTorch. Specifically, you learned:

  • How to create a BERT model from scratch using PyTorch
  • How to use PyTorch DataLoader to batch data and handle variable-length sequences
  • How to pre-train a BERT model using the MLM and NSP tasks

No comments yet.

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.