OGrohit's picture
Day 6: inference.py (renamed from baseline.py), HF_TOKEN/API_BASE_URL/MODEL_NAME env vars, pyproject.toml for openenv validate
4c76730
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
import uvicorn
import os
from server.models import TriageAction
from server.environment import LogTriageEnvironment
app = FastAPI(
title="LogTriageEnv",
description="OpenEnv environment for SRE incident triage",
version="1.0.0",
)
# One environment instance per server process
env = LogTriageEnvironment()
@app.get("/health")
def health():
return {"status": "ok", "environment": "logtriage-env", "version": "1.0.0"}
@app.post("/reset")
def reset(
task: str = Query(default="single_crash", description="Task ID to run"),
seed: int = Query(default=None, description="Random seed for reproducibility"),
):
try:
obs = env.reset(task_id=task, seed=seed)
return obs.model_dump()
except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/step")
def step(action: TriageAction):
valid, err = action.is_valid()
if not valid:
return JSONResponse(status_code=422, content={"error": err})
try:
obs = env.step(action)
return obs.model_dump()
except RuntimeError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.get("/state")
def state():
try:
return env.state.model_dump()
except RuntimeError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.get("/tasks")
def get_tasks():
return {
"tasks": [
{
"id": "single_crash",
"name": "Single Service Crash",
"difficulty": "easy",
"max_steps": 8,
"description": "One service crashes. Classify severity, find root cause, remediate.",
"action_schema": {
"action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
"value": "string (depends on action_type — see README)",
"confidence": "float [0.0, 1.0]",
"reasoning": "string (optional)",
},
},
{
"id": "cascading_failure",
"name": "Cascading Failure",
"difficulty": "medium",
"max_steps": 12,
"description": "DB slowdown cascades upstream. Find the true root cause, not symptoms.",
"action_schema": {
"action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
"value": "string (depends on action_type — see README)",
"confidence": "float [0.0, 1.0]",
"reasoning": "string (optional)",
},
},
{
"id": "silent_degradation",
"name": "Silent Degradation with Noise",
"difficulty": "hard",
"max_steps": 15,
"description": "Slow degradation hidden in 60% noise. Nuanced P2 severity judgment.",
"action_schema": {
"action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
"value": "string (depends on action_type — see README)",
"confidence": "float [0.0, 1.0]",
"reasoning": "string (optional)",
},
},
]
}
@app.post("/grader")
def grader():
try:
from server.graders import score_episode
state = env.state
result = score_episode(state.task_id, state)
return result
except RuntimeError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@app.post("/baseline")
def baseline():
"""
Run the baseline inference script against all 3 tasks.
Returns scores for each task produced by the LLM agent.
Note: Requires HF_TOKEN (or GROQ_API_KEY) to be set.
"""
import subprocess
import sys
import json as json_lib
try:
result = subprocess.run(
[sys.executable, "inference.py"],
capture_output=True,
text=True,
timeout=1200, # 20 minute timeout (matches spec)
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
if result.returncode != 0:
return JSONResponse(
status_code=500,
content={
"error": "Inference script failed",
"stderr": result.stderr[-500:] if result.stderr else "",
}
)
# Extract JSON from output
output_lines = result.stdout.strip().split("\n")
json_start = None
for i, line in enumerate(output_lines):
if line.strip() == "JSON Output:":
json_start = i + 1
break
if json_start and json_start < len(output_lines):
json_str = "\n".join(output_lines[json_start:])
return json_lib.loads(json_str)
else:
return {"message": "Baseline completed", "output": result.stdout[-1000:]}
except subprocess.TimeoutExpired:
return JSONResponse(status_code=504, content={"error": "Inference timed out (20min limit)"})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
def main():
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
if __name__ == "__main__":
main()