diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index 5a232b3a1bd..fadef569d2e 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -156,7 +156,8 @@ void build_dispatch( uint64_t uniform_size, uint32_t workgroup_count_x, uint32_t wg_size, - bool retain_uniform = false) { + bool retain_uniform = false, + const char* kernel_name = "") { WGPUDevice device = graph.device(); WGPUShaderSourceWGSL wgsl_desc = {}; @@ -227,7 +228,7 @@ void build_dispatch( bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count_x}); + graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name}); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); @@ -269,7 +270,8 @@ static WGPUBuffer record_update_cache_dispatch( sizeof(uc), wgc, uc_wg, - dynamic_pos); + dynamic_pos, + "update_cache"); return ubuf; } @@ -473,7 +475,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, qk_wg, - dynamic_pos); + dynamic_pos, + "sdpa_compute_attn_weights"); qk_buf = ubuf; qk_idx = graph.num_dispatches() - 1; } @@ -496,7 +499,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, 0, - dynamic_pos); + dynamic_pos, + "sdpa_softmax"); softmax_buf = ubuf; } @@ -521,7 +525,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { sizeof(p), wgc, av_wg, - dynamic_pos); + dynamic_pos, + "sdpa_compute_out"); av_buf = ubuf; }