File size: 17,270 Bytes
c452421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""Baseline inference script for the Incident Response Triage environment.

Supports two modes:
  1. LLM-based: Uses OpenAI API (set OPENAI_API_KEY env var)
  2. Rule-based: Deterministic heuristic baseline (fallback)

Usage:
    # LLM baseline (requires OPENAI_API_KEY)
    python -m baseline.inference --mode llm --base-url http://localhost:7860

    # Rule-based baseline (no API key needed)
    python -m baseline.inference --mode rules --base-url http://localhost:7860

    # Direct mode (in-process, no server needed)
    python -m baseline.inference --mode rules --direct
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from typing import Any, Dict, List, Optional

import httpx

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

TASK_IDS = ["severity_classification", "root_cause_analysis", "full_incident_management"]

SYSTEM_PROMPT = """\
You are an expert on-call Site Reliability Engineer handling a production incident.

You interact with an Incident Response environment by taking ONE action per step.
Each action must be a JSON object with these fields:
  - action_type: one of "classify", "investigate", "diagnose", "remediate", "escalate", "communicate"
  - target: service name, team name, or channel (depends on action_type)
  - parameters: dict with action-specific params
  - reasoning: brief explanation of why you chose this action

Action details:
  INVESTIGATE: target=service_name -> reveals that service's logs and metrics
  CLASSIFY: parameters.severity = "P1"|"P2"|"P3"|"P4"
  DIAGNOSE: target=service_name, parameters.root_cause = "description"
  REMEDIATE: target=service_name, parameters.action = "restart"|"rollback"|"scale"|"config_change"
  ESCALATE: target=team_name, parameters.priority = "urgent"|"high"|"medium", parameters.message = "..."
  COMMUNICATE: target="status_page"|"slack"|"email", parameters.message = "status update text"

Strategy:
1. Read alerts carefully
2. Investigate the most suspicious services first
3. Classify severity based on evidence
4. Diagnose root cause after investigation
5. Apply targeted remediation
6. Escalate and communicate as needed

Return ONLY valid JSON -- no markdown, no explanation outside the JSON.
"""


# ---------------------------------------------------------------------------
# Rule-based baseline
# ---------------------------------------------------------------------------

def _rule_based_easy() -> List[Dict[str, Any]]:
    """Deterministic optimal actions for the easy task."""
    return [
        {
            "action_type": "investigate",
            "target": "postgres-primary",
            "parameters": {},
            "reasoning": "Alert shows connection pool at 98% on postgres-primary. Investigating first.",
        },
        {
            "action_type": "investigate",
            "target": "user-service",
            "parameters": {},
            "reasoning": "user-service has high latency. Checking logs for connection pool errors.",
        },
        {
            "action_type": "classify",
            "target": "",
            "parameters": {"severity": "P2"},
            "reasoning": "Service degraded but not fully down. Connection pool issue is P2.",
        },
    ]


def _rule_based_medium() -> List[Dict[str, Any]]:
    """Deterministic optimal actions for the medium task."""
    return [
        {
            "action_type": "investigate",
            "target": "payment-gateway",
            "parameters": {},
            "reasoning": "Payment success rate is critically low. Starting with the payment gateway.",
        },
        {
            "action_type": "investigate",
            "target": "redis-session",
            "parameters": {},
            "reasoning": "Eviction spike on redis-session could explain missing payment tokens.",
        },
        {
            "action_type": "classify",
            "target": "",
            "parameters": {"severity": "P1"},
            "reasoning": "Payment processing at 45% success is a P1 revenue-impacting incident.",
        },
        {
            "action_type": "diagnose",
            "target": "redis-session",
            "parameters": {"root_cause": "Redis session store hit maxmemory limit causing eviction of payment session tokens. Sessions evicted before payment completion."},
            "reasoning": "Logs show redis-session at 100% memory with aggressive evictions of active sessions.",
        },
        {
            "action_type": "remediate",
            "target": "redis-session",
            "parameters": {"action": "scale"},
            "reasoning": "Scaling redis-session memory to stop evictions and restore payment flow.",
        },
    ]


def _rule_based_hard() -> List[Dict[str, Any]]:
    """Deterministic optimal actions for the hard task.

    Optimal action order: investigate (4x) -> classify -> diagnose ->
    remediate (2x) -> escalate (2x) -> communicate (triggers done).
    This maximises the score: 2 remediations (0.18) + 2 escalations (0.15)
    + 1 communication (0.06) = 0.39.  Any other ordering yields <= 0.37.
    """
    return [
        {
            "action_type": "investigate",
            "target": "auth-service",
            "parameters": {},
            "reasoning": "Auth-service has critical latency. Multiple services depend on auth. Investigating first.",
        },
        {
            "action_type": "investigate",
            "target": "api-gateway",
            "parameters": {},
            "reasoning": "API gateway returning 503s. Checking if it's auth-related.",
        },
        {
            "action_type": "investigate",
            "target": "redis-auth-cache",
            "parameters": {},
            "reasoning": "Checking auth cache - may explain why auth is slow.",
        },
        {
            "action_type": "investigate",
            "target": "order-service",
            "parameters": {},
            "reasoning": "Order queue depth at 15000+. Checking downstream impact and queue status.",
        },
        {
            "action_type": "classify",
            "target": "",
            "parameters": {"severity": "P1"},
            "reasoning": "Cascading multi-service outage affecting all authenticated endpoints. P1.",
        },
        {
            "action_type": "diagnose",
            "target": "auth-service",
            "parameters": {"root_cause": "Bad deployment v3.1.0 introduced memory leak via unbounded in-memory token cache. Auth-service OOMKill causes cascading failures to all dependent services."},
            "reasoning": "Auth-service logs show v3.1.0 deployment, memory climbing from 45% to 97%, GC pauses causing timeouts.",
        },
        {
            "action_type": "remediate",
            "target": "auth-service",
            "parameters": {"action": "rollback"},
            "reasoning": "Rolling back auth-service to v3.0.9 to fix the memory leak.",
        },
        {
            "action_type": "remediate",
            "target": "order-service",
            "parameters": {"action": "scale"},
            "reasoning": "Queue depth at 15000+. Scaling order-service to drain the backlog.",
        },
        {
            "action_type": "escalate",
            "target": "platform-team",
            "parameters": {"priority": "urgent", "message": "Cascading outage caused by auth-service v3.1.0 memory leak. Rolling back. Need platform support for queue recovery."},
            "reasoning": "Platform team needs to be aware of the cascading impact.",
        },
        {
            "action_type": "escalate",
            "target": "auth-team",
            "parameters": {"priority": "urgent", "message": "auth-service v3.1.0 has unbounded memory growth in token cache. Rolled back to v3.0.9. Please investigate before re-deploying."},
            "reasoning": "Auth team owns the service and needs to fix the root cause code.",
        },
        {
            "action_type": "communicate",
            "target": "status_page",
            "parameters": {"message": "INCIDENT UPDATE: Root cause identified - auth-service v3.1.0 memory leak. Rollback complete. Platform and auth teams engaged. ETA for full recovery: 15 minutes."},
            "reasoning": "External stakeholders need comprehensive status update with root cause and ETA.",
        },
        {
            "action_type": "communicate",
            "target": "slack",
            "parameters": {"message": "Incident update: auth-service v3.1.0 rolled back. Memory leak in token cache was root cause. Order queue draining. Monitoring recovery."},
            "reasoning": "Internal team needs current status.",
        },
    ]


RULE_BASED_ACTIONS = {
    "severity_classification": _rule_based_easy,
    "root_cause_analysis": _rule_based_medium,
    "full_incident_management": _rule_based_hard,
}


# ---------------------------------------------------------------------------
# Episode runners
# ---------------------------------------------------------------------------

def run_episode_rules(
    task_id: str,
    *,
    base_url: Optional[str] = None,
    env_instance: Any = None,
) -> Dict[str, Any]:
    """Run one episode with the rule-based baseline."""
    actions = RULE_BASED_ACTIONS[task_id]()

    if env_instance is not None:
        return _run_direct(task_id, actions, env_instance)
    else:
        return _run_http(task_id, actions, base_url)  # type: ignore[arg-type]


def _run_direct(task_id: str, actions: List[Dict], env_instance: Any) -> Dict[str, Any]:
    """Run episode directly against an env instance (in-process)."""
    from src.models import Action

    env_instance.reset(task_id)
    total_reward = 0.0
    steps = 0

    for act_dict in actions:
        action = Action(**act_dict)
        result = env_instance.step(action)
        total_reward += result.reward.value
        steps += 1
        if result.done:
            break

    grader_result = env_instance.grade()
    return {
        "task_id": task_id,
        "score": grader_result.score,
        "steps_taken": steps,
        "cumulative_reward": round(total_reward, 4),
        "grader_breakdown": grader_result.breakdown,
        "grader_feedback": grader_result.feedback,
    }


def _run_http(
    task_id: str,
    actions: List[Dict],
    base_url: str,
) -> Dict[str, Any]:
    """Run episode against the HTTP API."""
    client = httpx.Client(base_url=base_url, timeout=30.0)

    # Reset -- capture session_id for all subsequent calls
    resp = client.post("/reset", json={"task_id": task_id})
    resp.raise_for_status()
    session_id = resp.json()["session_id"]
    headers = {"X-Session-ID": session_id}

    total_reward = 0.0
    steps = 0
    done = False

    for act_dict in actions:
        if done:
            break
        resp = client.post("/step", json=act_dict, headers=headers)
        resp.raise_for_status()
        result = resp.json()
        total_reward += result["reward"]["value"]
        steps += 1
        done = result["done"]

    # Get grader score
    resp = client.post("/grader", headers=headers)
    resp.raise_for_status()
    grader = resp.json()

    return {
        "task_id": task_id,
        "score": grader["score"],
        "steps_taken": steps,
        "cumulative_reward": round(total_reward, 4),
        "grader_breakdown": grader["breakdown"],
        "grader_feedback": grader.get("feedback", ""),
    }


def run_episode_llm(
    task_id: str,
    base_url: str,
    model: str = "gpt-4o-mini",
) -> Dict[str, Any]:
    """Run one episode with an LLM agent via the OpenAI API."""
    try:
        from openai import OpenAI
    except ImportError:
        raise RuntimeError("openai package required for LLM baseline. pip install openai")

    # Support competition env vars (API_BASE_URL, HF_TOKEN, MODEL_NAME)
    # as well as the standard OPENAI_API_KEY
    api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise RuntimeError("Set HF_TOKEN (or OPENAI_API_KEY) environment variable.")

    api_base = os.environ.get("API_BASE_URL")
    effective_model = os.environ.get("MODEL_NAME", model)

    llm_kwargs: Dict[str, Any] = {"api_key": api_key}
    if api_base:
        llm_kwargs["base_url"] = api_base

    llm = OpenAI(**llm_kwargs)
    client = httpx.Client(base_url=base_url, timeout=30.0)

    # Reset environment
    resp = client.post("/reset", json={"task_id": task_id})
    resp.raise_for_status()
    obs = resp.json()
    session_id = obs["session_id"]
    headers = {"X-Session-ID": session_id}

    total_reward = 0.0
    steps = 0
    done = False
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    while not done and steps < obs.get("max_steps", 20):
        # Build user message with current observation
        user_msg = (
            f"Current observation (step {obs['step_number']}/{obs['max_steps']}):\n"
            f"{json.dumps(obs, indent=2, default=str)}\n\n"
            "What is your next action? Return ONLY a JSON object."
        )
        messages.append({"role": "user", "content": user_msg})

        # Query LLM
        completion = llm.chat.completions.create(
            model=effective_model,
            messages=messages,
            temperature=0.0,
            max_tokens=500,
            response_format={"type": "json_object"},
        )
        assistant_msg = completion.choices[0].message.content or "{}"
        messages.append({"role": "assistant", "content": assistant_msg})

        # Parse action and step
        try:
            action_dict = json.loads(assistant_msg)
        except json.JSONDecodeError:
            action_dict = {
                "action_type": "communicate",
                "target": "slack",
                "parameters": {"message": "Error parsing response"},
                "reasoning": "JSON parse error fallback",
            }

        resp = client.post("/step", json=action_dict, headers=headers)
        resp.raise_for_status()
        result = resp.json()

        obs = result["observation"]
        total_reward += result["reward"]["value"]
        steps += 1
        done = result["done"]

    # Final grader
    resp = client.post("/grader", headers=headers)
    resp.raise_for_status()
    grader = resp.json()

    return {
        "task_id": task_id,
        "score": grader["score"],
        "steps_taken": steps,
        "cumulative_reward": round(total_reward, 4),
        "grader_breakdown": grader["breakdown"],
        "grader_feedback": grader.get("feedback", ""),
    }


# ---------------------------------------------------------------------------
# Main entry points
# ---------------------------------------------------------------------------

def run_all_tasks(
    base_url: Optional[str] = None,
    env_instance: Any = None,
    mode: str = "rules",
    model: str = "gpt-4o-mini",
) -> List[Dict[str, Any]]:
    """Run baseline inference on all 3 tasks and return results."""
    results = []
    for task_id in TASK_IDS:
        if mode == "llm" and base_url:
            result = run_episode_llm(task_id, base_url, model=model)
        else:
            result = run_episode_rules(task_id, base_url=base_url, env_instance=env_instance)
        results.append(result)
        print(f"  Task: {task_id:30s}  Score: {result['score']:.4f}  Steps: {result['steps_taken']}")
    return results


def main():
    parser = argparse.ArgumentParser(description="Incident Response Triage - Baseline Inference")
    parser.add_argument("--mode", choices=["rules", "llm"], default="rules",
                        help="Baseline mode: rule-based or LLM-based")
    parser.add_argument("--base-url", default="http://localhost:7860",
                        help="Base URL of the running environment server")
    parser.add_argument("--model", default="gpt-4o-mini",
                        help="OpenAI model to use for LLM baseline")
    parser.add_argument("--direct", action="store_true",
                        help="Run in-process (no HTTP server needed)")
    args = parser.parse_args()

    print("=" * 60)
    print("Incident Response Triage - Baseline Inference")
    print(f"Mode: {args.mode}")
    print("=" * 60)

    if args.direct:
        from src.environment import IncidentResponseEnv
        env = IncidentResponseEnv()
        results = run_all_tasks(env_instance=env, mode=args.mode)
    else:
        results = run_all_tasks(base_url=args.base_url, mode=args.mode, model=args.model)

    print("=" * 60)
    mean_score = sum(r["score"] for r in results) / len(results)
    print(f"Mean score: {mean_score:.4f}")
    print("=" * 60)

    # Print detailed breakdown
    for r in results:
        print(f"\n--- {r['task_id']} ---")
        print(f"  Score: {r['score']:.4f}")
        print(f"  Steps: {r['steps_taken']}")
        print(f"  Cumulative reward: {r['cumulative_reward']:.4f}")
        print(f"  Feedback: {r.get('grader_feedback', 'N/A')}")
        if r.get("grader_breakdown"):
            for k, v in r["grader_breakdown"].items():
                print(f"    {k}: {v:.4f}")


if __name__ == "__main__":
    main()