File size: 24,402 Bytes
77eb356
 
 
 
 
3dc9d8d
 
 
 
 
77eb356
 
 
 
 
 
 
a50dd28
77eb356
 
8151d99
3dc9d8d
8151d99
77eb356
8151d99
3dc9d8d
8151d99
77eb356
 
 
 
 
 
 
 
 
8151d99
77eb356
3dc9d8d
77eb356
8151d99
77eb356
 
 
 
 
 
 
 
 
 
8151d99
77eb356
 
 
a50dd28
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
77eb356
a50dd28
 
 
 
 
 
 
 
77eb356
a50dd28
77eb356
 
a50dd28
 
77eb356
 
a50dd28
 
 
 
77eb356
 
a50dd28
 
 
 
77eb356
 
 
 
a50dd28
77eb356
 
 
a50dd28
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
 
 
 
77eb356
a50dd28
77eb356
a50dd28
 
 
77eb356
 
 
 
a50dd28
 
77eb356
a50dd28
 
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
77eb356
 
 
 
 
 
a50dd28
77eb356
 
a50dd28
77eb356
 
 
a50dd28
 
 
 
 
 
77eb356
a50dd28
77eb356
 
 
 
a50dd28
77eb356
a50dd28
 
 
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
77eb356
 
 
 
a50dd28
77eb356
 
 
 
a50dd28
77eb356
a50dd28
77eb356
a50dd28
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
 
77eb356
a50dd28
 
77eb356
 
 
a50dd28
 
 
77eb356
a50dd28
 
 
 
 
 
77eb356
 
a50dd28
 
77eb356
a50dd28
 
77eb356
 
 
 
 
 
 
a50dd28
77eb356
 
a50dd28
 
77eb356
a50dd28
77eb356
 
 
 
a50dd28
 
77eb356
 
 
 
 
 
 
a50dd28
77eb356
a50dd28
 
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50dd28
77eb356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8151d99
a50dd28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8151d99
a50dd28
 
 
 
 
 
 
 
 
 
3dc9d8d
8151d99
3dc9d8d
 
 
 
 
 
 
 
 
 
 
77eb356
 
3dc9d8d
 
 
 
 
 
77eb356
 
3dc9d8d
77eb356
3dc9d8d
 
 
 
 
 
 
 
 
77eb356
3dc9d8d
 
 
 
 
 
 
 
 
 
 
77eb356
3dc9d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77eb356
3dc9d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77eb356
4591acd
 
 
 
3dc9d8d
 
 
 
 
 
77eb356
4591acd
3dc9d8d
77eb356
 
3dc9d8d
77eb356
 
 
8151d99
 
 
 
 
 
 
 
 
77eb356
 
8151d99
 
 
 
 
 
 
77eb356
3dc9d8d
 
 
 
 
 
 
 
 
77eb356
 
3dc9d8d
77eb356
8151d99
3dc9d8d
77eb356
3dc9d8d
8151d99
3dc9d8d
 
 
8151d99
77eb356
8151d99
 
 
 
 
 
 
3dc9d8d
 
 
 
 
8151d99
 
 
 
 
 
 
 
 
77eb356
3dc9d8d
 
8151d99
 
3dc9d8d
 
8151d99
3dc9d8d
8151d99
3dc9d8d
 
8151d99
3dc9d8d
 
8151d99
 
77eb356
3dc9d8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
"""
inference.py β€” Cloud Incident Response OpenEnv baseline inference script.

The LLM reasons from evidence. Fallback is a dumb safety net that scores low.
Override only blocks clearly invalid actions (wrong task submission, bad params).

STRUCTURED OUTPUT:
  [START] task=<task_name> env=cloud-incident-response model=<model_name>
  [STEP]  step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
  [END]   success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""

from __future__ import annotations

import json
import os
import sys
import time

import requests
import time as _time
_START_TIME = _time.time()
_MAX_RUNTIME = 1080

def _check_timeout():
    if _time.time() - _START_TIME > _MAX_RUNTIME:
        raise RuntimeError("Approaching 20min limit β€” stopping early")
try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

# ── Config ──────────────────────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME   = os.environ.get("MODEL_NAME",   "llama-3.1-8b-instant")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or ""
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
ENV_NAME = "cloud-incident-response"

if not API_KEY:
    print("[WARN] No API key set β€” LLM calls will fail.", file=sys.stderr)

_session = requests.Session()
_client = None


def _get_client():
    global _client
    if _client is None:
        from openai import OpenAI
        _client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
    return _client


# ── Constants ───────────────────────────────────────────────────────────────
_TASK_SUBMIT = {
    "alert_classification":  "submit_severity",
    "root_cause_analysis":   "submit_root_cause",
    "remediation_planning":  "submit_resolution",
}

_DIAG_TYPES = frozenset({
    "query_logs", "check_metrics", "check_dependencies",
    "check_recent_deploys", "check_service_status",
})

_SUBMIT_TYPES = frozenset({
    "submit_severity", "submit_root_cause", "submit_resolution",
})

_REM_TYPES = frozenset({
    "restart_service", "rollback_deploy", "scale_service",
    "disable_feature_flag", "clear_cache", "execute_runbook_step",
})

_ALL_VALID = _DIAG_TYPES | _SUBMIT_TYPES | _REM_TYPES


SYSTEM_PROMPT = """\
You are an expert Site Reliability Engineer responding to a production incident.
Reply with exactly ONE JSON action object. No markdown, no explanation, no extra text.

VALID ACTIONS:
{"action_type":"query_logs","parameters":{"service":"<name>"}}
{"action_type":"check_metrics","parameters":{"service":"<name>"}}
{"action_type":"check_dependencies","parameters":{"service":"<name>"}}
{"action_type":"check_recent_deploys","parameters":{"service":"<name>"}}
{"action_type":"check_service_status","parameters":{"service":"<name>"}}
{"action_type":"restart_service","parameters":{"service":"<name>"}}
{"action_type":"rollback_deploy","parameters":{"service":"<name>","target_version":"previous"}}
{"action_type":"disable_feature_flag","parameters":{"flag":"<flag_name>"}}
{"action_type":"execute_runbook_step","parameters":{"runbook_action":"<action>"}}
{"action_type":"submit_severity","parameters":{"severity":"P1|P2|P3|P4","service":"<name>"}}
{"action_type":"submit_root_cause","parameters":{"service":"<name>","failure_mode":"<description>"}}
{"action_type":"submit_resolution","parameters":{"summary":"<3+ sentence summary>"}}

RULES:
- Service names MUST exactly match the KNOWN_SERVICES list.
- P1 = complete outage OR revenue > $1,000/min.  P2 = major degradation.
  P3 = minor/partial issue with graceful fallback.  P4 = informational.
- IMPORTANT: check_recent_deploys and check_dependencies require prior
  investigation. You MUST query_logs or check_metrics on a service BEFORE
  checking its deploys or dependencies. Otherwise you get limited data.
- Root cause = the upstream service that TRIGGERED the cascade. Often NOT
  in the alert's affected_services list.
- submit_resolution summary: 3+ sentences about what failed, what you did, status.
- Submit as soon as evidence is clear β€” do NOT waste steps.

STRATEGY:

alert_classification (max 3 steps):
  Query 1-2 services with logs/metrics, then submit_severity.
  Check revenue_impact and error_rate carefully. Not all high error rates are P1.

root_cause_analysis (max 10 steps):
  1. query_logs or check_metrics on 2-3 services to understand the blast radius
  2. THEN check_recent_deploys on services that look suspicious
  3. Look for the service whose deploy/change CAUSED the cascade
  4. Submit submit_root_cause with service and failure_mode

remediation_planning (max 15 steps):
  1. query_logs on affected services to confirm root cause
  2. Execute remediation actions in logical order
  3. Verify recovery with check_service_status
  4. Submit submit_resolution with detailed summary

CRITICAL: Each task has ONE correct submission action:
  alert_classification  -> submit_severity
  root_cause_analysis   -> submit_root_cause
  remediation_planning  -> submit_resolution"""


# ── Helpers ─────────────────────────────────────────────────────────────────

def _queried_svcs(queried_data: dict) -> set[str]:
    return {
        svc
        for at, svcs in queried_data.items()
        if at in _DIAG_TYPES and isinstance(svcs, dict)
        for svc in svcs
    }


def _extract_signals(queried_data: dict) -> list[str]:
    seen: set[str] = set()
    signals: list[str] = []

    def _add(msg: str) -> None:
        if msg not in seen:
            seen.add(msg)
            signals.append(msg)

    for action_type, services in queried_data.items():
        if not isinstance(services, dict):
            continue
        for svc, data in services.items():
            t = str(data).lower()
            if "out of memory" in t or "oom" in t:
                _add(f"OOM detected in {svc}")
            if "bgp" in t and ("withdrawal" in t or "withdrawn" in t):
                _add(f"BGP route issue in {svc}")
            if "pool" in t and ("exhaust" in t or "too many clients" in t):
                _add(f"Connection pool issue in {svc}")
            if "cache" in t and ("purge" in t or "invalidat" in t):
                _add(f"Cache purge in {svc}")
            if "unbounded" in t or "no limit" in t:
                _add(f"Unbounded query in {svc}")
            if "credential" in t or "password" in t or "authentication failed" in t:
                _add(f"Credential/auth issue in {svc}")
            if "requires deeper investigation" in t or "requires initial investigation" in t:
                _add(f"GATED: {svc} needs logs/metrics first before checking deploys")
            if action_type == "check_recent_deploys" and any(
                x in t for x in ("ago", "change", "update", "added", "deploy")
            ):
                if "requires" not in t:  # Don't show gated responses as signals
                    snippet = str(data)[:120].replace("\n", " ")
                    _add(f"Recent change in {svc}: {snippet}")
    return signals


def _first_obs_msg(obs: dict) -> str:
    alert = obs.get("alert", {})
    known = obs.get("known_services", [])
    affected = alert.get("affected_services", [])
    task_id = obs.get("task_id", "")
    non_aff = [s for s in known if s not in affected]

    lines = [
        "=== NEW INCIDENT ===",
        f"Task: {task_id}  |  Max steps: {obs.get('max_steps')}",
        f"Scenario: {obs.get('scenario_id', '')}",
        f"INCIDENT: {obs.get('incident_summary', '')}",
    ]

    if alert:
        lines.append("ALERT DETAILS:")
        if alert.get("title"):
            lines.append(f"  Title: {alert['title']}")
        if affected:
            lines.append(f"  Directly affected: {', '.join(affected)}")
        for s in alert.get("symptoms", []):
            lines.append(f"  - {s}")
        for k in ("error_rate", "duration_minutes", "revenue_impact_per_min"):
            if alert.get(k) is not None:
                lines.append(f"  {k}: {alert[k]}")

    lines.append(f"KNOWN_SERVICES: {json.dumps(known)}")

    if non_aff and task_id in ("root_cause_analysis", "remediation_planning"):
        lines.append(f"  Services NOT in alert (investigate these too): {json.dumps(non_aff)}")

    lines.append(f"AVAILABLE ACTIONS: {obs.get('available_actions', [])}")
    lines.append(f"REQUIRED SUBMISSION: {_TASK_SUBMIT.get(task_id, 'unknown')}")

    if task_id in ("root_cause_analysis", "remediation_planning"):
        lines.append("")
        lines.append("NOTE: check_recent_deploys requires prior investigation.")
        lines.append("You MUST query_logs or check_metrics on a service FIRST.")

    lines.append("")
    lines.append("Respond with your first action (JSON only):")
    return "\n".join(lines)


def _step_msg(obs: dict, prev_queried: dict) -> str:
    step = obs.get("step_count", 0)
    max_steps = obs.get("max_steps", 10)
    left = max_steps - step
    queried = obs.get("queried_data", {})
    task_id = obs.get("task_id", "")

    lines = [
        f"Step {step}/{max_steps} ({left} remaining) | "
        f"reward={obs.get('cumulative_reward', 0.0):.3f} | "
        f"feedback: {obs.get('feedback', '')}",
    ]

    new_data = []
    for action_type, services in queried.items():
        prev = prev_queried.get(action_type, {})
        if isinstance(services, dict):
            for svc, data in services.items():
                if svc not in prev:
                    d = str(data)
                    if len(d) > 500:
                        d = d[:500] + "..."
                    new_data.append(f"  [{action_type}][{svc}]: {d}")
    if new_data:
        lines.append("NEW DATA:")
        lines.extend(new_data)

    signals = _extract_signals(queried)
    if signals:
        lines.append("SIGNALS:")
        for sig in signals:
            lines.append(f"  *** {sig} ***")

    if left <= 3:
        lines.append(f"*** {left} steps left β€” submit {_TASK_SUBMIT.get(task_id, '')} soon ***")
    if left <= 1:
        lines.append(f"!!! LAST STEP β€” MUST {_TASK_SUBMIT.get(task_id, 'SUBMIT')} NOW !!!")

    lines.append("Next action (JSON only):")
    return "\n".join(lines)


def _parse(text: str) -> dict:
    text = text.strip()
    if text.startswith("`"):
        text = "\n".join(
            ln for ln in text.splitlines() if not ln.startswith("`")
        ).strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        s = text.find("{")
        e = text.rfind("}") + 1
        if s != -1 and e > s:
            return json.loads(text[s:e])
        raise


def _fallback_submit(task_id: str, obs: dict) -> dict:
    alert = obs.get("alert", {})
    known = obs.get("known_services", [])

    if task_id == "alert_classification":
        rev = alert.get("revenue_impact_per_min", 0) or 0
        err = alert.get("error_rate", 0) or 0
        sev = ("P1" if (rev > 1000 or err > 0.9) else
               ("P2" if (rev > 100 or err > 0.3) else "P3"))
        svc = (alert.get("affected_services") or known or ["unknown"])[0]
        return {"action_type": "submit_severity",
                "parameters": {"severity": sev, "service": svc}}

    if task_id == "root_cause_analysis":
        svc = known[0] if known else "unknown"
        return {"action_type": "submit_root_cause",
                "parameters": {"service": svc,
                               "failure_mode": "service failure causing cascade"}}

    return {"action_type": "submit_resolution",
            "parameters": {"summary": (
                "The incident was investigated through log and metric analysis. "
                "Remediation actions were applied to restore service health. "
                "Systems are being monitored for recovery confirmation."
            )}}


def _smart_fallback(task_id: str, obs: dict, step: int, max_steps: int) -> dict:
    known = obs.get("known_services", [])
    queried = obs.get("queried_data", {})
    left = max_steps - step
    q_svcs = _queried_svcs(queried)

    if left <= 1:
        return _fallback_submit(task_id, obs)

    if task_id == "alert_classification" and q_svcs:
        return _fallback_submit(task_id, obs)

    # Query logs on unvisited services first
    for svc in known:
        if svc not in q_svcs:
            return {"action_type": "query_logs",
                    "parameters": {"service": svc}}

    # Then try check_recent_deploys (will now work since we queried logs)
    if task_id in ("root_cause_analysis", "remediation_planning"):
        deploy_queried = set(queried.get("check_recent_deploys", {}).keys())
        for svc in known:
            if svc not in deploy_queried:
                return {"action_type": "check_recent_deploys",
                        "parameters": {"service": svc}}

    return _fallback_submit(task_id, obs)


def _should_override(
    task_id: str, action: dict, obs: dict, step: int, max_steps: int
) -> bool:
    at = action.get("action_type", "")
    params = action.get("parameters", {})
    left = max_steps - step
    known = obs.get("known_services", [])

    if at not in _ALL_VALID:
        return True
    if left <= 0 and at not in _SUBMIT_TYPES:
        return True

    correct_submit = _TASK_SUBMIT.get(task_id)
    if at in _SUBMIT_TYPES and at != correct_submit:
        return True

    svc = (params.get("service") or "").strip()
    if (svc and known
            and at not in ("disable_feature_flag", "execute_runbook_step")
            and svc not in known):
        return True

    if at == "submit_severity":
        sev = (params.get("severity") or "").upper().strip()
        if sev not in ("P1", "P2", "P3", "P4"):
            return True

    if at == "submit_root_cause":
        svc = (params.get("service") or "").strip()
        mode = (params.get("failure_mode") or "").strip()
        if not svc or len(mode) < 5:
            return True

    if at == "submit_resolution":
        summary = (params.get("summary") or "").strip()
        if len(summary) < 30:
            return True

    if task_id == "alert_classification" and at in _REM_TYPES:
        return True

    return False


def _llm_call_with_retry(messages: list, max_retries: int = 1) -> str:
    """Call LLM with retry on rate limit errors."""
    for attempt in range(max_retries + 1):
        try:
            resp = _get_client().chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                temperature=0.0,
                max_tokens=300,
                stream=False,
            )
            return resp.choices[0].message.content or ""
        except Exception as e:
            err_str = str(e).lower()
            if "rate_limit" in err_str or "429" in err_str:
                if attempt < max_retries:
                    # Parse wait time from error or use default
                    wait = 5 * (attempt + 1)
                    print(f"  [RATE LIMIT] waiting {wait}s (attempt {attempt + 1})",
                          file=sys.stderr)
                    time.sleep(wait)
                    continue
            if attempt == max_retries:
                print(f"  [WARN] LLM call failed: {e}", file=sys.stderr)
                return ""
    return ""


# ── Structured Output Helpers ───────────────────────────────────────────────

def _fmt_action(action: dict) -> str:
    """Format action as a compact string for [STEP] output."""
    at = action.get("action_type", "unknown")
    params = action.get("parameters", {})
    parts = []
    for k, v in params.items():
        if v is not None and v != "":
            parts.append(f"{k}={v}")
    if parts:
        return f"{at}({', '.join(parts)})"
    return at


def _fmt_error(error_val) -> str:
    """Format error for [STEP] output β€” return 'null' if no error."""
    if error_val is None or error_val == "" or error_val == "null":
        return "null"
    # Sanitize: remove newlines to keep [STEP] on a single line
    return str(error_val).replace("\n", " ").replace("\r", "")


# ── Episode Runner with Structured Output ───────────────────────────────────

def _run_episode_structured(task_id: str, scenario_index: int) -> tuple[float, int, list[float]]:
    """
    Run a single episode with required [START]/[STEP]/[END] structured stdout output.
    
    Returns: (score, steps_used, rewards_list)
    """
    rewards_list: list[float] = []
    steps_used = 0
    score = 0.0

    # ── [START] ──
    print(f"[START] task={task_id} env={ENV_NAME} model={MODEL_NAME}", flush=True)

    try:
        _check_timeout()

        # Reset environment
        r = _session.post(
            f"{ENV_BASE_URL}/reset",
            params={"task_id": task_id, "scenario_index": scenario_index},
            timeout=30,
        )
        r.raise_for_status()
        obs = r.json()

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": _first_obs_msg(obs)},
        ]

        prev_queried: dict = {}
        max_steps = obs.get("max_steps", 10)

        for step_i in range(max_steps):
            current_step = step_i + 1

            # Get LLM action
            raw = _llm_call_with_retry(messages)
            messages.append({"role": "assistant", "content": raw or "{}"})

            action = None
            try:
                if raw.strip():
                    action = _parse(raw)
            except Exception:
                pass

            if action is None:
                action = _smart_fallback(task_id, obs, current_step, max_steps)
                print(f"  [FALLBACK] step {current_step}: "
                      f"{action.get('action_type')}", file=sys.stderr)
            elif _should_override(task_id, action, obs, current_step, max_steps):
                old_at = action.get("action_type")
                action = _smart_fallback(task_id, obs, current_step, max_steps)
                print(f"  [OVERRIDE] step {current_step}: "
                      f"{old_at} -> {action.get('action_type')}", file=sys.stderr)

            # Execute step
            sr = _session.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
            sr.raise_for_status()
            result = sr.json()
            new_obs = result["observation"]

            step_reward = result["reward"]["value"]
            done = result["done"]
            error_raw = new_obs.get("last_action_error")

            rewards_list.append(step_reward)
            steps_used = current_step

            # ── [STEP] ──
            done_str = "true" if done else "false"
            error_str = _fmt_error(error_raw)
            action_str = _fmt_action(action)
            print(
                f"[STEP] step={current_step} action={action_str} "
                f"reward={step_reward:.2f} done={done_str} error={error_str}",
                flush=True,
            )

            # Debug to stderr
            print(
                f"    step {current_step:>2}: {action.get('action_type'):<28} "
                f"reward={step_reward:+.3f}  done={done}",
                file=sys.stderr,
            )

            if done:
                break

            step_msg = _step_msg(new_obs, prev_queried)
            messages.append({"role": "user", "content": step_msg})
            prev_queried = {
                k: dict(v)
                for k, v in new_obs.get("queried_data", {}).items()
                if isinstance(v, dict)
            }
            obs = new_obs

            if len(messages) > 20:
                messages = messages[:2] + messages[-16:]

        # Grade
        g = _session.get(f"{ENV_BASE_URL}/grader", timeout=30)
        g.raise_for_status()
        score = g.json().get("total", 0.0)

    except Exception as e:
        print(f"  [ERROR] {task_id} scenario {scenario_index}: {e}", file=sys.stderr)
        # If we haven't emitted any steps yet, emit a failure step
        if steps_used == 0:
            steps_used = 1
            rewards_list.append(0.0)
            print(
                f"[STEP] step=1 action=error reward=0.00 done=true "
                f"error={_fmt_error(str(e))}",
                flush=True,
            )

        # ── [END] ── (always emitted, even on exception)
    # Clamp score to open interval (0, 1) for validator compliance
    score = max(0.01, min(0.99, score))
    success_str = "true" if score > 0.01 else "false"
    rewards_str = ",".join(f"{rw:.2f}" for rw in rewards_list)
    print(
        f"[END] success={success_str} steps={steps_used} "
        f"score={score:.2f} rewards={rewards_str}",
        flush=True,
    )


    return score, steps_used, rewards_list


# ── Main ────────────────────────────────────────────────────────────────────

def main():
    runs = [
        ("alert_classification", 0),
        ("alert_classification", 1),
        ("alert_classification", 2),
        ("root_cause_analysis", 0),
        ("root_cause_analysis", 1),
        ("root_cause_analysis", 2),
        ("remediation_planning", 0),
        ("remediation_planning", 1),
        ("remediation_planning", 2),
    ]

    _DIFFICULTY = {
        "alert_classification": "🟒 Easy",
        "root_cause_analysis": "🟑 Medium",
        "remediation_planning": "πŸ”΄ Hard",
    }

    results: dict[str, list[dict]] = {}

    # Banner to stderr (not stdout β€” structured output only on stdout)
    print("", file=sys.stderr)
    print("=" * 100, file=sys.stderr)
    print("  ☁️  CLOUD INCIDENT RESPONSE β€” BASELINE INFERENCE", file=sys.stderr)
    print("=" * 100, file=sys.stderr)
    print(f"  Model:    {MODEL_NAME}", file=sys.stderr)
    print(f"  Endpoint: {API_BASE_URL}", file=sys.stderr)
    print("=" * 100, file=sys.stderr)
    print("", file=sys.stderr)

    for task_id, scenario_index in runs:
        score, steps_used, rewards_list = _run_episode_structured(task_id, scenario_index)

        difficulty = _DIFFICULTY.get(task_id, "?")
        cumulative_reward = sum(rewards_list)

        # Summary per episode to stderr
        print(
            f"  {task_id:<24} {difficulty:<12} scenario={scenario_index} "
            f"steps={steps_used} reward={cumulative_reward:+.4f} score={score:.4f}",
            file=sys.stderr,
        )

        results.setdefault(task_id, []).append({
            "scenario": scenario_index,
            "score": score,
            "steps": steps_used,
            "reward": cumulative_reward,
        })

    # Summary to stderr
    print("", file=sys.stderr)
    print("=" * 100, file=sys.stderr)
    print("  πŸ“Š SUMMARY BY TASK", file=sys.stderr)
    print("=" * 100, file=sys.stderr)

    summary = {}
    for task_id in ["alert_classification", "root_cause_analysis", "remediation_planning"]:
        if task_id not in results:
            continue
        data = results[task_id]
        avg_score = sum(d["score"] for d in data) / len(data)
        scenario_scores = " | ".join(f'{d["score"]:.2f}' for d in data)
        difficulty = _DIFFICULTY.get(task_id, "?")

        print(f"  {task_id:<24} {difficulty:<12} avg={avg_score:.4f}  [{scenario_scores}]",
              file=sys.stderr)
        summary[task_id] = round(avg_score, 4)

    if summary:
        summary["overall"] = round(sum(summary.values()) / len(summary), 4)
    else:
        summary["overall"] = 0.0

    print(f"  {'OVERALL':<24} {'':12} avg={summary['overall']:.4f}", file=sys.stderr)
    print("=" * 100, file=sys.stderr)

    # JSON summary as the LAST line of stdout (for /baseline endpoint compatibility)
    print(json.dumps(summary), flush=True)


if __name__ == "__main__":
    main()