Compare commits

...

2 Commits

Author SHA1 Message Date
Ahmed-Tawfik94
2b2ef12e25 #1156: Refactor completion function calls to use asynchronous version 2025-05-27 15:10:34 +08:00
Ahmed-Tawfik94
d9b3db925a Refactor extraction and completion functions to support asynchronous execution 2025-05-26 16:01:38 +08:00
6 changed files with 47 additions and 72 deletions

View File

@@ -607,7 +607,7 @@ class AsyncWebCrawler:
else config.chunking_strategy
)
sections = chunking.chunk(content)
extracted_content = config.extraction_strategy.run(url, sections)
extracted_content = await config.extraction_strategy.run(url, sections)
extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False
)

View File

@@ -9,7 +9,7 @@ from bs4 import NavigableString, Comment
from .utils import (
clean_tokens,
perform_completion_with_backoff,
aperform_completion_with_backoff,
escape_json_string,
sanitize_html,
get_home_folder,
@@ -953,7 +953,7 @@ class LLMContentFilter(RelevantContentFilter):
for var, value in prompt_variables.items():
prompt = prompt.replace("{" + var + "}", value)
def _proceed_with_chunk(
async def _proceed_with_chunk(
provider: str,
prompt: str,
api_token: str,
@@ -966,7 +966,7 @@ class LLMContentFilter(RelevantContentFilter):
tag="CHUNK",
params={"chunk_num": i + 1},
)
return perform_completion_with_backoff(
return await aperform_completion_with_backoff(
provider,
prompt,
api_token,

View File

@@ -3,6 +3,7 @@ import inspect
from typing import Any, List, Dict, Optional, Tuple, Pattern, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import asyncio
import time
from enum import IntFlag, auto
@@ -19,7 +20,7 @@ from .utils import * # noqa: F403
from .utils import (
sanitize_html,
escape_json_string,
perform_completion_with_backoff,
aperform_completion_with_backoff,
extract_xml_data,
split_and_parse_json_objects,
sanitize_input_encode,
@@ -66,7 +67,7 @@ class ExtractionStrategy(ABC):
self.verbose = kwargs.get("verbose", False)
@abstractmethod
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML.
@@ -76,7 +77,7 @@ class ExtractionStrategy(ABC):
"""
pass
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
"""
Process sections of text in parallel by default.
@@ -85,13 +86,13 @@ class ExtractionStrategy(ABC):
:return: A list of processed JSON blocks.
"""
extracted_content = []
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.extract, url, section, **kwargs)
for section in sections
]
for future in as_completed(futures):
extracted_content.extend(future.result())
tasks = [
asyncio.create_task(self.extract(url, section, **kwargs))
for section in sections
]
results = await asyncio.gather(*tasks)
for result in results:
extracted_content.extend(result)
return extracted_content
@@ -100,19 +101,18 @@ class NoExtractionStrategy(ExtractionStrategy):
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
"""
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML.
"""
return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [
{"index": i, "tags": [], "content": section}
for i, section in enumerate(sections)
]
#######################################################
# Strategies using clustering for text data extraction #
#######################################################
@@ -386,7 +386,7 @@ class CosineStrategy(ExtractionStrategy):
return filtered_clusters
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
"""
Extract clusters from HTML content using hierarchical clustering.
@@ -458,7 +458,7 @@ class CosineStrategy(ExtractionStrategy):
return cluster_list
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
"""
Process sections using hierarchical clustering.
@@ -584,7 +584,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
super().__setattr__(name, value)
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
async def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML using an LLM.
@@ -628,7 +628,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
)
try:
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
self.llm_config.provider,
prompt_with_variables,
self.llm_config.api_token,
@@ -723,7 +723,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
)
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
async def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
@@ -748,35 +748,11 @@ class LLMExtractionStrategy(ExtractionStrategy):
extracted_content.extend(
extract_func(ix, sanitize_input_encode(section))
)
time.sleep(0.5) # 500 ms delay between each processing
await asyncio.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
# extract_func = partial(self.extract, url)
# for ix, section in enumerate(merged_sections):
# extracted_content.append(extract_func(ix, section))
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [
executor.submit(extract_func, ix, sanitize_input_encode(section))
for ix, section in enumerate(merged_sections)
]
for future in as_completed(futures):
try:
extracted_content.extend(future.result())
except Exception as e:
if self.verbose:
print(f"Error in thread execution: {e}")
# Add error information to extracted_content
extracted_content.append(
{
"index": 0,
"error": True,
"tags": ["error"],
"content": str(e),
}
)
extract_func = partial(self.extract, url)
extracted_content = await asyncio.gather(*[extract_func(ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)])
return extracted_content
@@ -797,7 +773,6 @@ class LLMExtractionStrategy(ExtractionStrategy):
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
)
#######################################################
# New extraction strategies for JSON-based extraction #
#######################################################
@@ -846,7 +821,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
self.schema = schema
self.verbose = kwargs.get("verbose", False)
def extract(
async def extract(
self, url: str, html_content: str, *q, **kwargs
) -> List[Dict[str, Any]]:
"""
@@ -1044,7 +1019,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
print(f"Error computing field {field['name']}: {str(e)}")
return field.get("default")
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
"""
Run the extraction strategy on a combined HTML content.
@@ -1063,7 +1038,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""
combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs)
return await self.extract(url, combined_html, **kwargs)
@abstractmethod
def _get_element_text(self, element) -> str:
@@ -1086,7 +1061,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
}
@staticmethod
def generate_schema(
async def generate_schema(
html: str,
schema_type: str = "CSS", # or XPATH
query: str = None,
@@ -1112,7 +1087,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
dict: Generated schema following the JsonElementExtractionStrategy format
"""
from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff
from .utils import aperform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
@@ -1179,7 +1154,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
try:
# Call LLM with backoff handling
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True,
@@ -1858,7 +1833,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
# ------------------------------------------------------------------ #
# Extraction
# ------------------------------------------------------------------ #
def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]:
async def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]:
# text = self._plain_text(html)
out: List[Dict[str, Any]] = []
@@ -1889,7 +1864,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
# LLM-assisted one-off pattern builder
# ------------------------------------------------------------------ #
@staticmethod
def generate_pattern(
async def generate_pattern(
label: str,
html: str,
*,
@@ -1946,7 +1921,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
user_msg = "\n\n".join(user_parts)
# ── LLM call (with retry/backoff)
resp = perform_completion_with_backoff(
resp = await aperform_completion_with_backoff(
provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_msg, user_msg]),
json_response=True,

View File

@@ -1672,7 +1672,7 @@ def extract_xml_data(tags, string):
return data
def perform_completion_with_backoff(
async def aperform_completion_with_backoff(
provider,
prompt_with_variables,
api_token,
@@ -1700,7 +1700,7 @@ def perform_completion_with_backoff(
dict: The API response or an error message after all retries.
"""
from litellm import completion
from litellm import acompletion
from litellm.exceptions import RateLimitError
max_attempts = 3
@@ -1715,7 +1715,7 @@ def perform_completion_with_backoff(
for attempt in range(max_attempts):
try:
response = completion(
response = await acompletion(
model=provider,
messages=[{"role": "user", "content": prompt_with_variables}],
**extra_args,
@@ -1754,7 +1754,7 @@ def perform_completion_with_backoff(
# ]
def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None):
async def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None):
"""
Extract content blocks from website HTML using an AI provider.
@@ -1788,7 +1788,7 @@ def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_ur
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
provider, prompt_with_variables, api_token, base_url=base_url
)

View File

@@ -24,7 +24,7 @@ from crawl4ai import (
RateLimiter,
LLMConfig
)
from crawl4ai.utils import perform_completion_with_backoff
from crawl4ai.utils import aperform_completion_with_backoff
from crawl4ai.content_filter_strategy import (
PruningContentFilter,
BM25ContentFilter,
@@ -88,7 +88,7 @@ async def handle_llm_qa(
Answer:"""
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
provider=config["llm"]["provider"],
prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", ""))

View File

@@ -3553,7 +3553,7 @@ from .utils import * # noqa: F403
from .utils import (
sanitize_html,
escape_json_string,
perform_completion_with_backoff,
aperform_completion_with_backoff,
extract_xml_data,
split_and_parse_json_objects,
sanitize_input_encode,
@@ -4162,7 +4162,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
)
try:
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
self.llm_config.provider,
prompt_with_variables,
self.llm_config.api_token,
@@ -4646,7 +4646,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
dict: Generated schema following the JsonElementExtractionStrategy format
"""
from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff
from .utils import aperform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
@@ -4709,7 +4709,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
try:
# Call LLM with backoff handling
response = perform_completion_with_backoff(
response = await aperform_completion_with_backoff(
provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True,
@@ -5597,7 +5597,7 @@ from bs4 import NavigableString, Comment
from .utils import (
clean_tokens,
perform_completion_with_backoff,
aperform_completion_with_backoff,
escape_json_string,
sanitize_html,
get_home_folder,
@@ -6556,7 +6556,7 @@ class LLMContentFilter(RelevantContentFilter):
tag="CHUNK",
params={"chunk_num": i + 1},
)
return perform_completion_with_backoff(
return await aperform_completion_with_backoff(
provider,
prompt,
api_token,