Skip to content
3 changes: 3 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ def _load_shard_file(
disable_mmap=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
if hf_quantizer is not None:
state_dict = hf_quantizer.maybe_update_state_dict(state_dict)

mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
Expand Down
31 changes: 28 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ def save_pretrained(
and isinstance(hf_quantizer, DiffusersQuantizer)
and hf_quantizer.is_serializable
)
if safe_serialization and quantization_serializable:
quantization_serializable = (
quantization_serializable and hf_quantizer.supports_safetensors_serialization
)
if not quantization_serializable:
raise ValueError(
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
Expand Down Expand Up @@ -759,6 +763,11 @@ def save_pretrained(

# Save the model
state_dict = model_to_save.state_dict()
quantization_metadata = {}
if hf_quantizer is not None:
state_dict, quantization_metadata = hf_quantizer.get_state_dict_and_metadata(
state_dict, safe_serialization=safe_serialization
)

if use_flashpack:
if is_flashpack_available():
Expand Down Expand Up @@ -803,15 +812,22 @@ def save_pretrained(
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
metadata = {"format": "pt"}
if quantization_metadata:
metadata.update(quantization_metadata)
metadata = {k: str(v) if not isinstance(v, str) else v for k, v in metadata.items()}
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata=metadata)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving all metadata, then I'd restrict it to torchao only.

else:
torch.save(shard, filepath)

if state_dict_split.is_sharded:
metadata = dict(state_dict_split.metadata)
if quantization_metadata:
metadata.update(quantization_metadata)
index = {
"metadata": state_dict_split.metadata,
"metadata": metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -1367,11 +1383,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None
else:
loaded_keys = list(state_dict.keys())

checkpoint_files = resolved_model_file
if hf_quantizer is not None:
loaded_keys = hf_quantizer.maybe_update_loaded_keys(loaded_keys, checkpoint_files)

if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
model=model,
device_map=device_map,
keep_in_fp32_modules=keep_in_fp32_modules,
)

if hf_quantizer is not None and not hf_quantizer.supports_parallel_loading:
is_parallel_loading_enabled = False
Comment thread
sayakpaul marked this conversation as resolved.

# Now that the model is loaded, we can determine the device_map
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ def validate_environment(self, *args, **kwargs):
"""
return

def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]:
return loaded_keys

def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]:
return state_dict

@property
def supports_parallel_loading(self) -> bool:
return True

def get_state_dict_and_metadata(
self, state_dict: dict[str, Any], safe_serialization: bool = False
) -> tuple[dict[str, Any], dict[str, Any]]:
return state_dict, {}

@property
def supports_safetensors_serialization(self) -> bool:
return True

def preprocess_model(self, model: "ModelMixin", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point the model should be
Expand Down
94 changes: 90 additions & 4 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import importlib
import json
import re
import types
from typing import TYPE_CHECKING, Any
Expand All @@ -26,6 +27,7 @@

from ...utils import (
get_module_from_name,
is_safetensors_available,
is_torch_available,
is_torch_version,
is_torchao_available,
Expand All @@ -41,6 +43,9 @@
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin

if is_safetensors_available():
from safetensors import safe_open


if is_torch_available():
import torch
Expand Down Expand Up @@ -72,6 +77,13 @@
if is_torchao_available():
from torchao.quantization import quantize_

if is_torchao_version(">=", "0.16.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao


def _update_torch_safe_globals():
safe_globals = [
Expand Down Expand Up @@ -154,6 +166,9 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

self._metadata = {}
self._pending_flattened_state_dict = {}

def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError(
Expand Down Expand Up @@ -236,6 +251,72 @@ def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int |
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory

def get_state_dict_and_metadata(self, state_dict: dict[str, Any], safe_serialization: bool = False):
"""
We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
"""
if not safe_serialization or not is_torchao_available() or not is_torchao_version(">=", "0.16.0"):
return state_dict, {}

flattened_state_dict = flatten_tensor_state_dict(state_dict)
if isinstance(flattened_state_dict, tuple):
return flattened_state_dict

return flattened_state_dict, {}

def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]:
self.set_metadata(checkpoint_files)
if self._metadata:
return list(self.get_weight_names())
return loaded_keys

def set_metadata(self, checkpoint_files: list[str]):
self._metadata = {}
self._pending_flattened_state_dict = {}

if not is_safetensors_available() or not is_torchao_version(">=", "0.16.0"):
return

if len(checkpoint_files) == 0:
return

if not all(
isinstance(checkpoint, str) and checkpoint.endswith(".safetensors") for checkpoint in checkpoint_files
):
return

metadata = {}
for checkpoint in checkpoint_files:
with safe_open(checkpoint, framework="pt") as f:
metadata.update(f.metadata() or {})

self._metadata = metadata if is_metadata_torchao(metadata) else {}

@property
def metadata(self):
return self._metadata

def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]:
if not self._metadata or not is_torchao_version(">=", "0.16.0") or not is_metadata_torchao(self._metadata):
return state_dict

merged_state_dict = {**self._pending_flattened_state_dict, **state_dict}
reconstructed_state_dict, self._pending_flattened_state_dict = unflatten_tensor_state_dict(
merged_state_dict, self._metadata
)

return reconstructed_state_dict

@property
def supports_parallel_loading(self) -> bool:
# Safetensors reconstruction can carry leftover flattened tensor pieces from one shard to the next.
return not self._metadata

def get_weight_names(self):
if not self._metadata:
return set()
return set(json.loads(self._metadata["tensor_names"]))

def check_if_quantized_param(
self,
model: "ModelMixin",
Expand Down Expand Up @@ -337,14 +418,19 @@ def _process_model_before_weight_loading(
def _process_model_after_weight_loading(self, model: "ModelMixin"):
return model

def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
@property
def supports_safetensors_serialization(self):
if not is_torchao_version(">=", "0.16.0"):
logger.warning(
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
"TorchAO quantized model is not serializable with safe serialization without safetensors support "
"from the installed torchao version."
)
return False

return True

@property
def is_serializable(self):
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)
Expand Down
56 changes: 49 additions & 7 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
Expand Down Expand Up @@ -807,11 +808,12 @@ class TorchAoConfigMixin:
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
config_kwargs = {"version": 2}
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))
config_kwargs.setdefault("int4_packing_format", "plain_int32")

return TorchAoConfig(config_cls())
return TorchAoConfig(config_cls(**config_kwargs))

def _create_quantized_model(self, config_name, **extra_kwargs):
config = self._get_quant_config(config_name)
Expand Down Expand Up @@ -906,18 +908,58 @@ def test_torchao_quantization_lora_inference(self, quant_type):
self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])

@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
@require_torchao_version_greater_or_equal("0.16.0")
def test_torchao_quantization_serialization(self, quant_type, tmp_path):
"""Override to use safe_serialization=False for TorchAO (safetensors not supported)."""
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]
model = self._create_quantized_model(config_kwargs)
inputs = self.get_dummy_inputs()

with torch.no_grad():
expected_output = model(**inputs, return_dict=False)[0].detach().cpu()

model.save_pretrained(str(tmp_path), safe_serialization=True)
del model
gc.collect()
backend_empty_cache(torch_device)

model_loaded = self.model_class.from_pretrained(
str(tmp_path), device_map=str(torch_device), use_safetensors=True
)

model.save_pretrained(str(tmp_path), safe_serialization=False)
with torch.no_grad():
output = model_loaded(**inputs, return_dict=False)[0].detach().cpu()

model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device))
assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3)

@pytest.mark.parametrize("quant_type", ["int8dq"], ids=["int8dq"])
@require_torchao_version_greater_or_equal("0.16.0")
def test_torchao_quantization_sharded_serialization(self, quant_type, tmp_path):
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]
model = self._create_quantized_model(config_kwargs)
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"

with torch.no_grad():
expected_output = model(**inputs, return_dict=False)[0].detach().cpu()

model.save_pretrained(str(tmp_path), safe_serialization=True, max_shard_size="16KB")
del model
gc.collect()
backend_empty_cache(torch_device)

shard_files = list(tmp_path.glob("*.safetensors"))
assert len(shard_files) > 1, "Expected a sharded safe-serialization checkpoint."
assert any(path.name.endswith(".index.json") for path in tmp_path.iterdir()), (
"Expected an index file for sharded safe checkpoint."
)

model_loaded = self.model_class.from_pretrained(
str(tmp_path), device_map=str(torch_device), use_safetensors=True
)

with torch.no_grad():
output = model_loaded(**inputs, return_dict=False)[0].detach().cpu()

assert_tensors_close(output, expected_output, rtol=1e-3, atol=1e-3)

def test_torchao_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
Expand Down
Loading