diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index ba060b3b120d..975ca9c683d9 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -481,7 +481,30 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, model.to(torch_device) fp32_modules = model._keep_in_fp32_modules or [] - model.to(dtype).save_pretrained(tmp_path) + # Cast the in-memory reference the same way `from_pretrained(torch_dtype=dtype)` + # does, so the two never diverge: parameters and persistent buffers are cast to + # `dtype` (except `_keep_in_fp32_modules`, which stay fp32), while non-persistent + # buffers such as RoPE `inv_freq` are not stored in the checkpoint and keep the + # dtype assigned in `__init__`. A blanket `model.to(dtype)` casts both of these + # cases unconditionally and produces spurious output mismatches. + non_persistent_buffers = set() + for module_name, module in model.named_modules(): + for buffer_name in module._non_persistent_buffers_set: + non_persistent_buffers.add(f"{module_name}.{buffer_name}" if module_name else buffer_name) + + def _keep_in_fp32(name): + return bool(fp32_modules) and any(m in name.split(".") for m in fp32_modules) + + for name, param in model.named_parameters(): + if param.is_floating_point() and not _keep_in_fp32(name): + param.data = param.data.to(dtype) + for name, buffer in model.named_buffers(): + if name in non_persistent_buffers: + continue + if buffer.is_floating_point() and not _keep_in_fp32(name): + buffer.data = buffer.data.to(dtype) + + model.save_pretrained(tmp_path) model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) for name, param in model_loaded.named_parameters():