diff --git a/Makefile b/Makefile
index 097536fd0a6..459f2f50c11 100644
--- a/Makefile
+++ b/Makefile
@@ -127,7 +127,7 @@ help:
@echo " llava-cpu - Build Llava runner with CPU backend"
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
- @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
+ @echo " gemma4_31b-cuda - Build Gemma 4 31B runner + OpenAI serving worker with CUDA backend"
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner + OpenAI serving worker (CUDA)"
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
@@ -444,11 +444,13 @@ qwen3_5_moe-cuda:
gemma4_31b-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
- @echo "==> Building Gemma 4 31B runner with CUDA..."
+ @echo "==> Building Gemma 4 31B runner + serving worker with CUDA..."
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
+ @echo " Serving worker: cmake-out/examples/models/gemma4_31b/gemma4_31b_worker"
+ @echo " Launch: see examples/models/gemma4_31b/README.md (Serving)"
gemma4_31b-mlx:
@echo "==> Building and installing ExecuTorch with MLX..."
diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt
index 52419eb95bc..8021ef47d58 100644
--- a/examples/models/gemma4_31b/CMakeLists.txt
+++ b/examples/models/gemma4_31b/CMakeLists.txt
@@ -15,6 +15,9 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
+set(_json_include
+ ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include
+)
# gflags
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
@@ -58,9 +61,13 @@ endif()
# Tokenizer (HuggingFace tokenizer.json)
list(APPEND link_libraries tokenizers::tokenizers)
-add_executable(gemma4_31b_runner main.cpp)
+if(EXECUTORCH_BUILD_CUDA)
+ add_executable(gemma4_31b_runner main.cpp gemma4_31b_engine.cpp)
+else()
+ add_executable(gemma4_31b_runner main.cpp)
+endif()
target_include_directories(
- gemma4_31b_runner PUBLIC ${_common_include_directories}
+ gemma4_31b_runner PUBLIC ${_common_include_directories} ${_json_include}
)
target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})
@@ -71,6 +78,23 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
endif()
endif()
+if(EXECUTORCH_BUILD_CUDA)
+ add_executable(
+ gemma4_31b_worker gemma4_31b_worker.cpp gemma4_31b_engine.cpp
+ )
+ target_include_directories(
+ gemma4_31b_worker PUBLIC ${_common_include_directories} ${_json_include}
+ )
+ target_link_libraries(gemma4_31b_worker PUBLIC ${link_libraries})
+
+ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
+ target_link_options_gc_sections(gemma4_31b_worker)
+ if(NOT APPLE AND NOT MSVC)
+ target_link_options(gemma4_31b_worker PRIVATE "LINKER:-s")
+ endif()
+ endif()
+endif()
+
if(TARGET mlxdelegate)
executorch_target_copy_mlx_metallib(gemma4_31b_runner)
endif()
diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json
index 23a7d42e035..5d6019f1911 100644
--- a/examples/models/gemma4_31b/CMakePresets.json
+++ b/examples/models/gemma4_31b/CMakePresets.json
@@ -13,7 +13,7 @@
},
{
"name": "gemma4-31b-cuda",
- "displayName": "Gemma 4 31B runner (CUDA)",
+ "displayName": "Gemma 4 31B runner + serving worker (CUDA)",
"inherits": ["gemma4-31b-base"],
"cacheVariables": {
"EXECUTORCH_BUILD_CUDA": "ON"
@@ -39,9 +39,9 @@
"buildPresets": [
{
"name": "gemma4-31b-cuda",
- "displayName": "Build Gemma 4 31B runner (CUDA)",
+ "displayName": "Build Gemma 4 31B runner + serving worker (CUDA)",
"configurePreset": "gemma4-31b-cuda",
- "targets": ["gemma4_31b_runner"]
+ "targets": ["gemma4_31b_runner", "gemma4_31b_worker"]
},
{
"name": "gemma4-31b-mlx",
@@ -53,7 +53,7 @@
"workflowPresets": [
{
"name": "gemma4-31b-cuda",
- "displayName": "Configure and build Gemma 4 31B runner (CUDA)",
+ "displayName": "Configure and build Gemma 4 31B runner + serving worker (CUDA)",
"steps": [
{
"type": "configure",
diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md
index ae3bcb24c19..3961174808d 100644
--- a/examples/models/gemma4_31b/README.md
+++ b/examples/models/gemma4_31b/README.md
@@ -139,11 +139,12 @@ model produces sensible text.
## Build the runner
```bash
-make gemma4_31b-cuda # Linux — CUDA backend
-make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon)
+make gemma4_31b-cuda # Linux — CUDA runner + serving worker
+make gemma4_31b-mlx # macOS — MLX runner (serving later)
```
-The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.
+The CUDA build also produces
+`cmake-out/examples/models/gemma4_31b/gemma4_31b_worker`.
## Run the .pte
@@ -162,3 +163,29 @@ Pass `--raw_prompt` to skip template wrapping for pre-formatted input.
For benchmarking, add `--cuda_graph` to capture the decode method in a CUDA
graph (decode is fully static — `T=1`).
+
+## Serving
+
+The CUDA OpenAI-compatible server is a Python control plane plus a C++ model worker.
+The worker owns the ExecuTorch model and speaks the shared JSONL protocol used by
+the generic LLM server.
+
+```bash
+LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH \
+python -m executorch.examples.models.gemma4_31b.serve \
+ --model-path ./gemma4_31b_exports/model.pte \
+ --data-path ./gemma4_31b_exports/aoti_cuda_blob.ptd \
+ --tokenizer-path ./gemma4_31b_int4/tokenizer.json \
+ --hf-tokenizer ./gemma4_31b_int4 \
+ --model-id gemma4-31b \
+ --max-sessions 1
+```
+
+The launcher defaults to the Hermes `{...}` parser. Use
+`--tool-parser qwen` or `--tool-parser none` if the model/template you are
+testing emits a different tool-call format.
+
+Named sessions and warm resume require worker capacity above one. CUDA exports
+with `get_mutable_buffer_metadata` can use per-session mutable rebinding and
+advertise `--max-sessions > 1`; older exports fail closed to a single scratch
+session. MLX serving is intentionally left for a later change.
diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py
index d84e2c03a7f..0b4baa88fb1 100644
--- a/examples/models/gemma4_31b/export.py
+++ b/examples/models/gemma4_31b/export.py
@@ -24,6 +24,7 @@
"""
import argparse
+import json
import os
import torch
@@ -135,6 +136,11 @@ def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None:
# Export + lower
+def _mutable_buffer_metadata(model: nn.Module) -> str:
+ mutable = [name for name, _ in model.named_buffers() if ".kv_cache." in name]
+ return json.dumps({"version": 1, "mutable_buffers": mutable})
+
+
def export_and_lower(
model: Gemma4_31B,
config: Gemma4_31BConfig,
@@ -181,6 +187,7 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
import executorch.backends.cuda.quantize_op_dispatch # noqa: F401
materialize_runtime_buffers(model, dtype=torch.bfloat16)
+ mutable_buffer_metadata = _mutable_buffer_metadata(model)
# Int4Tensor weights are used directly — no format conversion.
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).
@@ -248,6 +255,8 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"get_max_prefill_chunk": max_prefill,
+ "get_min_prefill_chunk": 5,
+ "get_mutable_buffer_metadata": mutable_buffer_metadata,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
diff --git a/examples/models/gemma4_31b/gemma4_31b_engine.cpp b/examples/models/gemma4_31b/gemma4_31b_engine.cpp
new file mode 100644
index 00000000000..fb16f3c27af
--- /dev/null
+++ b/examples/models/gemma4_31b/gemma4_31b_engine.cpp
@@ -0,0 +1,599 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+#ifdef EXECUTORCH_BUILD_CUDA
+#include
+#include
+#include
+#endif
+
+namespace executorch::extension::llm {
+
+using ::executorch::extension::from_blob;
+using ::executorch::extension::Module;
+using ::executorch::extension::TensorPtr;
+using ::executorch::runtime::Error;
+using ::executorch::runtime::EValue;
+using ::executorch::runtime::Result;
+using SizesType = executorch::aten::SizesType;
+
+namespace {
+
+Result read_sampled_token(
+ const executorch::aten::Tensor& output,
+ float temperature) {
+#ifdef EXECUTORCH_BUILD_CUDA
+ (void)temperature;
+ const void* ptr = output.const_data_ptr();
+ cudaPointerAttributes attrs{};
+ const bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess &&
+ attrs.type == cudaMemoryTypeDevice;
+ float val = 0.0f;
+ if (on_device) {
+ if (cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost) !=
+ cudaSuccess) {
+ ET_LOG(Error, "read_sampled_token: cudaMemcpy D2H failed");
+ return Error::Internal;
+ }
+ } else {
+ std::memcpy(&val, ptr, sizeof(float));
+ }
+ return static_cast(llrintf(val));
+#else
+ (void)output;
+ (void)temperature;
+ return Error::NotSupported;
+#endif
+}
+
+Result> build_gemma_module(
+ const Gemma4_31BConfig& config) {
+ std::vector data_files;
+ if (!config.data_path.empty()) {
+ data_files.push_back(config.data_path);
+ }
+ auto module = std::make_unique(
+ config.model_path,
+ data_files,
+ Module::LoadMode::MmapUseMlockIgnoreErrors,
+ /*event_tracer=*/nullptr,
+ /*memory_allocator=*/nullptr,
+ /*temp_allocator=*/nullptr,
+ /*share_memory_arenas=*/true);
+
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (config.enable_cuda_graph) {
+ executorch::runtime::BackendOptions<2> cuda_opts;
+ ET_CHECK_OK_OR_RETURN_ERROR(
+ cuda_opts.set_option("enable_cuda_graph_for_method", "decode"));
+ ET_CHECK_OK_OR_RETURN_ERROR(
+ executorch::runtime::set_option("CudaBackend", cuda_opts.view()));
+ }
+ {
+ executorch::runtime::BackendOptions<1> backend_options;
+ ET_CHECK_OK_OR_RETURN_ERROR(
+ backend_options.set_option("weight_sharing_across_methods", true));
+ ET_CHECK_OK_OR_RETURN_ERROR(
+ executorch::runtime::set_option("CudaBackend", backend_options.view()));
+ }
+ ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill"));
+ ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode"));
+#else
+ (void)module;
+ ET_LOG(Error, "Gemma4_31BEngine is implemented for CUDA only");
+ return Error::NotSupported;
+#endif
+ return module;
+}
+
+void add_token_piece(
+ ::tokenizers::Tokenizer* tokenizer,
+ std::unordered_set& ids,
+ const char* piece) {
+ if (auto id = tokenizer->piece_to_id(piece); id.ok()) {
+ ids.insert(*id);
+ }
+}
+
+#ifdef EXECUTORCH_BUILD_CUDA
+Error register_mutable_fqns(Module* module, int mutable_ctx) {
+ auto res = module->execute("get_mutable_buffer_metadata");
+ if (res.error() != Error::Ok) {
+ ET_LOG(
+ Info, "Gemma4_31BEngine: no mutable-buffer metadata; capacity stays 1");
+ return res.error();
+ }
+ const auto& outs = res.get();
+ if (outs.empty() || !outs[0].isString()) {
+ ET_LOG(Error, "get_mutable_buffer_metadata did not return a string");
+ return Error::InvalidProgram;
+ }
+ std::string json_str(outs[0].toString());
+ auto j = nlohmann::json::parse(json_str, nullptr, /*allow_exceptions=*/false);
+ if (j.is_discarded() || !j.is_object() || j.value("version", 0) != 1 ||
+ !j.contains("mutable_buffers") || !j["mutable_buffers"].is_array()) {
+ ET_LOG(Error, "get_mutable_buffer_metadata has invalid schema");
+ return Error::InvalidProgram;
+ }
+ std::vector fqns;
+ for (const auto& f : j["mutable_buffers"]) {
+ if (!f.is_string() || f.get().empty()) {
+ ET_LOG(Error, "mutable_buffers entries must be non-empty strings");
+ return Error::InvalidProgram;
+ }
+ fqns.push_back(f.get());
+ }
+ if (fqns.empty()) {
+ ET_LOG(Error, "mutable_buffers must be non-empty for multi-session");
+ return Error::InvalidProgram;
+ }
+ ::executorch::backends::cuda::mutable_state_register_fqns(mutable_ctx, fqns);
+ return Error::Ok;
+}
+#endif
+
+class Gemma4_31BSession : public LLMSession {
+ public:
+ Gemma4_31BSession(
+ Module* module,
+ std::mutex* exec_mutex,
+ int mutable_ctx,
+ int session_token,
+ std::atomic* live_sessions,
+ ::tokenizers::Tokenizer* tokenizer,
+ std::unordered_map metadata,
+ std::unordered_set eos_ids,
+ int64_t max_prefill_chunk,
+ int64_t min_prefill_chunk)
+ : module_(module),
+ exec_mutex_(exec_mutex),
+ mutable_ctx_(mutable_ctx),
+ session_token_(session_token),
+ live_sessions_(live_sessions),
+ tokenizer_(tokenizer),
+ metadata_(std::move(metadata)),
+ eos_ids_(std::move(eos_ids)),
+ max_prefill_chunk_(max_prefill_chunk),
+ min_prefill_chunk_(min_prefill_chunk) {
+ decode_tokens_ = from_blob(
+ decode_token_data_, {1, 1}, executorch::aten::ScalarType::Long);
+ decode_pos_ =
+ from_blob(decode_pos_data_, {1}, executorch::aten::ScalarType::Long);
+#ifdef EXECUTORCH_BUILD_CUDA
+ temp_tensor_ =
+ from_blob(&temp_val_, {1}, executorch::aten::ScalarType::Float);
+#endif
+ }
+
+ ~Gemma4_31BSession() override {
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (session_token_ != ::executorch::backends::cuda::kNoMutableSession) {
+ ::executorch::backends::cuda::mutable_state_destroy_session(
+ mutable_ctx_, session_token_);
+ }
+#endif
+ if (live_sessions_ != nullptr) {
+ live_sessions_->fetch_sub(1);
+ }
+ }
+
+ Error prefill_tokens(
+ std::vector tokens,
+ const SamplingConfig* initial_sampling) override {
+ if (tokens.empty()) {
+ return Error::InvalidArgument;
+ }
+ float first_token_temp = temperature_;
+ if (initial_sampling != nullptr) {
+ if (initial_sampling->top_p != 1.0f || initial_sampling->top_k != 0 ||
+ initial_sampling->seed != 0) {
+ ET_LOG(
+ Error,
+ "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed "
+ "are not implemented");
+ return Error::NotSupported;
+ }
+ first_token_temp = initial_sampling->temperature;
+ }
+ const int64_t T = static_cast(tokens.size());
+ const auto ctx_it = metadata_.find(kMaxContextLen);
+ if (ctx_it != metadata_.end() && pos_ + T >= ctx_it->second) {
+ ET_LOG(Error, "prefill_tokens would leave no room to generate");
+ return Error::InvalidArgument;
+ }
+
+ stop_.store(false, std::memory_order_relaxed);
+ int64_t offset = 0;
+ while (offset < T) {
+ int64_t chunk = T - offset;
+ if (max_prefill_chunk_ > 0) {
+ chunk = std::min(chunk, max_prefill_chunk_);
+ }
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (chunk > 1 && chunk < min_prefill_chunk_) {
+ chunk = 1;
+ }
+#endif
+ auto sampled =
+ run_prefill_chunk(tokens.data() + offset, chunk, first_token_temp);
+ ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
+ pending_ = sampled.get();
+ pos_ += chunk;
+ offset += chunk;
+ }
+ prev_decode_token_ = tokens.back();
+ return Error::Ok;
+ }
+
+ Result decode_one(const SamplingConfig& sampling) override {
+ if (sampling.top_p != 1.0f || sampling.top_k != 0 || sampling.seed != 0) {
+ ET_LOG(
+ Error,
+ "Gemma4_31BSession: only temperature is supported; top_p/top_k/seed "
+ "are not implemented");
+ return Error::NotSupported;
+ }
+ ET_CHECK_OR_RETURN_ERROR(
+ pending_.has_value(),
+ InvalidState,
+ "decode_one requires a pending token; call prefill_tokens() first");
+ temperature_ = sampling.temperature;
+
+ const uint64_t token = pending_.value();
+ const bool is_eos = eos_ids_.find(token) != eos_ids_.end();
+ const uint64_t prev = prev_decode_token_.value_or(token);
+ auto dec = tokenizer_->decode(prev, token);
+ if (!dec.ok()) {
+ ET_LOG(
+ Error,
+ "Tokenizers error code %d",
+ static_cast(dec.error()));
+ return Error::InvalidArgument;
+ }
+ std::string text_piece = std::move(*dec);
+
+ if (is_eos || stop_.load(std::memory_order_relaxed)) {
+ pending_.reset();
+ return DecodeResult{
+ token, std::move(text_piece), is_eos, /*is_terminal=*/true};
+ }
+
+ const auto ctx_it = metadata_.find(kMaxContextLen);
+ if (ctx_it != metadata_.end()) {
+ ET_CHECK_OR_RETURN_ERROR(
+ pos_ < ctx_it->second,
+ InvalidArgument,
+ "decode_one would exceed context capacity");
+ }
+
+ decode_token_data_[0] = static_cast(token);
+ decode_pos_data_[0] = pos_;
+ std::vector inputs;
+ inputs.push_back(EValue(decode_tokens_));
+ inputs.push_back(EValue(decode_pos_));
+#ifdef EXECUTORCH_BUILD_CUDA
+ set_temp(temperature_);
+ inputs.push_back(EValue(temp_tensor_));
+ const char* method = "decode";
+#else
+ (void)inputs;
+ return Error::NotSupported;
+#endif
+ auto sampled =
+ run_locked(method, inputs, temperature_, /*sync_after=*/false);
+ ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
+ pending_ = sampled.get();
+ prev_decode_token_ = token;
+ pos_ += 1;
+ return DecodeResult{
+ token, std::move(text_piece), /*is_eos=*/false, /*is_terminal=*/false};
+ }
+
+ Error seek(int64_t pos) override {
+ (void)pos;
+ return Error::NotSupported;
+ }
+
+ int64_t position() const override {
+ return pos_;
+ }
+
+ Error reset() override {
+ pos_ = 0;
+ pending_.reset();
+ prev_decode_token_.reset();
+ stop_.store(false, std::memory_order_relaxed);
+ return Error::Ok;
+ }
+
+ void stop() override {
+ stop_.store(true, std::memory_order_relaxed);
+ }
+
+ private:
+#ifdef EXECUTORCH_BUILD_CUDA
+ void set_temp(float t) {
+ temp_val_ = (t <= 0.0f) ? 1e-6f : t;
+ }
+#endif
+
+ Result
+ run_prefill_chunk(const uint64_t* tokens, int64_t T, float temperature) {
+ std::vector token_data(tokens, tokens + T);
+ std::vector pos_data(T);
+ for (int64_t i = 0; i < T; ++i) {
+ pos_data[i] = pos_ + i;
+ }
+ auto tokens_tensor = from_blob(
+ token_data.data(),
+ {1, static_cast(T)},
+ executorch::aten::ScalarType::Long);
+ auto pos_tensor = from_blob(
+ pos_data.data(),
+ {static_cast(T)},
+ executorch::aten::ScalarType::Long);
+ std::vector inputs;
+ inputs.push_back(EValue(tokens_tensor));
+ inputs.push_back(EValue(pos_tensor));
+#ifdef EXECUTORCH_BUILD_CUDA
+ set_temp(temperature);
+ inputs.push_back(EValue(temp_tensor_));
+ const char* method = (T >= min_prefill_chunk_) ? "prefill" : "decode";
+#else
+ (void)inputs;
+ (void)temperature;
+ return Error::NotSupported;
+#endif
+ return run_locked(method, inputs, temperature, /*sync_after=*/true);
+ }
+
+ Result run_locked(
+ const char* method,
+ std::vector& inputs,
+ float temperature,
+ bool sync_after) {
+ std::lock_guard guard(*exec_mutex_);
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (mutable_ctx_ != 0) {
+ ::executorch::backends::cuda::mutable_state_set_active(
+ mutable_ctx_, session_token_);
+ }
+#endif
+ auto res = module_->execute(method, inputs);
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (mutable_ctx_ != 0) {
+ ::executorch::backends::cuda::mutable_state_set_active(
+ mutable_ctx_, ::executorch::backends::cuda::kNoMutableSession);
+ }
+#endif
+ ET_CHECK_OK_OR_RETURN_ERROR(res.error());
+ auto sampled = read_sampled_token(res.get()[0].toTensor(), temperature);
+ ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (sync_after && cudaDeviceSynchronize() != cudaSuccess) {
+ ET_LOG(Error, "run_locked: cudaDeviceSynchronize failed");
+ return Error::Internal;
+ }
+#else
+ (void)sync_after;
+#endif
+ return sampled.get();
+ }
+
+ Module* module_;
+ std::mutex* exec_mutex_;
+ int mutable_ctx_;
+ int session_token_;
+ std::atomic* live_sessions_;
+ ::tokenizers::Tokenizer* tokenizer_;
+ std::unordered_map metadata_;
+ std::unordered_set eos_ids_;
+ int64_t max_prefill_chunk_;
+ int64_t min_prefill_chunk_;
+
+ int64_t pos_ = 0;
+ std::optional pending_;
+ std::optional prev_decode_token_;
+ float temperature_ = -1.0f;
+ std::atomic stop_{false};
+
+ int64_t decode_token_data_[1] = {0};
+ int64_t decode_pos_data_[1] = {0};
+ TensorPtr decode_tokens_;
+ TensorPtr decode_pos_;
+#ifdef EXECUTORCH_BUILD_CUDA
+ float temp_val_ = 1e-6f;
+ TensorPtr temp_tensor_;
+#endif
+};
+
+} // namespace
+
+Result> Gemma4_31BEngine::create(
+ const Gemma4_31BConfig& config) {
+ if (config.model_path.empty() || config.tokenizer_path.empty()) {
+ ET_LOG(
+ Error, "Gemma4_31BEngine: model_path and tokenizer_path are required");
+ return Error::InvalidArgument;
+ }
+
+ auto tokenizer = std::make_unique<::tokenizers::HFTokenizer>();
+ if (tokenizer->load(config.tokenizer_path) != ::tokenizers::Error::Ok) {
+ ET_LOG(Error, "Gemma4_31BEngine: failed to load tokenizer");
+ return Error::InvalidArgument;
+ }
+
+ std::vector data_files;
+ if (!config.data_path.empty()) {
+ data_files.push_back(config.data_path);
+ }
+ auto meta_module = std::make_unique(
+ config.model_path, data_files, Module::LoadMode::File);
+ auto metadata_result = get_llm_metadata(tokenizer.get(), meta_module.get());
+ if (metadata_result.error() != Error::Ok) {
+ ET_LOG(Error, "Gemma4_31BEngine: failed to read metadata");
+ return metadata_result.error();
+ }
+
+ auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get());
+ eos_ids.insert(static_cast(config.eos_id));
+ add_token_piece(tokenizer.get(), eos_ids, "");
+ add_token_piece(tokenizer.get(), eos_ids, "");
+
+ const auto& metadata = metadata_result.get();
+ int64_t max_prefill_chunk = 1;
+ auto max_ctx_it = metadata.find(kMaxContextLen);
+ if (max_ctx_it != metadata.end() && max_ctx_it->second > 1) {
+ max_prefill_chunk = max_ctx_it->second - 1;
+ }
+ if (auto get_result = meta_module->get("get_max_prefill_chunk");
+ get_result.ok()) {
+ max_prefill_chunk = get_result->toScalar().to();
+ }
+ int64_t min_prefill_chunk = 1;
+#ifdef EXECUTORCH_BUILD_CUDA
+ min_prefill_chunk = 5;
+ if (auto get_result = meta_module->get("get_min_prefill_chunk");
+ get_result.ok()) {
+ min_prefill_chunk = get_result->toScalar().to();
+ }
+#endif
+
+ bool registered_mutable = false;
+ int mutable_ctx = 0;
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (!config.enable_cuda_graph) {
+ mutable_ctx = ::executorch::backends::cuda::mutable_state_create_context();
+ if (register_mutable_fqns(meta_module.get(), mutable_ctx) == Error::Ok) {
+ registered_mutable = true;
+ ::executorch::backends::cuda::mutable_state_begin_load(mutable_ctx);
+ } else {
+ ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx);
+ mutable_ctx = 0;
+ }
+ }
+#endif
+
+ auto module_res = build_gemma_module(config);
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (registered_mutable) {
+ ::executorch::backends::cuda::mutable_state_end_load();
+ }
+#endif
+ if (module_res.error() != Error::Ok) {
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (mutable_ctx != 0) {
+ ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx);
+ }
+#endif
+ return module_res.error();
+ }
+
+ bool rebind_available = false;
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (mutable_ctx != 0) {
+ rebind_available =
+ ::executorch::backends::cuda::mutable_state_available(mutable_ctx);
+ if (rebind_available &&
+ ::executorch::backends::cuda::mutable_state_validate_coverage(
+ mutable_ctx) != Error::Ok) {
+ ET_LOG(
+ Error,
+ "Gemma4_31BEngine: mutable-buffer coverage check failed; disabling "
+ "multi-session");
+ rebind_available = false;
+ }
+ }
+#endif
+
+ return std::unique_ptr(new Gemma4_31BEngine(
+ config,
+ std::move(tokenizer),
+ metadata,
+ std::move(eos_ids),
+ std::move(module_res.get()),
+ max_prefill_chunk,
+ min_prefill_chunk,
+ rebind_available,
+ mutable_ctx));
+}
+
+Gemma4_31BEngine::~Gemma4_31BEngine() {
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (mutable_ctx_ != 0) {
+ ::executorch::backends::cuda::mutable_state_destroy_context(mutable_ctx_);
+ }
+#endif
+}
+
+Result> Gemma4_31BEngine::create_session() {
+ const int cap =
+ serving_capacity().max_physical_sessions_without_weight_duplication;
+ {
+ std::lock_guard g(exec_mutex_);
+ if (live_sessions_.load() >= cap) {
+ return Error::InvalidState;
+ }
+ live_sessions_.fetch_add(1);
+ }
+
+ int token = -1;
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (rebind_available_) {
+ auto t = ::executorch::backends::cuda::mutable_state_create_session(
+ mutable_ctx_);
+ if (t.error() != Error::Ok) {
+ live_sessions_.fetch_sub(1);
+ return t.error();
+ }
+ token = t.get();
+ }
+#endif
+ return std::unique_ptr(new Gemma4_31BSession(
+ shared_module_.get(),
+ &exec_mutex_,
+ mutable_ctx_,
+ token,
+ &live_sessions_,
+ tokenizer_.get(),
+ metadata_,
+ eos_ids_,
+ max_prefill_chunk_,
+ min_prefill_chunk_));
+}
+
+LLMServingCapacity Gemma4_31BEngine::serving_capacity() const {
+ LLMServingCapacity cap;
+#ifdef EXECUTORCH_BUILD_CUDA
+ if (rebind_available_) {
+ cap.max_physical_sessions_without_weight_duplication =
+ config_.max_sessions > 1 ? config_.max_sessions : 1;
+ cap.estimated_bytes_per_session =
+ ::executorch::backends::cuda::mutable_state_bytes_per_session(
+ mutable_ctx_);
+ }
+#endif
+ return cap;
+}
+
+} // namespace executorch::extension::llm
diff --git a/examples/models/gemma4_31b/gemma4_31b_engine.h b/examples/models/gemma4_31b/gemma4_31b_engine.h
new file mode 100644
index 00000000000..92eaf1b02da
--- /dev/null
+++ b/examples/models/gemma4_31b/gemma4_31b_engine.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace executorch::extension::llm {
+
+struct Gemma4_31BConfig {
+ std::string model_path;
+ std::string data_path;
+ std::string tokenizer_path;
+ int32_t max_sessions = 1;
+ int64_t eos_id = 1;
+ bool enable_cuda_graph = false;
+};
+
+class ET_EXPERIMENTAL Gemma4_31BEngine : public LLMEngine {
+ public:
+ static ::executorch::runtime::Result>
+ create(const Gemma4_31BConfig& config);
+
+ ~Gemma4_31BEngine() override;
+
+ ::executorch::runtime::Result> create_session()
+ override;
+
+ LLMServingCapacity serving_capacity() const override;
+
+ const std::unordered_map& metadata() const override {
+ return metadata_;
+ }
+
+ ::tokenizers::Tokenizer* tokenizer() const {
+ return tokenizer_.get();
+ }
+
+ Gemma4_31BEngine(const Gemma4_31BEngine&) = delete;
+ Gemma4_31BEngine& operator=(const Gemma4_31BEngine&) = delete;
+
+ private:
+ Gemma4_31BEngine(
+ Gemma4_31BConfig config,
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
+ std::unordered_map metadata,
+ std::unordered_set eos_ids,
+ std::unique_ptr shared_module,
+ int64_t max_prefill_chunk,
+ int64_t min_prefill_chunk,
+ bool rebind_available,
+ int mutable_ctx)
+ : config_(std::move(config)),
+ tokenizer_(std::move(tokenizer)),
+ metadata_(std::move(metadata)),
+ eos_ids_(std::move(eos_ids)),
+ shared_module_(std::move(shared_module)),
+ max_prefill_chunk_(max_prefill_chunk),
+ min_prefill_chunk_(min_prefill_chunk),
+ rebind_available_(rebind_available),
+ mutable_ctx_(mutable_ctx) {}
+
+ Gemma4_31BConfig config_;
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
+ std::unordered_map metadata_;
+ std::unordered_set eos_ids_;
+ std::unique_ptr shared_module_;
+ std::mutex exec_mutex_;
+ int64_t max_prefill_chunk_ = 0;
+ int64_t min_prefill_chunk_ = 1;
+ bool rebind_available_ = false;
+ int mutable_ctx_ = 0;
+ std::atomic live_sessions_{0};
+};
+
+} // namespace executorch::extension::llm
diff --git a/examples/models/gemma4_31b/gemma4_31b_worker.cpp b/examples/models/gemma4_31b/gemma4_31b_worker.cpp
new file mode 100644
index 00000000000..197f1571269
--- /dev/null
+++ b/examples/models/gemma4_31b/gemma4_31b_worker.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+#include
+#include
+#include
+
+DEFINE_string(model_path, "", "Model .pte file path.");
+DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
+DEFINE_string(data_path, "", "Data file (.ptd) for delegated weights.");
+DEFINE_int32(
+ max_sessions,
+ 1,
+ "Max physical sessions to host on one weight allocation. CUDA may raise "
+ "this when per-session mutable rebinding is available.");
+DEFINE_bool(
+ warm_resume,
+ true,
+ "Warm append-only resume for named sessions when the engine supports them.");
+DEFINE_int32(bos_id, 2, "BOS token id to prepend to every Gemma prompt.");
+DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
+
+namespace {
+namespace llm = ::executorch::extension::llm;
+using ::executorch::runtime::Error;
+} // namespace
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ if (FLAGS_model_path.empty() || FLAGS_tokenizer_path.empty()) {
+ ET_LOG(
+ Error, "gemma4_31b_worker: --model_path and --tokenizer_path required");
+ return 1;
+ }
+
+ llm::Gemma4_31BConfig config;
+ config.model_path = FLAGS_model_path;
+ config.data_path = FLAGS_data_path;
+ config.tokenizer_path = FLAGS_tokenizer_path;
+ config.max_sessions = FLAGS_max_sessions;
+ config.eos_id = FLAGS_eos_id;
+
+ auto engine_result = llm::Gemma4_31BEngine::create(config);
+ if (engine_result.error() != Error::Ok) {
+ ET_LOG(Error, "gemma4_31b_worker: failed to create engine");
+ return 1;
+ }
+ auto engine = std::move(engine_result.get());
+
+ return llm::run_worker_stdio_loop(
+ *engine,
+ *engine->tokenizer(),
+ engine->metadata(),
+ FLAGS_warm_resume,
+ {static_cast(FLAGS_bos_id)});
+}
diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp
index 6cf65cc8246..7a3dfcf89ba 100644
--- a/examples/models/gemma4_31b/main.cpp
+++ b/examples/models/gemma4_31b/main.cpp
@@ -6,6 +6,174 @@
* LICENSE file in the root directory of this source tree.
*/
+#ifdef EXECUTORCH_BUILD_CUDA
+
+// Thin CUDA CLI over Gemma4_31BEngine / LLMSession. The non-CUDA legacy runner
+// remains below for the existing MLX target; serving is CUDA-only for now.
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+DEFINE_string(model_path, "", "Model .pte file path.");
+DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend.");
+DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
+DEFINE_string(prompt, "Hello", "Prompt text.");
+DEFINE_string(
+ prompt_file,
+ "",
+ "Path to file containing prompt text (overrides --prompt).");
+DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
+DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
+DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
+DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
+DEFINE_bool(
+ raw_prompt,
+ false,
+ "Skip chat-template wrapping (use if the prompt is already formatted).");
+DEFINE_bool(
+ cuda_graph,
+ false,
+ "Enable CUDA graph capture for the decode method. CUDA only.");
+
+namespace llm = ::executorch::extension::llm;
+using ::executorch::runtime::Error;
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ if (FLAGS_model_path.empty()) {
+ ET_LOG(Error, "Must specify --model_path");
+ return 1;
+ }
+ if (FLAGS_tokenizer_path.empty()) {
+ ET_LOG(Error, "Must specify --tokenizer_path");
+ return 1;
+ }
+
+ llm::Stats stats;
+ size_t gpu_free_bytes = 0, gpu_total_bytes = 0;
+ if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) {
+ stats.gpu_total_bytes = gpu_total_bytes;
+ stats.gpu_free_before_load_bytes = gpu_free_bytes;
+ }
+
+ stats.model_load_start_ms = llm::time_in_ms();
+
+ llm::Gemma4_31BConfig config;
+ config.model_path = FLAGS_model_path;
+ config.data_path = FLAGS_data_path;
+ config.tokenizer_path = FLAGS_tokenizer_path;
+ config.eos_id = FLAGS_eos_id;
+ config.enable_cuda_graph = FLAGS_cuda_graph;
+
+ printf("Loading methods...\n");
+ auto engine_result = llm::Gemma4_31BEngine::create(config);
+ if (engine_result.error() != Error::Ok) {
+ ET_LOG(Error, "Failed to create Gemma 4 31B engine");
+ return 1;
+ }
+ auto engine = std::move(engine_result.get());
+
+ auto session_result = engine->create_session();
+ if (session_result.error() != Error::Ok) {
+ ET_LOG(Error, "Failed to create session");
+ return 1;
+ }
+ auto session = std::move(session_result.get());
+
+ stats.model_load_end_ms = llm::time_in_ms();
+ if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) {
+ stats.gpu_free_after_load_bytes = gpu_free_bytes;
+ }
+
+ std::string prompt_text = FLAGS_prompt;
+ if (!FLAGS_prompt_file.empty()) {
+ std::ifstream f(FLAGS_prompt_file);
+ if (!f.is_open()) {
+ ET_LOG(
+ Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str());
+ return 1;
+ }
+ prompt_text = std::string(
+ (std::istreambuf_iterator(f)), std::istreambuf_iterator());
+ }
+
+ if (!FLAGS_raw_prompt) {
+ prompt_text = "<|turn>user\n" + prompt_text +
+ "\n<|turn>model\n<|channel>thought\n";
+ }
+
+ auto encode_result = engine->tokenizer()->encode(prompt_text);
+ if (!encode_result.ok()) {
+ ET_LOG(Error, "Failed to encode prompt");
+ return 1;
+ }
+ auto prompt_tokens = std::move(*encode_result);
+ prompt_tokens.insert(
+ prompt_tokens.begin(), static_cast(FLAGS_bos_id));
+ const int64_t num_prompt_tokens = static_cast(prompt_tokens.size());
+ printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens);
+ stats.num_prompt_tokens = num_prompt_tokens;
+
+ llm::SamplingConfig sampling;
+ sampling.temperature = static_cast(FLAGS_temperature);
+ stats.inference_start_ms = llm::time_in_ms();
+ if (session->prefill_tokens(prompt_tokens, &sampling) != Error::Ok) {
+ ET_LOG(Error, "Prefill failed");
+ return 1;
+ }
+ stats.prompt_eval_end_ms = llm::time_in_ms();
+ stats.first_token_ms = stats.prompt_eval_end_ms;
+
+ int64_t num_generated = 0;
+ for (int32_t step = 0; step < FLAGS_max_new_tokens; ++step) {
+ auto step_result = session->decode_one(sampling);
+ if (step_result.error() != Error::Ok) {
+ ET_LOG(Error, "Decode step %d failed", step);
+ return 1;
+ }
+ const auto& d = step_result.get();
+ if (d.is_terminal) {
+ break;
+ }
+ if (step == 0) {
+ stats.first_token_ms = llm::time_in_ms();
+ }
+ ++num_generated;
+ if (!d.text_piece.empty()) {
+ fwrite(d.text_piece.data(), 1, d.text_piece.size(), stdout);
+ fflush(stdout);
+ }
+ }
+ printf("\n");
+
+ stats.inference_end_ms = llm::time_in_ms();
+ stats.num_generated_tokens = num_generated;
+ if (cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes) == cudaSuccess) {
+ stats.gpu_free_after_generate_bytes = gpu_free_bytes;
+ stats.gpu_peak_usage_mb =
+ (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0;
+ }
+
+ llm::print_report(stats);
+ return 0;
+}
+
+#else
+
// Gemma 4 31B-IT runner for ExecuTorch. Supports two backends:
// CUDA — exports ``prefill`` (T>=2, dynamic) + ``decode`` (T=1, static)
// methods sharing KV-cache buffers; on-device Gumbel-max sampling
@@ -416,3 +584,5 @@ int main(int argc, char** argv) {
llm::print_report(stats);
return 0;
}
+
+#endif // EXECUTORCH_BUILD_CUDA
diff --git a/examples/models/gemma4_31b/serve.py b/examples/models/gemma4_31b/serve.py
new file mode 100644
index 00000000000..549759ce8a9
--- /dev/null
+++ b/examples/models/gemma4_31b/serve.py
@@ -0,0 +1,160 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""OpenAI-compatible HTTP server for Gemma 4 31B on CUDA."""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+from executorch.extension.llm.server.python.chat_template import ChatTemplate
+from executorch.extension.llm.server.python.serving_chat import ServingChat
+from executorch.extension.llm.server.python.session_runtime import SessionRuntime
+from executorch.extension.llm.server.python.tool_parsers import (
+ HermesDetector,
+ QwenFunctionCallDetector,
+)
+from executorch.extension.llm.server.python.worker_client import spawn_worker
+
+logger = logging.getLogger(__name__)
+
+
+def _default_worker_bin() -> str:
+ repo_root = Path(__file__).resolve().parents[3]
+ return str(
+ repo_root
+ / "cmake-out"
+ / "examples"
+ / "models"
+ / "gemma4_31b"
+ / "gemma4_31b_worker"
+ )
+
+
+def _spawn(args):
+ env = dict(os.environ)
+ conda = os.environ.get("CONDA_PREFIX")
+ if conda:
+ env["LD_LIBRARY_PATH"] = f"{conda}/lib:" + env.get("LD_LIBRARY_PATH", "")
+ worker_bin = args.worker_bin or _default_worker_bin()
+ cmd = [
+ worker_bin,
+ "--model_path",
+ args.model_path,
+ "--tokenizer_path",
+ args.tokenizer_path,
+ "--max_sessions",
+ str(args.max_sessions),
+ f"--warm_resume={'true' if args.warm_resume else 'false'}",
+ "--bos_id",
+ str(args.bos_id),
+ "--eos_id",
+ str(args.eos_id),
+ ]
+ if args.data_path:
+ cmd += ["--data_path", args.data_path]
+ logger.info("Starting Gemma4 31B worker subprocess...")
+ return spawn_worker(cmd, env=env)
+
+
+def _tool_detector(name: str):
+ if name == "hermes":
+ return HermesDetector
+ if name == "qwen":
+ return QwenFunctionCallDetector
+ if name == "none":
+ return None
+ raise ValueError(f"unknown tool parser: {name}")
+
+
+def build_app_from_args(args):
+ template = ChatTemplate(args.hf_tokenizer)
+ worker = _spawn(args)
+ runtime = SessionRuntime(worker)
+ serving = ServingChat(
+ runtime,
+ template,
+ args.model_id,
+ max_context=args.max_context,
+ tool_detector_cls=_tool_detector(args.tool_parser),
+ prompt_token_offset=1,
+ )
+
+ from executorch.extension.llm.server.python.server import build_app
+
+ app = build_app(serving, args.model_id)
+
+ @app.on_event("shutdown")
+ def _stop_worker():
+ runtime.close_worker()
+
+ return app, args.model_id
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(
+ description="OpenAI-compatible CUDA LLM server for Gemma 4 31B"
+ )
+ p.add_argument("--model-path", required=True, help="Path to the .pte model")
+ p.add_argument("--data-path", default=None, help="Path to the .ptd delegate blob")
+ p.add_argument("--tokenizer-path", required=True, help="Path to the tokenizer.json")
+ p.add_argument(
+ "--hf-tokenizer",
+ required=True,
+ help="HF tokenizer id/dir for the model's chat template",
+ )
+ p.add_argument("--model-id", default="gemma4-31b")
+ p.add_argument("--host", default="127.0.0.1")
+ p.add_argument("--port", type=int, default=8000)
+ p.add_argument("--max-context", type=int, default=None)
+ p.add_argument(
+ "--num-runners",
+ type=int,
+ default=1,
+ help="Worker processes. 1 only; more would duplicate the weights.",
+ )
+ p.add_argument(
+ "--max-sessions",
+ type=int,
+ default=1,
+ help="Isolated sessions the CUDA worker may host when the export has "
+ "mutable-buffer metadata.",
+ )
+ p.add_argument(
+ "--warm-resume",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Warm append-only resume for named sessions when available.",
+ )
+ p.add_argument(
+ "--tool-parser",
+ choices=("hermes", "qwen", "none"),
+ default="hermes",
+ help="Tool-call format parser to apply to model output.",
+ )
+ p.add_argument("--bos-id", type=int, default=2)
+ p.add_argument("--eos-id", type=int, default=1)
+ p.add_argument(
+ "--worker-bin",
+ default=None,
+ help="Path to the gemma4_31b_worker binary.",
+ )
+ args = p.parse_args()
+ logging.basicConfig(level=logging.INFO)
+
+ if args.num_runners != 1:
+ p.error("Only 1 worker process is supported; more would duplicate weights.")
+
+ app, _ = build_app_from_args(args)
+
+ import uvicorn
+
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/models/gemma4_31b/test_serve.py b/examples/models/gemma4_31b/test_serve.py
new file mode 100644
index 00000000000..6ff8e1306e5
--- /dev/null
+++ b/examples/models/gemma4_31b/test_serve.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pathlib
+from types import SimpleNamespace
+
+import pytest
+
+from executorch.examples.models.gemma4_31b import serve
+
+_HERE = pathlib.Path(serve.__file__).resolve().parent
+_REPO_ROOT = _HERE.parents[2]
+
+
+def test_generic_server_does_not_reference_gemma4_31b():
+ server_dir = _REPO_ROOT / "extension/llm/server"
+ offenders = [p for p in server_dir.rglob("*.py") if "gemma4_31b" in p.read_text()]
+ assert offenders == []
+
+
+def test_control_plane_runs_no_model_code():
+ serve_src = (_HERE / "serve.py").read_text()
+ assert "Gemma4_31BEngine" not in serve_src
+ worker_src = (_HERE / "gemma4_31b_worker.cpp").read_text()
+ assert "Gemma4_31BEngine" in worker_src
+
+
+def test_spawn_builds_worker_command(monkeypatch):
+ captured = {}
+
+ def fake_spawn(cmd, env=None):
+ captured["cmd"] = cmd
+ return object()
+
+ monkeypatch.setattr(serve, "spawn_worker", fake_spawn)
+ serve._spawn(
+ SimpleNamespace(
+ worker_bin="/bin/gemma_worker",
+ model_path="m.pte",
+ tokenizer_path="t.json",
+ data_path="d.ptd",
+ max_sessions=4,
+ warm_resume=True,
+ bos_id=2,
+ eos_id=1,
+ )
+ )
+ assert captured["cmd"] == [
+ "/bin/gemma_worker",
+ "--model_path",
+ "m.pte",
+ "--tokenizer_path",
+ "t.json",
+ "--max_sessions",
+ "4",
+ "--warm_resume=true",
+ "--bos_id",
+ "2",
+ "--eos_id",
+ "1",
+ "--data_path",
+ "d.ptd",
+ ]
+
+
+def test_spawn_defaults_worker_bin_and_omits_empty_data_path(monkeypatch):
+ captured = {}
+ monkeypatch.setattr(
+ serve, "spawn_worker", lambda cmd, env=None: captured.update(cmd=cmd)
+ )
+ serve._spawn(
+ SimpleNamespace(
+ worker_bin=None,
+ model_path="m.pte",
+ tokenizer_path="t.json",
+ data_path=None,
+ max_sessions=1,
+ warm_resume=False,
+ bos_id=2,
+ eos_id=1,
+ )
+ )
+ cmd = captured["cmd"]
+ assert cmd[0].endswith("gemma4_31b_worker")
+ assert "--data_path" not in cmd
+ assert "--warm_resume=false" in cmd
+
+
+def test_rejects_multiple_runners(monkeypatch):
+ import sys
+
+ monkeypatch.setattr(
+ sys,
+ "argv",
+ [
+ "serve.py",
+ "--model-path",
+ "m.pte",
+ "--tokenizer-path",
+ "t.json",
+ "--hf-tokenizer",
+ "hf",
+ "--num-runners",
+ "2",
+ ],
+ )
+ with pytest.raises(SystemExit):
+ serve.main()
diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h
index 3cf4541a4e2..47be1320029 100644
--- a/extension/llm/server/cpp/worker_loop.h
+++ b/extension/llm/server/cpp/worker_loop.h
@@ -105,7 +105,8 @@ inline void worker_handle_request(
bool warm,
::tokenizers::Tokenizer& tokenizer,
const std::unordered_map& metadata,
- const nlohmann::json& req) {
+ const nlohmann::json& req,
+ const std::vector& prompt_prefix_ids = {}) {
LLMSession& session = *st.session;
int64_t max_new = req.value("max_new_tokens", static_cast(-1));
const float temperature = req.value("temperature", 0.0f);
@@ -129,7 +130,7 @@ inline void worker_handle_request(
throw std::runtime_error(
"exactly one of prompt / prompt_segments is required");
}
- std::vector ids;
+ std::vector ids = prompt_prefix_ids;
auto encode_text = [&](const std::string& text) {
auto enc = tokenizer.encode(text, /*bos=*/0, /*eos=*/0);
if (!enc.ok()) {
@@ -397,7 +398,8 @@ inline int run_worker_stdio_loop(
LLMEngine& engine,
::tokenizers::Tokenizer& tokenizer,
const std::unordered_map& metadata,
- bool enable_warm_resume = true) {
+ bool enable_warm_resume = true,
+ const std::vector& prompt_prefix_ids = {}) {
WorkerSessions sessions(engine);
worker_emit(
{{"ready", true},
@@ -465,7 +467,8 @@ inline int run_worker_stdio_loop(
}
warm = enable_warm_resume;
}
- worker_handle_request(*st, warm, tokenizer, metadata, req);
+ worker_handle_request(
+ *st, warm, tokenizer, metadata, req, prompt_prefix_ids);
} catch (const std::exception& e) { // report and keep serving
worker_emit({{"error", std::string(e.what())}});
}
diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py
index 1b85f8fba3d..4d9f9a11f60 100644
--- a/extension/llm/server/python/serving_chat.py
+++ b/extension/llm/server/python/serving_chat.py
@@ -63,11 +63,13 @@ def __init__(
model_id: str,
max_context: Optional[int] = None,
tool_detector_cls: Optional[type[HermesDetector]] = None,
+ prompt_token_offset: int = 0,
):
self._runtime = runtime
self._template = template
self._model_id = model_id
self._max_context = max_context
+ self._prompt_token_offset = prompt_token_offset
# Detector CLASS; a fresh instance is created per request so streaming
# state is never shared across concurrent requests.
self._tool_detector_cls = tool_detector_cls
@@ -347,8 +349,9 @@ def _count_prompt_tokens(self, prompt: PromptInput) -> Optional[int]:
tokenized length of {text} chunks. None when no tokenizer is available to
count text (the worker still enforces the real context limit)."""
if prompt.text is not None:
- return self._template.count_tokens(prompt.text)
- total = 0
+ count = self._template.count_tokens(prompt.text)
+ return None if count is None else count + self._prompt_token_offset
+ total = self._prompt_token_offset
for seg in prompt.segments:
if "ids" in seg:
total += len(seg["ids"])