diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 19b98522..d46d54ac 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -607,7 +607,7 @@ class AsyncWebCrawler: else config.chunking_strategy ) sections = chunking.chunk(content) - extracted_content = config.extraction_strategy.run(url, sections) + extracted_content = await config.extraction_strategy.run(url, sections) extracted_content = json.dumps( extracted_content, indent=4, default=str, ensure_ascii=False ) diff --git a/crawl4ai/content_filter_strategy.py b/crawl4ai/content_filter_strategy.py index 4102cbad..329579f9 100644 --- a/crawl4ai/content_filter_strategy.py +++ b/crawl4ai/content_filter_strategy.py @@ -9,7 +9,7 @@ from bs4 import NavigableString, Comment from .utils import ( clean_tokens, - perform_completion_with_backoff, + aperform_completion_with_backoff, escape_json_string, sanitize_html, get_home_folder, @@ -953,7 +953,7 @@ class LLMContentFilter(RelevantContentFilter): for var, value in prompt_variables.items(): prompt = prompt.replace("{" + var + "}", value) - def _proceed_with_chunk( + async def _proceed_with_chunk( provider: str, prompt: str, api_token: str, @@ -966,7 +966,7 @@ class LLMContentFilter(RelevantContentFilter): tag="CHUNK", params={"chunk_num": i + 1}, ) - return perform_completion_with_backoff( + return await aperform_completion_with_backoff( provider, prompt, api_token, diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 966f333e..d042063c 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -3,6 +3,7 @@ import inspect from typing import Any, List, Dict, Optional, Tuple, Pattern, Union from concurrent.futures import ThreadPoolExecutor, as_completed import json +import asyncio import time from enum import IntFlag, auto @@ -19,7 +20,7 @@ from .utils import * # noqa: F403 from .utils import ( sanitize_html, escape_json_string, - perform_completion_with_backoff, + aperform_completion_with_backoff, extract_xml_data, split_and_parse_json_objects, sanitize_input_encode, @@ -66,7 +67,7 @@ class ExtractionStrategy(ABC): self.verbose = kwargs.get("verbose", False) @abstractmethod - def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: + async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML. @@ -76,7 +77,7 @@ class ExtractionStrategy(ABC): """ pass - def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: + async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ Process sections of text in parallel by default. @@ -85,13 +86,13 @@ class ExtractionStrategy(ABC): :return: A list of processed JSON blocks. """ extracted_content = [] - with ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self.extract, url, section, **kwargs) - for section in sections - ] - for future in as_completed(futures): - extracted_content.extend(future.result()) + tasks = [ + asyncio.create_task(self.extract(url, section, **kwargs)) + for section in sections + ] + results = await asyncio.gather(*tasks) + for result in results: + extracted_content.extend(result) return extracted_content @@ -100,19 +101,18 @@ class NoExtractionStrategy(ExtractionStrategy): A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block. """ - def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: + async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML. """ return [{"index": 0, "content": html}] - def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: + async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: return [ {"index": i, "tags": [], "content": section} for i, section in enumerate(sections) ] - ####################################################### # Strategies using clustering for text data extraction # ####################################################### @@ -386,7 +386,7 @@ class CosineStrategy(ExtractionStrategy): return filtered_clusters - def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: + async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: """ Extract clusters from HTML content using hierarchical clustering. @@ -458,7 +458,7 @@ class CosineStrategy(ExtractionStrategy): return cluster_list - def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: + async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ Process sections using hierarchical clustering. @@ -584,7 +584,7 @@ class LLMExtractionStrategy(ExtractionStrategy): super().__setattr__(name, value) - def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]: + async def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML using an LLM. @@ -628,7 +628,7 @@ class LLMExtractionStrategy(ExtractionStrategy): ) try: - response = perform_completion_with_backoff( + response = await aperform_completion_with_backoff( self.llm_config.provider, prompt_with_variables, self.llm_config.api_token, @@ -723,7 +723,7 @@ class LLMExtractionStrategy(ExtractionStrategy): ) return sections - def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: + async def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: """ Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. @@ -748,35 +748,11 @@ class LLMExtractionStrategy(ExtractionStrategy): extracted_content.extend( extract_func(ix, sanitize_input_encode(section)) ) - time.sleep(0.5) # 500 ms delay between each processing + await asyncio.sleep(0.5) # 500 ms delay between each processing else: # Parallel processing using ThreadPoolExecutor - # extract_func = partial(self.extract, url) - # for ix, section in enumerate(merged_sections): - # extracted_content.append(extract_func(ix, section)) - - with ThreadPoolExecutor(max_workers=4) as executor: - extract_func = partial(self.extract, url) - futures = [ - executor.submit(extract_func, ix, sanitize_input_encode(section)) - for ix, section in enumerate(merged_sections) - ] - - for future in as_completed(futures): - try: - extracted_content.extend(future.result()) - except Exception as e: - if self.verbose: - print(f"Error in thread execution: {e}") - # Add error information to extracted_content - extracted_content.append( - { - "index": 0, - "error": True, - "tags": ["error"], - "content": str(e), - } - ) + extract_func = partial(self.extract, url) + extracted_content = await asyncio.gather(*[extract_func(ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)]) return extracted_content @@ -797,7 +773,6 @@ class LLMExtractionStrategy(ExtractionStrategy): f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}" ) - ####################################################### # New extraction strategies for JSON-based extraction # ####################################################### @@ -846,7 +821,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): self.schema = schema self.verbose = kwargs.get("verbose", False) - def extract( + async def extract( self, url: str, html_content: str, *q, **kwargs ) -> List[Dict[str, Any]]: """ @@ -1044,7 +1019,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): print(f"Error computing field {field['name']}: {str(e)}") return field.get("default") - def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: + async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ Run the extraction strategy on a combined HTML content. @@ -1063,7 +1038,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): """ combined_html = self.DEL.join(sections) - return self.extract(url, combined_html, **kwargs) + return await self.extract(url, combined_html, **kwargs) @abstractmethod def _get_element_text(self, element) -> str: @@ -1086,7 +1061,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): } @staticmethod - def generate_schema( + async def generate_schema( html: str, schema_type: str = "CSS", # or XPATH query: str = None, @@ -1112,7 +1087,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): dict: Generated schema following the JsonElementExtractionStrategy format """ from .prompts import JSON_SCHEMA_BUILDER - from .utils import perform_completion_with_backoff + from .utils import aperform_completion_with_backoff for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items(): if locals()[name] is not None: raise AttributeError(f"Setting '{name}' is deprecated. {message}") @@ -1179,7 +1154,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa try: # Call LLM with backoff handling - response = perform_completion_with_backoff( + response = await aperform_completion_with_backoff( provider=llm_config.provider, prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]), json_response = True, @@ -1858,7 +1833,7 @@ class RegexExtractionStrategy(ExtractionStrategy): # ------------------------------------------------------------------ # # Extraction # ------------------------------------------------------------------ # - def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]: + async def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]: # text = self._plain_text(html) out: List[Dict[str, Any]] = [] @@ -1889,7 +1864,7 @@ class RegexExtractionStrategy(ExtractionStrategy): # LLM-assisted one-off pattern builder # ------------------------------------------------------------------ # @staticmethod - def generate_pattern( + async def generate_pattern( label: str, html: str, *, @@ -1946,7 +1921,7 @@ class RegexExtractionStrategy(ExtractionStrategy): user_msg = "\n\n".join(user_parts) # ── LLM call (with retry/backoff) - resp = perform_completion_with_backoff( + resp = await aperform_completion_with_backoff( provider=llm_config.provider, prompt_with_variables="\n\n".join([system_msg, user_msg]), json_response=True, diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index d8b366d9..2c397be6 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -1672,7 +1672,7 @@ def extract_xml_data(tags, string): return data -def perform_completion_with_backoff( +async def aperform_completion_with_backoff( provider, prompt_with_variables, api_token, @@ -1700,7 +1700,7 @@ def perform_completion_with_backoff( dict: The API response or an error message after all retries. """ - from litellm import completion + from litellm import acompletion from litellm.exceptions import RateLimitError max_attempts = 3 @@ -1715,7 +1715,7 @@ def perform_completion_with_backoff( for attempt in range(max_attempts): try: - response = completion( + response = await acompletion( model=provider, messages=[{"role": "user", "content": prompt_with_variables}], **extra_args, @@ -1754,7 +1754,7 @@ def perform_completion_with_backoff( # ] -def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None): +async def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None): """ Extract content blocks from website HTML using an AI provider. @@ -1788,7 +1788,7 @@ def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_ur "{" + variable + "}", variable_values[variable] ) - response = perform_completion_with_backoff( + response = await aperform_completion_with_backoff( provider, prompt_with_variables, api_token, base_url=base_url )