File size: 5,733 Bytes
5884d9c
e270f30
 
3358379
e270f30
5884d9c
 
e270f30
 
 
 
 
 
 
5884d9c
 
 
e270f30
 
 
 
 
 
 
5884d9c
 
 
 
 
 
 
 
 
e270f30
 
 
 
 
 
 
5884d9c
 
 
 
 
e270f30
 
 
 
5884d9c
 
 
 
e270f30
 
 
 
 
 
 
 
 
 
 
 
 
 
5884d9c
e270f30
 
 
 
 
 
 
 
 
5884d9c
e270f30
 
5884d9c
e270f30
 
 
 
 
 
 
 
 
5884d9c
e270f30
 
5884d9c
e270f30
 
 
 
 
 
 
 
 
 
fbb0927
 
 
 
 
 
 
 
 
e270f30
 
 
 
3358379
 
 
4c76730
3358379
 
 
 
 
 
 
4c76730
3358379
 
4c76730
3358379
 
 
 
 
 
 
4c76730
3358379
 
 
 
 
 
 
 
4c76730
3358379
 
 
 
 
 
 
 
 
 
4c76730
3358379
 
e270f30
 
4c76730
 
 
 
e270f30
4c76730
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
import uvicorn
import os

from server.models import TriageAction
from server.environment import LogTriageEnvironment

app = FastAPI(
    title="LogTriageEnv",
    description="OpenEnv environment for SRE incident triage",
    version="1.0.0",
)

# One environment instance per server process
env = LogTriageEnvironment()


@app.get("/health")
def health():
    return {"status": "ok", "environment": "logtriage-env", "version": "1.0.0"}


@app.post("/reset")
def reset(
    task: str = Query(default="single_crash", description="Task ID to run"),
    seed: int = Query(default=None, description="Random seed for reproducibility"),
):
    try:
        obs = env.reset(task_id=task, seed=seed)
        return obs.model_dump()
    except ValueError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})


@app.post("/step")
def step(action: TriageAction):
    valid, err = action.is_valid()
    if not valid:
        return JSONResponse(status_code=422, content={"error": err})
    try:
        obs = env.step(action)
        return obs.model_dump()
    except RuntimeError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})


@app.get("/state")
def state():
    try:
        return env.state.model_dump()
    except RuntimeError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})


@app.get("/tasks")
def get_tasks():
    return {
        "tasks": [
            {
                "id": "single_crash",
                "name": "Single Service Crash",
                "difficulty": "easy",
                "max_steps": 8,
                "description": "One service crashes. Classify severity, find root cause, remediate.",
                "action_schema": {
                    "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
                    "value": "string (depends on action_type — see README)",
                    "confidence": "float [0.0, 1.0]",
                    "reasoning": "string (optional)",
                },
            },
            {
                "id": "cascading_failure",
                "name": "Cascading Failure",
                "difficulty": "medium",
                "max_steps": 12,
                "description": "DB slowdown cascades upstream. Find the true root cause, not symptoms.",
                "action_schema": {
                    "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
                    "value": "string (depends on action_type — see README)",
                    "confidence": "float [0.0, 1.0]",
                    "reasoning": "string (optional)",
                },
            },
            {
                "id": "silent_degradation",
                "name": "Silent Degradation with Noise",
                "difficulty": "hard",
                "max_steps": 15,
                "description": "Slow degradation hidden in 60% noise. Nuanced P2 severity judgment.",
                "action_schema": {
                    "action_type": "classify_severity | identify_root_cause | escalate | remediate | request_more_logs | resolve | ignore",
                    "value": "string (depends on action_type — see README)",
                    "confidence": "float [0.0, 1.0]",
                    "reasoning": "string (optional)",
                },
            },
        ]
    }


@app.post("/grader")
def grader():
    try:
        from server.graders import score_episode
        state = env.state
        result = score_episode(state.task_id, state)
        return result
    except RuntimeError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})
    except ValueError as e:
        return JSONResponse(status_code=400, content={"error": str(e)})


@app.post("/baseline")
def baseline():
    """
    Run the baseline inference script against all 3 tasks.
    Returns scores for each task produced by the LLM agent.
    Note: Requires HF_TOKEN (or GROQ_API_KEY) to be set.
    """
    import subprocess
    import sys
    import json as json_lib

    try:
        result = subprocess.run(
            [sys.executable, "inference.py"],
            capture_output=True,
            text=True,
            timeout=1200,  # 20 minute timeout (matches spec)
            cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        )

        if result.returncode != 0:
            return JSONResponse(
                status_code=500,
                content={
                    "error": "Inference script failed",
                    "stderr": result.stderr[-500:] if result.stderr else "",
                }
            )

        # Extract JSON from output
        output_lines = result.stdout.strip().split("\n")
        json_start = None
        for i, line in enumerate(output_lines):
            if line.strip() == "JSON Output:":
                json_start = i + 1
                break

        if json_start and json_start < len(output_lines):
            json_str = "\n".join(output_lines[json_start:])
            return json_lib.loads(json_str)
        else:
            return {"message": "Baseline completed", "output": result.stdout[-1000:]}

    except subprocess.TimeoutExpired:
        return JSONResponse(status_code=504, content={"error": "Inference timed out (20min limit)"})
    except Exception as e:
        return JSONResponse(status_code=500, content={"error": str(e)})


def main():
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)


if __name__ == "__main__":
    main()