Spaces:
Sleeping
Sleeping
Commit ·
56a4099
1
Parent(s): eda1886
Fix Dockerfile curl Healthcheck and obey stdout protocols
Browse files- Dockerfile +2 -0
- inference.py +14 -13
Dockerfile
CHANGED
|
@@ -2,6 +2,8 @@ FROM python:3.10-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
|
|
|
| 5 |
COPY requirements.txt .
|
| 6 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
COPY requirements.txt .
|
| 8 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
|
inference.py
CHANGED
|
@@ -14,6 +14,7 @@ STDOUT FORMAT (OpenEnv compliance):
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
|
|
|
| 17 |
import json
|
| 18 |
import time
|
| 19 |
import requests
|
|
@@ -312,7 +313,7 @@ def llm_agent(obs: dict) -> dict:
|
|
| 312 |
)
|
| 313 |
return parse_llm_response(completion.choices[0].message.content or "")
|
| 314 |
except Exception as e:
|
| 315 |
-
print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", flush=True)
|
| 316 |
return heuristic_agent(obs)
|
| 317 |
|
| 318 |
# ── Episode runner ────────────────────────────────────────────────────────────
|
|
@@ -373,17 +374,17 @@ def run_episode(task_id: int) -> float:
|
|
| 373 |
print(
|
| 374 |
f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} "
|
| 375 |
f"detected={action['attack_detected']} type={action['attack_type']}",
|
| 376 |
-
flush=True,
|
| 377 |
)
|
| 378 |
|
| 379 |
grader_score = info.get("grader_score", 0.0)
|
| 380 |
success = grader_score > 0.0
|
| 381 |
|
| 382 |
except Exception as exc:
|
| 383 |
-
print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", flush=True)
|
| 384 |
success = False
|
| 385 |
except BaseException as exc:
|
| 386 |
-
print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", flush=True)
|
| 387 |
success = False
|
| 388 |
raise
|
| 389 |
|
|
@@ -395,27 +396,27 @@ def run_episode(task_id: int) -> float:
|
|
| 395 |
# ── Server Check ──────────────────────────────────────────────────────────────
|
| 396 |
|
| 397 |
def wait_for_server(env_url: str, timeout: int = 60) -> bool:
|
| 398 |
-
print(f"[DEBUG] Waiting for environment server at {env_url} to start...", flush=True)
|
| 399 |
start_t = time.time()
|
| 400 |
while time.time() - start_t < timeout:
|
| 401 |
try:
|
| 402 |
resp = requests.get(f"{env_url}/health", timeout=2)
|
| 403 |
if resp.status_code == 200:
|
| 404 |
-
print("[DEBUG] Environment server is up!", flush=True)
|
| 405 |
return True
|
| 406 |
except Exception:
|
| 407 |
pass
|
| 408 |
time.sleep(1)
|
| 409 |
-
print(f"[DEBUG] Environment server failed to start within {timeout}s.", flush=True)
|
| 410 |
return False
|
| 411 |
|
| 412 |
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 413 |
|
| 414 |
def main() -> None:
|
| 415 |
-
print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", flush=True)
|
| 416 |
|
| 417 |
if not wait_for_server(ENV_URL):
|
| 418 |
-
print("[DEBUG] Exiting due to server unavailable.", flush=True)
|
| 419 |
return
|
| 420 |
|
| 421 |
start_time = time.time()
|
|
@@ -426,16 +427,16 @@ def main() -> None:
|
|
| 426 |
try:
|
| 427 |
score = run_episode(task_id)
|
| 428 |
except Exception as exc:
|
| 429 |
-
print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True)
|
| 430 |
score = 0.0
|
| 431 |
scores.append(score)
|
| 432 |
-
print(f"[DEBUG] task={task_id} score={score:.4f}", flush=True)
|
| 433 |
except BaseException as exc:
|
| 434 |
-
print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", flush=True)
|
| 435 |
|
| 436 |
elapsed = time.time() - start_time
|
| 437 |
avg = sum(scores) / len(scores) if scores else 0.0
|
| 438 |
-
print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", flush=True)
|
| 439 |
|
| 440 |
|
| 441 |
if __name__ == "__main__":
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
| 17 |
+
import sys
|
| 18 |
import json
|
| 19 |
import time
|
| 20 |
import requests
|
|
|
|
| 313 |
)
|
| 314 |
return parse_llm_response(completion.choices[0].message.content or "")
|
| 315 |
except Exception as e:
|
| 316 |
+
print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", file=sys.stderr, flush=True)
|
| 317 |
return heuristic_agent(obs)
|
| 318 |
|
| 319 |
# ── Episode runner ────────────────────────────────────────────────────────────
|
|
|
|
| 374 |
print(
|
| 375 |
f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} "
|
| 376 |
f"detected={action['attack_detected']} type={action['attack_type']}",
|
| 377 |
+
file=sys.stderr, flush=True,
|
| 378 |
)
|
| 379 |
|
| 380 |
grader_score = info.get("grader_score", 0.0)
|
| 381 |
success = grader_score > 0.0
|
| 382 |
|
| 383 |
except Exception as exc:
|
| 384 |
+
print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
| 385 |
success = False
|
| 386 |
except BaseException as exc:
|
| 387 |
+
print(f"[DEBUG] Critical interruption: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
| 388 |
success = False
|
| 389 |
raise
|
| 390 |
|
|
|
|
| 396 |
# ── Server Check ──────────────────────────────────────────────────────────────
|
| 397 |
|
| 398 |
def wait_for_server(env_url: str, timeout: int = 60) -> bool:
|
| 399 |
+
print(f"[DEBUG] Waiting for environment server at {env_url} to start...", file=sys.stderr, flush=True)
|
| 400 |
start_t = time.time()
|
| 401 |
while time.time() - start_t < timeout:
|
| 402 |
try:
|
| 403 |
resp = requests.get(f"{env_url}/health", timeout=2)
|
| 404 |
if resp.status_code == 200:
|
| 405 |
+
print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
|
| 406 |
return True
|
| 407 |
except Exception:
|
| 408 |
pass
|
| 409 |
time.sleep(1)
|
| 410 |
+
print(f"[DEBUG] Environment server failed to start within {timeout}s.", file=sys.stderr, flush=True)
|
| 411 |
return False
|
| 412 |
|
| 413 |
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 414 |
|
| 415 |
def main() -> None:
|
| 416 |
+
print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", file=sys.stderr, flush=True)
|
| 417 |
|
| 418 |
if not wait_for_server(ENV_URL):
|
| 419 |
+
print("[DEBUG] Exiting due to server unavailable.", file=sys.stderr, flush=True)
|
| 420 |
return
|
| 421 |
|
| 422 |
start_time = time.time()
|
|
|
|
| 427 |
try:
|
| 428 |
score = run_episode(task_id)
|
| 429 |
except Exception as exc:
|
| 430 |
+
print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", file=sys.stderr, flush=True)
|
| 431 |
score = 0.0
|
| 432 |
scores.append(score)
|
| 433 |
+
print(f"[DEBUG] task={task_id} score={score:.4f}", file=sys.stderr, flush=True)
|
| 434 |
except BaseException as exc:
|
| 435 |
+
print(f"[DEBUG] Process interrupted: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
| 436 |
|
| 437 |
elapsed = time.time() - start_time
|
| 438 |
avg = sum(scores) / len(scores) if scores else 0.0
|
| 439 |
+
print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", file=sys.stderr, flush=True)
|
| 440 |
|
| 441 |
|
| 442 |
if __name__ == "__main__":
|