feat(docker): add flexible LLM provider configuration

- Support LLM_PROVIDER env var to override default provider (openai/gpt-4o-mini)
- Add optional 'provider' parameter to API endpoints for per-request overrides
- Implement provider validation to ensure API keys exist
- Update documentation and examples with new configuration options

Closes the need to hardcode providers in config.yml
This commit is contained in:
ntohidi
2025-08-05 14:09:54 +08:00
parent 31a435fb0e
commit ff6ea41ac3
11 changed files with 290 additions and 23 deletions

View File

@@ -21,6 +21,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- **Flexible LLM Provider Configuration** (Docker):
- Support for `LLM_PROVIDER` environment variable to override default provider
- Per-request provider override via optional `provider` parameter in API endpoints
- Automatic provider validation with clear error messages
- Updated Docker documentation and examples
### Changed
- **WebScrapingStrategy Refactoring**: Simplified content scraping architecture
- `WebScrapingStrategy` is now an alias for `LXMLWebScrapingStrategy` for backward compatibility

View File

@@ -5,4 +5,9 @@ ANTHROPIC_API_KEY=your_anthropic_key_here
GROQ_API_KEY=your_groq_key_here
TOGETHER_API_KEY=your_together_key_here
MISTRAL_API_KEY=your_mistral_key_here
GEMINI_API_TOKEN=your_gemini_key_here
GEMINI_API_TOKEN=your_gemini_key_here
# Optional: Override the default LLM provider
# Examples: "openai/gpt-4", "anthropic/claude-3-opus", "deepseek/chat", etc.
# If not set, uses the provider specified in config.yml (default: openai/gpt-4o-mini)
# LLM_PROVIDER=anthropic/claude-3-opus

View File

@@ -154,6 +154,29 @@ cp deploy/docker/.llm.env.example .llm.env
# Now edit .llm.env and add your API keys
```
**Flexible LLM Provider Configuration:**
The Docker setup now supports flexible LLM provider configuration through three methods:
1. **Environment Variable** (Highest Priority): Set `LLM_PROVIDER` to override the default
```bash
export LLM_PROVIDER="anthropic/claude-3-opus"
# Or in your .llm.env file:
# LLM_PROVIDER=anthropic/claude-3-opus
```
2. **API Request Parameter**: Specify provider per request
```json
{
"url": "https://example.com",
"provider": "groq/mixtral-8x7b"
}
```
3. **Config File Default**: Falls back to `config.yml` (default: `openai/gpt-4o-mini`)
The system automatically selects the appropriate API key based on the provider.
#### 3. Build and Run with Compose
The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx.
@@ -668,7 +691,7 @@ app:
# Default LLM Configuration
llm:
provider: "openai/gpt-4o-mini"
provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored

View File

@@ -40,7 +40,9 @@ from utils import (
get_base_url,
is_task_id,
should_cleanup_task,
decode_redis_hash
decode_redis_hash,
get_llm_api_key,
validate_llm_provider
)
import psutil, time
@@ -89,10 +91,12 @@ async def handle_llm_qa(
Answer:"""
# api_token=os.environ.get(config["llm"].get("api_key_env", ""))
response = perform_completion_with_backoff(
provider=config["llm"]["provider"],
prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
api_token=get_llm_api_key(config)
)
return response.choices[0].message.content
@@ -110,19 +114,23 @@ async def process_llm_extraction(
url: str,
instruction: str,
schema: Optional[str] = None,
cache: str = "0"
cache: str = "0",
provider: Optional[str] = None
) -> None:
"""Process LLM extraction in background."""
try:
# If config['llm'] has api_key then ignore the api_key_env
api_key = ""
if "api_key" in config["llm"]:
api_key = config["llm"]["api_key"]
else:
api_key = os.environ.get(config["llm"].get("api_key_env", None), "")
# Validate provider
is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": error_msg
})
return
api_key = get_llm_api_key(config, provider)
llm_strategy = LLMExtractionStrategy(
llm_config=LLMConfig(
provider=config["llm"]["provider"],
provider=provider or config["llm"]["provider"],
api_token=api_key
),
instruction=instruction,
@@ -169,10 +177,19 @@ async def handle_markdown_request(
filter_type: FilterType,
query: Optional[str] = None,
cache: str = "0",
config: Optional[dict] = None
config: Optional[dict] = None,
provider: Optional[str] = None
) -> str:
"""Handle markdown generation requests."""
try:
# Validate provider if using LLM filter
if filter_type == FilterType.LLM:
is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg
)
decoded_url = unquote(url)
if not decoded_url.startswith(('http://', 'https://')):
decoded_url = 'https://' + decoded_url
@@ -185,8 +202,8 @@ async def handle_markdown_request(
FilterType.BM25: BM25ContentFilter(user_query=query or ""),
FilterType.LLM: LLMContentFilter(
llm_config=LLMConfig(
provider=config["llm"]["provider"],
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""),
provider=provider or config["llm"]["provider"],
api_token=get_llm_api_key(config, provider),
),
instruction=query or "Extract main content"
)
@@ -230,7 +247,8 @@ async def handle_llm_request(
query: Optional[str] = None,
schema: Optional[str] = None,
cache: str = "0",
config: Optional[dict] = None
config: Optional[dict] = None,
provider: Optional[str] = None
) -> JSONResponse:
"""Handle LLM extraction requests."""
base_url = get_base_url(request)
@@ -260,7 +278,8 @@ async def handle_llm_request(
schema,
cache,
base_url,
config
config,
provider
)
except Exception as e:
@@ -304,7 +323,8 @@ async def create_new_task(
schema: Optional[str],
cache: str,
base_url: str,
config: dict
config: dict,
provider: Optional[str] = None
) -> JSONResponse:
"""Create and initialize a new task."""
decoded_url = unquote(input_path)
@@ -328,7 +348,8 @@ async def create_new_task(
decoded_url,
query,
schema,
cache
cache,
provider
)
return JSONResponse({

View File

@@ -36,6 +36,7 @@ class LlmJobPayload(BaseModel):
q: str
schema: Optional[str] = None
cache: bool = False
provider: Optional[str] = None
class CrawlJobPayload(BaseModel):
@@ -61,6 +62,7 @@ async def llm_job_enqueue(
schema=payload.schema,
cache=payload.cache,
config=_config,
provider=payload.provider,
)

View File

@@ -15,6 +15,7 @@ class MarkdownRequest(BaseModel):
f: FilterType = Field(FilterType.FIT, description="Contentfilter strategy: fit, raw, bm25, or llm")
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
c: Optional[str] = Field("0", description="Cachebust / revision counter")
provider: Optional[str] = Field(None, description="LLM provider override (e.g., 'anthropic/claude-3-opus')")
class RawCode(BaseModel):

View File

@@ -241,7 +241,7 @@ async def get_markdown(
raise HTTPException(
400, "URL must be absolute and start with http/https")
markdown = await handle_markdown_request(
body.url, body.f, body.q, body.c, config
body.url, body.f, body.q, body.c, config, body.provider
)
return JSONResponse({
"url": body.url,

View File

@@ -1,6 +1,7 @@
import dns.resolver
import logging
import yaml
import os
from datetime import datetime
from enum import Enum
from pathlib import Path
@@ -19,10 +20,24 @@ class FilterType(str, Enum):
LLM = "llm"
def load_config() -> Dict:
"""Load and return application configuration."""
"""Load and return application configuration with environment variable overrides."""
config_path = Path(__file__).parent / "config.yml"
with open(config_path, "r") as config_file:
return yaml.safe_load(config_file)
config = yaml.safe_load(config_file)
# Override LLM provider from environment if set
llm_provider = os.environ.get("LLM_PROVIDER")
if llm_provider:
config["llm"]["provider"] = llm_provider
logging.info(f"LLM provider overridden from environment: {llm_provider}")
# Also support direct API key from environment if the provider-specific key isn't set
llm_api_key = os.environ.get("LLM_API_KEY")
if llm_api_key and "api_key" not in config["llm"]:
config["llm"]["api_key"] = llm_api_key
logging.info("LLM API key loaded from LLM_API_KEY environment variable")
return config
def setup_logging(config: Dict) -> None:
"""Configure application logging."""
@@ -56,6 +71,52 @@ def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> str:
"""Get the appropriate API key based on the LLM provider.
Args:
config: The application configuration dictionary
provider: Optional provider override (e.g., "openai/gpt-4")
Returns:
The API key for the provider, or empty string if not found
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Check if direct API key is configured
if "api_key" in config["llm"]:
return config["llm"]["api_key"]
# Fall back to the configured api_key_env if no match
return os.environ.get(config["llm"].get("api_key_env", ""), "")
def validate_llm_provider(config: Dict, provider: Optional[str] = None) -> tuple[bool, str]:
"""Validate that the LLM provider has an associated API key.
Args:
config: The application configuration dictionary
provider: Optional provider override (e.g., "openai/gpt-4")
Returns:
Tuple of (is_valid, error_message)
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Get the API key for this provider
api_key = get_llm_api_key(config, provider)
if not api_key:
return False, f"No API key found for provider '{provider}'. Please set the appropriate environment variable."
return True, ""
def verify_email_domain(email: str) -> bool:
try:
domain = email.split('@')[1]

View File

@@ -14,6 +14,7 @@ x-base-config: &base-config
- TOGETHER_API_KEY=${TOGETHER_API_KEY:-}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:-}
- GEMINI_API_TOKEN=${GEMINI_API_TOKEN:-}
- LLM_PROVIDER=${LLM_PROVIDER:-} # Optional: Override default provider (e.g., "anthropic/claude-3-opus")
volumes:
- /dev/shm:/dev/shm # Chromium performance
deploy:

View File

@@ -154,6 +154,30 @@ cp deploy/docker/.llm.env.example .llm.env
# Now edit .llm.env and add your API keys
```
**Flexible LLM Provider Configuration:**
The Docker setup now supports flexible LLM provider configuration through three methods:
1. **Environment Variable** (Highest Priority): Set `LLM_PROVIDER` to override the default
```bash
export LLM_PROVIDER="anthropic/claude-3-opus"
# Or in your .llm.env file:
# LLM_PROVIDER=anthropic/claude-3-opus
```
2. **API Request Parameter**: Specify provider per request
```json
{
"url": "https://example.com",
"f": "llm",
"provider": "groq/mixtral-8x7b"
}
```
3. **Config File Default**: Falls back to `config.yml` (default: `openai/gpt-4o-mini`)
The system automatically selects the appropriate API key based on the configured `api_key_env` in the config file.
#### 3. Build and Run with Compose
The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx.
@@ -668,7 +692,7 @@ app:
# Default LLM Configuration
llm:
provider: "openai/gpt-4o-mini"
provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python3
"""Test script to verify Docker API with LLM provider configuration."""
import requests
import json
import time
BASE_URL = "http://localhost:11235"
def test_health():
"""Test health endpoint."""
print("1. Testing health endpoint...")
response = requests.get(f"{BASE_URL}/health")
print(f" Status: {response.status_code}")
print(f" Response: {response.json()}")
print()
def test_schema():
"""Test schema endpoint to see configuration."""
print("2. Testing schema endpoint...")
response = requests.get(f"{BASE_URL}/schema")
print(f" Status: {response.status_code}")
# Print only browser config to keep output concise
print(f" Browser config keys: {list(response.json().get('browser', {}).keys())[:5]}...")
print()
def test_markdown_with_llm_filter():
"""Test markdown endpoint with LLM filter (should use configured provider)."""
print("3. Testing markdown endpoint with LLM filter...")
print(" This should use the Groq provider from LLM_PROVIDER env var")
# Note: This will fail with dummy API keys, but we can see if it tries to use Groq
payload = {
"url": "https://httpbin.org/html",
"f": "llm",
"q": "Extract the main content"
}
response = requests.post(f"{BASE_URL}/md", json=payload)
print(f" Status: {response.status_code}")
if response.status_code != 200:
print(f" Error: {response.text[:200]}...")
else:
print(f" Success! Markdown length: {len(response.json().get('markdown', ''))} chars")
print()
def test_markdown_with_provider_override():
"""Test markdown endpoint with provider override in request."""
print("4. Testing markdown endpoint with provider override...")
print(" This should use OpenAI provider from request parameter")
payload = {
"url": "https://httpbin.org/html",
"f": "llm",
"q": "Extract the main content",
"provider": "openai/gpt-4" # Override to use OpenAI
}
response = requests.post(f"{BASE_URL}/md", json=payload)
print(f" Status: {response.status_code}")
if response.status_code != 200:
print(f" Error: {response.text[:200]}...")
else:
print(f" Success! Markdown length: {len(response.json().get('markdown', ''))} chars")
print()
def test_simple_crawl():
"""Test simple crawl without LLM."""
print("5. Testing simple crawl (no LLM required)...")
payload = {
"urls": ["https://httpbin.org/html"],
"browser_config": {
"type": "BrowserConfig",
"params": {"headless": True}
},
"crawler_config": {
"type": "CrawlerRunConfig",
"params": {"cache_mode": "bypass"}
}
}
response = requests.post(f"{BASE_URL}/crawl", json=payload)
print(f" Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f" Success: {result.get('success')}")
print(f" Results count: {len(result.get('results', []))}")
if result.get('results'):
print(f" First result success: {result['results'][0].get('success')}")
else:
print(f" Error: {response.text[:200]}...")
print()
def test_playground():
"""Test if playground is accessible."""
print("6. Testing playground interface...")
response = requests.get(f"{BASE_URL}/playground")
print(f" Status: {response.status_code}")
print(f" Content-Type: {response.headers.get('content-type')}")
print()
if __name__ == "__main__":
print("=== Crawl4AI Docker API Tests ===\n")
print(f"Testing API at {BASE_URL}\n")
# Wait a bit for server to be fully ready
time.sleep(2)
test_health()
test_schema()
test_simple_crawl()
test_playground()
print("\nTesting LLM functionality (these may fail with dummy API keys):\n")
test_markdown_with_llm_filter()
test_markdown_with_provider_override()
print("\nTests completed!")