feat(docker): add flexible LLM provider configuration

- Support LLM_PROVIDER env var to override default provider (openai/gpt-4o-mini)
- Add optional 'provider' parameter to API endpoints for per-request overrides
- Implement provider validation to ensure API keys exist
- Update documentation and examples with new configuration options

Closes the need to hardcode providers in config.yml
This commit is contained in:
ntohidi
2025-08-05 14:09:54 +08:00
parent 31a435fb0e
commit ff6ea41ac3
11 changed files with 290 additions and 23 deletions

View File

@@ -21,6 +21,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Added
- **Flexible LLM Provider Configuration** (Docker):
- Support for `LLM_PROVIDER` environment variable to override default provider
- Per-request provider override via optional `provider` parameter in API endpoints
- Automatic provider validation with clear error messages
- Updated Docker documentation and examples
### Changed ### Changed
- **WebScrapingStrategy Refactoring**: Simplified content scraping architecture - **WebScrapingStrategy Refactoring**: Simplified content scraping architecture
- `WebScrapingStrategy` is now an alias for `LXMLWebScrapingStrategy` for backward compatibility - `WebScrapingStrategy` is now an alias for `LXMLWebScrapingStrategy` for backward compatibility

View File

@@ -6,3 +6,8 @@ GROQ_API_KEY=your_groq_key_here
TOGETHER_API_KEY=your_together_key_here TOGETHER_API_KEY=your_together_key_here
MISTRAL_API_KEY=your_mistral_key_here MISTRAL_API_KEY=your_mistral_key_here
GEMINI_API_TOKEN=your_gemini_key_here GEMINI_API_TOKEN=your_gemini_key_here
# Optional: Override the default LLM provider
# Examples: "openai/gpt-4", "anthropic/claude-3-opus", "deepseek/chat", etc.
# If not set, uses the provider specified in config.yml (default: openai/gpt-4o-mini)
# LLM_PROVIDER=anthropic/claude-3-opus

View File

@@ -154,6 +154,29 @@ cp deploy/docker/.llm.env.example .llm.env
# Now edit .llm.env and add your API keys # Now edit .llm.env and add your API keys
``` ```
**Flexible LLM Provider Configuration:**
The Docker setup now supports flexible LLM provider configuration through three methods:
1. **Environment Variable** (Highest Priority): Set `LLM_PROVIDER` to override the default
```bash
export LLM_PROVIDER="anthropic/claude-3-opus"
# Or in your .llm.env file:
# LLM_PROVIDER=anthropic/claude-3-opus
```
2. **API Request Parameter**: Specify provider per request
```json
{
"url": "https://example.com",
"provider": "groq/mixtral-8x7b"
}
```
3. **Config File Default**: Falls back to `config.yml` (default: `openai/gpt-4o-mini`)
The system automatically selects the appropriate API key based on the provider.
#### 3. Build and Run with Compose #### 3. Build and Run with Compose
The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx. The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx.
@@ -668,7 +691,7 @@ app:
# Default LLM Configuration # Default LLM Configuration
llm: llm:
provider: "openai/gpt-4o-mini" provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY" api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored # api_key: sk-... # If you pass the API key directly then api_key_env will be ignored

View File

@@ -40,7 +40,9 @@ from utils import (
get_base_url, get_base_url,
is_task_id, is_task_id,
should_cleanup_task, should_cleanup_task,
decode_redis_hash decode_redis_hash,
get_llm_api_key,
validate_llm_provider
) )
import psutil, time import psutil, time
@@ -89,10 +91,12 @@ async def handle_llm_qa(
Answer:""" Answer:"""
# api_token=os.environ.get(config["llm"].get("api_key_env", ""))
response = perform_completion_with_backoff( response = perform_completion_with_backoff(
provider=config["llm"]["provider"], provider=config["llm"]["provider"],
prompt_with_variables=prompt, prompt_with_variables=prompt,
api_token=os.environ.get(config["llm"].get("api_key_env", "")) api_token=get_llm_api_key(config)
) )
return response.choices[0].message.content return response.choices[0].message.content
@@ -110,19 +114,23 @@ async def process_llm_extraction(
url: str, url: str,
instruction: str, instruction: str,
schema: Optional[str] = None, schema: Optional[str] = None,
cache: str = "0" cache: str = "0",
provider: Optional[str] = None
) -> None: ) -> None:
"""Process LLM extraction in background.""" """Process LLM extraction in background."""
try: try:
# If config['llm'] has api_key then ignore the api_key_env # Validate provider
api_key = "" is_valid, error_msg = validate_llm_provider(config, provider)
if "api_key" in config["llm"]: if not is_valid:
api_key = config["llm"]["api_key"] await redis.hset(f"task:{task_id}", mapping={
else: "status": TaskStatus.FAILED,
api_key = os.environ.get(config["llm"].get("api_key_env", None), "") "error": error_msg
})
return
api_key = get_llm_api_key(config, provider)
llm_strategy = LLMExtractionStrategy( llm_strategy = LLMExtractionStrategy(
llm_config=LLMConfig( llm_config=LLMConfig(
provider=config["llm"]["provider"], provider=provider or config["llm"]["provider"],
api_token=api_key api_token=api_key
), ),
instruction=instruction, instruction=instruction,
@@ -169,10 +177,19 @@ async def handle_markdown_request(
filter_type: FilterType, filter_type: FilterType,
query: Optional[str] = None, query: Optional[str] = None,
cache: str = "0", cache: str = "0",
config: Optional[dict] = None config: Optional[dict] = None,
provider: Optional[str] = None
) -> str: ) -> str:
"""Handle markdown generation requests.""" """Handle markdown generation requests."""
try: try:
# Validate provider if using LLM filter
if filter_type == FilterType.LLM:
is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg
)
decoded_url = unquote(url) decoded_url = unquote(url)
if not decoded_url.startswith(('http://', 'https://')): if not decoded_url.startswith(('http://', 'https://')):
decoded_url = 'https://' + decoded_url decoded_url = 'https://' + decoded_url
@@ -185,8 +202,8 @@ async def handle_markdown_request(
FilterType.BM25: BM25ContentFilter(user_query=query or ""), FilterType.BM25: BM25ContentFilter(user_query=query or ""),
FilterType.LLM: LLMContentFilter( FilterType.LLM: LLMContentFilter(
llm_config=LLMConfig( llm_config=LLMConfig(
provider=config["llm"]["provider"], provider=provider or config["llm"]["provider"],
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""), api_token=get_llm_api_key(config, provider),
), ),
instruction=query or "Extract main content" instruction=query or "Extract main content"
) )
@@ -230,7 +247,8 @@ async def handle_llm_request(
query: Optional[str] = None, query: Optional[str] = None,
schema: Optional[str] = None, schema: Optional[str] = None,
cache: str = "0", cache: str = "0",
config: Optional[dict] = None config: Optional[dict] = None,
provider: Optional[str] = None
) -> JSONResponse: ) -> JSONResponse:
"""Handle LLM extraction requests.""" """Handle LLM extraction requests."""
base_url = get_base_url(request) base_url = get_base_url(request)
@@ -260,7 +278,8 @@ async def handle_llm_request(
schema, schema,
cache, cache,
base_url, base_url,
config config,
provider
) )
except Exception as e: except Exception as e:
@@ -304,7 +323,8 @@ async def create_new_task(
schema: Optional[str], schema: Optional[str],
cache: str, cache: str,
base_url: str, base_url: str,
config: dict config: dict,
provider: Optional[str] = None
) -> JSONResponse: ) -> JSONResponse:
"""Create and initialize a new task.""" """Create and initialize a new task."""
decoded_url = unquote(input_path) decoded_url = unquote(input_path)
@@ -328,7 +348,8 @@ async def create_new_task(
decoded_url, decoded_url,
query, query,
schema, schema,
cache cache,
provider
) )
return JSONResponse({ return JSONResponse({

View File

@@ -36,6 +36,7 @@ class LlmJobPayload(BaseModel):
q: str q: str
schema: Optional[str] = None schema: Optional[str] = None
cache: bool = False cache: bool = False
provider: Optional[str] = None
class CrawlJobPayload(BaseModel): class CrawlJobPayload(BaseModel):
@@ -61,6 +62,7 @@ async def llm_job_enqueue(
schema=payload.schema, schema=payload.schema,
cache=payload.cache, cache=payload.cache,
config=_config, config=_config,
provider=payload.provider,
) )

View File

@@ -15,6 +15,7 @@ class MarkdownRequest(BaseModel):
f: FilterType = Field(FilterType.FIT, description="Contentfilter strategy: fit, raw, bm25, or llm") f: FilterType = Field(FilterType.FIT, description="Contentfilter strategy: fit, raw, bm25, or llm")
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters") q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
c: Optional[str] = Field("0", description="Cachebust / revision counter") c: Optional[str] = Field("0", description="Cachebust / revision counter")
provider: Optional[str] = Field(None, description="LLM provider override (e.g., 'anthropic/claude-3-opus')")
class RawCode(BaseModel): class RawCode(BaseModel):

View File

@@ -241,7 +241,7 @@ async def get_markdown(
raise HTTPException( raise HTTPException(
400, "URL must be absolute and start with http/https") 400, "URL must be absolute and start with http/https")
markdown = await handle_markdown_request( markdown = await handle_markdown_request(
body.url, body.f, body.q, body.c, config body.url, body.f, body.q, body.c, config, body.provider
) )
return JSONResponse({ return JSONResponse({
"url": body.url, "url": body.url,

View File

@@ -1,6 +1,7 @@
import dns.resolver import dns.resolver
import logging import logging
import yaml import yaml
import os
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@@ -19,10 +20,24 @@ class FilterType(str, Enum):
LLM = "llm" LLM = "llm"
def load_config() -> Dict: def load_config() -> Dict:
"""Load and return application configuration.""" """Load and return application configuration with environment variable overrides."""
config_path = Path(__file__).parent / "config.yml" config_path = Path(__file__).parent / "config.yml"
with open(config_path, "r") as config_file: with open(config_path, "r") as config_file:
return yaml.safe_load(config_file) config = yaml.safe_load(config_file)
# Override LLM provider from environment if set
llm_provider = os.environ.get("LLM_PROVIDER")
if llm_provider:
config["llm"]["provider"] = llm_provider
logging.info(f"LLM provider overridden from environment: {llm_provider}")
# Also support direct API key from environment if the provider-specific key isn't set
llm_api_key = os.environ.get("LLM_API_KEY")
if llm_api_key and "api_key" not in config["llm"]:
config["llm"]["api_key"] = llm_api_key
logging.info("LLM API key loaded from LLM_API_KEY environment variable")
return config
def setup_logging(config: Dict) -> None: def setup_logging(config: Dict) -> None:
"""Configure application logging.""" """Configure application logging."""
@@ -56,6 +71,52 @@ def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> str:
"""Get the appropriate API key based on the LLM provider.
Args:
config: The application configuration dictionary
provider: Optional provider override (e.g., "openai/gpt-4")
Returns:
The API key for the provider, or empty string if not found
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Check if direct API key is configured
if "api_key" in config["llm"]:
return config["llm"]["api_key"]
# Fall back to the configured api_key_env if no match
return os.environ.get(config["llm"].get("api_key_env", ""), "")
def validate_llm_provider(config: Dict, provider: Optional[str] = None) -> tuple[bool, str]:
"""Validate that the LLM provider has an associated API key.
Args:
config: The application configuration dictionary
provider: Optional provider override (e.g., "openai/gpt-4")
Returns:
Tuple of (is_valid, error_message)
"""
# Use provided provider or fall back to config
if not provider:
provider = config["llm"]["provider"]
# Get the API key for this provider
api_key = get_llm_api_key(config, provider)
if not api_key:
return False, f"No API key found for provider '{provider}'. Please set the appropriate environment variable."
return True, ""
def verify_email_domain(email: str) -> bool: def verify_email_domain(email: str) -> bool:
try: try:
domain = email.split('@')[1] domain = email.split('@')[1]

View File

@@ -14,6 +14,7 @@ x-base-config: &base-config
- TOGETHER_API_KEY=${TOGETHER_API_KEY:-} - TOGETHER_API_KEY=${TOGETHER_API_KEY:-}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:-} - MISTRAL_API_KEY=${MISTRAL_API_KEY:-}
- GEMINI_API_TOKEN=${GEMINI_API_TOKEN:-} - GEMINI_API_TOKEN=${GEMINI_API_TOKEN:-}
- LLM_PROVIDER=${LLM_PROVIDER:-} # Optional: Override default provider (e.g., "anthropic/claude-3-opus")
volumes: volumes:
- /dev/shm:/dev/shm # Chromium performance - /dev/shm:/dev/shm # Chromium performance
deploy: deploy:

View File

@@ -154,6 +154,30 @@ cp deploy/docker/.llm.env.example .llm.env
# Now edit .llm.env and add your API keys # Now edit .llm.env and add your API keys
``` ```
**Flexible LLM Provider Configuration:**
The Docker setup now supports flexible LLM provider configuration through three methods:
1. **Environment Variable** (Highest Priority): Set `LLM_PROVIDER` to override the default
```bash
export LLM_PROVIDER="anthropic/claude-3-opus"
# Or in your .llm.env file:
# LLM_PROVIDER=anthropic/claude-3-opus
```
2. **API Request Parameter**: Specify provider per request
```json
{
"url": "https://example.com",
"f": "llm",
"provider": "groq/mixtral-8x7b"
}
```
3. **Config File Default**: Falls back to `config.yml` (default: `openai/gpt-4o-mini`)
The system automatically selects the appropriate API key based on the configured `api_key_env` in the config file.
#### 3. Build and Run with Compose #### 3. Build and Run with Compose
The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx. The `docker-compose.yml` file in the project root provides a simplified approach that automatically handles architecture detection using buildx.
@@ -668,7 +692,7 @@ app:
# Default LLM Configuration # Default LLM Configuration
llm: llm:
provider: "openai/gpt-4o-mini" provider: "openai/gpt-4o-mini" # Can be overridden by LLM_PROVIDER env var
api_key_env: "OPENAI_API_KEY" api_key_env: "OPENAI_API_KEY"
# api_key: sk-... # If you pass the API key directly then api_key_env will be ignored # api_key: sk-... # If you pass the API key directly then api_key_env will be ignored

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python3
"""Test script to verify Docker API with LLM provider configuration."""
import requests
import json
import time
BASE_URL = "http://localhost:11235"
def test_health():
"""Test health endpoint."""
print("1. Testing health endpoint...")
response = requests.get(f"{BASE_URL}/health")
print(f" Status: {response.status_code}")
print(f" Response: {response.json()}")
print()
def test_schema():
"""Test schema endpoint to see configuration."""
print("2. Testing schema endpoint...")
response = requests.get(f"{BASE_URL}/schema")
print(f" Status: {response.status_code}")
# Print only browser config to keep output concise
print(f" Browser config keys: {list(response.json().get('browser', {}).keys())[:5]}...")
print()
def test_markdown_with_llm_filter():
"""Test markdown endpoint with LLM filter (should use configured provider)."""
print("3. Testing markdown endpoint with LLM filter...")
print(" This should use the Groq provider from LLM_PROVIDER env var")
# Note: This will fail with dummy API keys, but we can see if it tries to use Groq
payload = {
"url": "https://httpbin.org/html",
"f": "llm",
"q": "Extract the main content"
}
response = requests.post(f"{BASE_URL}/md", json=payload)
print(f" Status: {response.status_code}")
if response.status_code != 200:
print(f" Error: {response.text[:200]}...")
else:
print(f" Success! Markdown length: {len(response.json().get('markdown', ''))} chars")
print()
def test_markdown_with_provider_override():
"""Test markdown endpoint with provider override in request."""
print("4. Testing markdown endpoint with provider override...")
print(" This should use OpenAI provider from request parameter")
payload = {
"url": "https://httpbin.org/html",
"f": "llm",
"q": "Extract the main content",
"provider": "openai/gpt-4" # Override to use OpenAI
}
response = requests.post(f"{BASE_URL}/md", json=payload)
print(f" Status: {response.status_code}")
if response.status_code != 200:
print(f" Error: {response.text[:200]}...")
else:
print(f" Success! Markdown length: {len(response.json().get('markdown', ''))} chars")
print()
def test_simple_crawl():
"""Test simple crawl without LLM."""
print("5. Testing simple crawl (no LLM required)...")
payload = {
"urls": ["https://httpbin.org/html"],
"browser_config": {
"type": "BrowserConfig",
"params": {"headless": True}
},
"crawler_config": {
"type": "CrawlerRunConfig",
"params": {"cache_mode": "bypass"}
}
}
response = requests.post(f"{BASE_URL}/crawl", json=payload)
print(f" Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f" Success: {result.get('success')}")
print(f" Results count: {len(result.get('results', []))}")
if result.get('results'):
print(f" First result success: {result['results'][0].get('success')}")
else:
print(f" Error: {response.text[:200]}...")
print()
def test_playground():
"""Test if playground is accessible."""
print("6. Testing playground interface...")
response = requests.get(f"{BASE_URL}/playground")
print(f" Status: {response.status_code}")
print(f" Content-Type: {response.headers.get('content-type')}")
print()
if __name__ == "__main__":
print("=== Crawl4AI Docker API Tests ===\n")
print(f"Testing API at {BASE_URL}\n")
# Wait a bit for server to be fully ready
time.sleep(2)
test_health()
test_schema()
test_simple_crawl()
test_playground()
print("\nTesting LLM functionality (these may fail with dummy API keys):\n")
test_markdown_with_llm_filter()
test_markdown_with_provider_override()
print("\nTests completed!")