diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..63371264457d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1482,6 +1482,8 @@ def cuda(self, *args, **kwargs): def to(self, *args, **kwargs): from ..hooks.group_offloading import _is_group_offload_enabled + fp32_modules = self._keep_in_fp32_modules or [] + device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs dtype_present_in_args = "dtype" in kwargs @@ -1501,6 +1503,11 @@ def to(self, *args, **kwargs): dtype_present_in_args = True break + if dtype_present_in_args and fp32_modules is not None: + logger.warning( + f"There are modules in {self.__class__.__name__} that should be kept in float32: {fp32_modules}. A bare `to()` might lead to inconsistent results." + ) + if getattr(self, "is_quantized", False): if dtype_present_in_args: raise ValueError( diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py index 7d4ea24d5502..cc5cdfa1b738 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from diffusers import AutoencoderKLTemporalDecoder @@ -63,7 +64,12 @@ def get_dummy_inputs(self) -> dict: class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin): - pass + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # The reference and reloaded models hold identical weights, so any output difference is + # half-precision kernel nondeterminism between the two module instances rather than a save/load + # fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs. + super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=3e-3) class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin): diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 7fdab4aeb910..d969865c5a33 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -76,7 +76,13 @@ def get_dummy_inputs(self) -> dict: class TestAutoencoderTiny(AutoencoderTinyTesterConfig, ModelTesterMixin): - pass + @pytest.mark.skip( + reason="`forward` round-trips latents through a uint8 quantization (to simulate storing them as an RGBA " + "image), which upcasts them to fp32 and breaks the half-precision decoder. This is intrinsic to the model " + "and unrelated to save/load fidelity." + ) + def test_from_save_pretrained_dtype_inference(self, *args, **kwargs): + pass class TestAutoencoderTinyTraining(AutoencoderTinyTesterConfig, TrainingTesterMixin): diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 0edb713d9a1f..616487eff409 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline @@ -87,7 +88,12 @@ def get_dummy_inputs(self) -> dict: class TestConsistencyDecoderVAE(ConsistencyDecoderVAETesterConfig, ModelTesterMixin): - pass + @pytest.mark.skip( + reason="The consistency decoder samples noise (`randn_tensor`) during `decode`, so two forward passes " + "diverge regardless of dtype. This makes a save/load output comparison non-deterministic." + ) + def test_from_save_pretrained_dtype_inference(self, *args, **kwargs): + pass class TestConsistencyDecoderVAETraining(ConsistencyDecoderVAETesterConfig, TrainingTesterMixin): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index ce1606f0e859..5567a43ecc38 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -64,6 +64,13 @@ def get_dummy_inputs(self) -> dict: class TestVQModel(VQModelTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + # The reference and reloaded models hold identical weights, so any output difference is + # half-precision kernel nondeterminism between the two module instances rather than a save/load + # fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs. + super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=1e-3) + def test_from_pretrained_hub(self): model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) assert model is not None diff --git a/tests/models/controlnets/test_models_controlnet_cosmos.py b/tests/models/controlnets/test_models_controlnet_cosmos.py index 9bef488a8106..d4cbe2b91a0f 100644 --- a/tests/models/controlnets/test_models_controlnet_cosmos.py +++ b/tests/models/controlnets/test_models_controlnet_cosmos.py @@ -251,6 +251,10 @@ def test_determinism(self): def test_from_save_pretrained(self): super().test_from_save_pretrained() + @pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.") + def test_from_save_pretrained_dtype_inference(self, *args, **kwargs): + super().test_from_save_pretrained_dtype_inference(*args, **kwargs) + @pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.") def test_from_save_pretrained_variant(self): super().test_from_save_pretrained_variant() diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index ba060b3b120d..443502257e4a 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -469,6 +469,28 @@ def test_keep_in_fp32_modules(self, tmp_path): else: assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}" + def test_to_keep_in_fp32_modules_warns(self, caplog): + fp32_modules = self.model_class._keep_in_fp32_modules + if fp32_modules is None or len(fp32_modules) == 0: + pytest.skip("Model does not have _keep_in_fp32_modules defined.") + + model = self.model_class(**self.get_init_dict()) + + logger_name = "diffusers.models.modeling_utils" + logging.enable_propagation() + try: + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() + model.to(torch.float16) + finally: + logging.disable_propagation() + + expected_message = ( + f"There are modules in {model.__class__.__name__} that should be kept in float32: " + f"{fp32_modules}. A bare `to()` might lead to inconsistent results." + ) + assert expected_message in caplog.text + @require_accelerator @pytest.mark.skipif( torch_device not in ["cuda", "xpu"], @@ -481,7 +503,25 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, model.to(torch_device) fp32_modules = model._keep_in_fp32_modules or [] - model.to(dtype).save_pretrained(tmp_path) + # Build the reference model with the same mixed-precision layout that `from_pretrained` enforces, so + # the comparison reflects real save/load fidelity: + # - `_keep_in_fp32_modules` stay in fp32 while everything else is cast to `dtype`; + # - non-persistent buffers (e.g. fp32 RoPE `inv_freq`) are left untouched, because they are not part + # of the checkpoint and are regenerated by `__init__` on load. Truncating them here would make the + # reference diverge from the reloaded model for reasons unrelated to save/load. + persistent_tensor_names = {name for name, _ in named_persistent_module_tensors(model, recurse=True)} + + def keep_in_fp32(name): + return any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules) + + for name, param in model.named_parameters(): + param.data = param.data.to(torch.float32 if keep_in_fp32(name) else dtype) + for name, buf in model.named_buffers(): + if not buf.is_floating_point() or name not in persistent_tensor_names: + continue + buf.data = buf.data.to(torch.float32 if keep_in_fp32(name) else dtype) + + model.save_pretrained(tmp_path) model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) for name, param in model_loaded.named_parameters(): diff --git a/tests/models/transformers/test_models_transformer_anyflow.py b/tests/models/transformers/test_models_transformer_anyflow.py index df72567a7455..5011222f17c9 100644 --- a/tests/models/transformers/test_models_transformer_anyflow.py +++ b/tests/models/transformers/test_models_transformer_anyflow.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import AnyFlowTransformer3DModel @@ -100,12 +99,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestAnyFlowTransformer3D(AnyFlowTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for AnyFlow Transformer 3D (bidirectional variant).""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestAnyFlowTransformer3DMemory(AnyFlowTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for AnyFlow Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_anyflow_far.py b/tests/models/transformers/test_models_transformer_anyflow_far.py index d7ed471fa875..b1b9d155b752 100644 --- a/tests/models/transformers/test_models_transformer_anyflow_far.py +++ b/tests/models/transformers/test_models_transformer_anyflow_far.py @@ -113,12 +113,6 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]: class TestAnyFlowFARTransformer3D(AnyFlowFARTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for AnyFlow FAR causal Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestAnyFlowFARTransformer3DMemory(AnyFlowFARTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for AnyFlow FAR Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_helios.py b/tests/models/transformers/test_models_transformer_helios.py index c365c258e596..927581b095e8 100644 --- a/tests/models/transformers/test_models_transformer_helios.py +++ b/tests/models/transformers/test_models_transformer_helios.py @@ -135,12 +135,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Helios Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Helios Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_ideogram4.py b/tests/models/transformers/test_models_transformer_ideogram4.py index 31592ada64bc..9f32a4d04505 100644 --- a/tests/models/transformers/test_models_transformer_ideogram4.py +++ b/tests/models/transformers/test_models_transformer_ideogram4.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import Ideogram4Transformer2DModel @@ -141,14 +140,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin): """Core model tests for Ideogram 4 Transformer.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory - # .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any - # meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype - # and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for Ideogram 4 Transformer.""" diff --git a/tests/models/transformers/test_models_transformer_joyimage.py b/tests/models/transformers/test_models_transformer_joyimage.py index c464a44c29b5..45d15b2d470a 100644 --- a/tests/models/transformers/test_models_transformer_joyimage.py +++ b/tests/models/transformers/test_models_transformer_joyimage.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import JoyImageEditTransformer3DModel @@ -86,9 +85,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin): - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - pytest.skip("Tolerance requirements too high for meaningful test") + pass class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 60bba9dfbe18..aacbf542b548 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from diffusers import WanTransformer3DModel @@ -106,12 +105,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Wan Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index 30f78ca1c3de..228d11d0ea83 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -152,12 +152,6 @@ def test_output(self): expected_output_shape = (1, 4, 21, 16, 16) super().test_output(expected_output_shape=expected_output_shape) - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin): """Memory optimization tests for Wan Animate Transformer 3D.""" diff --git a/tests/models/transformers/test_models_transformer_wan_vace.py b/tests/models/transformers/test_models_transformer_wan_vace.py index 1cc829f88b9d..503569662b14 100644 --- a/tests/models/transformers/test_models_transformer_wan_vace.py +++ b/tests/models/transformers/test_models_transformer_wan_vace.py @@ -117,12 +117,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan VACE Transformer 3D.""" - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) - def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): - # Skip: fp16/bf16 require very high atol to pass, providing little signal. - # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules. - pytest.skip("Tolerance requirements too high for meaningful test") - def test_model_parallelism(self, tmp_path): # Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow pytest.skip("Model parallelism not yet supported for WanVACE")