add proxy and add ai base_url
This commit is contained in:
@@ -716,7 +716,7 @@ def extract_xml_data(tags, string):
|
||||
return data
|
||||
|
||||
# Function to perform the completion with exponential backoff
|
||||
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False):
|
||||
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False, base_url=None):
|
||||
from litellm import completion
|
||||
from litellm.exceptions import RateLimitError
|
||||
max_attempts = 3
|
||||
@@ -735,6 +735,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
|
||||
],
|
||||
temperature=0.01,
|
||||
api_key=api_token,
|
||||
base_url=base_url,
|
||||
**extra_args
|
||||
)
|
||||
return response # Return the successful response
|
||||
@@ -755,7 +756,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
|
||||
"content": ["Rate limit error. Please try again later."]
|
||||
}]
|
||||
|
||||
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
|
||||
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, base_url = None):
|
||||
# api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
|
||||
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
|
||||
|
||||
@@ -770,7 +771,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
|
||||
"{" + variable + "}", variable_values[variable]
|
||||
)
|
||||
|
||||
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
|
||||
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token, base_url=base_url)
|
||||
|
||||
try:
|
||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||
@@ -864,17 +865,17 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold):
|
||||
|
||||
return merged_sections
|
||||
|
||||
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
|
||||
def process_sections(url: str, sections: list, provider: str, api_token: str, base_url=None) -> list:
|
||||
extracted_content = []
|
||||
if provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
for section in sections:
|
||||
extracted_content.extend(extract_blocks(url, section, provider, api_token))
|
||||
extracted_content.extend(extract_blocks(url, section, provider, api_token, base_url=base_url))
|
||||
time.sleep(0.5) # 500 ms delay between each processing
|
||||
else:
|
||||
# Parallel processing using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
|
||||
futures = [executor.submit(extract_blocks, url, section, provider, api_token, base_url=base_url) for section in sections]
|
||||
for future in as_completed(futures):
|
||||
extracted_content.extend(future.result())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user