diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index c8aaec9c850..d13961637d5 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -173,8 +173,8 @@ serve.py (control plane: FastAPI/asyncio, OpenAI protocol, chat templating, tool parsing, validation — NO CUDA, NO pybind) │ JSONL over stdin/stdout ▼ -qwen3_5_moe_worker (C++ binary: one Qwen35MoEEngine + one session, synchronous - loop — the CUDA model; NO asyncio server) +qwen3_5_moe_worker (C++ binary: one Qwen35MoEEngine, many isolated sessions, + synchronous loop — the CUDA model; NO asyncio server) ``` The model runs in a **separate worker process** because executing the AOTI CUDA @@ -196,13 +196,42 @@ is safe under asyncio. | `--host` / `--port` | `127.0.0.1` / `8000` | Bind address | | `--max-context` | (none) | Reject prompts that exceed it with 400 | | `--no-think` | off | Default reasoning off (`enable_thinking=False`) | +| `--max-sessions` | `1` | Isolated sessions on one weight load (see Sessions) | -### V1 limitations +### Sessions + +One worker loads the weights once (~18 GB) and hosts multiple **isolated** +sessions on that single allocation — each with its own KV/recurrent state, via +CUDA per-session mutable rebinding. Set `--max-sessions N` (clamped to 1 if the +backend cannot rebind); one slot is reserved for anonymous requests, so up to +`N - 1` named `session_id`s are addressable. + +Route a request to a persistent session with the `session_id` body field or, as +aliases, the `X-ExecuTorch-Session-ID` / `session_id` / `x-session-affinity` +headers (body wins, then that header order). The header aliases let a client that +already emits a stable per-conversation affinity id (e.g. pi's +`sendSessionAffinityHeaders`) route with no extra config. Requests without any +share a transient scratch session. Free a session with `DELETE /v1/sessions/{id}`. + +```bash +curl http://127.0.0.1:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"qwen3.5-moe","session_id":"alice", + "messages":[{"role":"user","content":"hi"}]}' +``` + +Admission is up front: an explicit `session_id` on a single-session server +returns **400** (`unsupported_session`); past capacity it returns **429** +(`capacity_exhausted`) before any response bytes. + +This is **isolation, not concurrency or warm resume**: execution is still +synchronous (one in-flight request; `--num-runners > 1` is rejected since more +workers would duplicate the weights), and each request resets its session — the +recurrent/conv state cannot be rewound by position (`seek()` is NotSupported), so +turn-to-turn KV reuse (append-only warm resume) is a follow-up. + +### Other limitations -- **Single-slot** (`serving_capacity=1`): one worker, one session, one model - load. `--num-runners > 1` is rejected; concurrent requests queue on the worker. -- **No prefix cache**: the recurrent/conv state cannot be rewound by position - (`seek()` is NotSupported), so turn-to-turn KV reuse is off. - Supports the chat-completions contract of the generic server; `top_p != 1`, `seed`, `top_k`, `logprobs`, etc. are rejected (only temperature is plumbed). diff --git a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp index aa94a704bc2..c5018031716 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_worker.cpp @@ -18,23 +18,29 @@ // process segfaults in the int4 matmul (validated). Here the model runs in a // plain synchronous loop in its own process, which is reliable. // -// Single-slot serving: this worker creates one session and the control plane -// queues concurrent requests on it. (The engine itself can host multiple -// sessions on the one ~18GB weight allocation; exposing that over the worker -// protocol is a follow-up.) +// Multi-session (isolation): the engine loads weights once and hosts multiple +// isolated sessions on that one ~18GB allocation; the shared worker loop +// (worker_loop.h) routes requests to per-session_id state, up to +// --max_sessions. Execution is still synchronous (one in-flight request); warm +// context reuse across requests is a follow-up. #include #include -#include #include #include -#include +#include DEFINE_string(model_path, "", "Model .pte file path."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); DEFINE_string(data_path, "", "Data file (.ptd) for the CUDA backend."); +DEFINE_int32( + max_sessions, + 1, + "Max physical sessions to host on the one weight allocation (CUDA " + "per-session mutable rebinding). Clamped to 1 if the backend cannot " + "rebind."); namespace { namespace llm = ::executorch::extension::llm; @@ -54,6 +60,7 @@ int main(int argc, char** argv) { config.model_path = FLAGS_model_path; config.data_path = FLAGS_data_path; config.tokenizer_path = FLAGS_tokenizer_path; + config.max_sessions = FLAGS_max_sessions; auto engine_result = llm::Qwen35MoEEngine::create(config); if (engine_result.error() != Error::Ok) { @@ -62,16 +69,9 @@ int main(int argc, char** argv) { } auto engine = std::move(engine_result.get()); - auto session_result = engine->create_session(); - if (session_result.error() != Error::Ok) { - ET_LOG(Error, "qwen35_moe_worker: failed to create session"); - return 1; - } - auto session = std::move(session_result.get()); - - // The engine's tokenizer encodes the rendered prompt to ids; the session - // decodes ids back to text internally. + // The engine's tokenizer encodes the rendered prompt to ids; sessions decode + // ids back to text internally. The shared loop owns per-session_id state. ::tokenizers::Tokenizer* tokenizer = engine->tokenizer(); - return llm::run_worker_stdio_loop(*session, *tokenizer, engine->metadata()); + return llm::run_worker_stdio_loop(*engine, *tokenizer, engine->metadata()); } diff --git a/examples/models/qwen3_5_moe/serve.py b/examples/models/qwen3_5_moe/serve.py index 229a84425fb..e58ab23516b 100644 --- a/examples/models/qwen3_5_moe/serve.py +++ b/examples/models/qwen3_5_moe/serve.py @@ -17,9 +17,14 @@ CUDA execution while a live asyncio loop is resident). Isolating CUDA in a plain (no-asyncio) C++ worker process is the reliable shape, and it loads weights once. -V1 constraints: - * single-slot: one worker, one session; concurrent HTTP requests queue. - * prefix cache off (Qwen seek() is NotSupported). +Sessions and constraints: + * One worker hosts many isolated sessions on a single ~18GB weight load (CUDA + per-session mutable rebinding); requests route by session_id (anonymous + requests share a scratch session). See --max-sessions. + * Execution is synchronous: one in-flight request at a time, concurrent HTTP + requests queue. Sessions provide isolation, not concurrent throughput. + * No warm context reuse yet: each request resets its session (Qwen seek() is + NotSupported; append-only reuse is a follow-up). * The control plane only does blocking pipe I/O on its executor thread (no CUDA), which is safe under asyncio. @@ -77,6 +82,7 @@ def _spawn(args): ] if args.data_path: cmd += ["--data_path", args.data_path] + cmd += ["--max_sessions", str(args.max_sessions)] logger.info("Starting Qwen worker subprocess (loads the model once)...") return spawn_worker(cmd, env=env) @@ -88,7 +94,7 @@ def build_app_from_args(args): args.hf_tokenizer, default_template_kwargs=default_template_kwargs ) - worker = _spawn(args) # one worker == one session (single-slot V1) + worker = _spawn(args) # one worker, weights once, many isolated sessions runtime = SessionRuntime(worker) serving = ServingChat( runtime, @@ -144,7 +150,17 @@ def main() -> None: "--num-runners", type=int, default=1, - help="V1 supports 1 only (single-slot).", + help="Workers (processes). 1 only: a worker hosts many isolated sessions " + "on one weight load; more workers would duplicate the ~18GB weights.", + ) + p.add_argument( + "--max-sessions", + type=int, + default=1, + help="Isolated sessions the one worker hosts on a single weight load " + "(CUDA per-session mutable rebinding); clamped to 1 if the backend " + "cannot rebind. One slot is reserved for anonymous requests, so the " + "number of addressable session_ids is max-sessions - 1.", ) p.add_argument( "--worker-bin", @@ -157,8 +173,8 @@ def main() -> None: if args.num_runners != 1: p.error( - "Qwen3.5 MoE V1 is single-slot: one worker serves one session; " - "concurrent requests queue." + "Only 1 worker process is supported (it hosts many isolated sessions " + "on one ~18GB weight load); more workers would duplicate the weights." ) app, _ = build_app_from_args(args) diff --git a/examples/models/qwen3_5_moe/test_serve.py b/examples/models/qwen3_5_moe/test_serve.py index fdaa6a1ea62..f8768ef39ce 100644 --- a/examples/models/qwen3_5_moe/test_serve.py +++ b/examples/models/qwen3_5_moe/test_serve.py @@ -8,7 +8,7 @@ Hermetic: no model, GPU, or worker subprocess. Covers layering (Qwen stays an example; the control plane runs no CUDA and imports no model pybind), the worker -spawn command, and the single-slot CLI guard. The generic JSONL protocol is +spawn command, and the single-worker CLI guard. The generic JSONL protocol is covered by extension/llm/server/python/tests/test_worker_client.py; the live HTTP smoke test is documented in README.md and run on a CUDA box. """ @@ -75,6 +75,7 @@ def fake_spawn(cmd, env=None): model_path="m.pte", tokenizer_path="t.json", data_path="d.ptd", + max_sessions=4, ) ) assert captured["cmd"] == [ @@ -85,6 +86,8 @@ def fake_spawn(cmd, env=None): "t.json", "--data_path", "d.ptd", + "--max_sessions", + "4", ] @@ -95,7 +98,11 @@ def test_spawn_defaults_worker_bin_and_omits_empty_data_path(monkeypatch): ) serve._spawn( SimpleNamespace( - worker_bin=None, model_path="m.pte", tokenizer_path="t.json", data_path=None + worker_bin=None, + model_path="m.pte", + tokenizer_path="t.json", + data_path=None, + max_sessions=4, ) ) cmd = captured["cmd"] diff --git a/extension/llm/server/cpp/text_llm_worker.cpp b/extension/llm/server/cpp/text_llm_worker.cpp index f7bb9d69915..950ac91c4bd 100644 --- a/extension/llm/server/cpp/text_llm_worker.cpp +++ b/extension/llm/server/cpp/text_llm_worker.cpp @@ -12,25 +12,24 @@ // the stable serving abstraction) — no Python model code, no pybind, no // in-process Python serving. The OpenAI control plane (Python) spawns this // process and drives it over JSONL on stdin/stdout (see worker_client.py). The -// JSONL protocol and the decode loop are shared across all workers in -// worker_loop.h; this file only constructs the engine/session/tokenizer. +// JSONL protocol, session management, and the decode loop are shared across all +// workers in worker_loop.h; this file only constructs the engine/tokenizer. +// TextLLMEngine hosts a single session, so the worker serves anonymous requests +// via the shared loop's scratch session and reports no named sessions. #include #include -#include #include #include #include -#include DEFINE_string(model_path, "", "Self-contained model .pte file path."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); namespace { namespace llm = ::executorch::extension::llm; -using ::executorch::runtime::Error; } // namespace int main(int argc, char** argv) { @@ -50,12 +49,6 @@ int main(int argc, char** argv) { ET_LOG(Error, "text_llm_worker: failed to create engine"); return 1; } - auto session_result = engine->create_session(); - if (session_result.error() != Error::Ok) { - ET_LOG(Error, "text_llm_worker: failed to create session"); - return 1; - } - auto session = std::move(session_result.get()); // The session decodes token ids to text internally; this tokenizer encodes // the rendered prompt to ids. Same tokenizer.json -> same vocabulary. @@ -65,5 +58,5 @@ int main(int argc, char** argv) { return 1; } - return llm::run_worker_stdio_loop(*session, *tokenizer, engine->metadata()); + return llm::run_worker_stdio_loop(*engine, *tokenizer, engine->metadata()); } diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h index 883bcac69cd..a7ec92b08f2 100644 --- a/extension/llm/server/cpp/worker_loop.h +++ b/extension/llm/server/cpp/worker_loop.h @@ -10,23 +10,45 @@ // Shared model-worker generation loop + JSONL protocol, used by every model // worker (the generic text_llm_worker and model-specific workers like -// qwen3_5_moe_worker). A worker only constructs its engine/session/tokenizer -// and calls run_worker_stdio_loop(); the protocol and the decode loop live here -// once, so protocol changes (e.g. multi-session) land in a single place. +// qwen3_5_moe_worker). A worker only constructs its engine/tokenizer and calls +// run_worker_stdio_loop(); the protocol, session management, and the decode +// loop live here once, so protocol changes land in a single place. +// +// V2a (isolation): the worker owns one LLMEngine (weights loaded once) and +// hands out multiple isolated LLMSessions keyed by session_id, each with its +// own KV/recurrent state, up to the engine's serving capacity. Execution is +// still synchronous -- one in-flight request at a time, the control plane +// serializes -- so this proves "one model, many isolated contexts without +// duplicating weights", NOT concurrent streaming. It also does NOT yet reuse +// context across requests: worker_handle_request() resets the session at the +// top of every request (warm append-only resume is a follow-up). +// +// Sessions: +// - Named: an explicit session_id -> LLMSession, created on first use (or via +// an `open` op), capped at max_named_sessions = capacity - 1 (the scratch +// slot is reserved). 0 when the backend can host only one session. +// - Scratch: one session for anonymous requests (no session_id), reset each +// request -- preserves the original single-session behavior. // // Protocol (one JSON object per line; matches worker_client.py): -// worker -> stdout, once: {"ready": true} -// client -> stdin, per request: {"prompt": str, "max_new_tokens": int, -// "temperature": float, "stop": [str, ...]} -// worker -> stdout, per request: {"token": str} * (streamed) -// {"done": true, "prompt_tokens": int, -// "completion_tokens": int, -// "finish_reason": "stop" | "length"} -// or {"error": str} +// worker -> stdout, once: {"ready": true, "max_sessions": int, +// "max_named_sessions": int} +// client -> stdin: +// generate: {"prompt": str, "max_new_tokens": int, "temperature": float, +// "stop": [str, ...], "session_id"?: str} +// open: {"op": "open", "session_id": str} +// close: {"op": "close", "session_id": str} +// worker -> stdout: +// generate: {"token": str} * (streamed) +// {"done": true, "prompt_tokens": int, +// "completion_tokens": int, "finish_reason": "stop"|"length"} +// open: {"opened": true, "session_id": str} +// close: {"closed": true, "session_id": str} +// error: {"error": str, "code"?: str} // code: "capacity_exhausted", +// // "unsupported_session" // // stdout carries ONLY protocol JSON; all logs go to stderr (ET_LOG). One -// request at a time (the control plane serializes; V1 is one worker == one -// session). +// request at a time (the control plane serializes). #include @@ -35,8 +57,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -156,7 +180,7 @@ inline void worker_handle_request( } } // finish_reason: "stop" if the model emitted EOS or hit a stop string, else - // "length" — it ran to max_new (possibly clamped to the context window). + // "length" -- it ran to max_new (possibly clamped to the context window). worker_emit( {{"done", true}, {"prompt_tokens", num_prompt}, @@ -164,22 +188,141 @@ inline void worker_handle_request( {"finish_reason", finish}}); } -// Emit {"ready": true}, then read JSONL requests from stdin and dispatch each -// to worker_handle_request, reporting exceptions as {"error": ...} and +// Owns the engine's sessions for one worker: named sessions keyed by id plus a +// single scratch session for anonymous requests. Single-threaded (driven by the +// stdio loop), so no internal locking. +class WorkerSessions { + public: + explicit WorkerSessions(LLMEngine& engine) + : engine_(engine), + // Reserve one capacity slot for the scratch (anonymous) session when + // the backend can host more than one; a single-session backend hosts + // only the scratch and reports 0 named sessions. + max_named_(std::max( + 0, + engine.serving_capacity() + .max_physical_sessions_without_weight_duplication - + 1)) {} + + int32_t max_named() const { + return max_named_; + } + + // Resolve (and admit, creating on first use) a named session. Returns nullptr + // and sets code on failure: "unsupported_session" when the backend hosts no + // named sessions, "capacity_exhausted" when all named slots are taken. + LLMSession* open_named(const std::string& id, std::string& code) { + auto it = named_.find(id); + if (it != named_.end()) { + return it->second.get(); // idempotent open / reuse across requests + } + if (max_named_ == 0) { + code = "unsupported_session"; + return nullptr; + } + if (static_cast(named_.size()) >= max_named_) { + code = "capacity_exhausted"; + return nullptr; + } + auto result = engine_.create_session(); + if (result.error() != ::executorch::runtime::Error::Ok) { + code = "capacity_exhausted"; // engine-side capacity backstop + return nullptr; + } + auto* session = result.get().get(); + named_.emplace(id, std::move(result.get())); + return session; + } + + // Destroy a named session (freeing its per-session state); idempotent. + void close_named(const std::string& id) { + named_.erase(id); + } + + // The scratch session for anonymous requests, created on first use. Throws if + // the engine cannot create it. + LLMSession* scratch() { + if (!scratch_) { + auto result = engine_.create_session(); + if (result.error() != ::executorch::runtime::Error::Ok) { + throw std::runtime_error("failed to create scratch session"); + } + scratch_ = std::move(result.get()); + } + return scratch_.get(); + } + + private: + LLMEngine& engine_; + int32_t max_named_; + std::unordered_map> named_; + std::unique_ptr scratch_; +}; + +// Emit {"ready": true, ...}, then read JSONL requests from stdin and dispatch +// each (generate / open / close), reporting exceptions as {"error": ...} and // continuing to serve. Returns 0 when stdin closes. inline int run_worker_stdio_loop( - LLMSession& session, + LLMEngine& engine, ::tokenizers::Tokenizer& tokenizer, const std::unordered_map& metadata) { - worker_emit({{"ready", true}}); + WorkerSessions sessions(engine); + worker_emit( + {{"ready", true}, + {"max_sessions", + engine.serving_capacity() + .max_physical_sessions_without_weight_duplication}, + {"max_named_sessions", sessions.max_named()}}); + std::string line; while (std::getline(std::cin, line)) { if (line.empty()) { continue; } try { - worker_handle_request( - session, tokenizer, metadata, nlohmann::json::parse(line)); + const nlohmann::json req = nlohmann::json::parse(line); + const std::string op = req.value("op", std::string{}); + + if (op == "open" || op == "close") { + const std::string id = req.at("session_id").get(); + if (id.empty()) { + throw std::runtime_error("session_id required for op"); + } + if (op == "close") { + sessions.close_named(id); + worker_emit({{"closed", true}, {"session_id", id}}); + continue; + } + std::string code; + if (sessions.open_named(id, code) == nullptr) { + worker_emit( + {{"error", "cannot open session"}, + {"code", code}, + {"session_id", id}}); + } else { + worker_emit({{"opened", true}, {"session_id", id}}); + } + continue; + } + + // Generation. A session_id routes to its named session (admitted on first + // use); its absence uses the shared scratch session. + const std::string id = req.value("session_id", std::string{}); + LLMSession* session = nullptr; + if (id.empty()) { + session = sessions.scratch(); + } else { + std::string code; + session = sessions.open_named(id, code); + if (session == nullptr) { + worker_emit( + {{"error", "cannot open session"}, + {"code", code}, + {"session_id", id}}); + continue; + } + } + worker_handle_request(*session, tokenizer, metadata, req); } catch (const std::exception& e) { // report and keep serving worker_emit({{"error", std::string(e.what())}}); } diff --git a/extension/llm/server/python/errors.py b/extension/llm/server/python/errors.py index f24df43f2e8..4da6ef5a6d1 100644 --- a/extension/llm/server/python/errors.py +++ b/extension/llm/server/python/errors.py @@ -60,3 +60,36 @@ def __init__(self, detail: str): super().__init__( status=500, message=f"Generation failed: {detail}", err_type="server_error" ) + + +class InvalidSessionId(APIError): + def __init__(self, detail: str): + super().__init__( + status=400, + message=f"Invalid session_id: {detail}", + err_type="invalid_request_error", + code="invalid_session_id", + ) + + +class SessionCapacity(APIError): + """A worker rejected an explicit session_id: the backend hosts no named + sessions (unsupported_session -> 400) or all session slots are taken + (capacity_exhausted -> 429).""" + + def __init__(self, code: str): + if code == "unsupported_session": + super().__init__( + status=400, + message="This server hosts a single session; omit session_id.", + err_type="invalid_request_error", + code=code, + ) + else: + super().__init__( + status=429, + message="Session capacity exhausted; reuse an existing session_id " + "or retry later.", + err_type="capacity_error", + code="capacity_exhausted", + ) diff --git a/extension/llm/server/python/protocol.py b/extension/llm/server/python/protocol.py index 2d73d2d7f64..e383b06c2af 100644 --- a/extension/llm/server/python/protocol.py +++ b/extension/llm/server/python/protocol.py @@ -73,6 +73,11 @@ class ChatCompletionRequest(BaseModel): parallel_tool_calls: Optional[bool] = None # Per-request chat-template controls, e.g. {"enable_thinking": false} for Qwen3. chat_template_kwargs: Optional[dict[str, Any]] = None + # Vendor extension: route this request to a persistent, isolated session (its + # own KV/recurrent context) on a multi-session worker; requests sharing a + # session_id continue the same context. Anonymous (a transient scratch + # session) when unset. Also accepted via the X-ExecuTorch-Session-ID header. + session_id: Optional[str] = None # Accepted now so the contract is stable; parsing/enforcement land in M2/M5. tools: Optional[list[dict[str, Any]]] = None tool_choice: Optional[Union[str, dict[str, Any]]] = None diff --git a/extension/llm/server/python/server.py b/extension/llm/server/python/server.py index 94c55479275..454b7d87d44 100644 --- a/extension/llm/server/python/server.py +++ b/extension/llm/server/python/server.py @@ -30,7 +30,9 @@ import os from pathlib import Path -from fastapi import FastAPI +from typing import Optional + +from fastapi import FastAPI, Header from fastapi.responses import JSONResponse, StreamingResponse from .chat_template import ChatTemplate @@ -75,6 +77,21 @@ def _spawn(args): return spawn_worker(cmd, env=env) +def _resolve_session_id( + req: ChatCompletionRequest, + x_executorch_session_id: Optional[str], + session_id_header: Optional[str], + x_session_affinity: Optional[str], +) -> Optional[str]: + # Session id precedence: body field wins, else the X-ExecuTorch-Session-ID / + # session_id / x-session-affinity headers (in that order). Aliases let clients + # that already emit a stable per-conversation id for cache affinity (e.g. pi's + # sendSessionAffinityHeaders) route to a session with no extra config. + if req.session_id is not None: + return req.session_id + return x_executorch_session_id or session_id_header or x_session_affinity + + def build_app(serving: ServingChat, model_id: str) -> FastAPI: app = FastAPI(title="ExecuTorch LLM Server") @@ -87,10 +104,22 @@ async def list_models() -> ModelList: return ModelList(data=[ModelCard(id=model_id)]) @app.post("/v1/chat/completions") - async def chat_completions(req: ChatCompletionRequest): + async def chat_completions( + req: ChatCompletionRequest, + # FastAPI dependency: the Header() call in the default is required. + # `session_id` is matched verbatim (underscore). + x_executorch_session_id: Optional[str] = Header(default=None), # noqa: B008 + session_id_header: Optional[str] = Header( # noqa: B008 + default=None, alias="session_id" + ), + x_session_affinity: Optional[str] = Header(default=None), # noqa: B008 + ): # Typed param → FastAPI validates the body and returns 422 on bad input. # APIError (e.g. context_length_exceeded) → structured 4xx/5xx, never a # dropped connection. Mid-stream failures are handled inside the stream. + req.session_id = _resolve_session_id( + req, x_executorch_session_id, session_id_header, x_session_affinity + ) try: result = await serving.create(req) except APIError as e: @@ -99,6 +128,15 @@ async def chat_completions(req: ChatCompletionRequest): return StreamingResponse(result, media_type="text/event-stream") return JSONResponse(result.model_dump(exclude_none=True)) + @app.delete("/v1/sessions/{session_id}") + async def close_session(session_id: str): + # Free a named session's state + capacity slot (vendor extension; idempotent). + try: + await serving.close_session(session_id) + except APIError as e: + return JSONResponse(e.body(), status_code=e.status) + return JSONResponse({"closed": True, "session_id": session_id}) + return app diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index deea41085e2..711f7ed5460 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -15,7 +15,13 @@ from typing import AsyncIterator, Optional from .chat_template import ChatTemplate -from .errors import APIError, ContextLengthExceeded, GenerationError +from .errors import ( + APIError, + ContextLengthExceeded, + GenerationError, + InvalidSessionId, + SessionCapacity, +) from .protocol import ( _new_id, ChatCompletionChunk, @@ -31,9 +37,12 @@ ) from .session_runtime import GenerationOptions, GenStats, PromptInput, SessionRuntime from .tool_parsers import HermesDetector, ToolCallItem +from .worker_client import WorkerError logger = logging.getLogger(__name__) +_SESSION_ID_MAX_LEN = 128 + def _earliest_stop(text: str, stops: list[str]) -> Optional[int]: """Index of the earliest special-token occurrence in `text`, or None.""" @@ -198,6 +207,31 @@ def _options(self, req: ChatCompletionRequest) -> GenerationOptions: stop=self._stops + self._request_stops(req), ) + @staticmethod + def _validate_session_id(session_id: str) -> None: + # Keep it boring: non-empty printable ASCII (no spaces/control), <=128. + if not session_id or len(session_id) > _SESSION_ID_MAX_LEN: + raise InvalidSessionId(f"must be 1-{_SESSION_ID_MAX_LEN} characters") + if not all(0x21 <= ord(c) <= 0x7E for c in session_id): + raise InvalidSessionId("must be printable ASCII with no spaces") + + async def _preflight_session(self, session_id: str) -> None: + """Reserve the session before any response bytes are emitted so a + capacity refusal becomes an HTTP status, not an SSE error event.""" + try: + await self._runtime.open(session_id) + except WorkerError as e: + if e.code in ("capacity_exhausted", "unsupported_session"): + raise SessionCapacity(e.code) + raise GenerationError(str(e)) + + async def close_session(self, session_id: str) -> None: + self._validate_session_id(session_id) + try: + await self._runtime.close(session_id) + except WorkerError as e: + raise GenerationError(str(e)) + def _finish_reason( self, req: ChatCompletionRequest, @@ -292,6 +326,8 @@ def _reject_unsupported_params(req: ChatCompletionRequest) -> None: async def create(self, req: ChatCompletionRequest): self._reject_invalid_values(req) self._reject_unsupported_params(req) + if req.session_id is not None: + self._validate_session_id(req.session_id) # tool_choice="none" must hide tools from the model: if we still render # the tool schemas, the model can emit a that we'd surface as # plain text (parsing is disabled), instead of a normal answer. @@ -314,6 +350,10 @@ async def create(self, req: ChatCompletionRequest): raise ContextLengthExceeded(count, self._max_context, requested) options = self._options(req) prompt_input = PromptInput(text=prompt) + # Admit the session up front (before the stream's first chunk) so a + # capacity refusal is an HTTP status, not a mid-stream error event. + if req.session_id is not None: + await self._preflight_session(req.session_id) if req.stream: return self._stream(req, prompt_input, options) return await self._complete(req, prompt_input, options) @@ -329,7 +369,7 @@ async def _complete( # Collect raw text (markers intact for tool parsing), halting early # at a stop boundary (special token or request stop). text, stopped = await self._collect_until_stop( - self._runtime.generate_stream(None, prompt, options, stats), + self._runtime.generate_stream(req.session_id, prompt, options, stats), self._stops + self._request_stops(req), ) except Exception as e: # noqa: BLE001 - surface as a structured API error @@ -386,7 +426,9 @@ def chunk(delta: DeltaMessage, finish=None) -> str: # Halt early at a stop boundary, and bound the raw output # BEFORE parsing so post-stop tool calls / text don't leak. raw, stop_hit[0] = await self._collect_until_stop( - self._runtime.generate_stream(None, prompt, options, stats), + self._runtime.generate_stream( + req.session_id, prompt, options, stats + ), stops, ) tool_calls, content = self._extract_tools( @@ -400,7 +442,9 @@ def on_stop(): self._runtime.stop() async for token in self._clean( - self._runtime.generate_stream(None, prompt, options, stats), + self._runtime.generate_stream( + req.session_id, prompt, options, stats + ), stops, on_stop=on_stop, ): diff --git a/extension/llm/server/python/tests/test_sessions.py b/extension/llm/server/python/tests/test_sessions.py new file mode 100644 index 00000000000..e5697abd7d2 --- /dev/null +++ b/extension/llm/server/python/tests/test_sessions.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Session-routing contract tests (fake worker, no model/GPU). + +V2a: one worker hosts multiple isolated sessions, routed by session_id, admitted +up front so capacity refusals are HTTP statuses rather than mid-stream errors. +These assert the HTTP/wire contract only. +""" + + +def _chat(client, *, session_id=None, headers=None): + body = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]} + if session_id is not None: + body["session_id"] = session_id + return client.post("/v1/chat/completions", json=body, headers=headers or {}) + + +def test_session_id_routed_and_opened(make_client): + client, fake = make_client(max_named_sessions=2) + resp = _chat(client, session_id="abc") + assert resp.status_code == 200 + # Admitted before generation, and forwarded to the worker's generate(). + assert fake.opened_log == ["abc"] + assert fake.captured_config.session_id == "abc" + + +def test_reusing_session_id_is_idempotent(make_client): + client, fake = make_client(max_named_sessions=1) + assert _chat(client, session_id="s").status_code == 200 + assert _chat(client, session_id="s").status_code == 200 + # Same id reused, not re-admitted into a second slot. + assert fake.opened_log == ["s"] + + +def test_anonymous_request_does_not_open_named(make_client): + client, fake = make_client(max_named_sessions=2) + assert _chat(client).status_code == 200 + assert fake.opened_log == [] + assert fake.captured_config.session_id is None + + +def test_explicit_session_unsupported_when_single_session(make_client): + # max_named_sessions=0: backend hosts only the scratch session. + client, _ = make_client(max_named_sessions=0) + resp = _chat(client, session_id="abc") + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "unsupported_session" + + +def test_capacity_exhausted_returns_429(make_client): + client, _ = make_client(max_named_sessions=1) + assert _chat(client, session_id="a").status_code == 200 + resp = _chat(client, session_id="b") # second distinct id, no free slot + assert resp.status_code == 429 + assert resp.json()["error"]["code"] == "capacity_exhausted" + + +def test_close_frees_a_slot(make_client): + client, _ = make_client(max_named_sessions=1) + assert _chat(client, session_id="a").status_code == 200 + deleted = client.delete("/v1/sessions/a") + assert deleted.status_code == 200 + assert deleted.json() == {"closed": True, "session_id": "a"} + # The freed slot now admits a different session. + assert _chat(client, session_id="b").status_code == 200 + + +def test_invalid_session_id_rejected_before_worker(make_client): + client, fake = make_client(max_named_sessions=2) + resp = _chat(client, session_id="has space") + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "invalid_session_id" + assert fake.opened_log == [] # never reached the worker + + +def test_session_id_too_long_rejected(make_client): + client, _ = make_client(max_named_sessions=2) + resp = _chat(client, session_id="x" * 129) + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "invalid_session_id" + + +def test_session_id_header_alias(make_client): + client, fake = make_client(max_named_sessions=2) + resp = _chat(client, headers={"X-ExecuTorch-Session-ID": "from-header"}) + assert resp.status_code == 200 + assert fake.opened_log == ["from-header"] + + +def test_body_session_id_wins_over_header(make_client): + client, fake = make_client(max_named_sessions=2) + resp = _chat( + client, session_id="from-body", headers={"X-ExecuTorch-Session-ID": "from-hdr"} + ) + assert resp.status_code == 200 + assert fake.opened_log == ["from-body"] + + +def test_session_id_underscore_header_alias(make_client): + # pi's sendSessionAffinityHeaders emits a verbatim `session_id` header. + client, fake = make_client(max_named_sessions=2) + resp = _chat(client, headers={"session_id": "from-session-id"}) + assert resp.status_code == 200 + assert fake.opened_log == ["from-session-id"] + + +def test_x_session_affinity_header_alias(make_client): + client, fake = make_client(max_named_sessions=2) + resp = _chat(client, headers={"x-session-affinity": "from-affinity"}) + assert resp.status_code == 200 + assert fake.opened_log == ["from-affinity"] + + +def test_session_header_precedence(make_client): + # X-ExecuTorch-Session-ID > session_id > x-session-affinity (no body field). + client, fake = make_client(max_named_sessions=3) + resp = _chat( + client, + headers={ + "X-ExecuTorch-Session-ID": "xet", + "session_id": "sid", + "x-session-affinity": "aff", + }, + ) + assert resp.status_code == 200 + assert fake.opened_log == ["xet"] diff --git a/extension/llm/server/python/tests/test_worker_client.py b/extension/llm/server/python/tests/test_worker_client.py index dbed8d396f3..b461785036f 100644 --- a/extension/llm/server/python/tests/test_worker_client.py +++ b/extension/llm/server/python/tests/test_worker_client.py @@ -63,6 +63,7 @@ class _Cfg: max_new_tokens: int = 64 temperature: float = 0.0 stop: list = field(default_factory=list) + session_id: str = None def _lines(*objs): @@ -143,12 +144,44 @@ def test_generate_on_dead_worker_raises(): WorkerClient(proc).generate("hi", _Cfg()) +def test_generate_includes_session_id_when_set(): + proc = _FakeProc(_lines({"done": True, "prompt_tokens": 1, "completion_tokens": 0})) + WorkerClient(proc).generate("hi", _Cfg(session_id="abc")) + assert json.loads(proc.stdin.written[0])["session_id"] == "abc" + + +def test_generate_omits_session_id_when_unset(): + proc = _FakeProc(_lines({"done": True, "prompt_tokens": 1, "completion_tokens": 0})) + WorkerClient(proc).generate("hi", _Cfg()) + assert "session_id" not in json.loads(proc.stdin.written[0]) + + +def test_open_session_sends_op_and_acks(): + proc = _FakeProc(_lines({"opened": True, "session_id": "abc"})) + WorkerClient(proc).open_session("abc") + assert json.loads(proc.stdin.written[0]) == {"op": "open", "session_id": "abc"} + + +def test_open_session_capacity_error_carries_code(): + proc = _FakeProc(_lines({"error": "full", "code": "capacity_exhausted"})) + with pytest.raises(WorkerError) as ei: + WorkerClient(proc).open_session("abc") + assert ei.value.code == "capacity_exhausted" + + +def test_close_session_sends_op_and_acks(): + proc = _FakeProc(_lines({"closed": True, "session_id": "abc"})) + WorkerClient(proc).close_session("abc") + assert json.loads(proc.stdin.written[0]) == {"op": "close", "session_id": "abc"} + + def test_spawn_worker_waits_for_ready(): - proc = _FakeProc(_lines({"ready": True})) + proc = _FakeProc(_lines({"ready": True, "max_named_sessions": 3})) client = spawn_worker( ["/fake/worker", "--model_path", "m"], popen=lambda *a, **k: proc ) assert isinstance(client, WorkerClient) + assert client.max_named_sessions == 3 def test_spawn_worker_not_ready_raises():