feat(api): POST /agent/run endpoint (orchestrator + RAG, stub-injectable)
Browse files- src/api/main.py +2 -0
- src/api/routes.py +63 -0
- src/api/schemas.py +24 -0
- tests/agents/test_agent_route.py +54 -0
src/api/main.py
CHANGED
|
@@ -11,6 +11,7 @@ from src.api.routes import (
|
|
| 11 |
predict_router,
|
| 12 |
explain_router,
|
| 13 |
experiments_router,
|
|
|
|
| 14 |
)
|
| 15 |
from src.api.schemas import HealthResponse
|
| 16 |
|
|
@@ -24,6 +25,7 @@ app.include_router(pipeline_router)
|
|
| 24 |
app.include_router(predict_router)
|
| 25 |
app.include_router(explain_router)
|
| 26 |
app.include_router(experiments_router)
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
@app.get("/health", response_model=HealthResponse)
|
|
|
|
| 11 |
predict_router,
|
| 12 |
explain_router,
|
| 13 |
experiments_router,
|
| 14 |
+
agent_router,
|
| 15 |
)
|
| 16 |
from src.api.schemas import HealthResponse
|
| 17 |
|
|
|
|
| 25 |
app.include_router(predict_router)
|
| 26 |
app.include_router(explain_router)
|
| 27 |
app.include_router(experiments_router)
|
| 28 |
+
app.include_router(agent_router)
|
| 29 |
|
| 30 |
|
| 31 |
@app.get("/health", response_model=HealthResponse)
|
src/api/routes.py
CHANGED
|
@@ -18,6 +18,9 @@ import pandas as pd
|
|
| 18 |
from fastapi import APIRouter, HTTPException
|
| 19 |
|
| 20 |
from src.api.schemas import (
|
|
|
|
|
|
|
|
|
|
| 21 |
BBBExplainRequest,
|
| 22 |
BBBExplainResponse,
|
| 23 |
BBBPredictRequest,
|
|
@@ -500,3 +503,63 @@ def diff_runs(req: RunDiffRequest) -> RunDiffResponse:
|
|
| 500 |
)
|
| 501 |
)
|
| 502 |
return RunDiffResponse(rows=rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from fastapi import APIRouter, HTTPException
|
| 19 |
|
| 20 |
from src.api.schemas import (
|
| 21 |
+
AgentRunRequest,
|
| 22 |
+
AgentRunResponse,
|
| 23 |
+
AgentToolTraceItem,
|
| 24 |
BBBExplainRequest,
|
| 25 |
BBBExplainResponse,
|
| 26 |
BBBPredictRequest,
|
|
|
|
| 503 |
)
|
| 504 |
)
|
| 505 |
return RunDiffResponse(rows=rows)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# --- Agent router ----------------------------------------------------------
|
| 509 |
+
|
| 510 |
+
agent_router = APIRouter(prefix="/agent")
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
_DEFAULT_RAG_INDEX_DIR = Path("data/processed/faiss_index")
|
| 514 |
+
_AGENT_MODEL_ENV = "NEUROBRIDGE_AGENT_MODEL"
|
| 515 |
+
_AGENT_DEFAULT_MODEL = "google/gemini-2.0-flash-exp:free"
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def _build_orchestrator():
|
| 519 |
+
"""Construct the default orchestrator. Patchable in tests."""
|
| 520 |
+
from openai import OpenAI
|
| 521 |
+
|
| 522 |
+
from src.agents.orchestrator import Orchestrator
|
| 523 |
+
from src.agents.prompts import ORCHESTRATOR_SYSTEM_PROMPT
|
| 524 |
+
from src.agents.tools import build_default_tools
|
| 525 |
+
|
| 526 |
+
api_key = os.environ.get("OPENROUTER_API_KEY")
|
| 527 |
+
if not api_key:
|
| 528 |
+
raise HTTPException(
|
| 529 |
+
status_code=503,
|
| 530 |
+
detail="OPENROUTER_API_KEY not set; agent surface unavailable.",
|
| 531 |
+
)
|
| 532 |
+
client = OpenAI(
|
| 533 |
+
base_url="https://openrouter.ai/api/v1",
|
| 534 |
+
api_key=api_key,
|
| 535 |
+
timeout=30.0,
|
| 536 |
+
)
|
| 537 |
+
rag_dir = _DEFAULT_RAG_INDEX_DIR if _DEFAULT_RAG_INDEX_DIR.exists() else None
|
| 538 |
+
tools = build_default_tools(rag_index_dir=rag_dir)
|
| 539 |
+
model = os.environ.get(_AGENT_MODEL_ENV, _AGENT_DEFAULT_MODEL)
|
| 540 |
+
return Orchestrator(
|
| 541 |
+
llm_client=client,
|
| 542 |
+
tools=tools,
|
| 543 |
+
system_prompt=ORCHESTRATOR_SYSTEM_PROMPT,
|
| 544 |
+
model=model,
|
| 545 |
+
max_steps=5,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
@agent_router.post("/run", response_model=AgentRunResponse)
|
| 550 |
+
def run_agent(req: AgentRunRequest) -> AgentRunResponse:
|
| 551 |
+
"""Run the orchestrator on `user_input`. Picks a pipeline + grounds via RAG."""
|
| 552 |
+
orch = _build_orchestrator()
|
| 553 |
+
user_text = req.user_input
|
| 554 |
+
if req.user_question:
|
| 555 |
+
user_text = f"{req.user_input}\n\nUser question: {req.user_question}"
|
| 556 |
+
result = orch.run(user_text)
|
| 557 |
+
return AgentRunResponse(
|
| 558 |
+
text=result.text,
|
| 559 |
+
trace=[
|
| 560 |
+
AgentToolTraceItem(name=t.name, args=t.args, result=t.result, error=t.error)
|
| 561 |
+
for t in result.trace
|
| 562 |
+
],
|
| 563 |
+
model=result.model,
|
| 564 |
+
finish_reason=result.finish_reason,
|
| 565 |
+
)
|
src/api/schemas.py
CHANGED
|
@@ -228,3 +228,27 @@ class RunDiffRow(BaseModel):
|
|
| 228 |
class RunDiffResponse(BaseModel):
|
| 229 |
"""Response for POST /experiments/diff: side-by-side metric/param diff."""
|
| 230 |
rows: list[RunDiffRow]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
class RunDiffResponse(BaseModel):
|
| 229 |
"""Response for POST /experiments/diff: side-by-side metric/param diff."""
|
| 230 |
rows: list[RunDiffRow]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# --- Agent surface (orchestrator + RAG) ------------------------------------
|
| 234 |
+
|
| 235 |
+
class AgentRunRequest(BaseModel):
|
| 236 |
+
"""User input to the orchestrator."""
|
| 237 |
+
user_input: str = Field(..., min_length=1, description="SMILES, file path, or directory path")
|
| 238 |
+
user_question: str | None = Field(
|
| 239 |
+
None, description="Optional natural-language question to language-match the response"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class AgentToolTraceItem(BaseModel):
|
| 244 |
+
name: str
|
| 245 |
+
args: dict = Field(default_factory=dict)
|
| 246 |
+
result: dict | None = None
|
| 247 |
+
error: str | None = None
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class AgentRunResponse(BaseModel):
|
| 251 |
+
text: str
|
| 252 |
+
trace: list[AgentToolTraceItem] = Field(default_factory=list)
|
| 253 |
+
model: str | None = None
|
| 254 |
+
finish_reason: str = "complete"
|
tests/agents/test_agent_route.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for POST /agent/run — uses a stub orchestrator factory."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any
|
| 5 |
+
from unittest.mock import patch
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from fastapi.testclient import TestClient
|
| 9 |
+
|
| 10 |
+
from src.agents.schemas import AgentResult, ToolTraceItem
|
| 11 |
+
from src.api.main import app
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
client = TestClient(app)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _FakeOrchestrator:
|
| 18 |
+
"""Returns a canned AgentResult; ignores input."""
|
| 19 |
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def run(self, user_input: str) -> AgentResult:
|
| 23 |
+
return AgentResult(
|
| 24 |
+
text=f"Synthesized answer for: {user_input}",
|
| 25 |
+
trace=[
|
| 26 |
+
ToolTraceItem(name="run_bbb_pipeline", args={"smiles": user_input},
|
| 27 |
+
result={"label": 1, "label_text": "permeable"}),
|
| 28 |
+
ToolTraceItem(name="retrieve_context", args={"query": "BBB"},
|
| 29 |
+
result={"chunks": []}),
|
| 30 |
+
],
|
| 31 |
+
model="stub-model",
|
| 32 |
+
finish_reason="complete",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TestAgentRoute:
|
| 37 |
+
def test_post_returns_synthesized_text_and_trace(self) -> None:
|
| 38 |
+
with patch("src.api.routes._build_orchestrator", return_value=_FakeOrchestrator()):
|
| 39 |
+
r = client.post("/agent/run", json={"user_input": "CCO"})
|
| 40 |
+
assert r.status_code == 200
|
| 41 |
+
body = r.json()
|
| 42 |
+
assert "Synthesized answer for: CCO" in body["text"]
|
| 43 |
+
assert len(body["trace"]) == 2
|
| 44 |
+
assert body["trace"][0]["name"] == "run_bbb_pipeline"
|
| 45 |
+
assert body["model"] == "stub-model"
|
| 46 |
+
assert body["finish_reason"] == "complete"
|
| 47 |
+
|
| 48 |
+
def test_empty_user_input_422(self) -> None:
|
| 49 |
+
r = client.post("/agent/run", json={"user_input": ""})
|
| 50 |
+
assert r.status_code == 422
|
| 51 |
+
|
| 52 |
+
def test_missing_user_input_422(self) -> None:
|
| 53 |
+
r = client.post("/agent/run", json={})
|
| 54 |
+
assert r.status_code == 422
|