Skip to content

Fix fp16 LoRA unscale crash after validation in remaining DreamBooth LoRA scripts#13899

Open
HaozheZhang6 wants to merge 1 commit into
huggingface:mainfrom
HaozheZhang6:fix-fp16-lora-validation-remaining
Open

Fix fp16 LoRA unscale crash after validation in remaining DreamBooth LoRA scripts#13899
HaozheZhang6 wants to merge 1 commit into
huggingface:mainfrom
HaozheZhang6:fix-fp16-lora-validation-remaining

Conversation

@HaozheZhang6

Copy link
Copy Markdown
Contributor

What does this PR do?

Follow-up to #13895 (merged), which fixed the fp16 LoRA "unscale" crash in examples/dreambooth/train_dreambooth_lora.py. The same bug is present in the other DreamBooth LoRA training scripts — this PR applies the same fix to all of them.

Root cause (same as #13895)

Under --mixed_precision="fp16", cast_training_params(models, dtype=torch.float32) keeps the trainable LoRA params in fp32. The in-loop validation pipeline is rebuilt around the live training transformer:

pipeline = SomePipeline.from_pretrained(..., transformer=unwrap_model(transformer), torch_dtype=weight_dtype)

log_validation then casts that pipeline to torch_dtype (= fp16), which downcasts the shared transformer's fp32 LoRA params back to fp16. The next optimizer step then raises:

ValueError: Attempting to unscale FP16 gradients.

The final (post-training) validation builds a fresh pipeline from the saved weights, so it is unaffected either way.

Affected scripts & fix

I verified each script has all three conditions (fp16 cast_training_params guard + in-loop validation sharing the live transformer + the fp16 downcast in log_validation). None of these were covered by #13895, and none have an open issue/PR.

Group Apipeline.to(accelerator.device, dtype=torch_dtype) → drop dtype (keep the device move), exactly like #13895:

  • train_dreambooth_lora_flux.py
  • train_dreambooth_lora_flux_kontext.py
  • train_dreambooth_lora_qwen_image.py
  • train_dreambooth_lora_hidream.py
  • advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Group B — the cast is pipeline.to(dtype=torch_dtype) (no device move) immediately followed by pipeline.enable_model_cpu_offload(), so the cast line is simply removed. Frozen weights already use weight_dtype (from from_pretrained(torch_dtype=weight_dtype)) and the offload call handles device placement:

  • train_dreambooth_lora_z_image.py
  • train_dreambooth_lora_flux2.py
  • train_dreambooth_lora_flux2_img2img.py
  • train_dreambooth_lora_flux2_klein.py
  • train_dreambooth_lora_flux2_klein_img2img.py

(The torch_dtype parameter of log_validation is left in place for signature consistency, same as the merged base script.)

Testing

Same situation as #13895: this is a GPU-only codepath. On CPU accelerator.scaler is None (the fp16 GradScaler is CUDA-only), so the example CI cannot reproduce it, and the example suite has no --mixed_precision fp16 tests for this reason. The fix is mechanical and identical to the one already merged in #13895; the mechanism is verified there. ruff check and ruff format --check pass on all 10 changed files.

Before submitting

Who can review?

@sayakpaul (reviewed and merged #13895)

@github-actions github-actions Bot added examples size/S PR with diff < 50 LOC labels Jun 9, 2026
…LoRA scripts

Follow-up to huggingface#13895, which fixed this for examples/dreambooth/train_dreambooth_lora.py.
The same fp16 footgun is present in the other DreamBooth LoRA training scripts: under
`--mixed_precision="fp16"`, `cast_training_params(..., dtype=torch.float32)` keeps the
trainable LoRA params in fp32, but `log_validation` rebuilds the in-loop validation
pipeline around the *live* training transformer (`transformer=unwrap_model(transformer)`)
and then casts it to fp16. That downcasts the fp32 LoRA params, so the next optimizer
step raises `ValueError: Attempting to unscale FP16 gradients`.

Apply the same fix across the remaining scripts:
- flux, flux_kontext, qwen_image, hidream, and advanced flux: drop `dtype=torch_dtype`
  from `pipeline.to(accelerator.device, ...)` (keep the device move), matching huggingface#13895.
- z_image and the flux2 variants: the cast is `pipeline.to(dtype=torch_dtype)` with no
  device move, immediately followed by `enable_model_cpu_offload()`, so just drop the
  cast line. Frozen weights already use `weight_dtype` and the offload call handles
  device placement.

The final (post-training) validation in every script builds a fresh pipeline from the
saved weights, so it is unaffected either way.
@HaozheZhang6 HaozheZhang6 force-pushed the fix-fp16-lora-validation-remaining branch from f39ffba to fb8c0f2 Compare June 10, 2026 03:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant