diff --git a/extension/llm/server/__init__.py b/extension/llm/server/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/extension/llm/server/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/extension/llm/server/python/__init__.py b/extension/llm/server/python/__init__.py new file mode 100644 index 00000000000..00b6274c01f --- /dev/null +++ b/extension/llm/server/python/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible server for ExecuTorch LLMs (Python implementation).""" diff --git a/extension/llm/server/python/chat_template.py b/extension/llm/server/python/chat_template.py new file mode 100644 index 00000000000..cbb3eff80bf --- /dev/null +++ b/extension/llm/server/python/chat_template.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Render OpenAI chat messages into a single prompt string. + +The ExecuTorch runner tokenizes a plain prompt; chat formatting is the server's +job (control plane). We require the model's own Hugging Face ``chat_template`` +(via ``--hf-tokenizer``) for correct, tool-aware, reasoning-aware formatting. +The generic ChatML fallback is opt-in only (``allow_fallback``): it is +approximate and cannot reproduce model-specific controls (e.g. enable_thinking), +so it must be a deliberate choice rather than a silent default. +""" + +import json +import logging +from typing import Any, Optional + +from .protocol import ChatMessage + +logger = logging.getLogger(__name__) + + +_DEFAULT_SPECIAL_TOKENS = ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "<|end|>"] + +# Chat turn terminators eligible to be used as generation stop strings. This is a +# deliberate allowlist of end-of-turn / end-of-text tokens -- NOT the tokenizer's +# full special-token set. Structural/tool delimiters (e.g. ) must reach +# the tool parser, so they are intentionally excluded: using them as hard stops +# would truncate a tool call before it is ever parsed. +_TURN_TERMINATORS = ( + "<|im_end|>", + "<|endoftext|>", + "<|eot_id|>", + "<|end|>", + "<|end_of_text|>", + "", + "", +) + + +def _content_text(content) -> str: + """Best-effort text for the ChatML fallback: a str as-is, or the concatenated + text parts of an OpenAI list-content message (non-text parts dropped). Avoids + rendering a Python repr of structured content. None -> empty string.""" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + out.append(str(part.get("text", ""))) + elif isinstance(part, str): + out.append(part) + return "".join(out) + return str(content or "") + + +def _decode_tool_call_arguments(messages: list[dict[str, Any]]) -> None: + """In-place: parse each tool call's ``function.arguments`` from a JSON string + into an object. + + OpenAI sends assistant tool-call arguments as a JSON-encoded string, but HF + chat templates expect a mapping (e.g. Qwen renders ``arguments|items`` into + ```` tags). Without this, a multi-turn tool conversation makes + the template raise "Can only get item pairs from a mapping". Left as-is if + the value isn't valid JSON, so a template that wants the raw string still works. + """ + for m in messages: + for tc in m.get("tool_calls") or []: + fn = tc.get("function") + if not isinstance(fn, dict): + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + fn["arguments"] = json.loads(args) + except (ValueError, TypeError): + pass + + +class ChatTemplate: + def __init__( + self, + hf_tokenizer_path: Optional[str] = None, + default_template_kwargs: Optional[dict[str, Any]] = None, + allow_fallback: bool = False, + ): + # Server-level defaults (e.g. {"enable_thinking": False}); per-request + # chat_template_kwargs override these. + self._defaults = default_template_kwargs or {} + self._hf = None + if hf_tokenizer_path: + from transformers import AutoTokenizer + + self._hf = AutoTokenizer.from_pretrained(hf_tokenizer_path) + if self._hf.chat_template is None: + self._hf = None + if not allow_fallback: + raise ValueError( + f"HF tokenizer at {hf_tokenizer_path} has no chat_template; " + "pass an explicit fallback flag to use approximate ChatML." + ) + logger.warning( + "No chat_template at %s; using approximate ChatML.", + hf_tokenizer_path, + ) + elif not allow_fallback: + raise ValueError( + "A chat template is required: pass --hf-tokenizer for the model's own " + "template, or opt into approximate ChatML with --allow-chatml-fallback." + ) + else: + logger.warning( + "No --hf-tokenizer; using approximate ChatML (no thinking control)." + ) + + def render( + self, + messages: list[ChatMessage], + tools: Optional[list[dict[str, Any]]] = None, + template_kwargs: Optional[dict[str, Any]] = None, + ) -> str: + kwargs = {**self._defaults, **(template_kwargs or {})} + if self._hf is not None: + dumped = [m.model_dump(exclude_none=True) for m in messages] + _decode_tool_call_arguments(dumped) + return self._hf.apply_chat_template( + dumped, + tools=tools, + add_generation_prompt=True, + tokenize=False, + **kwargs, + ) + return self._fallback(messages) + + def chat_template_str(self) -> Optional[str]: + """Raw chat-template string (for tool-format auto-detection), if available.""" + return ( + getattr(self._hf, "chat_template", None) if self._hf is not None else None + ) + + def count_tokens(self, prompt: str) -> Optional[int]: + """Token count for the rendered prompt, or None if no tokenizer is available.""" + if self._hf is not None: + # The prompt is already rendered (apply_chat_template includes the + # control tokens), so encode without re-adding BOS/EOS — matching the + # session/prefix-cache paths, so the count isn't inflated and + # near-limit requests aren't falsely rejected under --max-context. + return len(self._hf.encode(prompt, add_special_tokens=False)) + return None + + def turn_stop_sequences(self) -> list[str]: + """Generation stop strings: model/template-specific *turn terminators* + only -- the tokenizer's EOS plus known chat turn-end tokens -- NOT the + full special-token set. + + Structural/tool delimiters (e.g. ) are deliberately excluded: + if a tokenizer registers them as special, using the whole special set as + hard stops would halt generation at the delimiter and truncate the tool + call before the parser ever sees it. Whitespace-only tokens are dropped. + User-supplied request `stop` strings are handled separately and are not + affected by this set. + + May return [] if the tokenizer has no eos_token and registers none of the + known terminators as special; in that case end-of-turn detection relies + entirely on the worker's EOS-by-token-id check (e.g. the Qwen engine adds + <|im_end|> to eos_ids), so the string set here is only a backstop. + """ + if self._hf is None: + return list(_DEFAULT_SPECIAL_TOKENS) + specials = { + t + for t in (getattr(self._hf, "all_special_tokens", []) or []) + if isinstance(t, str) and t.strip() + } + out: list[str] = [] + eos = getattr(self._hf, "eos_token", None) + if isinstance(eos, str) and eos.strip(): + out.append(eos) + for t in _TURN_TERMINATORS: + if t in specials and t not in out: + out.append(t) + return out + + def special_tokens(self) -> list[str]: + """ALL special-token strings, for final content cleanup -- stripping any + special token that leaked into visible output. Deliberately broad, and + distinct from turn_stop_sequences(): this set must NOT be used as + generation stops or pre-parse truncation (that would halt/cut a tool call + at a structural delimiter), only to scrub trailing specials from the + already-parsed visible content. Whitespace-only tokens are dropped so a + stray ' ' token can't truncate content at the first double space. + """ + if self._hf is not None: + toks = list(getattr(self._hf, "all_special_tokens", []) or []) + return [t for t in toks if isinstance(t, str) and t.strip()] + return list(_DEFAULT_SPECIAL_TOKENS) + + @staticmethod + def _fallback(messages: list[ChatMessage]) -> str: + # Approximate ChatML, TEXT-ONLY. Provide --hf-tokenizer for model-correct + # formatting (reasoning controls like enable_thinking, and structured + # tool/multimodal turns, which this fallback cannot reproduce). This path + # renders only text content: assistant `tool_calls` and a tool-role + # `tool_call_id` are dropped, so it is NOT a correctness path for tool or + # multimodal conversations -- use a real --hf-tokenizer for those. + parts = [] + for m in messages: + content = _content_text(m.content) + parts.append(f"<|im_start|>{m.role}\n{content}<|im_end|>") + parts.append("<|im_start|>assistant\n") + return "\n".join(parts) diff --git a/extension/llm/server/python/errors.py b/extension/llm/server/python/errors.py new file mode 100644 index 00000000000..f24df43f2e8 --- /dev/null +++ b/extension/llm/server/python/errors.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-shaped API errors. + +Raising these lets the server return a structured `{"error": {...}}` body with +the right HTTP status instead of dropping the connection. +""" + +from typing import Optional + + +class APIError(Exception): + def __init__( + self, status: int, message: str, err_type: str, code: Optional[str] = None + ): + super().__init__(message) + self.status = status + self.message = message + self.err_type = err_type + self.code = code + + def body(self) -> dict: + return { + "error": {"message": self.message, "type": self.err_type, "code": self.code} + } + + +class ContextLengthExceeded(APIError): + def __init__(self, num_tokens: int, max_context: int, completion_tokens: int = 0): + # completion_tokens > 0: the prompt fits but prompt + requested + # max_tokens would run past the window — reject up front rather than + # fail (or truncate) mid-generation. + if completion_tokens > 0: + message = ( + f"This model's maximum context length is {max_context} tokens. " + f"However, you requested {num_tokens + completion_tokens} tokens " + f"({num_tokens} in the messages, {completion_tokens} in the " + f"completion). Please reduce the length of the messages or " + f"completion." + ) + else: + message = ( + f"This model's maximum context length is {max_context} tokens, " + f"but the request has {num_tokens} prompt tokens." + ) + super().__init__( + status=400, + message=message, + err_type="invalid_request_error", + code="context_length_exceeded", + ) + + +class GenerationError(APIError): + def __init__(self, detail: str): + super().__init__( + status=500, message=f"Generation failed: {detail}", err_type="server_error" + ) diff --git a/extension/llm/server/python/protocol.py b/extension/llm/server/python/protocol.py new file mode 100644 index 00000000000..2d73d2d7f64 --- /dev/null +++ b/extension/llm/server/python/protocol.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenAI-compatible request/response schemas for the ExecuTorch LLM server. + +This is the Python view of the contract defined in ``extension/llm/server/spec``. +Any language server must serialize to the same shapes; the conformance suite in +``extension/llm/server/conformance`` validates them. +""" + +import time +import uuid +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +def _new_id(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex}" + + +class FunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + index: Optional[int] = None + id: Optional[str] = None + type: Literal["function"] = "function" + function: FunctionCall + + +class ChatMessage(BaseModel): + role: str + content: Optional[Union[str, list[dict[str, Any]]]] = None + name: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + tool_call_id: Optional[str] = None + + +class StreamOptions(BaseModel): + include_usage: bool = False + + +class ChatCompletionRequest(BaseModel): + model: Optional[str] = None + messages: list[ChatMessage] + stream: bool = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + stop: Optional[Union[str, list[str]]] = None + n: int = 1 + seed: Optional[int] = None + # Sampling knobs that change generation output. We don't plumb these, so they + # are modeled (not dropped) in order to be rejected with a clear error rather + # than silently ignored — see serving_chat's unsupported-parameter check. + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + top_k: Optional[int] = None + logit_bias: Optional[dict[str, float]] = None + # Output-contract fields: modeled (not dropped) so we reject the ones we + # can't honor rather than returning an output that violates what was asked. + response_format: Optional[dict[str, Any]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + parallel_tool_calls: Optional[bool] = None + # Per-request chat-template controls, e.g. {"enable_thinking": false} for Qwen3. + chat_template_kwargs: Optional[dict[str, Any]] = None + # Accepted now so the contract is stable; parsing/enforcement land in M2/M5. + tools: Optional[list[dict[str, Any]]] = None + tool_choice: Optional[Union[str, dict[str, Any]]] = None + reasoning_effort: Optional[str] = None + + def resolved_max_tokens(self) -> int: + # `is not None` (not `or`): an explicit 0 must not be treated as unset. + # Callers validate positivity; -1 means "unset / auto". + if self.max_completion_tokens is not None: + return self.max_completion_tokens + if self.max_tokens is not None: + return self.max_tokens + return -1 + + +class Usage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ResponseMessage(BaseModel): + role: str = "assistant" + content: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + + +class Choice(BaseModel): + index: int = 0 + message: ResponseMessage + finish_reason: Optional[str] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: _new_id("chatcmpl")) + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[Choice] + usage: Usage = Field(default_factory=Usage) + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + tool_calls: Optional[list[ToolCall]] = None + + +class ChunkChoice(BaseModel): + index: int = 0 + delta: DeltaMessage + finish_reason: Optional[str] = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChunkChoice] + usage: Optional[Usage] = None + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "executorch" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: list[ModelCard] diff --git a/extension/llm/server/python/requirements.txt b/extension/llm/server/python/requirements.txt new file mode 100644 index 00000000000..70ad7ccb4dd --- /dev/null +++ b/extension/llm/server/python/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.110 +uvicorn[standard]>=0.27 +pydantic>=2.0 +# Optional but recommended for model-correct chat templating (--hf-tokenizer): +# transformers>=4.40 diff --git a/extension/llm/server/python/tests/test_hermes_tool_parser.py b/extension/llm/server/python/tests/test_hermes_tool_parser.py new file mode 100644 index 00000000000..3bfebabdb9f --- /dev/null +++ b/extension/llm/server/python/tests/test_hermes_tool_parser.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for HermesDetector (Hermes/Qwen JSON format). + +Covers the explicit all-or-nothing malformed-call policy and the no-markup-leak +guarantee: an undefined/malformed/truncated call degrades to the leading text +with the markup stripped, never surfaced to the client. +""" + +import json + +from executorch.extension.llm.server.python.tool_parsers import HermesDetector + +_TOOLS = { + "get_weather": {"type": "object", "properties": {"city": {"type": "string"}}}, + "echo": {"type": "object", "properties": {"text": {"type": "string"}}}, +} + + +def _parse(text, tools=_TOOLS): + return HermesDetector().detect_and_parse(text, tools) + + +def test_basic_call(): + text = ( + '{"name": "get_weather", "arguments": {"city": "Paris"}}' + ) + r = _parse(text) + assert len(r.calls) == 1 and r.calls[0].name == "get_weather" + assert json.loads(r.calls[0].arguments) == {"city": "Paris"} + + +def test_multiple_calls_still_parse(): + text = ( + '{"name": "echo", "arguments": {"text": "a"}}' + '{"name": "echo", "arguments": {"text": "b"}}' + ) + r = _parse(text) + assert [json.loads(c.arguments)["text"] for c in r.calls] == ["a", "b"] + + +def test_no_tool_call_is_passthrough(): + r = _parse("just some text") + assert not r.calls and r.normal_text == "just some text" + + +def test_malformed_block_with_valid_sibling_degrades_no_leak(): + # All-or-nothing: one malformed block degrades the WHOLE response (the valid + # sibling is NOT emitted in isolation), and no markup leaks. + text = ( + 'lead{"name": "echo", "arguments": {"text": "ok"}}' + "{bad json}" + ) + r = _parse(text) + assert not r.calls + assert "" not in r.normal_text + assert r.normal_text == "lead" + + +def test_unclosed_marker_degrades_no_leak(): + text = 'lead{"name": "echo", "arguments": {"text": "x"}}' + r = _parse(text) + assert not r.calls + assert "" not in r.normal_text + assert r.normal_text == "lead" + + +def test_string_value_containing_close_marker_not_truncated(): + # A JSON string value containing literal must not truncate the + # captured JSON (raw_decode parses the whole object regardless). + text = ( + '{"name": "echo", "arguments": ' + '{"text": "a b"}}' + ) + r = _parse(text) + assert len(r.calls) == 1 + assert json.loads(r.calls[0].arguments) == {"text": "a b"} + + +def test_arguments_null_falls_back_to_parameters(): + text = ( + '{"name": "echo", "arguments": null, ' + '"parameters": {"text": "p"}}' + ) + r = _parse(text) + assert json.loads(r.calls[0].arguments) == {"text": "p"} + + +def test_undefined_tool_degrades_to_full_text(): + # A WELL-FORMED call to an undefined tool degrades the whole response to + # visible text (unchanged policy: surface the model's intent, never a partial + # set). This differs from the malformed/truncated case, which strips markup. + text = 'hi{"name": "nope", "arguments": {}}' + r = _parse(text) + assert not r.calls + assert "" in r.normal_text # full text, markup visible diff --git a/extension/llm/server/python/tests/test_qwen_tool_parser.py b/extension/llm/server/python/tests/test_qwen_tool_parser.py new file mode 100644 index 00000000000..3c54539dfd3 --- /dev/null +++ b/extension/llm/server/python/tests/test_qwen_tool_parser.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for QwenFunctionCallDetector (Qwen XML tool format).""" + +import json + +from executorch.extension.llm.server.python.tool_parsers import QwenFunctionCallDetector + +# name -> JSON-schema `parameters` (as the server passes it to the detector). +_TOOLS = { + "get_weather": {"type": "object", "properties": {"city": {"type": "string"}}}, + "add": { + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + }, +} + + +def _parse(text, tools=_TOOLS): + return QwenFunctionCallDetector().detect_and_parse(text, tools) + + +def test_basic_call(): + text = ( + "\n\n\nParis\n" + "\n\n" + ) + r = _parse(text) + assert len(r.calls) == 1 + assert r.calls[0].name == "get_weather" + assert json.loads(r.calls[0].arguments) == {"city": "Paris"} + assert r.normal_text == "" + + +def test_observed_model_output(): + # The exact shape seen from Qwen3.5-MoE during the live smoke. + text = ( + "\n\n\nParis\n" + "\n\n" + ) + r = _parse(text) + assert [c.name for c in r.calls] == ["get_weather"] + + +def test_numeric_and_multi_param_coercion(): + text = ( + "2" + "3" + ) + r = _parse(text) + assert json.loads(r.calls[0].arguments) == {"a": 2, "b": 3} + + +def test_multiple_calls(): + text = ( + "Paris" + "1" + ) + r = _parse(text) + assert [c.name for c in r.calls] == ["get_weather", "add"] + assert [c.tool_index for c in r.calls] == [0, 1] + + +def test_leading_text_preserved(): + text = "Let me check.Paris" + r = _parse(text) + assert r.normal_text == "Let me check." + assert len(r.calls) == 1 + + +def test_no_tool_call_is_plain_text(): + text = "The capital of France is Paris." + r = _parse(text) + assert r.calls == [] + assert r.normal_text == text + + +def test_undefined_tool_degrades_to_text(): + # A call to a tool not in the request -> whole response kept as visible text. + text = "1" + r = _parse(text) + assert r.calls == [] + assert r.normal_text == text + + +def test_missing_tool_call_wrapper_still_parses(): + # Tolerate a truncated/absent wrapper as long as the function + # block is complete. + text = "Paris" + r = _parse(text) + assert len(r.calls) == 1 + assert json.loads(r.calls[0].arguments) == {"city": "Paris"} + + +# Schema-aware coercion: the XML format is stringly-typed, so values must be cast +# to the declared schema type (the cause of several BFCL function-calling misses). +def test_boolean_value_coerced_by_schema(): + tools = {"f": {"properties": {"flag": {"type": "boolean"}}}} + # The model writes a non-JSON capitalized "True"; the schema says boolean. + text = "True" + r = _parse(text, tools) + assert json.loads(r.calls[0].arguments) == {"flag": True} + + +def test_string_schema_keeps_numeric_literal_as_string(): + tools = {"f": {"properties": {"id": {"type": "string"}}}} + # A numeric-looking value the schema declares as a string must NOT become int. + text = "1234" + r = _parse(text, tools) + args = json.loads(r.calls[0].arguments) + assert args == {"id": "1234"} and isinstance(args["id"], str) + + +def test_untyped_param_falls_back_to_json_guess(): + # No declared type -> best-effort JSON guess (so loosely-typed tools still work). + tools = {"f": {"properties": {}}} + text = ( + "42" + "[1, 2]" + ) + r = _parse(text, tools) + assert json.loads(r.calls[0].arguments) == {"n": 42, "items": [1, 2]} + + +_TYPED = { + "code_tool": {"type": "object", "properties": {"code": {"type": "string"}}}, + "calc": { + "type": "object", + "properties": { + "n": {"type": "integer"}, + "x": {"type": "number"}, + "flag": {"type": "boolean"}, + }, + }, +} + + +def test_param_value_with_literal_parameter_close(): + # A value containing literal must be preserved, not truncated. + text = "a b" + r = _parse(text, _TYPED) + assert json.loads(r.calls[0].arguments) == {"code": "a b"} + + +def test_param_value_with_function_markup(): + # A value containing markup must stay in the value, not split. + text = ( + "x = " + ) + r = _parse(text, _TYPED) + assert len(r.calls) == 1 + assert json.loads(r.calls[0].arguments) == {"code": "x = "} + + +def test_declared_integer_with_float_string_kept_raw(): + text = "10.0" + val = json.loads(_parse(text, _TYPED).calls[0].arguments)["n"] + assert val == "10.0" and isinstance(val, str) # not float 10.0 + + +def test_declared_boolean_with_one_kept_raw(): + text = "1" + val = json.loads(_parse(text, _TYPED).calls[0].arguments)["flag"] + assert val == "1" and isinstance(val, str) # not int 1 + + +def test_declared_integer_with_underscores_kept_raw(): + text = "1_000" + val = json.loads(_parse(text, _TYPED).calls[0].arguments)["n"] + assert val == "1_000" and isinstance(val, str) # not int 1000 + + +def _reject_bare_constant(c): + # json.loads parse_constant hook: fires only for bare NaN/Infinity/-Infinity. + raise AssertionError(f"emitted bare non-finite constant: {c}") + + +def test_declared_number_non_finite_never_emitted(): + for bad in ("NaN", "Infinity", "-Infinity", "1e999"): + text = f"{bad}" + args = _parse(text, _TYPED).calls[0].arguments + # Strict-client safe: no bare NaN/Infinity constant in the emitted JSON. + json.loads(args, parse_constant=_reject_bare_constant) + assert json.loads(args)["x"] == bad # kept as the raw string + + +def test_multiple_valid_calls_still_parse(): + text = ( + "12" + "34" + ) + r = _parse(text) + assert [json.loads(c.arguments) for c in r.calls] == [ + {"a": 1, "b": 2}, + {"a": 3, "b": 4}, + ] + + +def test_truncated_call_degrades_without_leaking_markup(): + # A call cut off by max_tokens (no closing tags) must NOT leak the partial + # markup -- only the leading text survives (mirrors Hermes). + text = "Sure! Paris" + r = _parse(text, _TYPED) + assert not r.calls + assert "" not in r.normal_text and "user" in out and out.endswith("<|im_start|>assistant\n") + + +# (5) Chat-template behaviors: multi-turn ordering, system message, roles. +def test_multi_turn_order_preserved(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="user", content="first"), + ChatMessage(role="assistant", content="second"), + ChatMessage(role="user", content="third"), + ] + ) + assert out.index("first") < out.index("second") < out.index("third") + assert out.endswith("<|im_start|>assistant\n") # generation prompt appended + + +def test_system_message_rendered(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="You are terse."), + ChatMessage(role="user", content="hi"), + ] + ) + assert "<|im_start|>system\nYou are terse." in out + + +def test_each_role_labeled(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + out = t.render( + [ + ChatMessage(role="system", content="s"), + ChatMessage(role="user", content="u"), + ChatMessage(role="assistant", content="a"), + ] + ) + for role in ("system", "user", "assistant"): + assert f"<|im_start|>{role}" in out + + +# Tool round-trip: a turn-2 request (assistant tool_call + tool result) must +# serialize into the shape any HF chat template consumes — the multi-turn loop +# breaks at turn 2 otherwise. OpenAI sends tool-call arguments as a JSON string; +# HF templates expect a mapping (Qwen renders `arguments|items`), so the server +# decodes it before templating. +def test_tool_call_arguments_decoded_for_template(): + t, fake = _template_with_fake() + t.render( + [ + ChatMessage(role="user", content="weather?"), + ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall( + name="get_weather", arguments='{"city": "Paris"}' + ), + ) + ], + ), + ChatMessage(role="tool", tool_call_id="c1", content='{"temp_c": 18}'), + ] + ) + msgs = fake.seen_messages + asst = next(m for m in msgs if m["role"] == "assistant") + assert asst["tool_calls"][0]["function"]["name"] == "get_weather" + # Decoded from the JSON string into a mapping the template can iterate. + assert asst["tool_calls"][0]["function"]["arguments"] == {"city": "Paris"} + tool = next(m for m in msgs if m["role"] == "tool") + assert tool["tool_call_id"] == "c1" and "temp_c" in tool["content"] + + +def test_tool_call_non_json_arguments_left_as_string(): + # A non-JSON arguments value must not crash; it passes through unchanged. + t, fake = _template_with_fake() + t.render( + [ + ChatMessage( + role="assistant", + content=None, + tool_calls=[ + ToolCall( + index=0, + id="c1", + type="function", + function=FunctionCall(name="f", arguments="not json"), + ) + ], + ) + ] + ) + asst = next(m for m in fake.seen_messages if m["role"] == "assistant") + assert asst["tool_calls"][0]["function"]["arguments"] == "not json" + + +class _HFSpecials: + """Minimal fake HF tokenizer exposing all_special_tokens / eos_token.""" + + def __init__(self, all_special_tokens, eos_token="<|im_end|>"): + self.all_special_tokens = list(all_special_tokens) + self.eos_token = eos_token + + +def _template_with_specials(all_special_tokens, eos_token="<|im_end|>"): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) + t._hf = _HFSpecials(all_special_tokens, eos_token) + return t + + +def test_turn_stop_excludes_tool_delimiters(): + # A tokenizer that marks BOTH a turn terminator and tool/structural delimiters + # as special: the stop set must keep the terminator and drop the delimiters, + # else generation halts at before the parser sees the call. + t = _template_with_specials( + ["<|im_end|>", "", "", "<|box_start|>"], + eos_token="<|im_end|>", + ) + stops = t.turn_stop_sequences() + assert "<|im_end|>" in stops + assert "" not in stops + assert "" not in stops + assert "<|box_start|>" not in stops + + +def test_turn_stop_includes_eos_and_known_terminators(): + t = _template_with_specials( + ["<|endoftext|>", "<|eot_id|>", ""], eos_token="<|endoftext|>" + ) + stops = t.turn_stop_sequences() + assert "<|endoftext|>" in stops # the tokenizer EOS + assert "<|eot_id|>" in stops # allowlisted terminator registered as special + assert "" not in stops + + +def test_turn_stop_drops_whitespace_only_specials(): + t = _template_with_specials(["<|im_end|>", " ", "\n", ""], eos_token="<|im_end|>") + stops = t.turn_stop_sequences() + assert all(s.strip() for s in stops) + assert " " not in stops and "\n" not in stops + + +def test_turn_stop_fallback_without_hf_is_narrow(): + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) # no _hf + stops = t.turn_stop_sequences() + assert "<|im_end|>" in stops + assert "" not in stops + + +def test_fallback_extracts_text_parts_not_repr(): + # 5e: the ChatML fallback renders list-content text parts, not a Python repr. + t = ChatTemplate(hf_tokenizer_path=None, allow_fallback=True) # no _hf + msg = ChatMessage( + role="user", + content=[ + {"type": "text", "text": "hello"}, + {"type": "image_url", "image_url": {}}, + ], + ) + out = t.render([msg]) + assert "hello" in out + assert "image_url" not in out and "{'type'" not in out # no repr leak diff --git a/extension/llm/server/python/tool_parsers/__init__.py b/extension/llm/server/python/tool_parsers/__init__.py new file mode 100644 index 00000000000..c890dec3888 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tool-call parsing. Two formats, pick the one matching your model: + +- HermesDetector: JSON inside (Qwen2.5/3, Hermes). +- QwenFunctionCallDetector: Qwen XML (Qwen3.5-MoE / + Qwen3-Coder). + +The server buffers the model's full output and parses it once into complete +OpenAI tool_calls; parse failures degrade to visible text. +""" + +from .hermes import HermesDetector +from .qwen import QwenFunctionCallDetector +from .types import ParseResult, ToolCallItem + +__all__ = [ + "HermesDetector", + "QwenFunctionCallDetector", + "ParseResult", + "ToolCallItem", +] diff --git a/extension/llm/server/python/tool_parsers/hermes.py b/extension/llm/server/python/tool_parsers/hermes.py new file mode 100644 index 00000000000..6ba19f89407 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/hermes.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hermes-style tool calls: {"name": ..., "arguments": {...}}. + +Used by Qwen2.5/Qwen3 (and Hermes models); the Qwen XML format is handled +separately by QwenFunctionCallDetector. The server buffers a model's full output +and parses it once into complete OpenAI tool_calls (no partial-fragment +streaming). + +Malformed-call policy (explicit): ALL-OR-NOTHING. If any block is +malformed, names an undefined tool, or is truncated/unclosed, the WHOLE response +degrades -- no partial call set is emitted. When it degrades, the raw +... markup is NOT leaked into visible content: only the +leading text before the first marker is returned. Never crashes. +""" + +import json +import logging +from typing import Any, Optional + +from .types import ParseResult, ToolCallItem + +logger = logging.getLogger(__name__) + +_BOT = "" +_EOT = "" +_DECODER = json.JSONDecoder() + + +class _UndefinedToolCall(Exception): + """A named a tool not in the request's `tools`. v1 degrades the + WHOLE response to visible text rather than emitting a partial set — never + silently drop an undefined call while keeping its siblings (spec).""" + + +class HermesDetector: + """Parses Hermes/Qwen tool calls. Create a fresh instance per request (it + holds the per-request tool-call index); never share across requests.""" + + bot_token = "" + + def __init__(self): + self._next_index = 0 + + def detect_and_parse(self, text: str, tools: dict[str, dict]) -> ParseResult: + """Return leading text + any complete tool calls. On no/undefined/ + malformed call, degrade to the leading text BEFORE the first marker -- + never leak the raw markup into visible content.""" + if _BOT not in text: + return ParseResult(normal_text=text) + # Leading text before the first marker; this is what we show if parsing + # degrades, so the structural markup is never surfaced to the client. + normal = text[: text.find(_BOT)].strip() + try: + calls = self._parse_calls(text, tools) + except _UndefinedToolCall as e: + # Well-formed call to an undefined tool: degrade the WHOLE response to + # visible text (surface the model's intent; never emit a partial set). + logger.debug("undefined tool %s; returning raw text (no partial calls)", e) + return ParseResult(normal_text=text) + except Exception as e: # noqa: BLE001 - never crash + # Genuinely malformed / truncated / unclosed markup: degrade to the + # leading text so the partial garbage is NOT surfaced. + logger.debug("malformed tool call (%s); degrading to leading text", e) + return ParseResult(normal_text=normal) + if not calls: + return ParseResult(normal_text=text) + return ParseResult(normal_text=normal, calls=calls) + + def _parse_calls(self, text: str, tools: dict[str, dict]) -> list[ToolCallItem]: + """All-or-nothing: any malformed/unclosed block raises (caller degrades). + + Each block's JSON is parsed with raw_decode rather than a non-greedy + regex, so a string value that itself contains '' does not + truncate the captured JSON. The block must be closed by ''. + """ + calls = [] + pos = 0 + while True: + start = text.find(_BOT, pos) + if start == -1: + break + s = start + len(_BOT) + while s < len(text) and text[s].isspace(): + s += 1 + obj, end = _DECODER.raw_decode(text, s) # JSONDecodeError -> degrade + close = text.find(_EOT, end) + if close == -1 or text[end:close].strip(): + raise ValueError("unclosed or trailing-garbage block") + pos = close + len(_EOT) + for entry in obj if isinstance(obj, list) else [obj]: + if not isinstance(entry, dict): + raise ValueError("tool call entry is not an object") + # `parameters` is the fallback ONLY when `arguments` is absent or + # explicitly null (get-with-default misses the explicit-null case). + args = entry.get("arguments") + if args is None: + args = entry.get("parameters") + calls.append(self._make_item(entry.get("name"), args, tools)) + return calls + + def _make_item( + self, name: Optional[str], arguments: Any, tools: dict[str, dict] + ) -> ToolCallItem: + if not name or name not in tools: + raise _UndefinedToolCall(repr(name)) + item = ToolCallItem( + tool_index=self._next_index, + name=name, + arguments=json.dumps( + arguments if arguments is not None else {}, ensure_ascii=False + ), + ) + self._next_index += 1 + return item diff --git a/extension/llm/server/python/tool_parsers/qwen.py b/extension/llm/server/python/tool_parsers/qwen.py new file mode 100644 index 00000000000..8b72f890d64 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/qwen.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Qwen XML-style tool calls: V. + +Emitted by Qwen3.5-MoE / Qwen3-Coder (typically wrapped in … +), e.g.: + + + + + Paris + + + + +This is a DIFFERENT format from HermesDetector (JSON inside ); pick the +detector that matches your model. Detection triggers only on the unambiguous +`` marker so ordinary prose is not misclassified. Parse failures fall +back to visible text — never a crash or a silent drop. +""" + +import json +import logging +import math +import re +from typing import Any, Optional + +from .types import ParseResult, ToolCallItem + +logger = logging.getLogger(__name__) + +# Structural matching, NOT first-close-wins. A parameter value runs to the +# that is followed by the next that is followed by the +# next , , or markup +# instead of silently truncating the call at the first delimiter (Qwen3-Coder +# emitting code/markup is a realistic trigger). +# Bound: a value containing the exact sequence "\s]+)\s*>(.*?)\s*(?=||\Z)", + re.DOTALL, +) +_PARAMETER_RE = re.compile( + r"\s]+)\s*>(.*?)\s*(?= Any: + """Cast a raw XML parameter string to the type declared in the tool's JSON + schema. + + The Qwen XML format is stringly-typed (`v`), so + without the schema we'd have to guess. Coercing to the declared type keeps the + emitted OpenAI tool_call schema-valid. + + On a DECLARED type whose strict cast fails, keep the raw string rather than + falling through to a JSON guess that would emit a *different* JSON type (the + bug this guards): `integer` + "10.0" must not become float 10.0; `boolean` + + "1" must not become int 1; underscore numerics ("1_000") are not accepted for + numeric types. Non-finite floats (NaN/Infinity/1e999) are never emitted -- they + are kept as the raw string -- so `arguments` is always valid JSON. Only when + the type is unknown do we make a JSON guess (then raw string), so + untyped/loosely-typed params keep working. + """ + v = value.strip() + if declared_type == "string": + return value + if declared_type == "boolean": + low = v.lower() + if low == "true": + return True + if low == "false": + return False + return value # not a valid bool literal -> keep raw, don't mistype + if declared_type == "integer": + # strict: digits only (no float, no underscores) + return int(v) if _INT_RE.match(v) else value + if declared_type == "number": + if _NUM_RE.match(v): + f = float(v) + if math.isfinite(f): + return f + return value # non-numeric / non-finite -> keep raw, never emit NaN/Inf + # Unknown declared type: a JSON guess, but reject non-finite (json.loads + # parses NaN/Infinity by default, which json.dumps would then re-emit). + try: + guess = json.loads(value) + except (ValueError, TypeError): + return value + if isinstance(guess, float) and not math.isfinite(guess): + return value + return guess + + +class QwenFunctionCallDetector: + """Parses Qwen's XML tool-call format. Create a fresh instance per request + (it holds the per-request tool-call index); never share across requests.""" + + bot_token = "" + + def __init__(self): + self._next_index = 0 + + def detect_and_parse(self, text: str, tools: dict[str, dict]) -> ParseResult: + """Return leading text + any complete tool calls. + + Degrade policy (mirrors HermesDetector): + * No tool marker at all -> return the text unchanged. + * A WELL-FORMED call to an undefined tool -> degrade the whole response + to the full visible text (surface the model's intent; never a partial + set). + * A TRUNCATED/partial call (a marker present but + no complete , e.g. cut by max_tokens) or other + malformed markup -> degrade to the LEADING text before the first + marker, so the raw markup is never leaked to the client as content. + + `tools` maps each defined tool name to its JSON-schema ``parameters`` + object; the schema is used to coerce stringly-typed XML values to their + declared types (and the key set validates names).""" + first = _FUNCTION_RE.search(text) + if first is None: + # No complete call. If a tool marker is present the call was + # truncated/partial -> strip it; otherwise there is no tool intent. + markers = [ + i + for i in (text.find(self.bot_token), text.find(" wrapper if present, else at the + # first tag. + cut = text.find(self.bot_token) + if cut == -1 or cut > first.start(): + cut = first.start() + normal = text[:cut].strip() + try: + calls = self._parse_calls(text, tools) + except _UndefinedToolCall as e: + # well-formed call to an undefined tool: surface full text (no partial set) + logger.debug("undefined tool %s; returning raw text (no partial calls)", e) + return ParseResult(normal_text=text) + except Exception as e: # noqa: BLE001 - never crash + # malformed markup: degrade to leading text (don't leak partial markup) + logger.debug("malformed tool call (%s); degrading to leading text", e) + return ParseResult(normal_text=normal) + if not calls: + return ParseResult(normal_text=text) + return ParseResult(normal_text=normal, calls=calls) + + def _parse_calls(self, text: str, tools: dict[str, dict]) -> list[ToolCallItem]: + calls = [] + for fm in _FUNCTION_RE.finditer(text): + name, body = fm.group(1), fm.group(2) + props = (tools.get(name) or {}).get("properties", {}) + args = {} + for pm in _PARAMETER_RE.finditer(body): + key = pm.group(1) + args[key] = _coerce(pm.group(2).strip(), props.get(key, {}).get("type")) + calls.append(self._make_item(name, args, tools)) + return calls + + def _make_item( + self, name: Optional[str], arguments: dict, tools: dict[str, dict] + ) -> ToolCallItem: + if not name or name not in tools: + raise _UndefinedToolCall(repr(name)) + item = ToolCallItem( + tool_index=self._next_index, + name=name, + # allow_nan=False: belt-and-suspenders so a non-finite that escaped + # _coerce raises (-> caught -> text fallback) rather than emitting + # invalid JSON (NaN/Infinity) a strict client would reject. + arguments=json.dumps(arguments, ensure_ascii=False, allow_nan=False), + ) + self._next_index += 1 + return item diff --git a/extension/llm/server/python/tool_parsers/types.py b/extension/llm/server/python/tool_parsers/types.py new file mode 100644 index 00000000000..2dae5c79458 --- /dev/null +++ b/extension/llm/server/python/tool_parsers/types.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Protocol-agnostic tool-parsing types. + +Kept independent of the OpenAI wire schema so the parser package is reusable; +serving_chat translates these into OpenAI tool_calls / deltas at the edge. +Design adapted from SGLang's core_types, with explicit per-request state. +""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ToolCallItem: + """A parsed tool call. `arguments` is a JSON string (the full arguments — + this server emits complete calls, not fragments).""" + + tool_index: int + name: Optional[str] = None + arguments: str = "" + + +@dataclass +class ParseResult: + """Outcome of a parse: free text plus any tool calls found.""" + + normal_text: str = "" + calls: list[ToolCallItem] = field(default_factory=list) diff --git a/extension/llm/server/spec/README.md b/extension/llm/server/spec/README.md new file mode 100644 index 00000000000..58e0e46ef57 --- /dev/null +++ b/extension/llm/server/spec/README.md @@ -0,0 +1,73 @@ +# ExecuTorch LLM Server — Contract Spec + +The language-neutral contract every ExecuTorch LLM server (Python today, C++ +later) implements. The conformance suite in `../conformance` validates an +implementation against this spec by hitting a live server, so it is independent +of language and engine. + +## Supported endpoints + +| Endpoint | Status | +|----------|--------| +| `GET /v1/models` | implemented | +| `POST /v1/chat/completions` (stream + non-stream) | implemented | +| `GET /health` | implemented | +| `POST /v1/completions` | planned | + +## `POST /v1/chat/completions` + +OpenAI Chat Completions subset. **Honored** request fields: `model`, `messages`, +`stream`, `temperature`, `max_tokens` / `max_completion_tokens`, `stop`, `tools`, +`tool_choice` (only `"none"` to disable tools, or `"auto"`/unset for default +parsing), `stream_options.include_usage`, and `chat_template_kwargs` (e.g. +`enable_thinking`). + +**Rejected** with `400 invalid_request_error` (`code: "unsupported_parameter"`) +rather than silently ignored — a client relying on them would otherwise get +wrong behavior: `top_p` (anything other than `1.0`), `seed`, `n` (> 1), +`reasoning_effort`, `frequency_penalty`/`presence_penalty` (nonzero), `top_k`, +`logit_bias`, `tool_choice` = `"required"` or a specific-function choice +(forcing/restricting a call needs constrained decoding, which v1 lacks), +`response_format` other than `{"type": "text"}` (no constrained JSON), +`logprobs`/`top_logprobs` (not returned), and `parallel_tool_calls: false` +(single-call can't be guaranteed without constraining). Unknown fields that +don't affect the output (e.g. `user`, `store`, `metadata`) are accepted and +ignored. + +Non-streaming response: `chat.completion` with one `choice` +(`message.role = "assistant"`, string `content` or `tool_calls`, `finish_reason` +∈ `stop` | `length` | `tool_calls`) and a `usage` block. + +Streaming response: `text/event-stream` of `chat.completion.chunk` objects — +first chunk carries `delta.role = "assistant"`, subsequent chunks carry +`delta.content` (or buffered `delta.tool_calls`), a final chunk carries +`finish_reason`, optionally a usage-only chunk (with +`stream_options.include_usage`), terminated by `data: [DONE]`. + +### Tool calling + +Two output formats are accepted: Hermes-style JSON +(`{"name":...,"arguments":{...}}`, used by Qwen2.5/Qwen3) +and Qwen XML-style (`V`, +typically wrapped in ``, used by Qwen3.5-MoE / Qwen3-Coder). The +server buffers the model's full output and emits **complete** OpenAI +`tool_calls` (no partial-argument fragments). Calls to tools absent from the +request, and malformed tool calls, degrade to visible text — never a crash or +silent drop. `tool_choice="none"` disables tool parsing. + +### Errors & cancellation + +Errors return `{"error": {"message", "type", "code"}}` with an appropriate +status (e.g. `400 context_length_exceeded` when `--max-context` is set and the +prompt exceeds it). A mid-stream failure emits an `error` SSE event then +`[DONE]` rather than dropping the socket. Cancellation is best-effort: on a +client disconnect the control plane stops consuming the stream (`stop()`), but +the worker runs the in-flight request to completion — V1 has no mid-generation +interrupt protocol. + +### Prefix cache + +Not in V1 serving. The control plane holds no KV state and does no prefix-reuse +routing; each request is an independent prompt to the worker. If turn-to-turn KV +prefix reuse returns, it will live inside the worker/session (where the KV cache +is), not in the control plane.