Serving Multiple Users at Once: How Continuous Batching Keeps LLM Inference Efficient

In the previous article, we saw how a language model processes a prompt during prefill, then generates tokens one at a time during decode, and uses KV cache to avoid repeated computation. In the real world, inference servers handle hundreds or thousands of requests at the same time. How a server schedules those requests determines whether the GPU is doing useful work or sitting idle waiting.

In this tutorial, we take a hands-on approach to understand:

  • Why does static batching create a bottleneck and waste tokens on padding
  • How dynamic scheduling admits new requests the moment a slot opens
  • How ragged batching allows multiple prompts to be processed together

By the end of this tutorial, you will have a working code that demonstrates how continuous batching works.

Let’s get started.

Serving Multiple Users at Once: How Continuous Batching Keeps LLM Inference Efficient
Photo by Petra Reid. Some rights reserved.

Overview

This article is divided into four parts; they are:

  • The Problem with Static Batching
  • Code Example of Static Batching
  • Continuous Batching: Dynamic Scheduling and Ragged Batching
  • Full Implementation

The Problem with Static Batching

The simplest way to serve multiple requests together is to use static batching, by grouping them into fixed-size batches and processing each batch together.

For example, there are 3 requests:

  • A: “The capital of France is” (generates 6 more tokens)
  • B: “Today’s weather is so”(generates 50 more tokens)
  • C: “In machine learning, a transformer is” (generates 300 more tokens)

In this batch, requests A and B finish early, but their slots cannot be freed. They sit idle while the GPU decodes the remaining tokens for C.

At some point, the decoding process looks like this:

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 …300
A <BOS> <The> <capital> <of> <France> <is> <the> <capital> <of> <the> <French> <Republic> <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> …<PAD>
B <BOS> <Today> <‘s> <weather> <is> <so> <cold> <that> <it> <‘s> <hard> <to> <see> <the> <sun> <.> <But> <it> <‘s> <…>
C <BOS> <In> <machine> <learning> <,> <a> <transformer> <is> <a> <type> <of> <machine> <learning> <algorithm> <that> <can> <be> <used> <to> <…>

Note about the special tokens: <BOS>: Beginning of Sentence; <EOS>: End of Sentence; <PAD> padding

See the prompt A is padded all the way through the end of 300 tokens in this batch. The GPU is burning cycles on computation that contributes nothing to any result with the padding tokens. Not to mention that likely the response to request A is not delivered until request C finishes.

Code Example of Static Batching

We will run six requests of different lengths through a GPT-2 model using static batching, with a batch size of 3. For illustration, each request consists of a prompt paired with the maximum number of tokens to generate.

Here is the static batching function:

In the beginning of the outer for-loop, wave is the static batch collected from the requests. The prompts variable is a list of strings to be tokenized into inputs. The LLM is invoked with the Hugging Face transformers library to generate tokens for the longest (with do_sample=False, greedy decoding instead of beam search).

Running this produces:

In this example, we can see how a short prompt (6 max tokens) had to wait in the iterations of the forward pass with another long prompt (300 max tokens), and we do not get the result until the entire batch is finished.

Continuous Batching: Dynamic Scheduling and Ragged Batching

Continuous batching is used to address the above problems to improve efficiency. There are two ideas behind continuous batching: dynamic scheduling and ragged batching.

Dynamic Scheduling

Instead of waiting for the entire batch to finish before admitting new work, the scheduler checks after every single decode step. The moment a sequence finishes (hits <EOS> token or reaches max token length), its slot is freed, and the next queued prompt is admitted immediately. Short requests do not keep a slot open any longer than they need it.

To see how this works in practice, it helps to think of the scheduler as managing two data structures:

  • A waiting_queue of requests that have arrived but not yet started
  • A running_set of sequences currently being decoded, each carrying its own KV cache and position state

In pseudocode, the main loop looks like this:

This loop runs at the iteration level, once per forward pass instead of per request. Every step, the batch may look different from the last, as some sequences have finished and new ones are admitted. However, it introduces a new problem in step 3 of the pseudocode above: when a new prompt is admitted mid-batch, it needs to go through prefill, while the other sequences are decoding just one token at a time. To run them together in a rectangular batch, a lot of padding tokens are wasted to match the length of the incoming prompt:

1 2 3 4 5 6 7 8 9 10
B (decode) <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> to
C (decode) <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> algorithm
D (prefill) Once upon a time in a land far away ,

B and C are mid-generation, their previous tokens (”Today’s weather is so…” and “In machine learning, a transformer is…”) have already been processed and stored in KV cache. Each only needs to submit 1 new query token in this step. However, for request D, it is a new prompt entering the batch, and none of the tokens were cached, so all 10 tokens had to be fed for the prefill. As a result, 18 of the 30 tokens in this step are padding.

Ragged Batching

The solution to the above problem is ragged batching, which concatenates prompts together. However, we would not want prompts “In machine learning, a transformer is…” to attend to any tokens from “Once upon a time…”. Therefore, a block-diagonal causal mask is used to prevent that. Here’s an illustration of the attention mask (# = attend; . = blocked):

B: Today B: ‘s B: weather B: is B: so B: cold B: that B: it B: ‘s B: hard D: In D: machine D: learning D: , D: a D: transformer D: is D: a D: type D: of D: machine D: learning A: Once A: upon A: a A: time A: in A: a A: land A: far A: away A: , B: to D: algorithm
A: Once . . . . . . . . . . . . . . . . . . . . . . # . . . . . . . . . . .
A: upon . . . . . . . . . . . . . . . . . . . . . . # # . . . . . . . . . .
A: a . . . . . . . . . . . . . . . . . . . . . . # # # . . . . . . . . .
A: time . . . . . . . . . . . . . . . . . . . . . . # # # # . . . . . . . .
A: in . . . . . . . . . . . . . . . . . . . . . . # # # # # . . . . . . .
A: a . . . . . . . . . . . . . . . . . . . . . . # # # # # # . . . . . .
A: land . . . . . . . . . . . . . . . . . . . . . . # # # # # # # . . . . .
A: far . . . . . . . . . . . . . . . . . . . . . . # # # # # # # # . . . .
A: away . . . . . . . . . . . . . . . . . . . . . . # # # # # # # # # . . .
A: , . . . . . . . . . . . . . . . . . . . . . . # # # # # # # # # # . .
B: to # # # # # # # # # # . . . . . . . . . . . . . . . . . . . . . . # .
D: algorithm . . . . . . . . . . # # # # # # # # # # # # . . . . . . . . . . . #

In this case, the tensor for the attention operation is no longer a batched 4D tensor of BSHD shape, but logically an unbatched 3D tensor of THD shape, where T denotes the token dimension of the concatenated prompts.

Full Implementation

Here is the full Python implementation of static batching and continuous batching for comparison. Notice the time difference of how continuous batching makes the LLM inference process so much more efficient. Also, notice the outputs are identical in the two versions, as the block-diagonal mask guarantees that packing sequences together does not change what the model computes.

This is a long piece of code. Because the token sequences are concatenated, you need the Sequence data class to preserve some essential data and the pointer to the KV cache for each original prompt. The static batching baseline is the same function as before. The function continuous_batching() is the continuous batching entry point, where the internal function _admit() loads a new request when one in the previous batch finishes its generation.

Function visualize_ragged_steps() is just to print the status of each step. Actual prefill or decode steps are implemented in _ragged_step().

In _ragged_step(), multiple sequences are concatenated into input_ids and the corresponding block-diagonal causal mask is created as attn_mask. The model is invoked with use_cache=True, which will use the KV cache as provided with the past_key_values parameter. Since input_ids parameter to call with the model may correspond to different requests, the KV cache will be rebuilt on each iteration. The latter half of the _ragged_step() function is to manage the KV cache.

Running this code, you will see:

You can observe how the batch changes (and correspondingly, the causal mask) in each step. Note that the attention operation has a complexity of $O(N^2)$ for the number of input tokens $N$. The time taken for generation in continuous batching is much shorter because you eliminate all padding tokens in the input, making the generation more efficient.

Further Readings

Below are some resources that you may find useful:

Summary

In this article, we walked through the two problems with static batching – having short prompts to wait for larger prompts in the same batch and wasting GPU cycles on padding tokens. Then, we built a working solution to address these problems with dynamic scheduling and ragged batching. These ideas work together to keep every GPU cycle working on real tokens more efficiently.

No comments yet.

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.