diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..abbde8082bb5 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -357,6 +357,9 @@ def _load_shard_file( disable_mmap=False, ): state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) + if hf_quantizer is not None: + state_dict = hf_quantizer.maybe_update_state_dict(state_dict) + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..9faae86ce8af 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -722,6 +722,10 @@ def save_pretrained( and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) + if safe_serialization and quantization_serializable: + quantization_serializable = ( + quantization_serializable and hf_quantizer.supports_safetensors_serialization + ) if not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" @@ -759,6 +763,11 @@ def save_pretrained( # Save the model state_dict = model_to_save.state_dict() + quantization_metadata = {} + if hf_quantizer is not None: + state_dict, quantization_metadata = hf_quantizer.get_state_dict_and_metadata( + state_dict, safe_serialization=safe_serialization + ) if use_flashpack: if is_flashpack_available(): @@ -803,15 +812,22 @@ def save_pretrained( shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} filepath = os.path.join(save_directory, filename) if safe_serialization: + metadata = {"format": "pt"} + if quantization_metadata: + metadata.update(quantization_metadata) + metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()} # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata=metadata) else: torch.save(shard, filepath) if state_dict_split.is_sharded: + metadata = dict(state_dict_split.metadata) + if quantization_metadata: + metadata.update(quantization_metadata) index = { - "metadata": state_dict_split.metadata, + "metadata": metadata, "weight_map": state_dict_split.tensor_to_filename, } save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME @@ -1367,11 +1383,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None else: loaded_keys = list(state_dict.keys()) + checkpoint_files = resolved_model_file + if hf_quantizer is not None: + loaded_keys = hf_quantizer.maybe_update_loaded_keys(loaded_keys, checkpoint_files) + if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + model=model, + device_map=device_map, + keep_in_fp32_modules=keep_in_fp32_modules, ) + if hf_quantizer is not None and not hf_quantizer.supports_parallel_loading: + is_parallel_loading_enabled = False + # Now that the model is loaded, we can determine the device_map device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 5dc20fa2f7e7..29dd2ba7f9ab 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -168,6 +168,25 @@ def validate_environment(self, *args, **kwargs): """ return + def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]: + return loaded_keys + + def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + return state_dict + + @property + def supports_parallel_loading(self) -> bool: + return True + + def get_state_dict_and_metadata( + self, state_dict: dict[str, Any], safe_serialization: bool = False + ) -> tuple[dict[str, Any], dict[str, Any]]: + return state_dict, {} + + @property + def supports_safetensors_serialization(self) -> bool: + return True + def preprocess_model(self, model: "ModelMixin", **kwargs): """ Setting model attributes and/or converting model before weights loading. At this point the model should be diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index b710fcd2db30..b33a18cd142c 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -18,6 +18,7 @@ """ import importlib +import json import re import types from typing import TYPE_CHECKING, Any @@ -26,6 +27,7 @@ from ...utils import ( get_module_from_name, + is_safetensors_available, is_torch_available, is_torch_version, is_torchao_available, @@ -41,6 +43,9 @@ if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin +if is_safetensors_available(): + from safetensors import safe_open + if is_torch_available(): import torch @@ -72,6 +77,13 @@ if is_torchao_available(): from torchao.quantization import quantize_ + if is_torchao_version(">=", "0.16.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + def _update_torch_safe_globals(): safe_globals = [ @@ -154,6 +166,9 @@ class TorchAoHfQuantizer(DiffusersQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + self._metadata = {} + self._pending_flattened_state_dict = {} + def validate_environment(self, *args, **kwargs): if not is_torchao_available(): raise ImportError( @@ -236,6 +251,72 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | max_memory = {key: val * 0.9 for key, val in max_memory.items()} return max_memory + def get_state_dict_and_metadata(self, state_dict: dict[str, Any], safe_serialization: bool = False): + """ + We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format. + """ + if not safe_serialization or not is_torchao_available() or not is_torchao_version(">=", "0.16.0"): + return state_dict, {} + + flattened_state_dict = flatten_tensor_state_dict(state_dict) + if isinstance(flattened_state_dict, tuple): + return flattened_state_dict + + return flattened_state_dict, {} + + def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]: + self.set_metadata(checkpoint_files) + if self._metadata: + return list(self.get_weight_names()) + return loaded_keys + + def set_metadata(self, checkpoint_files: list[str]): + self._metadata = {} + self._pending_flattened_state_dict = {} + + if not is_safetensors_available() or not is_torchao_version(">=", "0.16.0"): + return + + if len(checkpoint_files) == 0: + return + + if not all( + isinstance(checkpoint, str) and checkpoint.endswith(".safetensors") for checkpoint in checkpoint_files + ): + return + + metadata = {} + for checkpoint in checkpoint_files: + with safe_open(checkpoint, framework="pt") as f: + metadata.update(f.metadata() or {}) + + self._metadata = metadata if is_metadata_torchao(metadata) else {} + + @property + def metadata(self): + return self._metadata + + def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + if not self._metadata or not is_torchao_version(">=", "0.16.0") or not is_metadata_torchao(self._metadata): + return state_dict + + merged_state_dict = {**self._pending_flattened_state_dict, **state_dict} + reconstructed_state_dict, self._pending_flattened_state_dict = unflatten_tensor_state_dict( + merged_state_dict, self._metadata + ) + + return reconstructed_state_dict + + @property + def supports_parallel_loading(self) -> bool: + # Safetensors reconstruction can carry leftover flattened tensor pieces from one shard to the next. + return not self._metadata + + def get_weight_names(self): + if not self._metadata: + return set() + return set(json.loads(self._metadata["tensor_names"])) + def check_if_quantized_param( self, model: "ModelMixin", @@ -337,14 +418,19 @@ def _process_model_before_weight_loading( def _process_model_after_weight_loading(self, model: "ModelMixin"): return model - def is_serializable(self, safe_serialization=None): - # TODO(aryan): needs to be tested - if safe_serialization: + @property + def supports_safetensors_serialization(self): + if not is_torchao_version(">=", "0.16.0"): logger.warning( - "torchao quantized model does not support safe serialization, please set `safe_serialization` to False." + "TorchAO quantized model is not serializable with safe serialization without safetensors support " + "from the installed torchao version." ) return False + return True + + @property + def is_serializable(self): _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( "0.25.0" ) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index ded5cab52268..9a6c1ac6a790 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -28,6 +28,7 @@ ) from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, @@ -807,11 +808,12 @@ class TorchAoConfigMixin: @staticmethod def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) + config_kwargs = {"version": 2} # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": - return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) + config_kwargs.setdefault("int4_packing_format", "plain_int32") - return TorchAoConfig(config_cls()) + return TorchAoConfig(config_cls(**config_kwargs)) def _create_quantized_model(self, config_name, **extra_kwargs): config = self._get_quant_config(config_name) @@ -906,18 +908,58 @@ def test_torchao_quantization_lora_inference(self, quant_type): self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + @require_torchao_version_greater_or_equal("0.16.0") def test_torchao_quantization_serialization(self, quant_type, tmp_path): - """Override to use safe_serialization=False for TorchAO (safetensors not supported).""" config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] model = self._create_quantized_model(config_kwargs) + inputs = self.get_dummy_inputs() + + with torch.no_grad(): + expected_output = model(**inputs, return_dict=False)[0].detach().cpu() + + model.save_pretrained(str(tmp_path), safe_serialization=True) + del model + gc.collect() + backend_empty_cache(torch_device) + + model_loaded = self.model_class.from_pretrained( + str(tmp_path), device_map=str(torch_device), use_safetensors=True + ) - model.save_pretrained(str(tmp_path), safe_serialization=False) + with torch.no_grad(): + output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() - model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) + assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3) + @pytest.mark.parametrize("quant_type", ["int8dq"], ids=["int8dq"]) + @require_torchao_version_greater_or_equal("0.16.0") + def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path): + config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] + model = self._create_quantized_model(config_kwargs) inputs = self.get_dummy_inputs() - output = model_loaded(**inputs, return_dict=False)[0] - assert not torch.isnan(output).any(), "Loaded model output contains NaN" + + with torch.no_grad(): + expected_output = model(**inputs, return_dict=False)[0].detach().cpu() + + model.save_pretrained(str(tmp_path), safe_serialization=True, max_shard_size="16KB") + del model + gc.collect() + backend_empty_cache(torch_device) + + shard_files = list(tmp_path.glob("*.safetensors")) + assert len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint." + assert any(path.name.endswith(".index.json") for path in tmp_path.iterdir()), ( + "Expected an index file for sharded safe checkpoint." + ) + + model_loaded = self.model_class.from_pretrained( + str(tmp_path), device_map=str(torch_device), use_safetensors=True + ) + + with torch.no_grad(): + output = model_loaded(**inputs, return_dict=False)[0].detach().cpu() + + assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3) def test_torchao_modules_to_not_convert(self): """Test that modules_to_not_convert parameter works correctly."""