From c9cb0160cf621e77a8ec988b08fdd85a9da25dad Mon Sep 17 00:00:00 2001 From: unclecode Date: Wed, 18 Feb 2026 06:44:17 +0000 Subject: [PATCH] Add token usage tracking to generate_schema / agenerate_schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generate_schema can make up to 5 internal LLM calls (field inference, schema generation, validation retries) with no way to track token consumption. Add an optional `usage: TokenUsage = None` parameter that accumulates prompt/completion/total tokens across all calls in-place. - _infer_target_json: accept and populate usage accumulator - agenerate_schema: track usage after every aperform_completion call in the retry loop, forward usage to _infer_target_json - generate_schema (sync): forward usage to agenerate_schema Fully backward-compatible — omitting usage changes nothing. --- crawl4ai/extraction_strategy.py | 25 +- docs/md_v2/complete-sdk-reference.md | 14 + docs/md_v2/extraction/llm-strategies.md | 4 +- docs/md_v2/extraction/no-llm-strategies.md | 32 + tests/general/test_generate_schema_usage.py | 654 ++++++++++++++++++++ 5 files changed, 726 insertions(+), 3 deletions(-) create mode 100644 tests/general/test_generate_schema_usage.py diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index ad024b32..bea176e6 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -1599,9 +1599,13 @@ class JsonElementExtractionStrategy(ExtractionStrategy): return "\n".join(parts) @staticmethod - async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None) -> Optional[dict]: + async def _infer_target_json(query: str, html_snippet: str, llm_config, url: str = None, usage: 'TokenUsage' = None) -> Optional[dict]: """Infer a target JSON example from a query and HTML snippet via a quick LLM call. + Args: + usage: Optional TokenUsage accumulator. If provided, token counts from + this LLM call are added to it in-place. + Returns the parsed dict, or None if inference fails. """ from .utils import aperform_completion_with_backoff @@ -1633,6 +1637,10 @@ class JsonElementExtractionStrategy(ExtractionStrategy): api_token=llm_config.api_token, base_url=llm_config.base_url, ) + if usage is not None: + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens raw = response.choices[0].message.content if not raw or not raw.strip(): return None @@ -1726,6 +1734,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa url: Union[str, List[str]] = None, validate: bool = True, max_refinements: int = 3, + usage: 'TokenUsage' = None, **kwargs ) -> dict: """ @@ -1744,6 +1753,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa validate (bool): If True, validate the schema against the HTML and refine via LLM feedback loop. Defaults to False (zero overhead). max_refinements (int): Max refinement rounds when validate=True. Defaults to 3. + usage (TokenUsage, optional): Token usage accumulator. If provided, + token counts from all LLM calls (including inference and + validation retries) are added to it in-place. **kwargs: Additional args passed to LLM processor. Returns: @@ -1770,6 +1782,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa url=url, validate=validate, max_refinements=max_refinements, + usage=usage, **kwargs ) @@ -1793,6 +1806,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa url: Union[str, List[str]] = None, validate: bool = True, max_refinements: int = 3, + usage: 'TokenUsage' = None, **kwargs ) -> dict: """ @@ -1815,6 +1829,9 @@ In this scenario, use your best judgment to generate the schema. You need to exa validate (bool): If True, validate the schema against the HTML and refine via LLM feedback loop. Defaults to False (zero overhead). max_refinements (int): Max refinement rounds when validate=True. Defaults to 3. + usage (TokenUsage, optional): Token usage accumulator. If provided, + token counts from all LLM calls (including inference and + validation retries) are added to it in-place. **kwargs: Additional args passed to LLM processor. Returns: @@ -1913,7 +1930,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa if url is not None: first_url = url if isinstance(url, str) else url[0] inferred = await JsonElementExtractionStrategy._infer_target_json( - query=query, html_snippet=html, llm_config=llm_config, url=first_url + query=query, html_snippet=html, llm_config=llm_config, url=first_url, usage=usage ) if inferred: expected_fields = JsonElementExtractionStrategy._extract_expected_fields(inferred) @@ -1939,6 +1956,10 @@ In this scenario, use your best judgment to generate the schema. You need to exa messages=messages, extra_args=kwargs, ) + if usage is not None: + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens raw = response.choices[0].message.content if not raw or not raw.strip(): raise ValueError("LLM returned an empty response") diff --git a/docs/md_v2/complete-sdk-reference.md b/docs/md_v2/complete-sdk-reference.md index 61fa4fca..3d639edc 100644 --- a/docs/md_v2/complete-sdk-reference.md +++ b/docs/md_v2/complete-sdk-reference.md @@ -4537,6 +4537,20 @@ xpath_schema = JsonXPathExtractionStrategy.generate_schema( # Use the generated schema for fast, repeated extractions strategy = JsonCssExtractionStrategy(css_schema) ``` +### Token Usage Tracking +`generate_schema` may make multiple LLM calls internally (field inference, generation, validation retries). Track total token consumption by passing a `TokenUsage` accumulator: +```python +from crawl4ai.models import TokenUsage + +usage = TokenUsage() +schema = JsonCssExtractionStrategy.generate_schema( + url="https://example.com/products", + query="extract product name and price", + usage=usage, +) +print(f"Total tokens: {usage.total_tokens}") +``` +The `usage` parameter is optional and fully backward-compatible. Both `generate_schema` (sync) and `agenerate_schema` (async) support it. ### LLM Provider Options 1. **OpenAI GPT-4 (`openai/gpt4o`)** - Default provider diff --git a/docs/md_v2/extraction/llm-strategies.md b/docs/md_v2/extraction/llm-strategies.md index df948a9e..cba4d6e4 100644 --- a/docs/md_v2/extraction/llm-strategies.md +++ b/docs/md_v2/extraction/llm-strategies.md @@ -204,7 +204,9 @@ llm_strategy.show_usage() # e.g. “Total usage: 1241 tokens across 2 chunk calls” ``` -If your model provider doesn’t return usage info, these fields might be partial or empty. +If your model provider doesn't return usage info, these fields might be partial or empty. + +> **Tip:** `JsonCssExtractionStrategy.generate_schema()` also supports token usage tracking via an optional `usage` parameter. See [Token Usage Tracking in Schema Generation](./no-llm-strategies.md#token-usage-tracking) for details. --- diff --git a/docs/md_v2/extraction/no-llm-strategies.md b/docs/md_v2/extraction/no-llm-strategies.md index eb56a749..ab3c59a7 100644 --- a/docs/md_v2/extraction/no-llm-strategies.md +++ b/docs/md_v2/extraction/no-llm-strategies.md @@ -761,6 +761,38 @@ schema = JsonCssExtractionStrategy.generate_schema( The generator also understands sibling layouts — for sites like Hacker News where data is split across sibling elements, it will automatically use the [`source` field](#sibling-data) to reach sibling data. +### Token Usage Tracking + +`generate_schema` may make multiple LLM calls internally (field inference, schema generation, validation retries). To track the total token consumption across all of these calls, pass a `TokenUsage` accumulator: + +```python +from crawl4ai import JsonCssExtractionStrategy +from crawl4ai.models import TokenUsage + +usage = TokenUsage() + +schema = JsonCssExtractionStrategy.generate_schema( + url="https://news.ycombinator.com", + query="Extract each story: title, url, score, author", + usage=usage, +) + +print(f"Prompt tokens: {usage.prompt_tokens}") +print(f"Completion tokens: {usage.completion_tokens}") +print(f"Total tokens: {usage.total_tokens}") +``` + +The `usage` parameter is optional — omitting it changes nothing (fully backward-compatible). You can also reuse the same accumulator across multiple calls to get a grand total: + +```python +usage = TokenUsage() +schema1 = JsonCssExtractionStrategy.generate_schema(url=url1, query=q1, usage=usage) +schema2 = JsonCssExtractionStrategy.generate_schema(url=url2, query=q2, usage=usage) +print(f"Grand total: {usage.total_tokens} tokens") +``` + +Both `generate_schema` (sync) and `agenerate_schema` (async) support the `usage` parameter. + ### LLM Provider Options 1. **OpenAI GPT-4 (`openai/gpt4o`)** diff --git a/tests/general/test_generate_schema_usage.py b/tests/general/test_generate_schema_usage.py new file mode 100644 index 00000000..6f7227e9 --- /dev/null +++ b/tests/general/test_generate_schema_usage.py @@ -0,0 +1,654 @@ +"""Tests for TokenUsage accumulation in generate_schema / agenerate_schema. + +Covers: +- Backward compatibility (usage=None, the default) +- Single-shot schema generation accumulates usage +- Validation retry loop accumulates across all LLM calls +- _infer_target_json accumulates its own LLM call +- Sync wrapper forwards usage correctly +- JSON parse failure retry also accumulates usage +- usage object receives correct cumulative totals +""" + +import asyncio +import json +import pytest +from dataclasses import dataclass +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch, MagicMock + +from crawl4ai.extraction_strategy import JsonElementExtractionStrategy, JsonCssExtractionStrategy +from crawl4ai.models import TokenUsage + +# The functions are imported lazily inside method bodies via `from .utils import ...` +# so we must patch at the source module. +PATCH_TARGET = "crawl4ai.utils.aperform_completion_with_backoff" + + +# --------------------------------------------------------------------------- +# Helpers: fake LLM response builder +# --------------------------------------------------------------------------- + +def _make_llm_response(content: str, prompt_tokens: int = 100, completion_tokens: int = 50): + """Build a fake litellm-style response with .usage and .choices.""" + return SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + completion_tokens_details=None, + prompt_tokens_details=None, + ), + choices=[ + SimpleNamespace( + message=SimpleNamespace(content=content) + ) + ], + ) + + +# A valid CSS schema that will pass validation against SAMPLE_HTML +VALID_SCHEMA = { + "name": "products", + "baseSelector": ".product", + "fields": [ + {"name": "title", "selector": ".title", "type": "text"}, + {"name": "price", "selector": ".price", "type": "text"}, + ], +} + +SAMPLE_HTML = """ +
+
+ Widget + $10 +
+
+ Gadget + $20 +
+
+""" + +# A schema with a bad baseSelector — will fail validation and trigger retry +BAD_SCHEMA = { + "name": "products", + "baseSelector": ".nonexistent-selector", + "fields": [ + {"name": "title", "selector": ".title", "type": "text"}, + {"name": "price", "selector": ".price", "type": "text"}, + ], +} + +# Fake LLMConfig +@dataclass +class FakeLLMConfig: + provider: str = "fake/model" + api_token: str = "fake-token" + base_url: str = None + backoff_base_delay: float = 0 + backoff_max_attempts: int = 1 + backoff_exponential_factor: int = 2 + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestGenerateSchemaUsage: + """Test suite for usage tracking in generate_schema / agenerate_schema.""" + + @pytest.mark.asyncio + async def test_backward_compat_usage_none(self): + """When usage is not passed (default None), everything works as before.""" + mock_response = _make_llm_response(json.dumps(VALID_SCHEMA)) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + ) + + assert isinstance(result, dict) + assert result["name"] == "products" + + @pytest.mark.asyncio + async def test_single_shot_no_validate(self): + """Single LLM call with validate=False populates usage correctly.""" + usage = TokenUsage() + mock_response = _make_llm_response( + json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + usage=usage, + ) + + assert result["name"] == "products" + assert usage.prompt_tokens == 200 + assert usage.completion_tokens == 80 + assert usage.total_tokens == 280 + + @pytest.mark.asyncio + async def test_validation_success_first_try(self): + """With validate=True and schema passes validation on first try, usage reflects 1 call.""" + usage = TokenUsage() + mock_response = _make_llm_response( + json.dumps(VALID_SCHEMA), prompt_tokens=300, completion_tokens=120 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=3, + usage=usage, + # Provide target_json_example to skip _infer_target_json + target_json_example='{"title": "x", "price": "y"}', + ) + + assert result["name"] == "products" + # Only 1 LLM call since validation passed + assert usage.prompt_tokens == 300 + assert usage.completion_tokens == 120 + assert usage.total_tokens == 420 + + @pytest.mark.asyncio + async def test_validation_retries_accumulate_usage(self): + """When validation fails, retry calls accumulate into the same usage object.""" + usage = TokenUsage() + + # First call returns bad schema (fails validation), second returns good schema + responses = [ + _make_llm_response(json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100), + _make_llm_response(json.dumps(VALID_SCHEMA), prompt_tokens=350, completion_tokens=120), + ] + call_count = 0 + + async def mock_completion(*args, **kwargs): + nonlocal call_count + idx = min(call_count, len(responses) - 1) + call_count += 1 + return responses[idx] + + with patch( + PATCH_TARGET, + side_effect=mock_completion, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=3, + usage=usage, + target_json_example='{"title": "x", "price": "y"}', + ) + + assert result["name"] == "products" + # Two LLM calls: 300+350=650 prompt, 100+120=220 completion + assert usage.prompt_tokens == 650 + assert usage.completion_tokens == 220 + assert usage.total_tokens == 870 + + @pytest.mark.asyncio + async def test_infer_target_json_accumulates_usage(self): + """When validate=True and no target_json_example, _infer_target_json makes an extra LLM call.""" + usage = TokenUsage() + + infer_response = _make_llm_response( + '{"title": "Widget", "price": "$10"}', + prompt_tokens=50, + completion_tokens=30, + ) + schema_response = _make_llm_response( + json.dumps(VALID_SCHEMA), + prompt_tokens=300, + completion_tokens=120, + ) + + call_count = 0 + + async def mock_completion(*args, **kwargs): + nonlocal call_count + call_count += 1 + # First call is _infer_target_json, second is schema generation + if call_count == 1: + return infer_response + return schema_response + + with patch( + PATCH_TARGET, + side_effect=mock_completion, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + query="extract product title and price", + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=3, + usage=usage, + # No target_json_example — triggers _infer_target_json + ) + + assert result["name"] == "products" + # _infer_target_json: 50+30 = 80 + # schema generation: 300+120 = 420 + # Total: 350 prompt, 150 completion, 500 total + assert usage.prompt_tokens == 350 + assert usage.completion_tokens == 150 + assert usage.total_tokens == 500 + + @pytest.mark.asyncio + async def test_infer_plus_retries_accumulate(self): + """Full pipeline: infer + bad schema + good schema = 3 calls accumulated.""" + usage = TokenUsage() + + infer_resp = _make_llm_response( + '{"title": "x", "price": "y"}', + prompt_tokens=50, completion_tokens=20, + ) + bad_resp = _make_llm_response( + json.dumps(BAD_SCHEMA), + prompt_tokens=300, completion_tokens=100, + ) + good_resp = _make_llm_response( + json.dumps(VALID_SCHEMA), + prompt_tokens=400, completion_tokens=150, + ) + + call_count = 0 + + async def mock_completion(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return infer_resp + elif call_count == 2: + return bad_resp + else: + return good_resp + + with patch( + PATCH_TARGET, + side_effect=mock_completion, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + query="extract products", + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=3, + usage=usage, + ) + + # 3 calls total + assert call_count == 3 + assert usage.prompt_tokens == 750 # 50 + 300 + 400 + assert usage.completion_tokens == 270 # 20 + 100 + 150 + assert usage.total_tokens == 1020 # 70 + 400 + 550 + + @pytest.mark.asyncio + async def test_json_parse_failure_retry_accumulates(self): + """When LLM returns invalid JSON, the retry also accumulates usage.""" + usage = TokenUsage() + + # First response is not valid JSON, second is valid + bad_json_resp = _make_llm_response( + "this is not json {{{", + prompt_tokens=200, completion_tokens=60, + ) + good_resp = _make_llm_response( + json.dumps(VALID_SCHEMA), + prompt_tokens=250, completion_tokens=80, + ) + + call_count = 0 + + async def mock_completion(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return bad_json_resp + return good_resp + + with patch( + PATCH_TARGET, + side_effect=mock_completion, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=3, + usage=usage, + target_json_example='{"title": "x", "price": "y"}', + ) + + assert result["name"] == "products" + # Both calls tracked: even the one that returned bad JSON + assert usage.prompt_tokens == 450 # 200 + 250 + assert usage.completion_tokens == 140 # 60 + 80 + assert usage.total_tokens == 590 + + @pytest.mark.asyncio + async def test_usage_none_does_not_crash(self): + """Explicitly passing usage=None should not raise any errors.""" + mock_response = _make_llm_response(json.dumps(VALID_SCHEMA)) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + usage=None, + ) + + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_preexisting_usage_values_are_added_to(self): + """If usage already has values, new tokens are ADDED, not replaced.""" + usage = TokenUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500) + + mock_response = _make_llm_response( + json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + usage=usage, + ) + + assert usage.prompt_tokens == 1200 # 1000 + 200 + assert usage.completion_tokens == 580 # 500 + 80 + assert usage.total_tokens == 1780 # 1500 + 280 + + def test_sync_wrapper_passes_usage(self): + """The sync generate_schema forwards usage to agenerate_schema.""" + usage = TokenUsage() + mock_response = _make_llm_response( + json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = JsonElementExtractionStrategy.generate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + usage=usage, + ) + + assert result["name"] == "products" + assert usage.prompt_tokens == 200 + assert usage.completion_tokens == 80 + assert usage.total_tokens == 280 + + def test_sync_wrapper_usage_none_backward_compat(self): + """Sync wrapper with no usage arg (default) still works.""" + mock_response = _make_llm_response(json.dumps(VALID_SCHEMA)) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = JsonElementExtractionStrategy.generate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + ) + + assert isinstance(result, dict) + assert result["name"] == "products" + + @pytest.mark.asyncio + async def test_max_refinements_zero_single_call(self): + """max_refinements=0 with validate=True means exactly 1 attempt, 1 usage entry.""" + usage = TokenUsage() + mock_response = _make_llm_response( + json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=0, + usage=usage, + target_json_example='{"title": "x", "price": "y"}', + ) + + # Even though validation fails, only 1 attempt (0 refinements) + assert usage.prompt_tokens == 300 + assert usage.completion_tokens == 100 + assert usage.total_tokens == 400 + + @pytest.mark.asyncio + async def test_css_subclass_inherits_usage(self): + """JsonCssExtractionStrategy.agenerate_schema also supports usage.""" + usage = TokenUsage() + mock_response = _make_llm_response( + json.dumps(VALID_SCHEMA), prompt_tokens=150, completion_tokens=60 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonCssExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=False, + usage=usage, + ) + + assert result["name"] == "products" + assert usage.total_tokens == 210 + + @pytest.mark.asyncio + async def test_infer_target_json_failure_still_tracks_nothing(self): + """If _infer_target_json raises (and catches), usage should not break. + + When the inference LLM call itself throws an exception before we get + response.usage, no tokens should be added (graceful degradation). + """ + usage = TokenUsage() + + call_count = 0 + + async def mock_completion(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # _infer_target_json call — simulate exception + raise ConnectionError("LLM is down") + # Schema generation call + return _make_llm_response( + json.dumps(VALID_SCHEMA), + prompt_tokens=300, + completion_tokens=100, + ) + + with patch( + PATCH_TARGET, + side_effect=mock_completion, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + query="extract products", + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=0, + usage=usage, + ) + + # Only the schema call counted; infer call failed before tracking + assert usage.prompt_tokens == 300 + assert usage.completion_tokens == 100 + assert usage.total_tokens == 400 + + @pytest.mark.asyncio + async def test_multiple_bad_retries_then_best_effort(self): + """All retries fail validation, usage still accumulates for every attempt.""" + usage = TokenUsage() + + # Every call returns bad schema — validation will always fail + mock_response = _make_llm_response( + json.dumps(BAD_SCHEMA), prompt_tokens=200, completion_tokens=80 + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy.agenerate_schema( + html=SAMPLE_HTML, + llm_config=FakeLLMConfig(), + validate=True, + max_refinements=2, # 1 initial + 2 retries = 3 calls + usage=usage, + target_json_example='{"title": "x", "price": "y"}', + ) + + # Returns best-effort (last schema), but all 3 calls tracked + assert usage.prompt_tokens == 600 # 200 * 3 + assert usage.completion_tokens == 240 # 80 * 3 + assert usage.total_tokens == 840 # 280 * 3 + + +class TestInferTargetJsonUsage: + """Isolated tests for _infer_target_json usage tracking.""" + + @pytest.mark.asyncio + async def test_infer_tracks_usage(self): + """Direct call to _infer_target_json with usage accumulator.""" + usage = TokenUsage() + mock_response = _make_llm_response( + '{"name": "test", "value": "123"}', + prompt_tokens=80, + completion_tokens=25, + ) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy._infer_target_json( + query="extract names and values", + html_snippet="
test
", + llm_config=FakeLLMConfig(), + usage=usage, + ) + + assert result == {"name": "test", "value": "123"} + assert usage.prompt_tokens == 80 + assert usage.completion_tokens == 25 + assert usage.total_tokens == 105 + + @pytest.mark.asyncio + async def test_infer_usage_none_backward_compat(self): + """_infer_target_json with usage=None (default) still works.""" + mock_response = _make_llm_response('{"name": "test"}') + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy._infer_target_json( + query="extract names", + html_snippet="
test
", + llm_config=FakeLLMConfig(), + ) + + assert result == {"name": "test"} + + @pytest.mark.asyncio + async def test_infer_exception_no_usage_side_effect(self): + """When _infer_target_json fails, usage is untouched (exception before tracking).""" + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + side_effect=RuntimeError("API down"), + ): + result = await JsonElementExtractionStrategy._infer_target_json( + query="extract names", + html_snippet="
test
", + llm_config=FakeLLMConfig(), + usage=usage, + ) + + # Returns None on failure + assert result is None + # Usage unchanged — exception happened before tracking + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + + @pytest.mark.asyncio + async def test_infer_empty_response_still_tracks(self): + """When LLM returns empty content, usage is still tracked (response was received).""" + usage = TokenUsage() + mock_response = _make_llm_response("", prompt_tokens=80, completion_tokens=5) + + with patch( + PATCH_TARGET, + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await JsonElementExtractionStrategy._infer_target_json( + query="extract names", + html_snippet="
test
", + llm_config=FakeLLMConfig(), + usage=usage, + ) + + # Returns None because content is empty + assert result is None + # But usage was tracked because we got a response + assert usage.prompt_tokens == 80 + assert usage.completion_tokens == 5 + assert usage.total_tokens == 85