Skip to content

Add TriAttention For KV Cache Compression#1216

Draft
kaix-nv wants to merge 9 commits into
mainfrom
kaix/triattn
Draft

Add TriAttention For KV Cache Compression#1216
kaix-nv wants to merge 9 commits into
mainfrom
kaix/triattn

Conversation

@kaix-nv

@kaix-nv kaix-nv commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: ?

New feature. Adds TriAttention KV cache sparsity as a new calibration-only mode under modelopt.torch.sparsity.kv_cache. TriAttention scores cached KV entries using a trigonometric model derived from pre-RoPE Q/K concentration (arXiv:2604.04921), enabling KV cache compression at inference time with calibration only.

This PR includes:

  • Core algorithm: RoPE inversion, trigonometric scoring, frequency statistics computation
  • ModelOpt integration: Mode registration (triattention), Pydantic config (TriAttentionConfig), convert/restore
    entrypoints, sparsify() and calibrate() entry API under modelopt.torch.sparsity.kv_cache
  • GPU calibration: Hooks into attention layers during forward pass to capture pre-RoPE Q vectors and compute per-head
    frequency statistics
  • End-to-end example: examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py

Usage

import modelopt.torch.sparsity.kv_cache as mtskv
from modelopt.torch.sparsity.kv_cache.config import TriAttentionConfig

# Computes per-head Q/K frequency statistics
model = mtskv.calibrate(model, forward_loop=calib_forward_loop)

# Apply mode
model = mtskv.sparsify(model, TriAttentionConfig(budget=2048, prune_interval=128))

python examples/llm_sparsity/kv_cache_sparsity/hf_triattention.py \
  --model Qwen/Qwen3-8B \
  --input calibration_text.txt \
  --calib-seq-len 2048 \
  --output calibration.pt

Testing

Comparison to the original implementation:

┌──────────────────────┬────────────┬──────────┐
│        Metric        │ Qwen3-0.6B │ Qwen3-8B │
├──────────────────────┼────────────┼──────────┤
│ Mean absolute diff   │ 0.005      │ 0.002    │
├──────────────────────┼────────────┼──────────┤
│ Median relative diff │ 0.14%      │ 0.09%    │
├──────────────────────┼────────────┼──────────┤
│ Mean relative diff   │ 0.63%      │ 0.35%    │
└──────────────────────┴────────────┴──────────┘
pareto

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

@copy-pr-bot

copy-pr-bot Bot commented Apr 9, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9f31e6e5-d2bf-4922-95a3-819295327ce7

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triattn

Comment @coderabbitai help to get the list of available commands and usage tips.

@kaix-nv kaix-nv changed the title Kaix/triattn Add TriAttention For KV Cache Compression Apr 9, 2026
@github-actions

github-actions Bot commented Apr 9, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1216/

Built to branch gh-pages at 2026-05-13 05:21 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Apr 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 62.57485% with 125 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.69%. Comparing base (62401e1) to head (7776861).

Files with missing lines Patch % Lines
...orch/sparsity/kv_cache/triattention/calibration.py 30.00% 98 Missing ⚠️
...torch/sparsity/kv_cache/triattention/rope_utils.py 62.96% 20 Missing ⚠️
modelopt/torch/sparsity/kv_cache/model_sparsify.py 75.00% 5 Missing ⚠️
modelopt/torch/sparsity/kv_cache/mode.py 90.90% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1216      +/-   ##
==========================================
- Coverage   76.78%   76.69%   -0.10%     
==========================================
  Files         473      482       +9     
  Lines       51413    51747     +334     
==========================================
+ Hits        39476    39685     +209     
- Misses      11937    12062     +125     
Flag Coverage Δ
unit 52.62% <62.57%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

kaix-nv added 8 commits May 12, 2026 21:59
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant