From 63f99a51cd3d95ee24424f992aaddf4d9516d944 Mon Sep 17 00:00:00 2001 From: Bhavesh Pareek Date: Tue, 9 Jun 2026 18:16:08 -0400 Subject: [PATCH] fix(langchain): stop double-counting anthropic cache tokens in prompt totals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit langchain-anthropic has folded cache read/creation tokens into usage_metadata input_tokens since 0.2.3 (versions before that don't emit input_token_details at all), and langchain-aws does the same — per the langchain-core UsageMetadata contract, input_token_details is a breakdown of input_tokens, not an addition to it. The cache normalization from #411/#445 detected "separate cache token accounting" by the presence of cache_creation/ephemeral_* detail keys, which langchain-anthropic always emits, so every cached Anthropic call had cache tokens added to prompt_tokens a second time. With a warm cache this roughly doubles reported prompt tokens (e.g. a real trace reported 75,387 prompt tokens for a 37,694-token request with 37,324 cache reads and 369 cache writes). Detect separate accounting arithmetically instead: only fold cache tokens into prompt/total when they exceed the reported prompt total, which is impossible under the UsageMetadata contract but is exactly the inconsistency the original normalization (BT-5150) was added to repair. Strengthen the VCR prompt-caching test to assert span prompt/total tokens equal the usage_metadata the model reported, and add unit coverage for the folded (Anthropic), subset (OpenAI), and separate (legacy) conventions. Co-Authored-By: Claude Fable 5 --- .../integrations/langchain/callbacks.py | 33 ++---- .../integrations/langchain/test_callbacks.py | 100 ++++++++++++++++++ 2 files changed, 108 insertions(+), 25 deletions(-) diff --git a/py/src/braintrust/integrations/langchain/callbacks.py b/py/src/braintrust/integrations/langchain/callbacks.py index 96646604..2cb9b37d 100644 --- a/py/src/braintrust/integrations/langchain/callbacks.py +++ b/py/src/braintrust/integrations/langchain/callbacks.py @@ -617,24 +617,6 @@ def _get_model_name_from_response(response: LLMResult) -> str | None: return model_name -def _cache_tokens_are_separate_from_input_tokens(input_token_details: dict[str, Any]) -> bool: - # LangChain provider packages use different cache-token conventions: - # - OpenAI-style responses report cache reads as a subset of input_tokens. - # - Anthropic-style responses report cache reads/creation separately from input_tokens. - # - # Avoid provider-name checks here so any LangChain integration using the same - # "separate cache tokens" schema gets normalized, while providers that only - # expose cache_read as input-token detail do not get double-counted. - return any( - key in input_token_details - for key in ( - "cache_creation", - "ephemeral_5m_input_tokens", - "ephemeral_1h_input_tokens", - ) - ) - - def _get_metrics_from_response(response: LLMResult): metrics = {} @@ -685,15 +667,16 @@ def _get_metrics_from_response(response: LLMResult): completion_tokens = metrics.get("completion_tokens") total_tokens = metrics.get("total_tokens") if prompt_tokens is not None and completion_tokens is not None: - if ( - cache_tokens - and total_tokens == prompt_tokens + completion_tokens - and _cache_tokens_are_separate_from_input_tokens(input_token_details) - ): + # LangChain's UsageMetadata contract makes input_token_details a + # breakdown of input_tokens, so cache tokens already count toward + # the prompt total (langchain-anthropic >= 0.2.3, langchain-aws, + # langchain-openai all comply). Cache tokens exceeding the prompt + # total means the integration reported uncached input only — fold + # cache tokens back in so prompt/total stay internally consistent. + if cache_tokens > prompt_tokens and total_tokens == prompt_tokens + completion_tokens: prompt_tokens += cache_tokens metrics["prompt_tokens"] = prompt_tokens - if total_tokens is not None: - metrics["total_tokens"] = total_tokens + cache_tokens + metrics["total_tokens"] = total_tokens + cache_tokens metrics["tokens"] = prompt_tokens + completion_tokens if not metrics or not any(metrics.values()): diff --git a/py/src/braintrust/integrations/langchain/test_callbacks.py b/py/src/braintrust/integrations/langchain/test_callbacks.py index e05a5775..4cb1f23d 100644 --- a/py/src/braintrust/integrations/langchain/test_callbacks.py +++ b/py/src/braintrust/integrations/langchain/test_callbacks.py @@ -8,10 +8,12 @@ import pytest from braintrust import logger from braintrust.integrations.langchain import BraintrustCallbackHandler +from braintrust.integrations.langchain.callbacks import _get_metrics_from_response from braintrust.logger import flush from braintrust.test_helpers import init_test_logger from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import RunnableMap, RunnableSerializable @@ -906,6 +908,94 @@ def test_streaming_ttft(logger_memory_logger): ) +def _single_generation_response(usage_metadata: dict, model_name: str) -> LLMResult: + return LLMResult( + generations=[ + [ + ChatGeneration( + message=AIMessage( + content="Done", + response_metadata={"model_name": model_name}, + usage_metadata=cast(dict, usage_metadata), + ) + ) + ] + ] + ) + + +def test_folded_cache_tokens_are_not_double_counted(): + # langchain-anthropic >= 0.2.3 folds cache read/creation tokens into + # input_tokens, exposing them via input_token_details as a breakdown. + response = _single_generation_response( + { + "input_tokens": 1095, + "output_tokens": 40, + "total_tokens": 1135, + "input_token_details": { + "cache_read": 0, + "cache_creation": 0, + "ephemeral_5m_input_tokens": 1075, + "ephemeral_1h_input_tokens": 0, + }, + }, + model_name="claude-sonnet-4-5-20250929", + ) + + assert _get_metrics_from_response(response) == { + "prompt_tokens": 1095, + "completion_tokens": 40, + "total_tokens": 1135, + "tokens": 1135, + "prompt_cached_tokens": 0, + "prompt_cache_creation_5m_tokens": 1075, + "prompt_cache_creation_1h_tokens": 0, + } + + +def test_openai_cached_tokens_are_not_folded_into_prompt_tokens(): + response = _single_generation_response( + { + "input_tokens": 1000, + "output_tokens": 200, + "total_tokens": 1200, + "input_token_details": {"cache_read": 500}, + }, + model_name="gpt-4o-mini-2024-07-18", + ) + + assert _get_metrics_from_response(response) == { + "prompt_tokens": 1000, + "completion_tokens": 200, + "total_tokens": 1200, + "tokens": 1200, + "prompt_cached_tokens": 500, + } + + +def test_separately_reported_cache_tokens_are_folded_into_prompt_tokens(): + # Integrations that report uncached input only make cache tokens exceed + # the prompt total; normalize so prompt/total include cache tokens. + response = _single_generation_response( + { + "input_tokens": 20, + "output_tokens": 40, + "total_tokens": 60, + "input_token_details": {"cache_read": 1000, "cache_creation": 500}, + }, + model_name="claude-3-5-sonnet-20240620", + ) + + assert _get_metrics_from_response(response) == { + "prompt_tokens": 1520, + "completion_tokens": 40, + "total_tokens": 1560, + "tokens": 1560, + "prompt_cached_tokens": 1000, + "prompt_cache_creation_tokens": 500, + } + + @pytest.mark.vcr def test_prompt_caching_tokens(logger_memory_logger): from langchain_anthropic import ChatAnthropic @@ -1098,6 +1188,12 @@ def test_prompt_caching_tokens(logger_memory_logger): assert first_metrics["prompt_tokens"] >= first_cache_creation_tokens assert first_metrics["tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"] + # langchain-anthropic already folds cache read/creation tokens into + # usage_metadata input_tokens; the callback must not add them again. + assert res.usage_metadata is not None + assert first_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"] + assert first_metrics["total_tokens"] == res.usage_metadata["total_tokens"] + second_metrics = None for attempt in range(3): res = model.invoke( @@ -1134,6 +1230,10 @@ def test_prompt_caching_tokens(logger_memory_logger): assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"] assert second_metrics["tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"] + assert res.usage_metadata is not None + assert second_metrics["prompt_tokens"] == res.usage_metadata["input_tokens"] + assert second_metrics["total_tokens"] == res.usage_metadata["total_tokens"] + @pytest.mark.vcr def test_image_input(logger_memory_logger):