mekosotto Claude Sonnet 4.6 commited on
Commit
55d9d32
·
1 Parent(s): 6d2aa47

feat(api): POST /agent/run endpoint (orchestrator + RAG, stub-injectable)

Browse files
src/api/main.py CHANGED
@@ -11,6 +11,7 @@ from src.api.routes import (
11
  predict_router,
12
  explain_router,
13
  experiments_router,
 
14
  )
15
  from src.api.schemas import HealthResponse
16
 
@@ -24,6 +25,7 @@ app.include_router(pipeline_router)
24
  app.include_router(predict_router)
25
  app.include_router(explain_router)
26
  app.include_router(experiments_router)
 
27
 
28
 
29
  @app.get("/health", response_model=HealthResponse)
 
11
  predict_router,
12
  explain_router,
13
  experiments_router,
14
+ agent_router,
15
  )
16
  from src.api.schemas import HealthResponse
17
 
 
25
  app.include_router(predict_router)
26
  app.include_router(explain_router)
27
  app.include_router(experiments_router)
28
+ app.include_router(agent_router)
29
 
30
 
31
  @app.get("/health", response_model=HealthResponse)
src/api/routes.py CHANGED
@@ -18,6 +18,9 @@ import pandas as pd
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
 
 
 
21
  BBBExplainRequest,
22
  BBBExplainResponse,
23
  BBBPredictRequest,
@@ -500,3 +503,63 @@ def diff_runs(req: RunDiffRequest) -> RunDiffResponse:
500
  )
501
  )
502
  return RunDiffResponse(rows=rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
21
+ AgentRunRequest,
22
+ AgentRunResponse,
23
+ AgentToolTraceItem,
24
  BBBExplainRequest,
25
  BBBExplainResponse,
26
  BBBPredictRequest,
 
503
  )
504
  )
505
  return RunDiffResponse(rows=rows)
506
+
507
+
508
+ # --- Agent router ----------------------------------------------------------
509
+
510
+ agent_router = APIRouter(prefix="/agent")
511
+
512
+
513
+ _DEFAULT_RAG_INDEX_DIR = Path("data/processed/faiss_index")
514
+ _AGENT_MODEL_ENV = "NEUROBRIDGE_AGENT_MODEL"
515
+ _AGENT_DEFAULT_MODEL = "google/gemini-2.0-flash-exp:free"
516
+
517
+
518
+ def _build_orchestrator():
519
+ """Construct the default orchestrator. Patchable in tests."""
520
+ from openai import OpenAI
521
+
522
+ from src.agents.orchestrator import Orchestrator
523
+ from src.agents.prompts import ORCHESTRATOR_SYSTEM_PROMPT
524
+ from src.agents.tools import build_default_tools
525
+
526
+ api_key = os.environ.get("OPENROUTER_API_KEY")
527
+ if not api_key:
528
+ raise HTTPException(
529
+ status_code=503,
530
+ detail="OPENROUTER_API_KEY not set; agent surface unavailable.",
531
+ )
532
+ client = OpenAI(
533
+ base_url="https://openrouter.ai/api/v1",
534
+ api_key=api_key,
535
+ timeout=30.0,
536
+ )
537
+ rag_dir = _DEFAULT_RAG_INDEX_DIR if _DEFAULT_RAG_INDEX_DIR.exists() else None
538
+ tools = build_default_tools(rag_index_dir=rag_dir)
539
+ model = os.environ.get(_AGENT_MODEL_ENV, _AGENT_DEFAULT_MODEL)
540
+ return Orchestrator(
541
+ llm_client=client,
542
+ tools=tools,
543
+ system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
544
+ model=model,
545
+ max_steps=5,
546
+ )
547
+
548
+
549
+ @agent_router.post("/run", response_model=AgentRunResponse)
550
+ def run_agent(req: AgentRunRequest) -> AgentRunResponse:
551
+ """Run the orchestrator on `user_input`. Picks a pipeline + grounds via RAG."""
552
+ orch = _build_orchestrator()
553
+ user_text = req.user_input
554
+ if req.user_question:
555
+ user_text = f"{req.user_input}\n\nUser question: {req.user_question}"
556
+ result = orch.run(user_text)
557
+ return AgentRunResponse(
558
+ text=result.text,
559
+ trace=[
560
+ AgentToolTraceItem(name=t.name, args=t.args, result=t.result, error=t.error)
561
+ for t in result.trace
562
+ ],
563
+ model=result.model,
564
+ finish_reason=result.finish_reason,
565
+ )
src/api/schemas.py CHANGED
@@ -228,3 +228,27 @@ class RunDiffRow(BaseModel):
228
  class RunDiffResponse(BaseModel):
229
  """Response for POST /experiments/diff: side-by-side metric/param diff."""
230
  rows: list[RunDiffRow]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  class RunDiffResponse(BaseModel):
229
  """Response for POST /experiments/diff: side-by-side metric/param diff."""
230
  rows: list[RunDiffRow]
231
+
232
+
233
+ # --- Agent surface (orchestrator + RAG) ------------------------------------
234
+
235
+ class AgentRunRequest(BaseModel):
236
+ """User input to the orchestrator."""
237
+ user_input: str = Field(..., min_length=1, description="SMILES, file path, or directory path")
238
+ user_question: str | None = Field(
239
+ None, description="Optional natural-language question to language-match the response"
240
+ )
241
+
242
+
243
+ class AgentToolTraceItem(BaseModel):
244
+ name: str
245
+ args: dict = Field(default_factory=dict)
246
+ result: dict | None = None
247
+ error: str | None = None
248
+
249
+
250
+ class AgentRunResponse(BaseModel):
251
+ text: str
252
+ trace: list[AgentToolTraceItem] = Field(default_factory=list)
253
+ model: str | None = None
254
+ finish_reason: str = "complete"
tests/agents/test_agent_route.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for POST /agent/run — uses a stub orchestrator factory."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+ from unittest.mock import patch
6
+
7
+ import pytest
8
+ from fastapi.testclient import TestClient
9
+
10
+ from src.agents.schemas import AgentResult, ToolTraceItem
11
+ from src.api.main import app
12
+
13
+
14
+ client = TestClient(app)
15
+
16
+
17
+ class _FakeOrchestrator:
18
+ """Returns a canned AgentResult; ignores input."""
19
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
20
+ pass
21
+
22
+ def run(self, user_input: str) -> AgentResult:
23
+ return AgentResult(
24
+ text=f"Synthesized answer for: {user_input}",
25
+ trace=[
26
+ ToolTraceItem(name="run_bbb_pipeline", args={"smiles": user_input},
27
+ result={"label": 1, "label_text": "permeable"}),
28
+ ToolTraceItem(name="retrieve_context", args={"query": "BBB"},
29
+ result={"chunks": []}),
30
+ ],
31
+ model="stub-model",
32
+ finish_reason="complete",
33
+ )
34
+
35
+
36
+ class TestAgentRoute:
37
+ def test_post_returns_synthesized_text_and_trace(self) -> None:
38
+ with patch("src.api.routes._build_orchestrator", return_value=_FakeOrchestrator()):
39
+ r = client.post("/agent/run", json={"user_input": "CCO"})
40
+ assert r.status_code == 200
41
+ body = r.json()
42
+ assert "Synthesized answer for: CCO" in body["text"]
43
+ assert len(body["trace"]) == 2
44
+ assert body["trace"][0]["name"] == "run_bbb_pipeline"
45
+ assert body["model"] == "stub-model"
46
+ assert body["finish_reason"] == "complete"
47
+
48
+ def test_empty_user_input_422(self) -> None:
49
+ r = client.post("/agent/run", json={"user_input": ""})
50
+ assert r.status_code == 422
51
+
52
+ def test_missing_user_input_422(self) -> None:
53
+ r = client.post("/agent/run", json={})
54
+ assert r.status_code == 422