From 470e582d80d7734e57c712623dd2d224bfc90b9b Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 11 May 2026 17:13:15 +0000 Subject: [PATCH 1/9] Incorporate safetensors support to TorchAO --- src/diffusers/hooks/group_offloading.py | 174 +++++++++++++----- src/diffusers/models/model_loading_utils.py | 3 + src/diffusers/models/modeling_utils.py | 39 +++- .../quantizers/torchao/torchao_quantizer.py | 96 +++++++++- tests/quantization/torchao/test_torchao.py | 75 +++++++- 5 files changed, 335 insertions(+), 52 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f3d1f3389bb7..5f0a50610804 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +import json import os from contextlib import contextmanager, nullcontext from dataclasses import dataclass, replace @@ -21,8 +22,9 @@ import safetensors.torch import torch +from safetensors import safe_open -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -32,6 +34,15 @@ from accelerate.utils import send_to_device +if is_torchao_available(): + if is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -146,26 +157,28 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self._torchao_disk_key_remap: dict[str, str] = {} + if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") - - all_tensors = [] - for module in self.modules: - all_tensors.extend(list(module.parameters())) - all_tensors.extend(list(module.buffers())) - all_tensors.extend(self.parameters) - all_tensors.extend(self.buffers) - all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates - - self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} - self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() + self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) + self._torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -179,6 +192,26 @@ def _to_cpu(tensor, low_cpu_mem_usage): t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() return t if low_cpu_mem_usage else t.pin_memory() + @staticmethod + def _get_torchao_subset_metadata_for_unflatten(metadata): + tensor_names = metadata.get("tensor_names") + if tensor_names is None: + return None + + try: + tensor_names = json.loads(tensor_names) + except (TypeError, json.JSONDecodeError): + return None + + dotted_tensor_names = [name for name in tensor_names if "." in name] + if len(dotted_tensor_names) == 0: + return None + + return { + "tensor_names": json.dumps(dotted_tensor_names), + **{name: metadata[name] for name in dotted_tensor_names if name in metadata}, + } + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -238,19 +271,79 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _check_disk_offload_torchao(self): - all_tensors = list(self.tensor_to_key.keys()) - has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) - if has_torchao: - raise ValueError( - "Disk offloading is not supported for TorchAO quantized tensors because safetensors " - "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " - "setting `offload_to_disk_path`." + def _get_disk_state_dict(self): + tensors_to_save = { + key: ( + tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) ) + for tensor, key in self.tensor_to_key.items() + } + + metadata = {} + if self._has_torchao_tensors and is_torchao_version(">=", "0.15.0"): + tensors_for_flatten = {} + self._torchao_disk_key_remap = {} + for key, tensor in tensors_to_save.items(): + if _is_torchao_tensor(tensor) and "." not in key: + flattened_key = f"{key}.weight" + self._torchao_disk_key_remap[key] = flattened_key + tensors_for_flatten[flattened_key] = tensor + else: + tensors_for_flatten[key] = tensor - def _onload_from_disk(self): - self._check_disk_offload_torchao() + flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) + if isinstance(flattened_state_dict, tuple): + tensors_to_save, metadata = flattened_state_dict + + return tensors_to_save, metadata + + def _load_disk_state_dict(self, device): + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) + + if not self._has_torchao_tensors or not is_torchao_version(">=", "0.15.0"): + return loaded_tensors + + with safe_open(self.safetensors_file_path, framework="pt") as f: + metadata = f.metadata() or {} + + if is_metadata_torchao(metadata): + try: + reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) + loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} + except Exception as error: + logger.warning( + "Failed to unflatten TorchAO state dict metadata from disk; falling back to raw tensors." + ) + logger.debug(error) + + subset_metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) + if subset_metadata is not None: + try: + reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict( + loaded_tensors, subset_metadata + ) + loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} + except Exception as subset_error: + logger.debug("Failed to unflatten subset of TorchAO metadata; using raw tensors for onload.") + logger.debug(subset_error) + + # Support legacy in-memory tensor keys used by GroupOffloading when + # flattening introduced dot-based names to satisfy TorchAO's safetensors API. + for original_key, flattened_key in self._torchao_disk_key_remap.items(): + if original_key not in loaded_tensors and flattened_key in loaded_tensors: + loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) + + return loaded_tensors + + def _release_onload_tensors(self): + for tensor_obj in self.tensor_to_key.keys(): + if _is_torchao_tensor(tensor_obj): + placeholder = tensor_obj.to(self.offload_device) + _swap_torchao_tensor(tensor_obj, placeholder) + else: + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + def _onload_from_disk(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -259,22 +352,22 @@ def _onload_from_disk(self): current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None with context: + device = str(self.onload_device) if self.stream is None else "cpu" + loaded_tensors = self._load_disk_state_dict(device=device) + if self.stream is not None: - # Load to CPU first, pin memory, then async copy to the target device - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) + pinned_memory = { + tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() + for tensor_obj in self.tensor_to_key + } + self._process_tensors_from_modules(pinned_memory, default_stream=current_stream) else: - # Load directly to the target device - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + for tensor_obj in self.tensor_to_key: + self._transfer_tensor_to_device( + tensor_obj, + loaded_tensors[self.tensor_to_key[tensor_obj]], + default_stream=None, + ) def _onload_from_memory(self): if self.stream is not None: @@ -292,8 +385,6 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): - self._check_disk_offload_torchao() - # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -301,15 +392,14 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + tensors_to_save, metadata = self._get_disk_state_dict() + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + self._release_onload_tensors() def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..beeee1b498b0 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -357,6 +357,9 @@ def _load_shard_file( disable_mmap=False, ): state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) + if hf_quantizer is not None and hasattr(hf_quantizer, "get_reconstructed_state_dict"): + state_dict = hf_quantizer.get_reconstructed_state_dict(state_dict) + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..30bf01da51b2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -759,6 +759,17 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() + quantization_metadata = {} + if safe_serialization and hf_quantizer is not None: + get_state_dict_and_metadata = getattr(hf_quantizer, "get_state_dict_and_metadata", None) + if callable(get_state_dict_and_metadata): + state_dict_and_metadata = get_state_dict_and_metadata(model_to_save) + else: + state_dict_and_metadata = model_to_save.state_dict() + if isinstance(state_dict_and_metadata, tuple): + state_dict, quantization_metadata = state_dict_and_metadata + else: + state_dict = state_dict_and_metadata if use_flashpack: if is_flashpack_available(): @@ -803,15 +814,21 @@ def save_pretrained( shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} filepath = os.path.join(save_directory, filename) if safe_serialization: + metadata = dict(state_dict_split.metadata) + metadata.update(quantization_metadata) + metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata=metadata) else: torch.save(shard, filepath) if state_dict_split.is_sharded: + metadata = dict(state_dict_split.metadata) + metadata.update(quantization_metadata) + metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} index = { - "metadata": state_dict_split.metadata, + "metadata": metadata, "weight_map": state_dict_split.tensor_to_filename, } save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME @@ -1367,11 +1384,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None else: loaded_keys = list(state_dict.keys()) + checkpoint_files = resolved_model_file + if hf_quantizer is not None: + if hasattr(hf_quantizer, "set_metadata"): + hf_quantizer.set_metadata(checkpoint_files) + quantized_weight_names = [] + if hasattr(hf_quantizer, "get_weight_names"): + quantized_weight_names = hf_quantizer.get_weight_names() + if quantized_weight_names: + loaded_keys = list(quantized_weight_names) + if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + model=model, + device_map=device_map, + keep_in_fp32_modules=keep_in_fp32_modules, + checkpoint_files=checkpoint_files, ) + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: + is_parallel_loading_enabled = False + # Now that the model is loaded, we can determine the device_map device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index b710fcd2db30..6457ed40b128 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -18,6 +18,7 @@ """ import importlib +import json import re import types from typing import TYPE_CHECKING, Any @@ -26,6 +27,7 @@ from ...utils import ( get_module_from_name, + is_safetensors_available, is_torch_available, is_torch_version, is_torchao_available, @@ -41,6 +43,9 @@ if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin +if is_safetensors_available(): + from safetensors import safe_open + if is_torch_available(): import torch @@ -72,6 +77,13 @@ if is_torchao_available(): from torchao.quantization import quantize_ + if is_torchao_version(">=", "0.15.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + def _update_torch_safe_globals(): safe_globals = [ @@ -154,6 +166,11 @@ class TorchAoHfQuantizer(DiffusersQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + self._metadata = {} + self._pending_flattened_state_dict = {} + self._loaded_weight_names = set() + self._expected_weight_names = set() + def validate_environment(self, *args, **kwargs): if not is_torchao_available(): raise ImportError( @@ -236,6 +253,76 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory + def get_state_dict_and_metadata(self, model): + """ + We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format. + """ + if not is_torchao_available() or not is_torchao_version(">=", "0.15.0"): + return model.state_dict(), {} + return flatten_tensor_state_dict(model.state_dict()) + + def set_metadata(self, checkpoint_files: list[str]): + if not is_torchao_version(">=", "0.15.0"): + self._metadata = {} + return + + if self.metadata is None: + self.metadata = {} + self._pending_flattened_state_dict = {} + self._loaded_weight_names = set() + self._expected_weight_names = set() + + if len(checkpoint_files) == 0: + return + + if not checkpoint_files[0].endswith(".safetensors"): + self._metadata = {} + return + + metadata = {} + for checkpoint in checkpoint_files: + with safe_open(checkpoint, framework="pt") as f: + metadata.update(f.metadata() or {}) + + self._metadata = metadata if is_metadata_torchao(metadata) else {} + if is_metadata_torchao(self._metadata): + try: + self._expected_weight_names = set(json.loads(self._metadata["tensor_names"])) + except (TypeError, json.JSONDecodeError, UnicodeDecodeError): + self._metadata = {} + self._expected_weight_names = set() + + @property + def metadata(self): + return self._metadata + + @metadata.setter + def metadata(self, value: dict): + self._metadata = value + + def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + if not self._metadata or not is_torchao_version(">=", "0.15.0") or not is_metadata_torchao(self._metadata): + return state_dict + + merged_state_dict = {**self._pending_flattened_state_dict, **state_dict} + reconstructed_state_dict, self._pending_flattened_state_dict = unflatten_tensor_state_dict( + merged_state_dict, self._metadata + ) + + self._loaded_weight_names.update(reconstructed_state_dict.keys()) + return reconstructed_state_dict + + def get_weight_conversions(self): + return [] + + def get_weight_names(self): + return self._expected_weight_names if self._expected_weight_names else set() + + def get_weight_reconstruction_pending_keys(self): + if not self._expected_weight_names: + return [] + return sorted(self._expected_weight_names - self._loaded_weight_names) + def check_if_quantized_param( self, model: "ModelMixin", @@ -337,11 +424,12 @@ def _process_model_before_weight_loading( def _process_model_after_weight_loading(self, model: "ModelMixin"): return model - def is_serializable(self, safe_serialization=None): - # TODO(aryan): needs to be tested - if safe_serialization: + @property + def is_serializable(self): + if not is_torchao_version(">=", "0.15.0"): logger.warning( - "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + "TorchAO quantized model is not serializable with safe serialization without safetensors support " + "from the installed torchao version." ) return False diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..58913a0fc29f 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import os import tempfile import unittest from typing import List @@ -589,13 +590,32 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice): self.assertTrue(isinstance(weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def _check_serialization_expected_slice(self, quant_type, expected_slice, device): + def _check_serialization_expected_slice( + self, quant_type, expected_slice, device, safe_serialization=False, max_shard_size=None, assert_sharded=False + ): + if safe_serialization and getattr(quant_type, "version", None) != 2: + self.skipTest("TorchAO safe serialization tests require quantization config version=2.") + quantized_model = self.get_dummy_model(quant_type, device) + save_kwargs = {"safe_serialization": safe_serialization} + if max_shard_size is not None: + save_kwargs["max_shard_size"] = max_shard_size + with tempfile.TemporaryDirectory() as tmp_dir: - quantized_model.save_pretrained(tmp_dir, safe_serialization=False) + quantized_model.save_pretrained(tmp_dir, **save_kwargs) + if assert_sharded: + shard_files = [f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")] + if max_shard_size is not None: + self.assertTrue(len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint.") + self.assertTrue( + any("index" in f and f.endswith(".json") for f in os.listdir(tmp_dir)), + "Expected an index file for sharded safe checkpoint.", + ) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + tmp_dir, + torch_dtype=torch.bfloat16, + use_safetensors=safe_serialization, ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) @@ -605,6 +625,55 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + def test_int_a8w8_safe_cpu(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = "cpu" + self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) + + def test_int_a8w8_safe(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = torch_device + self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) + + def test_group_offload_to_disk(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + + quantized_model = self.get_dummy_model(quant_type, torch_device) + + with tempfile.TemporaryDirectory() as offload_to_disk_path: + quantized_model.enable_group_offload( + onload_device=torch_device, + offload_type="leaf_level", + offload_to_disk_path=offload_to_disk_path, + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + + output = quantized_model(**inputs)[0] + output_slice_2 = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) + + def test_int_a8w8_safe_sharded(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + device = torch_device + self._check_serialization_expected_slice( + quant_type, + expected_slice, + device, + safe_serialization=True, + max_shard_size="16KB", + assert_sharded=True, + ) + def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) From ed6f2f43ccc0f5659f0644eb1e83ad365c375c73 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 28 May 2026 08:56:25 +0100 Subject: [PATCH 2/9] Address TorchAO safetensors review comments --- src/diffusers/hooks/group_offloading.py | 155 ++++++++++++------ src/diffusers/models/modeling_utils.py | 39 +++-- .../quantizers/torchao/torchao_quantizer.py | 34 ++-- tests/models/testing_utils/quantization.py | 44 +++-- tests/quantization/torchao/test_torchao.py | 50 +----- 5 files changed, 189 insertions(+), 133 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 5f0a50610804..7a30177feb18 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -35,7 +35,7 @@ if is_torchao_available(): - if is_torchao_version(">=", "0.15.0"): + if is_torchao_version(">=", "0.16.0"): from torchao.prototype.safetensors.safetensors_support import ( flatten_tensor_state_dict, unflatten_tensor_state_dict, @@ -46,6 +46,10 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _supports_torchao_safetensors() -> bool: + return is_torchao_available() and is_torchao_version(">=", "0.16.0") + + def _is_torchao_tensor(tensor: torch.Tensor) -> bool: if not is_torchao_available(): return False @@ -157,28 +161,32 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - all_tensors = [] - for module in self.modules: - all_tensors.extend(list(module.parameters())) - all_tensors.extend(list(module.buffers())) - all_tensors.extend(self.parameters) - all_tensors.extend(self.buffers) - all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates - - self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.tensor_to_key = {} + self.key_to_tensor = {} self._torchao_disk_key_remap: dict[str, str] = {} + self._has_torchao_tensors = False if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") + + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() - self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) - self._torch_accelerator_module = ( getattr(torch, torch.accelerator.current_accelerator().type) if hasattr(torch, "accelerator") @@ -271,7 +279,15 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _get_disk_state_dict(self): + def _check_disk_offload_torchao_support(self): + if self._has_torchao_tensors and not _supports_torchao_safetensors(): + raise ValueError( + "Disk offloading TorchAO quantized tensors requires torchao >= 0.16.0 because older torchao " + "versions cannot serialize tensor subclasses with safetensors. Use memory offloading instead by " + "not setting `offload_to_disk_path`." + ) + + def _get_torchao_disk_state_dict(self): tensors_to_save = { key: ( tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) @@ -280,33 +296,32 @@ def _get_disk_state_dict(self): } metadata = {} - if self._has_torchao_tensors and is_torchao_version(">=", "0.15.0"): - tensors_for_flatten = {} - self._torchao_disk_key_remap = {} - for key, tensor in tensors_to_save.items(): - if _is_torchao_tensor(tensor) and "." not in key: - flattened_key = f"{key}.weight" - self._torchao_disk_key_remap[key] = flattened_key - tensors_for_flatten[flattened_key] = tensor - else: - tensors_for_flatten[key] = tensor + tensors_for_flatten = {} + self._torchao_disk_key_remap = {} + for key, tensor in tensors_to_save.items(): + if _is_torchao_tensor(tensor) and "." not in key: + flattened_key = f"{key}.weight" + self._torchao_disk_key_remap[key] = flattened_key + tensors_for_flatten[flattened_key] = tensor + else: + tensors_for_flatten[key] = tensor - flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) - if isinstance(flattened_state_dict, tuple): - tensors_to_save, metadata = flattened_state_dict + flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) + if isinstance(flattened_state_dict, tuple): + tensors_to_save, metadata = flattened_state_dict + else: + tensors_to_save = flattened_state_dict return tensors_to_save, metadata - def _load_disk_state_dict(self, device): + def _load_torchao_disk_state_dict(self, device): loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) - if not self._has_torchao_tensors or not is_torchao_version(">=", "0.15.0"): - return loaded_tensors - with safe_open(self.safetensors_file_path, framework="pt") as f: metadata = f.metadata() or {} if is_metadata_torchao(metadata): + metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) or metadata try: reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} @@ -316,17 +331,6 @@ def _load_disk_state_dict(self, device): ) logger.debug(error) - subset_metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) - if subset_metadata is not None: - try: - reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict( - loaded_tensors, subset_metadata - ) - loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} - except Exception as subset_error: - logger.debug("Failed to unflatten subset of TorchAO metadata; using raw tensors for onload.") - logger.debug(subset_error) - # Support legacy in-memory tensor keys used by GroupOffloading when # flattening introduced dot-based names to satisfy TorchAO's safetensors API. for original_key, flattened_key in self._torchao_disk_key_remap.items(): @@ -335,7 +339,7 @@ def _load_disk_state_dict(self, device): return loaded_tensors - def _release_onload_tensors(self): + def _release_torchao_onload_tensors(self): for tensor_obj in self.tensor_to_key.keys(): if _is_torchao_tensor(tensor_obj): placeholder = tensor_obj.to(self.offload_device) @@ -343,7 +347,7 @@ def _release_onload_tensors(self): else: tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - def _onload_from_disk(self): + def _onload_torchao_from_disk(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -353,7 +357,7 @@ def _onload_from_disk(self): with context: device = str(self.onload_device) if self.stream is None else "cpu" - loaded_tensors = self._load_disk_state_dict(device=device) + loaded_tensors = self._load_torchao_disk_state_dict(device=device) if self.stream is not None: pinned_memory = { @@ -369,6 +373,39 @@ def _onload_from_disk(self): default_stream=None, ) + def _onload_from_disk(self): + self._check_disk_offload_torchao_support() + + if self._has_torchao_tensors: + self._onload_torchao_from_disk() + return + + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) + current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None + + with context: + # Load to CPU (if using streams) or directly to target device, pin, and async copy to device + device = str(self.onload_device) if self.stream is None else "cpu" + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) + + if self.stream is not None: + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + def _onload_from_memory(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -384,7 +421,7 @@ def _onload_from_memory(self): else: self._process_tensors_from_modules(None) - def _offload_to_disk(self): + def _offload_torchao_to_disk(self): # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -392,14 +429,38 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save, metadata = self._get_disk_state_dict() + tensors_to_save, metadata = self._get_torchao_disk_state_dict() safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - self._release_onload_tensors() + self._release_torchao_onload_tensors() + + def _offload_to_disk(self): + self._check_disk_offload_torchao_support() + + if self._has_torchao_tensors: + self._offload_torchao_to_disk() + return + + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True + + # We do this to free up the RAM which is still holding the up tensor data. + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 30bf01da51b2..4f821fd2798c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -716,12 +716,17 @@ def save_pretrained( return hf_quantizer = getattr(self, "hf_quantizer", None) + is_torchao_quantized = ( + hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + ) if hf_quantizer is not None: quantization_serializable = ( hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) + if safe_serialization and is_torchao_quantized: + quantization_serializable = quantization_serializable and hf_quantizer.is_safetensors_serializable if not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" @@ -760,16 +765,8 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() quantization_metadata = {} - if safe_serialization and hf_quantizer is not None: - get_state_dict_and_metadata = getattr(hf_quantizer, "get_state_dict_and_metadata", None) - if callable(get_state_dict_and_metadata): - state_dict_and_metadata = get_state_dict_and_metadata(model_to_save) - else: - state_dict_and_metadata = model_to_save.state_dict() - if isinstance(state_dict_and_metadata, tuple): - state_dict, quantization_metadata = state_dict_and_metadata - else: - state_dict = state_dict_and_metadata + if safe_serialization and is_torchao_quantized: + state_dict, quantization_metadata = hf_quantizer.get_state_dict_and_metadata(state_dict) if use_flashpack: if is_flashpack_available(): @@ -814,8 +811,9 @@ def save_pretrained( shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} filepath = os.path.join(save_directory, filename) if safe_serialization: - metadata = dict(state_dict_split.metadata) - metadata.update(quantization_metadata) + metadata = {"format": "pt"} + if is_torchao_quantized: + metadata.update(quantization_metadata) metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. @@ -825,8 +823,8 @@ def save_pretrained( if state_dict_split.is_sharded: metadata = dict(state_dict_split.metadata) - metadata.update(quantization_metadata) - metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} + if is_torchao_quantized: + metadata.update(quantization_metadata) index = { "metadata": metadata, "weight_map": state_dict_split.tensor_to_filename, @@ -1384,12 +1382,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None else: loaded_keys = list(state_dict.keys()) + is_torchao_quantized = ( + hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + ) + has_torchao_safetensors_metadata = False checkpoint_files = resolved_model_file if hf_quantizer is not None: - if hasattr(hf_quantizer, "set_metadata"): + if is_torchao_quantized and hasattr(hf_quantizer, "set_metadata"): hf_quantizer.set_metadata(checkpoint_files) + has_torchao_safetensors_metadata = bool(getattr(hf_quantizer, "metadata", None)) quantized_weight_names = [] - if hasattr(hf_quantizer, "get_weight_names"): + if is_torchao_quantized and hasattr(hf_quantizer, "get_weight_names"): quantized_weight_names = hf_quantizer.get_weight_names() if quantized_weight_names: loaded_keys = list(quantized_weight_names) @@ -1402,7 +1405,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None checkpoint_files=checkpoint_files, ) - if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: + if has_torchao_safetensors_metadata: + # TorchAO safetensors reconstruction carries incomplete tensor subclass pieces from one shard to the next. + # Loading shards concurrently would make that pending state nondeterministic. is_parallel_loading_enabled = False # Now that the model is loaded, we can determine the device_map diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 6457ed40b128..1386adaa9fdf 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -77,7 +77,7 @@ if is_torchao_available(): from torchao.quantization import quantize_ - if is_torchao_version(">=", "0.15.0"): + if is_torchao_version(">=", "0.16.0"): from torchao.prototype.safetensors.safetensors_support import ( flatten_tensor_state_dict, unflatten_tensor_state_dict, @@ -253,16 +253,21 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory - def get_state_dict_and_metadata(self, model): + def get_state_dict_and_metadata(self, state_dict): """ We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format. """ - if not is_torchao_available() or not is_torchao_version(">=", "0.15.0"): - return model.state_dict(), {} - return flatten_tensor_state_dict(model.state_dict()) + if not is_torchao_available() or not is_torchao_version(">=", "0.16.0"): + return state_dict, {} + + flattened_state_dict = flatten_tensor_state_dict(state_dict) + if isinstance(flattened_state_dict, tuple): + return flattened_state_dict + + return flattened_state_dict, {} def set_metadata(self, checkpoint_files: list[str]): - if not is_torchao_version(">=", "0.15.0"): + if not is_safetensors_available() or not is_torchao_version(">=", "0.16.0"): self._metadata = {} return @@ -275,7 +280,9 @@ def set_metadata(self, checkpoint_files: list[str]): if len(checkpoint_files) == 0: return - if not checkpoint_files[0].endswith(".safetensors"): + if not all( + isinstance(checkpoint, str) and checkpoint.endswith(".safetensors") for checkpoint in checkpoint_files + ): self._metadata = {} return @@ -301,7 +308,7 @@ def metadata(self, value: dict): self._metadata = value def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: - if not self._metadata or not is_torchao_version(">=", "0.15.0") or not is_metadata_torchao(self._metadata): + if not self._metadata or not is_torchao_version(">=", "0.16.0") or not is_metadata_torchao(self._metadata): return state_dict merged_state_dict = {**self._pending_flattened_state_dict, **state_dict} @@ -312,9 +319,6 @@ def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, self._loaded_weight_names.update(reconstructed_state_dict.keys()) return reconstructed_state_dict - def get_weight_conversions(self): - return [] - def get_weight_names(self): return self._expected_weight_names if self._expected_weight_names else set() @@ -425,14 +429,18 @@ def _process_model_after_weight_loading(self, model: "ModelMixin"): return model @property - def is_serializable(self): - if not is_torchao_version(">=", "0.15.0"): + def is_safetensors_serializable(self): + if not is_torchao_version(">=", "0.16.0"): logger.warning( "TorchAO quantized model is not serializable with safe serialization without safetensors support " "from the installed torchao version." ) return False + return True + + @property + def is_serializable(self): _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( "0.25.0" ) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index ded5cab52268..d03d0ea0794f 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -805,16 +805,16 @@ class TorchAoConfigMixin: } @staticmethod - def _get_quant_config(config_name): + def _get_quant_config(config_name, **config_kwargs): config_cls = getattr(_torchao_quantization, config_name) # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": - return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) + config_kwargs.setdefault("int4_packing_format", "plain_int32") - return TorchAoConfig(config_cls()) + return TorchAoConfig(config_cls(**config_kwargs)) - def _create_quantized_model(self, config_name, **extra_kwargs): - config = self._get_quant_config(config_name) + def _create_quantized_model(self, config_name, quant_config_kwargs=None, **extra_kwargs): + config = self._get_quant_config(config_name, **(quant_config_kwargs or {})) kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() kwargs["quantization_config"] = config kwargs["device_map"] = str(torch_device) @@ -905,20 +905,44 @@ def test_torchao_quantized_layers(self, quant_type): def test_torchao_quantization_lora_inference(self, quant_type): self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + @pytest.mark.parametrize("quant_type", ["int8wo", "int8dq"], ids=["int8wo", "int8dq"]) + @require_torchao_version_greater_or_equal("0.16.0") def test_torchao_quantization_serialization(self, quant_type, tmp_path): - """Override to use safe_serialization=False for TorchAO (safetensors not supported).""" config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] - model = self._create_quantized_model(config_kwargs) + model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) - model.save_pretrained(str(tmp_path), safe_serialization=False) + model.save_pretrained(str(tmp_path), safe_serialization=True) - model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) + model_loaded = self.model_class.from_pretrained( + str(tmp_path), device_map=str(torch_device), use_safetensors=True + ) inputs = self.get_dummy_inputs() output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" + @pytest.mark.parametrize("quant_type", ["int8dq"], ids=["int8dq"]) + @require_torchao_version_greater_or_equal("0.16.0") + def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): + config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] + model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) + + model.save_pretrained(str(tmp_path), safe_serialization=True, max_shard_size="16KB") + + shard_files = list(tmp_path.glob("*.safetensors")) + assert len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint." + assert any(path.name.endswith(".index.json") for path in tmp_path.iterdir()), ( + "Expected an index file for sharded safe checkpoint." + ) + + model_loaded = self.model_class.from_pretrained( + str(tmp_path), device_map=str(torch_device), use_safetensors=True + ) + + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs, return_dict=False)[0] + assert not torch.isnan(output).any(), "Loaded sharded model output contains NaN" + def test_torchao_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 58913a0fc29f..bb6580ce3ae4 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -14,7 +14,6 @@ # limitations under the License. import gc -import os import tempfile import unittest from typing import List @@ -590,32 +589,15 @@ def _test_original_model_expected_slice(self, quant_type, expected_slice): self.assertTrue(isinstance(weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def _check_serialization_expected_slice( - self, quant_type, expected_slice, device, safe_serialization=False, max_shard_size=None, assert_sharded=False - ): - if safe_serialization and getattr(quant_type, "version", None) != 2: - self.skipTest("TorchAO safe serialization tests require quantization config version=2.") - + def _check_serialization_expected_slice(self, quant_type, expected_slice, device): quantized_model = self.get_dummy_model(quant_type, device) - save_kwargs = {"safe_serialization": safe_serialization} - if max_shard_size is not None: - save_kwargs["max_shard_size"] = max_shard_size - with tempfile.TemporaryDirectory() as tmp_dir: - quantized_model.save_pretrained(tmp_dir, **save_kwargs) - if assert_sharded: - shard_files = [f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")] - if max_shard_size is not None: - self.assertTrue(len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint.") - self.assertTrue( - any("index" in f and f.endswith(".json") for f in os.listdir(tmp_dir)), - "Expected an index file for sharded safe checkpoint.", - ) + quantized_model.save_pretrained(tmp_dir, safe_serialization=False) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( tmp_dir, torch_dtype=torch.bfloat16, - use_safetensors=safe_serialization, + use_safetensors=False, ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) @@ -625,18 +607,7 @@ def _check_serialization_expected_slice( self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def test_int_a8w8_safe_cpu(self): - quant_type = Int8DynamicActivationInt8WeightConfig(version=2) - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = "cpu" - self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) - - def test_int_a8w8_safe(self): - quant_type = Int8DynamicActivationInt8WeightConfig(version=2) - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = torch_device - self._check_serialization_expected_slice(quant_type, expected_slice, device, safe_serialization=True) - + @require_torchao_version_greater_or_equal("0.16.0") def test_group_offload_to_disk(self): quant_type = Int8DynamicActivationInt8WeightConfig(version=2) expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) @@ -661,19 +632,6 @@ def test_group_offload_to_disk(self): self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) - def test_int_a8w8_safe_sharded(self): - quant_type = Int8DynamicActivationInt8WeightConfig(version=2) - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = torch_device - self._check_serialization_expected_slice( - quant_type, - expected_slice, - device, - safe_serialization=True, - max_shard_size="16KB", - assert_sharded=True, - ) - def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) From a1220c31faa0831cb83d882ab85979b8db7be9fb Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 28 May 2026 13:30:47 +0100 Subject: [PATCH 3/9] Address TorchAO safetensors review feedback --- src/diffusers/hooks/group_offloading.py | 22 ++++++------- .../quantizers/torchao/torchao_quantizer.py | 33 ++++--------------- tests/models/testing_utils/quantization.py | 28 ++++++++++++---- tests/quantization/torchao/test_torchao.py | 16 +++++++++ 4 files changed, 55 insertions(+), 44 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7a30177feb18..dbcb135e80c2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -297,14 +297,9 @@ def _get_torchao_disk_state_dict(self): metadata = {} tensors_for_flatten = {} - self._torchao_disk_key_remap = {} + self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() for key, tensor in tensors_to_save.items(): - if _is_torchao_tensor(tensor) and "." not in key: - flattened_key = f"{key}.weight" - self._torchao_disk_key_remap[key] = flattened_key - tensors_for_flatten[flattened_key] = tensor - else: - tensors_for_flatten[key] = tensor + tensors_for_flatten[self._torchao_disk_key_remap.get(key, key)] = tensor flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) if isinstance(flattened_state_dict, tuple): @@ -314,6 +309,13 @@ def _get_torchao_disk_state_dict(self): return tensors_to_save, metadata + def _get_torchao_disk_key_remap(self): + return { + key: f"{key}.weight" + for tensor, key in self.tensor_to_key.items() + if _is_torchao_tensor(tensor) and "." not in key + } + def _load_torchao_disk_state_dict(self, device): loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) @@ -326,13 +328,11 @@ def _load_torchao_disk_state_dict(self, device): reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} except Exception as error: - logger.warning( - "Failed to unflatten TorchAO state dict metadata from disk; falling back to raw tensors." - ) - logger.debug(error) + raise RuntimeError("Failed to reconstruct TorchAO tensors from disk offload safetensors.") from error # Support legacy in-memory tensor keys used by GroupOffloading when # flattening introduced dot-based names to satisfy TorchAO's safetensors API. + self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() for original_key, flattened_key in self._torchao_disk_key_remap.items(): if original_key not in loaded_tensors and flattened_key in loaded_tensors: loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 1386adaa9fdf..cfa1f344c3f3 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -168,8 +168,6 @@ def __init__(self, quantization_config, **kwargs): self._metadata = {} self._pending_flattened_state_dict = {} - self._loaded_weight_names = set() - self._expected_weight_names = set() def validate_environment(self, *args, **kwargs): if not is_torchao_available(): @@ -267,23 +265,18 @@ def get_state_dict_and_metadata(self, state_dict): return flattened_state_dict, {} def set_metadata(self, checkpoint_files: list[str]): + self._metadata = {} + self._pending_flattened_state_dict = {} + if not is_safetensors_available() or not is_torchao_version(">=", "0.16.0"): - self._metadata = {} return - if self.metadata is None: - self.metadata = {} - self._pending_flattened_state_dict = {} - self._loaded_weight_names = set() - self._expected_weight_names = set() - if len(checkpoint_files) == 0: return if not all( isinstance(checkpoint, str) and checkpoint.endswith(".safetensors") for checkpoint in checkpoint_files ): - self._metadata = {} return metadata = {} @@ -292,21 +285,11 @@ def set_metadata(self, checkpoint_files: list[str]): metadata.update(f.metadata() or {}) self._metadata = metadata if is_metadata_torchao(metadata) else {} - if is_metadata_torchao(self._metadata): - try: - self._expected_weight_names = set(json.loads(self._metadata["tensor_names"])) - except (TypeError, json.JSONDecodeError, UnicodeDecodeError): - self._metadata = {} - self._expected_weight_names = set() @property def metadata(self): return self._metadata - @metadata.setter - def metadata(self, value: dict): - self._metadata = value - def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: if not self._metadata or not is_torchao_version(">=", "0.16.0") or not is_metadata_torchao(self._metadata): return state_dict @@ -316,16 +299,12 @@ def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, merged_state_dict, self._metadata ) - self._loaded_weight_names.update(reconstructed_state_dict.keys()) return reconstructed_state_dict def get_weight_names(self): - return self._expected_weight_names if self._expected_weight_names else set() - - def get_weight_reconstruction_pending_keys(self): - if not self._expected_weight_names: - return [] - return sorted(self._expected_weight_names - self._loaded_weight_names) + if not self._metadata: + return set() + return set(json.loads(self._metadata["tensor_names"])) def check_if_quantized_param( self, diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index d03d0ea0794f..512bfc8d7fbb 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -910,24 +910,39 @@ def test_torchao_quantization_lora_inference(self, quant_type): def test_torchao_quantization_serialization(self, quant_type, tmp_path): config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + expected_output = model(**inputs, return_dict=False)[0].detach().cpu() model.save_pretrained(str(tmp_path), safe_serialization=True) + del model + gc.collect() + backend_empty_cache(torch_device) model_loaded = self.model_class.from_pretrained( str(tmp_path), device_map=str(torch_device), use_safetensors=True ) - inputs = self.get_dummy_inputs() - output = model_loaded(**inputs, return_dict=False)[0] - assert not torch.isnan(output).any(), "Loaded model output contains NaN" + with torch.no_grad(): + output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() + + torch.testing.assert_close(output, expected_output, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("quant_type", ["int8dq"], ids=["int8dq"]) @require_torchao_version_greater_or_equal("0.16.0") def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + expected_output = model(**inputs, return_dict=False)[0].detach().cpu() model.save_pretrained(str(tmp_path), safe_serialization=True, max_shard_size="16KB") + del model + gc.collect() + backend_empty_cache(torch_device) shard_files = list(tmp_path.glob("*.safetensors")) assert len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint." @@ -939,9 +954,10 @@ def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): str(tmp_path), device_map=str(torch_device), use_safetensors=True ) - inputs = self.get_dummy_inputs() - output = model_loaded(**inputs, return_dict=False)[0] - assert not torch.isnan(output).any(), "Loaded sharded model output contains NaN" + with torch.no_grad(): + output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() + + torch.testing.assert_close(output, expected_output, rtol=1e-3, atol=1e-3) def test_torchao_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index bb6580ce3ae4..250df2d2170b 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -632,6 +632,22 @@ def test_group_offload_to_disk(self): self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) + del quantized_model + gc.collect() + backend_empty_cache(torch_device) + + quantized_model = self.get_dummy_model(quant_type, torch_device) + quantized_model.enable_group_offload( + onload_device=torch_device, + offload_type="leaf_level", + offload_to_disk_path=offload_to_disk_path, + ) + + output = quantized_model(**inputs)[0] + output_slice_3 = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice_3, expected_slice) < 1e-3) + def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) From 768f243b8a30cdeb3675f9971bcd6b539ec21275 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 29 May 2026 09:09:44 +0100 Subject: [PATCH 4/9] Address follow-up TorchAO review feedback --- src/diffusers/hooks/group_offloading.py | 93 +++++++--------------- src/diffusers/models/modeling_utils.py | 16 ++-- tests/quantization/torchao/test_torchao.py | 4 +- 3 files changed, 37 insertions(+), 76 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index dbcb135e80c2..ee7fdee24494 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -161,11 +161,6 @@ def __init__( self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False - self.tensor_to_key = {} - self.key_to_tensor = {} - self._torchao_disk_key_remap: dict[str, str] = {} - self._has_torchao_tensors = False - if self.offload_to_disk_path is not None: # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. self.group_id = group_id if group_id is not None else str(id(self)) @@ -182,6 +177,7 @@ def __init__( self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self._torchao_disk_key_remap: dict[str, str] = {} self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) self.cpu_param_dict = {} else: @@ -209,6 +205,7 @@ def _get_torchao_subset_metadata_for_unflatten(metadata): try: tensor_names = json.loads(tensor_names) except (TypeError, json.JSONDecodeError): + logger.warning("Could not parse TorchAO safetensors metadata for disk offloading; using full metadata.") return None dotted_tensor_names = [name for name in tensor_names if "." in name] @@ -295,6 +292,8 @@ def _get_torchao_disk_state_dict(self): for tensor, key in self.tensor_to_key.items() } + # TorchAO safetensors support expects logical parameter names and stores + # tensor subclass internals plus reconstruction metadata separately. metadata = {} tensors_for_flatten = {} self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() @@ -347,7 +346,9 @@ def _release_torchao_onload_tensors(self): else: tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - def _onload_torchao_from_disk(self): + def _onload_from_disk(self): + self._check_disk_offload_torchao_support() + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -357,14 +358,19 @@ def _onload_torchao_from_disk(self): with context: device = str(self.onload_device) if self.stream is None else "cpu" - loaded_tensors = self._load_torchao_disk_state_dict(device=device) + loaded_tensors = ( + self._load_torchao_disk_state_dict(device=device) + if self._has_torchao_tensors + else safetensors.torch.load_file(self.safetensors_file_path, device=device) + ) if self.stream is not None: pinned_memory = { tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() for tensor_obj in self.tensor_to_key } - self._process_tensors_from_modules(pinned_memory, default_stream=current_stream) + for tensor_obj, pinned_tensor in pinned_memory.items(): + self._transfer_tensor_to_device(tensor_obj, pinned_tensor, current_stream) else: for tensor_obj in self.tensor_to_key: self._transfer_tensor_to_device( @@ -373,39 +379,6 @@ def _onload_torchao_from_disk(self): default_stream=None, ) - def _onload_from_disk(self): - self._check_disk_offload_torchao_support() - - if self._has_torchao_tensors: - self._onload_torchao_from_disk() - return - - if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) - current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None - - with context: - # Load to CPU (if using streams) or directly to target device, pin, and async copy to device - device = str(self.onload_device) if self.stream is None else "cpu" - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) - - if self.stream is not None: - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) - else: - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] - def _onload_from_memory(self): if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -421,30 +394,9 @@ def _onload_from_memory(self): else: self._process_tensors_from_modules(None) - def _offload_torchao_to_disk(self): - # TODO: we can potentially optimize this code path by checking if the _all_ the desired - # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not - # we perform a write. - # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): - os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save, metadata = self._get_torchao_disk_state_dict() - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) - - # The group is now considered offloaded to disk for the rest of the session. - self._is_offloaded_to_disk = True - - # We do this to free up the RAM which is still holding the up tensor data. - self._release_torchao_onload_tensors() - def _offload_to_disk(self): self._check_disk_offload_torchao_support() - if self._has_torchao_tensors: - self._offload_torchao_to_disk() - return - # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -452,15 +404,24 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + if self._has_torchao_tensors: + tensors_to_save, metadata = self._get_torchao_disk_state_dict() + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) + else: + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + if self._has_torchao_tensors: + self._release_torchao_onload_tensors() + else: + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4f821fd2798c..00369338e5b5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -725,7 +725,7 @@ def save_pretrained( and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) - if safe_serialization and is_torchao_quantized: + if safe_serialization and hasattr(hf_quantizer, "is_safetensors_serializable"): quantization_serializable = quantization_serializable and hf_quantizer.is_safetensors_serializable if not quantization_serializable: raise ValueError( @@ -1398,12 +1398,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None loaded_keys = list(quantized_weight_names) if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, - device_map=device_map, - keep_in_fp32_modules=keep_in_fp32_modules, - checkpoint_files=checkpoint_files, - ) + preprocess_kwargs = { + "model": model, + "device_map": device_map, + "keep_in_fp32_modules": keep_in_fp32_modules, + } + if is_torchao_quantized: + preprocess_kwargs["checkpoint_files"] = checkpoint_files + hf_quantizer.preprocess_model(**preprocess_kwargs) if has_torchao_safetensors_metadata: # TorchAO safetensors reconstruction carries incomplete tensor subclass pieces from one shard to the next. diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 250df2d2170b..c0c1a8fd4fc9 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -595,9 +595,7 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, - torch_dtype=torch.bfloat16, - use_safetensors=False, + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) From 3cfc59239cbb58ae72a3368a0a1dcabc895b58a8 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 29 May 2026 10:08:11 +0100 Subject: [PATCH 5/9] Simplify TorchAO metadata loading guard --- src/diffusers/models/modeling_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 00369338e5b5..a3790316f73c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1387,15 +1387,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None ) has_torchao_safetensors_metadata = False checkpoint_files = resolved_model_file - if hf_quantizer is not None: - if is_torchao_quantized and hasattr(hf_quantizer, "set_metadata"): - hf_quantizer.set_metadata(checkpoint_files) - has_torchao_safetensors_metadata = bool(getattr(hf_quantizer, "metadata", None)) - quantized_weight_names = [] - if is_torchao_quantized and hasattr(hf_quantizer, "get_weight_names"): - quantized_weight_names = hf_quantizer.get_weight_names() - if quantized_weight_names: - loaded_keys = list(quantized_weight_names) + if is_torchao_quantized: + hf_quantizer.set_metadata(checkpoint_files) + has_torchao_safetensors_metadata = bool(hf_quantizer.metadata) + if has_torchao_safetensors_metadata: + loaded_keys = list(hf_quantizer.get_weight_names()) if hf_quantizer is not None: preprocess_kwargs = { From ba21459f79650e67d93c36afe6bba6951af949ce Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Jun 2026 07:43:10 +0100 Subject: [PATCH 6/9] Address TorchAO test utility review feedback --- src/diffusers/models/modeling_utils.py | 13 +++++-------- tests/models/testing_utils/quantization.py | 13 +++++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a3790316f73c..2a384468e595 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1394,14 +1394,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None loaded_keys = list(hf_quantizer.get_weight_names()) if hf_quantizer is not None: - preprocess_kwargs = { - "model": model, - "device_map": device_map, - "keep_in_fp32_modules": keep_in_fp32_modules, - } - if is_torchao_quantized: - preprocess_kwargs["checkpoint_files"] = checkpoint_files - hf_quantizer.preprocess_model(**preprocess_kwargs) + hf_quantizer.preprocess_model( + model=model, + device_map=device_map, + keep_in_fp32_modules=keep_in_fp32_modules, + ) if has_torchao_safetensors_metadata: # TorchAO safetensors reconstruction carries incomplete tensor subclass pieces from one shard to the next. diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 512bfc8d7fbb..1c45ce7dafb5 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -805,16 +805,17 @@ class TorchAoConfigMixin: } @staticmethod - def _get_quant_config(config_name, **config_kwargs): + def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) + config_kwargs = {"version": 2} # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": config_kwargs.setdefault("int4_packing_format", "plain_int32") return TorchAoConfig(config_cls(**config_kwargs)) - def _create_quantized_model(self, config_name, quant_config_kwargs=None, **extra_kwargs): - config = self._get_quant_config(config_name, **(quant_config_kwargs or {})) + def _create_quantized_model(self, config_name, **extra_kwargs): + config = self._get_quant_config(config_name) kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() kwargs["quantization_config"] = config kwargs["device_map"] = str(torch_device) @@ -905,11 +906,11 @@ def test_torchao_quantized_layers(self, quant_type): def test_torchao_quantization_lora_inference(self, quant_type): self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", ["int8wo", "int8dq"], ids=["int8wo", "int8dq"]) + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) @require_torchao_version_greater_or_equal("0.16.0") def test_torchao_quantization_serialization(self, quant_type, tmp_path): config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] - model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) + model = self._create_quantized_model(config_kwargs) inputs = self.get_dummy_inputs() with torch.no_grad(): @@ -933,7 +934,7 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): @require_torchao_version_greater_or_equal("0.16.0") def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] - model = self._create_quantized_model(config_kwargs, quant_config_kwargs={"version": 2}) + model = self._create_quantized_model(config_kwargs) inputs = self.get_dummy_inputs() with torch.no_grad(): From e3ecbac482b8bedff4aa1651d70c700603ea06e0 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Jun 2026 17:38:23 +0100 Subject: [PATCH 7/9] Scope TorchAO safetensors to quantizer hooks --- src/diffusers/hooks/group_offloading.py | 168 +++--------------- src/diffusers/models/model_loading_utils.py | 4 +- src/diffusers/models/modeling_utils.py | 32 ++-- src/diffusers/quantizers/base.py | 19 ++ .../quantizers/torchao/torchao_quantizer.py | 19 +- tests/quantization/torchao/test_torchao.py | 41 ----- 6 files changed, 75 insertions(+), 208 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ee7fdee24494..f3d1f3389bb7 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,6 @@ # limitations under the License. import hashlib -import json import os from contextlib import contextmanager, nullcontext from dataclasses import dataclass, replace @@ -22,9 +21,8 @@ import safetensors.torch import torch -from safetensors import safe_open -from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version +from ..utils import get_logger, is_accelerate_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -34,22 +32,9 @@ from accelerate.utils import send_to_device -if is_torchao_available(): - if is_torchao_version(">=", "0.16.0"): - from torchao.prototype.safetensors.safetensors_support import ( - flatten_tensor_state_dict, - unflatten_tensor_state_dict, - ) - from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao - - logger = get_logger(__name__) # pylint: disable=invalid-name -def _supports_torchao_safetensors() -> bool: - return is_torchao_available() and is_torchao_version(">=", "0.16.0") - - def _is_torchao_tensor(tensor: torch.Tensor) -> bool: if not is_torchao_available(): return False @@ -177,8 +162,6 @@ def __init__( self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} - self._torchao_disk_key_remap: dict[str, str] = {} - self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() @@ -196,27 +179,6 @@ def _to_cpu(tensor, low_cpu_mem_usage): t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() return t if low_cpu_mem_usage else t.pin_memory() - @staticmethod - def _get_torchao_subset_metadata_for_unflatten(metadata): - tensor_names = metadata.get("tensor_names") - if tensor_names is None: - return None - - try: - tensor_names = json.loads(tensor_names) - except (TypeError, json.JSONDecodeError): - logger.warning("Could not parse TorchAO safetensors metadata for disk offloading; using full metadata.") - return None - - dotted_tensor_names = [name for name in tensor_names if "." in name] - if len(dotted_tensor_names) == 0: - return None - - return { - "tensor_names": json.dumps(dotted_tensor_names), - **{name: metadata[name] for name in dotted_tensor_names if name in metadata}, - } - def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -276,78 +238,18 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _check_disk_offload_torchao_support(self): - if self._has_torchao_tensors and not _supports_torchao_safetensors(): + def _check_disk_offload_torchao(self): + all_tensors = list(self.tensor_to_key.keys()) + has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) + if has_torchao: raise ValueError( - "Disk offloading TorchAO quantized tensors requires torchao >= 0.16.0 because older torchao " - "versions cannot serialize tensor subclasses with safetensors. Use memory offloading instead by " - "not setting `offload_to_disk_path`." + "Disk offloading is not supported for TorchAO quantized tensors because safetensors " + "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " + "setting `offload_to_disk_path`." ) - def _get_torchao_disk_state_dict(self): - tensors_to_save = { - key: ( - tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) - ) - for tensor, key in self.tensor_to_key.items() - } - - # TorchAO safetensors support expects logical parameter names and stores - # tensor subclass internals plus reconstruction metadata separately. - metadata = {} - tensors_for_flatten = {} - self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() - for key, tensor in tensors_to_save.items(): - tensors_for_flatten[self._torchao_disk_key_remap.get(key, key)] = tensor - - flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) - if isinstance(flattened_state_dict, tuple): - tensors_to_save, metadata = flattened_state_dict - else: - tensors_to_save = flattened_state_dict - - return tensors_to_save, metadata - - def _get_torchao_disk_key_remap(self): - return { - key: f"{key}.weight" - for tensor, key in self.tensor_to_key.items() - if _is_torchao_tensor(tensor) and "." not in key - } - - def _load_torchao_disk_state_dict(self, device): - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) - - with safe_open(self.safetensors_file_path, framework="pt") as f: - metadata = f.metadata() or {} - - if is_metadata_torchao(metadata): - metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) or metadata - try: - reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) - loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} - except Exception as error: - raise RuntimeError("Failed to reconstruct TorchAO tensors from disk offload safetensors.") from error - - # Support legacy in-memory tensor keys used by GroupOffloading when - # flattening introduced dot-based names to satisfy TorchAO's safetensors API. - self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() - for original_key, flattened_key in self._torchao_disk_key_remap.items(): - if original_key not in loaded_tensors and flattened_key in loaded_tensors: - loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) - - return loaded_tensors - - def _release_torchao_onload_tensors(self): - for tensor_obj in self.tensor_to_key.keys(): - if _is_torchao_tensor(tensor_obj): - placeholder = tensor_obj.to(self.offload_device) - _swap_torchao_tensor(tensor_obj, placeholder) - else: - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - def _onload_from_disk(self): - self._check_disk_offload_torchao_support() + self._check_disk_offload_torchao() if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -357,27 +259,22 @@ def _onload_from_disk(self): current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None with context: - device = str(self.onload_device) if self.stream is None else "cpu" - loaded_tensors = ( - self._load_torchao_disk_state_dict(device=device) - if self._has_torchao_tensors - else safetensors.torch.load_file(self.safetensors_file_path, device=device) - ) - if self.stream is not None: - pinned_memory = { - tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() - for tensor_obj in self.tensor_to_key - } - for tensor_obj, pinned_tensor in pinned_memory.items(): - self._transfer_tensor_to_device(tensor_obj, pinned_tensor, current_stream) + # Load to CPU first, pin memory, then async copy to the target device + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) else: - for tensor_obj in self.tensor_to_key: - self._transfer_tensor_to_device( - tensor_obj, - loaded_tensors[self.tensor_to_key[tensor_obj]], - default_stream=None, - ) + # Load directly to the target device + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] def _onload_from_memory(self): if self.stream is not None: @@ -395,7 +292,7 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): - self._check_disk_offload_torchao_support() + self._check_disk_offload_torchao() # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO @@ -404,24 +301,15 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - if self._has_torchao_tensors: - tensors_to_save, metadata = self._get_torchao_disk_state_dict() - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) - else: - tensors_to_save = { - key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() - } - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - if self._has_torchao_tensors: - self._release_torchao_onload_tensors() - else: - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index beeee1b498b0..abbde8082bb5 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -357,8 +357,8 @@ def _load_shard_file( disable_mmap=False, ): state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) - if hf_quantizer is not None and hasattr(hf_quantizer, "get_reconstructed_state_dict"): - state_dict = hf_quantizer.get_reconstructed_state_dict(state_dict) + if hf_quantizer is not None: + state_dict = hf_quantizer.maybe_update_state_dict(state_dict) mismatched_keys = _find_mismatched_keys( state_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2a384468e595..3a22b47a17a3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -716,17 +716,14 @@ def save_pretrained( return hf_quantizer = getattr(self, "hf_quantizer", None) - is_torchao_quantized = ( - hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - ) if hf_quantizer is not None: quantization_serializable = ( hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) - if safe_serialization and hasattr(hf_quantizer, "is_safetensors_serializable"): - quantization_serializable = quantization_serializable and hf_quantizer.is_safetensors_serializable + if safe_serialization and quantization_serializable: + quantization_serializable = quantization_serializable and hf_quantizer.supports_safetensors_serialization if not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" @@ -765,8 +762,10 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() quantization_metadata = {} - if safe_serialization and is_torchao_quantized: - state_dict, quantization_metadata = hf_quantizer.get_state_dict_and_metadata(state_dict) + if hf_quantizer is not None: + state_dict, quantization_metadata = hf_quantizer.get_state_dict_and_metadata( + state_dict, safe_serialization=safe_serialization + ) if use_flashpack: if is_flashpack_available(): @@ -812,7 +811,7 @@ def save_pretrained( filepath = os.path.join(save_directory, filename) if safe_serialization: metadata = {"format": "pt"} - if is_torchao_quantized: + if quantization_metadata: metadata.update(quantization_metadata) metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} # At some point we will need to deal better with save_function (used for TPU and other distributed @@ -823,7 +822,7 @@ def save_pretrained( if state_dict_split.is_sharded: metadata = dict(state_dict_split.metadata) - if is_torchao_quantized: + if quantization_metadata: metadata.update(quantization_metadata) index = { "metadata": metadata, @@ -1382,16 +1381,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None else: loaded_keys = list(state_dict.keys()) - is_torchao_quantized = ( - hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - ) - has_torchao_safetensors_metadata = False checkpoint_files = resolved_model_file - if is_torchao_quantized: - hf_quantizer.set_metadata(checkpoint_files) - has_torchao_safetensors_metadata = bool(hf_quantizer.metadata) - if has_torchao_safetensors_metadata: - loaded_keys = list(hf_quantizer.get_weight_names()) + if hf_quantizer is not None: + loaded_keys = hf_quantizer.maybe_update_loaded_keys(loaded_keys, checkpoint_files) if hf_quantizer is not None: hf_quantizer.preprocess_model( @@ -1400,9 +1392,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None keep_in_fp32_modules=keep_in_fp32_modules, ) - if has_torchao_safetensors_metadata: - # TorchAO safetensors reconstruction carries incomplete tensor subclass pieces from one shard to the next. - # Loading shards concurrently would make that pending state nondeterministic. + if hf_quantizer is not None and not hf_quantizer.supports_parallel_loading: is_parallel_loading_enabled = False # Now that the model is loaded, we can determine the device_map diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 5dc20fa2f7e7..29dd2ba7f9ab 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -168,6 +168,25 @@ def validate_environment(self, *args, **kwargs): """ return + def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]: + return loaded_keys + + def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + return state_dict + + @property + def supports_parallel_loading(self) -> bool: + return True + + def get_state_dict_and_metadata( + self, state_dict: dict[str, Any], safe_serialization: bool = False + ) -> tuple[dict[str, Any], dict[str, Any]]: + return state_dict, {} + + @property + def supports_safetensors_serialization(self) -> bool: + return True + def preprocess_model(self, model: "ModelMixin", **kwargs): """ Setting model attributes and/or converting model before weights loading. At this point the model should be diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index cfa1f344c3f3..b33a18cd142c 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -251,11 +251,11 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory - def get_state_dict_and_metadata(self, state_dict): + def get_state_dict_and_metadata(self, state_dict: dict[str, Any], safe_serialization: bool = False): """ We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format. """ - if not is_torchao_available() or not is_torchao_version(">=", "0.16.0"): + if not safe_serialization or not is_torchao_available() or not is_torchao_version(">=", "0.16.0"): return state_dict, {} flattened_state_dict = flatten_tensor_state_dict(state_dict) @@ -264,6 +264,12 @@ def get_state_dict_and_metadata(self, state_dict): return flattened_state_dict, {} + def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]: + self.set_metadata(checkpoint_files) + if self._metadata: + return list(self.get_weight_names()) + return loaded_keys + def set_metadata(self, checkpoint_files: list[str]): self._metadata = {} self._pending_flattened_state_dict = {} @@ -290,7 +296,7 @@ def set_metadata(self, checkpoint_files: list[str]): def metadata(self): return self._metadata - def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: if not self._metadata or not is_torchao_version(">=", "0.16.0") or not is_metadata_torchao(self._metadata): return state_dict @@ -301,6 +307,11 @@ def get_reconstructed_state_dict(self, state_dict: dict[str, Any]) -> dict[str, return reconstructed_state_dict + @property + def supports_parallel_loading(self) -> bool: + # Safetensors reconstruction can carry leftover flattened tensor pieces from one shard to the next. + return not self._metadata + def get_weight_names(self): if not self._metadata: return set() @@ -408,7 +419,7 @@ def _process_model_after_weight_loading(self, model: "ModelMixin"): return model @property - def is_safetensors_serializable(self): + def supports_safetensors_serialization(self): if not is_torchao_version(">=", "0.16.0"): logger.warning( "TorchAO quantized model is not serializable with safe serialization without safetensors support " diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index c0c1a8fd4fc9..8a811cfc1c73 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -605,47 +605,6 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - @require_torchao_version_greater_or_equal("0.16.0") - def test_group_offload_to_disk(self): - quant_type = Int8DynamicActivationInt8WeightConfig(version=2) - expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - - quantized_model = self.get_dummy_model(quant_type, torch_device) - - with tempfile.TemporaryDirectory() as offload_to_disk_path: - quantized_model.enable_group_offload( - onload_device=torch_device, - offload_type="leaf_level", - offload_to_disk_path=offload_to_disk_path, - ) - - inputs = self.get_dummy_tensor_inputs(torch_device) - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - - self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - - output = quantized_model(**inputs)[0] - output_slice_2 = output.flatten()[-9:].detach().float().cpu().numpy() - - self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) - - del quantized_model - gc.collect() - backend_empty_cache(torch_device) - - quantized_model = self.get_dummy_model(quant_type, torch_device) - quantized_model.enable_group_offload( - onload_device=torch_device, - offload_type="leaf_level", - offload_to_disk_path=offload_to_disk_path, - ) - - output = quantized_model(**inputs)[0] - output_slice_3 = output.flatten()[-9:].detach().float().cpu().numpy() - - self.assertTrue(numpy_cosine_similarity_distance(output_slice_3, expected_slice) < 1e-3) - def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) From 79ea6bbe456158514a904dcbb02db1af3d2b3c23 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 8 Jun 2026 16:04:00 +0100 Subject: [PATCH 8/9] Use tensor assertion helper in TorchAO tests --- tests/models/testing_utils/quantization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 1c45ce7dafb5..9a6c1ac6a790 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -28,6 +28,7 @@ ) from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, @@ -928,7 +929,7 @@ def test_torchao_quantization_serialization(self, quant_type, tmp_path): with torch.no_grad(): output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() - torch.testing.assert_close(output, expected_output, rtol=1e-3, atol=1e-3) + assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("quant_type", ["int8dq"], ids=["int8dq"]) @require_torchao_version_greater_or_equal("0.16.0") @@ -958,7 +959,7 @@ def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): with torch.no_grad(): output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() - torch.testing.assert_close(output, expected_output, rtol=1e-3, atol=1e-3) + assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3) def test_torchao_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly.""" From 7aa45e5e87d3ad7f06f9ee10b20b08e4fb7ce7ea Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 9 Jun 2026 15:06:29 +0100 Subject: [PATCH 9/9] ruff --- src/diffusers/models/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 3a22b47a17a3..9faae86ce8af 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -723,7 +723,9 @@ def save_pretrained( and hf_quantizer.is_serializable ) if safe_serialization and quantization_serializable: - quantization_serializable = quantization_serializable and hf_quantizer.supports_safetensors_serialization + quantization_serializable = ( + quantization_serializable and hf_quantizer.supports_safetensors_serialization + ) if not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"