File size: 3,086 Bytes
3aeaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Any


class AttackStrategy(str, Enum):
    PERSONA_MANIPULATION = "persona_manipulation"
    INDIRECT_INJECTION = "indirect_injection"
    MULTI_TURN_ESCALATION = "multi_turn_escalation"
    STEERING_VECTOR = "steering_vector"
    SOFT_PROMPT_GCG = "soft_prompt_gcg"
    REPE_BELIEF_FLIP = "repe_belief_flip"
    CIRCUIT_BYPASS = "circuit_bypass"
    COALITION_DISTRACTION = "coalition_distraction"
    COALITION_POISONING = "coalition_poisoning"


class BlueActionType(str, Enum):
    MONITOR = "monitor"
    PROBE = "probe"
    FLAG = "flag"
    BLOCK = "block"
    PATCH = "patch"
    EXPLAIN = "explain"


@dataclass
class Explanation:
    threat_level: str
    detection_method: str
    layer_implicated: int | None = None
    direction_match: str | None = None
    evidence_turns: list[int] = field(default_factory=list)
    anomaly_score: float = 0.0
    recommended_action: str = "warn"
    circuit_hypothesis: str | None = None


@dataclass
class RedAction:
    strategy: AttackStrategy
    sub_strategy: str
    payload: str
    target_layer: int | None = None
    direction_label: str | None = None
    magnitude: float = 0.5
    coalition_partner: str | None = None


@dataclass
class BlueAction:
    action_type: BlueActionType
    session_id: str
    layer: int | None = None
    explanation: Explanation | None = None
    patch_reference: str = "clean"


@dataclass
class Turn:
    turn_number: int
    user_input: str
    assistant_output: str
    activation_snapshot: dict[int, Any] | None = None


@dataclass
class SessionState:
    session_id: str
    turns: list[Turn] = field(default_factory=list)
    attack_active: bool = False
    attack_strategy: AttackStrategy | None = None
    attack_payload_turn: int | None = None
    extraction_succeeded: bool = False
    detected: bool = False
    detection_turn: int | None = None
    patched: bool = False


@dataclass
class BlueResources:
    probe_budget: int = 10
    probes_used: int = 0
    log_sample_rate: float = 0.5
    latency_ticks: int = 1


@dataclass
class TargetSystemState:
    hidden_secret: str
    policy_rules: list[str]
    sessions: dict[str, SessionState]
    activation_baseline: dict[int, list[float]]
    tick: int = 0
    stage: int = 1
    seen_strategy_embeddings: list[list[float]] = field(default_factory=list)
    blue_resources: BlueResources = field(default_factory=BlueResources)
    ground_truth_attacks: dict[str, AttackStrategy] = field(default_factory=dict)

    @classmethod
    def sample(
        cls,
        secrets_bank: Any,
        rules_bank: Any,
        baseline: dict[int, list[float]],
        num_sessions: int = 8,
    ) -> "TargetSystemState":
        sessions = {f"sess_{i}": SessionState(session_id=f"sess_{i}") for i in range(num_sessions)}
        return cls(
            hidden_secret=secrets_bank.sample(),
            policy_rules=rules_bank.sample(n=5),
            sessions=sessions,
            activation_baseline=baseline,
        )