Spaces:
Running
Running
Commit Β·
e63f982
1
Parent(s): 6b8d880
fix: route LLM to Groq (deepseek-r1) instead of hardcoded gpt-4o/openai
Browse files- agent/reflection_agent.py +31 -9
- api/tasks.py +2 -1
- configs/settings.py +2 -0
agent/reflection_agent.py
CHANGED
|
@@ -455,28 +455,50 @@ def _call_llm(
|
|
| 455 |
client=None,
|
| 456 |
model: str = "gpt-4o",
|
| 457 |
) -> tuple[str, dict]:
|
| 458 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
if client is None:
|
| 460 |
try:
|
| 461 |
from openai import OpenAI
|
| 462 |
-
client = OpenAI()
|
| 463 |
except ImportError as e:
|
| 464 |
-
raise ImportError(
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
response = client.chat.completions.create(
|
| 467 |
-
model=
|
| 468 |
messages=[
|
| 469 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 470 |
-
{"role": "user",
|
| 471 |
],
|
| 472 |
-
max_tokens=
|
| 473 |
-
temperature=
|
| 474 |
)
|
| 475 |
patch_text = response.choices[0].message.content or ""
|
| 476 |
usage = {
|
| 477 |
-
"prompt_tokens":
|
| 478 |
"completion_tokens": response.usage.completion_tokens,
|
| 479 |
-
"total_tokens":
|
| 480 |
}
|
| 481 |
return patch_text, usage
|
| 482 |
|
|
|
|
| 455 |
client=None,
|
| 456 |
model: str = "gpt-4o",
|
| 457 |
) -> tuple[str, dict]:
|
| 458 |
+
"""
|
| 459 |
+
Call the configured LLM provider (Groq, OpenAI, etc.).
|
| 460 |
+
Auto-detects provider from settings when client is None.
|
| 461 |
+
Returns (patch_text, usage_dict).
|
| 462 |
+
"""
|
| 463 |
+
from configs.settings import settings
|
| 464 |
+
|
| 465 |
+
provider = settings.llm_provider.lower()
|
| 466 |
+
effective_model = model
|
| 467 |
+
|
| 468 |
+
# ββ Groq (free, recommended) βββββββββββββββββββββββββββββββββββββββββββ
|
| 469 |
+
if client is None and provider == "groq":
|
| 470 |
+
try:
|
| 471 |
+
from groq import Groq
|
| 472 |
+
client = Groq(api_key=settings.groq_api_key)
|
| 473 |
+
effective_model = settings.llm_model # use configured Groq model
|
| 474 |
+
except ImportError as e:
|
| 475 |
+
raise ImportError("Install groq: pip install groq") from e
|
| 476 |
+
|
| 477 |
+
# ββ OpenAI (fallback) βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 478 |
if client is None:
|
| 479 |
try:
|
| 480 |
from openai import OpenAI
|
| 481 |
+
client = OpenAI(api_key=settings.openai_api_key or None)
|
| 482 |
except ImportError as e:
|
| 483 |
+
raise ImportError(
|
| 484 |
+
"No LLM client available. Set LLM_PROVIDER=groq and GROQ_API_KEY, "
|
| 485 |
+
"or install openai: pip install openai"
|
| 486 |
+
) from e
|
| 487 |
|
| 488 |
response = client.chat.completions.create(
|
| 489 |
+
model=effective_model,
|
| 490 |
messages=[
|
| 491 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 492 |
+
{"role": "user", "content": user_prompt},
|
| 493 |
],
|
| 494 |
+
max_tokens=settings.llm_max_tokens,
|
| 495 |
+
temperature=settings.llm_temperature,
|
| 496 |
)
|
| 497 |
patch_text = response.choices[0].message.content or ""
|
| 498 |
usage = {
|
| 499 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
| 500 |
"completion_tokens": response.usage.completion_tokens,
|
| 501 |
+
"total_tokens": response.usage.total_tokens,
|
| 502 |
}
|
| 503 |
return patch_text, usage
|
| 504 |
|
api/tasks.py
CHANGED
|
@@ -165,9 +165,10 @@ async def run_agent_task_async(
|
|
| 165 |
traj_path = Path(f"results/trajectories/{task_id}.jsonl")
|
| 166 |
traj_logger = TrajectoryLogger(traj_path)
|
| 167 |
|
|
|
|
| 168 |
from agent.reflection_agent import ReflectionAgent
|
| 169 |
agent = ReflectionAgent(
|
| 170 |
-
model=
|
| 171 |
max_attempts=max_attempts,
|
| 172 |
sandbox=sandbox,
|
| 173 |
trajectory_logger=traj_logger,
|
|
|
|
| 165 |
traj_path = Path(f"results/trajectories/{task_id}.jsonl")
|
| 166 |
traj_logger = TrajectoryLogger(traj_path)
|
| 167 |
|
| 168 |
+
from configs.settings import settings
|
| 169 |
from agent.reflection_agent import ReflectionAgent
|
| 170 |
agent = ReflectionAgent(
|
| 171 |
+
model=settings.llm_model, # reads LLM_MODEL from env (e.g. deepseek-r1-distill-llama-70b)
|
| 172 |
max_attempts=max_attempts,
|
| 173 |
sandbox=sandbox,
|
| 174 |
trajectory_logger=traj_logger,
|
configs/settings.py
CHANGED
|
@@ -18,6 +18,8 @@ class Settings(BaseSettings):
|
|
| 18 |
|
| 19 |
# ββ LLM βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
openai_api_key: str = Field(default="", alias="OPENAI_API_KEY")
|
|
|
|
|
|
|
| 21 |
llm_model: str = Field(default="gpt-4o", alias="LLM_MODEL")
|
| 22 |
llm_max_tokens: int = Field(default=4096, alias="LLM_MAX_TOKENS")
|
| 23 |
llm_temperature: float = Field(default=0.2, alias="LLM_TEMPERATURE")
|
|
|
|
| 18 |
|
| 19 |
# ββ LLM βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
openai_api_key: str = Field(default="", alias="OPENAI_API_KEY")
|
| 21 |
+
groq_api_key: str = Field(default="", alias="GROQ_API_KEY")
|
| 22 |
+
llm_provider: str = Field(default="openai", alias="LLM_PROVIDER") # openai | groq | gemini | ollama
|
| 23 |
llm_model: str = Field(default="gpt-4o", alias="LLM_MODEL")
|
| 24 |
llm_max_tokens: int = Field(default=4096, alias="LLM_MAX_TOKENS")
|
| 25 |
llm_temperature: float = Field(default=0.2, alias="LLM_TEMPERATURE")
|