Question Answering is a crucial natural language processing task that enables machines to understand and respond to human questions by extracting relevant information from a given context. DistilBERT, a distilled version of BERT, offers an excellent balance between performance and computational efficiency for building Q&A systems.
In this tutorial, you will learn how to build a powerful Question Answering (Q&A) system using DistilBERT and the transformers
library. You’ll learn everything from basic implementation to advanced features. In particular, you will learn:
- How to implement a basic Q&A system with DistilBERT
- Advanced techniques for improving answer quality
Let’s get started.

Building Q&A Systems with DistilBERT and Transformers
Photo by Ana Municio. Some rights reserved.
Overview
This post is in three parts; they are:
- Building a simple Q&A system
- Handling Large Contexts
- Building an Expert System
Building a Simple Q&A System
Question and answering system is not just to throw a question at a model and get an answer. You want the answer to be accurate and well-supported. The way to do this is to provide a “context” in which the answer should be found. While this prevents the model from answering an open-ended question, it also prevents it from hallucinating an answer. The model that can do this task will be able to understand the question and the context, which is more than just a language model.
A model that can do this is BERT. Below, you will use the DistilBERT model to build a simple Q&A system:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "distilbert-base-uncased-distilled-squad" tokenizer = DistilBertTokenizer.from_pretrained(model_name) model = DistilBertForQuestionAnswering.from_pretrained(model_name) qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) max_answer_length = 50 top_k = 3 question = "What is the capital of France?" context = "France is a country in Western Europe. Its capital is Paris, which is known for its art, fashion, gastronomy and culture." result = qa_pipeline(question=question, context=context, max_answer_len=max_answer_length, top_k=top_k) print(f"Question: {question}") print(f"Context: {context}") print(result) |
This is the output you will get:
1 2 3 4 5 6 |
Device set to use cpu Question: What is the capital of France? Context: France is a country in Western Europe. Its capital is Paris, which is known for its art, fashion, gastronomy and culture. [{'score': 0.9776948690414429, 'start': 54, 'end': 59, 'answer': 'Paris'}, {'score': 0.017595181241631508, 'start': 54, 'end': 60, 'answer': 'Paris,'}, {'score': 0.0026904228143393993, 'start': 39, 'end': 59, 'answer': 'Its capital is Paris'}] |
The model we used is distilbert-base-uncased-distilled-squad
which is a DistilBERT model fine-tuned with SQuAD dataset. It is an “uncased” model, which means it treats input as case-insensitive. This is a fine-tuned model that can perform better on knowledge distillation. Hence, it is particularly good for question-answering tasks that require understanding both the question and the context.
To use it, you created a pipeline
using transformers
library. You requested it to be a question-answering pipeline but specified the model and tokenizer to use rather than let the pipeline()
function pick one for you.
When you invoke the pipeline, you provide the question and the context. The model will find the answer in the context and return the answer. However, instead of a simple answer, it returns the positions from the context where the answer is found, together with the score (between 0 and 1) of the answer. Since the top_k
is set to 3, three such answers are returned.
From the output, you can find that the one with the highest score is simply “Paris” (character positions 54 to 59 in the context string), but the other answers are not wrong, just presented differently. You can modify the code above to pick the best answer based on the score.
Handling Large Contexts
This simple Q&A system’s problem is that it can only handle short contexts. The model has a limit on the maximum sequence length that it can accept, which in this particular model is 512 tokens.
Usually, the problem with this limit is not on the question but on the context since you usually have a large piece of text as the background information, while the question is a single sentence that you want to find the answer from the context. To handle this, you can “chunk”, namely, split the long context string into smaller chunks to feed into the Q&A model one by one. You should repeat the question but iterate on the different chunks to find the answer.
With top_k=3
, you can expect to have 3 answers from each chunk. Since each answer has a score, you can simply pick the answer with the highest score. You can also discard the answers with low scores before finding the best answer. In this way, you can tell if the context does not provide enough information to answer the question.
Let’s see how to implement this:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
mport time from dataclasses import dataclass import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline @dataclass class QAConfig: """Configuration for QA settings""" max_sequence_length: int = 512 max_answer_length: int = 50 top_k: int = 3 threshold: float = 0.5 class QASystem: """Q&A system with chunking""" def __init__(self, model_name="distilbert-base-uncased-distilled-squad", device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) self.model = DistilBertForQuestionAnswering.from_pretrained(model_name) # Initialize pipeline for simple queries and answer cache self.qa_pipeline = pipeline("question-answering", model=self.model, tokenizer=self.tokenizer, device=self.device) self.answer_cache = {} def preprocess_context(self, context, max_length=512): """Split long contexts into chunks below max_length""" chunks = [] current_chunk = [] current_length = 0 for word in context.split(): if current_length + 1 + len(word) > max_length: chunks.append(" ".join(current_chunk)) current_chunk = [word] current_length = len(word) else: current_chunk.append(word) current_length += 1 + len(word) # length of space + word # Add the last chunk if it's not empty if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def get_answer(self, question, context, config): """Get answer with confidence score""" # Check cache cache_key = (question, context) if cache_key in self.answer_cache: return self.answer_cache[cache_key] # Preprocess context into chunks context_chunks = self.preprocess_context(context, config.max_sequence_length) # Get answers from all chunks answers = [] for chunk in context_chunks: result = self.qa_pipeline(question=question, context=chunk, max_answer_len=config.max_answer_length, top_k=config.top_k) assert isinstance(result, list) for answer in result: if answer["score"] >= config.threshold: answers.append(answer) # Return the best answer or indicate no answer found if answers: best_answer = max(answers, key=lambda x: x["score"]) result = { "answer": best_answer["answer"], "confidence": best_answer["score"], } else: result = { "answer": "No answer found", "confidence": 0.0, } # Cache the result self.answer_cache[cache_key] = result return result config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() context = """ The Python programming language was created by Guido van Rossum and was released in 1991. Python is known for its simple syntax and readability. It has become one of the most popular programming languages, especially in fields like data science and machine learning. The language is maintained by the Python Steering Council and developed by a large community of contributors. """ questions = [ "Who created Python?", "When was Python released?", "Why is Python popular?", "What is Python known for?" ] for question in questions: start_time = time.time() answer = qa_system.get_answer(question, context, config) duration = time.time() - start_time print(f"Question: {question}") print(f"Answer: {answer['answer']}") print(f"Confidence: {answer['confidence']:.2f}") print(f"Duration: {duration:.2f}s") print("-" * 50) |
This wraps the workflow into a class to make it easier to use. You pass the question and the context to the get_answer()
method, and it will return the answer with the highest score.
In the get_answer()
method, it will return the answer immediately if it is already in the cache. Otherwise, it will preprocess the context into chunks by splitting at spaces to keep each chunk below the length limit. Then each chunk is matched with the question to get the answers (with scores) from the Q&A model. Only the answers with scores above the threshold are considered valid. Then the best answer is picked. There may be no answer found with a high enough score. In that case, you marked it as “No answer found”.
For syntax convenience, the parameters used are stored in a dataclass
object. Note that it sets max_sequence_length
to 512. It is a conservative choice since the model can handle up to 512 tokens, which is approximately 1500 characters. However, setting a low sequence length can help the model run more efficiently since the time and memory complexity of transformer models is quadratic to the sequence length.
The output of this code is:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
Device set to use cuda Question: Who created Python? Answer: Guido van Rossum Confidence: 1.00 Duration: 0.10s -------------------------------------------------- Question: When was Python released? Answer: 1991 Confidence: 0.98 Duration: 0.00s -------------------------------------------------- Question: Why is Python popular? Answer: No answer found Confidence: 0.00 Duration: 0.00s -------------------------------------------------- Question: What is Python known for? Answer: No answer found Confidence: 0.00 Duration: 0.00s -------------------------------------------------- |
You may notice that the implementation above may have a problem in that a chunk is split in the middle of a sentence where the most appropriate answer lies. In this case, you may find that the Q&A model cannot find the answer, or a suboptimal answer is returned. This is a problem in the algorithm of preprocess_context()
method. You may consider using a longer chunk size or creating chunks with overlapping words. You can try to implement it as an exercise.
Building an Expert System
With the Q&A system above as a building block, you can automate the process of constructing a context for a question. With a database of documents that can be used as context for Q&A, you can build an expert system that can answer a wide range of questions.
Building a good expert system is a complex task that involves a lot of considerations. However, the high-level framework is not difficult to understand. This is similar to the idea of RAG, retrieval-augmented generation, where the context is retrieved from a database of documents, and the answer is generated by the model. One key component is a database that can retrieve the most relevant context for a question. Let’s see how you can build one:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import collections class ContextManager: def __init__(self, max_contexts=10): self.contexts = collections.OrderedDict() self.max_contexts = max_contexts def add_context(self, context_id, context): """Add context with automatic cleanup""" if len(self.contexts) >= self.max_contexts: self.contexts.popitem(last=False) self.contexts[context_id] = context def get_context(self, context_id): """Get context by ID""" return self.contexts.get(context_id) def search_relevant_context(self, question, top_k=3): """Search for relevant contexts based on relevance score""" relevant_contexts = [] for context_id, context in self.contexts.items(): relevance_score = self._calculate_relevance(question, context) relevant_contexts.append((relevance_score, context_id)) return sorted(relevant_contexts, reverse=True)[:top_k] def _calculate_relevance(self, question, context): """Calculate relevance score between question and context. This is a simple counting the number of overlap words """ question_words = set(question.lower().split()) context_words = set(context.lower().split()) return len(question_words.intersection(context_words)) / len(question_words) |
This class is named ContextManager
. You can add a piece of text to it with a context ID and the context manager keeps only a limited number of contexts. You can get back the text using the context ID. But the most important method is search_relevant_context()
, which will search for the most relevant contexts based on the provided question. You can use a different algorithm to calculate the relevance score. Here a simple one is used, which is to count the number of overlap words, or the Jaccard similarity.
With this class, you can build an expert system that can answer a wide range of questions. Here is an example of how to use it:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
... context_manager = ContextManager(max_contexts=10) context_manager.add_context("python", """ Python is a high-level, interpreted programming language created by Guido van Rossum and released in 1991. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Python features a dynamic type system and automatic memory management and supports multiple programming paradigms, including structured, object-oriented, and functional programming. """) context_manager.add_context("machine_learning", """ Machine learning is a field of study that gives computers the ability to learn without being explicitly programmed. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention. """) config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() question = "Who created Python?" relevant_contexts = context_manager.search_relevant_context(question, top_k=1) if relevant_contexts: relevance, context_id = relevant_contexts[0] context = context_manager.get_context(context_id) print(f"Question: {question}") print(f"Most relevant context: {context_id} (relevance: {relevance:.2f})") print(context) answer = qa_system.get_answer(question, context, config) print(f"Answer: {answer['answer']}") print(f"Confidence: {answer['confidence']:.2f}") else: print("No relevant context found.") |
You first add some contexts to the context manager. Depending on the desired maximum size of the context manager, you can add a lot of text to the system. Then with a question, you can search for the most relevant context. Then, you can feed the question and the context to the Q&A system to get the answer as in the previous section, in which the chunking and iteratively finding the best answer is done behind the scenes.
You can extend this to try more than the top context to find the answer in a wider range of contexts. This is a simple way to avoid missing the answer in the context of not scoring the best. However, if you have a better way to score the relevance of the context, such as using a neural network model to compute the relevance score, you may not need to try a lot of contexts.
The output of the above will be:
1 2 3 4 5 6 7 8 9 10 |
Question: Who created Python? Most relevant context: python (relevance: 0.33) Python is a high-level, interpreted programming language created by Guido van Rossum and released in 1991. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Python features a dynamic type system and automatic memory management and supports multiple programming paradigms, including structured, object-oriented, and functional programming. Answer: Guido van Rossum Confidence: 1.00 |
Putting it all together, below is the complete code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import collections import time from dataclasses import dataclass import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline @dataclass class QAConfig: """Configuration for QA settings""" max_sequence_length: int = 512 max_answer_length: int = 50 top_k: int = 3 threshold: float = 0.5 class QASystem: """Q&A system with chunking""" def __init__(self, model_name="distilbert-base-uncased-distilled-squad", device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) self.model = DistilBertForQuestionAnswering.from_pretrained(model_name) # Initialize pipeline for simple queries and answer cache self.qa_pipeline = pipeline("question-answering", model=self.model, tokenizer=self.tokenizer, device=self.device) self.answer_cache = {} def preprocess_context(self, context, max_length=512): """Split long contexts into chunks below max_length""" chunks = [] current_chunk = [] current_length = 0 for word in context.split(): if current_length + 1 + len(word) > max_length: chunks.append(" ".join(current_chunk)) current_chunk = [word] current_length = len(word) else: current_chunk.append(word) current_length += 1 + len(word) # length of space + word # Add the last chunk if it's not empty if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def get_answer(self, question, context, config): """Get answer with confidence score""" # Check cache cache_key = (question, context) if cache_key in self.answer_cache: return self.answer_cache[cache_key] # Preprocess context into chunks context_chunks = self.preprocess_context(context, config.max_sequence_length) # Get answers from all chunks answers = [] for chunk in context_chunks: result = self.qa_pipeline(question=question, context=chunk, max_answer_len=config.max_answer_length, top_k=config.top_k) assert isinstance(result, list) for answer in result: if answer["score"] >= config.threshold: answers.append(answer) # Return the best answer or indicate no answer found if answers: best_answer = max(answers, key=lambda x: x["score"]) result = { "answer": best_answer["answer"], "confidence": best_answer["score"], } else: result = { "answer": "No answer found", "confidence": 0.0, } # Cache the result self.answer_cache[cache_key] = result return result class ContextManager: def __init__(self, max_contexts=10): self.contexts = collections.OrderedDict() self.max_contexts = max_contexts def add_context(self, context_id, context): """Add context with automatic cleanup""" if len(self.contexts) >= self.max_contexts: self.contexts.popitem(last=False) self.contexts[context_id] = context def get_context(self, context_id): """Get context by ID""" return self.contexts.get(context_id) def search_relevant_context(self, question, top_k=3): """Search for relevant contexts based on relevance score""" relevant_contexts = [] for context_id, context in self.contexts.items(): relevance_score = self._calculate_relevance(question, context) relevant_contexts.append((relevance_score, context_id)) return sorted(relevant_contexts, reverse=True)[:top_k] def _calculate_relevance(self, question, context): """Calculate relevance score between question and context. This is a simple counting the number of overlap words """ question_words = set(question.lower().split()) context_words = set(context.lower().split()) return len(question_words.intersection(context_words)) / len(question_words) context_manager = ContextManager(max_contexts=10) context_manager.add_context("python", """ Python is a high-level, interpreted programming language created by Guido van Rossum and released in 1991. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Python features a dynamic type system and automatic memory management and supports multiple programming paradigms, including structured, object-oriented, and functional programming. """) context_manager.add_context("machine_learning", """ Machine learning is a field of study that gives computers the ability to learn without being explicitly programmed. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention. """) config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() question = "Who created Python?" relevant_contexts = context_manager.search_relevant_context(question, top_k=1) if relevant_contexts: relevance, context_id = relevant_contexts[0] context = context_manager.get_context(context_id) print(f"Question: {question}") print(f"Most relevant context: {context_id} (relevance: {relevance:.2f})") print(context) answer = qa_system.get_answer(question, context, config) print(f"Answer: {answer['answer']}") print(f"Confidence: {answer['confidence']:.2f}") else: print("No relevant context found.") |
Further Readings
Below are some resources that you may find useful:
- DistilBERT model used in this tutorial
- What is RAG (Retrieval-Augmented Generation)?
- Question Answering Pipeline in transformers documentation
Summary
In this tutorial, you have built a comprehensive Q&A system using DistilBERT. In particular, you learned how to:
- Build a Q&A system using pipeline function in transformers
- Handle large contexts by chunking
- Using a context manager to manage the contexts and build an expert system on top of it
No comments yet.