diff --git a/py/src/braintrust/devserver/schemas.py b/py/src/braintrust/devserver/schemas.py index a359a93d..a48403a0 100644 --- a/py/src/braintrust/devserver/schemas.py +++ b/py/src/braintrust/devserver/schemas.py @@ -35,6 +35,7 @@ class ParsedEvalBody(TypedDict, total=False): """Type for parsed eval request body.""" name: str # Required + id: str parameters: dict[str, Any] data: Any scores: list[ParsedFunctionId] @@ -202,6 +203,11 @@ def parse_eval_body(request_data: str | bytes | dict) -> ParsedEvalBody: # Build the parsed body parsed: ParsedEvalBody = {"name": name} + if "id" in data: + if not isinstance(data["id"], str): + raise ValidationError(f"id must be a string, got {type(data['id']).__name__}") + parsed["id"] = data["id"] + # Optional fields with validation if "parameters" in data: if not isinstance(data["parameters"], dict): diff --git a/py/src/braintrust/devserver/server.py b/py/src/braintrust/devserver/server.py index a141fc79..0e96cba6 100644 --- a/py/src/braintrust/devserver/server.py +++ b/py/src/braintrust/devserver/server.py @@ -2,6 +2,7 @@ import json import sys import textwrap +from collections import Counter, defaultdict from typing import Any @@ -53,6 +54,7 @@ _all_evaluators: dict[str, Evaluator[Any, Any, Any]] = {} +_evaluator_ids_by_name: dict[str, list[str]] = {} class _ParameterOverrideHooks: @@ -120,8 +122,10 @@ async def list_evaluators(request: Request) -> JSONResponse: return JSONResponse({"error": "Unauthorized"}, status_code=401) evaluator_list = {} - for name, evaluator in _all_evaluators.items(): - evaluator_list[name] = { + for evaluator_id, evaluator in _all_evaluators.items(): + ids_for_name = _evaluator_ids_by_name[evaluator.eval_name] + list_key = evaluator.eval_name if len(ids_for_name) == 1 else evaluator_id + evaluator_list[list_key] = { "parameters": ( serialize_remote_eval_parameters_container(evaluator.parameters) if evaluator.parameters else None ), @@ -130,6 +134,9 @@ async def list_evaluators(request: Request) -> JSONResponse: {"name": _classifier_name(classifier, i)} for i, classifier in enumerate(evaluator.classifiers or []) ], } + if len(ids_for_name) > 1: + evaluator_list[list_key]["id"] = evaluator_id + evaluator_list[list_key]["name"] = evaluator.eval_name return JSONResponse(evaluator_list) @@ -159,10 +166,9 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse: state = ctx.state - # Check if the evaluator exists - evaluator = _all_evaluators.get(eval_data["name"]) - if not evaluator: - return JSONResponse({"error": f"Evaluator '{eval_data['name']}' not found"}, status_code=404) + evaluator, error = _resolve_evaluator(eval_data) + if error is not None: + return error # Get the dataset if data is provided try: @@ -302,6 +308,65 @@ async def run_and_complete(): return JSONResponse({"error": f"Failed to run evaluation: {str(e)}"}, status_code=500) +def _make_evaluator_id(eval_name: str, index: int, duplicate_names: set[str]) -> str: + if eval_name not in duplicate_names: + return eval_name + return f"{eval_name}#{index}" + + +def _set_evaluator_registry(evaluators: list[Evaluator[Any, Any, Any]]) -> None: + global _all_evaluators, _evaluator_ids_by_name + name_counts = Counter(evaluator.eval_name for evaluator in evaluators) + duplicate_names = {name for name, count in name_counts.items() if count > 1} + name_indexes: dict[str, int] = defaultdict(int) + all_evaluators: dict[str, Evaluator[Any, Any, Any]] = {} + evaluator_ids_by_name: dict[str, list[str]] = defaultdict(list) + + for evaluator in evaluators: + name_indexes[evaluator.eval_name] += 1 + evaluator_id = _make_evaluator_id(evaluator.eval_name, name_indexes[evaluator.eval_name], duplicate_names) + base_evaluator_id = evaluator_id + suffix = 1 + while evaluator_id in all_evaluators: + suffix += 1 + evaluator_id = f"{base_evaluator_id}#{suffix}" + all_evaluators[evaluator_id] = evaluator + evaluator_ids_by_name[evaluator.eval_name].append(evaluator_id) + + _all_evaluators = all_evaluators + _evaluator_ids_by_name = dict(evaluator_ids_by_name) + + +def _resolve_evaluator( + eval_data: dict[str, Any], +) -> tuple[Evaluator[Any, Any, Any] | None, JSONResponse | None]: + evaluator_id = eval_data.get("id") + if evaluator_id: + evaluator = _all_evaluators.get(evaluator_id) + if evaluator is None: + return None, JSONResponse({"error": f"Evaluator id '{evaluator_id}' not found"}, status_code=404) + return evaluator, None + + eval_name = eval_data["name"] + evaluator = _all_evaluators.get(eval_name) + if evaluator is not None: + return evaluator, None + + evaluator_ids = _evaluator_ids_by_name.get(eval_name, []) + if len(evaluator_ids) == 1: + return _all_evaluators[evaluator_ids[0]], None + if len(evaluator_ids) > 1: + candidates = [{"id": candidate_id, "name": eval_name} for candidate_id in evaluator_ids] + return None, JSONResponse( + { + "error": f"Evaluator name '{eval_name}' is ambiguous. Retry with one of the listed ids.", + "candidates": candidates, + }, + status_code=409, + ) + return None, JSONResponse({"error": f"Evaluator '{eval_name}' not found"}, status_code=404) + + def create_app(evaluators: list[Evaluator[Any, Any, Any]], org_name: str | None = None): """Create and configure the Starlette app for the dev server. @@ -312,8 +377,7 @@ def create_app(evaluators: list[Evaluator[Any, Any, Any]], org_name: str | None Returns: Configured Starlette app """ - global _all_evaluators - _all_evaluators = {evaluator.eval_name: evaluator for evaluator in evaluators} + _set_evaluator_registry(evaluators) routes = [ Route("/", endpoint=index), diff --git a/py/src/braintrust/devserver/test_server_integration.py b/py/src/braintrust/devserver/test_server_integration.py index 4a013684..e6dcb6fc 100644 --- a/py/src/braintrust/devserver/test_server_integration.py +++ b/py/src/braintrust/devserver/test_server_integration.py @@ -1,5 +1,6 @@ import json import os +import inspect from pathlib import Path from typing import Any @@ -122,6 +123,109 @@ def test_devserver_list_evaluators(client, api_key, org_name): assert evaluators["simple-math-eval"]["classifiers"] == [{"name": "classifier"}] +def test_devserver_keeps_duplicate_eval_names_addressable(api_key, org_name, monkeypatch): + if not has_devserver_installed(): + pytest.skip("Devserver dependencies not installed (requires .[cli])") + + from braintrust import Evaluator + from braintrust.devserver import server as devserver_module + from braintrust.devserver.server import create_app + from braintrust.logger import BraintrustState + from starlette.testclient import TestClient + + def task_a(input: str, _hooks) -> str: + return f"a:{input}" + + def task_b(input: str, _hooks) -> str: + return f"b:{input}" + + evaluator_a = Evaluator( + project_name="shared-project", + eval_name="shared-project", + data=lambda: [{"input": "ping", "expected": "a:ping"}], + task=task_a, + scores=[], + experiment_name=None, + metadata=None, + ) + evaluator_b = Evaluator( + project_name="shared-project", + eval_name="shared-project", + data=lambda: [{"input": "ping", "expected": "b:ping"}], + task=task_b, + scores=[], + experiment_name=None, + metadata=None, + ) + + async def fake_cached_login(**_kwargs): + return BraintrustState() + + class FakeSummary: + def as_dict(self): + return {"experiment_name": "shared-project", "project_name": "shared-project", "scores": {}} + + class FakeResult: + summary = FakeSummary() + + captured: dict[str, Any] = {} + + async def fake_eval_async(*, task, **_kwargs): + hooks = type("Hooks", (), {"parameters": None, "report_progress": lambda self, _progress: None})() + output = task("ping", hooks) + if inspect.isawaitable(output): + output = await output + captured["output"] = output + return FakeResult() + + monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login) + monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async) + + test_client = TestClient(create_app([evaluator_a, evaluator_b])) + headers = {"x-bt-auth-token": api_key, "x-bt-org-name": org_name, "Content-Type": "application/json"} + + list_response = test_client.get("/list", headers=headers) + assert list_response.status_code == 200 + evaluators = list_response.json() + assert list(evaluators) == ["shared-project#1", "shared-project#2"] + assert evaluators["shared-project#1"]["id"] == "shared-project#1" + assert evaluators["shared-project#1"]["name"] == "shared-project" + assert evaluators["shared-project#2"]["id"] == "shared-project#2" + + ambiguous_response = test_client.post( + "/eval", + headers=headers, + json={"name": "shared-project", "stream": False, "data": [{"input": "ping", "expected": "b:ping"}]}, + ) + assert ambiguous_response.status_code == 409 + assert ambiguous_response.json()["candidates"] == [ + {"id": "shared-project#1", "name": "shared-project"}, + {"id": "shared-project#2", "name": "shared-project"}, + ] + + selected_response = test_client.post( + "/eval", + headers=headers, + json={ + "name": "shared-project", + "id": "shared-project#2", + "stream": False, + "data": [{"input": "ping", "expected": "b:ping"}], + }, + ) + assert selected_response.status_code == 200 + assert captured["output"] == "b:ping" + + captured.clear() + list_key_response = test_client.post( + "/eval", + headers=headers, + json={"name": "shared-project#1", "stream": False, "data": [{"input": "ping", "expected": "a:ping"}]}, + ) + assert list_key_response.status_code == 200 + assert captured["output"] == "a:ping" + + def parse_sse_events(response_text: str) -> list[dict[str, Any]]: """Parse SSE events from response text.""" events = []