Add token usage tracking to generate_schema / agenerate_schema

generate_schema can make up to 5 internal LLM calls (field inference,
schema generation, validation retries) with no way to track token
consumption. Add an optional `usage: TokenUsage = None` parameter that
accumulates prompt/completion/total tokens across all calls in-place.

- _infer_target_json: accept and populate usage accumulator
- agenerate_schema: track usage after every aperform_completion call
  in the retry loop, forward usage to _infer_target_json
- generate_schema (sync): forward usage to agenerate_schema

Fully backward-compatible — omitting usage changes nothing.
This commit is contained in:
unclecode
2026-02-18 06:44:17 +00:00
parent 8576331d4e
commit c9cb0160cf
5 changed files with 726 additions and 3 deletions

View File

@@ -1599,9 +1599,13 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
return "\n".join(parts)
@staticmethod
async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None) -> Optional[dict]:
async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None, usage: 'TokenUsage' = None) -> Optional[dict]:
"""Infer a target JSON example from a query and HTML snippet via a quick LLM call.
Args:
usage: Optional TokenUsage accumulator. If provided, token counts from
this LLM call are added to it in-place.
Returns the parsed dict, or None if inference fails.
"""
from .utils import aperform_completion_with_backoff
@@ -1633,6 +1637,10 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
api_token=llm_config.api_token,
base_url=llm_config.base_url,
)
if usage is not None:
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
raw = response.choices[0].message.content
if not raw or not raw.strip():
return None
@@ -1726,6 +1734,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url: Union[str, List[str]] = None,
validate: bool = True,
max_refinements: int = 3,
usage: 'TokenUsage' = None,
**kwargs
) -> dict:
"""
@@ -1744,6 +1753,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
validate (bool): If True, validate the schema against the HTML and
refine via LLM feedback loop. Defaults to False (zero overhead).
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
usage (TokenUsage, optional): Token usage accumulator. If provided,
token counts from all LLM calls (including inference and
validation retries) are added to it in-place.
**kwargs: Additional args passed to LLM processor.
Returns:
@@ -1770,6 +1782,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url=url,
validate=validate,
max_refinements=max_refinements,
usage=usage,
**kwargs
)
@@ -1793,6 +1806,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url: Union[str, List[str]] = None,
validate: bool = True,
max_refinements: int = 3,
usage: 'TokenUsage' = None,
**kwargs
) -> dict:
"""
@@ -1815,6 +1829,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
validate (bool): If True, validate the schema against the HTML and
refine via LLM feedback loop. Defaults to False (zero overhead).
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
usage (TokenUsage, optional): Token usage accumulator. If provided,
token counts from all LLM calls (including inference and
validation retries) are added to it in-place.
**kwargs: Additional args passed to LLM processor.
Returns:
@@ -1913,7 +1930,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
if url is not None:
first_url = url if isinstance(url, str) else url[0]
inferred = await JsonElementExtractionStrategy._infer_target_json(
query=query, html_snippet=html, llm_config=llm_config, url=first_url
query=query, html_snippet=html, llm_config=llm_config, url=first_url, usage=usage
)
if inferred:
expected_fields = JsonElementExtractionStrategy._extract_expected_fields(inferred)
@@ -1939,6 +1956,10 @@ In this scenario, use your best judgment to generate the schema. You need to exa
messages=messages,
extra_args=kwargs,
)
if usage is not None:
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
raw = response.choices[0].message.content
if not raw or not raw.strip():
raise ValueError("LLM returned an empty response")

View File

@@ -4537,6 +4537,20 @@ xpath_schema = JsonXPathExtractionStrategy.generate_schema(
# Use the generated schema for fast, repeated extractions
strategy = JsonCssExtractionStrategy(css_schema)
```
### Token Usage Tracking
`generate_schema` may make multiple LLM calls internally (field inference, generation, validation retries). Track total token consumption by passing a `TokenUsage` accumulator:
```python
from crawl4ai.models import TokenUsage
usage = TokenUsage()
schema = JsonCssExtractionStrategy.generate_schema(
url="https://example.com/products",
query="extract product name and price",
usage=usage,
)
print(f"Total tokens: {usage.total_tokens}")
```
The `usage` parameter is optional and fully backward-compatible. Both `generate_schema` (sync) and `agenerate_schema` (async) support it.
### LLM Provider Options
1. **OpenAI GPT-4 (`openai/gpt4o`)**
- Default provider

View File

@@ -204,7 +204,9 @@ llm_strategy.show_usage()
# e.g. “Total usage: 1241 tokens across 2 chunk calls”
```
If your model provider doesnt return usage info, these fields might be partial or empty.
If your model provider doesn't return usage info, these fields might be partial or empty.
> **Tip:** `JsonCssExtractionStrategy.generate_schema()` also supports token usage tracking via an optional `usage` parameter. See [Token Usage Tracking in Schema Generation](./no-llm-strategies.md#token-usage-tracking) for details.
---

View File

@@ -761,6 +761,38 @@ schema = JsonCssExtractionStrategy.generate_schema(
The generator also understands sibling layouts — for sites like Hacker News where data is split across sibling elements, it will automatically use the [`source` field](#sibling-data) to reach sibling data.
### Token Usage Tracking
`generate_schema` may make multiple LLM calls internally (field inference, schema generation, validation retries). To track the total token consumption across all of these calls, pass a `TokenUsage` accumulator:
```python
from crawl4ai import JsonCssExtractionStrategy
from crawl4ai.models import TokenUsage
usage = TokenUsage()
schema = JsonCssExtractionStrategy.generate_schema(
url="https://news.ycombinator.com",
query="Extract each story: title, url, score, author",
usage=usage,
)
print(f"Prompt tokens: {usage.prompt_tokens}")
print(f"Completion tokens: {usage.completion_tokens}")
print(f"Total tokens: {usage.total_tokens}")
```
The `usage` parameter is optional — omitting it changes nothing (fully backward-compatible). You can also reuse the same accumulator across multiple calls to get a grand total:
```python
usage = TokenUsage()
schema1 = JsonCssExtractionStrategy.generate_schema(url=url1, query=q1, usage=usage)
schema2 = JsonCssExtractionStrategy.generate_schema(url=url2, query=q2, usage=usage)
print(f"Grand total: {usage.total_tokens} tokens")
```
Both `generate_schema` (sync) and `agenerate_schema` (async) support the `usage` parameter.
### LLM Provider Options
1. **OpenAI GPT-4 (`openai/gpt4o`)**

View File

@@ -0,0 +1,654 @@
"""Tests for TokenUsage accumulation in generate_schema / agenerate_schema.
Covers:
- Backward compatibility (usage=None, the default)
- Single-shot schema generation accumulates usage
- Validation retry loop accumulates across all LLM calls
- _infer_target_json accumulates its own LLM call
- Sync wrapper forwards usage correctly
- JSON parse failure retry also accumulates usage
- usage object receives correct cumulative totals
"""
import asyncio
import json
import pytest
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch, MagicMock
from crawl4ai.extraction_strategy import JsonElementExtractionStrategy, JsonCssExtractionStrategy
from crawl4ai.models import TokenUsage
# The functions are imported lazily inside method bodies via `from .utils import ...`
# so we must patch at the source module.
PATCH_TARGET = "crawl4ai.utils.aperform_completion_with_backoff"
# ---------------------------------------------------------------------------
# Helpers: fake LLM response builder
# ---------------------------------------------------------------------------
def _make_llm_response(content: str, prompt_tokens: int = 100, completion_tokens: int = 50):
"""Build a fake litellm-style response with .usage and .choices."""
return SimpleNamespace(
usage=SimpleNamespace(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
completion_tokens_details=None,
prompt_tokens_details=None,
),
choices=[
SimpleNamespace(
message=SimpleNamespace(content=content)
)
],
)
# A valid CSS schema that will pass validation against SAMPLE_HTML
VALID_SCHEMA = {
"name": "products",
"baseSelector": ".product",
"fields": [
{"name": "title", "selector": ".title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"},
],
}
SAMPLE_HTML = """
<div class="products">
<div class="product">
<span class="title">Widget</span>
<span class="price">$10</span>
</div>
<div class="product">
<span class="title">Gadget</span>
<span class="price">$20</span>
</div>
</div>
"""
# A schema with a bad baseSelector — will fail validation and trigger retry
BAD_SCHEMA = {
"name": "products",
"baseSelector": ".nonexistent-selector",
"fields": [
{"name": "title", "selector": ".title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"},
],
}
# Fake LLMConfig
@dataclass
class FakeLLMConfig:
provider: str = "fake/model"
api_token: str = "fake-token"
base_url: str = None
backoff_base_delay: float = 0
backoff_max_attempts: int = 1
backoff_exponential_factor: int = 2
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestGenerateSchemaUsage:
"""Test suite for usage tracking in generate_schema / agenerate_schema."""
@pytest.mark.asyncio
async def test_backward_compat_usage_none(self):
"""When usage is not passed (default None), everything works as before."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
)
assert isinstance(result, dict)
assert result["name"] == "products"
@pytest.mark.asyncio
async def test_single_shot_no_validate(self):
"""Single LLM call with validate=False populates usage correctly."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 80
assert usage.total_tokens == 280
@pytest.mark.asyncio
async def test_validation_success_first_try(self):
"""With validate=True and schema passes validation on first try, usage reflects 1 call."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=300, completion_tokens=120
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
# Provide target_json_example to skip _infer_target_json
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Only 1 LLM call since validation passed
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 120
assert usage.total_tokens == 420
@pytest.mark.asyncio
async def test_validation_retries_accumulate_usage(self):
"""When validation fails, retry calls accumulate into the same usage object."""
usage = TokenUsage()
# First call returns bad schema (fails validation), second returns good schema
responses = [
_make_llm_response(json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100),
_make_llm_response(json.dumps(VALID_SCHEMA), prompt_tokens=350, completion_tokens=120),
]
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
idx = min(call_count, len(responses) - 1)
call_count += 1
return responses[idx]
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Two LLM calls: 300+350=650 prompt, 100+120=220 completion
assert usage.prompt_tokens == 650
assert usage.completion_tokens == 220
assert usage.total_tokens == 870
@pytest.mark.asyncio
async def test_infer_target_json_accumulates_usage(self):
"""When validate=True and no target_json_example, _infer_target_json makes an extra LLM call."""
usage = TokenUsage()
infer_response = _make_llm_response(
'{"title": "Widget", "price": "$10"}',
prompt_tokens=50,
completion_tokens=30,
)
schema_response = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=300,
completion_tokens=120,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
# First call is _infer_target_json, second is schema generation
if call_count == 1:
return infer_response
return schema_response
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract product title and price",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
# No target_json_example — triggers _infer_target_json
)
assert result["name"] == "products"
# _infer_target_json: 50+30 = 80
# schema generation: 300+120 = 420
# Total: 350 prompt, 150 completion, 500 total
assert usage.prompt_tokens == 350
assert usage.completion_tokens == 150
assert usage.total_tokens == 500
@pytest.mark.asyncio
async def test_infer_plus_retries_accumulate(self):
"""Full pipeline: infer + bad schema + good schema = 3 calls accumulated."""
usage = TokenUsage()
infer_resp = _make_llm_response(
'{"title": "x", "price": "y"}',
prompt_tokens=50, completion_tokens=20,
)
bad_resp = _make_llm_response(
json.dumps(BAD_SCHEMA),
prompt_tokens=300, completion_tokens=100,
)
good_resp = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=400, completion_tokens=150,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return infer_resp
elif call_count == 2:
return bad_resp
else:
return good_resp
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract products",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
)
# 3 calls total
assert call_count == 3
assert usage.prompt_tokens == 750 # 50 + 300 + 400
assert usage.completion_tokens == 270 # 20 + 100 + 150
assert usage.total_tokens == 1020 # 70 + 400 + 550
@pytest.mark.asyncio
async def test_json_parse_failure_retry_accumulates(self):
"""When LLM returns invalid JSON, the retry also accumulates usage."""
usage = TokenUsage()
# First response is not valid JSON, second is valid
bad_json_resp = _make_llm_response(
"this is not json {{{",
prompt_tokens=200, completion_tokens=60,
)
good_resp = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=250, completion_tokens=80,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return bad_json_resp
return good_resp
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Both calls tracked: even the one that returned bad JSON
assert usage.prompt_tokens == 450 # 200 + 250
assert usage.completion_tokens == 140 # 60 + 80
assert usage.total_tokens == 590
@pytest.mark.asyncio
async def test_usage_none_does_not_crash(self):
"""Explicitly passing usage=None should not raise any errors."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=None,
)
assert isinstance(result, dict)
@pytest.mark.asyncio
async def test_preexisting_usage_values_are_added_to(self):
"""If usage already has values, new tokens are ADDED, not replaced."""
usage = TokenUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert usage.prompt_tokens == 1200 # 1000 + 200
assert usage.completion_tokens == 580 # 500 + 80
assert usage.total_tokens == 1780 # 1500 + 280
def test_sync_wrapper_passes_usage(self):
"""The sync generate_schema forwards usage to agenerate_schema."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = JsonElementExtractionStrategy.generate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 80
assert usage.total_tokens == 280
def test_sync_wrapper_usage_none_backward_compat(self):
"""Sync wrapper with no usage arg (default) still works."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = JsonElementExtractionStrategy.generate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
)
assert isinstance(result, dict)
assert result["name"] == "products"
@pytest.mark.asyncio
async def test_max_refinements_zero_single_call(self):
"""max_refinements=0 with validate=True means exactly 1 attempt, 1 usage entry."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=0,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
# Even though validation fails, only 1 attempt (0 refinements)
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 100
assert usage.total_tokens == 400
@pytest.mark.asyncio
async def test_css_subclass_inherits_usage(self):
"""JsonCssExtractionStrategy.agenerate_schema also supports usage."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=150, completion_tokens=60
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonCssExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.total_tokens == 210
@pytest.mark.asyncio
async def test_infer_target_json_failure_still_tracks_nothing(self):
"""If _infer_target_json raises (and catches), usage should not break.
When the inference LLM call itself throws an exception before we get
response.usage, no tokens should be added (graceful degradation).
"""
usage = TokenUsage()
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# _infer_target_json call — simulate exception
raise ConnectionError("LLM is down")
# Schema generation call
return _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=300,
completion_tokens=100,
)
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract products",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=0,
usage=usage,
)
# Only the schema call counted; infer call failed before tracking
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 100
assert usage.total_tokens == 400
@pytest.mark.asyncio
async def test_multiple_bad_retries_then_best_effort(self):
"""All retries fail validation, usage still accumulates for every attempt."""
usage = TokenUsage()
# Every call returns bad schema — validation will always fail
mock_response = _make_llm_response(
json.dumps(BAD_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=2, # 1 initial + 2 retries = 3 calls
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
# Returns best-effort (last schema), but all 3 calls tracked
assert usage.prompt_tokens == 600 # 200 * 3
assert usage.completion_tokens == 240 # 80 * 3
assert usage.total_tokens == 840 # 280 * 3
class TestInferTargetJsonUsage:
"""Isolated tests for _infer_target_json usage tracking."""
@pytest.mark.asyncio
async def test_infer_tracks_usage(self):
"""Direct call to _infer_target_json with usage accumulator."""
usage = TokenUsage()
mock_response = _make_llm_response(
'{"name": "test", "value": "123"}',
prompt_tokens=80,
completion_tokens=25,
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names and values",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
assert result == {"name": "test", "value": "123"}
assert usage.prompt_tokens == 80
assert usage.completion_tokens == 25
assert usage.total_tokens == 105
@pytest.mark.asyncio
async def test_infer_usage_none_backward_compat(self):
"""_infer_target_json with usage=None (default) still works."""
mock_response = _make_llm_response('{"name": "test"}')
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
)
assert result == {"name": "test"}
@pytest.mark.asyncio
async def test_infer_exception_no_usage_side_effect(self):
"""When _infer_target_json fails, usage is untouched (exception before tracking)."""
usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
side_effect=RuntimeError("API down"),
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
# Returns None on failure
assert result is None
# Usage unchanged — exception happened before tracking
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
@pytest.mark.asyncio
async def test_infer_empty_response_still_tracks(self):
"""When LLM returns empty content, usage is still tracked (response was received)."""
usage = TokenUsage()
mock_response = _make_llm_response("", prompt_tokens=80, completion_tokens=5)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
# Returns None because content is empty
assert result is None
# But usage was tracked because we got a response
assert usage.prompt_tokens == 80
assert usage.completion_tokens == 5
assert usage.total_tokens == 85