Spaces:
Sleeping
Sleeping
File size: 3,295 Bytes
3aeaf3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | from __future__ import annotations
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata, State
from models import SeigeAction, SeigeObservation
from .env import SeigeEnv
class SeigeOpenEnv(Environment[SeigeAction, SeigeObservation, State]):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
self._env = SeigeEnv()
self._episode_id = str(uuid4())
self._step_count = 0
self._has_reset = False
def reset(
self, seed: int | None = None, episode_id: str | None = None, **_: Any
) -> SeigeObservation:
del seed
self._episode_id = episode_id or str(uuid4())
self._step_count = 0
self._has_reset = True
observations = self._env.reset()
return SeigeObservation(
red=observations["red"],
blue=observations["blue"],
current_agent="both",
done=False,
reward=0.0,
)
def step(self, action: SeigeAction, **_: Any) -> SeigeObservation: # type: ignore[override]
if not self._has_reset:
self.reset()
result = self._env.step(self._to_legacy_action(action))
self._step_count += 1
observation = result.get("observation", {})
current_agent = action.agent_type
red = observation if current_agent == "red" else None
blue = observation if current_agent == "blue" else None
return SeigeObservation(
red=red,
blue=blue,
current_agent=current_agent,
info=result.get("info", {}),
done=bool(result.get("done", False)),
reward=result.get("reward"),
)
@property
def state(self) -> State:
return State(
episode_id=self._episode_id,
step_count=self._step_count,
**self._env.state(),
)
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="seige",
description=(
"Adversarial oversight environment for training red attackers "
"and blue defenders around a target model."
),
version="0.1.0",
)
@staticmethod
def _to_legacy_action(action: SeigeAction) -> dict[str, Any]:
data = action.model_dump(exclude={"metadata"}, exclude_none=True)
if action.agent_type == "red":
return {
key: value
for key, value in data.items()
if key
in {
"agent_type",
"strategy",
"sub_strategy",
"payload",
"target_layer",
"direction_label",
"magnitude",
"coalition_partner",
}
}
return {
key: value
for key, value in data.items()
if key
in {
"agent_type",
"action_type",
"session_id",
"layer",
"explanation",
"patch_reference",
}
}
|