diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 0583765cb77..899c816e859 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -243,6 +243,16 @@ Each `done` event reports (`new`/`exact_prefix`/`dirty`/`mismatch`/`equal`) for measuring the hit rate. `--no-warm-resume` forces a full prefill every request (for A/B comparison). +**Tool-call turns (token-ID continuation):** an assistant turn re-rendered from +its parsed tool call rarely re-tokenizes to the tokens the model actually +generated, so plain warm resume misses on agent loops. The server stores the +exact generated token ids per session and, on the next turn, sends the prompt as +segments (`{"text"}` / `{"ids"}`) that splice those ids back in for prior +assistant turns instead of re-rendering them — so the resident state stays an +exact token prefix and resume hits. Tool *results* remain text (re-tokenized +deterministically). The worker's exact-token check still backstops everything, so +a mismatch just falls back to a full prefill. + This is **isolation + warm resume, not concurrency**: execution is still synchronous (one in-flight request; `--num-runners > 1` is rejected since more workers would duplicate the weights). Fair interleaving across in-flight requests diff --git a/extension/llm/server/cpp/CMakeLists.txt b/extension/llm/server/cpp/CMakeLists.txt index 18f62cfcd5f..641cb55b0f8 100644 --- a/extension/llm/server/cpp/CMakeLists.txt +++ b/extension/llm/server/cpp/CMakeLists.txt @@ -95,3 +95,14 @@ target_include_directories( test_worker_prefill_plan PUBLIC ${_common_include_directories} ) add_test(NAME worker_prefill_plan COMMAND test_worker_prefill_plan) + +# Worker-loop harness (worker_handle_request + WorkerSessions) driven by a +# scriptable fake LLMSession/Tokenizer/LLMEngine -- no model/GPU. It includes +# the full worker_loop.h, so it needs the JSON include + the runtime/tokenizer +# libs. +add_executable(test_worker_loop test_worker_loop.cpp) +target_include_directories( + test_worker_loop PUBLIC ${_common_include_directories} ${_json_include} +) +target_link_libraries(test_worker_loop PUBLIC ${link_libraries}) +add_test(NAME worker_loop COMMAND test_worker_loop) diff --git a/extension/llm/server/cpp/test_worker_loop.cpp b/extension/llm/server/cpp/test_worker_loop.cpp new file mode 100644 index 00000000000..a14d6518681 --- /dev/null +++ b/extension/llm/server/cpp/test_worker_loop.cpp @@ -0,0 +1,444 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Hermetic tests for worker_loop.h (worker_handle_request + WorkerSessions), +// the highest-risk serving logic. A scriptable fake LLMSession / Tokenizer / +// LLMEngine drives the real loop with NO model, tokenizer, or GPU. worker_emit +// writes to std::cout, so each test captures stdout and parses the JSON events. +// Self-contained assertions (no gtest) to match test_worker_prefill_plan. + +#include + +#include +#include +#include +#include +#include +#include + +using executorch::extension::llm::DecodeResult; +using executorch::extension::llm::LLMEngine; +using executorch::extension::llm::LLMServingCapacity; +using executorch::extension::llm::LLMSession; +using executorch::extension::llm::SamplingConfig; +using executorch::extension::llm::worker_handle_request; +using executorch::extension::llm::WorkerSessions; +using executorch::extension::llm::WorkerSessionState; +using ETError = ::executorch::runtime::Error; +template +using ETResult = ::executorch::runtime::Result; + +namespace { +int g_failures = 0; + +void check(const char* name, bool ok) { + printf(" [%s] %s\n", ok ? "PASS" : "FAIL", name); + if (!ok) { + ++g_failures; + } +} + +// ---- Fake LLMSession: scriptable decode stream + injectable failures -------- +class FakeSession : public LLMSession { + public: + struct Step { + uint64_t id; + std::string piece; + bool is_eos; + bool is_terminal; + }; + std::vector steps; + size_t step_i = 0; + int64_t pos = 0; // models the session's KV position + + int prefill_calls = 0; + std::vector prefill_sizes; // size of each prefill_tokens() call + int fail_prefill_on = -1; // 0-based call index to fail (-1 = never) + int decode_calls = 0; + int fail_decode_on = -1; + int reset_calls = 0; + bool fail_reset = false; + + ETError prefill_tokens( + std::vector tokens, + const SamplingConfig* /*initial_sampling*/ = nullptr) override { + prefill_sizes.push_back(tokens.size()); + if (prefill_calls++ == fail_prefill_on) { + return ETError::Internal; // failed AFTER (notionally) mutating state + } + pos += static_cast(tokens.size()); + return ETError::Ok; + } + + ETResult decode_one(const SamplingConfig& /*s*/) override { + if (decode_calls++ == fail_decode_on) { + return ETError::Internal; + } + if (step_i >= steps.size()) { + return DecodeResult{0, "", true, true}; // default: EOS/terminal + } + const Step s = steps[step_i++]; + if (!s.is_terminal) { + pos += 1; // a forwarded token advances the cache position + } + return DecodeResult{s.id, s.piece, s.is_eos, s.is_terminal}; + } + + ETError seek(int64_t /*pos*/) override { + return ETError::NotSupported; + } + int64_t position() const override { + return pos; + } + ETError reset() override { + ++reset_calls; + if (fail_reset) { + return ETError::Internal; + } + pos = 0; + step_i = 0; + return ETError::Ok; + } + void stop() override {} +}; + +// ---- Fake Tokenizer: only needed to satisfy the signature; tests use {ids} +// segments so encode() is not exercised on the hot paths. ------------------- +class FakeTokenizer : public ::tokenizers::Tokenizer { + public: + ::tokenizers::Error load(const std::string&) override { + initialized_ = true; + return ::tokenizers::Error::Ok; + } + ::tokenizers::Result> encode( + const std::string& input, + int8_t /*bos*/ = 0, + int8_t /*eos*/ = 0) const override { + std::vector + out; // 1 id per byte (deterministic; unused by ids tests) + for (unsigned char c : input) { + out.push_back(static_cast(c)); + } + return out; + } + ::tokenizers::Result decode( + uint64_t /*prev*/, + uint64_t /*token*/, + bool /*skip_special_tokens*/ = false) const override { + return std::string(""); + } + ::tokenizers::Result id_to_piece(uint64_t /*t*/) const override { + return std::string(""); + } + ::tokenizers::Result piece_to_id( + const std::string& /*t*/) const override { + return static_cast(0); + } + bool is_loaded() const override { + return true; + } +}; + +class FakeEngine : public LLMEngine { + public: + int32_t capacity = 4; + ETResult> create_session() override { + return std::unique_ptr(new FakeSession()); + } + LLMServingCapacity serving_capacity() const override { + return LLMServingCapacity{capacity, 0}; + } + const std::unordered_map& metadata() const override { + return md_; + } + + private: + std::unordered_map md_; +}; + +// ---- stdout-capturing driver ------------------------------------------------ +struct Emitted { + std::string text; // concatenated {"token": ...} pieces + nlohmann::json done; // the {"done": true, ...} event + int token_events = 0; + bool threw = false; +}; + +Emitted run( + WorkerSessionState& st, + bool warm, + const nlohmann::json& req, + const std::unordered_map& md = {}) { + static FakeTokenizer tok; + std::ostringstream cap; + std::streambuf* old = std::cout.rdbuf(cap.rdbuf()); + Emitted em; + try { + worker_handle_request(st, warm, tok, md, req); + } catch (const std::exception&) { + em.threw = true; + } + std::cout.rdbuf(old); + std::istringstream iss(cap.str()); + std::string line; + while (std::getline(iss, line)) { + if (line.empty()) { + continue; + } + auto j = nlohmann::json::parse(line); + if (j.contains("token")) { + em.text += j["token"].get(); + ++em.token_events; + } + if (j.contains("done")) { + em.done = j; + } + } + return em; +} + +WorkerSessionState makeState() { + WorkerSessionState st; + st.session.reset(new FakeSession()); + return st; +} +FakeSession& fake(WorkerSessionState& st) { + return *static_cast(st.session.get()); +} +nlohmann::json idsReq(std::vector ids, int64_t max_new = 8) { + return {{"max_new_tokens", max_new}, {"prompt_segments", {{{"ids", ids}}}}}; +} + +void test_new_full_prefill() { + auto st = makeState(); + fake(st).steps = {{10, "a", false, false}, {0, "", true, true}}; + auto em = run(st, /*warm=*/true, idsReq({1, 2, 3})); + check("new: reason=new", em.done["session_reset_reason"] == "new"); + check("new: reset called once", fake(st).reset_calls == 1); + check( + "new: full prefill (3)", + fake(st).prefill_sizes == std::vector{3}); + check( + "new: reused=0 prefilled=3", + em.done["reused_prompt_tokens"] == 0 && + em.done["prefilled_prompt_tokens"] == 3); + check( + "new: resident.size()==position()", + st.resident_token_ids.size() == (size_t)st.session->position()); +} + +void test_exact_prefix_warm_suffix() { + auto st = makeState(); + // First turn establishes resident [1,2]. + fake(st).steps = {{0, "", true, true}}; + run(st, true, idsReq({1, 2})); + size_t resets_after_first = fake(st).reset_calls; + fake(st).steps = {{0, "", true, true}}; + fake(st).prefill_sizes.clear(); + // Second turn extends to [1,2,3] -> warm suffix prefill of just [3]. + auto em = run(st, true, idsReq({1, 2, 3})); + check( + "warm: reason=exact_prefix", + em.done["session_reset_reason"] == "exact_prefix"); + check( + "warm: prefill suffix only ([3])", + fake(st).prefill_sizes == std::vector{1}); + check( + "warm: reused=2 prefilled=1", + em.done["reused_prompt_tokens"] == 2 && + em.done["prefilled_prompt_tokens"] == 1); + check("warm: no extra reset", fake(st).reset_calls == resets_after_first); + check( + "warm: resident.size()==position()", + st.resident_token_ids.size() == (size_t)st.session->position()); +} + +void test_mismatch_full_reset() { + auto st = makeState(); + fake(st).steps = {{0, "", true, true}}; + run(st, true, idsReq({1, 2})); + fake(st).steps = {{0, "", true, true}}; + fake(st).prefill_sizes.clear(); + auto em = run(st, true, idsReq({1, 9})); // divergent token + check( + "mismatch: reason=mismatch", + em.done["session_reset_reason"] == "mismatch"); + check( + "mismatch: full prefill (2)", + fake(st).prefill_sizes == std::vector{2}); +} + +void test_equal_prompt_no_empty_prefill() { + auto st = makeState(); + fake(st).steps = {{0, "", true, true}}; + run(st, true, idsReq({1, 2, 3})); + fake(st).steps = {{0, "", true, true}}; + fake(st).prefill_sizes.clear(); + auto em = run(st, true, idsReq({1, 2, 3})); // identical prompt + check("equal: reason=equal", em.done["session_reset_reason"] == "equal"); + bool any_empty = false; + for (size_t s : fake(st).prefill_sizes) { + any_empty = any_empty || (s == 0); + } + check("equal: prefill_tokens never called with []", !any_empty); + check( + "equal: full reprefill (3)", + fake(st).prefill_sizes == std::vector{3}); +} + +void test_anonymous_never_warm() { + auto st = makeState(); + fake(st).steps = {{0, "", true, true}}; + run(st, /*warm=*/false, idsReq({1, 2})); + fake(st).steps = {{0, "", true, true}}; + fake(st).prefill_sizes.clear(); + // Even though resident now matches a prefix, warm=false forces a full reset. + auto em = run(st, /*warm=*/false, idsReq({1, 2, 3})); + check( + "scratch: reason=new (warm disabled)", + em.done["session_reset_reason"] == "new"); + check( + "scratch: full prefill (3)", + fake(st).prefill_sizes == std::vector{3}); +} + +void test_generated_token_ids_excludes_terminal() { + auto st = makeState(); + fake(st).steps = { + {10, "a", false, false}, {11, "b", false, false}, {0, "", true, true}}; + auto em = run(st, true, idsReq({1, 2})); + check("genids: text=ab", em.text == "ab"); + check("genids: completion_tokens=2", em.done["completion_tokens"] == 2); + std::vector ids = + em.done["generated_token_ids"].get>(); + check( + "genids: ==[10,11] (terminal EOS excluded)", + ids == std::vector{10, 11}); + check("genids: finish=stop (EOS)", em.done["finish_reason"] == "stop"); + check( + "genids: resident.size()==position()", + st.resident_token_ids.size() == (size_t)st.session->position()); +} + +void test_stop_string_marks_dirty_and_omits_ids() { + auto st = makeState(); + fake(st).steps = { + {10, "a", false, false}, + {11, "b", false, false}, + {12, "X", false, false}}; + nlohmann::json req = idsReq({1, 2}); + req["stop"] = {"X"}; + auto em = run(st, true, req); + check("stop: text=ab (stop trimmed)", em.text == "ab"); + check("stop: finish=stop", em.done["finish_reason"] == "stop"); + check( + "stop: no generated_token_ids", !em.done.contains("generated_token_ids")); + check("stop: session marked dirty", st.dirty); +} + +void test_prefill_failure_marks_dirty() { + auto st = makeState(); + fake(st).fail_prefill_on = 0; + auto em = run(st, true, idsReq({1, 2, 3})); + check("prefill-fail: threw", em.threw); + check("prefill-fail: dirty", st.dirty); +} + +void test_decode_failure_marks_dirty() { + auto st = makeState(); + fake(st).fail_decode_on = 0; + auto em = run(st, true, idsReq({1, 2, 3})); + check("decode-fail: threw", em.threw); + check("decode-fail: dirty", st.dirty); +} + +void test_utf8_split_across_pieces_emits_once_intact() { + auto st = makeState(); + // "é" = 0xC3 0xA9, split across two decode pieces; must emit once, intact. + fake(st).steps = { + {10, std::string("\xC3"), false, false}, + {11, std::string("\xA9"), false, false}, + {0, "", true, true}}; + auto em = run(st, true, idsReq({1})); + check( + "utf8: emitted bytes == C3 A9 intact", + em.text == std::string("\xC3\xA9")); + check("utf8: not emitted as a partial first byte", em.token_events == 1); +} + +void test_stop_straddles_pieces() { + auto st = makeState(); + // stop "ab" arrives across two pieces "a","b": nothing should be emitted. + fake(st).steps = { + {10, "a", false, false}, + {11, "b", false, false}, + {12, "c", false, false}}; + nlohmann::json req = idsReq({1}); + req["stop"] = {"ab"}; + auto em = run(st, true, req); + check("stop-straddle: nothing emitted", em.text.empty()); + check("stop-straddle: finish=stop", em.done["finish_reason"] == "stop"); + check("stop-straddle: dirty", st.dirty); +} + +void test_reset_named_only_clears_on_success() { + FakeEngine engine; + WorkerSessions sessions(engine); + std::string code; + WorkerSessionState* st = sessions.open_named("s", code); + check("reset_named: session opened", st != nullptr); + if (st == nullptr) { + return; + } + st->resident_token_ids = {1, 2, 3}; + auto& s = *static_cast(st->session.get()); + + // Failed reset: must report error AND leave resident state intact (lockstep). + s.fail_reset = true; + ETError err = sessions.reset_named("s"); + check("reset_named: failed reset reports error", err != ETError::Ok); + check( + "reset_named: resident intact after failed reset", + st->resident_token_ids.size() == 3); + + // Successful reset: clears resident state. + s.fail_reset = false; + err = sessions.reset_named("s"); + check("reset_named: success reports Ok", err == ETError::Ok); + check( + "reset_named: resident cleared after success", + st->resident_token_ids.empty()); + + // Absent session is an idempotent no-op (Ok). + check( + "reset_named: absent id is Ok", + sessions.reset_named("nope") == ETError::Ok); +} + +} // namespace + +int main() { + printf("worker_loop.h harness:\n"); + test_new_full_prefill(); + test_exact_prefix_warm_suffix(); + test_mismatch_full_reset(); + test_equal_prompt_no_empty_prefill(); + test_anonymous_never_warm(); + test_generated_token_ids_excludes_terminal(); + test_stop_string_marks_dirty_and_omits_ids(); + test_prefill_failure_marks_dirty(); + test_decode_failure_marks_dirty(); + test_utf8_split_across_pieces_emits_once_intact(); + test_stop_straddles_pieces(); + test_reset_named_only_clears_on_success(); + printf( + "\n%s (%d failure(s))\n", + g_failures ? "FAILURES" : "ALL PASS", + g_failures); + return g_failures ? 1 : 0; +} diff --git a/extension/llm/server/cpp/worker_loop.h b/extension/llm/server/cpp/worker_loop.h index 7f92e60371e..f580d21d356 100644 --- a/extension/llm/server/cpp/worker_loop.h +++ b/extension/llm/server/cpp/worker_loop.h @@ -42,8 +42,11 @@ // worker -> stdout, once: {"ready": true, "max_sessions": int, // "max_named_sessions": int} // client -> stdin: -// generate: {"prompt": str, "max_new_tokens": int, "temperature": float, -// "stop": [str, ...], "session_id"?: str} +// generate: {"max_new_tokens": int, "temperature": float, +// "stop": [str, ...], "session_id"?: str, +// and exactly one prompt form: +// "prompt": str +// "prompt_segments": [{"text": str} | {"ids": [int, ...]}]} // open: {"op": "open", "session_id": str} // close: {"op": "close", "session_id": str} // reset: {"op": "reset", "session_id": str} // clear context, keep @@ -55,7 +58,9 @@ // "finish_reason": "stop"|"length", // "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, // "session_reset_reason": "new"|"exact_prefix"|"dirty"| -// "mismatch"|"equal"} +// "mismatch"|"equal", +// "generated_token_ids"?: [int, ...]} // omitted if +// stop-trimmed // open: {"opened": true, "session_id": str} // close: {"closed": true, "session_id": str} // reset: {"reset": true, "session_id": str} @@ -122,7 +127,6 @@ inline void worker_handle_request( const std::unordered_map& metadata, const nlohmann::json& req) { LLMSession& session = *st.session; - const std::string prompt = req.at("prompt").get(); int64_t max_new = req.value("max_new_tokens", static_cast(-1)); const float temperature = req.value("temperature", 0.0f); // Stop strings (the request's `stop` sequences): terminate at the token @@ -131,13 +135,43 @@ inline void worker_handle_request( const std::vector stops = req.value("stop", std::vector{}); - // No special tokens: the prompt is already rendered (the control plane - // applied the chat template), matching the runner's own encode path. - auto encode_result = tokenizer.encode(prompt, /*bos=*/0, /*eos=*/0); - if (!encode_result.ok()) { - throw std::runtime_error("prompt encode failed"); + // The prompt is either a single rendered string ("prompt") or an ordered list + // of segments ("prompt_segments"), each a {"text": ...} chunk to tokenize or + // a + // {"ids": [...]} run of literal token ids. Segments let the control plane + // splice the exact generated token ids of prior assistant turns back in, + // instead of re-tokenizing the chat template's lossy re-rendering of them (so + // warm resume can hit on tool-use turns). Text is encoded with no special + // tokens (already rendered), matching the runner's own encode path. + const bool has_prompt = req.contains("prompt"); + const bool has_segments = req.contains("prompt_segments"); + if (has_prompt == has_segments) { + throw std::runtime_error( + "exactly one of prompt / prompt_segments is required"); + } + std::vector ids; + auto encode_text = [&](const std::string& text) { + auto enc = tokenizer.encode(text, /*bos=*/0, /*eos=*/0); + if (!enc.ok()) { + throw std::runtime_error("prompt encode failed"); + } + ids.insert(ids.end(), enc->begin(), enc->end()); + }; + if (has_segments) { + for (const auto& seg : req.at("prompt_segments")) { + if (seg.contains("ids")) { + for (const auto& id : seg.at("ids")) { + ids.push_back(id.get()); + } + } else if (seg.contains("text")) { + encode_text(seg.at("text").get()); + } else { + throw std::runtime_error("prompt_segment needs `text` or `ids`"); + } + } + } else { + encode_text(req.at("prompt").get()); } - std::vector ids = std::move(*encode_result); if (ids.empty()) { throw std::runtime_error("empty prompt"); } @@ -249,14 +283,27 @@ inline void worker_handle_request( // "length" -- it ran to max_new (possibly clamped to the context window). // reused/prefilled sum to prompt_tokens; session_reset_reason explains the // prefill plan (for measuring warm-resume hit rate). - worker_emit( - {{"done", true}, - {"prompt_tokens", num_prompt}, - {"completion_tokens", num_generated}, - {"finish_reason", finish}, - {"reused_prompt_tokens", reused}, - {"prefilled_prompt_tokens", prefilled}, - {"session_reset_reason", plan.reason}}); + nlohmann::json done = { + {"done", true}, + {"prompt_tokens", num_prompt}, + {"completion_tokens", num_generated}, + {"finish_reason", finish}, + {"reused_prompt_tokens", reused}, + {"prefilled_prompt_tokens", prefilled}, + {"session_reset_reason", plan.reason}}; + // generated_token_ids = the (non-terminal) tokens made resident this turn, + // for the control plane to splice back as an `ids` segment. Only emit them + // when they faithfully decode to the emitted text: a stop-string trim kept + // the post-stop tokens resident but dropped them from the output, so splicing + // them would inject text the client never saw. Omitting them makes the + // control plane record this turn as not resumable (falls back to a text + // re-render). + if (!stop_string) { + done["generated_token_ids"] = std::vector( + st.resident_token_ids.end() - num_generated, + st.resident_token_ids.end()); + } + worker_emit(done); } // Owns the engine's sessions for one worker: named sessions keyed by id plus a diff --git a/extension/llm/server/python/README.md b/extension/llm/server/python/README.md index 7e6e3bd468e..f0c1003d009 100644 --- a/extension/llm/server/python/README.md +++ b/extension/llm/server/python/README.md @@ -137,6 +137,7 @@ does blocking pipe I/O on its executor thread. | `chat_template.py` | messages (+tools) → prompt string | | `worker_client.py` | spawn a worker process + drive it over JSONL (raw transport) | | `session_runtime.py` | stateful runtime over one worker: open/generate/reset/close + streaming bridge | +| `openai_transcript.py` | OpenAI token-ID warm-resume state (fingerprints + sentinel splicing) | | `serving_chat.py` | `/v1/chat/completions` OpenAI adapter (streaming + non-streaming, stop, tools) | | `tool_parsers/` | Hermes/Qwen `` parser only | | `cpp/text_llm_worker.cpp` | the generic C++ worker binary | diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py index cbb3eff80bf..1235f6fcf2c 100644 --- a/extension/llm/server/python/chat_template.py +++ b/extension/llm/server/python/chat_template.py @@ -91,6 +91,9 @@ def __init__( # Server-level defaults (e.g. {"enable_thinking": False}); per-request # chat_template_kwargs override these. self._defaults = default_template_kwargs or {} + # Cache of the (deterministic) generation scaffold per resolved mode, so + # warm-resume bookkeeping doesn't re-render a probe prompt every request. + self._preamble_cache: dict[tuple, str] = {} self._hf = None if hf_tokenizer_path: from transformers import AutoTokenizer @@ -136,6 +139,50 @@ def render( ) return self._fallback(messages) + def generation_preamble( + self, + template_kwargs: Optional[dict[str, Any]] = None, + tools: Optional[list[dict[str, Any]]] = None, + ) -> str: + """The deterministic text the generation prompt appends after the final + ``<|im_start|>assistant\\n`` for this mode (Qwen3 no-think: + ``\\n\\n\\n\\n``; think: ``\\n``; ``""`` for + templates that add no scaffold). The worker prefills this into resident + KV, so warm-resume splicing must reproduce it ahead of a turn's generated + ids. Computed by rendering a trivial prompt with the same mode resolution + AND ``tools`` as :meth:`render`, taking the text after the final assistant + header. ``tools`` is threaded (and keyed) so that if a template ever makes + the post-header scaffold tool-dependent, the stored preamble still matches + the resident one (for Qwen3 the scaffold is tool-independent -> same key). + Returns ``""`` for the fallback / no-scaffold templates (fix is a no-op). + """ + if self._hf is None: + return "" + merged = {**self._defaults, **(template_kwargs or {})} + if tools: + try: + tools_key = json.dumps( + tools, sort_keys=True, ensure_ascii=False, default=str + ) + except (TypeError, ValueError): + tools_key = repr(tools) + else: + tools_key = None + key = (tuple(sorted((k, repr(v)) for k, v in merged.items())), tools_key) + cached = self._preamble_cache.get(key) + if cached is not None: + return cached + rendered = self.render( + [ChatMessage(role="user", content="")], + tools=tools, + template_kwargs=template_kwargs, + ) + marker = "<|im_start|>assistant\n" + idx = rendered.rfind(marker) + preamble = rendered[idx + len(marker) :] if idx != -1 else "" + self._preamble_cache[key] = preamble + return preamble + def chat_template_str(self) -> Optional[str]: """Raw chat-template string (for tool-format auto-detection), if available.""" return ( diff --git a/extension/llm/server/python/openai_transcript.py b/extension/llm/server/python/openai_transcript.py new file mode 100644 index 00000000000..2eaff5fd7f0 --- /dev/null +++ b/extension/llm/server/python/openai_transcript.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI/chat-template transcript state for token-ID warm resume (V2b.1.5). + +This is the OpenAI-adapter-specific glue that makes warm resume work across the +chat template's lossy re-render of prior assistant turns (especially tool calls, +which re-render from parsed structure and don't re-tokenize to what the model +generated). It is NOT generic runtime infrastructure: it knows ChatMessages, +tool_calls, the ChatTemplate, sentinels, and assistant fingerprints. The runtime +(session_runtime) only sees PromptInput. + +Per session we keep one record per assistant turn we produced, in order: +{"fp": fingerprint of the response we returned, "ids": exact generated token ids +| None}. On the next request each prior assistant turn is replaced with a unique +sentinel, the conversation is rendered once, and the rendered text is split on +the sentinels with the stored ids spliced back in -- but only for turns whose +fingerprint matches the incoming message (so an edited/branched history, or a +session reused for another conversation, is never substituted with stale ids) +and whose ids are present (a stop-trimmed turn has None and is left as text). +Everything is backstopped by the worker's exact-token prefix check. +""" + +import hashlib +import json +import re +import uuid +from typing import Optional + +from .chat_template import ChatTemplate +from .protocol import ChatMessage +from .session_runtime import PromptInput + +# The assistant header that precedes a turn's generation scaffold + content. +_ASSIST_HDR = "<|im_start|>assistant\n" +# A scaffold region is exactly empty (history strips it before the last user) or +# one of the Qwen3 think scaffolds (history preserves the empty block after the +# last user; the open form is the think-mode generation preamble). Anything else +# in that region is unrecognized -> the splice falls back to plain text. +_THINK_SCAFFOLD_RE = re.compile(r"\A(?:\n\n\n\n|\n)?\Z") + + +def _normalize_tool_args(args): + """OpenAI tool-call ``arguments`` are JSON strings a client may reserialize + with different whitespace or key order while preserving the same value (e.g. + a server-emitted ``{"command": "x"}`` echoed back compact as + ``{"command":"x"}``). Parse to an object so the fingerprint compares the + semantic payload, not bytes -- the outer sort_keys dump then canonicalizes + it. A non-JSON string (or already-structured args) is returned unchanged, so + it stays byte-sensitive.""" + if isinstance(args, str): + try: + return json.loads(args) + except (ValueError, TypeError): + return args + return args + + +class OpenAITranscriptState: + def __init__(self, template: ChatTemplate): + self._template = template + # session_id -> [{"fp": str, "ids": list[int] | None}, ...] (one per + # assistant turn we produced, in order). Cleared on reset/close. + self._turns: dict[str, list[dict]] = {} + + @staticmethod + def _assistant_fingerprint(content, tool_calls) -> str: + """Stable fingerprint of an assistant turn's semantic payload (content + + each tool call's name/arguments; the random call id is ignored). Used to + confirm an incoming assistant message is the one we generated before + splicing its stored token ids.""" + norm = [] + for tc in tool_calls or []: + fn = getattr(tc, "function", None) + if fn is not None: + name, args = getattr(fn, "name", None), getattr(fn, "arguments", None) + elif isinstance(tc, dict): + f = tc.get("function", {}) + name, args = f.get("name"), f.get("arguments") + else: + continue + norm.append([name, _normalize_tool_args(args)]) + blob = json.dumps([content or "", norm], sort_keys=True, ensure_ascii=False) + return hashlib.sha1(blob.encode("utf-8")).hexdigest() + + @staticmethod + def _normalize_scaffold(text_chunk: str, preamble: str) -> Optional[str]: + """Force the scaffold region -- the text between the last assistant header + in `text_chunk` and its end -- to equal `preamble`, so the worker + re-tokenizes the exact generation scaffold it made resident for this turn. + The region (the content was replaced by a sentinel) is empty when history + stripped the scaffold (insert) or a think scaffold when history preserved + it (replace, possibly with a different form than `preamble`). Returns the + adjusted text, or None if the region is not a recognized scaffold + (ambiguous -> caller falls back to plain text).""" + # No scaffold for this turn's mode/template: nothing to reproduce, so + # leave the chunk untouched -- and don't require the Qwen/ChatML header, + # so token-id splicing still works for templates with a different + # assistant header (the fix stays a true no-op for non-think models). + if not preamble: + return text_chunk + h = text_chunk.rfind(_ASSIST_HDR) + if h == -1: + return None + base = h + len(_ASSIST_HDR) + region = text_chunk[base:] + if region == preamble: + return text_chunk + if not _THINK_SCAFFOLD_RE.match(region): + return None + return text_chunk[:base] + preamble + + @staticmethod + def _split_on_sentinels( + rendered: str, sub: dict[str, dict] + ) -> Optional[list[dict]]: + """Split `rendered` on the sentinels into alternating {"text"} chunks and + {"ids"} runs (each sentinel -> sub[sentinel] = {"ids", "preamble"}). The + {text} chunk before each {ids} run has its assistant scaffold normalized + to that turn's stored preamble. Returns None if any pre-sentinel scaffold + region is ambiguous (caller falls back to plain text).""" + pattern = re.compile("|".join(re.escape(s) for s in sub)) + segments: list[dict] = [] + pos = 0 + for mobj in pattern.finditer(rendered): + norm = OpenAITranscriptState._normalize_scaffold( + rendered[pos : mobj.start()], sub[mobj.group()]["preamble"] + ) + if norm is None: + return None + if norm: + segments.append({"text": norm}) + segments.append({"ids": sub[mobj.group()]["ids"]}) + pos = mobj.end() + if pos < len(rendered): + segments.append({"text": rendered[pos:]}) + return segments + + def build_prompt_input( + self, + *, + session_id: Optional[str], + messages: list[ChatMessage], + rendered_prompt: str, + tools, + template_kwargs, + ) -> PromptInput: + """Return a PromptInput: token-ID segments when this session has faithful + stored ids for matching prior assistant turns, else the plain rendered + text. Each incoming assistant turn is matched IN ORDER against the stored + records and only spliced when (a) its fingerprint matches what we returned + (else the history diverged -> stop, splice nothing further) and (b) we + kept faithful ids for it (a stop-trimmed turn's None -> rendered as text). + Falls back to text on a sentinel collision or a render that + dropped/duplicated a sentinel.""" + stored = self._turns.get(session_id or "") + if not stored: + return PromptInput(text=rendered_prompt) + # ORDINAL ASSUMPTION: stored[k] is the k-th assistant turn WE generated + # for this session, matched positionally against the k-th assistant + # message in the request. A client-injected assistant turn we did not + # generate -- a few-shot exemplar, a pre-seeded turn, or any reused + # session -- shifts that alignment, so the fingerprint at k mismatches and + # we stop splicing from there. This is always SAFE (text fallback + + # worker exact-prefix backstop); it only lowers the warm-resume hit rate, + # silently, for such conversations. + positions = [i for i, m in enumerate(messages) if m.role == "assistant"] + splice: dict[int, dict] = {} # message index -> {"ids", "preamble"} + diverged_at = None + for k, pos in enumerate(positions): + if k >= len(stored): + break + m = messages[pos] + if self._assistant_fingerprint(m.content, m.tool_calls) != stored[k]["fp"]: + diverged_at = k # this stored turn and every later one are stale + break + if stored[k]["ids"] is not None: + splice[pos] = { + "ids": stored[k]["ids"], + "preamble": stored[k].get("preamble", ""), + } + if diverged_at is not None: + # Drop the stale tail from the first mismatch so an edited/branched + # earlier turn can't keep shadowing future requests; the matched + # prefix [:diverged_at] is untouched and still splices. We have no + # exact ids for the edited turn itself (the client authored it, we + # didn't generate it), so warm resume for that turn and the ones after + # it stays text until the session is reset/closed. Safe regardless: + # stale ids are never spliced and the worker's exact-token prefix + # check backstops correctness. + del stored[diverged_at:] + if not splice: + return PromptInput(text=rendered_prompt) + token = uuid.uuid4().hex + sentinel_at = {pos: f"<>" for j, pos in enumerate(splice)} + sub = {sentinel_at[pos]: splice[pos] for pos in splice} + # A sentinel must not already occur in the rendered output. + if any(s in rendered_prompt for s in sub): + return PromptInput(text=rendered_prompt) + modified = [ + ( + ChatMessage(role="assistant", content=sentinel_at[i]) + if i in sentinel_at + else m + ) + for i, m in enumerate(messages) + ] + rendered = self._template.render( + modified, tools=tools, template_kwargs=template_kwargs + ) + # Each sentinel must survive templating exactly once, else fall back. + if any(rendered.count(s) != 1 for s in sub): + return PromptInput(text=rendered_prompt) + # Splice ids and normalize each turn's scaffold; None => ambiguous region. + segments = self._split_on_sentinels(rendered, sub) + if segments is None: + return PromptInput(text=rendered_prompt) + return PromptInput(segments=segments) + + def record_assistant_turn( + self, + *, + session_id: Optional[str], + content, + tool_calls, + generated_token_ids: list, + prior_turns: int, + preamble: str = "", + ) -> None: + """Record this turn's {fingerprint, exact generated ids, generation + preamble} at position `prior_turns` -- the count of assistant turns in the + request this response answers. Stored records at/after that index are + dropped first, so a regenerated or branched turn under the same session_id + replaces stale records instead of leaving them to shadow future + warm-resume hits with a stale fingerprint. ids is None when the worker + omitted them (stop-trimmed turn) -- recorded as non-resumable but kept for + positional alignment. `preamble` is the generation scaffold resident ahead + of these ids (mode-specific, e.g. Qwen3 `` block), reproduced ahead + of the spliced ids on the next request so the prefix stays exact.""" + if not session_id: + return + turns = self._turns.setdefault(session_id, []) + del turns[prior_turns:] + turns.append( + { + "fp": self._assistant_fingerprint(content, tool_calls), + "ids": list(generated_token_ids) if generated_token_ids else None, + "preamble": preamble, + } + ) + + def reset(self, session_id: str) -> None: + self._turns.pop(session_id, None) + + def close(self, session_id: str) -> None: + self._turns.pop(session_id, None) diff --git a/extension/llm/server/python/serving_chat.py b/extension/llm/server/python/serving_chat.py index bd6101366f4..53d32978251 100644 --- a/extension/llm/server/python/serving_chat.py +++ b/extension/llm/server/python/serving_chat.py @@ -22,6 +22,7 @@ InvalidSessionId, SessionCapacity, ) +from .openai_transcript import OpenAITranscriptState from .protocol import ( _new_id, ChatCompletionChunk, @@ -70,16 +71,21 @@ def __init__( # Detector CLASS; a fresh instance is created per request so streaming # state is never shared across concurrent requests. self._tool_detector_cls = tool_detector_cls - # Two distinct sets (see chat_template): - # * _stops: NARROW turn terminators (e.g. <|im_end|>) used as generation - # stops AND for pre-parse truncation (_options/_collect_until_stop/ - # _truncate_raw/_clean). Excludes structural/tool delimiters so a - # is never halted or cut before _extract_tools sees it. - # * _content_specials: BROAD all-special-tokens set, used ONLY by - # _strip_specials for final cleanup of the already-parsed visible - # content, so a stray special token can't leak to the user. + # Two distinct sets (see chat_template); create() combines them per path: + # * _stops: NARROW turn terminators (e.g. <|im_end|>). The ONLY stop set + # for tool turns -- excludes structural/tool delimiters so a + # is never halted or cut before _extract_tools sees it. Also used by + # _truncate_raw (pre-parse truncation on the tool path). + # * _content_specials: BROAD all-special-tokens set. For PLAIN chat it is + # added to the worker/clean stop set (create() -> gen_stops) so a leaked + # special halts the worker and never reaches the client, AND it backs + # _strip_specials for final cleanup of already-parsed visible content. self._stops = template.turn_stop_sequences() self._content_specials = template.special_tokens() + # OpenAI/chat-template token-ID warm-resume state (V2b.1.5). Adapter-side, + # not runtime; kept in lockstep with the worker's session state by + # clearing both on reset/close. + self._transcript = OpenAITranscriptState(template) @staticmethod def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]: @@ -194,17 +200,20 @@ async def _clean( if buf: yield buf - def _options(self, req: ChatCompletionRequest) -> GenerationOptions: + def _options( + self, req: ChatCompletionRequest, stops: list[str] + ) -> GenerationOptions: return GenerationOptions( max_new_tokens=req.resolved_max_tokens(), temperature=req.temperature if req.temperature is not None else 0.0, - # Let the worker terminate at the same boundary the control plane - # would cut: the model's special tokens (e.g. <|im_end|>) AND request - # stop sequences. This stops generation at end-of-turn even when the - # worker's EOS-by-token-id check misses it, instead of running to - # max_new (or erroring) past the turn. The server's - # _clean/_collect_until_stop still re-apply these as a backstop. - stop=self._stops + self._request_stops(req), + # Worker stop set, decided per path in create(): narrow turn + # terminators (+ request stops) for tool turns so a structural/tool + # delimiter is never cut before the parser sees it; plus the broad + # content specials for plain chat so a leaked special halts the worker + # -- which then marks the turn dirty and omits its ids -- instead of + # streaming a token the client should not see. The server re-applies + # the same set in _clean/_collect_until_stop as a backstop. + stop=stops, ) @staticmethod @@ -226,18 +235,24 @@ async def _preflight_session(self, session_id: str) -> None: raise GenerationError(str(e)) async def close_session(self, session_id: str) -> None: + # Lockstep: do the fallible worker op FIRST, then clear the (best-effort, + # can't-fail) transcript. If the worker op fails both retain old state, + # so they never drift. self._validate_session_id(session_id) try: await self._runtime.close(session_id) except WorkerError as e: raise GenerationError(str(e)) + self._transcript.close(session_id) async def reset_session(self, session_id: str) -> None: + # Lockstep: worker op first (fallible), then clear the transcript. self._validate_session_id(session_id) try: await self._runtime.reset(session_id) except WorkerError as e: raise GenerationError(str(e)) + self._transcript.reset(session_id) def _finish_reason( self, @@ -330,6 +345,24 @@ def _reject_unsupported_params(req: ChatCompletionRequest) -> None: "unsupported_parameter", ) + def _count_prompt_tokens(self, prompt: PromptInput) -> Optional[int]: + """Token count of what the worker will actually assemble: the rendered + text, or for token-ID segments sum(len(ids)) for {ids} runs + the + tokenized length of {text} chunks. None when no tokenizer is available to + count text (the worker still enforces the real context limit).""" + if prompt.text is not None: + return self._template.count_tokens(prompt.text) + total = 0 + for seg in prompt.segments: + if "ids" in seg: + total += len(seg["ids"]) + else: + c = self._template.count_tokens(seg["text"]) + if c is None: + return None + total += c + return total + async def create(self, req: ChatCompletionRequest): self._reject_invalid_values(req) self._reject_unsupported_params(req) @@ -342,10 +375,24 @@ async def create(self, req: ChatCompletionRequest): prompt = self._template.render( req.messages, tools=template_tools, template_kwargs=req.chat_template_kwargs ) - # Pre-flight context check: reject cleanly instead of failing mid-generation - # (only possible when a tokenizer is available to count, e.g. --hf-tokenizer). + # Build the prompt input first: token-ID segments (V2b.1.5) splice this + # session's prior assistant turns' exact ids so warm resume stays exact + # across the chat template's lossy re-render of tool-call turns; plain + # rendered text when there's nothing to splice / on any ambiguity (the + # worker verifies the exact-token prefix regardless). + prompt_input = self._transcript.build_prompt_input( + session_id=req.session_id, + messages=req.messages, + rendered_prompt=prompt, + tools=template_tools, + template_kwargs=req.chat_template_kwargs, + ) + # Pre-flight context check against the tokens the worker will actually + # assemble: for segments that is sum(len(ids)) + tokenized text, not the + # rendered string, so a near-limit prompt agrees with the worker rather + # than false-400ing or failing mid-decode. Only when a tokenizer exists. if self._max_context: - count = self._template.count_tokens(prompt) + count = self._count_prompt_tokens(prompt_input) if count is not None: if count >= self._max_context: raise ContextLengthExceeded(count, self._max_context) @@ -355,35 +402,74 @@ async def create(self, req: ChatCompletionRequest): requested = req.resolved_max_tokens() if requested > 0 and count + requested > self._max_context: raise ContextLengthExceeded(count, self._max_context, requested) - options = self._options(req) - prompt_input = PromptInput(text=prompt) + # Stop-set split by path. Tool turns use only the narrow turn terminators + # (+ request stops) so a structural/tool delimiter is never halted before + # the parser sees it. Plain chat adds the broad content specials so a + # leaked special (one non-streaming would strip) halts the worker -- which + # marks the turn dirty and omits its ids -- instead of reaching the client + # or being recorded as resumable ids for text never shown. Both paths + # reuse this exact set in the control-plane cut (_clean/_collect_until_stop). + if self._tools_active(req): + gen_stops = self._stops + self._request_stops(req) + else: + gen_stops = self._stops + self._content_specials + self._request_stops(req) + options = self._options(req, gen_stops) + # The generation scaffold the worker will prefill ahead of this turn's + # tokens (e.g. Qwen3 block), resolved with the same per-request + # mode AND tools as the render; recorded per turn so warm-resume splicing + # reproduces the exact resident scaffold even if the mode changes between + # requests. + preamble = self._template.generation_preamble( + req.chat_template_kwargs, tools=template_tools + ) # Admit the session up front (before the stream's first chunk) so a # capacity refusal is an HTTP status, not a mid-stream error event. if req.session_id is not None: await self._preflight_session(req.session_id) if req.stream: - return self._stream(req, prompt_input, options) - return await self._complete(req, prompt_input, options) + return self._stream(req, prompt_input, options, preamble, gen_stops) + return await self._complete(req, prompt_input, options, preamble, gen_stops) async def _complete( self, req: ChatCompletionRequest, prompt: PromptInput, options: GenerationOptions, + preamble: str = "", + gen_stops: Optional[list[str]] = None, ) -> ChatCompletionResponse: + # Same stop set the worker was given (per-path: narrow for tools, broad + # content specials added for plain chat); falls back to narrow if a caller + # didn't supply it. + stops = ( + gen_stops + if gen_stops is not None + else self._stops + self._request_stops(req) + ) stats = GenStats() try: # Collect raw text (markers intact for tool parsing), halting early # at a stop boundary (special token or request stop). text, stopped = await self._collect_until_stop( self._runtime.generate_stream(req.session_id, prompt, options, stats), - self._stops + self._request_stops(req), + stops, ) except Exception as e: # noqa: BLE001 - surface as a structured API error raise GenerationError(str(e)) # Bound the raw output at the first stop/special token BEFORE tool # parsing, so a call after the stop boundary is not parsed/emitted. tool_calls, content = self._extract_tools(req, self._truncate_raw(text, req)) + # Record after the response is finalized: the fingerprint is of exactly + # what we return (content + tool_calls), so the next turn can confirm the + # client echoed this turn before splicing its ids. + self._transcript.record_assistant_turn( + session_id=req.session_id, + content=content, + tool_calls=tool_calls, + generated_token_ids=stats.generated_token_ids, + prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + preamble=preamble, + ) finish = self._finish_reason( req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason ) @@ -407,6 +493,8 @@ async def _stream( req: ChatCompletionRequest, prompt: PromptInput, options: GenerationOptions, + preamble: str = "", + gen_stops: Optional[list[str]] = None, ) -> AsyncIterator[str]: cid = _new_id("chatcmpl") @@ -426,7 +514,14 @@ def chunk(delta: DeltaMessage, finish=None) -> str: stats = GenStats() stop_hit = [False] # set when a stop boundary is reached (forces finish="stop") - stops = self._stops + self._request_stops(req) + # Per-path stop set from create(): for plain chat this includes the broad + # content specials, so _clean cuts a leaked special out of the stream (and + # the worker, given the same set, halts + omits ids -> non-resumable turn). + stops = ( + gen_stops + if gen_stops is not None + else self._stops + self._request_stops(req) + ) try: if use_tools: # v1: buffer the (usually short) tool response, parse once. @@ -448,6 +543,7 @@ def on_stop(): stop_hit[0] = True self._runtime.stop() + streamed: list[str] = [] async for token in self._clean( self._runtime.generate_stream( req.session_id, prompt, options, stats @@ -455,7 +551,9 @@ def on_stop(): stops, on_stop=on_stop, ): + streamed.append(token) yield chunk(DeltaMessage(content=token)) + content = "".join(streamed) # for the session fingerprint except ( Exception ) as e: # noqa: BLE001 - emit a structured error event, never drop the socket @@ -470,6 +568,14 @@ def on_stop(): yield f"data: {json.dumps({'error': err})}\n\n" yield "data: [DONE]\n\n" return + self._transcript.record_assistant_turn( + session_id=req.session_id, + content=content, + tool_calls=tool_calls, + generated_token_ids=stats.generated_token_ids, + prior_turns=sum(1 for m in req.messages if m.role == "assistant"), + preamble=preamble, + ) if use_tools: if content: diff --git a/extension/llm/server/python/tests/test_sessions.py b/extension/llm/server/python/tests/test_sessions.py index a237f47e865..38bf62630fe 100644 --- a/extension/llm/server/python/tests/test_sessions.py +++ b/extension/llm/server/python/tests/test_sessions.py @@ -11,6 +11,15 @@ These assert the HTTP/wire contract only. """ +import asyncio + +import pytest + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.errors import GenerationError +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.worker_client import WorkerError + def _chat(client, *, session_id=None, headers=None): body = {"model": "test-model", "messages": [{"role": "user", "content": "hi"}]} @@ -130,6 +139,96 @@ def test_session_header_precedence(make_client): assert fake.opened_log == ["xet"] +def _chat_msgs(client, messages, session_id): + return client.post( + "/v1/chat/completions", + json={"model": "test-model", "session_id": session_id, "messages": messages}, + ) + + +# The fake worker streams tokens ("Hello", ", ", "world"), so the assistant +# content we return (and the client must echo back to match the fingerprint) is: +_FAKE_REPLY = "Hello, world" + + +def test_token_id_segments_splice_prior_assistant_turn(make_client): + # V2b.1.5: the server stores turn-1's generated ids and, on turn 2, sends + # prompt_segments that splice them back as an exact {ids} run (not text) -- + # but only because the client echoes back the assistant turn we generated. + client, fake = make_client(max_named_sessions=2, gen_ids=[7, 8, 9]) + assert ( + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s").status_code == 200 + ) + # First turn has no prior assistant turn -> plain text prompt. + assert fake.captured_config.prompt_segments is None + + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": _FAKE_REPLY}, # matches what we returned + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + segs = fake.captured_config.prompt_segments + assert segs is not None, "expected token-ID segments on the second turn" + # The stored generated ids are spliced in as an exact id run... + assert any(s.get("ids") == [7, 8, 9] for s in segs) + # ...bracketed by text segments (template glue + the new user turn). + assert any("text" in s for s in segs) + + +def test_edited_assistant_turn_not_spliced(make_client): + # P1 guard: if the client edits a prior assistant turn (or reuses the session + # for a different conversation), the stale ids must NOT be spliced -- the + # fingerprint mismatches and we fall back to text. + client, fake = make_client(max_named_sessions=2, gen_ids=[7, 8, 9]) + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "EDITED - not what the model generated"}, + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + assert fake.captured_config.prompt_segments is None + + +def test_stop_trimmed_turn_not_spliced(make_client): + # P1/P2 guard: a stop-trimmed turn (worker omits generated_token_ids -> + # recorded ids=None) is never spliced, even when the turn fingerprint matches, + # so unseen post-stop tokens can't be injected into a later prompt. + client, fake = make_client(max_named_sessions=2, gen_ids=[]) # [] => ids None + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": _FAKE_REPLY}, # fingerprint matches + {"role": "user", "content": "more"}, + ] + assert _chat_msgs(client, turn2, "s").status_code == 200 + assert fake.captured_config.prompt_segments is None + + +def test_no_segments_for_anonymous_requests(make_client): + client, fake = make_client(max_named_sessions=2, gen_ids=[1, 2]) + client.post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert fake.captured_config.prompt_segments is None + + +def test_reset_clears_stored_ids(make_client): + # After reset, the next turn has no stored ids to splice -> plain text again. + client, fake = make_client(max_named_sessions=2, gen_ids=[5, 6]) + _chat_msgs(client, [{"role": "user", "content": "hi"}], "s") + client.post("/v1/sessions/s/reset") + turn2 = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "more"}, + ] + _chat_msgs(client, turn2, "s") + assert fake.captured_config.prompt_segments is None + + def test_reset_endpoint_clears_context_but_keeps_slot(make_client): # max_named=1: open "a", reset it, then a *different* id must still 429 — # proving reset cleared context without freeing the slot (unlike DELETE). @@ -147,3 +246,182 @@ def test_reset_invalid_session_id_rejected(make_client): r = client.post("/v1/sessions/has%20space/reset") assert r.status_code == 400 assert r.json()["error"]["code"] == "invalid_session_id" + + +class _RaisingRuntime: + """Runtime whose worker ops fail, to exercise the lockstep invariant.""" + + async def open(self, sid): + pass + + async def reset(self, sid): + raise WorkerError("worker down") + + async def close(self, sid): + raise WorkerError("worker down") + + +@pytest.mark.parametrize("op", ["reset_session", "close_session"]) +def test_worker_op_failure_keeps_transcript(op): + # Lockstep invariant: if the worker reset/close fails, the adapter transcript + # must NOT be cleared -- both retain old state so they never drift. + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + serving = ServingChat(_RaisingRuntime(), template, "test-model") + serving._transcript.record_assistant_turn( + session_id="s", + content="hi", + tool_calls=None, + generated_token_ids=[1, 2], + prior_turns=0, + ) + + async def go(): + with pytest.raises(GenerationError): + await getattr(serving, op)("s") + + asyncio.run(go()) + assert serving._transcript._turns.get( + "s" + ), "transcript cleared despite worker failure" + + +def test_record_assistant_turn_replaces_stale_at_position(): + # A regenerated/branched turn under the same session_id must REPLACE the + # record at its position (prior_turns), not append, so a later turn can still + # splice the regenerated ids instead of breaking on a stale fingerprint. + from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, + ) + + t = OpenAITranscriptState(ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)) + t.record_assistant_turn( + session_id="s", + content="a0", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + ) + t.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + ) + assert [r["ids"] for r in t._turns["s"]] == [[1], [2]] + # regenerate turn 2 (same prior_turns) -> replaces stale [2], no stale tail + t.record_assistant_turn( + session_id="s", + content="a1b", + tool_calls=None, + generated_token_ids=[3], + prior_turns=1, + ) + assert [r["ids"] for r in t._turns["s"]] == [[1], [3]] + + +def test_divergence_truncates_stale_tail(): + # Editing an EARLIER assistant turn (divergence at k) prunes the stale tail + # from k so it can't keep shadowing future requests; nothing is spliced and + # the matched prefix is kept. (Restoring hits for the edited turn isn't + # possible -- we never generated its ids -- but staleness is bounded.) + from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, + ) + from executorch.extension.llm.server.python.protocol import ChatMessage + + t = OpenAITranscriptState(ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)) + t.record_assistant_turn( + session_id="s", + content="a0", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + ) + t.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + ) + msgs = [ + ChatMessage(role="user", content="u0"), + ChatMessage(role="assistant", content="a0-EDITED"), + ChatMessage(role="user", content="u1"), + ] + out = t.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt="X", + tools=None, + template_kwargs=None, + ) + assert out.text == "X" # diverged -> plain text fallback + assert t._turns["s"] == [] # stale tail pruned from the first mismatch + + +class _HFToolSpecials: + # Tokenizer that marks a turn terminator AND tool/structural delimiters special. + all_special_tokens = ["<|im_end|>", "", "", "<|box_start|>"] + eos_token = "<|im_end|>" + + +def test_stop_set_narrow_but_strip_set_broad(): + # Two-set split (work item 1): the generation/pre-parse-truncation set is + # NARROW (turn terminators only) so a is never halted or cut + # before the parser sees it; the final content-strip set stays BROAD so stray + # specials can't leak into visible content. + from types import SimpleNamespace + + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + template._hf = _HFToolSpecials() + serving = ServingChat(_RaisingRuntime(), template, "test-model") + + assert "<|im_end|>" in serving._stops # terminator kept + assert "" not in serving._stops # delimiter excluded + assert "" not in serving._stops + assert "<|box_start|>" not in serving._stops + + assert "" in serving._content_specials # broad strip keeps it + assert "<|box_start|>" in serving._content_specials + + # The insidious site: _truncate_raw must NOT cut at (it uses the + # narrow set), so the full tool-call markup survives to the parser. + raw = ( + "sure\n\n\n1\n" + "\n\n" + ) + assert serving._truncate_raw(raw, SimpleNamespace(stop=None)) == raw + + +def test_injected_assistant_exemplar_falls_back_to_text(): + # 5d: a client-injected assistant turn we never generated (few-shot exemplar / + # pre-seeded turn) shifts the ordinal alignment -> fingerprint mismatch -> + # safe text fallback (no stale ids spliced). + from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, + ) + from executorch.extension.llm.server.python.protocol import ChatMessage + + t = OpenAITranscriptState(ChatTemplate(hf_tokenizer_path=None, allow_fallback=True)) + t.record_assistant_turn( + session_id="s", + content="a0", + tool_calls=None, + generated_token_ids=[1, 2], + prior_turns=0, + ) + msgs = [ + ChatMessage(role="user", content="u0"), + ChatMessage(role="assistant", content="INJECTED EXEMPLAR"), # not ours + ChatMessage(role="user", content="u1"), + ] + out = t.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt="X", + tools=None, + template_kwargs=None, + ) + assert out.text == "X" and out.segments is None # safe text fallback diff --git a/extension/llm/server/python/tests/test_streaming_stops.py b/extension/llm/server/python/tests/test_streaming_stops.py new file mode 100644 index 00000000000..f295422c7db --- /dev/null +++ b/extension/llm/server/python/tests/test_streaming_stops.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Plain-chat streaming special-token cleanup (WI1). + +Non-streaming scrubs broad content specials (the full all_special_tokens set) from +visible content via _strip_specials. Plain-chat streaming must be consistent: a +broad special that is NOT a turn terminator (e.g. <|fim_pad|>) must not reach the +client, and -- since trimming it makes the turn's visible text != generated text +-- the worker halts (omitting ids), so the turn is recorded non-resumable. Tool +turns keep the narrow terminator set so a is never cut before parsing. +""" + +import json + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.server import build_app +from executorch.extension.llm.server.python.serving_chat import ServingChat +from executorch.extension.llm.server.python.session_runtime import SessionRuntime +from executorch.extension.llm.server.python.tool_parsers import HermesDetector +from executorch.extension.llm.server.python.worker_client import WorkerError +from fastapi.testclient import TestClient + +FIM = "<|fim_pad|>" # a broad content special that is NOT a turn terminator +WEATHER_TOOLS = [ + {"type": "function", "function": {"name": "get_weather", "parameters": {}}} +] + + +class _SpecialTok: + """Fake HF tokenizer whose special set is broader than the turn terminators: + eos=<|im_end|> (a terminator) plus <|fim_pad|> (broad-only).""" + + eos_token = "<|im_end|>" + all_special_tokens = ["<|im_end|>", FIM] + + def encode(self, text, add_special_tokens=False): + return [0] * 5 + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kw + ): + return "PROMPT" + + +class _Runner: + """Fake worker. With honor_stops it models the real worker's stop-trim: a + stop string halts generation, the stop and everything after is dropped, and + the turn is non-resumable (generated_token_ids omitted).""" + + def __init__(self, tokens, gen_ids=None, honor_stops=False, max_named=4): + self._tokens = list(tokens) + self._gen_ids = list(gen_ids or []) + self._honor = honor_stops + self.max_named_sessions = max_named + self.open_named = set() + self.captured_config = None + + def reset(self): + pass + + def stop(self): + pass + + def open_session(self, sid): + if sid in self.open_named: + return + if self.max_named_sessions == 0: + raise WorkerError("no named sessions", code="unsupported_session") + if len(self.open_named) >= self.max_named_sessions: + raise WorkerError("capacity", code="capacity_exhausted") + self.open_named.add(sid) + + def close_session(self, sid): + self.open_named.discard(sid) + + def reset_session(self, sid): + pass + + def generate(self, prompt, config, token_callback=None, stats_callback=None): + self.captured_config = config + stops = list(getattr(config, "stop", []) or []) if self._honor else [] + emitted, trimmed = 0, False + for tok in self._tokens: + if any(s and s in tok for s in stops): + trimmed = True + break + if token_callback: + token_callback(tok) + emitted += 1 + if stats_callback: + stats = type("S", (), {})() + stats.num_prompt_tokens = 5 + stats.num_generated_tokens = emitted + stats.finish_reason = "stop" if trimmed else None + stats.generated_token_ids = [] if trimmed else list(self._gen_ids) + stats_callback(stats) + + +def _serving(tokens, honor_stops=False, gen_ids=None): + runner = _Runner(tokens, gen_ids=gen_ids, honor_stops=honor_stops) + template = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + template._hf = _SpecialTok() + serving = ServingChat( + SessionRuntime(runner), template, "test-model", tool_detector_cls=HermesDetector + ) + return serving, runner + + +def _client(serving): + return TestClient(build_app(serving, "test-model")) + + +def _sse_content(text): + content, finish = "", None + for line in text.splitlines(): + if line.startswith("data:") and "[DONE]" not in line: + d = json.loads(line[5:]) + for ch in d.get("choices", []): + content += (ch.get("delta", {}) or {}).get("content") or "" + if ch.get("finish_reason"): + finish = ch["finish_reason"] + return content, finish + + +def test_plain_chat_worker_stops_include_broad_specials(): + serving, runner = _serving(["hi"]) + _client(serving).post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert FIM in (runner.captured_config.stop or []) + + +def test_tool_path_worker_stops_exclude_broad_specials(): + serving, runner = _serving(["hi"]) + _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "tools": WEATHER_TOOLS, + }, + ) + # Tool turns keep the narrow set so a structural/tool delimiter isn't cut. + assert FIM not in (runner.captured_config.stop or []) + + +def test_plain_chat_streaming_does_not_leak_broad_special(): + serving, _ = _serving(["Hi", FIM, "leak"]) + r = _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + content, finish = _sse_content(r.text) + assert content == "Hi" # the special and everything after is cut + assert FIM not in content + assert finish == "stop" + + +def test_plain_chat_nonstreaming_matches_streaming_visible(): + serving, _ = _serving(["Hi", FIM, "leak"]) + r = _client(serving).post( + "/v1/chat/completions", + json={"model": "test-model", "messages": [{"role": "user", "content": "hi"}]}, + ) + choice = r.json()["choices"][0] + assert choice["message"]["content"] == "Hi" + assert choice["finish_reason"] == "stop" + + +def test_tool_streaming_not_broken_by_broad_special(): + tc = '\n{"name": "get_weather", "arguments": {}}\n' + serving, _ = _serving([tc]) + r = _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "weather?"}], + "tools": WEATHER_TOOLS, + "stream": True, + }, + ) + chunks = [ + json.loads(line[5:]) + for line in r.text.splitlines() + if line.startswith("data:") and "[DONE]" not in line + ] + finishes = [ + c["choices"][0]["finish_reason"] + for c in chunks + if c["choices"][0].get("finish_reason") + ] + has_tool = any( + (c["choices"][0].get("delta") or {}).get("tool_calls") for c in chunks + ) + assert "tool_calls" in finishes and has_tool + + +def test_plain_chat_broad_stop_marks_turn_nonresumable(): + # honor_stops: the worker trims at the broad special and omits ids; the + # transcript must record the turn as non-resumable (ids=None), not splice ids + # for text the client never saw. + serving, _ = _serving(["Hi", FIM, "leak"], honor_stops=True, gen_ids=[1, 2, 3]) + r = _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "session_id": "s", + }, + ) + assert r.json()["choices"][0]["message"]["content"] == "Hi" + assert serving._transcript._turns["s"][0]["ids"] is None + + +def test_user_request_stop_trims_and_nonresumable(): + serving, _ = _serving(["keep", "STOPHERE", "drop"], honor_stops=True, gen_ids=[9]) + r = _client(serving).post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "session_id": "s", + "stop": "STOPHERE", + }, + ) + assert r.json()["choices"][0]["message"]["content"] == "keep" + assert serving._transcript._turns["s"][0]["ids"] is None diff --git a/extension/llm/server/python/tests/test_warm_resume_scaffold.py b/extension/llm/server/python/tests/test_warm_resume_scaffold.py new file mode 100644 index 00000000000..fff89db5c94 --- /dev/null +++ b/extension/llm/server/python/tests/test_warm_resume_scaffold.py @@ -0,0 +1,541 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Warm-resume generation-scaffold reproduction (V2b.1.5). + +Qwen3's template prefills a deterministic ```` scaffold into the +generation prompt (so it lands in resident KV) but strips it when re-rendering a +turn as history *before* the last user message, while *preserving* it (as the +empty block) for turns after. The token-ID splice must reproduce each turn's +exact resident scaffold ahead of its generated ids, normalizing whatever the +history render put there -- inserting when stripped, replacing when a different +form was preserved -- so the worker's exact-token prefix check lands. +""" + +import os + +import pytest + +from executorch.extension.llm.server.python.chat_template import ChatTemplate +from executorch.extension.llm.server.python.openai_transcript import ( + OpenAITranscriptState, +) +from executorch.extension.llm.server.python.protocol import ( + ChatMessage, + FunctionCall, + ToolCall, +) + +HDR = "<|im_start|>assistant\n" +NOTHINK = "\n\n\n\n" # no-think generation preamble / preserved block +THINK = "\n" # think-mode generation preamble + + +def _msgs(*pairs): + return [ChatMessage(role=r, content=c) for r, c in pairs] + + +class _FakeQwen: + """Mimics Qwen3 scaffold behavior in render(): the generation prompt appends + the mode scaffold after the assistant header; history strips the scaffold for + assistant turns before the last user message and preserves the empty block + for turns after it (true in both modes -- the case that needs normalize).""" + + def __init__(self, default_thinking=False): + self._default_thinking = default_thinking + + def _gen(self, kw): + thinking = (kw or {}).get("enable_thinking", self._default_thinking) + return THINK if thinking else NOTHINK + + def render(self, messages, tools=None, template_kwargs=None): + last_user = max( + (i for i, m in enumerate(messages) if m.role == "user"), default=-1 + ) + out = [] + for i, m in enumerate(messages): + c = m.content if isinstance(m.content, str) else "" + if m.role == "assistant" and i > last_user: + out.append(f"{HDR}{NOTHINK}{c}<|im_end|>\n") # preserved empty block + else: + out.append(f"<|im_start|>{m.role}\n{c}<|im_end|>\n") + out.append(HDR + self._gen(template_kwargs)) + return "".join(out) + + +class _FakePlain: + """No-scaffold ChatML template (preamble '').""" + + def render(self, messages, tools=None, template_kwargs=None): + out = [ + f"<|im_start|>{m.role}\n" + f"{m.content if isinstance(m.content, str) else ''}<|im_end|>\n" + for m in messages + ] + out.append(HDR) + return "".join(out) + + +class _FakeOtherHeader: + """No-scaffold template whose assistant header is NOT the Qwen/ChatML one + (Llama-style), to prove token-id splicing isn't disabled for templates that + don't use ``<|im_start|>assistant\\n`` when the preamble is ''.""" + + OHDR = "<|start_header_id|>assistant<|end_header_id|>\n\n" + + def render(self, messages, tools=None, template_kwargs=None): + out = [] + for m in messages: + c = m.content if isinstance(m.content, str) else "" + if m.role == "assistant": + out.append(f"{self.OHDR}{c}<|eot_id|>") + else: + out.append( + f"<|start_header_id|>{m.role}<|end_header_id|>\n\n{c}<|eot_id|>" + ) + out.append(self.OHDR) + return "".join(out) + + +def _ids_index(segs, ids): + for i, s in enumerate(segs): + if s.get("ids") == ids: + return i + return -1 + + +def _text_before_ids(segs, ids): + i = _ids_index(segs, ids) + assert i > 0 and "text" in segs[i - 1], "expected a {text} segment before {ids}" + return segs[i - 1]["text"] + + +def _scaffold_before(segs, ids): + """The scaffold region: text after the last assistant header preceding ids.""" + return _text_before_ids(segs, ids).rsplit(HDR, 1)[-1] + + +# --- 5a. Hermetic unit tests (no model) ------------------------------------- + + +def test_nothink_ordinary_append_inserts_scaffold(): + st = OpenAITranscriptState(_FakeQwen(default_thinking=False)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[10, 11, 12], + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": False} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + # History stripped the scaffold; the fix inserts exactly one copy. + assert _scaffold_before(pi.segments, [10, 11, 12]) == NOTHINK + assert _text_before_ids(pi.segments, [10, 11, 12]).count(NOTHINK) == 1 + + +def test_think_ordinary_append_inserts_open_scaffold(): + st = OpenAITranscriptState(_FakeQwen(default_thinking=True)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[1, 2], + prior_turns=0, + preamble=THINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [1, 2]) == THINK + + +def test_think_toolloop_normalizes_preserved_scaffold(): + # Turn generated in THINK mode (preamble open-think) but rendered as a + # post-last-user turn, where history preserves the *empty* block. The fix + # must REPLACE that block with the stored open-think preamble -- not keep it + # (wrong scaffold) and not append a second one (double-insert). + st = OpenAITranscriptState(_FakeQwen(default_thinking=True)) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[7, 8, 9], + prior_turns=0, + preamble=THINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1")) # a1 AFTER last user + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [7, 8, 9]) == THINK + # the preserved empty block was replaced, not kept and not doubled + assert NOTHINK not in _text_before_ids(pi.segments, [7, 8, 9]) + + +def test_no_scaffold_template_is_unchanged(): + st = OpenAITranscriptState(_FakePlain()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[5], + prior_turns=0, + preamble="", + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [5]) == "" # nothing inserted, no regression + + +def test_non_qwen_header_no_scaffold_still_splices(): + # Regression: a no-scaffold template whose assistant header isn't the + # Qwen/ChatML one must still get token-id splicing (the normalization is a + # no-op when preamble == "", not a hard requirement for the Qwen header). + st = OpenAITranscriptState(_FakeOtherHeader()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[9, 9], + prior_turns=0, + preamble="", + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs), + tools=None, + template_kwargs=None, + ) + assert pi.segments is not None # splicing NOT disabled by the missing header + assert any(s.get("ids") == [9, 9] for s in pi.segments) # ids actually spliced + + +def test_stop_trimmed_turn_falls_back_to_text(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[], # stop-trimmed -> ids None -> not resumable + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "a1"), ("user", "u2")) + kw = {"enable_thinking": False} + rendered = st._template.render(msgs, template_kwargs=kw) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=None, + template_kwargs=kw, + ) + assert pi.segments is None and pi.text == rendered + + +def test_fingerprint_mismatch_falls_back_to_text(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[10], + prior_turns=0, + preamble=NOTHINK, + ) + msgs = _msgs(("user", "u1"), ("assistant", "EDITED"), ("user", "u2")) + kw = {"enable_thinking": False} + rendered = st._template.render(msgs, template_kwargs=kw) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=rendered, + tools=None, + template_kwargs=kw, + ) + assert pi.segments is None and pi.text == rendered + + +def test_mode_switch_uses_per_turn_scaffold(): + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content="a1", + tool_calls=None, + generated_token_ids=[1], + prior_turns=0, + preamble=NOTHINK, # turn 1 generated no-think + ) + st.record_assistant_turn( + session_id="s", + content="a2", + tool_calls=None, + generated_token_ids=[2], + prior_turns=1, + preamble=THINK, # turn 2 generated think + ) + msgs = _msgs( + ("user", "u1"), + ("assistant", "a1"), + ("user", "u2"), + ("assistant", "a2"), + ("user", "u3"), + ) + kw = {"enable_thinking": True} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assert _scaffold_before(pi.segments, [1]) == NOTHINK + assert _scaffold_before(pi.segments, [2]) == THINK + + +# --- Tool-call argument fingerprint canonicalization ------------------------ + + +def _fp(content, tool_calls): + return OpenAITranscriptState._assistant_fingerprint(content, tool_calls) + + +def _dtc(name, args): + return {"function": {"name": name, "arguments": args}} + + +def test_fingerprint_ignores_tool_arg_whitespace(): + assert _fp(None, [_dtc("bash", '{"command": "echo hi"}')]) == _fp( + None, [_dtc("bash", '{"command":"echo hi"}')] + ) + + +def test_fingerprint_ignores_tool_arg_key_order(): + assert _fp(None, [_dtc("f", '{"x": 1, "y": 2}')]) == _fp( + None, [_dtc("f", '{"y": 2, "x": 1}')] + ) + + +def test_fingerprint_invalid_json_args_stay_byte_sensitive(): + # Non-JSON arguments can't be canonicalized, so they stay literal: a + # genuinely different string remains a different turn. + assert _fp(None, [_dtc("f", "not json {")]) != _fp( + None, [_dtc("f", "not json { ")] + ) + + +def test_fingerprint_non_string_args_match_equivalent_json_string(): + # Already-structured args hash stably and match the equivalent JSON string. + assert _fp(None, [_dtc("f", {"x": 1})]) == _fp(None, [_dtc("f", '{"x": 1}')]) + + +def test_tool_turn_splices_despite_reserialized_args(): + # End-to-end: the server recorded a spaced arguments string; the client echoes + # the same call back compact (the real pi behavior). The turn must still + # fingerprint-match and splice -- not prune to a text fallback. + st = OpenAITranscriptState(_FakeQwen()) + st.record_assistant_turn( + session_id="s", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="bash", arguments='{"command": "echo hi"}'), + ) + ], + generated_token_ids=[1, 2, 3], + prior_turns=0, + preamble=NOTHINK, + ) + echoed = ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="bash", arguments='{"command":"echo hi"}'), + ) + ], + ) + msgs = [ + ChatMessage(role="user", content="u1"), + echoed, + ChatMessage(role="user", content="u2"), + ] + kw = {"enable_thinking": False} + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=st._template.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None # matched + spliced, not pruned to text + assert any(s.get("ids") == [1, 2, 3] for s in pi.segments) + + +# --- 5b. Token-level fidelity against the real tokenizer (gated/skipped) ----- + +_MODEL = os.environ.get( + "QWEN_HF_DIR", "/home/mnachin/local/scripts/models/Qwen3.5-35B-A3B-HQQ-INT4" +) +_HAVE_MODEL = os.path.isdir(_MODEL) +_skip = pytest.mark.skipif( + not _HAVE_MODEL, reason=f"real Qwen tokenizer dir not present: {_MODEL}" +) + + +def _real_template_and_enc(): + pytest.importorskip("transformers") + from executorch.extension.llm.server.python.chat_template import ChatTemplate + from transformers import AutoTokenizer + + tmpl = ChatTemplate(hf_tokenizer_path=_MODEL) + tok = AutoTokenizer.from_pretrained(_MODEL) + # Encode the way the worker does: no extra special tokens (the rendered text + # already contains the literal <|im_*|> / control strings). + return tmpl, (lambda s: tok.encode(s, add_special_tokens=False)) + + +def _assemble(segs, enc): + out = [] + for seg in segs: + out += seg["ids"] if "ids" in seg else enc(seg["text"]) + return out + + +@_skip +@pytest.mark.parametrize("thinking", [False, True]) +def test_token_level_exact_prefix_ordinary(thinking): + tmpl, enc = _real_template_and_enc() + kw = {"enable_thinking": thinking} + st = OpenAITranscriptState(tmpl) + content = "Mercury, Venus, Earth." + gen_ids = enc(content) # stand-in for the worker's generated_token_ids + gen_prompt1 = tmpl.render(_msgs(("user", "u1")), template_kwargs=kw) + resident = enc(gen_prompt1) + gen_ids + st.record_assistant_turn( + session_id="s", + content=content, + tool_calls=None, + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(kw), + ) + msgs = _msgs(("user", "u1"), ("assistant", content), ("user", "u2")) + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=tmpl.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assembled = _assemble(pi.segments, enc) + # resident is an exact token prefix => plan_prefill returns exact_prefix and + # reuses exactly len(resident) tokens. + assert assembled[: len(resident)] == resident + + +@_skip +def test_token_level_exact_prefix_toolloop_think(): + # Mandatory: post-last-user turn where the template preserves a think block + # before the sentinel; the fix must normalize it to the stored open-think + # preamble so the token prefix still lands. + tmpl, enc = _real_template_and_enc() + kw = {"enable_thinking": True} + st = OpenAITranscriptState(tmpl) + content = "result is 42" + gen_ids = enc(content) + gen_prompt1 = tmpl.render(_msgs(("user", "u1")), template_kwargs=kw) + resident = enc(gen_prompt1) + gen_ids + st.record_assistant_turn( + session_id="s", + content=content, + tool_calls=None, + generated_token_ids=gen_ids, + prior_turns=0, + preamble=tmpl.generation_preamble(kw), + ) + msgs = _msgs(("user", "u1"), ("assistant", content)) # a1 AFTER last user + pi = st.build_prompt_input( + session_id="s", + messages=msgs, + rendered_prompt=tmpl.render(msgs, template_kwargs=kw), + tools=None, + template_kwargs=kw, + ) + assert pi.segments is not None + assembled = _assemble(pi.segments, enc) + assert assembled[: len(resident)] == resident + + +# --- WI4a: generation_preamble threads tools -------------------------------- + + +def test_generation_preamble_threads_tools(): + # generation_preamble must pass `tools` to the render probe and key the cache + # on them, so a template whose post-header scaffold depends on tools gets the + # right preamble for each (and never serves a stale cached one). + class _ToolScaffoldTok: + eos_token = "<|im_end|>" + all_special_tokens = ["<|im_end|>"] + + def encode(self, text, add_special_tokens=False): + return [0] + + def apply_chat_template( + self, messages, tools, add_generation_prompt, tokenize, **kwargs + ): + scaffold = "" if tools else "" + return "<|im_start|>assistant\n" + scaffold + + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + t._hf = _ToolScaffoldTok() + assert t.generation_preamble(tools=None) == "" + assert ( + t.generation_preamble(tools=[{"type": "function", "function": {"name": "f"}}]) + == "" + ) + # cached separately -> the no-tool value is not shadowed by the tool one + assert t.generation_preamble(tools=None) == ""