From 6650b2f34a1849d01e00ca1bcce5772ebaf7cc54 Mon Sep 17 00:00:00 2001 From: Aravind Karnam Date: Fri, 2 May 2025 16:51:15 +0530 Subject: [PATCH] fix: replace openAI with litellm to support multiple llm providers --- .gitignore | 4 ++- docs/apps/linkdin/c4ai_insights.py | 55 ++++++++++++++++-------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 1658a987..6a118cba 100644 --- a/.gitignore +++ b/.gitignore @@ -261,4 +261,6 @@ CLAUDE.md tests/**/test_site tests/**/reports -tests/**/benchmark_reports \ No newline at end of file +tests/**/benchmark_reports + +docs/**/data \ No newline at end of file diff --git a/docs/apps/linkdin/c4ai_insights.py b/docs/apps/linkdin/c4ai_insights.py index 8307c30d..94370258 100644 --- a/docs/apps/linkdin/c4ai_insights.py +++ b/docs/apps/linkdin/c4ai_insights.py @@ -43,7 +43,7 @@ import numpy as np import pandas as pd import hashlib -from openai import OpenAI # same SDK you pre-loaded +from litellm import completion #Support any LLM Provider # ─────────────────────────────────────────────────────────────────────────────── # Utils @@ -70,11 +70,12 @@ def dev_defaults() -> SimpleNamespace: out_dir="./insights_debug", embed_model="all-MiniLM-L6-v2", top_k=10, - openai_model="gpt-4.1", + llm_provider="openai/gpt-4.1", + llm_api_key=None, max_llm_tokens=8000, llm_temperature=1.0, - workers=4, # parallel processing - stub=False, # manual + workers=4, + stub=False ) # ─────────────────────────────────────────────────────────────────────────────── @@ -166,7 +167,7 @@ def build_company_graph(companies, embeds:np.ndarray, top_k:int) -> Dict[str,Any # ─────────────────────────────────────────────────────────────────────────────── # Org-chart via LLM # ─────────────────────────────────────────────────────────────────────────────── -async def infer_org_chart_llm(company, people, client:OpenAI, model_name:str, max_tokens:int, temperature:float, stub:bool): +async def infer_org_chart_llm(company, people, llm_provider:str, api_key:str, max_tokens:int, temperature:float, stub:bool): if stub: # Tiny fake org-chart when debugging offline chief = random.choice(people) @@ -202,15 +203,19 @@ Here is a JSON list of employees: Return JSON: {{ "nodes":[{{id,name,title,dept,yoe_total,yoe_current,seniority_score,decision_score,avatar_url,profile_url}}], "edges":[{{source,target,type,confidence}}] }} """} ] - resp = client.chat.completions.create( - model=model_name, + resp = completion( + model=llm_provider, messages=prompt, max_tokens=max_tokens, temperature=temperature, - response_format={"type":"json_object"} + response_format={"type":"json_object"}, + api_key=api_key ) chart = json.loads(resp.choices[0].message.content) - chart["meta"] = dict(model=model_name, generated_at=datetime.now(UTC).isoformat()) + chart["meta"] = dict( + model=llm_provider, + generated_at=datetime.now(UTC).isoformat() + ) return chart # ─────────────────────────────────────────────────────────────────────────────── @@ -270,15 +275,11 @@ async def run(opts): logging.info(f"[bold cyan]Loaded[/] {len(companies)} companies, {len(people)} people") logging.info("[bold]⇢[/] Embedding company descriptions…") - # embeds = embed_descriptions(companies, opts.embed_model, opts) + embeds = embed_descriptions(companies, opts.embed_model, opts) logging.info("[bold]⇢[/] Building similarity graph") - # company_graph = build_company_graph(companies, embeds, opts.top_k) - # dump_json(company_graph, out_dir/"company_graph.json") - - # OpenAI client (only built if not debugging) - stub = bool(opts.stub) - client = OpenAI() if not stub else None + company_graph = build_company_graph(companies, embeds, opts.top_k) + dump_json(company_graph, out_dir/"company_graph.json") # Filter companies that need processing to_process = [] @@ -311,14 +312,13 @@ async def run(opts): async def process_one(comp): handle = comp["handle"].strip("/").replace("/","_") persons = [p for p in people if p["company_handle"].strip("/") == comp["handle"].strip("/")] - chart = await infer_org_chart_llm( comp, persons, - client=client if client else OpenAI(api_key="sk-debug"), - model_name=opts.openai_model, + llm_provider=opts.llm_provider, + api_key=getattr(opts, 'llm_api_key', None), max_tokens=opts.max_llm_tokens, temperature=opts.llm_temperature, - stub=stub, + stub=opts.stub or False, ) chart["meta"]["company"] = comp["name"] @@ -354,18 +354,21 @@ def build_arg_parser(): p = argparse.ArgumentParser(description="Build graphs & visualisation from Stage-1 output") p.add_argument("--in", dest="in_dir", required=False, help="Stage-1 output dir", default=".") p.add_argument("--out", dest="out_dir", required=False, help="Destination dir", default=".") - p.add_argument("--embed_model", default="all-MiniLM-L6-v2") - p.add_argument("--top_k", type=int, default=10, help="Top-k neighbours per company") - p.add_argument("--openai_model", default="gpt-4.1") - p.add_argument("--max_llm_tokens", type=int, default=8024) - p.add_argument("--llm_temperature", type=float, default=1.0) + p.add_argument("--embed-model", default="all-MiniLM-L6-v2") + p.add_argument("--top-k", type=int, default=10, help="Top-k neighbours per company") + p.add_argument("--llm-provider", default="openai/gpt-4.1", + help="LLM model to use in format 'provider/model_name' (e.g., 'anthropic/claude-3')") + p.add_argument("--llm-api-key", help="API key for LLM provider (defaults to env vars)") + p.add_argument("--max-llm-tokens", type=int, default=8024) + p.add_argument("--llm-temperature", type=float, default=1.0) p.add_argument("--stub", action="store_true", help="Skip OpenAI call and generate tiny fake org charts") p.add_argument("--workers", type=int, default=4, help="Number of parallel workers for LLM inference") return p def main(): dbg = dev_defaults() - opts = dbg if True else build_arg_parser().parse_args() + # opts = dbg if True else build_arg_parser().parse_args() + opts = build_arg_parser().parse_args() asyncio.run(run(opts)) if __name__ == "__main__":