Split bypass prerequisites#1468
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughThis PR extends the pruning framework with KV-heads support across multiple model descriptors, adds LM-config helpers and sequential multi-mixin application, introduces normalized MSE loss utilities, adds a training dataloader factory with tokenizer-aware chat preprocessing, updates stitched-loss formatting and warmup resolver behavior, and adds comprehensive unit tests. ChangesPruning and model descriptor enhancements
Training infrastructure and loss utilities
Sewing kit infrastructure
🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
566cb1d to
0639883
Compare
|
/claude review |
Claude review — summaryFindings: CRITICAL: 1 · IMPORTANT: 2 · SUGGESTION: 2 Most impactful
Risk levelModerate. The bulk of the PR is cleanly scoped prerequisite plumbing (descriptor mixins, dataloader, chat-template fallback, warmup-step grad-accum handling, re-exports) with good test coverage for the pure-function helpers. The blocker is the one test that presupposes function-signature changes shipping in the follow-up PR — that needs to be resolved before merge. The mixin-composition and |
0639883 to
a79fbae
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1468 +/- ##
==========================================
+ Coverage 76.78% 76.85% +0.06%
==========================================
Files 473 478 +5
Lines 51413 51906 +493
==========================================
+ Hits 39476 39890 +414
- Misses 11937 12016 +79
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@AAnoosheh and @kevalmorabia97 ready for review (split the bypass MR into 3, this is the first one, nothing too important, just some preparations and tiny fixes) |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
tests/unit/torch/puzzletron/test_bypass_dataloaders.py (1)
206-219: ⚡ Quick winAdd a direct test for
ConstantLengthDatasetchat-template fallbackThis fixture replaces
ConstantLengthDataset, so the new no-chat_templatepreprocessing path inConstantLengthDataset.__iter__is not exercised. A small targeted iterator test would close that regression gap.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/puzzletron/test_bypass_dataloaders.py` around lines 206 - 219, The fixture patches out ConstantLengthDataset so ConstantLengthDataset.__iter__'s new no-chat_template fallback isn't tested; add a small unit test that imports the real ConstantLengthDataset (not _FakeConstantLengthDataset), constructs it with a tiny dataset whose items lack "chat_template", iterates it (e.g., list(ConstantLengthDataset(...)) or calling its __iter__), and asserts the output matches the expected realized items (e.g., tensors like {"input_ids": torch.tensor([0])}); ensure this test does not apply the patched_dataloader monkeypatch and references ConstantLengthDataset and ConstantLengthDataset.__iter__ (and optionally create_validation_dataloader) so the fallback path is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 479-495: The function batched_normalized_mse_loss allows silent
broadcasting when input and target shapes differ; add explicit shape validation
at the top of the function: verify input.ndim == target.ndim, confirm batch_dims
are valid indices, and ensure sizes match for every dimension (both batch dims
and non-batch dims computed via norm_dims) so that target and input are exactly
compatible; if any mismatch, raise a ValueError with a clear message that
includes the shapes of input and target and the resolved batch_dims/norm_dims to
aid debugging.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 93-95: The per-layer loop currently does full copies via
current_parent_state_dict = dict(parent_state_dict), current_new_state_dict =
dict(new_state_dict), current_keys = dict(keys) which is expensive; instead,
stop cloning entire mappings inside the loop and operate on the original dicts
(parent_state_dict, new_state_dict, keys) by reading values directly and only
materialize copies for individual tensors/entries that are actually modified
(e.g., when applying a mixin to a specific key). Locate the per-layer mixin loop
and replace the dict() copies with references to the originals, and when you
need to mutate a specific parameter, copy only that parameter (or its key->value
pair) and write back to new_state_dict; ensure any iteration over keys uses an
iterator or list(keys) outside the hot loop if necessary to avoid mutation
races.
In `@modelopt/torch/puzzletron/tools/hydra_utils.py`:
- Around line 35-50: The warmup_steps function must validate and normalize
inputs before doing integer divisions: ensure tokens, block, mbs and grad_accum
are ints (or cast) and that block>0, mbs>0, grad_accum>=1, and that pct is a
float within [0.0,1.0] (or at least >=0); raise ValueError with clear messages
for invalid values. In function warmup_steps, coerce tokens, block, mbs,
grad_accum and pct to the expected types up front, check block and mbs are >0 to
avoid ZeroDivisionError, check grad_accum>=1 (existing check can be reused), and
validate pct (and tokens>=0) before computing iters/steps and returning the
rounded warmup steps.
In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 131-138: The fallback that concatenates messages when
getattr(self.tokenizer, "chat_template", None) is None assumes every
m["content"] is a str and can raise TypeError for structured payloads; update
the else branch in dataset.py where sample is built to normalize each
m["content"] to a string before joining (e.g., if m["content"] is a dict or
other structured object, extract a text field if present like
m["content"].get("text") or otherwise call str(m["content"])), so the
concatenation in the no-template path (the code around tokenizer.chat_template
and tokenizer.apply_chat_template) always receives plain text.
In `@tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py`:
- Around line 137-139: The test currently checks values in received["kwargs"]
but doesn't ensure no extra kwargs are present; update the second-order test in
test_sewing_kit_function_target_kwargs (use the local variables received,
student_value, teacher_value) to assert that received["kwargs"] contains exactly
the keys "input" and "target" (e.g., compare set(received["kwargs"].keys()) to
{"input","target"}) before the existing torch.equal assertions, then keep the
existing checks for received["args"] and the tensor equality against
student_value and teacher_value.
---
Nitpick comments:
In `@tests/unit/torch/puzzletron/test_bypass_dataloaders.py`:
- Around line 206-219: The fixture patches out ConstantLengthDataset so
ConstantLengthDataset.__iter__'s new no-chat_template fallback isn't tested; add
a small unit test that imports the real ConstantLengthDataset (not
_FakeConstantLengthDataset), constructs it with a tiny dataset whose items lack
"chat_template", iterates it (e.g., list(ConstantLengthDataset(...)) or calling
its __iter__), and asserts the output matches the expected realized items (e.g.,
tensors like {"input_ids": torch.tensor([0])}); ensure this test does not apply
the patched_dataloader monkeypatch and references ConstantLengthDataset and
ConstantLengthDataset.__iter__ (and optionally create_validation_dataloader) so
the fallback path is exercised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ddff5f0a-3633-4520-914f-dad472197cf8
📒 Files selected for processing (22)
modelopt/torch/puzzletron/anymodel/model_descriptor/base.pymodelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.pymodelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/sewing_kit/passage.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/hydra_utils.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/data/dataset.pymodelopt/torch/puzzletron/utils/parsing.pytests/unit/torch/puzzletron/test_bypass_dataloaders.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_child_init_mixins.pytests/unit/torch/puzzletron/test_kv_heads_pruning_utils.pytests/unit/torch/puzzletron/test_sewing_kit_activity_context.pytests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.pytests/unit/torch/puzzletron/test_sewing_kit_input_args.pytests/unit/torch/puzzletron/test_sewing_kit_needle.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
a79fbae to
12086fb
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 540-542: Validate that epsilon is strictly positive before
computing den; in the function that computes num = ((input - target) **
2).sum(dim=norm_dims) and den = (target**2).sum(dim=norm_dims) + epsilon, add a
guard at the start (before the denominator math) that either raises a ValueError
with a clear message if epsilon <= 0, or clamps epsilon to a small positive
floor (e.g., max(epsilon, 1e-12)); ensure the check references the epsilon
variable and occurs before computing den to prevent any inf/nan from division.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 113-121: The shuffle call for map-style datasets currently
hardcodes keep_in_memory=True and ignores the function argument; update the
branch that handles non-IterableDataset so that it passes the caller's
keep_in_memory parameter (the function arg named keep_in_memory) into
train_data.shuffle(seed=shuffle_seed, keep_in_memory=keep_in_memory) while
leaving IterableDataset.shuffle(seed=shuffle_seed) unchanged; reference the
symbols train_data, datasets.IterableDataset, shuffle_seed, and keep_in_memory
to locate and modify the code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5e8b4997-90ef-408e-b03d-7bb26b85189d
📒 Files selected for processing (23)
modelopt/torch/puzzletron/anymodel/model_descriptor/base.pymodelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.pymodelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/sewing_kit/passage.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/hydra_utils.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/data/dataset.pymodelopt/torch/puzzletron/utils/parsing.pytests/unit/torch/puzzletron/test_bypass_dataloaders.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_child_init_mixins.pytests/unit/torch/puzzletron/test_hydra_utils.pytests/unit/torch/puzzletron/test_kv_heads_pruning_utils.pytests/unit/torch/puzzletron/test_sewing_kit_activity_context.pytests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.pytests/unit/torch/puzzletron/test_sewing_kit_input_args.pytests/unit/torch/puzzletron/test_sewing_kit_needle.py
✅ Files skipped from review due to trivial changes (3)
- modelopt/torch/puzzletron/sewing_kit/passage.py
- modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py
- tests/unit/torch/puzzletron/test_sewing_kit_input_args.py
🚧 Files skipped from review as they are similar to previous changes (16)
- modelopt/torch/puzzletron/utils/data/dataset.py
- tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py
- modelopt/torch/puzzletron/anymodel/model_descriptor/base.py
- modelopt/torch/puzzletron/tools/hydra_utils.py
- modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
- tests/unit/torch/puzzletron/test_child_init_mixins.py
- modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py
- modelopt/torch/puzzletron/pruning/pruning_utils.py
- tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py
- tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py
- modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py
- tests/unit/torch/puzzletron/test_bypass_losses.py
- modelopt/torch/puzzletron/utils/parsing.py
- tests/unit/torch/puzzletron/test_bypass_dataloaders.py
- tests/unit/torch/puzzletron/test_sewing_kit_needle.py
- modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
/claude review |
Summary
This is PR 1 of 3 in the Puzzletron bypass/local-distillation stack.
This PR contains prerequisite infrastructure only. It does not wire bypass distillation into the Puzzletron pipeline yet.
Stack:
ssameni/puzzletron-bypass-2-core: bypass distillation coressameni/puzzletron-bypass-3-integration: Puzzletron integration, configs, docs, GPU coverageWhat Changed
ModelDescriptor.pruning_mixins()so model families can expose pruning mixins needed by downstream bypass initialization.create_train_dataloader()and streaming-safe shuffle handling.tokenizer.chat_template.Why
The bypass distillation MR needs these reusable pieces, but they are independently reviewable and useful without adding the bypass
training stage itself.
Splitting them out keeps the bypass core PR focused on the actual local-distillation engine.
Tests
Added focused unit coverage for:
Summary by CodeRabbit
New Features
Improvements
Tests