Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 43 additions & 52 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Support quantization for Transformer Engine layers."""

import inspect
import warnings

import torch
Expand Down Expand Up @@ -74,30 +75,24 @@ def _setup(self):
def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
"""Quantized version specifically for TE with weight first, then input."""
_assert_te_fp8_enabled()
if Version("2.0") <= _TE_VERSION:
idx = 1 if func_name == "_forward" else 0
weight, inputs = args[idx], args[idx + 1]
remaining_args = args[idx + 2 :]
weight = self.weight_quantizer(weight)
inputs = self.input_quantizer(inputs)
new_args = (weight, inputs, *remaining_args)
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
output = getattr(package, func_name)(
*new_args,
**kwargs,
)
else:
idx = 1 if func_name == "_forward" else 0
weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
remaining_args = args[idx + 3 :]
weight = self.weight_quantizer(weight)
inputs = self.input_quantizer(inputs)
new_args = (weight, weight_fp8, inputs, *remaining_args)
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
output = getattr(package, func_name)(
*new_args,
**kwargs,
)
# Locate `weight` and `inp` by parameter name in the un-patched `_Linear.forward`
# signature — robust to TE versions that insert positional args between them
# (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15).
# NOTE: we're called from inside `replace_function`'s context, so
# `_Linear.forward` may currently point at the `functools.partial` wrapper
# (whose signature collapses to `*args, **kwargs`). The original is cached at
# `_Linear._forward` while the patch is active (when `_apply` is patched
# instead, `_forward` is absent and `forward` is itself the original).
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
orig_forward = getattr(te_linear._Linear, "_forward", te_linear._Linear.forward)
names = list(inspect.signature(orig_forward).parameters)
ctx_offset = 0 if func_name == "_forward" else 1
weight_pos = names.index("weight") - ctx_offset
inp_pos = names.index("inp") - ctx_offset
new_args = list(args)
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
output = getattr(package, func_name)(*new_args, **kwargs)
return self.output_quantizer(output)

# Override the quantized linear function
Expand Down Expand Up @@ -161,35 +156,31 @@ def iter_weights_for_calibration(self):
@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
_assert_te_fp8_enabled()
idx = 1 if func_name == "_forward" else 0
inp = args[idx]

# Handle both old and new TE signatures (changed in PR #2377 in TE 2.10)
# New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases)
# Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...)
if Version("2.10") <= _TE_VERSION:
# New signature: non_tensor_args is a tuple, m_splits is the first element
num_gemms = len(args[idx + 1][0])
else:
# Old signature: m_splits is directly args[idx + 1]
num_gemms = len(args[idx + 1])

weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
quantized_inputs = self.input_quantizer(inp)
quantized_weights = [self.weight_quantizer(weight) for weight in weights]

output = getattr(package, func_name)(
*(
args[0],
quantized_inputs,
)
if func_name == "_forward"
else (quantized_inputs,),
*args[idx + 1 : -2 * num_gemms],
*quantized_weights,
*biases,
# Locate `inp` and the m_splits-bearing arg by parameter name. The second
# slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE
# 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases`
# is always the trailing variadic — 2 * num_gemms tensors (weights, then biases).
# See `te_quantized_linear_fn` for why we look up `_forward` here.
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
orig_forward = getattr(
te_grouped_linear._GroupedLinear,
"_forward",
te_grouped_linear._GroupedLinear.forward,
)
sig_params = list(inspect.signature(orig_forward).parameters)
ctx_offset = 0 if func_name == "_forward" else 1
inp_pos = sig_params.index("inp") - ctx_offset
if "non_tensor_args" in sig_params:
num_gemms = len(args[sig_params.index("non_tensor_args") - ctx_offset][0])
else:
num_gemms = len(args[sig_params.index("m_splits") - ctx_offset])
Comment thread
kevalmorabia97 marked this conversation as resolved.
weights_start = len(args) - 2 * num_gemms

new_args = list(args)
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
for i in range(weights_start, weights_start + num_gemms):
new_args[i] = self.weight_quantizer(args[i])
output = getattr(package, func_name)(*new_args)
return self.output_quantizer(output)

# Override the quantized linear function
Expand Down
Loading