Compare commits
2 Commits
feature/ma
...
feature/as
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b2ef12e25 | ||
|
|
d9b3db925a |
@@ -607,7 +607,7 @@ class AsyncWebCrawler:
|
||||
else config.chunking_strategy
|
||||
)
|
||||
sections = chunking.chunk(content)
|
||||
extracted_content = config.extraction_strategy.run(url, sections)
|
||||
extracted_content = await config.extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(
|
||||
extracted_content, indent=4, default=str, ensure_ascii=False
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from bs4 import NavigableString, Comment
|
||||
|
||||
from .utils import (
|
||||
clean_tokens,
|
||||
perform_completion_with_backoff,
|
||||
aperform_completion_with_backoff,
|
||||
escape_json_string,
|
||||
sanitize_html,
|
||||
get_home_folder,
|
||||
@@ -953,7 +953,7 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
for var, value in prompt_variables.items():
|
||||
prompt = prompt.replace("{" + var + "}", value)
|
||||
|
||||
def _proceed_with_chunk(
|
||||
async def _proceed_with_chunk(
|
||||
provider: str,
|
||||
prompt: str,
|
||||
api_token: str,
|
||||
@@ -966,7 +966,7 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
tag="CHUNK",
|
||||
params={"chunk_num": i + 1},
|
||||
)
|
||||
return perform_completion_with_backoff(
|
||||
return await aperform_completion_with_backoff(
|
||||
provider,
|
||||
prompt,
|
||||
api_token,
|
||||
|
||||
@@ -3,6 +3,7 @@ import inspect
|
||||
from typing import Any, List, Dict, Optional, Tuple, Pattern, Union
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from enum import IntFlag, auto
|
||||
|
||||
@@ -19,7 +20,7 @@ from .utils import * # noqa: F403
|
||||
from .utils import (
|
||||
sanitize_html,
|
||||
escape_json_string,
|
||||
perform_completion_with_backoff,
|
||||
aperform_completion_with_backoff,
|
||||
extract_xml_data,
|
||||
split_and_parse_json_objects,
|
||||
sanitize_input_encode,
|
||||
@@ -66,7 +67,7 @@ class ExtractionStrategy(ABC):
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract meaningful blocks or chunks from the given HTML.
|
||||
|
||||
@@ -76,7 +77,7 @@ class ExtractionStrategy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections of text in parallel by default.
|
||||
|
||||
@@ -85,13 +86,13 @@ class ExtractionStrategy(ABC):
|
||||
:return: A list of processed JSON blocks.
|
||||
"""
|
||||
extracted_content = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.extract, url, section, **kwargs)
|
||||
for section in sections
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
extracted_content.extend(future.result())
|
||||
tasks = [
|
||||
asyncio.create_task(self.extract(url, section, **kwargs))
|
||||
for section in sections
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
for result in results:
|
||||
extracted_content.extend(result)
|
||||
return extracted_content
|
||||
|
||||
|
||||
@@ -100,19 +101,18 @@ class NoExtractionStrategy(ExtractionStrategy):
|
||||
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
|
||||
"""
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract meaningful blocks or chunks from the given HTML.
|
||||
"""
|
||||
return [{"index": 0, "content": html}]
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"index": i, "tags": [], "content": section}
|
||||
for i, section in enumerate(sections)
|
||||
]
|
||||
|
||||
|
||||
#######################################################
|
||||
# Strategies using clustering for text data extraction #
|
||||
#######################################################
|
||||
@@ -386,7 +386,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
|
||||
return filtered_clusters
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract clusters from HTML content using hierarchical clustering.
|
||||
|
||||
@@ -458,7 +458,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
|
||||
return cluster_list
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections using hierarchical clustering.
|
||||
|
||||
@@ -584,7 +584,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
|
||||
async def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract meaningful blocks or chunks from the given HTML using an LLM.
|
||||
|
||||
@@ -628,7 +628,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
)
|
||||
|
||||
try:
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
self.llm_config.provider,
|
||||
prompt_with_variables,
|
||||
self.llm_config.api_token,
|
||||
@@ -723,7 +723,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
)
|
||||
return sections
|
||||
|
||||
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
|
||||
async def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||
|
||||
@@ -748,35 +748,11 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
extracted_content.extend(
|
||||
extract_func(ix, sanitize_input_encode(section))
|
||||
)
|
||||
time.sleep(0.5) # 500 ms delay between each processing
|
||||
await asyncio.sleep(0.5) # 500 ms delay between each processing
|
||||
else:
|
||||
# Parallel processing using ThreadPoolExecutor
|
||||
# extract_func = partial(self.extract, url)
|
||||
# for ix, section in enumerate(merged_sections):
|
||||
# extracted_content.append(extract_func(ix, section))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
extract_func = partial(self.extract, url)
|
||||
futures = [
|
||||
executor.submit(extract_func, ix, sanitize_input_encode(section))
|
||||
for ix, section in enumerate(merged_sections)
|
||||
]
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
extracted_content.extend(future.result())
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error in thread execution: {e}")
|
||||
# Add error information to extracted_content
|
||||
extracted_content.append(
|
||||
{
|
||||
"index": 0,
|
||||
"error": True,
|
||||
"tags": ["error"],
|
||||
"content": str(e),
|
||||
}
|
||||
)
|
||||
extract_func = partial(self.extract, url)
|
||||
extracted_content = await asyncio.gather(*[extract_func(ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)])
|
||||
|
||||
return extracted_content
|
||||
|
||||
@@ -797,7 +773,6 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
|
||||
)
|
||||
|
||||
|
||||
#######################################################
|
||||
# New extraction strategies for JSON-based extraction #
|
||||
#######################################################
|
||||
@@ -846,7 +821,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
self.schema = schema
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
def extract(
|
||||
async def extract(
|
||||
self, url: str, html_content: str, *q, **kwargs
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -1044,7 +1019,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
print(f"Error computing field {field['name']}: {str(e)}")
|
||||
return field.get("default")
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
async def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run the extraction strategy on a combined HTML content.
|
||||
|
||||
@@ -1063,7 +1038,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
"""
|
||||
|
||||
combined_html = self.DEL.join(sections)
|
||||
return self.extract(url, combined_html, **kwargs)
|
||||
return await self.extract(url, combined_html, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _get_element_text(self, element) -> str:
|
||||
@@ -1086,7 +1061,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_schema(
|
||||
async def generate_schema(
|
||||
html: str,
|
||||
schema_type: str = "CSS", # or XPATH
|
||||
query: str = None,
|
||||
@@ -1112,7 +1087,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
dict: Generated schema following the JsonElementExtractionStrategy format
|
||||
"""
|
||||
from .prompts import JSON_SCHEMA_BUILDER
|
||||
from .utils import perform_completion_with_backoff
|
||||
from .utils import aperform_completion_with_backoff
|
||||
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
|
||||
if locals()[name] is not None:
|
||||
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
|
||||
@@ -1179,7 +1154,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
|
||||
try:
|
||||
# Call LLM with backoff handling
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
provider=llm_config.provider,
|
||||
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
|
||||
json_response = True,
|
||||
@@ -1858,7 +1833,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
|
||||
# ------------------------------------------------------------------ #
|
||||
# Extraction
|
||||
# ------------------------------------------------------------------ #
|
||||
def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]:
|
||||
async def extract(self, url: str, content: str, *q, **kw) -> List[Dict[str, Any]]:
|
||||
# text = self._plain_text(html)
|
||||
out: List[Dict[str, Any]] = []
|
||||
|
||||
@@ -1889,7 +1864,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
|
||||
# LLM-assisted one-off pattern builder
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def generate_pattern(
|
||||
async def generate_pattern(
|
||||
label: str,
|
||||
html: str,
|
||||
*,
|
||||
@@ -1946,7 +1921,7 @@ class RegexExtractionStrategy(ExtractionStrategy):
|
||||
user_msg = "\n\n".join(user_parts)
|
||||
|
||||
# ── LLM call (with retry/backoff)
|
||||
resp = perform_completion_with_backoff(
|
||||
resp = await aperform_completion_with_backoff(
|
||||
provider=llm_config.provider,
|
||||
prompt_with_variables="\n\n".join([system_msg, user_msg]),
|
||||
json_response=True,
|
||||
|
||||
@@ -1672,7 +1672,7 @@ def extract_xml_data(tags, string):
|
||||
return data
|
||||
|
||||
|
||||
def perform_completion_with_backoff(
|
||||
async def aperform_completion_with_backoff(
|
||||
provider,
|
||||
prompt_with_variables,
|
||||
api_token,
|
||||
@@ -1700,7 +1700,7 @@ def perform_completion_with_backoff(
|
||||
dict: The API response or an error message after all retries.
|
||||
"""
|
||||
|
||||
from litellm import completion
|
||||
from litellm import acompletion
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
max_attempts = 3
|
||||
@@ -1715,7 +1715,7 @@ def perform_completion_with_backoff(
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = completion(
|
||||
response = await acompletion(
|
||||
model=provider,
|
||||
messages=[{"role": "user", "content": prompt_with_variables}],
|
||||
**extra_args,
|
||||
@@ -1754,7 +1754,7 @@ def perform_completion_with_backoff(
|
||||
# ]
|
||||
|
||||
|
||||
def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None):
|
||||
async def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None):
|
||||
"""
|
||||
Extract content blocks from website HTML using an AI provider.
|
||||
|
||||
@@ -1788,7 +1788,7 @@ def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_ur
|
||||
"{" + variable + "}", variable_values[variable]
|
||||
)
|
||||
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
provider, prompt_with_variables, api_token, base_url=base_url
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from crawl4ai import (
|
||||
RateLimiter,
|
||||
LLMConfig
|
||||
)
|
||||
from crawl4ai.utils import perform_completion_with_backoff
|
||||
from crawl4ai.utils import aperform_completion_with_backoff
|
||||
from crawl4ai.content_filter_strategy import (
|
||||
PruningContentFilter,
|
||||
BM25ContentFilter,
|
||||
@@ -88,7 +88,7 @@ async def handle_llm_qa(
|
||||
|
||||
Answer:"""
|
||||
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
provider=config["llm"]["provider"],
|
||||
prompt_with_variables=prompt,
|
||||
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
|
||||
|
||||
@@ -3553,7 +3553,7 @@ from .utils import * # noqa: F403
|
||||
from .utils import (
|
||||
sanitize_html,
|
||||
escape_json_string,
|
||||
perform_completion_with_backoff,
|
||||
aperform_completion_with_backoff,
|
||||
extract_xml_data,
|
||||
split_and_parse_json_objects,
|
||||
sanitize_input_encode,
|
||||
@@ -4162,7 +4162,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
)
|
||||
|
||||
try:
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
self.llm_config.provider,
|
||||
prompt_with_variables,
|
||||
self.llm_config.api_token,
|
||||
@@ -4646,7 +4646,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
dict: Generated schema following the JsonElementExtractionStrategy format
|
||||
"""
|
||||
from .prompts import JSON_SCHEMA_BUILDER
|
||||
from .utils import perform_completion_with_backoff
|
||||
from .utils import aperform_completion_with_backoff
|
||||
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
|
||||
if locals()[name] is not None:
|
||||
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
|
||||
@@ -4709,7 +4709,7 @@ In this scenario, use your best judgment to generate the schema. You need to exa
|
||||
|
||||
try:
|
||||
# Call LLM with backoff handling
|
||||
response = perform_completion_with_backoff(
|
||||
response = await aperform_completion_with_backoff(
|
||||
provider=llm_config.provider,
|
||||
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
|
||||
json_response = True,
|
||||
@@ -5597,7 +5597,7 @@ from bs4 import NavigableString, Comment
|
||||
|
||||
from .utils import (
|
||||
clean_tokens,
|
||||
perform_completion_with_backoff,
|
||||
aperform_completion_with_backoff,
|
||||
escape_json_string,
|
||||
sanitize_html,
|
||||
get_home_folder,
|
||||
@@ -6556,7 +6556,7 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
tag="CHUNK",
|
||||
params={"chunk_num": i + 1},
|
||||
)
|
||||
return perform_completion_with_backoff(
|
||||
return await aperform_completion_with_backoff(
|
||||
provider,
|
||||
prompt,
|
||||
api_token,
|
||||
|
||||
Reference in New Issue
Block a user