Fix fp16 LoRA unscale crash after validation in remaining DreamBooth LoRA scripts#13899
Open
HaozheZhang6 wants to merge 1 commit into
Open
Fix fp16 LoRA unscale crash after validation in remaining DreamBooth LoRA scripts#13899HaozheZhang6 wants to merge 1 commit into
HaozheZhang6 wants to merge 1 commit into
Conversation
…LoRA scripts Follow-up to huggingface#13895, which fixed this for examples/dreambooth/train_dreambooth_lora.py. The same fp16 footgun is present in the other DreamBooth LoRA training scripts: under `--mixed_precision="fp16"`, `cast_training_params(..., dtype=torch.float32)` keeps the trainable LoRA params in fp32, but `log_validation` rebuilds the in-loop validation pipeline around the *live* training transformer (`transformer=unwrap_model(transformer)`) and then casts it to fp16. That downcasts the fp32 LoRA params, so the next optimizer step raises `ValueError: Attempting to unscale FP16 gradients`. Apply the same fix across the remaining scripts: - flux, flux_kontext, qwen_image, hidream, and advanced flux: drop `dtype=torch_dtype` from `pipeline.to(accelerator.device, ...)` (keep the device move), matching huggingface#13895. - z_image and the flux2 variants: the cast is `pipeline.to(dtype=torch_dtype)` with no device move, immediately followed by `enable_model_cpu_offload()`, so just drop the cast line. Frozen weights already use `weight_dtype` and the offload call handles device placement. The final (post-training) validation in every script builds a fresh pipeline from the saved weights, so it is unaffected either way.
f39ffba to
fb8c0f2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Follow-up to #13895 (merged), which fixed the fp16 LoRA "unscale" crash in
examples/dreambooth/train_dreambooth_lora.py. The same bug is present in the other DreamBooth LoRA training scripts — this PR applies the same fix to all of them.Root cause (same as #13895)
Under
--mixed_precision="fp16",cast_training_params(models, dtype=torch.float32)keeps the trainable LoRA params in fp32. The in-loop validation pipeline is rebuilt around the live training transformer:log_validationthen casts that pipeline totorch_dtype(= fp16), which downcasts the shared transformer's fp32 LoRA params back to fp16. The next optimizer step then raises:The final (post-training) validation builds a fresh pipeline from the saved weights, so it is unaffected either way.
Affected scripts & fix
I verified each script has all three conditions (fp16
cast_training_paramsguard + in-loop validation sharing the live transformer + the fp16 downcast inlog_validation). None of these were covered by #13895, and none have an open issue/PR.Group A —
pipeline.to(accelerator.device, dtype=torch_dtype)→ dropdtype(keep the device move), exactly like #13895:train_dreambooth_lora_flux.pytrain_dreambooth_lora_flux_kontext.pytrain_dreambooth_lora_qwen_image.pytrain_dreambooth_lora_hidream.pyadvanced_diffusion_training/train_dreambooth_lora_flux_advanced.pyGroup B — the cast is
pipeline.to(dtype=torch_dtype)(no device move) immediately followed bypipeline.enable_model_cpu_offload(), so the cast line is simply removed. Frozen weights already useweight_dtype(fromfrom_pretrained(torch_dtype=weight_dtype)) and the offload call handles device placement:train_dreambooth_lora_z_image.pytrain_dreambooth_lora_flux2.pytrain_dreambooth_lora_flux2_img2img.pytrain_dreambooth_lora_flux2_klein.pytrain_dreambooth_lora_flux2_klein_img2img.py(The
torch_dtypeparameter oflog_validationis left in place for signature consistency, same as the merged base script.)Testing
Same situation as #13895: this is a GPU-only codepath. On CPU
accelerator.scaler is None(the fp16GradScaleris CUDA-only), so the example CI cannot reproduce it, and the example suite has no--mixed_precision fp16tests for this reason. The fix is mechanical and identical to the one already merged in #13895; the mechanism is verified there.ruff checkandruff format --checkpass on all 10 changed files.Before submitting
Who can review?
@sayakpaul (reviewed and merged #13895)