From bdd4d006f907ff2f3a9eff42b4041130d08c633b Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 10 Jun 2026 13:19:19 -0700 Subject: [PATCH 1/3] [ExecuTorch][WebGPU] Add update_cache op (llama.update_cache) Pull Request resolved: https://github.com/pytorch/executorch/pull/20083 Add `llama.update_cache.default`: an in-place KV-cache write. The shader scatters the new K/V (`[1,S,H,D]`) into the cache (`[1,Cmax,H,D]`) at `dst_offset = input_pos*n_heads*head_dim`, bounds-checked against the cache size. The handler validates shape (batch==1, matching n_heads/head_dim) and sizes the 1D dispatch from the device limit via `WebGPUUtils` before allocating. Mirrors the Vulkan `sdpa_kv_cache_update` reference. The export/delegation test is the follow-up diff stacked directly above. Authored with assistance from Claude. ghstack-source-id: 392019030 @exported-using-ghexport Differential Revision: [D107547308](https://our.internmc.facebook.com/intern/diff/D107547308/) --- backends/webgpu/CMakeLists.txt | 1 + .../runtime/ops/update_cache/UpdateCache.cpp | 198 ++++++++++++++++++ .../ops/update_cache/update_cache.wgsl | 24 +++ .../ops/update_cache/update_cache_wgsl.h | 48 +++++ 4 files changed, 271 insertions(+) create mode 100644 backends/webgpu/runtime/ops/update_cache/UpdateCache.cpp create mode 100644 backends/webgpu/runtime/ops/update_cache/update_cache.wgsl create mode 100644 backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index b6b41fb6587..5e6e1d7bf35 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -33,6 +33,7 @@ set(WEBGPU_SRCS runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp runtime/ops/rms_norm/RmsNorm.cpp + runtime/ops/update_cache/UpdateCache.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/ops/update_cache/UpdateCache.cpp b/backends/webgpu/runtime/ops/update_cache/UpdateCache.cpp new file mode 100644 index 00000000000..dc23a45eb91 --- /dev/null +++ b/backends/webgpu/runtime/ops/update_cache/UpdateCache.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform buffer layout matching the WGSL Params struct (16-byte aligned). +struct UpdateCacheParams { + uint32_t numel; + uint32_t dst_offset; + uint32_t cache_numel; + uint32_t _pad0; +}; +static_assert( + sizeof(UpdateCacheParams) == 16, + "UpdateCacheParams must be 16 bytes"); + +// llama.update_cache.default args: [value, cache, input_pos, out]. +void update_cache_impl(WebGPUGraph& graph, const std::vector& args) { + const int value_id = args.at(0); + const int cache_id = args.at(1); + const int input_pos_id = args.at(2); + + WGPUDevice device = graph.device(); + + const auto& value_tensor = graph.get_tensor(value_id); + const auto& cache_tensor = graph.get_tensor(cache_id); + if (value_tensor.dims.size() < 4 || cache_tensor.dims.size() < 4 || + value_tensor.nbytes == 0) { + throw std::runtime_error("WebGPU update_cache: expects 4D value and cache"); + } + + uint64_t value_numel = 1; + for (int64_t d : value_tensor.dims) { + value_numel *= static_cast(d); + } + // fp32-only shader: bail if bytes don't match an fp32 element count. + if (value_tensor.nbytes != value_numel * sizeof(float)) { + throw std::runtime_error( + "WebGPU update_cache: fp32-only (byte-size mismatch)"); + } + + const size_t ndim = value_tensor.dims.size(); + const size_t cndim = cache_tensor.dims.size(); + // Mirror Vulkan update_cache_impl shape guards (backends/vulkan SDPA.cpp). + if (value_tensor.dims[ndim - 4] != 1 || cache_tensor.dims[cndim - 4] != 1) { + throw std::runtime_error("WebGPU update_cache: batch must be 1"); + } + if (value_tensor.dims[ndim - 1] != cache_tensor.dims[cndim - 1]) { + throw std::runtime_error("WebGPU update_cache: head_dim mismatch"); + } + if (value_tensor.dims[ndim - 2] != cache_tensor.dims[cndim - 2]) { + throw std::runtime_error("WebGPU update_cache: n_heads mismatch"); + } + const uint64_t head_dim = static_cast(value_tensor.dims[ndim - 1]); + const uint64_t n_heads = static_cast(value_tensor.dims[ndim - 2]); + + uint64_t cache_numel = 1; + for (int64_t d : cache_tensor.dims) { + cache_numel *= static_cast(d); + } + + if (graph.get_value_type(input_pos_id) != WebGPUGraph::ValueType::Int) { + throw std::runtime_error( + "WebGPU update_cache: input_pos must be Int (SymInt not yet supported)"); + } + const int64_t input_pos = graph.get_int(input_pos_id); + if (input_pos < 0) { + throw std::runtime_error( + "WebGPU update_cache: input_pos must be non-negative"); + } + + // Bound input_pos in u64 so the u32 param downcasts cannot overflow/truncate. + const uint64_t stride = n_heads * head_dim; + if (cache_numel > UINT32_MAX || value_numel > cache_numel || + static_cast(input_pos) > (cache_numel - value_numel) / stride) { + throw std::runtime_error( + "WebGPU update_cache: input_pos writes past cache capacity"); + } + const uint64_t dst_offset = static_cast(input_pos) * stride; + + UpdateCacheParams params = {}; + params.numel = static_cast(value_numel); + params.dst_offset = static_cast(dst_offset); + params.cache_numel = static_cast(cache_numel); + + // Validate dispatch against device limits before allocating GPU objects. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kUpdateCacheWorkgroupSizeX); + const uint32_t workgroup_count_x = utils::compute_1d_workgroup_count( + device, params.numel, wg_size, "update_cache"); + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(UpdateCacheParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(UpdateCacheParams)); + std::memcpy(mapped, ¶ms, sizeof(UpdateCacheParams)); + wgpuBufferUnmap(uniform_buffer); + + graph.add_uniform_buffer_bytes(sizeof(UpdateCacheParams)); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kUpdateCacheWGSL, WGPU_STRLEN}; + + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Bind group layout: cache (rw storage) + value (ro storage) + params. + WGPUBindGroupLayoutEntry entries[3] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + entries[1].binding = 1; + entries[1].visibility = WGPUShaderStage_Compute; + entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + entries[2].binding = 2; + entries[2].visibility = WGPUShaderStage_Compute; + entries[2].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 3; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[3] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = cache_tensor.buffer; + bg_entries[0].size = cache_tensor.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = value_tensor.buffer; + bg_entries[1].size = value_tensor.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = uniform_buffer; + bg_entries[2].size = sizeof(UpdateCacheParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 3; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch({pipeline, bind_group, workgroup_count_x}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // Drop our ref; the bind group keeps the uniform buffer alive until release. + wgpuBufferRelease(uniform_buffer); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(update_cache.default, update_cache_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/update_cache/update_cache.wgsl b/backends/webgpu/runtime/ops/update_cache/update_cache.wgsl new file mode 100644 index 00000000000..62f882ad547 --- /dev/null +++ b/backends/webgpu/runtime/ops/update_cache/update_cache.wgsl @@ -0,0 +1,24 @@ +@group(0) @binding(0) var t_cache: array; +@group(0) @binding(1) var t_value: array; + +struct Params { + numel: u32, + dst_offset: u32, + cache_numel: u32, + _pad0: u32, +} +@group(0) @binding(2) var params: Params; + +override wg_size: u32 = 256; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + if (i >= params.numel) { + return; + } + if (params.dst_offset + i >= params.cache_numel) { + return; + } + t_cache[params.dst_offset + i] = t_value[i]; +} diff --git a/backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h b/backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h new file mode 100644 index 00000000000..ce26ccb767d --- /dev/null +++ b/backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from update_cache.wgsl - DO NOT EDIT. +// wgsl-sha256: 994cac9bab0ed25c9c82d54af77d9bbbe34e49419e916d0164c9cf0e5b199c6a +inline constexpr const char* kUpdateCacheWGSL = R"( +@group(0) @binding(0) var t_cache: array; +@group(0) @binding(1) var t_value: array; + +struct Params { + numel: u32, + dst_offset: u32, + cache_numel: u32, + _pad0: u32, +} +@group(0) @binding(2) var params: Params; + +override wg_size: u32 = 256; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + if (i >= params.numel) { + return; + } + if (params.dst_offset + i >= params.cache_numel) { + return; + } + t_cache[params.dst_offset + i] = t_value[i]; +} +)"; + +inline constexpr uint32_t kUpdateCacheWorkgroupSizeX = 256; +inline constexpr uint32_t kUpdateCacheWorkgroupSizeY = 1; +inline constexpr uint32_t kUpdateCacheWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu From 6e1c5ca46c5921374eb4ae6106af4d3661ccc7e3 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 10 Jun 2026 13:19:20 -0700 Subject: [PATCH 2/3] [ExecuTorch][WebGPU] Add update_cache tests (native numeric + export) Pull Request resolved: https://github.com/pytorch/executorch/pull/20084 Tests for `llama.update_cache.default`, stacked on the op diff below. `test/ops/sdpa/test_update_cache.py` lowers the op through `VulkanPartitioner` (asserting it delegates to VulkanBackend) and exports per-case `.pte`s; `test/native/test_update_cache.cpp` runs them on-GPU and checks an integer-exact scatter golden against the returned cache. Coverage mirrors the Vulkan KV-cache test (`VulkanSDPATest`): single-shot writes at varied shapes/offsets, plus a multi-step advancing-input_pos replay that threads the returned cache across steps over the same GQA param sets (incl. llama3 head_dim=128). Comparing the cache directly is stronger than Vulkan, which checks it only indirectly via the SDPA output. Authored with assistance from Claude. ghstack-source-id: 391979582 @exported-using-ghexport Differential Revision: [D107547307](https://our.internmc.facebook.com/intern/diff/D107547307/) --- backends/webgpu/CMakeLists.txt | 3 + .../webgpu/scripts/test_webgpu_native_ci.sh | 22 +- .../webgpu/test/native/test_update_cache.cpp | 291 ++++++++++++++++++ backends/webgpu/test/ops/sdpa/__init__.py | 5 + .../webgpu/test/ops/sdpa/test_update_cache.py | 196 ++++++++++++ 5 files changed, 514 insertions(+), 3 deletions(-) create mode 100644 backends/webgpu/test/native/test_update_cache.cpp create mode 100644 backends/webgpu/test/ops/sdpa/__init__.py create mode 100644 backends/webgpu/test/ops/sdpa/test_update_cache.py diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 5e6e1d7bf35..3351c213d4a 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -125,4 +125,7 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) add_webgpu_native_test( webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp ) + add_webgpu_native_test( + webgpu_update_cache_test test/native/test_update_cache.cpp + ) endif() diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index af014efb228..02f1411401a 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -18,8 +18,8 @@ # # Builds whatever native test targets are present in the landed tree (NOT a fixed # list). This stack lands: webgpu_native_test, webgpu_rms_norm_test (base) + -# webgpu_dispatch_order_test, webgpu_scratch_buffer_test (D107576199). update_cache -# / SDPA executables join automatically once their sibling diffs land. +# webgpu_dispatch_order_test, webgpu_scratch_buffer_test (D107576199) + +# webgpu_update_cache_test (D107547307). SDPA executables join once they land. set -e @@ -45,6 +45,8 @@ RMS_NORM_DIR="/tmp/rmsn" RMS_NORM_OK=1 DISPATCH_ORDER_DIR="/tmp/dispatch_order" DISPATCH_ORDER_OK=1 +UPDATE_CACHE_DIR="/tmp/update_cache" +UPDATE_CACHE_OK=1 $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.add.test_add import export_add_model, export_chained_add_model @@ -62,6 +64,17 @@ from executorch.backends.webgpu.test.ops.dispatch_order.test_dispatch_order impo export_dispatch_order_cases('${DISPATCH_ORDER_DIR}') " || { echo "WARN: dispatch_order export failed; skipping dispatch_order native test"; DISPATCH_ORDER_OK=0; } +$PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.sdpa.test_update_cache import ( + export_update_cache_cases, + export_update_cache_replay, + export_update_cache_negative, +) +export_update_cache_cases('${UPDATE_CACHE_DIR}') +export_update_cache_replay('${UPDATE_CACHE_DIR}') +export_update_cache_negative('${UPDATE_CACHE_DIR}') +" || { echo "WARN: update_cache export failed; skipping update_cache native test"; UPDATE_CACHE_OK=0; } + # ── Configure (Dawn-only: no -DWEBGPU_IMPL; Dawn is the sole backend) ───────── echo "=== Configure WebGPU native tests on Dawn ===" rm -rf "${BUILD_DIR}" @@ -79,7 +92,7 @@ cmake \ "${EXECUTORCH_ROOT}" # ── Build + run every native test target that exists in this tree ──────────── -TARGETS=(webgpu_native_test webgpu_rms_norm_test webgpu_dispatch_order_test webgpu_scratch_buffer_test) +TARGETS=(webgpu_native_test webgpu_rms_norm_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test) BIN_DIR="${BUILD_DIR}/backends/webgpu" # Which targets are defined depends on which diffs are landed (native_test + @@ -122,6 +135,9 @@ fi if [[ "${RMS_NORM_OK}" == "1" && -x "${BIN_DIR}/webgpu_rms_norm_test" ]]; then "${BIN_DIR}/webgpu_rms_norm_test" "${RMS_NORM_DIR}" fi +if [[ "${UPDATE_CACHE_OK}" == "1" && -x "${BIN_DIR}/webgpu_update_cache_test" ]]; then + "${BIN_DIR}/webgpu_update_cache_test" "${UPDATE_CACHE_DIR}" +fi if [[ "${DISPATCH_ORDER_OK}" == "1" && -x "${BIN_DIR}/webgpu_dispatch_order_test" ]]; then "${BIN_DIR}/webgpu_dispatch_order_test" "${DISPATCH_ORDER_DIR}" fi diff --git a/backends/webgpu/test/native/test_update_cache.cpp b/backends/webgpu/test/native/test_update_cache.cpp new file mode 100644 index 00000000000..3f932ea7f03 --- /dev/null +++ b/backends/webgpu/test/native/test_update_cache.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::webgpu; +using namespace executorch::extension; +using namespace executorch::runtime; + +namespace { + +struct UpdateCacheCase { + const char* name; + int s; + int h; + int d; + int cmax; + int input_pos; +}; + +// Mirrors test_update_cache.py CASES; golden scatter is integer-exact (inline). +constexpr UpdateCacheCase kCases[] = { + {"prefill", 2, 2, 4, 8, 0}, + {"offset", 2, 2, 4, 8, 5}, + {"shape_b", 3, 4, 8, 16, 0}, + {"shape_b_offset", 3, 4, 8, 16, 10}, +}; + +bool run_case(const std::string& dir, const UpdateCacheCase& tc) { + printf( + "\n--- Test: update_cache[%s] (S=%d,H=%d,D=%d,Cmax=%d,pos=%d) ---\n", + tc.name, + tc.s, + tc.h, + tc.d, + tc.cmax, + tc.input_pos); + Module module(dir + "/" + tc.name + ".pte"); + if (module.load_forward() != Error::Ok) { + printf("FAIL: could not load %s.pte\n", tc.name); + return false; + } + + const int vnumel = tc.s * tc.h * tc.d; + const int cnumel = tc.cmax * tc.h * tc.d; + std::vector value(vnumel); + std::vector cache(cnumel); + for (int i = 0; i < vnumel; i++) { + value[i] = static_cast(i) * 0.5f; + } + for (int i = 0; i < cnumel; i++) { + cache[i] = static_cast(i) + 100.0f; + } + + // Inline reference: scatter value into the cache at input_pos, bounds-checked + // exactly as the op (integer-exact copy, no library needed). + std::vector ref(cache); + const int dst_offset = tc.input_pos * tc.h * tc.d; + for (int i = 0; i < vnumel; i++) { + if (dst_offset + i < cnumel) { + ref[dst_offset + i] = value[i]; + } + } + + auto v = make_tensor_ptr({1, tc.s, tc.h, tc.d}, std::vector(value)); + auto c = make_tensor_ptr({1, tc.cmax, tc.h, tc.d}, std::vector(cache)); + auto result = module.forward({EValue(v), EValue(c)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + const auto& out_tensor = outputs[0].toTensor(); + if (static_cast(out_tensor.numel()) != cnumel) { + printf( + "FAIL: output numel %zu != expected %d\n", + (size_t)out_tensor.numel(), + cnumel); + return false; + } + const float* out_data = out_tensor.const_data_ptr(); + + float max_abs_err = 0.0f; + for (int i = 0; i < cnumel; i++) { + max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i])); + } + printf("Max abs error: %e (checked %d elements)\n", max_abs_err, cnumel); + // update_cache is a pure scatter copy: the output must be bit-exact. + if (max_abs_err > 0.0f) { + printf("FAIL: update_cache[%s] not bit-exact\n", tc.name); + return false; + } + printf("PASS: update_cache[%s]\n", tc.name); + return true; +} + +struct ReplayCase { + const char* name; + int h; + int d; + std::vector seq_lens; +}; + +// Multi-step advancing-input_pos cache accumulation, mirroring VulkanSDPATest. +bool run_replay(const std::string& dir, const ReplayCase& rc) { + int cmax = 0; + for (int s : rc.seq_lens) { + cmax += s; + } + printf( + "\n--- Replay: update_cache[%s] (H=%d,D=%d,Cmax=%d,%zu steps) ---\n", + rc.name, + rc.h, + rc.d, + cmax, + rc.seq_lens.size()); + + const int cnumel = cmax * rc.h * rc.d; + std::vector cache(cnumel); + for (int i = 0; i < cnumel; i++) { + cache[i] = static_cast(i) + 100.0f; + } + std::vector ref(cache); + + int input_pos = 0; + bool ok = true; + for (size_t step = 0; step < rc.seq_lens.size(); step++) { + const int s = rc.seq_lens[step]; + const int vnumel = s * rc.h * rc.d; + std::vector value(vnumel); + const float base = static_cast((input_pos + 1) * 1000); + for (int i = 0; i < vnumel; i++) { + value[i] = (base + static_cast(i)) * 0.25f; + } + + const std::string fname = dir + "/" + rc.name + "_step" + + std::to_string(step) + "_S" + std::to_string(s) + "_pos" + + std::to_string(input_pos) + ".pte"; + Module module(fname); + if (module.load_forward() != Error::Ok) { + printf("FAIL: could not load %s\n", fname.c_str()); + return false; + } + + auto v = make_tensor_ptr({1, s, rc.h, rc.d}, std::vector(value)); + auto c = make_tensor_ptr({1, cmax, rc.h, rc.d}, std::vector(cache)); + auto result = module.forward({EValue(v), EValue(c)}); + if (!result.ok()) { + printf( + "FAIL: forward failed step %zu (error %d)\n", + step, + (int)result.error()); + return false; + } + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor() || + static_cast(outputs[0].toTensor().numel()) != cnumel) { + printf("FAIL: bad cache output at step %zu\n", step); + return false; + } + const float* out_data = outputs[0].toTensor().const_data_ptr(); + + const int dst_offset = input_pos * rc.h * rc.d; + for (int i = 0; i < vnumel; i++) { + if (dst_offset + i < cnumel) { + ref[dst_offset + i] = value[i]; + } + } + + float max_abs_err = 0.0f; + for (int i = 0; i < cnumel; i++) { + max_abs_err = std::max(max_abs_err, std::abs(out_data[i] - ref[i])); + cache[i] = out_data[i]; // thread the accumulated cache into the next step + } + printf( + " step %zu (S=%d,pos=%d): max abs error %e\n", + step, + s, + input_pos, + max_abs_err); + if (max_abs_err > 0.0f) { // pure scatter copy: must be bit-exact + ok = false; + } + input_pos += s; + } + + if (ok) { + printf("PASS: update_cache[%s] replay\n", rc.name); + } else { + printf("FAIL: update_cache[%s] replay\n", rc.name); + } + return ok; +} + +struct NegativeCase { + const char* name; + const char* guard; +}; + +// Single-op, single-guard-violation cases: rejection maps to the named guard. +bool run_negative_case(const std::string& dir, const NegativeCase& nc) { + printf( + "\n--- Negative: update_cache[%s] (expect rejection: %s) ---\n", + nc.name, + nc.guard); + Module module(dir + "/" + nc.name + ".pte"); + const Error err = module.load_forward(); + // init catches the guard throw -> this code; other errors = setup failure. + if (err != Error::DelegateInvalidCompatibility) { + printf( + "FAIL: %s.pte -> error %d; expected DelegateInvalidCompatibility " + "from the '%s' guard\n", + nc.name, + (int)err, + nc.guard); + return false; + } + printf("PASS: rejected with DelegateInvalidCompatibility (%s)\n", nc.guard); + return true; +} + +} // namespace + +int main(int argc, char** argv) { + std::string dir = "/tmp/update_cache"; + if (argc > 1) { + dir = argv[1]; + } + if (const char* env = std::getenv("WEBGPU_UPDATE_CACHE_DIR")) { + dir = env; + } + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + printf("SKIP: %s\n", e.what()); + return 0; + } + set_default_webgpu_context(&ctx); + printf("WebGPU device acquired (native); case dir: %s\n", dir.c_str()); + + bool ok = true; + for (const auto& tc : kCases) { + ok = run_case(dir, tc) && ok; + } + + const std::vector kReplays = { + {"seqA", 4, 4, {3, 1, 1, 5, 1, 1, 2}}, + {"seqB", 2, 8, {3, 1, 1, 5, 1, 1}}, + {"llama3", 8, 128, {111, 1, 1, 1, 57, 1, 1}}, + }; + for (const auto& rc : kReplays) { + ok = run_replay(dir, rc) && ok; + } + + const NegativeCase kNegatives[] = { + {"neg_batch", "batch must be 1"}, + {"neg_fp16", "fp32-only"}, + }; + for (const auto& nc : kNegatives) { + ok = run_negative_case(dir, nc) && ok; + } + + set_default_webgpu_context(nullptr); + destroy_webgpu_context(ctx); + + if (!ok) { + return 1; + } + printf("\nAll update_cache tests passed\n"); + return 0; +} diff --git a/backends/webgpu/test/ops/sdpa/__init__.py b/backends/webgpu/test/ops/sdpa/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/webgpu/test/ops/sdpa/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/webgpu/test/ops/sdpa/test_update_cache.py b/backends/webgpu/test/ops/sdpa/test_update_cache.py new file mode 100644 index 00000000000..a25321bdd7d --- /dev/null +++ b/backends/webgpu/test/ops/sdpa/test_update_cache.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""fp32 update_cache (KV-cache write) export tests via VulkanPartitioner. + +Verifies the export/delegation side here; on-GPU numerics are checked by the +dedicated native test `test/native/test_update_cache.cpp`: single-shot cases +(non-zero input_pos + varied shapes) via `export_update_cache_cases`, and the +multi-step advancing-input_pos replay (mirroring VulkanSDPATest) via +`export_update_cache_replay`. update_cache scatters a projected value tensor +[1, S, H, D] into the KV cache [1, Cmax, H, D] at the sequence offset input_pos. +""" + +import os +import unittest + +import torch + +# Importing custom_ops registers torch.ops.llama.update_cache (the schema lives +# in the C++ AOT lib loaded here). +from executorch.backends.vulkan import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + + +class UpdateCacheModule(torch.nn.Module): + """Writes the projected value into the KV cache at input_pos.""" + + def __init__(self, input_pos: int = 0) -> None: + super().__init__() + self.input_pos = input_pos + + def forward(self, value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: + return torch.ops.llama.update_cache(value, cache, self.input_pos) + + +class TestUpdateCache(unittest.TestCase): + def _export_and_check(self, model, example_inputs) -> None: + ep = torch.export.export(model, example_inputs) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + found_vulkan = False + for plan in et_program.executorch_program.execution_plan: + for delegate in plan.delegates: + if delegate.id == "VulkanBackend": + found_vulkan = True + break + self.assertTrue(found_vulkan, "Expected VulkanBackend delegate in .pte") + + def test_update_cache_prefill_small(self) -> None: + # input_pos=0 prefill: value [1,S=2,H=2,D=4] into cache [1,Cmax=8,H=2,D=4]. + value = torch.randn(1, 2, 2, 4) + cache = torch.zeros(1, 8, 2, 4) + self._export_and_check(UpdateCacheModule(0), (value, cache)) + + def test_update_cache_gqa_shapes(self) -> None: + # GQA-style: fewer kv heads, larger head dim. + value = torch.randn(1, 3, 2, 8) + cache = torch.zeros(1, 16, 2, 8) + self._export_and_check(UpdateCacheModule(0), (value, cache)) + + +def export_update_cache_model(output_path: str) -> None: + """Export an update_cache model to .pte for the native runtime test. + + Shapes match the native test: value [1,S=2,H=2,D=4] into cache + [1,Cmax=8,H=2,D=4] at input_pos=0. Example tensor *values* here are only for + tracing; the native test supplies its own deterministic inputs at runtime. + """ + S, H, D, Cmax = 2, 2, 4, 8 + model = UpdateCacheModule(0) + value = torch.zeros(1, S, H, D) + cache = torch.zeros(1, Cmax, H, D) + ep = torch.export.export(model, (value, cache)) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported {output_path}") + + +# (name, S, H, D, Cmax, input_pos) -- mirrors kCases in +# test/native/test_update_cache.cpp. Covers non-zero input_pos (the dst_offset +# path) and a second head_dim/n_heads shape. All writes stay in-bounds. +_NATIVE_CASES = [ + ("prefill", 2, 2, 4, 8, 0), + ("offset", 2, 2, 4, 8, 5), + ("shape_b", 3, 4, 8, 16, 0), + ("shape_b_offset", 3, 4, 8, 16, 10), +] + + +def export_update_cache_cases(out_dir: str) -> None: + """Export one .pte per native test case (input_pos baked). + + The native test supplies deterministic inputs and computes the integer-exact + scatter reference inline, so only the .pte (shapes + input_pos baked) is + written here -- no golden file. + """ + os.makedirs(out_dir, exist_ok=True) + for name, s, h, d, cmax, input_pos in _NATIVE_CASES: + model = UpdateCacheModule(input_pos) + value = torch.zeros(1, s, h, d) + cache = torch.zeros(1, cmax, h, d) + ep = torch.export.export(model, (value, cache)) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + with open(os.path.join(out_dir, f"{name}.pte"), "wb") as f: + f.write(et_program.buffer) + print(f"Exported {name}.pte (input_pos={input_pos})") + + +# (name, num_kv_heads, head_dim, seq_lens) -- mirrors the VulkanSDPATest param +# sets (sdpa_test.cpp:855-881). Cmax = sum(seq_lens) (exact fit). The native test +# threads the returned cache across steps as input_pos advances by seq_len. +_REPLAY_SEQS = [ + ("seqA", 4, 4, [3, 1, 1, 5, 1, 1, 2]), + ("seqB", 2, 8, [3, 1, 1, 5, 1, 1]), + ("llama3", 8, 128, [111, 1, 1, 1, 57, 1, 1]), +] + + +def export_update_cache_replay(out_dir: str) -> None: + """Export one .pte per replay step (seq_len + input_pos baked). + + Mirrors Vulkan's multi-step advancing-input_pos cache accumulation; the + native test feeds the returned cache into the next step and checks the + integer-exact scatter golden after each write -- no golden file. + """ + os.makedirs(out_dir, exist_ok=True) + for name, h, d, seqs in _REPLAY_SEQS: + cmax = sum(seqs) + input_pos = 0 + for idx, s in enumerate(seqs): + model = UpdateCacheModule(input_pos) + value = torch.zeros(1, s, h, d) + cache = torch.zeros(1, cmax, h, d) + ep = torch.export.export(model, (value, cache)) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + fname = f"{name}_step{idx}_S{s}_pos{input_pos}.pte" + with open(os.path.join(out_dir, fname), "wb") as f: + f.write(et_program.buffer) + print(f"Exported {fname}") + input_pos += s + + +# (name, value_shape, cache_shape, dtype) -- each violates one runtime guard but +# still delegates to VulkanBackend at export (ATen's update_cache meta allows +# it). The WebGPU backend must reject each at graph build; the native test +# asserts a graceful delegate error (no crash, no silent-wrong output). The +# other guards (head_dim/n_heads mismatch, non-4D, out-of-bounds start_pos) are +# rejected by ATen at export, so they cannot be baked into a .pte. +_NEGATIVE_CASES = [ + ("neg_batch", (2, 2, 2, 4), (2, 8, 2, 4), torch.float32), # batch must be 1 + ("neg_fp16", (1, 2, 2, 4), (1, 8, 2, 4), torch.float16), # fp32-only +] + + +def export_update_cache_negative(out_dir: str) -> None: + """Export guard-violating .pte's the WebGPU backend must reject at build. + + Asserts each still delegates to VulkanBackend, so the native test exercises + the runtime guard rather than a CPU-fallback path. + """ + os.makedirs(out_dir, exist_ok=True) + for name, vshape, cshape, dtype in _NEGATIVE_CASES: + model = UpdateCacheModule(0) + value = torch.zeros(*vshape, dtype=dtype) + cache = torch.zeros(*cshape, dtype=dtype) + ep = torch.export.export(model, (value, cache)) + et_program = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + delegated = any( + d.id == "VulkanBackend" + for plan in et_program.executorch_program.execution_plan + for d in plan.delegates + ) + if not delegated: + raise RuntimeError(f"{name}: expected VulkanBackend delegation") + with open(os.path.join(out_dir, f"{name}.pte"), "wb") as f: + f.write(et_program.buffer) + print(f"Exported {name}.pte") + + +if __name__ == "__main__": + unittest.main() From 78293be2ab8aae7d18ef05d56814c4e737e8380e Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Wed, 10 Jun 2026 13:19:20 -0700 Subject: [PATCH 3/3] [ExecuTorch][WebGPU] SymInt live-scalar mechanism + et_vk.select_as_symint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20085 Adds the dynamic-scalar (SymInt) mechanism to the WebGPU graph as a standalone enabler, ahead of the SDPA op that consumes it. Mirrors the Vulkan delegate's SymInt = live uniform-buffer design: a `ValueType::SymInt` backed by a 16-byte `Uniform|CopyDst` buffer, `set_symint`/`read_symint`/`symint_buffer` accessors with dirty-tracking, a `SymIntSource` + `add_symint_source`/`update_symints_from_inputs` host-read path, and an `add_resize_hook`/`propagate_resize`/`dispatch_at` recompute plumbing. `WebGPUBackend::execute` calls `propagate_resize` after refreshing the SymInts from the runtime inputs. The `et_vk.select_as_symint` op handler records `out SymInt = x[index]` along a dim at build time. This diff has no in-graph consumer yet — the SDPA op (stacked above) reads the SymInt value via `read_symint()` for dynamic `input_pos`. Building it as its own diff keeps the enabler separate from the op, matching the update_cache → mechanism → SDPA layering. Authored with assistance from Claude. ghstack-source-id: 391979584 @exported-using-ghexport Differential Revision: [D107584280](https://our.internmc.facebook.com/intern/diff/D107584280/) --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUBackend.cpp | 9 ++ backends/webgpu/runtime/WebGPUGraph.cpp | 111 ++++++++++++++++++ backends/webgpu/runtime/WebGPUGraph.h | 74 +++++++++++- .../ops/select_as_symint/SelectAsSymint.cpp | 47 ++++++++ 5 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 3351c213d4a..9b1476f2290 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -34,6 +34,7 @@ set(WEBGPU_SRCS runtime/ops/add/BinaryOp.cpp runtime/ops/rms_norm/RmsNorm.cpp runtime/ops/update_cache/UpdateCache.cpp + runtime/ops/select_as_symint/SelectAsSymint.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index b4e3165d8f4..aed769da4a4 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -106,6 +106,15 @@ Error WebGPUBackend::execute( } graph->copy_inputs(inputs); + // Fail loud as a runtime Error so a throw never crosses the backend boundary. + try { + graph->update_symints_from_inputs(inputs); + graph->propagate_resize(); + } catch (const std::exception& e) { + ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what()); + return Error::Internal; + } + // Execute the compute graph graph->execute(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index a60bfc18e3b..b3ae5511d13 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -59,6 +59,86 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) { return buffer; } +void WebGPUGraph::update_symints_from_inputs( + const std::vector>& inputs) { + for (const auto& src : symint_sources_) { + int pos = -1; + for (size_t i = 0; i < input_ids_.size(); i++) { + if (input_ids_[i] == src.input_tensor_id) { + pos = static_cast(i); + break; + } + } + if (pos < 0 || pos >= static_cast(inputs.size())) { + throw std::runtime_error( + "select_as_symint: source tensor is not a graph input"); + } + const auto& dims = tensors_[src.input_tensor_id].dims; + int dim = src.dim < 0 ? src.dim + static_cast(dims.size()) : src.dim; + if (dim < 0 || dim >= static_cast(dims.size())) { + throw std::runtime_error("select_as_symint: dim out of range"); + } + int index = src.index; + if (index < 0) { + index += static_cast(dims[dim]); + } + if (index < 0 || index >= static_cast(dims[dim])) { + throw std::runtime_error("select_as_symint: index out of range"); + } + int64_t numel = 1; + for (int64_t d : dims) { + numel *= d; + } + if (numel <= 0) { + throw std::runtime_error("select_as_symint: empty input tensor"); + } + int64_t stride = 1; + for (size_t i = static_cast(dim) + 1; i < dims.size(); i++) { + stride *= dims[i]; + } + // Reads the [0,..,index,..,0] element; symint sources are scalar-ish. + const int64_t offset = static_cast(index) * stride; + // elem_size back-derived from build-time numel (sources are static-shaped). + const void* host = inputs[pos].first; + const size_t elem_size = inputs[pos].second / static_cast(numel); + int32_t val; + if (elem_size == sizeof(int64_t)) { + val = static_cast(static_cast(host)[offset]); + } else if (elem_size == sizeof(int32_t)) { + val = static_cast(host)[offset]; + } else { + throw std::runtime_error( + "select_as_symint: unsupported input element size"); + } + set_symint(src.symint_id, val); + } +} + +void WebGPUGraph::set_symint(int id, int32_t val) { + auto it = symints_.find(id); + if (it == symints_.end()) { + throw std::runtime_error("WebGPUGraph::set_symint: id is not a SymInt"); + } + if (it->second.value != val) { + it->second.value = val; + wgpuQueueWriteBuffer( + queue_, it->second.buffer, 0, &it->second.value, sizeof(int32_t)); + dirty_symints_.insert(id); + } +} + +void WebGPUGraph::propagate_resize() { + if (dirty_symints_.empty()) { + return; + } + for (auto& hook : resize_hooks_) { + if (dirty_symints_.count(hook.symint_id) != 0) { + hook.fn(*this); + } + } + dirty_symints_.clear(); +} + WebGPUGraph::~WebGPUGraph() { for (size_t i = 0; i < tensors_.size(); i++) { if (tensors_[i].buffer && @@ -76,6 +156,16 @@ WebGPUGraph::~WebGPUGraph() { wgpuBufferRelease(buf); } } + for (auto& buf : owned_uniform_buffers_) { + if (buf) { + wgpuBufferRelease(buf); + } + } + for (auto& kv : symints_) { + if (kv.second.buffer) { + wgpuBufferRelease(kv.second.buffer); + } + } for (auto& buf : output_staging_buffers_) { if (buf) { wgpuBufferRelease(buf); @@ -236,6 +326,27 @@ void WebGPUGraph::build( bools_[i] = val->value_as_Bool()->bool_val(); break; } + case vkgraph::GraphTypes::SymInt: { + // Live scalar: small Uniform buffer the CPU rewrites per execute. + value_types_[i] = ValueType::SymInt; + SymIntSlot slot; + slot.value = static_cast(val->value_as_SymInt()->value()); + // 16B matches the backend uniform-struct alignment; int32 in first 4. + constexpr size_t kSymIntUniformBytes = 16; + WGPUBufferDescriptor d = {}; + d.size = kSymIntUniformBytes; + d.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + d.mappedAtCreation = true; + slot.buffer = wgpuDeviceCreateBuffer(device_, &d); + void* mapped = + wgpuBufferGetMappedRange(slot.buffer, 0, kSymIntUniformBytes); + std::memset(mapped, 0, kSymIntUniformBytes); + std::memcpy(mapped, &slot.value, sizeof(int32_t)); + wgpuBufferUnmap(slot.buffer); + symints_[i] = slot; + add_uniform_buffer_bytes(kSymIntUniformBytes); + break; + } default: value_types_[i] = ValueType::Null; break; diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index aa3dadc13ab..9f656ce4d14 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -11,8 +11,10 @@ #include #include +#include #include #include +#include #include #include @@ -104,6 +106,52 @@ class WebGPUGraph { return ints_[id]; } + // Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO. + // set_symint writes the buffer + marks dirty only if the value changed. + void set_symint(int id, int32_t val); + // read_symint throws (fail-loud) if id is not a SymInt. + int32_t read_symint(int id) const { + return symints_.at(id).value; + } + // symint_buffer throws (fail-loud) if id is not a SymInt. + WGPUBuffer symint_buffer(int id) const { + return symints_.at(id).buffer; + } + + // Records that a SymInt's value is read from input_tensor[index] along dim. + struct SymIntSource { + int symint_id; + int input_tensor_id; + int dim; + int index; + }; + void + add_symint_source(int symint_id, int input_tensor_id, int dim, int index) { + symint_sources_.push_back({symint_id, input_tensor_id, dim, index}); + } + const std::vector& symint_sources() const { + return symint_sources_; + } + + // Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl. + void update_symints_from_inputs( + const std::vector>& inputs); + + // Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize. + void add_resize_hook(int symint_id, std::function fn) { + resize_hooks_.push_back({symint_id, std::move(fn)}); + } + // Run hooks for changed SymInts then clear; call before execute(). + void propagate_resize(); + + // Mutable dispatch access for resize hooks (to rewrite workgroup_count_x). + WebGPUDispatch& dispatch_at(size_t i) { + return dispatches_[i]; + } + size_t num_dispatches() const { + return dispatches_.size(); + } + WGPUDevice device() const { return device_; } @@ -119,6 +167,11 @@ class WebGPUGraph { uniform_buffer_bytes_ += bytes; } + // Keep a uniform alive for the graph's lifetime; released in the dtor. + void own_uniform_buffer(WGPUBuffer buffer) { + owned_uniform_buffers_.push_back(buffer); + } + // Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA). WGPUBuffer create_scratch_buffer(size_t nbytes); @@ -149,7 +202,7 @@ class WebGPUGraph { return static_cast(value_types_.size()); } - enum class ValueType { Tensor, Int, Double, Bool, Null, String }; + enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt }; ValueType get_value_type(int id) const { return value_types_[id]; @@ -168,6 +221,22 @@ class WebGPUGraph { std::vector doubles_; std::vector bools_; + // SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse. + struct SymIntSlot { + WGPUBuffer buffer = nullptr; + int32_t value = 0; + }; + std::unordered_map symints_; + std::vector symint_sources_; + + // Resize hooks + the set of SymInts changed since the last propagate_resize. + struct ResizeHook { + int symint_id; + std::function fn; + }; + std::vector resize_hooks_; + std::unordered_set dirty_symints_; + std::vector input_ids_; std::vector output_ids_; @@ -179,6 +248,9 @@ class WebGPUGraph { // Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries). std::vector scratch_buffers_; + // Uniform buffers owned for the graph's lifetime; released in the dtor. + std::vector owned_uniform_buffers_; + // Staging buffers for reading back outputs (MapRead | CopyDst). std::vector output_staging_buffers_; diff --git a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp new file mode 100644 index 00000000000..573a88ce0fe --- /dev/null +++ b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// et_vk.select_as_symint: out SymInt = x[index] along dim; read at execute. +void select_as_symint_impl(WebGPUGraph& graph, const std::vector& args) { + const int x_id = args.at(0); + const int dim_id = args.at(1); + const int index_id = args.at(2); + const int out_id = args.at(3); + + if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) { + throw std::runtime_error("select_as_symint: output is not a SymInt"); + } + const std::vector& inputs = graph.input_ids(); + if (std::find(inputs.begin(), inputs.end(), x_id) == inputs.end()) { + throw std::runtime_error( + "select_as_symint: source tensor is not a graph input"); + } + graph.add_symint_source( + out_id, + x_id, + static_cast(graph.get_int(dim_id)), + static_cast(graph.get_int(index_id))); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl); +} + +} // namespace executorch::backends::webgpu