Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,30 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []

model.to(dtype).save_pretrained(tmp_path)
# Cast the in-memory reference the same way `from_pretrained(torch_dtype=dtype)`
# does, so the two never diverge: parameters and persistent buffers are cast to
# `dtype` (except `_keep_in_fp32_modules`, which stay fp32), while non-persistent
# buffers such as RoPE `inv_freq` are not stored in the checkpoint and keep the
# dtype assigned in `__init__`. A blanket `model.to(dtype)` casts both of these
# cases unconditionally and produces spurious output mismatches.
non_persistent_buffers = set()
for module_name, module in model.named_modules():
for buffer_name in module._non_persistent_buffers_set:
non_persistent_buffers.add(f"{module_name}.{buffer_name}" if module_name else buffer_name)

def _keep_in_fp32(name):
return bool(fp32_modules) and any(m in name.split(".") for m in fp32_modules)

for name, param in model.named_parameters():
if param.is_floating_point() and not _keep_in_fp32(name):
param.data = param.data.to(dtype)
for name, buffer in model.named_buffers():
if name in non_persistent_buffers:
continue
if buffer.is_floating_point() and not _keep_in_fp32(name):
buffer.data = buffer.data.to(dtype)

model.save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)

for name, param in model_loaded.named_parameters():
Expand Down
Loading