""" Inference Script — PLL Cyberattack Detection OpenEnv ===================================================== Environment variables (injected by the judging sandbox): API_BASE_URL LiteLLM proxy endpoint (MUST be used for all LLM calls) API_KEY LiteLLM proxy key (MUST be used — do not hardcode keys) MODEL_NAME Model identifier ENV_URL Environment server URL (default: http://localhost:7860) STDOUT FORMAT (OpenEnv compliance): [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ import os import sys import json import time import requests from typing import List, Optional from openai import OpenAI try: from dotenv import load_dotenv load_dotenv() except ImportError: pass # ── Config — always read from environment, never hardcode ───────────────────── # The judging sandbox injects API_BASE_URL and API_KEY via their LiteLLM proxy. # All LLM calls MUST go through these values or the submission will be rejected. API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "dummy") ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") # OpenAI client pointed at the proxy — never bypass this client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) # Persistent HTTP session for env calls — avoids TCP handshake per step _session = requests.Session() # ── Task metadata ───────────────────────────────────────────────────────────── TASK_NAMES = { 0: "Sinusoidal FDI Detection (Easy)", 1: "Multi-Attack Classification (Medium)", 2: "Stealthy Attack Detection (Hard)", } BENCHMARK = "pll-cyberattack-detection" DEFAULT_ACTION = { "attack_detected": False, "attack_type": 0, "confidence": 0.5, "protective_action": 0, } # ── System prompt ───────────────────────────────────────────────────────────── SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL). You receive time-windowed sensor readings each step and must detect cyberattacks. vq_window: q-axis voltage error (should be ~0 when healthy) vd_window: d-axis voltage omega_window: estimated frequency (normalized, nominal=0) omega_deviation_window: frequency deviation from nominal in rad/s (useful for detecting slow phase drift) raw_voltages: [va, vb, vc] at current step task_id: 0=detect only, 1=classify type, 2=detect stealthy attack For task_id=0: Focus on detecting any attack (attack_detected=True/False). For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse). For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq. Analysis tips: - In healthy state, vq values should be near 0 and stable. - Sinusoidal attacks cause oscillating patterns in vq. - Ramp attacks cause steadily increasing vq magnitude. - Pulse attacks cause sudden step changes in vq. - Stealthy attacks cause very slow, gradual drift in omega_deviation_window. - Look at trends across the full window, not just the latest value. Respond ONLY with valid JSON, no explanation: { "attack_detected": , "attack_type": , "confidence": , "protective_action": }""" # ── Logging helpers ─────────────────────────────────────────────────────────── def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: dict, reward: float, done: bool, error) -> None: action_str = json.dumps(action, separators=(',', ':')) error_val = error if error else "null" print( f"[STEP] step={step} action={action_str} " f"reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.3f} rewards={rewards_str}", flush=True, ) # ── Detector Agent & Smart Blending ─────────────────────────────────────────── def detector_agent(prev_info: dict) -> Optional[dict]: """Reads the environment's intrinsic physics-based detector output.""" det = prev_info.get("detector", {}) if not det or "attack_detected" not in det: return None return { "attack_detected": det.get("attack_detected", False), "attack_type": det.get("attack_type", 0), "confidence": det.get("confidence", 0.5), "protective_action": det.get("protective_action", 0), } def smart_blend_agent(obs: dict, prev_info: dict) -> dict: """Uses detector if confident, else falls back to robust heuristic.""" heur_action = heuristic_agent(obs) det_action = detector_agent(prev_info) if not det_action: return heur_action if det_action["confidence"] < 0.5: return heur_action return det_action # ── Rule-Based Heuristic Agent ──────────────────────────────────────────────── class HeuristicState: """Tracks running state for the heuristic agent across steps.""" def __init__(self): self.reset() def reset(self): self.vq_history = [] self.omega_dev_history = [] self.attack_detected = False self.predicted_type = 0 self.settled_baseline = None self.peak_vq = 0.0 _hstate = HeuristicState() def heuristic_agent(obs: dict) -> dict: """Rule-based fallback — only called when the LLM request fails.""" global _hstate vq = obs["vq_window"] omega_dev = obs["omega_deviation_window"] task_id = obs["task_id"] step = obs["step"] if step == 0: _hstate.reset() vq_abs = [abs(v) for v in vq] vq_mean = sum(vq_abs) / len(vq_abs) vq_max = max(vq_abs) omega_dev_abs = [abs(v) for v in omega_dev] omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) _hstate.vq_history.append(vq_mean) _hstate.omega_dev_history.append(omega_dev_mean) _hstate.peak_vq = max(_hstate.peak_vq, vq_mean) if step == 50: _hstate.settled_baseline = omega_dev_mean detected = False if step < 25 else (vq_mean > 0.01 or vq_max > 0.025) if detected: _hstate.attack_detected = True # ── Task 0: binary detection ────────────────────────────────────────────── if task_id == 0: return { "attack_detected": _hstate.attack_detected, "attack_type": 1 if _hstate.attack_detected else 0, "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8, "protective_action": 1 if _hstate.attack_detected else 0, } # ── Task 1: classification ──────────────────────────────────────────────── if task_id == 1: if not _hstate.attack_detected: return { "attack_detected": False, "attack_type": 0, "confidence": 0.7, "protective_action": 0, } n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01) if n_elevated < 5: attack_type = 1 else: elevated = [v for v in _hstate.vq_history if v > 0.005] recent = elevated[-min(20, len(elevated)):] current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0 zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i - 1] < 0) if len(recent) >= 6: first_third = sum(recent[: len(recent) // 3]) / (len(recent) // 3) last_third = sum(recent[-len(recent) // 3 :]) / (len(recent) // 3) growth = last_third / first_third if first_third > 0.001 else 1.0 else: growth = 1.0 if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05: attack_type = 3 elif current_vs_peak < 0.4 and n_elevated > 30: attack_type = 3 elif zero_crossings >= 2 and growth < 1.5: attack_type = 1 elif growth > 1.3: attack_type = 2 elif zero_crossings >= 1: attack_type = 1 else: vq_diffs = [vq[i] - vq[i - 1] for i in range(1, len(vq))] neg = sum(1 for d in vq_diffs if d < 0) attack_type = 3 if neg > 14 else 1 _hstate.predicted_type = attack_type return { "attack_detected": True, "attack_type": _hstate.predicted_type, "confidence": 0.8, "protective_action": 1, } # ── Task 2: stealthy attack ─────────────────────────────────────────────── if task_id == 2: drift_detected = False confidence = 0.3 if step > 50 and _hstate.settled_baseline is not None: baseline = _hstate.settled_baseline ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100 if len(_hstate.omega_dev_history) > 10: recent_10 = _hstate.omega_dev_history[-10:] old_10 = (_hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]) recent_avg = sum(recent_10) / len(recent_10) old_avg = sum(old_10) / len(old_10) rising = recent_avg > old_avg * 1.1 else: rising = False if ratio > 2.0: drift_detected, confidence = True, 0.9 elif ratio > 1.3 and rising: drift_detected, confidence = True, 0.8 elif rising and vq_mean > 0.1: drift_detected, confidence = True, 0.6 elif vq_mean > 0.2: drift_detected, confidence = True, 0.5 if drift_detected: _hstate.attack_detected = True return { "attack_detected": drift_detected, "attack_type": 4 if drift_detected else 0, "confidence": confidence, "protective_action": 2 if drift_detected else 0, } return DEFAULT_ACTION.copy() # ── LLM agent (PRIMARY — always called first) ───────────────────────────────── def parse_llm_response(response_text: str) -> dict: try: text = response_text.strip() if text.startswith("```"): lines = text.split("\n") in_block = False json_lines: List[str] = [] for line in lines: if line.strip().startswith("```") and not in_block: in_block = True continue elif line.strip().startswith("```") and in_block: break elif in_block: json_lines.append(line) text = "\n".join(json_lines) parsed = json.loads(text) return { "attack_detected": bool(parsed.get("attack_detected", False)), "attack_type": max(0, min(4, int(parsed.get("attack_type", 0)))), "confidence": max(0.0, min(1.0, float(parsed.get("confidence", 0.5)))), "protective_action": max(0, min(3, int(parsed.get("protective_action", 0)))), } except (json.JSONDecodeError, KeyError, TypeError, ValueError): return DEFAULT_ACTION.copy() def format_observation(obs: dict) -> str: return "\n".join([ f"Step: {obs['step']}", f"Task: {obs['task_id']}", f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}", f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}", f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}", f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}", f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}", ]) _llm_disabled = False # circuit breaker — flips True after first LLM failure def llm_agent(obs: dict, prev_info: dict) -> dict: """Primary agent — calls the LLM through the injected proxy. Falls back to smart blending if the API call itself raises an exception. Uses a circuit breaker: after the first failure, all future calls skip the network request and go straight to blending (restoring ~10s runtime). """ global _llm_disabled if _llm_disabled: return smart_blend_agent(obs, prev_info) try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": format_observation(obs)}, ], temperature=0.1, max_tokens=200, timeout=10.0, ) return parse_llm_response(completion.choices[0].message.content or "") except Exception as e: print(f"[WARN] LLM error ({type(e).__name__}: {e}), disabling LLM for remaining steps", file=sys.stderr, flush=True) _llm_disabled = True return smart_blend_agent(obs, prev_info) # ── Episode runner ──────────────────────────────────────────────────────────── def run_episode(task_id: int) -> float: task_name = TASK_NAMES[task_id] log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) # Reset heuristic state before every episode so stale data from a previous # task never bleeds into the next one (also covers the LLM fallback path). _hstate.reset() step_count = 0 grader_score = 0.0 rewards: List[float] = [] success = False try: reset_resp = _session.post( f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=60, ) reset_resp.raise_for_status() obs = reset_resp.json() done = False total_reward = 0.0 info = {} while not done: # Frame skipping: call LLM every 10 steps, heuristic in between. # This caps LLM calls at ~150 total across 3 tasks, keeping runtime # well under the 20-min judging limit even with 3s/call latency. if step_count % 10 == 0: action = llm_agent(obs, info) else: action = smart_blend_agent(obs, info) step_resp = _session.post( f"{ENV_URL}/step", json=action, timeout=60, ) step_resp.raise_for_status() result = step_resp.json() obs = result["observation"] reward = result["reward"] done = result["done"] info = result.get("info", {}) error = result.get("error", None) step_reward = reward["total"] if isinstance(reward, dict) else float(reward) total_reward += step_reward rewards.append(step_reward) step_count += 1 log_step(step=step_count, action=action, reward=step_reward, done=done, error=error) if step_count % 50 == 0: print( f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} " f"detected={action['attack_detected']} type={action['attack_type']}", file=sys.stderr, flush=True, ) grader_score = info.get("grader_score", 0.01) grader_score = max(0.01, min(0.99, grader_score)) # strict (0, 1) success = grader_score > 0.01 except Exception as exc: print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) success = False except BaseException as exc: print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) success = False raise finally: log_end(success=success, steps=step_count, score=grader_score, rewards=rewards) return grader_score # ── Server Check ────────────────────────────────────────────────────────────── def wait_for_server(env_url: str, timeout: int = 60) -> bool: print(f"[DEBUG] Waiting for environment server at {env_url} to start...", file=sys.stderr, flush=True) start_t = time.time() while time.time() - start_t < timeout: try: resp = _session.get(f"{env_url}/health", timeout=2) if resp.status_code == 200: print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True) return True except Exception: pass time.sleep(1) print(f"[DEBUG] Environment server failed to start within {timeout}s.", file=sys.stderr, flush=True) return False # ── Entry point ─────────────────────────────────────────────────────────────── def main() -> None: print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", file=sys.stderr, flush=True) if not wait_for_server(ENV_URL): print("[DEBUG] Exiting due to server unavailable.", file=sys.stderr, flush=True) return start_time = time.time() scores: List[float] = [] try: for task_id in range(3): try: score = run_episode(task_id) except Exception as exc: print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", file=sys.stderr, flush=True) score = 0.0 scores.append(score) print(f"[DEBUG] task={task_id} score={score:.4f}", file=sys.stderr, flush=True) except BaseException as exc: print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True) elapsed = time.time() - start_time avg = sum(scores) / len(scores) if scores else 0.0 print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", file=sys.stderr, flush=True) if __name__ == "__main__": main()