File size: 2,594 Bytes
c0a7163 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """Deterministic fallbacks for the orchestrator workflow.
The LLM remains responsible for normal function-calling, but these helpers
keep the public agent route reliable when a model skips or mis-shapes a tool
call during a live demo.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from src.agents.schemas import ToolTraceItem
_EEG_SUFFIXES = {".fif", ".edf"}
def route_pipeline_input(
user_input: str,
context: dict[str, Any] | None = None,
) -> tuple[str, dict[str, Any]] | None:
"""Map raw user input to exactly one pipeline tool and argument dict."""
text = _primary_input(user_input)
if not text:
return None
path = Path(text)
lower = text.lower()
if path.suffix.lower() in _EEG_SUFFIXES:
return "run_eeg_pipeline", {"input_path": text}
if _looks_like_mri_input(path, lower):
input_dir = path.parent if lower.endswith(".nii.gz") or path.suffix.lower() == ".nii" else path
sites_csv = _sites_csv_for(input_dir, context)
return "run_mri_pipeline", {
"input_dir": str(input_dir),
"sites_csv": sites_csv,
}
if _looks_like_path(text):
return None
return "run_bbb_pipeline", {"smiles": text, "top_k": 5}
def build_retrieval_query(
user_input: str,
pipeline_trace: ToolTraceItem,
context: dict[str, Any] | None = None,
) -> str:
"""Build the canonical scientific RAG query for a completed pipeline tool."""
if pipeline_trace.name == "run_eeg_pipeline":
return "ICA artifact removal in multi-channel EEG"
if pipeline_trace.name == "run_mri_pipeline":
return "ComBat scanner site harmonization in multi-center MRI"
return "BBB permeability of small lipophilic molecules"
def _primary_input(user_input: str) -> str:
"""Return the first non-empty input line, excluding appended user questions."""
before_question = user_input.split("\n\nUser question:", 1)[0]
return before_question.strip().strip("\"'")
def _looks_like_mri_input(path: Path, lower: str) -> bool:
if lower.endswith(".nii.gz") or path.suffix.lower() == ".nii":
return True
if path.exists() and path.is_dir():
return True
return not path.suffix and _looks_like_path(str(path))
def _looks_like_path(text: str) -> bool:
return "/" in text or "\\" in text
def _sites_csv_for(input_dir: Path, context: dict[str, Any] | None) -> str:
explicit = (context or {}).get("sites_csv")
if explicit:
return str(explicit)
return str(input_dir / "sites.csv")
|