[v0.3.71] Enhance chunking strategies and improve overall performance

- Add OverlappingWindowChunking and improve SlidingWindowChunking
- Update CHUNK_TOKEN_THRESHOLD to 2048 tokens
- Optimize AsyncPlaywrightCrawlerStrategy close method
- Enhance flexibility in CosineStrategy with generic embedding model loading
- Improve JSON-based extraction strategies
- Add knowledge graph generation example
This commit is contained in:
UncleCode
2024-10-19 18:36:59 +08:00
parent b309bc34e1
commit 4e2852d5ff
7 changed files with 118 additions and 18 deletions

3
.gitignore vendored
View File

@@ -205,4 +205,5 @@ pypi_build.sh
git_issues.py
git_issues.md
.tests/
.tests/
.issues/

View File

@@ -134,7 +134,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
async def close(self):
if self.sleep_on_close:
await asyncio.sleep(500)
await asyncio.sleep(0.5)
if self.browser:
await self.browser.close()
self.browser = None

View File

@@ -84,6 +84,12 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy):
def __init__(self, chunk_size=100, **kwargs):
"""
Initialize the fixed-length word chunking strategy with the given chunk size.
Args:
chunk_size (int): The size of each chunk in words.
"""
self.chunk_size = chunk_size
def chunk(self, text: str) -> list:
@@ -93,14 +99,64 @@ class FixedLengthWordChunking(ChunkingStrategy):
# Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=100, step=50, **kwargs):
"""
Initialize the sliding window chunking strategy with the given window size and
step size.
Args:
window_size (int): The size of the sliding window in words.
step (int): The step size for sliding the window in words.
"""
self.window_size = window_size
self.step = step
def chunk(self, text: str) -> list:
words = text.split()
chunks = []
for i in range(0, len(words), self.step):
chunks.append(' '.join(words[i:i + self.window_size]))
if len(words) <= self.window_size:
return [text]
for i in range(0, len(words) - self.window_size + 1, self.step):
chunk = ' '.join(words[i:i + self.window_size])
chunks.append(chunk)
# Handle the last chunk if it doesn't align perfectly
if i + self.window_size < len(words):
chunks.append(' '.join(words[-self.window_size:]))
return chunks
class OverlappingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=1000, overlap=100, **kwargs):
"""
Initialize the overlapping window chunking strategy with the given window size and
overlap size.
Args:
window_size (int): The size of the window in words.
overlap (int): The size of the overlap between consecutive chunks in words.
"""
self.window_size = window_size
self.overlap = overlap
def chunk(self, text: str) -> list:
words = text.split()
chunks = []
if len(words) <= self.window_size:
return [text]
start = 0
while start < len(words):
end = start + self.window_size
chunk = ' '.join(words[start:end])
chunks.append(chunk)
if end >= len(words):
break
start = end - self.overlap
return chunks

View File

@@ -21,7 +21,7 @@ PROVIDER_MODELS = {
# Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 500
CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens
OVERLAP_RATE = 0.1
WORD_TOKEN_RATE = 1.3

View File

@@ -234,11 +234,12 @@ class CosineStrategy(ExtractionStrategy):
"""
Initialize the strategy with clustering parameters.
:param semantic_filter: A keyword filter for document filtering.
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
:param top_k: Number of top categories to extract.
Args:
semantic_filter (str): A keyword filter for document filtering.
word_count_threshold (int): Minimum number of words per cluster.
max_dist (float): The maximum cophenetic distance on the dendrogram to form clusters.
linkage_method (str): The linkage method for hierarchical clustering.
top_k (int): Number of top categories to extract.
"""
super().__init__()
@@ -257,8 +258,8 @@ class CosineStrategy(ExtractionStrategy):
self.get_embedding_method = "direct"
self.device = get_device()
import torch
self.device = torch.device('cpu')
# import torch
# self.device = torch.device('cpu')
self.default_batch_size = calculate_batch_size(self.device)
@@ -271,7 +272,7 @@ class CosineStrategy(ExtractionStrategy):
# self.get_embedding_method = "direct"
# else:
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.tokenizer, self.model = load_HF_embedding_model(model_name)
self.model.to(self.device)
self.model.eval()
@@ -738,7 +739,6 @@ class JsonCssExtractionStrategy(ExtractionStrategy):
combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs)
class JsonXPATHExtractionStrategy(ExtractionStrategy):
def __init__(self, schema: Dict[str, Any], **kwargs):
super().__init__(**kwargs)

View File

@@ -72,10 +72,18 @@ def load_bert_base_uncased():
return tokenizer, model
@lru_cache()
def load_bge_small_en_v1_5():
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
"""Load the Hugging Face model for embedding.
Args:
model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5".
Returns:
tuple: The tokenizer and model.
"""
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
model = AutoModel.from_pretrained(model_name, resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model

View File

@@ -10,7 +10,7 @@ import time
import json
import os
import re
from typing import Dict
from typing import Dict, List
from bs4 import BeautifulSoup
from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler
@@ -456,6 +456,41 @@ async def speed_comparison():
print("If you run these tests in an environment with better network conditions,")
print("you may observe an even more significant speed advantage for Crawl4AI.")
async def generate_knowledge_graph():
class Entity(BaseModel):
name: str
description: str
class Relationship(BaseModel):
entity1: Entity
entity2: Entity
description: str
relation_type: str
class KnowledgeGraph(BaseModel):
entities: List[Entity]
relationships: List[Relationship]
extraction_strategy = LLMExtractionStrategy(
provider='openai/gpt-4o-mini',
api_token=os.getenv('OPENAI_API_KEY'),
schema=KnowledgeGraph.model_json_schema(),
extraction_type="schema",
instruction="""Extract entities and relationships from the given text."""
)
async with AsyncWebCrawler() as crawler:
url = "https://paulgraham.com/love.html"
result = await crawler.arun(
url=url,
bypass_cache=True,
extraction_strategy=extraction_strategy,
# magic=True
)
# print(result.extracted_content)
with open(os.path.join(__location__, "kb.json"), "w") as f:
f.write(result.extracted_content)
async def main():
await simple_crawl()
await simple_example_with_running_js_code()