Files
crawl4ai/tests/general/test_generate_schema_usage.py
unclecode c9cb0160cf Add token usage tracking to generate_schema / agenerate_schema
generate_schema can make up to 5 internal LLM calls (field inference,
schema generation, validation retries) with no way to track token
consumption. Add an optional `usage: TokenUsage = None` parameter that
accumulates prompt/completion/total tokens across all calls in-place.

- _infer_target_json: accept and populate usage accumulator
- agenerate_schema: track usage after every aperform_completion call
  in the retry loop, forward usage to _infer_target_json
- generate_schema (sync): forward usage to agenerate_schema

Fully backward-compatible — omitting usage changes nothing.
2026-02-18 06:44:17 +00:00

655 lines
22 KiB
Python

"""Tests for TokenUsage accumulation in generate_schema / agenerate_schema.
Covers:
- Backward compatibility (usage=None, the default)
- Single-shot schema generation accumulates usage
- Validation retry loop accumulates across all LLM calls
- _infer_target_json accumulates its own LLM call
- Sync wrapper forwards usage correctly
- JSON parse failure retry also accumulates usage
- usage object receives correct cumulative totals
"""
import asyncio
import json
import pytest
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch, MagicMock
from crawl4ai.extraction_strategy import JsonElementExtractionStrategy, JsonCssExtractionStrategy
from crawl4ai.models import TokenUsage
# The functions are imported lazily inside method bodies via `from .utils import ...`
# so we must patch at the source module.
PATCH_TARGET = "crawl4ai.utils.aperform_completion_with_backoff"
# ---------------------------------------------------------------------------
# Helpers: fake LLM response builder
# ---------------------------------------------------------------------------
def _make_llm_response(content: str, prompt_tokens: int = 100, completion_tokens: int = 50):
"""Build a fake litellm-style response with .usage and .choices."""
return SimpleNamespace(
usage=SimpleNamespace(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
completion_tokens_details=None,
prompt_tokens_details=None,
),
choices=[
SimpleNamespace(
message=SimpleNamespace(content=content)
)
],
)
# A valid CSS schema that will pass validation against SAMPLE_HTML
VALID_SCHEMA = {
"name": "products",
"baseSelector": ".product",
"fields": [
{"name": "title", "selector": ".title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"},
],
}
SAMPLE_HTML = """
<div class="products">
<div class="product">
<span class="title">Widget</span>
<span class="price">$10</span>
</div>
<div class="product">
<span class="title">Gadget</span>
<span class="price">$20</span>
</div>
</div>
"""
# A schema with a bad baseSelector — will fail validation and trigger retry
BAD_SCHEMA = {
"name": "products",
"baseSelector": ".nonexistent-selector",
"fields": [
{"name": "title", "selector": ".title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"},
],
}
# Fake LLMConfig
@dataclass
class FakeLLMConfig:
provider: str = "fake/model"
api_token: str = "fake-token"
base_url: str = None
backoff_base_delay: float = 0
backoff_max_attempts: int = 1
backoff_exponential_factor: int = 2
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestGenerateSchemaUsage:
"""Test suite for usage tracking in generate_schema / agenerate_schema."""
@pytest.mark.asyncio
async def test_backward_compat_usage_none(self):
"""When usage is not passed (default None), everything works as before."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
)
assert isinstance(result, dict)
assert result["name"] == "products"
@pytest.mark.asyncio
async def test_single_shot_no_validate(self):
"""Single LLM call with validate=False populates usage correctly."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 80
assert usage.total_tokens == 280
@pytest.mark.asyncio
async def test_validation_success_first_try(self):
"""With validate=True and schema passes validation on first try, usage reflects 1 call."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=300, completion_tokens=120
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
# Provide target_json_example to skip _infer_target_json
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Only 1 LLM call since validation passed
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 120
assert usage.total_tokens == 420
@pytest.mark.asyncio
async def test_validation_retries_accumulate_usage(self):
"""When validation fails, retry calls accumulate into the same usage object."""
usage = TokenUsage()
# First call returns bad schema (fails validation), second returns good schema
responses = [
_make_llm_response(json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100),
_make_llm_response(json.dumps(VALID_SCHEMA), prompt_tokens=350, completion_tokens=120),
]
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
idx = min(call_count, len(responses) - 1)
call_count += 1
return responses[idx]
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Two LLM calls: 300+350=650 prompt, 100+120=220 completion
assert usage.prompt_tokens == 650
assert usage.completion_tokens == 220
assert usage.total_tokens == 870
@pytest.mark.asyncio
async def test_infer_target_json_accumulates_usage(self):
"""When validate=True and no target_json_example, _infer_target_json makes an extra LLM call."""
usage = TokenUsage()
infer_response = _make_llm_response(
'{"title": "Widget", "price": "$10"}',
prompt_tokens=50,
completion_tokens=30,
)
schema_response = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=300,
completion_tokens=120,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
# First call is _infer_target_json, second is schema generation
if call_count == 1:
return infer_response
return schema_response
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract product title and price",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
# No target_json_example — triggers _infer_target_json
)
assert result["name"] == "products"
# _infer_target_json: 50+30 = 80
# schema generation: 300+120 = 420
# Total: 350 prompt, 150 completion, 500 total
assert usage.prompt_tokens == 350
assert usage.completion_tokens == 150
assert usage.total_tokens == 500
@pytest.mark.asyncio
async def test_infer_plus_retries_accumulate(self):
"""Full pipeline: infer + bad schema + good schema = 3 calls accumulated."""
usage = TokenUsage()
infer_resp = _make_llm_response(
'{"title": "x", "price": "y"}',
prompt_tokens=50, completion_tokens=20,
)
bad_resp = _make_llm_response(
json.dumps(BAD_SCHEMA),
prompt_tokens=300, completion_tokens=100,
)
good_resp = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=400, completion_tokens=150,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return infer_resp
elif call_count == 2:
return bad_resp
else:
return good_resp
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract products",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
)
# 3 calls total
assert call_count == 3
assert usage.prompt_tokens == 750 # 50 + 300 + 400
assert usage.completion_tokens == 270 # 20 + 100 + 150
assert usage.total_tokens == 1020 # 70 + 400 + 550
@pytest.mark.asyncio
async def test_json_parse_failure_retry_accumulates(self):
"""When LLM returns invalid JSON, the retry also accumulates usage."""
usage = TokenUsage()
# First response is not valid JSON, second is valid
bad_json_resp = _make_llm_response(
"this is not json {{{",
prompt_tokens=200, completion_tokens=60,
)
good_resp = _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=250, completion_tokens=80,
)
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return bad_json_resp
return good_resp
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=3,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
assert result["name"] == "products"
# Both calls tracked: even the one that returned bad JSON
assert usage.prompt_tokens == 450 # 200 + 250
assert usage.completion_tokens == 140 # 60 + 80
assert usage.total_tokens == 590
@pytest.mark.asyncio
async def test_usage_none_does_not_crash(self):
"""Explicitly passing usage=None should not raise any errors."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=None,
)
assert isinstance(result, dict)
@pytest.mark.asyncio
async def test_preexisting_usage_values_are_added_to(self):
"""If usage already has values, new tokens are ADDED, not replaced."""
usage = TokenUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert usage.prompt_tokens == 1200 # 1000 + 200
assert usage.completion_tokens == 580 # 500 + 80
assert usage.total_tokens == 1780 # 1500 + 280
def test_sync_wrapper_passes_usage(self):
"""The sync generate_schema forwards usage to agenerate_schema."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = JsonElementExtractionStrategy.generate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.prompt_tokens == 200
assert usage.completion_tokens == 80
assert usage.total_tokens == 280
def test_sync_wrapper_usage_none_backward_compat(self):
"""Sync wrapper with no usage arg (default) still works."""
mock_response = _make_llm_response(json.dumps(VALID_SCHEMA))
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = JsonElementExtractionStrategy.generate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
)
assert isinstance(result, dict)
assert result["name"] == "products"
@pytest.mark.asyncio
async def test_max_refinements_zero_single_call(self):
"""max_refinements=0 with validate=True means exactly 1 attempt, 1 usage entry."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(BAD_SCHEMA), prompt_tokens=300, completion_tokens=100
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=0,
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
# Even though validation fails, only 1 attempt (0 refinements)
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 100
assert usage.total_tokens == 400
@pytest.mark.asyncio
async def test_css_subclass_inherits_usage(self):
"""JsonCssExtractionStrategy.agenerate_schema also supports usage."""
usage = TokenUsage()
mock_response = _make_llm_response(
json.dumps(VALID_SCHEMA), prompt_tokens=150, completion_tokens=60
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonCssExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=False,
usage=usage,
)
assert result["name"] == "products"
assert usage.total_tokens == 210
@pytest.mark.asyncio
async def test_infer_target_json_failure_still_tracks_nothing(self):
"""If _infer_target_json raises (and catches), usage should not break.
When the inference LLM call itself throws an exception before we get
response.usage, no tokens should be added (graceful degradation).
"""
usage = TokenUsage()
call_count = 0
async def mock_completion(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# _infer_target_json call — simulate exception
raise ConnectionError("LLM is down")
# Schema generation call
return _make_llm_response(
json.dumps(VALID_SCHEMA),
prompt_tokens=300,
completion_tokens=100,
)
with patch(
PATCH_TARGET,
side_effect=mock_completion,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
query="extract products",
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=0,
usage=usage,
)
# Only the schema call counted; infer call failed before tracking
assert usage.prompt_tokens == 300
assert usage.completion_tokens == 100
assert usage.total_tokens == 400
@pytest.mark.asyncio
async def test_multiple_bad_retries_then_best_effort(self):
"""All retries fail validation, usage still accumulates for every attempt."""
usage = TokenUsage()
# Every call returns bad schema — validation will always fail
mock_response = _make_llm_response(
json.dumps(BAD_SCHEMA), prompt_tokens=200, completion_tokens=80
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy.agenerate_schema(
html=SAMPLE_HTML,
llm_config=FakeLLMConfig(),
validate=True,
max_refinements=2, # 1 initial + 2 retries = 3 calls
usage=usage,
target_json_example='{"title": "x", "price": "y"}',
)
# Returns best-effort (last schema), but all 3 calls tracked
assert usage.prompt_tokens == 600 # 200 * 3
assert usage.completion_tokens == 240 # 80 * 3
assert usage.total_tokens == 840 # 280 * 3
class TestInferTargetJsonUsage:
"""Isolated tests for _infer_target_json usage tracking."""
@pytest.mark.asyncio
async def test_infer_tracks_usage(self):
"""Direct call to _infer_target_json with usage accumulator."""
usage = TokenUsage()
mock_response = _make_llm_response(
'{"name": "test", "value": "123"}',
prompt_tokens=80,
completion_tokens=25,
)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names and values",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
assert result == {"name": "test", "value": "123"}
assert usage.prompt_tokens == 80
assert usage.completion_tokens == 25
assert usage.total_tokens == 105
@pytest.mark.asyncio
async def test_infer_usage_none_backward_compat(self):
"""_infer_target_json with usage=None (default) still works."""
mock_response = _make_llm_response('{"name": "test"}')
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
)
assert result == {"name": "test"}
@pytest.mark.asyncio
async def test_infer_exception_no_usage_side_effect(self):
"""When _infer_target_json fails, usage is untouched (exception before tracking)."""
usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
side_effect=RuntimeError("API down"),
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
# Returns None on failure
assert result is None
# Usage unchanged — exception happened before tracking
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
@pytest.mark.asyncio
async def test_infer_empty_response_still_tracks(self):
"""When LLM returns empty content, usage is still tracked (response was received)."""
usage = TokenUsage()
mock_response = _make_llm_response("", prompt_tokens=80, completion_tokens=5)
with patch(
PATCH_TARGET,
new_callable=AsyncMock,
return_value=mock_response,
):
result = await JsonElementExtractionStrategy._infer_target_json(
query="extract names",
html_snippet="<div>test</div>",
llm_config=FakeLLMConfig(),
usage=usage,
)
# Returns None because content is empty
assert result is None
# But usage was tracked because we got a response
assert usage.prompt_tokens == 80
assert usage.completion_tokens == 5
assert usage.total_tokens == 85