diff --git a/examples/models/eagle3/draft.py b/examples/models/eagle3/draft.py index c372730b784..0f7ba817759 100644 --- a/examples/models/eagle3/draft.py +++ b/examples/models/eagle3/draft.py @@ -45,6 +45,7 @@ class Eagle3Config: norm_before_residual: bool = True norm_before_fc: bool = False has_own_embed: bool = False + max_seq_len: int = 4096 def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -52,6 +53,48 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) +class Eagle3KVCache(nn.Module): + """Flat KV cache for the single EAGLE-3 draft decoder layer. + + ``update`` writes the new K/V at ``input_pos`` and returns the whole buffer; + an explicit causal mask (built by the draft) selects the valid positions, so + the same path serves both prefill (T>1) and single-step draft decode (T=1). + """ + + def __init__( + self, + max_batch_size: int, + max_seq_len: int, + num_kv_heads: int, + head_dim: int, + ): + super().__init__() + shape = (max_batch_size, num_kv_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(shape), persistent=False) + self.register_buffer("v_cache", torch.zeros(shape), persistent=False) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + self.k_cache.index_copy_(2, input_pos, k_val) + self.v_cache.index_copy_(2, input_pos, v_val) + return self.k_cache, self.v_cache + + def allocate(self, dtype: torch.dtype, device) -> None: + """Re-register the cache buffers with a given dtype/device (zeroed).""" + shape = self.k_cache.shape + self.register_buffer( + "k_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False + ) + self.register_buffer( + "v_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False + ) + + def reset(self) -> None: + self.k_cache.zero_() + self.v_cache.zero_() + + class Eagle3Attention(nn.Module): """Llama GQA attention; q/k/v project from the doubled-width (2*hidden) input.""" @@ -75,7 +118,16 @@ def __init__(self, config: Eagle3Config): ) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + self.kv_cache = Eagle3KVCache( + max_batch_size=1, + max_seq_len=config.max_seq_len, + num_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + + def _project_rope( + self, x: torch.Tensor, positions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) @@ -87,8 +139,30 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: sin = emb.sin().to(q.dtype) q = q * cos + _rotate_half(q) * sin k = k * cos + _rotate_half(k) * sin + return q, k, v - y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + attn_mask: torch.Tensor = None, + ) -> torch.Tensor: + """Causal attention. + + With ``attn_mask=None`` runs stateless over the full sequence + (``is_causal``). With an explicit mask, K/V are written to the KV cache + at ``positions`` and read back, so the same call serves prefill and + single-step draft decode. + """ + B, T, _ = x.shape + q, k, v = self._project_rope(x, positions) + if attn_mask is None: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) + else: + k, v = self.kv_cache.update(positions, k, v) + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, enable_gqa=True + ) y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(y) @@ -129,12 +203,13 @@ def forward( input_embeds: torch.Tensor, feature: torch.Tensor, positions: torch.Tensor, + attn_mask: torch.Tensor = None, ) -> torch.Tensor: normed_embeds = self.input_layernorm(input_embeds) normed_feature = self.hidden_norm(feature) residual = normed_feature if self.norm_before_residual else feature x = torch.cat((normed_embeds, normed_feature), dim=-1) - x = self.self_attn(x, positions) + x = self.self_attn(x, positions, attn_mask) x = residual + x residual = x @@ -170,6 +245,15 @@ def __init__(self, config: Eagle3Config): persistent=False, ) self.register_buffer("t2d", torch.zeros(1, dtype=torch.bool), persistent=False) + # cache_positions[i] = i; used to build the causal mask over the KV cache + # without introducing dynamic-shape index tensors at runtime. + self.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long), + persistent=False, + ) + # Eager-only end of the valid contiguous cache prefix (see forward_cached). + self._cache_valid_end = 0 def fuse(self, aux: torch.Tensor) -> torch.Tensor: """Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D).""" @@ -193,23 +277,112 @@ def forward( input_embeds: torch.Tensor, feature: torch.Tensor, positions: torch.Tensor, + attn_mask: torch.Tensor = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Run the midlayer over a sequence. + ``attn_mask=None`` runs stateless over the full sequence; an explicit + mask uses the incremental KV cache (see ``forward_cached``). + Returns (draft_logits, g): draft_logits: (B, T, draft_vocab_size) over the reduced vocab. g: (B, T, hidden) midlayer output — the recurrent feature. """ - g = self.midlayer(input_embeds, feature, positions) + g = self.midlayer(input_embeds, feature, positions, attn_mask) draft_logits = self.lm_head(self.norm(g)) return draft_logits, g + def _build_causal_mask(self, positions: torch.Tensor) -> torch.Tensor: + """Boolean (1, 1, T, max_seq_len) causal mask (True = attend). + + Query position p attends to cache slot j iff j <= p. This is correct + only under the contiguous-from-0 invariant of ``forward_cached``: a query + at p attends to slots 0..p, all of which must already hold this + sequence's K/V. Rejected speculative tokens sit at slots > p (the next + query's p shrinks on rollback) and are excluded by the causal bound, so + they need no extra masking. A non-contiguous seed (e.g. writing only + slot 10 after reset) would wrongly attend to the zeroed slots 0..9 — see + ``forward_cached``. + """ + q_pos = positions.unsqueeze(1) # (T, 1) + cache_pos = self.cache_positions.unsqueeze(0) # (1, max_seq_len) + return (q_pos >= cache_pos).unsqueeze(0).unsqueeze(0) + + def forward_cached( + self, + input_embeds: torch.Tensor, + feature: torch.Tensor, + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """KV-cached forward for prefill (T>1) and single-step draft decode (T=1). + + Writes K/V at ``positions`` and attends over the cache. Like the target's + KV cache, attention scores against the whole ``max_seq_len`` buffer under + a static causal mask (export-friendly, no data-dependent slicing); what + this avoids versus a full recompute is re-running the prefix's + projections and MLP, not the attention score width. + + Invariant: writes must be contiguous from position 0. Seed a fresh + sequence (after ``reset_cache``) starting at position 0 and only ever + extend with the next contiguous positions; offset or gapped seeds attend + to unwritten (zeroed) slots and are not supported. Batch size must be 1. + """ + if input_embeds.shape[0] != 1: + raise ValueError("forward_cached supports batch size 1 only") + if not torch.compiler.is_compiling(): + self._validate_contiguous(positions) + return self.forward( + input_embeds, feature, positions, self._build_causal_mask(positions) + ) + + def _validate_contiguous(self, positions: torch.Tensor) -> None: + """Eager-only guard for the contiguous-from-0 cache invariant. + + Tracks the end of the valid contiguous prefix (reset by ``reset_cache``). + A write may overwrite already-written slots (speculative rollback) but + must be contiguous and ascending and must not start beyond the valid + prefix, which would leave unwritten (zeroed) slots below it in the + attention window. A rollback overwrite truncates the valid prefix to the + end of the write, so a slot above it is treated as stale and a later + write that skips it is rejected until it is rewritten. Skipped under + export/compile, where positions are traced tensors and the runner owns + the contract. + """ + start = int(positions[0]) + length = int(positions.shape[0]) + expected = torch.arange(start, start + length, device=positions.device) + if not torch.equal(positions, expected): + raise ValueError( + f"forward_cached positions must be contiguous ascending, " + f"got {positions.tolist()}" + ) + if start > self._cache_valid_end: + raise ValueError( + f"non-contiguous cache seed: positions start at {start} but only " + f"{self._cache_valid_end} slot(s) are valid; seed from 0 after " + f"reset_cache" + ) + # A write defines the valid prefix up to its end; slots above it (from an + # earlier longer write that this one rolled back) are now stale. + self._cache_valid_end = start + length + + def reset_cache(self) -> None: + self.midlayer.self_attn.kv_cache.reset() + self._cache_valid_end = 0 + + def allocate_kv_cache(self, dtype: torch.dtype, device) -> None: + """(Re)allocate the KV cache in a given dtype/device (zeroed).""" + self.midlayer.self_attn.kv_cache.allocate(dtype, device) + def draft_to_target(self, draft_ids: torch.Tensor) -> torch.Tensor: return draft_ids + self.d2t[draft_ids] @staticmethod def from_checkpoint( - model_dir: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16 + model_dir: str, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_seq_len: int = 4096, ) -> tuple["Eagle3Draft", Eagle3Config]: import json @@ -231,6 +404,7 @@ def from_checkpoint( aux_hidden_state_layers=cfg["eagle_aux_hidden_state_layer_ids"], norm_before_residual=cfg.get("norm_before_residual", False), norm_before_fc=cfg.get("norm_before_fc", False), + max_seq_len=max_seq_len, ) if config.norm_before_fc: # This checkpoint variant requires an input RMSNorm before fc. @@ -256,6 +430,12 @@ def from_checkpoint( model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False) model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False) model.load_state_dict(state_dict, strict=True, assign=True) + # Allocate the KV cache directly in the compute dtype on the target + # device *before* moving weights, so the float32 placeholder cache from + # __init__ is freed without ever being copied to the device. The + # subsequent .to(device) is a no-op for the (already-placed) cache and + # carries no dtype argument, so inv_freq stays float32. + model.allocate_kv_cache(dtype, device) model = model.to(device) assert ( model.midlayer.self_attn.inv_freq.dtype == torch.float32 diff --git a/examples/models/eagle3/test_draft.py b/examples/models/eagle3/test_draft.py index eb2587f5605..2534ca18da2 100644 --- a/examples/models/eagle3/test_draft.py +++ b/examples/models/eagle3/test_draft.py @@ -73,6 +73,119 @@ def test_norm_before_residual_changes_output(): assert not torch.allclose(outs[0], outs[1]), "norm_before_residual had no effect" +def test_kv_cache_matches_full_recompute(): + # Cached prefill + single-step decode must equal stateless full recompute. + torch.manual_seed(0) + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + T, prefill = 6, 3 + feat = model.fuse( + torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size) + ) + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) + + with torch.no_grad(): + ref_logits, ref_g = model(emb, feat, torch.arange(T)) + + model.reset_cache() + pl, pg = model.forward_cached( + emb[:, :prefill], feat[:, :prefill], torch.arange(prefill) + ) + torch.testing.assert_close(pl, ref_logits[:, :prefill]) + torch.testing.assert_close(pg, ref_g[:, :prefill]) + + for i in range(prefill, T): + sl, sg = model.forward_cached( + emb[:, i : i + 1], feat[:, i : i + 1], torch.arange(i, i + 1) + ) + torch.testing.assert_close(sl, ref_logits[:, i : i + 1]) + torch.testing.assert_close(sg, ref_g[:, i : i + 1]) + + +def test_reset_cache_isolates_sequences(): + # A second sequence after reset_cache must match a fresh full recompute. + torch.manual_seed(0) + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + T = 4 + feat = model.fuse( + torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size) + ) + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) + with torch.no_grad(): + model.forward_cached(emb, feat, torch.arange(T)) # pollute the cache + model.reset_cache() + cached, _ = model.forward_cached(emb, feat, torch.arange(T)) + ref, _ = model(emb, feat, torch.arange(T)) + torch.testing.assert_close(cached, ref) + + +def test_offset_seed_after_reset_is_rejected(): + # The contiguous-from-0 invariant is enforced in eager: an offset seed would + # attend to zeroed slots, so forward_cached rejects it outright. + torch.manual_seed(0) + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + T = 4 + feat = model.fuse( + torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size) + ) + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) + model.reset_cache() + with pytest.raises(ValueError, match="non-contiguous cache seed"): + model.forward_cached(emb, feat, torch.arange(10, 10 + T)) + + +def test_gapped_positions_rejected(): + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + emb = torch.randn(1, 3, cfg.hidden_size) + feat = torch.randn(1, 3, cfg.hidden_size) + model.reset_cache() + with pytest.raises(ValueError, match="contiguous ascending"): + model.forward_cached(emb, feat, torch.tensor([0, 2, 3])) + + +def test_rollback_reseed_is_allowed(): + # Overwriting already-written slots (speculative rollback) must be accepted. + torch.manual_seed(0) + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (6,))).unsqueeze(0) + feat = model.fuse( + torch.randn(1, 6, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size) + ) + with torch.no_grad(): + model.reset_cache() + model.forward_cached(emb, feat, torch.arange(6)) # write slots 0..5 + # re-decode at slot 4 (a rejected proposal rolled back) — allowed. + model.forward_cached(emb[:, 4:5], feat[:, 4:5], torch.arange(4, 5)) + + +def test_post_rollback_gap_is_rejected(): + # A rollback overwrite shrinks the valid prefix: after writing 0..5 then + # re-decoding slot 4, slot 5 holds stale (rejected) K/V, so a write starting + # at 6 must be rejected until slot 5 is rewritten. + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + model.reset_cache() + model._validate_contiguous(torch.arange(6)) # write slots 0..5 + model._validate_contiguous(torch.arange(4, 5)) # rollback overwrite slot 4 + with pytest.raises(ValueError, match="non-contiguous"): + model._validate_contiguous(torch.arange(6, 7)) # slot 5 stale -> rejected + model._validate_contiguous(torch.arange(5, 6)) # rewrite slot 5 -> ok + model._validate_contiguous(torch.arange(6, 7)) # now slot 6 -> ok + + +def test_forward_cached_rejects_batch_gt_1(): + cfg = tiny_config() + model = Eagle3Draft(cfg).to(torch.float32).eval() + emb = torch.randn(2, 3, cfg.hidden_size) + feat = torch.randn(2, 3, cfg.hidden_size) + with pytest.raises(ValueError, match="batch size 1"): + model.forward_cached(emb, feat, torch.arange(3)) + + def test_draft_to_target_mapping(): model = Eagle3Draft(tiny_config()).eval() model.d2t.copy_(torch.arange(model.config.draft_vocab_size)) # offset = id