Enhance crawler capabilities and documentation
- Add llm.txt generator - Added SSL certificate extraction in AsyncWebCrawler. - Introduced new content filters and chunking strategies for more robust data extraction. - Updated documentation.
This commit is contained in:
@@ -207,6 +207,8 @@ class CrawlerRunConfig:
|
||||
Default: None.
|
||||
excluded_tags (list of str or None): List of HTML tags to exclude from processing.
|
||||
Default: None.
|
||||
excluded_selector (str or None): CSS selector to exclude from processing.
|
||||
Default: None.
|
||||
keep_data_attributes (bool): If True, retain `data-*` attributes while removing unwanted attributes.
|
||||
Default: False.
|
||||
remove_forms (bool): If True, remove all `<form>` elements from the HTML.
|
||||
@@ -316,10 +318,14 @@ class CrawlerRunConfig:
|
||||
only_text: bool = False,
|
||||
css_selector: str = None,
|
||||
excluded_tags: list = None,
|
||||
excluded_selector: str = None,
|
||||
keep_data_attributes: bool = False,
|
||||
remove_forms: bool = False,
|
||||
prettiify: bool = False,
|
||||
|
||||
# SSL Parameters
|
||||
fetch_ssl_certificate: bool = False,
|
||||
|
||||
# Caching Parameters
|
||||
cache_mode=None,
|
||||
session_id: str = None,
|
||||
@@ -383,10 +389,14 @@ class CrawlerRunConfig:
|
||||
self.only_text = only_text
|
||||
self.css_selector = css_selector
|
||||
self.excluded_tags = excluded_tags or []
|
||||
self.excluded_selector = excluded_selector or ""
|
||||
self.keep_data_attributes = keep_data_attributes
|
||||
self.remove_forms = remove_forms
|
||||
self.prettiify = prettiify
|
||||
|
||||
# SSL Parameters
|
||||
self.fetch_ssl_certificate = fetch_ssl_certificate
|
||||
|
||||
# Caching Parameters
|
||||
self.cache_mode = cache_mode
|
||||
self.session_id = session_id
|
||||
@@ -464,10 +474,14 @@ class CrawlerRunConfig:
|
||||
only_text=kwargs.get("only_text", False),
|
||||
css_selector=kwargs.get("css_selector"),
|
||||
excluded_tags=kwargs.get("excluded_tags", []),
|
||||
excluded_selector=kwargs.get("excluded_selector", ""),
|
||||
keep_data_attributes=kwargs.get("keep_data_attributes", False),
|
||||
remove_forms=kwargs.get("remove_forms", False),
|
||||
prettiify=kwargs.get("prettiify", False),
|
||||
|
||||
# SSL Parameters
|
||||
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
|
||||
|
||||
# Caching Parameters
|
||||
cache_mode=kwargs.get("cache_mode"),
|
||||
session_id=kwargs.get("session_id"),
|
||||
@@ -521,70 +535,59 @@ class CrawlerRunConfig:
|
||||
url=kwargs.get("url"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def from_kwargs(kwargs: dict) -> "CrawlerRunConfig":
|
||||
# return CrawlerRunConfig(
|
||||
# word_count_threshold=kwargs.get("word_count_threshold", 200),
|
||||
# extraction_strategy=kwargs.get("extraction_strategy"),
|
||||
# chunking_strategy=kwargs.get("chunking_strategy"),
|
||||
# markdown_generator=kwargs.get("markdown_generator"),
|
||||
# content_filter=kwargs.get("content_filter"),
|
||||
# cache_mode=kwargs.get("cache_mode"),
|
||||
# session_id=kwargs.get("session_id"),
|
||||
# bypass_cache=kwargs.get("bypass_cache", False),
|
||||
# disable_cache=kwargs.get("disable_cache", False),
|
||||
# no_cache_read=kwargs.get("no_cache_read", False),
|
||||
# no_cache_write=kwargs.get("no_cache_write", False),
|
||||
# css_selector=kwargs.get("css_selector"),
|
||||
# screenshot=kwargs.get("screenshot", False),
|
||||
# pdf=kwargs.get("pdf", False),
|
||||
# verbose=kwargs.get("verbose", True),
|
||||
# only_text=kwargs.get("only_text", False),
|
||||
# image_description_min_word_threshold=kwargs.get(
|
||||
# "image_description_min_word_threshold",
|
||||
# IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||
# ),
|
||||
# prettiify=kwargs.get("prettiify", False),
|
||||
# js_code=kwargs.get(
|
||||
# "js_code"
|
||||
# ), # If not provided here, will default inside constructor
|
||||
# wait_for=kwargs.get("wait_for"),
|
||||
# js_only=kwargs.get("js_only", False),
|
||||
# wait_until=kwargs.get("wait_until", "domcontentloaded"),
|
||||
# page_timeout=kwargs.get("page_timeout", 60000),
|
||||
# ignore_body_visibility=kwargs.get("ignore_body_visibility", True),
|
||||
# adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
|
||||
# scan_full_page=kwargs.get("scan_full_page", False),
|
||||
# scroll_delay=kwargs.get("scroll_delay", 0.2),
|
||||
# process_iframes=kwargs.get("process_iframes", False),
|
||||
# remove_overlay_elements=kwargs.get("remove_overlay_elements", False),
|
||||
# delay_before_return_html=kwargs.get("delay_before_return_html", 0.1),
|
||||
# log_console=kwargs.get("log_console", False),
|
||||
# simulate_user=kwargs.get("simulate_user", False),
|
||||
# override_navigator=kwargs.get("override_navigator", False),
|
||||
# magic=kwargs.get("magic", False),
|
||||
# screenshot_wait_for=kwargs.get("screenshot_wait_for"),
|
||||
# screenshot_height_threshold=kwargs.get(
|
||||
# "screenshot_height_threshold", 20000
|
||||
# ),
|
||||
# mean_delay=kwargs.get("mean_delay", 0.1),
|
||||
# max_range=kwargs.get("max_range", 0.3),
|
||||
# semaphore_count=kwargs.get("semaphore_count", 5),
|
||||
# image_score_threshold=kwargs.get(
|
||||
# "image_score_threshold", IMAGE_SCORE_THRESHOLD
|
||||
# ),
|
||||
# exclude_social_media_domains=kwargs.get(
|
||||
# "exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
|
||||
# ),
|
||||
# exclude_external_links=kwargs.get("exclude_external_links", False),
|
||||
# exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
|
||||
# exclude_domains=kwargs.get("exclude_domains", []),
|
||||
# exclude_external_images=kwargs.get("exclude_external_images", False),
|
||||
# remove_forms=kwargs.get("remove_forms", False),
|
||||
# keep_data_attributes=kwargs.get("keep_data_attributes", False),
|
||||
# excluded_tags=kwargs.get("excluded_tags", []),
|
||||
# )
|
||||
|
||||
# Create a funciton returns dict of the object
|
||||
def to_dict(self):
|
||||
return {
|
||||
"word_count_threshold": self.word_count_threshold,
|
||||
"extraction_strategy": self.extraction_strategy,
|
||||
"chunking_strategy": self.chunking_strategy,
|
||||
"markdown_generator": self.markdown_generator,
|
||||
"content_filter": self.content_filter,
|
||||
"only_text": self.only_text,
|
||||
"css_selector": self.css_selector,
|
||||
"excluded_tags": self.excluded_tags,
|
||||
"excluded_selector": self.excluded_selector,
|
||||
"keep_data_attributes": self.keep_data_attributes,
|
||||
"remove_forms": self.remove_forms,
|
||||
"prettiify": self.prettiify,
|
||||
"fetch_ssl_certificate": self.fetch_ssl_certificate,
|
||||
"cache_mode": self.cache_mode,
|
||||
"session_id": self.session_id,
|
||||
"bypass_cache": self.bypass_cache,
|
||||
"disable_cache": self.disable_cache,
|
||||
"no_cache_read": self.no_cache_read,
|
||||
"no_cache_write": self.no_cache_write,
|
||||
"wait_until": self.wait_until,
|
||||
"page_timeout": self.page_timeout,
|
||||
"wait_for": self.wait_for,
|
||||
"wait_for_images": self.wait_for_images,
|
||||
"delay_before_return_html": self.delay_before_return_html,
|
||||
"mean_delay": self.mean_delay,
|
||||
"max_range": self.max_range,
|
||||
"semaphore_count": self.semaphore_count,
|
||||
"js_code": self.js_code,
|
||||
"js_only": self.js_only,
|
||||
"ignore_body_visibility": self.ignore_body_visibility,
|
||||
"scan_full_page": self.scan_full_page,
|
||||
"scroll_delay": self.scroll_delay,
|
||||
"process_iframes": self.process_iframes,
|
||||
"remove_overlay_elements": self.remove_overlay_elements,
|
||||
"simulate_user": self.simulate_user,
|
||||
"override_navigator": self.override_navigator,
|
||||
"magic": self.magic,
|
||||
"adjust_viewport_to_content": self.adjust_viewport_to_content,
|
||||
"screenshot": self.screenshot,
|
||||
"screenshot_wait_for": self.screenshot_wait_for,
|
||||
"screenshot_height_threshold": self.screenshot_height_threshold,
|
||||
"pdf": self.pdf,
|
||||
"image_description_min_word_threshold": self.image_description_min_word_threshold,
|
||||
"image_score_threshold": self.image_score_threshold,
|
||||
"exclude_external_images": self.exclude_external_images,
|
||||
"exclude_social_media_domains": self.exclude_social_media_domains,
|
||||
"exclude_external_links": self.exclude_external_links,
|
||||
"exclude_social_media_links": self.exclude_social_media_links,
|
||||
"exclude_domains": self.exclude_domains,
|
||||
"verbose": self.verbose,
|
||||
"log_console": self.log_console,
|
||||
"url": self.url,
|
||||
}
|
||||
|
||||
@@ -23,11 +23,7 @@ from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from .async_logger import AsyncLogger
|
||||
from playwright_stealth import StealthConfig, stealth_async
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from .utilities.ssl_utils import get_ssl_certificate
|
||||
|
||||
stealth_config = StealthConfig(
|
||||
webdriver=True,
|
||||
@@ -566,18 +562,6 @@ class AsyncCrawlerStrategy(ABC):
|
||||
async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def take_screenshot(self, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_user_agent(self, user_agent: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
||||
def __init__(
|
||||
@@ -928,6 +912,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
||||
page.on("pageerror", lambda e: log_consol(e, "error"))
|
||||
|
||||
try:
|
||||
# Get SSL certificate information if requested and URL is HTTPS
|
||||
ssl_certificate = None
|
||||
if config.fetch_ssl_certificate and url.startswith('https://'):
|
||||
ssl_certificate = get_ssl_certificate(url)
|
||||
|
||||
# Set up download handling
|
||||
if self.browser_config.accept_downloads:
|
||||
page.on(
|
||||
@@ -1155,6 +1144,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
||||
screenshot=screenshot_data,
|
||||
pdf_data=pdf_data,
|
||||
get_delayed_content=get_delayed_content,
|
||||
ssl_certificate=ssl_certificate,
|
||||
downloaded_files=(
|
||||
self._downloaded_files if self._downloaded_files else None
|
||||
),
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Any, List, Optional, Awaitable
|
||||
import os, sys, shutil
|
||||
import tempfile, subprocess
|
||||
from playwright.async_api import async_playwright, Page, Browser, Error
|
||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from pathlib import Path
|
||||
from playwright.async_api import ProxySettings
|
||||
from pydantic import BaseModel
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from .models import AsyncCrawlResponse
|
||||
from .utils import create_box_message
|
||||
from .user_agent_generator import UserAgentGenerator
|
||||
from playwright_stealth import StealthConfig, stealth_async
|
||||
|
||||
|
||||
class ManagedBrowser:
|
||||
def __init__(self, browser_type: str = "chromium", user_data_dir: Optional[str] = None, headless: bool = False, logger = None, host: str = "localhost", debugging_port: int = 9222):
|
||||
self.browser_type = browser_type
|
||||
self.user_data_dir = user_data_dir
|
||||
self.headless = headless
|
||||
self.browser_process = None
|
||||
self.temp_dir = None
|
||||
self.debugging_port = debugging_port
|
||||
self.host = host
|
||||
self.logger = logger
|
||||
self.shutting_down = False
|
||||
|
||||
async def start(self) -> str:
|
||||
"""
|
||||
Starts the browser process and returns the CDP endpoint URL.
|
||||
If user_data_dir is not provided, creates a temporary directory.
|
||||
"""
|
||||
|
||||
# Create temp dir if needed
|
||||
if not self.user_data_dir:
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="browser-profile-")
|
||||
self.user_data_dir = self.temp_dir
|
||||
|
||||
# Get browser path and args based on OS and browser type
|
||||
browser_path = self._get_browser_path()
|
||||
args = self._get_browser_args()
|
||||
|
||||
# Start browser process
|
||||
try:
|
||||
self.browser_process = subprocess.Popen(
|
||||
args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
# Monitor browser process output for errors
|
||||
asyncio.create_task(self._monitor_browser_process())
|
||||
await asyncio.sleep(2) # Give browser time to start
|
||||
return f"http://{self.host}:{self.debugging_port}"
|
||||
except Exception as e:
|
||||
await self.cleanup()
|
||||
raise Exception(f"Failed to start browser: {e}")
|
||||
|
||||
async def _monitor_browser_process(self):
|
||||
"""Monitor the browser process for unexpected termination."""
|
||||
if self.browser_process:
|
||||
try:
|
||||
stdout, stderr = await asyncio.gather(
|
||||
asyncio.to_thread(self.browser_process.stdout.read),
|
||||
asyncio.to_thread(self.browser_process.stderr.read)
|
||||
)
|
||||
|
||||
# Check shutting_down flag BEFORE logging anything
|
||||
if self.browser_process.poll() is not None:
|
||||
if not self.shutting_down:
|
||||
self.logger.error(
|
||||
message="Browser process terminated unexpectedly | Code: {code} | STDOUT: {stdout} | STDERR: {stderr}",
|
||||
tag="ERROR",
|
||||
params={
|
||||
"code": self.browser_process.returncode,
|
||||
"stdout": stdout.decode(),
|
||||
"stderr": stderr.decode()
|
||||
}
|
||||
)
|
||||
await self.cleanup()
|
||||
else:
|
||||
self.logger.info(
|
||||
message="Browser process terminated normally | Code: {code}",
|
||||
tag="INFO",
|
||||
params={"code": self.browser_process.returncode}
|
||||
)
|
||||
except Exception as e:
|
||||
if not self.shutting_down:
|
||||
self.logger.error(
|
||||
message="Error monitoring browser process: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
)
|
||||
|
||||
def _get_browser_path(self) -> str:
|
||||
"""Returns the browser executable path based on OS and browser type"""
|
||||
if sys.platform == "darwin": # macOS
|
||||
paths = {
|
||||
"chromium": "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome",
|
||||
"firefox": "/Applications/Firefox.app/Contents/MacOS/firefox",
|
||||
"webkit": "/Applications/Safari.app/Contents/MacOS/Safari"
|
||||
}
|
||||
elif sys.platform == "win32": # Windows
|
||||
paths = {
|
||||
"chromium": "C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe",
|
||||
"firefox": "C:\\Program Files\\Mozilla Firefox\\firefox.exe",
|
||||
"webkit": None # WebKit not supported on Windows
|
||||
}
|
||||
else: # Linux
|
||||
paths = {
|
||||
"chromium": "google-chrome",
|
||||
"firefox": "firefox",
|
||||
"webkit": None # WebKit not supported on Linux
|
||||
}
|
||||
|
||||
return paths.get(self.browser_type)
|
||||
|
||||
def _get_browser_args(self) -> List[str]:
|
||||
"""Returns browser-specific command line arguments"""
|
||||
base_args = [self._get_browser_path()]
|
||||
|
||||
if self.browser_type == "chromium":
|
||||
args = [
|
||||
f"--remote-debugging-port={self.debugging_port}",
|
||||
f"--user-data-dir={self.user_data_dir}",
|
||||
]
|
||||
if self.headless:
|
||||
args.append("--headless=new")
|
||||
elif self.browser_type == "firefox":
|
||||
args = [
|
||||
"--remote-debugging-port", str(self.debugging_port),
|
||||
"--profile", self.user_data_dir,
|
||||
]
|
||||
if self.headless:
|
||||
args.append("--headless")
|
||||
else:
|
||||
raise NotImplementedError(f"Browser type {self.browser_type} not supported")
|
||||
|
||||
return base_args + args
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup browser process and temporary directory"""
|
||||
# Set shutting_down flag BEFORE any termination actions
|
||||
self.shutting_down = True
|
||||
|
||||
if self.browser_process:
|
||||
try:
|
||||
self.browser_process.terminate()
|
||||
# Wait for process to end gracefully
|
||||
for _ in range(10): # 10 attempts, 100ms each
|
||||
if self.browser_process.poll() is not None:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Force kill if still running
|
||||
if self.browser_process.poll() is None:
|
||||
self.browser_process.kill()
|
||||
await asyncio.sleep(0.1) # Brief wait for kill to take effect
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
message="Error terminating browser: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
)
|
||||
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
try:
|
||||
shutil.rmtree(self.temp_dir)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
message="Error removing temporary directory: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
)
|
||||
|
||||
@@ -42,6 +42,26 @@ class AsyncWebCrawler:
|
||||
"""
|
||||
Asynchronous web crawler with flexible caching capabilities.
|
||||
|
||||
There are two ways to use the crawler:
|
||||
|
||||
1. Using context manager (recommended for simple cases):
|
||||
```python
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(url="https://example.com")
|
||||
```
|
||||
|
||||
2. Using explicit lifecycle management (recommended for long-running applications):
|
||||
```python
|
||||
crawler = AsyncWebCrawler()
|
||||
await crawler.start()
|
||||
|
||||
# Use the crawler multiple times
|
||||
result1 = await crawler.arun(url="https://example.com")
|
||||
result2 = await crawler.arun(url="https://another.com")
|
||||
|
||||
await crawler.close()
|
||||
```
|
||||
|
||||
Migration Guide:
|
||||
Old way (deprecated):
|
||||
crawler = AsyncWebCrawler(always_by_pass_cache=True, browser_type="chromium", headless=True)
|
||||
@@ -127,16 +147,49 @@ class AsyncWebCrawler:
|
||||
|
||||
self.ready = False
|
||||
|
||||
async def __aenter__(self):
|
||||
async def start(self):
|
||||
"""
|
||||
Start the crawler explicitly without using context manager.
|
||||
This is equivalent to using 'async with' but gives more control over the lifecycle.
|
||||
|
||||
This method will:
|
||||
1. Initialize the browser and context
|
||||
2. Perform warmup sequence
|
||||
3. Return the crawler instance for method chaining
|
||||
|
||||
Returns:
|
||||
AsyncWebCrawler: The initialized crawler instance
|
||||
"""
|
||||
await self.crawler_strategy.__aenter__()
|
||||
await self.awarmup()
|
||||
return self
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close the crawler explicitly without using context manager.
|
||||
This should be called when you're done with the crawler if you used start().
|
||||
|
||||
This method will:
|
||||
1. Clean up browser resources
|
||||
2. Close any open pages and contexts
|
||||
"""
|
||||
await self.crawler_strategy.__aexit__(None, None, None)
|
||||
|
||||
async def __aenter__(self):
|
||||
return await self.start()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.crawler_strategy.__aexit__(exc_type, exc_val, exc_tb)
|
||||
await self.close()
|
||||
|
||||
async def awarmup(self):
|
||||
"""Initialize the crawler with warm-up sequence."""
|
||||
"""
|
||||
Initialize the crawler with warm-up sequence.
|
||||
|
||||
This method:
|
||||
1. Logs initialization info
|
||||
2. Sets up browser configuration
|
||||
3. Marks the crawler as ready
|
||||
"""
|
||||
self.logger.info(f"Crawl4AI {crawl4ai_version}", tag="INIT")
|
||||
self.ready = True
|
||||
|
||||
@@ -144,7 +197,7 @@ class AsyncWebCrawler:
|
||||
async def nullcontext(self):
|
||||
"""异步空上下文管理器"""
|
||||
yield
|
||||
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
url: str,
|
||||
@@ -204,14 +257,14 @@ class AsyncWebCrawler:
|
||||
try:
|
||||
# Handle configuration
|
||||
if crawler_config is not None:
|
||||
if any(param is not None for param in [
|
||||
word_count_threshold, extraction_strategy, chunking_strategy,
|
||||
content_filter, cache_mode, css_selector, screenshot, pdf
|
||||
]):
|
||||
self.logger.warning(
|
||||
message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.",
|
||||
tag="WARNING"
|
||||
)
|
||||
# if any(param is not None for param in [
|
||||
# word_count_threshold, extraction_strategy, chunking_strategy,
|
||||
# content_filter, cache_mode, css_selector, screenshot, pdf
|
||||
# ]):
|
||||
# self.logger.warning(
|
||||
# message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.",
|
||||
# tag="WARNING"
|
||||
# )
|
||||
config = crawler_config
|
||||
else:
|
||||
# Merge all parameters into a single kwargs dict for config creation
|
||||
@@ -322,6 +375,7 @@ class AsyncWebCrawler:
|
||||
screenshot=screenshot_data,
|
||||
pdf_data=pdf_data,
|
||||
verbose=config.verbose,
|
||||
is_raw_html = True if url.startswith("raw:") else False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -330,9 +384,11 @@ class AsyncWebCrawler:
|
||||
crawl_result.status_code = async_response.status_code
|
||||
crawl_result.response_headers = async_response.response_headers
|
||||
crawl_result.downloaded_files = async_response.downloaded_files
|
||||
crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate
|
||||
else:
|
||||
crawl_result.status_code = 200
|
||||
crawl_result.response_headers = cached_result.response_headers if cached_result else {}
|
||||
crawl_result.ssl_certificate = cached_result.ssl_certificate if cached_result else None # Add SSL certificate from cache
|
||||
|
||||
crawl_result.success = bool(html)
|
||||
crawl_result.session_id = getattr(config, 'session_id', None)
|
||||
@@ -416,15 +472,20 @@ class AsyncWebCrawler:
|
||||
scrapping_strategy = WebScrapingStrategy(logger=self.logger)
|
||||
|
||||
# Process HTML content
|
||||
params = {k:v for k, v in config.to_dict().items() if k not in ["url"]}
|
||||
# add keys from kwargs to params that doesn't exist in params
|
||||
params.update({k:v for k, v in kwargs.items() if k not in params.keys()})
|
||||
|
||||
result = scrapping_strategy.scrap(
|
||||
url,
|
||||
html,
|
||||
word_count_threshold=config.word_count_threshold,
|
||||
css_selector=config.css_selector,
|
||||
only_text=config.only_text,
|
||||
image_description_min_word_threshold=config.image_description_min_word_threshold,
|
||||
content_filter=config.content_filter,
|
||||
**kwargs
|
||||
**params,
|
||||
# word_count_threshold=config.word_count_threshold,
|
||||
# css_selector=config.css_selector,
|
||||
# only_text=config.only_text,
|
||||
# image_description_min_word_threshold=config.image_description_min_word_threshold,
|
||||
# content_filter=config.content_filter,
|
||||
# **kwargs
|
||||
)
|
||||
|
||||
if result is None:
|
||||
@@ -476,15 +537,27 @@ class AsyncWebCrawler:
|
||||
|
||||
t1 = time.perf_counter()
|
||||
|
||||
# Handle different extraction strategy types
|
||||
if isinstance(config.extraction_strategy, (JsonCssExtractionStrategy, JsonXPathExtractionStrategy)):
|
||||
config.extraction_strategy.verbose = verbose
|
||||
extracted_content = config.extraction_strategy.run(url, [html])
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
||||
else:
|
||||
sections = config.chunking_strategy.chunk(markdown)
|
||||
extracted_content = config.extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
||||
# Choose content based on input_format
|
||||
content_format = config.extraction_strategy.input_format
|
||||
if content_format == "fit_markdown" and not markdown_result.fit_markdown:
|
||||
self.logger.warning(
|
||||
message="Fit markdown requested but not available. Falling back to raw markdown.",
|
||||
tag="EXTRACT",
|
||||
params={"url": _url}
|
||||
)
|
||||
content_format = "markdown"
|
||||
|
||||
content = {
|
||||
"markdown": markdown,
|
||||
"html": html,
|
||||
"fit_markdown": markdown_result.raw_markdown
|
||||
}.get(content_format, markdown)
|
||||
|
||||
# Use IdentityChunking for HTML input, otherwise use provided chunking strategy
|
||||
chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy
|
||||
sections = chunking.chunk(content)
|
||||
extracted_content = config.extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
||||
|
||||
# Log extraction completion
|
||||
self.logger.info(
|
||||
@@ -683,5 +756,3 @@ class AsyncWebCrawler:
|
||||
async def aget_cache_size(self):
|
||||
"""Get the total number of cached items."""
|
||||
return await async_db_manager.aget_total_count()
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,12 @@ class ChunkingStrategy(ABC):
|
||||
Abstract method to chunk the given text.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Create an identity chunking strategy f(x) = [x]
|
||||
class IdentityChunking(ChunkingStrategy):
|
||||
def chunk(self, text: str) -> list:
|
||||
return [text]
|
||||
|
||||
# Regex-based chunking
|
||||
class RegexChunking(ChunkingStrategy):
|
||||
def __init__(self, patterns=None, **kwargs):
|
||||
@@ -127,7 +132,6 @@ class SlidingWindowChunking(ChunkingStrategy):
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class OverlappingWindowChunking(ChunkingStrategy):
|
||||
def __init__(self, window_size=1000, overlap=100, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import click
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from .docs_manager import DocsManager
|
||||
from .async_logger import AsyncLogger
|
||||
|
||||
@@ -10,20 +9,19 @@ logger = AsyncLogger(verbose=True)
|
||||
docs_manager = DocsManager(logger)
|
||||
|
||||
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
||||
"""Helper function to print formatted tables"""
|
||||
col_widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
|
||||
border = '+' + '+'.join('-' * (width + 2 * padding) for width in col_widths) + '+'
|
||||
"""Print formatted table with headers and rows"""
|
||||
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
|
||||
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+'
|
||||
|
||||
def format_row(row):
|
||||
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
|
||||
for cell, w in zip(row, widths)) + '|'
|
||||
|
||||
def print_row(row):
|
||||
return '|' + '|'.join(
|
||||
f"{str(cell):{' '}<{width}}" for cell, width in zip(row, col_widths)
|
||||
) + '|'
|
||||
|
||||
click.echo(border)
|
||||
click.echo(print_row(headers))
|
||||
click.echo(format_row(headers))
|
||||
click.echo(border)
|
||||
for row in rows:
|
||||
click.echo(print_row(row))
|
||||
click.echo(format_row(row))
|
||||
click.echo(border)
|
||||
|
||||
@click.group()
|
||||
@@ -33,63 +31,75 @@ def cli():
|
||||
|
||||
@cli.group()
|
||||
def docs():
|
||||
"""Documentation and LLM text operations"""
|
||||
"""Documentation operations"""
|
||||
pass
|
||||
|
||||
@docs.command()
|
||||
@click.argument('sections', nargs=-1)
|
||||
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended',
|
||||
help='Documentation detail level')
|
||||
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended')
|
||||
def combine(sections: tuple, mode: str):
|
||||
"""Combine documentation sections.
|
||||
|
||||
If no sections are specified, combines all available sections.
|
||||
"""
|
||||
"""Combine documentation sections"""
|
||||
try:
|
||||
asyncio.run(docs_manager.ensure_docs_exist())
|
||||
result = docs_manager.concatenate_docs(sections, mode)
|
||||
click.echo(result)
|
||||
click.echo(docs_manager.generate(sections, mode))
|
||||
except Exception as e:
|
||||
logger.error(str(e), tag="ERROR")
|
||||
sys.exit(1)
|
||||
|
||||
@docs.command()
|
||||
@click.argument('query')
|
||||
@click.option('--top-k', '-k', default=5, help='Number of top results to return')
|
||||
def search(query: str, top_k: int):
|
||||
"""Search through documentation questions"""
|
||||
@click.option('--top-k', '-k', default=5)
|
||||
@click.option('--build-index', is_flag=True, help='Build index if missing')
|
||||
def search(query: str, top_k: int, build_index: bool):
|
||||
"""Search documentation"""
|
||||
try:
|
||||
results = docs_manager.search_questions(query, top_k)
|
||||
click.echo(results)
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
@docs.command()
|
||||
def list():
|
||||
"""List available documentation sections"""
|
||||
try:
|
||||
file_map = docs_manager.get_file_map()
|
||||
rows = [[num, name] for name, num in file_map.items()]
|
||||
rows.sort(key=lambda x: int(x[0]))
|
||||
print_table(['Number', 'Section Name'], rows)
|
||||
result = docs_manager.search(query, top_k)
|
||||
if result == "No search index available. Call build_search_index() first.":
|
||||
if build_index or click.confirm('No search index found. Build it now?'):
|
||||
asyncio.run(docs_manager.llm_text.generate_index_files())
|
||||
result = docs_manager.search(query, top_k)
|
||||
click.echo(result)
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
@docs.command()
|
||||
def update():
|
||||
"""Update local documentation cache from GitHub"""
|
||||
"""Update docs from GitHub"""
|
||||
try:
|
||||
docs_manager = DocsManager()
|
||||
docs_manager.update_docs()
|
||||
asyncio.run(docs_manager.fetch_docs())
|
||||
click.echo("Documentation updated successfully")
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
@docs.command()
|
||||
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files')
|
||||
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache')
|
||||
def index(force_facts: bool, clear_cache: bool):
|
||||
"""Build or rebuild search indexes"""
|
||||
try:
|
||||
asyncio.run(docs_manager.ensure_docs_exist())
|
||||
asyncio.run(docs_manager.llm_text.generate_index_files(
|
||||
force_generate_facts=force_facts,
|
||||
clear_bm25_cache=clear_cache
|
||||
))
|
||||
click.echo("Search indexes built successfully")
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Add docs list command
|
||||
@docs.command()
|
||||
def list():
|
||||
"""List available documentation sections"""
|
||||
try:
|
||||
sections = docs_manager.list()
|
||||
print_table(["Sections"], [[section] for section in sections])
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
cli()
|
||||
@@ -1,4 +1,5 @@
|
||||
import re # Point 1: Pre-Compile Regular Expressions
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -16,7 +17,8 @@ from .models import MarkdownGenerationResult
|
||||
from .utils import (
|
||||
extract_metadata,
|
||||
normalize_url,
|
||||
is_external_url
|
||||
is_external_url,
|
||||
get_base_domain,
|
||||
)
|
||||
|
||||
|
||||
@@ -341,6 +343,7 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
# if element.name == 'img':
|
||||
# process_image(element, url, 0, 1)
|
||||
# return True
|
||||
base_domain = kwargs.get("base_domain", get_base_domain(url))
|
||||
|
||||
if element.name in ['script', 'style', 'link', 'meta', 'noscript']:
|
||||
element.decompose()
|
||||
@@ -348,8 +351,10 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
|
||||
keep_element = False
|
||||
|
||||
exclude_social_media_domains = SOCIAL_MEDIA_DOMAINS + kwargs.get('exclude_social_media_domains', [])
|
||||
exclude_social_media_domains = list(set(exclude_social_media_domains))
|
||||
exclude_domains = kwargs.get('exclude_domains', [])
|
||||
# exclude_social_media_domains = kwargs.get('exclude_social_media_domains', set(SOCIAL_MEDIA_DOMAINS))
|
||||
# exclude_social_media_domains = SOCIAL_MEDIA_DOMAINS + kwargs.get('exclude_social_media_domains', [])
|
||||
# exclude_social_media_domains = list(set(exclude_social_media_domains))
|
||||
|
||||
try:
|
||||
if element.name == 'a' and element.get('href'):
|
||||
@@ -369,33 +374,43 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
link_data = {
|
||||
'href': normalized_href,
|
||||
'text': element.get_text().strip(),
|
||||
'title': element.get('title', '').strip()
|
||||
'title': element.get('title', '').strip(),
|
||||
'base_domain': base_domain
|
||||
}
|
||||
|
||||
is_external = is_external_url(normalized_href, base_domain)
|
||||
|
||||
keep_element = True
|
||||
|
||||
# Check for duplicates and add to appropriate dictionary
|
||||
is_external = is_external_url(normalized_href, url_base)
|
||||
# Handle external link exclusions
|
||||
if is_external:
|
||||
link_base_domain = get_base_domain(normalized_href)
|
||||
link_data['base_domain'] = link_base_domain
|
||||
if kwargs.get('exclude_external_links', False):
|
||||
element.decompose()
|
||||
return False
|
||||
# elif kwargs.get('exclude_social_media_links', False):
|
||||
# if link_base_domain in exclude_social_media_domains:
|
||||
# element.decompose()
|
||||
# return False
|
||||
# if any(domain in normalized_href.lower() for domain in exclude_social_media_domains):
|
||||
# element.decompose()
|
||||
# return False
|
||||
elif exclude_domains:
|
||||
if link_base_domain in exclude_domains:
|
||||
element.decompose()
|
||||
return False
|
||||
# if any(domain in normalized_href.lower() for domain in kwargs.get('exclude_domains', [])):
|
||||
# element.decompose()
|
||||
# return False
|
||||
|
||||
if is_external:
|
||||
if normalized_href not in external_links_dict:
|
||||
external_links_dict[normalized_href] = link_data
|
||||
else:
|
||||
if normalized_href not in internal_links_dict:
|
||||
internal_links_dict[normalized_href] = link_data
|
||||
|
||||
keep_element = True
|
||||
|
||||
# Handle external link exclusions
|
||||
if is_external:
|
||||
if kwargs.get('exclude_external_links', False):
|
||||
element.decompose()
|
||||
return False
|
||||
elif kwargs.get('exclude_social_media_links', False):
|
||||
if any(domain in normalized_href.lower() for domain in exclude_social_media_domains):
|
||||
element.decompose()
|
||||
return False
|
||||
elif kwargs.get('exclude_domains', []):
|
||||
if any(domain in normalized_href.lower() for domain in kwargs.get('exclude_domains', [])):
|
||||
element.decompose()
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error processing links: {str(e)}")
|
||||
@@ -414,26 +429,40 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
if 'srcset' in element.attrs:
|
||||
src = element.attrs['srcset'].split(',')[0].split(' ')[0]
|
||||
|
||||
# If image src is internal, then skip
|
||||
if not is_external_url(src, base_domain):
|
||||
return True
|
||||
|
||||
image_src_base_domain = get_base_domain(src)
|
||||
|
||||
# Check flag if we should remove external images
|
||||
if kwargs.get('exclude_external_images', False):
|
||||
src_url_base = src.split('/')[2]
|
||||
url_base = url.split('/')[2]
|
||||
if url_base not in src_url_base:
|
||||
element.decompose()
|
||||
return False
|
||||
element.decompose()
|
||||
return False
|
||||
# src_url_base = src.split('/')[2]
|
||||
# url_base = url.split('/')[2]
|
||||
# if url_base not in src_url_base:
|
||||
# element.decompose()
|
||||
# return False
|
||||
|
||||
if not kwargs.get('exclude_external_images', False) and kwargs.get('exclude_social_media_links', False):
|
||||
src_url_base = src.split('/')[2]
|
||||
url_base = url.split('/')[2]
|
||||
if any(domain in src for domain in exclude_social_media_domains):
|
||||
element.decompose()
|
||||
return False
|
||||
# if kwargs.get('exclude_social_media_links', False):
|
||||
# if image_src_base_domain in exclude_social_media_domains:
|
||||
# element.decompose()
|
||||
# return False
|
||||
# src_url_base = src.split('/')[2]
|
||||
# url_base = url.split('/')[2]
|
||||
# if any(domain in src for domain in exclude_social_media_domains):
|
||||
# element.decompose()
|
||||
# return False
|
||||
|
||||
# Handle exclude domains
|
||||
if kwargs.get('exclude_domains', []):
|
||||
if any(domain in src for domain in kwargs.get('exclude_domains', [])):
|
||||
if exclude_domains:
|
||||
if image_src_base_domain in exclude_domains:
|
||||
element.decompose()
|
||||
return False
|
||||
# if any(domain in src for domain in kwargs.get('exclude_domains', [])):
|
||||
# element.decompose()
|
||||
# return False
|
||||
|
||||
return True # Always keep image elements
|
||||
except Exception as e:
|
||||
@@ -511,6 +540,7 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
body = soup.body
|
||||
base_domain = get_base_domain(url)
|
||||
|
||||
try:
|
||||
meta = extract_metadata("", soup)
|
||||
@@ -556,10 +586,16 @@ class WebScrapingStrategy(ContentScrapingStrategy):
|
||||
for el in selected_elements:
|
||||
body.append(el)
|
||||
|
||||
kwargs['exclude_social_media_domains'] = set(kwargs.get('exclude_social_media_domains', []) + SOCIAL_MEDIA_DOMAINS)
|
||||
kwargs['exclude_domains'] = set(kwargs.get('exclude_domains', []))
|
||||
if kwargs.get('exclude_social_media_links', False):
|
||||
kwargs['exclude_domains'] = kwargs['exclude_domains'].union(kwargs['exclude_social_media_domains'])
|
||||
|
||||
result_obj = self.process_element(
|
||||
url,
|
||||
body,
|
||||
word_count_threshold = word_count_threshold,
|
||||
base_domain=base_domain,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -1,59 +1,67 @@
|
||||
import os
|
||||
import requests
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from .async_logger import AsyncLogger
|
||||
from .llmtxt import LLMTextManager
|
||||
from crawl4ai.async_logger import AsyncLogger
|
||||
from crawl4ai.llmtxt import AsyncLLMTextManager
|
||||
|
||||
class DocsManager:
|
||||
BASE_URL = "https://raw.githubusercontent.com/unclecode/crawl4ai/main/docs/llm.txt"
|
||||
|
||||
def __init__(self, logger: Optional[AsyncLogger] = None):
|
||||
def __init__(self, logger=None):
|
||||
self.docs_dir = Path.home() / ".crawl4ai" / "docs"
|
||||
self.local_docs = Path(__file__).parent.parent / "docs" / "llm.txt"
|
||||
self.docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logger = logger or AsyncLogger(verbose=True)
|
||||
self.llm_text = LLMTextManager(self.docs_dir, self.logger)
|
||||
|
||||
self.llm_text = AsyncLLMTextManager(self.docs_dir, self.logger)
|
||||
|
||||
async def ensure_docs_exist(self):
|
||||
"""Ensure docs are downloaded, fetch if not present"""
|
||||
"""Fetch docs if not present"""
|
||||
if not any(self.docs_dir.iterdir()):
|
||||
self.logger.info("Documentation not found, downloading...", tag="DOCS")
|
||||
await self.update_docs()
|
||||
|
||||
async def update_docs(self) -> bool:
|
||||
"""Always fetch latest docs"""
|
||||
await self.fetch_docs()
|
||||
|
||||
async def fetch_docs(self) -> bool:
|
||||
"""Copy from local docs or download from GitHub"""
|
||||
try:
|
||||
self.logger.info("Fetching documentation files...", tag="DOCS")
|
||||
|
||||
# Get file list
|
||||
response = requests.get(f"{self.BASE_URL}/files.json")
|
||||
# Try local first
|
||||
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))):
|
||||
# Empty the local docs directory
|
||||
for file_path in self.docs_dir.glob("*.md"):
|
||||
file_path.unlink()
|
||||
# for file_path in self.docs_dir.glob("*.tokens"):
|
||||
# file_path.unlink()
|
||||
for file_path in self.local_docs.glob("*.md"):
|
||||
shutil.copy2(file_path, self.docs_dir / file_path.name)
|
||||
# for file_path in self.local_docs.glob("*.tokens"):
|
||||
# shutil.copy2(file_path, self.docs_dir / file_path.name)
|
||||
return True
|
||||
|
||||
# Fallback to GitHub
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
|
||||
headers={'Accept': 'application/vnd.github.v3+json'}
|
||||
)
|
||||
response.raise_for_status()
|
||||
files = response.json()["files"]
|
||||
|
||||
# Download each file
|
||||
for file in files:
|
||||
response = requests.get(f"{self.BASE_URL}/{file}")
|
||||
response.raise_for_status()
|
||||
|
||||
file_path = self.docs_dir / file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(response.text)
|
||||
|
||||
self.logger.debug(f"Downloaded {file}", tag="DOCS")
|
||||
|
||||
self.logger.success("Documentation updated successfully", tag="DOCS")
|
||||
for item in response.json():
|
||||
if item['type'] == 'file' and item['name'].endswith('.md'):
|
||||
content = requests.get(item['download_url']).text
|
||||
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to update documentation: {str(e)}", tag="ERROR")
|
||||
self.logger.error(f"Failed to fetch docs: {str(e)}")
|
||||
raise
|
||||
|
||||
def list(self) -> list[str]:
|
||||
"""List available topics"""
|
||||
names = [file_path.stem for file_path in self.docs_dir.glob("*.md")]
|
||||
# Remove [0-9]+_ prefix
|
||||
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
|
||||
# Exclude those end with .xs.md and .q.md
|
||||
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")]
|
||||
return names
|
||||
|
||||
# Delegate LLM text operations to LLMTextManager
|
||||
def get_file_map(self) -> dict:
|
||||
return self.llm_text.get_file_map()
|
||||
def generate(self, sections, mode="extended"):
|
||||
return self.llm_text.generate(sections, mode)
|
||||
|
||||
def concatenate_docs(self, sections: List[str], mode: str) -> str:
|
||||
return self.llm_text.concatenate_docs(sections, mode)
|
||||
|
||||
def search_questions(self, query: str, top_k: int = 5) -> str:
|
||||
return self.llm_text.search_questions(query, top_k)
|
||||
def search(self, query: str, top_k: int = 5):
|
||||
return self.llm_text.search(query, top_k)
|
||||
@@ -6,6 +6,7 @@ import json, time
|
||||
from .prompts import *
|
||||
from .config import *
|
||||
from .utils import *
|
||||
from .models import *
|
||||
from functools import partial
|
||||
from .model_loader import *
|
||||
import math
|
||||
@@ -13,13 +14,23 @@ import numpy as np
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
from lxml import html, etree
|
||||
from dataclasses import dataclass
|
||||
|
||||
class ExtractionStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for all extraction strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, input_format: str = "markdown", **kwargs):
|
||||
"""
|
||||
Initialize the extraction strategy.
|
||||
|
||||
Args:
|
||||
input_format: Content format to use for extraction.
|
||||
Options: "markdown" (default), "html", "fit_markdown"
|
||||
**kwargs: Additional keyword arguments
|
||||
"""
|
||||
self.input_format = input_format
|
||||
self.DEL = "<|DEL|>"
|
||||
self.name = self.__class__.__name__
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
@@ -62,6 +73,8 @@ class NoExtractionStrategy(ExtractionStrategy):
|
||||
# Strategies using LLM-based extraction for text data #
|
||||
#######################################################
|
||||
|
||||
|
||||
|
||||
class LLMExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self,
|
||||
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
|
||||
@@ -73,7 +86,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
:param api_token: The API token for the provider.
|
||||
:param instruction: The instruction to use for the LLM model.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
self.api_token = api_token or PROVIDER_MODELS.get(provider, "no-token") or os.getenv("OPENAI_API_KEY")
|
||||
self.instruction = instruction
|
||||
@@ -93,6 +106,8 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
self.chunk_token_threshold = 1e9
|
||||
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
self.usages = [] # Store individual usages
|
||||
self.total_usage = TokenUsage() # Accumulated usage
|
||||
|
||||
if not self.api_token:
|
||||
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
|
||||
@@ -129,6 +144,21 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
base_url=self.api_base or self.base_url,
|
||||
extra_args = self.extra_args
|
||||
) # , json_response=self.extract_type == "schema")
|
||||
# Track usage
|
||||
usage = TokenUsage(
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
completion_tokens_details=response.usage.completion_tokens_details.__dict__ if response.usage.completion_tokens_details else {},
|
||||
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ if response.usage.prompt_tokens_details else {}
|
||||
)
|
||||
self.usages.append(usage)
|
||||
|
||||
# Update totals
|
||||
self.total_usage.completion_tokens += usage.completion_tokens
|
||||
self.total_usage.prompt_tokens += usage.prompt_tokens
|
||||
self.total_usage.total_tokens += usage.total_tokens
|
||||
|
||||
try:
|
||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||
blocks = json.loads(blocks)
|
||||
@@ -238,6 +268,22 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
|
||||
|
||||
return extracted_content
|
||||
|
||||
|
||||
def show_usage(self) -> None:
|
||||
"""Print a detailed token usage report showing total and per-request usage."""
|
||||
print("\n=== Token Usage Summary ===")
|
||||
print(f"{'Type':<15} {'Count':>12}")
|
||||
print("-" * 30)
|
||||
print(f"{'Completion':<15} {self.total_usage.completion_tokens:>12,}")
|
||||
print(f"{'Prompt':<15} {self.total_usage.prompt_tokens:>12,}")
|
||||
print(f"{'Total':<15} {self.total_usage.total_tokens:>12,}")
|
||||
|
||||
print("\n=== Usage History ===")
|
||||
print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}")
|
||||
print("-" * 48)
|
||||
for i, usage in enumerate(self.usages, 1):
|
||||
print(f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}")
|
||||
|
||||
|
||||
#######################################################
|
||||
@@ -256,7 +302,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
linkage_method (str): The linkage method for hierarchical clustering.
|
||||
top_k (int): Number of top categories to extract.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -537,7 +583,7 @@ class TopicExtractionStrategy(ExtractionStrategy):
|
||||
:param num_keywords: Number of keywords to represent each topic segment.
|
||||
"""
|
||||
import nltk
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.num_keywords = num_keywords
|
||||
self.tokenizer = nltk.TextTilingTokenizer()
|
||||
|
||||
@@ -604,6 +650,7 @@ class ContentSummarizationStrategy(ExtractionStrategy):
|
||||
|
||||
:param model_name: The model to use for summarization.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
from transformers import pipeline
|
||||
self.summarizer = pipeline("summarization", model=model_name)
|
||||
|
||||
@@ -809,6 +856,10 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
pass
|
||||
|
||||
class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
kwargs['input_format'] = 'html' # Force HTML input
|
||||
super().__init__(schema, **kwargs)
|
||||
|
||||
def _parse_html(self, html_content: str):
|
||||
return BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
@@ -829,6 +880,10 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
||||
return element.get(attribute)
|
||||
|
||||
class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
kwargs['input_format'] = 'html' # Force HTML input
|
||||
super().__init__(schema, **kwargs)
|
||||
|
||||
def _parse_html(self, html_content: str):
|
||||
return html.fromstring(html_content)
|
||||
|
||||
@@ -869,6 +924,7 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
||||
|
||||
class _JsonCssExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
kwargs['input_format'] = 'html' # Force HTML input
|
||||
super().__init__(**kwargs)
|
||||
self.schema = schema
|
||||
|
||||
@@ -983,6 +1039,7 @@ class _JsonCssExtractionStrategy(ExtractionStrategy):
|
||||
return self.extract(url, combined_html, **kwargs)
|
||||
class _JsonXPathExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
kwargs['input_format'] = 'html' # Force HTML input
|
||||
super().__init__(**kwargs)
|
||||
self.schema = schema
|
||||
|
||||
|
||||
@@ -1,196 +1,498 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from rank_bm25 import BM25Okapi
|
||||
import re
|
||||
from typing import List, Literal
|
||||
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import psutil
|
||||
import numpy as np
|
||||
from rank_bm25 import BM25Okapi
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
import nltk
|
||||
from litellm import completion, batch_completion
|
||||
from .async_logger import AsyncLogger
|
||||
import litellm
|
||||
import pickle
|
||||
import hashlib # <--- ADDED for file-hash
|
||||
from fnmatch import fnmatch
|
||||
import glob
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
BASE_PATH = Path(__file__).resolve().parent
|
||||
def _compute_file_hash(file_path: Path) -> str:
|
||||
"""Compute MD5 hash for the file's entire content."""
|
||||
hash_md5 = hashlib.md5()
|
||||
with file_path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
class LLMTextManager:
|
||||
"""Manages LLM text operations and caching"""
|
||||
|
||||
def __init__(self, docs_dir: Path, logger: Optional['AsyncLogger'] = None):
|
||||
class AsyncLLMTextManager:
|
||||
def __init__(
|
||||
self,
|
||||
docs_dir: Path,
|
||||
logger: Optional[AsyncLogger] = None,
|
||||
max_concurrent_calls: int = 5,
|
||||
batch_size: int = 3
|
||||
) -> None:
|
||||
self.docs_dir = docs_dir
|
||||
self.logger = logger
|
||||
|
||||
def get_file_map(self) -> dict:
|
||||
"""Cache file mappings to avoid repeated directory scans"""
|
||||
files = os.listdir(self.docs_dir)
|
||||
file_map = {}
|
||||
|
||||
for file in files:
|
||||
if file.endswith('.md'):
|
||||
# Extract number and name: "6_chunking_strategies.md" -> ("chunking_strategies", "6")
|
||||
match = re.match(r'(\d+)_(.+?)(?:\.(?:ex|xs|sm|q)?\.md)?$', file)
|
||||
if match:
|
||||
num, name = match.groups()
|
||||
if name not in file_map:
|
||||
file_map[name] = num
|
||||
return file_map
|
||||
self.max_concurrent_calls = max_concurrent_calls
|
||||
self.batch_size = batch_size
|
||||
self.bm25_index = None
|
||||
self.document_map: Dict[str, Any] = {}
|
||||
self.tokenized_facts: List[str] = []
|
||||
self.bm25_index_file = self.docs_dir / "bm25_index.pkl"
|
||||
|
||||
def concatenate_docs(self, file_names: List[str], mode: str) -> str:
|
||||
"""Concatenate documentation files based on names and mode."""
|
||||
file_map = self.get_file_map()
|
||||
result = []
|
||||
suffix_map = {
|
||||
"extended": ".ex.md",
|
||||
"condensed": [".xs.md", ".sm.md"]
|
||||
}
|
||||
|
||||
for name in file_names:
|
||||
if name not in file_map:
|
||||
continue
|
||||
|
||||
num = file_map[name]
|
||||
base_path = self.docs_dir
|
||||
|
||||
if mode == "extended":
|
||||
file_path = base_path / f"{num}_{name}{suffix_map[mode]}"
|
||||
if not file_path.exists():
|
||||
file_path = base_path / f"{num}_{name}.md"
|
||||
else:
|
||||
file_path = None
|
||||
for suffix in suffix_map["condensed"]:
|
||||
temp_path = base_path / f"{num}_{name}{suffix}"
|
||||
if temp_path.exists():
|
||||
file_path = temp_path
|
||||
break
|
||||
if not file_path:
|
||||
file_path = base_path / f"{num}_{name}.md"
|
||||
|
||||
if file_path.exists():
|
||||
async def _process_document_batch(self, doc_batch: List[Path]) -> None:
|
||||
"""Process a batch of documents in parallel"""
|
||||
contents = []
|
||||
for file_path in doc_batch:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
result.append(f.read())
|
||||
|
||||
return "\n\n---\n\n".join(result)
|
||||
contents.append(f.read())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading {file_path}: {str(e)}")
|
||||
contents.append("") # Add empty content to maintain batch alignment
|
||||
|
||||
def search_questions(self, query: str, top_k: int = 5) -> str:
|
||||
"""Search through Q files using BM25 ranking and return top K matches."""
|
||||
q_files = [f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")]
|
||||
# Prepare base path for file reading
|
||||
q_files = [self.docs_dir / f for f in q_files] # Convert to full path
|
||||
|
||||
documents = []
|
||||
file_contents = {}
|
||||
|
||||
for file in q_files:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
questions = extract_questions(content)
|
||||
for category, question, full_section in questions:
|
||||
documents.append(question)
|
||||
file_contents[question] = (file, category, full_section)
|
||||
prompt = """Given a documentation file, generate a list of atomic facts where each fact:
|
||||
1. Represents a single piece of knowledge
|
||||
2. Contains variations in terminology for the same concept
|
||||
3. References relevant code patterns if they exist
|
||||
4. Is written in a way that would match natural language queries
|
||||
|
||||
if not documents:
|
||||
return "No questions found in documentation."
|
||||
Each fact should follow this format:
|
||||
<main_concept>: <fact_statement> | <related_terms> | <code_reference>
|
||||
|
||||
tokenized_docs = [preprocess_text(doc) for doc in documents]
|
||||
tokenized_query = preprocess_text(query)
|
||||
Example Facts:
|
||||
browser_config: Configure headless mode and browser type for AsyncWebCrawler | headless, browser_type, chromium, firefox | BrowserConfig(browser_type="chromium", headless=True)
|
||||
redis_connection: Redis client connection requires host and port configuration | redis setup, redis client, connection params | Redis(host='localhost', port=6379, db=0)
|
||||
pandas_filtering: Filter DataFrame rows using boolean conditions | dataframe filter, query, boolean indexing | df[df['column'] > 5]
|
||||
|
||||
Wrap your response in <index>...</index> tags.
|
||||
"""
|
||||
|
||||
# Prepare messages for batch processing
|
||||
messages_list = [
|
||||
[
|
||||
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"}
|
||||
]
|
||||
for content in contents if content
|
||||
]
|
||||
|
||||
try:
|
||||
responses = batch_completion(
|
||||
model="anthropic/claude-3-5-sonnet-latest",
|
||||
messages=messages_list,
|
||||
logger_fn=None
|
||||
)
|
||||
|
||||
# Process responses and save index files
|
||||
for response, file_path in zip(responses, doc_batch):
|
||||
try:
|
||||
index_content_match = re.search(
|
||||
r'<index>(.*?)</index>',
|
||||
response.choices[0].message.content,
|
||||
re.DOTALL
|
||||
)
|
||||
if not index_content_match:
|
||||
self.logger.warning(f"No <index>...</index> content found for {file_path}")
|
||||
continue
|
||||
|
||||
index_content = re.sub(
|
||||
r"\n\s*\n", "\n", index_content_match.group(1)
|
||||
).strip()
|
||||
if index_content:
|
||||
index_file = file_path.with_suffix('.q.md')
|
||||
with open(index_file, 'w', encoding='utf-8') as f:
|
||||
f.write(index_content)
|
||||
self.logger.info(f"Created index file: {index_file}")
|
||||
else:
|
||||
self.logger.warning(f"No index content found in response for {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing response for {file_path}: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in batch completion: {str(e)}")
|
||||
|
||||
def _validate_fact_line(self, line: str) -> Tuple[bool, Optional[str]]:
|
||||
if "|" not in line:
|
||||
return False, "Missing separator '|'"
|
||||
|
||||
parts = [p.strip() for p in line.split("|")]
|
||||
if len(parts) != 3:
|
||||
return False, f"Expected 3 parts, got {len(parts)}"
|
||||
|
||||
concept_part = parts[0]
|
||||
if ":" not in concept_part:
|
||||
return False, "Missing ':' in concept definition"
|
||||
|
||||
return True, None
|
||||
|
||||
def _load_or_create_token_cache(self, fact_file: Path) -> Dict:
|
||||
"""
|
||||
Load token cache from .q.tokens if present and matching file hash.
|
||||
Otherwise return a new structure with updated file-hash.
|
||||
"""
|
||||
cache_file = fact_file.with_suffix(".q.tokens")
|
||||
current_hash = _compute_file_hash(fact_file)
|
||||
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
cache = json.load(f)
|
||||
# If the hash matches, return it directly
|
||||
if cache.get("content_hash") == current_hash:
|
||||
return cache
|
||||
# Otherwise, we signal that it's changed
|
||||
self.logger.info(f"Hash changed for {fact_file}, reindex needed.")
|
||||
except json.JSONDecodeError:
|
||||
self.logger.warning(f"Corrupt token cache for {fact_file}, rebuilding.")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error reading cache for {fact_file}: {str(e)}")
|
||||
|
||||
# Return a fresh cache
|
||||
return {"facts": {}, "content_hash": current_hash}
|
||||
|
||||
def _save_token_cache(self, fact_file: Path, cache: Dict) -> None:
|
||||
cache_file = fact_file.with_suffix(".q.tokens")
|
||||
# Always ensure we're saving the correct file-hash
|
||||
cache["content_hash"] = _compute_file_hash(fact_file)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(cache, f)
|
||||
|
||||
def preprocess_text(self, text: str) -> List[str]:
|
||||
parts = [x.strip() for x in text.split("|")] if "|" in text else [text]
|
||||
# Remove : after the first word of parts[0]
|
||||
parts[0] = re.sub(r"^(.*?):", r"\1", parts[0])
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
stop_words = set(stopwords.words("english")) - {
|
||||
"how", "what", "when", "where", "why", "which",
|
||||
}
|
||||
|
||||
tokens = []
|
||||
for part in parts:
|
||||
if "(" in part and ")" in part:
|
||||
code_tokens = re.findall(
|
||||
r'[\w_]+(?=\()|[\w_]+(?==[\'"]{1}[\w_]+[\'"]{1})', part
|
||||
)
|
||||
tokens.extend(code_tokens)
|
||||
|
||||
words = word_tokenize(part.lower())
|
||||
tokens.extend(
|
||||
[
|
||||
lemmatizer.lemmatize(token)
|
||||
for token in words
|
||||
if token not in stop_words
|
||||
]
|
||||
)
|
||||
|
||||
return tokens
|
||||
|
||||
def maybe_load_bm25_index(self, clear_cache=False) -> bool:
|
||||
"""
|
||||
Load existing BM25 index from disk, if present and clear_cache=False.
|
||||
"""
|
||||
if not clear_cache and os.path.exists(self.bm25_index_file):
|
||||
self.logger.info("Loading existing BM25 index from disk.")
|
||||
with open(self.bm25_index_file, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
self.tokenized_facts = data["tokenized_facts"]
|
||||
self.bm25_index = data["bm25_index"]
|
||||
return True
|
||||
return False
|
||||
|
||||
def build_search_index(self, clear_cache=False) -> None:
|
||||
"""
|
||||
Checks for new or modified .q.md files by comparing file-hash.
|
||||
If none need reindexing and clear_cache is False, loads existing index if available.
|
||||
Otherwise, reindexes only changed/new files and merges or creates a new index.
|
||||
"""
|
||||
# If clear_cache is True, we skip partial logic: rebuild everything from scratch
|
||||
if clear_cache:
|
||||
self.logger.info("Clearing cache and rebuilding full search index.")
|
||||
if self.bm25_index_file.exists():
|
||||
self.bm25_index_file.unlink()
|
||||
|
||||
process = psutil.Process()
|
||||
self.logger.info("Checking which .q.md files need (re)indexing...")
|
||||
|
||||
# Gather all .q.md files
|
||||
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")]
|
||||
|
||||
# We'll store known (unchanged) facts in these lists
|
||||
existing_facts: List[str] = []
|
||||
existing_tokens: List[List[str]] = []
|
||||
|
||||
# Keep track of invalid lines for logging
|
||||
invalid_lines = []
|
||||
needSet = [] # files that must be (re)indexed
|
||||
|
||||
for qf in q_files:
|
||||
token_cache_file = qf.with_suffix(".q.tokens")
|
||||
|
||||
# If no .q.tokens or clear_cache is True → definitely reindex
|
||||
if clear_cache or not token_cache_file.exists():
|
||||
needSet.append(qf)
|
||||
continue
|
||||
|
||||
# Otherwise, load the existing cache and compare hash
|
||||
cache = self._load_or_create_token_cache(qf)
|
||||
# If the .q.tokens was out of date (i.e. changed hash), we reindex
|
||||
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf):
|
||||
needSet.append(qf)
|
||||
else:
|
||||
# File is unchanged → retrieve cached token data
|
||||
for line, cache_data in cache["facts"].items():
|
||||
existing_facts.append(line)
|
||||
existing_tokens.append(cache_data["tokens"])
|
||||
self.document_map[line] = qf # track the doc for that fact
|
||||
|
||||
if not needSet and not clear_cache:
|
||||
# If no file needs reindexing, try loading existing index
|
||||
if self.maybe_load_bm25_index(clear_cache=False):
|
||||
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.")
|
||||
return
|
||||
else:
|
||||
# If there's no existing index, we must build a fresh index from the old caches
|
||||
self.logger.info("No existing BM25 index found. Building from cached facts.")
|
||||
if existing_facts:
|
||||
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.")
|
||||
self.bm25_index = BM25Okapi(existing_tokens)
|
||||
self.tokenized_facts = existing_facts
|
||||
with open(self.bm25_index_file, "wb") as f:
|
||||
pickle.dump({
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts
|
||||
}, f)
|
||||
else:
|
||||
self.logger.warning("No facts found at all. Index remains empty.")
|
||||
return
|
||||
|
||||
# ----------------------------------------------------- /Users/unclecode/.crawl4ai/docs/14_proxy_security.q.q.tokens '/Users/unclecode/.crawl4ai/docs/14_proxy_security.q.md'
|
||||
# If we reach here, we have new or changed .q.md files
|
||||
# We'll parse them, reindex them, and then combine with existing_facts
|
||||
# -----------------------------------------------------
|
||||
|
||||
self.logger.info(f"{len(needSet)} file(s) need reindexing. Parsing now...")
|
||||
|
||||
# 1) Parse the new or changed .q.md files
|
||||
new_facts = []
|
||||
new_tokens = []
|
||||
with tqdm(total=len(needSet), desc="Indexing changed files") as file_pbar:
|
||||
for file in needSet:
|
||||
# We'll build up a fresh cache
|
||||
fresh_cache = {"facts": {}, "content_hash": _compute_file_hash(file)}
|
||||
try:
|
||||
with open(file, "r", encoding="utf-8") as f_obj:
|
||||
content = f_obj.read().strip()
|
||||
lines = [l.strip() for l in content.split("\n") if l.strip()]
|
||||
|
||||
for line in lines:
|
||||
is_valid, error = self._validate_fact_line(line)
|
||||
if not is_valid:
|
||||
invalid_lines.append((file, line, error))
|
||||
continue
|
||||
|
||||
tokens = self.preprocess_text(line)
|
||||
fresh_cache["facts"][line] = {
|
||||
"tokens": tokens,
|
||||
"added": time.time(),
|
||||
}
|
||||
new_facts.append(line)
|
||||
new_tokens.append(tokens)
|
||||
self.document_map[line] = file
|
||||
|
||||
# Save the new .q.tokens with updated hash
|
||||
self._save_token_cache(file, fresh_cache)
|
||||
|
||||
mem_usage = process.memory_info().rss / 1024 / 1024
|
||||
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file}: {str(e)}")
|
||||
|
||||
file_pbar.update(1)
|
||||
|
||||
if invalid_lines:
|
||||
self.logger.warning(f"Found {len(invalid_lines)} invalid fact lines:")
|
||||
for file, line, error in invalid_lines:
|
||||
self.logger.warning(f"{file}: {error} in line: {line[:50]}...")
|
||||
|
||||
# 2) Merge newly tokenized facts with the existing ones
|
||||
all_facts = existing_facts + new_facts
|
||||
all_tokens = existing_tokens + new_tokens
|
||||
|
||||
# 3) Build BM25 index from combined facts
|
||||
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).")
|
||||
self.bm25_index = BM25Okapi(all_tokens)
|
||||
self.tokenized_facts = all_facts
|
||||
|
||||
# 4) Save the updated BM25 index to disk
|
||||
with open(self.bm25_index_file, "wb") as f:
|
||||
pickle.dump({
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts
|
||||
}, f)
|
||||
|
||||
final_mem = process.memory_info().rss / 1024 / 1024
|
||||
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
|
||||
|
||||
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None:
|
||||
"""
|
||||
Generate index files for all documents in parallel batches
|
||||
|
||||
bm25 = BM25Okapi(tokenized_docs)
|
||||
doc_scores = bm25.get_scores(tokenized_query)
|
||||
Args:
|
||||
force_generate_facts (bool): If True, regenerate indexes even if they exist
|
||||
clear_bm25_cache (bool): If True, clear existing BM25 index cache
|
||||
"""
|
||||
self.logger.info("Starting index generation for documentation files.")
|
||||
|
||||
score_threshold = max(doc_scores) * 0.4
|
||||
md_files = [
|
||||
self.docs_dir / f for f in os.listdir(self.docs_dir)
|
||||
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md'])
|
||||
]
|
||||
|
||||
# Filter out files that already have .q files unless force=True
|
||||
if not force_generate_facts:
|
||||
md_files = [
|
||||
f for f in md_files
|
||||
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists()
|
||||
]
|
||||
|
||||
if not md_files:
|
||||
self.logger.info("All index files exist. Use force=True to regenerate.")
|
||||
else:
|
||||
# Process documents in batches
|
||||
for i in range(0, len(md_files), self.batch_size):
|
||||
batch = md_files[i:i + self.batch_size]
|
||||
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}")
|
||||
await self._process_document_batch(batch)
|
||||
|
||||
self.logger.info("Index generation complete, building/updating search index.")
|
||||
self.build_search_index(clear_cache=clear_bm25_cache)
|
||||
|
||||
def generate(self, sections: List[str], mode: str = "extended") -> str:
|
||||
# Get all markdown files
|
||||
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \
|
||||
glob.glob(str(self.docs_dir / "[0-9]*.xs.md"))
|
||||
|
||||
# Aggregate scores by file
|
||||
file_data = {}
|
||||
for idx, score in enumerate(doc_scores):
|
||||
if score > score_threshold:
|
||||
question = documents[idx]
|
||||
file, category, _ = file_contents[question]
|
||||
|
||||
if file not in file_data:
|
||||
file_data[file] = {
|
||||
'total_score': 0,
|
||||
'match_count': 0,
|
||||
'questions': []
|
||||
}
|
||||
|
||||
file_data[file]['total_score'] += score
|
||||
file_data[file]['match_count'] += 1
|
||||
file_data[file]['questions'].append({
|
||||
'category': category,
|
||||
'question': question,
|
||||
'score': score
|
||||
})
|
||||
# Extract base names without extensions
|
||||
base_docs = {Path(f).name.split('.')[0] for f in all_files
|
||||
if not Path(f).name.endswith('.q.md')}
|
||||
|
||||
# Sort files by match count and total score
|
||||
# Filter by sections if provided
|
||||
if sections:
|
||||
base_docs = {doc for doc in base_docs
|
||||
if any(section.lower() in doc.lower() for section in sections)}
|
||||
|
||||
# Get file paths based on mode
|
||||
files = []
|
||||
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999):
|
||||
if mode == "condensed":
|
||||
xs_file = self.docs_dir / f"{doc}.xs.md"
|
||||
regular_file = self.docs_dir / f"{doc}.md"
|
||||
files.append(str(xs_file if xs_file.exists() else regular_file))
|
||||
else:
|
||||
files.append(str(self.docs_dir / f"{doc}.md"))
|
||||
|
||||
# Read and format content
|
||||
content = []
|
||||
for file in files:
|
||||
try:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
fname = Path(file).name
|
||||
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading {file}: {str(e)}")
|
||||
|
||||
return "\n\n---\n\n".join(content) if content else ""
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> str:
|
||||
if not self.bm25_index:
|
||||
return "No search index available. Call build_search_index() first."
|
||||
|
||||
query_tokens = self.preprocess_text(query)
|
||||
doc_scores = self.bm25_index.get_scores(query_tokens)
|
||||
|
||||
mean_score = np.mean(doc_scores)
|
||||
std_score = np.std(doc_scores)
|
||||
score_threshold = mean_score + (0.25 * std_score)
|
||||
|
||||
file_data = self._aggregate_search_scores(
|
||||
doc_scores=doc_scores,
|
||||
score_threshold=score_threshold,
|
||||
query_tokens=query_tokens,
|
||||
)
|
||||
|
||||
ranked_files = sorted(
|
||||
file_data.items(),
|
||||
key=lambda x: (x[1]['match_count'], x[1]['total_score']),
|
||||
reverse=True
|
||||
key=lambda x: (
|
||||
x[1]["code_match_score"] * 2.0
|
||||
+ x[1]["match_count"] * 1.5
|
||||
+ x[1]["total_score"]
|
||||
),
|
||||
reverse=True,
|
||||
)[:top_k]
|
||||
|
||||
# Format results by file
|
||||
|
||||
results = []
|
||||
for file, data in ranked_files:
|
||||
questions_summary = "\n".join(
|
||||
f"- [{q['category']}] {q['question']} (score: {q['score']:.2f})"
|
||||
for q in sorted(data['questions'], key=lambda x: x['score'], reverse=True)
|
||||
for file, _ in ranked_files:
|
||||
main_doc = str(file).replace(".q.md", ".md")
|
||||
if os.path.exists(self.docs_dir / main_doc):
|
||||
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f:
|
||||
only_file_name = main_doc.split("/")[-1]
|
||||
content = [
|
||||
"#" * 20,
|
||||
f"# {only_file_name}",
|
||||
"#" * 20,
|
||||
"",
|
||||
f.read()
|
||||
]
|
||||
results.append("\n".join(content))
|
||||
|
||||
return "\n\n---\n\n".join(results)
|
||||
|
||||
def _aggregate_search_scores(
|
||||
self, doc_scores: List[float], score_threshold: float, query_tokens: List[str]
|
||||
) -> Dict:
|
||||
file_data = {}
|
||||
|
||||
for idx, score in enumerate(doc_scores):
|
||||
if score <= score_threshold:
|
||||
continue
|
||||
|
||||
fact = self.tokenized_facts[idx]
|
||||
file_path = self.document_map[fact]
|
||||
|
||||
if file_path not in file_data:
|
||||
file_data[file_path] = {
|
||||
"total_score": 0,
|
||||
"match_count": 0,
|
||||
"code_match_score": 0,
|
||||
"matched_facts": [],
|
||||
}
|
||||
|
||||
components = fact.split("|") if "|" in fact else [fact]
|
||||
|
||||
code_match_score = 0
|
||||
if len(components) == 3:
|
||||
code_ref = components[2].strip()
|
||||
code_tokens = self.preprocess_text(code_ref)
|
||||
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens)
|
||||
|
||||
file_data[file_path]["total_score"] += score
|
||||
file_data[file_path]["match_count"] += 1
|
||||
file_data[file_path]["code_match_score"] = max(
|
||||
file_data[file_path]["code_match_score"], code_match_score
|
||||
)
|
||||
|
||||
results.append(
|
||||
f"File: {file}\n"
|
||||
f"Match Count: {data['match_count']}\n"
|
||||
f"Total Score: {data['total_score']:.2f}\n\n"
|
||||
f"Matching Questions:\n{questions_summary}"
|
||||
)
|
||||
|
||||
return "\n\n---\n\n".join(results) if results else "No relevant matches found."
|
||||
file_data[file_path]["matched_facts"].append(fact)
|
||||
|
||||
def extract_questions(content: str) -> List[tuple[str, str, str]]:
|
||||
"""
|
||||
Extract questions from Q files, returning list of (category, question, full_section).
|
||||
"""
|
||||
# Split into main sections (### Questions or ### Hypothetical Questions)
|
||||
sections = re.split(r'^###\s+.*Questions\s*$', content, flags=re.MULTILINE)[1:]
|
||||
|
||||
results = []
|
||||
for section in sections:
|
||||
# Find all numbered categories (1. **Category Name**)
|
||||
categories = re.split(r'^\d+\.\s+\*\*([^*]+)\*\*\s*$', section.strip(), flags=re.MULTILINE)
|
||||
|
||||
# Process each category
|
||||
for i in range(1, len(categories), 2):
|
||||
category = categories[i].strip()
|
||||
category_content = categories[i+1].strip()
|
||||
|
||||
# Extract questions (lines starting with dash and wrapped in italics)
|
||||
questions = re.findall(r'^\s*-\s*\*"([^"]+)"\*\s*$', category_content, flags=re.MULTILINE)
|
||||
|
||||
# Add each question with its category and full context
|
||||
for q in questions:
|
||||
results.append((category, q, f"Category: {category}\nQuestion: {q}"))
|
||||
|
||||
return results
|
||||
return file_data
|
||||
|
||||
def preprocess_text(text: str) -> List[str]:
|
||||
"""Preprocess text for better semantic matching"""
|
||||
# Lowercase and tokenize
|
||||
tokens = word_tokenize(text.lower())
|
||||
|
||||
# Remove stopwords but keep question words
|
||||
stop_words = set(stopwords.words('english')) - {'how', 'what', 'when', 'where', 'why', 'which'}
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
|
||||
# Lemmatize but preserve original form for technical terms
|
||||
tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
|
||||
|
||||
return tokens
|
||||
|
||||
if __name__ == "__main__":
|
||||
llm_manager = LLMTextManager(BASE_PATH)
|
||||
|
||||
# Example 1: Concatenate docs
|
||||
docs = llm_manager.concatenate_docs(["chunking_strategies", "content_selection"], "extended")
|
||||
print("Concatenated docs:", docs[:200], "...\n")
|
||||
|
||||
# Example 2: Search questions
|
||||
results = llm_manager.search_questions("How do I execute JS script on the page?", 3)
|
||||
print("Search results:", results[:200], "...")
|
||||
def refresh_index(self) -> None:
|
||||
"""Convenience method for a full rebuild."""
|
||||
self.build_search_index(clear_cache=True)
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from typing import List, Dict, Optional, Callable, Awaitable, Union
|
||||
|
||||
|
||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens_details: Optional[dict] = None
|
||||
prompt_tokens_details: Optional[dict] = None
|
||||
|
||||
|
||||
class UrlModel(BaseModel):
|
||||
url: HttpUrl
|
||||
@@ -34,7 +41,8 @@ class CrawlResult(BaseModel):
|
||||
session_id: Optional[str] = None
|
||||
response_headers: Optional[dict] = None
|
||||
status_code: Optional[int] = None
|
||||
|
||||
ssl_certificate: Optional[Dict[str, Any]] = None
|
||||
|
||||
class AsyncCrawlResponse(BaseModel):
|
||||
html: str
|
||||
response_headers: Dict[str, str]
|
||||
@@ -43,8 +51,7 @@ class AsyncCrawlResponse(BaseModel):
|
||||
pdf_data: Optional[bytes] = None
|
||||
get_delayed_content: Optional[Callable[[Optional[float]], Awaitable[str]]] = None
|
||||
downloaded_files: Optional[List[str]] = None
|
||||
ssl_certificate: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
|
||||
156
crawl4ai/utilities/cert_exporter.py
Normal file
156
crawl4ai/utilities/cert_exporter.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Utility functions for exporting SSL certificates in various formats."""
|
||||
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
import OpenSSL.crypto
|
||||
from datetime import datetime
|
||||
|
||||
class CertificateExporter:
|
||||
"""
|
||||
Handles exporting SSL certificates in various formats:
|
||||
1. JSON - Human-readable format with all certificate details
|
||||
2. PEM - Standard text format for certificates
|
||||
3. DER - Binary format
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _decode_cert_data(data: Any) -> Any:
|
||||
"""Helper method to decode bytes in certificate data."""
|
||||
if isinstance(data, bytes):
|
||||
return data.decode('utf-8')
|
||||
elif isinstance(data, dict):
|
||||
return {
|
||||
(k.decode('utf-8') if isinstance(k, bytes) else k): CertificateExporter._decode_cert_data(v)
|
||||
for k, v in data.items()
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return [CertificateExporter._decode_cert_data(item) for item in data]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def to_json(cert_info: Dict[str, Any], filepath: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Export certificate information to JSON format.
|
||||
|
||||
Args:
|
||||
cert_info: Dictionary containing certificate information
|
||||
filepath: Optional path to save the JSON file
|
||||
|
||||
Returns:
|
||||
str: JSON string if filepath is None, otherwise None
|
||||
"""
|
||||
if not cert_info:
|
||||
return None
|
||||
|
||||
# Decode any bytes in the certificate data
|
||||
cert_data = CertificateExporter._decode_cert_data(cert_info)
|
||||
|
||||
# Convert datetime objects to ISO format strings
|
||||
for key, value in cert_data.items():
|
||||
if isinstance(value, datetime):
|
||||
cert_data[key] = value.isoformat()
|
||||
|
||||
json_str = json.dumps(cert_data, indent=2, ensure_ascii=False)
|
||||
|
||||
if filepath:
|
||||
Path(filepath).write_text(json_str, encoding='utf-8')
|
||||
return None
|
||||
return json_str
|
||||
|
||||
@staticmethod
|
||||
def to_pem(cert_info: Dict[str, Any], filepath: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Export certificate to PEM format.
|
||||
This is the most common format, used for Apache/Nginx configs.
|
||||
|
||||
Args:
|
||||
cert_info: Dictionary containing certificate information
|
||||
filepath: Optional path to save the PEM file
|
||||
|
||||
Returns:
|
||||
str: PEM string if filepath is None, otherwise None
|
||||
"""
|
||||
if not cert_info or 'raw_cert' not in cert_info:
|
||||
return None
|
||||
|
||||
try:
|
||||
x509 = OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_ASN1,
|
||||
base64.b64decode(cert_info['raw_cert'])
|
||||
)
|
||||
pem_data = OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
x509
|
||||
).decode('utf-8')
|
||||
|
||||
if filepath:
|
||||
Path(filepath).write_text(pem_data, encoding='utf-8')
|
||||
return None
|
||||
return pem_data
|
||||
|
||||
except Exception as e:
|
||||
return f"Error converting to PEM: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def to_der(cert_info: Dict[str, Any], filepath: Optional[str] = None) -> Optional[bytes]:
|
||||
"""
|
||||
Export certificate to DER format (binary).
|
||||
This format is commonly used in Java environments.
|
||||
|
||||
Args:
|
||||
cert_info: Dictionary containing certificate information
|
||||
filepath: Optional path to save the DER file
|
||||
|
||||
Returns:
|
||||
bytes: DER bytes if filepath is None, otherwise None
|
||||
"""
|
||||
if not cert_info or 'raw_cert' not in cert_info:
|
||||
return None
|
||||
|
||||
try:
|
||||
der_data = base64.b64decode(cert_info['raw_cert'])
|
||||
|
||||
if filepath:
|
||||
Path(filepath).write_bytes(der_data)
|
||||
return None
|
||||
return der_data
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def export_all(cert_info: Dict[str, Any], base_path: str, filename: str) -> Dict[str, str]:
|
||||
"""
|
||||
Export certificate in all supported formats.
|
||||
|
||||
Args:
|
||||
cert_info: Dictionary containing certificate information
|
||||
base_path: Base directory to save the files
|
||||
filename: Base filename without extension
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Dictionary mapping format to filepath
|
||||
"""
|
||||
base_path = Path(base_path)
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
paths = {}
|
||||
|
||||
# Export JSON
|
||||
json_path = base_path / f"{filename}.json"
|
||||
CertificateExporter.to_json(cert_info, str(json_path))
|
||||
paths['json'] = str(json_path)
|
||||
|
||||
# Export PEM
|
||||
pem_path = base_path / f"{filename}.pem"
|
||||
CertificateExporter.to_pem(cert_info, str(pem_path))
|
||||
paths['pem'] = str(pem_path)
|
||||
|
||||
# Export DER
|
||||
der_path = base_path / f"{filename}.der"
|
||||
CertificateExporter.to_der(cert_info, str(der_path))
|
||||
paths['der'] = str(der_path)
|
||||
|
||||
return paths
|
||||
83
crawl4ai/utilities/ssl_utils.py
Normal file
83
crawl4ai/utilities/ssl_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Utility functions for SSL certificate handling."""
|
||||
|
||||
import ssl
|
||||
import socket
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
import OpenSSL.crypto
|
||||
import datetime
|
||||
import base64
|
||||
|
||||
|
||||
def get_ssl_certificate(url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve SSL certificate information from a given URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL to get SSL certificate from
|
||||
timeout (int): Socket timeout in seconds
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Dictionary containing certificate information or None if not available
|
||||
|
||||
The returned dictionary includes:
|
||||
- subject: Certificate subject information
|
||||
- issuer: Certificate issuer information
|
||||
- version: SSL version
|
||||
- serial_number: Certificate serial number
|
||||
- not_before: Certificate validity start date
|
||||
- not_after: Certificate validity end date
|
||||
- fingerprint: Certificate fingerprint
|
||||
- raw_cert: Base64 encoded raw certificate data
|
||||
"""
|
||||
try:
|
||||
hostname = urlparse(url).netloc
|
||||
if ':' in hostname:
|
||||
hostname = hostname.split(':')[0]
|
||||
|
||||
context = ssl.create_default_context()
|
||||
with socket.create_connection((hostname, 443), timeout=timeout) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||
cert_binary = ssock.getpeercert(binary_form=True)
|
||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary)
|
||||
|
||||
cert_info = {
|
||||
"subject": {
|
||||
key: value.decode() if isinstance(value, bytes) else value
|
||||
for key, value in dict(x509.get_subject().get_components()).items()
|
||||
},
|
||||
"issuer": {
|
||||
key: value.decode() if isinstance(value, bytes) else value
|
||||
for key, value in dict(x509.get_issuer().get_components()).items()
|
||||
},
|
||||
"version": x509.get_version(),
|
||||
"serial_number": hex(x509.get_serial_number()),
|
||||
"not_before": x509.get_notBefore().decode(),
|
||||
"not_after": x509.get_notAfter().decode(),
|
||||
"fingerprint": x509.digest("sha256").hex(),
|
||||
"signature_algorithm": x509.get_signature_algorithm().decode(),
|
||||
"raw_cert": base64.b64encode(cert_binary).decode('utf-8')
|
||||
}
|
||||
|
||||
# Add extensions
|
||||
extensions = []
|
||||
for i in range(x509.get_extension_count()):
|
||||
ext = x509.get_extension(i)
|
||||
extensions.append({
|
||||
"name": ext.get_short_name().decode(),
|
||||
"value": str(ext)
|
||||
})
|
||||
cert_info["extensions"] = extensions
|
||||
|
||||
return cert_info
|
||||
|
||||
except (socket.gaierror, socket.timeout, ssl.SSLError, ValueError) as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"status": "failed"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"error": f"Unexpected error: {str(e)}",
|
||||
"status": "failed"
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString
|
||||
import json
|
||||
@@ -6,7 +7,6 @@ import html
|
||||
import re
|
||||
import os
|
||||
import platform
|
||||
from .html2text import HTML2Text
|
||||
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||
from .config import *
|
||||
from pathlib import Path
|
||||
@@ -14,7 +14,6 @@ from typing import Dict, Any
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
from requests.exceptions import InvalidSchema
|
||||
import hashlib
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
import xxhash
|
||||
from colorama import Fore, Style, init
|
||||
@@ -1110,21 +1109,52 @@ def normalize_url_tmp(href, base_url):
|
||||
|
||||
return href.strip()
|
||||
|
||||
def is_external_url(url, base_domain):
|
||||
"""Determine if a URL is external"""
|
||||
special_protocols = {'mailto:', 'tel:', 'ftp:', 'file:', 'data:', 'javascript:'}
|
||||
if any(url.lower().startswith(proto) for proto in special_protocols):
|
||||
def get_base_domain(url: str) -> str:
|
||||
"""Extract base domain from URL, handling various edge cases."""
|
||||
try:
|
||||
# Get domain from URL
|
||||
domain = urlparse(url).netloc.lower()
|
||||
if not domain:
|
||||
return ""
|
||||
|
||||
# Remove port if present
|
||||
domain = domain.split(':')[0]
|
||||
|
||||
# Remove www
|
||||
domain = re.sub(r'^www\.', '', domain)
|
||||
|
||||
# Extract last two parts of domain (handles co.uk etc)
|
||||
parts = domain.split('.')
|
||||
if len(parts) > 2 and parts[-2] in {
|
||||
'co', 'com', 'org', 'gov', 'edu', 'net',
|
||||
'mil', 'int', 'ac', 'ad', 'ae', 'af', 'ag'
|
||||
}:
|
||||
return '.'.join(parts[-3:])
|
||||
|
||||
return '.'.join(parts[-2:])
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def is_external_url(url: str, base_domain: str) -> bool:
|
||||
"""Check if URL is external to base domain."""
|
||||
special = {'mailto:', 'tel:', 'ftp:', 'file:', 'data:', 'javascript:'}
|
||||
if any(url.lower().startswith(p) for p in special):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Handle URLs with protocol
|
||||
if url.startswith(('http://', 'https://')):
|
||||
url_domain = url.split('/')[2]
|
||||
return base_domain.lower() not in url_domain.lower()
|
||||
except IndexError:
|
||||
return False
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc: # Relative URL
|
||||
return False
|
||||
|
||||
# Strip 'www.' from both domains for comparison
|
||||
url_domain = parsed.netloc.lower().replace('www.', '')
|
||||
base = base_domain.lower().replace('www.', '')
|
||||
|
||||
return False
|
||||
# Check if URL domain ends with base domain
|
||||
return not url_domain.endswith(base)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def clean_tokens(tokens: list[str]) -> list[str]:
|
||||
# Set of tokens to remove
|
||||
@@ -1289,4 +1319,7 @@ def get_error_context(exc_info, context_lines: int = 5):
|
||||
"line_no": line_no,
|
||||
"function": func_name,
|
||||
"code_context": code_context
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user