Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 98 additions & 7 deletions examples/models/gemma4_31b/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def update(
# Config


def validate_eagle_tap_layers(layers: list, num_hidden_layers: int) -> None:
"""Validate EAGLE-3 tap indices (HF/vLLM convention; 0 = embedding output).

Indices must be non-bool ints, unique, ascending, and in
``[0, num_hidden_layers]``. Order defines the fc concatenation order.
"""
if not layers:
return
if any(isinstance(t, bool) or not isinstance(t, int) for t in layers):
raise ValueError(f"eagle_tap_layers must be non-bool ints, got {layers}")
if len(set(layers)) != len(layers):
raise ValueError(f"eagle_tap_layers has duplicates: {layers}")
if any(t < 0 or t > num_hidden_layers for t in layers):
raise ValueError(
f"eagle_tap_layers {layers} out of range [0, {num_hidden_layers}]"
)
if list(layers) != sorted(layers):
raise ValueError(f"eagle_tap_layers must be ascending (fc order): {layers}")


@dataclass
class Gemma4_31BConfig:
# Embedding / shape
Expand Down Expand Up @@ -144,6 +164,11 @@ class Gemma4_31BConfig:
# Runtime
max_seq_len: int = 4096

# EAGLE-3 auxiliary hidden-state taps. Indices use the HF/vLLM convention:
# 0 = embedding output, k = output after decoder layer k-1. Empty disables
# tap collection.
eagle_tap_layers: list = field(default_factory=list)

def __post_init__(self):
if not self.layer_types:
# Default hybrid pattern: 5 sliding then 1 full, repeated.
Expand All @@ -156,6 +181,7 @@ def __post_init__(self):
f"layer_types length {len(self.layer_types)} != "
f"num_hidden_layers {self.num_hidden_layers}"
)
validate_eagle_tap_layers(self.eagle_tap_layers, self.num_hidden_layers)

@staticmethod
def from_hf_config(config_path: str) -> "Gemma4_31BConfig":
Expand Down Expand Up @@ -466,6 +492,48 @@ def _build_masks(

return sliding_mask, full_mask

def _decode(
self,
tokens: torch.LongTensor,
input_pos: torch.LongTensor,
collect_taps: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Embed -> decoder layers -> final norm.

Returns the normed hidden states and, when ``collect_taps`` is set, the
concatenated tap features for ``config.eagle_tap_layers`` (in ascending
index order) as ``(B, T, len(tap_layers) * hidden_size)``; else None.

Tap indices follow the HF/vLLM hidden-state convention: index 0 is the
embedding output (before any decoder layer) and index k is the output
*after* decoder layer k-1.
"""
x = self.embed_tokens(tokens) * self.embed_normalizer

tap_layers = self.config.eagle_tap_layers
if collect_taps:
# Revalidate dynamic tap configuration before membership checks.
validate_eagle_tap_layers(tap_layers, len(self.layers))
taps = []
if collect_taps and 0 in tap_layers:
taps.append(x) # index 0 == embedding output

sliding_mask, full_mask = self._build_masks(input_pos)
for i, layer in enumerate(self.layers):
x = layer(x, input_pos, sliding_mask, full_mask)
if collect_taps and (i + 1) in tap_layers:
taps.append(x) # output of layer i == hidden-state index i+1

if collect_taps and len(taps) != len(tap_layers):
raise ValueError(
f"collected {len(taps)} taps but eagle_tap_layers requests "
f"{len(tap_layers)} ({tap_layers}); check the index convention"
)

x = self.norm(x)
taps_out = torch.cat(taps, dim=-1) if taps else None
return x, taps_out

def forward(
self,
tokens: torch.LongTensor,
Expand All @@ -482,18 +550,41 @@ def forward(
Returns:
(B, 1) sampled token IDs as float.
"""
x = self.embed_tokens(tokens) * self.embed_normalizer

sliding_mask, full_mask = self._build_masks(input_pos)
for layer in self.layers:
x = layer(x, input_pos, sliding_mask, full_mask)

x = self.norm(x)
x, _ = self._decode(tokens, input_pos, collect_taps=False)
last = self.lm_head(x[:, -1, :]).float()
cap = self.logit_softcap.float()
last = torch.tanh(last / cap) * cap
return sample(last, temperature)

def set_eagle_tap_layers(self, layers: list) -> None:
"""Set and validate EAGLE-3 tap layers."""
validate_eagle_tap_layers(layers, self.config.num_hidden_layers)
self.config.eagle_tap_layers = list(layers)

def forward_logits_taps(
self,
tokens: torch.LongTensor,
input_pos: torch.LongTensor,
last_logits_only: bool = True,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tapped model, do we need all logits, not just final so sampling can be done outside model?

If so, why is last_logits_only default to True on forward_logits_taps?

) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Return soft-capped logits and EAGLE-3 tap features.

Defaults to final-position logits. Set ``last_logits_only=False`` to
materialize per-position float32 logits over the full vocabulary.

Returns:
logits: (B, 1, vocab_size) soft-capped float32, or (B, T, vocab_size)
when ``last_logits_only=False``.
taps: (B, T, len(eagle_tap_layers) * hidden_size) or None.
"""
x, taps = self._decode(tokens, input_pos, collect_taps=True)
if last_logits_only:
x = x[:, -1:, :]
logits = self.lm_head(x).float()
cap = self.logit_softcap.float()
logits = torch.tanh(logits / cap) * cap
return logits, taps

# ---------------- checkpoint loading ----------------

@staticmethod
Expand Down
141 changes: 141 additions & 0 deletions examples/models/gemma4_31b/test_eagle_tap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Unit tests for the gemma4-31B EAGLE-3 hidden-state tap.

Covers the tap-index convention (HF/vLLM: index 0 = embedding, index k = output
after decoder layer k-1), exact concatenation order/content, config validation
(including the runtime-mutation path), and that the default decode path is
unaffected by enabling the tap.
"""

import pytest
import torch

from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig


def tiny_config(num_layers=6, tap_layers=None) -> Gemma4_31BConfig:
return Gemma4_31BConfig(
vocab_size=128,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=num_layers,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
num_global_key_value_heads=1,
global_head_dim=8,
sliding_window=8,
max_seq_len=32,
eagle_tap_layers=tap_layers or [],
)


def build(num_layers=6, tap_layers=None):
torch.manual_seed(0)
return Gemma4_31B(tiny_config(num_layers, tap_layers)).to(torch.float32).eval()


def reset_kv(model):
"""Zero the (stateful) KV caches so independent forwards don't couple."""
for name, buf in model.named_buffers():
if ".kv_cache." in name:
buf.zero_()


def reference_states(model, tokens, input_pos):
"""Recompute _decode's per-index states: 0=embedding, k=after layer k-1."""
x = model.embed_tokens(tokens) * model.embed_normalizer
states = {0: x}
sliding_mask, full_mask = model._build_masks(input_pos)
for i, layer in enumerate(model.layers):
x = layer(x, input_pos, sliding_mask, full_mask)
states[i + 1] = x
return states


def test_tap_off_does_not_change_logits():
model = build(tap_layers=[1, 2, 3])
T = 7
tokens = torch.randint(0, 128, (1, T))
pos = torch.arange(T)
with torch.no_grad():
reset_kv(model)
logits_on, taps_on = model.forward_logits_taps(
tokens, pos, last_logits_only=False
)
model.config.eagle_tap_layers = []
reset_kv(model)
logits_off, taps_off = model.forward_logits_taps(
tokens, pos, last_logits_only=False
)
assert taps_off is None
assert taps_on.shape == (1, T, 3 * model.config.hidden_size)
torch.testing.assert_close(logits_on, logits_off)


@pytest.mark.parametrize(
"num_layers,tap_layers",
[
(6, [0, 1, 3]),
(60, [2, 30, 57]),
],
)
def test_tap_collects_exact_states_in_order(num_layers, tap_layers):
model = build(num_layers=num_layers, tap_layers=tap_layers)
T = 5
tokens = torch.randint(0, 128, (1, T))
pos = torch.arange(T)
with torch.no_grad():
reset_kv(model)
_, taps = model.forward_logits_taps(tokens, pos)
reset_kv(model)
states = reference_states(model, tokens, pos)
expected = torch.cat([states[i] for i in tap_layers], dim=-1)
assert taps.shape == (1, T, len(tap_layers) * model.config.hidden_size)
torch.testing.assert_close(taps, expected, rtol=0, atol=0)


def test_last_logits_only_default_matches_full():
model = build(tap_layers=[1])
T = 4
tokens = torch.randint(0, 128, (1, T))
pos = torch.arange(T)
with torch.no_grad():
reset_kv(model)
full, _ = model.forward_logits_taps(tokens, pos, last_logits_only=False)
reset_kv(model)
last, _ = model.forward_logits_taps(tokens, pos)
assert last.shape == (1, 1, model.config.vocab_size)
torch.testing.assert_close(last[:, 0], full[:, -1])


@pytest.mark.parametrize("bad", [[99], [1, 1], [1.0, 2], [True], [3, 1]])
def test_invalid_tap_config_rejected(bad):
with pytest.raises(ValueError):
tiny_config(num_layers=6, tap_layers=bad)


def test_set_eagle_tap_layers_validates():
model = build()
model.set_eagle_tap_layers([0, 2, 4])
assert model.config.eagle_tap_layers == [0, 2, 4]
with pytest.raises(ValueError):
model.set_eagle_tap_layers([4, 2])


def test_runtime_mutation_is_revalidated_in_decode():
model = build(tap_layers=[1, 2])
model.config.eagle_tap_layers = [True]
tokens = torch.randint(0, 128, (1, 4))
pos = torch.arange(4)
with pytest.raises(ValueError):
model.forward_logits_taps(tokens, pos, last_logits_only=False)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-q"]))
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ testpaths =
examples/models/llama3_2_vision/text_decoder/test
examples/models/llava/test
examples/models/eagle3/test_draft.py
examples/models/gemma4_31b/test_eagle_tap.py

# exir
exir/
Expand Down
Loading