Compare commits

..

1 Commits

Author SHA1 Message Date
ntohidi
95051020f4 fix(docker): Fix LLM API key handling for multi-provider support
Previously, the system incorrectly used OPENAI_API_KEY for all LLM providers
due to a hardcoded api_key_env fallback in config.yml. This caused authentication
errors when using non-OpenAI providers like Gemini.

Changes:
- Remove api_key_env from config.yml to let litellm handle provider-specific env vars
- Simplify get_llm_api_key() to return None, allowing litellm to auto-detect keys
- Update validate_llm_provider() to trust litellm's built-in key detection
- Update documentation to reflect the new automatic key handling

The fix leverages litellm's existing capability to automatically find the correct
environment variable for each provider (OPENAI_API_KEY, GEMINI_API_TOKEN, etc.)
without manual configuration.

ref #1291
2025-08-21 14:01:04 +08:00
11 changed files with 26 additions and 272 deletions

View File

@@ -97,16 +97,13 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
if value != param.default and not ignore_default_value:
current_values[name] = to_serializable_dict(value)
# Don't serialize private __slots__ - they're internal implementation details
# not constructor parameters. This was causing URLPatternFilter to fail
# because _simple_suffixes was being serialized as 'simple_suffixes'
# if hasattr(obj, '__slots__'):
# for slot in obj.__slots__:
# if slot.startswith('_'): # Handle private slots
# attr_name = slot[1:] # Remove leading '_'
# value = getattr(obj, slot, None)
# if value is not None:
# current_values[attr_name] = to_serializable_dict(value)
if hasattr(obj, '__slots__'):
for slot in obj.__slots__:
if slot.startswith('_'): # Handle private slots
attr_name = slot[1:] # Remove leading '_'
value = getattr(obj, slot, None)
if value is not None:
current_values[attr_name] = to_serializable_dict(value)

View File

@@ -47,13 +47,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
self.url_scorer = url_scorer
self.include_external = include_external
self.max_pages = max_pages
# self.logger = logger or logging.getLogger(__name__)
# Ensure logger is always a Logger instance, not a dict from serialization
if isinstance(logger, logging.Logger):
self.logger = logger
else:
# Create a new logger if logger is None, dict, or any other non-Logger type
self.logger = logging.getLogger(__name__)
self.logger = logger or logging.getLogger(__name__)
self.stats = TraversalStats(start_time=datetime.now())
self._cancel_event = asyncio.Event()
self._pages_crawled = 0

View File

@@ -38,13 +38,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
self.include_external = include_external
self.score_threshold = score_threshold
self.max_pages = max_pages
# self.logger = logger or logging.getLogger(__name__)
# Ensure logger is always a Logger instance, not a dict from serialization
if isinstance(logger, logging.Logger):
self.logger = logger
else:
# Create a new logger if logger is None, dict, or any other non-Logger type
self.logger = logging.getLogger(__name__)
self.logger = logger or logging.getLogger(__name__)
self.stats = TraversalStats(start_time=datetime.now())
self._cancel_event = asyncio.Event()
self._pages_crawled = 0

View File

@@ -120,9 +120,6 @@ class URLPatternFilter(URLFilter):
"""Pattern filter balancing speed and completeness"""
__slots__ = (
"patterns", # Store original patterns for serialization
"use_glob", # Store original use_glob for serialization
"reverse", # Store original reverse for serialization
"_simple_suffixes",
"_simple_prefixes",
"_domain_patterns",
@@ -145,11 +142,6 @@ class URLPatternFilter(URLFilter):
reverse: bool = False,
):
super().__init__()
# Store original constructor params for serialization
self.patterns = patterns
self.use_glob = use_glob
self.reverse = reverse
self._reverse = reverse
patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns

View File

@@ -253,16 +253,6 @@ class CrawlResult(BaseModel):
requirements change, this is where you would update the logic.
"""
result = super().model_dump(*args, **kwargs)
# Remove any property descriptors that might have been included
# These deprecated properties should not be in the serialized output
for key in ['fit_html', 'fit_markdown', 'markdown_v2']:
if key in result and isinstance(result[key], property):
# del result[key]
# Nasrin: I decided to convert it to string instead of removing it.
result[key] = str(result[key])
# Add the markdown field properly
if self._markdown is not None:
result["markdown"] = self._markdown.model_dump()
return result

View File

@@ -692,8 +692,7 @@ app:
# Default LLM Configuration
llm:
provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored
# api_key: sk-... # If you pass the API key directly (not recommended)
# Redis Configuration (Used by internal Redis server managed by supervisord)
redis:

View File

@@ -96,7 +96,7 @@ async def handle_llm_qa(
response = perform_completion_with_backoff(
provider=config["llm"]["provider"],
prompt_with_variables=prompt,
api_token=get_llm_api_key(config)
api_token=get_llm_api_key(config) # Returns None to let litellm handle it
)
return response.choices[0].message.content
@@ -127,7 +127,7 @@ async def process_llm_extraction(
"error": error_msg
})
return
api_key = get_llm_api_key(config, provider)
api_key = get_llm_api_key(config, provider) # Returns None to let litellm handle it
llm_strategy = LLMExtractionStrategy(
llm_config=LLMConfig(
provider=provider or config["llm"]["provider"],
@@ -203,7 +203,7 @@ async def handle_markdown_request(
FilterType.LLM: LLMContentFilter(
llm_config=LLMConfig(
provider=provider or config["llm"]["provider"],
api_token=get_llm_api_key(config, provider),
api_token=get_llm_api_key(config, provider), # Returns None to let litellm handle it
),
instruction=query or "Extract main content"
)

View File

@@ -11,8 +11,7 @@ app:
# Default LLM Configuration
llm:
provider: "openai/gpt-4o-mini"
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored
# api_key: sk-... # If you pass the API key directly (not recommended)
# Redis Configuration
redis:

View File

@@ -71,7 +71,7 @@ def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> str:
def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> Optional[str]:
"""Get the appropriate API key based on the LLM provider.
Args:
@@ -79,19 +79,14 @@ def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> str:
provider: Optional provider override (e.g., "openai/gpt-4")
Returns:
The API key for the provider, or empty string if not found
The API key if directly configured, otherwise None to let litellm handle it
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Check if direct API key is configured
# Check if direct API key is configured (for backward compatibility)
if "api_key" in config["llm"]:
return config["llm"]["api_key"]
# Fall back to the configured api_key_env if no match
return os.environ.get(config["llm"].get("api_key_env", ""), "")
# Return None - litellm will automatically find the right environment variable
return None
def validate_llm_provider(config: Dict, provider: Optional[str] = None) -> tuple[bool, str]:
@@ -104,16 +99,12 @@ def validate_llm_provider(config: Dict, provider: Optional[str] = None) -> tuple
Returns:
Tuple of (is_valid, error_message)
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Get the API key for this provider
api_key = get_llm_api_key(config, provider)
if not api_key:
return False, f"No API key found for provider '{provider}'. Please set the appropriate environment variable."
# If a direct API key is configured, validation passes
if "api_key" in config["llm"]:
return True, ""
# Otherwise, trust that litellm will find the appropriate environment variable
# We can't easily validate this without reimplementing litellm's logic
return True, ""

View File

@@ -176,7 +176,7 @@ The Docker setup now supports flexible LLM provider configuration through three
3. **Config File Default**: Falls back to `config.yml` (default: `openai/gpt-4o-mini`)
The system automatically selects the appropriate API key based on the configured `api_key_env` in the config file.
The system automatically selects the appropriate API key based on the provider. LiteLLM handles finding the correct environment variable for each provider (e.g., OPENAI_API_KEY for OpenAI, GEMINI_API_TOKEN for Google Gemini, etc.).
#### 3. Build and Run with Compose
@@ -693,8 +693,7 @@ app:
# Default LLM Configuration
llm:
provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored
# api_key: sk-... # If you pass the API key directly (not recommended)
# Redis Configuration (Used by internal Redis server managed by supervisord)
redis:

View File

@@ -1,201 +0,0 @@
"""
Test the complete fix for both the filter serialization and JSON serialization issues.
"""
import asyncio
import httpx
from crawl4ai import BrowserConfig, CacheMode, CrawlerRunConfig
from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, FilterChain, URLPatternFilter
BASE_URL = "http://localhost:11234/" # Adjust port as needed
async def test_with_docker_client():
"""Test using the Docker client (same as 1419.py)."""
from crawl4ai.docker_client import Crawl4aiDockerClient
print("=" * 60)
print("Testing with Docker Client")
print("=" * 60)
try:
async with Crawl4aiDockerClient(
base_url=BASE_URL,
verbose=True,
) as client:
# Create filter chain - testing the serialization fix
filter_chain = [
URLPatternFilter(
# patterns=["*about*", "*privacy*", "*terms*"],
patterns=["*advanced*"],
reverse=True
),
]
crawler_config = CrawlerRunConfig(
deep_crawl_strategy=BFSDeepCrawlStrategy(
max_depth=2, # Keep it shallow for testing
# max_pages=5, # Limit pages for testing
filter_chain=FilterChain(filter_chain)
),
cache_mode=CacheMode.BYPASS,
)
print("\n1. Testing crawl with filters...")
results = await client.crawl(
["https://docs.crawl4ai.com"], # Simple test page
browser_config=BrowserConfig(headless=True),
crawler_config=crawler_config,
)
if results:
print(f"✅ Crawl succeeded! Type: {type(results)}")
if hasattr(results, 'success'):
print(f"✅ Results success: {results.success}")
# Test that we can iterate results without JSON errors
if hasattr(results, '__iter__'):
for i, result in enumerate(results):
if hasattr(result, 'url'):
print(f" Result {i}: {result.url[:50]}...")
else:
print(f" Result {i}: {str(result)[:50]}...")
else:
# Handle list of results
print(f"✅ Got {len(results)} results")
for i, result in enumerate(results[:3]): # Show first 3
print(f" Result {i}: {result.url[:50]}...")
else:
print("❌ Crawl failed - no results returned")
return False
print("\n✅ Docker client test completed successfully!")
return True
except Exception as e:
print(f"❌ Docker client test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_with_rest_api():
"""Test using REST API directly."""
print("\n" + "=" * 60)
print("Testing with REST API")
print("=" * 60)
# Create filter configuration
deep_crawl_strategy_payload = {
"type": "BFSDeepCrawlStrategy",
"params": {
"max_depth": 2,
# "max_pages": 5,
"filter_chain": {
"type": "FilterChain",
"params": {
"filters": [
{
"type": "URLPatternFilter",
"params": {
"patterns": ["*advanced*"],
"reverse": True
}
}
]
}
}
}
}
crawl_payload = {
"urls": ["https://docs.crawl4ai.com"],
"browser_config": {"type": "BrowserConfig", "params": {"headless": True}},
"crawler_config": {
"type": "CrawlerRunConfig",
"params": {
"deep_crawl_strategy": deep_crawl_strategy_payload,
"cache_mode": "bypass"
}
}
}
try:
async with httpx.AsyncClient() as client:
print("\n1. Sending crawl request to REST API...")
response = await client.post(
f"{BASE_URL}crawl",
json=crawl_payload,
timeout=30
)
if response.status_code == 200:
print(f"✅ REST API returned 200 OK")
data = response.json()
if data.get("success"):
results = data.get("results", [])
print(f"✅ Got {len(results)} results")
for i, result in enumerate(results[:3]):
print(f" Result {i}: {result.get('url', 'unknown')[:50]}...")
else:
print(f"❌ Crawl not successful: {data}")
return False
else:
print(f"❌ REST API returned {response.status_code}")
print(f" Response: {response.text[:500]}")
return False
print("\n✅ REST API test completed successfully!")
return True
except Exception as e:
print(f"❌ REST API test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests."""
print("\n🧪 TESTING COMPLETE FIX FOR DOCKER FILTER AND JSON ISSUES")
print("=" * 60)
print("Make sure the server is running with the updated code!")
print("=" * 60)
results = []
# Test 1: Docker client
docker_passed = await test_with_docker_client()
results.append(("Docker Client", docker_passed))
# Test 2: REST API
rest_passed = await test_with_rest_api()
results.append(("REST API", rest_passed))
# Summary
print("\n" + "=" * 60)
print("FINAL TEST SUMMARY")
print("=" * 60)
all_passed = True
for test_name, passed in results:
status = "✅ PASSED" if passed else "❌ FAILED"
print(f"{test_name:20} {status}")
if not passed:
all_passed = False
print("=" * 60)
if all_passed:
print("🎉 ALL TESTS PASSED! Both issues are fully resolved!")
print("\nThe fixes:")
print("1. Filter serialization: Fixed by not serializing private __slots__")
print("2. JSON serialization: Fixed by removing property descriptors from model_dump()")
else:
print("⚠️ Some tests failed. Please check the server logs for details.")
return 0 if all_passed else 1
if __name__ == "__main__":
import sys
sys.exit(asyncio.run(main()))