SouravNath commited on
Commit
e63f982
Β·
1 Parent(s): 6b8d880

fix: route LLM to Groq (deepseek-r1) instead of hardcoded gpt-4o/openai

Browse files
Files changed (3) hide show
  1. agent/reflection_agent.py +31 -9
  2. api/tasks.py +2 -1
  3. configs/settings.py +2 -0
agent/reflection_agent.py CHANGED
@@ -455,28 +455,50 @@ def _call_llm(
455
  client=None,
456
  model: str = "gpt-4o",
457
  ) -> tuple[str, dict]:
458
- """Call OpenAI chat completion. Returns (patch_text, usage_dict)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  if client is None:
460
  try:
461
  from openai import OpenAI
462
- client = OpenAI()
463
  except ImportError as e:
464
- raise ImportError("Install openai: pip install openai") from e
 
 
 
465
 
466
  response = client.chat.completions.create(
467
- model=model,
468
  messages=[
469
  {"role": "system", "content": SYSTEM_PROMPT},
470
- {"role": "user", "content": user_prompt},
471
  ],
472
- max_tokens=4096,
473
- temperature=0.2,
474
  )
475
  patch_text = response.choices[0].message.content or ""
476
  usage = {
477
- "prompt_tokens": response.usage.prompt_tokens,
478
  "completion_tokens": response.usage.completion_tokens,
479
- "total_tokens": response.usage.total_tokens,
480
  }
481
  return patch_text, usage
482
 
 
455
  client=None,
456
  model: str = "gpt-4o",
457
  ) -> tuple[str, dict]:
458
+ """
459
+ Call the configured LLM provider (Groq, OpenAI, etc.).
460
+ Auto-detects provider from settings when client is None.
461
+ Returns (patch_text, usage_dict).
462
+ """
463
+ from configs.settings import settings
464
+
465
+ provider = settings.llm_provider.lower()
466
+ effective_model = model
467
+
468
+ # ── Groq (free, recommended) ───────────────────────────────────────────
469
+ if client is None and provider == "groq":
470
+ try:
471
+ from groq import Groq
472
+ client = Groq(api_key=settings.groq_api_key)
473
+ effective_model = settings.llm_model # use configured Groq model
474
+ except ImportError as e:
475
+ raise ImportError("Install groq: pip install groq") from e
476
+
477
+ # ── OpenAI (fallback) ─────────────────────────────────────────────────
478
  if client is None:
479
  try:
480
  from openai import OpenAI
481
+ client = OpenAI(api_key=settings.openai_api_key or None)
482
  except ImportError as e:
483
+ raise ImportError(
484
+ "No LLM client available. Set LLM_PROVIDER=groq and GROQ_API_KEY, "
485
+ "or install openai: pip install openai"
486
+ ) from e
487
 
488
  response = client.chat.completions.create(
489
+ model=effective_model,
490
  messages=[
491
  {"role": "system", "content": SYSTEM_PROMPT},
492
+ {"role": "user", "content": user_prompt},
493
  ],
494
+ max_tokens=settings.llm_max_tokens,
495
+ temperature=settings.llm_temperature,
496
  )
497
  patch_text = response.choices[0].message.content or ""
498
  usage = {
499
+ "prompt_tokens": response.usage.prompt_tokens,
500
  "completion_tokens": response.usage.completion_tokens,
501
+ "total_tokens": response.usage.total_tokens,
502
  }
503
  return patch_text, usage
504
 
api/tasks.py CHANGED
@@ -165,9 +165,10 @@ async def run_agent_task_async(
165
  traj_path = Path(f"results/trajectories/{task_id}.jsonl")
166
  traj_logger = TrajectoryLogger(traj_path)
167
 
 
168
  from agent.reflection_agent import ReflectionAgent
169
  agent = ReflectionAgent(
170
- model="gpt-4o",
171
  max_attempts=max_attempts,
172
  sandbox=sandbox,
173
  trajectory_logger=traj_logger,
 
165
  traj_path = Path(f"results/trajectories/{task_id}.jsonl")
166
  traj_logger = TrajectoryLogger(traj_path)
167
 
168
+ from configs.settings import settings
169
  from agent.reflection_agent import ReflectionAgent
170
  agent = ReflectionAgent(
171
+ model=settings.llm_model, # reads LLM_MODEL from env (e.g. deepseek-r1-distill-llama-70b)
172
  max_attempts=max_attempts,
173
  sandbox=sandbox,
174
  trajectory_logger=traj_logger,
configs/settings.py CHANGED
@@ -18,6 +18,8 @@ class Settings(BaseSettings):
18
 
19
  # ── LLM ─────────────────────────────────────────────────────────────────
20
  openai_api_key: str = Field(default="", alias="OPENAI_API_KEY")
 
 
21
  llm_model: str = Field(default="gpt-4o", alias="LLM_MODEL")
22
  llm_max_tokens: int = Field(default=4096, alias="LLM_MAX_TOKENS")
23
  llm_temperature: float = Field(default=0.2, alias="LLM_TEMPERATURE")
 
18
 
19
  # ── LLM ─────────────────────────────────────────────────────────────────
20
  openai_api_key: str = Field(default="", alias="OPENAI_API_KEY")
21
+ groq_api_key: str = Field(default="", alias="GROQ_API_KEY")
22
+ llm_provider: str = Field(default="openai", alias="LLM_PROVIDER") # openai | groq | gemini | ollama
23
  llm_model: str = Field(default="gpt-4o", alias="LLM_MODEL")
24
  llm_max_tokens: int = Field(default=4096, alias="LLM_MAX_TOKENS")
25
  llm_temperature: float = Field(default=0.2, alias="LLM_TEMPERATURE")