sh4shv4t commited on
Commit
ca72cb2
·
0 Parent(s):

Initial env deployment

Browse files
Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+ WORKDIR /app
3
+ COPY . .
4
+ RUN pip install openenv-core fastapi uvicorn httpx python-dotenv
5
+ ENV HONEYPOT_URL="https://sh4shv4t-statestrike-honeypot.hf.space"
6
+ EXPOSE 7860
7
+ CMD ["python", "-m", "statestrike_env.server", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StateStrike Environment
3
+ emoji: 🎯
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: true
8
+ ---
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core==0.1.0
2
+ fastapi==0.111.0
3
+ uvicorn[standard]==0.30.1
4
+ sqlalchemy==2.0.30
5
+ httpx==0.27.0
6
+ pydantic==2.7.1
7
+ streamlit==1.35.0
8
+ plotly==5.22.0
9
+ python-dotenv==1.0.1
10
+ pytest==8.2.0
11
+ pytest-asyncio==0.23.7
12
+ rich==13.7.1
13
+ websockets==12.0
14
+ portalocker==2.8.2
statestrike_env/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StateStrike Environment
3
+ emoji: 🎯
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: true
8
+ ---
statestrike_env/__init__.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """StateStrike OpenEnv-compatible client exports."""
4
+
5
+ import json
6
+ from contextlib import AbstractContextManager
7
+ from typing import Any
8
+
9
+ from websockets.sync.client import ClientConnection, connect
10
+
11
+ from statestrike_env.models import StateStrikeAction, StateStrikeObservation, StateStrikeState
12
+
13
+
14
+ class _SyncStateStrikeClient(AbstractContextManager["_SyncStateStrikeClient"]):
15
+ """Synchronous WebSocket client wrapper for reset/step/state calls."""
16
+
17
+ def __init__(self, base_url: str) -> None:
18
+ """Initialize client.
19
+
20
+ Args:
21
+ base_url: WebSocket URL including `/ws` path.
22
+ """
23
+
24
+ normalized = base_url.rstrip("/")
25
+ self.base_url = normalized if normalized.endswith("/ws") else f"{normalized}/ws"
26
+ self._conn: ClientConnection | None = None
27
+
28
+ def __enter__(self) -> "_SyncStateStrikeClient":
29
+ """Open WebSocket connection for environment operations.
30
+
31
+ Returns:
32
+ Connected client instance.
33
+ """
34
+
35
+ self._conn = connect(self.base_url)
36
+ return self
37
+
38
+ def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
39
+ """Close WebSocket connection.
40
+
41
+ Args:
42
+ exc_type: Exception type if raised in context block.
43
+ exc: Exception value if raised in context block.
44
+ tb: Traceback object if raised in context block.
45
+ """
46
+
47
+ if self._conn is not None:
48
+ self._conn.close()
49
+ self._conn = None
50
+
51
+ def reset(self) -> StateStrikeObservation:
52
+ """Request environment reset.
53
+
54
+ Returns:
55
+ Initial observation.
56
+
57
+ Raises:
58
+ RuntimeError: If the server response is malformed or unsuccessful.
59
+ """
60
+
61
+ frame = self._request({"method": "reset"})
62
+ return StateStrikeObservation.model_validate(frame["observation"])
63
+
64
+ def step(self, action: StateStrikeAction) -> StateStrikeObservation:
65
+ """Execute one environment step.
66
+
67
+ Args:
68
+ action: Action payload.
69
+
70
+ Returns:
71
+ Updated observation.
72
+
73
+ Raises:
74
+ RuntimeError: If the server response is malformed or unsuccessful.
75
+ """
76
+
77
+ frame = self._request({"method": "step", "action": action.model_dump()})
78
+ return StateStrikeObservation.model_validate(frame["observation"])
79
+
80
+ def state(self) -> StateStrikeState:
81
+ """Retrieve current environment state.
82
+
83
+ Returns:
84
+ Current state model.
85
+
86
+ Raises:
87
+ RuntimeError: If the server response is malformed or unsuccessful.
88
+ """
89
+
90
+ frame = self._request({"method": "state"})
91
+ return StateStrikeState.model_validate(frame["state"])
92
+
93
+ def _request(self, payload: dict[str, Any]) -> dict[str, Any]:
94
+ """Send request frame and parse server response.
95
+
96
+ Args:
97
+ payload: JSON-serializable request payload.
98
+
99
+ Returns:
100
+ Parsed response object.
101
+
102
+ Raises:
103
+ RuntimeError: If connection is closed or server reports failure.
104
+ """
105
+
106
+ if self._conn is None:
107
+ raise RuntimeError("WebSocket connection is not open")
108
+
109
+ self._conn.send(json.dumps(payload))
110
+ raw = self._conn.recv()
111
+ frame = json.loads(raw)
112
+ if not frame.get("ok"):
113
+ raise RuntimeError(frame.get("error", "Unknown server error"))
114
+ return frame
115
+
116
+
117
+ class StateStrikeEnv:
118
+ """Environment client namespace matching OpenEnv SDK usage patterns."""
119
+
120
+ def __init__(self, base_url: str = "ws://localhost:8001/ws") -> None:
121
+ """Store base URL for later sync client creation.
122
+
123
+ Args:
124
+ base_url: Environment WebSocket endpoint.
125
+ """
126
+
127
+ self.base_url = base_url
128
+
129
+ def sync(self) -> _SyncStateStrikeClient:
130
+ """Create synchronous context-managed client.
131
+
132
+ Returns:
133
+ A synchronous environment client implementing reset/step/state.
134
+ """
135
+
136
+ return _SyncStateStrikeClient(self.base_url)
137
+
138
+
139
+ __all__ = ["StateStrikeEnv", "StateStrikeAction", "StateStrikeObservation", "StateStrikeState"]
statestrike_env/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (6.4 kB). View file
 
statestrike_env/__pycache__/constants.cpython-313.pyc ADDED
Binary file (1.85 kB). View file
 
statestrike_env/__pycache__/models.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
statestrike_env/__pycache__/server.cpython-313.pyc ADDED
Binary file (24.3 kB). View file
 
statestrike_env/constants.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Centralized constants for StateStrike environment and reward grading.
4
+
5
+ Theory:
6
+ Consolidating reward and episode hyperparameters avoids hidden magic numbers,
7
+ supports reproducibility, and aligns with RL experiment hygiene guidance from
8
+ Sutton & Barto (2018).
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+
13
+
14
+ ALPHA = 1.0
15
+ BETA = 10.0
16
+ GAMMA = 500.0
17
+ DELTA = 1.0
18
+ BASELINE_WINDOW = 10
19
+ EPISODE_LENGTH = 200
20
+ REDOS_LATENCY_THRESHOLD = 1500.0
21
+ DB_TIMEOUT_THRESHOLD = 3000.0
22
+ CHAIN_REQUIRED_ORDERS = 20
23
+ CHAIN_COOLDOWN_STEPS = 10
24
+ MAX_ACTION_HISTORY = 20
25
+ ACTION_TIMEOUT_SECONDS = 8.0
26
+ DEFAULT_BASELINE_LATENCY_MS = 50.0
27
+ EARLY_TERMINATION_REWARD = -200.0
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class RewardConstants:
32
+ """Typed reward constants passed into the reward grader.
33
+
34
+ Attributes:
35
+ ALPHA: Latency reward weight.
36
+ BETA: State-chain completion bonus.
37
+ GAMMA: Exploitation bounty for severe degradation/failure.
38
+ DELTA: Penalty magnitude for low-value fuzzing requests.
39
+ REDOS_LATENCY_THRESHOLD: Latency threshold used to infer ReDoS impact.
40
+ DB_TIMEOUT_THRESHOLD: Latency threshold used for DB timeout exploitation.
41
+ CHAIN_REQUIRED_ORDERS: Minimum order count before GET /orders chain bonus.
42
+ CHAIN_COOLDOWN_STEPS: Minimum steps between chain bonus awards.
43
+ EARLY_TERMINATION_REWARD: Episode early-stop reward floor.
44
+ BASELINE_WINDOW: EMA window used for baseline latency updates.
45
+ """
46
+
47
+ ALPHA: float = ALPHA
48
+ BETA: float = BETA
49
+ GAMMA: float = GAMMA
50
+ DELTA: float = DELTA
51
+ REDOS_LATENCY_THRESHOLD: float = REDOS_LATENCY_THRESHOLD
52
+ DB_TIMEOUT_THRESHOLD: float = DB_TIMEOUT_THRESHOLD
53
+ CHAIN_REQUIRED_ORDERS: int = CHAIN_REQUIRED_ORDERS
54
+ CHAIN_COOLDOWN_STEPS: int = CHAIN_COOLDOWN_STEPS
55
+ EARLY_TERMINATION_REWARD: float = EARLY_TERMINATION_REWARD
56
+ BASELINE_WINDOW: int = BASELINE_WINDOW
statestrike_env/grader.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Reward grading logic for StateStrike.
4
+
5
+ Theory:
6
+ The reward function follows standard MDP shaping principles from Sutton &
7
+ Barto (2018): combine dense shaping signals (latency ratio), sparse goal
8
+ rewards (exploit bounty), and penalties (invalid spam suppression). It also
9
+ borrows stateful-sequence ideas from RESTler (Atlidakis et al., ICSE 2019)
10
+ while rewarding infrastructure effects (e.g., ReDoS latency spikes) inspired
11
+ by Davis et al. (USENIX Security 2018).
12
+ """
13
+
14
+ import logging
15
+ import math
16
+ from typing import TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ from statestrike_env.constants import RewardConstants
20
+ from statestrike_env.models import StateStrikeObservation
21
+ from statestrike_env.session import StateStrikeSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def compute_reward(
27
+ obs: "StateStrikeObservation",
28
+ session: "StateStrikeSession",
29
+ constants: "RewardConstants",
30
+ ) -> tuple[float, dict[str, float]]:
31
+ """
32
+ Compute R_t = α·log(L_t/L_base) + β·S_t + γ·E_t − δ·P_t
33
+
34
+ Theory (Sutton & Barto, 2018, Ch. 3 — Finite MDPs):
35
+ The reward signal must be designed so the ONLY way to maximize cumulative
36
+ reward is to achieve the TRUE objective. Each term is chosen to prevent a
37
+ specific reward-hacking strategy:
38
+
39
+ TERM 1 — α·log(L_t/L_base): Logarithmic latency reward.
40
+ Why log? Linear reward incentivizes the agent to find ONE massive spike
41
+ and repeat it. Logarithmic reward gives diminishing returns per repeated
42
+ exploitation, pushing the agent to discover NEW vulnerabilities.
43
+ Why ratio? Prevents baseline-anchoring attacks where agent engineers a
44
+ low baseline then makes normal requests look like spikes.
45
+ Anti-hack: baseline ONLY updates from successful (latency>0) steps.
46
+
47
+ TERM 2 — β·S_t: State-chain bonus.
48
+ Fires at most once per CHAIN_COOLDOWN_STEPS steps, and only if
49
+ order_count has increased since the last award. This prevents the
50
+ POST→GET farming loop that would yield +5 reward/step for free.
51
+ Anti-hack: last_chain_bonus_step and post_count_at_last_chain guards.
52
+
53
+ TERM 3 — γ·E_t: Exploitation bounty.
54
+ Fires EXACTLY ONCE per vulnerability type per episode (one-time flag).
55
+ Without this, an agent discovering db_degradation would spam GET /orders
56
+ for +500/step indefinitely. The one-time award correctly signals
57
+ "you found it" without incentivizing repeated triggering.
58
+ Anti-hack: redos_bounty_awarded and db_degradation_bounty_awarded flags.
59
+
60
+ TERM 4 — δ·P_t: Fuzzing penalty.
61
+ Applied only to genuinely fast 400s (latency < 100ms), not to slow 400s
62
+ (which may indicate actual CPU burn from ReDoS parsing).
63
+ Threshold tightened from 200ms to 100ms to avoid penalizing legitimate
64
+ slow-failing payloads.
65
+ Anti-hack: latency threshold ensures ReDoS probes are not penalized.
66
+
67
+ Reference:
68
+ - Sutton & Barto (2018): reward shaping and sparse reward design
69
+ - Atlidakis et al. (ICSE 2019): stateful API exploration objectives
70
+ - Davis et al. (USENIX 2018): ReDoS computational complexity
71
+
72
+ Args:
73
+ obs: The observation from the current step.
74
+ session: The mutable session state (modified in-place for flags).
75
+ constants: Reward weight constants from constants.py.
76
+
77
+ Returns:
78
+ Tuple of (scalar_reward, breakdown_dict) where breakdown_dict
79
+ contains each term's contribution for telemetry and dashboard display.
80
+ """
81
+
82
+ reward = 0.0
83
+ breakdown: dict[str, float] = {
84
+ "latency_reward": 0.0,
85
+ "chain_bonus": 0.0,
86
+ "exploit_bounty": 0.0,
87
+ "fuzz_penalty": 0.0,
88
+ "total": 0.0,
89
+ }
90
+
91
+ # Guard: connection failure -> neutral observation, no reward signal.
92
+ if obs.http_status == 0 or obs.latency_ms == 0.0:
93
+ breakdown["error"] = 1.0
94
+ logger.debug("Step %d: connection failure, returning zero reward", obs.step)
95
+ return 0.0, breakdown
96
+
97
+ # Update rolling baseline only from successful steps.
98
+ _update_baseline(session, obs.latency_ms, constants.BASELINE_WINDOW)
99
+
100
+ # TERM 1: Logarithmic latency reward.
101
+ if session.baseline_latency > 0:
102
+ latency_ratio = obs.latency_ms / session.baseline_latency
103
+ latency_ratio = max(0.01, min(latency_ratio, 100.0))
104
+ latency_reward = constants.ALPHA * math.log(latency_ratio)
105
+ reward += latency_reward
106
+ breakdown["latency_reward"] = round(latency_reward, 4)
107
+ logger.debug(
108
+ "Step %d: latency=%.1fms baseline=%.1fms ratio=%.2f reward=%.3f",
109
+ obs.step,
110
+ obs.latency_ms,
111
+ session.baseline_latency,
112
+ latency_ratio,
113
+ latency_reward,
114
+ )
115
+
116
+ # TERM 2: State-chain bonus (anti-farming guards).
117
+ chain_bonus = 0.0
118
+ if _should_award_chain_bonus(obs, session, constants):
119
+ chain_bonus = constants.BETA
120
+ session.last_chain_bonus_step = obs.step
121
+ session.post_count_at_last_chain = session.order_count
122
+ logger.info(
123
+ "Step %d: Chain bonus awarded (+%.1f). order_count=%d",
124
+ obs.step,
125
+ chain_bonus,
126
+ session.order_count,
127
+ )
128
+ reward += chain_bonus
129
+ breakdown["chain_bonus"] = chain_bonus
130
+
131
+ # TERM 3: Exploitation bounties (one-time per episode).
132
+ exploit_bounty = 0.0
133
+
134
+ if (
135
+ not session.db_degradation_bounty_awarded
136
+ and (obs.http_status >= 500 or obs.latency_ms > constants.DB_TIMEOUT_THRESHOLD)
137
+ and obs.action_taken.action_type.value == "get_orders"
138
+ ):
139
+ exploit_bounty += constants.GAMMA
140
+ session.db_degradation_bounty_awarded = True
141
+ session.triggered_vulns.add("db_degradation")
142
+ logger.info(
143
+ "Step %d: DB_DEGRADATION bounty awarded (+%.1f). latency=%.1fms",
144
+ obs.step,
145
+ constants.GAMMA,
146
+ obs.latency_ms,
147
+ )
148
+
149
+ if (
150
+ not session.redos_bounty_awarded
151
+ and obs.latency_ms > constants.REDOS_LATENCY_THRESHOLD
152
+ and obs.http_status == 400
153
+ and obs.action_taken.action_type.value == "post_user"
154
+ ):
155
+ redos_bounty = constants.GAMMA * 0.8
156
+ exploit_bounty += redos_bounty
157
+ session.redos_bounty_awarded = True
158
+ session.triggered_vulns.add("redos")
159
+ logger.info(
160
+ "Step %d: REDOS bounty awarded (+%.1f). latency=%.1fms",
161
+ obs.step,
162
+ redos_bounty,
163
+ obs.latency_ms,
164
+ )
165
+
166
+ reward += exploit_bounty
167
+ breakdown["exploit_bounty"] = round(exploit_bounty, 4)
168
+
169
+ # TERM 4: Fuzzing penalty (only genuine fast-fail syntax errors).
170
+ fuzz_penalty = 0.0
171
+ if obs.http_status == 400 and obs.latency_ms < 100.0:
172
+ fuzz_penalty = -constants.DELTA
173
+ logger.debug("Step %d: Fuzz penalty applied (fast 400, %.1fms)", obs.step, obs.latency_ms)
174
+ reward += fuzz_penalty
175
+ breakdown["fuzz_penalty"] = round(fuzz_penalty, 4)
176
+
177
+ breakdown["total"] = round(reward, 4)
178
+ return reward, breakdown
179
+
180
+
181
+ def _update_baseline(session: "StateStrikeSession", latency_ms: float, window: int) -> None:
182
+ """Update rolling baseline latency using exponential moving average."""
183
+
184
+ alpha_ema = 2.0 / (window + 1)
185
+ if session.baseline_sample_count == 0:
186
+ session.baseline_latency = latency_ms
187
+ else:
188
+ session.baseline_latency = alpha_ema * latency_ms + (1 - alpha_ema) * session.baseline_latency
189
+ session.baseline_sample_count += 1
190
+
191
+
192
+ def _should_award_chain_bonus(
193
+ obs: "StateStrikeObservation",
194
+ session: "StateStrikeSession",
195
+ constants: "RewardConstants",
196
+ ) -> bool:
197
+ """Determine if the state-chain bonus should be awarded this step."""
198
+
199
+ if obs.action_taken.action_type.value != "get_orders":
200
+ return False
201
+ if session.order_count < constants.CHAIN_REQUIRED_ORDERS:
202
+ return False
203
+ steps_since_last = obs.step - session.last_chain_bonus_step
204
+ if steps_since_last < constants.CHAIN_COOLDOWN_STEPS:
205
+ return False
206
+ if session.order_count <= session.post_count_at_last_chain:
207
+ return False
208
+ return True
statestrike_env/models.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Typed action, observation, and state models for StateStrike.
4
+
5
+ Theory:
6
+ Explicit state/action schemas reduce ambiguity in RL interfaces and improve
7
+ reproducibility when evaluating policies across different backends.
8
+ """
9
+
10
+ from enum import Enum
11
+ from typing import Any, Optional
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+
16
+ class ActionType(str, Enum):
17
+ """Discrete actions available to the StateStrike agent."""
18
+
19
+ POST_USER = "post_user"
20
+ GET_USER = "get_user"
21
+ POST_ORDER = "post_order"
22
+ GET_ORDERS = "get_orders"
23
+ HEALTH_CHECK = "health_check"
24
+
25
+
26
+ class PayloadStrategy(str, Enum):
27
+ """Payload generation strategies used by the fuzzing policy."""
28
+
29
+ VALID = "valid"
30
+ REDOS_ATTACK = "redos"
31
+ OVERSIZED = "oversized"
32
+ MALFORMED = "malformed"
33
+
34
+
35
+ class StateStrikeAction(BaseModel):
36
+ """Action frame sent by the RL agent.
37
+
38
+ Args:
39
+ action_type: Target endpoint operation.
40
+ payload_strategy: Payload mutation strategy.
41
+ target_user_id: Optional user identifier override.
42
+ """
43
+
44
+ action_type: ActionType
45
+ payload_strategy: PayloadStrategy
46
+ target_user_id: Optional[int] = None
47
+
48
+
49
+ class StateStrikeObservation(BaseModel):
50
+ """Step-level feedback returned by the environment.
51
+
52
+ Args:
53
+ step: Current step index within the episode.
54
+ action_taken: Action executed during the step.
55
+ http_status: HTTP status code from honeypot response.
56
+ latency_ms: End-to-end processing latency in milliseconds.
57
+ reward: Scalar reward at this step.
58
+ cumulative_reward: Running reward sum for the episode.
59
+ baseline_latency_ms: Rolling latency baseline used for normalization.
60
+ order_count: Number of POST /orders calls in this episode.
61
+ triggered_vulns: Vulnerability labels discovered so far.
62
+ done: Terminal signal for episode completion.
63
+ info: Arbitrary metadata, including reward breakdown.
64
+ """
65
+
66
+ step: int
67
+ action_taken: StateStrikeAction
68
+ http_status: int
69
+ latency_ms: float
70
+ reward: float
71
+ cumulative_reward: float
72
+ baseline_latency_ms: float
73
+ order_count: int
74
+ triggered_vulns: list[str]
75
+ done: bool
76
+ info: dict[str, Any] = Field(default_factory=dict)
77
+
78
+
79
+ class StateStrikeState(BaseModel):
80
+ """Persistent session state exposed by state().
81
+
82
+ Args:
83
+ session_id: Unique identifier for current environment episode.
84
+ step_count: Number of actions executed in current session.
85
+ cumulative_reward: Running reward sum for current session.
86
+ order_count: Number of POST /orders calls in session.
87
+ baseline_latency_ms: Rolling baseline latency in milliseconds.
88
+ action_history: Most recent action history window.
89
+ triggered_vulns: Vulnerabilities discovered in this session.
90
+ """
91
+
92
+ session_id: str
93
+ step_count: int
94
+ cumulative_reward: float
95
+ order_count: int
96
+ baseline_latency_ms: float
97
+ action_history: list[StateStrikeAction]
98
+ triggered_vulns: list[str]
statestrike_env/server.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """OpenEnv-style WebSocket environment server for StateStrike."""
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from contextlib import asynccontextmanager
11
+ from typing import Any
12
+
13
+ import httpx
14
+ from dotenv import load_dotenv
15
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
16
+ from fastapi.responses import JSONResponse
17
+
18
+ try:
19
+ import openenv_core # noqa: F401
20
+ except ImportError: # pragma: no cover - optional import for compatibility signaling.
21
+ openenv_core = None
22
+
23
+ from statestrike_env.constants import (
24
+ ACTION_TIMEOUT_SECONDS,
25
+ DEFAULT_BASELINE_LATENCY_MS,
26
+ EPISODE_LENGTH,
27
+ RewardConstants,
28
+ )
29
+ from statestrike_env.grader import compute_reward
30
+ from statestrike_env.models import ActionType, PayloadStrategy, StateStrikeAction, StateStrikeObservation, StateStrikeState
31
+ from statestrike_env.session import StateStrikeSession
32
+
33
+ load_dotenv()
34
+
35
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+ HONEYPOT_URL = os.getenv("HONEYPOT_URL", "http://localhost:8000")
39
+ HOST = os.getenv("STATESTRIKE_ENV_HOST", "0.0.0.0")
40
+ PORT = int(os.getenv("STATESTRIKE_ENV_PORT", "8001"))
41
+
42
+
43
+ async def wait_for_honeypot(url: str, max_wait: int = 30) -> None:
44
+ """Block until honeypot is reachable or raise RuntimeError.
45
+
46
+ Args:
47
+ url: Honeypot base URL.
48
+ max_wait: Maximum wait time in seconds.
49
+
50
+ Raises:
51
+ RuntimeError: If honeypot is not reachable before timeout.
52
+ """
53
+
54
+ deadline = asyncio.get_event_loop().time() + max_wait
55
+ delay = 1.0
56
+ async with httpx.AsyncClient() as client:
57
+ while asyncio.get_event_loop().time() < deadline:
58
+ try:
59
+ response = await client.get(f"{url}/health", timeout=3.0)
60
+ if response.status_code == 200:
61
+ LOGGER.info("Honeypot is ready at %s", url)
62
+ return
63
+ LOGGER.warning(
64
+ "Honeypot health returned status=%s, retrying in %.1fs...",
65
+ response.status_code,
66
+ delay,
67
+ )
68
+ except Exception as exc: # noqa: BLE001
69
+ LOGGER.warning("Honeypot not ready (%s), retrying in %.1fs...", exc, delay)
70
+
71
+ await asyncio.sleep(delay)
72
+ delay = min(delay * 1.5, 5.0)
73
+
74
+ raise RuntimeError(f"Honeypot at {url} did not become ready within {max_wait}s")
75
+
76
+
77
+ class StateStrikeEnvironment:
78
+ """Core reset/step/state implementation.
79
+
80
+ Theory:
81
+ OpenEnv training loops benefit from persistent transport: WebSocket-based
82
+ sessions amortize handshake overhead and preserve episode-local state,
83
+ which aligns with OpenEnv architecture guidance (Burtenshaw, 2025).
84
+ """
85
+
86
+ def __init__(self, honeypot_url: str, constants: RewardConstants | None = None) -> None:
87
+ """Initialize environment service.
88
+
89
+ Args:
90
+ honeypot_url: Base URL for vulnerable honeypot API.
91
+ constants: Optional reward constants override.
92
+ """
93
+
94
+ self.honeypot_url = honeypot_url.rstrip("/")
95
+ self.constants = constants or RewardConstants()
96
+
97
+ async def reset(self, session: StateStrikeSession) -> StateStrikeObservation:
98
+ """Reset session and return initial observation.
99
+
100
+ Args:
101
+ session: Session object tied to one client connection.
102
+
103
+ Returns:
104
+ Initial observation with zero reward.
105
+ """
106
+
107
+ status, latency_ms, _ = await self._request_honeypot("GET", "/health")
108
+ baseline = latency_ms if latency_ms > 0 else DEFAULT_BASELINE_LATENCY_MS
109
+ session.reset(baseline_latency=baseline)
110
+
111
+ action = StateStrikeAction(action_type=ActionType.HEALTH_CHECK, payload_strategy=PayloadStrategy.VALID)
112
+ obs = StateStrikeObservation(
113
+ step=0,
114
+ action_taken=action,
115
+ http_status=status,
116
+ latency_ms=latency_ms,
117
+ reward=0.0,
118
+ cumulative_reward=0.0,
119
+ baseline_latency_ms=session.baseline_latency,
120
+ order_count=0,
121
+ triggered_vulns=[],
122
+ done=False,
123
+ info={"event": "reset"},
124
+ )
125
+ return obs
126
+
127
+ async def step(self, session: StateStrikeSession, action: StateStrikeAction) -> StateStrikeObservation:
128
+ """Execute one environment transition.
129
+
130
+ Args:
131
+ session: Session object tied to one client connection.
132
+ action: Agent action.
133
+
134
+ Returns:
135
+ Updated observation with reward and terminal signal.
136
+ """
137
+
138
+ request_method, request_path, params, payload = self._translate_action(action, session)
139
+ status, latency_ms, body = await self._request_honeypot(request_method, request_path, params=params, payload=payload)
140
+
141
+ session.step_count += 1
142
+ if action.action_type == ActionType.POST_ORDER:
143
+ session.order_count += 1
144
+ session.append_action(action)
145
+
146
+ provisional = StateStrikeObservation(
147
+ step=session.step_count,
148
+ action_taken=action,
149
+ http_status=status,
150
+ latency_ms=latency_ms,
151
+ reward=0.0,
152
+ cumulative_reward=session.cumulative_reward,
153
+ baseline_latency_ms=session.baseline_latency,
154
+ order_count=session.order_count,
155
+ triggered_vulns=sorted(session.triggered_vulns),
156
+ done=False,
157
+ info={"response": body},
158
+ )
159
+
160
+ reward, breakdown = compute_reward(provisional, session, self.constants)
161
+ session.cumulative_reward += reward
162
+
163
+ done = (
164
+ session.step_count >= EPISODE_LENGTH
165
+ or session.cumulative_reward < self.constants.EARLY_TERMINATION_REWARD
166
+ )
167
+ obs = StateStrikeObservation(
168
+ step=session.step_count,
169
+ action_taken=action,
170
+ http_status=status,
171
+ latency_ms=latency_ms,
172
+ reward=reward,
173
+ cumulative_reward=session.cumulative_reward,
174
+ baseline_latency_ms=session.baseline_latency,
175
+ order_count=session.order_count,
176
+ triggered_vulns=sorted(session.triggered_vulns),
177
+ done=done,
178
+ info={"reward_breakdown": breakdown, "response": body},
179
+ )
180
+ return obs
181
+
182
+ async def state(self, session: StateStrikeSession) -> StateStrikeState:
183
+ """Return serializable state snapshot.
184
+
185
+ Args:
186
+ session: Session object tied to one client connection.
187
+
188
+ Returns:
189
+ Current state model.
190
+ """
191
+
192
+ return session.as_state()
193
+
194
+ def _translate_action(
195
+ self,
196
+ action: StateStrikeAction,
197
+ session: StateStrikeSession,
198
+ ) -> tuple[str, str, dict[str, Any] | None, dict[str, Any] | None]:
199
+ """Translate action schema into honeypot HTTP request details.
200
+
201
+ Args:
202
+ action: Agent action.
203
+ session: Session used for contextual defaults.
204
+
205
+ Returns:
206
+ Tuple of method, path, query params, and JSON payload.
207
+ """
208
+
209
+ target_user_id = action.target_user_id or 1
210
+
211
+ if action.action_type == ActionType.POST_USER:
212
+ email = self._payload_email(action.payload_strategy)
213
+ return "POST", "/users", None, {"email": email}
214
+ if action.action_type == ActionType.GET_USER:
215
+ return "GET", f"/users/{target_user_id}", None, None
216
+ if action.action_type == ActionType.POST_ORDER:
217
+ item = self._payload_item(action.payload_strategy)
218
+ return "POST", "/orders", None, {"user_id": target_user_id, "item": item}
219
+ if action.action_type == ActionType.GET_ORDERS:
220
+ return "GET", "/orders", {"user_id": target_user_id}, None
221
+ return "GET", "/health", None, None
222
+
223
+ @staticmethod
224
+ def _payload_email(strategy: PayloadStrategy) -> str:
225
+ """Build email-like payload for POST /users action.
226
+
227
+ Args:
228
+ strategy: Payload strategy enum.
229
+
230
+ Returns:
231
+ Strategy-specific string payload.
232
+ """
233
+
234
+ if strategy == PayloadStrategy.REDOS_ATTACK:
235
+ return "a" * 39 + "!"
236
+ if strategy == PayloadStrategy.OVERSIZED:
237
+ return "A" * 4096
238
+ if strategy == PayloadStrategy.MALFORMED:
239
+ return "@@@"
240
+ return "validuser123"
241
+
242
+ @staticmethod
243
+ def _payload_item(strategy: PayloadStrategy) -> str:
244
+ """Build order item payload.
245
+
246
+ Args:
247
+ strategy: Payload strategy enum.
248
+
249
+ Returns:
250
+ Strategy-specific order item string.
251
+ """
252
+
253
+ if strategy == PayloadStrategy.OVERSIZED:
254
+ return "item_" + ("X" * 2048)
255
+ if strategy == PayloadStrategy.MALFORMED:
256
+ return ""
257
+ return "standard_item"
258
+
259
+ async def _request_honeypot(
260
+ self,
261
+ method: str,
262
+ path: str,
263
+ *,
264
+ params: dict[str, Any] | None = None,
265
+ payload: dict[str, Any] | None = None,
266
+ ) -> tuple[int, float, dict[str, Any]]:
267
+ """Execute honeypot request and normalize response metadata.
268
+
269
+ Args:
270
+ method: HTTP method.
271
+ path: Relative path.
272
+ params: Optional query parameters.
273
+ payload: Optional JSON body.
274
+
275
+ Returns:
276
+ Tuple of status code, latency milliseconds, and parsed response body.
277
+ """
278
+
279
+ url = f"{self.honeypot_url}{path}"
280
+ started = time.perf_counter()
281
+ try:
282
+ async with httpx.AsyncClient(timeout=ACTION_TIMEOUT_SECONDS) as client:
283
+ response = await client.request(method, url, params=params, json=payload)
284
+ elapsed_ms = (time.perf_counter() - started) * 1000.0
285
+ header_latency = response.headers.get("X-Process-Time-Ms")
286
+ latency_ms = float(header_latency) if header_latency else elapsed_ms
287
+ body = response.json() if response.content else {}
288
+ return response.status_code, latency_ms, body
289
+ except (httpx.RequestError, ValueError) as exc:
290
+ LOGGER.warning("Honeypot request failed method=%s path=%s error=%s", method, path, exc)
291
+ return 0, 0.0, {"error": str(exc), "synthetic": True}
292
+
293
+
294
+ @asynccontextmanager
295
+ async def lifespan(_: FastAPI):
296
+ """Block API startup until honeypot health endpoint is reachable."""
297
+
298
+ await wait_for_honeypot(HONEYPOT_URL, max_wait=30)
299
+ yield
300
+
301
+
302
+ app = FastAPI(title="StateStrike OpenEnv Server", version="1.0.0", lifespan=lifespan)
303
+ env_service = StateStrikeEnvironment(HONEYPOT_URL)
304
+ http_debug_session = StateStrikeSession.new_session()
305
+
306
+
307
+ # OpenEnv uses WebSocket (/ws) for persistent sessions rather than
308
+ # stateless HTTP. Each step() is a lightweight frame over an existing
309
+ # connection (~0.1ms overhead vs ~10-50ms TCP handshake per HTTP call).
310
+ # Reference: openenv-course module-5, burtenshaw/openenv-scaling
311
+ # This architecture enables high-frequency RL training loops.
312
+ @app.websocket("/ws")
313
+ async def websocket_env(websocket: WebSocket) -> None:
314
+ """Run one isolated environment loop per WebSocket client.
315
+
316
+ Args:
317
+ websocket: Connected client transport.
318
+ """
319
+
320
+ await websocket.accept()
321
+ session = StateStrikeSession.new_session()
322
+ LOGGER.info("WebSocket session started session_id=%s", session.session_id)
323
+
324
+ try:
325
+ while True:
326
+ frame = await websocket.receive_text()
327
+ request = json.loads(frame)
328
+ method = request.get("method")
329
+
330
+ if method == "reset":
331
+ obs = await env_service.reset(session)
332
+ await websocket.send_json({"ok": True, "observation": obs.model_dump()})
333
+ continue
334
+
335
+ if method == "step":
336
+ action_payload = request.get("action", {})
337
+ action = StateStrikeAction.model_validate(action_payload)
338
+ obs = await env_service.step(session, action)
339
+ await websocket.send_json({"ok": True, "observation": obs.model_dump()})
340
+ continue
341
+
342
+ if method == "state":
343
+ state = await env_service.state(session)
344
+ await websocket.send_json({"ok": True, "state": state.model_dump()})
345
+ continue
346
+
347
+ await websocket.send_json({"ok": False, "error": f"Unknown method: {method}"})
348
+ except (WebSocketDisconnect, json.JSONDecodeError):
349
+ LOGGER.info("WebSocket session ended session_id=%s", session.session_id)
350
+
351
+
352
+ @app.get("/reset")
353
+ async def reset_http() -> JSONResponse:
354
+ """HTTP debug endpoint for reset semantics.
355
+
356
+ Returns:
357
+ JSON response containing reset observation.
358
+ """
359
+
360
+ obs = await env_service.reset(http_debug_session)
361
+ return JSONResponse(obs.model_dump())
362
+
363
+
364
+ @app.post("/step")
365
+ async def step_http(action: StateStrikeAction) -> JSONResponse:
366
+ """HTTP debug endpoint for step semantics.
367
+
368
+ Args:
369
+ action: Action payload.
370
+
371
+ Returns:
372
+ JSON response containing post-step observation.
373
+ """
374
+
375
+ obs = await env_service.step(http_debug_session, action)
376
+ return JSONResponse(obs.model_dump())
377
+
378
+
379
+ @app.get("/state")
380
+ async def state_http() -> JSONResponse:
381
+ """HTTP debug endpoint for state semantics.
382
+
383
+ Returns:
384
+ JSON response containing current session state.
385
+ """
386
+
387
+ state = await env_service.state(http_debug_session)
388
+ return JSONResponse(state.model_dump())
389
+
390
+
391
+ def main() -> None:
392
+ """Entrypoint for running environment server via python -m."""
393
+
394
+ import uvicorn
395
+
396
+ uvicorn.run("statestrike_env.server:app", host=HOST, port=PORT, reload=False)
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
statestrike_env/session.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Session state manager for per-agent environment isolation."""
4
+
5
+ from dataclasses import dataclass, field
6
+ from uuid import uuid4
7
+
8
+ from statestrike_env.constants import DEFAULT_BASELINE_LATENCY_MS, MAX_ACTION_HISTORY
9
+ from statestrike_env.models import StateStrikeAction, StateStrikeState
10
+
11
+
12
+ @dataclass
13
+ class StateStrikeSession:
14
+ """Mutable per-WebSocket environment session.
15
+
16
+ Attributes:
17
+ session_id: Current episode UUID.
18
+ step_count: Number of steps taken in current episode.
19
+ cumulative_reward: Running reward total.
20
+ order_count: Number of POST /orders actions issued.
21
+ baseline_latency: Rolling average latency used in reward normalization.
22
+ action_history: Most recent action history window.
23
+ triggered_vulns: Vulnerabilities discovered in current episode.
24
+ redos_bounty_awarded: One-time ReDoS bounty guard.
25
+ db_degradation_bounty_awarded: One-time DB degradation bounty guard.
26
+ last_chain_bonus_step: Last step where chain bonus was awarded.
27
+ post_count_at_last_chain: Order count snapshot at last chain award.
28
+ baseline_sample_count: Number of successful baseline samples seen.
29
+ """
30
+
31
+ session_id: str
32
+ step_count: int = 0
33
+ cumulative_reward: float = 0.0
34
+ order_count: int = 0
35
+ baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS
36
+ action_history: list[StateStrikeAction] = field(default_factory=list)
37
+ triggered_vulns: set[str] = field(default_factory=set)
38
+ # Anti-hacking: one-time flags so each bounty fires exactly once per episode.
39
+ redos_bounty_awarded: bool = False
40
+ db_degradation_bounty_awarded: bool = False
41
+ # Anti-hacking: chain bonus can only fire once between meaningful progress windows.
42
+ last_chain_bonus_step: int = -10
43
+ post_count_at_last_chain: int = 0
44
+ # Baseline integrity: updated only on successful (non-zero latency) steps.
45
+ baseline_sample_count: int = 0
46
+
47
+ @classmethod
48
+ def new_session(cls) -> StateStrikeSession:
49
+ """Create a new initialized session.
50
+
51
+ Returns:
52
+ Newly initialized StateStrikeSession instance.
53
+ """
54
+
55
+ return cls(session_id=str(uuid4()))
56
+
57
+ def reset(self, baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS) -> None:
58
+ """Reset session in-place for a new episode.
59
+
60
+ Args:
61
+ baseline_latency: Fresh baseline latency in milliseconds.
62
+ """
63
+
64
+ self.session_id = str(uuid4())
65
+ self.step_count = 0
66
+ self.cumulative_reward = 0.0
67
+ self.order_count = 0
68
+ self.baseline_latency = baseline_latency
69
+ self.action_history.clear()
70
+ self.triggered_vulns.clear()
71
+ self.redos_bounty_awarded = False
72
+ self.db_degradation_bounty_awarded = False
73
+ self.last_chain_bonus_step = -10
74
+ self.post_count_at_last_chain = 0
75
+ self.baseline_sample_count = 1 if baseline_latency > 0 else 0
76
+
77
+ def record_latency(self, latency_ms: float) -> float:
78
+ """Update baseline latency using EMA from successful samples.
79
+
80
+ Args:
81
+ latency_ms: Observed latency for the current step.
82
+
83
+ Returns:
84
+ Updated baseline latency.
85
+ """
86
+
87
+ sample = max(latency_ms, 1.0)
88
+ alpha_ema = 2.0 / (10 + 1)
89
+ if self.baseline_sample_count == 0:
90
+ self.baseline_latency = sample
91
+ else:
92
+ self.baseline_latency = alpha_ema * sample + (1 - alpha_ema) * self.baseline_latency
93
+ self.baseline_sample_count += 1
94
+ return self.baseline_latency
95
+
96
+ def append_action(self, action: StateStrikeAction) -> None:
97
+ """Append action while enforcing history length constraints.
98
+
99
+ Args:
100
+ action: Action to append.
101
+ """
102
+
103
+ self.action_history.append(action)
104
+ if len(self.action_history) > MAX_ACTION_HISTORY:
105
+ self.action_history.pop(0)
106
+
107
+ def as_state(self) -> StateStrikeState:
108
+ """Convert mutable session internals to external state model.
109
+
110
+ Returns:
111
+ Immutable API-safe state representation.
112
+ """
113
+
114
+ return StateStrikeState(
115
+ session_id=self.session_id,
116
+ step_count=self.step_count,
117
+ cumulative_reward=self.cumulative_reward,
118
+ order_count=self.order_count,
119
+ baseline_latency_ms=self.baseline_latency,
120
+ action_history=list(self.action_history),
121
+ triggered_vulns=sorted(self.triggered_vulns),
122
+ )