In the previous article, we saw how a language model processes a prompt during prefill, then generates tokens one at a time during decode, and uses KV cache to avoid repeated computation. In the real world, inference servers handle hundreds or thousands of requests at the same time. How a server schedules those requests determines whether the GPU is doing useful work or sitting idle waiting.
In this tutorial, we take a hands-on approach to understand:
- Why does static batching create a bottleneck and waste tokens on padding
- How dynamic scheduling admits new requests the moment a slot opens
- How ragged batching allows multiple prompts to be processed together
By the end of this tutorial, you will have a working code that demonstrates how continuous batching works.
Let’s get started.

Serving Multiple Users at Once: How Continuous Batching Keeps LLM Inference Efficient
Photo by Petra Reid. Some rights reserved.
Overview
This article is divided into four parts; they are:
- The Problem with Static Batching
- Code Example of Static Batching
- Continuous Batching: Dynamic Scheduling and Ragged Batching
- Full Implementation
The Problem with Static Batching
The simplest way to serve multiple requests together is to use static batching, by grouping them into fixed-size batches and processing each batch together.
For example, there are 3 requests:
- A: “The capital of France is” (generates 6 more tokens)
- B: “Today’s weather is so”(generates 50 more tokens)
- C: “In machine learning, a transformer is” (generates 300 more tokens)
In this batch, requests A and B finish early, but their slots cannot be freed. They sit idle while the GPU decodes the remaining tokens for C.
At some point, the decoding process looks like this:
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | …300 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| A | <BOS> | <The> | <capital> | <of> | <France> | <is> | <the> | <capital> | <of> | <the> | <French> | <Republic> | <EOS> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | …<PAD> |
| B | <BOS> | <Today> | <‘s> | <weather> | <is> | <so> | <cold> | <that> | <it> | <‘s> | <hard> | <to> | <see> | <the> | <sun> | <.> | <But> | <it> | <‘s> | <…> |
| C | <BOS> | <In> | <machine> | <learning> | <,> | <a> | <transformer> | <is> | <a> | <type> | <of> | <machine> | <learning> | <algorithm> | <that> | <can> | <be> | <used> | <to> | <…> |
Note about the special tokens: <BOS>: Beginning of Sentence; <EOS>: End of Sentence; <PAD> padding
See the prompt A is padded all the way through the end of 300 tokens in this batch. The GPU is burning cycles on computation that contributes nothing to any result with the padding tokens. Not to mention that likely the response to request A is not delivered until request C finishes.
Code Example of Static Batching
We will run six requests of different lengths through a GPT-2 model using static batching, with a batch size of 3. For illustration, each request consists of a prompt paired with the maximum number of tokens to generate.
|
1 2 3 4 5 6 7 8 9 10 11 |
MODEL_ID = "openai-community/gpt2" BATCH_SIZE = 3 requests = [ ("The capital of France is", 6), ("Today's weather is so", 50), ("In machine learning, a transformer is", 300), ("Once upon a time in a land far away,", 30), ("Quantum computing differs from classical computing because", 180), ("The history of the Roman Empire began", 45), ] |
Here is the static batching function:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
def static_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]: """Baseline. Process requests BATCH_SIZE at a time; each wave runs together until its LONGEST request finishes, then a batch barrier clears before the next wave starts. Downside: short requests in a wave idle until the wave's longest is done - and no slot can be refilled until the whole wave clears the barrier. """ if not requests: return [] tokenizer.padding_side = "left" results: dict[int, str] = {} indexed = list(enumerate(requests)) # (req_id, (prompt, cap)) for wave_start in range(0, len(indexed), BATCH_SIZE): wave = indexed[wave_start: wave_start + BATCH_SIZE] wave_max = max(cap for _req_id, (_prompt, cap) in wave) # Show which request occupies each slot in this wave. for slot, (req_id, (prompt, cap)) in enumerate(wave): print(f" ++ slot {slot} <- req {req_id} ({cap} tok cap): {prompt!r}", flush=True) prompts = [p for _, (p, _) in wave] inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ).to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=wave_max, # whole wave decodes to the longest pad_token_id=tokenizer.eos_token_id, do_sample=False, ) width = inputs.input_ids.shape[1] print( f" *** batch barrier: all {len(wave)} slots wait for the longest " f"({wave_max} tokens) ***", flush=True, ) for slot, ((req_id, (prompt, cap)), row) in enumerate(zip(wave, output_ids)): text = prompt + tokenizer.decode(row[width:width + cap], skip_special_tokens=True) results[req_id] = text print( f" -- slot {slot} done req {req_id} ({cap}/{wave_max} tokens): {text[:90]}", flush=True, ) return [results[k] for k in sorted(results)] |
In the beginning of the outer for-loop, wave is the static batch collected from the requests. The prompts variable is a list of strings to be tokenized into inputs. The LLM is invoked with the Hugging Face transformers library to generate tokens for the longest (with do_sample=False, greedy decoding instead of beam search).
Running this produces:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
++ slot 0 <- req 0 (6 tok cap): 'The capital of France is' ++ slot 1 <- req 1 (50 tok cap): "Today's weather is so" ++ slot 2 <- req 2 (300 tok cap): 'In machine learning, a transformer is' *** batch barrier: all 3 slots wait for the longest (300 tokens) *** -- slot 0 done req 0 (6/300 tokens): The capital of France is the capital of the French Republic -- slot 1 done req 1 (50/300 tokens): Today's weather is so cold that it's hard to see the sun. But it's not like we're going to -- slot 2 done req 2 (300/300 tokens): In machine learning, a transformer is a type of machine learning algorithm that can be use ++ slot 0 <- req 3 (30 tok cap): 'Once upon a time in a land far away,' ++ slot 1 <- req 4 (180 tok cap): 'Quantum computing differs from classical computing because' ++ slot 2 <- req 5 (45 tok cap): 'The history of the Roman Empire began' *** batch barrier: all 3 slots wait for the longest (180 tokens) *** -- slot 0 done req 3 (30/180 tokens): Once upon a time in a land far away, the sun was shining, and the moon was shining. The su -- slot 1 done req 4 (180/180 tokens): Quantum computing differs from classical computing because it is based on the notion of a -- slot 2 done req 5 (45/180 tokens): The history of the Roman Empire began in the fourth century B.C.E. with the arrival of the |
In this example, we can see how a short prompt (6 max tokens) had to wait in the iterations of the forward pass with another long prompt (300 max tokens), and we do not get the result until the entire batch is finished.
Continuous Batching: Dynamic Scheduling and Ragged Batching
Continuous batching is used to address the above problems to improve efficiency. There are two ideas behind continuous batching: dynamic scheduling and ragged batching.
Dynamic Scheduling
Instead of waiting for the entire batch to finish before admitting new work, the scheduler checks after every single decode step. The moment a sequence finishes (hits <EOS> token or reaches max token length), its slot is freed, and the next queued prompt is admitted immediately. Short requests do not keep a slot open any longer than they need it.
To see how this works in practice, it helps to think of the scheduler as managing two data structures:
- A
waiting_queueof requests that have arrived but not yet started - A
running_setof sequences currently being decoded, each carrying its own KV cache and position state
In pseudocode, the main loop looks like this:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
while not (waiting_queue.empty() and running_set.empty()): # 1) Remove finished sequences for seq in list(running_set): if seq.done: # hit EOS or max tokens running_set.remove(seq) release_kv_cache(seq) # 2) Admit new requests from waiting queue while not waiting_queue.empty(): if len(running_set) >= max_num_seqs: break next_req = waiting_queue.peek() if would_violate_token_limit(next_req, running_set, max_num_batched_tokens): break waiting_queue.pop() init_seq_state(next_req) # allocate KV, set step=0, etc. running_set.add(next_req) if len(running_set) == 0: break # nothing left to run # 3) Form the current batch for this iteration batch = select_seqs_for_step(running_set, max_num_batched_tokens) # 4) Run ONE model forward step for all seqs in 'batch' logits = model.forward(batch.input_tokens, batch.kv_caches) # 5) For each sequence in batch: for seq, seq_logits in zip(batch.seqs, logits): next_token = sample_or_argmax(seq_logits, seq.sampling_params) seq.tokens.append(next_token) update_kv_cache(seq, next_token) if is_eos_or_max_len(seq, next_token): seq.done = True |
This loop runs at the iteration level, once per forward pass instead of per request. Every step, the batch may look different from the last, as some sequences have finished and new ones are admitted. However, it introduces a new problem in step 3 of the pseudocode above: when a new prompt is admitted mid-batch, it needs to go through prefill, while the other sequences are decoding just one token at a time. To run them together in a rectangular batch, a lot of padding tokens are wasted to match the length of the incoming prompt:
| 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | |
|---|---|---|---|---|---|---|---|---|---|---|
| B (decode) | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | to |
| C (decode) | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | <PAD> | algorithm |
| D (prefill) | Once | upon | a | time | in | a | land | far | away | , |
B and C are mid-generation, their previous tokens (”Today’s weather is so…” and “In machine learning, a transformer is…”) have already been processed and stored in KV cache. Each only needs to submit 1 new query token in this step. However, for request D, it is a new prompt entering the batch, and none of the tokens were cached, so all 10 tokens had to be fed for the prefill. As a result, 18 of the 30 tokens in this step are padding.
Ragged Batching
The solution to the above problem is ragged batching, which concatenates prompts together. However, we would not want prompts “In machine learning, a transformer is…” to attend to any tokens from “Once upon a time…”. Therefore, a block-diagonal causal mask is used to prevent that. Here’s an illustration of the attention mask (# = attend; . = blocked):
| B: Today | B: ‘s | B: weather | B: is | B: so | B: cold | B: that | B: it | B: ‘s | B: hard | D: In | D: machine | D: learning | D: , | D: a | D: transformer | D: is | D: a | D: type | D: of | D: machine | D: learning | A: Once | A: upon | A: a | A: time | A: in | A: a | A: land | A: far | A: away | A: , | B: to | D: algorithm | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| A: Once | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | . | . | . | . | . | . | . | . | . | . | . |
| A: upon | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | . | . | . | . | . | . | . | . | . | . |
| A: a | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | . | . | . | . | . | . | . | . | . |
| A: time | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | . | . | . | . | . | . | . | . |
| A: in | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | . | . | . | . | . | . | . |
| A: a | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | . | . | . | . | . | . |
| A: land | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | # | . | . | . | . | . |
| A: far | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | # | # | . | . | . | . |
| A: away | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | # | # | # | . | . | . |
| A: , | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | # | # | # | # | . | . |
| B: to | # | # | # | # | # | # | # | # | # | # | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | . | # | . |
| D: algorithm | . | . | . | . | . | . | . | . | . | . | # | # | # | # | # | # | # | # | # | # | # | # | . | . | . | . | . | . | . | . | . | . | . | # |
In this case, the tensor for the attention operation is no longer a batched 4D tensor of BSHD shape, but logically an unbatched 3D tensor of THD shape, where T denotes the token dimension of the concatenated prompts.
Full Implementation
Here is the full Python implementation of static batching and continuous batching for comparison. Notice the time difference of how continuous batching makes the LLM inference process so much more efficient. Also, notice the outputs are identical in the two versions, as the block-diagonal mask guarantees that packing sequences together does not change what the model computes.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 |
""" Continuous batching = iteration-level scheduling + ragged (packed) batching. Two approaches are compared (both run BATCH_SIZE sequences concurrently, so the comparison is slot-for-slot fair): 1. Static batching (baseline): Prompts are processed BATCH_SIZE at a time. Each wave is padded to a common length and run together until the LONGEST request in that wave finishes; a hard "batch barrier" then has to clear before the next wave starts. Short requests sit idle behind the barrier. 2. Continuous batching (production-aligned): Two ideas combine to keep the GPU busy. (a) Iteration-level scheduling: the moment a sequence finishes it frees its slot, and the next queued prompt is admitted on the SAME step - no waiting for the rest of the batch. (b) Ragged / packed batching - the part that makes it truly "continuous": instead of padding every sequence into a rectangular [B, max_len] tensor, ALL in-flight tokens are concatenated into a single unpadded [1, total_tokens] row and run in ONE forward pass. A block-diagonal causal attention mask stops tokens from attending across sequence boundaries, so packing is mathematically identical to running each sequence on its own (verified: greedy output matches per-prompt generation token-for-token). Because attention is governed entirely by the mask, a newly admitted prompt's multi-token PREFILL rides along in the same forward pass as every other sequence's single-token DECODE step. Prefill and decode are fused: no padding, no separate prefill pass. KV cache: each sequence keeps its own DynamicCache; every step the caches are concatenated along the time axis into one packed cache, and the newly computed KV is scattered back per sequence. (Real engines store the cache in fixed-size pages - "paged attention" - to avoid this per-step reassembly, but the attention/masking logic is exactly what you see here.) """ import time import torch from dataclasses import dataclass, field from typing import Optional from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache from transformers.cache_utils import DynamicLayer MODEL_ID = "openai-community/gpt2" # swap for any causal LM BATCH_SIZE = 3 # max concurrent sequences (slots) def _device_sync(model) -> None: """Block until queued GPU work finishes, so timings are accurate.""" if model.device.type == "cuda": torch.cuda.synchronize() elif model.device.type == "mps": torch.mps.synchronize() def static_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]: """Baseline. Process requests BATCH_SIZE at a time; each wave runs together until its LONGEST request finishes, then a batch barrier clears before the next wave starts. Downside: short requests in a wave idle until the wave's longest is done - and no slot can be refilled until the whole wave clears the barrier. """ if not requests: return [] tokenizer.padding_side = "left" results: dict[int, str] = {} indexed = list(enumerate(requests)) # (req_id, (prompt, cap)) for wave_start in range(0, len(indexed), BATCH_SIZE): wave = indexed[wave_start: wave_start + BATCH_SIZE] wave_max = max(cap for _, (_, cap) in wave) # Show which request occupies each slot in this wave. for slot, (req_id, (prompt, cap)) in enumerate(wave): print(f" ++ slot {slot} <- req {req_id} ({cap} tok cap): {prompt!r}", flush=True) prompts = [p for _, (p, _) in wave] inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ).to(model.device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=wave_max, # whole wave decodes to the longest pad_token_id=tokenizer.eos_token_id, do_sample=False, ) width = inputs.input_ids.shape[1] print( f" *** batch barrier: all {len(wave)} slots wait for the longest " f"({wave_max} tokens) ***", flush=True, ) for slot, ((req_id, (prompt, cap)), row) in enumerate(zip(wave, output_ids)): text = prompt + tokenizer.decode(row[width:width + cap], skip_special_tokens=True) results[req_id] = text print( f" -- slot {slot} done req {req_id} ({cap}/{wave_max} tokens): {text[:90]}", flush=True, ) return [results[k] for k in sorted(results)] @dataclass class Sequence: """State for a single in-flight sequence.""" req_id: int # original request index (for ordering results) prompt: str max_new_tokens: int # per-request cap so short requests finish early # Tokens to feed on the NEXT step: the whole prompt right after admission # (prefill), then a single token per step (decode). pending_ids: list[int] # Per-sequence KV-cache; None until this sequence has run once. kv_cache: Optional[DynamicCache] = None kv_len: int = 0 # number of cached tokens (prompt + generated) tokens_generated: int = 0 output_ids: list[int] = field(default_factory=list) def _make_cache(layers_kv: list[tuple[torch.Tensor, torch.Tensor]]) -> DynamicCache: """Build a DynamicCache from explicit per-layer (keys, values) tensors. We SET the tensors directly instead of calling DynamicLayer.update() (which would append), because we are assembling caches from scratch each step. """ cache = DynamicCache() for k, v in layers_kv: layer = DynamicLayer() layer.lazy_initialization(k, v) layer.keys = k layer.values = v cache.layers.append(layer) return cache def _ragged_step(seqs: list[Sequence], model, device, dtype) -> list[int]: """Run ONE packed forward pass over every active sequence. All sequences are flattened into a single row (batch dim = 1): input_ids [1, total_q] - every sequence's pending tokens position_ids [1, total_q] - each token's position in ITS sequence attention_mask [1, 1, total_q, total_kv + total_q] - block-diagonal causal past_key_values packed cache [1, H, total_kv, D] total_q = sum of pending tokens (1 per decoding seq, prompt_len per new seq) total_kv = sum of already-cached tokens across sequences Returns the next greedy token for each sequence (same order as ``seqs``). """ q_lens = [len(s.pending_ids) for s in seqs] total_q = sum(q_lens) total_kv = sum(s.kv_len for s in seqs) # Packed inputs: concatenate every sequence's pending tokens into one row. flat_ids = [t for s in seqs for t in s.pending_ids] input_ids = torch.tensor([flat_ids], dtype=torch.long, device=device) # Tag every KEY and every QUERY token with (sequence index, position-in-sequence). # Key space is laid out as [ cached tokens | this step's new tokens ], matching # how the model appends new KV to the end of the packed cache. key_seq, key_pos = [], [] for si, s in enumerate(seqs): # cached block for p in range(s.kv_len): key_seq.append(si) key_pos.append(p) q_seq, q_pos = [], [] for si, s in enumerate(seqs): # new block (also queries) for j in range(len(s.pending_ids)): pos = s.kv_len + j q_seq.append(si) q_pos.append(pos) key_seq.append(si) key_pos.append(pos) q_seq_t = torch.tensor(q_seq, device=device) q_pos_t = torch.tensor(q_pos, device=device) key_seq_t = torch.tensor(key_seq, device=device) key_pos_t = torch.tensor(key_pos, device=device) # Each token's positional embedding uses its own sequence position, not its # offset in the packed row. position_ids = q_pos_t.unsqueeze(0) # [1, total_q] # Block-diagonal causal mask: a query may attend to a key only if they belong # to the SAME sequence (block-diagonal) and the key is not in the future # (causal). This is the whole trick - it makes packing equivalent to running # each sequence separately. 0.0 = attend, large-negative = blocked (additive). same = q_seq_t[:, None] == key_seq_t[None, :] causal = key_pos_t[None, :] <= q_pos_t[:, None] allowed = same & causal # [total_q, total_kv + total_q] attn_mask = torch.zeros(1, 1, total_q, total_kv + total_q, dtype=dtype, device=device) attn_mask.masked_fill_(~allowed[None, None], torch.finfo(dtype).min) # Packed KV-cache: concatenate each sequence's cache along the time axis. # Freshly admitted sequences (kv_len == 0) contribute nothing here. cached = [s for s in seqs if s.kv_len > 0] if cached: num_layers = len(cached[0].kv_cache.layers) layers_kv = [] for l in range(num_layers): ks = torch.cat([s.kv_cache.layers[l].keys for s in cached], dim=2) vs = torch.cat([s.kv_cache.layers[l].values for s in cached], dim=2) layers_kv.append((ks, vs)) past = _make_cache(layers_kv) else: past = DynamicCache() with torch.no_grad(): out = model( input_ids=input_ids, attention_mask=attn_mask, position_ids=position_ids, past_key_values=past, use_cache=True, ) # Greedy next token for each sequence: read the logits at its LAST pending # token (for a prefilling sequence that is the final prompt token). logits = out.logits[0] # [total_q, vocab] offsets, last_idx, off = [], [], 0 for ql in q_lens: offsets.append(off) last_idx.append(off + ql - 1) off += ql next_tokens = [int(logits[i].argmax()) for i in last_idx] # Scatter the newly computed KV back to each sequence. The output cache is # [ old packed block | new packed block ]; slice this step's new block per # sequence and append it to that sequence's own cache. out_kv = out.past_key_values num_layers = len(out_kv.layers) for si, s in enumerate(seqs): o, ql = offsets[si], q_lens[si] layers_kv = [] for l in range(num_layers): k_new = out_kv.layers[l].keys[:, :, total_kv + o: total_kv + o + ql, :] v_new = out_kv.layers[l].values[:, :, total_kv + o: total_kv + o + ql, :] if s.kv_cache is None: layers_kv.append((k_new, v_new)) else: layers_kv.append(( torch.cat([s.kv_cache.layers[l].keys, k_new], dim=2), torch.cat([s.kv_cache.layers[l].values, v_new], dim=2), )) s.kv_cache = _make_cache(layers_kv) s.kv_len += ql return next_tokens def visualize_ragged_step(seqs: list[Sequence], tokenizer, title: str, slot_ids: list[int]) -> None: """Illustrative print of ONE packed step: the concatenated input row and the block-diagonal causal attention mask. This mirrors the masking logic in _ragged_step (recomputed here as a boolean grid purely for display) so you can SEE that sequences are packed together yet isolated by the mask. Each sequence gets a letter A, B, C, ... # = a query may attend to that key . = blocked """ labels = [chr(ord("A") + s.req_id) for s in seqs] q_lens = [len(s.pending_ids) for s in seqs] total_q = sum(q_lens) total_kv = sum(s.kv_len for s in seqs) print(f"\n{'=' * 72}\n {title}") print(f" total_q={total_q} tokens fed this step | total_kv={total_kv} cached") print(f" {len(seqs)} sequences packed into ONE unpadded row of shape [1, {total_q}]:\n") # The concatenated tokens, grouped per sequence (this is the "ragged" row). for i, s in enumerate(seqs): kind = f"PREFILL({q_lens[i]})" if s.kv_len == 0 else f"decode({q_lens[i]})" toks = " ".join(repr(tokenizer.decode([t])) for t in s.pending_ids) if len(toks) > 66: toks = toks[:63] + "..." print(f" {labels[i]} = slot {slot_ids[i]} {kind:<11} {toks}") # Rebuild the block-diagonal causal mask as a boolean grid for display. key_seq, key_pos = [], [] for si, s in enumerate(seqs): # cached keys key_seq += [si] * s.kv_len key_pos += list(range(s.kv_len)) q_seq, q_pos = [], [] for si, s in enumerate(seqs): # new keys / queries for j in range(q_lens[si]): q_seq.append(si) q_pos.append(s.kv_len + j) key_seq += q_seq key_pos += q_pos q_seq_t, q_pos_t = torch.tensor(q_seq), torch.tensor(q_pos) key_seq_t, key_pos_t = torch.tensor(key_seq), torch.tensor(key_pos) allowed = (q_seq_t[:, None] == key_seq_t[None, :]) & (key_pos_t[None, :] <= q_pos_t[:, None]) K = len(key_seq) def row_str(cells): # Space between sequence groups; ' | ' at the cached -> new-tokens split. out = [] for ki in range(K): if total_kv > 0 and ki == total_kv: out.append(" | ") elif ki > 0 and key_seq[ki] != key_seq[ki - 1]: out.append(" ") out.append(cells[ki]) return "".join(out) def line(left, cells): return f"{left:>7} " + row_str(cells) print(f"\n block-diagonal causal mask (row = query, col = key) # attend . blocked") if total_kv > 0: print(f" key layout: [ cached KV | this step's new tokens ]") print(line("keys:", [labels[key_seq[ki]] for ki in range(K)])) for qi in range(total_q): cells = ["#" if allowed[qi, ki] else "." for ki in range(K)] print(line(f"{labels[q_seq[qi]]} p{q_pos[qi]}", cells)) def continuous_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]: """Ragged continuous batching: dynamic scheduling + packed prefill/decode. Scheduling policy: - Up to BATCH_SIZE sequences run concurrently. - A newly admitted sequence is queued with its full prompt as the next tokens to feed; its prefill then happens packed into the next step alongside everyone else's decode. - Every step runs ONE packed forward pass across all active slots. - When a sequence finishes it is immediately replaced by the next prompt. The admission log shows slots being reused (iteration-level scheduling). Two representative steps are visualized: the first step (all prompts being prefilled at once) and the first step that fuses a new prompt's prefill with other sequences' decode tokens. """ device = model.device dtype = next(model.parameters()).dtype queue = list(enumerate(requests)) # (req_id, (prompt, max_new_tokens)) slots: list[Optional[Sequence]] = [None] * BATCH_SIZE results: dict[int, str] = {} def _admit(slot_idx: int) -> None: if not queue: slots[slot_idx] = None return req_id, (prompt, max_new_tokens) = queue.pop(0) prompt_ids = tokenizer(prompt)["input_ids"] slots[slot_idx] = Sequence( req_id=req_id, prompt=prompt, max_new_tokens=max_new_tokens, pending_ids=list(prompt_ids), # prefill rides the next step ) print( f" ++ [step {step:3d}] slot {slot_idx} <- admit req {req_id} " f"({max_new_tokens} tok cap): {prompt!r}", flush=True, ) # Fill the pool with the first batch of prompts (step 0 = before any decode). step = 0 for i in range(BATCH_SIZE): _admit(i) printed_mixed = False while any(s is not None for s in slots): step += 1 active = [(i, s) for i, s in enumerate(slots) if s is not None] seqs = [s for _, s in active] slot_ids = [i for i, _ in active] # Visualize a couple of representative steps so the packing is visible # (printing every step would be far too much output). mixed = any(s.kv_len == 0 for s in seqs) and any(s.kv_len > 0 for s in seqs) if step == 1: visualize_ragged_step( seqs, tokenizer, f"STEP {step} - prompts packed together (all PREFILL)", slot_ids) elif mixed and not printed_mixed: visualize_ragged_step( seqs, tokenizer, f"STEP {step} - PREFILL + DECODE fused in one pass", slot_ids) printed_mixed = True # ONE packed forward pass (prefill + decode fused, no padding). next_tokens = _ragged_step(seqs, model, device, dtype) for (slot_idx, seq), tok in zip(active, next_tokens): seq.output_ids.append(tok) seq.tokens_generated += 1 seq.pending_ids = [tok] # next step: a single decode token if tok == tokenizer.eos_token_id or seq.tokens_generated >= seq.max_new_tokens: result_text = seq.prompt + \ tokenizer.decode(seq.output_ids, skip_special_tokens=True) results[seq.req_id] = result_text print( f" -- step {step:3d}] slot {slot_idx} done req {seq.req_id} " f"({seq.tokens_generated}/{seq.max_new_tokens} tokens): {result_text[:90]}", flush=True, ) _admit(slot_idx) return [results[k] for k in sorted(results)] def main(): print(f"Loading {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token # Pick the fastest available device. On Apple Silicon (M1/M2/...) this is # the MPS GPU. We keep float32 on MPS on purpose: float16 there flips a few # greedy ties, which would break the "static == continuous, token-for-token" # property this demo relies on. if torch.cuda.is_available(): device, dtype = "cuda", torch.float16 elif torch.backends.mps.is_available(): device, dtype = "mps", torch.float32 else: device, dtype = "cpu", torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=dtype, attn_implementation="eager", # use our custom 4D mask directly ) model.eval() model.to(device) print(f"Running on {device} ({dtype})\n") requests = [ ("The capital of France is", 6), ("Today's weather is so", 50), ("In machine learning, a transformer is", 300), ("Once upon a time in a land far away,", 30), ("Quantum computing differs from classical computing because", 180), ("The history of the Roman Empire began", 45), ] print("=== Static batching ===") _device_sync(model) start = time.perf_counter() static_batching(requests, tokenizer, model) _device_sync(model) static_elapsed = time.perf_counter() - start print(f"\nStatic batching elapsed: {static_elapsed:.2f}s\n") print("=== Continuous batching (ragged) ===") _device_sync(model) start = time.perf_counter() continuous_batching(requests, tokenizer, model) _device_sync(model) continuous_elapsed = time.perf_counter() - start print(f"\nContinuous batching elapsed: {continuous_elapsed:.2f}s") if __name__ == "__main__": main() |
This is a long piece of code. Because the token sequences are concatenated, you need the Sequence data class to preserve some essential data and the pointer to the KV cache for each original prompt. The static batching baseline is the same function as before. The function continuous_batching() is the continuous batching entry point, where the internal function _admit() loads a new request when one in the previous batch finishes its generation.
Function visualize_ragged_steps() is just to print the status of each step. Actual prefill or decode steps are implemented in _ragged_step().
In _ragged_step(), multiple sequences are concatenated into input_ids and the corresponding block-diagonal causal mask is created as attn_mask. The model is invoked with use_cache=True, which will use the KV cache as provided with the past_key_values parameter. Since input_ids parameter to call with the model may correspond to different requests, the KV cache will be rebuilt on each iteration. The latter half of the _ragged_step() function is to manage the KV cache.
Running this code, you will see:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
Running on mps (torch.float32) === Static batching === ++ slot 0 <- req 0 (6 tok cap): 'The capital of France is' ++ slot 1 <- req 1 (50 tok cap): "Today's weather is so" ++ slot 2 <- req 2 (300 tok cap): 'In machine learning, a transformer is' *** batch barrier: all 3 slots wait for the longest (300 tokens) *** -- slot 0 done req 0 (6/300 tokens): The capital of France is the capital of the French Republic -- slot 1 done req 1 (50/300 tokens): Today's weather is so cold that it's hard to see the sun. But it's not like we're going to -- slot 2 done req 2 (300/300 tokens): In machine learning, a transformer is a type of machine learning algorithm that can be use ++ slot 0 <- req 3 (30 tok cap): 'Once upon a time in a land far away,' ++ slot 1 <- req 4 (180 tok cap): 'Quantum computing differs from classical computing because' ++ slot 2 <- req 5 (45 tok cap): 'The history of the Roman Empire began' *** batch barrier: all 3 slots wait for the longest (180 tokens) *** -- slot 0 done req 3 (30/180 tokens): Once upon a time in a land far away, the sun was shining, and the moon was shining. The su -- slot 1 done req 4 (180/180 tokens): Quantum computing differs from classical computing because it is based on the notion of a -- slot 2 done req 5 (45/180 tokens): The history of the Roman Empire began in the fourth century B.C.E. with the arrival of the Static batching elapsed: 61.80s === Continuous batching (ragged) === ++ [step 0] slot 0 <- admit req 0 (6 tok cap): 'The capital of France is' ++ [step 0] slot 1 <- admit req 1 (50 tok cap): "Today's weather is so" ++ [step 0] slot 2 <- admit req 2 (300 tok cap): 'In machine learning, a transformer is' ======================================================================== STEP 1 - prompts packed together (all PREFILL) total_q=17 tokens fed this step | total_kv=0 cached 3 sequences packed into ONE unpadded row of shape [1, 17]: A = slot 0 PREFILL(5) 'The' ' capital' ' of' ' France' ' is' B = slot 1 PREFILL(5) 'Today' "'s" ' weather' ' is' ' so' C = slot 2 PREFILL(7) 'In' ' machine' ' learning' ',' ' a' ' transformer' ' is' block-diagonal causal mask (row = query, col = key) # attend . blocked keys: AAAAA BBBBB CCCCCCC A p0 #.... ..... ....... A p1 ##... ..... ....... A p2 ###.. ..... ....... A p3 ####. ..... ....... A p4 ##### ..... ....... B p0 ..... #.... ....... B p1 ..... ##... ....... B p2 ..... ###.. ....... B p3 ..... ####. ....... B p4 ..... ##### ....... C p0 ..... ..... #...... C p1 ..... ..... ##..... C p2 ..... ..... ###.... C p3 ..... ..... ####... C p4 ..... ..... #####.. C p5 ..... ..... ######. C p6 ..... ..... ####### -- step 6] slot 0 done req 0 (6/6 tokens): The capital of France is the capital of the French Republic ++ [step 6] slot 0 <- admit req 3 (30 tok cap): 'Once upon a time in a land far away,' ======================================================================== STEP 7 - PREFILL + DECODE fused in one pass total_q=12 tokens fed this step | total_kv=22 cached 3 sequences packed into ONE unpadded row of shape [1, 12]: D = slot 0 PREFILL(10) 'Once' ' upon' ' a' ' time' ' in' ' a' ' land' ' far' ' away' ',' B = slot 1 decode(1) ' to' C = slot 2 decode(1) ' algorithm' block-diagonal causal mask (row = query, col = key) # attend . blocked key layout: [ cached KV | this step's new tokens ] keys: BBBBBBBBBB CCCCCCCCCCCC | DDDDDDDDDD B C D p0 .......... ............ | #......... . . D p1 .......... ............ | ##........ . . D p2 .......... ............ | ###....... . . D p3 .......... ............ | ####...... . . D p4 .......... ............ | #####..... . . D p5 .......... ............ | ######.... . . D p6 .......... ............ | #######... . . D p7 .......... ............ | ########.. . . D p8 .......... ............ | #########. . . D p9 .......... ............ | ########## . . B p10 ########## ............ | .......... # . C p12 .......... ############ | .......... . # -- step 36] slot 0 done req 3 (30/30 tokens): Once upon a time in a land far away, the sun was shining, and the moon was shining. The su ++ [step 36] slot 0 <- admit req 4 (180 tok cap): 'Quantum computing differs from classical computing because' -- step 50] slot 1 done req 1 (50/50 tokens): Today's weather is so cold that it's hard to see the sun. But it's not like we're going to ++ [step 50] slot 1 <- admit req 5 (45 tok cap): 'The history of the Roman Empire began' -- step 95] slot 1 done req 5 (45/45 tokens): The history of the Roman Empire began in the fourth century B.C.E. with the arrival of the -- step 216] slot 0 done req 4 (180/180 tokens): Quantum computing differs from classical computing because it is based on the notion of a -- step 300] slot 2 done req 2 (300/300 tokens): In machine learning, a transformer is a type of machine learning algorithm that can be use Continuous batching elapsed: 9.54s |
You can observe how the batch changes (and correspondingly, the causal mask) in each step. Note that the attention operation has a complexity of $O(N^2)$ for the number of input tokens $N$. The time taken for generation in continuous batching is much shorter because you eliminate all padding tokens in the input, making the generation more efficient.
Further Readings
Below are some resources that you may find useful:
- How continuous batching enables 23x throughput in LLM inference while reducing p50 latency, by Daniel et al, Anyscale blog, 2023
- Static, dynamic and continuous batching, LLM Inference Handbook
- LLM Serving (1): Continuous batching, by Ludovico Bessi, 2025
- LLM Inference: Continuous Batching and PagedAttention , by Insu Jang, 2024
- The Existential Problems in LLM Serving, by Kukil, 2025
- LLM Serving: Why So Hard??, by Or Zipori, 2026
- Continuous Batching: Optimizing LLM Inference Throughput, by Michael Brenndoerfer, 2026
- Model edxecution and inference flow, vLLM documentation
- Does the continuous batching technology in the vLLM online service scenario contain the concept of batch size?, vLLM issue #2257 on GitHub, 2023
- Continuous Batching, by Reboul et al, Hugging Face blog, 2025
- PagedAttention vs Continuous Batching vs vLLM vs SGLang — A Practical Breakdown, by Varun Rao, 2025
Summary
In this article, we walked through the two problems with static batching – having short prompts to wait for larger prompts in the same batch and wasting GPU cycles on padding tokens. Then, we built a working solution to address these problems with dynamic scheduling and ragged batching. These ideas work together to keep every GPU cycle working on real tokens more efficiently.






No comments yet.