Skip to content

fix: correct mask shape for masked flash attention#1625

Open
RapidMark wants to merge 2 commits into
leejet:masterfrom
CloudhandsAI:cloudhands/fattn-masked-attention-fix
Open

fix: correct mask shape for masked flash attention#1625
RapidMark wants to merge 2 commits into
leejet:masterfrom
CloudhandsAI:cloudhands/fattn-masked-attention-fix

Conversation

@RapidMark

Copy link
Copy Markdown
Contributor

Summary

Masked flash attention (--diffusion-fa with a model that supplies an attention mask) produced an all-blank image. The mask was passed to ggml_flash_attn_ext after a ggml_transpose, which yields the wrong shape; the kernel then misreads it and outputs NaN/blank.

This surfaced with Chroma, which supplies a T5 padding mask of shape [n_kv, 1] (a per-key mask broadcast over queries). With --diffusion-fa enabled, every Chroma generation came out blank.

Root cause

In ggml_ext_attention_ext's flash-attention path (build_kqv):

if (mask_in != nullptr) {
    mask_in = ggml_transpose(ctx, mask_in);   // [n_kv, n_q] -> [n_q, n_kv]  (wrong)
}
if (mask_in != nullptr) {
    mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
}
auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, ...);

ggml_flash_attn_ext expects the mask as a contiguous F16 tensor shaped [n_kv, n_q, (heads), (batch)] (ne0 = key length, ne1 = query length) and does not broadcast the query dimension. The manual-attention path adds the mask with ggml_add, which does broadcast a [n_kv, 1] mask over queries, so the non-flash path is correct; the flash path was not. Two problems:

  1. ggml_transpose puts the key length on the wrong axis ([n_kv, n_q][n_q, n_kv], and a query-broadcast [n_kv, 1][1, n_kv]).
  2. A query-broadcast mask (ne1 == 1) is never expanded to the real query count, which the flash kernel requires.

This runs silently rather than asserting because ggml's mask shape check is currently disabled (// GGML_ASSERT(ggml_can_repeat_rows(mask, qk));, still commented out as of ggml v0.14.0). Masked flash attention was effectively never exercised before (there is a // TODO: figure out if we can bend t5 to work too right here), so the latent bug went unnoticed until Chroma needed both a mask and flash attention.

Fix

Drop the transpose, and for a query-broadcast mask materialize the query dimension to L_q with ggml_repeat before the F16 cast:

if (mask_in != nullptr) {
    if (mask_in->ne[1] != L_q) {
        mask_in = ggml_repeat(ctx, mask_in,
            ggml_new_tensor_4d(ctx, mask_in->type, mask_in->ne[0], L_q, mask_in->ne[2], mask_in->ne[3]));
    }
    mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
}

The whole change is inside if (mask_in != nullptr), so the common no-mask flash-attention path (e.g. ordinary Flux text-to-image) is byte-for-byte unchanged.

Testing

Chroma1-HD, 1024×1024, 20 steps, euler, cfg 4.0, fixed seed — three-way check: no-FA+mask (reference) / FA+mask (this fix) / FA+no-mask.

  • CUDA: with the fix, FA+mask renders a correct image identical to the non-flash-attention reference (was blank before), ~30% faster than the non-flash path.
  • Vulkan (AMD RDNA4): with the fix, FA+mask renders a correct image (was blank before).
  • The no-mask flash path and the non-flash path are unaffected.

Verified the fix is required on both the currently-pinned ggml v0.12.0 and current ggml main v0.14.0: built with and without the fix against each — masked flash attention is blank on both ggml versions without it (the mask-shape assert is disabled in both).

Note on the ROCm/HIP backend

On the ROCm/HIP backend, the same (now correctly shaped) mask does not produce a blank image but a subtly corrupted one (a doubled/ghosted subject), whereas CUDA and Vulkan consume the identical mask tensor correctly. This points to a separate issue inside ggml's HIP flash-attn kernel, independent of and not introduced by this PR (without this PR the HIP path is blank too). It reproduces byte-for-byte identically on both ggml v0.12.0 and current ggml main v0.14.0, so the recent flash-attn rework (largely CUDA/RDNA3-focused) does not cover this path. That likely warrants a separate ggml-side report and is out of scope here.

Repro (before the fix)

sd-cli --diffusion-model Chroma1-HD.safetensors \
       --t5xxl t5xxl_fp16.safetensors --vae ae.safetensors \
       -p "a corgi astronaut on the moon" --cfg-scale 4.0 --steps 20 \
       -W 1024 -H 1024 --seed 42 --sampling-method euler \
       --diffusion-fa -o out.png
# before: blank white image;  after: correct image == non-FA reference

The flash-attention branch in ggml_ext_attention_ext passed the attention
mask to ggml_flash_attn_ext after a ggml_transpose, turning a [n_kv, n_q]
(or query-broadcast [n_kv, 1]) mask into [n_q/1, n_kv]. ggml_flash_attn_ext
expects the mask as a contiguous F16 tensor shaped [n_kv, n_q, ...] and does
not broadcast the query dimension, so the transposed mask was misindexed by
the kernel and produced NaN / all-blank output. ggml's mask shape assertion
(ggml_can_repeat_rows) is currently disabled, so this ran silently instead
of erroring.

Models that need a real attention mask with flash attention enabled (e.g.
Chroma, which passes a T5 padding mask of shape [n_kv, 1] broadcast over
queries) therefore rendered a blank image whenever --diffusion-fa was set.

Drop the transpose and, for query-broadcast masks, materialize the query
dimension to L_q with ggml_repeat before the F16 cast. The change is guarded
by mask != nullptr, so the common no-mask flash-attention path is unchanged.

Verified on Chroma1-HD at 1024x1024: masked flash attention now produces a
correct image matching the non-flash-attention reference on both the CUDA
and Vulkan backends.
@wbruna

wbruna commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

On the ROCm/HIP backend, the same (now correctly shaped) mask does not produce a blank image but a subtly corrupted one (a doubled/ghosted subject)

Seems to be working fine for me (Chroma1-HD-Flash-Q4_0, RX 7600 XT (gfx1102), ROCm 6.4.4 on Linux):

ROCm Vulkan
chroma_1781040282 chroma_1781040444

By the way, you may also want to remove the warning:

"!!!It looks like you are using Chroma with flash attention. "

The masked-flash-attention fix on this branch makes Chroma + flash
attention render correctly, so the warning telling users it is
unsupported and to disable flash attention now gives the wrong advice.
Drop it. (Spotted by @wbruna in review.)
@RapidMark

RapidMark commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for testing this @wbruna, and good catch on the warning — it describes exactly the broken-output case this PR fixes, so after the fix the "currently unsupported / disable it" advice is wrong. I've removed it (18cd3e8).

The HIP corruption you're not seeing is a separate, ggml-side issue, independent of this mask fix. Chroma is non-GQA, so without rocWMMA flash-attention the kernel selector falls to the tile kernel, which on fast-fp16 AMD keeps Q/K in fp16 and ignores GGML_PREC_F32. It's scale-dependent: with the same rocWMMA-off HIP build I get clean output at 768² (your resolution) but a ghosted subject at 1024² on the full de-distilled fp16 model — same model/seed/prompt, only the resolution differs. So your clean 768² run is consistent with the resolution being below where it kicks in, rather than anything specific to RDNA3. (A CPU-backend reference of the 1024² job is clean, and a per-node CPU-vs-HIP compare puts the first real divergence on the FLASH_ATTN_EXT node.) Building the HIP runtime with -DGGML_HIP_ROCWMMA_FATTN=ON routes it to the fp32 WMMA kernel and clears it at any size.

It's out of scope here (without this PR the HIP path is just blank); we'll follow up ggml-side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants