fix: replace openAI with litellm to support multiple llm providers

This commit is contained in:
Aravind Karnam
2025-05-02 16:51:15 +05:30
parent 5cc58f9bb3
commit 6650b2f34a
2 changed files with 32 additions and 27 deletions

2
.gitignore vendored
View File

@@ -262,3 +262,5 @@ CLAUDE.md
tests/**/test_site tests/**/test_site
tests/**/reports tests/**/reports
tests/**/benchmark_reports tests/**/benchmark_reports
docs/**/data

View File

@@ -43,7 +43,7 @@ import numpy as np
import pandas as pd import pandas as pd
import hashlib import hashlib
from openai import OpenAI # same SDK you pre-loaded from litellm import completion #Support any LLM Provider
# ─────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────────────────────────────────────────────────────
# Utils # Utils
@@ -70,11 +70,12 @@ def dev_defaults() -> SimpleNamespace:
out_dir="./insights_debug", out_dir="./insights_debug",
embed_model="all-MiniLM-L6-v2", embed_model="all-MiniLM-L6-v2",
top_k=10, top_k=10,
openai_model="gpt-4.1", llm_provider="openai/gpt-4.1",
llm_api_key=None,
max_llm_tokens=8000, max_llm_tokens=8000,
llm_temperature=1.0, llm_temperature=1.0,
workers=4, # parallel processing workers=4,
stub=False, # manual stub=False
) )
# ─────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────────────────────────────────────────────────────
@@ -166,7 +167,7 @@ def build_company_graph(companies, embeds:np.ndarray, top_k:int) -> Dict[str,Any
# ─────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────────────────────────────────────────────────────
# Org-chart via LLM # Org-chart via LLM
# ─────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────────────────────────────────────────────────────
async def infer_org_chart_llm(company, people, client:OpenAI, model_name:str, max_tokens:int, temperature:float, stub:bool): async def infer_org_chart_llm(company, people, llm_provider:str, api_key:str, max_tokens:int, temperature:float, stub:bool):
if stub: if stub:
# Tiny fake org-chart when debugging offline # Tiny fake org-chart when debugging offline
chief = random.choice(people) chief = random.choice(people)
@@ -202,15 +203,19 @@ Here is a JSON list of employees:
Return JSON: {{ "nodes":[{{id,name,title,dept,yoe_total,yoe_current,seniority_score,decision_score,avatar_url,profile_url}}], "edges":[{{source,target,type,confidence}}] }} Return JSON: {{ "nodes":[{{id,name,title,dept,yoe_total,yoe_current,seniority_score,decision_score,avatar_url,profile_url}}], "edges":[{{source,target,type,confidence}}] }}
"""} """}
] ]
resp = client.chat.completions.create( resp = completion(
model=model_name, model=llm_provider,
messages=prompt, messages=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
response_format={"type":"json_object"} response_format={"type":"json_object"},
api_key=api_key
) )
chart = json.loads(resp.choices[0].message.content) chart = json.loads(resp.choices[0].message.content)
chart["meta"] = dict(model=model_name, generated_at=datetime.now(UTC).isoformat()) chart["meta"] = dict(
model=llm_provider,
generated_at=datetime.now(UTC).isoformat()
)
return chart return chart
# ─────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────────────────────────────────────────────────────
@@ -270,15 +275,11 @@ async def run(opts):
logging.info(f"[bold cyan]Loaded[/] {len(companies)} companies, {len(people)} people") logging.info(f"[bold cyan]Loaded[/] {len(companies)} companies, {len(people)} people")
logging.info("[bold]⇢[/] Embedding company descriptions…") logging.info("[bold]⇢[/] Embedding company descriptions…")
# embeds = embed_descriptions(companies, opts.embed_model, opts) embeds = embed_descriptions(companies, opts.embed_model, opts)
logging.info("[bold]⇢[/] Building similarity graph") logging.info("[bold]⇢[/] Building similarity graph")
# company_graph = build_company_graph(companies, embeds, opts.top_k) company_graph = build_company_graph(companies, embeds, opts.top_k)
# dump_json(company_graph, out_dir/"company_graph.json") dump_json(company_graph, out_dir/"company_graph.json")
# OpenAI client (only built if not debugging)
stub = bool(opts.stub)
client = OpenAI() if not stub else None
# Filter companies that need processing # Filter companies that need processing
to_process = [] to_process = []
@@ -311,14 +312,13 @@ async def run(opts):
async def process_one(comp): async def process_one(comp):
handle = comp["handle"].strip("/").replace("/","_") handle = comp["handle"].strip("/").replace("/","_")
persons = [p for p in people if p["company_handle"].strip("/") == comp["handle"].strip("/")] persons = [p for p in people if p["company_handle"].strip("/") == comp["handle"].strip("/")]
chart = await infer_org_chart_llm( chart = await infer_org_chart_llm(
comp, persons, comp, persons,
client=client if client else OpenAI(api_key="sk-debug"), llm_provider=opts.llm_provider,
model_name=opts.openai_model, api_key=getattr(opts, 'llm_api_key', None),
max_tokens=opts.max_llm_tokens, max_tokens=opts.max_llm_tokens,
temperature=opts.llm_temperature, temperature=opts.llm_temperature,
stub=stub, stub=opts.stub or False,
) )
chart["meta"]["company"] = comp["name"] chart["meta"]["company"] = comp["name"]
@@ -354,18 +354,21 @@ def build_arg_parser():
p = argparse.ArgumentParser(description="Build graphs & visualisation from Stage-1 output") p = argparse.ArgumentParser(description="Build graphs & visualisation from Stage-1 output")
p.add_argument("--in", dest="in_dir", required=False, help="Stage-1 output dir", default=".") p.add_argument("--in", dest="in_dir", required=False, help="Stage-1 output dir", default=".")
p.add_argument("--out", dest="out_dir", required=False, help="Destination dir", default=".") p.add_argument("--out", dest="out_dir", required=False, help="Destination dir", default=".")
p.add_argument("--embed_model", default="all-MiniLM-L6-v2") p.add_argument("--embed-model", default="all-MiniLM-L6-v2")
p.add_argument("--top_k", type=int, default=10, help="Top-k neighbours per company") p.add_argument("--top-k", type=int, default=10, help="Top-k neighbours per company")
p.add_argument("--openai_model", default="gpt-4.1") p.add_argument("--llm-provider", default="openai/gpt-4.1",
p.add_argument("--max_llm_tokens", type=int, default=8024) help="LLM model to use in format 'provider/model_name' (e.g., 'anthropic/claude-3')")
p.add_argument("--llm_temperature", type=float, default=1.0) p.add_argument("--llm-api-key", help="API key for LLM provider (defaults to env vars)")
p.add_argument("--max-llm-tokens", type=int, default=8024)
p.add_argument("--llm-temperature", type=float, default=1.0)
p.add_argument("--stub", action="store_true", help="Skip OpenAI call and generate tiny fake org charts") p.add_argument("--stub", action="store_true", help="Skip OpenAI call and generate tiny fake org charts")
p.add_argument("--workers", type=int, default=4, help="Number of parallel workers for LLM inference") p.add_argument("--workers", type=int, default=4, help="Number of parallel workers for LLM inference")
return p return p
def main(): def main():
dbg = dev_defaults() dbg = dev_defaults()
opts = dbg if True else build_arg_parser().parse_args() # opts = dbg if True else build_arg_parser().parse_args()
opts = build_arg_parser().parse_args()
asyncio.run(run(opts)) asyncio.run(run(opts))
if __name__ == "__main__": if __name__ == "__main__":