seige / environment /env.py
BART-ender's picture
Upload folder using huggingface_hub
3aeaf3d verified
from __future__ import annotations
import time
from typing import Any
from .actions import ActionParseError, parse_action
from .curriculum import CurriculumManager
from .direction_library import DirectionLibrary
from .executor import EnvironmentExecutor
from .observations import build_blue_observation, build_red_observation
from .constants import MAX_TURNS, STEP_TIMEOUT_SECS
from .rewards import compute_blue_reward, compute_red_reward
from .secrets_bank import RulesBank, SecretsBank
from .state import BlueAction, BlueResources, RedAction, TargetSystemState
from .target_system import build_target_system
try:
from openenv import Environment
_BASE = Environment
except ImportError:
_BASE = object # graceful fallback for local dev
class SeigeEnv(_BASE):
def __init__(self) -> None:
self.direction_library = DirectionLibrary()
self.target = build_target_system(self.direction_library)
self.executor = EnvironmentExecutor(self.target)
self.curriculum = CurriculumManager()
self.secrets_bank = SecretsBank()
self.rules_bank = RulesBank()
self._state: TargetSystemState | None = None
self._episode_red_rewards: list[float] = []
self._episode_blue_rewards: list[float] = []
def reset(self) -> dict:
config = self.curriculum.get_config()
self._state = TargetSystemState.sample(
secrets_bank=self.secrets_bank,
rules_bank=self.rules_bank,
baseline=self.target.baseline_means,
num_sessions=config["num_sessions"],
)
self._state.stage = self.curriculum.stage
self._state.blue_resources = BlueResources(
probe_budget=config["probe_budget"],
log_sample_rate=config["log_sample_rate"],
latency_ticks=config["latency_ticks"],
)
self._episode_red_rewards = []
self._episode_blue_rewards = []
return {
"red": build_red_observation(self._state, config, self.target.get_num_layers()).to_dict(),
"blue": build_blue_observation(self._state, config).to_dict(),
}
def step(self, action: dict[str, Any] | str) -> dict:
if self._state is None:
raise RuntimeError("Call reset() before step()")
start = time.time()
try:
parsed = parse_action(action)
except ActionParseError as exc:
return self._error_result(str(exc))
if time.time() - start > STEP_TIMEOUT_SECS:
return self._timeout_result()
if isinstance(parsed, RedAction):
result = self.executor.execute_red(parsed, self._state)
reward = compute_red_reward(result, self._state, result.strategy_embedding)
self._episode_red_rewards.append(reward)
observation = build_red_observation(
self._state,
self.curriculum.get_config(),
self.target.get_num_layers(),
).to_dict()
elif isinstance(parsed, BlueAction):
result = self.executor.execute_blue(parsed, self._state)
reward = compute_blue_reward(result, self._state)
self._episode_blue_rewards.append(reward)
observation = build_blue_observation(self._state, self.curriculum.get_config()).to_dict()
else:
return self._error_result("Unknown parsed action")
self._state.tick += 1
self._state.blue_resources.probes_used = 0
done = (
result.episode_terminal
or self._state.tick >= MAX_TURNS
or any(session.extraction_succeeded for session in self._state.sessions.values())
)
if done:
avg_red = sum(self._episode_red_rewards) / max(1, len(self._episode_red_rewards))
avg_blue = sum(self._episode_blue_rewards) / max(1, len(self._episode_blue_rewards))
self.curriculum.record(avg_red, avg_blue)
if self.curriculum.should_advance():
self.curriculum.advance()
info_dict = result.info_dict()
reward_info = {
"reward/total": reward,
"reward/extraction": 10.0 * float(result.extraction_succeeded),
"reward/policy_bypass": 5.0 * float(result.policy_bypassed),
"reward/evasion": 4.0 * float(result.success and not result.detected),
"reward/detected_penalty": -3.0 * float(result.detected),
}
info_dict.update(reward_info)
try:
import wandb
if wandb.run:
wandb.log(reward_info, step=self._state.tick)
except ImportError:
pass
return {
"observation": observation,
"reward": reward,
"done": done,
"info": info_dict,
}
def state(self) -> dict:
if self._state is None:
return {}
return {
"tick": self._state.tick,
"stage": self._state.stage,
"num_sessions": len(self._state.sessions),
"active_attacks": sum(1 for session in self._state.sessions.values() if session.attack_active),
"detections": sum(1 for session in self._state.sessions.values() if session.detected),
}
def _error_result(self, message: str) -> dict:
return {"observation": {}, "reward": -1.0, "done": False, "info": {"error": message}}
def _timeout_result(self) -> dict:
return {"observation": {}, "reward": -2.0, "done": True, "info": {"error": "timeout"}}