feat: add comprehensive type definitions and improve test coverage
Add new type definitions file with extensive Union type aliases for all core components including AsyncUrlSeeder, SeedingConfig, and various crawler strategies. Enhance test coverage with improved bot detection tests, Docker-based testing, and extended features validation. The changes provide better type safety and more robust testing infrastructure for the crawling framework.
This commit is contained in:
@@ -779,6 +779,144 @@ async def test_stream_crawl(token: str = None): # Made token optional
|
||||
# asyncio.run(test_stream_crawl())
|
||||
```
|
||||
|
||||
#### LLM Job with Chunking Strategy
|
||||
|
||||
```python
|
||||
import requests
|
||||
import time
|
||||
|
||||
# Example: LLM extraction with RegexChunking strategy
|
||||
# This breaks large documents into smaller chunks before LLM processing
|
||||
|
||||
llm_job_payload = {
|
||||
"url": "https://example.com/long-article",
|
||||
"q": "Extract all key points and main ideas from this article",
|
||||
"chunking_strategy": {
|
||||
"type": "RegexChunking",
|
||||
"params": {
|
||||
"patterns": ["\\n\\n"], # Split on double newlines (paragraphs)
|
||||
"overlap": 50
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Submit LLM job
|
||||
response = requests.post(
|
||||
"http://localhost:11235/llm/job",
|
||||
json=llm_job_payload
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
job_data = response.json()
|
||||
job_id = job_data["task_id"]
|
||||
print(f"Job submitted successfully. Job ID: {job_id}")
|
||||
|
||||
# Poll for completion
|
||||
while True:
|
||||
status_response = requests.get(f"http://localhost:11235/llm/job/{job_id}")
|
||||
if status_response.ok:
|
||||
status_data = status_response.json()
|
||||
if status_data["status"] == "completed":
|
||||
print("Job completed!")
|
||||
print("Extracted content:", status_data["result"])
|
||||
break
|
||||
elif status_data["status"] == "failed":
|
||||
print("Job failed:", status_data.get("error"))
|
||||
break
|
||||
else:
|
||||
print(f"Job status: {status_data['status']}")
|
||||
time.sleep(2) # Wait 2 seconds before checking again
|
||||
else:
|
||||
print(f"Error checking job status: {status_response.text}")
|
||||
break
|
||||
else:
|
||||
print(f"Error submitting job: {response.text}")
|
||||
```
|
||||
|
||||
**Available Chunking Strategies:**
|
||||
|
||||
- **IdentityChunking**: Returns the entire content as a single chunk (no splitting)
|
||||
```json
|
||||
{
|
||||
"type": "IdentityChunking",
|
||||
"params": {}
|
||||
}
|
||||
```
|
||||
|
||||
- **RegexChunking**: Split content using regular expression patterns
|
||||
```json
|
||||
{
|
||||
"type": "RegexChunking",
|
||||
"params": {
|
||||
"patterns": ["\\n\\n"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **NlpSentenceChunking**: Split content into sentences using NLP (requires NLTK)
|
||||
```json
|
||||
{
|
||||
"type": "NlpSentenceChunking",
|
||||
"params": {}
|
||||
}
|
||||
```
|
||||
|
||||
- **TopicSegmentationChunking**: Segment content into topics using TextTiling (requires NLTK)
|
||||
```json
|
||||
{
|
||||
"type": "TopicSegmentationChunking",
|
||||
"params": {
|
||||
"num_keywords": 3
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **FixedLengthWordChunking**: Split into fixed-length word chunks
|
||||
```json
|
||||
{
|
||||
"type": "FixedLengthWordChunking",
|
||||
"params": {
|
||||
"chunk_size": 100
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **SlidingWindowChunking**: Overlapping word chunks with configurable step size
|
||||
```json
|
||||
{
|
||||
"type": "SlidingWindowChunking",
|
||||
"params": {
|
||||
"window_size": 100,
|
||||
"step": 50
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- **OverlappingWindowChunking**: Fixed-size chunks with word overlap
|
||||
```json
|
||||
{
|
||||
"type": "OverlappingWindowChunking",
|
||||
"params": {
|
||||
"window_size": 1000,
|
||||
"overlap": 100
|
||||
}
|
||||
}
|
||||
```
|
||||
{
|
||||
"type": "OverlappingWindowChunking",
|
||||
"params": {
|
||||
"chunk_size": 1500,
|
||||
"overlap": 100
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- `chunking_strategy` is optional - if omitted, default token-based chunking is used
|
||||
- Chunking is applied at the API level without modifying the core SDK
|
||||
- Results from all chunks are merged into a single response
|
||||
- Each chunk is processed independently with the same LLM instruction
|
||||
|
||||
---
|
||||
|
||||
## Metrics & Monitoring
|
||||
|
||||
@@ -60,7 +60,7 @@ try:
|
||||
from utils import (
|
||||
FilterType, TaskStatus, get_base_url, is_task_id,
|
||||
get_llm_api_key, get_llm_temperature, get_llm_base_url,
|
||||
validate_llm_provider
|
||||
validate_llm_provider, create_chunking_strategy
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback definitions for development/testing
|
||||
@@ -249,6 +249,7 @@ async def process_llm_extraction(
|
||||
provider: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
base_url: Optional[str] = None,
|
||||
chunking_strategy_config: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Process LLM extraction in background."""
|
||||
try:
|
||||
@@ -263,44 +264,145 @@ async def process_llm_extraction(
|
||||
api_key = get_llm_api_key(
|
||||
config, provider
|
||||
) # Returns None to let litellm handle it
|
||||
llm_strategy = LLMExtractionStrategy(
|
||||
llm_config=LLMConfig(
|
||||
|
||||
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
||||
|
||||
if chunking_strategy_config:
|
||||
# API-level chunking approach: crawl first, then chunk, then extract
|
||||
try:
|
||||
chunking_strategy = create_chunking_strategy(chunking_strategy_config)
|
||||
except ValueError as e:
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": f"Invalid chunking strategy: {str(e)}"},
|
||||
)
|
||||
return
|
||||
|
||||
# Step 1: Crawl the URL to get raw content
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
crawl_result = await crawler.arun(
|
||||
url=url,
|
||||
config=CrawlerRunConfig(
|
||||
extraction_strategy=NoExtractionStrategy(),
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if not crawl_result.success:
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": crawl_result.error_message},
|
||||
)
|
||||
return
|
||||
|
||||
# Step 2: Apply chunking to the raw content
|
||||
raw_content = crawl_result.markdown_v2.raw_markdown if hasattr(crawl_result, 'markdown_v2') else crawl_result.markdown
|
||||
if not raw_content:
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": "No content extracted from URL"},
|
||||
)
|
||||
return
|
||||
|
||||
chunks = chunking_strategy.chunk(raw_content)
|
||||
# Filter out empty chunks
|
||||
chunks = [chunk for chunk in chunks if chunk.strip()]
|
||||
|
||||
if not chunks:
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": "No valid chunks after applying chunking strategy"},
|
||||
)
|
||||
return
|
||||
|
||||
# Step 3: Process each chunk with LLM extraction
|
||||
llm_config = LLMConfig(
|
||||
provider=provider or config["llm"]["provider"],
|
||||
api_token=api_key,
|
||||
temperature=temperature or get_llm_temperature(config, provider),
|
||||
base_url=base_url or get_llm_base_url(config, provider),
|
||||
),
|
||||
instruction=instruction,
|
||||
schema=json.loads(schema) if schema else None,
|
||||
)
|
||||
|
||||
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
config=CrawlerRunConfig(
|
||||
extraction_strategy=llm_strategy,
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
all_results = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Create LLM strategy for this chunk
|
||||
chunk_instruction = f"{instruction}\n\nContent chunk {i+1}/{len(chunks)}:\n{chunk}"
|
||||
llm_strategy = LLMExtractionStrategy(
|
||||
llm_config=llm_config,
|
||||
instruction=chunk_instruction,
|
||||
schema=json.loads(schema) if schema else None,
|
||||
)
|
||||
|
||||
# Extract from this chunk
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
chunk_result = await crawler.arun(
|
||||
url=url,
|
||||
config=CrawlerRunConfig(
|
||||
extraction_strategy=llm_strategy,
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if chunk_result.success:
|
||||
try:
|
||||
chunk_content = json.loads(chunk_result.extracted_content)
|
||||
all_results.extend(chunk_content if isinstance(chunk_content, list) else [chunk_content])
|
||||
except json.JSONDecodeError:
|
||||
all_results.append(chunk_result.extracted_content)
|
||||
# Continue with other chunks even if one fails
|
||||
|
||||
except Exception as chunk_error:
|
||||
# Log chunk error but continue with other chunks
|
||||
print(f"Error processing chunk {i+1}: {chunk_error}")
|
||||
continue
|
||||
|
||||
# Step 4: Store merged results
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": result.error_message},
|
||||
mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(all_results)},
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
content = json.loads(result.extracted_content)
|
||||
except json.JSONDecodeError:
|
||||
content = result.extracted_content
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(content)},
|
||||
)
|
||||
else:
|
||||
# Original approach: direct LLM extraction without chunking
|
||||
llm_strategy = LLMExtractionStrategy(
|
||||
llm_config=LLMConfig(
|
||||
provider=provider or config["llm"]["provider"],
|
||||
api_token=api_key,
|
||||
temperature=temperature or get_llm_temperature(config, provider),
|
||||
base_url=base_url or get_llm_base_url(config, provider),
|
||||
),
|
||||
instruction=instruction,
|
||||
schema=json.loads(schema) if schema else None,
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
config=CrawlerRunConfig(
|
||||
extraction_strategy=llm_strategy,
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.FAILED, "error": result.error_message},
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
content = json.loads(result.extracted_content)
|
||||
except json.JSONDecodeError:
|
||||
content = result.extracted_content
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(content)},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
|
||||
@@ -398,6 +500,7 @@ async def handle_llm_request(
|
||||
provider: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
chunking_strategy_config: Optional[dict] = None,
|
||||
) -> JSONResponse:
|
||||
"""Handle LLM extraction requests."""
|
||||
base_url = get_base_url(request)
|
||||
@@ -431,6 +534,7 @@ async def handle_llm_request(
|
||||
provider,
|
||||
temperature,
|
||||
api_base_url,
|
||||
chunking_strategy_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -473,6 +577,7 @@ async def create_new_task(
|
||||
provider: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
chunking_strategy_config: Optional[dict] = None,
|
||||
) -> JSONResponse:
|
||||
"""Create and initialize a new task."""
|
||||
decoded_url = unquote(input_path)
|
||||
@@ -506,6 +611,7 @@ async def create_new_task(
|
||||
provider,
|
||||
temperature,
|
||||
api_base_url,
|
||||
chunking_strategy_config,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -982,3 +1088,26 @@ async def handle_seed(url, cfg):
|
||||
"count": 0,
|
||||
"message": "No URLs found for the given domain and configuration.",
|
||||
}
|
||||
|
||||
|
||||
async def handle_url_discovery(domain, seeding_config):
|
||||
"""
|
||||
Handle URL discovery using AsyncUrlSeeder functionality.
|
||||
|
||||
Args:
|
||||
domain (str): Domain to discover URLs from
|
||||
seeding_config (dict): Configuration for URL discovery
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Discovered URL objects with metadata
|
||||
"""
|
||||
try:
|
||||
config = SeedingConfig(**seeding_config)
|
||||
|
||||
# Use an async context manager for the seeder
|
||||
async with AsyncUrlSeeder() as seeder:
|
||||
# The seeder's 'urls' method expects a domain
|
||||
urls = await seeder.urls(domain, config)
|
||||
return urls
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
@@ -39,6 +39,7 @@ class LlmJobPayload(BaseModel):
|
||||
provider: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
base_url: Optional[str] = None
|
||||
chunking_strategy: Optional[Dict] = None
|
||||
|
||||
|
||||
class CrawlJobPayload(BaseModel):
|
||||
@@ -67,6 +68,7 @@ async def llm_job_enqueue(
|
||||
provider=payload.provider,
|
||||
temperature=payload.temperature,
|
||||
api_base_url=payload.base_url,
|
||||
chunking_strategy_config=payload.chunking_strategy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -174,6 +174,31 @@ class SeedRequest(BaseModel):
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class URLDiscoveryRequest(BaseModel):
|
||||
"""Request model for URL discovery endpoint."""
|
||||
|
||||
domain: str = Field(..., example="docs.crawl4ai.com", description="Domain to discover URLs from")
|
||||
seeding_config: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Configuration for URL discovery using AsyncUrlSeeder",
|
||||
example={
|
||||
"source": "sitemap+cc",
|
||||
"pattern": "*",
|
||||
"live_check": False,
|
||||
"extract_head": False,
|
||||
"max_urls": -1,
|
||||
"concurrency": 1000,
|
||||
"hits_per_sec": 5,
|
||||
"force": False,
|
||||
"verbose": False,
|
||||
"query": None,
|
||||
"score_threshold": None,
|
||||
"scoring_method": "bm25",
|
||||
"filter_nonsense_urls": True
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --- C4A Script Schemas ---
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from api import (
|
||||
handle_markdown_request,
|
||||
handle_seed,
|
||||
handle_stream_crawl_request,
|
||||
handle_url_discovery,
|
||||
stream_results,
|
||||
)
|
||||
from auth import TokenRequest, create_access_token, get_token_dependency
|
||||
@@ -58,6 +59,7 @@ from schemas import (
|
||||
RawCode,
|
||||
ScreenshotRequest,
|
||||
SeedRequest,
|
||||
URLDiscoveryRequest,
|
||||
)
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
@@ -437,6 +439,97 @@ async def seed_url(request: SeedRequest):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/urls/discover",
|
||||
summary="URL Discovery and Seeding",
|
||||
description="Discover and extract crawlable URLs from a domain using AsyncUrlSeeder functionality.",
|
||||
response_description="List of discovered URL objects with metadata",
|
||||
tags=["Core Crawling"]
|
||||
)
|
||||
async def discover_urls(request: URLDiscoveryRequest):
|
||||
"""
|
||||
Discover URLs from a domain using AsyncUrlSeeder functionality.
|
||||
|
||||
This endpoint allows users to find relevant URLs from a domain before
|
||||
committing to a full crawl. It supports various discovery sources like
|
||||
sitemaps and Common Crawl, with filtering and scoring capabilities.
|
||||
|
||||
**Parameters:**
|
||||
- **domain**: Domain to discover URLs from (e.g., "example.com")
|
||||
- **seeding_config**: Configuration object mirroring SeedingConfig parameters
|
||||
- **source**: Discovery source(s) - "sitemap", "cc", or "sitemap+cc" (default: "sitemap+cc")
|
||||
- **pattern**: URL pattern filter using glob-style wildcards (default: "*")
|
||||
- **live_check**: Whether to verify URL liveness with HEAD requests (default: false)
|
||||
- **extract_head**: Whether to fetch and parse <head> metadata (default: false)
|
||||
- **max_urls**: Maximum URLs to discover, -1 for no limit (default: -1)
|
||||
- **concurrency**: Maximum concurrent requests (default: 1000)
|
||||
- **hits_per_sec**: Rate limit in requests per second (default: 5)
|
||||
- **force**: Bypass internal cache and re-fetch URLs (default: false)
|
||||
- **query**: Search query for BM25 relevance scoring (optional)
|
||||
- **scoring_method**: Scoring method when query provided (default: "bm25")
|
||||
- **score_threshold**: Minimum score threshold for filtering (optional)
|
||||
- **filter_nonsense_urls**: Filter out nonsense URLs (default: true)
|
||||
|
||||
**Example Request:**
|
||||
```json
|
||||
{
|
||||
"domain": "docs.crawl4ai.com",
|
||||
"seeding_config": {
|
||||
"source": "sitemap",
|
||||
"pattern": "*/docs/*",
|
||||
"extract_head": true,
|
||||
"max_urls": 50,
|
||||
"query": "API documentation"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example Response:**
|
||||
```json
|
||||
[
|
||||
{
|
||||
"url": "https://docs.crawl4ai.com/api/getting-started",
|
||||
"status": "valid",
|
||||
"head_data": {
|
||||
"title": "Getting Started - Crawl4AI API",
|
||||
"description": "Learn how to get started with Crawl4AI API"
|
||||
},
|
||||
"score": 0.85
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
response = requests.post(
|
||||
"http://localhost:11235/urls/discover",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={
|
||||
"domain": "docs.crawl4ai.com",
|
||||
"seeding_config": {
|
||||
"source": "sitemap+cc",
|
||||
"extract_head": true,
|
||||
"max_urls": 100
|
||||
}
|
||||
}
|
||||
)
|
||||
urls = response.json()
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Returns direct list of URL objects with metadata if requested
|
||||
- Empty list returned if no URLs found
|
||||
- Supports BM25 relevance scoring when query is provided
|
||||
- Can combine multiple sources for maximum coverage
|
||||
"""
|
||||
try:
|
||||
res = await handle_url_discovery(request.domain, request.seeding_config)
|
||||
return JSONResponse(res)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in discover_urls: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/md",
|
||||
summary="Extract Markdown",
|
||||
description="Extract clean markdown content from a URL or raw HTML.",
|
||||
|
||||
@@ -6,7 +6,26 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from fastapi import Request
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
# Import dispatchers from crawl4ai
|
||||
from crawl4ai.async_dispatcher import (
|
||||
BaseDispatcher,
|
||||
MemoryAdaptiveDispatcher,
|
||||
SemaphoreDispatcher,
|
||||
)
|
||||
|
||||
# Import chunking strategies from crawl4ai
|
||||
from crawl4ai.chunking_strategy import (
|
||||
ChunkingStrategy,
|
||||
IdentityChunking,
|
||||
RegexChunking,
|
||||
NlpSentenceChunking,
|
||||
TopicSegmentationChunking,
|
||||
FixedLengthWordChunking,
|
||||
SlidingWindowChunking,
|
||||
OverlappingWindowChunking,
|
||||
)
|
||||
|
||||
# Import dispatchers from crawl4ai
|
||||
from crawl4ai.async_dispatcher import (
|
||||
@@ -303,4 +322,55 @@ def verify_email_domain(email: str) -> bool:
|
||||
records = dns.resolver.resolve(domain, 'MX')
|
||||
return True if records else False
|
||||
except Exception as e:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def create_chunking_strategy(config: Optional[Dict[str, Any]] = None) -> Optional[ChunkingStrategy]:
|
||||
"""
|
||||
Factory function to create chunking strategy instances from configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing 'type' and 'params' keys
|
||||
Example: {"type": "RegexChunking", "params": {"patterns": ["\\n\\n+"]}}
|
||||
|
||||
Returns:
|
||||
ChunkingStrategy instance or None if config is None
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking strategy type is unknown or config is invalid
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"Chunking strategy config must be a dictionary, got {type(config)}")
|
||||
|
||||
if "type" not in config:
|
||||
raise ValueError("Chunking strategy config must contain 'type' field")
|
||||
|
||||
strategy_type = config["type"]
|
||||
params = config.get("params", {})
|
||||
|
||||
# Validate params is a dict
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError(f"Chunking strategy params must be a dictionary, got {type(params)}")
|
||||
|
||||
# Strategy factory mapping
|
||||
strategies = {
|
||||
"IdentityChunking": IdentityChunking,
|
||||
"RegexChunking": RegexChunking,
|
||||
"NlpSentenceChunking": NlpSentenceChunking,
|
||||
"TopicSegmentationChunking": TopicSegmentationChunking,
|
||||
"FixedLengthWordChunking": FixedLengthWordChunking,
|
||||
"SlidingWindowChunking": SlidingWindowChunking,
|
||||
"OverlappingWindowChunking": OverlappingWindowChunking,
|
||||
}
|
||||
|
||||
if strategy_type not in strategies:
|
||||
available = ", ".join(strategies.keys())
|
||||
raise ValueError(f"Unknown chunking strategy type: {strategy_type}. Available: {available}")
|
||||
|
||||
try:
|
||||
return strategies[strategy_type](**params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create {strategy_type} with params {params}: {str(e)}")
|
||||
Reference in New Issue
Block a user