Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
980 changes: 950 additions & 30 deletions demos/BERT.ipynb

Large diffs are not rendered by default.

171 changes: 171 additions & 0 deletions tests/integration/test_hooked_encoder_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Convenience-property tests for ``HookedEncoder``.

Closes the last open ask in #277 — verify each ``W_*`` / ``b_*`` / circuit
property has the right shape AND aliases the right underlying parameter, so
property-level mech-interp work doesn't silently read the wrong tensor.

Uses a randomly-initialized small encoder (no HF download) so the tests run
fast and deterministically.
"""

from __future__ import annotations

import pytest
import torch

from transformer_lens import FactoredMatrix, HookedEncoder, HookedTransformerConfig

D_MODEL = 12
D_HEAD = 4
N_HEADS = D_MODEL // D_HEAD
D_MLP = 4 * D_MODEL
N_CTX = 5
N_LAYERS = 3
D_VOCAB = 22


@pytest.fixture
def model() -> HookedEncoder:
cfg = HookedTransformerConfig(
d_head=D_HEAD,
d_model=D_MODEL,
n_ctx=N_CTX,
n_layers=N_LAYERS,
act_fn="gelu",
d_vocab=D_VOCAB,
)
encoder = HookedEncoder(cfg)
# HookedEncoder uses torch.empty() for params and does no init pass; the
# uninitialized memory contains NaNs which break torch.equal comparisons.
torch.manual_seed(0)
for p in encoder.parameters():
torch.nn.init.normal_(p, std=0.02)
return encoder


# ---------------------------------------------------------------------------
# Embed / unembed
# ---------------------------------------------------------------------------


def test_W_U(model: HookedEncoder):
assert model.W_U.shape == (D_MODEL, D_VOCAB)
assert model.W_U is model.unembed.W_U


def test_b_U(model: HookedEncoder):
assert model.b_U.shape == (D_VOCAB,)
assert model.b_U is model.unembed.b_U


def test_W_E(model: HookedEncoder):
assert model.W_E.shape == (D_VOCAB, D_MODEL)
assert model.W_E is model.embed.embed.W_E


def test_W_pos(model: HookedEncoder):
assert model.W_pos.shape == (N_CTX, D_MODEL)
assert model.W_pos is model.embed.pos_embed.W_pos


@pytest.mark.xfail(
reason=(
"HookedEncoder.W_E_pos return annotation 'd_vocab+n_ctx d_model' references "
"unbound dimension names (no input args supply them), so the jaxtyping import-hook "
"can't resolve the sum at runtime. Same annotation exists on HookedTransformer.W_E_pos; "
"fixing it is a separate API-touch."
),
strict=True,
)
def test_W_E_pos(model: HookedEncoder):
assert model.W_E_pos.shape == (D_VOCAB + N_CTX, D_MODEL)
# Concatenation, so identity doesn't apply — verify the slices match.
assert torch.equal(model.W_E_pos[:D_VOCAB], model.W_E)
assert torch.equal(model.W_E_pos[D_VOCAB:], model.W_pos)


# ---------------------------------------------------------------------------
# Per-layer attention weights/biases — stacked across blocks
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("attr", ["W_Q", "W_K", "W_V"])
def test_attn_qkv_weight(model: HookedEncoder, attr: str):
stacked = getattr(model, attr)
assert stacked.shape == (N_LAYERS, N_HEADS, D_MODEL, D_HEAD)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(stacked[layer_idx], getattr(block.attn, attr))


def test_W_O(model: HookedEncoder):
assert model.W_O.shape == (N_LAYERS, N_HEADS, D_HEAD, D_MODEL)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.W_O[layer_idx], block.attn.W_O)


@pytest.mark.parametrize("attr", ["b_Q", "b_K", "b_V"])
def test_attn_qkv_bias(model: HookedEncoder, attr: str):
stacked = getattr(model, attr)
assert stacked.shape == (N_LAYERS, N_HEADS, D_HEAD)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(stacked[layer_idx], getattr(block.attn, attr))


def test_b_O(model: HookedEncoder):
assert model.b_O.shape == (N_LAYERS, D_MODEL)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.b_O[layer_idx], block.attn.b_O)


# ---------------------------------------------------------------------------
# Per-layer MLP weights/biases — stacked across blocks
# ---------------------------------------------------------------------------


def test_W_in(model: HookedEncoder):
assert model.W_in.shape == (N_LAYERS, D_MODEL, D_MLP)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.W_in[layer_idx], block.mlp.W_in)


def test_W_out(model: HookedEncoder):
assert model.W_out.shape == (N_LAYERS, D_MLP, D_MODEL)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.W_out[layer_idx], block.mlp.W_out)


def test_b_in(model: HookedEncoder):
assert model.b_in.shape == (N_LAYERS, D_MLP)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.b_in[layer_idx], block.mlp.b_in)


def test_b_out(model: HookedEncoder):
assert model.b_out.shape == (N_LAYERS, D_MODEL)
for layer_idx, block in enumerate(model.blocks):
assert torch.equal(model.b_out[layer_idx], block.mlp.b_out)


# ---------------------------------------------------------------------------
# Factored circuits
# ---------------------------------------------------------------------------


def test_QK_circuit(model: HookedEncoder):
qk = model.QK
assert isinstance(qk, FactoredMatrix)
# Left factor is W_Q [..., d_model, d_head]; right factor is W_K transposed
# to [..., d_head, d_model]. Their product would be [..., d_model, d_model].
assert qk.A.shape == (N_LAYERS, N_HEADS, D_MODEL, D_HEAD)
assert qk.B.shape == (N_LAYERS, N_HEADS, D_HEAD, D_MODEL)
assert torch.equal(qk.A, model.W_Q)
assert torch.equal(qk.B, model.W_K.transpose(-2, -1))


def test_OV_circuit(model: HookedEncoder):
ov = model.OV
assert isinstance(ov, FactoredMatrix)
assert ov.A.shape == (N_LAYERS, N_HEADS, D_MODEL, D_HEAD)
assert ov.B.shape == (N_LAYERS, N_HEADS, D_HEAD, D_MODEL)
assert torch.equal(ov.A, model.W_V)
assert torch.equal(ov.B, model.W_O)
40 changes: 40 additions & 0 deletions tests/unit/components/mlps/test_gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformer_lens.components import GatedMLP, LayerNorm
from transformer_lens.utils import solu
Expand Down Expand Up @@ -39,3 +40,42 @@ def test_forward(cfg: Dict[str, Any]):
x = torch.randn(2, 10, cfg["d_model"])
output = model(x)
assert output.shape == (2, 10, cfg["d_model"])


def test_forward_matches_reference_equation():
"""Numeric equivalence vs a hand-rolled gated-MLP reference (issue #264).

Closes the original ask in the thread: build an "equivalent gated MLP in
pytorch" and confirm the component matches it under ``torch.allclose``.
Uses ``silu`` so the LN-activation branch is not exercised — that keeps the
reference equation to the documented form.
"""
cfg: Dict[str, Any] = {
"n_layers": 1,
"n_ctx": 16,
"d_head": 32,
"d_model": 64,
"d_mlp": 128,
"dtype": torch.float32,
"act_fn": "silu",
"normalization_type": None,
"load_in_4bit": False,
}
torch.manual_seed(0)
model = GatedMLP(cfg).eval()
# Randomize the params so the test isn't run against zero-bias defaults.
for p in model.parameters():
torch.nn.init.normal_(p, std=0.02)

x = torch.randn(2, 5, cfg["d_model"])
actual = model(x)

# Reference: mlp_out = (silu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out.
# GatedMLP uses F.linear with .T.contiguous() to match HF accumulation order;
# mirror that here so the two compute graphs are bitwise comparable in fp32.
pre_act = F.linear(x, model.W_gate.T.contiguous())
pre_linear = F.linear(x, model.W_in.T.contiguous())
post_act = F.silu(pre_act) * pre_linear + model.b_in
expected = F.linear(post_act, model.W_out.T.contiguous(), model.b_out)

assert torch.allclose(actual, expected, atol=1e-6)
6 changes: 6 additions & 0 deletions transformer_lens/config/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ class HookedTransformerConfig(TransformerLensConfig):
use_hook_tokens (bool): Will add a hook point on the token input to
HookedTransformer.forward, which lets you cache or intervene on the tokens.
Defaults to False.
gated_mlp (bool): If True, the MLP layer uses a gated formulation
(SwiGLU/GeGLU-style): ``mlp_out = W_out @ (act_fn(W_gate @ x) * (W_in @ x))``,
with an extra ``W_gate`` weight matrix alongside ``W_in`` and ``W_out``. Used by
LLaMA, Mistral, Gemma, Qwen and similar families. When False (default), the MLP
is the plain ``mlp_out = W_out @ act_fn(W_in @ x)`` form. ``loading_from_pretrained``
sets this automatically per architecture; only set manually for a custom config.
default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
methods of HookedTransformer process input text to tokenize (only when input is a string).
Defaults to True - even for models not explicitly trained with this, heads often use the
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter,
"StableLmForCausalLM": StableLmArchitectureAdapter,
"T5ForConditionalGeneration": T5ArchitectureAdapter,
"MT5ForConditionalGeneration": T5ArchitectureAdapter,
"XGLMForCausalLM": XGLMArchitectureAdapter,
"NanoGPTForCausalLM": NanogptArchitectureAdapter,
"MinGPTForCausalLM": MingptArchitectureAdapter,
Expand Down
17 changes: 17 additions & 0 deletions transformer_lens/model_bridge/generalized_components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
requires_position_embeddings: bool = False,
requires_attention_mask: bool = False,
attention_mask_4d: bool = False,
requires_relative_position_bias: bool = False,
is_cross_attention: bool = False,
optional: bool = False,
):
"""Initialize the attention bridge.
Expand All @@ -78,6 +80,9 @@ def __init__(
(e.g., GPTNeoX/Pythia). Defaults to False.
attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len]
instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
requires_relative_position_bias: T5/mT5-style relative attention; supplies a
zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback.
is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``.
"""
if conversion_rule is None:
conversion_rule = AttentionAutoConversion(config)
Expand Down Expand Up @@ -122,6 +127,8 @@ def __init__(
self.requires_position_embeddings = requires_position_embeddings
self.requires_attention_mask = requires_attention_mask
self.attention_mask_4d = attention_mask_4d
self.requires_relative_position_bias = requires_relative_position_bias
self.is_cross_attention = is_cross_attention
self._layer_idx: Optional[int] = None

def set_original_component(self, original_component: torch.nn.Module) -> None:
Expand Down Expand Up @@ -212,6 +219,16 @@ def get_random_inputs(
else:
# Generate 2D attention mask [batch, seq_len] for most models
inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device)
if self.requires_relative_position_bias:
# Zero bias short-circuits HF's None-cache_position fallback in T5Attention.
n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1
inputs["position_bias"] = torch.zeros(
1, n_heads, seq_len, seq_len, device=device, dtype=dtype
)
if self.is_cross_attention:
inputs["key_value_states"] = torch.randn(
batch_size, seq_len, d_model, device=device, dtype=dtype
)
return inputs

def _setup_qkv_hook_reshaping(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def determine_architecture_from_hf_config(hf_config):
"openelm": "OpenELMForCausalLM",
"stablelm": "StableLmForCausalLM",
"t5": "T5ForConditionalGeneration",
"mt5": "MT5ForConditionalGeneration",
}
if model_type in model_type_mappings:
architectures.append(model_type_mappings[model_type])
Expand Down
4 changes: 4 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, cfg: Any) -> None:
"v": LinearBridge(name="v"),
"o": LinearBridge(name="o"),
},
requires_relative_position_bias=True,
),
"ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
"mlp": encoder_mlp,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(self, cfg: Any) -> None:
"v": LinearBridge(name="v"),
"o": LinearBridge(name="o"),
},
requires_relative_position_bias=True,
),
"ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
"cross_attn": AttentionBridge(
Expand All @@ -153,6 +155,8 @@ def __init__(self, cfg: Any) -> None:
"v": LinearBridge(name="v"),
"o": LinearBridge(name="o"),
},
requires_relative_position_bias=True,
is_cross_attention=True,
),
"ln3": RMSNormalizationBridge(name="layer.2.layer_norm", config=self.cfg),
"mlp": decoder_mlp,
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/tools/model_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"Qwen3_5ForCausalLM",
"StableLmForCausalLM",
"T5ForConditionalGeneration",
"MT5ForConditionalGeneration",
"XGLMForCausalLM",
}

Expand Down
Loading
Loading