Apr 3, 2026

Speculative Decoding - Making LLMs Think Faster

How speculative decoding breaks the autoregressive bottleneck by drafting cheap tokens and verifying them in parallel, achieving 2-3x speedups with mathematically identical output.

Speculative Decoding - Making LLMs Think Faster hero image

Every token an LLM generates requires a full forward pass through the entire model. A 70-billion parameter model. For one token. Then it does it again. And again. Hundreds or thousands of times, sequentially, to produce a single response.

In the KV caching post, we fixed the quadratic blowup of recomputing keys and values. That was a necessary optimization. But even with a perfectly cached model, the fundamental bottleneck remains: autoregressive generation is sequential. The model cannot start producing token t+1t+1 until token tt exists, because token t+1t+1's probability distribution depends on the entire preceding context including tt.

This post is about a technique that breaks that bottleneck without changing the model, without retraining, and without any approximation. The output is mathematically identical to standard decoding. The technique is called speculative decoding (Leviathan et al., 2023; Chen et al., 2023), and it is one of the most elegant ideas in modern LLM inference.

The Bottleneck: Why Inference is Memory-Bound

Let's revisit why generating tokens is slow. During the decode phase, at each step, the model:

  1. Takes one token embedding as input
  2. Passes it through every layer (projections, attention, FFN)
  3. Produces logits over the vocabulary
  4. Samples or selects one token
  5. Repeats

For a model like Llama 2 70B, each forward pass involves loading roughly 140 GB of model weights from GPU memory (HBM) into the compute units. But the actual arithmetic -- a matrix-vector multiply for each weight matrix -- is tiny. The compute-to-memory ratio is abysmal.

This is what it means to be memory-bound: the GPU spends most of its time waiting for data to arrive from memory, not performing calculations. The arithmetic intensity (FLOPs per byte loaded) during single-token decode is far below what the hardware can handle.

Here are the numbers for Llama 2 70B on an A100-80GB:

python
# Llama 2 70B decode step analysis model_size_bytes = 70e9 * 2 # ~140 GB in fp16 hbm_bandwidth = 2e12 # 2 TB/s on A100 compute_flops = 312e12 # 312 TFLOPS fp16 on A100 # Time to load all weights for one token time_memory = model_size_bytes / hbm_bandwidth # ~70 ms # Time to compute (matrix-vector products) time_compute = 2 * 70e9 / compute_flops # ~0.45 ms # The GPU is 99.4% idle waiting for memory utilization = time_compute / time_memory # 0.6%

The GPU's compute units are utilized at less than 1% during autoregressive generation. This is why generating 100 tokens takes roughly the same wall-clock time whether you generate them one at a time or verify all 100 in a single batch. The forward pass is dominated by weight loading, and you load the same weights regardless of how many tokens you process.

This is the key insight that makes speculative decoding possible:

Verification is (nearly) free. Processing KK tokens in one forward pass costs almost the same as processing 1 token, because the bottleneck is loading weights, not computing with them.

The Core Idea: Draft Then Verify

Speculative decoding exploits this asymmetry with a beautifully simple two-phase approach:

Phase 1 -- Draft: A small, fast "draft" model generates KK candidate tokens autoregressively. Because it's small (say, 1B parameters vs 70B), it runs 10-50x faster per token.

Phase 2 -- Verify: The large "target" model processes all KK candidate tokens in a single forward pass. It computes the target probability for each token position and decides whether to accept or reject each draft.

If all KK tokens are accepted, we've generated KK tokens for the cost of KK cheap draft steps plus one expensive target step -- instead of KK expensive target steps. If some are rejected, we still make progress: we keep all tokens up to the first rejection and sample a correction from the target model.

Let's trace through a concrete example. Suppose we're generating the sentence "The quick brown fox jumps over" with K=4K = 4:

Standard decoding (4 target forward passes):

  1. Target model: context = "The" → generates "quick" (40ms)
  2. Target model: context = "The quick" → generates "brown" (40ms)
  3. Target model: context = "The quick brown" → generates "fox" (40ms)
  4. Target model: context = "The quick brown fox" → generates "jumps" (40ms)
  5. Total: 160ms for 4 tokens

Speculative decoding (K=4 draft steps + 1 target verification):

  1. Draft model generates 4 candidates: "quick", "brown", "fox", "jumped" (4 × 2ms = 8ms)
  2. Target model verifies all 4 in one pass (40ms)
  3. First 3 accepted, "jumped" rejected → sample correction "jumps"
  4. Total: 48ms for 4 tokens (3.3x speedup!)

The draft model got 3 out of 4 right. We accepted those, rejected the wrong one, and sampled the correct token from the target model as a bonus. We got 4 tokens for the price of 48ms instead of 160ms.

The Math: Modified Rejection Sampling

The magic of speculative decoding is not just that it's fast -- it's that the output distribution is exactly identical to standard decoding from the target model. Not approximately. Exactly. Let's see why.

Setup

Let p(xx<t)p(x | x_{<t}) be the target model's distribution at position tt, and q(xx<t)q(x | x_{<t}) be the draft model's distribution. For brevity, I'll write p(x)p(x) and q(x)q(x).

The draft model samples a token xqx \sim q. We want to decide whether to accept this sample as if it came from pp.

The Acceptance Rule

For each draft token xx, compute:

Accept with probabilitymin(1,p(x)q(x))\text{Accept with probability} \quad \min\left(1, \frac{p(x)}{q(x)}\right)

This means:

  • If p(x)q(x)p(x) \geq q(x): Always accept. The draft model assigned less probability to this token than the target would, so the draft is "under-confident" here. Accepting always is fine.
  • If p(x)<q(x)p(x) < q(x): Accept with probability p(x)q(x)<1\frac{p(x)}{q(x)} < 1. The draft model is "over-confident" about this token relative to the target. We randomly reject it to correct the bias.

This is a form of rejection sampling, a classic technique in statistics for sampling from a target distribution using proposals from an easier distribution.

What Happens on Rejection?

When we reject a draft token at position ii, we discard tokens i,i+1,,Ki, i+1, \ldots, K (all remaining drafts after the rejection point) and sample a single token from the recovery distribution:

p(x)=max(0,  p(x)q(x))xmax(0,  p(x)q(x))p'(x) = \frac{\max(0, \; p(x) - q(x))}{\sum_{x'} \max(0, \; p(x') - q(x'))}

This is the normalized "residual" distribution: it contains exactly the probability mass that the draft model under-represented. Tokens where p(x)>q(x)p(x) > q(x) get positive mass; tokens where p(x)q(x)p(x) \leq q(x) get zero mass.

The recovery distribution ensures we fill in exactly the probability that rejection sampling "missed." Together, the accept + recovery process produces samples from the exact target distribution pp.

Working Through an Example

Let's say the vocabulary is {A, B, C} and the distributions at some position are:

Tokenp(x)p(x) (target)q(x)q(x) (draft)Accept probmax(0,pq)\max(0, p-q)
A0.50.31.00.2
B0.30.50.60.0
C0.20.21.00.0

Suppose the draft samples B (q(B)=0.5q(B) = 0.5). The acceptance probability is min(1,0.3/0.5)=0.6\min(1, 0.3/0.5) = 0.6.

  • With probability 0.6: Accept B.
  • With probability 0.4: Reject B and sample from pp'.

The recovery distribution pp' is:

p(A)=0.20.2=1.0,p(B)=0,p(C)=0p'(A) = \frac{0.2}{0.2} = 1.0, \quad p'(B) = 0, \quad p'(C) = 0

If we reject B, we always sample A. This makes sense: the draft over-represents B and under-represents A, so when we reject, we correct by sampling A.

Let's verify the overall probability of producing each token:

Pr[output=A]=q(A)1+q(B)0.41.0+q(C)10[wait, this needs more care]\Pr[\text{output} = A] = q(A) \cdot 1 + q(B) \cdot 0.4 \cdot 1.0 + q(C) \cdot 1 \cdot 0 \quad \text{[wait, this needs more care]}

Actually, let me be more careful. For a single position, the probability of outputting token xx is:

Pr[output=x]=q(x)min(1,p(x)q(x))accept draft x+xxq(x)(1min(1,p(x)q(x)))p(x)reject some x,then sample x from p\Pr[\text{output} = x] = \underbrace{q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right)}_{\text{accept draft } x} + \underbrace{\sum_{x' \neq x} q(x') \cdot \left(1 - \min\left(1, \frac{p(x')}{q(x')}\right)\right) \cdot p'(x)}_{\text{reject some } x', \text{then sample } x \text{ from } p'}

For token A:

  • Accept A directly: q(A)min(1,p(A)/q(A))=0.31.0=0.3q(A) \cdot \min(1, p(A)/q(A)) = 0.3 \cdot 1.0 = 0.3
  • Reject B, sample A from pp': q(B)(10.6)1.0=0.50.41.0=0.2q(B) \cdot (1 - 0.6) \cdot 1.0 = 0.5 \cdot 0.4 \cdot 1.0 = 0.2
  • Reject C, sample A from pp': q(C)(11.0)p(A)=0q(C) \cdot (1 - 1.0) \cdot p'(A) = 0
  • Total: 0.3+0.2=0.5=p(A)0.3 + 0.2 = 0.5 = p(A)

For token B:

  • Accept B directly: q(B)min(1,p(B)/q(B))=0.50.6=0.3q(B) \cdot \min(1, p(B)/q(B)) = 0.5 \cdot 0.6 = 0.3
  • Sample B from pp': p(B)=0p'(B) = 0, so no contribution
  • Total: 0.3=p(B)0.3 = p(B)

For token C:

  • Accept C directly: q(C)min(1,p(C)/q(C))=0.21.0=0.2q(C) \cdot \min(1, p(C)/q(C)) = 0.2 \cdot 1.0 = 0.2
  • Sample C from pp': p(C)=0p'(C) = 0, so no contribution
  • Total: 0.2=p(C)0.2 = p(C)

The output distribution matches the target exactly. This is not a coincidence -- it's a mathematical guarantee.

Why It's Lossless: The Proof

Let's prove this in general. For any token xx in the vocabulary:

Pr[output=x]=q(x)min(1,p(x)q(x))+p(x)xq(x)max(0,1p(x)q(x))\Pr[\text{output} = x] = q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right) + p'(x) \cdot \sum_{x'} q(x') \cdot \max\left(0, 1 - \frac{p(x')}{q(x')}\right)

Case 1: p(x)q(x)p(x) \leq q(x) (draft is over-confident)

The first term gives q(x)p(x)q(x)=p(x)q(x) \cdot \frac{p(x)}{q(x)} = p(x).

Since p(x)q(x)p(x) \leq q(x), we have p(x)=max(0,p(x)q(x))/Z=0p'(x) = \max(0, p(x) - q(x)) / Z = 0.

So Pr[output=x]=p(x)+0=p(x)\Pr[\text{output} = x] = p(x) + 0 = p(x). ✓

Case 2: p(x)>q(x)p(x) > q(x) (draft is under-confident)

The first term gives q(x)1=q(x)q(x) \cdot 1 = q(x).

For the second term, we need p(x)=(p(x)q(x))/Zp'(x) = (p(x) - q(x)) / Z where Z=xmax(0,p(x)q(x))Z = \sum_{x'} \max(0, p(x') - q(x')).

The total rejection probability is:

xq(x)max(0,1p(x)q(x))=x:q(x)>p(x)(q(x)p(x))\sum_{x'} q(x') \cdot \max\left(0, 1 - \frac{p(x')}{q(x')}\right) = \sum_{x': q(x') > p(x')} (q(x') - p(x'))

Now, since p(x)=q(x)=1\sum p(x') = \sum q(x') = 1:

x:q(x)>p(x)(q(x)p(x))=x:p(x)>q(x)(p(x)q(x))=Z\sum_{x': q(x') > p(x')} (q(x') - p(x')) = \sum_{x': p(x') > q(x')} (p(x') - q(x')) = Z

(The total surplus where q>pq > p must equal the total deficit where p>qp > q, because both distributions sum to 1.)

So the second term becomes p(x)q(x)ZZ=p(x)q(x)\frac{p(x) - q(x)}{Z} \cdot Z = p(x) - q(x).

Total: q(x)+(p(x)q(x))=p(x)q(x) + (p(x) - q(x)) = p(x). ✓

Both cases yield p(x)p(x). The output distribution is exactly the target distribution, regardless of how good or bad the draft model is. A bad draft model just means more rejections (slower speed), not different output.

Expected Speedup: How Many Tokens Per Step?

The expected number of accepted tokens depends on the acceptance rate. Define αi\alpha_i as the acceptance probability of the ii-th draft token (which depends on how aligned pp and qq are at that position).

For simplicity, assume a constant acceptance rate α\alpha across positions. The probability that the first ii tokens are all accepted is αi\alpha^i. The expected number of accepted tokens per speculation step is:

E[accepted tokens]=i=1Kαi=α(1αK)1α\mathbb{E}[\text{accepted tokens}] = \sum_{i=1}^{K} \alpha^i = \frac{\alpha(1 - \alpha^K)}{1 - \alpha}

Plus the bonus token (either the (K+1)(K+1)-th from the target when all are accepted, or the recovery token when one is rejected):

E[total tokens per step]=α(1αK)1α+1\mathbb{E}[\text{total tokens per step}] = \frac{\alpha(1 - \alpha^K)}{1 - \alpha} + 1

Let's compute this for various acceptance rates with K=5K = 5:

python
def expected_tokens(alpha, K): """Expected tokens per speculative decoding step.""" if alpha == 1.0: return K + 1 accepted = alpha * (1 - alpha**K) / (1 - alpha) return accepted + 1 # +1 for bonus/recovery token # K = 5 draft tokens for alpha in [0.5, 0.7, 0.8, 0.9, 0.95]: tokens = expected_tokens(alpha, K=5) print(f" alpha={alpha:.2f}: {tokens:.2f} tokens/step") # Output: # alpha=0.50: 1.97 tokens/step # alpha=0.70: 3.16 tokens/step # alpha=0.80: 3.97 tokens/step # alpha=0.90: 4.87 tokens/step # alpha=0.95: 5.23 tokens/step

At 80% acceptance rate, we get nearly 4 tokens per verification step. If the draft model is 20x faster than the target, the speedup is significant.

The overall speedup formula:

Speedup=E[tokens per step]TtargetKTdraft+Ttarget\text{Speedup} = \frac{\mathbb{E}[\text{tokens per step}] \cdot T_{\text{target}}}{K \cdot T_{\text{draft}} + T_{\text{target}}}

where TtargetT_{\text{target}} is the target model's forward pass latency and TdraftT_{\text{draft}} is the draft model's per-token latency.

Choosing the Draft Model

The choice of draft model is critical. You need a model that is simultaneously:

  1. Fast: The whole point is that KK draft steps are cheaper than KK target steps
  2. Accurate: Higher acceptance rate means more tokens per verification step
  3. Compatible: Must use the same tokenizer and vocabulary as the target model

These requirements create a fundamental tension. A more capable draft model has a higher acceptance rate but runs slower. A tiny draft model is lightning fast but might get rejected constantly.

Same Tokenizer Requirement

This is a hard constraint. If the draft model uses a different tokenizer, the tokens don't correspond, and verification is impossible. In practice, this usually means the draft model must be from the same model family.

Common pairings:

Target ModelDraft ModelSpeed RatioTypical α\alpha
Llama 2 70BLlama 2 7B~10x0.7-0.85
GPT-4GPT-3.5~5-8x0.6-0.8
Codex 175BCodex 12B~12x0.8-0.9 (code)
Gemma 2 27BGemma 2 2B~8x0.65-0.8

Note that code generation tends to have higher acceptance rates -- code is more predictable than natural language, with boilerplate, common patterns, and strict syntax.

The Speed-Accuracy Tradeoff

python
def optimal_K(alpha, draft_ms, target_ms, max_K=20): """Find the draft length K that maximizes speedup.""" best_K, best_speedup = 1, 0 for K in range(1, max_K + 1): tokens = expected_tokens(alpha, K) time = K * draft_ms + target_ms speedup = (tokens * target_ms) / time if speedup > best_speedup: best_K, best_speedup = K, speedup return best_K, best_speedup # Llama 2 70B (40ms) + Llama 2 7B (4ms), alpha=0.8 K, speedup = optimal_K(0.8, draft_ms=4, target_ms=40) print(f"Optimal K={K}, speedup={speedup:.2f}x") # Optimal K=8, speedup=2.71x # With a worse draft model (alpha=0.5) K, speedup = optimal_K(0.5, draft_ms=4, target_ms=40) print(f"Optimal K={K}, speedup={speedup:.2f}x") # Optimal K=3, speedup=1.47x

Key insight: with a bad draft model, the optimal KK is small. You should only speculate a few tokens ahead because most will be rejected anyway. With a strong draft model, you can speculate further and reap bigger rewards.

Beyond Small Models: Alternative Drafters

The draft model doesn't have to be a smaller neural network. Several alternatives exist:

N-gram models: A simple lookup table that predicts the next token based on the previous nn tokens. Essentially free to evaluate. Works surprisingly well for repetitive text, code, and common phrases. The acceptance rate is lower, but the near-zero cost compensates.

python
class NGramDrafter: def __init__(self, n=4): self.n = n self.counts = defaultdict(Counter) def train(self, token_ids): for i in range(len(token_ids) - self.n): context = tuple(token_ids[i:i+self.n]) next_token = token_ids[i+self.n] self.counts[context][next_token] += 1 def predict(self, context): key = tuple(context[-self.n:]) if key in self.counts: total = sum(self.counts[key].values()) return {t: c/total for t, c in self.counts[key].items()} return None # Fall back to target model

Retrieval-based drafters: Look up similar contexts in a database and predict the most likely continuation. This is the idea behind REST (He et al., 2023).

Quantized versions of the target: Aggressively quantize the target model (e.g., from fp16 to int4) and use the quantized version as the drafter. Same architecture, same tokenizer, much faster, but somewhat less accurate.

Medusa: No Separate Draft Model Needed

What if we could eliminate the draft model entirely? Medusa (Cai et al., 2024) does exactly this by adding multiple lightweight "prediction heads" directly to the target model.

The Idea

Standard LLMs have a single prediction head: the final linear layer that maps the last hidden state to vocabulary logits. This head predicts the next token.

Medusa adds KK additional heads, each trained to predict tokens further into the future:

  • Head 1: predicts token t+1t+1 (same as the original head)
  • Head 2: predicts token t+2t+2 given hidden state at position tt
  • Head 3: predicts token t+3t+3 given hidden state at position tt
  • ...and so on

Each head is a small MLP (typically 1-2 layers) that operates on the same hidden state. The overhead of running KK extra heads is minimal compared to the rest of the model.

python
class MedusaHead(nn.Module): """One Medusa prediction head.""" def __init__(self, hidden_size, vocab_size): super().__init__() self.linear1 = nn.Linear(hidden_size, hidden_size) self.act = nn.SiLU() self.linear2 = nn.Linear(hidden_size, vocab_size) def forward(self, hidden_states): # hidden_states: (batch, seq_len, hidden_size) x = self.act(self.linear1(hidden_states)) return self.linear2(x) # (batch, seq_len, vocab_size) class MedusaModel(nn.Module): """Target model with Medusa heads.""" def __init__(self, base_model, num_heads, hidden_size, vocab_size): super().__init__() self.base = base_model self.medusa_heads = nn.ModuleList([ MedusaHead(hidden_size, vocab_size) for _ in range(num_heads) ]) def forward(self, input_ids, **kwargs): # Run base model outputs = self.base(input_ids, **kwargs) hidden = outputs.last_hidden_state # Run Medusa heads medusa_logits = [head(hidden) for head in self.medusa_heads] return outputs.logits, medusa_logits

Tree Attention: Exploring Multiple Candidates

Here's where it gets clever. Each Medusa head provides a probability distribution over the vocabulary. Rather than greedily picking one token per head, Medusa takes the top-kk candidates from each head and forms a tree of possible continuations.

For example, with 3 heads and top-2 candidates each:

  • Head 1 suggests: {"quick": 0.45, "lazy": 0.30}
  • Head 2 suggests: {"brown": 0.60, "red": 0.25}
  • Head 3 suggests: {"fox": 0.72, "dog": 0.18}

This creates a tree with up to 2×2×2=82 \times 2 \times 2 = 8 candidate paths. Using a specially constructed tree attention mask, the target model can verify all paths simultaneously in a single forward pass.

The tree attention mask is the key innovation. In standard causal attention, token ii can attend to tokens 1,2,,i1, 2, \ldots, i. In tree attention, each node in the tree can attend to its ancestors in the tree (its path from root to self) but not to nodes on other branches. This prevents information leaking between candidate paths.

python
def build_tree_attention_mask(tree_candidates): """ Build attention mask for tree verification. Each candidate token attends to: 1. All prefix tokens (the already-generated context) 2. Its ancestors in the tree (root -> parent -> self) But NOT to: - Siblings (other branches at the same depth) - Tokens in other subtrees """ n_candidates = len(tree_candidates) # Start with a lower-triangular prefix mask # Then add tree-structured entries mask = torch.zeros(n_candidates, n_candidates, dtype=torch.bool) for i, candidate in enumerate(tree_candidates): # Each candidate can attend to its ancestors for ancestor_idx in candidate.ancestor_indices: mask[i, ancestor_idx] = True # And to itself mask[i, i] = True return mask

The beauty of tree attention is that it transforms a sequential search (try path 1, then path 2, then path 3...) into a parallel verification. The model loads its weights once, processes all candidate paths, and determines which tokens to accept -- all in one pass.

Medusa Training

The Medusa heads are trained on the same data as the base model, but with an offset in the target labels:

python
def medusa_training_loss(model, input_ids, labels): """ Train Medusa heads to predict future tokens. Head k predicts token at position t+k+1 given hidden state at t. """ base_logits, medusa_logits = model(input_ids) # Standard next-token loss for the base head loss = F.cross_entropy( base_logits[:, :-1].reshape(-1, vocab_size), labels[:, 1:].reshape(-1) ) # Medusa head losses: head k predicts k+1 steps ahead for k, head_logits in enumerate(medusa_logits): offset = k + 2 # Head 0 predicts t+2, head 1 predicts t+3, etc. if offset < labels.size(1): head_loss = F.cross_entropy( head_logits[:, :-offset].reshape(-1, vocab_size), labels[:, offset:].reshape(-1) ) loss += head_loss return loss

Crucially, the base model weights are frozen during Medusa training. Only the lightweight heads are trained, typically requiring a few hours on a single GPU. This makes Medusa much cheaper to adopt than training a separate draft model from scratch.

Self-Speculative Decoding: The Model Drafts for Itself

Medusa requires training additional heads. Self-speculative decoding (Zhang et al., 2023) takes a different approach: use the target model itself as the drafter by performing early exit.

The idea is simple. A 70B model has, say, 80 transformer layers. The representations at layer 20 already encode a lot of information about the next token. If we attach a lightweight prediction head to layer 20, we get a "cheap" draft model that:

  • Uses the same tokenizer (trivially)
  • Shares most of the computation with the target model
  • Requires no separate model in memory
python
class SelfSpeculativeDecoder: def __init__(self, model, exit_layer=20, num_layers=80): self.model = model self.exit_layer = exit_layer # Small head trained on intermediate representations self.draft_head = nn.Linear(model.config.hidden_size, model.config.vocab_size) def draft_step(self, input_ids, kv_cache): """Run only the first `exit_layer` layers for a cheap prediction.""" hidden = self.model.embed(input_ids) for i in range(self.exit_layer): hidden, kv_cache[i] = self.model.layers[i](hidden, kv_cache[i]) logits = self.draft_head(hidden[:, -1:]) return logits, kv_cache def verify_step(self, input_ids, kv_cache): """Run the full model on all candidates.""" return self.model(input_ids, kv_cache=kv_cache)

The tradeoff is that the draft quality depends on which layer we exit at. Exit too early and the acceptance rate plummets. Exit too late and we save little computation. Research suggests that exiting around 25-40% of the way through the model provides a good balance.

Practical Speedups: When Does It Actually Help?

Speculative decoding sounds great in theory. In practice, the speedup depends heavily on the workload.

Best Case: Predictable Text

When the text is highly predictable -- code, boilerplate, common phrases, formatted data -- the draft model matches the target model well, acceptance rates are high, and speedups of 2-3x are common.

python
# Code completion: highly predictable prompt = "def fibonacci(n):\n if n <= 1:\n return" # Draft model easily predicts: " n\n return fibonacci(n-1) + fibonacci(n-2)" # Acceptance rate: ~85-95% # Speedup: 2.5-3.5x # JSON/structured output: very predictable prompt = '{"name": "Alice", "age": 30, "address": {' # Draft model predicts keys, braces, quotes easily # Acceptance rate: ~90-95% # Speedup: 3-4x

Worst Case: Creative/Diverse Text

When the text is creative, unusual, or requires reasoning, the draft model's predictions diverge from the target, acceptance rates drop, and the overhead of running the draft model eats into the speedup.

python
# Creative writing: less predictable prompt = "Write a surreal poem about quantum mechanics:" # Draft model often disagrees with target on word choice # Acceptance rate: ~40-60% # Speedup: 1.2-1.5x (still positive, but modest) # Reasoning/math: hard to predict prompt = "Prove that there are infinitely many primes:" # Each logical step is hard to predict # Acceptance rate: ~30-50% # Speedup: 1.0-1.3x (barely worth it)

Batch Size Matters

This is the most important practical consideration. Speculative decoding shines at batch size 1 (single user, single request), where decode is heavily memory-bound. As batch size increases:

  1. The target model's forward pass becomes more compute-bound (processing many sequences)
  2. The cost of verification increases proportionally to the number of candidate tokens
  3. The draft model's overhead is no longer negligible
python
# Batch size impact on speculative decoding speedup # (approximate, varies by hardware and model) batch_sizes = [1, 2, 4, 8, 16, 32] speedups = [2.8, 2.4, 2.0, 1.5, 1.2, 1.0] # At batch size 32, speculative decoding barely helps # because the GPU is already well-utilized

In production serving systems with continuous batching (like vLLM), the effective batch size is often large enough that speculative decoding provides diminished returns. The technique is most impactful for:

  • Interactive single-user applications (chatbots, coding assistants)
  • Latency-sensitive applications where throughput matters less
  • Edge/mobile deployment where batch size is always 1

Temperature and Sampling

Speculative decoding works with any sampling strategy -- greedy, temperature sampling, top-kk, top-pp (nucleus). The key is that both the draft and target models must use the same sampling parameters.

At temperature 0 (greedy decoding), the acceptance rule simplifies: a draft token is accepted if and only if it matches the target's argmax. There's no probabilistic acceptance -- it's binary.

At higher temperatures, both distributions become flatter, which actually helps alignment. The acceptance rate tends to increase with temperature because both models assign more uniform probabilities, making their distributions more similar.

Full Implementation

Here's a complete PyTorch implementation of speculative decoding:

python
import torch import torch.nn.functional as F @torch.no_grad() def speculative_decode( target_model, draft_model, input_ids, # (1, seq_len) - the prompt max_new_tokens=100, K=5, # speculation length temperature=1.0, ): """ Speculative decoding with a separate draft model. Returns tokens sampled from the exact target distribution. """ device = input_ids.device generated = input_ids.clone() # Initialize KV caches target_cache = None draft_cache = None tokens_generated = 0 while tokens_generated < max_new_tokens: # ---- Phase 1: Draft ---- # Generate K candidate tokens autoregressively with the draft model draft_tokens = [] draft_probs = [] draft_input = generated for _ in range(K): draft_logits, draft_cache = draft_model( draft_input[:, -1:] if draft_cache else draft_input, kv_cache=draft_cache ) # Get distribution from draft model draft_dist = F.softmax(draft_logits[:, -1] / temperature, dim=-1) # Sample from draft token = torch.multinomial(draft_dist, 1) draft_tokens.append(token) draft_probs.append(draft_dist) draft_input = torch.cat([draft_input, token], dim=-1) draft_tokens_tensor = torch.cat(draft_tokens, dim=-1) # (1, K) # ---- Phase 2: Verify ---- # Run target model on all K candidates in ONE forward pass verify_input = torch.cat([generated, draft_tokens_tensor], dim=-1) target_logits, target_cache = target_model( verify_input[:, -(K+1):] if target_cache else verify_input, kv_cache=target_cache ) # Get target distributions for each position target_dists = F.softmax( target_logits[:, -(K+1):] / temperature, dim=-1 ) # ---- Phase 3: Accept/Reject ---- n_accepted = 0 for i in range(K): token_id = draft_tokens_tensor[0, i].item() p_i = target_dists[0, i, token_id].item() q_i = draft_probs[i][0, token_id].item() # Acceptance probability if q_i == 0: accept = True # draft assigned 0 prob but we somehow sampled it else: accept_prob = min(1.0, p_i / q_i) accept = torch.rand(1).item() < accept_prob if accept: n_accepted += 1 else: # Reject: sample from recovery distribution # p'(x) = max(0, p(x) - q(x)) / Z p_dist = target_dists[0, i] q_dist = draft_probs[i][0] residual = torch.clamp(p_dist - q_dist, min=0) residual_sum = residual.sum() if residual_sum > 0: recovery_dist = residual / residual_sum else: recovery_dist = p_dist # fallback recovery_token = torch.multinomial(recovery_dist.unsqueeze(0), 1) # Accept tokens up to i, plus the recovery token accepted = draft_tokens_tensor[:, :i] generated = torch.cat( [generated, accepted, recovery_token], dim=-1 ) tokens_generated += i + 1 # Roll back draft cache to position of rejection draft_cache = None # Simplification: reset draft cache # Roll back target cache target_cache = None # In practice, truncate to correct length break else: # All K tokens accepted! Sample bonus token from target's # distribution at position K+1 bonus_dist = target_dists[0, -1] bonus_token = torch.multinomial(bonus_dist.unsqueeze(0), 1) generated = torch.cat( [generated, draft_tokens_tensor, bonus_token], dim=-1 ) tokens_generated += K + 1 # Reset draft cache for next round draft_cache = None return generated

KV Cache Management

The trickiest part of implementation is managing the KV caches for both models. When tokens are rejected:

  1. Target model cache: Must be truncated to remove entries for rejected tokens. In the simplified code above, we reset it entirely, but in practice you'd slice the cache tensors.

  2. Draft model cache: Must be completely reset or rebuilt, because the draft model's internal state diverged from the accepted sequence at the point of rejection.

python
def truncate_kv_cache(kv_cache, keep_length): """Remove cache entries beyond keep_length.""" if kv_cache is None: return None truncated = [] for layer_cache in kv_cache: k, v = layer_cache truncated.append(( k[:, :, :keep_length], # (batch, heads, seq, dim) v[:, :, :keep_length] )) return truncated

Efficient cache management is what makes the difference between a textbook implementation and a production-ready one. Systems like vLLM and TensorRT-LLM have sophisticated cache management specifically optimized for speculative decoding.

Advanced Variants

The field has evolved rapidly since the original speculative decoding papers. Here is a survey of the major variants.

Staged Speculative Decoding

Instead of one draft model, use a cascade: an extremely fast "level-0" drafter (n-gram model), a medium-speed "level-1" drafter (small neural network), and the target model for verification. Each stage filters candidates, reducing the work the target model needs to do.

Speculative Decoding with Tree Drafts

Instead of generating a single chain of KK tokens, the draft model generates a tree of candidates (similar to Medusa but using a separate draft model). The target model verifies all paths using tree attention. This increases the probability that at least one long path is fully accepted.

Online Speculative Decoding (Liu et al., 2024)

The draft model is continuously fine-tuned during serving to better match the target model's behavior on the current distribution of queries. As the draft model improves, acceptance rates increase over time.

SpecInfer (Miao et al., 2024)

Combines multiple draft models (of different architectures) into an ensemble, boosting the diversity of candidate tokens. The target model's single verification pass checks candidates from all draft models.

Lookahead Decoding (Fu et al., 2024)

Eliminates the draft model entirely by exploiting Jacobi iteration. The idea: guess all KK tokens simultaneously, run a forward pass, and check which positions converged. Repeat until all positions agree. Provably generates from the correct distribution.

Summary

Speculative decoding is a rare case of getting something for nothing. The output is mathematically identical to standard decoding -- same distribution, same quality, same everything -- but faster.

The key ideas:

  1. Single-token decode is memory-bound. The GPU loads all model weights for tiny matrix-vector multiplies. Most compute capacity is wasted.

  2. Verification is almost free. Processing KK tokens costs roughly the same as processing 1 token, because the bottleneck is weight loading, not arithmetic.

  3. Draft-then-verify exploits the asymmetry. A cheap draft model proposes candidates; the expensive target model verifies in parallel.

  4. Modified rejection sampling guarantees losslessness. The accept/reject/recovery mechanism produces samples from the exact target distribution.

  5. The speedup scales with prediction quality. Predictable text (code, boilerplate) sees 2-3x speedup. Creative text sees 1.2-1.5x. Batch size 1 benefits most.

The technique has rapidly moved from research to production. As of early 2026, speculative decoding (or its variants like Medusa) is deployed in virtually every major LLM serving system: vLLM, TensorRT-LLM, HuggingFace TGI, and proprietary systems at OpenAI, Google, and Anthropic.

The autoregressive bottleneck may be fundamental to how language models work, but speculative decoding proves that the cost of that bottleneck is not. We can decode sequentially while still keeping our GPUs busy.