krishuggingface commited on
Commit
4c2a495
·
1 Parent(s): 82e7138

Restore 500 steps, LLM every step with circuit breaker, use requests.Session for speed

Browse files
Files changed (3) hide show
  1. inference.py +18 -10
  2. openenv.yaml +4 -4
  3. src/env.py +1 -1
inference.py CHANGED
@@ -32,6 +32,9 @@ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
32
  # OpenAI client pointed at the proxy — never bypass this
33
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
34
 
 
 
 
35
  # ── Task metadata ─────────────────────────────────────────────────────────────
36
  TASK_NAMES = {
37
  0: "Sinusoidal FDI Detection (Easy)",
@@ -295,11 +298,19 @@ def format_observation(obs: dict) -> str:
295
  f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
296
  ])
297
 
 
 
298
 
299
  def llm_agent(obs: dict) -> dict:
300
  """Primary agent — calls the LLM through the injected proxy.
301
  Falls back to heuristic only if the API call itself raises an exception.
 
 
302
  """
 
 
 
 
303
  try:
304
  completion = client.chat.completions.create(
305
  model=MODEL_NAME,
@@ -313,7 +324,8 @@ def llm_agent(obs: dict) -> dict:
313
  )
314
  return parse_llm_response(completion.choices[0].message.content or "")
315
  except Exception as e:
316
- print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", file=sys.stderr, flush=True)
 
317
  return heuristic_agent(obs)
318
 
319
  # ── Episode runner ────────────────────────────────────────────────────────────
@@ -333,7 +345,7 @@ def run_episode(task_id: int) -> float:
333
  success = False
334
 
335
  try:
336
- reset_resp = requests.post(
337
  f"{ENV_URL}/reset",
338
  json={"task_id": task_id},
339
  timeout=60,
@@ -346,14 +358,10 @@ def run_episode(task_id: int) -> float:
346
  info = {}
347
 
348
  while not done:
349
- # Frame skipping: only invoke the LLM every 10 steps to prevent 30-min evaluation timeouts.
350
- # Step skips use the heuristics to keep episode run-time blazing fast.
351
- if step_count % 10 == 0:
352
- action = llm_agent(obs)
353
- else:
354
- action = heuristic_agent(obs)
355
 
356
- step_resp = requests.post(
357
  f"{ENV_URL}/step",
358
  json=action,
359
  timeout=60,
@@ -404,7 +412,7 @@ def wait_for_server(env_url: str, timeout: int = 60) -> bool:
404
  start_t = time.time()
405
  while time.time() - start_t < timeout:
406
  try:
407
- resp = requests.get(f"{env_url}/health", timeout=2)
408
  if resp.status_code == 200:
409
  print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
410
  return True
 
32
  # OpenAI client pointed at the proxy — never bypass this
33
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
34
 
35
+ # Persistent HTTP session for env calls — avoids TCP handshake per step
36
+ _session = requests.Session()
37
+
38
  # ── Task metadata ─────────────────────────────────────────────────────────────
39
  TASK_NAMES = {
40
  0: "Sinusoidal FDI Detection (Easy)",
 
298
  f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
299
  ])
300
 
301
+ _llm_disabled = False # circuit breaker — flips True after first LLM failure
302
+
303
 
304
  def llm_agent(obs: dict) -> dict:
305
  """Primary agent — calls the LLM through the injected proxy.
306
  Falls back to heuristic only if the API call itself raises an exception.
307
+ Uses a circuit breaker: after the first failure, all future calls skip the
308
+ network request and go straight to heuristic (restoring ~10s runtime).
309
  """
310
+ global _llm_disabled
311
+ if _llm_disabled:
312
+ return heuristic_agent(obs)
313
+
314
  try:
315
  completion = client.chat.completions.create(
316
  model=MODEL_NAME,
 
324
  )
325
  return parse_llm_response(completion.choices[0].message.content or "")
326
  except Exception as e:
327
+ print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), disabling LLM for remaining steps", file=sys.stderr, flush=True)
328
+ _llm_disabled = True
329
  return heuristic_agent(obs)
330
 
331
  # ── Episode runner ────────────────────────────────────────────────────────────
 
345
  success = False
346
 
347
  try:
348
+ reset_resp = _session.post(
349
  f"{ENV_URL}/reset",
350
  json={"task_id": task_id},
351
  timeout=60,
 
358
  info = {}
359
 
360
  while not done:
361
+ # LLM is primary; circuit breaker auto-disables after first failure
362
+ action = llm_agent(obs)
 
 
 
 
363
 
364
+ step_resp = _session.post(
365
  f"{ENV_URL}/step",
366
  json=action,
367
  timeout=60,
 
412
  start_t = time.time()
413
  while time.time() - start_t < timeout:
414
  try:
415
+ resp = _session.get(f"{env_url}/health", timeout=2)
416
  if resp.status_code == 200:
417
  print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
418
  return True
openenv.yaml CHANGED
@@ -46,20 +46,20 @@ tasks:
46
  numeric_id: 0
47
  grader: time_to_detection
48
  max_score: 1.0
49
- episode_length: 250
50
  description: Detect sinusoidal FDI attack within 100 steps of attack start
51
  - id: multi_attack_classification
52
  difficulty: medium
53
  numeric_id: 1
54
  grader: classification_accuracy
55
  max_score: 1.0
56
- episode_length: 250
57
  description: Classify attack type from observation window
58
  - id: stealthy_attack_detection
59
  difficulty: hard
60
  numeric_id: 2
61
  grader: pre_lock_loss_detection
62
  max_score: 1.0
63
- episode_length: 250
64
  description: Detect stealthy low-amplitude attack before PLL loss-of-lock
65
- episode_length: 250
 
46
  numeric_id: 0
47
  grader: time_to_detection
48
  max_score: 1.0
49
+ episode_length: 500
50
  description: Detect sinusoidal FDI attack within 100 steps of attack start
51
  - id: multi_attack_classification
52
  difficulty: medium
53
  numeric_id: 1
54
  grader: classification_accuracy
55
  max_score: 1.0
56
+ episode_length: 500
57
  description: Classify attack type from observation window
58
  - id: stealthy_attack_detection
59
  difficulty: hard
60
  numeric_id: 2
61
  grader: pre_lock_loss_detection
62
  max_score: 1.0
63
+ episode_length: 500
64
  description: Detect stealthy low-amplitude attack before PLL loss-of-lock
65
+ episode_length: 500
src/env.py CHANGED
@@ -27,7 +27,7 @@ from src.detector import AdaptiveDetector
27
 
28
 
29
  WINDOW_SIZE = 20
30
- MAX_STEPS = 250
31
  LOCK_LOSS_THRESHOLD = 0.0873 # 5 degrees in radians
32
 
33
  DETECTION_THRESHOLD = 2.0
 
27
 
28
 
29
  WINDOW_SIZE = 20
30
+ MAX_STEPS = 500
31
  LOCK_LOSS_THRESHOLD = 0.0873 # 5 degrees in radians
32
 
33
  DETECTION_THRESHOLD = 2.0