diff --git a/docs/examples/docker_example.py b/docs/examples/docker_example.py new file mode 100644 index 00000000..c22acd55 --- /dev/null +++ b/docs/examples/docker_example.py @@ -0,0 +1,300 @@ +import requests +import json +import time +import sys +import base64 +import os +from typing import Dict, Any + +class Crawl4AiTester: + def __init__(self, base_url: str = "http://localhost:11235"): + self.base_url = base_url + + def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: + # Submit crawl job + response = requests.post(f"{self.base_url}/crawl", json=request_data) + task_id = response.json()["task_id"] + print(f"Task ID: {task_id}") + + # Poll for result + start_time = time.time() + while True: + if time.time() - start_time > timeout: + raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") + + result = requests.get(f"{self.base_url}/task/{task_id}") + status = result.json() + + if status["status"] == "failed": + print("Task failed:", status.get("error")) + raise Exception(f"Task failed: {status.get('error')}") + + if status["status"] == "completed": + return status + + time.sleep(2) + +def test_docker_deployment(version="basic"): + tester = Crawl4AiTester() + print(f"Testing Crawl4AI Docker {version} version") + + # Health check with timeout and retry + max_retries = 5 + for i in range(max_retries): + try: + health = requests.get(f"{tester.base_url}/health", timeout=10) + print("Health check:", health.json()) + break + except requests.exceptions.RequestException as e: + if i == max_retries - 1: + print(f"Failed to connect after {max_retries} attempts") + sys.exit(1) + print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") + time.sleep(5) + + # Test cases based on version + test_basic_crawl(tester) + + # if version in ["full", "transformer"]: + # test_cosine_extraction(tester) + + # test_js_execution(tester) + # test_css_selector(tester) + # test_structured_extraction(tester) + # test_llm_extraction(tester) + # test_llm_with_ollama(tester) + # test_screenshot(tester) + + +def test_basic_crawl(tester: Crawl4AiTester): + print("\n=== Testing Basic Crawl ===") + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 10 + } + + result = tester.submit_and_wait(request) + print(f"Basic crawl result length: {len(result['result']['markdown'])}") + assert result["result"]["success"] + assert len(result["result"]["markdown"]) > 0 + +def test_js_execution(tester: Crawl4AiTester): + print("\n=== Testing JS Execution ===") + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 8, + "js_code": [ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ], + "wait_for": "article.tease-card:nth-child(10)", + "crawler_params": { + "headless": True + } + } + + result = tester.submit_and_wait(request) + print(f"JS execution result length: {len(result['result']['markdown'])}") + assert result["result"]["success"] + +def test_css_selector(tester: Crawl4AiTester): + print("\n=== Testing CSS Selector ===") + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 7, + "css_selector": ".wide-tease-item__description", + "crawler_params": { + "headless": True + }, + "extra": {"word_count_threshold": 10} + + } + + result = tester.submit_and_wait(request) + print(f"CSS selector result length: {len(result['result']['markdown'])}") + assert result["result"]["success"] + +def test_structured_extraction(tester: Crawl4AiTester): + print("\n=== Testing Structured Extraction ===") + schema = { + "name": "Coinbase Crypto Prices", + "baseSelector": ".cds-tableRow-t45thuk", + "fields": [ + { + "name": "crypto", + "selector": "td:nth-child(1) h2", + "type": "text", + }, + { + "name": "symbol", + "selector": "td:nth-child(1) p", + "type": "text", + }, + { + "name": "price", + "selector": "td:nth-child(2)", + "type": "text", + } + ], + } + + request = { + "urls": "https://www.coinbase.com/explore", + "priority": 9, + "extraction_config": { + "type": "json_css", + "params": { + "schema": schema + } + } + } + + result = tester.submit_and_wait(request) + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted {len(extracted)} items") + print("Sample item:", json.dumps(extracted[0], indent=2)) + assert result["result"]["success"] + assert len(extracted) > 0 + +def test_llm_extraction(tester: Crawl4AiTester): + print("\n=== Testing LLM Extraction ===") + schema = { + "type": "object", + "properties": { + "model_name": { + "type": "string", + "description": "Name of the OpenAI model." + }, + "input_fee": { + "type": "string", + "description": "Fee for input token for the OpenAI model." + }, + "output_fee": { + "type": "string", + "description": "Fee for output token for the OpenAI model." + } + }, + "required": ["model_name", "input_fee", "output_fee"] + } + + request = { + "urls": "https://openai.com/api/pricing", + "priority": 8, + "extraction_config": { + "type": "llm", + "params": { + "provider": "openai/gpt-4o-mini", + "api_token": os.getenv("OPENAI_API_KEY"), + "schema": schema, + "extraction_type": "schema", + "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" + } + }, + "crawler_params": {"word_count_threshold": 1} + } + + try: + result = tester.submit_and_wait(request) + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted {len(extracted)} model pricing entries") + print("Sample entry:", json.dumps(extracted[0], indent=2)) + assert result["result"]["success"] + except Exception as e: + print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") + +def test_llm_with_ollama(tester: Crawl4AiTester): + print("\n=== Testing LLM with Ollama ===") + schema = { + "type": "object", + "properties": { + "article_title": { + "type": "string", + "description": "The main title of the news article" + }, + "summary": { + "type": "string", + "description": "A brief summary of the article content" + }, + "main_topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Main topics or themes discussed in the article" + } + } + } + + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 8, + "extraction_config": { + "type": "llm", + "params": { + "provider": "ollama/llama2", + "schema": schema, + "extraction_type": "schema", + "instruction": "Extract the main article information including title, summary, and main topics." + } + }, + "extra": {"word_count_threshold": 1}, + "crawler_params": {"verbose": True} + } + + try: + result = tester.submit_and_wait(request) + extracted = json.loads(result["result"]["extracted_content"]) + print("Extracted content:", json.dumps(extracted, indent=2)) + assert result["result"]["success"] + except Exception as e: + print(f"Ollama extraction test failed: {str(e)}") + +def test_cosine_extraction(tester: Crawl4AiTester): + print("\n=== Testing Cosine Extraction ===") + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 8, + "extraction_config": { + "type": "cosine", + "params": { + "semantic_filter": "business finance economy", + "word_count_threshold": 10, + "max_dist": 0.2, + "top_k": 3 + } + } + } + + try: + result = tester.submit_and_wait(request) + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted {len(extracted)} text clusters") + print("First cluster tags:", extracted[0]["tags"]) + assert result["result"]["success"] + except Exception as e: + print(f"Cosine extraction test failed: {str(e)}") + +def test_screenshot(tester: Crawl4AiTester): + print("\n=== Testing Screenshot ===") + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 5, + "screenshot": True, + "crawler_params": { + "headless": True + } + } + + result = tester.submit_and_wait(request) + print("Screenshot captured:", bool(result["result"]["screenshot"])) + + if result["result"]["screenshot"]: + # Save screenshot + screenshot_data = base64.b64decode(result["result"]["screenshot"]) + with open("test_screenshot.jpg", "wb") as f: + f.write(screenshot_data) + print("Screenshot saved as test_screenshot.jpg") + + assert result["result"]["success"] + +if __name__ == "__main__": + version = sys.argv[1] if len(sys.argv) > 1 else "basic" + # version = "full" + test_docker_deployment(version) \ No newline at end of file