Skip to content

[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20086

Open
JulianCloudNTH wants to merge 11 commits into
gh/JulianCloudNTH/19/basefrom
gh/JulianCloudNTH/19/head
Open

[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20086
JulianCloudNTH wants to merge 11 commits into
gh/JulianCloudNTH/19/basefrom
gh/JulianCloudNTH/19/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Adds the fused sdpa_with_kv_cache op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the update_cache op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).

input_pos is consumed dynamically via the SymInt mechanism: the op reads symint_buffer() as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing input_pos) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.

Tests live in the stacked test-suite diff above (clean op diff here).

Authored with assistance from Claude.
@exported-using-ghexport

Differential Revision: D107595125

Differential Revision: D107595125

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 6, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20086

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 08abf3b with merge base af92b60 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

github-actions Bot commented Jun 6, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 9, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 3m 59s —— View job


Code Review: Fused SDPA (sdpa_with_kv_cache) for WebGPU

  • Read changed files (Sdpa.cpp, 3 WGSL kernels + headers, build, WebGPUGraph.h)
  • Analyze correctness, performance, and style
  • Post review feedback

Reviewed against origin/gh/JulianCloudNTH/19/base. Overall this is clean, well-structured work — the flat-index layouts are self-consistent across the three kernels, the GQA/causal-mask math is correct, and the dynamic input_pos resize path is handled carefully. Findings below, ordered by significance.


Correctness (verified, no blocking bugs found)

I traced the index math across all three kernels and the C++ dispatch sizing; they agree:

  • Layouts consistent. QK writes attn_weights as [Hq, S, context_len] (idx = h·S·ctx + s·ctx + c); softmax reads with row_width = context_len; AV reads aw_base = h·S·ctx + s·ctx. All three use the context_len stride, not Cmax, so rows pack contiguously into the front of the Cmax-capacity scratch and the unused tail is harmless. ✅
  • Resize hook updates exactly what varies with context_len. Only the QK dispatch's workgroup_count_x depends on context_len (= Hq·S·ctx) and it is rewritten (Sdpa.cpp:521). Softmax (Hq·S rows) and AV (S·Hq·D) dispatch counts are context_len-independent, so leaving their counts fixed while only rewriting their uniforms is correct. ✅
  • Prefill/first-run path is sound. The build placeholder input_pos = read_symint(...) means that when the first real input_pos equals the build value (e.g. prefill at 0), set_symint won't mark it dirty and the hook won't fire — but the baked build params already match, so output is correct. Nicely consistent. ✅
  • Masked-row softmax is safe. Causal guarantees ≥1 unmasked entry per row (c ≤ s + input_pos, and context_len = S + input_pos), so row_max is finite and exp(NEG_INF − row_max) → 0. The row_sum > 0 guard in sdpa_softmax.wgsl:91 is belt-and-suspenders. ✅
  • Idle-thread reduction. When context_len < 64, idle lanes seed local_max = NEG_INF / local_sum = 0 and still hit every workgroupBarrier() in uniform control flow. Correct. ✅

Suggestions (non-blocking)

1. Pipeline/shader/BGL caches go unused — Sdpa.cpp:93 build_dispatch.
The graph exposes get_or_create_shader/pipeline/bgl (WebGPUGraph.h:181-193), but build_dispatch creates a fresh shader module, BGL, pipeline layout, and pipeline on every call. The two update_cache dispatches (K and V) compile the identical kernel twice, and SDPA shaders rebuild from scratch even though other ops in the graph may use them. Build-time only, but routing through the existing caches would cut redundant compilation. Fix this →

2. Full O(Hq·S·context_len) attention matrix in scratch, ×2 — Sdpa.cpp:303-311.
Two scratch buffers are each Hq·S·Cmax·4 bytes in the dynamic case. Fine for decode (S=1), but for a long prefill this is large — e.g. Hq=32, S=2048, Cmax=2048 → ~512 MB each. This mirrors the Vulkan reference (no flash-attention tiling), so it's a known limitation rather than a defect, but it's worth a comment noting prefill memory scales as Hq·S·Cmax so future readers don't get surprised.

3. No bound check on kMaxEntriesSdpa.cpp:113.
kMaxEntries = 8 and n_storage is currently ≤ 3, so n_storage + 1 is safe today. A static_assert/runtime guard (or just a comment that callers must keep n_storage + 1 ≤ kMaxEntries) would make the silent buffer-overflow risk explicit if a future kernel adds bindings.

4. K/V sequence length is assumed equal to q's SSdpa.cpp:327.
kv_numel and the update_cache dispatch derive the new-token count from q.dims (S), never from k/v. If k/v ever had a different seq dim this would silently mis-copy. A k.dims[kn-3] == S check alongside the existing validation would harden this.

5. Softmax workgroup size isn't clamped to device limits.
QK/AV use utils::clamp_workgroup_size and pass it as an override; sdpa_softmax.wgsl hardcodes WG_SIZE = 64u with wg_size = 0 sentinel (Sdpa.cpp:428). Since the WebGPU spec floor for maxComputeInvocationsPerWorkgroup is 256, 64 is always safe — so this is fine as-is, just noting the intentional asymmetry in case it ever confuses.

Style

Matches the repo's WebGPU conventions well — minimal comments, fail-loud validation mirroring the Vulkan reference, explicit param structs with static_assert size checks. No style concerns. Note the op-level tests are in the stacked diff (#20087), so I couldn't validate runtime behavior here.


· branch gh/JulianCloudNTH/19/head

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant