Decoder-only language models like Llama are usually trained using self-supervised learning objectives on large amounts of text. This is called pretraining to distinguish it from later fine-tuning steps on specific tasks. In this article, you will learn how to pretrain a Llama model on a local GPU. Specifically, you will learn how to:
- Prepare the training data
- Run the pretraining
Let’s get started.

Pretraining a Llama Model on Your Local GPU
Photo by Hongbin. Some rights reserved.
Overview
This article is divided into three parts; they are:
- Training a Tokenizer with Special Tokens
- Preparing the Training Data
- Running the Pretraining
Training a Tokenizer with Special Tokens
The model architecture you will use is the same as the one created in the previous post. This is a 12-layer Llama model with a vocabulary size of 50,000. The data you will use for pretraining is the HuggingFaceFW/fineweb dataset.
To prepare the training data, you first need to set up the tokenizer. To recap, the following code trains a BPE tokenizer on the HuggingFaceFW/fineweb dataset and saves it to a file:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from typing import Iterator import datasets from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, normalizers # Load FineWeb 10B sample (using only a slice for demo to save memory) dataset = datasets.load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train", streaming=True) def get_texts(dataset: datasets.Dataset, limit: int = 100_000) -> Iterator[str]: """Get texts from the dataset until the limit is reached or the dataset is exhausted.""" count = 0 for sample in dataset: yield sample["text"] count += 1 if limit and count >= limit: break # Initialize a BPE model tokenizer = Tokenizer(models.BPE(byte_fallback=True, unk_token="[UNK]")) tokenizer.normalizer = normalizers.NFKC() tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True, use_regex=False) tokenizer.decoder = decoders.ByteLevel() # Trainer trainer = trainers.BpeTrainer( vocab_size=50_000, min_frequency=2, special_tokens=["[PAD]", "[BOT]", "[EOT]", "[UNK]"], show_progress=True, ) # Train and save the tokenizer to disk texts = get_texts(dataset, limit=100_000) tokenizer.train_from_iterator(texts, trainer=trainer) tokenizer.save("bpe_50k.json") |
This tokenizer uses the BPE (byte-pair encoding) algorithm at the byte level. Normally, it would not emit any unknown tokens, but you still set a special token for them. Additionally, you set special tokens for the beginning of text ([BOT]), end of text ([EOT]), and padding ([PAD]). These are useful for next-token prediction.
This code automatically uses all CPU cores. Running this code will take a few minutes on a high-end computer. The trained tokenizer will be saved to a file named bpe_50k.json. Once trained, you can load it back with the following code:
|
1 2 3 |
from tokenizers import Tokenizer tokenizer = Tokenizer.from_file("bpe_50k.json") |
Note that you trained the tokenizer with a vocabulary size of 50,000. This is fairly useful for a single-language model. However, if you intend to train a model for multiple languages, a larger vocabulary size is preferred.
Preparing the Training Data
Pretraining a language model means predicting the next token in a sequence. With the training data, you need to tokenize the text to create a tensor of integer token IDs and a shift-by-one version as the prediction target.
As you can see from the previous section, you can load the dataset and print out the text as strings by iterating over the dataset object:
|
1 2 3 4 |
dataset = datasets.load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train") for sample in dataset: print(sample["text"]) break |
This dataset is small compared to those usually used for language model training. However, it is still large enough to contain diverse samples of human language.
For pretraining, you need to create a PyTorch Dataset object so that your model can consume it, as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
class PretrainingDataset(torch.utils.data.Dataset): def __init__(self, dataset, tokenizer, seq_length, device): self.dataset = dataset self.tokenizer = tokenizer self.device = device self.seq_length = seq_length self.bot = tokenizer.token_to_id("[BOT]") self.eot = tokenizer.token_to_id("[EOT]") self.pad = tokenizer.token_to_id("[PAD]") def __len__(self): return len(self.dataset) def __getitem__(self, index): """Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens are added. Clipped and padded to the sequence length. """ seq = self.dataset[index]["text"] tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot] # pad to target sequence length toklen = len(tokens) if toklen < self.seq_length+1: pad_length = self.seq_length+1 - toklen tokens += [self.pad] * pad_length # return the sequence x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64, device=self.device) y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64, device=self.device) return x, y |
This is the simplest way to tokenize text data for pretraining. You wrap around the Hugging Face dataset object, matching the number of samples in the __len__ method. In the __getitem__ method, you tokenize a particular text sample into a tensor of integer token IDs. You add the begin of text and end of text tokens to help with pretraining: When you provide just the begin of text token, the model can predict the first token of a sentence. When you provide the entire sequence, the model should predict it is the end.
A transformer model does not limit the length you pass to it, except for a maximum sequence length that the positional encoding can handle. However, when you pass multiple sequences as a batch, you need to ensure all sequences have the same length so you can stack them into a single tensor. You add padding tokens to shorter sequences and clip longer sequences to the target sequence length.
Pretraining is self-supervised learning. The label for the expected output is already in the input sequence. Therefore, you set x as the input sequence and its shift-by-one version as the target sequence y. You want them to be PyTorch tensors instead of Python lists so you can use them with a PyTorch data loader. You must also set the data type to int64 due to a limitation of PyTorch’s CrossEntropyLoss, which requires this type to recognize padding tokens when computing the training loss.
You can test the dataset by creating a DataLoader object and drawing a batch from it:
|
1 2 3 4 5 6 7 8 9 10 11 |
batch_size = 8 seq_length = 512 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") dataloader = torch.utils.data.DataLoader( PretrainingDataset(dataset, tokenizer, seq_length, device), batch_size=batch_size ) for x, y in dataloader: print(x) print(y) break |
Running the Pretraining
Once you have the input and target data ready from the dataset, running pretraining on a language model is no different from training other deep learning models.
Using the model code from the previous post, let’s first create a model object:
|
1 2 3 4 |
# Create pretraining model with default config model_config = LlamaConfig() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = LlamaForPretraining(model_config).to(device) |
This is a small model for demonstration purposes. It has only 171 million parameters, much smaller than any large language model you can find on the internet.
Next, you should define the training parameters. Depending on your hardware, you may want to adjust the batch size, but keeping the sequence length moderately long helps the model learn context. Here is the strategy to use:
- This dataset has only a training split. For simplicity, the data is not shuffled, no holdout set is created, and the training loop does not contain any evaluation step.
- Next-token prediction is a classification problem over the entire vocabulary. Naturally, the loss function is cross-entropy. You should ensure that padding tokens are not used in computing the loss, as they are not valid inputs.
- Set the sequence length to 512. The resources required to train a model scale as $O(N^2)$ with sequence length. Therefore, you prefer to keep it short, but a sequence length that is too short prevents the model from understanding longer contexts.
- Following best practices for training large language models, use a cosine learning rate scheduler with a warmup period. The warmup period can be set to a fixed number of steps or to a percentage of the total training steps (e.g., 0.1%-2%). Let’s set it to 1,000 steps here.
- Once the sequence length is determined, adjust the batch size to fit your GPU memory. You can start with 8, which empirically fits into 12GB of VRAM.
- With 14 million samples and 10 billion tokens in the
HuggingFaceFW/fineweb10B dataset, you probably do not need to train for many epochs. In fact, many large language models are trained for only 1-3 epochs on very large datasets.
Let’s put these parameters together to define the training configuration:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# Training parameters epochs = 3 learning_rate = 1e-3 batch_size = 8 seq_length = 512 num_warmup_steps = 1000 PAD_TOKEN_ID = tokenizer.token_to_id("[PAD]") # DataLoader, optimizer, scheduler, and loss function model.train() dataloader = torch.utils.data.DataLoader( PretrainingDataset(dataset, tokenizer, seq_length, device), batch_size=batch_size ) optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.01 ) num_training_steps = len(dataloader) * epochs warmup_scheduler = lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps ) cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0 ) scheduler = lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[num_warmup_steps] ) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID) |
The AdamW optimizer is configured with a peak learning rate of 1e-3. Other parameters are set to their defaults. The cosine scheduler from PyTorch is combined with a linear scheduler to implement the warmup period. They are combined using the SequentialLR scheduler and configured to switch from a linear to a cosine schedule at the 1,000th step.
Note that you did not set streaming=True when loading the dataset for training, nor did you shuffle the dataset. This makes the DataLoader object deterministic. This way, you can easily determine the total number of training steps, which helps you set up the learning rate scheduler.
The loss function uses nn.CrossEntropyLoss with the padding token ID set as the ignore index. This means whenever the reference target is a padding token, the loss is not computed. This is important to match the behavior you defined when you created the dataset object in the previous section.
This is a small model and a small dataset by large language model standards. However, the training is still very slow. Running the training from scratch on a single GPU will take several hundred hours. It is important that you can checkpoint the model and resume training. Let’s implement this in a training loop:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# look for last checkpoint if os.path.exists("llama_pretraining_checkpoint.pth"): checkpoint = torch.load("llama_pretraining_checkpoint.pth") begin_epoch = checkpoint["epoch"] begin_batch = checkpoint["batch"] model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) del checkpoint print(f"Resuming training from epoch {begin_epoch} and batch {begin_batch}") else: begin_epoch = 0 begin_batch = 0 # start training for epoch in range(begin_epoch, epochs): dataloader = torch.utils.data.DataLoader( PretrainingDataset( dataset.skip(begin_batch * batch_size), tokenizer, seq_length, device, ), batch_size=batch_size ) pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch_id, batch in enumerate(pbar): if (begin_batch + batch_id) % 1000 == 0: # checkpoint the model and optimizer state torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, "batch": batch_id + begin_batch, }, f"llama_pretraining_checkpoint.pth") # get batched data input_ids, target_ids = batch # create attention mask: causal mask + padding mask attn_mask = create_causal_mask(input_ids.shape[1], device) + \ create_padding_mask(input_ids, PAD_TOKEN_ID, device) # extract output from model logits = model(input_ids, attn_mask) # compute loss: cross-entropy between logits and target, ignoring padding tokens loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1)) # backward with loss and apply gradient clipping optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item()) pbar.update(1) begin_batch = 0 # reset for next epoch pbar.close() |
When you checkpoint the training, you need to save the model state, the optimizer state, and the scheduler state. You also need to remember the epoch and batch index so you can resume from the same batch in the dataset.
You visualize the training progress with a progress bar from the tqdm library. During training, you pull a pair of input and target tensors from the DataLoader object. The datasets library allows you to skip an arbitrary number of samples. You use this to create a DataLoader object to resume from the previous checkpoint.
Then you create an attention mask to mask out padding tokens and enable causal masking to control the self-attention mechanism. The model output is a 3D tensor with the same batch size and sequence length as your input. You need to reshape it for the loss function, then update the model with the computed loss. Everything is standard for training a deep learning model.
At the end, you can save the model so you can reuse it for inference:
|
1 2 |
torch.save(model.state_dict(), "llama_pretraining_model.pth") torch.save(model.base_model.state_dict(), "llama_model.pth") |
Depending on your use case, you may want to save the base model, the pretraining model, or both. The base model is useful for other tasks, while the pretraining model is useful as a generative model.
For completeness, below is the full code for the training:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
import dataclasses import os import datasets import tqdm import tokenizers import torch import torch.nn as nn import torch.nn.functional as F import torch.optim.lr_scheduler as lr_scheduler from torch import Tensor # Load the tokenizer tokenizer = tokenizers.Tokenizer.from_file("bpe_50K.json") # Load the dataset dataset = datasets.load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train") # Build the model @dataclasses.dataclass class LlamaConfig: """Define Llama model hyperparameters.""" vocab_size: int = 50000 # Size of the tokenizer vocabulary max_position_embeddings: int = 2048 # Maximum sequence length hidden_size: int = 768 # Dimension of hidden layers intermediate_size: int = 4*768 # Dimension of MLP's hidden layer num_hidden_layers: int = 12 # Number of transformer layers num_attention_heads: int = 12 # Number of attention heads num_key_value_heads: int = 3 # Number of key-value heads for GQA def rotate_half(x: Tensor) -> Tensor: """Rotates half the hidden dims of the input. This is a helper function for rotary position embeddings (RoPE). For a tensor of shape (..., d), it returns a tensor where the last d/2 dimensions are rotated by swapping and negating. Args: x: Input tensor of shape (..., d) Returns: Tensor of same shape with rotated last dimension """ x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) # Concatenate with rotation class RotaryPositionEncoding(nn.Module): """Rotary position encoding.""" def __init__(self, dim: int, max_position_embeddings: int) -> None: """Initialize the RotaryPositionEncoding module Args: dim: The hidden dimension of the input tensor to which RoPE is applied max_position_embeddings: The maximum sequence length of the input tensor """ super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings # compute a matrix of n\theta_i N = 10_000.0 inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim)) inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) position = torch.arange(max_position_embeddings) sinusoid_inp = torch.outer(position, inv_freq) # save cosine and sine matrices as buffers, not parameters self.register_buffer("cos", sinusoid_inp.cos()) self.register_buffer("sin", sinusoid_inp.sin()) def forward(self, x: Tensor) -> Tensor: """Apply RoPE to tensor x Args: x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim) Returns: Output tensor of shape (batch_size, seq_length, num_heads, head_dim) """ batch_size, seq_len, num_heads, head_dim = x.shape dtype = x.dtype # transform the cosine and sine matrices to 4D tensor and the same dtype as x cos = self.cos.to(dtype)[:seq_len].view(1, seq_len, 1, -1) sin = self.sin.to(dtype)[:seq_len].view(1, seq_len, 1, -1) # apply RoPE to x output = (x * cos) + (rotate_half(x) * sin) return output class LlamaAttention(nn.Module): """Grouped-query attention with rotary embeddings.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads # GQA: H_kv < H_q # hidden_size must be divisible by num_heads assert (self.head_dim * self.num_heads) == self.hidden_size # Linear layers for Q, K, V projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: bs, seq_len, dim = hidden_states.size() # Project inputs to Q, K, V query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) # Apply rotary position embeddings query_states = rope(query_states) key_states = rope(key_states) # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # Use PyTorch's optimized attention implementation # setting is_causal=True is incompatible with setting explicit attention mask attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attn_mask, dropout_p=0.0, enable_gqa=True, ) # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output class LlamaMLP(nn.Module): """Feed-forward network with SwiGLU activation.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() # Two parallel projections for SwiGLU self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.act_fn = F.silu # SwiGLU activation function # Project back to hidden size self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x: Tensor) -> Tensor: # SwiGLU activation: multiply gate and up-projected inputs gate = self.act_fn(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up) class LlamaDecoderLayer(nn.Module): """Single transformer layer for a Llama model.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5) self.self_attn = LlamaAttention(config) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5) self.mlp = LlamaMLP(config) def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: # First residual block: Self-attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask) hidden_states = attn_outputs + residual # Second residual block: MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + residual return hidden_states class LlamaModel(nn.Module): """The full Llama model without any pretraining heads.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.rotary_emb = RotaryPositionEncoding( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, ) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5) def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: # Convert input token IDs to embeddings hidden_states = self.embed_tokens(input_ids) # Process through all transformer layers, then the final norm layer for layer in self.layers: hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask) hidden_states = self.norm(hidden_states) # Return the final hidden states return hidden_states class LlamaForPretraining(nn.Module): def __init__(self, config: LlamaConfig) -> None: super().__init__() self.base_model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: hidden_states = self.base_model(input_ids, attn_mask) return self.lm_head(hidden_states) def create_causal_mask(seq_len: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: """Create a causal mask for self-attention. Args: seq_len: Length of the sequence device: Device to create the mask on dtype: Data type of the mask Returns: Causal mask of shape (seq_len, seq_len) """ mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype) \ .triu(diagonal=1) return mask def create_padding_mask(batch, padding_token_id, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: """Create a padding mask for a batch of sequences for self-attention. Args: batch: Batch of sequences, shape (batch_size, seq_len) padding_token_id: ID of the padding token Returns: Padding mask of shape (batch_size, 1, seq_len, seq_len) """ padded = torch.zeros_like(batch, device=device, dtype=dtype) \ .masked_fill(batch == padding_token_id, float('-inf')) mask = padded[:,:,None] + padded[:,None,:] return mask[:, None, :, :] # Generator function to create padded sequences of fixed length class PretrainingDataset(torch.utils.data.Dataset): def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer, seq_length: int, device: torch.device = None): self.dataset = dataset self.tokenizer = tokenizer self.device = device self.seq_length = seq_length self.bot = tokenizer.token_to_id("[BOT]") self.eot = tokenizer.token_to_id("[EOT]") self.pad = tokenizer.token_to_id("[PAD]") def __len__(self): return len(self.dataset) def __getitem__(self, index): """Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens are added. Clipped and padded to the sequence length. """ seq = self.dataset[index]["text"] tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot] # pad to target sequence length toklen = len(tokens) if toklen < self.seq_length+1: pad_length = self.seq_length+1 - toklen tokens += [self.pad] * pad_length # return the sequence x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64, device=self.device) y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64, device=self.device) return x, y # Create pretraining model with default config model_config = LlamaConfig() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = LlamaForPretraining(model_config).to(device) # print the model size print(f"Model parameters size: {sum(p.numel() for p in model.parameters()) / 1024**2:.2f} M") print(f"Model buffers size: {sum(p.numel() for p in model.buffers()) / 1024**2:.2f} M") # Training parameters epochs = 3 learning_rate = 1e-3 batch_size = 8 seq_length = 512 num_warmup_steps = 1000 PAD_TOKEN_ID = tokenizer.token_to_id("[PAD]") # DataLoader, optimizer, scheduler, and loss function model.train() dataloader = torch.utils.data.DataLoader( PretrainingDataset(dataset, tokenizer, seq_length, device), batch_size=batch_size ) optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.01 ) num_training_steps = len(dataloader) * epochs warmup_scheduler = lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps ) cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0 ) scheduler = lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[num_warmup_steps] ) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID) # look for last checkpoint if os.path.exists("llama_pretraining_checkpoint.pth"): checkpoint = torch.load("llama_pretraining_checkpoint.pth") begin_epoch = checkpoint["epoch"] begin_batch = checkpoint["batch"] model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) del checkpoint print(f"Resuming training from epoch {begin_epoch} and batch {begin_batch}") else: begin_epoch = 0 begin_batch = 0 # start training for epoch in range(begin_epoch, epochs): dataloader = torch.utils.data.DataLoader( PretrainingDataset( dataset.skip(begin_batch * batch_size), tokenizer, seq_length, device, ), batch_size=batch_size ) pbar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch_id, batch in enumerate(pbar): if (begin_batch + batch_id) % 1000 == 0: # checkpoint the model and optimizer state torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, "batch": batch_id + begin_batch, }, f"llama_pretraining_checkpoint.pth") # get batched data input_ids, target_ids = batch # create attention mask: causal mask + padding mask attn_mask = create_causal_mask(input_ids.shape[1], device) + \ create_padding_mask(input_ids, PAD_TOKEN_ID, device) # extract output from model logits = model(input_ids, attn_mask) # compute loss: cross-entropy between logits and target, ignoring padding tokens loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1)) # backward with loss and apply gradient clipping optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item()) pbar.update(1) begin_batch = 0 # reset for next epoch pbar.close() # Save the model torch.save(model.state_dict(), "llama_pretraining_model.pth") torch.save(model.base_model.state_dict(), "llama_model.pth") |
Note that this is a simplified training recipe. A professional model training process would use a much larger dataset on a much larger model. For example, Llama 2 models with 7B-70B parameters are trained on 2 trillion tokens. The hyperparameters for training, such as the learning rate, would be tuned before they are finalized for actual training.
Moreover, it would be more efficient to train the model with shorter sequence lengths first, then expand to longer ones later. It is also known to train the model on lower-quality data initially and then use higher-quality data toward the end to make the model more expressive. None of these techniques is implemented in the code above. You can refer to the previous post for techniques to improve the training.
Further Reading
Below are some further reading materials that you may find useful:
- Liu et al (2024) Understanding LLMs: A Comprehensive Overview from Training to Inference
- Grattafiori et al (2024) The Llama 3 Herd of Models
- Groeneveld et al (2024) OLMo: Accelerating the Science of Language Models
- Sebastian Raschka, Build a Large Language Model (From Scratch). Manning Publications 2024
Summary
In this article, you learned how to pretrain a Llama model on a single GPU. Specifically, you learned how to:
- Train a tokenizer with special tokens for next-token prediction
- Prepare the training data for pretraining
- Run the pretraining on a single GPU with checkpointing






This is a very detailed guide and easy to follow step by step
It is helpful to see a full example that runs on local hardware
The explanations make the training process feel less intimidating
I will bookmark the website and come back to try parts of this later
Thank you, William! We appreciate the support and feedback!
This is good. Just a question, what exactly could you use the base model, pretrained or even an instruction tuned one? Would those be in a future tutorial? And when you say hundreds of hours, like 10 days? 20 days non stop? Just for a 170M parameters? Can you spread it out to multiple GPUs? And could you run this all on something like a PI with lots of RAM and an external GPU?
How to use the base model is out of scope of this post. We will talk about that later, probably in a different series about inference and applications.
On a high end GPU like RTX 4090, probably you can finish this training in a week non-stop. This is still manageable that you do not need multiple GPUs. But as a 170M parameters, don’t expect it to be very powerful. As a comparison, Llama 2 model released in 2023 took 200,000 hours to train on A100. Of course, that would need to train with GPUs in parallel.