[v0.3.71] Enhance chunking strategies and improve overall performance

- Add OverlappingWindowChunking and improve SlidingWindowChunking
- Update CHUNK_TOKEN_THRESHOLD to 2048 tokens
- Optimize AsyncPlaywrightCrawlerStrategy close method
- Enhance flexibility in CosineStrategy with generic embedding model loading
- Improve JSON-based extraction strategies
- Add knowledge graph generation example
This commit is contained in:
UncleCode
2024-10-19 18:36:59 +08:00
parent b309bc34e1
commit 4e2852d5ff
7 changed files with 118 additions and 18 deletions

1
.gitignore vendored
View File

@@ -206,3 +206,4 @@ git_issues.py
git_issues.md git_issues.md
.tests/ .tests/
.issues/

View File

@@ -134,7 +134,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
async def close(self): async def close(self):
if self.sleep_on_close: if self.sleep_on_close:
await asyncio.sleep(500) await asyncio.sleep(0.5)
if self.browser: if self.browser:
await self.browser.close() await self.browser.close()
self.browser = None self.browser = None

View File

@@ -84,6 +84,12 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Fixed-length word chunks # Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy): class FixedLengthWordChunking(ChunkingStrategy):
def __init__(self, chunk_size=100, **kwargs): def __init__(self, chunk_size=100, **kwargs):
"""
Initialize the fixed-length word chunking strategy with the given chunk size.
Args:
chunk_size (int): The size of each chunk in words.
"""
self.chunk_size = chunk_size self.chunk_size = chunk_size
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
@@ -93,14 +99,64 @@ class FixedLengthWordChunking(ChunkingStrategy):
# Sliding window chunking # Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy): class SlidingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=100, step=50, **kwargs): def __init__(self, window_size=100, step=50, **kwargs):
"""
Initialize the sliding window chunking strategy with the given window size and
step size.
Args:
window_size (int): The size of the sliding window in words.
step (int): The step size for sliding the window in words.
"""
self.window_size = window_size self.window_size = window_size
self.step = step self.step = step
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
words = text.split() words = text.split()
chunks = [] chunks = []
for i in range(0, len(words), self.step):
chunks.append(' '.join(words[i:i + self.window_size])) if len(words) <= self.window_size:
return [text]
for i in range(0, len(words) - self.window_size + 1, self.step):
chunk = ' '.join(words[i:i + self.window_size])
chunks.append(chunk)
# Handle the last chunk if it doesn't align perfectly
if i + self.window_size < len(words):
chunks.append(' '.join(words[-self.window_size:]))
return chunks return chunks
class OverlappingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=1000, overlap=100, **kwargs):
"""
Initialize the overlapping window chunking strategy with the given window size and
overlap size.
Args:
window_size (int): The size of the window in words.
overlap (int): The size of the overlap between consecutive chunks in words.
"""
self.window_size = window_size
self.overlap = overlap
def chunk(self, text: str) -> list:
words = text.split()
chunks = []
if len(words) <= self.window_size:
return [text]
start = 0
while start < len(words):
end = start + self.window_size
chunk = ' '.join(words[start:end])
chunks.append(chunk)
if end >= len(words):
break
start = end - self.overlap
return chunks

View File

@@ -21,7 +21,7 @@ PROVIDER_MODELS = {
# Chunk token threshold # Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 500 CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens
OVERLAP_RATE = 0.1 OVERLAP_RATE = 0.1
WORD_TOKEN_RATE = 1.3 WORD_TOKEN_RATE = 1.3

View File

@@ -234,11 +234,12 @@ class CosineStrategy(ExtractionStrategy):
""" """
Initialize the strategy with clustering parameters. Initialize the strategy with clustering parameters.
:param semantic_filter: A keyword filter for document filtering. Args:
:param word_count_threshold: Minimum number of words per cluster. semantic_filter (str): A keyword filter for document filtering.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters. word_count_threshold (int): Minimum number of words per cluster.
:param linkage_method: The linkage method for hierarchical clustering. max_dist (float): The maximum cophenetic distance on the dendrogram to form clusters.
:param top_k: Number of top categories to extract. linkage_method (str): The linkage method for hierarchical clustering.
top_k (int): Number of top categories to extract.
""" """
super().__init__() super().__init__()
@@ -257,8 +258,8 @@ class CosineStrategy(ExtractionStrategy):
self.get_embedding_method = "direct" self.get_embedding_method = "direct"
self.device = get_device() self.device = get_device()
import torch # import torch
self.device = torch.device('cpu') # self.device = torch.device('cpu')
self.default_batch_size = calculate_batch_size(self.device) self.default_batch_size = calculate_batch_size(self.device)
@@ -271,7 +272,7 @@ class CosineStrategy(ExtractionStrategy):
# self.get_embedding_method = "direct" # self.get_embedding_method = "direct"
# else: # else:
self.tokenizer, self.model = load_bge_small_en_v1_5() self.tokenizer, self.model = load_HF_embedding_model(model_name)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@@ -738,7 +739,6 @@ class JsonCssExtractionStrategy(ExtractionStrategy):
combined_html = self.DEL.join(sections) combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs) return self.extract(url, combined_html, **kwargs)
class JsonXPATHExtractionStrategy(ExtractionStrategy): class JsonXPATHExtractionStrategy(ExtractionStrategy):
def __init__(self, schema: Dict[str, Any], **kwargs): def __init__(self, schema: Dict[str, Any], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -72,10 +72,18 @@ def load_bert_base_uncased():
return tokenizer, model return tokenizer, model
@lru_cache() @lru_cache()
def load_bge_small_en_v1_5(): def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
"""Load the Hugging Face model for embedding.
Args:
model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5".
Returns:
tuple: The tokenizer and model.
"""
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) model = AutoModel.from_pretrained(model_name, resume_download=None)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
return tokenizer, model return tokenizer, model

View File

@@ -10,7 +10,7 @@ import time
import json import json
import os import os
import re import re
from typing import Dict from typing import Dict, List
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
@@ -456,6 +456,41 @@ async def speed_comparison():
print("If you run these tests in an environment with better network conditions,") print("If you run these tests in an environment with better network conditions,")
print("you may observe an even more significant speed advantage for Crawl4AI.") print("you may observe an even more significant speed advantage for Crawl4AI.")
async def generate_knowledge_graph():
class Entity(BaseModel):
name: str
description: str
class Relationship(BaseModel):
entity1: Entity
entity2: Entity
description: str
relation_type: str
class KnowledgeGraph(BaseModel):
entities: List[Entity]
relationships: List[Relationship]
extraction_strategy = LLMExtractionStrategy(
provider='openai/gpt-4o-mini',
api_token=os.getenv('OPENAI_API_KEY'),
schema=KnowledgeGraph.model_json_schema(),
extraction_type="schema",
instruction="""Extract entities and relationships from the given text."""
)
async with AsyncWebCrawler() as crawler:
url = "https://paulgraham.com/love.html"
result = await crawler.arun(
url=url,
bypass_cache=True,
extraction_strategy=extraction_strategy,
# magic=True
)
# print(result.extracted_content)
with open(os.path.join(__location__, "kb.json"), "w") as f:
f.write(result.extracted_content)
async def main(): async def main():
await simple_crawl() await simple_crawl()
await simple_example_with_running_js_code() await simple_example_with_running_js_code()