Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ void build_dispatch(
uint64_t uniform_size,
uint32_t workgroup_count_x,
uint32_t wg_size,
bool retain_uniform = false) {
bool retain_uniform = false,
const char* kernel_name = "") {
WGPUDevice device = graph.device();

WGPUShaderSourceWGSL wgsl_desc = {};
Expand Down Expand Up @@ -227,7 +228,7 @@ void build_dispatch(
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count_x});
graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
Expand Down Expand Up @@ -269,7 +270,8 @@ static WGPUBuffer record_update_cache_dispatch(
sizeof(uc),
wgc,
uc_wg,
dynamic_pos);
dynamic_pos,
"update_cache");
return ubuf;
}

Expand Down Expand Up @@ -473,7 +475,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
qk_wg,
dynamic_pos);
dynamic_pos,
"sdpa_compute_attn_weights");
qk_buf = ubuf;
qk_idx = graph.num_dispatches() - 1;
}
Expand All @@ -496,7 +499,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
0,
dynamic_pos);
dynamic_pos,
"sdpa_softmax");
softmax_buf = ubuf;
}

Expand All @@ -521,7 +525,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
av_wg,
dynamic_pos);
dynamic_pos,
"sdpa_compute_out");
av_buf = ubuf;
}

Expand Down
127 changes: 126 additions & 1 deletion backends/webgpu/test/test_webgpu_native.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -1122,6 +1122,129 @@
return true;
}

// Capacity-overrun must throw; runs without a device or TimestampQuery.
static bool test_query_pool_overrun_throws() {
printf("\n--- Test: WebGPUQueryPool capacity-overrun guard ---\n");
WebGPUQueryPool qp;
try {
qp.reset(1);
} catch (const std::exception&) {
printf("PASS: reset beyond capacity throws\n");
return true;
}
printf("FAIL: reset beyond capacity did not throw\n");
return false;
}

// WebGPUQueryPool roundtrip: time a probe pass; assert non-zero GPU duration.
static bool test_query_pool_roundtrip(const WebGPUContext& ctx) {
printf("\n--- Test: WebGPUQueryPool roundtrip ---\n");
if (!ctx.timestamp_supported) {
printf("SKIP: adapter lacks TimestampQuery feature\n");
return true;
}
WGPUDevice device = ctx.device;

// Probe loop iterates enough to burn a measurable, non-zero GPU duration.
const char* kProbeWGSL =
"@group(0) @binding(0) var<storage, read_write> out: array<f32>;\n"
"@compute @workgroup_size(64)\n"
"fn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n"
" var acc = 0.0;\n"
" for (var i = 0u; i < 8192u; i = i + 1u) {\n"
" acc = acc + f32(i) * 1.000001;\n"
" }\n"
" out[gid.x] = acc;\n"
"}\n";

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kProbeWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

WGPUBindGroupLayoutEntry bgl_entry = {};
bgl_entry.binding = 0;
bgl_entry.visibility = WGPUShaderStage_Compute;
bgl_entry.buffer.type = WGPUBufferBindingType_Storage;
WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 1;
bgl_desc.entries = &bgl_entry;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pl = wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipe_desc = {};
pipe_desc.layout = pl;
pipe_desc.compute.module = shader;
pipe_desc.compute.entryPoint = {"main", WGPU_STRLEN};
WGPUComputePipeline pipe =
wgpuDeviceCreateComputePipeline(device, &pipe_desc);

WGPUBufferDescriptor obd = {};
obd.size = 64 * sizeof(float);
obd.usage = WGPUBufferUsage_Storage;
WGPUBuffer out_buf = wgpuDeviceCreateBuffer(device, &obd);

WGPUBindGroupEntry bg_entry = {};
bg_entry.binding = 0;
bg_entry.buffer = out_buf;
bg_entry.size = obd.size;
WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 1;
bg_desc.entries = &bg_entry;
WGPUBindGroup bg = wgpuDeviceCreateBindGroup(device, &bg_desc);

WebGPUQueryPool qp;
qp.initialize(device, 1);
qp.reset(1);

WGPUCommandEncoder enc = wgpuDeviceCreateCommandEncoder(device, nullptr);
WGPUPassTimestampWrites tw = qp.writes_for(0);
WGPUComputePassDescriptor pass_desc = {};
pass_desc.timestampWrites = &tw;
WGPUComputePassEncoder pass =
wgpuCommandEncoderBeginComputePass(enc, &pass_desc);
wgpuComputePassEncoderSetPipeline(pass, pipe);
wgpuComputePassEncoderSetBindGroup(pass, 0, bg, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(pass, 1, 1, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
qp.record(0, "probe", {1, 1, 1}, {64, 1, 1});
qp.resolve(enc);
WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(enc, nullptr);
wgpuQueueSubmit(ctx.queue, 1, &cmd);
wgpuCommandBufferRelease(cmd);
wgpuCommandEncoderRelease(enc);

qp.extract_results(ctx.instance);

wgpuBufferRelease(out_buf);
wgpuComputePipelineRelease(pipe);
wgpuPipelineLayoutRelease(pl);
wgpuBindGroupLayoutRelease(bgl);
wgpuBindGroupRelease(bg);
wgpuShaderModuleRelease(shader);

if (qp.results().size() != 1) {
printf("FAIL: expected 1 duration, got %zu\n", qp.results().size());
return false;
}
const uint64_t dur = qp.results()[0].execution_duration_ns;
printf(" probe duration: %llu ns\n", (unsigned long long)dur);
if (dur == 0) {
printf("FAIL: probe duration is zero (expected monotonic non-zero)\n");
return false;
}
printf("PASS: WebGPUQueryPool roundtrip -- non-zero GPU kernel duration\n");
return true;
}

int main(int argc, char** argv) {
std::string model_path = "webgpu_add_test.pte";
if (argc > 1) {
Expand Down Expand Up @@ -1163,7 +1286,9 @@
set_default_webgpu_context(&ctx);
printf("WebGPU device acquired (native)\n");

bool ok = test_single_add(model_path);
bool ok = test_query_pool_overrun_throws();
ok = test_query_pool_roundtrip(ctx) && ok;
ok = test_single_add(model_path) && ok;

if (!chained_model_path.empty()) {
ok = test_chained_add(chained_model_path) && ok;
Expand Down
Loading