From a5ec1a1d0877d92e9aa45a630ee15f3dd8beddc7 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Tue, 9 Jun 2026 09:51:29 -0700 Subject: [PATCH] refactor unet_2d tests --- tests/models/unets/test_models_unet_2d.py | 296 +++++++++------------- 1 file changed, 113 insertions(+), 183 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index e289f44303f2..a5cd8abd873a 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -15,12 +15,12 @@ import gc import math -import unittest +import pytest import torch from diffusers import UNet2DModel -from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import ( backend_empty_cache, @@ -31,39 +31,31 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin -logger = logging.get_logger(__name__) - enable_full_determinism() -class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" - +class Unet2DModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) + def model_class(self): + return UNet2DModel - return {"sample": noise, "timestep": time_step} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (3, 32, 32) @property - def output_shape(self): - return (3, 32, 32) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": (4, 8), "norm_num_groups": 2, "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), @@ -74,110 +66,77 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 2, "sample_size": 32, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - def test_mid_block_attn_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + def get_dummy_inputs(self) -> dict: + noise = randn_tensor((4, 3, 32, 32), generator=self.generator, device=torch_device) + timestep = torch.tensor([10], device=torch_device) + return {"sample": noise, "timestep": timestep} + +class TestUnet2DModel(Unet2DModelTesterConfig, ModelTesterMixin): + def test_mid_block_attn_groups(self): + init_dict = self.get_init_dict() init_dict["add_attention"] = True init_dict["attn_norm_num_groups"] = 4 + model = self.model_class(**init_dict).to(torch_device).eval() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - self.assertIsNotNone( - model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not." + assert model.mid_block.attentions[0].group_norm is not None, ( + "Mid block Attention group norm should exist but does not." ) - self.assertEqual( - model.mid_block.attentions[0].group_norm.num_groups, - init_dict["attn_norm_num_groups"], - "Mid block Attention group norm does not have the expected number of groups.", + assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], ( + "Mid block Attention group norm does not have the expected number of groups." ) with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**self.get_dummy_inputs()).sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_mid_block_none(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + mid_none_init_dict = self.get_init_dict() mid_none_init_dict["mid_block_type"] = None - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - mid_none_model = self.model_class(**mid_none_init_dict) - mid_none_model.to(torch_device) - mid_none_model.eval() - - self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.") - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + model = self.model_class(**init_dict).to(torch_device).eval() + mid_none_model = self.model_class(**mid_none_init_dict).to(torch_device).eval() + assert mid_none_model.mid_block is None, "Mid block should not exist." with torch.no_grad(): - mid_none_output = mid_none_model(**mid_none_inputs_dict) + output = model(**self.get_dummy_inputs()).sample + mid_none_output = mid_none_model(**self.get_dummy_inputs()).sample - if isinstance(mid_none_output, dict): - mid_none_output = mid_none_output.to_tuple()[0] + assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different." - self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.") +class TestUnet2DModelTraining(Unet2DModelTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): - expected_set = { - "AttnUpBlock2D", - "AttnDownBlock2D", - "UNetMidBlock2D", - "UpBlock2D", - "DownBlock2D", - } - - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 8 - block_out_channels = (16, 32) + expected_set = {"AttnUpBlock2D", "AttnDownBlock2D", "UNetMidBlock2D", "UpBlock2D", "DownBlock2D"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +class TestUnet2DModelMemory(Unet2DModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNet2DModel.""" -class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" +class UNetLDMModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) + def model_class(self): + return UNet2DModel - return {"sample": noise, "timestep": time_step} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 32, 32) @property - def output_shape(self): - return (4, 32, 32) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "sample_size": 32, "in_channels": 4, "out_channels": 4, @@ -187,26 +146,28 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ("DownBlock2D", "DownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict: + noise = randn_tensor((4, 4, 32, 32), generator=self.generator, device=torch_device) + timestep = torch.tensor([10], device=torch_device) + return {"sample": noise, "timestep": timestep} + + +class TestUNetLDMModel(UNetLDMModelTesterConfig, ModelTesterMixin): def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) - - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input).sample - + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @require_torch_accelerator def test_from_pretrained_accelerate(self): model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model.to(torch_device) - image = model(**self.dummy_input).sample - + image = model(**self.get_dummy_inputs()).sample assert image is not None, "Make sure output is not None" @require_torch_accelerator @@ -264,45 +225,38 @@ def test_output_pretrained(self): # fmt: off expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) # fmt: on + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3) - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3)) +class TestUNetLDMModelTraining(UNetLDMModelTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim` - attention_head_dim = 32 - block_out_channels = (32, 64) - - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels - ) +class TestUNetLDMModelMemory(UNetLDMModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for the LDM UNet2DModel config.""" -class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet2DModel - main_input_name = "sample" +class NCSNppModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self, sizes=(32, 32)): - batch_size = 4 - num_channels = 3 - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) + def model_class(self): + return UNet2DModel - return {"sample": noise, "timestep": time_step} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (3, 32, 32) @property - def output_shape(self): - return (3, 32, 32) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": [32, 64, 64, 64], "in_channels": 3, "layers_per_block": 1, @@ -311,34 +265,27 @@ def prepare_init_args_and_inputs_for_common(self): "norm_eps": 1e-6, "mid_block_scale_factor": math.sqrt(2.0), "norm_num_groups": None, - "down_block_types": [ - "SkipDownBlock2D", - "AttnSkipDownBlock2D", - "SkipDownBlock2D", - "SkipDownBlock2D", - ], - "up_block_types": [ - "SkipUpBlock2D", - "SkipUpBlock2D", - "AttnSkipUpBlock2D", - "SkipUpBlock2D", - ], + "down_block_types": ["SkipDownBlock2D", "AttnSkipDownBlock2D", "SkipDownBlock2D", "SkipDownBlock2D"], + "up_block_types": ["SkipUpBlock2D", "SkipUpBlock2D", "AttnSkipUpBlock2D", "SkipUpBlock2D"], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict: + noise = randn_tensor((4, 3, 32, 32), generator=self.generator, device=torch_device) + timestep = torch.tensor(4 * [10], dtype=torch.int32, device=torch_device) + return {"sample": noise, "timestep": timestep} + + +class TestNCSNppModel(NCSNppModelTesterConfig, ModelTesterMixin): @slow def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - inputs = self.dummy_input - noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) - inputs["sample"] = noise + inputs = self.get_dummy_inputs() + inputs["sample"] = floats_tensor((4, 3) + (256, 256)).to(torch_device) image = model(**inputs) - assert image is not None, "Make sure output is not None" @slow @@ -346,12 +293,8 @@ def test_output_pretrained_ve_mid(self): model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") model.to(torch_device) - batch_size = 4 - num_channels = 3 - sizes = (256, 256) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + noise = torch.ones((4, 3) + (256, 256)).to(torch_device) + time_step = torch.tensor(4 * [1e-4]).to(torch_device) with torch.no_grad(): output = model(noise, time_step).sample @@ -360,19 +303,14 @@ def test_output_pretrained_ve_mid(self): # fmt: off expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056]) # fmt: on - - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) def test_output_pretrained_ve_large(self): model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") model.to(torch_device) - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + noise = torch.ones((4, 3) + (32, 32)).to(torch_device) + time_step = torch.tensor(4 * [1e-4]).to(torch_device) with torch.no_grad(): output = model(noise, time_step).sample @@ -381,36 +319,28 @@ def test_output_pretrained_ve_large(self): # fmt: off expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) # fmt: on + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) - - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # not required for this model - pass +class TestNCSNppModelTraining(NCSNppModelTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): - expected_set = { - "UNetMidBlock2D", - } + expected_set = {"UNetMidBlock2D"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - block_out_channels = (32, 64, 64, 64) + def test_gradient_checkpointing_equivalence(self): + super().test_gradient_checkpointing_equivalence(skip={"time_proj.weight"}) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, block_out_channels=block_out_channels - ) - def test_effective_gradient_checkpointing(self): - super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) +class TestNCSNppModelMemory(NCSNppModelTesterConfig, MemoryTesterMixin): + # Layerwise casting is not supported for this model. + @pytest.mark.skip("Layerwise casting is not supported for this model.") + def test_layerwise_casting_memory(self): + pass - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_inference(self): + @pytest.mark.skip("Layerwise casting is not supported for this model.") + def test_layerwise_casting_training(self): pass - @unittest.skip( - "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." - ) - def test_layerwise_casting_memory(self): + @pytest.mark.skip("Layerwise casting is not supported for this model.") + def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): pass