Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 8 additions & 25 deletions py/src/braintrust/integrations/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,24 +617,6 @@ def _get_model_name_from_response(response: LLMResult) -> str | None:
return model_name


def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool:
# LangChain provider packages use different cache-token conventions:
# - OpenAI-style responses report cache reads as a subset of input_tokens.
# - Anthropic-style responses report cache reads/creation separately from input_tokens.
#
# Avoid provider-name checks here so any LangChain integration using the same
# "separate cache tokens" schema gets normalized, while providers that only
# expose cache_read as input-token detail do not get double-counted.
return any(
key in input_token_details
for key in (
"cache_creation",
"ephemeral_5m_input_tokens",
"ephemeral_1h_input_tokens",
)
)


def _get_metrics_from_response(response: LLMResult):
metrics = {}

Expand Down Expand Up @@ -685,15 +667,16 @@ def _get_metrics_from_response(response: LLMResult):
completion_tokens = metrics.get("completion_tokens")
total_tokens = metrics.get("total_tokens")
if prompt_tokens is not None and completion_tokens is not None:
if (
cache_tokens
and total_tokens == prompt_tokens + completion_tokens
and _cache_tokens_are_separate_from_input_tokens(input_token_details)
):
# LangChain's UsageMetadata contract makes input_token_details a
# breakdown of input_tokens, so cache tokens already count toward
# the prompt total (langchain-anthropic >= 0.2.3, langchain-aws,
# langchain-openai all comply). Cache tokens exceeding the prompt
# total means the integration reported uncached input only — fold
# cache tokens back in so prompt/total stay internally consistent.
if cache_tokens > prompt_tokens and total_tokens == prompt_tokens + completion_tokens:
prompt_tokens += cache_tokens
metrics["prompt_tokens"] = prompt_tokens
if total_tokens is not None:
metrics["total_tokens"] = total_tokens + cache_tokens
metrics["total_tokens"] = total_tokens + cache_tokens
metrics["tokens"] = prompt_tokens + completion_tokens

if not metrics or not any(metrics.values()):
Expand Down
100 changes: 100 additions & 0 deletions py/src/braintrust/integrations/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
from braintrust import logger
from braintrust.integrations.langchain import BraintrustCallbackHandler
from braintrust.integrations.langchain.callbacks import _get_metrics_from_response
from braintrust.logger import flush
from braintrust.test_helpers import init_test_logger
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import RunnableMap, RunnableSerializable
Expand Down Expand Up @@ -906,6 +908,94 @@ def test_streaming_ttft(logger_memory_logger):
)


def _single_generation_response(usage_metadata: dict, model_name: str) -> LLMResult:
return LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(
content="Done",
response_metadata={"model_name": model_name},
usage_metadata=cast(dict, usage_metadata),
)
)
]
]
)


def test_folded_cache_tokens_are_not_double_counted():
# langchain-anthropic >= 0.2.3 folds cache read/creation tokens into
# input_tokens, exposing them via input_token_details as a breakdown.
response = _single_generation_response(
{
"input_tokens": 1095,
"output_tokens": 40,
"total_tokens": 1135,
"input_token_details": {
"cache_read": 0,
"cache_creation": 0,
"ephemeral_5m_input_tokens": 1075,
"ephemeral_1h_input_tokens": 0,
},
},
model_name="claude-sonnet-4-5-20250929",
)

assert _get_metrics_from_response(response) == {
"prompt_tokens": 1095,
"completion_tokens": 40,
"total_tokens": 1135,
"tokens": 1135,
"prompt_cached_tokens": 0,
"prompt_cache_creation_5m_tokens": 1075,
"prompt_cache_creation_1h_tokens": 0,
}


def test_openai_cached_tokens_are_not_folded_into_prompt_tokens():
response = _single_generation_response(
{
"input_tokens": 1000,
"output_tokens": 200,
"total_tokens": 1200,
"input_token_details": {"cache_read": 500},
},
model_name="gpt-4o-mini-2024-07-18",
)

assert _get_metrics_from_response(response) == {
"prompt_tokens": 1000,
"completion_tokens": 200,
"total_tokens": 1200,
"tokens": 1200,
"prompt_cached_tokens": 500,
}


def test_separately_reported_cache_tokens_are_folded_into_prompt_tokens():
# Integrations that report uncached input only make cache tokens exceed
# the prompt total; normalize so prompt/total include cache tokens.
response = _single_generation_response(
{
"input_tokens": 20,
"output_tokens": 40,
"total_tokens": 60,
"input_token_details": {"cache_read": 1000, "cache_creation": 500},
},
model_name="claude-3-5-sonnet-20240620",
)

assert _get_metrics_from_response(response) == {
"prompt_tokens": 1520,
"completion_tokens": 40,
"total_tokens": 1560,
"tokens": 1560,
"prompt_cached_tokens": 1000,
"prompt_cache_creation_tokens": 500,
}


@pytest.mark.vcr
def test_prompt_caching_tokens(logger_memory_logger):
from langchain_anthropic import ChatAnthropic
Expand Down Expand Up @@ -1098,6 +1188,12 @@ def test_prompt_caching_tokens(logger_memory_logger):
assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens
assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"]

# langchain-anthropic already folds cache read/creation tokens into
# usage_metadata input_tokens; the callback must not add them again.
assert res.usage_metadata is not None
assert first_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
assert first_metrics["total_tokens"] == res.usage_metadata["total_tokens"]

second_metrics = None
for attempt in range(3):
res = model.invoke(
Expand Down Expand Up @@ -1134,6 +1230,10 @@ def test_prompt_caching_tokens(logger_memory_logger):
assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"]
assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"]

assert res.usage_metadata is not None
assert second_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"]
assert second_metrics["total_tokens"] == res.usage_metadata["total_tokens"]


@pytest.mark.vcr
def test_image_input(logger_memory_logger):
Expand Down