From e1d33de681aff1ff04e0a03ca986137b3b4fb166 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 7 Apr 2026 16:20:06 -0700 Subject: [PATCH 1/9] Add RoPE utilities for TriAttention scoring Signed-off-by: Kai Xu --- modelopt/torch/sparsity/kv_cache/__init__.py | 18 +++ .../kv_cache/triattention/__init__.py | 27 ++++ .../kv_cache/triattention/rope_utils.py | 141 ++++++++++++++++++ .../unit/torch/sparsity/kv_cache/__init__.py | 15 ++ .../sparsity/kv_cache/test_rope_utils.py | 99 ++++++++++++ 5 files changed, 300 insertions(+) create mode 100644 modelopt/torch/sparsity/kv_cache/__init__.py create mode 100644 modelopt/torch/sparsity/kv_cache/triattention/__init__.py create mode 100644 modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py create mode 100644 tests/unit/torch/sparsity/kv_cache/__init__.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_rope_utils.py diff --git a/modelopt/torch/sparsity/kv_cache/__init__.py b/modelopt/torch/sparsity/kv_cache/__init__.py new file mode 100644 index 00000000000..feb0b7e3f78 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/__init__.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""KV cache sparsity algorithms for LLM inference optimization.""" diff --git a/modelopt/torch/sparsity/kv_cache/triattention/__init__.py b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py new file mode 100644 index 00000000000..4aaae7a51d9 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""TriAttention: Trigonometric KV cache compression.""" + +from .rope_utils import build_geometric_offsets, invert_rope, rotate_half, to_complex_pairs + +__all__ = [ + "build_geometric_offsets", + "invert_rope", + "rotate_half", + "to_complex_pairs", +] diff --git a/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py b/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py new file mode 100644 index 00000000000..70d7f59e21d --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""RoPE inversion and frequency-domain utilities for TriAttention. + +These functions support the TriAttention scoring algorithm by: +- Inverting RoPE rotations to recover pre-RoPE Q/K representations +- Converting real-valued tensors to complex frequency-domain representations +- Building geometric offset sequences for multi-distance scoring +""" + +from __future__ import annotations + +import torch + +__all__ = [ + "build_geometric_offsets", + "invert_rope", + "rotate_half", + "to_complex_pairs", +] + + +def rotate_half(x: torch.Tensor, *, style: str = "half") -> torch.Tensor: + """Rotate tensor for RoPE. Supports 'half' (front/back) and 'interleaved' (even/odd).""" + if style == "interleaved": + x_even = x[..., ::2] + x_odd = x[..., 1::2] + return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + return torch.cat((-x2, x1), dim=-1) + + +def invert_rope( + rotated: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + scale: float, + *, + style: str = "half", +) -> torch.Tensor: + """Invert RoPE rotation to recover pre-RoPE representation. + + Args: + rotated: RoPE-rotated tensor. + cos: Cosine table from rotary embedding. + sin: Sine table from rotary embedding. + scale: Attention scaling factor applied during RoPE. + style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + Pre-RoPE tensor with RoPE rotation undone. + """ + if scale == 0: + raise ValueError("attention scaling factor must be non-zero") + scale_t = torch.tensor(scale, device=rotated.device, dtype=rotated.dtype) + base = rotated / scale_t + cos_unit = cos / scale_t + sin_unit = sin / scale_t + if style == "interleaved": + even = base[..., ::2] + odd = base[..., 1::2] + cos_even = cos_unit[..., ::2] + cos_odd = cos_unit[..., 1::2] + sin_even = sin_unit[..., ::2] + sin_odd = sin_unit[..., 1::2] + det = cos_even * cos_odd + sin_even * sin_odd + det = det.clamp_min(1e-12) + orig_even = (even * cos_odd + odd * sin_even) / det + orig_odd = (odd * cos_even - even * sin_odd) / det + restored = torch.empty_like(base) + restored[..., ::2] = orig_even + restored[..., 1::2] = orig_odd + return restored + return base * cos_unit - rotate_half(base, style=style) * sin_unit + + +def to_complex_pairs(tensor: torch.Tensor, *, style: str = "half") -> torch.Tensor: + """Convert real tensor to complex representation for frequency analysis. + + Maps head_dim real values to head_dim/2 complex values. For 'half' style: + real part = first half of dimensions, imag part = second half. + + Args: + tensor: Real-valued tensor with even last dimension. + style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + Complex tensor with last dimension halved. + """ + if tensor.size(-1) % 2 != 0: + raise ValueError("Head dimension must be even to form complex pairs") + real_dtype = torch.float32 if tensor.dtype in (torch.bfloat16, torch.float16) else tensor.dtype + tensor_real = tensor.to(dtype=real_dtype) + if style == "interleaved": + real = tensor_real[..., ::2].contiguous() + imag = tensor_real[..., 1::2].contiguous() + return torch.complex(real, imag) + freq_count = tensor.shape[-1] // 2 + real = tensor_real[..., :freq_count].contiguous() + imag = tensor_real[..., freq_count:].contiguous() + return torch.complex(real, imag) + + +def build_geometric_offsets(max_length: int, device: torch.device) -> torch.Tensor: + """Build geometric offset sequence [1, 2, 4, 8, ..., max_length]. + + Used for multi-distance scoring in TriAttention — each offset represents a + future distance at which the key's importance is evaluated. + + Args: + max_length: Maximum offset value (must be >= 1). + device: Device for the output tensor. + + Returns: + 1D float tensor of powers of 2 up to max_length. + """ + if max_length < 1: + raise ValueError("max_length must be >= 1") + offsets: list[float] = [] + value = 1 + while value <= max_length: + offsets.append(float(value)) + value *= 2 + return torch.tensor(offsets, device=device, dtype=torch.float32) diff --git a/tests/unit/torch/sparsity/kv_cache/__init__.py b/tests/unit/torch/sparsity/kv_cache/__init__.py new file mode 100644 index 00000000000..47f1c65a15f --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py b/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py new file mode 100644 index 00000000000..a61bee0af76 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention RoPE utilities.""" + +import pytest +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.rope_utils import ( + build_geometric_offsets, + invert_rope, + rotate_half, + to_complex_pairs, +) + + +def test_rotate_half_roundtrip(): + """rotate_half applied twice returns the negated original.""" + x = torch.randn(2, 4, 8) + result = rotate_half(rotate_half(x)) + torch.testing.assert_close(result, -x) + + +def test_invert_rope_recovers_original(): + """Inverting RoPE-rotated tensor recovers the pre-RoPE original.""" + head_dim = 16 + seq_len = 8 + x = torch.randn(1, seq_len, head_dim) + + freqs = torch.arange(head_dim // 2, dtype=torch.float32) / head_dim + positions = torch.arange(seq_len, dtype=torch.float32).unsqueeze(-1) + angles = positions * freqs + cos = torch.cos(angles).repeat(1, 2).unsqueeze(0) + sin = torch.sin(angles).repeat(1, 2).unsqueeze(0) + scale = 1.0 + + rotated = x * cos + rotate_half(x) * sin + recovered = invert_rope(rotated * scale, cos, sin, scale) + torch.testing.assert_close(recovered, x, atol=1e-5, rtol=1e-5) + + +def test_invert_rope_zero_scale_raises(): + """Zero scale raises ValueError.""" + with pytest.raises(ValueError, match="non-zero"): + invert_rope(torch.randn(1, 4, 8), torch.ones(1, 4, 8), torch.ones(1, 4, 8), 0.0) + + +def test_to_complex_pairs_shape(): + """Complex pairs halves the last dimension.""" + c = to_complex_pairs(torch.randn(4, 16)) + assert c.shape == (4, 8) + assert c.dtype == torch.complex64 + + +def test_to_complex_pairs_values(): + """Half style: real = first half, imag = second half.""" + x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]) + c = to_complex_pairs(x) + assert c[0, 0].real == 1.0 + assert c[0, 0].imag == 5.0 + assert c[0, 3].real == 4.0 + assert c[0, 3].imag == 8.0 + + +def test_to_complex_pairs_odd_dim_raises(): + """Odd head dimension raises ValueError.""" + with pytest.raises(ValueError, match="even"): + to_complex_pairs(torch.randn(4, 7)) + + +def test_build_geometric_offsets(): + """Geometric offsets are powers of 2 up to max_length.""" + offsets = build_geometric_offsets(16, torch.device("cpu")) + expected = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0]) + torch.testing.assert_close(offsets, expected) + + +def test_build_geometric_offsets_single(): + """Single offset when max_length is 1.""" + offsets = build_geometric_offsets(1, torch.device("cpu")) + torch.testing.assert_close(offsets, torch.tensor([1.0])) + + +def test_build_geometric_offsets_zero_raises(): + """max_length < 1 raises ValueError.""" + with pytest.raises(ValueError, match="must be >= 1"): + build_geometric_offsets(0, torch.device("cpu")) From f1fbc088bf3bd89cff48f32ea3c8b184d596fa44 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 7 Apr 2026 23:16:42 -0700 Subject: [PATCH 2/9] Add trigonometric scoring algorithm Signed-off-by: Kai Xu --- .../sparsity/kv_cache/triattention/scoring.py | 130 ++++++++++++ .../torch/sparsity/kv_cache/test_scoring.py | 186 ++++++++++++++++++ 2 files changed, 316 insertions(+) create mode 100644 modelopt/torch/sparsity/kv_cache/triattention/scoring.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_scoring.py diff --git a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py new file mode 100644 index 00000000000..e91d00b2ae2 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trigonometric scoring for TriAttention KV cache compression. + +Scores cached keys by predicted future attention importance using a trigonometric +series derived from pre-RoPE Q/K concentration. See arXiv:2604.04921. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from .rope_utils import to_complex_pairs + +__all__ = [ + "HeadFrequencyStats", + "compute_frequency_statistics_from_means", + "score_keys_for_round", +] + + +@dataclass +class HeadFrequencyStats: + """Per-head calibration statistics in frequency domain.""" + + q_mean_complex: torch.Tensor # (freq_count,) complex64 + q_abs_mean: torch.Tensor # (freq_count,) float32 + + +def compute_frequency_statistics_from_means( + q_mean_complex: torch.Tensor, + q_abs_mean: torch.Tensor, + k_unrot: torch.Tensor, + *, + style: str = "half", + disable_mlr: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute amplitude, phase, and MLR extra term from Q/K frequency statistics. + + Args: + q_mean_complex: Mean of Q in complex frequency domain, shape (freq_count,). + q_abs_mean: Mean of |Q| in frequency domain, shape (freq_count,). + k_unrot: Unrotated key vectors, shape (num_keys, head_dim). + style: RoPE pairing style. + disable_mlr: If True, use q_abs_mean directly instead of (q_abs_mean - |q_mean|). + + Returns: + amp: Amplitude, shape (num_keys, freq_count). + phi: Phase, shape (num_keys, freq_count). + extra: MLR extra term, shape (num_keys, freq_count). + """ + k_complex = to_complex_pairs(k_unrot, style=style) + q_mean_abs = torch.abs(q_mean_complex) + k_abs = torch.abs(k_complex) + relative = q_mean_complex.unsqueeze(0) * torch.conj(k_complex) + phi = torch.atan2(relative.imag, relative.real) + amp = q_mean_abs.unsqueeze(0) * k_abs + if disable_mlr: + extra = q_abs_mean.unsqueeze(0) * k_abs + else: + extra = (q_abs_mean - q_mean_abs).unsqueeze(0) * k_abs + return amp, phi, extra + + +def score_keys_for_round( + key_indices: torch.Tensor, + round_start: int, + amp: torch.Tensor, + phi: torch.Tensor, + omega: torch.Tensor, + extra: torch.Tensor, + offsets: torch.Tensor, + aggregation: str, + freq_scale_sq: torch.Tensor, + disable_trig: bool = False, +) -> torch.Tensor: + """Score cached keys for a single pruning round. + + Evaluates the trigonometric importance formula over multiple future offsets + and aggregates scores. + + Args: + key_indices: Position indices of cached keys, shape (num_keys,). + round_start: Current generation position. + amp: Amplitude per key per frequency, shape (num_keys, freq_count). + phi: Phase per key per frequency, shape (num_keys, freq_count). + omega: RoPE frequencies (inv_freq), shape (freq_count,). + extra: MLR extra term, shape (num_keys, freq_count). + offsets: Geometric offsets for future distance sampling, shape (num_offsets,). + aggregation: 'mean' or 'max' over offsets. + freq_scale_sq: Per-frequency scaling weights, shape (freq_count,). + disable_trig: If True, use only the additive (MLR) term. + + Returns: + Importance scores, shape (num_keys,). Higher = more important. + """ + if key_indices.numel() == 0: + return torch.empty(0, device=amp.device, dtype=torch.float32) + + base_delta = round_start - key_indices.to(device=amp.device, dtype=torch.float32) + delta_grid = base_delta.unsqueeze(1) + offsets.unsqueeze(0) # (num_keys, num_offsets) + + freq_scale_sq = freq_scale_sq.to(device=amp.device, dtype=torch.float32) + phase = delta_grid.unsqueeze(2) * omega.view(1, 1, -1) + phi.unsqueeze(1) + + cos_phase = torch.cos(phase) + scale = freq_scale_sq.view(1, 1, -1) + base_scores = (amp.unsqueeze(1) * scale * cos_phase).sum(dim=2) + + additive = (extra * freq_scale_sq.view(1, -1)).sum(dim=1, keepdim=True) + combined = additive if disable_trig else (base_scores + additive) + + if aggregation == "mean": + return combined.mean(dim=1) + return combined.max(dim=1).values diff --git a/tests/unit/torch/sparsity/kv_cache/test_scoring.py b/tests/unit/torch/sparsity/kv_cache/test_scoring.py new file mode 100644 index 00000000000..13e89c7e5d1 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_scoring.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention trigonometric scoring.""" + +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.rope_utils import build_geometric_offsets +from modelopt.torch.sparsity.kv_cache.triattention.scoring import ( + HeadFrequencyStats, + compute_frequency_statistics_from_means, + score_keys_for_round, +) + + +def test_compute_frequency_statistics_shapes(): + """Frequency statistics have correct shapes.""" + freq_count = 8 + seq_len = 16 + head_dim = freq_count * 2 + + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count) + k_unrot = torch.randn(seq_len, head_dim) + + amp, phi, extra = compute_frequency_statistics_from_means(q_mean_complex, q_abs_mean, k_unrot) + assert amp.shape == (seq_len, freq_count) + assert phi.shape == (seq_len, freq_count) + assert extra.shape == (seq_len, freq_count) + + +def test_compute_frequency_statistics_amplitude_positive(): + """Amplitude is non-negative (product of absolute values).""" + freq_count = 4 + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count).abs() + 0.1 + k_unrot = torch.randn(8, freq_count * 2) + + amp, _, _ = compute_frequency_statistics_from_means(q_mean_complex, q_abs_mean, k_unrot) + assert (amp >= 0).all() + + +def test_compute_frequency_statistics_disable_mlr(): + """With disable_mlr=True, extra uses q_abs_mean directly.""" + freq_count = 4 + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count) + 1.0 + k_unrot = torch.randn(8, freq_count * 2) + + _, _, extra_normal = compute_frequency_statistics_from_means( + q_mean_complex, q_abs_mean, k_unrot, disable_mlr=False + ) + _, _, extra_disabled = compute_frequency_statistics_from_means( + q_mean_complex, q_abs_mean, k_unrot, disable_mlr=True + ) + assert not torch.allclose(extra_normal, extra_disabled) + + +def test_score_keys_for_round_shape(): + """Score output matches number of keys.""" + num_keys = 32 + freq_count = 8 + key_indices = torch.arange(num_keys) + amp = torch.rand(num_keys, freq_count) + phi = torch.rand(num_keys, freq_count) + omega = torch.rand(freq_count, dtype=torch.float64) + extra = torch.rand(num_keys, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores = score_keys_for_round( + key_indices, + round_start=64, + amp=amp, + phi=phi, + omega=omega, + extra=extra, + offsets=offsets, + aggregation="mean", + freq_scale_sq=freq_scale_sq, + ) + assert scores.shape == (num_keys,) + + +def test_score_keys_empty(): + """Empty key set returns empty scores.""" + scores = score_keys_for_round( + key_indices=torch.tensor([], dtype=torch.long), + round_start=100, + amp=torch.empty(0, 4), + phi=torch.empty(0, 4), + omega=torch.rand(4, dtype=torch.float64), + extra=torch.empty(0, 4), + offsets=build_geometric_offsets(16, torch.device("cpu")), + aggregation="mean", + freq_scale_sq=torch.ones(4), + ) + assert scores.numel() == 0 + + +def test_score_aggregation_mean_vs_max(): + """Mean and max aggregation produce different results.""" + num_keys = 10 + freq_count = 4 + key_indices = torch.arange(num_keys) + amp = torch.rand(num_keys, freq_count) + phi = torch.rand(num_keys, freq_count) + omega = torch.rand(freq_count, dtype=torch.float64) + extra = torch.rand(num_keys, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores_mean = score_keys_for_round( + key_indices, + 50, + amp, + phi, + omega, + extra, + offsets, + "mean", + freq_scale_sq, + ) + scores_max = score_keys_for_round( + key_indices, + 50, + amp, + phi, + omega, + extra, + offsets, + "max", + freq_scale_sq, + ) + assert not torch.allclose(scores_mean, scores_max) + + +def test_score_keys_disable_trig(): + """With disable_trig=True, scores are position-independent (additive only).""" + freq_count = 4 + # Two keys at very different positions + key_indices = torch.tensor([0, 99]) + # Large amplitude so trig term dominates when enabled + amp = torch.ones(2, freq_count) * 10.0 + phi = torch.zeros(2, freq_count) + omega = torch.tensor([0.1, 0.5, 1.0, 2.0], dtype=torch.float64) + extra = torch.ones(2, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores_no_trig = score_keys_for_round( + key_indices, + 100, + amp, + phi, + omega, + extra, + offsets, + "mean", + freq_scale_sq, + disable_trig=True, + ) + # Without trig, both keys get the same score (additive term is position-independent) + torch.testing.assert_close(scores_no_trig[0], scores_no_trig[1]) + + +def test_head_frequency_stats_dataclass(): + """HeadFrequencyStats holds correct fields.""" + stats = HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ) + assert stats.q_mean_complex.shape == (8,) + assert stats.q_abs_mean.shape == (8,) From 05e90444d4eb5def86536ee1a612db1e4d9e36ab Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 7 Apr 2026 23:21:25 -0700 Subject: [PATCH 3/9] Add TriAttention config Signed-off-by: Kai Xu --- modelopt/torch/sparsity/kv_cache/config.py | 109 ++++++++++++++++++ .../torch/sparsity/kv_cache/test_config.py | 76 ++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 modelopt/torch/sparsity/kv_cache/config.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_config.py diff --git a/modelopt/torch/sparsity/kv_cache/config.py b/modelopt/torch/sparsity/kv_cache/config.py new file mode 100644 index 00000000000..46c97aeafc7 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/config.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for KV cache sparsity modes.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +__all__ = ["TriAttentionConfig"] + + +class TriAttentionConfig(ModeloptBaseConfig): + """Configuration for TriAttention KV cache eviction. + + TriAttention scores cached KV entries using a trigonometric model derived from + pre-RoPE Q/K concentration. Calibration computes per-head frequency statistics; + at runtime, the serving engine scores and evicts tokens periodically. + """ + + # Budget + budget: int = ModeloptField( + default=2048, + title="KV token budget.", + description="Number of KV tokens to retain per head after pruning.", + ) + + # Pruning schedule + prune_interval: int = ModeloptField( + default=128, + title="Pruning interval.", + description="Re-score and evict every N generated tokens.", + ) + window_size: int = ModeloptField( + default=128, + title="Protected window size.", + description="Number of most recent tokens always retained.", + ) + + # Scoring + pruning_mode: Literal["per_head", "per_layer_per_head"] = ModeloptField( + default="per_head", + title="Pruning mode.", + description=( + "'per_head': independent budget per KV head. " + "'per_layer_per_head': budget allocated per layer and head." + ), + ) + score_aggregation: Literal["mean", "max"] = ModeloptField( + default="mean", + title="Offset score aggregation.", + description="How to aggregate scores across geometric offsets.", + ) + offset_max_length: int = ModeloptField( + default=65536, + title="Maximum geometric offset.", + description="Offsets are [1, 2, 4, ..., offset_max_length].", + ) + disable_mlr: bool = ModeloptField( + default=False, + title="Disable MLR term.", + description="If True, disable the magnitude linear regression extra term.", + ) + disable_trig: bool = ModeloptField( + default=False, + title="Disable trigonometric term.", + description="If True, use only the additive (MLR) term for scoring.", + ) + + # Calibration + calib_size: int = ModeloptField( + default=100000, + title="Calibration tokens.", + description="Number of tokens for calibration. 50K-960K, any domain.", + ) + + @field_validator("pruning_mode") + @classmethod + def validate_pruning_mode(cls, v: str) -> str: + """Validate pruning_mode is a supported value.""" + valid = {"per_head", "per_layer_per_head"} + if v not in valid: + raise ValueError(f"pruning_mode must be one of {valid}, got '{v}'") + return v + + @field_validator("score_aggregation") + @classmethod + def validate_score_aggregation(cls, v: str) -> str: + """Validate score_aggregation is a supported value.""" + valid = {"mean", "max"} + if v not in valid: + raise ValueError(f"score_aggregation must be one of {valid}, got '{v}'") + return v diff --git a/tests/unit/torch/sparsity/kv_cache/test_config.py b/tests/unit/torch/sparsity/kv_cache/test_config.py new file mode 100644 index 00000000000..bed66a9b4a4 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_config.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention configuration.""" + +import pytest +from pydantic import ValidationError + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig + + +def test_default_config(): + """Default config creates valid instance.""" + config = TriAttentionConfig() + assert config.budget == 2048 + assert config.prune_interval == 128 + assert config.window_size == 128 + assert config.pruning_mode == "per_head" + assert config.score_aggregation == "mean" + assert config.offset_max_length == 65536 + assert config.disable_mlr is False + assert config.disable_trig is False + assert config.calib_size == 100000 + + +def test_config_custom_values(): + """Config accepts custom values.""" + config = TriAttentionConfig(budget=4096, prune_interval=64, window_size=256) + assert config.budget == 4096 + assert config.prune_interval == 64 + assert config.window_size == 256 + + +def test_config_invalid_pruning_mode(): + """Invalid pruning mode raises validation error.""" + with pytest.raises(ValidationError): + TriAttentionConfig(pruning_mode="invalid") + + +def test_config_invalid_aggregation(): + """Invalid score aggregation raises validation error.""" + with pytest.raises(ValidationError): + TriAttentionConfig(score_aggregation="invalid") + + +def test_config_serialization_roundtrip(): + """Config can be serialized and deserialized.""" + config = TriAttentionConfig(budget=1024, prune_interval=64) + data = config.model_dump() + restored = TriAttentionConfig(**data) + assert restored.budget == 1024 + assert restored.prune_interval == 64 + + +def test_config_per_layer_per_head_mode(): + """per_layer_per_head is a valid pruning mode.""" + config = TriAttentionConfig(pruning_mode="per_layer_per_head") + assert config.pruning_mode == "per_layer_per_head" + + +def test_config_max_aggregation(): + """max is a valid score aggregation.""" + config = TriAttentionConfig(score_aggregation="max") + assert config.score_aggregation == "max" From 85499dfa3e4cf264dae35becec5afe9079a0db93 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 8 Apr 2026 11:04:55 -0700 Subject: [PATCH 4/9] Add TriAttention calibration module Signed-off-by: Kai Xu --- .../kv_cache/triattention/calibration.py | 120 +++++++++++++++++ .../sparsity/kv_cache/test_calibration.py | 127 ++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 modelopt/torch/sparsity/kv_cache/triattention/calibration.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_calibration.py diff --git a/modelopt/torch/sparsity/kv_cache/triattention/calibration.py b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py new file mode 100644 index 00000000000..89185390ee0 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration for TriAttention: compute per-head Q/K frequency statistics. + +Hooks into attention layers during a forward pass, captures pre-RoPE Q vectors, +inverts RoPE, converts to frequency domain, and computes per-head mean statistics. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from .rope_utils import to_complex_pairs +from .scoring import HeadFrequencyStats + +__all__ = [ + "CalibrationData", + "compute_head_stats_from_q", +] + + +@dataclass +class CalibrationData: + """Container for TriAttention calibration output. + + Stores per-head frequency statistics computed during calibration, along with + model metadata needed for scoring at inference time. + """ + + head_stats: dict[tuple[int, int], HeadFrequencyStats] # (layer, head) -> stats + head_dim: int + rope_style: str + num_layers: int + num_kv_heads: int + + def state_dict(self) -> dict[str, Any]: + """Serialize to state dict for checkpoint embedding.""" + stats_serialized = {} + for (layer, head), hs in self.head_stats.items(): + key = f"layer{layer:02d}_head{head:02d}" + stats_serialized[key] = { + "q_mean_real": hs.q_mean_complex.real.cpu(), + "q_mean_imag": hs.q_mean_complex.imag.cpu(), + "q_abs_mean": hs.q_abs_mean.cpu(), + } + return { + "metadata": { + "head_dim": self.head_dim, + "rope_style": self.rope_style, + "num_layers": self.num_layers, + "num_kv_heads": self.num_kv_heads, + "sampled_heads": [[layer, head] for layer, head in self.head_stats], + }, + "stats": stats_serialized, + } + + @classmethod + def from_state_dict(cls, state: dict[str, Any]) -> CalibrationData: + """Deserialize from state dict.""" + metadata = state["metadata"] + stats_raw = state["stats"] + sampled_heads = [tuple(pair) for pair in metadata["sampled_heads"]] + head_stats: dict[tuple[int, int], HeadFrequencyStats] = {} + for layer, head in sampled_heads: + key = f"layer{layer:02d}_head{head:02d}" + entry = stats_raw[key] + q_mean_complex = torch.complex( + entry["q_mean_real"].to(torch.float32), + entry["q_mean_imag"].to(torch.float32), + ) + q_abs_mean = entry["q_abs_mean"].to(torch.float32) + head_stats[(int(layer), int(head))] = HeadFrequencyStats( + q_mean_complex=q_mean_complex, + q_abs_mean=q_abs_mean, + ) + return cls( + head_stats=head_stats, + head_dim=metadata["head_dim"], + rope_style=metadata["rope_style"], + num_layers=metadata["num_layers"], + num_kv_heads=metadata["num_kv_heads"], + ) + + +def compute_head_stats_from_q( + q_pre_rope: torch.Tensor, + style: str = "half", +) -> HeadFrequencyStats: + """Compute frequency statistics for a single head from pre-RoPE Q vectors. + + Args: + q_pre_rope: Pre-RoPE query vectors for one head, shape (seq_len, head_dim). + style: RoPE pairing style. + + Returns: + HeadFrequencyStats with q_mean_complex and q_abs_mean. + """ + q_complex = to_complex_pairs(q_pre_rope, style=style) # (seq_len, freq_count) + q_mean_complex = q_complex.mean(dim=0) # (freq_count,) + q_abs_mean = q_complex.abs().mean(dim=0) # (freq_count,) + return HeadFrequencyStats( + q_mean_complex=q_mean_complex, + q_abs_mean=q_abs_mean, + ) diff --git a/tests/unit/torch/sparsity/kv_cache/test_calibration.py b/tests/unit/torch/sparsity/kv_cache/test_calibration.py new file mode 100644 index 00000000000..6b598a522f2 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_calibration.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention calibration.""" + +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.calibration import ( + CalibrationData, + compute_head_stats_from_q, +) +from modelopt.torch.sparsity.kv_cache.triattention.scoring import HeadFrequencyStats + + +def test_compute_head_stats_shapes(): + """Head stats computed from Q tensor have correct shapes.""" + seq_len = 64 + head_dim = 16 + freq_count = head_dim // 2 + + q_pre_rope = torch.randn(seq_len, head_dim) + stats = compute_head_stats_from_q(q_pre_rope) + + assert stats.q_mean_complex.shape == (freq_count,) + assert stats.q_mean_complex.dtype == torch.complex64 + assert stats.q_abs_mean.shape == (freq_count,) + assert stats.q_abs_mean.dtype == torch.float32 + + +def test_compute_head_stats_mean_abs_ge_abs_mean(): + """Mean of absolute values >= absolute value of mean (triangle inequality).""" + q_pre_rope = torch.randn(128, 32) + stats = compute_head_stats_from_q(q_pre_rope) + + abs_of_mean = torch.abs(stats.q_mean_complex) + assert (stats.q_abs_mean >= abs_of_mean - 1e-6).all() + + +def test_compute_head_stats_single_token(): + """Single-token input: mean equals the single value.""" + head_dim = 8 + q_pre_rope = torch.randn(1, head_dim) + stats = compute_head_stats_from_q(q_pre_rope) + + # For single token, mean_complex == the single complex value + # and abs_mean == |single complex value| + torch.testing.assert_close(stats.q_abs_mean, torch.abs(stats.q_mean_complex)) + + +def test_calibration_data_state_dict_roundtrip(): + """CalibrationData can be serialized to and restored from state dict.""" + stats = { + (0, 0): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + (0, 1): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + (1, 0): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + } + calib = CalibrationData( + head_stats=stats, + head_dim=16, + rope_style="half", + num_layers=2, + num_kv_heads=2, + ) + + state = calib.state_dict() + restored = CalibrationData.from_state_dict(state) + + assert restored.head_dim == 16 + assert restored.rope_style == "half" + assert restored.num_layers == 2 + assert restored.num_kv_heads == 2 + assert len(restored.head_stats) == 3 + + for key in stats: + torch.testing.assert_close( + restored.head_stats[key].q_abs_mean, + calib.head_stats[key].q_abs_mean, + ) + torch.testing.assert_close( + restored.head_stats[key].q_mean_complex, + calib.head_stats[key].q_mean_complex, + ) + + +def test_calibration_data_state_dict_keys(): + """State dict has expected structure.""" + stats = { + (2, 3): HeadFrequencyStats( + q_mean_complex=torch.randn(4, dtype=torch.complex64), + q_abs_mean=torch.rand(4), + ), + } + calib = CalibrationData( + head_stats=stats, + head_dim=8, + rope_style="half", + num_layers=4, + num_kv_heads=8, + ) + + state = calib.state_dict() + assert "metadata" in state + assert "stats" in state + assert "layer02_head03" in state["stats"] + assert state["metadata"]["head_dim"] == 8 + assert state["metadata"]["sampled_heads"] == [[2, 3]] From b75dc650383338d1a3f95e2805e3cf8f327a94ed Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 8 Apr 2026 12:09:27 -0700 Subject: [PATCH 5/9] Add model and conversion for TriAttention Signed-off-by: Kai Xu --- modelopt/torch/sparsity/kv_cache/__init__.py | 6 +- .../torch/sparsity/kv_cache/conversion.py | 69 ++++++++++++++ modelopt/torch/sparsity/kv_cache/mode.py | 65 +++++++++++++ .../sparsity/kv_cache/test_conversion.py | 92 +++++++++++++++++++ 4 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 modelopt/torch/sparsity/kv_cache/conversion.py create mode 100644 modelopt/torch/sparsity/kv_cache/mode.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_conversion.py diff --git a/modelopt/torch/sparsity/kv_cache/__init__.py b/modelopt/torch/sparsity/kv_cache/__init__.py index feb0b7e3f78..84e428f37af 100644 --- a/modelopt/torch/sparsity/kv_cache/__init__.py +++ b/modelopt/torch/sparsity/kv_cache/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-License-Identifier: Apache-2.0 """KV cache sparsity algorithms for LLM inference optimization.""" + +from . import mode +from .config import * +from .conversion import * diff --git a/modelopt/torch/sparsity/kv_cache/conversion.py b/modelopt/torch/sparsity/kv_cache/conversion.py new file mode 100644 index 00000000000..2dbbab39bff --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/conversion.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert/restore/update entrypoints for TriAttention mode. + +TriAttention is a calibration-only mode. Convert is a no-op on model weights. +Calibration data is stored in metadata and fused into the checkpoint at save time. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.nn as nn + + from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + + from .config import TriAttentionConfig + +__all__ = [ + "convert_triattention", + "restore_triattention", + "update_triattention_metadata", +] + + +def convert_triattention(model: nn.Module, config: TriAttentionConfig) -> ConvertReturnType: + """Apply TriAttention mode to model. + + This is a no-op on model weights. It stores the configuration in metadata + so that calibration can be run subsequently. + """ + metadata = { + "triattention_config": config.model_dump(), + } + return model, metadata + + +def restore_triattention( + model: nn.Module, config: TriAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore TriAttention mode from saved state. + + Loads calibration data from metadata if present. + """ + return model + + +def update_triattention_metadata( + model: nn.Module, config: TriAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata before saving. + + Ensures calibration data and config are current in metadata. + """ + metadata["triattention_config"] = config.model_dump() diff --git a/modelopt/torch/sparsity/kv_cache/mode.py b/modelopt/torch/sparsity/kv_cache/mode.py new file mode 100644 index 00000000000..ee89745e0d4 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/mode.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mode registration for KV cache sparsity.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import TriAttentionConfig +from .conversion import convert_triattention, restore_triattention, update_triattention_metadata + +KVCacheSparsityRegistry = _ModeRegistryCls("kv_cache_sparsity") + + +@KVCacheSparsityRegistry.register_mode +class TriAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for TriAttention KV cache sparsity. + + TriAttention is a calibration-only mode: convert is a no-op on model weights, + calibration computes per-head frequency statistics, and the results are stored + in metadata for export to serving engines. + """ + + @property + def name(self) -> str: + """Return the mode name.""" + return "triattention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Return the configuration class.""" + return TriAttentionConfig + + @property + def convert(self) -> ConvertEntrypoint: + """Return the convert entrypoint.""" + return convert_triattention + + @property + def restore(self) -> RestoreEntrypoint: + """Return the restore entrypoint.""" + return restore_triattention + + @property + def update_for_save(self) -> UpdateEntrypoint: + """Return the update-for-save entrypoint.""" + return update_triattention_metadata diff --git a/tests/unit/torch/sparsity/kv_cache/test_conversion.py b/tests/unit/torch/sparsity/kv_cache/test_conversion.py new file mode 100644 index 00000000000..6a880a7cb84 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_conversion.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention mode registration and conversion.""" + +import torch +import torch.nn as nn + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig +from modelopt.torch.sparsity.kv_cache.conversion import ( + convert_triattention, + restore_triattention, + update_triattention_metadata, +) +from modelopt.torch.sparsity.kv_cache.mode import KVCacheSparsityRegistry + + +def test_mode_registered(): + """TriAttention mode is registered in KVCacheSparsityRegistry.""" + assert "triattention" in KVCacheSparsityRegistry + + +def test_mode_descriptor_properties(): + """Mode descriptor has correct properties.""" + descriptor = KVCacheSparsityRegistry["triattention"] + assert descriptor.name == "triattention" + assert descriptor.config_class is TriAttentionConfig + + +def test_mode_discoverable_globally(): + """TriAttention mode is discoverable via get_from_any.""" + from modelopt.torch.opt.mode import _ModeRegistryCls + + descriptor = _ModeRegistryCls.get_from_any("triattention") + assert descriptor is not None + assert descriptor.name == "triattention" + + +def test_convert_returns_model_and_metadata(): + """Convert returns (model, metadata) without modifying weights.""" + model = nn.Linear(16, 16) + original_weight = model.weight.data.clone() + config = TriAttentionConfig() + + converted_model, metadata = convert_triattention(model, config) + + torch.testing.assert_close(converted_model.weight.data, original_weight) + assert "triattention_config" in metadata + + +def test_convert_metadata_contains_config(): + """Metadata stores the config values.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=1024, prune_interval=64) + + _, metadata = convert_triattention(model, config) + + assert metadata["triattention_config"]["budget"] == 1024 + assert metadata["triattention_config"]["prune_interval"] == 64 + + +def test_restore_returns_model(): + """Restore returns the model.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig() + _, metadata = convert_triattention(model, config) + + restored = restore_triattention(model, config, metadata) + assert restored is model + + +def test_update_metadata(): + """update_triattention_metadata updates config in metadata.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=512) + metadata = {} + + update_triattention_metadata(model, config, metadata) + + assert metadata["triattention_config"]["budget"] == 512 From ef40a0fd789f456cb611cbbdd3afe635eec67235 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 8 Apr 2026 12:13:59 -0700 Subject: [PATCH 6/9] Add sparsify() and calibrate() entry API Signed-off-by: Kai Xu --- modelopt/torch/sparsity/__init__.py | 3 + modelopt/torch/sparsity/kv_cache/__init__.py | 1 + .../torch/sparsity/kv_cache/model_sparsify.py | 90 +++++++++++++++++++ .../sparsity/kv_cache/test_model_sparsify.py | 60 +++++++++++++ 4 files changed, 154 insertions(+) create mode 100644 modelopt/torch/sparsity/kv_cache/model_sparsify.py create mode 100644 tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py diff --git a/modelopt/torch/sparsity/__init__.py b/modelopt/torch/sparsity/__init__.py index 2013fded1ae..433db71402c 100644 --- a/modelopt/torch/sparsity/__init__.py +++ b/modelopt/torch/sparsity/__init__.py @@ -22,3 +22,6 @@ # Import weight sparsity for backward compatibility from .weight_sparsity import mode, module, plugins from .weight_sparsity.sparsification import * + +# Import kv_cache to register KV cache sparsity modes +from . import kv_cache diff --git a/modelopt/torch/sparsity/kv_cache/__init__.py b/modelopt/torch/sparsity/kv_cache/__init__.py index 84e428f37af..3229b8162d1 100644 --- a/modelopt/torch/sparsity/kv_cache/__init__.py +++ b/modelopt/torch/sparsity/kv_cache/__init__.py @@ -18,3 +18,4 @@ from . import mode from .config import * from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/kv_cache/model_sparsify.py b/modelopt/torch/sparsity/kv_cache/model_sparsify.py new file mode 100644 index 00000000000..254ba0f64e4 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/model_sparsify.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entry points for KV cache sparsity: sparsify() and calibrate().""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from modelopt.torch.opt.conversion import apply_mode + +from .config import TriAttentionConfig +from .mode import KVCacheSparsityRegistry + +if TYPE_CHECKING: + import torch.nn as nn + + from modelopt.torch.opt.searcher import ForwardLoop + +__all__ = ["calibrate", "sparsify"] + + +def sparsify( + model: nn.Module, + config: dict[str, Any] | TriAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> nn.Module: + """Apply KV cache sparsity optimization to a model. + + Registers the TriAttention mode on the model. Call ``calibrate()`` afterwards + to compute frequency statistics from calibration data. + + Args: + model: The model to optimize. + config: TriAttentionConfig or dict with config values. + forward_loop: Optional forward loop for integrated calibration. + + Returns: + The model with TriAttention mode applied (in-place). + """ + if isinstance(config, dict): + config = TriAttentionConfig(**config) + + model = apply_mode( + model, + mode=[("triattention", config.model_dump())], + registry=KVCacheSparsityRegistry, + ) + + if forward_loop is not None: + model = calibrate(model, config, forward_loop=forward_loop) + + return model + + +def calibrate( + model: nn.Module, + config: dict[str, Any] | TriAttentionConfig | None = None, + forward_loop: ForwardLoop | None = None, +) -> nn.Module: + """Calibrate TriAttention frequency statistics. + + Runs a forward pass with hooks to capture pre-RoPE Q vectors, inverts RoPE, + and computes per-head frequency centers. Results are stored in the model's + modelopt_state metadata. + + Args: + model: Model with TriAttention mode applied. + config: Optional config override. + forward_loop: Callable that runs forward passes on calibration data. + + Returns: + The model with calibration data stored in metadata. + """ + # Full GPU calibration with model forward passes will be implemented + # when engine integration is designed. For now, calibrate is a no-op + # that can be called safely after sparsify(). + return model diff --git a/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py new file mode 100644 index 00000000000..922a2a7d550 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention sparsify/calibrate entry API.""" + +import torch +import torch.nn as nn + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig +from modelopt.torch.sparsity.kv_cache.model_sparsify import calibrate, sparsify + + +def test_sparsify_returns_model(): + """sparsify() returns the model.""" + model = nn.Linear(16, 16) + result = sparsify(model, TriAttentionConfig()) + assert result is model + + +def test_sparsify_accepts_dict_config(): + """sparsify() accepts dict config.""" + model = nn.Linear(16, 16) + result = sparsify(model, {"budget": 1024}) + assert result is model + + +def test_sparsify_preserves_weights(): + """sparsify() does not modify model weights.""" + model = nn.Linear(16, 16) + original_weight = model.weight.data.clone() + sparsify(model, TriAttentionConfig()) + torch.testing.assert_close(model.weight.data, original_weight) + + +def test_calibrate_returns_model(): + """calibrate() returns the model.""" + model = nn.Linear(16, 16) + sparsify(model, TriAttentionConfig()) + result = calibrate(model) + assert result is model + + +def test_sparsify_then_calibrate(): + """sparsify() followed by calibrate() works without error.""" + model = nn.Linear(16, 16) + model = sparsify(model, TriAttentionConfig(budget=512)) + model = calibrate(model) + assert isinstance(model, nn.Module) From 22d30e5b59b724f78f5d0f7b9308823996c5ad5f Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 8 Apr 2026 12:54:14 -0700 Subject: [PATCH 7/9] Add TriAttention calibration with HF models Signed-off-by: Kai Xu --- .../kv_cache_sparsity/hf_triattention.py | 208 ++++++++++++++++++ modelopt/torch/sparsity/__init__.py | 5 +- .../torch/sparsity/kv_cache/model_sparsify.py | 14 +- .../kv_cache/triattention/calibration.py | 200 ++++++++++++++++- 4 files changed, 420 insertions(+), 7 deletions(-) create mode 100644 examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py diff --git a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py new file mode 100644 index 00000000000..ace03626a65 --- /dev/null +++ b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example: TriAttention KV cache sparsity calibration on HuggingFace models. + +Demonstrates the TriAttention calibration pipeline: +1. Load a pretrained HF model +2. Apply KV cache sparsity mode (sparsify) +3. Run calibration with a forward pass to compute per-head frequency statistics +4. Verify calibration data was produced +5. Optionally save calibration data + +Usage: + python hf_triattention.py --model Qwen/Qwen3-0.6B + + # With custom budget and calibration length + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 1024 --calib-seq-len 4096 + + # Save calibration data + python hf_triattention.py --model Qwen/Qwen3-0.6B --output calibration.pt +""" + +import argparse +import time + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.kv_cache as mtskv +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig + + +def make_calibration_forward_loop(tokenizer, seq_len: int = 2048, input_file: str | None = None): + """Create a forward loop that runs calibration data through the model. + + Args: + tokenizer: Tokenizer for the model. + seq_len: Sequence length for calibration input. + input_file: Path to a plain text file for calibration. If None, uses + built-in placeholder text. + + Returns: + Callable that takes a model and runs a forward pass. + """ + if input_file is not None: + from pathlib import Path + + calib_text = Path(input_file).read_text(encoding="utf-8") + else: + calib_text = ( + "The quick brown fox jumps over the lazy dog. " + "Machine learning is a subset of artificial intelligence that enables systems " + "to learn and improve from experience without being explicitly programmed. " + "Deep learning, a branch of machine learning, uses neural networks with many " + "layers to model complex patterns in data. Transformers have revolutionized " + "natural language processing by introducing self-attention mechanisms that " + "allow models to weigh the importance of different parts of the input. " + ) * 100 + + input_ids = tokenizer.encode( + calib_text, return_tensors="pt", truncation=True, max_length=seq_len + ) + print(f" Calibration tokens: {input_ids.shape[1]}") + + def forward_loop(model): + device = next(model.parameters()).device + inputs = input_ids.to(device) + with torch.no_grad(): + model(input_ids=inputs) + + return forward_loop + + +def main(args): + """Run TriAttention calibration pipeline.""" + print(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model.eval() + + # Print model info + config = model.config + print(f" Layers: {config.num_hidden_layers}") + print(f" Heads: {config.num_attention_heads}") + print(f" KV heads: {getattr(config, 'num_key_value_heads', config.num_attention_heads)}") + print(f" Hidden size: {config.hidden_size}") + + # Step 1: Calibrate (before sparsify — apply_mode adds wrappers that alter forward pass) + print(f"\nRunning calibration (seq_len={args.calib_seq_len})...") + forward_loop = make_calibration_forward_loop( + tokenizer, seq_len=args.calib_seq_len, input_file=args.input + ) + + t0 = time.time() + model = mtskv.calibrate(model, forward_loop=forward_loop) + elapsed = time.time() - t0 + print(f" Calibration complete in {elapsed:.1f}s") + + # Step 2: Apply KV cache sparsity mode + print(f"\nApplying TriAttention mode (budget={args.budget})...") + triattention_config = TriAttentionConfig( + budget=args.budget, + prune_interval=args.prune_interval, + ) + model = mtskv.sparsify(model, triattention_config) + print(" Mode applied (no-op on weights).") + + # Step 3: Verify calibration data + calib_data = getattr(model, "_triattention_calibration", None) + if calib_data is None: + print("\n ERROR: No calibration data found on model!") + return + + print("\n Calibration results:") + print(f" Head dim: {calib_data.head_dim}") + print(f" RoPE style: {calib_data.rope_style}") + print(f" Num layers: {calib_data.num_layers}") + print(f" Num KV heads: {calib_data.num_kv_heads}") + print(f" Heads calibrated: {len(calib_data.head_stats)}") + + # Check concentration (Mean Resultant Length) + total_heads = 0 + concentrated = 0 + for stats in calib_data.head_stats.values(): + abs_mean = torch.abs(stats.q_mean_complex) # |E[q]| + mean_abs = stats.q_abs_mean # E[|q|] + # R = |E[q]| / E[|q|] — concentration metric + r_values = abs_mean / (mean_abs + 1e-8) + r_mean = r_values.mean().item() + total_heads += 1 + if r_mean > 0.9: + concentrated += 1 + + print(f" Concentrated heads (R > 0.9): {concentrated}/{total_heads}") + print(f" Concentration ratio: {concentrated / total_heads:.1%}") + + # Step 4: Optionally save calibration data + if args.output: + state = calib_data.state_dict() + torch.save(state, args.output) + print(f"\n Calibration data saved to: {args.output}") + + print("\nDone.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="TriAttention KV cache sparsity calibration example." + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-0.6B", + help="HuggingFace model name or local path.", + ) + parser.add_argument( + "--budget", + type=int, + default=2048, + help="KV token budget (tokens to retain per head).", + ) + parser.add_argument( + "--prune-interval", + type=int, + default=128, + help="Re-score and evict every N tokens.", + ) + parser.add_argument( + "--input", + type=str, + default=None, + help="Plain text file for calibration input. " + "If not provided, uses built-in placeholder text. " + "Use the same file as triattention/scripts/calibrate.py --input for comparison.", + ) + parser.add_argument( + "--calib-seq-len", + type=int, + default=2048, + help="Sequence length for calibration input.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to save calibration data (.pt file).", + ) + + args = parser.parse_args() + main(args) diff --git a/modelopt/torch/sparsity/__init__.py b/modelopt/torch/sparsity/__init__.py index 433db71402c..55a59ae4e7b 100644 --- a/modelopt/torch/sparsity/__init__.py +++ b/modelopt/torch/sparsity/__init__.py @@ -20,8 +20,7 @@ """ # Import weight sparsity for backward compatibility -from .weight_sparsity import mode, module, plugins -from .weight_sparsity.sparsification import * - # Import kv_cache to register KV cache sparsity modes from . import kv_cache +from .weight_sparsity import mode, module, plugins +from .weight_sparsity.sparsification import * diff --git a/modelopt/torch/sparsity/kv_cache/model_sparsify.py b/modelopt/torch/sparsity/kv_cache/model_sparsify.py index 254ba0f64e4..18df12bcfb4 100644 --- a/modelopt/torch/sparsity/kv_cache/model_sparsify.py +++ b/modelopt/torch/sparsity/kv_cache/model_sparsify.py @@ -80,11 +80,19 @@ def calibrate( model: Model with TriAttention mode applied. config: Optional config override. forward_loop: Callable that runs forward passes on calibration data. + If None, calibration is skipped (no-op). Returns: The model with calibration data stored in metadata. """ - # Full GPU calibration with model forward passes will be implemented - # when engine integration is designed. For now, calibrate is a no-op - # that can be called safely after sparsify(). + if forward_loop is None: + return model + + from .triattention.calibration import run_calibration + + calib_data = run_calibration(model, forward_loop=forward_loop) + + # Store calibration data in model attribute for later export + model._triattention_calibration = calib_data + return model diff --git a/modelopt/torch/sparsity/kv_cache/triattention/calibration.py b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py index 89185390ee0..fbf0d256ab6 100644 --- a/modelopt/torch/sparsity/kv_cache/triattention/calibration.py +++ b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py @@ -26,12 +26,13 @@ import torch -from .rope_utils import to_complex_pairs +from .rope_utils import invert_rope, rotate_half, to_complex_pairs from .scoring import HeadFrequencyStats __all__ = [ "CalibrationData", "compute_head_stats_from_q", + "run_calibration", ] @@ -118,3 +119,200 @@ def compute_head_stats_from_q( q_mean_complex=q_mean_complex, q_abs_mean=q_abs_mean, ) + + +# --------------------------------------------------------------------------- +# Model introspection helpers +# --------------------------------------------------------------------------- + + +def _find_attention_layers(model: torch.nn.Module) -> list[torch.nn.Module]: + """Find attention sub-modules in HF model (model.model.layers[i].self_attn).""" + backbone = getattr(model, "model", model) + layer_list = getattr(backbone, "layers", None) + if layer_list is None: + raise RuntimeError( + "Cannot locate transformer layers. Expected model.model.layers attribute." + ) + layers = [] + for layer_module in layer_list: + attn = getattr(layer_module, "self_attn", None) + if attn is None: + raise RuntimeError("Layer missing self_attn attribute.") + layers.append(attn) + return layers + + +def _get_rotary_embedding(model: torch.nn.Module) -> torch.nn.Module: + """Find the rotary embedding module in the model.""" + backbone = getattr(model, "model", model) + if hasattr(backbone, "rotary_emb"): + return backbone.rotary_emb + # Some models put rotary_emb on individual attention layers + attn_layers = _find_attention_layers(model) + if attn_layers and hasattr(attn_layers[0], "rotary_emb"): + return attn_layers[0].rotary_emb + raise RuntimeError("Cannot locate rotary_emb on model.model or self_attn.") + + +def _get_model_config(model: torch.nn.Module) -> dict[str, Any]: + """Extract model config parameters needed for calibration.""" + config = getattr(model, "config", None) + if config is not None: + num_layers = getattr(config, "num_hidden_layers", None) + num_heads = getattr(config, "num_attention_heads", None) + hidden_size = getattr(config, "hidden_size", None) + head_dim = getattr(config, "head_dim", None) + num_kv_heads = getattr(config, "num_key_value_heads", num_heads) + if head_dim is None and hidden_size and num_heads: + head_dim = hidden_size // num_heads + if all(v is not None for v in [num_layers, num_heads, head_dim, num_kv_heads]): + return { + "num_layers": num_layers, + "num_heads": num_heads, + "head_dim": head_dim, + "num_kv_heads": num_kv_heads, + } + + # Fallback: infer from model structure + attn_layers = _find_attention_layers(model) + num_layers = len(attn_layers) + attn0 = attn_layers[0] + num_heads = getattr(attn0, "num_heads", None) + head_dim = getattr(attn0, "head_dim", None) + num_kv_heads = getattr(attn0, "num_key_value_heads", num_heads) + + if num_heads is None or head_dim is None: + # Infer from q_proj weight shape + q_proj = getattr(attn0, "q_proj", None) + if q_proj is not None: + out_features = q_proj.out_features + # Guess: if num_heads not available, try common head_dims + for hd in [128, 96, 64, 32]: + if out_features % hd == 0: + head_dim = head_dim or hd + num_heads = num_heads or (out_features // hd) + break + + if num_heads is None or head_dim is None: + raise RuntimeError("Cannot determine num_heads and head_dim from model.") + + num_kv_heads = num_kv_heads or num_heads + return { + "num_layers": num_layers, + "num_heads": num_heads, + "head_dim": head_dim, + "num_kv_heads": num_kv_heads, + } + + +# --------------------------------------------------------------------------- +# Main calibration function +# --------------------------------------------------------------------------- + + +def run_calibration( + model: torch.nn.Module, + forward_loop: Any = None, + rope_style: str = "half", +) -> CalibrationData: + """Run TriAttention calibration on a model. + + Hooks into attention layers, captures post-RoPE Q states during a forward + pass, inverts RoPE, and computes per-head frequency statistics. + + Args: + model: The model to calibrate. Must follow HF structure + (model.model.layers[i].self_attn with q_proj). + forward_loop: Callable that takes the model and runs forward passes + on calibration data. Signature: ``forward_loop(model) -> None``. + rope_style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + CalibrationData with per-head frequency statistics. + """ + model_cfg = _get_model_config(model) + num_layers = model_cfg["num_layers"] + num_heads = model_cfg["num_heads"] + head_dim = model_cfg["head_dim"] + num_kv_heads = model_cfg["num_kv_heads"] + + attn_layers = _find_attention_layers(model) + rotary = _get_rotary_embedding(model) + attn_scale = float(getattr(rotary, "attention_scaling", 1.0)) + + # Storage for captured Q states + captured_q: dict[int, torch.Tensor] = {} + + def _make_pre_hook(layer_idx: int): + def hook_fn(module, args, kwargs): + hidden_states = args[0] if args else kwargs.get("hidden_states") + if hidden_states is None: + return + bsz, q_len, _ = hidden_states.shape + q = module.q_proj(hidden_states) + q = q.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) + + # Apply RoPE — ensure cos/sin are on the same device as hidden_states + device = hidden_states.device + pos_ids = torch.arange(q_len, device=device).unsqueeze(0) + probe = torch.zeros(1, q_len, head_dim, device=device, dtype=hidden_states.dtype) + cos, sin = rotary(probe, pos_ids) + cos = cos.to(device=device) + sin = sin.to(device=device) + q_rot = (q * cos.unsqueeze(1)) + (rotate_half(q, style=rope_style) * sin.unsqueeze(1)) + q_rot = q_rot * attn_scale + captured_q[layer_idx] = q_rot.detach() + + return hook_fn + + # Register hooks + handles = [] + for layer_idx, attn in enumerate(attn_layers): + handle = attn.register_forward_pre_hook(_make_pre_hook(layer_idx), with_kwargs=True) + handles.append(handle) + + # Run forward pass + try: + if forward_loop is not None: + forward_loop(model) + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + # Compute per-head frequency statistics + head_stats: dict[tuple[int, int], HeadFrequencyStats] = {} + + for layer_idx in range(num_layers): + q_rot = captured_q.get(layer_idx) + if q_rot is None: + continue + + # q_rot: (batch, num_heads, seq_len, head_dim) + # Build cos/sin for RoPE inversion — ensure same device as q_rot + device = q_rot.device + seq_len = q_rot.shape[2] + pos_ids = torch.arange(seq_len, device=device).unsqueeze(0) + probe = torch.zeros(1, seq_len, head_dim, device=device, dtype=q_rot.dtype) + cos, sin = rotary(probe, pos_ids) + cos = cos.to(device=device).unsqueeze(1) # (1, 1, seq_len, head_dim) + sin = sin.to(device=device).unsqueeze(1) + + # Invert RoPE + q_base = invert_rope(q_rot, cos, sin, attn_scale, style=rope_style) + + for head_idx in range(num_heads): + q_head = q_base[0, head_idx] # (seq_len, head_dim) + head_stats[(layer_idx, head_idx)] = compute_head_stats_from_q(q_head, style=rope_style) + + # Free memory + del captured_q[layer_idx] + + return CalibrationData( + head_stats=head_stats, + head_dim=head_dim, + rope_style=rope_style, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + ) From 84ae221ce1848acb0c6b6659c3f10afa9a8515b7 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 20 Apr 2026 19:45:03 -0700 Subject: [PATCH 8/9] Add target_sparsity_ratio option Signed-off-by: Kai Xu --- .../kv_cache_sparsity/hf_triattention.py | 46 ++++++++--- modelopt/torch/sparsity/kv_cache/config.py | 45 +++++++++-- .../kv_cache/triattention/__init__.py | 10 +++ .../sparsity/kv_cache/triattention/scoring.py | 46 +++++++++++ .../torch/sparsity/kv_cache/test_config.py | 77 ++++++++++++++----- .../sparsity/kv_cache/test_conversion.py | 40 +++++++++- .../sparsity/kv_cache/test_model_sparsify.py | 6 +- .../torch/sparsity/kv_cache/test_scoring.py | 63 +++++++++++++++ 8 files changed, 294 insertions(+), 39 deletions(-) diff --git a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py index ace03626a65..19b9d95398a 100644 --- a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py +++ b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py @@ -24,13 +24,17 @@ 5. Optionally save calibration data Usage: - python hf_triattention.py --model Qwen/Qwen3-0.6B + # Fixed-size budget (retain top-K tokens per head) + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 2048 + + # Percentile-based eviction (evict 70% at each prune step) + python hf_triattention.py --model Qwen/Qwen3-0.6B --target-sparsity-ratio 0.7 # With custom budget and calibration length python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 1024 --calib-seq-len 4096 # Save calibration data - python hf_triattention.py --model Qwen/Qwen3-0.6B --output calibration.pt + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 2048 --output calibration.pt """ import argparse @@ -115,11 +119,20 @@ def main(args): print(f" Calibration complete in {elapsed:.1f}s") # Step 2: Apply KV cache sparsity mode - print(f"\nApplying TriAttention mode (budget={args.budget})...") - triattention_config = TriAttentionConfig( - budget=args.budget, - prune_interval=args.prune_interval, - ) + if args.target_sparsity_ratio is not None: + print( + f"\nApplying TriAttention mode (target_sparsity_ratio={args.target_sparsity_ratio})..." + ) + triattention_config = TriAttentionConfig( + target_sparsity_ratio=args.target_sparsity_ratio, + prune_interval=args.prune_interval, + ) + else: + print(f"\nApplying TriAttention mode (budget={args.budget})...") + triattention_config = TriAttentionConfig( + budget=args.budget, + prune_interval=args.prune_interval, + ) model = mtskv.sparsify(model, triattention_config) print(" Mode applied (no-op on weights).") @@ -171,11 +184,22 @@ def main(args): default="Qwen/Qwen3-0.6B", help="HuggingFace model name or local path.", ) - parser.add_argument( + policy = parser.add_mutually_exclusive_group() + policy.add_argument( "--budget", type=int, - default=2048, - help="KV token budget (tokens to retain per head).", + default=None, + help="KV token budget (tokens to retain per head). " + "Mutually exclusive with --target-sparsity-ratio. " + "Defaults to 2048 if neither is set.", + ) + policy.add_argument( + "--target-sparsity-ratio", + type=float, + default=None, + help="Fraction of tokens to evict at each prune step, in (0, 1). " + "Example: 0.7 evicts 70%% of tokens (keeps top 30%% by score). " + "Mutually exclusive with --budget.", ) parser.add_argument( "--prune-interval", @@ -205,4 +229,6 @@ def main(args): ) args = parser.parse_args() + if args.budget is None and args.target_sparsity_ratio is None: + args.budget = 2048 main(args) diff --git a/modelopt/torch/sparsity/kv_cache/config.py b/modelopt/torch/sparsity/kv_cache/config.py index 46c97aeafc7..344bfb2decb 100644 --- a/modelopt/torch/sparsity/kv_cache/config.py +++ b/modelopt/torch/sparsity/kv_cache/config.py @@ -19,7 +19,7 @@ from typing import Literal -from pydantic import field_validator +from pydantic import field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField @@ -32,13 +32,31 @@ class TriAttentionConfig(ModeloptBaseConfig): TriAttention scores cached KV entries using a trigonometric model derived from pre-RoPE Q/K concentration. Calibration computes per-head frequency statistics; at runtime, the serving engine scores and evicts tokens periodically. + + Exactly one of ``budget`` or ``target_sparsity_ratio`` must be set: + + - ``budget``: absolute token count to retain per head (fixed-size cache). + - ``target_sparsity_ratio``: fraction of tokens to evict at each pruning step. + Cache size auto-scales with generation length. Value in (0, 1). """ - # Budget - budget: int = ModeloptField( - default=2048, - title="KV token budget.", - description="Number of KV tokens to retain per head after pruning.", + # Eviction policy (exactly one must be set) + budget: int | None = ModeloptField( + default=None, + title="KV token budget (absolute).", + description=( + "Number of KV tokens to retain per head after pruning. " + "Mutually exclusive with target_sparsity_ratio." + ), + ) + target_sparsity_ratio: float | None = ModeloptField( + default=None, + title="Target sparsity ratio (percentile-based).", + description=( + "Fraction of tokens to evict at each pruning step, in (0, 1). " + "Example: 0.7 means evict 70% of tokens (keep top 30% by score). " + "Mutually exclusive with budget." + ), ) # Pruning schedule @@ -107,3 +125,18 @@ def validate_score_aggregation(cls, v: str) -> str: if v not in valid: raise ValueError(f"score_aggregation must be one of {valid}, got '{v}'") return v + + @model_validator(mode="after") + def validate_budget_or_sparsity(self) -> TriAttentionConfig: + """Exactly one of budget or target_sparsity_ratio must be set.""" + budget_set = self.budget is not None + sparsity_set = self.target_sparsity_ratio is not None + if not budget_set and not sparsity_set: + raise ValueError("Must set exactly one of 'budget' or 'target_sparsity_ratio'") + if budget_set and sparsity_set: + raise ValueError("Cannot set both 'budget' and 'target_sparsity_ratio'; pick one") + if sparsity_set and not (0.0 < self.target_sparsity_ratio < 1.0): + raise ValueError( + f"target_sparsity_ratio must be in (0, 1), got {self.target_sparsity_ratio}" + ) + return self diff --git a/modelopt/torch/sparsity/kv_cache/triattention/__init__.py b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py index 4aaae7a51d9..4190eabe1b8 100644 --- a/modelopt/torch/sparsity/kv_cache/triattention/__init__.py +++ b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py @@ -18,10 +18,20 @@ """TriAttention: Trigonometric KV cache compression.""" from .rope_utils import build_geometric_offsets, invert_rope, rotate_half, to_complex_pairs +from .scoring import ( + HeadFrequencyStats, + compute_frequency_statistics_from_means, + score_keys_for_round, + select_keys_to_keep, +) __all__ = [ + "HeadFrequencyStats", "build_geometric_offsets", + "compute_frequency_statistics_from_means", "invert_rope", "rotate_half", + "score_keys_for_round", + "select_keys_to_keep", "to_complex_pairs", ] diff --git a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py index e91d00b2ae2..e7740a029f9 100644 --- a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py +++ b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py @@ -31,6 +31,7 @@ "HeadFrequencyStats", "compute_frequency_statistics_from_means", "score_keys_for_round", + "select_keys_to_keep", ] @@ -128,3 +129,48 @@ def score_keys_for_round( if aggregation == "mean": return combined.mean(dim=1) return combined.max(dim=1).values + + +def select_keys_to_keep( + scores: torch.Tensor, + *, + kv_budget: int | None = None, + target_sparsity_ratio: float | None = None, +) -> torch.Tensor: + """Select which keys to retain based on importance scores. + + Exactly one of ``kv_budget`` or ``target_sparsity_ratio`` must be provided. + + Args: + scores: Importance scores, shape (num_keys,). Higher = more important. + kv_budget: Absolute number of tokens to retain. Keeps top-K. + If budget >= num_keys, keeps all. + target_sparsity_ratio: Fraction of tokens to evict, in (0, 1). + Keeps top (1 - ratio) fraction. Example: 0.7 → keep top 30%. + + Returns: + Boolean mask, shape (num_keys,). True = keep, False = evict. + """ + budget_set = kv_budget is not None + sparsity_set = target_sparsity_ratio is not None + if budget_set == sparsity_set: + raise ValueError( + "select_keys_to_keep requires exactly one of kv_budget or target_sparsity_ratio" + ) + + num_keys = scores.shape[0] + if num_keys == 0: + return torch.zeros(0, dtype=torch.bool, device=scores.device) + + if budget_set: + k = min(kv_budget, num_keys) + else: + k = max(1, round(num_keys * (1.0 - target_sparsity_ratio))) + + if k >= num_keys: + return torch.ones(num_keys, dtype=torch.bool, device=scores.device) + + top_indices = torch.topk(scores, k=k, largest=True).indices + mask = torch.zeros(num_keys, dtype=torch.bool, device=scores.device) + mask[top_indices] = True + return mask diff --git a/tests/unit/torch/sparsity/kv_cache/test_config.py b/tests/unit/torch/sparsity/kv_cache/test_config.py index bed66a9b4a4..26c4435ba7f 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_config.py +++ b/tests/unit/torch/sparsity/kv_cache/test_config.py @@ -21,22 +21,52 @@ from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig -def test_default_config(): - """Default config creates valid instance.""" - config = TriAttentionConfig() +def test_budget_only(): + """Setting only budget is valid.""" + config = TriAttentionConfig(budget=2048) assert config.budget == 2048 - assert config.prune_interval == 128 - assert config.window_size == 128 - assert config.pruning_mode == "per_head" - assert config.score_aggregation == "mean" - assert config.offset_max_length == 65536 - assert config.disable_mlr is False - assert config.disable_trig is False - assert config.calib_size == 100000 + assert config.target_sparsity_ratio is None + + +def test_target_sparsity_only(): + """Setting only target_sparsity_ratio is valid.""" + config = TriAttentionConfig(target_sparsity_ratio=0.7) + assert config.budget is None + assert config.target_sparsity_ratio == 0.7 + + +def test_both_budget_and_sparsity_raises(): + """Setting both budget and target_sparsity_ratio raises.""" + with pytest.raises(ValidationError, match="Cannot set both"): + TriAttentionConfig(budget=2048, target_sparsity_ratio=0.7) + + +def test_neither_budget_nor_sparsity_raises(): + """Setting neither budget nor target_sparsity_ratio raises.""" + with pytest.raises(ValidationError, match="Must set exactly one"): + TriAttentionConfig() + + +def test_target_sparsity_out_of_range_low(): + """target_sparsity_ratio <= 0 raises.""" + with pytest.raises(ValidationError, match="must be in"): + TriAttentionConfig(target_sparsity_ratio=0.0) + + +def test_target_sparsity_out_of_range_high(): + """target_sparsity_ratio >= 1 raises.""" + with pytest.raises(ValidationError, match="must be in"): + TriAttentionConfig(target_sparsity_ratio=1.0) + + +def test_target_sparsity_negative(): + """Negative target_sparsity_ratio raises.""" + with pytest.raises(ValidationError): + TriAttentionConfig(target_sparsity_ratio=-0.1) def test_config_custom_values(): - """Config accepts custom values.""" + """Config accepts custom values alongside budget.""" config = TriAttentionConfig(budget=4096, prune_interval=64, window_size=256) assert config.budget == 4096 assert config.prune_interval == 64 @@ -46,31 +76,42 @@ def test_config_custom_values(): def test_config_invalid_pruning_mode(): """Invalid pruning mode raises validation error.""" with pytest.raises(ValidationError): - TriAttentionConfig(pruning_mode="invalid") + TriAttentionConfig(budget=2048, pruning_mode="invalid") def test_config_invalid_aggregation(): """Invalid score aggregation raises validation error.""" with pytest.raises(ValidationError): - TriAttentionConfig(score_aggregation="invalid") + TriAttentionConfig(budget=2048, score_aggregation="invalid") -def test_config_serialization_roundtrip(): - """Config can be serialized and deserialized.""" +def test_config_serialization_roundtrip_budget(): + """Config with budget survives serialization roundtrip.""" config = TriAttentionConfig(budget=1024, prune_interval=64) data = config.model_dump() restored = TriAttentionConfig(**data) assert restored.budget == 1024 + assert restored.target_sparsity_ratio is None + assert restored.prune_interval == 64 + + +def test_config_serialization_roundtrip_sparsity(): + """Config with target_sparsity_ratio survives serialization roundtrip.""" + config = TriAttentionConfig(target_sparsity_ratio=0.5, prune_interval=64) + data = config.model_dump() + restored = TriAttentionConfig(**data) + assert restored.budget is None + assert restored.target_sparsity_ratio == 0.5 assert restored.prune_interval == 64 def test_config_per_layer_per_head_mode(): """per_layer_per_head is a valid pruning mode.""" - config = TriAttentionConfig(pruning_mode="per_layer_per_head") + config = TriAttentionConfig(budget=2048, pruning_mode="per_layer_per_head") assert config.pruning_mode == "per_layer_per_head" def test_config_max_aggregation(): """max is a valid score aggregation.""" - config = TriAttentionConfig(score_aggregation="max") + config = TriAttentionConfig(budget=2048, score_aggregation="max") assert config.score_aggregation == "max" diff --git a/tests/unit/torch/sparsity/kv_cache/test_conversion.py b/tests/unit/torch/sparsity/kv_cache/test_conversion.py index 6a880a7cb84..50efa9d024e 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_conversion.py +++ b/tests/unit/torch/sparsity/kv_cache/test_conversion.py @@ -52,7 +52,7 @@ def test_convert_returns_model_and_metadata(): """Convert returns (model, metadata) without modifying weights.""" model = nn.Linear(16, 16) original_weight = model.weight.data.clone() - config = TriAttentionConfig() + config = TriAttentionConfig(budget=2048) converted_model, metadata = convert_triattention(model, config) @@ -74,7 +74,7 @@ def test_convert_metadata_contains_config(): def test_restore_returns_model(): """Restore returns the model.""" model = nn.Linear(16, 16) - config = TriAttentionConfig() + config = TriAttentionConfig(budget=2048) _, metadata = convert_triattention(model, config) restored = restore_triattention(model, config, metadata) @@ -90,3 +90,39 @@ def test_update_metadata(): update_triattention_metadata(model, config, metadata) assert metadata["triattention_config"]["budget"] == 512 + + +def test_convert_metadata_with_sparsity_ratio(): + """Metadata serializes target_sparsity_ratio when set.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(target_sparsity_ratio=0.7) + + _, metadata = convert_triattention(model, config) + + serialized = metadata["triattention_config"] + assert serialized["target_sparsity_ratio"] == 0.7 + assert serialized["budget"] is None + + +def test_convert_metadata_with_budget(): + """Metadata has budget set and target_sparsity_ratio None.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=1024) + + _, metadata = convert_triattention(model, config) + + serialized = metadata["triattention_config"] + assert serialized["budget"] == 1024 + assert serialized["target_sparsity_ratio"] is None + + +def test_update_metadata_with_sparsity_ratio(): + """update_triattention_metadata serializes target_sparsity_ratio.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(target_sparsity_ratio=0.5) + metadata = {} + + update_triattention_metadata(model, config, metadata) + + assert metadata["triattention_config"]["target_sparsity_ratio"] == 0.5 + assert metadata["triattention_config"]["budget"] is None diff --git a/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py index 922a2a7d550..e9d0c11ea42 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py +++ b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py @@ -25,7 +25,7 @@ def test_sparsify_returns_model(): """sparsify() returns the model.""" model = nn.Linear(16, 16) - result = sparsify(model, TriAttentionConfig()) + result = sparsify(model, TriAttentionConfig(budget=2048)) assert result is model @@ -40,14 +40,14 @@ def test_sparsify_preserves_weights(): """sparsify() does not modify model weights.""" model = nn.Linear(16, 16) original_weight = model.weight.data.clone() - sparsify(model, TriAttentionConfig()) + sparsify(model, TriAttentionConfig(budget=2048)) torch.testing.assert_close(model.weight.data, original_weight) def test_calibrate_returns_model(): """calibrate() returns the model.""" model = nn.Linear(16, 16) - sparsify(model, TriAttentionConfig()) + sparsify(model, TriAttentionConfig(budget=2048)) result = calibrate(model) assert result is model diff --git a/tests/unit/torch/sparsity/kv_cache/test_scoring.py b/tests/unit/torch/sparsity/kv_cache/test_scoring.py index 13e89c7e5d1..903820ff084 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_scoring.py +++ b/tests/unit/torch/sparsity/kv_cache/test_scoring.py @@ -15,6 +15,7 @@ """Tests for TriAttention trigonometric scoring.""" +import pytest import torch from modelopt.torch.sparsity.kv_cache.triattention.rope_utils import build_geometric_offsets @@ -22,6 +23,7 @@ HeadFrequencyStats, compute_frequency_statistics_from_means, score_keys_for_round, + select_keys_to_keep, ) @@ -184,3 +186,64 @@ def test_head_frequency_stats_dataclass(): ) assert stats.q_mean_complex.shape == (8,) assert stats.q_abs_mean.shape == (8,) + + +def test_select_keys_top_k_basic(): + """Top-K selection keeps highest-scoring tokens.""" + scores = torch.tensor([0.1, 0.9, 0.3, 0.8, 0.2]) + mask = select_keys_to_keep(scores, kv_budget=2) + # indices 1 (0.9) and 3 (0.8) should be kept + assert mask.dtype == torch.bool + assert mask.sum().item() == 2 + assert mask[1].item() is True + assert mask[3].item() is True + + +def test_select_keys_top_k_exceeds_size(): + """Budget larger than input keeps all tokens.""" + scores = torch.tensor([0.1, 0.9, 0.3]) + mask = select_keys_to_keep(scores, kv_budget=10) + assert mask.all() + assert mask.shape == scores.shape + + +def test_select_keys_percentile_basic(): + """Percentile selection evicts target fraction.""" + scores = torch.arange(10, dtype=torch.float32) + # sparsity=0.7 → evict 70%, keep top 30% (3 tokens) + mask = select_keys_to_keep(scores, target_sparsity_ratio=0.7) + assert mask.dtype == torch.bool + assert mask.sum().item() == 3 + # Top 3 by score are indices 7, 8, 9 + assert mask[7].item() is True + assert mask[8].item() is True + assert mask[9].item() is True + + +def test_select_keys_percentile_half(): + """50% sparsity keeps half the tokens.""" + scores = torch.arange(20, dtype=torch.float32) + mask = select_keys_to_keep(scores, target_sparsity_ratio=0.5) + assert mask.sum().item() == 10 + + +def test_select_keys_both_raises(): + """Setting both budget and target_sparsity_ratio raises.""" + scores = torch.rand(10) + with pytest.raises(ValueError, match="exactly one"): + select_keys_to_keep(scores, kv_budget=5, target_sparsity_ratio=0.5) + + +def test_select_keys_neither_raises(): + """Setting neither budget nor target_sparsity_ratio raises.""" + scores = torch.rand(10) + with pytest.raises(ValueError, match="exactly one"): + select_keys_to_keep(scores) + + +def test_select_keys_empty_scores(): + """Empty score tensor returns empty mask.""" + scores = torch.tensor([], dtype=torch.float32) + mask = select_keys_to_keep(scores, kv_budget=5) + assert mask.numel() == 0 + assert mask.dtype == torch.bool From 77768611366c27705281f8861b103c72642e5894 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 12 May 2026 21:59:11 -0700 Subject: [PATCH 9/9] Remove target_sparsity_ratio mode Signed-off-by: Kai Xu --- .../kv_cache_sparsity/hf_triattention.py | 40 +++------------ examples/speculative_decoding/eagle_utils.py | 2 +- modelopt/torch/sparsity/kv_cache/config.py | 42 +++++---------- .../sparsity/kv_cache/triattention/scoring.py | 25 +++------ .../torch/sparsity/kv_cache/test_config.py | 51 ++++--------------- .../sparsity/kv_cache/test_conversion.py | 27 +--------- .../torch/sparsity/kv_cache/test_scoring.py | 36 +++---------- 7 files changed, 49 insertions(+), 174 deletions(-) diff --git a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py index 19b9d95398a..7d2103d4a53 100644 --- a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py +++ b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py @@ -27,9 +27,6 @@ # Fixed-size budget (retain top-K tokens per head) python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 2048 - # Percentile-based eviction (evict 70% at each prune step) - python hf_triattention.py --model Qwen/Qwen3-0.6B --target-sparsity-ratio 0.7 - # With custom budget and calibration length python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 1024 --calib-seq-len 4096 @@ -119,20 +116,11 @@ def main(args): print(f" Calibration complete in {elapsed:.1f}s") # Step 2: Apply KV cache sparsity mode - if args.target_sparsity_ratio is not None: - print( - f"\nApplying TriAttention mode (target_sparsity_ratio={args.target_sparsity_ratio})..." - ) - triattention_config = TriAttentionConfig( - target_sparsity_ratio=args.target_sparsity_ratio, - prune_interval=args.prune_interval, - ) - else: - print(f"\nApplying TriAttention mode (budget={args.budget})...") - triattention_config = TriAttentionConfig( - budget=args.budget, - prune_interval=args.prune_interval, - ) + print(f"\nApplying TriAttention mode (budget={args.budget})...") + triattention_config = TriAttentionConfig( + budget=args.budget, + prune_interval=args.prune_interval, + ) model = mtskv.sparsify(model, triattention_config) print(" Mode applied (no-op on weights).") @@ -184,22 +172,12 @@ def main(args): default="Qwen/Qwen3-0.6B", help="HuggingFace model name or local path.", ) - policy = parser.add_mutually_exclusive_group() - policy.add_argument( + parser.add_argument( "--budget", type=int, - default=None, + default=2048, help="KV token budget (tokens to retain per head). " - "Mutually exclusive with --target-sparsity-ratio. " - "Defaults to 2048 if neither is set.", - ) - policy.add_argument( - "--target-sparsity-ratio", - type=float, - default=None, - help="Fraction of tokens to evict at each prune step, in (0, 1). " - "Example: 0.7 evicts 70%% of tokens (keeps top 30%% by score). " - "Mutually exclusive with --budget.", + "Compression triggers after --prune-interval additional tokens.", ) parser.add_argument( "--prune-interval", @@ -229,6 +207,4 @@ def main(args): ) args = parser.parse_args() - if args.budget is None and args.target_sparsity_ratio is None: - args.budget = 2048 main(args) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index a0a28d78a7f..2c99c6171e2 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -25,7 +25,6 @@ import transformers from datasets import load_dataset from packaging.version import Version -from scripts.ar_validate import validate_ar from transformers import Trainer, TrainerCallback import modelopt @@ -41,6 +40,7 @@ ShardedDataset, VisionLanguageDataCollator, ) +from scripts.ar_validate import validate_ar try: import wandb diff --git a/modelopt/torch/sparsity/kv_cache/config.py b/modelopt/torch/sparsity/kv_cache/config.py index 344bfb2decb..475394b4b11 100644 --- a/modelopt/torch/sparsity/kv_cache/config.py +++ b/modelopt/torch/sparsity/kv_cache/config.py @@ -33,29 +33,19 @@ class TriAttentionConfig(ModeloptBaseConfig): pre-RoPE Q/K concentration. Calibration computes per-head frequency statistics; at runtime, the serving engine scores and evicts tokens periodically. - Exactly one of ``budget`` or ``target_sparsity_ratio`` must be set: - - - ``budget``: absolute token count to retain per head (fixed-size cache). - - ``target_sparsity_ratio``: fraction of tokens to evict at each pruning step. - Cache size auto-scales with generation length. Value in (0, 1). + ``budget`` is the absolute token count to retain per head after each pruning + round. Runtime compression follows slack/sawtooth semantics: grow until + ``budget + prune_interval`` tokens are present, then evict back to + ``budget``. """ - # Eviction policy (exactly one must be set) + # Eviction policy budget: int | None = ModeloptField( default=None, title="KV token budget (absolute).", description=( - "Number of KV tokens to retain per head after pruning. " - "Mutually exclusive with target_sparsity_ratio." - ), - ) - target_sparsity_ratio: float | None = ModeloptField( - default=None, - title="Target sparsity ratio (percentile-based).", - description=( - "Fraction of tokens to evict at each pruning step, in (0, 1). " - "Example: 0.7 means evict 70% of tokens (keep top 30% by score). " - "Mutually exclusive with budget." + "Number of KV tokens to retain per head after pruning. Runtime " + "compression triggers after prune_interval additional tokens." ), ) @@ -127,16 +117,10 @@ def validate_score_aggregation(cls, v: str) -> str: return v @model_validator(mode="after") - def validate_budget_or_sparsity(self) -> TriAttentionConfig: - """Exactly one of budget or target_sparsity_ratio must be set.""" - budget_set = self.budget is not None - sparsity_set = self.target_sparsity_ratio is not None - if not budget_set and not sparsity_set: - raise ValueError("Must set exactly one of 'budget' or 'target_sparsity_ratio'") - if budget_set and sparsity_set: - raise ValueError("Cannot set both 'budget' and 'target_sparsity_ratio'; pick one") - if sparsity_set and not (0.0 < self.target_sparsity_ratio < 1.0): - raise ValueError( - f"target_sparsity_ratio must be in (0, 1), got {self.target_sparsity_ratio}" - ) + def validate_budget(self) -> TriAttentionConfig: + """Validate the fixed KV token budget.""" + if self.budget is None: + raise ValueError("TriAttention requires 'budget' to be set") + if self.budget <= 0: + raise ValueError(f"budget must be positive, got {self.budget}") return self diff --git a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py index e7740a029f9..657d90d0e7b 100644 --- a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py +++ b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py @@ -55,10 +55,11 @@ def compute_frequency_statistics_from_means( Args: q_mean_complex: Mean of Q in complex frequency domain, shape (freq_count,). - q_abs_mean: Mean of |Q| in frequency domain, shape (freq_count,). + q_abs_mean: Mean of ``|Q|`` in frequency domain, shape (freq_count,). k_unrot: Unrotated key vectors, shape (num_keys, head_dim). style: RoPE pairing style. - disable_mlr: If True, use q_abs_mean directly instead of (q_abs_mean - |q_mean|). + disable_mlr: If True, use ``q_abs_mean`` directly instead of + ``q_abs_mean - |q_mean|``. Returns: amp: Amplitude, shape (num_keys, freq_count). @@ -135,37 +136,27 @@ def select_keys_to_keep( scores: torch.Tensor, *, kv_budget: int | None = None, - target_sparsity_ratio: float | None = None, ) -> torch.Tensor: """Select which keys to retain based on importance scores. - Exactly one of ``kv_budget`` or ``target_sparsity_ratio`` must be provided. - Args: scores: Importance scores, shape (num_keys,). Higher = more important. kv_budget: Absolute number of tokens to retain. Keeps top-K. If budget >= num_keys, keeps all. - target_sparsity_ratio: Fraction of tokens to evict, in (0, 1). - Keeps top (1 - ratio) fraction. Example: 0.7 → keep top 30%. Returns: Boolean mask, shape (num_keys,). True = keep, False = evict. """ - budget_set = kv_budget is not None - sparsity_set = target_sparsity_ratio is not None - if budget_set == sparsity_set: - raise ValueError( - "select_keys_to_keep requires exactly one of kv_budget or target_sparsity_ratio" - ) + if kv_budget is None: + raise ValueError("select_keys_to_keep requires kv_budget") + if kv_budget <= 0: + raise ValueError(f"kv_budget must be positive, got {kv_budget}") num_keys = scores.shape[0] if num_keys == 0: return torch.zeros(0, dtype=torch.bool, device=scores.device) - if budget_set: - k = min(kv_budget, num_keys) - else: - k = max(1, round(num_keys * (1.0 - target_sparsity_ratio))) + k = min(kv_budget, num_keys) if k >= num_keys: return torch.ones(num_keys, dtype=torch.bool, device=scores.device) diff --git a/tests/unit/torch/sparsity/kv_cache/test_config.py b/tests/unit/torch/sparsity/kv_cache/test_config.py index 26c4435ba7f..4e549f4c639 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_config.py +++ b/tests/unit/torch/sparsity/kv_cache/test_config.py @@ -25,44 +25,24 @@ def test_budget_only(): """Setting only budget is valid.""" config = TriAttentionConfig(budget=2048) assert config.budget == 2048 - assert config.target_sparsity_ratio is None -def test_target_sparsity_only(): - """Setting only target_sparsity_ratio is valid.""" - config = TriAttentionConfig(target_sparsity_ratio=0.7) - assert config.budget is None - assert config.target_sparsity_ratio == 0.7 - - -def test_both_budget_and_sparsity_raises(): - """Setting both budget and target_sparsity_ratio raises.""" - with pytest.raises(ValidationError, match="Cannot set both"): +def test_target_sparsity_ratio_is_not_supported(): + """Ratio-based eviction is not part of the TriAttention API.""" + with pytest.raises(ValidationError): TriAttentionConfig(budget=2048, target_sparsity_ratio=0.7) -def test_neither_budget_nor_sparsity_raises(): - """Setting neither budget nor target_sparsity_ratio raises.""" - with pytest.raises(ValidationError, match="Must set exactly one"): +def test_missing_budget_raises(): + """TriAttention requires an explicit budget.""" + with pytest.raises(ValidationError, match="requires 'budget'"): TriAttentionConfig() -def test_target_sparsity_out_of_range_low(): - """target_sparsity_ratio <= 0 raises.""" - with pytest.raises(ValidationError, match="must be in"): - TriAttentionConfig(target_sparsity_ratio=0.0) - - -def test_target_sparsity_out_of_range_high(): - """target_sparsity_ratio >= 1 raises.""" - with pytest.raises(ValidationError, match="must be in"): - TriAttentionConfig(target_sparsity_ratio=1.0) - - -def test_target_sparsity_negative(): - """Negative target_sparsity_ratio raises.""" - with pytest.raises(ValidationError): - TriAttentionConfig(target_sparsity_ratio=-0.1) +def test_non_positive_budget_raises(): + """Budget must be positive.""" + with pytest.raises(ValidationError, match="budget must be positive"): + TriAttentionConfig(budget=0) def test_config_custom_values(): @@ -91,17 +71,6 @@ def test_config_serialization_roundtrip_budget(): data = config.model_dump() restored = TriAttentionConfig(**data) assert restored.budget == 1024 - assert restored.target_sparsity_ratio is None - assert restored.prune_interval == 64 - - -def test_config_serialization_roundtrip_sparsity(): - """Config with target_sparsity_ratio survives serialization roundtrip.""" - config = TriAttentionConfig(target_sparsity_ratio=0.5, prune_interval=64) - data = config.model_dump() - restored = TriAttentionConfig(**data) - assert restored.budget is None - assert restored.target_sparsity_ratio == 0.5 assert restored.prune_interval == 64 diff --git a/tests/unit/torch/sparsity/kv_cache/test_conversion.py b/tests/unit/torch/sparsity/kv_cache/test_conversion.py index 50efa9d024e..8cc4bf300ca 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_conversion.py +++ b/tests/unit/torch/sparsity/kv_cache/test_conversion.py @@ -92,20 +92,8 @@ def test_update_metadata(): assert metadata["triattention_config"]["budget"] == 512 -def test_convert_metadata_with_sparsity_ratio(): - """Metadata serializes target_sparsity_ratio when set.""" - model = nn.Linear(16, 16) - config = TriAttentionConfig(target_sparsity_ratio=0.7) - - _, metadata = convert_triattention(model, config) - - serialized = metadata["triattention_config"] - assert serialized["target_sparsity_ratio"] == 0.7 - assert serialized["budget"] is None - - def test_convert_metadata_with_budget(): - """Metadata has budget set and target_sparsity_ratio None.""" + """Metadata has budget set.""" model = nn.Linear(16, 16) config = TriAttentionConfig(budget=1024) @@ -113,16 +101,3 @@ def test_convert_metadata_with_budget(): serialized = metadata["triattention_config"] assert serialized["budget"] == 1024 - assert serialized["target_sparsity_ratio"] is None - - -def test_update_metadata_with_sparsity_ratio(): - """update_triattention_metadata serializes target_sparsity_ratio.""" - model = nn.Linear(16, 16) - config = TriAttentionConfig(target_sparsity_ratio=0.5) - metadata = {} - - update_triattention_metadata(model, config, metadata) - - assert metadata["triattention_config"]["target_sparsity_ratio"] == 0.5 - assert metadata["triattention_config"]["budget"] is None diff --git a/tests/unit/torch/sparsity/kv_cache/test_scoring.py b/tests/unit/torch/sparsity/kv_cache/test_scoring.py index 903820ff084..4d206cc7fb6 100644 --- a/tests/unit/torch/sparsity/kv_cache/test_scoring.py +++ b/tests/unit/torch/sparsity/kv_cache/test_scoring.py @@ -207,38 +207,18 @@ def test_select_keys_top_k_exceeds_size(): assert mask.shape == scores.shape -def test_select_keys_percentile_basic(): - """Percentile selection evicts target fraction.""" - scores = torch.arange(10, dtype=torch.float32) - # sparsity=0.7 → evict 70%, keep top 30% (3 tokens) - mask = select_keys_to_keep(scores, target_sparsity_ratio=0.7) - assert mask.dtype == torch.bool - assert mask.sum().item() == 3 - # Top 3 by score are indices 7, 8, 9 - assert mask[7].item() is True - assert mask[8].item() is True - assert mask[9].item() is True - - -def test_select_keys_percentile_half(): - """50% sparsity keeps half the tokens.""" - scores = torch.arange(20, dtype=torch.float32) - mask = select_keys_to_keep(scores, target_sparsity_ratio=0.5) - assert mask.sum().item() == 10 - - -def test_select_keys_both_raises(): - """Setting both budget and target_sparsity_ratio raises.""" +def test_select_keys_missing_budget_raises(): + """Budget is required for selection.""" scores = torch.rand(10) - with pytest.raises(ValueError, match="exactly one"): - select_keys_to_keep(scores, kv_budget=5, target_sparsity_ratio=0.5) + with pytest.raises(ValueError, match="requires kv_budget"): + select_keys_to_keep(scores) -def test_select_keys_neither_raises(): - """Setting neither budget nor target_sparsity_ratio raises.""" +def test_select_keys_non_positive_budget_raises(): + """Budget must be positive.""" scores = torch.rand(10) - with pytest.raises(ValueError, match="exactly one"): - select_keys_to_keep(scores) + with pytest.raises(ValueError, match="must be positive"): + select_keys_to_keep(scores, kv_budget=0) def test_select_keys_empty_scores():