Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 185 additions & 5 deletions examples/models/eagle3/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,56 @@ class Eagle3Config:
norm_before_residual: bool = True
norm_before_fc: bool = False
has_own_embed: bool = False
max_seq_len: int = 4096


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)


class Eagle3KVCache(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What distinguishes this from our standard KVCache? Can we just import the standard one?

"""Flat KV cache for the single EAGLE-3 draft decoder layer.

``update`` writes the new K/V at ``input_pos`` and returns the whole buffer;
an explicit causal mask (built by the draft) selects the valid positions, so
the same path serves both prefill (T>1) and single-step draft decode (T=1).
"""

def __init__(
self,
max_batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
shape = (max_batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(shape), persistent=False)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
self.k_cache.index_copy_(2, input_pos, k_val)
self.v_cache.index_copy_(2, input_pos, v_val)
return self.k_cache, self.v_cache

def allocate(self, dtype: torch.dtype, device) -> None:
"""Re-register the cache buffers with a given dtype/device (zeroed)."""
shape = self.k_cache.shape
self.register_buffer(
"k_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(shape, dtype=dtype, device=device), persistent=False
)

def reset(self) -> None:
self.k_cache.zero_()
self.v_cache.zero_()


class Eagle3Attention(nn.Module):
"""Llama GQA attention; q/k/v project from the doubled-width (2*hidden) input."""

Expand All @@ -75,7 +118,16 @@ def __init__(self, config: Eagle3Config):
)
self.register_buffer("inv_freq", inv_freq, persistent=False)

def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
self.kv_cache = Eagle3KVCache(
max_batch_size=1,
max_seq_len=config.max_seq_len,
num_kv_heads=self.n_kv_heads,
head_dim=self.head_dim,
)

def _project_rope(
self, x: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
Expand All @@ -87,8 +139,30 @@ def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
sin = emb.sin().to(q.dtype)
q = q * cos + _rotate_half(q) * sin
k = k * cos + _rotate_half(k) * sin
return q, k, v

y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
"""Causal attention.

With ``attn_mask=None`` runs stateless over the full sequence
(``is_causal``). With an explicit mask, K/V are written to the KV cache
at ``positions`` and read back, so the same call serves prefill and
single-step draft decode.
"""
B, T, _ = x.shape
q, k, v = self._project_rope(x, positions)
if attn_mask is None:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
else:
k, v = self.kv_cache.update(positions, k, v)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, enable_gqa=True
)
y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(y)

Expand Down Expand Up @@ -129,12 +203,13 @@ def forward(
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
normed_embeds = self.input_layernorm(input_embeds)
normed_feature = self.hidden_norm(feature)
residual = normed_feature if self.norm_before_residual else feature
x = torch.cat((normed_embeds, normed_feature), dim=-1)
x = self.self_attn(x, positions)
x = self.self_attn(x, positions, attn_mask)
x = residual + x

residual = x
Expand Down Expand Up @@ -170,6 +245,15 @@ def __init__(self, config: Eagle3Config):
persistent=False,
)
self.register_buffer("t2d", torch.zeros(1, dtype=torch.bool), persistent=False)
# cache_positions[i] = i; used to build the causal mask over the KV cache
# without introducing dynamic-shape index tensors at runtime.
self.register_buffer(
"cache_positions",
torch.arange(config.max_seq_len, dtype=torch.long),
persistent=False,
)
# Eager-only end of the valid contiguous cache prefix (see forward_cached).
self._cache_valid_end = 0

def fuse(self, aux: torch.Tensor) -> torch.Tensor:
"""Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D)."""
Expand All @@ -193,23 +277,112 @@ def forward(
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the midlayer over a sequence.

``attn_mask=None`` runs stateless over the full sequence; an explicit
mask uses the incremental KV cache (see ``forward_cached``).

Returns (draft_logits, g):
draft_logits: (B, T, draft_vocab_size) over the reduced vocab.
g: (B, T, hidden) midlayer output — the recurrent feature.
"""
g = self.midlayer(input_embeds, feature, positions)
g = self.midlayer(input_embeds, feature, positions, attn_mask)
draft_logits = self.lm_head(self.norm(g))
return draft_logits, g

def _build_causal_mask(self, positions: torch.Tensor) -> torch.Tensor:
"""Boolean (1, 1, T, max_seq_len) causal mask (True = attend).

Query position p attends to cache slot j iff j <= p. This is correct
only under the contiguous-from-0 invariant of ``forward_cached``: a query
at p attends to slots 0..p, all of which must already hold this
sequence's K/V. Rejected speculative tokens sit at slots > p (the next
query's p shrinks on rollback) and are excluded by the causal bound, so
they need no extra masking. A non-contiguous seed (e.g. writing only
slot 10 after reset) would wrongly attend to the zeroed slots 0..9 — see
``forward_cached``.
"""
q_pos = positions.unsqueeze(1) # (T, 1)
cache_pos = self.cache_positions.unsqueeze(0) # (1, max_seq_len)
return (q_pos >= cache_pos).unsqueeze(0).unsqueeze(0)

def forward_cached(
self,
input_embeds: torch.Tensor,
feature: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""KV-cached forward for prefill (T>1) and single-step draft decode (T=1).

Writes K/V at ``positions`` and attends over the cache. Like the target's
KV cache, attention scores against the whole ``max_seq_len`` buffer under
a static causal mask (export-friendly, no data-dependent slicing); what
this avoids versus a full recompute is re-running the prefix's
projections and MLP, not the attention score width.

Invariant: writes must be contiguous from position 0. Seed a fresh
sequence (after ``reset_cache``) starting at position 0 and only ever
extend with the next contiguous positions; offset or gapped seeds attend
to unwritten (zeroed) slots and are not supported. Batch size must be 1.
"""
if input_embeds.shape[0] != 1:
raise ValueError("forward_cached supports batch size 1 only")
if not torch.compiler.is_compiling():
self._validate_contiguous(positions)
return self.forward(
input_embeds, feature, positions, self._build_causal_mask(positions)
)

def _validate_contiguous(self, positions: torch.Tensor) -> None:
"""Eager-only guard for the contiguous-from-0 cache invariant.

Tracks the end of the valid contiguous prefix (reset by ``reset_cache``).
A write may overwrite already-written slots (speculative rollback) but
must be contiguous and ascending and must not start beyond the valid
prefix, which would leave unwritten (zeroed) slots below it in the
attention window. A rollback overwrite truncates the valid prefix to the
end of the write, so a slot above it is treated as stale and a later
write that skips it is rejected until it is rewritten. Skipped under
export/compile, where positions are traced tensors and the runner owns
the contract.
"""
start = int(positions[0])
length = int(positions.shape[0])
expected = torch.arange(start, start + length, device=positions.device)
if not torch.equal(positions, expected):
raise ValueError(
f"forward_cached positions must be contiguous ascending, "
f"got {positions.tolist()}"
)
if start > self._cache_valid_end:
raise ValueError(
f"non-contiguous cache seed: positions start at {start} but only "
f"{self._cache_valid_end} slot(s) are valid; seed from 0 after "
f"reset_cache"
)
# A write defines the valid prefix up to its end; slots above it (from an
# earlier longer write that this one rolled back) are now stale.
self._cache_valid_end = start + length

def reset_cache(self) -> None:
self.midlayer.self_attn.kv_cache.reset()
self._cache_valid_end = 0

def allocate_kv_cache(self, dtype: torch.dtype, device) -> None:
"""(Re)allocate the KV cache in a given dtype/device (zeroed)."""
self.midlayer.self_attn.kv_cache.allocate(dtype, device)

def draft_to_target(self, draft_ids: torch.Tensor) -> torch.Tensor:
return draft_ids + self.d2t[draft_ids]

@staticmethod
def from_checkpoint(
model_dir: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16
model_dir: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_seq_len: int = 4096,
) -> tuple["Eagle3Draft", Eagle3Config]:
import json

Expand All @@ -231,6 +404,7 @@ def from_checkpoint(
aux_hidden_state_layers=cfg["eagle_aux_hidden_state_layer_ids"],
norm_before_residual=cfg.get("norm_before_residual", False),
norm_before_fc=cfg.get("norm_before_fc", False),
max_seq_len=max_seq_len,
)
if config.norm_before_fc:
# This checkpoint variant requires an input RMSNorm before fc.
Expand All @@ -256,6 +430,12 @@ def from_checkpoint(
model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False)
model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False)
model.load_state_dict(state_dict, strict=True, assign=True)
# Allocate the KV cache directly in the compute dtype on the target
# device *before* moving weights, so the float32 placeholder cache from
# __init__ is freed without ever being copied to the device. The
# subsequent .to(device) is a no-op for the (already-placed) cache and
# carries no dtype argument, so inv_freq stays float32.
model.allocate_kv_cache(dtype, device)
model = model.to(device)
assert (
model.midlayer.self_attn.inv_freq.dtype == torch.float32
Expand Down
113 changes: 113 additions & 0 deletions examples/models/eagle3/test_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,119 @@ def test_norm_before_residual_changes_output():
assert not torch.allclose(outs[0], outs[1]), "norm_before_residual had no effect"


def test_kv_cache_matches_full_recompute():
# Cached prefill + single-step decode must equal stateless full recompute.
torch.manual_seed(0)
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
T, prefill = 6, 3
feat = model.fuse(
torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size)
)
emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0)

with torch.no_grad():
ref_logits, ref_g = model(emb, feat, torch.arange(T))

model.reset_cache()
pl, pg = model.forward_cached(
emb[:, :prefill], feat[:, :prefill], torch.arange(prefill)
)
torch.testing.assert_close(pl, ref_logits[:, :prefill])
torch.testing.assert_close(pg, ref_g[:, :prefill])

for i in range(prefill, T):
sl, sg = model.forward_cached(
emb[:, i : i + 1], feat[:, i : i + 1], torch.arange(i, i + 1)
)
torch.testing.assert_close(sl, ref_logits[:, i : i + 1])
torch.testing.assert_close(sg, ref_g[:, i : i + 1])


def test_reset_cache_isolates_sequences():
# A second sequence after reset_cache must match a fresh full recompute.
torch.manual_seed(0)
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
T = 4
feat = model.fuse(
torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size)
)
emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0)
with torch.no_grad():
model.forward_cached(emb, feat, torch.arange(T)) # pollute the cache
model.reset_cache()
cached, _ = model.forward_cached(emb, feat, torch.arange(T))
ref, _ = model(emb, feat, torch.arange(T))
torch.testing.assert_close(cached, ref)


def test_offset_seed_after_reset_is_rejected():
# The contiguous-from-0 invariant is enforced in eager: an offset seed would
# attend to zeroed slots, so forward_cached rejects it outright.
torch.manual_seed(0)
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
T = 4
feat = model.fuse(
torch.randn(1, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size)
)
emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0)
model.reset_cache()
with pytest.raises(ValueError, match="non-contiguous cache seed"):
model.forward_cached(emb, feat, torch.arange(10, 10 + T))


def test_gapped_positions_rejected():
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
emb = torch.randn(1, 3, cfg.hidden_size)
feat = torch.randn(1, 3, cfg.hidden_size)
model.reset_cache()
with pytest.raises(ValueError, match="contiguous ascending"):
model.forward_cached(emb, feat, torch.tensor([0, 2, 3]))


def test_rollback_reseed_is_allowed():
# Overwriting already-written slots (speculative rollback) must be accepted.
torch.manual_seed(0)
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
emb = model.embed(torch.randint(0, cfg.target_vocab_size, (6,))).unsqueeze(0)
feat = model.fuse(
torch.randn(1, 6, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size)
)
with torch.no_grad():
model.reset_cache()
model.forward_cached(emb, feat, torch.arange(6)) # write slots 0..5
# re-decode at slot 4 (a rejected proposal rolled back) — allowed.
model.forward_cached(emb[:, 4:5], feat[:, 4:5], torch.arange(4, 5))


def test_post_rollback_gap_is_rejected():
# A rollback overwrite shrinks the valid prefix: after writing 0..5 then
# re-decoding slot 4, slot 5 holds stale (rejected) K/V, so a write starting
# at 6 must be rejected until slot 5 is rewritten.
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
model.reset_cache()
model._validate_contiguous(torch.arange(6)) # write slots 0..5
model._validate_contiguous(torch.arange(4, 5)) # rollback overwrite slot 4
with pytest.raises(ValueError, match="non-contiguous"):
model._validate_contiguous(torch.arange(6, 7)) # slot 5 stale -> rejected
model._validate_contiguous(torch.arange(5, 6)) # rewrite slot 5 -> ok
model._validate_contiguous(torch.arange(6, 7)) # now slot 6 -> ok


def test_forward_cached_rejects_batch_gt_1():
cfg = tiny_config()
model = Eagle3Draft(cfg).to(torch.float32).eval()
emb = torch.randn(2, 3, cfg.hidden_size)
feat = torch.randn(2, 3, cfg.hidden_size)
with pytest.raises(ValueError, match="batch size 1"):
model.forward_cached(emb, feat, torch.arange(3))


def test_draft_to_target_mapping():
model = Eagle3Draft(tiny_config()).eval()
model.d2t.copy_(torch.arange(model.config.draft_vocab_size)) # offset = id
Expand Down
Loading