Spaces:
Running
Running
| from fastapi import FastAPI, Query | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| import os | |
| from server.models import TriageAction | |
| from server.environment import LogTriageEnvironment | |
| app = FastAPI( | |
| title="LogTriageEnv", | |
| description="OpenEnv environment for SRE incident triage", | |
| version="1.0.0", | |
| ) | |
| # One environment instance per server process | |
| env = LogTriageEnvironment() | |
| def health(): | |
| return {"status": "ok", "environment": "logtriage-env", "version": "1.0.0"} | |
| def reset( | |
| task: str = Query(default="single_crash", description="Task ID to run"), | |
| seed: int = Query(default=None, description="Random seed for reproducibility"), | |
| ): | |
| try: | |
| obs = env.reset(task_id=task, seed=seed) | |
| return obs.model_dump() | |
| except ValueError as e: | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| def step(action: TriageAction): | |
| valid, err = action.is_valid() | |
| if not valid: | |
| return JSONResponse(status_code=422, content={"error": err}) | |
| try: | |
| obs = env.step(action) | |
| return obs.model_dump() | |
| except RuntimeError as e: | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| def state(): | |
| try: | |
| return env.state.model_dump() | |
| except RuntimeError as e: | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| def get_tasks(): | |
| return { | |
| "tasks": [ | |
| { | |
| "id": "single_crash", | |
| "name": "Single Service Crash", | |
| "difficulty": "easy", | |
| "max_steps": 8, | |
| "description": "One service crashes. Classify severity, find root cause, remediate.", | |
| "action_schema": { | |
| "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore", | |
| "value": "string (depends on action_type — see README)", | |
| "confidence": "float [0.0, 1.0]", | |
| "reasoning": "string (optional)", | |
| }, | |
| }, | |
| { | |
| "id": "cascading_failure", | |
| "name": "Cascading Failure", | |
| "difficulty": "medium", | |
| "max_steps": 12, | |
| "description": "DB slowdown cascades upstream. Find the true root cause, not symptoms.", | |
| "action_schema": { | |
| "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore", | |
| "value": "string (depends on action_type — see README)", | |
| "confidence": "float [0.0, 1.0]", | |
| "reasoning": "string (optional)", | |
| }, | |
| }, | |
| { | |
| "id": "silent_degradation", | |
| "name": "Silent Degradation with Noise", | |
| "difficulty": "hard", | |
| "max_steps": 15, | |
| "description": "Slow degradation hidden in 60% noise. Nuanced P2 severity judgment.", | |
| "action_schema": { | |
| "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore", | |
| "value": "string (depends on action_type — see README)", | |
| "confidence": "float [0.0, 1.0]", | |
| "reasoning": "string (optional)", | |
| }, | |
| }, | |
| ] | |
| } | |
| def grader(): | |
| try: | |
| from server.graders import score_episode | |
| state = env.state | |
| result = score_episode(state.task_id, state) | |
| return result | |
| except RuntimeError as e: | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| except ValueError as e: | |
| return JSONResponse(status_code=400, content={"error": str(e)}) | |
| def baseline(): | |
| """ | |
| Run the baseline inference script against all 3 tasks. | |
| Returns scores for each task produced by the LLM agent. | |
| Note: Requires HF_TOKEN (or GROQ_API_KEY) to be set. | |
| """ | |
| import subprocess | |
| import sys | |
| import json as json_lib | |
| try: | |
| result = subprocess.run( | |
| [sys.executable, "inference.py"], | |
| capture_output=True, | |
| text=True, | |
| timeout=1200, # 20 minute timeout (matches spec) | |
| cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| ) | |
| if result.returncode != 0: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Inference script failed", | |
| "stderr": result.stderr[-500:] if result.stderr else "", | |
| } | |
| ) | |
| # Extract JSON from output | |
| output_lines = result.stdout.strip().split("\n") | |
| json_start = None | |
| for i, line in enumerate(output_lines): | |
| if line.strip() == "JSON Output:": | |
| json_start = i + 1 | |
| break | |
| if json_start and json_start < len(output_lines): | |
| json_str = "\n".join(output_lines[json_start:]) | |
| return json_lib.loads(json_str) | |
| else: | |
| return {"message": "Baseline completed", "output": result.stdout[-1000:]} | |
| except subprocess.TimeoutExpired: | |
| return JSONResponse(status_code=504, content={"error": "Inference timed out (20min limit)"}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| def main(): | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False) | |
| if __name__ == "__main__": | |
| main() | |