Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions py/src/braintrust/devserver/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ParsedEvalBody(TypedDict, total=False):
"""Type for parsed eval request body."""

name: str # Required
id: str
parameters: dict[str, Any]
data: Any
scores: list[ParsedFunctionId]
Expand Down Expand Up @@ -202,6 +203,11 @@ def parse_eval_body(request_data: str | bytes | dict) -> ParsedEvalBody:
# Build the parsed body
parsed: ParsedEvalBody = {"name": name}

if "id" in data:
if not isinstance(data["id"], str):
raise ValidationError(f"id must be a string, got {type(data['id']).__name__}")
parsed["id"] = data["id"]

# Optional fields with validation
if "parameters" in data:
if not isinstance(data["parameters"], dict):
Expand Down
80 changes: 72 additions & 8 deletions py/src/braintrust/devserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import sys
import textwrap
from collections import Counter, defaultdict
from typing import Any


Expand Down Expand Up @@ -53,6 +54,7 @@


_all_evaluators: dict[str, Evaluator[Any, Any, Any]] = {}
_evaluator_ids_by_name: dict[str, list[str]] = {}


class _ParameterOverrideHooks:
Expand Down Expand Up @@ -120,8 +122,10 @@ async def list_evaluators(request: Request) -> JSONResponse:
return JSONResponse({"error": "Unauthorized"}, status_code=401)

evaluator_list = {}
for name, evaluator in _all_evaluators.items():
evaluator_list[name] = {
for evaluator_id, evaluator in _all_evaluators.items():
ids_for_name = _evaluator_ids_by_name[evaluator.eval_name]
list_key = evaluator.eval_name if len(ids_for_name) == 1 else evaluator_id
evaluator_list[list_key] = {
"parameters": (
serialize_remote_eval_parameters_container(evaluator.parameters) if evaluator.parameters else None
),
Expand All @@ -130,6 +134,9 @@ async def list_evaluators(request: Request) -> JSONResponse:
{"name": _classifier_name(classifier, i)} for i, classifier in enumerate(evaluator.classifiers or [])
],
}
if len(ids_for_name) > 1:
evaluator_list[list_key]["id"] = evaluator_id
evaluator_list[list_key]["name"] = evaluator.eval_name

return JSONResponse(evaluator_list)

Expand Down Expand Up @@ -159,10 +166,9 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:

state = ctx.state

# Check if the evaluator exists
evaluator = _all_evaluators.get(eval_data["name"])
if not evaluator:
return JSONResponse({"error": f"Evaluator '{eval_data['name']}' not found"}, status_code=404)
evaluator, error = _resolve_evaluator(eval_data)
if error is not None:
return error

# Get the dataset if data is provided
try:
Expand Down Expand Up @@ -302,6 +308,65 @@ async def run_and_complete():
return JSONResponse({"error": f"Failed to run evaluation: {str(e)}"}, status_code=500)


def _make_evaluator_id(eval_name: str, index: int, duplicate_names: set[str]) -> str:
if eval_name not in duplicate_names:
return eval_name
return f"{eval_name}#{index}"


def _set_evaluator_registry(evaluators: list[Evaluator[Any, Any, Any]]) -> None:
global _all_evaluators, _evaluator_ids_by_name
name_counts = Counter(evaluator.eval_name for evaluator in evaluators)
duplicate_names = {name for name, count in name_counts.items() if count > 1}
name_indexes: dict[str, int] = defaultdict(int)
all_evaluators: dict[str, Evaluator[Any, Any, Any]] = {}
evaluator_ids_by_name: dict[str, list[str]] = defaultdict(list)

for evaluator in evaluators:
name_indexes[evaluator.eval_name] += 1
evaluator_id = _make_evaluator_id(evaluator.eval_name, name_indexes[evaluator.eval_name], duplicate_names)
base_evaluator_id = evaluator_id
suffix = 1
while evaluator_id in all_evaluators:
suffix += 1
evaluator_id = f"{base_evaluator_id}#{suffix}"
all_evaluators[evaluator_id] = evaluator
evaluator_ids_by_name[evaluator.eval_name].append(evaluator_id)

_all_evaluators = all_evaluators
_evaluator_ids_by_name = dict(evaluator_ids_by_name)


def _resolve_evaluator(
eval_data: dict[str, Any],
) -> tuple[Evaluator[Any, Any, Any] | None, JSONResponse | None]:
evaluator_id = eval_data.get("id")
if evaluator_id:
evaluator = _all_evaluators.get(evaluator_id)
if evaluator is None:
return None, JSONResponse({"error": f"Evaluator id '{evaluator_id}' not found"}, status_code=404)
return evaluator, None

eval_name = eval_data["name"]
evaluator = _all_evaluators.get(eval_name)
if evaluator is not None:
return evaluator, None

evaluator_ids = _evaluator_ids_by_name.get(eval_name, [])
if len(evaluator_ids) == 1:
return _all_evaluators[evaluator_ids[0]], None
if len(evaluator_ids) > 1:
candidates = [{"id": candidate_id, "name": eval_name} for candidate_id in evaluator_ids]
return None, JSONResponse(
{
"error": f"Evaluator name '{eval_name}' is ambiguous. Retry with one of the listed ids.",
"candidates": candidates,
},
status_code=409,
)
return None, JSONResponse({"error": f"Evaluator '{eval_name}' not found"}, status_code=404)


def create_app(evaluators: list[Evaluator[Any, Any, Any]], org_name: str | None = None):
"""Create and configure the Starlette app for the dev server.

Expand All @@ -312,8 +377,7 @@ def create_app(evaluators: list[Evaluator[Any, Any, Any]], org_name: str | None
Returns:
Configured Starlette app
"""
global _all_evaluators
_all_evaluators = {evaluator.eval_name: evaluator for evaluator in evaluators}
_set_evaluator_registry(evaluators)

routes = [
Route("/", endpoint=index),
Expand Down
104 changes: 104 additions & 0 deletions py/src/braintrust/devserver/test_server_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import inspect
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -122,6 +123,109 @@ def test_devserver_list_evaluators(client, api_key, org_name):
assert evaluators["simple-math-eval"]["classifiers"] == [{"name": "classifier"}]


def test_devserver_keeps_duplicate_eval_names_addressable(api_key, org_name, monkeypatch):
if not has_devserver_installed():
pytest.skip("Devserver dependencies not installed (requires .[cli])")

from braintrust import Evaluator
from braintrust.devserver import server as devserver_module
from braintrust.devserver.server import create_app
from braintrust.logger import BraintrustState
from starlette.testclient import TestClient

def task_a(input: str, _hooks) -> str:
return f"a:{input}"

def task_b(input: str, _hooks) -> str:
return f"b:{input}"

evaluator_a = Evaluator(
project_name="shared-project",
eval_name="shared-project",
data=lambda: [{"input": "ping", "expected": "a:ping"}],
task=task_a,
scores=[],
experiment_name=None,
metadata=None,
)
evaluator_b = Evaluator(
project_name="shared-project",
eval_name="shared-project",
data=lambda: [{"input": "ping", "expected": "b:ping"}],
task=task_b,
scores=[],
experiment_name=None,
metadata=None,
)

async def fake_cached_login(**_kwargs):
return BraintrustState()

class FakeSummary:
def as_dict(self):
return {"experiment_name": "shared-project", "project_name": "shared-project", "scores": {}}

class FakeResult:
summary = FakeSummary()

captured: dict[str, Any] = {}

async def fake_eval_async(*, task, **_kwargs):
hooks = type("Hooks", (), {"parameters": None, "report_progress": lambda self, _progress: None})()
output = task("ping", hooks)
if inspect.isawaitable(output):
output = await output
captured["output"] = output
return FakeResult()

monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login)
monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async)

test_client = TestClient(create_app([evaluator_a, evaluator_b]))
headers = {"x-bt-auth-token": api_key, "x-bt-org-name": org_name, "Content-Type": "application/json"}

list_response = test_client.get("/list", headers=headers)
assert list_response.status_code == 200
evaluators = list_response.json()
assert list(evaluators) == ["shared-project#1", "shared-project#2"]
assert evaluators["shared-project#1"]["id"] == "shared-project#1"
assert evaluators["shared-project#1"]["name"] == "shared-project"
assert evaluators["shared-project#2"]["id"] == "shared-project#2"

ambiguous_response = test_client.post(
"/eval",
headers=headers,
json={"name": "shared-project", "stream": False, "data": [{"input": "ping", "expected": "b:ping"}]},
)
assert ambiguous_response.status_code == 409
assert ambiguous_response.json()["candidates"] == [
{"id": "shared-project#1", "name": "shared-project"},
{"id": "shared-project#2", "name": "shared-project"},
]

selected_response = test_client.post(
"/eval",
headers=headers,
json={
"name": "shared-project",
"id": "shared-project#2",
"stream": False,
"data": [{"input": "ping", "expected": "b:ping"}],
},
)
assert selected_response.status_code == 200
assert captured["output"] == "b:ping"

captured.clear()
list_key_response = test_client.post(
"/eval",
headers=headers,
json={"name": "shared-project#1", "stream": False, "data": [{"input": "ping", "expected": "a:ping"}]},
)
assert list_key_response.status_code == 200
assert captured["output"] == "a:ping"


def parse_sse_events(response_text: str) -> list[dict[str, Any]]:
"""Parse SSE events from response text."""
events = []
Expand Down