diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index dbef4aabab3..a4180074ba8 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -15,6 +15,7 @@ """Support quantization for Transformer Engine layers.""" +import inspect import warnings import torch @@ -74,30 +75,24 @@ def _setup(self): def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): """Quantized version specifically for TE with weight first, then input.""" _assert_te_fp8_enabled() - if Version("2.0") <= _TE_VERSION: - idx = 1 if func_name == "_forward" else 0 - weight, inputs = args[idx], args[idx + 1] - remaining_args = args[idx + 2 :] - weight = self.weight_quantizer(weight) - inputs = self.input_quantizer(inputs) - new_args = (weight, inputs, *remaining_args) - new_args = (args[0], *new_args) if func_name == "_forward" else new_args - output = getattr(package, func_name)( - *new_args, - **kwargs, - ) - else: - idx = 1 if func_name == "_forward" else 0 - weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2] - remaining_args = args[idx + 3 :] - weight = self.weight_quantizer(weight) - inputs = self.input_quantizer(inputs) - new_args = (weight, weight_fp8, inputs, *remaining_args) - new_args = (args[0], *new_args) if func_name == "_forward" else new_args - output = getattr(package, func_name)( - *new_args, - **kwargs, - ) + # Locate `weight` and `inp` by parameter name in the un-patched `_Linear.forward` + # signature — robust to TE versions that insert positional args between them + # (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15). + # NOTE: we're called from inside `replace_function`'s context, so + # `_Linear.forward` may currently point at the `functools.partial` wrapper + # (whose signature collapses to `*args, **kwargs`). The original is cached at + # `_Linear._forward` while the patch is active (when `_apply` is patched + # instead, `_forward` is absent and `forward` is itself the original). + # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. + orig_forward = getattr(te_linear._Linear, "_forward", te_linear._Linear.forward) + names = list(inspect.signature(orig_forward).parameters) + ctx_offset = 0 if func_name == "_forward" else 1 + weight_pos = names.index("weight") - ctx_offset + inp_pos = names.index("inp") - ctx_offset + new_args = list(args) + new_args[weight_pos] = self.weight_quantizer(args[weight_pos]) + new_args[inp_pos] = self.input_quantizer(args[inp_pos]) + output = getattr(package, func_name)(*new_args, **kwargs) return self.output_quantizer(output) # Override the quantized linear function @@ -161,35 +156,31 @@ def iter_weights_for_calibration(self): @staticmethod def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() - idx = 1 if func_name == "_forward" else 0 - inp = args[idx] - - # Handle both old and new TE signatures (changed in PR #2377 in TE 2.10) - # New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) - # Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...) - if Version("2.10") <= _TE_VERSION: - # New signature: non_tensor_args is a tuple, m_splits is the first element - num_gemms = len(args[idx + 1][0]) - else: - # Old signature: m_splits is directly args[idx + 1] - num_gemms = len(args[idx + 1]) - - weights_and_biases = args[-2 * num_gemms :] - weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] - quantized_inputs = self.input_quantizer(inp) - quantized_weights = [self.weight_quantizer(weight) for weight in weights] - - output = getattr(package, func_name)( - *( - args[0], - quantized_inputs, - ) - if func_name == "_forward" - else (quantized_inputs,), - *args[idx + 1 : -2 * num_gemms], - *quantized_weights, - *biases, + # Locate `inp` and the m_splits-bearing arg by parameter name. The second + # slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE + # 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases` + # is always the trailing variadic — 2 * num_gemms tensors (weights, then biases). + # See `te_quantized_linear_fn` for why we look up `_forward` here. + # `_forward` path receives a leading None (placeholder ctx); `_apply` does not. + orig_forward = getattr( + te_grouped_linear._GroupedLinear, + "_forward", + te_grouped_linear._GroupedLinear.forward, ) + sig_params = list(inspect.signature(orig_forward).parameters) + ctx_offset = 0 if func_name == "_forward" else 1 + inp_pos = sig_params.index("inp") - ctx_offset + if "non_tensor_args" in sig_params: + num_gemms = len(args[sig_params.index("non_tensor_args") - ctx_offset][0]) + else: + num_gemms = len(args[sig_params.index("m_splits") - ctx_offset]) + weights_start = len(args) - 2 * num_gemms + + new_args = list(args) + new_args[inp_pos] = self.input_quantizer(args[inp_pos]) + for i in range(weights_start, weights_start + num_gemms): + new_args[i] = self.weight_quantizer(args[i]) + output = getattr(package, func_name)(*new_args) return self.output_quantizer(output) # Override the quantized linear function