cesjavi commited on
Commit
9873daa
·
1 Parent(s): a362a22

Fix: Permanent model override for decommissioned Groq models in AgentRunnerService (Phase 9)

Browse files
backend/services/agent_runner_service.py CHANGED
@@ -63,11 +63,17 @@ class AgentRunnerService:
63
  metadata={"project_id": project_id, "run_id": run_id, "status": "running"},
64
  )
65
 
 
 
 
 
 
 
66
  agent = AgentFactory.get_agent(
67
  provider=agent_data["api_provider"],
68
  name=agent_data["name"],
69
  role=agent_data["role"],
70
- model=agent_data["model"],
71
  system_prompt=agent_data.get("system_prompt")
72
  )
73
 
@@ -214,18 +220,23 @@ class AgentRunnerService:
214
  if pattern in raw_out:
215
  logger.warning(f"SECURITY: Suspicious pattern '{pattern}' detected in agent output for task {task_id}.")
216
  result["security_warning"] = f"Output sanitized: suspicious pattern '{pattern}' detected."
217
- # We don't block yet, but we flag it.
218
 
219
  quality_review = validate_output(quality_task, result)
220
  result["quality_review"] = quality_review
221
  claims_count = await evidence_service.replace_task_claims(task, result)
222
- completion_tokens = budget_service.estimate_completion_tokens(result)
223
- estimated_cost = budget_service.estimate_cost(
 
 
 
 
 
224
  agent_data.get("api_provider"),
225
  agent_data.get("model"),
226
- prompt_tokens,
227
- completion_tokens,
228
  )
 
229
  budget_service.record_usage(
230
  project_id=project_id,
231
  task_id=task_id,
@@ -233,10 +244,10 @@ class AgentRunnerService:
233
  agent_id=agent_data.get("id"),
234
  provider=agent_data.get("api_provider"),
235
  model=agent_data.get("model"),
236
- prompt_tokens=prompt_tokens,
237
- completion_tokens=completion_tokens,
238
- estimated_cost=estimated_cost,
239
- metadata={"duration_seconds": round(duration, 2), "claims_count": claims_count},
240
  )
241
 
242
  # 6. Save to Cache
 
63
  metadata={"project_id": project_id, "run_id": run_id, "status": "running"},
64
  )
65
 
66
+ # Emergency Model Override for decommissioned Groq models
67
+ model_to_use = agent_data["model"]
68
+ if "llama3-70b-8192" in model_to_use:
69
+ model_to_use = "llama-3.3-70b-versatile"
70
+ logger.warning(f"Overriding decommissioned model {agent_data['model']} with {model_to_use}")
71
+
72
  agent = AgentFactory.get_agent(
73
  provider=agent_data["api_provider"],
74
  name=agent_data["name"],
75
  role=agent_data["role"],
76
+ model=model_to_use,
77
  system_prompt=agent_data.get("system_prompt")
78
  )
79
 
 
220
  if pattern in raw_out:
221
  logger.warning(f"SECURITY: Suspicious pattern '{pattern}' detected in agent output for task {task_id}.")
222
  result["security_warning"] = f"Output sanitized: suspicious pattern '{pattern}' detected."
 
223
 
224
  quality_review = validate_output(quality_task, result)
225
  result["quality_review"] = quality_review
226
  claims_count = await evidence_service.replace_task_claims(task, result)
227
+
228
+ # Use actual usage if provided by agent, otherwise fallback to estimation
229
+ usage = result.get("usage") or {}
230
+ actual_prompt_tokens = usage.get("prompt_tokens") or prompt_tokens
231
+ actual_completion_tokens = usage.get("completion_tokens") or budget_service.estimate_completion_tokens(result)
232
+
233
+ actual_cost = budget_service.estimate_cost(
234
  agent_data.get("api_provider"),
235
  agent_data.get("model"),
236
+ actual_prompt_tokens,
237
+ actual_completion_tokens,
238
  )
239
+
240
  budget_service.record_usage(
241
  project_id=project_id,
242
  task_id=task_id,
 
244
  agent_id=agent_data.get("id"),
245
  provider=agent_data.get("api_provider"),
246
  model=agent_data.get("model"),
247
+ prompt_tokens=actual_prompt_tokens,
248
+ completion_tokens=actual_completion_tokens,
249
+ estimated_cost=actual_cost,
250
+ metadata={"duration_seconds": round(duration, 2), "claims_count": claims_count, "usage_source": "api" if result.get("usage") else "estimation"},
251
  )
252
 
253
  # 6. Save to Cache