From 9d16ad93bfd02922f6658225be9175a6609359d3 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 12 May 2026 11:27:55 -0700 Subject: [PATCH 1/2] fix(te-plugin): make _Linear arg indexing robust to TE signature changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ModelOpt's te_quantized_linear_fn and te_grouped_quantized_linear_fn read weight/inp from hard-coded positions in args. TE 2.15 broke this by inserting weight_workspace between weight and inp at the _Linear.forward call site — passing None into self.input_quantizer() and crashing with "AttributeError: 'NoneType' object has no attribute 'numel'" during calibration (observed in Megatron-Bridge after bumping to modelopt 0.44.0rc3). Replace the version-gated indexing with parameter-name introspection of the live _Linear.forward / _GroupedLinear.forward signature. The param names weight / inp / m_splits / non_tensor_args have been stable across TE 1.x, 2.x, and 2.15+; the gap between them changes. This also subsumes the old TE < 2.0 branch (`weight, weight_fp8, inputs`) — the same gap-detection handles it uniformly, so the dual-branch code is replaced with one path. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../plugins/transformer_engine.py | 83 +++++++------------ 1 file changed, 31 insertions(+), 52 deletions(-) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index dbef4aabab3..bee9e1fbe5d 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -15,6 +15,7 @@ """Support quantization for Transformer Engine layers.""" +import inspect import warnings import torch @@ -74,30 +75,18 @@ def _setup(self): def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): """Quantized version specifically for TE with weight first, then input.""" _assert_te_fp8_enabled() - if Version("2.0") <= _TE_VERSION: - idx = 1 if func_name == "_forward" else 0 - weight, inputs = args[idx], args[idx + 1] - remaining_args = args[idx + 2 :] - weight = self.weight_quantizer(weight) - inputs = self.input_quantizer(inputs) - new_args = (weight, inputs, *remaining_args) - new_args = (args[0], *new_args) if func_name == "_forward" else new_args - output = getattr(package, func_name)( - *new_args, - **kwargs, - ) - else: - idx = 1 if func_name == "_forward" else 0 - weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2] - remaining_args = args[idx + 3 :] - weight = self.weight_quantizer(weight) - inputs = self.input_quantizer(inputs) - new_args = (weight, weight_fp8, inputs, *remaining_args) - new_args = (args[0], *new_args) if func_name == "_forward" else new_args - output = getattr(package, func_name)( - *new_args, - **kwargs, - ) + # Locate `weight` and `inp` by parameter name in the live `_Linear.forward` + # signature — robust to TE versions that insert positional args between them + # (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15). + # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. + names = list(inspect.signature(te_linear._Linear.forward).parameters) + ctx_offset = 0 if func_name == "_forward" else 1 + weight_pos = names.index("weight") - ctx_offset + inp_pos = names.index("inp") - ctx_offset + new_args = list(args) + new_args[weight_pos] = self.weight_quantizer(args[weight_pos]) + new_args[inp_pos] = self.input_quantizer(args[inp_pos]) + output = getattr(package, func_name)(*new_args, **kwargs) return self.output_quantizer(output) # Override the quantized linear function @@ -161,35 +150,25 @@ def iter_weights_for_calibration(self): @staticmethod def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() - idx = 1 if func_name == "_forward" else 0 - inp = args[idx] - - # Handle both old and new TE signatures (changed in PR #2377 in TE 2.10) - # New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) - # Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...) - if Version("2.10") <= _TE_VERSION: - # New signature: non_tensor_args is a tuple, m_splits is the first element - num_gemms = len(args[idx + 1][0]) + # Locate `inp` and the m_splits-bearing arg by parameter name. The second + # slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE + # 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases` + # is always the trailing variadic — 2 * num_gemms tensors (weights, then biases). + # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. + sig_params = list(inspect.signature(te_grouped_linear._GroupedLinear.forward).parameters) + ctx_offset = 0 if func_name == "_forward" else 1 + inp_pos = sig_params.index("inp") - ctx_offset + if "non_tensor_args" in sig_params: + num_gemms = len(args[sig_params.index("non_tensor_args") - ctx_offset][0]) else: - # Old signature: m_splits is directly args[idx + 1] - num_gemms = len(args[idx + 1]) - - weights_and_biases = args[-2 * num_gemms :] - weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] - quantized_inputs = self.input_quantizer(inp) - quantized_weights = [self.weight_quantizer(weight) for weight in weights] - - output = getattr(package, func_name)( - *( - args[0], - quantized_inputs, - ) - if func_name == "_forward" - else (quantized_inputs,), - *args[idx + 1 : -2 * num_gemms], - *quantized_weights, - *biases, - ) + num_gemms = len(args[sig_params.index("m_splits") - ctx_offset]) + weights_start = len(args) - 2 * num_gemms + + new_args = list(args) + new_args[inp_pos] = self.input_quantizer(args[inp_pos]) + for i in range(weights_start, weights_start + num_gemms): + new_args[i] = self.weight_quantizer(args[i]) + output = getattr(package, func_name)(*new_args) return self.output_quantizer(output) # Override the quantized linear function From eeb9b9d38a705ba4ded694dc5d815cfdfdc30c87 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 12 May 2026 11:45:59 -0700 Subject: [PATCH 2/2] fix(te-plugin): inspect un-patched _Linear.forward via cached `_forward` Inside `te_quantized_linear_fn` / `te_grouped_quantized_linear_fn` we run *while* `replace_function`'s context is active, so `_Linear.forward` and `_GroupedLinear.forward` point at the `functools.partial` wrapper that replaced them. `inspect.signature` of a partial collapses to `(*args, **kwargs)`, so `names.index("weight")` would raise `ValueError: 'weight' is not in list` on the first calibration forward. `replace_function` saves the original at `_Linear._forward` / `_GroupedLinear._forward` while the patch is active. Look that up (falling back to the still-untouched `.forward` when only `apply` was patched). Caught by Claude review on PR #1473. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../quantization/plugins/transformer_engine.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index bee9e1fbe5d..a4180074ba8 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -75,11 +75,17 @@ def _setup(self): def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): """Quantized version specifically for TE with weight first, then input.""" _assert_te_fp8_enabled() - # Locate `weight` and `inp` by parameter name in the live `_Linear.forward` + # Locate `weight` and `inp` by parameter name in the un-patched `_Linear.forward` # signature — robust to TE versions that insert positional args between them # (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15). + # NOTE: we're called from inside `replace_function`'s context, so + # `_Linear.forward` may currently point at the `functools.partial` wrapper + # (whose signature collapses to `*args, **kwargs`). The original is cached at + # `_Linear._forward` while the patch is active (when `_apply` is patched + # instead, `_forward` is absent and `forward` is itself the original). # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. - names = list(inspect.signature(te_linear._Linear.forward).parameters) + orig_forward = getattr(te_linear._Linear, "_forward", te_linear._Linear.forward) + names = list(inspect.signature(orig_forward).parameters) ctx_offset = 0 if func_name == "_forward" else 1 weight_pos = names.index("weight") - ctx_offset inp_pos = names.index("inp") - ctx_offset @@ -154,8 +160,14 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): # slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE # 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases` # is always the trailing variadic — 2 * num_gemms tensors (weights, then biases). + # See `te_quantized_linear_fn` for why we look up `_forward` here. # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. - sig_params = list(inspect.signature(te_grouped_linear._GroupedLinear.forward).parameters) + orig_forward = getattr( + te_grouped_linear._GroupedLinear, + "_forward", + te_grouped_linear._GroupedLinear.forward, + ) + sig_params = list(inspect.signature(orig_forward).parameters) ctx_offset = 0 if func_name == "_forward" else 1 inp_pos = sig_params.index("inp") - ctx_offset if "non_tensor_args" in sig_params: