diff --git a/examples/models/eagle3/__init__.py b/examples/models/eagle3/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/models/eagle3/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/eagle3/draft.py b/examples/models/eagle3/draft.py new file mode 100644 index 00000000000..c372730b784 --- /dev/null +++ b/examples/models/eagle3/draft.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""EAGLE-3 draft head for vLLM speculator checkpoints. + +The draft model fuses target auxiliary hidden states with ``fc``, runs one +Llama-style decoder layer over token embeddings plus the fused feature, and +projects the midlayer output to reduced-vocabulary draft logits. The midlayer +output ``g`` is reused as the recurrent feature for drafting; ``fc`` is used +only for target auxiliary hidden states. + +Draft ids map back to target ids with ``target_id = draft_id + d2t[draft_id]``. +Speculator checkpoints store the decoder layer under ``layers.0.*`` and may +include ``embed_tokens``, ``d2t``, and ``t2d``. +""" + +import os +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +@dataclass +class Eagle3Config: + hidden_size: int = 5376 + target_hidden_size: int = 5376 + intermediate_size: int = 21504 + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + head_dim: int = 256 + rope_theta: float = 10_000.0 + rms_norm_eps: float = 1e-6 + draft_vocab_size: int = 32000 + target_vocab_size: int = 262144 + aux_hidden_state_layers: list = field(default_factory=lambda: [2, 30, 57]) + # norm_before_residual: store the attention residual after hidden_norm. + # norm_before_fc: apply an RMSNorm over the concatenated aux features before + # fc (gpt-oss-style speculators checkpoints); not supported here. + # has_own_embed: the head ships its own embed_tokens (set during load). + norm_before_residual: bool = True + norm_before_fc: bool = False + has_own_embed: bool = False + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +class Eagle3Attention(nn.Module): + """Llama GQA attention; q/k/v project from the doubled-width (2*hidden) input.""" + + def __init__(self, config: Eagle3Config): + super().__init__() + self.n_heads = config.num_attention_heads + self.n_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + in_dim = 2 * config.hidden_size + + self.q_proj = nn.Linear(in_dim, self.n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(in_dim, self.n_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear( + self.n_heads * self.head_dim, config.hidden_size, bias=False + ) + + inv_freq = 1.0 / ( + config.rope_theta + ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + B, T, _ = x.shape + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) + + freqs = torch.outer(positions.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(q.dtype) + sin = emb.sin().to(q.dtype) + q = q * cos + _rotate_half(q) * sin + k = k * cos + _rotate_half(k) * sin + + y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + +class Eagle3MLP(nn.Module): + def __init__(self, config: Eagle3Config): + super().__init__() + self.gate_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + config.intermediate_size, config.hidden_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Eagle3Midlayer(nn.Module): + """Single EAGLE-3 decoder layer with dual input norms over two streams.""" + + def __init__(self, config: Eagle3Config): + super().__init__() + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = Eagle3Attention(config) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.mlp = Eagle3MLP(config) + self.norm_before_residual = config.norm_before_residual + + def forward( + self, + input_embeds: torch.Tensor, + feature: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + normed_embeds = self.input_layernorm(input_embeds) + normed_feature = self.hidden_norm(feature) + residual = normed_feature if self.norm_before_residual else feature + x = torch.cat((normed_embeds, normed_feature), dim=-1) + x = self.self_attn(x, positions) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + return residual + x + + +class Eagle3Draft(nn.Module): + def __init__(self, config: Eagle3Config): + super().__init__() + self.config = config + self.fc = nn.Linear( + len(config.aux_hidden_state_layers) * config.target_hidden_size, + config.hidden_size, + bias=False, + ) + self.midlayer = Eagle3Midlayer(config) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear( + config.hidden_size, config.draft_vocab_size, bias=False + ) + if config.has_own_embed: + self.embed_tokens = nn.Embedding( + config.target_vocab_size, config.hidden_size + ) + # d2t/t2d are loaded from the checkpoint (assign=True adopts their + # shapes/dtypes): d2t[draft_id] is the offset to the target vocab id; + # t2d masks which target ids are in the draft vocab. + self.register_buffer( + "d2t", + torch.zeros(config.draft_vocab_size, dtype=torch.long), + persistent=False, + ) + self.register_buffer("t2d", torch.zeros(1, dtype=torch.bool), persistent=False) + + def fuse(self, aux: torch.Tensor) -> torch.Tensor: + """Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D).""" + return self.fc(aux) + + def embed(self, ids: torch.Tensor) -> torch.Tensor: + """Embed token ids with the head's own table. + + Only valid when the checkpoint shipped its own ``embed_tokens``; heads + that reuse the target embedding must source embeddings from the target. + """ + if not self.config.has_own_embed: + raise RuntimeError( + "this draft head has no own embed_tokens (has_own_embed=False); " + "provide token embeddings from the target model instead" + ) + return self.embed_tokens(ids) + + def forward( + self, + input_embeds: torch.Tensor, + feature: torch.Tensor, + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run the midlayer over a sequence. + + Returns (draft_logits, g): + draft_logits: (B, T, draft_vocab_size) over the reduced vocab. + g: (B, T, hidden) midlayer output — the recurrent feature. + """ + g = self.midlayer(input_embeds, feature, positions) + draft_logits = self.lm_head(self.norm(g)) + return draft_logits, g + + def draft_to_target(self, draft_ids: torch.Tensor) -> torch.Tensor: + return draft_ids + self.d2t[draft_ids] + + @staticmethod + def from_checkpoint( + model_dir: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16 + ) -> tuple["Eagle3Draft", Eagle3Config]: + import json + + with open(os.path.join(model_dir, "config.json")) as f: + cfg = json.load(f) + + tlc = cfg["transformer_layer_config"] + config = Eagle3Config( + hidden_size=tlc["hidden_size"], + target_hidden_size=cfg.get("target_hidden_size") or tlc["hidden_size"], + intermediate_size=tlc["intermediate_size"], + num_attention_heads=tlc["num_attention_heads"], + num_key_value_heads=tlc["num_key_value_heads"], + head_dim=tlc["head_dim"], + rope_theta=tlc["rope_parameters"]["rope_theta"], + rms_norm_eps=tlc["rms_norm_eps"], + draft_vocab_size=cfg["draft_vocab_size"], + target_vocab_size=tlc.get("vocab_size", 262144), + aux_hidden_state_layers=cfg["eagle_aux_hidden_state_layer_ids"], + norm_before_residual=cfg.get("norm_before_residual", False), + norm_before_fc=cfg.get("norm_before_fc", False), + ) + if config.norm_before_fc: + # This checkpoint variant requires an input RMSNorm before fc. + raise ValueError( + "norm_before_fc=True checkpoints are not supported " + "(would need an input RMSNorm before fc)" + ) + + raw = _load_safetensors(model_dir) + config.has_own_embed = "embed_tokens.weight" in raw + + # Cast checkpoint weights after module construction so inv_freq stays fp32. + model = Eagle3Draft(config) + # The single decoder layer is stored as layers.0.* on disk. + state_dict = { + (k.replace("layers.0.", "midlayer.") if k.startswith("layers.0.") else k): ( + v.to(dtype) if v.is_floating_point() else v + ) + for k, v in raw.items() + } + # d2t/t2d are index/mask tensors (their checkpoint shape differs from the + # placeholder buffers); register them directly, load the rest strict. + model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False) + model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False) + model.load_state_dict(state_dict, strict=True, assign=True) + model = model.to(device) + assert ( + model.midlayer.self_attn.inv_freq.dtype == torch.float32 + ), "RoPE inv_freq must remain float32" + return model.eval(), config + + +def _load_safetensors(model_dir: str) -> dict: + """Load a monolithic or sharded safetensors checkpoint to CPU tensors.""" + import json + + from safetensors import safe_open + + index = os.path.join(model_dir, "model.safetensors.index.json") + mono = os.path.join(model_dir, "model.safetensors") + if os.path.exists(mono): + shards = ["model.safetensors"] + elif os.path.exists(index): + with open(index) as f: + shards = sorted(set(json.load(f)["weight_map"].values())) + else: + raise FileNotFoundError( + f"no model.safetensors or model.safetensors.index.json in {model_dir}" + ) + raw = {} + for shard in shards: + with safe_open( + os.path.join(model_dir, shard), framework="pt", device="cpu" + ) as f: + for k in f.keys(): + if k in raw: + raise ValueError(f"duplicate tensor {k!r} across shards ({shard})") + raw[k] = f.get_tensor(k) + return raw diff --git a/examples/models/eagle3/test_draft.py b/examples/models/eagle3/test_draft.py new file mode 100644 index 00000000000..eb2587f5605 --- /dev/null +++ b/examples/models/eagle3/test_draft.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the EAGLE-3 draft head and its checkpoint adapter.""" + +import json +import os + +import pytest +import torch + +from executorch.examples.models.eagle3.draft import Eagle3Config, Eagle3Draft + + +def tiny_config(norm_before_residual=True, has_own_embed=True) -> Eagle3Config: + return Eagle3Config( + hidden_size=32, + target_hidden_size=32, + intermediate_size=64, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + draft_vocab_size=16, + target_vocab_size=40, + aux_hidden_state_layers=[0, 1, 2], + norm_before_residual=norm_before_residual, + has_own_embed=has_own_embed, + ) + + +@pytest.mark.parametrize("norm_before_residual", [True, False]) +@pytest.mark.parametrize("has_own_embed", [True, False]) +def test_tiny_forward_shapes(norm_before_residual, has_own_embed): + torch.manual_seed(0) + cfg = tiny_config(norm_before_residual, has_own_embed) + model = Eagle3Draft(cfg).to(torch.float32).eval() + B, T = 1, 5 + + aux = torch.randn(B, T, len(cfg.aux_hidden_state_layers) * cfg.target_hidden_size) + feat = model.fuse(aux) + assert feat.shape == (B, T, cfg.hidden_size) + + if has_own_embed: + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) + assert emb.shape == (B, T, cfg.hidden_size) + else: + emb = torch.randn(B, T, cfg.hidden_size) + + with torch.no_grad(): + logits, g = model(emb, feat, torch.arange(T)) + assert logits.shape == (B, T, cfg.draft_vocab_size) + assert g.shape == (B, T, cfg.hidden_size) + assert torch.isfinite(logits).all() and torch.isfinite(g).all() + + +def test_norm_before_residual_changes_output(): + # Check residual-path wiring, not only output shape. + B, T = 1, 5 + aux = torch.randn(B, T, 3 * 32) + emb = torch.randn(B, T, 32) + outs = [] + for nbr in (True, False): + torch.manual_seed(1) # identical weights, only the flag differs + model = ( + Eagle3Draft(tiny_config(norm_before_residual=nbr)).to(torch.float32).eval() + ) + with torch.no_grad(): + _, g = model(emb, model.fuse(aux), torch.arange(T)) + outs.append(g) + assert not torch.allclose(outs[0], outs[1]), "norm_before_residual had no effect" + + +def test_draft_to_target_mapping(): + model = Eagle3Draft(tiny_config()).eval() + model.d2t.copy_(torch.arange(model.config.draft_vocab_size)) # offset = id + ids = torch.tensor([0, 3, 7]) + assert torch.equal(model.draft_to_target(ids), ids + ids) + + +def test_embed_requires_own_embed(): + model = Eagle3Draft(tiny_config(has_own_embed=False)).eval() + assert not hasattr(model, "embed_tokens") + with pytest.raises(RuntimeError, match="no own embed_tokens"): + model.embed(torch.tensor([0, 1, 2])) + + +def test_inv_freq_stays_fp32_under_assign_load(): + cfg = tiny_config() + model = Eagle3Draft(cfg) + assert model.midlayer.self_attn.inv_freq.dtype == torch.float32 + sd = { + k: (v.to(torch.bfloat16) if v.is_floating_point() else v) + for k, v in model.state_dict().items() + } + model.load_state_dict(sd, strict=True, assign=True) + assert model.midlayer.self_attn.inv_freq.dtype == torch.float32 + assert model.fc.weight.dtype == torch.bfloat16 + + +def _write_checkpoint(model_dir, cfg, *, sharded=False, norm_before_fc=False): + """Write a tiny speculators-format checkpoint (config.json + safetensors).""" + from safetensors.torch import save_file + + os.makedirs(model_dir, exist_ok=True) + torch.manual_seed(2) + src = Eagle3Draft(cfg) + disk = { + k.replace("midlayer.", "layers.0."): v.clone().contiguous() + for k, v in src.state_dict().items() + } + d2t = torch.arange(cfg.draft_vocab_size, dtype=torch.int64) + t2d = torch.zeros(cfg.target_vocab_size, dtype=torch.bool) + t2d[: cfg.draft_vocab_size] = True + disk["d2t"] = d2t + disk["t2d"] = t2d + + config = { + "draft_vocab_size": cfg.draft_vocab_size, + "target_hidden_size": cfg.target_hidden_size, + "eagle_aux_hidden_state_layer_ids": cfg.aux_hidden_state_layers, + "norm_before_residual": cfg.norm_before_residual, + "norm_before_fc": norm_before_fc, + "transformer_layer_config": { + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, + "head_dim": cfg.head_dim, + "vocab_size": cfg.target_vocab_size, + "rms_norm_eps": cfg.rms_norm_eps, + "rope_parameters": {"rope_theta": cfg.rope_theta}, + }, + } + with open(os.path.join(model_dir, "config.json"), "w") as f: + json.dump(config, f) + + if not sharded: + save_file(disk, os.path.join(model_dir, "model.safetensors")) + else: + keys = list(disk) + half = len(keys) // 2 + s1 = {k: disk[k] for k in keys[:half]} + s2 = {k: disk[k] for k in keys[half:]} + save_file(s1, os.path.join(model_dir, "model-00001-of-00002.safetensors")) + save_file(s2, os.path.join(model_dir, "model-00002-of-00002.safetensors")) + weight_map = {k: "model-00001-of-00002.safetensors" for k in s1} + weight_map.update({k: "model-00002-of-00002.safetensors" for k in s2}) + with open(os.path.join(model_dir, "model.safetensors.index.json"), "w") as f: + json.dump({"weight_map": weight_map}, f) + return src, d2t, t2d + + +@pytest.mark.parametrize("sharded", [False, True]) +def test_from_checkpoint_roundtrip(tmp_path, sharded): + cfg = tiny_config() + src, d2t, t2d = _write_checkpoint(str(tmp_path), cfg, sharded=sharded) + + model, loaded_cfg = Eagle3Draft.from_checkpoint( + str(tmp_path), device="cpu", dtype=torch.float32 + ) + assert loaded_cfg.has_own_embed + assert loaded_cfg.aux_hidden_state_layers == cfg.aux_hidden_state_layers + assert loaded_cfg.target_vocab_size == cfg.target_vocab_size + torch.testing.assert_close( + model.midlayer.self_attn.q_proj.weight, src.midlayer.self_attn.q_proj.weight + ) + torch.testing.assert_close(model.fc.weight, src.fc.weight) + assert torch.equal(model.d2t, d2t) + assert torch.equal(model.t2d, t2d) + assert model.midlayer.self_attn.inv_freq.dtype == torch.float32 + T = 4 + feat = model.fuse(torch.randn(1, T, 3 * cfg.target_hidden_size)) + emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) + logits, g = model(emb, feat, torch.arange(T)) + assert logits.shape == (1, T, cfg.draft_vocab_size) + + +def test_from_checkpoint_rejects_norm_before_fc(tmp_path): + cfg = tiny_config() + _write_checkpoint(str(tmp_path), cfg, norm_before_fc=True) + with pytest.raises(ValueError, match="norm_before_fc"): + Eagle3Draft.from_checkpoint(str(tmp_path), device="cpu", dtype=torch.float32) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/pytest.ini b/pytest.ini index 05aea9d4da6..3dd960987f2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -96,6 +96,7 @@ testpaths = examples/models/llama3_2_vision/vision_encoder/test examples/models/llama3_2_vision/text_decoder/test examples/models/llava/test + examples/models/eagle3/test_draft.py # exir exir/