add proxy and add ai base_url

This commit is contained in:
datehoer
2024-08-26 16:12:49 +08:00
parent eba831ca30
commit fe9ff498ce
4 changed files with 13 additions and 8 deletions

View File

@@ -82,6 +82,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
self.options = Options()
self.options.headless = True
if kwargs.get("proxy"):
self.options.add_argument("--proxy-server={}".format(kwargs.get("proxy")))
if kwargs.get("user_agent"):
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
else:

View File

@@ -79,6 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
self.apply_chunking = kwargs.get("apply_chunking", True)
self.base_url = kwargs.get("base_url", None)
if not self.apply_chunking:
self.chunk_token_threshold = 1e9
@@ -110,7 +111,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token) # , json_response=self.extract_type == "schema")
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token, base_url=self.base_url) # , json_response=self.extract_type == "schema")
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
blocks = json.loads(blocks)

View File

@@ -716,7 +716,7 @@ def extract_xml_data(tags, string):
return data
# Function to perform the completion with exponential backoff
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False):
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False, base_url=None):
from litellm import completion
from litellm.exceptions import RateLimitError
max_attempts = 3
@@ -735,6 +735,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
],
temperature=0.01,
api_key=api_token,
base_url=base_url,
**extra_args
)
return response # Return the successful response
@@ -755,7 +756,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
"content": ["Rate limit error. Please try again later."]
}]
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, base_url = None):
# api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
@@ -770,7 +771,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token, base_url=base_url)
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
@@ -864,17 +865,17 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold):
return merged_sections
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
def process_sections(url: str, sections: list, provider: str, api_token: str, base_url=None) -> list:
extracted_content = []
if provider.startswith("groq/"):
# Sequential processing with a delay
for section in sections:
extracted_content.extend(extract_blocks(url, section, provider, api_token))
extracted_content.extend(extract_blocks(url, section, provider, api_token, base_url=base_url))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
futures = [executor.submit(extract_blocks, url, section, provider, api_token, base_url=base_url) for section in sections]
for future in as_completed(futures):
extracted_content.extend(future.result())

View File

@@ -22,9 +22,10 @@ class WebCrawler:
crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False,
verbose: bool = False,
proxy: str = None,
):
# self.db_path = db_path
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose, proxy=proxy)
self.always_by_pass_cache = always_by_pass_cache
# Create the .crawl4ai folder in the user's home directory if it doesn't exist