Compare commits

...

2 Commits

Author SHA1 Message Date
Ahmed-Tawfik94
2b2ef12e25 #1156: Refactor completion function calls to use asynchronous version 2025-05-27 15:10:34 +08:00
Ahmed-Tawfik94
d9b3db925a Refactor extraction and completion functions to support asynchronous execution 2025-05-26 16:01:38 +08:00
6 changed files with 47 additions and 72 deletions

View File

@@ -607,7 +607,7 @@ class AsyncWebCrawler:
else config.chunking_strategy else config.chunking_strategy
) )
sections = chunking.chunk(content) sections = chunking.chunk(content)
extracted_content = config.extraction_strategy.run(url, sections) extracted_content = await config.extraction_strategy.run(url, sections)
extracted_content = json.dumps( extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False extracted_content, indent=4, default=str, ensure_ascii=False
) )

View File

@@ -9,7 +9,7 @@ from bs4 import NavigableString, Comment
from .utils import ( from .utils import (
clean_tokens, clean_tokens,
perform_completion_with_backoff, aperform_completion_with_backoff,
escape_json_string, escape_json_string,
sanitize_html, sanitize_html,
get_home_folder, get_home_folder,
@@ -953,7 +953,7 @@ class LLMContentFilter(RelevantContentFilter):
for var, value in prompt_variables.items(): for var, value in prompt_variables.items():
prompt = prompt.replace("{" + var + "}", value) prompt = prompt.replace("{" + var + "}", value)
def _proceed_with_chunk( async def _proceed_with_chunk(
provider: str, provider: str,
prompt: str, prompt: str,
api_token: str, api_token: str,
@@ -966,7 +966,7 @@ class LLMContentFilter(RelevantContentFilter):
tag="CHUNK", tag="CHUNK",
params={"chunk_num": i + 1}, params={"chunk_num": i + 1},
) )
return perform_completion_with_backoff( return await aperform_completion_with_backoff(
provider, provider,
prompt, prompt,
api_token, api_token,

View File

@@ -3,6 +3,7 @@ import inspect
from typing import Any, List, Dict, Optional, Tuple, Pattern, Union from typing import Any, List, Dict, Optional, Tuple, Pattern, Union
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import json import json
import asyncio
import time import time
from enum import IntFlag, auto from enum import IntFlag, auto
@@ -19,7 +20,7 @@ from .utils import * # noqa: F403
from .utils import ( from .utils import (
sanitize_html, sanitize_html,
escape_json_string, escape_json_string,
perform_completion_with_backoff, aperform_completion_with_backoff,
extract_xml_data, extract_xml_data,
split_and_parse_json_objects, split_and_parse_json_objects,
sanitize_input_encode, sanitize_input_encode,
@@ -66,7 +67,7 @@ class ExtractionStrategy(ABC):
self.verbose = kwargs.get("verbose", False) self.verbose = kwargs.get("verbose", False)
@abstractmethod @abstractmethod
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Extract meaningful blocks or chunks from the given HTML. Extract meaningful blocks or chunks from the given HTML.
@@ -76,7 +77,7 @@ class ExtractionStrategy(ABC):
""" """
pass pass
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Process sections of text in parallel by default. Process sections of text in parallel by default.
@@ -85,13 +86,13 @@ class ExtractionStrategy(ABC):
:return: A list of processed JSON blocks. :return: A list of processed JSON blocks.
""" """
extracted_content = [] extracted_content = []
with ThreadPoolExecutor() as executor: tasks = [
futures = [ asyncio.create_task(self.extract(url, section, **kwargs))
executor.submit(self.extract, url, section, **kwargs) for section in sections
for section in sections ]
] results = await asyncio.gather(*tasks)
for future in as_completed(futures): for result in results:
extracted_content.extend(future.result()) extracted_content.extend(result)
return extracted_content return extracted_content
@@ -100,19 +101,18 @@ class NoExtractionStrategy(ExtractionStrategy):
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block. A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
""" """
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Extract meaningful blocks or chunks from the given HTML. Extract meaningful blocks or chunks from the given HTML.
""" """
return [{"index": 0, "content": html}] return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [ return [
{"index": i, "tags": [], "content": section} {"index": i, "tags": [], "content": section}
for i, section in enumerate(sections) for i, section in enumerate(sections)
] ]
####################################################### #######################################################
# Strategies using clustering for text data extraction # # Strategies using clustering for text data extraction #
####################################################### #######################################################
@@ -386,7 +386,7 @@ class CosineStrategy(ExtractionStrategy):
return filtered_clusters return filtered_clusters
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Extract clusters from HTML content using hierarchical clustering. Extract clusters from HTML content using hierarchical clustering.
@@ -458,7 +458,7 @@ class CosineStrategy(ExtractionStrategy):
return cluster_list return cluster_list
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Process sections using hierarchical clustering. Process sections using hierarchical clustering.
@@ -584,7 +584,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
super().__setattr__(name, value) super().__setattr__(name, value)
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]: async def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
""" """
Extract meaningful blocks or chunks from the given HTML using an LLM. Extract meaningful blocks or chunks from the given HTML using an LLM.
@@ -628,7 +628,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
) )
try: try:
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
self.llm_config.provider, self.llm_config.provider,
prompt_with_variables, prompt_with_variables,
self.llm_config.api_token, self.llm_config.api_token,
@@ -723,7 +723,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
) )
return sections return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: async def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
""" """
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
@@ -748,35 +748,11 @@ class LLMExtractionStrategy(ExtractionStrategy):
extracted_content.extend( extracted_content.extend(
extract_func(ix, sanitize_input_encode(section)) extract_func(ix, sanitize_input_encode(section))
) )
time.sleep(0.5) # 500 ms delay between each processing await asyncio.sleep(0.5) # 500 ms delay between each processing
else: else:
# Parallel processing using ThreadPoolExecutor # Parallel processing using ThreadPoolExecutor
# extract_func = partial(self.extract, url) extract_func = partial(self.extract, url)
# for ix, section in enumerate(merged_sections): extracted_content = await asyncio.gather(*[extract_func(ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)])
# extracted_content.append(extract_func(ix, section))
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [
executor.submit(extract_func, ix, sanitize_input_encode(section))
for ix, section in enumerate(merged_sections)
]
for future in as_completed(futures):
try:
extracted_content.extend(future.result())
except Exception as e:
if self.verbose:
print(f"Error in thread execution: {e}")
# Add error information to extracted_content
extracted_content.append(
{
"index": 0,
"error": True,
"tags": ["error"],
"content": str(e),
}
)
return extracted_content return extracted_content
@@ -797,7 +773,6 @@ class LLMExtractionStrategy(ExtractionStrategy):
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}" f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
) )
####################################################### #######################################################
# New extraction strategies for JSON-based extraction # # New extraction strategies for JSON-based extraction #
####################################################### #######################################################
@@ -846,7 +821,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
self.schema = schema self.schema = schema
self.verbose = kwargs.get("verbose", False) self.verbose = kwargs.get("verbose", False)
def extract( async def extract(
self, url: str, html_content: str, *q, **kwargs self, url: str, html_content: str, *q, **kwargs
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -1044,7 +1019,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
print(f"Error computing field {field['name']}: {str(e)}") print(f"Error computing field {field['name']}: {str(e)}")
return field.get("default") return field.get("default")
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Run the extraction strategy on a combined HTML content. Run the extraction strategy on a combined HTML content.
@@ -1063,7 +1038,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
""" """
combined_html = self.DEL.join(sections) combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs) return await self.extract(url, combined_html, **kwargs)
@abstractmethod @abstractmethod
def _get_element_text(self, element) -> str: def _get_element_text(self, element) -> str:
@@ -1086,7 +1061,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
} }
@staticmethod @staticmethod
def generate_schema( async def generate_schema(
html: str, html: str,
schema_type: str = "CSS", # or XPATH schema_type: str = "CSS", # or XPATH
query: str = None, query: str = None,
@@ -1112,7 +1087,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
dict: Generated schema following the JsonElementExtractionStrategy format dict: Generated schema following the JsonElementExtractionStrategy format
""" """
from .prompts import JSON_SCHEMA_BUILDER from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff from .utils import aperform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items(): for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None: if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}") raise AttributeError(f"Setting '{name}' is deprecated. {message}")
@@ -1179,7 +1154,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
try: try:
# Call LLM with backoff handling # Call LLM with backoff handling
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
provider=llm_config.provider, provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]), prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True, json_response = True,
@@ -1858,7 +1833,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Extraction # Extraction
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]: async def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]:
# text = self._plain_text(html) # text = self._plain_text(html)
out: List[Dict[str, Any]] = [] out: List[Dict[str, Any]] = []
@@ -1889,7 +1864,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
# LLM-assisted one-off pattern builder # LLM-assisted one-off pattern builder
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
@staticmethod @staticmethod
def generate_pattern( async def generate_pattern(
label: str, label: str,
html: str, html: str,
*, *,
@@ -1946,7 +1921,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
user_msg = "\n\n".join(user_parts) user_msg = "\n\n".join(user_parts)
# ── LLM call (with retry/backoff) # ── LLM call (with retry/backoff)
resp = perform_completion_with_backoff( resp = await aperform_completion_with_backoff(
provider=llm_config.provider, provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_msg, user_msg]), prompt_with_variables="\n\n".join([system_msg, user_msg]),
json_response=True, json_response=True,

View File

@@ -1672,7 +1672,7 @@ def extract_xml_data(tags, string):
return data return data
def perform_completion_with_backoff( async def aperform_completion_with_backoff(
provider, provider,
prompt_with_variables, prompt_with_variables,
api_token, api_token,
@@ -1700,7 +1700,7 @@ def perform_completion_with_backoff(
dict: The API response or an error message after all retries. dict: The API response or an error message after all retries.
""" """
from litellm import completion from litellm import acompletion
from litellm.exceptions import RateLimitError from litellm.exceptions import RateLimitError
max_attempts = 3 max_attempts = 3
@@ -1715,7 +1715,7 @@ def perform_completion_with_backoff(
for attempt in range(max_attempts): for attempt in range(max_attempts):
try: try:
response = completion( response = await acompletion(
model=provider, model=provider,
messages=[{"role": "user", "content": prompt_with_variables}], messages=[{"role": "user", "content": prompt_with_variables}],
**extra_args, **extra_args,
@@ -1754,7 +1754,7 @@ def perform_completion_with_backoff(
# ] # ]
def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None): async def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None):
""" """
Extract content blocks from website HTML using an AI provider. Extract content blocks from website HTML using an AI provider.
@@ -1788,7 +1788,7 @@ def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_ur
"{" + variable + "}", variable_values[variable] "{" + variable + "}", variable_values[variable]
) )
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
provider, prompt_with_variables, api_token, base_url=base_url provider, prompt_with_variables, api_token, base_url=base_url
) )

View File

@@ -24,7 +24,7 @@ from crawl4ai import (
RateLimiter, RateLimiter,
LLMConfig LLMConfig
) )
from crawl4ai.utils import perform_completion_with_backoff from crawl4ai.utils import aperform_completion_with_backoff
from crawl4ai.content_filter_strategy import ( from crawl4ai.content_filter_strategy import (
PruningContentFilter, PruningContentFilter,
BM25ContentFilter, BM25ContentFilter,
@@ -88,7 +88,7 @@ async def handle_llm_qa(
Answer:""" Answer:"""
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
provider=config["llm"]["provider"], provider=config["llm"]["provider"],
prompt_with_variables=prompt, prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", "")) api_token=os.environ.get(config["llm"].get("api_key_env", ""))

View File

@@ -3553,7 +3553,7 @@ from .utils import * # noqa: F403
from .utils import ( from .utils import (
sanitize_html, sanitize_html,
escape_json_string, escape_json_string,
perform_completion_with_backoff, aperform_completion_with_backoff,
extract_xml_data, extract_xml_data,
split_and_parse_json_objects, split_and_parse_json_objects,
sanitize_input_encode, sanitize_input_encode,
@@ -4162,7 +4162,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
) )
try: try:
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
self.llm_config.provider, self.llm_config.provider,
prompt_with_variables, prompt_with_variables,
self.llm_config.api_token, self.llm_config.api_token,
@@ -4646,7 +4646,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
dict: Generated schema following the JsonElementExtractionStrategy format dict: Generated schema following the JsonElementExtractionStrategy format
""" """
from .prompts import JSON_SCHEMA_BUILDER from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff from .utils import aperform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items(): for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None: if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}") raise AttributeError(f"Setting '{name}' is deprecated. {message}")
@@ -4709,7 +4709,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
try: try:
# Call LLM with backoff handling # Call LLM with backoff handling
response = perform_completion_with_backoff( response = await aperform_completion_with_backoff(
provider=llm_config.provider, provider=llm_config.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]), prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True, json_response = True,
@@ -5597,7 +5597,7 @@ from bs4 import NavigableString, Comment
from .utils import ( from .utils import (
clean_tokens, clean_tokens,
perform_completion_with_backoff, aperform_completion_with_backoff,
escape_json_string, escape_json_string,
sanitize_html, sanitize_html,
get_home_folder, get_home_folder,
@@ -6556,7 +6556,7 @@ class LLMContentFilter(RelevantContentFilter):
tag="CHUNK", tag="CHUNK",
params={"chunk_num": i + 1}, params={"chunk_num": i + 1},
) )
return perform_completion_with_backoff( return await aperform_completion_with_backoff(
provider, provider,
prompt, prompt,
api_token, api_token,