diff --git a/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py new file mode 100644 index 00000000000..7d2103d4a53 --- /dev/null +++ b/examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example: TriAttention KV cache sparsity calibration on HuggingFace models. + +Demonstrates the TriAttention calibration pipeline: +1. Load a pretrained HF model +2. Apply KV cache sparsity mode (sparsify) +3. Run calibration with a forward pass to compute per-head frequency statistics +4. Verify calibration data was produced +5. Optionally save calibration data + +Usage: + # Fixed-size budget (retain top-K tokens per head) + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 2048 + + # With custom budget and calibration length + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 1024 --calib-seq-len 4096 + + # Save calibration data + python hf_triattention.py --model Qwen/Qwen3-0.6B --budget 2048 --output calibration.pt +""" + +import argparse +import time + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.kv_cache as mtskv +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig + + +def make_calibration_forward_loop(tokenizer, seq_len: int = 2048, input_file: str | None = None): + """Create a forward loop that runs calibration data through the model. + + Args: + tokenizer: Tokenizer for the model. + seq_len: Sequence length for calibration input. + input_file: Path to a plain text file for calibration. If None, uses + built-in placeholder text. + + Returns: + Callable that takes a model and runs a forward pass. + """ + if input_file is not None: + from pathlib import Path + + calib_text = Path(input_file).read_text(encoding="utf-8") + else: + calib_text = ( + "The quick brown fox jumps over the lazy dog. " + "Machine learning is a subset of artificial intelligence that enables systems " + "to learn and improve from experience without being explicitly programmed. " + "Deep learning, a branch of machine learning, uses neural networks with many " + "layers to model complex patterns in data. Transformers have revolutionized " + "natural language processing by introducing self-attention mechanisms that " + "allow models to weigh the importance of different parts of the input. " + ) * 100 + + input_ids = tokenizer.encode( + calib_text, return_tensors="pt", truncation=True, max_length=seq_len + ) + print(f" Calibration tokens: {input_ids.shape[1]}") + + def forward_loop(model): + device = next(model.parameters()).device + inputs = input_ids.to(device) + with torch.no_grad(): + model(input_ids=inputs) + + return forward_loop + + +def main(args): + """Run TriAttention calibration pipeline.""" + print(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model.eval() + + # Print model info + config = model.config + print(f" Layers: {config.num_hidden_layers}") + print(f" Heads: {config.num_attention_heads}") + print(f" KV heads: {getattr(config, 'num_key_value_heads', config.num_attention_heads)}") + print(f" Hidden size: {config.hidden_size}") + + # Step 1: Calibrate (before sparsify — apply_mode adds wrappers that alter forward pass) + print(f"\nRunning calibration (seq_len={args.calib_seq_len})...") + forward_loop = make_calibration_forward_loop( + tokenizer, seq_len=args.calib_seq_len, input_file=args.input + ) + + t0 = time.time() + model = mtskv.calibrate(model, forward_loop=forward_loop) + elapsed = time.time() - t0 + print(f" Calibration complete in {elapsed:.1f}s") + + # Step 2: Apply KV cache sparsity mode + print(f"\nApplying TriAttention mode (budget={args.budget})...") + triattention_config = TriAttentionConfig( + budget=args.budget, + prune_interval=args.prune_interval, + ) + model = mtskv.sparsify(model, triattention_config) + print(" Mode applied (no-op on weights).") + + # Step 3: Verify calibration data + calib_data = getattr(model, "_triattention_calibration", None) + if calib_data is None: + print("\n ERROR: No calibration data found on model!") + return + + print("\n Calibration results:") + print(f" Head dim: {calib_data.head_dim}") + print(f" RoPE style: {calib_data.rope_style}") + print(f" Num layers: {calib_data.num_layers}") + print(f" Num KV heads: {calib_data.num_kv_heads}") + print(f" Heads calibrated: {len(calib_data.head_stats)}") + + # Check concentration (Mean Resultant Length) + total_heads = 0 + concentrated = 0 + for stats in calib_data.head_stats.values(): + abs_mean = torch.abs(stats.q_mean_complex) # |E[q]| + mean_abs = stats.q_abs_mean # E[|q|] + # R = |E[q]| / E[|q|] — concentration metric + r_values = abs_mean / (mean_abs + 1e-8) + r_mean = r_values.mean().item() + total_heads += 1 + if r_mean > 0.9: + concentrated += 1 + + print(f" Concentrated heads (R > 0.9): {concentrated}/{total_heads}") + print(f" Concentration ratio: {concentrated / total_heads:.1%}") + + # Step 4: Optionally save calibration data + if args.output: + state = calib_data.state_dict() + torch.save(state, args.output) + print(f"\n Calibration data saved to: {args.output}") + + print("\nDone.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="TriAttention KV cache sparsity calibration example." + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-0.6B", + help="HuggingFace model name or local path.", + ) + parser.add_argument( + "--budget", + type=int, + default=2048, + help="KV token budget (tokens to retain per head). " + "Compression triggers after --prune-interval additional tokens.", + ) + parser.add_argument( + "--prune-interval", + type=int, + default=128, + help="Re-score and evict every N tokens.", + ) + parser.add_argument( + "--input", + type=str, + default=None, + help="Plain text file for calibration input. " + "If not provided, uses built-in placeholder text. " + "Use the same file as triattention/scripts/calibrate.py --input for comparison.", + ) + parser.add_argument( + "--calib-seq-len", + type=int, + default=2048, + help="Sequence length for calibration input.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to save calibration data (.pt file).", + ) + + args = parser.parse_args() + main(args) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index a0a28d78a7f..2c99c6171e2 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -25,7 +25,6 @@ import transformers from datasets import load_dataset from packaging.version import Version -from scripts.ar_validate import validate_ar from transformers import Trainer, TrainerCallback import modelopt @@ -41,6 +40,7 @@ ShardedDataset, VisionLanguageDataCollator, ) +from scripts.ar_validate import validate_ar try: import wandb diff --git a/modelopt/torch/sparsity/__init__.py b/modelopt/torch/sparsity/__init__.py index 2013fded1ae..55a59ae4e7b 100644 --- a/modelopt/torch/sparsity/__init__.py +++ b/modelopt/torch/sparsity/__init__.py @@ -20,5 +20,7 @@ """ # Import weight sparsity for backward compatibility +# Import kv_cache to register KV cache sparsity modes +from . import kv_cache from .weight_sparsity import mode, module, plugins from .weight_sparsity.sparsification import * diff --git a/modelopt/torch/sparsity/kv_cache/__init__.py b/modelopt/torch/sparsity/kv_cache/__init__.py new file mode 100644 index 00000000000..3229b8162d1 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""KV cache sparsity algorithms for LLM inference optimization.""" + +from . import mode +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/kv_cache/config.py b/modelopt/torch/sparsity/kv_cache/config.py new file mode 100644 index 00000000000..475394b4b11 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/config.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for KV cache sparsity modes.""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import field_validator, model_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +__all__ = ["TriAttentionConfig"] + + +class TriAttentionConfig(ModeloptBaseConfig): + """Configuration for TriAttention KV cache eviction. + + TriAttention scores cached KV entries using a trigonometric model derived from + pre-RoPE Q/K concentration. Calibration computes per-head frequency statistics; + at runtime, the serving engine scores and evicts tokens periodically. + + ``budget`` is the absolute token count to retain per head after each pruning + round. Runtime compression follows slack/sawtooth semantics: grow until + ``budget + prune_interval`` tokens are present, then evict back to + ``budget``. + """ + + # Eviction policy + budget: int | None = ModeloptField( + default=None, + title="KV token budget (absolute).", + description=( + "Number of KV tokens to retain per head after pruning. Runtime " + "compression triggers after prune_interval additional tokens." + ), + ) + + # Pruning schedule + prune_interval: int = ModeloptField( + default=128, + title="Pruning interval.", + description="Re-score and evict every N generated tokens.", + ) + window_size: int = ModeloptField( + default=128, + title="Protected window size.", + description="Number of most recent tokens always retained.", + ) + + # Scoring + pruning_mode: Literal["per_head", "per_layer_per_head"] = ModeloptField( + default="per_head", + title="Pruning mode.", + description=( + "'per_head': independent budget per KV head. " + "'per_layer_per_head': budget allocated per layer and head." + ), + ) + score_aggregation: Literal["mean", "max"] = ModeloptField( + default="mean", + title="Offset score aggregation.", + description="How to aggregate scores across geometric offsets.", + ) + offset_max_length: int = ModeloptField( + default=65536, + title="Maximum geometric offset.", + description="Offsets are [1, 2, 4, ..., offset_max_length].", + ) + disable_mlr: bool = ModeloptField( + default=False, + title="Disable MLR term.", + description="If True, disable the magnitude linear regression extra term.", + ) + disable_trig: bool = ModeloptField( + default=False, + title="Disable trigonometric term.", + description="If True, use only the additive (MLR) term for scoring.", + ) + + # Calibration + calib_size: int = ModeloptField( + default=100000, + title="Calibration tokens.", + description="Number of tokens for calibration. 50K-960K, any domain.", + ) + + @field_validator("pruning_mode") + @classmethod + def validate_pruning_mode(cls, v: str) -> str: + """Validate pruning_mode is a supported value.""" + valid = {"per_head", "per_layer_per_head"} + if v not in valid: + raise ValueError(f"pruning_mode must be one of {valid}, got '{v}'") + return v + + @field_validator("score_aggregation") + @classmethod + def validate_score_aggregation(cls, v: str) -> str: + """Validate score_aggregation is a supported value.""" + valid = {"mean", "max"} + if v not in valid: + raise ValueError(f"score_aggregation must be one of {valid}, got '{v}'") + return v + + @model_validator(mode="after") + def validate_budget(self) -> TriAttentionConfig: + """Validate the fixed KV token budget.""" + if self.budget is None: + raise ValueError("TriAttention requires 'budget' to be set") + if self.budget <= 0: + raise ValueError(f"budget must be positive, got {self.budget}") + return self diff --git a/modelopt/torch/sparsity/kv_cache/conversion.py b/modelopt/torch/sparsity/kv_cache/conversion.py new file mode 100644 index 00000000000..2dbbab39bff --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/conversion.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert/restore/update entrypoints for TriAttention mode. + +TriAttention is a calibration-only mode. Convert is a no-op on model weights. +Calibration data is stored in metadata and fused into the checkpoint at save time. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.nn as nn + + from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + + from .config import TriAttentionConfig + +__all__ = [ + "convert_triattention", + "restore_triattention", + "update_triattention_metadata", +] + + +def convert_triattention(model: nn.Module, config: TriAttentionConfig) -> ConvertReturnType: + """Apply TriAttention mode to model. + + This is a no-op on model weights. It stores the configuration in metadata + so that calibration can be run subsequently. + """ + metadata = { + "triattention_config": config.model_dump(), + } + return model, metadata + + +def restore_triattention( + model: nn.Module, config: TriAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore TriAttention mode from saved state. + + Loads calibration data from metadata if present. + """ + return model + + +def update_triattention_metadata( + model: nn.Module, config: TriAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata before saving. + + Ensures calibration data and config are current in metadata. + """ + metadata["triattention_config"] = config.model_dump() diff --git a/modelopt/torch/sparsity/kv_cache/mode.py b/modelopt/torch/sparsity/kv_cache/mode.py new file mode 100644 index 00000000000..ee89745e0d4 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/mode.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mode registration for KV cache sparsity.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import TriAttentionConfig +from .conversion import convert_triattention, restore_triattention, update_triattention_metadata + +KVCacheSparsityRegistry = _ModeRegistryCls("kv_cache_sparsity") + + +@KVCacheSparsityRegistry.register_mode +class TriAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for TriAttention KV cache sparsity. + + TriAttention is a calibration-only mode: convert is a no-op on model weights, + calibration computes per-head frequency statistics, and the results are stored + in metadata for export to serving engines. + """ + + @property + def name(self) -> str: + """Return the mode name.""" + return "triattention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Return the configuration class.""" + return TriAttentionConfig + + @property + def convert(self) -> ConvertEntrypoint: + """Return the convert entrypoint.""" + return convert_triattention + + @property + def restore(self) -> RestoreEntrypoint: + """Return the restore entrypoint.""" + return restore_triattention + + @property + def update_for_save(self) -> UpdateEntrypoint: + """Return the update-for-save entrypoint.""" + return update_triattention_metadata diff --git a/modelopt/torch/sparsity/kv_cache/model_sparsify.py b/modelopt/torch/sparsity/kv_cache/model_sparsify.py new file mode 100644 index 00000000000..18df12bcfb4 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/model_sparsify.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entry points for KV cache sparsity: sparsify() and calibrate().""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from modelopt.torch.opt.conversion import apply_mode + +from .config import TriAttentionConfig +from .mode import KVCacheSparsityRegistry + +if TYPE_CHECKING: + import torch.nn as nn + + from modelopt.torch.opt.searcher import ForwardLoop + +__all__ = ["calibrate", "sparsify"] + + +def sparsify( + model: nn.Module, + config: dict[str, Any] | TriAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> nn.Module: + """Apply KV cache sparsity optimization to a model. + + Registers the TriAttention mode on the model. Call ``calibrate()`` afterwards + to compute frequency statistics from calibration data. + + Args: + model: The model to optimize. + config: TriAttentionConfig or dict with config values. + forward_loop: Optional forward loop for integrated calibration. + + Returns: + The model with TriAttention mode applied (in-place). + """ + if isinstance(config, dict): + config = TriAttentionConfig(**config) + + model = apply_mode( + model, + mode=[("triattention", config.model_dump())], + registry=KVCacheSparsityRegistry, + ) + + if forward_loop is not None: + model = calibrate(model, config, forward_loop=forward_loop) + + return model + + +def calibrate( + model: nn.Module, + config: dict[str, Any] | TriAttentionConfig | None = None, + forward_loop: ForwardLoop | None = None, +) -> nn.Module: + """Calibrate TriAttention frequency statistics. + + Runs a forward pass with hooks to capture pre-RoPE Q vectors, inverts RoPE, + and computes per-head frequency centers. Results are stored in the model's + modelopt_state metadata. + + Args: + model: Model with TriAttention mode applied. + config: Optional config override. + forward_loop: Callable that runs forward passes on calibration data. + If None, calibration is skipped (no-op). + + Returns: + The model with calibration data stored in metadata. + """ + if forward_loop is None: + return model + + from .triattention.calibration import run_calibration + + calib_data = run_calibration(model, forward_loop=forward_loop) + + # Store calibration data in model attribute for later export + model._triattention_calibration = calib_data + + return model diff --git a/modelopt/torch/sparsity/kv_cache/triattention/__init__.py b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py new file mode 100644 index 00000000000..4190eabe1b8 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/__init__.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""TriAttention: Trigonometric KV cache compression.""" + +from .rope_utils import build_geometric_offsets, invert_rope, rotate_half, to_complex_pairs +from .scoring import ( + HeadFrequencyStats, + compute_frequency_statistics_from_means, + score_keys_for_round, + select_keys_to_keep, +) + +__all__ = [ + "HeadFrequencyStats", + "build_geometric_offsets", + "compute_frequency_statistics_from_means", + "invert_rope", + "rotate_half", + "score_keys_for_round", + "select_keys_to_keep", + "to_complex_pairs", +] diff --git a/modelopt/torch/sparsity/kv_cache/triattention/calibration.py b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py new file mode 100644 index 00000000000..fbf0d256ab6 --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/calibration.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration for TriAttention: compute per-head Q/K frequency statistics. + +Hooks into attention layers during a forward pass, captures pre-RoPE Q vectors, +inverts RoPE, converts to frequency domain, and computes per-head mean statistics. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from .rope_utils import invert_rope, rotate_half, to_complex_pairs +from .scoring import HeadFrequencyStats + +__all__ = [ + "CalibrationData", + "compute_head_stats_from_q", + "run_calibration", +] + + +@dataclass +class CalibrationData: + """Container for TriAttention calibration output. + + Stores per-head frequency statistics computed during calibration, along with + model metadata needed for scoring at inference time. + """ + + head_stats: dict[tuple[int, int], HeadFrequencyStats] # (layer, head) -> stats + head_dim: int + rope_style: str + num_layers: int + num_kv_heads: int + + def state_dict(self) -> dict[str, Any]: + """Serialize to state dict for checkpoint embedding.""" + stats_serialized = {} + for (layer, head), hs in self.head_stats.items(): + key = f"layer{layer:02d}_head{head:02d}" + stats_serialized[key] = { + "q_mean_real": hs.q_mean_complex.real.cpu(), + "q_mean_imag": hs.q_mean_complex.imag.cpu(), + "q_abs_mean": hs.q_abs_mean.cpu(), + } + return { + "metadata": { + "head_dim": self.head_dim, + "rope_style": self.rope_style, + "num_layers": self.num_layers, + "num_kv_heads": self.num_kv_heads, + "sampled_heads": [[layer, head] for layer, head in self.head_stats], + }, + "stats": stats_serialized, + } + + @classmethod + def from_state_dict(cls, state: dict[str, Any]) -> CalibrationData: + """Deserialize from state dict.""" + metadata = state["metadata"] + stats_raw = state["stats"] + sampled_heads = [tuple(pair) for pair in metadata["sampled_heads"]] + head_stats: dict[tuple[int, int], HeadFrequencyStats] = {} + for layer, head in sampled_heads: + key = f"layer{layer:02d}_head{head:02d}" + entry = stats_raw[key] + q_mean_complex = torch.complex( + entry["q_mean_real"].to(torch.float32), + entry["q_mean_imag"].to(torch.float32), + ) + q_abs_mean = entry["q_abs_mean"].to(torch.float32) + head_stats[(int(layer), int(head))] = HeadFrequencyStats( + q_mean_complex=q_mean_complex, + q_abs_mean=q_abs_mean, + ) + return cls( + head_stats=head_stats, + head_dim=metadata["head_dim"], + rope_style=metadata["rope_style"], + num_layers=metadata["num_layers"], + num_kv_heads=metadata["num_kv_heads"], + ) + + +def compute_head_stats_from_q( + q_pre_rope: torch.Tensor, + style: str = "half", +) -> HeadFrequencyStats: + """Compute frequency statistics for a single head from pre-RoPE Q vectors. + + Args: + q_pre_rope: Pre-RoPE query vectors for one head, shape (seq_len, head_dim). + style: RoPE pairing style. + + Returns: + HeadFrequencyStats with q_mean_complex and q_abs_mean. + """ + q_complex = to_complex_pairs(q_pre_rope, style=style) # (seq_len, freq_count) + q_mean_complex = q_complex.mean(dim=0) # (freq_count,) + q_abs_mean = q_complex.abs().mean(dim=0) # (freq_count,) + return HeadFrequencyStats( + q_mean_complex=q_mean_complex, + q_abs_mean=q_abs_mean, + ) + + +# --------------------------------------------------------------------------- +# Model introspection helpers +# --------------------------------------------------------------------------- + + +def _find_attention_layers(model: torch.nn.Module) -> list[torch.nn.Module]: + """Find attention sub-modules in HF model (model.model.layers[i].self_attn).""" + backbone = getattr(model, "model", model) + layer_list = getattr(backbone, "layers", None) + if layer_list is None: + raise RuntimeError( + "Cannot locate transformer layers. Expected model.model.layers attribute." + ) + layers = [] + for layer_module in layer_list: + attn = getattr(layer_module, "self_attn", None) + if attn is None: + raise RuntimeError("Layer missing self_attn attribute.") + layers.append(attn) + return layers + + +def _get_rotary_embedding(model: torch.nn.Module) -> torch.nn.Module: + """Find the rotary embedding module in the model.""" + backbone = getattr(model, "model", model) + if hasattr(backbone, "rotary_emb"): + return backbone.rotary_emb + # Some models put rotary_emb on individual attention layers + attn_layers = _find_attention_layers(model) + if attn_layers and hasattr(attn_layers[0], "rotary_emb"): + return attn_layers[0].rotary_emb + raise RuntimeError("Cannot locate rotary_emb on model.model or self_attn.") + + +def _get_model_config(model: torch.nn.Module) -> dict[str, Any]: + """Extract model config parameters needed for calibration.""" + config = getattr(model, "config", None) + if config is not None: + num_layers = getattr(config, "num_hidden_layers", None) + num_heads = getattr(config, "num_attention_heads", None) + hidden_size = getattr(config, "hidden_size", None) + head_dim = getattr(config, "head_dim", None) + num_kv_heads = getattr(config, "num_key_value_heads", num_heads) + if head_dim is None and hidden_size and num_heads: + head_dim = hidden_size // num_heads + if all(v is not None for v in [num_layers, num_heads, head_dim, num_kv_heads]): + return { + "num_layers": num_layers, + "num_heads": num_heads, + "head_dim": head_dim, + "num_kv_heads": num_kv_heads, + } + + # Fallback: infer from model structure + attn_layers = _find_attention_layers(model) + num_layers = len(attn_layers) + attn0 = attn_layers[0] + num_heads = getattr(attn0, "num_heads", None) + head_dim = getattr(attn0, "head_dim", None) + num_kv_heads = getattr(attn0, "num_key_value_heads", num_heads) + + if num_heads is None or head_dim is None: + # Infer from q_proj weight shape + q_proj = getattr(attn0, "q_proj", None) + if q_proj is not None: + out_features = q_proj.out_features + # Guess: if num_heads not available, try common head_dims + for hd in [128, 96, 64, 32]: + if out_features % hd == 0: + head_dim = head_dim or hd + num_heads = num_heads or (out_features // hd) + break + + if num_heads is None or head_dim is None: + raise RuntimeError("Cannot determine num_heads and head_dim from model.") + + num_kv_heads = num_kv_heads or num_heads + return { + "num_layers": num_layers, + "num_heads": num_heads, + "head_dim": head_dim, + "num_kv_heads": num_kv_heads, + } + + +# --------------------------------------------------------------------------- +# Main calibration function +# --------------------------------------------------------------------------- + + +def run_calibration( + model: torch.nn.Module, + forward_loop: Any = None, + rope_style: str = "half", +) -> CalibrationData: + """Run TriAttention calibration on a model. + + Hooks into attention layers, captures post-RoPE Q states during a forward + pass, inverts RoPE, and computes per-head frequency statistics. + + Args: + model: The model to calibrate. Must follow HF structure + (model.model.layers[i].self_attn with q_proj). + forward_loop: Callable that takes the model and runs forward passes + on calibration data. Signature: ``forward_loop(model) -> None``. + rope_style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + CalibrationData with per-head frequency statistics. + """ + model_cfg = _get_model_config(model) + num_layers = model_cfg["num_layers"] + num_heads = model_cfg["num_heads"] + head_dim = model_cfg["head_dim"] + num_kv_heads = model_cfg["num_kv_heads"] + + attn_layers = _find_attention_layers(model) + rotary = _get_rotary_embedding(model) + attn_scale = float(getattr(rotary, "attention_scaling", 1.0)) + + # Storage for captured Q states + captured_q: dict[int, torch.Tensor] = {} + + def _make_pre_hook(layer_idx: int): + def hook_fn(module, args, kwargs): + hidden_states = args[0] if args else kwargs.get("hidden_states") + if hidden_states is None: + return + bsz, q_len, _ = hidden_states.shape + q = module.q_proj(hidden_states) + q = q.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) + + # Apply RoPE — ensure cos/sin are on the same device as hidden_states + device = hidden_states.device + pos_ids = torch.arange(q_len, device=device).unsqueeze(0) + probe = torch.zeros(1, q_len, head_dim, device=device, dtype=hidden_states.dtype) + cos, sin = rotary(probe, pos_ids) + cos = cos.to(device=device) + sin = sin.to(device=device) + q_rot = (q * cos.unsqueeze(1)) + (rotate_half(q, style=rope_style) * sin.unsqueeze(1)) + q_rot = q_rot * attn_scale + captured_q[layer_idx] = q_rot.detach() + + return hook_fn + + # Register hooks + handles = [] + for layer_idx, attn in enumerate(attn_layers): + handle = attn.register_forward_pre_hook(_make_pre_hook(layer_idx), with_kwargs=True) + handles.append(handle) + + # Run forward pass + try: + if forward_loop is not None: + forward_loop(model) + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + # Compute per-head frequency statistics + head_stats: dict[tuple[int, int], HeadFrequencyStats] = {} + + for layer_idx in range(num_layers): + q_rot = captured_q.get(layer_idx) + if q_rot is None: + continue + + # q_rot: (batch, num_heads, seq_len, head_dim) + # Build cos/sin for RoPE inversion — ensure same device as q_rot + device = q_rot.device + seq_len = q_rot.shape[2] + pos_ids = torch.arange(seq_len, device=device).unsqueeze(0) + probe = torch.zeros(1, seq_len, head_dim, device=device, dtype=q_rot.dtype) + cos, sin = rotary(probe, pos_ids) + cos = cos.to(device=device).unsqueeze(1) # (1, 1, seq_len, head_dim) + sin = sin.to(device=device).unsqueeze(1) + + # Invert RoPE + q_base = invert_rope(q_rot, cos, sin, attn_scale, style=rope_style) + + for head_idx in range(num_heads): + q_head = q_base[0, head_idx] # (seq_len, head_dim) + head_stats[(layer_idx, head_idx)] = compute_head_stats_from_q(q_head, style=rope_style) + + # Free memory + del captured_q[layer_idx] + + return CalibrationData( + head_stats=head_stats, + head_dim=head_dim, + rope_style=rope_style, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + ) diff --git a/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py b/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py new file mode 100644 index 00000000000..70d7f59e21d --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/rope_utils.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""RoPE inversion and frequency-domain utilities for TriAttention. + +These functions support the TriAttention scoring algorithm by: +- Inverting RoPE rotations to recover pre-RoPE Q/K representations +- Converting real-valued tensors to complex frequency-domain representations +- Building geometric offset sequences for multi-distance scoring +""" + +from __future__ import annotations + +import torch + +__all__ = [ + "build_geometric_offsets", + "invert_rope", + "rotate_half", + "to_complex_pairs", +] + + +def rotate_half(x: torch.Tensor, *, style: str = "half") -> torch.Tensor: + """Rotate tensor for RoPE. Supports 'half' (front/back) and 'interleaved' (even/odd).""" + if style == "interleaved": + x_even = x[..., ::2] + x_odd = x[..., 1::2] + return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + return torch.cat((-x2, x1), dim=-1) + + +def invert_rope( + rotated: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + scale: float, + *, + style: str = "half", +) -> torch.Tensor: + """Invert RoPE rotation to recover pre-RoPE representation. + + Args: + rotated: RoPE-rotated tensor. + cos: Cosine table from rotary embedding. + sin: Sine table from rotary embedding. + scale: Attention scaling factor applied during RoPE. + style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + Pre-RoPE tensor with RoPE rotation undone. + """ + if scale == 0: + raise ValueError("attention scaling factor must be non-zero") + scale_t = torch.tensor(scale, device=rotated.device, dtype=rotated.dtype) + base = rotated / scale_t + cos_unit = cos / scale_t + sin_unit = sin / scale_t + if style == "interleaved": + even = base[..., ::2] + odd = base[..., 1::2] + cos_even = cos_unit[..., ::2] + cos_odd = cos_unit[..., 1::2] + sin_even = sin_unit[..., ::2] + sin_odd = sin_unit[..., 1::2] + det = cos_even * cos_odd + sin_even * sin_odd + det = det.clamp_min(1e-12) + orig_even = (even * cos_odd + odd * sin_even) / det + orig_odd = (odd * cos_even - even * sin_odd) / det + restored = torch.empty_like(base) + restored[..., ::2] = orig_even + restored[..., 1::2] = orig_odd + return restored + return base * cos_unit - rotate_half(base, style=style) * sin_unit + + +def to_complex_pairs(tensor: torch.Tensor, *, style: str = "half") -> torch.Tensor: + """Convert real tensor to complex representation for frequency analysis. + + Maps head_dim real values to head_dim/2 complex values. For 'half' style: + real part = first half of dimensions, imag part = second half. + + Args: + tensor: Real-valued tensor with even last dimension. + style: RoPE pairing style ('half' or 'interleaved'). + + Returns: + Complex tensor with last dimension halved. + """ + if tensor.size(-1) % 2 != 0: + raise ValueError("Head dimension must be even to form complex pairs") + real_dtype = torch.float32 if tensor.dtype in (torch.bfloat16, torch.float16) else tensor.dtype + tensor_real = tensor.to(dtype=real_dtype) + if style == "interleaved": + real = tensor_real[..., ::2].contiguous() + imag = tensor_real[..., 1::2].contiguous() + return torch.complex(real, imag) + freq_count = tensor.shape[-1] // 2 + real = tensor_real[..., :freq_count].contiguous() + imag = tensor_real[..., freq_count:].contiguous() + return torch.complex(real, imag) + + +def build_geometric_offsets(max_length: int, device: torch.device) -> torch.Tensor: + """Build geometric offset sequence [1, 2, 4, 8, ..., max_length]. + + Used for multi-distance scoring in TriAttention — each offset represents a + future distance at which the key's importance is evaluated. + + Args: + max_length: Maximum offset value (must be >= 1). + device: Device for the output tensor. + + Returns: + 1D float tensor of powers of 2 up to max_length. + """ + if max_length < 1: + raise ValueError("max_length must be >= 1") + offsets: list[float] = [] + value = 1 + while value <= max_length: + offsets.append(float(value)) + value *= 2 + return torch.tensor(offsets, device=device, dtype=torch.float32) diff --git a/modelopt/torch/sparsity/kv_cache/triattention/scoring.py b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py new file mode 100644 index 00000000000..657d90d0e7b --- /dev/null +++ b/modelopt/torch/sparsity/kv_cache/triattention/scoring.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trigonometric scoring for TriAttention KV cache compression. + +Scores cached keys by predicted future attention importance using a trigonometric +series derived from pre-RoPE Q/K concentration. See arXiv:2604.04921. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from .rope_utils import to_complex_pairs + +__all__ = [ + "HeadFrequencyStats", + "compute_frequency_statistics_from_means", + "score_keys_for_round", + "select_keys_to_keep", +] + + +@dataclass +class HeadFrequencyStats: + """Per-head calibration statistics in frequency domain.""" + + q_mean_complex: torch.Tensor # (freq_count,) complex64 + q_abs_mean: torch.Tensor # (freq_count,) float32 + + +def compute_frequency_statistics_from_means( + q_mean_complex: torch.Tensor, + q_abs_mean: torch.Tensor, + k_unrot: torch.Tensor, + *, + style: str = "half", + disable_mlr: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute amplitude, phase, and MLR extra term from Q/K frequency statistics. + + Args: + q_mean_complex: Mean of Q in complex frequency domain, shape (freq_count,). + q_abs_mean: Mean of ``|Q|`` in frequency domain, shape (freq_count,). + k_unrot: Unrotated key vectors, shape (num_keys, head_dim). + style: RoPE pairing style. + disable_mlr: If True, use ``q_abs_mean`` directly instead of + ``q_abs_mean - |q_mean|``. + + Returns: + amp: Amplitude, shape (num_keys, freq_count). + phi: Phase, shape (num_keys, freq_count). + extra: MLR extra term, shape (num_keys, freq_count). + """ + k_complex = to_complex_pairs(k_unrot, style=style) + q_mean_abs = torch.abs(q_mean_complex) + k_abs = torch.abs(k_complex) + relative = q_mean_complex.unsqueeze(0) * torch.conj(k_complex) + phi = torch.atan2(relative.imag, relative.real) + amp = q_mean_abs.unsqueeze(0) * k_abs + if disable_mlr: + extra = q_abs_mean.unsqueeze(0) * k_abs + else: + extra = (q_abs_mean - q_mean_abs).unsqueeze(0) * k_abs + return amp, phi, extra + + +def score_keys_for_round( + key_indices: torch.Tensor, + round_start: int, + amp: torch.Tensor, + phi: torch.Tensor, + omega: torch.Tensor, + extra: torch.Tensor, + offsets: torch.Tensor, + aggregation: str, + freq_scale_sq: torch.Tensor, + disable_trig: bool = False, +) -> torch.Tensor: + """Score cached keys for a single pruning round. + + Evaluates the trigonometric importance formula over multiple future offsets + and aggregates scores. + + Args: + key_indices: Position indices of cached keys, shape (num_keys,). + round_start: Current generation position. + amp: Amplitude per key per frequency, shape (num_keys, freq_count). + phi: Phase per key per frequency, shape (num_keys, freq_count). + omega: RoPE frequencies (inv_freq), shape (freq_count,). + extra: MLR extra term, shape (num_keys, freq_count). + offsets: Geometric offsets for future distance sampling, shape (num_offsets,). + aggregation: 'mean' or 'max' over offsets. + freq_scale_sq: Per-frequency scaling weights, shape (freq_count,). + disable_trig: If True, use only the additive (MLR) term. + + Returns: + Importance scores, shape (num_keys,). Higher = more important. + """ + if key_indices.numel() == 0: + return torch.empty(0, device=amp.device, dtype=torch.float32) + + base_delta = round_start - key_indices.to(device=amp.device, dtype=torch.float32) + delta_grid = base_delta.unsqueeze(1) + offsets.unsqueeze(0) # (num_keys, num_offsets) + + freq_scale_sq = freq_scale_sq.to(device=amp.device, dtype=torch.float32) + phase = delta_grid.unsqueeze(2) * omega.view(1, 1, -1) + phi.unsqueeze(1) + + cos_phase = torch.cos(phase) + scale = freq_scale_sq.view(1, 1, -1) + base_scores = (amp.unsqueeze(1) * scale * cos_phase).sum(dim=2) + + additive = (extra * freq_scale_sq.view(1, -1)).sum(dim=1, keepdim=True) + combined = additive if disable_trig else (base_scores + additive) + + if aggregation == "mean": + return combined.mean(dim=1) + return combined.max(dim=1).values + + +def select_keys_to_keep( + scores: torch.Tensor, + *, + kv_budget: int | None = None, +) -> torch.Tensor: + """Select which keys to retain based on importance scores. + + Args: + scores: Importance scores, shape (num_keys,). Higher = more important. + kv_budget: Absolute number of tokens to retain. Keeps top-K. + If budget >= num_keys, keeps all. + + Returns: + Boolean mask, shape (num_keys,). True = keep, False = evict. + """ + if kv_budget is None: + raise ValueError("select_keys_to_keep requires kv_budget") + if kv_budget <= 0: + raise ValueError(f"kv_budget must be positive, got {kv_budget}") + + num_keys = scores.shape[0] + if num_keys == 0: + return torch.zeros(0, dtype=torch.bool, device=scores.device) + + k = min(kv_budget, num_keys) + + if k >= num_keys: + return torch.ones(num_keys, dtype=torch.bool, device=scores.device) + + top_indices = torch.topk(scores, k=k, largest=True).indices + mask = torch.zeros(num_keys, dtype=torch.bool, device=scores.device) + mask[top_indices] = True + return mask diff --git a/tests/unit/torch/sparsity/kv_cache/__init__.py b/tests/unit/torch/sparsity/kv_cache/__init__.py new file mode 100644 index 00000000000..47f1c65a15f --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/torch/sparsity/kv_cache/test_calibration.py b/tests/unit/torch/sparsity/kv_cache/test_calibration.py new file mode 100644 index 00000000000..6b598a522f2 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_calibration.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention calibration.""" + +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.calibration import ( + CalibrationData, + compute_head_stats_from_q, +) +from modelopt.torch.sparsity.kv_cache.triattention.scoring import HeadFrequencyStats + + +def test_compute_head_stats_shapes(): + """Head stats computed from Q tensor have correct shapes.""" + seq_len = 64 + head_dim = 16 + freq_count = head_dim // 2 + + q_pre_rope = torch.randn(seq_len, head_dim) + stats = compute_head_stats_from_q(q_pre_rope) + + assert stats.q_mean_complex.shape == (freq_count,) + assert stats.q_mean_complex.dtype == torch.complex64 + assert stats.q_abs_mean.shape == (freq_count,) + assert stats.q_abs_mean.dtype == torch.float32 + + +def test_compute_head_stats_mean_abs_ge_abs_mean(): + """Mean of absolute values >= absolute value of mean (triangle inequality).""" + q_pre_rope = torch.randn(128, 32) + stats = compute_head_stats_from_q(q_pre_rope) + + abs_of_mean = torch.abs(stats.q_mean_complex) + assert (stats.q_abs_mean >= abs_of_mean - 1e-6).all() + + +def test_compute_head_stats_single_token(): + """Single-token input: mean equals the single value.""" + head_dim = 8 + q_pre_rope = torch.randn(1, head_dim) + stats = compute_head_stats_from_q(q_pre_rope) + + # For single token, mean_complex == the single complex value + # and abs_mean == |single complex value| + torch.testing.assert_close(stats.q_abs_mean, torch.abs(stats.q_mean_complex)) + + +def test_calibration_data_state_dict_roundtrip(): + """CalibrationData can be serialized to and restored from state dict.""" + stats = { + (0, 0): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + (0, 1): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + (1, 0): HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ), + } + calib = CalibrationData( + head_stats=stats, + head_dim=16, + rope_style="half", + num_layers=2, + num_kv_heads=2, + ) + + state = calib.state_dict() + restored = CalibrationData.from_state_dict(state) + + assert restored.head_dim == 16 + assert restored.rope_style == "half" + assert restored.num_layers == 2 + assert restored.num_kv_heads == 2 + assert len(restored.head_stats) == 3 + + for key in stats: + torch.testing.assert_close( + restored.head_stats[key].q_abs_mean, + calib.head_stats[key].q_abs_mean, + ) + torch.testing.assert_close( + restored.head_stats[key].q_mean_complex, + calib.head_stats[key].q_mean_complex, + ) + + +def test_calibration_data_state_dict_keys(): + """State dict has expected structure.""" + stats = { + (2, 3): HeadFrequencyStats( + q_mean_complex=torch.randn(4, dtype=torch.complex64), + q_abs_mean=torch.rand(4), + ), + } + calib = CalibrationData( + head_stats=stats, + head_dim=8, + rope_style="half", + num_layers=4, + num_kv_heads=8, + ) + + state = calib.state_dict() + assert "metadata" in state + assert "stats" in state + assert "layer02_head03" in state["stats"] + assert state["metadata"]["head_dim"] == 8 + assert state["metadata"]["sampled_heads"] == [[2, 3]] diff --git a/tests/unit/torch/sparsity/kv_cache/test_config.py b/tests/unit/torch/sparsity/kv_cache/test_config.py new file mode 100644 index 00000000000..4e549f4c639 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_config.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention configuration.""" + +import pytest +from pydantic import ValidationError + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig + + +def test_budget_only(): + """Setting only budget is valid.""" + config = TriAttentionConfig(budget=2048) + assert config.budget == 2048 + + +def test_target_sparsity_ratio_is_not_supported(): + """Ratio-based eviction is not part of the TriAttention API.""" + with pytest.raises(ValidationError): + TriAttentionConfig(budget=2048, target_sparsity_ratio=0.7) + + +def test_missing_budget_raises(): + """TriAttention requires an explicit budget.""" + with pytest.raises(ValidationError, match="requires 'budget'"): + TriAttentionConfig() + + +def test_non_positive_budget_raises(): + """Budget must be positive.""" + with pytest.raises(ValidationError, match="budget must be positive"): + TriAttentionConfig(budget=0) + + +def test_config_custom_values(): + """Config accepts custom values alongside budget.""" + config = TriAttentionConfig(budget=4096, prune_interval=64, window_size=256) + assert config.budget == 4096 + assert config.prune_interval == 64 + assert config.window_size == 256 + + +def test_config_invalid_pruning_mode(): + """Invalid pruning mode raises validation error.""" + with pytest.raises(ValidationError): + TriAttentionConfig(budget=2048, pruning_mode="invalid") + + +def test_config_invalid_aggregation(): + """Invalid score aggregation raises validation error.""" + with pytest.raises(ValidationError): + TriAttentionConfig(budget=2048, score_aggregation="invalid") + + +def test_config_serialization_roundtrip_budget(): + """Config with budget survives serialization roundtrip.""" + config = TriAttentionConfig(budget=1024, prune_interval=64) + data = config.model_dump() + restored = TriAttentionConfig(**data) + assert restored.budget == 1024 + assert restored.prune_interval == 64 + + +def test_config_per_layer_per_head_mode(): + """per_layer_per_head is a valid pruning mode.""" + config = TriAttentionConfig(budget=2048, pruning_mode="per_layer_per_head") + assert config.pruning_mode == "per_layer_per_head" + + +def test_config_max_aggregation(): + """max is a valid score aggregation.""" + config = TriAttentionConfig(budget=2048, score_aggregation="max") + assert config.score_aggregation == "max" diff --git a/tests/unit/torch/sparsity/kv_cache/test_conversion.py b/tests/unit/torch/sparsity/kv_cache/test_conversion.py new file mode 100644 index 00000000000..8cc4bf300ca --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_conversion.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention mode registration and conversion.""" + +import torch +import torch.nn as nn + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig +from modelopt.torch.sparsity.kv_cache.conversion import ( + convert_triattention, + restore_triattention, + update_triattention_metadata, +) +from modelopt.torch.sparsity.kv_cache.mode import KVCacheSparsityRegistry + + +def test_mode_registered(): + """TriAttention mode is registered in KVCacheSparsityRegistry.""" + assert "triattention" in KVCacheSparsityRegistry + + +def test_mode_descriptor_properties(): + """Mode descriptor has correct properties.""" + descriptor = KVCacheSparsityRegistry["triattention"] + assert descriptor.name == "triattention" + assert descriptor.config_class is TriAttentionConfig + + +def test_mode_discoverable_globally(): + """TriAttention mode is discoverable via get_from_any.""" + from modelopt.torch.opt.mode import _ModeRegistryCls + + descriptor = _ModeRegistryCls.get_from_any("triattention") + assert descriptor is not None + assert descriptor.name == "triattention" + + +def test_convert_returns_model_and_metadata(): + """Convert returns (model, metadata) without modifying weights.""" + model = nn.Linear(16, 16) + original_weight = model.weight.data.clone() + config = TriAttentionConfig(budget=2048) + + converted_model, metadata = convert_triattention(model, config) + + torch.testing.assert_close(converted_model.weight.data, original_weight) + assert "triattention_config" in metadata + + +def test_convert_metadata_contains_config(): + """Metadata stores the config values.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=1024, prune_interval=64) + + _, metadata = convert_triattention(model, config) + + assert metadata["triattention_config"]["budget"] == 1024 + assert metadata["triattention_config"]["prune_interval"] == 64 + + +def test_restore_returns_model(): + """Restore returns the model.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=2048) + _, metadata = convert_triattention(model, config) + + restored = restore_triattention(model, config, metadata) + assert restored is model + + +def test_update_metadata(): + """update_triattention_metadata updates config in metadata.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=512) + metadata = {} + + update_triattention_metadata(model, config, metadata) + + assert metadata["triattention_config"]["budget"] == 512 + + +def test_convert_metadata_with_budget(): + """Metadata has budget set.""" + model = nn.Linear(16, 16) + config = TriAttentionConfig(budget=1024) + + _, metadata = convert_triattention(model, config) + + serialized = metadata["triattention_config"] + assert serialized["budget"] == 1024 diff --git a/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py new file mode 100644 index 00000000000..e9d0c11ea42 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_model_sparsify.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention sparsify/calibrate entry API.""" + +import torch +import torch.nn as nn + +from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig +from modelopt.torch.sparsity.kv_cache.model_sparsify import calibrate, sparsify + + +def test_sparsify_returns_model(): + """sparsify() returns the model.""" + model = nn.Linear(16, 16) + result = sparsify(model, TriAttentionConfig(budget=2048)) + assert result is model + + +def test_sparsify_accepts_dict_config(): + """sparsify() accepts dict config.""" + model = nn.Linear(16, 16) + result = sparsify(model, {"budget": 1024}) + assert result is model + + +def test_sparsify_preserves_weights(): + """sparsify() does not modify model weights.""" + model = nn.Linear(16, 16) + original_weight = model.weight.data.clone() + sparsify(model, TriAttentionConfig(budget=2048)) + torch.testing.assert_close(model.weight.data, original_weight) + + +def test_calibrate_returns_model(): + """calibrate() returns the model.""" + model = nn.Linear(16, 16) + sparsify(model, TriAttentionConfig(budget=2048)) + result = calibrate(model) + assert result is model + + +def test_sparsify_then_calibrate(): + """sparsify() followed by calibrate() works without error.""" + model = nn.Linear(16, 16) + model = sparsify(model, TriAttentionConfig(budget=512)) + model = calibrate(model) + assert isinstance(model, nn.Module) diff --git a/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py b/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py new file mode 100644 index 00000000000..a61bee0af76 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_rope_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention RoPE utilities.""" + +import pytest +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.rope_utils import ( + build_geometric_offsets, + invert_rope, + rotate_half, + to_complex_pairs, +) + + +def test_rotate_half_roundtrip(): + """rotate_half applied twice returns the negated original.""" + x = torch.randn(2, 4, 8) + result = rotate_half(rotate_half(x)) + torch.testing.assert_close(result, -x) + + +def test_invert_rope_recovers_original(): + """Inverting RoPE-rotated tensor recovers the pre-RoPE original.""" + head_dim = 16 + seq_len = 8 + x = torch.randn(1, seq_len, head_dim) + + freqs = torch.arange(head_dim // 2, dtype=torch.float32) / head_dim + positions = torch.arange(seq_len, dtype=torch.float32).unsqueeze(-1) + angles = positions * freqs + cos = torch.cos(angles).repeat(1, 2).unsqueeze(0) + sin = torch.sin(angles).repeat(1, 2).unsqueeze(0) + scale = 1.0 + + rotated = x * cos + rotate_half(x) * sin + recovered = invert_rope(rotated * scale, cos, sin, scale) + torch.testing.assert_close(recovered, x, atol=1e-5, rtol=1e-5) + + +def test_invert_rope_zero_scale_raises(): + """Zero scale raises ValueError.""" + with pytest.raises(ValueError, match="non-zero"): + invert_rope(torch.randn(1, 4, 8), torch.ones(1, 4, 8), torch.ones(1, 4, 8), 0.0) + + +def test_to_complex_pairs_shape(): + """Complex pairs halves the last dimension.""" + c = to_complex_pairs(torch.randn(4, 16)) + assert c.shape == (4, 8) + assert c.dtype == torch.complex64 + + +def test_to_complex_pairs_values(): + """Half style: real = first half, imag = second half.""" + x = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]) + c = to_complex_pairs(x) + assert c[0, 0].real == 1.0 + assert c[0, 0].imag == 5.0 + assert c[0, 3].real == 4.0 + assert c[0, 3].imag == 8.0 + + +def test_to_complex_pairs_odd_dim_raises(): + """Odd head dimension raises ValueError.""" + with pytest.raises(ValueError, match="even"): + to_complex_pairs(torch.randn(4, 7)) + + +def test_build_geometric_offsets(): + """Geometric offsets are powers of 2 up to max_length.""" + offsets = build_geometric_offsets(16, torch.device("cpu")) + expected = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0]) + torch.testing.assert_close(offsets, expected) + + +def test_build_geometric_offsets_single(): + """Single offset when max_length is 1.""" + offsets = build_geometric_offsets(1, torch.device("cpu")) + torch.testing.assert_close(offsets, torch.tensor([1.0])) + + +def test_build_geometric_offsets_zero_raises(): + """max_length < 1 raises ValueError.""" + with pytest.raises(ValueError, match="must be >= 1"): + build_geometric_offsets(0, torch.device("cpu")) diff --git a/tests/unit/torch/sparsity/kv_cache/test_scoring.py b/tests/unit/torch/sparsity/kv_cache/test_scoring.py new file mode 100644 index 00000000000..4d206cc7fb6 --- /dev/null +++ b/tests/unit/torch/sparsity/kv_cache/test_scoring.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TriAttention trigonometric scoring.""" + +import pytest +import torch + +from modelopt.torch.sparsity.kv_cache.triattention.rope_utils import build_geometric_offsets +from modelopt.torch.sparsity.kv_cache.triattention.scoring import ( + HeadFrequencyStats, + compute_frequency_statistics_from_means, + score_keys_for_round, + select_keys_to_keep, +) + + +def test_compute_frequency_statistics_shapes(): + """Frequency statistics have correct shapes.""" + freq_count = 8 + seq_len = 16 + head_dim = freq_count * 2 + + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count) + k_unrot = torch.randn(seq_len, head_dim) + + amp, phi, extra = compute_frequency_statistics_from_means(q_mean_complex, q_abs_mean, k_unrot) + assert amp.shape == (seq_len, freq_count) + assert phi.shape == (seq_len, freq_count) + assert extra.shape == (seq_len, freq_count) + + +def test_compute_frequency_statistics_amplitude_positive(): + """Amplitude is non-negative (product of absolute values).""" + freq_count = 4 + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count).abs() + 0.1 + k_unrot = torch.randn(8, freq_count * 2) + + amp, _, _ = compute_frequency_statistics_from_means(q_mean_complex, q_abs_mean, k_unrot) + assert (amp >= 0).all() + + +def test_compute_frequency_statistics_disable_mlr(): + """With disable_mlr=True, extra uses q_abs_mean directly.""" + freq_count = 4 + q_mean_complex = torch.randn(freq_count, dtype=torch.complex64) + q_abs_mean = torch.rand(freq_count) + 1.0 + k_unrot = torch.randn(8, freq_count * 2) + + _, _, extra_normal = compute_frequency_statistics_from_means( + q_mean_complex, q_abs_mean, k_unrot, disable_mlr=False + ) + _, _, extra_disabled = compute_frequency_statistics_from_means( + q_mean_complex, q_abs_mean, k_unrot, disable_mlr=True + ) + assert not torch.allclose(extra_normal, extra_disabled) + + +def test_score_keys_for_round_shape(): + """Score output matches number of keys.""" + num_keys = 32 + freq_count = 8 + key_indices = torch.arange(num_keys) + amp = torch.rand(num_keys, freq_count) + phi = torch.rand(num_keys, freq_count) + omega = torch.rand(freq_count, dtype=torch.float64) + extra = torch.rand(num_keys, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores = score_keys_for_round( + key_indices, + round_start=64, + amp=amp, + phi=phi, + omega=omega, + extra=extra, + offsets=offsets, + aggregation="mean", + freq_scale_sq=freq_scale_sq, + ) + assert scores.shape == (num_keys,) + + +def test_score_keys_empty(): + """Empty key set returns empty scores.""" + scores = score_keys_for_round( + key_indices=torch.tensor([], dtype=torch.long), + round_start=100, + amp=torch.empty(0, 4), + phi=torch.empty(0, 4), + omega=torch.rand(4, dtype=torch.float64), + extra=torch.empty(0, 4), + offsets=build_geometric_offsets(16, torch.device("cpu")), + aggregation="mean", + freq_scale_sq=torch.ones(4), + ) + assert scores.numel() == 0 + + +def test_score_aggregation_mean_vs_max(): + """Mean and max aggregation produce different results.""" + num_keys = 10 + freq_count = 4 + key_indices = torch.arange(num_keys) + amp = torch.rand(num_keys, freq_count) + phi = torch.rand(num_keys, freq_count) + omega = torch.rand(freq_count, dtype=torch.float64) + extra = torch.rand(num_keys, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores_mean = score_keys_for_round( + key_indices, + 50, + amp, + phi, + omega, + extra, + offsets, + "mean", + freq_scale_sq, + ) + scores_max = score_keys_for_round( + key_indices, + 50, + amp, + phi, + omega, + extra, + offsets, + "max", + freq_scale_sq, + ) + assert not torch.allclose(scores_mean, scores_max) + + +def test_score_keys_disable_trig(): + """With disable_trig=True, scores are position-independent (additive only).""" + freq_count = 4 + # Two keys at very different positions + key_indices = torch.tensor([0, 99]) + # Large amplitude so trig term dominates when enabled + amp = torch.ones(2, freq_count) * 10.0 + phi = torch.zeros(2, freq_count) + omega = torch.tensor([0.1, 0.5, 1.0, 2.0], dtype=torch.float64) + extra = torch.ones(2, freq_count) + offsets = build_geometric_offsets(16, torch.device("cpu")) + freq_scale_sq = torch.ones(freq_count) + + scores_no_trig = score_keys_for_round( + key_indices, + 100, + amp, + phi, + omega, + extra, + offsets, + "mean", + freq_scale_sq, + disable_trig=True, + ) + # Without trig, both keys get the same score (additive term is position-independent) + torch.testing.assert_close(scores_no_trig[0], scores_no_trig[1]) + + +def test_head_frequency_stats_dataclass(): + """HeadFrequencyStats holds correct fields.""" + stats = HeadFrequencyStats( + q_mean_complex=torch.randn(8, dtype=torch.complex64), + q_abs_mean=torch.rand(8), + ) + assert stats.q_mean_complex.shape == (8,) + assert stats.q_abs_mean.shape == (8,) + + +def test_select_keys_top_k_basic(): + """Top-K selection keeps highest-scoring tokens.""" + scores = torch.tensor([0.1, 0.9, 0.3, 0.8, 0.2]) + mask = select_keys_to_keep(scores, kv_budget=2) + # indices 1 (0.9) and 3 (0.8) should be kept + assert mask.dtype == torch.bool + assert mask.sum().item() == 2 + assert mask[1].item() is True + assert mask[3].item() is True + + +def test_select_keys_top_k_exceeds_size(): + """Budget larger than input keeps all tokens.""" + scores = torch.tensor([0.1, 0.9, 0.3]) + mask = select_keys_to_keep(scores, kv_budget=10) + assert mask.all() + assert mask.shape == scores.shape + + +def test_select_keys_missing_budget_raises(): + """Budget is required for selection.""" + scores = torch.rand(10) + with pytest.raises(ValueError, match="requires kv_budget"): + select_keys_to_keep(scores) + + +def test_select_keys_non_positive_budget_raises(): + """Budget must be positive.""" + scores = torch.rand(10) + with pytest.raises(ValueError, match="must be positive"): + select_keys_to_keep(scores, kv_budget=0) + + +def test_select_keys_empty_scores(): + """Empty score tensor returns empty mask.""" + scores = torch.tensor([], dtype=torch.float32) + mask = select_keys_to_keep(scores, kv_budget=5) + assert mask.numel() == 0 + assert mask.dtype == torch.bool