Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,8 @@ def cuda(self, *args, **kwargs):
def to(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled

fp32_modules = self._keep_in_fp32_modules or []

device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs

Expand All @@ -1501,6 +1503,11 @@ def to(self, *args, **kwargs):
dtype_present_in_args = True
break

if dtype_present_in_args and fp32_modules is not None:
logger.warning(
f"There are modules in {self.__class__.__name__} that should be kept in float32: {fp32_modules}. A bare `to()` might lead to inconsistent results."
)

if getattr(self, "is_quantized", False):
if dtype_present_in_args:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import AutoencoderKLTemporalDecoder
Expand Down Expand Up @@ -63,7 +64,12 @@ def get_dummy_inputs(self) -> dict:


class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin):
pass
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# The reference and reloaded models hold identical weights, so any output difference is
# half-precision kernel nondeterminism between the two module instances rather than a save/load
# fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs.
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=3e-3)


class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin):
Expand Down
8 changes: 7 additions & 1 deletion tests/models/autoencoders/test_models_autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ def get_dummy_inputs(self) -> dict:


class TestAutoencoderTiny(AutoencoderTinyTesterConfig, ModelTesterMixin):
pass
@pytest.mark.skip(
reason="`forward` round-trips latents through a uint8 quantization (to simulate storing them as an RGBA "
"image), which upcasts them to fp32 and breaks the half-precision decoder. This is intrinsic to the model "
"and unrelated to save/load fidelity."
)
def test_from_save_pretrained_dtype_inference(self, *args, **kwargs):
pass


class TestAutoencoderTinyTraining(AutoencoderTinyTesterConfig, TrainingTesterMixin):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
import pytest
import torch

from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
Expand Down Expand Up @@ -87,7 +88,12 @@ def get_dummy_inputs(self) -> dict:


class TestConsistencyDecoderVAE(ConsistencyDecoderVAETesterConfig, ModelTesterMixin):
pass
@pytest.mark.skip(
reason="The consistency decoder samples noise (`randn_tensor`) during `decode`, so two forward passes "
"diverge regardless of dtype. This makes a save/load output comparison non-deterministic."
)
def test_from_save_pretrained_dtype_inference(self, *args, **kwargs):
pass


class TestConsistencyDecoderVAETraining(ConsistencyDecoderVAETesterConfig, TrainingTesterMixin):
Expand Down
7 changes: 7 additions & 0 deletions tests/models/autoencoders/test_models_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def get_dummy_inputs(self) -> dict:


class TestVQModel(VQModelTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# The reference and reloaded models hold identical weights, so any output difference is
# half-precision kernel nondeterminism between the two module instances rather than a save/load
# fidelity issue. The default 1e-4 tolerance is too tight for that fp16/bf16 noise on some GPUs.
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype, atol=1e-3)

def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
assert model is not None
Expand Down
4 changes: 4 additions & 0 deletions tests/models/controlnets/test_models_controlnet_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def test_determinism(self):
def test_from_save_pretrained(self):
super().test_from_save_pretrained()

@pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.")
def test_from_save_pretrained_dtype_inference(self, *args, **kwargs):
super().test_from_save_pretrained_dtype_inference(*args, **kwargs)

@pytest.mark.skip("Output is a list of tensors; comparison helper calls .shape on it.")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
Expand Down
42 changes: 41 additions & 1 deletion tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,28 @@ def test_keep_in_fp32_modules(self, tmp_path):
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"

def test_to_keep_in_fp32_modules_warns(self, caplog):
fp32_modules = self.model_class._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")

model = self.model_class(**self.get_init_dict())

logger_name = "diffusers.models.modeling_utils"
logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
model.to(torch.float16)
finally:
logging.disable_propagation()

expected_message = (
f"There are modules in {model.__class__.__name__} that should be kept in float32: "
f"{fp32_modules}. A bare `to()` might lead to inconsistent results."
)
assert expected_message in caplog.text

@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
Expand All @@ -481,7 +503,25 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []

model.to(dtype).save_pretrained(tmp_path)
# Build the reference model with the same mixed-precision layout that `from_pretrained` enforces, so
# the comparison reflects real save/load fidelity:
# - `_keep_in_fp32_modules` stay in fp32 while everything else is cast to `dtype`;
# - non-persistent buffers (e.g. fp32 RoPE `inv_freq`) are left untouched, because they are not part
# of the checkpoint and are regenerated by `__init__` on load. Truncating them here would make the
# reference diverge from the reloaded model for reasons unrelated to save/load.
persistent_tensor_names = {name for name, _ in named_persistent_module_tensors(model, recurse=True)}

def keep_in_fp32(name):
return any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules)

for name, param in model.named_parameters():
param.data = param.data.to(torch.float32 if keep_in_fp32(name) else dtype)
for name, buf in model.named_buffers():
if not buf.is_floating_point() or name not in persistent_tensor_names:
continue
buf.data = buf.data.to(torch.float32 if keep_in_fp32(name) else dtype)

model.save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)

for name, param in model_loaded.named_parameters():
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import AnyFlowTransformer3DModel
Expand Down Expand Up @@ -100,12 +99,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestAnyFlowTransformer3D(AnyFlowTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow Transformer 3D (bidirectional variant)."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowTransformer3DMemory(AnyFlowTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]:
class TestAnyFlowFARTransformer3D(AnyFlowFARTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow FAR causal Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowFARTransformer3DMemory(AnyFlowFARTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow FAR Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Helios Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Helios Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import Ideogram4Transformer2DModel
Expand Down Expand Up @@ -141,14 +140,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Ideogram 4 Transformer."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory
# .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any
# meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype
# and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Ideogram 4 Transformer."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import JoyImageEditTransformer3DModel
Expand Down Expand Up @@ -86,9 +85,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:


class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
pytest.skip("Tolerance requirements too high for meaningful test")
pass


class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin):
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import WanTransformer3DModel
Expand Down Expand Up @@ -106,12 +105,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,6 @@ def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")

def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
Expand Down
Loading