diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 4dbb8ca7c075..6d8462c96d88 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -16,14 +16,13 @@ import copy import gc import os -import tempfile import unittest from collections import OrderedDict +import pytest import torch from huggingface_hub import snapshot_download from parameterized import parameterized -from pytest import mark from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import ( @@ -33,7 +32,7 @@ ) from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import ( backend_empty_cache, @@ -52,11 +51,14 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ( +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, LoraHotSwappingForModelTesterMixin, + MemoryTesterMixin, ModelTesterMixin, TorchCompileTesterMixin, - UNetTesterMixin, + TrainingTesterMixin, ) @@ -354,34 +356,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - main_input_name = "sample" +class UNet2DConditionModelTesterConfig(BaseModelTesterConfig): # We override the items here because the unet under consideration is small. model_split_percents = [0.5, 0.34, 0.4] @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) + def model_class(self): + return UNet2DConditionModel - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), @@ -393,190 +389,92 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.enable_xformers_memory_efficient_attention() + def get_dummy_inputs(self) -> dict: + noise = randn_tensor((4, 4, 16, 16), generator=self.generator, device=torch_device) + timestep = torch.tensor([10], device=torch_device) + encoder_hidden_states = randn_tensor((4, 4, 8), generator=self.generator, device=torch_device) + return {"sample": noise, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states} - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNet2DConditionModel(UNet2DConditionModelTesterConfig, ModelTesterMixin): def test_model_with_attention_head_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample + output = model(**self.get_dummy_inputs()).sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_model_with_use_linear_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["use_linear_projection"] = True - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) + output = model(**self.get_dummy_inputs()).sample - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["cross_attention_dim"] = (8, 8) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample + output = model(**self.get_dummy_inputs()).sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_model_with_simple_projection(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape init_dict["class_embed_type"] = "simple_projection" init_dict["projection_class_embeddings_input_dim"] = sample_size - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample + output = model(**inputs_dict).sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == inputs_dict["sample"].shape, "Input and output shapes do not match" def test_model_with_class_embeddings_concat(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() batch_size, _, _, sample_size = inputs_dict["sample"].shape init_dict["class_embed_type"] = "simple_projection" init_dict["projection_class_embeddings_input_dim"] = sample_size init_dict["class_embeddings_concat"] = True - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) + output = model(**inputs_dict).sample - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - model.set_attention_slice("auto") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice("max") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - - model.set_attention_slice(2) - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None + assert output.shape == inputs_dict["sample"].shape, "Input and output shapes do not match" def test_model_sliceable_head_dim(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) - model = self.model_class(**init_dict) def check_sliceable_dim_attr(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): assert isinstance(module.sliceable_head_dim, int) - for child in module.children(): check_sliceable_dim_attr(child) - # retrieve number of attention layers for module in model.children(): check_sliceable_dim_attr(module) - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "CrossAttnUpBlock2D", - "CrossAttnDownBlock2D", - "UNetMidBlock2DCrossAttn", - "UpBlock2D", - "Transformer2DModel", - "DownBlock2D", - } - attention_head_dim = (8, 16) - block_out_channels = (16, 32) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) - def test_special_attn_proc(self): class AttnEasyProc(torch.nn.Module): def __init__(self, num): @@ -617,40 +515,27 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma return hidden_states - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) + model = self.model_class(**init_dict).to(torch_device) processor = AttnEasyProc(5.0) - model.set_attn_processor(processor) - model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample + model(**self.get_dummy_inputs(), cross_attention_kwargs={"number": 123}).sample assert processor.counter == 8 assert processor.is_run assert processor.number == 123 - @parameterized.expand( - [ - # fmt: off - [torch.bool], - [torch.long], - [torch.float], - # fmt: on - ] - ) + @pytest.mark.parametrize("mask_dtype", [torch.bool, torch.long, torch.float]) def test_model_xattn_mask(self, mask_dtype): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)}) model.to(torch_device) model.eval() + inputs_dict = self.get_dummy_inputs() cond = inputs_dict["encoder_hidden_states"] with torch.no_grad(): full_cond_out = model(**inputs_dict).sample @@ -679,16 +564,16 @@ def test_model_xattn_mask(self, mask_dtype): # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. # maybe it's fine that this only works for the unclip use-case. - @mark.skip( + @pytest.mark.skip( reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." ) def test_model_xattn_padding(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) model.to(torch_device) model.eval() + inputs_dict = self.get_dummy_inputs() cond = inputs_dict["encoder_hidden_states"] with torch.no_grad(): full_cond_out = model(**inputs_dict).sample @@ -706,15 +591,12 @@ def test_model_xattn_padding(self): ) def test_custom_diffusion_processors(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) + model = self.model_class(**init_dict).to(torch_device) - model = self.model_class(**init_dict) - model.to(torch_device) - + inputs_dict = self.get_dummy_inputs() with torch.no_grad(): sample1 = model(**inputs_dict).sample @@ -732,17 +614,15 @@ def test_custom_diffusion_processors(self): assert (sample1 - sample2).abs().max() < 3e-3 - def test_custom_diffusion_save_load(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + def test_custom_diffusion_save_load(self, tmp_path): + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) + model = self.model_class(**init_dict).to(torch_device) + inputs_dict = self.get_dummy_inputs() with torch.no_grad(): old_sample = model(**inputs_dict).sample @@ -752,13 +632,12 @@ def test_custom_diffusion_save_load(self): with torch.no_grad(): sample = model(**inputs_dict).sample - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) - torch.manual_seed(0) - new_model = self.model_class(**init_dict) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") - new_model.to(torch_device) + model.save_attn_procs(tmp_path, safe_serialization=False) + assert os.path.isfile(os.path.join(tmp_path, "pytorch_custom_diffusion_weights.bin")) + torch.manual_seed(0) + new_model = self.model_class(**init_dict) + new_model.load_attn_procs(tmp_path, weight_name="pytorch_custom_diffusion_weights.bin") + new_model.to(torch_device) with torch.no_grad(): new_sample = new_model(**inputs_dict).sample @@ -768,78 +647,38 @@ def test_custom_diffusion_save_load(self): # custom diffusion and no custom diffusion should be the same assert (sample - old_sample).abs().max() < 3e-3 - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_custom_diffusion_xformers_on_off(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) - model.set_attn_processor(custom_diffusion_attn_procs) - - # default - with torch.no_grad(): - sample = model(**inputs_dict).sample - - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample - - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample - - assert (sample - on_sample).abs().max() < 1e-4 - assert (sample - off_sample).abs().max() < 1e-4 - def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) + model = self.model_class(**init_dict).to(torch_device) with torch.no_grad(): - sample = model(**inputs_dict).sample + sample = model(**self.get_dummy_inputs()).sample sample_copy = copy.copy(sample) - assert (sample - sample_copy).abs().max() < 1e-4 def test_asymmetrical_unet(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() # Add asymmetry to configs init_dict["transformer_layers_per_block"] = [[3, 2], 1] init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) + model = self.model_class(**init_dict).to(torch_device) + inputs_dict = self.get_dummy_inputs() output = model(**inputs_dict).sample - expected_shape = inputs_dict["sample"].shape - - # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == inputs_dict["sample"].shape, "Input and output shapes do not match" def test_ip_adapter(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) + model = self.model_class(**init_dict).to(torch_device) - model = self.model_class(**init_dict) - model.to(torch_device) - + inputs_dict = self.get_dummy_inputs() # forward pass without ip-adapter with torch.no_grad(): sample1 = model(**inputs_dict).sample @@ -905,14 +744,12 @@ def test_ip_adapter(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) def test_ip_adapter_plus(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (16, 32) init_dict["attention_head_dim"] = (8, 16) + model = self.model_class(**init_dict).to(torch_device) - model = self.model_class(**init_dict) - model.to(torch_device) - + inputs_dict = self.get_dummy_inputs() # forward pass without ip-adapter with torch.no_grad(): sample1 = model(**inputs_dict).sample @@ -977,118 +814,114 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) - @parameterized.expand( + @pytest.mark.parametrize( + "repo_id, variant", [ ("hf-internal-testing/unet2d-sharded-dummy", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] + ], ) @require_torch_accelerator def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @parameterized.expand( + @pytest.mark.parametrize( + "repo_id, variant", [ ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] + ], ) @require_torch_accelerator def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator def test_load_sharded_checkpoint_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator def test_load_sharded_checkpoint_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator - @parameterized.expand( + @pytest.mark.parametrize( + "repo_id, variant", [ ("hf-internal-testing/unet2d-sharded-dummy", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), - ] + ], ) def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator - @parameterized.expand( + @pytest.mark.parametrize( + "repo_id, variant", [ ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), - ] + ], ) def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator def test_load_sharded_checkpoint_device_map_from_hub_local(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_accelerator def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") loaded_model = self.model_class.from_pretrained( ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" ) - new_output = loaded_model(**inputs_dict) + new_output = loaded_model(**self.get_dummy_inputs()) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_peft_backend - def test_load_attn_procs_raise_warning(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) + def test_load_attn_procs_raise_warning(self, tmp_path): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + inputs_dict = self.get_dummy_inputs() # forward pass without LoRA with torch.no_grad(): non_lora_sample = model(**inputs_dict).sample @@ -1102,21 +935,20 @@ def test_load_attn_procs_raise_warning(self): with torch.no_grad(): lora_sample_1 = model(**inputs_dict).sample - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - model.unload_lora() + model.save_attn_procs(tmp_path) + model.unload_lora() - with self.assertWarns(FutureWarning) as warning: - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + with pytest.warns(FutureWarning) as warning: + model.load_attn_procs(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")) - warning_message = str(warning.warnings[0].message) - assert "Using the `load_attn_procs()` method has been deprecated" in warning_message + warning_message = str(warning[0].message) + assert "Using the `load_attn_procs()` method has been deprecated" in warning_message - # import to still check for the rest of the stuff. - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." + # import to still check for the rest of the stuff. + assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - with torch.no_grad(): - lora_sample_2 = model(**inputs_dict).sample + with torch.no_grad(): + lora_sample_2 = model(**inputs_dict).sample assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( "LoRA injected UNet should produce different results." @@ -1126,36 +958,49 @@ def test_load_attn_procs_raise_warning(self): ) @require_peft_backend - def test_save_attn_procs_raise_warning(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) + def test_save_attn_procs_raise_warning(self, tmp_path): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) unet_lora_config = get_unet_lora_config() model.add_adapter(unet_lora_config) assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - with tempfile.TemporaryDirectory() as tmpdirname: - with self.assertWarns(FutureWarning) as warning: - model.save_attn_procs(tmpdirname) + with pytest.warns(FutureWarning) as warning: + model.save_attn_procs(tmp_path) - warning_message = str(warning.warnings[0].message) + warning_message = str(warning[0].message) assert "Using the `save_attn_procs()` method has been deprecated" in warning_message -class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel +class TestUNet2DConditionModelTraining(UNet2DConditionModelTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "CrossAttnUpBlock2D", + "CrossAttnDownBlock2D", + "UNetMidBlock2DCrossAttn", + "UpBlock2D", + "Transformer2DModel", + "DownBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNet2DConditionModelMemory(UNet2DConditionModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNet2DConditionModel.""" + + +class TestUNet2DConditionModelAttention(UNet2DConditionModelTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNet2DConditionModel.""" - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class TestUNet2DConditionModelCompile(UNet2DConditionModelTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for UNet2DConditionModel.""" -class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = UNet2DConditionModel - def prepare_init_args_and_inputs_for_common(self): - return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() +class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionModelTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for UNet2DConditionModel.""" @slow