fix: correct mask shape for masked flash attention#1625
Conversation
The flash-attention branch in ggml_ext_attention_ext passed the attention mask to ggml_flash_attn_ext after a ggml_transpose, turning a [n_kv, n_q] (or query-broadcast [n_kv, 1]) mask into [n_q/1, n_kv]. ggml_flash_attn_ext expects the mask as a contiguous F16 tensor shaped [n_kv, n_q, ...] and does not broadcast the query dimension, so the transposed mask was misindexed by the kernel and produced NaN / all-blank output. ggml's mask shape assertion (ggml_can_repeat_rows) is currently disabled, so this ran silently instead of erroring. Models that need a real attention mask with flash attention enabled (e.g. Chroma, which passes a T5 padding mask of shape [n_kv, 1] broadcast over queries) therefore rendered a blank image whenever --diffusion-fa was set. Drop the transpose and, for query-broadcast masks, materialize the query dimension to L_q with ggml_repeat before the F16 cast. The change is guarded by mask != nullptr, so the common no-mask flash-attention path is unchanged. Verified on Chroma1-HD at 1024x1024: masked flash attention now produces a correct image matching the non-flash-attention reference on both the CUDA and Vulkan backends.
Seems to be working fine for me (Chroma1-HD-Flash-Q4_0, RX 7600 XT (gfx1102), ROCm 6.4.4 on Linux):
By the way, you may also want to remove the warning: stable-diffusion.cpp/src/stable-diffusion.cpp Line 580 in 19bdfe2 |
The masked-flash-attention fix on this branch makes Chroma + flash attention render correctly, so the warning telling users it is unsupported and to disable flash attention now gives the wrong advice. Drop it. (Spotted by @wbruna in review.)
|
Thanks for testing this @wbruna, and good catch on the warning — it describes exactly the broken-output case this PR fixes, so after the fix the "currently unsupported / disable it" advice is wrong. I've removed it (18cd3e8). The HIP corruption you're not seeing is a separate, ggml-side issue, independent of this mask fix. Chroma is non-GQA, so without rocWMMA flash-attention the kernel selector falls to the tile kernel, which on fast-fp16 AMD keeps Q/K in fp16 and ignores It's out of scope here (without this PR the HIP path is just blank); we'll follow up ggml-side. |


Summary
Masked flash attention (
--diffusion-fawith a model that supplies an attention mask) produced an all-blank image. The mask was passed toggml_flash_attn_extafter aggml_transpose, which yields the wrong shape; the kernel then misreads it and outputs NaN/blank.This surfaced with Chroma, which supplies a T5 padding mask of shape
[n_kv, 1](a per-key mask broadcast over queries). With--diffusion-faenabled, every Chroma generation came out blank.Root cause
In
ggml_ext_attention_ext's flash-attention path (build_kqv):ggml_flash_attn_extexpects the mask as a contiguous F16 tensor shaped[n_kv, n_q, (heads), (batch)](ne0 = key length, ne1 = query length) and does not broadcast the query dimension. The manual-attention path adds the mask withggml_add, which does broadcast a[n_kv, 1]mask over queries, so the non-flash path is correct; the flash path was not. Two problems:ggml_transposeputs the key length on the wrong axis ([n_kv, n_q]→[n_q, n_kv], and a query-broadcast[n_kv, 1]→[1, n_kv]).ne1 == 1) is never expanded to the real query count, which the flash kernel requires.This runs silently rather than asserting because ggml's mask shape check is currently disabled (
// GGML_ASSERT(ggml_can_repeat_rows(mask, qk));, still commented out as of ggml v0.14.0). Masked flash attention was effectively never exercised before (there is a// TODO: figure out if we can bend t5 to work tooright here), so the latent bug went unnoticed until Chroma needed both a mask and flash attention.Fix
Drop the transpose, and for a query-broadcast mask materialize the query dimension to
L_qwithggml_repeatbefore the F16 cast:The whole change is inside
if (mask_in != nullptr), so the common no-mask flash-attention path (e.g. ordinary Flux text-to-image) is byte-for-byte unchanged.Testing
Chroma1-HD, 1024×1024, 20 steps, euler, cfg 4.0, fixed seed — three-way check: no-FA+mask (reference) / FA+mask (this fix) / FA+no-mask.
Verified the fix is required on both the currently-pinned ggml v0.12.0 and current ggml main v0.14.0: built with and without the fix against each — masked flash attention is blank on both ggml versions without it (the mask-shape assert is disabled in both).
Note on the ROCm/HIP backend
On the ROCm/HIP backend, the same (now correctly shaped) mask does not produce a blank image but a subtly corrupted one (a doubled/ghosted subject), whereas CUDA and Vulkan consume the identical mask tensor correctly. This points to a separate issue inside ggml's HIP flash-attn kernel, independent of and not introduced by this PR (without this PR the HIP path is blank too). It reproduces byte-for-byte identically on both ggml v0.12.0 and current ggml main v0.14.0, so the recent flash-attn rework (largely CUDA/RDNA3-focused) does not cover this path. That likely warrants a separate ggml-side report and is out of scope here.
Repro (before the fix)