Add token usage tracking to generate_schema / agenerate_schema
generate_schema can make up to 5 internal LLM calls (field inference, schema generation, validation retries) with no way to track token consumption. Add an optional `usage: TokenUsage = None` parameter that accumulates prompt/completion/total tokens across all calls in-place. - _infer_target_json: accept and populate usage accumulator - agenerate_schema: track usage after every aperform_completion call in the retry loop, forward usage to _infer_target_json - generate_schema (sync): forward usage to agenerate_schema Fully backward-compatible — omitting usage changes nothing.
This commit is contained in:
@@ -1599,9 +1599,13 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None) -> Optional[dict]:
|
||||
async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None, usage: 'TokenUsage' = None) -> Optional[dict]:
|
||||
"""Infer a target JSON example from a query and HTML snippet via a quick LLM call.
|
||||
|
||||
Args:
|
||||
usage: Optional TokenUsage accumulator. If provided, token counts from
|
||||
this LLM call are added to it in-place.
|
||||
|
||||
Returns the parsed dict, or None if inference fails.
|
||||
"""
|
||||
from .utils import aperform_completion_with_backoff
|
||||
@@ -1633,6 +1637,10 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
api_token=llm_config.api_token,
|
||||
base_url=llm_config.base_url,
|
||||
)
|
||||
if usage is not None:
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
raw = response.choices[0].message.content
|
||||
if not raw or not raw.strip():
|
||||
return None
|
||||
@@ -1726,6 +1734,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
url: Union[str, List[str]] = None,
|
||||
validate: bool = True,
|
||||
max_refinements: int = 3,
|
||||
usage: 'TokenUsage' = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -1744,6 +1753,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
validate (bool): If True, validate the schema against the HTML and
|
||||
refine via LLM feedback loop. Defaults to False (zero overhead).
|
||||
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
|
||||
usage (TokenUsage, optional): Token usage accumulator. If provided,
|
||||
token counts from all LLM calls (including inference and
|
||||
validation retries) are added to it in-place.
|
||||
**kwargs: Additional args passed to LLM processor.
|
||||
|
||||
Returns:
|
||||
@@ -1770,6 +1782,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
url=url,
|
||||
validate=validate,
|
||||
max_refinements=max_refinements,
|
||||
usage=usage,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -1793,6 +1806,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
url: Union[str, List[str]] = None,
|
||||
validate: bool = True,
|
||||
max_refinements: int = 3,
|
||||
usage: 'TokenUsage' = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -1815,6 +1829,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
validate (bool): If True, validate the schema against the HTML and
|
||||
refine via LLM feedback loop. Defaults to False (zero overhead).
|
||||
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
|
||||
usage (TokenUsage, optional): Token usage accumulator. If provided,
|
||||
token counts from all LLM calls (including inference and
|
||||
validation retries) are added to it in-place.
|
||||
**kwargs: Additional args passed to LLM processor.
|
||||
|
||||
Returns:
|
||||
@@ -1913,7 +1930,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
if url is not None:
|
||||
first_url = url if isinstance(url, str) else url[0]
|
||||
inferred = await JsonElementExtractionStrategy._infer_target_json(
|
||||
query=query, html_snippet=html, llm_config=llm_config, url=first_url
|
||||
query=query, html_snippet=html, llm_config=llm_config, url=first_url, usage=usage
|
||||
)
|
||||
if inferred:
|
||||
expected_fields = JsonElementExtractionStrategy._extract_expected_fields(inferred)
|
||||
@@ -1939,6 +1956,10 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
messages=messages,
|
||||
extra_args=kwargs,
|
||||
)
|
||||
if usage is not None:
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
raw = response.choices[0].message.content
|
||||
if not raw or not raw.strip():
|
||||
raise ValueError("LLM returned an empty response")
|
||||
|
||||
@@ -4537,6 +4537,20 @@ xpath_schema = JsonXPathExtractionStrategy.generate_schema(
|
||||
# Use the generated schema for fast, repeated extractions
|
||||
strategy = JsonCssExtractionStrategy(css_schema)
|
||||
```
|
||||
### Token Usage Tracking
|
||||
`generate_schema` may make multiple LLM calls internally (field inference, generation, validation retries). Track total token consumption by passing a `TokenUsage` accumulator:
|
||||
```python
|
||||
from crawl4ai.models import TokenUsage
|
||||
|
||||
usage = TokenUsage()
|
||||
schema = JsonCssExtractionStrategy.generate_schema(
|
||||
url="https://example.com/products",
|
||||
query="extract product name and price",
|
||||
usage=usage,
|
||||
)
|
||||
print(f"Total tokens: {usage.total_tokens}")
|
||||
```
|
||||
The `usage` parameter is optional and fully backward-compatible. Both `generate_schema` (sync) and `agenerate_schema` (async) support it.
|
||||
### LLM Provider Options
|
||||
1. **OpenAI GPT-4 (`openai/gpt4o`)**
|
||||
- Default provider
|
||||
|
||||
@@ -204,7 +204,9 @@ llm_strategy.show_usage()
|
||||
# e.g. “Total usage: 1241 tokens across 2 chunk calls”
|
||||
```
|
||||
|
||||
If your model provider doesn’t return usage info, these fields might be partial or empty.
|
||||
If your model provider doesn't return usage info, these fields might be partial or empty.
|
||||
|
||||
> **Tip:** `JsonCssExtractionStrategy.generate_schema()` also supports token usage tracking via an optional `usage` parameter. See [Token Usage Tracking in Schema Generation](./no-llm-strategies.md#token-usage-tracking) for details.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -761,6 +761,38 @@ schema = JsonCssExtractionStrategy.generate_schema(
|
||||
|
||||
The generator also understands sibling layouts — for sites like Hacker News where data is split across sibling elements, it will automatically use the [`source` field](#sibling-data) to reach sibling data.
|
||||
|
||||
### Token Usage Tracking
|
||||
|
||||
`generate_schema` may make multiple LLM calls internally (field inference, schema generation, validation retries). To track the total token consumption across all of these calls, pass a `TokenUsage` accumulator:
|
||||
|
||||
```python
|
||||
from crawl4ai import JsonCssExtractionStrategy
|
||||
from crawl4ai.models import TokenUsage
|
||||
|
||||
usage = TokenUsage()
|
||||
|
||||
schema = JsonCssExtractionStrategy.generate_schema(
|
||||
url="https://news.ycombinator.com",
|
||||
query="Extract each story: title, url, score, author",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
print(f"Prompt tokens: {usage.prompt_tokens}")
|
||||
print(f"Completion tokens: {usage.completion_tokens}")
|
||||
print(f"Total tokens: {usage.total_tokens}")
|
||||
```
|
||||
|
||||
The `usage` parameter is optional — omitting it changes nothing (fully backward-compatible). You can also reuse the same accumulator across multiple calls to get a grand total:
|
||||
|
||||
```python
|
||||
usage = TokenUsage()
|
||||
schema1 = JsonCssExtractionStrategy.generate_schema(url=url1, query=q1, usage=usage)
|
||||
schema2 = JsonCssExtractionStrategy.generate_schema(url=url2, query=q2, usage=usage)
|
||||
print(f"Grand total: {usage.total_tokens} tokens")
|
||||
```
|
||||
|
||||
Both `generate_schema` (sync) and `agenerate_schema` (async) support the `usage` parameter.
|
||||
|
||||
### LLM Provider Options
|
||||
|
||||
1. **OpenAI GPT-4 (`openai/gpt4o`)**
|
||||
|
||||
654
tests/general/test_generate_schema_usage.py
Normal file
654
tests/general/test_generate_schema_usage.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""Tests for TokenUsage accumulation in generate_schema / agenerate_schema.
|
||||
|
||||
Covers:
|
||||
- Backward compatibility (usage=None, the default)
|
||||
- Single-shot schema generation accumulates usage
|
||||
- Validation retry loop accumulates across all LLM calls
|
||||
- _infer_target_json accumulates its own LLM call
|
||||
- Sync wrapper forwards usage correctly
|
||||
- JSON parse failure retry also accumulates usage
|
||||
- usage object receives correct cumulative totals
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from crawl4ai.extraction_strategy import JsonElementExtractionStrategy, JsonCssExtractionStrategy
|
||||
from crawl4ai.models import TokenUsage
|
||||
|
||||
# The functions are imported lazily inside method bodies via `from .utils import ...`
|
||||
# so we must patch at the source module.
|
||||
PATCH_TARGET = "crawl4ai.utils.aperform_completion_with_backoff"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: fake LLM response builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_llm_response(content: str, prompt_tokens: int = 100, completion_tokens: int = 50):
|
||||
"""Build a fake litellm-style response with .usage and .choices."""
|
||||
return SimpleNamespace(
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
completion_tokens_details=None,
|
||||
prompt_tokens_details=None,
|
||||
),
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
message=SimpleNamespace(content=content)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# A valid CSS schema that will pass validation against SAMPLE_HTML
|
||||
VALID_SCHEMA = {
|
||||
"name": "products",
|
||||
"baseSelector": ".product",
|
||||
"fields": [
|
||||
{"name": "title", "selector": ".title", "type": "text"},
|
||||
{"name": "price", "selector": ".price", "type": "text"},
|
||||
],
|
||||
}
|
||||
|
||||
SAMPLE_HTML = """
|
||||
<div class="products">
|
||||
<div class="product">
|
||||
<span class="title">Widget</span>
|
||||
<span class="price">$10</span>
|
||||
</div>
|
||||
<div class="product">
|
||||
<span class="title">Gadget</span>
|
||||
<span class="price">$20</span>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# A schema with a bad baseSelector — will fail validation and trigger retry
|
||||
BAD_SCHEMA = {
|
||||
"name": "products",
|
||||
"baseSelector": ".nonexistent-selector",
|
||||
"fields": [
|
||||
{"name": "title", "selector": ".title", "type": "text"},
|
||||
{"name": "price", "selector": ".price", "type": "text"},
|
||||
],
|
||||
}
|
||||
|
||||
# Fake LLMConfig
|
||||
@dataclass
|
||||
class FakeLLMConfig:
|
||||
provider: str = "fake/model"
|
||||
api_token: str = "fake-token"
|
||||
base_url: str = None
|
||||
backoff_base_delay: float = 0
|
||||
backoff_max_attempts: int = 1
|
||||
backoff_exponential_factor: int = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateSchemaUsage:
|
||||
"""Test suite for usage tracking in generate_schema / agenerate_schema."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backward_compat_usage_none(self):
|
||||
"""When usage is not passed (default None), everything works as before."""
|
||||
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "products"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_shot_no_validate(self):
|
||||
"""Single LLM call with validate=False populates usage correctly."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
assert usage.prompt_tokens == 200
|
||||
assert usage.completion_tokens == 80
|
||||
assert usage.total_tokens == 280
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_success_first_try(self):
|
||||
"""With validate=True and schema passes validation on first try, usage reflects 1 call."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA), prompt_tokens=300, completion_tokens=120
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=3,
|
||||
usage=usage,
|
||||
# Provide target_json_example to skip _infer_target_json
|
||||
target_json_example='{"title": "x", "price": "y"}',
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
# Only 1 LLM call since validation passed
|
||||
assert usage.prompt_tokens == 300
|
||||
assert usage.completion_tokens == 120
|
||||
assert usage.total_tokens == 420
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_retries_accumulate_usage(self):
|
||||
"""When validation fails, retry calls accumulate into the same usage object."""
|
||||
usage = TokenUsage()
|
||||
|
||||
# First call returns bad schema (fails validation), second returns good schema
|
||||
responses = [
|
||||
_make_llm_response(json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100),
|
||||
_make_llm_response(json.dumps(VALID_SCHEMA), prompt_tokens=350, completion_tokens=120),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
idx = min(call_count, len(responses) - 1)
|
||||
call_count += 1
|
||||
return responses[idx]
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
side_effect=mock_completion,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=3,
|
||||
usage=usage,
|
||||
target_json_example='{"title": "x", "price": "y"}',
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
# Two LLM calls: 300+350=650 prompt, 100+120=220 completion
|
||||
assert usage.prompt_tokens == 650
|
||||
assert usage.completion_tokens == 220
|
||||
assert usage.total_tokens == 870
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_target_json_accumulates_usage(self):
|
||||
"""When validate=True and no target_json_example, _infer_target_json makes an extra LLM call."""
|
||||
usage = TokenUsage()
|
||||
|
||||
infer_response = _make_llm_response(
|
||||
'{"title": "Widget", "price": "$10"}',
|
||||
prompt_tokens=50,
|
||||
completion_tokens=30,
|
||||
)
|
||||
schema_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA),
|
||||
prompt_tokens=300,
|
||||
completion_tokens=120,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First call is _infer_target_json, second is schema generation
|
||||
if call_count == 1:
|
||||
return infer_response
|
||||
return schema_response
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
side_effect=mock_completion,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
query="extract product title and price",
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=3,
|
||||
usage=usage,
|
||||
# No target_json_example — triggers _infer_target_json
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
# _infer_target_json: 50+30 = 80
|
||||
# schema generation: 300+120 = 420
|
||||
# Total: 350 prompt, 150 completion, 500 total
|
||||
assert usage.prompt_tokens == 350
|
||||
assert usage.completion_tokens == 150
|
||||
assert usage.total_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_plus_retries_accumulate(self):
|
||||
"""Full pipeline: infer + bad schema + good schema = 3 calls accumulated."""
|
||||
usage = TokenUsage()
|
||||
|
||||
infer_resp = _make_llm_response(
|
||||
'{"title": "x", "price": "y"}',
|
||||
prompt_tokens=50, completion_tokens=20,
|
||||
)
|
||||
bad_resp = _make_llm_response(
|
||||
json.dumps(BAD_SCHEMA),
|
||||
prompt_tokens=300, completion_tokens=100,
|
||||
)
|
||||
good_resp = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA),
|
||||
prompt_tokens=400, completion_tokens=150,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return infer_resp
|
||||
elif call_count == 2:
|
||||
return bad_resp
|
||||
else:
|
||||
return good_resp
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
side_effect=mock_completion,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
query="extract products",
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=3,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# 3 calls total
|
||||
assert call_count == 3
|
||||
assert usage.prompt_tokens == 750 # 50 + 300 + 400
|
||||
assert usage.completion_tokens == 270 # 20 + 100 + 150
|
||||
assert usage.total_tokens == 1020 # 70 + 400 + 550
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_parse_failure_retry_accumulates(self):
|
||||
"""When LLM returns invalid JSON, the retry also accumulates usage."""
|
||||
usage = TokenUsage()
|
||||
|
||||
# First response is not valid JSON, second is valid
|
||||
bad_json_resp = _make_llm_response(
|
||||
"this is not json {{{",
|
||||
prompt_tokens=200, completion_tokens=60,
|
||||
)
|
||||
good_resp = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA),
|
||||
prompt_tokens=250, completion_tokens=80,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return bad_json_resp
|
||||
return good_resp
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
side_effect=mock_completion,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=3,
|
||||
usage=usage,
|
||||
target_json_example='{"title": "x", "price": "y"}',
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
# Both calls tracked: even the one that returned bad JSON
|
||||
assert usage.prompt_tokens == 450 # 200 + 250
|
||||
assert usage.completion_tokens == 140 # 60 + 80
|
||||
assert usage.total_tokens == 590
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_none_does_not_crash(self):
|
||||
"""Explicitly passing usage=None should not raise any errors."""
|
||||
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preexisting_usage_values_are_added_to(self):
|
||||
"""If usage already has values, new tokens are ADDED, not replaced."""
|
||||
usage = TokenUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
|
||||
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
assert usage.prompt_tokens == 1200 # 1000 + 200
|
||||
assert usage.completion_tokens == 580 # 500 + 80
|
||||
assert usage.total_tokens == 1780 # 1500 + 280
|
||||
|
||||
def test_sync_wrapper_passes_usage(self):
|
||||
"""The sync generate_schema forwards usage to agenerate_schema."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = JsonElementExtractionStrategy.generate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
assert usage.prompt_tokens == 200
|
||||
assert usage.completion_tokens == 80
|
||||
assert usage.total_tokens == 280
|
||||
|
||||
def test_sync_wrapper_usage_none_backward_compat(self):
|
||||
"""Sync wrapper with no usage arg (default) still works."""
|
||||
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = JsonElementExtractionStrategy.generate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "products"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_refinements_zero_single_call(self):
|
||||
"""max_refinements=0 with validate=True means exactly 1 attempt, 1 usage entry."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=0,
|
||||
usage=usage,
|
||||
target_json_example='{"title": "x", "price": "y"}',
|
||||
)
|
||||
|
||||
# Even though validation fails, only 1 attempt (0 refinements)
|
||||
assert usage.prompt_tokens == 300
|
||||
assert usage.completion_tokens == 100
|
||||
assert usage.total_tokens == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_css_subclass_inherits_usage(self):
|
||||
"""JsonCssExtractionStrategy.agenerate_schema also supports usage."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA), prompt_tokens=150, completion_tokens=60
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonCssExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=False,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
assert result["name"] == "products"
|
||||
assert usage.total_tokens == 210
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_target_json_failure_still_tracks_nothing(self):
|
||||
"""If _infer_target_json raises (and catches), usage should not break.
|
||||
|
||||
When the inference LLM call itself throws an exception before we get
|
||||
response.usage, no tokens should be added (graceful degradation).
|
||||
"""
|
||||
usage = TokenUsage()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _infer_target_json call — simulate exception
|
||||
raise ConnectionError("LLM is down")
|
||||
# Schema generation call
|
||||
return _make_llm_response(
|
||||
json.dumps(VALID_SCHEMA),
|
||||
prompt_tokens=300,
|
||||
completion_tokens=100,
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
side_effect=mock_completion,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
query="extract products",
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=0,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Only the schema call counted; infer call failed before tracking
|
||||
assert usage.prompt_tokens == 300
|
||||
assert usage.completion_tokens == 100
|
||||
assert usage.total_tokens == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_bad_retries_then_best_effort(self):
|
||||
"""All retries fail validation, usage still accumulates for every attempt."""
|
||||
usage = TokenUsage()
|
||||
|
||||
# Every call returns bad schema — validation will always fail
|
||||
mock_response = _make_llm_response(
|
||||
json.dumps(BAD_SCHEMA), prompt_tokens=200, completion_tokens=80
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy.agenerate_schema(
|
||||
html=SAMPLE_HTML,
|
||||
llm_config=FakeLLMConfig(),
|
||||
validate=True,
|
||||
max_refinements=2, # 1 initial + 2 retries = 3 calls
|
||||
usage=usage,
|
||||
target_json_example='{"title": "x", "price": "y"}',
|
||||
)
|
||||
|
||||
# Returns best-effort (last schema), but all 3 calls tracked
|
||||
assert usage.prompt_tokens == 600 # 200 * 3
|
||||
assert usage.completion_tokens == 240 # 80 * 3
|
||||
assert usage.total_tokens == 840 # 280 * 3
|
||||
|
||||
|
||||
class TestInferTargetJsonUsage:
|
||||
"""Isolated tests for _infer_target_json usage tracking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_tracks_usage(self):
|
||||
"""Direct call to _infer_target_json with usage accumulator."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response(
|
||||
'{"name": "test", "value": "123"}',
|
||||
prompt_tokens=80,
|
||||
completion_tokens=25,
|
||||
)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy._infer_target_json(
|
||||
query="extract names and values",
|
||||
html_snippet="<div>test</div>",
|
||||
llm_config=FakeLLMConfig(),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
assert result == {"name": "test", "value": "123"}
|
||||
assert usage.prompt_tokens == 80
|
||||
assert usage.completion_tokens == 25
|
||||
assert usage.total_tokens == 105
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_usage_none_backward_compat(self):
|
||||
"""_infer_target_json with usage=None (default) still works."""
|
||||
mock_response = _make_llm_response('{"name": "test"}')
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy._infer_target_json(
|
||||
query="extract names",
|
||||
html_snippet="<div>test</div>",
|
||||
llm_config=FakeLLMConfig(),
|
||||
)
|
||||
|
||||
assert result == {"name": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_exception_no_usage_side_effect(self):
|
||||
"""When _infer_target_json fails, usage is untouched (exception before tracking)."""
|
||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("API down"),
|
||||
):
|
||||
result = await JsonElementExtractionStrategy._infer_target_json(
|
||||
query="extract names",
|
||||
html_snippet="<div>test</div>",
|
||||
llm_config=FakeLLMConfig(),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Returns None on failure
|
||||
assert result is None
|
||||
# Usage unchanged — exception happened before tracking
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_empty_response_still_tracks(self):
|
||||
"""When LLM returns empty content, usage is still tracked (response was received)."""
|
||||
usage = TokenUsage()
|
||||
mock_response = _make_llm_response("", prompt_tokens=80, completion_tokens=5)
|
||||
|
||||
with patch(
|
||||
PATCH_TARGET,
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await JsonElementExtractionStrategy._infer_target_json(
|
||||
query="extract names",
|
||||
html_snippet="<div>test</div>",
|
||||
llm_config=FakeLLMConfig(),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Returns None because content is empty
|
||||
assert result is None
|
||||
# But usage was tracked because we got a response
|
||||
assert usage.prompt_tokens == 80
|
||||
assert usage.completion_tokens == 5
|
||||
assert usage.total_tokens == 85
|
||||
Reference in New Issue
Block a user