chore: Update configuration values for chunk token threshold, overlap rate, and minimum word threshold. Create a new example for LLMExtraction Strategy, update Dockerfile, and README

This commit is contained in:
unclecode
2024-06-19 18:32:20 +08:00
parent 3f0e265baf
commit 539263a8ba
11 changed files with 212 additions and 130 deletions

View File

@@ -21,7 +21,9 @@ PROVIDER_MODELS = {
# Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 1000
CHUNK_TOKEN_THRESHOLD = 500
OVERLAP_RATE = 0.1
WORD_TOKEN_RATE = 1.3
# Threshold for the minimum number of word in a HTML tag to be considered
MIN_WORD_THRESHOLD = 5
MIN_WORD_THRESHOLD = 1

View File

@@ -79,8 +79,15 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.headless = True
if kwargs.get("user_agent"):
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
else:
# Set user agent
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
self.options.add_argument(f"--user-agent={user_agent}")
self.options.add_argument("--no-sandbox")
self.options.add_argument("--headless")
self.options.headless = kwargs.get("headless", True)
if self.options.headless:
self.options.add_argument("--headless")
# self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--disable-gpu")
# self.options.add_argument("--disable-extensions")
@@ -112,10 +119,19 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# chromedriver_autoinstaller.install()
import chromedriver_autoinstaller
self.service = Service(chromedriver_autoinstaller.install())
crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver(crawl4ai_folder, False)
# self.service = Service(chromedriver_autoinstaller.install())
self.service = Service(chromedriver_path)
self.service.log_path = "NUL"
self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.driver = self.execute_hook('on_driver_created', self.driver)
if kwargs.get("cookies"):
for cookie in kwargs.get("cookies"):
self.driver.add_cookie(cookie)
def set_hook(self, hook_type: str, hook: Callable):
if hook_type in self.hooks:

View File

@@ -3,12 +3,12 @@ from typing import Any, List, Dict, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import json, time
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS, PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
from .prompts import *
from .config import *
from .utils import *
from functools import partial
from .model_loader import *
import math
import numpy as np
class ExtractionStrategy(ABC):
@@ -55,7 +55,9 @@ class NoExtractionStrategy(ExtractionStrategy):
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
class LLMExtractionStrategy(ExtractionStrategy):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, instruction:str = None, **kwargs):
def __init__(self,
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs):
"""
Initialize the strategy with clustering parameters.
@@ -67,6 +69,13 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
self.instruction = instruction
self.extract_type = extraction_type
self.schema = schema
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD)
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
self.verbose = kwargs.get("verbose", False)
if not self.api_token:
@@ -81,10 +90,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
"HTML": escape_json_string(sanitize_html(html)),
}
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
if self.instruction:
variable_values["REQUEST"] = self.instruction
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
if self.extract_type == "schema":
variable_values["SCHEMA"] = json.dumps(self.schema)
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
prompt_with_variables = PROMPT_EXTRACT_BLOCKS if not self.instruction else PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
for variable in variable_values:
prompt_with_variables = prompt_with_variables.replace(
"{" + variable + "}", variable_values[variable]
@@ -112,32 +126,62 @@ class LLMExtractionStrategy(ExtractionStrategy):
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
return blocks
def _merge(self, documents):
def _merge(self, documents, chunk_token_threshold, overlap):
chunks = []
sections = []
total_tokens = 0
# Calculate the total tokens across all documents
for document in documents:
total_tokens += len(document.split(' ')) * self.word_token_rate
# Calculate the number of sections needed
num_sections = math.floor(total_tokens / chunk_token_threshold)
if num_sections < 1:
num_sections = 1 # Ensure there is at least one section
adjusted_chunk_threshold = total_tokens / num_sections
total_token_so_far = 0
current_chunk = []
for document in documents:
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
chunk = document.split(' ')
total_token_so_far += len(chunk) * 1.3
chunks.append(document)
else:
sections.append('\n\n'.join(chunks))
chunks = [document]
total_token_so_far = len(document.split(' ')) * 1.3
if chunks:
sections.append('\n\n'.join(chunks))
tokens = document.split(' ')
token_count = len(tokens) * self.word_token_rate
return sections
if total_token_so_far + token_count <= adjusted_chunk_threshold:
current_chunk.extend(tokens)
total_token_so_far += token_count
else:
# Ensure to handle the last section properly
if len(sections) == num_sections - 1:
current_chunk.extend(tokens)
continue
# Add overlap if specified
if overlap > 0 and current_chunk:
overlap_tokens = current_chunk[-overlap:]
current_chunk.extend(overlap_tokens)
sections.append(' '.join(current_chunk))
current_chunk = tokens
total_token_so_far = token_count
# Add the last chunk
if current_chunk:
sections.append(' '.join(current_chunk))
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
merged_sections = self._merge(
sections, self.chunk_token_threshold,
overlap= int(self.chunk_token_threshold * self.overlap_rate)
)
extracted_content = []
if self.provider.startswith("groq/"):
# Sequential processing with a delay

View File

@@ -164,4 +164,35 @@ Please provide your output within <blocks> tags, like this:
**Make sure to follow the user instruction to extract blocks aligin with the instruction.**
Remember, the output should be a complete, parsable JSON wrapped in <blocks> tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
Remember, the output should be a complete, parsable JSON wrapped in <blocks> tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION = """Here is the content from the URL:
<url>{URL}</url>
<url_content>
{HTML}
</url_content>
The user has made the following request for what information to extract from the above content:
<user_request>
{REQUEST}
</user_request>
<schema_block>
{SCHEMA}
</schema_block>
Please carefully read the URL content and the user's request. If the user provided a desired JSON schema in the <schema_block> above, extract the requested information from the URL content according to that schema. If no schema was provided, infer an appropriate JSON schema based on the user's request that will best capture the key information they are looking for.
Extraction instructions:
Return the extracted information as a list of JSON objects, with each object in the list corresponding to a block of content from the URL, in the same order as it appears on the page. Wrap the entire JSON list in <blocks> tags.
Quality Reflection:
Before outputting your final answer, double check that the JSON you are returning is complete, containing all the information requested by the user, and is valid JSON that could be parsed by json.loads() with no errors or omissions. The outputted JSON objects should fully match the schema, either provided or inferred.
Quality Score:
After reflecting, score the quality and completeness of the JSON data you are about to return on a scale of 1 to 5. Write the score inside <score> tags.
Result
Output the final list of JSON objects, wrapped in <blocks> tags."""

View File

@@ -42,7 +42,7 @@ class WebCrawler:
def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler")
result = self.run(
url='https://crawl4ai.uccode.io/',
url='https://google.com/',
word_count_threshold=5,
extraction_strategy= NoExtractionStrategy(),
bypass_cache=False,