add proxy and add ai base_url
This commit is contained in:
@@ -82,6 +82,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
||||||
self.options = Options()
|
self.options = Options()
|
||||||
self.options.headless = True
|
self.options.headless = True
|
||||||
|
if kwargs.get("proxy"):
|
||||||
|
self.options.add_argument("--proxy-server={}".format(kwargs.get("proxy")))
|
||||||
if kwargs.get("user_agent"):
|
if kwargs.get("user_agent"):
|
||||||
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
|
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
|
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
|
||||||
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
|
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
|
||||||
self.apply_chunking = kwargs.get("apply_chunking", True)
|
self.apply_chunking = kwargs.get("apply_chunking", True)
|
||||||
|
self.base_url = kwargs.get("base_url", None)
|
||||||
if not self.apply_chunking:
|
if not self.apply_chunking:
|
||||||
self.chunk_token_threshold = 1e9
|
self.chunk_token_threshold = 1e9
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
"{" + variable + "}", variable_values[variable]
|
"{" + variable + "}", variable_values[variable]
|
||||||
)
|
)
|
||||||
|
|
||||||
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token) # , json_response=self.extract_type == "schema")
|
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token, base_url=self.base_url) # , json_response=self.extract_type == "schema")
|
||||||
try:
|
try:
|
||||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||||
blocks = json.loads(blocks)
|
blocks = json.loads(blocks)
|
||||||
|
|||||||
@@ -716,7 +716,7 @@ def extract_xml_data(tags, string):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
# Function to perform the completion with exponential backoff
|
# Function to perform the completion with exponential backoff
|
||||||
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False):
|
def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False, base_url=None):
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
from litellm.exceptions import RateLimitError
|
from litellm.exceptions import RateLimitError
|
||||||
max_attempts = 3
|
max_attempts = 3
|
||||||
@@ -735,6 +735,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
|
|||||||
],
|
],
|
||||||
temperature=0.01,
|
temperature=0.01,
|
||||||
api_key=api_token,
|
api_key=api_token,
|
||||||
|
base_url=base_url,
|
||||||
**extra_args
|
**extra_args
|
||||||
)
|
)
|
||||||
return response # Return the successful response
|
return response # Return the successful response
|
||||||
@@ -755,7 +756,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token,
|
|||||||
"content": ["Rate limit error. Please try again later."]
|
"content": ["Rate limit error. Please try again later."]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
|
def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, base_url = None):
|
||||||
# api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
|
# api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token
|
||||||
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
|
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
|
||||||
|
|
||||||
@@ -770,7 +771,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None):
|
|||||||
"{" + variable + "}", variable_values[variable]
|
"{" + variable + "}", variable_values[variable]
|
||||||
)
|
)
|
||||||
|
|
||||||
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
|
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token, base_url=base_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
||||||
@@ -864,17 +865,17 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold):
|
|||||||
|
|
||||||
return merged_sections
|
return merged_sections
|
||||||
|
|
||||||
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
|
def process_sections(url: str, sections: list, provider: str, api_token: str, base_url=None) -> list:
|
||||||
extracted_content = []
|
extracted_content = []
|
||||||
if provider.startswith("groq/"):
|
if provider.startswith("groq/"):
|
||||||
# Sequential processing with a delay
|
# Sequential processing with a delay
|
||||||
for section in sections:
|
for section in sections:
|
||||||
extracted_content.extend(extract_blocks(url, section, provider, api_token))
|
extracted_content.extend(extract_blocks(url, section, provider, api_token, base_url=base_url))
|
||||||
time.sleep(0.5) # 500 ms delay between each processing
|
time.sleep(0.5) # 500 ms delay between each processing
|
||||||
else:
|
else:
|
||||||
# Parallel processing using ThreadPoolExecutor
|
# Parallel processing using ThreadPoolExecutor
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
|
futures = [executor.submit(extract_blocks, url, section, provider, api_token, base_url=base_url) for section in sections]
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
extracted_content.extend(future.result())
|
extracted_content.extend(future.result())
|
||||||
|
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ class WebCrawler:
|
|||||||
crawler_strategy: CrawlerStrategy = None,
|
crawler_strategy: CrawlerStrategy = None,
|
||||||
always_by_pass_cache: bool = False,
|
always_by_pass_cache: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
proxy: str = None,
|
||||||
):
|
):
|
||||||
# self.db_path = db_path
|
# self.db_path = db_path
|
||||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose, proxy=proxy)
|
||||||
self.always_by_pass_cache = always_by_pass_cache
|
self.always_by_pass_cache = always_by_pass_cache
|
||||||
|
|
||||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||||
|
|||||||
Reference in New Issue
Block a user