[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20086
[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20086JulianCloudNTH wants to merge 11 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20086
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 08abf3b with merge base af92b60 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 3m 59s —— View job Code Review: Fused SDPA (
|
Stack from ghstack (oldest at bottom):
Adds the fused
sdpa_with_kv_cacheop (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), theupdate_cacheop, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).input_posis consumed dynamically via the SymInt mechanism: the op readssymint_buffer()as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancinginput_pos) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.Tests live in the stacked test-suite diff above (clean op diff here).
Authored with assistance from Claude.
@exported-using-ghexport
Differential Revision: D107595125
Differential Revision: D107595125