Elliot89 commited on
Commit
3dc9d8d
·
verified ·
1 Parent(s): 07a26d7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +182 -202
inference.py CHANGED
@@ -3,6 +3,11 @@ inference.py — Cloud Incident Response OpenEnv baseline inference script.
3
 
4
  The LLM reasons from evidence. Fallback is a dumb safety net that scores low.
5
  Override only blocks clearly invalid actions (wrong task submission, bad params).
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
@@ -14,11 +19,11 @@ import time
14
 
15
  import requests
16
  import time as _time
17
- _START = _time.time()
18
  _MAX_RUNTIME = 1080
19
 
20
  def _check_timeout():
21
- if _time.time() - _START > _MAX_RUNTIME:
22
  raise RuntimeError("Approaching 20min limit — stopping early")
23
  try:
24
  from dotenv import load_dotenv
@@ -31,6 +36,7 @@ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
31
  MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.1-8b-instant")
32
  API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or ""
33
  ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
 
34
 
35
  if not API_KEY:
36
  print("[WARN] No API key set — LLM calls will fail.", file=sys.stderr)
@@ -404,83 +410,163 @@ def _llm_call_with_retry(messages: list, max_retries: int = 1) -> str:
404
  return ""
405
 
406
 
407
- def _run_episode(task_id: str, scenario_index: int) -> float:
408
- if _time.time() - _START > _MAX_RUNTIME:
409
- print(f" [TIMEOUT] Approaching 20min limit — skipping {task_id} s{scenario_index}",
410
- file=sys.stderr)
411
- return 0.0
412
- _check_timeout()
413
 
414
- r = _session.post(
415
- f"{ENV_BASE_URL}/reset",
416
- params={"task_id": task_id, "scenario_index": scenario_index},
417
- timeout=30,
418
- )
419
- r.raise_for_status()
420
- obs = r.json()
 
 
 
 
421
 
422
- messages = [
423
- {"role": "system", "content": SYSTEM_PROMPT},
424
- {"role": "user", "content": _first_obs_msg(obs)},
425
- ]
426
 
427
- prev_queried: dict = {}
428
- max_steps = obs.get("max_steps", 10)
 
 
 
 
429
 
430
- for step_i in range(max_steps):
431
- current_step = step_i + 1
432
 
433
- raw = _llm_call_with_retry(messages)
434
- messages.append({"role": "assistant", "content": raw or "{}"})
435
 
436
- action = None
437
- try:
438
- if raw.strip():
439
- action = _parse(raw)
440
- except Exception:
441
- pass
442
-
443
- if action is None:
444
- action = _smart_fallback(task_id, obs, current_step, max_steps)
445
- print(f" [FALLBACK] step {current_step}: "
446
- f"{action.get('action_type')}", file=sys.stderr)
447
- elif _should_override(task_id, action, obs, current_step, max_steps):
448
- old_at = action.get("action_type")
449
- action = _smart_fallback(task_id, obs, current_step, max_steps)
450
- print(f" [OVERRIDE] step {current_step}: "
451
- f"{old_at} -> {action.get('action_type')}", file=sys.stderr)
452
-
453
- sr = _session.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
454
- sr.raise_for_status()
455
- result = sr.json()
456
- new_obs = result["observation"]
457
 
458
- print(
459
- f" step {current_step:>2}: {action.get('action_type'):<28} "
460
- f"reward={result['reward']['value']:+.3f} "
461
- f"done={result['done']}",
462
- file=sys.stderr,
 
 
 
 
 
 
463
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
- if result.get("done"):
466
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
- step_msg = _step_msg(new_obs, prev_queried)
469
- messages.append({"role": "user", "content": step_msg})
470
- prev_queried = {
471
- k: dict(v)
472
- for k, v in new_obs.get("queried_data", {}).items()
473
- if isinstance(v, dict)
474
- }
475
- obs = new_obs
476
 
477
- if len(messages) > 20:
478
- messages = messages[:2] + messages[-16:]
479
 
480
- g = _session.get(f"{ENV_BASE_URL}/grader", timeout=30)
481
- g.raise_for_status()
482
- return g.json().get("total", 0.0)
483
 
 
484
 
485
  def main():
486
  runs = [
@@ -501,60 +587,43 @@ def main():
501
  "remediation_planning": "🔴 Hard",
502
  }
503
 
504
- _MAX_STEPS = {
505
- "alert_classification": 3,
506
- "root_cause_analysis": 10,
507
- "remediation_planning": 15,
508
- }
509
-
510
  results: dict[str, list[dict]] = {}
511
 
512
- print()
513
- print("=" * 100)
514
- print(" ☁️ CLOUD INCIDENT RESPONSE — BASELINE INFERENCE")
515
- print("=" * 100)
516
- print(f" Model: {MODEL_NAME}")
517
- print(f" Endpoint: {API_BASE_URL}")
518
- print("=" * 100)
519
- print()
520
-
521
- # Table header
522
- print(f"{'Task':<24} {'Difficulty':<12} {'Scenario':>8} {'Steps':>10} {'Actions':>10} {'Reward':>10} {'Score':>10}")
523
- print("─" * 100)
524
 
525
  for task_id, scenario_index in runs:
526
- try:
527
- score, steps_used, actions_taken, cumulative_reward = _run_episode_detailed(task_id, scenario_index)
528
- except Exception as e:
529
- print(f" [ERROR] {task_id} scenario {scenario_index}: {e}", file=sys.stderr)
530
- score, steps_used, actions_taken, cumulative_reward = 0.0, 0, 0, 0.0
531
 
532
  difficulty = _DIFFICULTY.get(task_id, "?")
533
- max_steps = _MAX_STEPS.get(task_id, "?")
534
- steps_display = f"{steps_used}/{max_steps}"
535
 
 
536
  print(
537
- f"{task_id:<24} {difficulty:<12} {scenario_index:>8} "
538
- f"{steps_display:>10} {actions_taken:>10} {cumulative_reward:>+10.4f} {score:>10.4f}"
 
539
  )
540
 
541
  results.setdefault(task_id, []).append({
542
  "scenario": scenario_index,
543
  "score": score,
544
  "steps": steps_used,
545
- "actions": actions_taken,
546
  "reward": cumulative_reward,
547
  })
548
 
549
- print("─" * 100)
550
- print()
551
-
552
- # Summary table
553
- print("=" * 100)
554
- print(" 📊 SUMMARY BY TASK")
555
- print("=" * 100)
556
- print(f"{'Task':<24} {'Difficulty':<12} {'Avg Score':>10} {'Avg Steps':>10} {'Scenarios':>20}")
557
- print("─" * 100)
558
 
559
  summary = {}
560
  for task_id in ["alert_classification", "root_cause_analysis", "remediation_planning"]:
@@ -562,113 +631,24 @@ def main():
562
  continue
563
  data = results[task_id]
564
  avg_score = sum(d["score"] for d in data) / len(data)
565
- avg_steps = sum(d["steps"] for d in data) / len(data)
566
  scenario_scores = " | ".join(f'{d["score"]:.2f}' for d in data)
567
  difficulty = _DIFFICULTY.get(task_id, "?")
568
 
569
- print(f"{task_id:<24} {difficulty:<12} {avg_score:>10.4f} {avg_steps:>10.1f} {scenario_scores:>20}")
 
570
  summary[task_id] = round(avg_score, 4)
571
 
572
- summary["overall"] = round(sum(summary.values()) / len(summary), 4)
573
-
574
- print("─" * 100)
575
- print(f"{'OVERALL':<24} {'':12} {summary['overall']:>10.4f}")
576
- print("=" * 100)
577
- print()
578
-
579
- # Difficulty progression check
580
- easy = summary.get("alert_classification", 0)
581
- med = summary.get("root_cause_analysis", 0)
582
- hard = summary.get("remediation_planning", 0)
583
-
584
- if easy > med > hard:
585
- print(" ✅ Difficulty Progression: Easy (%.2f) > Medium (%.2f) > Hard (%.2f)" % (easy, med, hard))
586
- elif easy > med and easy > hard:
587
- print(" ⚠️ Difficulty Progression: Easy highest, Medium ≈ Hard")
588
  else:
589
- print(" Difficulty Progression: Unexpected order")
590
-
591
- print()
592
- print(json.dumps(summary))
593
-
594
-
595
- def _run_episode_detailed(task_id: str, scenario_index: int) -> tuple[float, int, int, float]:
596
- """Run episode and return (score, steps_used, actions_taken, cumulative_reward)."""
597
- r = _session.post(
598
- f"{ENV_BASE_URL}/reset",
599
- params={"task_id": task_id, "scenario_index": scenario_index},
600
- timeout=30,
601
- )
602
- r.raise_for_status()
603
- obs = r.json()
604
 
605
- messages = [
606
- {"role": "system", "content": SYSTEM_PROMPT},
607
- {"role": "user", "content": _first_obs_msg(obs)},
608
- ]
609
 
610
- prev_queried: dict = {}
611
- max_steps = obs.get("max_steps", 10)
612
- actions_taken = 0
613
- cumulative_reward = 0.0
614
-
615
- for step_i in range(max_steps):
616
- current_step = step_i + 1
617
-
618
- raw = _llm_call_with_retry(messages)
619
- messages.append({"role": "assistant", "content": raw or "{}"})
620
-
621
- action = None
622
- try:
623
- if raw.strip():
624
- action = _parse(raw)
625
- except Exception:
626
- pass
627
-
628
- if action is None:
629
- action = _smart_fallback(task_id, obs, current_step, max_steps)
630
- print(f" [FALLBACK] step {current_step}: {action.get('action_type')}", file=sys.stderr)
631
- elif _should_override(task_id, action, obs, current_step, max_steps):
632
- old_at = action.get("action_type")
633
- action = _smart_fallback(task_id, obs, current_step, max_steps)
634
- print(f" [OVERRIDE] step {current_step}: {old_at} -> {action.get('action_type')}", file=sys.stderr)
635
-
636
- sr = _session.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
637
- sr.raise_for_status()
638
- result = sr.json()
639
- new_obs = result["observation"]
640
-
641
- actions_taken += 1
642
- step_reward = result['reward']['value']
643
- cumulative_reward = result['reward'].get('cumulative', cumulative_reward + step_reward)
644
-
645
- # Step detail output
646
- print(
647
- f" step {current_step:>2}: {action.get('action_type'):<28} "
648
- f"reward={step_reward:+.3f} done={result['done']}"
649
- )
650
-
651
- if result.get("done"):
652
- break
653
-
654
- step_msg = _step_msg(new_obs, prev_queried)
655
- messages.append({"role": "user", "content": step_msg})
656
- prev_queried = {
657
- k: dict(v)
658
- for k, v in new_obs.get("queried_data", {}).items()
659
- if isinstance(v, dict)
660
- }
661
- obs = new_obs
662
-
663
- if len(messages) > 20:
664
- messages = messages[:2] + messages[-16:]
665
-
666
- g = _session.get(f"{ENV_BASE_URL}/grader", timeout=30)
667
- g.raise_for_status()
668
- score = g.json().get("total", 0.0)
669
-
670
- return score, current_step, actions_taken, cumulative_reward
671
 
672
 
673
  if __name__ == "__main__":
674
- main()
 
3
 
4
  The LLM reasons from evidence. Fallback is a dumb safety net that scores low.
5
  Override only blocks clearly invalid actions (wrong task submission, bad params).
6
+
7
+ STRUCTURED OUTPUT:
8
+ [START] task=<task_name> env=cloud-incident-response model=<model_name>
9
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
10
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
11
  """
12
 
13
  from __future__ import annotations
 
19
 
20
  import requests
21
  import time as _time
22
+ _START_TIME = _time.time()
23
  _MAX_RUNTIME = 1080
24
 
25
  def _check_timeout():
26
+ if _time.time() - _START_TIME > _MAX_RUNTIME:
27
  raise RuntimeError("Approaching 20min limit — stopping early")
28
  try:
29
  from dotenv import load_dotenv
 
36
  MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.1-8b-instant")
37
  API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or ""
38
  ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
39
+ ENV_NAME = "cloud-incident-response"
40
 
41
  if not API_KEY:
42
  print("[WARN] No API key set — LLM calls will fail.", file=sys.stderr)
 
410
  return ""
411
 
412
 
413
+ # ── Structured Output Helpers ───────────────────────────────────────────────
 
 
 
 
 
414
 
415
+ def _fmt_action(action: dict) -> str:
416
+ """Format action as a compact string for [STEP] output."""
417
+ at = action.get("action_type", "unknown")
418
+ params = action.get("parameters", {})
419
+ parts = []
420
+ for k, v in params.items():
421
+ if v is not None and v != "":
422
+ parts.append(f"{k}={v}")
423
+ if parts:
424
+ return f"{at}({', '.join(parts)})"
425
+ return at
426
 
 
 
 
 
427
 
428
+ def _fmt_error(error_val) -> str:
429
+ """Format error for [STEP] output — return 'null' if no error."""
430
+ if error_val is None or error_val == "" or error_val == "null":
431
+ return "null"
432
+ # Sanitize: remove newlines to keep [STEP] on a single line
433
+ return str(error_val).replace("\n", " ").replace("\r", "")
434
 
 
 
435
 
436
+ # ── Episode Runner with Structured Output ───────────────────────────────────
 
437
 
438
+ def _run_episode_structured(task_id: str, scenario_index: int) -> tuple[float, int, list[float]]:
439
+ """
440
+ Run a single episode with required [START]/[STEP]/[END] structured stdout output.
441
+
442
+ Returns: (score, steps_used, rewards_list)
443
+ """
444
+ rewards_list: list[float] = []
445
+ steps_used = 0
446
+ score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
+ # ── [START] ──
449
+ print(f"[START] task={task_id} env={ENV_NAME} model={MODEL_NAME}", flush=True)
450
+
451
+ try:
452
+ _check_timeout()
453
+
454
+ # Reset environment
455
+ r = _session.post(
456
+ f"{ENV_BASE_URL}/reset",
457
+ params={"task_id": task_id, "scenario_index": scenario_index},
458
+ timeout=30,
459
  )
460
+ r.raise_for_status()
461
+ obs = r.json()
462
+
463
+ messages = [
464
+ {"role": "system", "content": SYSTEM_PROMPT},
465
+ {"role": "user", "content": _first_obs_msg(obs)},
466
+ ]
467
+
468
+ prev_queried: dict = {}
469
+ max_steps = obs.get("max_steps", 10)
470
+
471
+ for step_i in range(max_steps):
472
+ current_step = step_i + 1
473
+
474
+ # Get LLM action
475
+ raw = _llm_call_with_retry(messages)
476
+ messages.append({"role": "assistant", "content": raw or "{}"})
477
+
478
+ action = None
479
+ try:
480
+ if raw.strip():
481
+ action = _parse(raw)
482
+ except Exception:
483
+ pass
484
+
485
+ if action is None:
486
+ action = _smart_fallback(task_id, obs, current_step, max_steps)
487
+ print(f" [FALLBACK] step {current_step}: "
488
+ f"{action.get('action_type')}", file=sys.stderr)
489
+ elif _should_override(task_id, action, obs, current_step, max_steps):
490
+ old_at = action.get("action_type")
491
+ action = _smart_fallback(task_id, obs, current_step, max_steps)
492
+ print(f" [OVERRIDE] step {current_step}: "
493
+ f"{old_at} -> {action.get('action_type')}", file=sys.stderr)
494
+
495
+ # Execute step
496
+ sr = _session.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
497
+ sr.raise_for_status()
498
+ result = sr.json()
499
+ new_obs = result["observation"]
500
+
501
+ step_reward = result["reward"]["value"]
502
+ done = result["done"]
503
+ error_raw = new_obs.get("last_action_error")
504
+
505
+ rewards_list.append(step_reward)
506
+ steps_used = current_step
507
+
508
+ # ── [STEP] ──
509
+ done_str = "true" if done else "false"
510
+ error_str = _fmt_error(error_raw)
511
+ action_str = _fmt_action(action)
512
+ print(
513
+ f"[STEP] step={current_step} action={action_str} "
514
+ f"reward={step_reward:.2f} done={done_str} error={error_str}",
515
+ flush=True,
516
+ )
517
+
518
+ # Debug to stderr
519
+ print(
520
+ f" step {current_step:>2}: {action.get('action_type'):<28} "
521
+ f"reward={step_reward:+.3f} done={done}",
522
+ file=sys.stderr,
523
+ )
524
 
525
+ if done:
526
+ break
527
+
528
+ step_msg = _step_msg(new_obs, prev_queried)
529
+ messages.append({"role": "user", "content": step_msg})
530
+ prev_queried = {
531
+ k: dict(v)
532
+ for k, v in new_obs.get("queried_data", {}).items()
533
+ if isinstance(v, dict)
534
+ }
535
+ obs = new_obs
536
+
537
+ if len(messages) > 20:
538
+ messages = messages[:2] + messages[-16:]
539
+
540
+ # Grade
541
+ g = _session.get(f"{ENV_BASE_URL}/grader", timeout=30)
542
+ g.raise_for_status()
543
+ score = g.json().get("total", 0.0)
544
+
545
+ except Exception as e:
546
+ print(f" [ERROR] {task_id} scenario {scenario_index}: {e}", file=sys.stderr)
547
+ # If we haven't emitted any steps yet, emit a failure step
548
+ if steps_used == 0:
549
+ steps_used = 1
550
+ rewards_list.append(0.0)
551
+ print(
552
+ f"[STEP] step=1 action=error reward=0.00 done=true "
553
+ f"error={_fmt_error(str(e))}",
554
+ flush=True,
555
+ )
556
 
557
+ # ── [END] ── (always emitted, even on exception)
558
+ success_str = "true" if score > 0 else "false"
559
+ rewards_str = ",".join(f"{rw:.2f}" for rw in rewards_list)
560
+ print(
561
+ f"[END] success={success_str} steps={steps_used} "
562
+ f"score={score:.2f} rewards={rewards_str}",
563
+ flush=True,
564
+ )
565
 
566
+ return score, steps_used, rewards_list
 
567
 
 
 
 
568
 
569
+ # ── Main ────────────────────────────────────────────────────────────────────
570
 
571
  def main():
572
  runs = [
 
587
  "remediation_planning": "🔴 Hard",
588
  }
589
 
 
 
 
 
 
 
590
  results: dict[str, list[dict]] = {}
591
 
592
+ # Banner to stderr (not stdout — structured output only on stdout)
593
+ print("", file=sys.stderr)
594
+ print("=" * 100, file=sys.stderr)
595
+ print(" ☁️ CLOUD INCIDENT RESPONSE — BASELINE INFERENCE", file=sys.stderr)
596
+ print("=" * 100, file=sys.stderr)
597
+ print(f" Model: {MODEL_NAME}", file=sys.stderr)
598
+ print(f" Endpoint: {API_BASE_URL}", file=sys.stderr)
599
+ print("=" * 100, file=sys.stderr)
600
+ print("", file=sys.stderr)
 
 
 
601
 
602
  for task_id, scenario_index in runs:
603
+ score, steps_used, rewards_list = _run_episode_structured(task_id, scenario_index)
 
 
 
 
604
 
605
  difficulty = _DIFFICULTY.get(task_id, "?")
606
+ cumulative_reward = sum(rewards_list)
 
607
 
608
+ # Summary per episode to stderr
609
  print(
610
+ f" {task_id:<24} {difficulty:<12} scenario={scenario_index} "
611
+ f"steps={steps_used} reward={cumulative_reward:+.4f} score={score:.4f}",
612
+ file=sys.stderr,
613
  )
614
 
615
  results.setdefault(task_id, []).append({
616
  "scenario": scenario_index,
617
  "score": score,
618
  "steps": steps_used,
 
619
  "reward": cumulative_reward,
620
  })
621
 
622
+ # Summary to stderr
623
+ print("", file=sys.stderr)
624
+ print("=" * 100, file=sys.stderr)
625
+ print(" 📊 SUMMARY BY TASK", file=sys.stderr)
626
+ print("=" * 100, file=sys.stderr)
 
 
 
 
627
 
628
  summary = {}
629
  for task_id in ["alert_classification", "root_cause_analysis", "remediation_planning"]:
 
631
  continue
632
  data = results[task_id]
633
  avg_score = sum(d["score"] for d in data) / len(data)
 
634
  scenario_scores = " | ".join(f'{d["score"]:.2f}' for d in data)
635
  difficulty = _DIFFICULTY.get(task_id, "?")
636
 
637
+ print(f" {task_id:<24} {difficulty:<12} avg={avg_score:.4f} [{scenario_scores}]",
638
+ file=sys.stderr)
639
  summary[task_id] = round(avg_score, 4)
640
 
641
+ if summary:
642
+ summary["overall"] = round(sum(summary.values()) / len(summary), 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  else:
644
+ summary["overall"] = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
+ print(f" {'OVERALL':<24} {'':12} avg={summary['overall']:.4f}", file=sys.stderr)
647
+ print("=" * 100, file=sys.stderr)
 
 
648
 
649
+ # JSON summary as the LAST line of stdout (for /baseline endpoint compatibility)
650
+ print(json.dumps(summary), flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
 
653
  if __name__ == "__main__":
654
+ main()