[ExecuTorch][WebGPU] Add update_cache op (llama.update_cache)#20083
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20083
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 1m 56s —— View job Code Review:
|
|
Claude finished @JulianCloudNTH's task in 1m 56s —— View job Code Review:
|
2647a36
into
gh/JulianCloudNTH/16/base
Pull Request resolved: #20083 Add `llama.update_cache.default`: an in-place KV-cache write. The shader scatters the new K/V (`[1,S,H,D]`) into the cache (`[1,Cmax,H,D]`) at `dst_offset = input_pos*n_heads*head_dim`, bounds-checked against the cache size. The handler validates shape (batch==1, matching n_heads/head_dim) and sizes the 1D dispatch from the device limit via `WebGPUUtils` before allocating. Mirrors the Vulkan `sdpa_kv_cache_update` reference. The export/delegation test is the follow-up diff stacked directly above. Authored with assistance from Claude. ghstack-source-id: 392019030 @exported-using-ghexport Differential Revision: [D107547308](https://our.internmc.facebook.com/intern/diff/D107547308/)
Stack from ghstack (oldest at bottom):
Add
llama.update_cache.default: an in-place KV-cache write. The shader scatters the new K/V ([1,S,H,D]) into the cache ([1,Cmax,H,D]) atdst_offset = input_pos*n_heads*head_dim, bounds-checked against the cache size. The handler validates shape (batch==1, matching n_heads/head_dim) and sizes the 1D dispatch from the device limit viaWebGPUUtilsbefore allocating. Mirrors the Vulkansdpa_kv_cache_updatereference. The export/delegation test is the follow-up diff stacked directly above. Authored with assistance from Claude.@exported-using-ghexport
Differential Revision: D107547308
Differential Revision: D107547308