chore: Update crawl4ai package with AsyncWebCrawler and JsonCssExtractionStrategy

This commit is contained in:
unclecode
2024-09-03 23:32:27 +08:00
parent c37614cbc8
commit 2fada16abb
4 changed files with 216 additions and 54 deletions

View File

@@ -1 +1,11 @@
from .web_crawler import WebCrawler
from .web_crawler import WebCrawler
from .async_webcrawler import AsyncWebCrawler
from .models import CrawlResult
__version__ = "0.2.77"
__all__ = [
"WebCrawler",
"AsyncWebCrawler",
"CrawlResult",
]

View File

@@ -44,7 +44,8 @@ class AsyncWebCrawler:
await self.crawler_strategy.__aexit__(exc_type, exc_val, exc_tb)
async def awarmup(self):
print("[LOG] 🌤️ Warming up the AsyncWebCrawler")
if self.verbose:
print("[LOG] 🌤️ Warming up the AsyncWebCrawler")
await async_db_manager.ainit_db()
await self.arun(
url="https://google.com/",
@@ -53,7 +54,8 @@ class AsyncWebCrawler:
verbose=False,
)
self.ready = True
print("[LOG] 🌞 AsyncWebCrawler is ready to crawl")
if self.verbose:
print("[LOG] 🌞 AsyncWebCrawler is ready to crawl")
async def arun(
self,
@@ -215,7 +217,7 @@ class AsyncWebCrawler:
)
# Check if extraction strategy is type of JsonCssExtractionStrategy
if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, EnhancedJsonCssExtractionStrategy):
if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, JsonCssExtractionStrategy):
extraction_strategy.verbose = verbose
extracted_content = extraction_strategy.run(url, [html])
extracted_content = json.dumps(extracted_content, indent=4, default=str)

View File

@@ -10,7 +10,7 @@ from functools import partial
from .model_loader import *
import math
import numpy as np
from lxml import etree
class ExtractionStrategy(ABC):
"""
@@ -623,60 +623,12 @@ class ContentSummarizationStrategy(ExtractionStrategy):
# Sort summaries by the original section index to maintain order
summaries.sort(key=lambda x: x[0])
return [summary for _, summary in summaries]
class JsonCssExtractionStrategy(ExtractionStrategy):
def __init__(self, schema: Dict[str, Any], **kwargs):
super().__init__(**kwargs)
self.schema = schema
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
soup = BeautifulSoup(html, 'html.parser')
base_elements = soup.select(self.schema['baseSelector'])
results = []
for element in base_elements:
item = {}
for field in self.schema['fields']:
value = self._extract_field(element, field)
if value is not None:
item[field['name']] = value
if item:
results.append(item)
return results
def _extract_field(self, element, field):
try:
selected = element.select_one(field['selector'])
if not selected:
return None
if field['type'] == 'text':
return selected.get_text(strip=True)
elif field['type'] == 'attribute':
return selected.get(field['attribute'])
elif field['type'] == 'html':
return str(selected)
elif field['type'] == 'regex':
text = selected.get_text(strip=True)
match = re.search(field['pattern'], text)
return match.group(1) if match else None
except Exception as e:
if self.verbose:
print(f"Error extracting field {field['name']}: {str(e)}")
return None
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs)
class EnhancedJsonCssExtractionStrategy(ExtractionStrategy):
def __init__(self, schema: Dict[str, Any], **kwargs):
super().__init__(**kwargs)
self.schema = schema
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
soup = BeautifulSoup(html, 'html.parser')
base_elements = soup.select(self.schema['baseSelector'])
@@ -775,6 +727,137 @@ class EnhancedJsonCssExtractionStrategy(ExtractionStrategy):
print(f"Error computing field {field['name']}: {str(e)}")
return field.get('default')
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs)
class JsonXPATHExtractionStrategy(ExtractionStrategy):
def __init__(self, schema: Dict[str, Any], **kwargs):
super().__init__(**kwargs)
self.schema = schema
self.use_cssselect = self._check_cssselect()
def _check_cssselect(self):
try:
import cssselect
return True
except ImportError:
print("Warning: cssselect is not installed. Falling back to XPath for all selectors.")
return False
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
self.soup = BeautifulSoup(html, 'lxml')
self.tree = etree.HTML(str(self.soup))
selector_type = 'xpath' if not self.use_cssselect else self.schema.get('selectorType', 'css')
base_selector = self.schema.get('baseXPath' if selector_type == 'xpath' else 'baseSelector')
base_elements = self._select_elements(base_selector, selector_type)
results = []
for element in base_elements:
item = self._extract_item(element, self.schema['fields'])
if item:
results.append(item)
return results
def _select_elements(self, selector, selector_type, element=None):
if selector_type == 'xpath' or not self.use_cssselect:
return self.tree.xpath(selector) if element is None else element.xpath(selector)
else: # CSS
return self.tree.cssselect(selector) if element is None else element.cssselect(selector)
def _extract_field(self, element, field):
try:
selector_type = 'xpath' if not self.use_cssselect else field.get('selectorType', 'css')
selector = field.get('xpathSelector' if selector_type == 'xpath' else 'selector')
if field['type'] == 'nested':
nested_element = self._select_elements(selector, selector_type, element)
return self._extract_item(nested_element[0], field['fields']) if nested_element else {}
if field['type'] == 'list':
elements = self._select_elements(selector, selector_type, element)
return [self._extract_list_item(el, field['fields']) for el in elements]
if field['type'] == 'nested_list':
elements = self._select_elements(selector, selector_type, element)
return [self._extract_item(el, field['fields']) for el in elements]
return self._extract_single_field(element, field)
except Exception as e:
if self.verbose:
print(f"Error extracting field {field['name']}: {str(e)}")
return field.get('default')
def _extract_list_item(self, element, fields):
item = {}
for field in fields:
value = self._extract_single_field(element, field)
if value is not None:
item[field['name']] = value
return item
def _extract_single_field(self, element, field):
selector_type = field.get('selectorType', 'css')
if 'selector' in field:
selected = self._select_elements(field['selector'], selector_type, element)
if not selected:
return field.get('default')
selected = selected[0]
else:
selected = element
value = None
if field['type'] == 'text':
value = selected.text_content().strip() if hasattr(selected, 'text_content') else selected.text.strip()
elif field['type'] == 'attribute':
value = selected.get(field['attribute'])
elif field['type'] == 'html':
value = etree.tostring(selected, encoding='unicode')
elif field['type'] == 'regex':
text = selected.text_content().strip() if hasattr(selected, 'text_content') else selected.text.strip()
match = re.search(field['pattern'], text)
value = match.group(1) if match else None
if 'transform' in field:
value = self._apply_transform(value, field['transform'])
return value if value is not None else field.get('default')
def _extract_item(self, element, fields):
item = {}
for field in fields:
if field['type'] == 'computed':
value = self._compute_field(item, field)
else:
value = self._extract_field(element, field)
if value is not None:
item[field['name']] = value
return item
def _apply_transform(self, value, transform):
if transform == 'lowercase':
return value.lower()
elif transform == 'uppercase':
return value.upper()
elif transform == 'strip':
return value.strip()
return value
def _compute_field(self, item, field):
try:
if 'expression' in field:
return eval(field['expression'], {}, item)
elif 'function' in field:
return field['function'](item)
except Exception as e:
if self.verbose:
print(f"Error computing field {field['name']}: {str(e)}")
return field.get('default')
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
combined_html = self.DEL.join(sections)
return self.extract(url, combined_html, **kwargs)

View File

@@ -0,0 +1,67 @@
import os, time
# append the path to the root of the project
import sys
import asyncio
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from firecrawl import FirecrawlApp
from crawl4ai import AsyncWebCrawler
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
async def compare():
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
# Tet Firecrawl with a simple crawl
start = time.time()
scrape_status = app.scrape_url(
'https://www.nbcnews.com/business',
params={'formats': ['markdown', 'html']}
)
end = time.time()
print(f"Time taken: {end - start} seconds")
print(len(scrape_status['markdown']))
# save the markdown content with provider name
with open(f"{__data__}/firecrawl_simple.md", "w") as f:
f.write(scrape_status['markdown'])
# Count how many "cldnry.s-nbcnews.com" are in the markdown
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com"))
async with AsyncWebCrawler() as crawler:
start = time.time()
result = await crawler.arun(
url="https://www.nbcnews.com/business",
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
word_count_threshold=0,
bypass_cache=True,
verbose=False
)
end = time.time()
print(f"Time taken: {end - start} seconds")
print(len(result.markdown))
# save the markdown content with provider name
with open(f"{__data__}/crawl4ai_simple.md", "w") as f:
f.write(result.markdown)
# count how many "cldnry.s-nbcnews.com" are in the markdown
print(result.markdown.count("cldnry.s-nbcnews.com"))
start = time.time()
result = await crawler.arun(
url="https://www.nbcnews.com/business",
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
word_count_threshold=0,
bypass_cache=True,
verbose=False
)
end = time.time()
print(f"Time taken: {end - start} seconds")
print(len(result.markdown))
# save the markdown content with provider name
with open(f"{__data__}/crawl4ai_js.md", "w") as f:
f.write(result.markdown)
# count how many "cldnry.s-nbcnews.com" are in the markdown
print(result.markdown.count("cldnry.s-nbcnews.com"))
if __name__ == "__main__":
asyncio.run(compare())