The attention mechanism, introduced by Bahdanau et al. in 2014, significantly improved sequence-to-sequence (seq2seq) models. In this post, you’ll learn how to build and train a seq2seq model with attention for language translation, focusing on:
- Why attention mechanisms are essential
- How to implement attention in a seq2seq model
Kick-start your project with my book Building Transformer Models From Scratch with PyTorch. It provides self-study tutorials with working code.
Let’s get started.

Building a Seq2Seq Model with Attention for Language Translation
Photo by Esther T. Some rights reserved.
Overview
This post is divided into four parts; they are:
- Why Attnetion Matters: Limitations of Basic Seq2Seq Models
- Implementing Seq2Seq Model with Attention
- Training and Evaluating the Model
- Using the Model
Why Attention Matters: Limitations of Basic Seq2Seq Models
Traditional seq2seq models use an encoder-decoder architecture where the encoder compresses the input sequence into a single context vector, which the decoder then uses to generate the output sequence. This approach has a critical limitation: the decoder must rely on this single context vector regardless of the output sequence length.
This becomes problematic with longer sequences as the model struggles to retain important details from earlier parts of the sequence. Consider English to French translation: The decoder uses the context vector as its initial state to generate the first token, then uses each previous output as input for subsequent tokens. As the hidden state updates, the decoder gradually loses information from the original context vector.
Attention mechanisms solve this by:
- Giving the decoder access to all encoder hidden states during generation
- Allowing focus on relevant input parts for each output token
- Eliminating reliance on a single context vector
Implementing Seq2Seq Model with Attention
Let’s implement a seq2seq model with attention following Bahdanau et al. (2014). You’ll use GRU (Gated Recurrent Unit) modules instead of LSTM for their simplicity and faster training while maintaining comparable performance.
With the same dataset for training and similar to the encoder in the plain seq2seq model in the previous post, the encoder is implemented as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
class EncoderRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True) self.dropout = nn.Dropout(dropout) def forward(self, input_seq): embedded = self.dropout(self.embedding(input_seq)) outputs, hidden = self.rnn(embedded) return outputs, hidden |
The dropout module prevents overfitting by being applied to the embedding layer output. The RNN uses nn.GRU with batch_first=True to accept input shaped as (batch_size, seq_len, embedding_dim). The encoder’s forward() method returns:
- A 3D tensor of shape
(batch_size, seq_len, hidden_dim)containing RNN outputs - A 2D tensor of shape
(1, batch_size, hidden_dim)containing the final hidden state
The Bahdanau attention mechanism differs from modern transformer attention. Here’s its implementation:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
class BahdanauAttention(nn.Module): def __init__(self, hidden_size): super(BahdanauAttention, self).__init__() self.Wa = nn.Linear(hidden_size, hidden_size) self.Ua = nn.Linear(hidden_size, hidden_size) self.Va = nn.Linear(hidden_size, 1) def forward(self, query, keys): scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) scores = scores.transpose(1,2) # shape of scores = [B, 1, S] weights = F.softmax(scores, dim=-1) context = torch.bmm(weights, keys) return context, weights |
The attention mechanism is defined mathematically as:
$$
y = \textrm{softmax}\big(W^V \tanh(W^Q Q + W^K K)\big) K
$$
Unlike scaled dot-product attention, it uses summed projections of query and key.
With the Bahdanau attention module, the decoder is implemented as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
class DecoderRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.embedding = nn.Embedding(vocab_size, embedding_dim) self.dropout = nn.Dropout(dropout) self.attention = BahdanauAttention(hidden_dim) self.gru = nn.GRU(embedding_dim + hidden_dim, hidden_dim, batch_first=True) self.out_proj = nn.Linear(hidden_dim, vocab_size) def forward(self, input_seq, hidden, enc_out): """Single token input, single token output""" embedded = self.dropout(self.embedding(input_seq)) context, attn_weights = self.attention(hidden.transpose(0, 1), enc_out) rnn_input = torch.cat([embedded, context], dim=-1) rnn_output, hidden = self.gru(rnn_input, hidden) output = self.out_proj(rnn_output) return output, hidden |
The decoder’s forward() method expects three inputs: A single-token input sequence, the latest RNN hidden state, and the encoder’s full output sequence. It will process align input token to the encoder’s output sequence using attention to generate a context vector for the decoder. Then this context vector, together with the input token, is used to generate the next token using the GRU module. The output is then projected to a logit vector of the same size as the vocabulary.
The seq2seq model is then built by connecting the encoder and decoder modules, as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
class Seq2SeqRNN(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, input_seq, target_seq): """Given the partial target sequence, predict the next token""" batch_size, target_len = target_seq.shape device = target_seq.device # list for storing the output logits outputs = [] # encoder forward pass enc_out, hidden = self.encoder(input_seq) dec_hidden = hidden # decoder forward pass for t in range(target_len-1): dec_in = target_seq[:, t].unsqueeze(1) dec_out, dec_hidden = self.decoder(dec_in, dec_hidden, enc_out) outputs.append(dec_out) outputs = torch.cat(outputs, dim=1) return outputs |
The seq2seq model employs teacher forcing during training, where ground-truth tokens (instead of decoder outputs from the previous step) are used as inputs to accelerate learning. In this implementation, the encoder is invoked once, but the decoder is invoked multiple times to generate the output sequence.
Training and Evaluating the Model
With the modules you created in the previous section, you can initialize a seq2seq model:
|
1 2 3 4 5 6 7 8 9 10 11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') enc_vocab = len(en_tokenizer.get_vocab()) dec_vocab = len(fr_tokenizer.get_vocab()) emb_dim = 256 hidden_dim = 256 dropout = 0.1 # Create model encoder = EncoderRNN(enc_vocab, emb_dim, hidden_dim, dropout).to(device) decoder = DecoderRNN(dec_vocab, emb_dim, hidden_dim, dropout).to(device) model = Seq2SeqRNN(encoder, decoder).to(device) |
The training loop is very similar to the one in the previous post,
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
optimizer = optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.CrossEntropyLoss() N_EPOCHS = 50 for epoch in range(N_EPOCHS): model.train() epoch_loss = 0 for en_ids, fr_ids in dataloader: # Move the "sentences" to device en_ids = en_ids.to(device) fr_ids = fr_ids.to(device) # zero the grad, then forward pass optimizer.zero_grad() outputs = model(en_ids, fr_ids) # compute the loss: compare 3D logits to 2D targets loss = loss_fn(outputs.reshape(-1, dec_vocab), fr_ids[:, 1:].reshape(-1)) loss.backward() optimizer.step() epoch_loss += loss.item() print(f"Epoch {epoch+1}/{N_EPOCHS}; Avg loss {epoch_loss/len(dataloader)}; Latest loss {loss.item()}") torch.save(model.state_dict(), f"seq2seq_attn-epoch-{epoch+1}.pth") # Test if (epoch+1) % 5 != 0: continue model.eval() epoch_loss = 0 with torch.no_grad(): for en_ids, fr_ids in dataloader: en_ids = en_ids.to(device) fr_ids = fr_ids.to(device) outputs = model(en_ids, fr_ids) loss = loss_fn(outputs.reshape(-1, dec_vocab), fr_ids[:, 1:].reshape(-1)) epoch_loss += loss.item() print(f"Eval loss: {epoch_loss/len(dataloader)}") |
The training process utilizes cross-entropy loss to compare the output logits with the ground-truth French translation. The decoder begins with [start] and predicts one token at a time. Since training data includes padding and special tokens, we compare output with fr_ids[:, 1:] for alignment. Note that the [pad] token is included in the loss calculation, but you can skip it by specifying the ignore_index parameter when you create the loss function.
The model is trained for 50 epochs. Evaluation is performed once every five epochs. Since you don’t have a separate test set, you can use the training data for evaluation. You should toggle the model to evaluation mode and use the model under torch.no_grad() to avoid computing the gradients.
Using the Model
A well-trained model typically achieves a mean cross-entropy loss around 0.1. While the training loop in the previous section outlines how you can use a model, you should use the encoder and decoder separately for inference since the forward() method of the Seq2SeqRNN class is created for training. Here’s how to use the trained model for translation:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import random model.eval() N_SAMPLES = 5 MAX_LEN = 60 with torch.no_grad(): start_token = torch.tensor([fr_tokenizer.token_to_id("[start]")]).to(device) for en, true_fr in random.sample(text_pairs, N_SAMPLES): en_ids = torch.tensor(en_tokenizer.encode(en).ids).unsqueeze(0).to(device) enc_out, hidden = model.encoder(en_ids) pred_ids = [] prev_token = start_token.unsqueeze(0) for _ in range(MAX_LEN): output, hidden = model.decoder(prev_token, hidden, enc_out) output = output.argmax(dim=2) pred_ids.append(output.item()) prev_token = output # early stop if the predicted token is the end token if pred_ids[-1] == fr_tokenizer.token_to_id("[end]"): break # Decode the predicted IDs pred_fr = fr_tokenizer.decode(pred_ids) print(f"English: {en}") print(f"French: {true_fr}") print(f"Predicted: {pred_fr}") print() |
During inference, you pass on a tensor of sequence length 1 and batch size 1 as the input to the decoder in each step. The decoder will give you a logit vector of sequence length 1 and batch size 1. You use argmax() to decode the output token id. This output token is then used as the input to the next iteration of the loop, until [end] token is generated or or reached the maximum length.
Sample outputs below demonstrate the model’s capabilities:
|
1 2 3 4 5 6 7 8 9 10 11 |
English: we'll all die sooner or later. French: nous mourrons tous tôt ou tard. Predicted: nous mourronsrons tôt ou tard. English: tom made room for mary on the bench. French: tom fit de la place pour marie sur le banc. Predicted: tom fit fait sa pour pour sur le banc banc. English: keep quiet! French: restez tranquille ! Predicted: ailles tranquille |
To further improve the model’s performance, you can:
- Increase the size of the vocabulary in the tokenizer
- Revise the model architecture, e.g., a larger embedding dimension, a larger hidden state dimension, or more layers of GRU.
- Improve the training process, e.g., adjust the learning rate, number of epochs, a different optimizer, or to use a separate test set for evaluation.
For completeness, below is the complete code you created in this post:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
import random import os import re import unicodedata import zipfile import matplotlib.pyplot as plt import numpy as np import requests import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tokenizers import tqdm # # Data preparation # # Download dataset provided by Anki: https://www.manythings.org/anki/ with requests if not os.path.exists("fra-eng.zip"): url = "http://storage.googleapis.com/download.tensorflow.org/data/fra-eng.zip" response = requests.get(url) with open("fra-eng.zip", "wb") as f: f.write(response.content) # Normalize text # each line of the file is in the format "<english>\t<french>" # We convert text to lowercasee, normalize unicode (UFKC) def normalize(line): """Normalize a line of text and split into two at the tab character""" line = unicodedata.normalize("NFKC", line.strip().lower()) eng, fra = line.split("\t") return eng.lower().strip(), fra.lower().strip() text_pairs = [] with zipfile.ZipFile("fra-eng.zip", "r") as zip_ref: for line in zip_ref.read("fra.txt").decode("utf-8").splitlines(): eng, fra = normalize(line) text_pairs.append((eng, fra)) # # Tokenization with BPE # if os.path.exists("en_tokenizer.json") and os.path.exists("fr_tokenizer.json"): en_tokenizer = tokenizers.Tokenizer.from_file("en_tokenizer.json") fr_tokenizer = tokenizers.Tokenizer.from_file("fr_tokenizer.json") else: en_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE()) fr_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE()) # Configure pre-tokenizer to split on whitespace and punctuation, add space at beginning of the sentence en_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True) fr_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True) # Configure decoder: So that word boundary symbol "Ġ" will be removed en_tokenizer.decoder = tokenizers.decoders.ByteLevel() fr_tokenizer.decoder = tokenizers.decoders.ByteLevel() # Train BPE for English and French using the same trainer VOCAB_SIZE = 8000 trainer = tokenizers.trainers.BpeTrainer( vocab_size=VOCAB_SIZE, special_tokens=["[start]", "[end]", "[pad]"], show_progress=True ) en_tokenizer.train_from_iterator([x[0] for x in text_pairs], trainer=trainer) fr_tokenizer.train_from_iterator([x[1] for x in text_pairs], trainer=trainer) en_tokenizer.enable_padding(pad_id=en_tokenizer.token_to_id("[pad]"), pad_token="[pad]") fr_tokenizer.enable_padding(pad_id=fr_tokenizer.token_to_id("[pad]"), pad_token="[pad]") # Save the trained tokenizers en_tokenizer.save("en_tokenizer.json", pretty=True) fr_tokenizer.save("fr_tokenizer.json", pretty=True) # Test the tokenizer print("Sample tokenization:") en_sample, fr_sample = random.choice(text_pairs) encoded = en_tokenizer.encode(en_sample) print(f"Original: {en_sample}") print(f"Tokens: {encoded.tokens}") print(f"IDs: {encoded.ids}") print(f"Decoded: {en_tokenizer.decode(encoded.ids)}") print() encoded = fr_tokenizer.encode("[start] " + fr_sample + " [end]") print(f"Original: {fr_sample}") print(f"Tokens: {encoded.tokens}") print(f"IDs: {encoded.ids}") print(f"Decoded: {fr_tokenizer.decode(encoded.ids)}") print() # # Create PyTorch dataset for the BPE-encoded translation pairs # class TranslationDataset(torch.utils.data.Dataset): def __init__(self, text_pairs): self.text_pairs = text_pairs def __len__(self): return len(self.text_pairs) def __getitem__(self, idx): eng, fra = self.text_pairs[idx] return eng, "[start] " + fra + " [end]" def collate_fn(batch): en_str, fr_str = zip(*batch) en_enc = en_tokenizer.encode_batch(en_str, add_special_tokens=True) fr_enc = fr_tokenizer.encode_batch(fr_str, add_special_tokens=True) en_ids = [enc.ids for enc in en_enc] fr_ids = [enc.ids for enc in fr_enc] return torch.tensor(en_ids), torch.tensor(fr_ids) BATCH_SIZE = 32 dataset = TranslationDataset(text_pairs) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # # Create seq2seq model with attention for translation # class EncoderRNN(nn.Module): """A RNN encoder with an embedding layer""" def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1): """ Args: vocab_size: The size of the input vocabulary embedding_dim: The dimension of the embedding vector hidden_dim: The dimension of the hidden state dropout: The dropout rate """ super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True) self.dropout = nn.Dropout(dropout) def forward(self, input_seq): # input seq = [batch_size, seq_len] -> embedded = [batch_size, seq_len, embedding_dim] embedded = self.dropout(self.embedding(input_seq)) # outputs = [batch_size, seq_len, embedding_dim] # hidden = [1, batch_size, hidden_dim] outputs, hidden = self.rnn(embedded) return outputs, hidden class BahdanauAttention(nn.Module): """Bahdanau Attention https://arxiv.org/pdf/1409.0473.pdf The forward function takes query and keys only, and they should be the same shape (B,S,H) """ def __init__(self, hidden_size): super(BahdanauAttention, self).__init__() self.Wa = nn.Linear(hidden_size, hidden_size) self.Ua = nn.Linear(hidden_size, hidden_size) self.Va = nn.Linear(hidden_size, 1) def forward(self, query, keys): """Bahdanau Attention Args: query: [B, 1, H] keys: [B, S, H] Returns: context: [B, 1, H] weights: [B, 1, S] """ B, S, H = keys.shape assert query.shape == (B, 1, H) scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) scores = scores.transpose(1,2) # scores = [B, 1, S] weights = F.softmax(scores, dim=-1) context = torch.bmm(weights, keys) return context, weights class DecoderRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.embedding = nn.Embedding(vocab_size, embedding_dim) self.dropout = nn.Dropout(dropout) self.attention = BahdanauAttention(hidden_dim) self.gru = nn.GRU(embedding_dim + hidden_dim, hidden_dim, batch_first=True) self.out_proj = nn.Linear(hidden_dim, vocab_size) def forward(self, input_seq, hidden, enc_out): """Single token input, single token output""" # input seq = [batch_size, 1] -> embedded = [batch_size, 1, embedding_dim] embedded = self.dropout(self.embedding(input_seq)) # hidden = [1, batch_size, hidden_dim] # context = [batch_size, 1, hidden_dim] context, attn_weights = self.attention(hidden.transpose(0, 1), enc_out) # rnn_input = [batch_size, 1, embedding_dim + hidden_dim] rnn_input = torch.cat([embedded, context], dim=-1) # rnn_output = [batch_size, 1, hidden_dim] rnn_output, hidden = self.gru(rnn_input, hidden) output = self.out_proj(rnn_output) return output, hidden class Seq2SeqRNN(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, input_seq, target_seq): """Given the partial target sequence, predict the next token""" # input seq = [batch_size, seq_len] # target seq = [batch_size, seq_len] batch_size, target_len = target_seq.shape device = target_seq.device # list for storing the output logits outputs = [] # encoder forward pass enc_out, hidden = self.encoder(input_seq) dec_hidden = hidden # decoder forward pass for t in range(target_len-1): # during training, use the ground truth token as the input (teacher forcing) dec_in = target_seq[:, t].unsqueeze(1) # last target token and hidden states -> next token dec_out, dec_hidden = self.decoder(dec_in, dec_hidden, enc_out) # store the prediction outputs.append(dec_out) outputs = torch.cat(outputs, dim=1) return outputs # Initialize model parameters device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') enc_vocab = len(en_tokenizer.get_vocab()) dec_vocab = len(fr_tokenizer.get_vocab()) emb_dim = 256 hidden_dim = 256 dropout = 0.1 # Create model encoder = EncoderRNN(enc_vocab, emb_dim, hidden_dim, dropout).to(device) decoder = DecoderRNN(dec_vocab, emb_dim, hidden_dim, dropout).to(device) model = Seq2SeqRNN(encoder, decoder).to(device) print(model) print("Model created with:") print(f" Input vocabulary size: {enc_vocab}") print(f" Output vocabulary size: {dec_vocab}") print(f" Embedding dimension: {emb_dim}") print(f" Hidden dimension: {hidden_dim}") print(f" Dropout: {dropout}") print(f" Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") # Initialize model parameters with uniform distribution [-0.08, 0.08] #for name, param in model.named_parameters(): # if param.dim() > 1: # nn.init.normal_(param.data, mean=0, std=0.01) # Train unless model.pth exists if os.path.exists("seq2seq_attn.pth"): model.load_state_dict(torch.load("seq2seq_attn.pth")) else: optimizer = optim.Adam(model.parameters(), lr=0.0005) loss_fn = nn.CrossEntropyLoss() #ignore_index=fr_tokenizer.token_to_id("[pad]")) N_EPOCHS = 100 for epoch in range(N_EPOCHS): model.train() epoch_loss = 0 for en_ids, fr_ids in tqdm.tqdm(dataloader, desc="Training"): # Move the "sentences" to device en_ids = en_ids.to(device) fr_ids = fr_ids.to(device) # zero the grad, then forward pass optimizer.zero_grad() outputs = model(en_ids, fr_ids) # compute the loss: compare 3D logits to 2D targets loss = loss_fn(outputs.reshape(-1, dec_vocab), fr_ids[:, 1:].reshape(-1)) loss.backward() optimizer.step() epoch_loss += loss.item() print(f"Epoch {epoch+1}/{N_EPOCHS}; Avg loss {epoch_loss/len(dataloader)}; Latest loss {loss.item()}") torch.save(model.state_dict(), f"seq2seq_attn-epoch-{epoch+1}.pth") # Test if (epoch+1) % 5 != 0: continue model.eval() epoch_loss = 0 with torch.no_grad(): for en_ids, fr_ids in tqdm.tqdm(dataloader, desc="Evaluating"): en_ids = en_ids.to(device) fr_ids = fr_ids.to(device) outputs = model(en_ids, fr_ids) loss = loss_fn(outputs.reshape(-1, dec_vocab), fr_ids[:, 1:].reshape(-1)) epoch_loss += loss.item() print(f"Eval loss: {epoch_loss/len(dataloader)}") torch.save(model.state_dict(), "seq2seq_attn.pth") # Test for a few samples model.eval() N_SAMPLES = 5 MAX_LEN = 60 with torch.no_grad(): start_token = torch.tensor([fr_tokenizer.token_to_id("[start]")]).to(device) for en, true_fr in random.sample(text_pairs, N_SAMPLES): en_ids = torch.tensor(en_tokenizer.encode(en).ids).unsqueeze(0).to(device) enc_out, hidden = model.encoder(en_ids) pred_ids = [] prev_token = start_token.unsqueeze(0) for _ in range(MAX_LEN): output, hidden = model.decoder(prev_token, hidden, enc_out) output = output.argmax(dim=2) pred_ids.append(output.item()) prev_token = output # early stop if the predicted token is the end token if pred_ids[-1] == fr_tokenizer.token_to_id("[end]"): break # Decode the predicted IDs pred_fr = fr_tokenizer.decode(pred_ids) print(f"English: {en}") print(f"French: {true_fr}") print(f"Predicted: {pred_fr}") print() |
Note that, the code above uses GRU as the RNN module in the decoder and encoder. You can also use other RNN modules, such as LSTM or bi-directional RNN. All you need to just to swap the nn.GRU module in the encoder and decoder with a different module. Below is an implementation of the encoder and decoder using LSTM and scaled dot-product attention. You can replace the implementation above the the code should just run fine.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
... class EncoderRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) def forward(self, input_seq): embedded = self.embedding(input_seq) outputs, (hidden, cell) = self.lstm(embedded) return outputs, hidden, cell class DecoderRNN(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, embedding_dim) self.attention = nn.MultiheadAttention(hidden_dim, num_heads=1, dropout=dropout, batch_first=True) self.lstm = nn.LSTM(embedding_dim + hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) self.out_proj = nn.Linear(hidden_dim, vocab_size) def forward(self, input_seq, hidden, cell, enc_out): embedded = self.embedding(input_seq) context = self.attention(hidden.transpose(0, 1), enc_out, enc_out)[0] rnn_input = torch.cat([embedded, context], dim=-1) output, (hidden, cell) = self.lstm(rnn_input, (hidden, cell)) output = self.out_proj(output) return output, hidden, cell |
Further Readings
Below are some resources that you may find useful:
- Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et al., 2014 paper)
- Sequence to Sequence Learning with Neural Networks
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
- PyTorch Tutorial on Seq2Seq Translation
Summary
In this post, you learned how to build and train an attention-based seq2seq model for English to French translation. Specifically, you learned about:
- How to build an encoder-decoder architecture with GRU
- Implementing attention mechanisms to help the model focus on relevant input
- Building a complete translation model in PyTorch
- Training effectively using teacher forcing
Attention mechanisms significantly improve translation by enabling dynamic focus on relevant input parts during generation.







No comments yet.