From fe9ff498ce1cbb3f453473c1721dfd306e60f3ee Mon Sep 17 00:00:00 2001 From: datehoer Date: Mon, 26 Aug 2024 16:12:49 +0800 Subject: [PATCH] add proxy and add ai base_url --- crawl4ai/crawler_strategy.py | 2 ++ crawl4ai/extraction_strategy.py | 3 ++- crawl4ai/utils.py | 13 +++++++------ crawl4ai/web_crawler.py | 3 ++- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index fb7980d3..66a8f7dd 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -82,6 +82,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy") self.options = Options() self.options.headless = True + if kwargs.get("proxy"): + self.options.add_argument("--proxy-server={}".format(kwargs.get("proxy"))) if kwargs.get("user_agent"): self.options.add_argument("--user-agent=" + kwargs.get("user_agent")) else: diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 080229f4..8096f11c 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -79,6 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy): self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE) self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE) self.apply_chunking = kwargs.get("apply_chunking", True) + self.base_url = kwargs.get("base_url", None) if not self.apply_chunking: self.chunk_token_threshold = 1e9 @@ -110,7 +111,7 @@ class LLMExtractionStrategy(ExtractionStrategy): "{" + variable + "}", variable_values[variable] ) - response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token) # , json_response=self.extract_type == "schema") + response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token, base_url=self.base_url) # , json_response=self.extract_type == "schema") try: blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] blocks = json.loads(blocks) diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 07832888..64ce9f57 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -716,7 +716,7 @@ def extract_xml_data(tags, string): return data # Function to perform the completion with exponential backoff -def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False): +def perform_completion_with_backoff(provider, prompt_with_variables, api_token, json_response = False, base_url=None): from litellm import completion from litellm.exceptions import RateLimitError max_attempts = 3 @@ -735,6 +735,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token, ], temperature=0.01, api_key=api_token, + base_url=base_url, **extra_args ) return response # Return the successful response @@ -755,7 +756,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token, "content": ["Rate limit error. Please try again later."] }] -def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None): +def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, base_url = None): # api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token @@ -770,7 +771,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None): "{" + variable + "}", variable_values[variable] ) - response = perform_completion_with_backoff(provider, prompt_with_variables, api_token) + response = perform_completion_with_backoff(provider, prompt_with_variables, api_token, base_url=base_url) try: blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] @@ -864,17 +865,17 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold): return merged_sections -def process_sections(url: str, sections: list, provider: str, api_token: str) -> list: +def process_sections(url: str, sections: list, provider: str, api_token: str, base_url=None) -> list: extracted_content = [] if provider.startswith("groq/"): # Sequential processing with a delay for section in sections: - extracted_content.extend(extract_blocks(url, section, provider, api_token)) + extracted_content.extend(extract_blocks(url, section, provider, api_token, base_url=base_url)) time.sleep(0.5) # 500 ms delay between each processing else: # Parallel processing using ThreadPoolExecutor with ThreadPoolExecutor() as executor: - futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections] + futures = [executor.submit(extract_blocks, url, section, provider, api_token, base_url=base_url) for section in sections] for future in as_completed(futures): extracted_content.extend(future.result()) diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index db0d9856..b354b5cd 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -22,9 +22,10 @@ class WebCrawler: crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False, + proxy: str = None, ): # self.db_path = db_path - self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) + self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose, proxy=proxy) self.always_by_pass_cache = always_by_pass_cache # Create the .crawl4ai folder in the user's home directory if it doesn't exist