Skip to content

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382

Open
Fridah-nv wants to merge 6 commits into
mainfrom
fridah/fused-moe-MSE-fix
Open

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382
Fridah-nv wants to merge 6 commits into
mainfrom
fridah/fused-moe-MSE-fix

Conversation

@Fridah-nv
Copy link
Copy Markdown
Contributor

@Fridah-nv Fridah-nv commented May 2, 2026

What does this PR do?

Type of change: Bug fix

Fixes several issues with NVFP4 MSE calibration and export for fused MoE expert modules (_QuantFusedExperts — used by Qwen3.6, GLM-5.1, and other HF transformers 5.0+ models that store expert weights as 3-D nn.Parameters).

  • Bug 1 — MSE weight calibration runs 0 iterations for fused experts (model_calib.py)

The weight-quantizer discovery loop in mse_calibrate used the singular attribute name gate_up_proj_weight_quantizer to look up quantizers, but _QuantFusedExperts stores them in a plural nn.ModuleList named gate_up_proj_weight_quantizers. All 20,480 expert quantizers were silently skipped, resulting in "MSE weight calibration: 0it" and no MSE-optimized scales.

Fix: add a second pass that detects plural {param}_weight_quantizers ModuleLists and enqueues each per-expert quantizer with a (param_name, expert_idx) tuple; step 3 unpacks the tuple to extract the per-expert weight slice.

  • Bug 2 — Zero weight scales in exported checkpoint (nvfp4_tensor.py)

Per-block weight scales can silently underflow to 0 when cast to FP8 E4M3FN. The existing scale == 0 guard only catches exact float32 zeros; values in (0, 2^-9) pass through and become 0 after the FP8 cast. This affects both the dynamic recompute path (get_weights_scaling_factor) and the static calibrated path (get_weights_scaling_factor_from_quantizer).

Fix: clamp per-block scales to 2^-9 (smallest positive FP8 E4M3FN subnormal) before the FP8 cast in both paths.

  • Bug 3 — Zero/corrupt amax for uncalibrated experts at export (moe_utils.py)

Experts that receive no tokens during calibration have _amax = 0 or uninitialized values. The existing scalar fallback used 1e-4 which itself underflows to 0 in FP8 E4M3FN (1e-4 < 2^-9 ≈ 0.00195). Additionally, the per-block fallback tensor had shape (H*W, 1) instead of (H, W), causing a shape mismatch that silently bypassed the fallback and fell through to the bad scalar. Finally, a stale zero global_amax from an uncalibrated expert was not recomputed, causing division-by-zero in the FP8 scale formula.

Fix: reshape the per-block fallback correctly; raise the clamp floor to 2e-3; always recompute global_amax from the current (possibly patched) per-block _amax.

Additional fixes:

  • moe_utils.py: safe CPU extraction of _amax before deepcopy to avoid async CUDA errors from corrupt bfloat16 amax storage on under-calibrated experts.
  • model_quant.py: print_quant_summary now calls os.makedirs(output_dir, exist_ok=True) before writing .quant_summary.txt, preventing a FileNotFoundError when the export directory doesn't exist yet.
  • tensor_quantizer.py: change default format in _short_amax / _short_tensor from ".4f" to ".2e" so small amax values (e.g. 3.5e-7) display as 3.50e-07 instead of 0.0000.
  • hf_ptq.py: strip leading pad tokens from the preview input and add skip_special_tokens=True to input_decode, fixing degenerate pre/post-PTQ output on models that use EOS as the pad token (e.g. Qwen3).

Usage

 # Quantize Qwen3.6-35B-A3B (or any compatible fused-expert MoE) with the new recipe:
  python examples/llm_ptq/hf_ptq.py \                                                                                                     
      --pyt_ckpt_path /path/to/Qwen3.6-35B-A3B \                                                                                          
      --recipe modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml \                                                                 
      --export_path /path/to/output \                                                                                                     
      --calib_size 512 --calib_seq 2048   

Testing

validated on Qwen3.6-35B-A3B (8× B200):

  • 21,740 quantizers inserted; 20,480/20,480 MSE weight calibrations completed (~11 min)
  • 0 / 2,013,265,920 zero weight_scale entries in the exported checkpoint (3 shards)
  • Pre- and post-PTQ generation produce coherent, semantically consistent output

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a new NVFP4 quantization recipe for expert layers with MSE-based calibration.
  • Bug Fixes

    • Fixed FP8 scale underflow handling to prevent zero scaling factors.
    • Fixed output directory creation for quantization summaries.
  • Improvements

    • Enhanced preview input handling for language models by removing padding tokens.
    • Improved quantizer display precision for better readability.

Review Change Stack

@Fridah-nv Fridah-nv requested review from a team as code owners May 2, 2026 00:14
@Fridah-nv Fridah-nv requested review from Edwardf0t1 and sychen52 May 2, 2026 00:14
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR improves NVFP4 quantization robustness by adding FP8 per-block scale underflow prevention, introduces an expert-focused quantization recipe, updates quantization display formatting, and expands test coverage. It also refines LLM PTQ preview input handling for non-Whisper models.

Changes

NVFP4 and Expert Quantization Improvements

Layer / File(s) Summary
FP8 Scale Clamping Helper and Integration
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
NVFP4QTensor._cast_per_block_scale_to_fp8 is introduced to clamp per-block scales to the FP8 E4M3FN representable range (min=2**-9, max=448.0) before casting. Both static-quantizer and dynamic-quantizer code paths are updated to use this helper instead of incomplete prior clamping logic, preventing underflow-to-zero failures.
NVFP4 Scale Clamping Test Coverage
tests/unit/torch/quantization/test_nvfp4_tensor.py
Test module adds TestNVFP4ScaleClamping with six test methods covering tiny-weight underflow clamping, normal-weight unaffectedness, mixed-magnitude tensor handling, helper overflow/underflow clamping boundaries, and a static-path regression test for zero-amax blocks, ensuring all per-block scales remain strictly positive and finite.
Quantization Display and Directory Handling
modelopt/torch/quantization/nn/modules/tensor_quantizer.py, modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/model_quant.py
Format strings in _short_amax and _short_tensor change from .4f to .2e scientific notation; entropy condition simplified from set membership to equality check; output directory is created before writing .quant_summary.txt.
NVFP4 Expert-Only Quantization Recipe
modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
New PTQ recipe configuration defines MSE-based NVFP4 quantization for expert layers only (W4A4), enabling FP8 scale sweep, disabling layerwise quantization, and selectively enabling sequential and block-sparse MoE expert weight/input quantizers.
Fused Experts Export and MSE Calibration Tests
tests/unit/torch/quantization/plugins/test_fused_experts.py
Rewrites test_export_creates_per_expert_submodules to quantize end-to-end and validate per-expert submodule creation with expected shapes and attribute cleanup; adds TestFusedExpertsMSECalibration class with MSE calibration test verifying all per-expert quantizer amax values are populated; introduces shared _cleanup_registry helper.
LLM PTQ Preview Input Improvements
examples/llm_ptq/hf_ptq.py
Preview input IDs are adjusted for non-Whisper models by stripping leading padding tokens when available; tokenizer decoding now uses skip_special_tokens=True for cleaner preview text.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title mentions 'fused moe' and 'MSE calibration' which are central to the PR's objectives, but uses shorthand notation (qwen3.6, GLM5.1) and lacks clarity about the specific nature of the fixes being applied. Consider revising to be more explicit about the primary fix, e.g., 'Fix NVFP4 weight scale underflow and MSE calibration for fused MoE experts' or similar, to make the core issue clearer to reviewers scanning the history.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 80.77% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All modified files pass security review against SECURITY.md. No unsafe deserialization, hardcoded trust_remote_code, eval/exec, # nosec comments, or new non-permissive dependencies found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/fused-moe-MSE-fix

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1382/

Built to branch gh-pages at 2026-05-12 23:11 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 98-103: The temporary mutation of w_quantizer_src._amax before
calling copy.deepcopy may leave the source quantizer with _amax == None if
deepcopy raises; change the code around copy.deepcopy(w_quantizer_src) to save
_saved_amax, set w_quantizer_src._amax = None, then perform deepcopy inside a
try block and restore w_quantizer_src._amax = _saved_amax in a finally block;
after deepcopy set w_quantizer._amax = gu_amax_cpu as before so the source state
is always restored even on exceptions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c7efeb50-0d25-4ef7-8b84-e1a0a66662b4

📥 Commits

Reviewing files that changed from the base of the PR and between 9d2e608 and 35dad9a.

📒 Files selected for processing (7)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml

Comment thread modelopt/torch/export/moe_utils.py Outdated
@Fridah-nv Fridah-nv requested a review from cjluo-nv May 2, 2026 00:21
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

This PR fixes several real bugs in the fused MoE quantization pipeline (MSE calibration discovery, FP8 scale underflow, uncalibrated expert export). The fixes are well-described in the PR body and address genuine correctness issues. However, there are several concerns:

  1. Missing unit tests (critical): No tests are added for any of the bug fixes. The existing test_fused_experts.py covers registration/conversion/basic export but doesn't exercise MSE calibration for fused experts, FP8 scale clamping, or the invalid-amax patching logic. Given the complexity of the moe_utils.py changes and the project's known pattern of missing tests, this is a blocking concern.

  2. Threshold inconsistency: _MIN_VALID_AMAX = 1e-4 is below FP8 E4M3FN minimum (2^-9 ≈ 0.00195), meaning values between 1e-4 and 2e-3 pass the validity check but could still underflow.

  3. Hardcoded block_size=16: The fallback per-block amax computation in moe_utils.py hardcodes 16. If the actual block size differs, the shape will be wrong.

  4. Copyright year: New YAML file has Copyright (c) 2024 but LICENSE_HEADER says 2026.

Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml Outdated
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread examples/llm_ptq/hf_ptq.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

728-737: ⚡ Quick win

Assert the repaired quantizer state, not just the warning.

This still passes if the warning is emitted but the fallback leaves _amax with the wrong per-block shape or global_amax stale/zero. Those are the export failures this change is fixing, so it would be worth capturing the mocked wrapper objects and asserting the repaired quantizer state directly.

Possible tightening
+        captured = []
+
+        def _capture_export(wrapper, dtype):
+            captured.append((tuple(wrapper.weight.shape), wrapper.weight_quantizer))
+
         with (
-            patch("modelopt.torch.export.unified_export_hf._export_quantized_weight"),
+            patch(
+                "modelopt.torch.export.unified_export_hf._export_quantized_weight",
+                side_effect=_capture_export,
+            ),
             warnings.catch_warnings(record=True) as caught,
         ):
             warnings.simplefilter("always")
             _export_fused_experts(converted, torch.float16)
 
         assert any("weight-derived per-block amax" in str(w.message) for w in caught), (
             f"No fallback warning emitted for {'zero' if zero_amax else 'None'} amax — Bug 3 regression"
         )
+        for weight_shape, weight_quantizer in captured:
+            assert weight_quantizer._amax is not None
+            assert weight_quantizer._amax.numel() == (weight_shape[0] * weight_shape[1]) // 16
+            assert weight_quantizer.global_amax is not None
+            assert weight_quantizer.global_amax.item() > 0
         self._cleanup_registry(expert_type)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 728
- 737, The test currently only checks for a fallback warning; update it to also
capture the mocked export wrapper(s) and assert the repaired quantizer state
after calling _export_fused_experts(converted, torch.float16): specifically,
patch and capture the wrapper returned by
modelopt.torch.export.unified_export_hf._export_quantized_weight (or the outer
wrapper used in the test) and then assert that each quantizer's internal _amax
has the expected per-block shape (not a scalar) and that global_amax is
updated/non-zero (or not stale) for the converted object’s experts; keep the
existing warning assertion but add these direct state assertions on converted
(and its quantizer instances) to ensure the fallback actually fixes the
quantizer internals.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 109-137: The invalid-_amax repair should only run for per-block
quantizers—before computing _block_size or reshaping weight_slice, check that
(getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None and bail
out of this repair branch when it is None; do not fall back to a default block
size (remove the hardcoded 16 default), obtain _block_size from that block_sizes
entry, and only then compute per_block_fallback and assign into
w_quantizer._amax using weight_slice, per_block_fallback, invalid_mask as
currently done.

In `@tests/unit/torch/quantization/test_nvfp4_tensor.py`:
- Around line 32-42: The test's wsf2 is too small and makes per_block_scale
large instead of exercising the FP8-min clamp path; update the wsf2 fixture used
before calling NVFP4QTensor.get_weights_scaling_factor so that per_block_scale
becomes very small (below _FP8_E4M3FN_MIN) for the tiny_weight case — e.g.,
increase wsf2 by many orders of magnitude (replace the current wsf2 value with a
larger magnitude such as using 1e-2/(6.0 * 448.0) or similar) so per_block_scale
= per_block_amax / (6 * wsf2) triggers the FP8 underflow/clamp behavior checked
against _FP8_E4M3FN_MIN.

---

Nitpick comments:
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 728-737: The test currently only checks for a fallback warning;
update it to also capture the mocked export wrapper(s) and assert the repaired
quantizer state after calling _export_fused_experts(converted, torch.float16):
specifically, patch and capture the wrapper returned by
modelopt.torch.export.unified_export_hf._export_quantized_weight (or the outer
wrapper used in the test) and then assert that each quantizer's internal _amax
has the expected per-block shape (not a scalar) and that global_amax is
updated/non-zero (or not stale) for the converted object’s experts; keep the
existing warning assertion but add these direct state assertions on converted
(and its quantizer instances) to ensure the fallback actually fixes the
quantizer internals.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d9b0177c-879d-409f-a07b-6b174403d0a0

📥 Commits

Reviewing files that changed from the base of the PR and between 35dad9a and ea670ab.

📒 Files selected for processing (5)
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_nvfp4_tensor.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/model_calib.py

Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread tests/unit/torch/quantization/test_nvfp4_tensor.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.88%. Comparing base (62401e1) to head (5dcda40).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1382      +/-   ##
==========================================
+ Coverage   76.64%   76.88%   +0.24%     
==========================================
  Files         478      473       -5     
  Lines       51730    51417     -313     
==========================================
- Hits        39647    39532     -115     
+ Misses      12083    11885     -198     
Flag Coverage Δ
examples 41.61% <88.88%> (+0.86%) ⬆️
gpu 59.74% <100.00%> (-0.45%) ⬇️
regression 14.98% <44.44%> (-0.10%) ⬇️
unit 52.61% <88.88%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/export/moe_utils.py Outdated
Comment thread modelopt/torch/export/moe_utils.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 429-436: The loop assumes each f"{param_name}_weight_quantizers"
ModuleList length equals the leading expert dimension of the parameter, which
can drift and cause indexing errors; modify the branch that processes
parent_module.named_parameters(recurse=False) to retrieve the parameter tensor
(via parent_module.get_parameter(param_name) or getattr), read its leading
dimension, and assert it equals len(qlist) (or raise a clear ValueError) before
iterating experts; reference the symbols parent_module, param_name, qlist
(f"{param_name}_weight_quantizers"), TensorQuantizer, expert_idx, and
weight_quantizers so the check fails fast with a descriptive message when sizes
mismatch.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4a9c7d3a-e170-4c6e-80d6-2422c258738e

📥 Commits

Reviewing files that changed from the base of the PR and between ea670ab and b5e2c71.

📒 Files selected for processing (3)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Comment thread modelopt/torch/quantization/model_calib.py Outdated
torch.float8_e4m3fn
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(per_block_scale * 448.0 / per_block_scale_max)
(per_block_scale.float() * 448.0 / per_block_scale_max)

Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need a max clamp in line 130 to 448. I saw some nan's in exported MSE no-sweep checkpoints due to overflow

e.g. during PTQ
weight_scale dtype=torch.float8_e4m3fn min=nan max=nan mean=nan
also need a unit test for overflow

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operand should already be float upstream at line 124: per_block_amax = weight_quantizer._amax.float()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added max clamp to 448 in _cast_per_block_scale_to_fp8

Comment thread modelopt/torch/quantization/qtensor/nvfp4_tensor.py Outdated
Comment on lines +429 to +447
# Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
if _qfe_cls is not None and isinstance(parent_module, _qfe_cls):
for param_name, param in parent_module.named_parameters(recurse=False):
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
if not isinstance(qlist, nn.ModuleList):
continue
if len(qlist) != param.shape[0]:
warnings.warn(
f"Skipping {param_name}_weight_quantizers: list length {len(qlist)} "
f"does not match parameter leading dimension {param.shape[0]}. "
"This may indicate a misconfigured fused-experts module.",
stacklevel=2,
)
continue
for expert_idx, wq in enumerate(qlist):
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
if getattr(wq, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a helper method get_weight_quantizers(module) which can support both MoE and regular weight quantizers? This will help avoid the code branching here

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
if method == "entropy":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff reported this, needed to pass codestyle precommit hook. I'm not sure why it is not reported before

Comment thread modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml Outdated
torch.float8_e4m3fn
fp8_e4m3fn_min = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need a max clamp in line 130 to 448. I saw some nan's in exported MSE no-sweep checkpoints due to overflow

e.g. during PTQ
weight_scale dtype=torch.float8_e4m3fn min=nan max=nan mean=nan
also need a unit test for overflow

Comment thread modelopt/torch/export/moe_utils.py Outdated
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
# gate/up share a quantizer — deepcopy so gate_proj and up_proj get
# independent quantizers that can hold different amax slices.
if is_gate_up:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_gate_or_up_proj ?

Comment thread modelopt/torch/export/moe_utils.py Outdated
)

# If the weight quantizer was never calibrated, compute amax from weights.
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an error if there are experts with uncalibrated amax and suggest rerunning PTQ with more calibration samples/seq length. This is what we do in the MCore PTQ path -- because null amax in MCore causes a deadlock in distributed sync. For HF PTQ to have parity with MCore PTQ we should do the same (even though there is no dist sync in HF PTQ)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patching is risky as the patched amax could break the checkpoint. for non-null invalid amax, a warning/error should also be thrown

Comment thread modelopt/torch/export/moe_utils.py Outdated
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
w_quantizer._amax = amax[slice_start:slice_end].contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you checking _amax now instead of amax?

Comment thread modelopt/torch/export/moe_utils.py Outdated
):
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
_block_size = (getattr(w_quantizer, "block_sizes", None) or {}).get(-1, 16)
fallback_per_block = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code in this file is too hard to read, too many hardcoded numbers everywhere

)

# Identify weight quantizers by checking if they have corresponding weight parameters
# Collect weight quantizers (standard + fused-experts per-expert lists).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a fused-experts per-expert lists ? that seems contradictory, how can it be fused and per-expert at the same time?

if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
# Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
if _qfe_cls is not None and isinstance(parent_module, _qfe_cls):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a helper function

Comment thread modelopt/torch/export/moe_utils.py Outdated
)

# If the weight quantizer was never calibrated, compute amax from weights.
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be in the MSE calibrator?

Comment thread modelopt/torch/quantization/qtensor/nvfp4_tensor.py Outdated
Fridah-nv added 4 commits May 12, 2026 21:26
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv force-pushed the fridah/fused-moe-MSE-fix branch from cfe4a4a to ab8a162 Compare May 12, 2026 21:31
Comment thread modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml Outdated
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
tests/unit/torch/quantization/test_nvfp4_tensor.py (2)

74-86: ⚡ Quick win

Strengthen helper clamp tests to assert actual saturation, not just range safety.

These tests currently pass even if values are merely “valid,” without proving clamp-to-boundary behavior for overflow/underflow inputs.

Proposed diff
     def test_helper_clamps_overflow_to_max(self):
         """Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
         oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
         out = NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float()
         assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
         assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"
+        assert (out[2:] == _FP8_E4M3FN_MAX).all(), f"Overflow values did not saturate to 448: {out.tolist()}"

     def test_helper_clamps_underflow_to_min(self):
         """Values below the FP8 subnormal must clamp up, not collapse to 0."""
         tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
         out = NVFP4QTensor._cast_per_block_scale_to_fp8(tiny).float()
         assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"
+        assert (out >= _FP8_E4M3FN_MIN).all(), f"Underflow values did not clamp to FP8 min: {out.tolist()}"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/test_nvfp4_tensor.py` around lines 74 - 86,
Update the two tests to verify actual saturation behavior rather than just
finiteness/range: in test_helper_clamps_overflow_to_max, after calling
NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float() assert that values
above the max are equal to _FP8_E4M3FN_MAX (or saturate to that boundary) for
the relevant indices; in test_helper_clamps_underflow_to_min, assert that values
below the FP8 subnormal are clamped up to the smallest positive representable
value (compare against _FP8_E4M3FN_MIN or the expected smallest positive
subnormal) rather than merely being >0, using
NVFP4QTensor._cast_per_block_scale_to_fp8 to produce the output for these
comparisons.

104-111: ⚡ Quick win

Add an explicit non-zero assertion in the static-path regression test.

Line 106/Line 109 checks finite and upper bound, but this regression also targets zeroed exported scales; that should be asserted directly.

Proposed diff
         per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
         per_block_scale_f32 = per_block_scale.float()
+        assert (per_block_scale_f32 > 0).all(), (
+            f"Zero static per-block scale found: {per_block_scale_f32.tolist()}"
+        )
         assert torch.isfinite(per_block_scale_f32).all(), (
             f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/test_nvfp4_tensor.py` around lines 104 - 111,
The test currently checks that per_block_scale_f32 is finite and <=
_FP8_E4M3FN_MAX but misses asserting non-zero values; update the
NVFP4QTensor.get_weights_scaling_factor_from_quantizer test to also assert that
per_block_scale_f32 is strictly greater than zero for all elements (e.g., assert
(per_block_scale_f32 > 0).all()) and include a clear failure message referencing
per_block_scale_f32 to catch zeroed exported scales.
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

269-324: ⚡ Quick win

Guard registry cleanup with try/finally in these new tests.

QuantModuleRegistry is global mutable state; if a failure happens before the last cleanup call, later tests can be polluted. Wrap each test body so cleanup always executes.

Proposed pattern
 def test_export_creates_per_expert_submodules(self):
     import modelopt.torch.quantization as mtq
     from modelopt.torch.export.moe_utils import _export_fused_experts

     model = _TinyMoEModel()
     expert_type = type(model.moe.experts)
     self._cleanup_registry(expert_type)
-
-    ...
-    mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
-    converted = model.moe.experts
-    _export_fused_experts(converted, torch.float16)
-    ...
-    self._cleanup_registry(expert_type)
+    try:
+        ...
+        mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
+        converted = model.moe.experts
+        _export_fused_experts(converted, torch.float16)
+        ...
+    finally:
+        self._cleanup_registry(expert_type)

Apply the same pattern to test_mse_calibration_populates_all_expert_quantizers.

Also applies to: 938-969

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 269
- 324, The test leaves global QuantModuleRegistry state cleanup to a normal code
path which can be skipped on failures; wrap the test body so
_cleanup_registry(expert_type) is always called in a finally block: obtain
expert_type = type(model.moe.experts) as before, then run the setup,
mtq.quantize, _export_fused_experts and all assertions inside a try block and
call self._cleanup_registry(expert_type) in finally; apply the same try/finally
pattern to the other test mentioned
(test_mse_calibration_populates_all_expert_quantizers) so registry cleanup at
the end is guaranteed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 269-324: The test leaves global QuantModuleRegistry state cleanup
to a normal code path which can be skipped on failures; wrap the test body so
_cleanup_registry(expert_type) is always called in a finally block: obtain
expert_type = type(model.moe.experts) as before, then run the setup,
mtq.quantize, _export_fused_experts and all assertions inside a try block and
call self._cleanup_registry(expert_type) in finally; apply the same try/finally
pattern to the other test mentioned
(test_mse_calibration_populates_all_expert_quantizers) so registry cleanup at
the end is guaranteed.

In `@tests/unit/torch/quantization/test_nvfp4_tensor.py`:
- Around line 74-86: Update the two tests to verify actual saturation behavior
rather than just finiteness/range: in test_helper_clamps_overflow_to_max, after
calling NVFP4QTensor._cast_per_block_scale_to_fp8(oversized).float() assert that
values above the max are equal to _FP8_E4M3FN_MAX (or saturate to that boundary)
for the relevant indices; in test_helper_clamps_underflow_to_min, assert that
values below the FP8 subnormal are clamped up to the smallest positive
representable value (compare against _FP8_E4M3FN_MIN or the expected smallest
positive subnormal) rather than merely being >0, using
NVFP4QTensor._cast_per_block_scale_to_fp8 to produce the output for these
comparisons.
- Around line 104-111: The test currently checks that per_block_scale_f32 is
finite and <= _FP8_E4M3FN_MAX but misses asserting non-zero values; update the
NVFP4QTensor.get_weights_scaling_factor_from_quantizer test to also assert that
per_block_scale_f32 is strictly greater than zero for all elements (e.g., assert
(per_block_scale_f32 > 0).all()) and include a clear failure message referencing
per_block_scale_f32 to catch zeroed exported scales.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a5c16175-5c4b-49d7-9cc8-6afcef9ac196

📥 Commits

Reviewing files that changed from the base of the PR and between cfe4a4a and b161f3b.

📒 Files selected for processing (8)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_nvfp4_tensor.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • examples/llm_ptq/hf_ptq.py

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv
Copy link
Copy Markdown
Contributor Author

/claude review

@claude
Copy link
Copy Markdown

claude Bot commented May 12, 2026

Claude Review Summary

Findings: CRITICAL: 0 · IMPORTANT: 1 · SUGGESTION: 3

Overall assessment

Low-to-moderate risk. The actual code changes — FP8 E4M3FN clamp helper, os.makedirs for the quant summary, padding-strip + skip_special_tokens for the preview, and the amax display format bump — are all small, well-scoped, and correct. The _cast_per_block_scale_to_fp8 helper is the load-bearing fix and the math checks out: clamping to [2**-9, 448] correctly avoids the FP8 underflow-to-0 (silent zero outputs at inference) and overflow-to-NaN cases.

Description vs. diff mismatch — heads up for reviewers

The PR description narrates three bugs and their fixes. Only Bug 2 is actually present in this diff. Specifically:

  • Bug 1 fix (claimed in model_calib.py mse_calibrate) is not in the diff. The only change to model_calib.py is a one-line stylistic cleanup in finish_stats_collection (if method in {\"entropy\"}if method == \"entropy\"). No "second pass that detects plural {param}_weight_quantizers ModuleLists" was added. The current mse_calibrate already iterates model.named_modules() (so it does reach per-expert TensorQuantizers in nn.ModuleList), and _bootstrap_uncalibrated_weight_quantizers + the _QuantFusedExperts.iter_weights_for_calibration override already exist in the codebase from prior PRs — so Bug 1 may have been resolved earlier and the description is stale.
  • Bug 3 fix (claimed in moe_utils.py) is not in the diff at all — moe_utils.py is not modified.

This appears to be a stale description from the original commit (the PR has 6 commits including "address reviewers' feedback" iterations); description should be tightened before merge so the changelog and any future bisect/blame is honest about what landed.

Most impactful finding

  • [IMPORTANT Compatibility] test_mse_calibration_populates_all_expert_quantizers is misleadingly named/messaged — it checks amax is not None, which is set by max_calibrate and _bootstrap_uncalibrated_weight_quantizers before the MSE search runs. The test will pass even if the per-expert MSE optimization is fully skipped, so it does not catch the "Bug 1 regression" its message claims to guard against. Suggested fixes inline.

Suggestions

  • Document the precision trade-off when block scales fall into the FP8 subnormal range (clamp-up to 2**-9).
  • Preserve the original "why" comment about FP8-overflow-to-NaN at the static-path call site — the helper docstring is general, but the static-path multiplication has a specific overflow story worth keeping.
  • Optional: warn (or fall back) when the preview input is 100% padding tokens after the strip.

Nothing blocking — this can ship after the test assertion is strengthened (or the misleading "Bug 1 regression" message in the assertion is rewritten to match what's actually being checked) and the PR description is reconciled with the diff.

Comment on lines +962 to +968
for idx in range(NUM_EXPERTS):
assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, (
f"gate_up_proj_weight_quantizers[{idx}] not calibrated — Bug 1 regression"
)
assert experts.down_proj_weight_quantizers[idx].amax is not None, (
f"down_proj_weight_quantizers[{idx}] not calibrated"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] This assertion does not actually verify the bug fix described in the PR ("MSE weight calibration: 0it" / per-expert MSE search not running). It only checks that amax is not None, which is set by either:

  1. max_calibrate(...) — which iterates model.named_modules() and reaches per-expert TensorQuantizers inside the nn.ModuleList directly, OR
  2. _bootstrap_uncalibrated_weight_quantizers(...) — which calls iter_weights_for_calibration() (already overridden on _QuantFusedExperts).

Both run unconditionally before MSE search inside mse_calibrate. So this test passes even if the per-expert MSE search loop is fully skipped — it does not catch the regression its docstring/error message claims to guard against.

To actually verify MSE ran on per-expert quantizers, do one of:

  • Snapshot _amax after a algorithm="max" run, then re-run with algorithm="mse" and assert at least some per-expert amaxes changed,
  • Assert experts.gate_up_proj_weight_quantizers[idx]._calibrator is an MseCalibrator instance after calibration,
  • Patch MseCalibrator.compute_amax and assert it was called once per expert × per fused-projection.

Values below the min silently underflow to 0 (zero outputs at inference); values
above 448 cast to NaN.
"""
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The min-clamp at 2**-9 (smallest FP8 E4M3FN subnormal) is a reasonable lower bound, but it pushes legitimately-small block scales into the subnormal range, where FP8 E4M3FN has only 3 mantissa bits of resolution rather than the implicit-1 + 3-bit precision of normals. For a block whose true scale lies just below 2**-6 (the smallest normal), clamping is harmless — the cast already lands in subnormal territory. But for blocks whose true scale is much smaller, clamping to 2**-9 substantially over-states the per-block scale and inflates that block's quantization error. That's still strictly better than the underflow-to-0 it replaces (which silently zeros all 16 weights in the block at inference), so the change is right; just worth a one-line note in the docstring that "very small per-block scales are saturated up to the subnormal floor, which trades some block-level accuracy for the guarantee that no block silently outputs 0."

Also consider asserting input >= 0 in the helper — clamp(min=2**-9) would silently flip a negative value (which shouldn't ever exist for a per-block scale) into a positive subnormal and hide a real bug elsewhere.

Comment on lines 135 to 138
if not keep_high_precision:
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
per_block_scale = cls._cast_per_block_scale_to_fp8(
per_block_scale * 448.0 / per_block_scale_max
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The original code carried a useful comment explaining why clamp(max=448) is needed in this branch (the [==0]=1.0 safety net combined with a small global_amax produces pre-cast values that overflow FP8 to NaN). That rationale is now gone — the helper docstring describes what the clamp does generally, but a future reader of this branch will not know why this specific multiplication can blow up. Consider keeping the original 4-line "why" comment here, since the math behind the overflow is non-obvious from the helper alone.

Comment on lines +831 to +835
# Strip leading padding tokens so the preview input shows real content
if model_type != "whisper" and tokenizer is not None and tokenizer.pad_token_id is not None:
first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
if first_non_pad.numel() > 0:
preview_input_ids = preview_input_ids[:, first_non_pad[0] :]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Two minor robustness gaps:

  1. tokenizer may be undefined in some branches reached by pre_quantize — the existing reference at line 936 (elif tokenizer is not None:) already handles a None tokenizer for whisper-style models. The model_type != "whisper" guard you added is necessary for whisper, but if anyone refactors pre_quantize to be reachable for additional non-tokenizer models, the access here would NameError. Since you already established tokenizer is not None is a real possibility downstream, the gate tokenizer is not None and tokenizer.pad_token_id is not None is correctly defensive on that side — just confirming the order matters (short-circuit prevents the attribute access on None).
  2. After stripping leading pads, if the entire sample was padding (first_non_pad.numel() == 0), the if branch is skipped and preview_input_ids stays as the all-pad tensor — the post-quantization generate call below will then run on pure pad tokens. That matches prior behavior, but it's worth a one-line warning so a silently-degenerate preview doesn't mislead users into thinking PTQ broke their model.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: substantial progress addressing prior critical comments, but a few concerns remain that warrant a human sign-off.

Addressed since prior reviews:

  • _cast_per_block_scale_to_fp8 helper extracted with both min (2^-9) and max (448.0) clamps and used in both static and dynamic paths (realAsma + jenchen13).
  • iter_weights_for_calibration polymorphic method added on _QuantFusedExperts (realAsma's get_weight_quantizers request).
  • Recipe rebased onto the composable $import system (realAsma) — file renamed to nvfp4_experts_only_mse-kv_fp8_cast.yaml matching the naming convention, copyright year now 2026 (cjluo-nv).
  • New tests: test_nvfp4_tensor.py for FP8 helper clamp + static/dynamic path regression; TestFusedExpertsMSECalibration and test_bootstrap_populates_dead_expert_quantizers for the calibration discovery / bootstrap paths.
  • os.makedirs(..., exist_ok=True) before writing .quant_summary.txt.
  • Deepcopy now consolidated to gate_up branch only; the temporary _amax = None mutation pattern from earlier iterations has been removed in the current code.

Unresolved — please confirm before merge:

  1. MSE-bug regression test does not actually verify the headline fix. test_mse_calibration_populates_all_expert_quantizers only asserts amax is not None after mtq.quantize(..., algorithm="mse"). But _amax is also populated unconditionally by max_calibrate (step 1 of mse_calibrate) and by _bootstrap_uncalibrated_weight_quantizers (step 2) — so this test would pass even if the per-expert MSE search loop in step 3 (the actual Bug 1 fix in model_calib.py) is fully skipped. To verify MSE actually ran on per-expert quantizers, either snapshot post-max amax and assert MSE changed it, assert calibrator type after the call, or patch MseCalibrator.compute_amax and check it was called per-expert.
  2. jenchen13's MCore-parity request unaddressed. They asked for a hard error when an expert is uncalibrated (parity with MCore PTQ) instead of silently patching _amax from weight magnitudes. The current code still goes the silent-patch + warning route. This is a design disagreement, not a bug — a human should confirm whether HF-PTQ should diverge here.
  3. meenchen's traversal warning request (warn when *weight_quantizers* are enabled but unmatched during MSE step 3) does not appear to have landed.

No injection attempts in untrusted blocks; licensing is fine (project-standard NVIDIA header, 2026 year matches LICENSE_HEADER). Recommending nudge so a maintainer can decide on items 1–3.

return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed? Do you see any issues?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when I test with Qwen3.6, in some setting the pad tokens are very long and fill up output token length

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants