Spaces:
Sleeping
Sleeping
Commit ·
4c2a495
1
Parent(s): 82e7138
Restore 500 steps, LLM every step with circuit breaker, use requests.Session for speed
Browse files- inference.py +18 -10
- openenv.yaml +4 -4
- src/env.py +1 -1
inference.py
CHANGED
|
@@ -32,6 +32,9 @@ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
|
| 32 |
# OpenAI client pointed at the proxy — never bypass this
|
| 33 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
# ── Task metadata ─────────────────────────────────────────────────────────────
|
| 36 |
TASK_NAMES = {
|
| 37 |
0: "Sinusoidal FDI Detection (Easy)",
|
|
@@ -295,11 +298,19 @@ def format_observation(obs: dict) -> str:
|
|
| 295 |
f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
|
| 296 |
])
|
| 297 |
|
|
|
|
|
|
|
| 298 |
|
| 299 |
def llm_agent(obs: dict) -> dict:
|
| 300 |
"""Primary agent — calls the LLM through the injected proxy.
|
| 301 |
Falls back to heuristic only if the API call itself raises an exception.
|
|
|
|
|
|
|
| 302 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
try:
|
| 304 |
completion = client.chat.completions.create(
|
| 305 |
model=MODEL_NAME,
|
|
@@ -313,7 +324,8 @@ def llm_agent(obs: dict) -> dict:
|
|
| 313 |
)
|
| 314 |
return parse_llm_response(completion.choices[0].message.content or "")
|
| 315 |
except Exception as e:
|
| 316 |
-
print(f"[DEBUG] LLM error ({type(e).__name__}: {e}),
|
|
|
|
| 317 |
return heuristic_agent(obs)
|
| 318 |
|
| 319 |
# ── Episode runner ────────────────────────────────────────────────────────────
|
|
@@ -333,7 +345,7 @@ def run_episode(task_id: int) -> float:
|
|
| 333 |
success = False
|
| 334 |
|
| 335 |
try:
|
| 336 |
-
reset_resp =
|
| 337 |
f"{ENV_URL}/reset",
|
| 338 |
json={"task_id": task_id},
|
| 339 |
timeout=60,
|
|
@@ -346,14 +358,10 @@ def run_episode(task_id: int) -> float:
|
|
| 346 |
info = {}
|
| 347 |
|
| 348 |
while not done:
|
| 349 |
-
#
|
| 350 |
-
|
| 351 |
-
if step_count % 10 == 0:
|
| 352 |
-
action = llm_agent(obs)
|
| 353 |
-
else:
|
| 354 |
-
action = heuristic_agent(obs)
|
| 355 |
|
| 356 |
-
step_resp =
|
| 357 |
f"{ENV_URL}/step",
|
| 358 |
json=action,
|
| 359 |
timeout=60,
|
|
@@ -404,7 +412,7 @@ def wait_for_server(env_url: str, timeout: int = 60) -> bool:
|
|
| 404 |
start_t = time.time()
|
| 405 |
while time.time() - start_t < timeout:
|
| 406 |
try:
|
| 407 |
-
resp =
|
| 408 |
if resp.status_code == 200:
|
| 409 |
print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
|
| 410 |
return True
|
|
|
|
| 32 |
# OpenAI client pointed at the proxy — never bypass this
|
| 33 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 34 |
|
| 35 |
+
# Persistent HTTP session for env calls — avoids TCP handshake per step
|
| 36 |
+
_session = requests.Session()
|
| 37 |
+
|
| 38 |
# ── Task metadata ─────────────────────────────────────────────────────────────
|
| 39 |
TASK_NAMES = {
|
| 40 |
0: "Sinusoidal FDI Detection (Easy)",
|
|
|
|
| 298 |
f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
|
| 299 |
])
|
| 300 |
|
| 301 |
+
_llm_disabled = False # circuit breaker — flips True after first LLM failure
|
| 302 |
+
|
| 303 |
|
| 304 |
def llm_agent(obs: dict) -> dict:
|
| 305 |
"""Primary agent — calls the LLM through the injected proxy.
|
| 306 |
Falls back to heuristic only if the API call itself raises an exception.
|
| 307 |
+
Uses a circuit breaker: after the first failure, all future calls skip the
|
| 308 |
+
network request and go straight to heuristic (restoring ~10s runtime).
|
| 309 |
"""
|
| 310 |
+
global _llm_disabled
|
| 311 |
+
if _llm_disabled:
|
| 312 |
+
return heuristic_agent(obs)
|
| 313 |
+
|
| 314 |
try:
|
| 315 |
completion = client.chat.completions.create(
|
| 316 |
model=MODEL_NAME,
|
|
|
|
| 324 |
)
|
| 325 |
return parse_llm_response(completion.choices[0].message.content or "")
|
| 326 |
except Exception as e:
|
| 327 |
+
print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), disabling LLM for remaining steps", file=sys.stderr, flush=True)
|
| 328 |
+
_llm_disabled = True
|
| 329 |
return heuristic_agent(obs)
|
| 330 |
|
| 331 |
# ── Episode runner ────────────────────────────────────────────────────────────
|
|
|
|
| 345 |
success = False
|
| 346 |
|
| 347 |
try:
|
| 348 |
+
reset_resp = _session.post(
|
| 349 |
f"{ENV_URL}/reset",
|
| 350 |
json={"task_id": task_id},
|
| 351 |
timeout=60,
|
|
|
|
| 358 |
info = {}
|
| 359 |
|
| 360 |
while not done:
|
| 361 |
+
# LLM is primary; circuit breaker auto-disables after first failure
|
| 362 |
+
action = llm_agent(obs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
+
step_resp = _session.post(
|
| 365 |
f"{ENV_URL}/step",
|
| 366 |
json=action,
|
| 367 |
timeout=60,
|
|
|
|
| 412 |
start_t = time.time()
|
| 413 |
while time.time() - start_t < timeout:
|
| 414 |
try:
|
| 415 |
+
resp = _session.get(f"{env_url}/health", timeout=2)
|
| 416 |
if resp.status_code == 200:
|
| 417 |
print("[DEBUG] Environment server is up!", file=sys.stderr, flush=True)
|
| 418 |
return True
|
openenv.yaml
CHANGED
|
@@ -46,20 +46,20 @@ tasks:
|
|
| 46 |
numeric_id: 0
|
| 47 |
grader: time_to_detection
|
| 48 |
max_score: 1.0
|
| 49 |
-
episode_length:
|
| 50 |
description: Detect sinusoidal FDI attack within 100 steps of attack start
|
| 51 |
- id: multi_attack_classification
|
| 52 |
difficulty: medium
|
| 53 |
numeric_id: 1
|
| 54 |
grader: classification_accuracy
|
| 55 |
max_score: 1.0
|
| 56 |
-
episode_length:
|
| 57 |
description: Classify attack type from observation window
|
| 58 |
- id: stealthy_attack_detection
|
| 59 |
difficulty: hard
|
| 60 |
numeric_id: 2
|
| 61 |
grader: pre_lock_loss_detection
|
| 62 |
max_score: 1.0
|
| 63 |
-
episode_length:
|
| 64 |
description: Detect stealthy low-amplitude attack before PLL loss-of-lock
|
| 65 |
-
episode_length:
|
|
|
|
| 46 |
numeric_id: 0
|
| 47 |
grader: time_to_detection
|
| 48 |
max_score: 1.0
|
| 49 |
+
episode_length: 500
|
| 50 |
description: Detect sinusoidal FDI attack within 100 steps of attack start
|
| 51 |
- id: multi_attack_classification
|
| 52 |
difficulty: medium
|
| 53 |
numeric_id: 1
|
| 54 |
grader: classification_accuracy
|
| 55 |
max_score: 1.0
|
| 56 |
+
episode_length: 500
|
| 57 |
description: Classify attack type from observation window
|
| 58 |
- id: stealthy_attack_detection
|
| 59 |
difficulty: hard
|
| 60 |
numeric_id: 2
|
| 61 |
grader: pre_lock_loss_detection
|
| 62 |
max_score: 1.0
|
| 63 |
+
episode_length: 500
|
| 64 |
description: Detect stealthy low-amplitude attack before PLL loss-of-lock
|
| 65 |
+
episode_length: 500
|
src/env.py
CHANGED
|
@@ -27,7 +27,7 @@ from src.detector import AdaptiveDetector
|
|
| 27 |
|
| 28 |
|
| 29 |
WINDOW_SIZE = 20
|
| 30 |
-
MAX_STEPS =
|
| 31 |
LOCK_LOSS_THRESHOLD = 0.0873 # 5 degrees in radians
|
| 32 |
|
| 33 |
DETECTION_THRESHOLD = 2.0
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
WINDOW_SIZE = 20
|
| 30 |
+
MAX_STEPS = 500
|
| 31 |
LOCK_LOSS_THRESHOLD = 0.0873 # 5 degrees in radians
|
| 32 |
|
| 33 |
DETECTION_THRESHOLD = 2.0
|