From fb8c0f2a1e35381390bff87d82cb580162cd2059 Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Tue, 9 Jun 2026 10:20:07 -0700 Subject: [PATCH] Fix fp16 LoRA unscale crash after validation in remaining DreamBooth LoRA scripts Follow-up to #13895, which fixed this for examples/dreambooth/train_dreambooth_lora.py. The same fp16 footgun is present in the other DreamBooth LoRA training scripts: under `--mixed_precision="fp16"`, `cast_training_params(..., dtype=torch.float32)` keeps the trainable LoRA params in fp32, but `log_validation` rebuilds the in-loop validation pipeline around the *live* training transformer (`transformer=unwrap_model(transformer)`) and then casts it to fp16. That downcasts the fp32 LoRA params, so the next optimizer step raises `ValueError: Attempting to unscale FP16 gradients`. Apply the same fix across the remaining scripts: - flux, flux_kontext, qwen_image, hidream, and advanced flux: drop `dtype=torch_dtype` from `pipeline.to(accelerator.device, ...)` (keep the device move), matching #13895. - z_image and the flux2 variants: the cast is `pipeline.to(dtype=torch_dtype)` with no device move, immediately followed by `enable_model_cpu_offload()`, so just drop the cast line. Frozen weights already use `weight_dtype` and the offload call handles device placement. The final (post-training) validation in every script builds a fresh pipeline from the saved weights, so it is unaffected either way. --- .../train_dreambooth_lora_flux_advanced.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux2.py | 1 - examples/dreambooth/train_dreambooth_lora_flux2_img2img.py | 1 - examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 1 - .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 1 - examples/dreambooth/train_dreambooth_lora_flux_kontext.py | 2 +- examples/dreambooth/train_dreambooth_lora_hidream.py | 2 +- examples/dreambooth/train_dreambooth_lora_qwen_image.py | 2 +- examples/dreambooth/train_dreambooth_lora_z_image.py | 1 - 10 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 005f4303c3c1..8edbb1553bcd 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -243,7 +243,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5fb666a4d42c..4fc5182debc6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -197,7 +197,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 886e251937e6..8c14bddbfa29 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -200,7 +200,6 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(dtype=torch_dtype) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=True) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 477697fadb64..e23a913e4ae0 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -200,7 +200,6 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(dtype=torch_dtype) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=True) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 7eb627e4bd1d..47a758d8b1a2 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -200,7 +200,6 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(dtype=torch_dtype) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=True) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 63862eed9f1e..1c78882b6e8a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -200,7 +200,6 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(dtype=torch_dtype) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=True) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 97e0414635fb..e1b90b2947f9 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -200,7 +200,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) pipeline_args_cp = pipeline_args.copy() diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index c87d96366c6d..3a5eedba1c02 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -204,7 +204,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 573e0bf53f8a..a1e8439e6d10 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -192,7 +192,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index a54c84b0798f..3d62fdaeff0e 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -199,7 +199,6 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(dtype=torch_dtype) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=True)