diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py
index ad024b32..bea176e6 100644
--- a/crawl4ai/extraction_strategy.py
+++ b/crawl4ai/extraction_strategy.py
@@ -1599,9 +1599,13 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
return "\n".join(parts)
@staticmethod
- async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None) -> Optional[dict]:
+ async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None, usage: 'TokenUsage' = None) -> Optional[dict]:
"""Infer a target JSON example from a query and HTML snippet via a quick LLM call.
+ Args:
+ usage: Optional TokenUsage accumulator. If provided, token counts from
+ this LLM call are added to it in-place.
+
Returns the parsed dict, or None if inference fails.
"""
from .utils import aperform_completion_with_backoff
@@ -1633,6 +1637,10 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
api_token=llm_config.api_token,
base_url=llm_config.base_url,
)
+ if usage is not None:
+ usage.completion_tokens += response.usage.completion_tokens
+ usage.prompt_tokens += response.usage.prompt_tokens
+ usage.total_tokens += response.usage.total_tokens
raw = response.choices[0].message.content
if not raw or not raw.strip():
return None
@@ -1726,6 +1734,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url: Union[str, List[str]] = None,
validate: bool = True,
max_refinements: int = 3,
+ usage: 'TokenUsage' = None,
**kwargs
) -> dict:
"""
@@ -1744,6 +1753,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
validate (bool): If True, validate the schema against the HTML and
refine via LLM feedback loop. Defaults to False (zero overhead).
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
+ usage (TokenUsage, optional): Token usage accumulator. If provided,
+ token counts from all LLM calls (including inference and
+ validation retries) are added to it in-place.
**kwargs: Additional args passed to LLM processor.
Returns:
@@ -1770,6 +1782,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url=url,
validate=validate,
max_refinements=max_refinements,
+ usage=usage,
**kwargs
)
@@ -1793,6 +1806,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
url: Union[str, List[str]] = None,
validate: bool = True,
max_refinements: int = 3,
+ usage: 'TokenUsage' = None,
**kwargs
) -> dict:
"""
@@ -1815,6 +1829,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa
validate (bool): If True, validate the schema against the HTML and
refine via LLM feedback loop. Defaults to False (zero overhead).
max_refinements (int): Max refinement rounds when validate=True. Defaults to 3.
+ usage (TokenUsage, optional): Token usage accumulator. If provided,
+ token counts from all LLM calls (including inference and
+ validation retries) are added to it in-place.
**kwargs: Additional args passed to LLM processor.
Returns:
@@ -1913,7 +1930,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
if url is not None:
first_url = url if isinstance(url, str) else url[0]
inferred = await JsonElementExtractionStrategy._infer_target_json(
- query=query, html_snippet=html, llm_config=llm_config, url=first_url
+ query=query, html_snippet=html, llm_config=llm_config, url=first_url, usage=usage
)
if inferred:
expected_fields = JsonElementExtractionStrategy._extract_expected_fields(inferred)
@@ -1939,6 +1956,10 @@ In this scenario, use your best judgment to generate the schema. You need to exa
messages=messages,
extra_args=kwargs,
)
+ if usage is not None:
+ usage.completion_tokens += response.usage.completion_tokens
+ usage.prompt_tokens += response.usage.prompt_tokens
+ usage.total_tokens += response.usage.total_tokens
raw = response.choices[0].message.content
if not raw or not raw.strip():
raise ValueError("LLM returned an empty response")
diff --git a/docs/md_v2/complete-sdk-reference.md b/docs/md_v2/complete-sdk-reference.md
index 61fa4fca..3d639edc 100644
--- a/docs/md_v2/complete-sdk-reference.md
+++ b/docs/md_v2/complete-sdk-reference.md
@@ -4537,6 +4537,20 @@ xpath_schema = JsonXPathExtractionStrategy.generate_schema(
# Use the generated schema for fast, repeated extractions
strategy = JsonCssExtractionStrategy(css_schema)
```
+### Token Usage Tracking
+`generate_schema` may make multiple LLM calls internally (field inference, generation, validation retries). Track total token consumption by passing a `TokenUsage` accumulator:
+```python
+from crawl4ai.models import TokenUsage
+
+usage = TokenUsage()
+schema = JsonCssExtractionStrategy.generate_schema(
+ url="https://example.com/products",
+ query="extract product name and price",
+ usage=usage,
+)
+print(f"Total tokens: {usage.total_tokens}")
+```
+The `usage` parameter is optional and fully backward-compatible. Both `generate_schema` (sync) and `agenerate_schema` (async) support it.
### LLM Provider Options
1. **OpenAI GPT-4 (`openai/gpt4o`)**
- Default provider
diff --git a/docs/md_v2/extraction/llm-strategies.md b/docs/md_v2/extraction/llm-strategies.md
index df948a9e..cba4d6e4 100644
--- a/docs/md_v2/extraction/llm-strategies.md
+++ b/docs/md_v2/extraction/llm-strategies.md
@@ -204,7 +204,9 @@ llm_strategy.show_usage()
# e.g. “Total usage: 1241 tokens across 2 chunk calls”
```
-If your model provider doesn’t return usage info, these fields might be partial or empty.
+If your model provider doesn't return usage info, these fields might be partial or empty.
+
+> **Tip:** `JsonCssExtractionStrategy.generate_schema()` also supports token usage tracking via an optional `usage` parameter. See [Token Usage Tracking in Schema Generation](./no-llm-strategies.md#token-usage-tracking) for details.
---
diff --git a/docs/md_v2/extraction/no-llm-strategies.md b/docs/md_v2/extraction/no-llm-strategies.md
index eb56a749..ab3c59a7 100644
--- a/docs/md_v2/extraction/no-llm-strategies.md
+++ b/docs/md_v2/extraction/no-llm-strategies.md
@@ -761,6 +761,38 @@ schema = JsonCssExtractionStrategy.generate_schema(
The generator also understands sibling layouts — for sites like Hacker News where data is split across sibling elements, it will automatically use the [`source` field](#sibling-data) to reach sibling data.
+### Token Usage Tracking
+
+`generate_schema` may make multiple LLM calls internally (field inference, schema generation, validation retries). To track the total token consumption across all of these calls, pass a `TokenUsage` accumulator:
+
+```python
+from crawl4ai import JsonCssExtractionStrategy
+from crawl4ai.models import TokenUsage
+
+usage = TokenUsage()
+
+schema = JsonCssExtractionStrategy.generate_schema(
+ url="https://news.ycombinator.com",
+ query="Extract each story: title, url, score, author",
+ usage=usage,
+)
+
+print(f"Prompt tokens: {usage.prompt_tokens}")
+print(f"Completion tokens: {usage.completion_tokens}")
+print(f"Total tokens: {usage.total_tokens}")
+```
+
+The `usage` parameter is optional — omitting it changes nothing (fully backward-compatible). You can also reuse the same accumulator across multiple calls to get a grand total:
+
+```python
+usage = TokenUsage()
+schema1 = JsonCssExtractionStrategy.generate_schema(url=url1, query=q1, usage=usage)
+schema2 = JsonCssExtractionStrategy.generate_schema(url=url2, query=q2, usage=usage)
+print(f"Grand total: {usage.total_tokens} tokens")
+```
+
+Both `generate_schema` (sync) and `agenerate_schema` (async) support the `usage` parameter.
+
### LLM Provider Options
1. **OpenAI GPT-4 (`openai/gpt4o`)**
diff --git a/tests/general/test_generate_schema_usage.py b/tests/general/test_generate_schema_usage.py
new file mode 100644
index 00000000..6f7227e9
--- /dev/null
+++ b/tests/general/test_generate_schema_usage.py
@@ -0,0 +1,654 @@
+"""Tests for TokenUsage accumulation in generate_schema / agenerate_schema.
+
+Covers:
+- Backward compatibility (usage=None, the default)
+- Single-shot schema generation accumulates usage
+- Validation retry loop accumulates across all LLM calls
+- _infer_target_json accumulates its own LLM call
+- Sync wrapper forwards usage correctly
+- JSON parse failure retry also accumulates usage
+- usage object receives correct cumulative totals
+"""
+
+import asyncio
+import json
+import pytest
+from dataclasses import dataclass
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, patch, MagicMock
+
+from crawl4ai.extraction_strategy import JsonElementExtractionStrategy, JsonCssExtractionStrategy
+from crawl4ai.models import TokenUsage
+
+# The functions are imported lazily inside method bodies via `from .utils import ...`
+# so we must patch at the source module.
+PATCH_TARGET = "crawl4ai.utils.aperform_completion_with_backoff"
+
+
+# ---------------------------------------------------------------------------
+# Helpers: fake LLM response builder
+# ---------------------------------------------------------------------------
+
+def _make_llm_response(content: str, prompt_tokens: int = 100, completion_tokens: int = 50):
+ """Build a fake litellm-style response with .usage and .choices."""
+ return SimpleNamespace(
+ usage=SimpleNamespace(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ completion_tokens_details=None,
+ prompt_tokens_details=None,
+ ),
+ choices=[
+ SimpleNamespace(
+ message=SimpleNamespace(content=content)
+ )
+ ],
+ )
+
+
+# A valid CSS schema that will pass validation against SAMPLE_HTML
+VALID_SCHEMA = {
+ "name": "products",
+ "baseSelector": ".product",
+ "fields": [
+ {"name": "title", "selector": ".title", "type": "text"},
+ {"name": "price", "selector": ".price", "type": "text"},
+ ],
+}
+
+SAMPLE_HTML = """
+
+
+ Widget
+ $10
+
+
+ Gadget
+ $20
+
+
+"""
+
+# A schema with a bad baseSelector — will fail validation and trigger retry
+BAD_SCHEMA = {
+ "name": "products",
+ "baseSelector": ".nonexistent-selector",
+ "fields": [
+ {"name": "title", "selector": ".title", "type": "text"},
+ {"name": "price", "selector": ".price", "type": "text"},
+ ],
+}
+
+# Fake LLMConfig
+@dataclass
+class FakeLLMConfig:
+ provider: str = "fake/model"
+ api_token: str = "fake-token"
+ base_url: str = None
+ backoff_base_delay: float = 0
+ backoff_max_attempts: int = 1
+ backoff_exponential_factor: int = 2
+
+
+# ---------------------------------------------------------------------------
+# Tests
+# ---------------------------------------------------------------------------
+
+class TestGenerateSchemaUsage:
+ """Test suite for usage tracking in generate_schema / agenerate_schema."""
+
+ @pytest.mark.asyncio
+ async def test_backward_compat_usage_none(self):
+ """When usage is not passed (default None), everything works as before."""
+ mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ )
+
+ assert isinstance(result, dict)
+ assert result["name"] == "products"
+
+ @pytest.mark.asyncio
+ async def test_single_shot_no_validate(self):
+ """Single LLM call with validate=False populates usage correctly."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ usage=usage,
+ )
+
+ assert result["name"] == "products"
+ assert usage.prompt_tokens == 200
+ assert usage.completion_tokens == 80
+ assert usage.total_tokens == 280
+
+ @pytest.mark.asyncio
+ async def test_validation_success_first_try(self):
+ """With validate=True and schema passes validation on first try, usage reflects 1 call."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA), prompt_tokens=300, completion_tokens=120
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=3,
+ usage=usage,
+ # Provide target_json_example to skip _infer_target_json
+ target_json_example='{"title": "x", "price": "y"}',
+ )
+
+ assert result["name"] == "products"
+ # Only 1 LLM call since validation passed
+ assert usage.prompt_tokens == 300
+ assert usage.completion_tokens == 120
+ assert usage.total_tokens == 420
+
+ @pytest.mark.asyncio
+ async def test_validation_retries_accumulate_usage(self):
+ """When validation fails, retry calls accumulate into the same usage object."""
+ usage = TokenUsage()
+
+ # First call returns bad schema (fails validation), second returns good schema
+ responses = [
+ _make_llm_response(json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100),
+ _make_llm_response(json.dumps(VALID_SCHEMA), prompt_tokens=350, completion_tokens=120),
+ ]
+ call_count = 0
+
+ async def mock_completion(*args, **kwargs):
+ nonlocal call_count
+ idx = min(call_count, len(responses) - 1)
+ call_count += 1
+ return responses[idx]
+
+ with patch(
+ PATCH_TARGET,
+ side_effect=mock_completion,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=3,
+ usage=usage,
+ target_json_example='{"title": "x", "price": "y"}',
+ )
+
+ assert result["name"] == "products"
+ # Two LLM calls: 300+350=650 prompt, 100+120=220 completion
+ assert usage.prompt_tokens == 650
+ assert usage.completion_tokens == 220
+ assert usage.total_tokens == 870
+
+ @pytest.mark.asyncio
+ async def test_infer_target_json_accumulates_usage(self):
+ """When validate=True and no target_json_example, _infer_target_json makes an extra LLM call."""
+ usage = TokenUsage()
+
+ infer_response = _make_llm_response(
+ '{"title": "Widget", "price": "$10"}',
+ prompt_tokens=50,
+ completion_tokens=30,
+ )
+ schema_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA),
+ prompt_tokens=300,
+ completion_tokens=120,
+ )
+
+ call_count = 0
+
+ async def mock_completion(*args, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ # First call is _infer_target_json, second is schema generation
+ if call_count == 1:
+ return infer_response
+ return schema_response
+
+ with patch(
+ PATCH_TARGET,
+ side_effect=mock_completion,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ query="extract product title and price",
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=3,
+ usage=usage,
+ # No target_json_example — triggers _infer_target_json
+ )
+
+ assert result["name"] == "products"
+ # _infer_target_json: 50+30 = 80
+ # schema generation: 300+120 = 420
+ # Total: 350 prompt, 150 completion, 500 total
+ assert usage.prompt_tokens == 350
+ assert usage.completion_tokens == 150
+ assert usage.total_tokens == 500
+
+ @pytest.mark.asyncio
+ async def test_infer_plus_retries_accumulate(self):
+ """Full pipeline: infer + bad schema + good schema = 3 calls accumulated."""
+ usage = TokenUsage()
+
+ infer_resp = _make_llm_response(
+ '{"title": "x", "price": "y"}',
+ prompt_tokens=50, completion_tokens=20,
+ )
+ bad_resp = _make_llm_response(
+ json.dumps(BAD_SCHEMA),
+ prompt_tokens=300, completion_tokens=100,
+ )
+ good_resp = _make_llm_response(
+ json.dumps(VALID_SCHEMA),
+ prompt_tokens=400, completion_tokens=150,
+ )
+
+ call_count = 0
+
+ async def mock_completion(*args, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ return infer_resp
+ elif call_count == 2:
+ return bad_resp
+ else:
+ return good_resp
+
+ with patch(
+ PATCH_TARGET,
+ side_effect=mock_completion,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ query="extract products",
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=3,
+ usage=usage,
+ )
+
+ # 3 calls total
+ assert call_count == 3
+ assert usage.prompt_tokens == 750 # 50 + 300 + 400
+ assert usage.completion_tokens == 270 # 20 + 100 + 150
+ assert usage.total_tokens == 1020 # 70 + 400 + 550
+
+ @pytest.mark.asyncio
+ async def test_json_parse_failure_retry_accumulates(self):
+ """When LLM returns invalid JSON, the retry also accumulates usage."""
+ usage = TokenUsage()
+
+ # First response is not valid JSON, second is valid
+ bad_json_resp = _make_llm_response(
+ "this is not json {{{",
+ prompt_tokens=200, completion_tokens=60,
+ )
+ good_resp = _make_llm_response(
+ json.dumps(VALID_SCHEMA),
+ prompt_tokens=250, completion_tokens=80,
+ )
+
+ call_count = 0
+
+ async def mock_completion(*args, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ return bad_json_resp
+ return good_resp
+
+ with patch(
+ PATCH_TARGET,
+ side_effect=mock_completion,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=3,
+ usage=usage,
+ target_json_example='{"title": "x", "price": "y"}',
+ )
+
+ assert result["name"] == "products"
+ # Both calls tracked: even the one that returned bad JSON
+ assert usage.prompt_tokens == 450 # 200 + 250
+ assert usage.completion_tokens == 140 # 60 + 80
+ assert usage.total_tokens == 590
+
+ @pytest.mark.asyncio
+ async def test_usage_none_does_not_crash(self):
+ """Explicitly passing usage=None should not raise any errors."""
+ mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ usage=None,
+ )
+
+ assert isinstance(result, dict)
+
+ @pytest.mark.asyncio
+ async def test_preexisting_usage_values_are_added_to(self):
+ """If usage already has values, new tokens are ADDED, not replaced."""
+ usage = TokenUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
+
+ mock_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ usage=usage,
+ )
+
+ assert usage.prompt_tokens == 1200 # 1000 + 200
+ assert usage.completion_tokens == 580 # 500 + 80
+ assert usage.total_tokens == 1780 # 1500 + 280
+
+ def test_sync_wrapper_passes_usage(self):
+ """The sync generate_schema forwards usage to agenerate_schema."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = JsonElementExtractionStrategy.generate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ usage=usage,
+ )
+
+ assert result["name"] == "products"
+ assert usage.prompt_tokens == 200
+ assert usage.completion_tokens == 80
+ assert usage.total_tokens == 280
+
+ def test_sync_wrapper_usage_none_backward_compat(self):
+ """Sync wrapper with no usage arg (default) still works."""
+ mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = JsonElementExtractionStrategy.generate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ )
+
+ assert isinstance(result, dict)
+ assert result["name"] == "products"
+
+ @pytest.mark.asyncio
+ async def test_max_refinements_zero_single_call(self):
+ """max_refinements=0 with validate=True means exactly 1 attempt, 1 usage entry."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=0,
+ usage=usage,
+ target_json_example='{"title": "x", "price": "y"}',
+ )
+
+ # Even though validation fails, only 1 attempt (0 refinements)
+ assert usage.prompt_tokens == 300
+ assert usage.completion_tokens == 100
+ assert usage.total_tokens == 400
+
+ @pytest.mark.asyncio
+ async def test_css_subclass_inherits_usage(self):
+ """JsonCssExtractionStrategy.agenerate_schema also supports usage."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ json.dumps(VALID_SCHEMA), prompt_tokens=150, completion_tokens=60
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonCssExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=False,
+ usage=usage,
+ )
+
+ assert result["name"] == "products"
+ assert usage.total_tokens == 210
+
+ @pytest.mark.asyncio
+ async def test_infer_target_json_failure_still_tracks_nothing(self):
+ """If _infer_target_json raises (and catches), usage should not break.
+
+ When the inference LLM call itself throws an exception before we get
+ response.usage, no tokens should be added (graceful degradation).
+ """
+ usage = TokenUsage()
+
+ call_count = 0
+
+ async def mock_completion(*args, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ # _infer_target_json call — simulate exception
+ raise ConnectionError("LLM is down")
+ # Schema generation call
+ return _make_llm_response(
+ json.dumps(VALID_SCHEMA),
+ prompt_tokens=300,
+ completion_tokens=100,
+ )
+
+ with patch(
+ PATCH_TARGET,
+ side_effect=mock_completion,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ query="extract products",
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=0,
+ usage=usage,
+ )
+
+ # Only the schema call counted; infer call failed before tracking
+ assert usage.prompt_tokens == 300
+ assert usage.completion_tokens == 100
+ assert usage.total_tokens == 400
+
+ @pytest.mark.asyncio
+ async def test_multiple_bad_retries_then_best_effort(self):
+ """All retries fail validation, usage still accumulates for every attempt."""
+ usage = TokenUsage()
+
+ # Every call returns bad schema — validation will always fail
+ mock_response = _make_llm_response(
+ json.dumps(BAD_SCHEMA), prompt_tokens=200, completion_tokens=80
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy.agenerate_schema(
+ html=SAMPLE_HTML,
+ llm_config=FakeLLMConfig(),
+ validate=True,
+ max_refinements=2, # 1 initial + 2 retries = 3 calls
+ usage=usage,
+ target_json_example='{"title": "x", "price": "y"}',
+ )
+
+ # Returns best-effort (last schema), but all 3 calls tracked
+ assert usage.prompt_tokens == 600 # 200 * 3
+ assert usage.completion_tokens == 240 # 80 * 3
+ assert usage.total_tokens == 840 # 280 * 3
+
+
+class TestInferTargetJsonUsage:
+ """Isolated tests for _infer_target_json usage tracking."""
+
+ @pytest.mark.asyncio
+ async def test_infer_tracks_usage(self):
+ """Direct call to _infer_target_json with usage accumulator."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response(
+ '{"name": "test", "value": "123"}',
+ prompt_tokens=80,
+ completion_tokens=25,
+ )
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy._infer_target_json(
+ query="extract names and values",
+ html_snippet="test
",
+ llm_config=FakeLLMConfig(),
+ usage=usage,
+ )
+
+ assert result == {"name": "test", "value": "123"}
+ assert usage.prompt_tokens == 80
+ assert usage.completion_tokens == 25
+ assert usage.total_tokens == 105
+
+ @pytest.mark.asyncio
+ async def test_infer_usage_none_backward_compat(self):
+ """_infer_target_json with usage=None (default) still works."""
+ mock_response = _make_llm_response('{"name": "test"}')
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy._infer_target_json(
+ query="extract names",
+ html_snippet="test
",
+ llm_config=FakeLLMConfig(),
+ )
+
+ assert result == {"name": "test"}
+
+ @pytest.mark.asyncio
+ async def test_infer_exception_no_usage_side_effect(self):
+ """When _infer_target_json fails, usage is untouched (exception before tracking)."""
+ usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ side_effect=RuntimeError("API down"),
+ ):
+ result = await JsonElementExtractionStrategy._infer_target_json(
+ query="extract names",
+ html_snippet="test
",
+ llm_config=FakeLLMConfig(),
+ usage=usage,
+ )
+
+ # Returns None on failure
+ assert result is None
+ # Usage unchanged — exception happened before tracking
+ assert usage.prompt_tokens == 100
+ assert usage.completion_tokens == 50
+ assert usage.total_tokens == 150
+
+ @pytest.mark.asyncio
+ async def test_infer_empty_response_still_tracks(self):
+ """When LLM returns empty content, usage is still tracked (response was received)."""
+ usage = TokenUsage()
+ mock_response = _make_llm_response("", prompt_tokens=80, completion_tokens=5)
+
+ with patch(
+ PATCH_TARGET,
+ new_callable=AsyncMock,
+ return_value=mock_response,
+ ):
+ result = await JsonElementExtractionStrategy._infer_target_json(
+ query="extract names",
+ html_snippet="test
",
+ llm_config=FakeLLMConfig(),
+ usage=usage,
+ )
+
+ # Returns None because content is empty
+ assert result is None
+ # But usage was tracked because we got a response
+ assert usage.prompt_tokens == 80
+ assert usage.completion_tokens == 5
+ assert usage.total_tokens == 85