krishuggingface commited on
Commit
56a4099
·
1 Parent(s): eda1886

Fix Dockerfile curl Healthcheck and obey stdout protocols

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. inference.py +14 -13
Dockerfile CHANGED
@@ -2,6 +2,8 @@ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
 
 
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
 
2
 
3
  WORKDIR /app
4
 
5
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
6
+
7
  COPY requirements.txt .
8
  RUN pip install --no-cache-dir -r requirements.txt
9
 
inference.py CHANGED
@@ -14,6 +14,7 @@ STDOUT FORMAT (OpenEnv compliance):
14
  """
15
 
16
  import os
 
17
  import json
18
  import time
19
  import requests
@@ -312,7 +313,7 @@ def llm_agent(obs: dict) -> dict:
312
  )
313
  return parse_llm_response(completion.choices[0].message.content or "")
314
  except Exception as e:
315
- print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", flush=True)
316
  return heuristic_agent(obs)
317
 
318
  # ── Episode runner ────────────────────────────────────────────────────────────
@@ -373,17 +374,17 @@ def run_episode(task_id: int) -> float:
373
  print(
374
  f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} "
375
  f"detected={action['attack_detected']} type={action['attack_type']}",
376
- flush=True,
377
  )
378
 
379
  grader_score = info.get("grader_score", 0.0)
380
  success = grader_score > 0.0
381
 
382
  except Exception as exc:
383
- print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", flush=True)
384
  success = False
385
  except BaseException as exc:
386
- print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", flush=True)
387
  success = False
388
  raise
389
 
@@ -395,27 +396,27 @@ def run_episode(task_id: int) -> float:
395
  # ── Server Check ──────────────────────────────────────────────────────────────
396
 
397
  def wait_for_server(env_url: str, timeout: int = 60) -> bool:
398
- print(f"[DEBUG] Waiting for environment server at {env_url} to start...", flush=True)
399
  start_t = time.time()
400
  while time.time() - start_t < timeout:
401
  try:
402
  resp = requests.get(f"{env_url}/health", timeout=2)
403
  if resp.status_code == 200:
404
- print("[DEBUG] Environment server is up!", flush=True)
405
  return True
406
  except Exception:
407
  pass
408
  time.sleep(1)
409
- print(f"[DEBUG] Environment server failed to start within {timeout}s.", flush=True)
410
  return False
411
 
412
  # ── Entry point ───────────────────────────────────────────────────────────────
413
 
414
  def main() -> None:
415
- print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", flush=True)
416
 
417
  if not wait_for_server(ENV_URL):
418
- print("[DEBUG] Exiting due to server unavailable.", flush=True)
419
  return
420
 
421
  start_time = time.time()
@@ -426,16 +427,16 @@ def main() -> None:
426
  try:
427
  score = run_episode(task_id)
428
  except Exception as exc:
429
- print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True)
430
  score = 0.0
431
  scores.append(score)
432
- print(f"[DEBUG] task={task_id} score={score:.4f}", flush=True)
433
  except BaseException as exc:
434
- print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", flush=True)
435
 
436
  elapsed = time.time() - start_time
437
  avg = sum(scores) / len(scores) if scores else 0.0
438
- print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", flush=True)
439
 
440
 
441
  if __name__ == "__main__":
 
14
  """
15
 
16
  import os
17
+ import sys
18
  import json
19
  import time
20
  import requests
 
313
  )
314
  return parse_llm_response(completion.choices[0].message.content or "")
315
  except Exception as e:
316
+ print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", file=sys.stderr, flush=True)
317
  return heuristic_agent(obs)
318
 
319
  # ── Episode runner ────────────────────────────────────────────────────────────
 
374
  print(
375
  f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} "
376
  f"detected={action['attack_detected']} type={action['attack_type']}",
377
+ file=sys.stderr, flush=True,
378
  )
379
 
380
  grader_score = info.get("grader_score", 0.0)
381
  success = grader_score > 0.0
382
 
383
  except Exception as exc:
384
+ print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
385
  success = False
386
  except BaseException as exc:
387
+ print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
388
  success = False
389
  raise
390
 
 
396
  # ── Server Check ──────────────────────────────────────────────────────────────
397
 
398
  def wait_for_server(env_url: str, timeout: int = 60) -> bool:
399
+ print(f"[DEBUG] Waiting for environment server at {env_url} to start...", file=sys.stderr, flush=True)
400
  start_t = time.time()
401
  while time.time() - start_t < timeout:
402
  try:
403
  resp = requests.get(f"{env_url}/health", timeout=2)
404
  if resp.status_code == 200:
405
+ print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
406
  return True
407
  except Exception:
408
  pass
409
  time.sleep(1)
410
+ print(f"[DEBUG] Environment server failed to start within {timeout}s.", file=sys.stderr, flush=True)
411
  return False
412
 
413
  # ── Entry point ───────────────────────────────────────────────────────────────
414
 
415
  def main() -> None:
416
+ print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", file=sys.stderr, flush=True)
417
 
418
  if not wait_for_server(ENV_URL):
419
+ print("[DEBUG] Exiting due to server unavailable.", file=sys.stderr, flush=True)
420
  return
421
 
422
  start_time = time.time()
 
427
  try:
428
  score = run_episode(task_id)
429
  except Exception as exc:
430
+ print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", file=sys.stderr, flush=True)
431
  score = 0.0
432
  scores.append(score)
433
+ print(f"[DEBUG] task={task_id} score={score:.4f}", file=sys.stderr, flush=True)
434
  except BaseException as exc:
435
+ print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
436
 
437
  elapsed = time.time() - start_time
438
  avg = sum(scores) / len(scores) if scores else 0.0
439
+ print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", file=sys.stderr, flush=True)
440
 
441
 
442
  if __name__ == "__main__":