The large language models today are a simplified form of the transformer model. They are called decoder-only models because their role is similar to the decoder part of the transformer, which generates an output sequence given a partial sequence as input. Architecturally, they are closer to the encoder part of the transformer model. In this post, you will build a decoder-only transformer model for text generation in the same architecture as Meta’s Llama-2 or Llama-3. Specifically, you will learn:
- How to build a decoder-only model
- The variations in the architecture design of the decoder-only model
- How to train the model
Kick-start your project with my book Building Transformer Models From Scratch with PyTorch. It provides self-study tutorials with working code.
Let’s get started.

Building a Decoder-Only Transformer Model for Text Generation
Photo by Jay. Some rights reserved.
Overview
This post is divided into five parts; they are:
- From a Full Transformer to a Decoder-Only Model
- Building a Decoder-Only Model
- Data Preparation for Self-Supervised Learning
- Training the Model
- Extensions
From a Full Transformer to a Decoder-Only Model
The transformer model originated as a sequence-to-sequence (seq2seq) model that converts an input sequence into a context vector, which is then used to generate a new sequence. In this architecture, the encoder part is responsible for converting the input sequence into a context vector, while the decoder part generates the new sequence from this context vector.
Instead of using the context vector to generate an entirely new sequence, can we project it into a vector of logits representing probabilities for each token in the vocabulary? This way, given a partial sequence as input, the model can predict the next most likely token. By iteratively feeding the sequence back into the model, we can generate coherent text one token at a time, much like auto-complete functions in text editors. This simplified architecture, which focuses solely on predicting the next token, is called a decoder-only model.
Building a Decoder-Only Model
A decoder-only model has a simpler architecture than a full transformer model. Starting with the full transformer architecture discussed in the previous post, you can create a decoder-only model by removing the encoder component entirely and adapting the decoder for standalone operation.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
class DecoderLayer(nn.Module): def __init__(self, hidden_dim, num_heads, num_kv_heads, dropout=0.1): super().__init__() self.self_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout) self.mlp = SwiGLU(hidden_dim, 4 * hidden_dim) self.norm1 = nn.RMSNorm(hidden_dim) self.norm2 = nn.RMSNorm(hidden_dim) def forward(self, x, mask=None, rope=None): # self-attention sublayer out = self.norm1(x) out = self.self_attn(out, out, out, mask, rope) x = out + x # MLP sublayer out = self.norm2(x) out = self.mlp(out) return out + x class TextGenerationModel(nn.Module): def __init__(self, num_layers, num_heads, num_kv_heads, hidden_dim, max_seq_len, vocab_size, dropout=0.1): super().__init__() self.rope = RotaryPositionalEncoding(hidden_dim // num_heads, max_seq_len) self.embedding = nn.Embedding(vocab_size, hidden_dim) self.decoders = nn.ModuleList([ DecoderLayer(hidden_dim, num_heads, num_kv_heads, dropout) for _ in range(num_layers) ]) self.norm = nn.RMSNorm(hidden_dim) self.out = nn.Linear(hidden_dim, vocab_size) def forward(self, ids, mask=None): x = self.embedding(ids) for decoder in self.decoders: x = decoder(x, mask, self.rope) x = self.norm(x) return self.out(x) |
The implementation reuses a significant portion of the code from the full transformer model. The DecoderLayer class shares the same structure as the EncoderLayer from the previous implementation. The TextGenerationModel class features a simplified forward() method since it no longer needs to handle encoder-decoder interactions. It simply converts input token IDs into embeddings, processes them through the stacked decoder layers, and projects the output into logits representing probabilities for each token in the vocabulary.
In the picture, the model is like the following. This shares the same architectural design of Llama-2/Llama-3 models proposed by Meta:

Decoder-Only Model Following the Architecture of Llama-2/Llama-3
Data Preparation for Self-Supervised Learning
Our goal is to create a model that can generate coherent paragraphs of text from a given prompt, even if that prompt is just a single word. To train such a model effectively, we need to consider our training approach and data requirements carefully.
The training technique we’ll use is called self-supervised learning. Unlike traditional supervised learning, which requires manually labeled data, self-supervised learning leverages the inherent structure of the text itself. When we input a sequence of text, the model learns to predict the next token, and the actual next token in the text serves as the ground truth. This eliminates the need for manual labeling.
The size of the training dataset is crucial. With a vocabulary size of $N$ tokens and a dataset containing $M$ words, each token appears approximately $M/N$ times on average. To ensure the model learns meaningful representations for all tokens, this ratio needs to be sufficiently large.
In this post, you will download some novels from Project Gutenberg and use them as the dataset to train the model.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import os import requests # Download novels from Project Gutenberg DATASOURCE = { "moby_dick": "https://www.gutenberg.org/ebooks/2701.txt.utf-8", "frankenstein": "https://www.gutenberg.org/ebooks/84.txt.utf-8", "dracula": "https://www.gutenberg.org/ebooks/345.txt.utf-8", "little_women": "https://www.gutenberg.org/ebooks/37106.txt.utf-8", "pride_and_prejudice": "https://www.gutenberg.org/ebooks/1342.txt.utf-8", "alice_in_wonderland": "https://www.gutenberg.org/ebooks/11.txt.utf-8", "crime_and_punishment": "https://www.gutenberg.org/ebooks/2554.txt.utf-8", "tom_sawyer": "https://www.gutenberg.org/ebooks/74.txt.utf-8", "tale_of_two_cities": "https://www.gutenberg.org/ebooks/98.txt.utf-8", "sherlock_holmes": "https://www.gutenberg.org/ebooks/1661.txt.utf-8", "war_and_peace": "https://www.gutenberg.org/ebooks/2600.txt.utf-8", } for filename, url in DATASOURCE.items(): if not os.path.exists(f"{filename}.txt"): response = requests.get(url) with open(f"{filename}.txt", "wb") as f: f.write(response.content) |
These public domain novels, written by various authors across different genres, provide a diverse dataset that will help our model learn a wide range of vocabulary and writing styles.
With these novels downloaded, you can extract the main context as a string and keep these strings as a list:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
# Read and preprocess the text def preprocess_gutenberg(filename): with open(filename, "r", encoding="utf-8") as f: text = f.read() # Find the start and end of the actual content start = text.find("*** START OF THE PROJECT GUTENBERG EBOOK") start = text.find("\n", start) + 1 end = text.find("*** END OF THE PROJECT GUTENBERG EBOOK") # Extract the main content text = text[start:end].strip() # Basic preprocessing # Remove multiple newlines and spaces text = "\n".join(line.strip() for line in text.split("\n") if line.strip()) return text def get_dataset_text(): all_text = [] for filename in DATASOURCE: text = preprocess_gutenberg(f"{filename}.txt") all_text.append(text) return all_text |
The next step is to create a tokenizer. You can build a naive tokenizer by splitting the text into words. You can also use the Byte-Pair Encoding (BPE) algorithm to create a more sophisticated tokenizer, as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import tokenizers # Tokenization with BPE tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE()) # Configure pre-tokenizer add space at beginning of the sentence tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True) # Configure decoder so that would boundary symbol will be removed tokenizer.decoder = tokenizers.decoders.ByteLevel() # Train BPE VOCAB_SIZE = 10000 trainer = tokenizers.trainers.BpeTrainer( vocab_size=VOCAB_SIZE, special_tokens=["[pad]", "[eos]"], show_progress=True ) text = get_dataset_text() tokenizer.train_from_iterator(text, trainer=trainer) tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[pad]"), pad_token="[pad]") # Save the trained tokenizer tokenizer.save("gutenberg_tokenizer.json", pretty=True) |
This uses the tokenizers library to train a BPE tokenizer. You called get_dataset_text() to get the text of all the novels and then train the tokenizer on it. You also need two special tokens: [pad] and [eos]. Most importantly, the [eos] token is used to indicate the end of the sequence. If your model generates this token, you know you can stop the generation.
Training the Model
With the tokenizer and the dataset ready, you can now train the model.
First, you need to create a Dataset object that can be used to train the model. PyTorch provides a framework for this.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch class GutenbergDataset(torch.utils.data.Dataset): def __init__(self, text, tokenizer, seq_len=512): self.seq_len = seq_len # Encode the entire text self.encoded = tokenizer.encode(text).ids def __len__(self): return len(self.encoded) - self.seq_len def __getitem__(self, idx): chunk = self.encoded[idx:idx + self.seq_len + 1] # +1 for target x = torch.tensor(chunk[:-1]) y = torch.tensor(chunk[1:]) return x, y BATCH_SIZE = 32 text = "\n".join(get_dataset_text()) dataset = GutenbergDataset(text, tokenizer, seq_len=model_config["max_seq_len"]) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
This Dataset object is used to create a DataLoader object that can be used to train the model. The DataLoader object will automatically batch the data and shuffle it.
The Dataset object produces a pair of input and output sequences in the __getitem__() method. They are of the same length but offset by one token. When the input sequence is passed to the model, the model generates the next token for each position in the sequence. Hence, the ground truth output is from the same source, offset by one. This is how you can set up the self-supervised training.
Now you can create the model and train it. You can use this code to create a very large model. However, if you do not expect the model to be very powerful, you can design a smaller one. Let’s make one with:
- 8 layers
- Attention uses 8 query heads and 4 key-value heads
- Hidden dimension is 768
- Maximum sequence length is 512
- Set dropout at attention to 0.1
- Train with AdamW optimizer with initial learning rate 0.0005
- Learning rate scheduler with 2000 steps of warmup and then cosine annealing
- Train for 2 epochs with batch size 32, clip norm to 6.0
Everything above is typical. Training a decoder-only model typically requires a very large dataset, and the number of epochs may be as few as 1. It is the number of steps trained that matters. The training will use a linear warmup to gradually increase the learning rate at the beginning, which can reduce the effect of how the model is initialized. Then the cosine annealing will gradually decrease the learning rate such that at the end of the training, when the model is almost converged, it keeps the learning rate at a very small value to stabilize the result.
The code for model creation and training is as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# Training configuration model_config = { "num_layers": 8, "num_heads": 8, "num_kv_heads": 4, "hidden_dim": 768, "max_seq_len": 512, "vocab_size": len(tokenizer.get_vocab()), "dropout": 0.1, } # Initialize model, optimizer, etc. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = TextGenerationModel(**model_config).to(device) # Create dataset and dataloader BATCH_SIZE = 32 text = "\n".join(get_dataset_text()) dataset = GutenbergDataset(text, tokenizer, seq_len=model_config["max_seq_len"]) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # Training loop N_EPOCHS = 2 LR = 0.0005 WARMUP_STEPS = 2000 CLIP_NORM = 6.0 optimizer = optim.AdamW(model.parameters(), lr=LR) loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[pad]")) # Learning rate scheduling warmup_scheduler = optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, end_factor=1.0, total_iters=WARMUP_STEPS) cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=N_EPOCHS * len(dataloader) - WARMUP_STEPS, eta_min=0) scheduler = optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[WARMUP_STEPS]) print(f"Training for {N_EPOCHS} epochs with {len(dataloader)} steps per epoch") best_loss = float('inf') for epoch in range(N_EPOCHS): model.train() epoch_loss = 0 for x, y in dataloader: x = x.to(device) y = y.to(device) # Create causal mask mask = create_causal_mask(x.shape[1], device) # Forward pass optimizer.zero_grad() outputs = model(x, mask.unsqueeze(0)) # Compute loss loss = loss_fn(outputs.view(-1, outputs.shape[-1]), y.view(-1)) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), CLIP_NORM, error_if_nonfinite=True ) optimizer.step() scheduler.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) print(f"Epoch {epoch+1}/{N_EPOCHS}; Avg loss: {avg_loss:.4f}") # Save checkpoint if loss improved if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), "textgen_model.pth") |
In the training loop, you did the usual forward and backward passes. The model will be saved whenever the loss is improved. For simplicity, no evaluation is implemented. You should evaluate the model regularly (not necessarily after every epoch) to monitor progress.
Due to the large vocabulary size and sequence length, the training process is computationally intensive. Even on a high-end RTX 4090 GPU, each epoch takes approximately 10 hours to complete.
Once the training is done, you can load the model and generate text:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Generation function def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): model.eval() device = next(model.parameters()).device # Encode the prompt input_ids = torch.tensor(tokenizer.encode(prompt).ids).unsqueeze(0).to(device) with torch.no_grad(): for _ in range(max_length): # Get model predictions for the next token as the last element of the output outputs = model(input_ids) next_token_logits = outputs[:, -1, :] / temperature # Sample from the distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to input_ids input_ids = torch.cat([input_ids, next_token], dim=1) # Stop if we predict the end token if next_token[0].item() == tokenizer.token_to_id("[eos]"): break return tokenizer.decode(input_ids[0].tolist()) # Test the model with some prompts test_prompts = [ "Once upon a time,", "We the people of the", "In the beginning was the", ] print("\nGenerating sample texts:") for prompt in test_prompts: generated = generate_text(model, tokenizer, prompt) print(f"\nPrompt: {prompt}") print(f"Generated: {generated}") print("-" * 80) |
The model was used to generate text in the generate_text() function. It expects a partial sentence as input prompt, and the model will be used to generate the next token for each step inside the for-loop. The generation algorithm uses a probability sampling rather than always choosing the most likely token. This allows the model to generate more creative text. The temperature parameter controls the level of creativity of the generated text.
The output from the model is a vector of logits, and the sampling process generates a vector of token IDs. This vector will be converted back to a string by the tokenizer.
If you run this code, you may see the following output:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
Generating sample texts: Prompt: Once upon a time, Generated: Once upon a time, and Tom rested with her, and they talked of home, and the friends there, and the comfortable beds and, above all, the light! Becky cried, and Tom tried to think of some way of comforting her, but all his encouragements were grown thread-bare with use, and sounded like sarcasms. Fatigue bore so heavily upon Becky that she drowsed off to sleep. Tom was grateful. He -------------------------------------------------------------------------------- Prompt: We the people of the Generated: We the people of the French near: going a terrible danger there is written in the case of the perilous event that we sprang from the apes. The soldiers who, afraid of being able to attack the French themselves. It was plain that the Russian nest was ruined and destroyed, but in place of the Russian order of life that had been destroyed, Pierre unconsciously felt that a quite different, firm, French order had been established over this ruined nest. -------------------------------------------------------------------------------- Prompt: In the beginning was the Generated: In the beginning was the first rummation of my beloved sister.” “One moment,” said Holmes, “are you sure about the Emperorer clean, and not you, and I need not trouble so far as I had a machine than Jane well enough, and of being with the truth of a woman asylum a man of not oftenionable, with any other, and the consequences of such a step must be waived for the fête as -------------------------------------------------------------------------------- |
While the generated text shows some coherence and understanding of language patterns, it’s not perfect. However, considering the relatively small size of our model and limited training data, these results are encouraging.
For completeness, below is the full code of the model and the training:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
import os import requests import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tokenizers import tqdm # Download novels from Project Gutenberg DATASOURCE = { "moby_dick": "https://www.gutenberg.org/ebooks/2701.txt.utf-8", "frankenstein": "https://www.gutenberg.org/ebooks/84.txt.utf-8", "dracula": "https://www.gutenberg.org/ebooks/345.txt.utf-8", "little_women": "https://www.gutenberg.org/ebooks/37106.txt.utf-8", "pride_and_prejudice": "https://www.gutenberg.org/ebooks/1342.txt.utf-8", "alice_in_wonderland": "https://www.gutenberg.org/ebooks/11.txt.utf-8", "crime_and_punishment": "https://www.gutenberg.org/ebooks/2554.txt.utf-8", "tom_sawyer": "https://www.gutenberg.org/ebooks/74.txt.utf-8", "tale_of_two_cities": "https://www.gutenberg.org/ebooks/98.txt.utf-8", "sherlock_holmes": "https://www.gutenberg.org/ebooks/1661.txt.utf-8", "war_and_peace": "https://www.gutenberg.org/ebooks/2600.txt.utf-8", } for filename, url in DATASOURCE.items(): if not os.path.exists(f"{filename}.txt"): response = requests.get(url) with open(f"{filename}.txt", "wb") as f: f.write(response.content) # Read and preprocess the text def preprocess_gutenberg(filename): with open(filename, "r", encoding="utf-8") as f: text = f.read() # Find the start and end of the actual content start = text.find("*** START OF THE PROJECT GUTENBERG EBOOK") start = text.find("\n", start) + 1 end = text.find("*** END OF THE PROJECT GUTENBERG EBOOK") # Extract the main content text = text[start:end].strip() # Basic preprocessing # Remove multiple newlines and spaces text = "\n".join(line.strip() for line in text.split("\n") if line.strip()) return text def get_dataset_text(): all_text = [] for filename in DATASOURCE: text = preprocess_gutenberg(f"{filename}.txt") all_text.append(text) return all_text # Tokenization with BPE if os.path.exists("gutenberg_tokenizer.json"): tokenizer = tokenizers.Tokenizer.from_file("gutenberg_tokenizer.json") else: tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE()) # Configure pre-tokenizer add space at beginning of the sentence tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True) # Configure decoder so that would boundary symbol will be removed tokenizer.decoder = tokenizers.decoders.ByteLevel() # Train BPE VOCAB_SIZE = 10000 trainer = tokenizers.trainers.BpeTrainer( vocab_size=VOCAB_SIZE, special_tokens=["[pad]", "[eos]"], show_progress=True ) text = get_dataset_text() tokenizer.train_from_iterator(text, trainer=trainer) tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[pad]"), pad_token="[pad]") # Save the trained tokenizer tokenizer.save("gutenberg_tokenizer.json", pretty=True) # Create PyTorch dataset class GutenbergDataset(torch.utils.data.Dataset): def __init__(self, text, tokenizer, seq_len=512): self.seq_len = seq_len # Encode the entire text self.encoded = tokenizer.encode(text).ids def __len__(self): return len(self.encoded) - self.seq_len def __getitem__(self, idx): chunk = self.encoded[idx:idx + self.seq_len + 1] # +1 for target x = torch.tensor(chunk[:-1]) y = torch.tensor(chunk[1:]) return x, y def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(x, cos, sin): return (x * cos) + (rotate_half(x) * sin) class RotaryPositionalEncoding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() N = 10000 inv_freq = 1. / (N ** (torch.arange(0, dim, 2).float() / dim)) position = torch.arange(max_seq_len).float() inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) sinusoid_inp = torch.outer(position, inv_freq) self.register_buffer("cos", sinusoid_inp.cos()) self.register_buffer("sin", sinusoid_inp.sin()) def forward(self, x, seq_len=None): if seq_len is None: seq_len = x.size(1) cos = self.cos[:seq_len].view(1, seq_len, 1, -1) sin = self.sin[:seq_len].view(1, seq_len, 1, -1) return apply_rotary_pos_emb(x, cos, sin) class SwiGLU(nn.Module): def __init__(self, hidden_dim, intermediate_dim): super().__init__() self.gate = nn.Linear(hidden_dim, intermediate_dim) self.up = nn.Linear(hidden_dim, intermediate_dim) self.down = nn.Linear(intermediate_dim, hidden_dim) self.act = nn.SiLU() def forward(self, x): x = self.act(self.gate(x)) * self.up(x) x = self.down(x) return x class GQA(nn.Module): def __init__(self, hidden_dim, num_heads, num_kv_heads=None, dropout=0.1): super().__init__() self.num_heads = num_heads self.num_kv_heads = num_kv_heads or num_heads self.head_dim = hidden_dim // num_heads self.num_groups = num_heads // num_kv_heads self.dropout = dropout self.q_proj = nn.Linear(hidden_dim, hidden_dim) self.k_proj = nn.Linear(hidden_dim, hidden_dim) self.v_proj = nn.Linear(hidden_dim, hidden_dim) self.out_proj = nn.Linear(hidden_dim, hidden_dim) def forward(self, q, k, v, mask=None, rope=None): q_batch_size, q_seq_len, hidden_dim = q.shape k_batch_size, k_seq_len, hidden_dim = k.shape v_batch_size, v_seq_len, hidden_dim = v.shape # projection q = self.q_proj(q).view(q_batch_size, q_seq_len, -1, self.head_dim).transpose(1, 2) k = self.k_proj(k).view(k_batch_size, k_seq_len, -1, self.head_dim).transpose(1, 2) v = self.v_proj(v).view(v_batch_size, v_seq_len, -1, self.head_dim).transpose(1, 2) # apply rotary positional encoding if rope: q = rope(q) k = rope(k) # compute grouped query attention q = q.contiguous() k = k.contiguous() v = v.contiguous() output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout, enable_gqa=True) output = output.transpose(1, 2).reshape(q_batch_size, q_seq_len, hidden_dim).contiguous() output = self.out_proj(output) return output class DecoderLayer(nn.Module): def __init__(self, hidden_dim, num_heads, num_kv_heads, dropout=0.1): super().__init__() self.self_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout) self.mlp = SwiGLU(hidden_dim, 4 * hidden_dim) self.norm1 = nn.RMSNorm(hidden_dim) self.norm2 = nn.RMSNorm(hidden_dim) def forward(self, x, mask=None, rope=None): # self-attention sublayer out = self.norm1(x) out = self.self_attn(out, out, out, mask, rope) x = out + x # MLP sublayer out = self.norm2(x) out = self.mlp(out) return out + x class TextGenerationModel(nn.Module): def __init__(self, num_layers, num_heads, num_kv_heads, hidden_dim, max_seq_len, vocab_size, dropout=0.1): super().__init__() self.rope = RotaryPositionalEncoding(hidden_dim // num_heads, max_seq_len) self.embedding = nn.Embedding(vocab_size, hidden_dim) self.decoders = nn.ModuleList([ DecoderLayer(hidden_dim, num_heads, num_kv_heads, dropout) for _ in range(num_layers) ]) self.norm = nn.RMSNorm(hidden_dim) self.out = nn.Linear(hidden_dim, vocab_size) def forward(self, ids, mask=None): x = self.embedding(ids) for decoder in self.decoders: x = decoder(x, mask, self.rope) x = self.norm(x) return self.out(x) def create_causal_mask(seq_len, device): """Create a causal mask for autoregressive attention.""" mask = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device), diagonal=1) return mask # Training configuration model_config = { "num_layers": 8, "num_heads": 8, "num_kv_heads": 4, "hidden_dim": 768, "max_seq_len": 512, "vocab_size": len(tokenizer.get_vocab()), "dropout": 0.1, } # Initialize model, optimizer, etc. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = TextGenerationModel(**model_config).to(device) # Create dataset and dataloader BATCH_SIZE = 32 text = "\n".join(get_dataset_text()) dataset = GutenbergDataset(text, tokenizer, seq_len=model_config["max_seq_len"]) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # Training loop if os.path.exists("textgen_model.pth"): model.load_state_dict(torch.load("textgen_model.pth")) else: N_EPOCHS = 2 LR = 0.0005 WARMUP_STEPS = 2000 CLIP_NORM = 6.0 optimizer = optim.AdamW(model.parameters(), lr=LR) loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[pad]")) # Learning rate scheduling warmup_scheduler = optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, end_factor=1.0, total_iters=WARMUP_STEPS) cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=N_EPOCHS * len(dataloader) - WARMUP_STEPS, eta_min=0) scheduler = optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[WARMUP_STEPS]) print(f"Training for {N_EPOCHS} epochs with {len(dataloader)} steps per epoch") best_loss = float('inf') for epoch in range(N_EPOCHS): model.train() epoch_loss = 0 progress_bar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{N_EPOCHS}") for x, y in progress_bar: x = x.to(device) y = y.to(device) # Create causal mask mask = create_causal_mask(x.shape[1], device) # Forward pass optimizer.zero_grad() outputs = model(x, mask.unsqueeze(0)) # Compute loss loss = loss_fn(outputs.view(-1, outputs.shape[-1]), y.view(-1)) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), CLIP_NORM, error_if_nonfinite=True ) optimizer.step() scheduler.step() epoch_loss += loss.item() # Show loss in tqdm progress_bar.set_postfix(loss=loss.item()) avg_loss = epoch_loss / len(dataloader) print(f"Epoch {epoch+1}/{N_EPOCHS}; Avg loss: {avg_loss:.4f}") # Save checkpoint if loss improved if avg_loss < best_loss: best_loss = avg_loss torch.save(model.state_dict(), "textgen_model.pth") # Generation function def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): model.eval() device = next(model.parameters()).device # Encode the prompt input_ids = torch.tensor(tokenizer.encode(prompt).ids).unsqueeze(0).to(device) with torch.no_grad(): for _ in range(max_length): # Get model predictions for the next token as the last element of the output outputs = model(input_ids) next_token_logits = outputs[:, -1, :] / temperature # Sample from the distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to input_ids input_ids = torch.cat([input_ids, next_token], dim=1) # Stop if we predict the end token if next_token[0].item() == tokenizer.token_to_id("[eos]"): break return tokenizer.decode(input_ids[0].tolist()) # Test the model with some prompts test_prompts = [ "Once upon a time,", "We the people of the", "In the beginning was the", ] print("\nGenerating sample texts:") for prompt in test_prompts: generated = generate_text(model, tokenizer, prompt) print(f"\nPrompt: {prompt}") print(f"Generated: {generated}") print("-" * 80) |
Extensions
While we’ve successfully implemented a basic decoder-only model, modern large language models (LLMs) are significantly more sophisticated. Here are key areas for improvement:
- Scale and Architecture: Modern LLMs use many more layers and larger hidden dimensions. They also incorporate advanced techniques beyond what we’ve implemented here, such as mixture of experts.
- Dataset Size and Diversity: Our current dataset, consisting of a few megabytes of novel text, is tiny compared to the terabyte-scale datasets used in modern LLMs. Production models are trained on diverse content types across multiple languages.
- Training Pipeline: What we’ve implemented is called “pretraining” in LLM development. Production models typically undergo additional fine-tuning phases for specific tasks, such as question-answering or instruction-following, using specialized datasets and tailored training objectives.
- Training Infrastructure: Training larger models requires sophisticated techniques for distributed training across multiple GPUs, gradient accumulation, and other optimizations that would require significant modifications to our training loop.
Further Reading
Below are some links that you may find useful:
- LLaMA: Open and Efficient Foundation Language Models
- BloombergGPT: A Large Language Model for Finance
- SmolLM2: When Smol Goes Big — Data Centric Training of a Small Language Model
- What is the difference between pre-training, fine-tuning, and instruct-tuning exactly?
Summary
In this post, you’ve walked through the process of building a decoder-only transformer model for text generation. In particular, you’ve learned:
- Understanding how to simplify a full transformer architecture into a decoder-only model
- Implementing self-supervised learning for text generation tasks
- Creating a text generation pipeline using the trained model
This decoder-only architecture serves as the foundation for many modern large language models, making it a crucial concept to understand in the field of natural language processing.







No comments yet.