Spaces:
Sleeping
Sleeping
File size: 3,029 Bytes
3aeaf3d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | from __future__ import annotations
import random
from dataclasses import asdict, dataclass
from typing import Any
from .state import TargetSystemState
from .constants import MAX_TURNS
@dataclass
class RedObservation:
tick: int
stage: int
policy_summary: list[str]
last_attack_outcome: dict[str, Any]
available_strategies: list[str]
seen_strategy_history: list[str]
turns_remaining: int
num_layers: int
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class BlueObservation:
tick: int
stage: int
sessions_visible: list[str]
probe_budget_remaining: int
log_sample: dict[str, list[dict]]
prior_alerts: list[dict]
activation_baseline_available: bool
num_sessions_total: int
def to_dict(self) -> dict:
return asdict(self)
def build_red_observation(state: TargetSystemState, config: dict, num_layers: int) -> RedObservation:
last_outcome: dict[str, Any] = {}
for session in state.sessions.values():
if session.attack_active and session.turns:
last_outcome = {
"success": session.extraction_succeeded,
"detected": session.detected,
"turns_taken": len(session.turns),
}
break
strategy_history = [
session.attack_strategy.value
for session in state.sessions.values()
if session.attack_strategy is not None
]
return RedObservation(
tick=state.tick,
stage=state.stage,
policy_summary=state.policy_rules,
last_attack_outcome=last_outcome,
available_strategies=config["allowed_red_strategies"],
seen_strategy_history=strategy_history[-5:],
turns_remaining=MAX_TURNS - state.tick,
num_layers=num_layers,
)
def build_blue_observation(state: TargetSystemState, config: dict) -> BlueObservation:
all_ids = list(state.sessions.keys())
sample_size = max(1, int(len(all_ids) * config["log_sample_rate"]))
visible_ids = random.sample(all_ids, min(sample_size, len(all_ids)))
cutoff = state.tick - config["latency_ticks"]
log_sample = {}
for session_id in visible_ids:
session = state.sessions[session_id]
log_sample[session_id] = [
{"turn": turn.turn_number, "user": turn.user_input, "assistant": turn.assistant_output}
for turn in session.turns
if turn.turn_number <= cutoff
]
prior_alerts = [
{"session_id": session_id, "detection_turn": session.detection_turn}
for session_id, session in state.sessions.items()
if session.detected
]
return BlueObservation(
tick=state.tick,
stage=state.stage,
sessions_visible=visible_ids,
probe_budget_remaining=state.blue_resources.probe_budget - state.blue_resources.probes_used,
log_sample=log_sample,
prior_alerts=prior_alerts,
activation_baseline_available=True,
num_sessions_total=len(state.sessions),
)
|