sh4shv4t commited on
Commit
21cee38
·
1 Parent(s): ca72cb2

feat: unified single-container deployment with all 3 tasks + inference.py

Browse files
Dockerfile CHANGED
@@ -1,7 +1,22 @@
1
- FROM python:3.11-slim
 
2
  WORKDIR /app
 
 
 
 
 
 
3
  COPY . .
4
- RUN pip install openenv-core fastapi uvicorn httpx python-dotenv
5
- ENV HONEYPOT_URL="https://sh4shv4t-statestrike-honeypot.hf.space"
 
 
 
 
6
  EXPOSE 7860
7
- CMD ["python", "-m", "statestrike_env.server", "--port", "7860"]
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
  WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
6
+
7
+ COPY requirements.txt .
8
+ RUN pip install --no-cache-dir -r requirements.txt
9
+
10
  COPY . .
11
+
12
+ RUN touch /app/statestrike.db
13
+
14
+ COPY scripts/start.sh /start.sh
15
+ RUN chmod +x /start.sh
16
+
17
  EXPOSE 7860
18
+
19
+ HEALTHCHECK --interval=10s --timeout=5s --start-period=20s --retries=5 \
20
+ CMD curl -f http://localhost:7860/health || exit 1
21
+
22
+ CMD ["/start.sh"]
honeypot/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StateStrike Environment
3
+ emoji: 🎯
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: true
8
+ tags: ["openenv", "hackathon", "security", "rl"]
9
+ ---
honeypot/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """StateStrike vulnerable honeypot API package."""
honeypot/app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """FastAPI honeypot target for stateful API fuzzing experiments."""
4
+
5
+ import logging
6
+ import re
7
+ import time
8
+ from datetime import datetime, timezone
9
+
10
+ from fastapi import Depends, FastAPI, HTTPException, Query
11
+ from pydantic import BaseModel, Field
12
+ from sqlalchemy.orm import Session
13
+
14
+ from honeypot.database import get_db, init_db
15
+ from honeypot.middleware import TelemetryMiddleware, create_telemetry_router
16
+ from honeypot.models import Order, User
17
+
18
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
19
+ LOGGER = logging.getLogger(__name__)
20
+
21
+ app = FastAPI(title="StateStrike Honeypot", version="1.0.0")
22
+ app.add_middleware(TelemetryMiddleware)
23
+ app.include_router(create_telemetry_router())
24
+
25
+
26
+ class UserCreate(BaseModel):
27
+ """Payload for POST /users."""
28
+
29
+ email: str = Field(min_length=1, max_length=256)
30
+
31
+
32
+ class OrderCreate(BaseModel):
33
+ """Payload for POST /orders."""
34
+
35
+ user_id: int
36
+ item: str = Field(min_length=1, max_length=256)
37
+
38
+
39
+ @app.on_event("startup")
40
+ async def on_startup() -> None:
41
+ """Initialize database tables at service startup."""
42
+
43
+ init_db()
44
+
45
+
46
+ @app.get("/health")
47
+ def health_check() -> dict[str, object]:
48
+ """Return liveness and timestamp information.
49
+
50
+ Returns:
51
+ Dictionary with service status and UNIX timestamp.
52
+ """
53
+
54
+ return {"status": "ok", "ts": int(time.time())}
55
+
56
+
57
+ @app.post("/users")
58
+ def create_user(payload: UserCreate, db: Session = Depends(get_db)) -> dict[str, object]:
59
+ """Create a user with intentionally vulnerable regex validation.
60
+
61
+ Args:
62
+ payload: User creation body.
63
+ db: SQLAlchemy session.
64
+
65
+ Returns:
66
+ Created user dictionary.
67
+
68
+ Raises:
69
+ HTTPException: If email validation fails.
70
+ """
71
+
72
+ pattern = r"^([a-zA-Z0-9]+\s?)*[a-zA-Z0-9]+$"
73
+
74
+ # VULNERABILITY: ReDoS via catastrophic backtracking
75
+ # Reference: Davis et al., "ReDoS in the Wild" (USENIX Security 2018)
76
+ # This pattern exhibits O(2^n) backtracking on input "aaa...a!"
77
+ # A production-hardened alternative would use: re2 or a finite automaton
78
+ if not re.fullmatch(pattern, payload.email, flags=re.DOTALL):
79
+ raise HTTPException(status_code=400, detail="Invalid email format")
80
+
81
+ user = User(email=payload.email)
82
+ db.add(user)
83
+ db.commit()
84
+ db.refresh(user)
85
+
86
+ return {"id": user.id, "email": user.email, "created_at": user.created_at.isoformat()}
87
+
88
+
89
+ @app.get("/users/{user_id}")
90
+ def get_user(user_id: int, db: Session = Depends(get_db)) -> dict[str, object]:
91
+ """Fetch user by identifier.
92
+
93
+ Args:
94
+ user_id: User identifier.
95
+ db: SQLAlchemy session.
96
+
97
+ Returns:
98
+ User dictionary.
99
+
100
+ Raises:
101
+ HTTPException: If user does not exist.
102
+ """
103
+
104
+ user = db.query(User).filter(User.id == user_id).first()
105
+ if user is None:
106
+ raise HTTPException(status_code=404, detail="User not found")
107
+ return {"id": user.id, "email": user.email, "created_at": user.created_at.isoformat()}
108
+
109
+
110
+ @app.post("/orders")
111
+ def create_order(payload: OrderCreate, db: Session = Depends(get_db)) -> dict[str, object]:
112
+ """Create an order for an existing user.
113
+
114
+ Args:
115
+ payload: Order creation body.
116
+ db: SQLAlchemy session.
117
+
118
+ Returns:
119
+ Created order dictionary.
120
+
121
+ Raises:
122
+ HTTPException: If user does not exist.
123
+ """
124
+
125
+ user = db.query(User).filter(User.id == payload.user_id).first()
126
+ if user is None:
127
+ raise HTTPException(status_code=404, detail="User not found")
128
+
129
+ order = Order(user_id=payload.user_id, item=payload.item)
130
+ db.add(order)
131
+ db.commit()
132
+ db.refresh(order)
133
+
134
+ return {
135
+ "id": order.id,
136
+ "user_id": order.user_id,
137
+ "item": order.item,
138
+ "created_at": order.created_at.isoformat(),
139
+ }
140
+
141
+
142
+ @app.get("/orders")
143
+ def list_orders(
144
+ user_id: int | None = Query(default=None),
145
+ db: Session = Depends(get_db),
146
+ ) -> dict[str, object]:
147
+ """List orders and expose intentional stateful degradation path.
148
+
149
+ Args:
150
+ user_id: Optional user filter and degradation trigger key.
151
+ db: SQLAlchemy session.
152
+
153
+ Returns:
154
+ Order list payload with count metadata.
155
+ """
156
+
157
+ query = db.query(Order)
158
+ if user_id is not None:
159
+ query = query.filter(Order.user_id == user_id)
160
+
161
+ orders = query.all()
162
+
163
+ if user_id is not None and len(orders) > 20:
164
+ # VULNERABILITY: Unindexed aggregate query degradation
165
+ # Only reachable after stateful chain: 21x POST /orders -> GET /orders
166
+ # An RL agent can discover this; a stateless fuzzer cannot.
167
+ # Reference: RESTler (Atlidakis et al., ICSE 2019) pioneered stateful
168
+ # REST fuzzing but used grammar-based, not RL-based exploration.
169
+ all_orders = db.query(Order).all()
170
+ expensive_aggregate: dict[int, int] = {}
171
+ for left in all_orders:
172
+ total = 0
173
+ for right in all_orders:
174
+ if left.user_id == right.user_id:
175
+ total += 1
176
+ expensive_aggregate[left.user_id] = total
177
+ LOGGER.info(
178
+ "Triggered synthetic O(n^2) aggregate for user_id=%s with %s total rows",
179
+ user_id,
180
+ len(all_orders),
181
+ )
182
+ time.sleep(0.8)
183
+
184
+ return {
185
+ "count": len(orders),
186
+ "orders": [
187
+ {
188
+ "id": order.id,
189
+ "user_id": order.user_id,
190
+ "item": order.item,
191
+ "created_at": order.created_at.isoformat()
192
+ if isinstance(order.created_at, datetime)
193
+ else datetime.now(timezone.utc).isoformat(),
194
+ }
195
+ for order in orders
196
+ ],
197
+ }
honeypot/database.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Database setup for the StateStrike honeypot.
4
+
5
+ Theory:
6
+ A local SQLite backend keeps the demo deterministic and lightweight while
7
+ preserving enough statefulness for multi-step fuzzing trajectories.
8
+ """
9
+
10
+ import logging
11
+ import os
12
+ from collections.abc import Generator
13
+
14
+ from dotenv import load_dotenv
15
+ from sqlalchemy import create_engine
16
+ from sqlalchemy.orm import Session, declarative_base, sessionmaker
17
+
18
+ load_dotenv()
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+ DATABASE_FILE = os.getenv("DATABASE_FILE", "statestrike.db")
23
+ DATABASE_URL = f"sqlite:///{DATABASE_FILE}"
24
+
25
+ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
26
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
27
+ Base = declarative_base()
28
+
29
+
30
+ def get_db() -> Generator[Session, None, None]:
31
+ """Yield a SQLAlchemy session for request-scoped DB access.
32
+
33
+ Yields:
34
+ An open SQLAlchemy Session object.
35
+
36
+ Raises:
37
+ RuntimeError: If session creation fails.
38
+ """
39
+
40
+ db = SessionLocal()
41
+ try:
42
+ yield db
43
+ finally:
44
+ db.close()
45
+
46
+
47
+ def init_db() -> None:
48
+ """Create database schema if tables do not yet exist.
49
+
50
+ Raises:
51
+ Exception: Propagates SQLAlchemy creation errors.
52
+ """
53
+
54
+ from honeypot import models # Local import avoids circular import at module load.
55
+
56
+ Base.metadata.create_all(bind=engine)
57
+ LOGGER.info("Initialized SQLite schema at %s", DATABASE_URL)
honeypot/middleware.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Telemetry middleware and stream endpoint for honeypot observations."""
4
+
5
+ import json
6
+ import time
7
+ from collections import deque
8
+ from collections.abc import AsyncIterator
9
+ from datetime import datetime, timezone
10
+
11
+ from fastapi import APIRouter
12
+ from fastapi.responses import StreamingResponse
13
+ from starlette.middleware.base import BaseHTTPMiddleware
14
+ from starlette.requests import Request
15
+ from starlette.responses import Response
16
+
17
+ TELEMETRY_BUFFER: deque[dict[str, object]] = deque(maxlen=500)
18
+
19
+
20
+ class TelemetryMiddleware(BaseHTTPMiddleware):
21
+ """Capture request latency and expose response timing metadata.
22
+
23
+ Args:
24
+ app: Wrapped ASGI application.
25
+
26
+ Returns:
27
+ None. Middleware mutates response headers and side effects telemetry buffer.
28
+ """
29
+
30
+ async def dispatch(self, request: Request, call_next) -> Response:
31
+ """Process an incoming request and append telemetry entry.
32
+
33
+ Args:
34
+ request: Starlette request object.
35
+ call_next: Next middleware/app callable.
36
+
37
+ Returns:
38
+ The downstream response with X-Process-Time-Ms header attached.
39
+ """
40
+
41
+ start = time.perf_counter()
42
+ response = await call_next(request)
43
+ elapsed_ms = (time.perf_counter() - start) * 1000.0
44
+ response.headers["X-Process-Time-Ms"] = f"{elapsed_ms:.3f}"
45
+
46
+ TELEMETRY_BUFFER.append(
47
+ {
48
+ "ts": datetime.now(timezone.utc).isoformat(),
49
+ "path": request.url.path,
50
+ "method": request.method,
51
+ "status": response.status_code,
52
+ "latency_ms": round(elapsed_ms, 3),
53
+ }
54
+ )
55
+ return response
56
+
57
+
58
+ def create_telemetry_router() -> APIRouter:
59
+ """Create telemetry SSE routes.
60
+
61
+ Returns:
62
+ A FastAPI router exposing telemetry streaming endpoint.
63
+ """
64
+
65
+ router = APIRouter(prefix="/telemetry", tags=["telemetry"])
66
+
67
+ @router.get("/stream")
68
+ async def stream_recent_entries() -> StreamingResponse:
69
+ """Emit the latest telemetry entries over Server-Sent Events.
70
+
71
+ Returns:
72
+ StreamingResponse configured with text/event-stream media type.
73
+ """
74
+
75
+ async def event_source() -> AsyncIterator[str]:
76
+ payload = json.dumps(list(TELEMETRY_BUFFER)[-100:])
77
+ yield f"data: {payload}\n\n"
78
+
79
+ return StreamingResponse(event_source(), media_type="text/event-stream")
80
+
81
+ return router
honeypot/models.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """ORM models used by the honeypot API."""
4
+
5
+ from datetime import datetime, timezone
6
+
7
+ from sqlalchemy import DateTime, ForeignKey, Integer, String
8
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
9
+
10
+ from honeypot.database import Base
11
+
12
+
13
+ class User(Base):
14
+ """User entity created by POST /users."""
15
+
16
+ __tablename__ = "users"
17
+
18
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
19
+ email: Mapped[str] = mapped_column(String(256), unique=False, nullable=False)
20
+ created_at: Mapped[datetime] = mapped_column(
21
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
22
+ )
23
+
24
+ orders: Mapped[list[Order]] = relationship("Order", back_populates="user")
25
+
26
+
27
+ class Order(Base):
28
+ """Order entity created by POST /orders.
29
+
30
+ Note:
31
+ The `user_id` field intentionally has no explicit DB index to preserve the
32
+ degradation path used by the challenge.
33
+ """
34
+
35
+ __tablename__ = "orders"
36
+
37
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
38
+ user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
39
+ item: Mapped[str] = mapped_column(String(256), nullable=False)
40
+ created_at: Mapped[datetime] = mapped_column(
41
+ DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
42
+ )
43
+
44
+ user: Mapped[User] = relationship("User", back_populates="orders")
inference.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StateStrike Inference Script
3
+ ============================
4
+ Runs an LLM agent against all 3 StateStrike tasks and emits
5
+ structured [START]/[STEP]/[END] logs for automated scoring.
6
+
7
+ Environment variables:
8
+ API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1)
9
+ MODEL_NAME Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
10
+ HF_TOKEN Hugging Face API token (required, no default)
11
+ LOCAL_IMAGE_NAME Docker image name if using from_docker_image()
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import json
18
+ import os
19
+ import textwrap
20
+ from typing import List, Optional
21
+
22
+ from openai import OpenAI
23
+
24
+ from statestrike_env.environment import StateStrikeEnv
25
+ from statestrike_env.models import EndpointChoice, PayloadStrategy, StateStrikeAction
26
+
27
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
28
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
29
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
30
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
31
+ BENCHMARK = "statestrike"
32
+
33
+
34
+ def log_start(task: str, env: str, model: str) -> None:
35
+ print(f"[START] task={task} env={env} model={model}", flush=True)
36
+
37
+
38
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
39
+ error_val = error if error else "null"
40
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
41
+
42
+
43
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
44
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
45
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
46
+
47
+
48
+ TASK_CONFIGS = {
49
+ "endpoint_discovery": {"max_steps": 10, "success_threshold": 0.6},
50
+ "vulnerability_probe": {"max_steps": 15, "success_threshold": 0.5},
51
+ "exploit_chain": {"max_steps": 30, "success_threshold": 0.75},
52
+ }
53
+
54
+
55
+ SYSTEM_PROMPTS = {
56
+ "endpoint_discovery": textwrap.dedent(
57
+ """
58
+ You are a security agent probing an API to discover all available endpoints.
59
+ Known endpoints to try: POST /users, GET /users/{id}, POST /orders, GET /orders, GET /health
60
+ At each step, choose ONE endpoint to call and a payload strategy.
61
+ Reply with a JSON object: {"endpoint": "POST /users", "payload_strategy": "valid", "target_user_id": null}
62
+ Try each endpoint at least once. Vary payload strategies to confirm reachability.
63
+ """
64
+ ).strip(),
65
+ "vulnerability_probe": textwrap.dedent(
66
+ """
67
+ You are a security agent looking for vulnerabilities in an API.
68
+ Two vulnerabilities exist:
69
+ 1. ReDoS: POST /users with payload_strategy "redos" causes CPU spike (latency > 1500ms, returns 400)
70
+ 2. DB Degradation: GET /orders after 20+ orders causes slow query (latency > 800ms)
71
+ Reply with JSON: {"endpoint": "POST /users", "payload_strategy": "redos", "target_user_id": null}
72
+ Valid endpoint values: "POST /users", "GET /users/{id}", "POST /orders", "GET /orders", "GET /health"
73
+ Valid strategies: "valid", "redos", "oversized", "malformed"
74
+ """
75
+ ).strip(),
76
+ "exploit_chain": textwrap.dedent(
77
+ """
78
+ You are a security agent executing a full exploit chain. Complete these in order:
79
+ 1. Create a user: POST /users with payload_strategy "valid"
80
+ 2. Create 21 orders: POST /orders with payload_strategy "valid", target_user_id=<user_id from step 1>
81
+ 3. Trigger slow query: GET /orders with target_user_id set
82
+ 4. Trigger ReDoS: POST /users with payload_strategy "redos"
83
+ Track your progress. Each step builds on the previous.
84
+ Reply with JSON: {"endpoint": "POST /orders", "payload_strategy": "valid", "target_user_id": 1}
85
+ """
86
+ ).strip(),
87
+ }
88
+
89
+
90
+ def get_agent_action(
91
+ client: OpenAI,
92
+ task_name: str,
93
+ step: int,
94
+ last_obs: dict,
95
+ history: List[str],
96
+ ) -> StateStrikeAction:
97
+ system = SYSTEM_PROMPTS[task_name]
98
+ history_block = "\n".join(history[-5:]) if history else "None"
99
+ user_msg = textwrap.dedent(
100
+ f"""
101
+ Step: {step}
102
+ Last observation: {json.dumps(last_obs, indent=2)}
103
+ Recent history:
104
+ {history_block}
105
+ What is your next action? Reply with JSON only.
106
+ """
107
+ ).strip()
108
+
109
+ fallback = StateStrikeAction(
110
+ endpoint=EndpointChoice.HEALTH,
111
+ payload_strategy=PayloadStrategy.VALID,
112
+ )
113
+
114
+ try:
115
+ completion = client.chat.completions.create(
116
+ model=MODEL_NAME,
117
+ messages=[
118
+ {"role": "system", "content": system},
119
+ {"role": "user", "content": user_msg},
120
+ ],
121
+ temperature=0.7,
122
+ max_tokens=100,
123
+ )
124
+ text = (completion.choices[0].message.content or "").strip()
125
+ text = text.removeprefix("```json").removesuffix("```").strip()
126
+ data = json.loads(text)
127
+ return StateStrikeAction(**data)
128
+ except Exception as exc:
129
+ print(f"[DEBUG] Action parse failed: {exc}", flush=True)
130
+ return fallback
131
+
132
+
133
+ async def run_task(
134
+ env: StateStrikeEnv,
135
+ client: OpenAI,
136
+ task_name: str,
137
+ ) -> float:
138
+ config = TASK_CONFIGS[task_name]
139
+ max_steps = config["max_steps"]
140
+ success_threshold = config["success_threshold"]
141
+
142
+ rewards: List[float] = []
143
+ steps_taken = 0
144
+ score = 0.0
145
+ success = False
146
+ history: List[str] = []
147
+
148
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
149
+
150
+ try:
151
+ result = await env.reset(task_name=task_name)
152
+ obs = result.observation
153
+ last_obs_dict = obs.model_dump()
154
+
155
+ for step in range(1, max_steps + 1):
156
+ if result.done:
157
+ break
158
+
159
+ action = get_agent_action(client, task_name, step, last_obs_dict, history)
160
+ action_str = f"{action.endpoint}+{action.payload_strategy}"
161
+
162
+ result = await env.step(action)
163
+ obs = result.observation
164
+ reward = result.reward or 0.0
165
+ done = result.done
166
+ error = result.info.get("error") if isinstance(result.info, dict) else None
167
+
168
+ rewards.append(reward)
169
+ steps_taken = step
170
+ last_obs_dict = obs.model_dump()
171
+
172
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
173
+ history.append(
174
+ f"Step {step}: {action_str} -> status={obs.http_status} "
175
+ f"latency={obs.latency_ms:.0f}ms reward={reward:.2f}"
176
+ )
177
+
178
+ if done:
179
+ break
180
+
181
+ score = min(max(obs.task_progress, 0.0), 1.0)
182
+ success = score >= success_threshold
183
+
184
+ except Exception as exc:
185
+ print(f"[DEBUG] Task {task_name} failed: {exc}", flush=True)
186
+ finally:
187
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
188
+
189
+ return score
190
+
191
+
192
+ async def main() -> None:
193
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
194
+
195
+ if LOCAL_IMAGE_NAME:
196
+ env = await StateStrikeEnv.from_docker_image(LOCAL_IMAGE_NAME)
197
+ else:
198
+ env = StateStrikeEnv()
199
+
200
+ scores = {}
201
+ for task_name in ["endpoint_discovery", "vulnerability_probe", "exploit_chain"]:
202
+ score = await run_task(env, client, task_name)
203
+ scores[task_name] = score
204
+
205
+ await env.close()
206
+
207
+ print(f"\n[DEBUG] Final scores: {scores}", flush=True)
208
+ avg = sum(scores.values()) / len(scores)
209
+ print(f"[DEBUG] Average score: {avg:.3f}", flush=True)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ asyncio.run(main())
openenv.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: statestrike
2
+ version: "1.0.0"
3
+ description: >
4
+ A stateful API security audit environment where an agent learns to discover
5
+ real infrastructure vulnerabilities through systematic endpoint exploration
6
+ and stateful exploit chaining.
7
+ author: StateStrike Team
8
+ license: MIT
9
+ tags:
10
+ - security
11
+ - api-testing
12
+ - stateful
13
+ - openenv
14
+ tasks:
15
+ - id: endpoint_discovery
16
+ description: >
17
+ Identify all reachable API endpoints. Agent receives the base URL
18
+ and must probe systematically to discover which endpoints exist.
19
+ difficulty: easy
20
+ max_steps: 20
21
+ reward_range: [0.0, 1.0]
22
+ - id: vulnerability_probe
23
+ description: >
24
+ Identify and correctly classify at least one vulnerability in the
25
+ target API (redos or db_degradation).
26
+ difficulty: medium
27
+ max_steps: 30
28
+ reward_range: [0.0, 1.0]
29
+ - id: exploit_chain
30
+ description: >
31
+ Execute the full stateful exploit chain: create user, build order
32
+ history, trigger DB degradation slow path, and trigger ReDoS.
33
+ difficulty: hard
34
+ max_steps: 60
35
+ reward_range: [0.0, 1.0]
36
+ observation_space:
37
+ type: object
38
+ description: HTTP response details including status, latency, body, and session state
39
+ action_space:
40
+ type: object
41
+ description: HTTP action with endpoint choice, method, and payload strategy
requirements.txt CHANGED
@@ -10,5 +10,5 @@ python-dotenv==1.0.1
10
  pytest==8.2.0
11
  pytest-asyncio==0.23.7
12
  rich==13.7.1
13
- websockets==12.0
14
- portalocker==2.8.2
 
10
  pytest==8.2.0
11
  pytest-asyncio==0.23.7
12
  rich==13.7.1
13
+ websockets>=15.0
14
+ openai>=1.0.0
scripts/run_demo.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # StateStrike Demo Launch Script
5
+ # Starts all services and runs 200-step agent demo
6
+
7
+ echo "🎯 StateStrike — OpenEnv Hackathon Demo"
8
+ echo "Starting honeypot API..."
9
+ uvicorn honeypot.app:app --port 8000 &
10
+ HONEY_PID=$!
11
+
12
+ echo "Starting OpenEnv environment server..."
13
+ python -m statestrike_env.server &
14
+ ENV_PID=$!
15
+
16
+ echo "Starting dashboard..."
17
+ streamlit run dashboard/app.py --server.port 8501 &
18
+ DASH_PID=$!
19
+
20
+ cleanup() {
21
+ kill "$HONEY_PID" "$ENV_PID" "$DASH_PID" 2>/dev/null || true
22
+ }
23
+ trap cleanup EXIT
24
+
25
+ sleep 3
26
+ echo "All services up. Running agent..."
27
+ echo "Dashboard: http://localhost:8501"
28
+ python -m agent.runner --steps 200
scripts/setup_env.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ # StateStrike bootstrap script
5
+ # Creates local env file and installs pinned dependencies.
6
+
7
+ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
8
+ cd "$ROOT_DIR"
9
+
10
+ if [ ! -f .env ]; then
11
+ cp .env.example .env
12
+ echo "Created .env from .env.example"
13
+ fi
14
+
15
+ python -m pip install --upgrade pip
16
+ python -m pip install -r requirements.txt
17
+
18
+ # Ensure Docker bind-mount file targets exist as files.
19
+ touch statestrike.db
20
+ touch telemetry.json
21
+
22
+ echo "StateStrike environment setup complete."
scripts/start.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ echo "[StateStrike] Starting honeypot on port 8000..."
5
+ uvicorn honeypot.app:app --host 0.0.0.0 --port 8000 &
6
+ HONEYPOT_PID=$!
7
+
8
+ echo "[StateStrike] Waiting for honeypot..."
9
+ for i in $(seq 1 30); do
10
+ if curl -sf http://localhost:8000/health > /dev/null 2>&1; then
11
+ echo "[StateStrike] Honeypot ready."
12
+ break
13
+ fi
14
+ sleep 1
15
+ done
16
+
17
+ echo "[StateStrike] Starting environment server on port 7860..."
18
+ export HONEYPOT_URL="http://localhost:8000"
19
+ uvicorn statestrike_env.environment:app --host 0.0.0.0 --port 7860
20
+
21
+ wait $HONEYPOT_PID
statestrike_env/README.md CHANGED
@@ -5,4 +5,5 @@ colorFrom: red
5
  colorTo: purple
6
  sdk: docker
7
  pinned: true
 
8
  ---
 
5
  colorTo: purple
6
  sdk: docker
7
  pinned: true
8
+ tags: ["openenv", "hackathon", "security", "rl"]
9
  ---
statestrike_env/__init__.py CHANGED
@@ -1,139 +1,21 @@
1
  from __future__ import annotations
2
 
3
- """StateStrike OpenEnv-compatible client exports."""
4
-
5
- import json
6
- from contextlib import AbstractContextManager
7
- from typing import Any
8
-
9
- from websockets.sync.client import ClientConnection, connect
10
-
11
- from statestrike_env.models import StateStrikeAction, StateStrikeObservation, StateStrikeState
12
-
13
-
14
- class _SyncStateStrikeClient(AbstractContextManager["_SyncStateStrikeClient"]):
15
- """Synchronous WebSocket client wrapper for reset/step/state calls."""
16
-
17
- def __init__(self, base_url: str) -> None:
18
- """Initialize client.
19
-
20
- Args:
21
- base_url: WebSocket URL including `/ws` path.
22
- """
23
-
24
- normalized = base_url.rstrip("/")
25
- self.base_url = normalized if normalized.endswith("/ws") else f"{normalized}/ws"
26
- self._conn: ClientConnection | None = None
27
-
28
- def __enter__(self) -> "_SyncStateStrikeClient":
29
- """Open WebSocket connection for environment operations.
30
-
31
- Returns:
32
- Connected client instance.
33
- """
34
-
35
- self._conn = connect(self.base_url)
36
- return self
37
-
38
- def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
39
- """Close WebSocket connection.
40
-
41
- Args:
42
- exc_type: Exception type if raised in context block.
43
- exc: Exception value if raised in context block.
44
- tb: Traceback object if raised in context block.
45
- """
46
-
47
- if self._conn is not None:
48
- self._conn.close()
49
- self._conn = None
50
-
51
- def reset(self) -> StateStrikeObservation:
52
- """Request environment reset.
53
-
54
- Returns:
55
- Initial observation.
56
-
57
- Raises:
58
- RuntimeError: If the server response is malformed or unsuccessful.
59
- """
60
-
61
- frame = self._request({"method": "reset"})
62
- return StateStrikeObservation.model_validate(frame["observation"])
63
-
64
- def step(self, action: StateStrikeAction) -> StateStrikeObservation:
65
- """Execute one environment step.
66
-
67
- Args:
68
- action: Action payload.
69
-
70
- Returns:
71
- Updated observation.
72
-
73
- Raises:
74
- RuntimeError: If the server response is malformed or unsuccessful.
75
- """
76
-
77
- frame = self._request({"method": "step", "action": action.model_dump()})
78
- return StateStrikeObservation.model_validate(frame["observation"])
79
-
80
- def state(self) -> StateStrikeState:
81
- """Retrieve current environment state.
82
-
83
- Returns:
84
- Current state model.
85
-
86
- Raises:
87
- RuntimeError: If the server response is malformed or unsuccessful.
88
- """
89
-
90
- frame = self._request({"method": "state"})
91
- return StateStrikeState.model_validate(frame["state"])
92
-
93
- def _request(self, payload: dict[str, Any]) -> dict[str, Any]:
94
- """Send request frame and parse server response.
95
-
96
- Args:
97
- payload: JSON-serializable request payload.
98
-
99
- Returns:
100
- Parsed response object.
101
-
102
- Raises:
103
- RuntimeError: If connection is closed or server reports failure.
104
- """
105
-
106
- if self._conn is None:
107
- raise RuntimeError("WebSocket connection is not open")
108
-
109
- self._conn.send(json.dumps(payload))
110
- raw = self._conn.recv()
111
- frame = json.loads(raw)
112
- if not frame.get("ok"):
113
- raise RuntimeError(frame.get("error", "Unknown server error"))
114
- return frame
115
-
116
-
117
- class StateStrikeEnv:
118
- """Environment client namespace matching OpenEnv SDK usage patterns."""
119
-
120
- def __init__(self, base_url: str = "ws://localhost:8001/ws") -> None:
121
- """Store base URL for later sync client creation.
122
-
123
- Args:
124
- base_url: Environment WebSocket endpoint.
125
- """
126
-
127
- self.base_url = base_url
128
-
129
- def sync(self) -> _SyncStateStrikeClient:
130
- """Create synchronous context-managed client.
131
-
132
- Returns:
133
- A synchronous environment client implementing reset/step/state.
134
- """
135
-
136
- return _SyncStateStrikeClient(self.base_url)
137
-
138
-
139
- __all__ = ["StateStrikeEnv", "StateStrikeAction", "StateStrikeObservation", "StateStrikeState"]
 
1
  from __future__ import annotations
2
 
3
+ from statestrike_env.environment import StateStrikeEnv
4
+ from statestrike_env.models import (
5
+ EndpointChoice,
6
+ PayloadStrategy,
7
+ StateStrikeAction,
8
+ StateStrikeObservation,
9
+ StateStrikeState,
10
+ StepResult,
11
+ )
12
+
13
+ __all__ = [
14
+ "StateStrikeEnv",
15
+ "EndpointChoice",
16
+ "PayloadStrategy",
17
+ "StateStrikeAction",
18
+ "StateStrikeObservation",
19
+ "StateStrikeState",
20
+ "StepResult",
21
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
statestrike_env/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (6.4 kB)
 
statestrike_env/__pycache__/constants.cpython-313.pyc DELETED
Binary file (1.85 kB)
 
statestrike_env/__pycache__/models.cpython-311.pyc DELETED
Binary file (4.42 kB)
 
statestrike_env/__pycache__/server.cpython-313.pyc DELETED
Binary file (24.3 kB)
 
statestrike_env/constants.py CHANGED
@@ -1,56 +1,39 @@
1
  from __future__ import annotations
2
 
3
- """Centralized constants for StateStrike environment and reward grading.
4
-
5
- Theory:
6
- Consolidating reward and episode hyperparameters avoids hidden magic numbers,
7
- supports reproducibility, and aligns with RL experiment hygiene guidance from
8
- Sutton & Barto (2018).
9
- """
10
-
11
  from dataclasses import dataclass
12
 
13
 
14
- ALPHA = 1.0
15
- BETA = 10.0
16
- GAMMA = 500.0
17
- DELTA = 1.0
18
  BASELINE_WINDOW = 10
19
- EPISODE_LENGTH = 200
20
- REDOS_LATENCY_THRESHOLD = 1500.0
21
- DB_TIMEOUT_THRESHOLD = 3000.0
22
- CHAIN_REQUIRED_ORDERS = 20
23
- CHAIN_COOLDOWN_STEPS = 10
24
- MAX_ACTION_HISTORY = 20
25
  ACTION_TIMEOUT_SECONDS = 8.0
26
- DEFAULT_BASELINE_LATENCY_MS = 50.0
 
 
 
 
 
 
 
 
 
 
 
 
27
  EARLY_TERMINATION_REWARD = -200.0
28
 
29
 
30
  @dataclass(frozen=True)
31
  class RewardConstants:
32
- """Typed reward constants passed into the reward grader.
33
-
34
- Attributes:
35
- ALPHA: Latency reward weight.
36
- BETA: State-chain completion bonus.
37
- GAMMA: Exploitation bounty for severe degradation/failure.
38
- DELTA: Penalty magnitude for low-value fuzzing requests.
39
- REDOS_LATENCY_THRESHOLD: Latency threshold used to infer ReDoS impact.
40
- DB_TIMEOUT_THRESHOLD: Latency threshold used for DB timeout exploitation.
41
- CHAIN_REQUIRED_ORDERS: Minimum order count before GET /orders chain bonus.
42
- CHAIN_COOLDOWN_STEPS: Minimum steps between chain bonus awards.
43
- EARLY_TERMINATION_REWARD: Episode early-stop reward floor.
44
- BASELINE_WINDOW: EMA window used for baseline latency updates.
45
- """
46
-
47
- ALPHA: float = ALPHA
48
- BETA: float = BETA
49
- GAMMA: float = GAMMA
50
- DELTA: float = DELTA
51
- REDOS_LATENCY_THRESHOLD: float = REDOS_LATENCY_THRESHOLD
52
- DB_TIMEOUT_THRESHOLD: float = DB_TIMEOUT_THRESHOLD
53
  CHAIN_REQUIRED_ORDERS: int = CHAIN_REQUIRED_ORDERS
54
  CHAIN_COOLDOWN_STEPS: int = CHAIN_COOLDOWN_STEPS
 
 
 
 
 
55
  EARLY_TERMINATION_REWARD: float = EARLY_TERMINATION_REWARD
56
- BASELINE_WINDOW: int = BASELINE_WINDOW
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
4
 
5
 
6
+ DEFAULT_BASELINE_LATENCY_MS = 50.0
 
 
 
7
  BASELINE_WINDOW = 10
 
 
 
 
 
 
8
  ACTION_TIMEOUT_SECONDS = 8.0
9
+
10
+ REDOS_LATENCY_THRESHOLD_MS = 1500.0
11
+ DB_TIMEOUT_THRESHOLD_MS = 800.0
12
+
13
+ CHAIN_REQUIRED_ORDERS = 21
14
+ CHAIN_COOLDOWN_STEPS = 10
15
+
16
+ STEP_DELTA_MAX = 0.30
17
+ NEW_ENDPOINT_BONUS = 0.05
18
+ NEW_VULNERABILITY_BONUS = 0.10
19
+ REPEATED_ACTION_PENALTY = 0.02
20
+ TERMINAL_BONUS = 0.20
21
+
22
  EARLY_TERMINATION_REWARD = -200.0
23
 
24
 
25
  @dataclass(frozen=True)
26
  class RewardConstants:
27
+ DEFAULT_BASELINE_LATENCY_MS: float = DEFAULT_BASELINE_LATENCY_MS
28
+ BASELINE_WINDOW: int = BASELINE_WINDOW
29
+ ACTION_TIMEOUT_SECONDS: float = ACTION_TIMEOUT_SECONDS
30
+ REDOS_LATENCY_THRESHOLD_MS: float = REDOS_LATENCY_THRESHOLD_MS
31
+ DB_TIMEOUT_THRESHOLD_MS: float = DB_TIMEOUT_THRESHOLD_MS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  CHAIN_REQUIRED_ORDERS: int = CHAIN_REQUIRED_ORDERS
33
  CHAIN_COOLDOWN_STEPS: int = CHAIN_COOLDOWN_STEPS
34
+ STEP_DELTA_MAX: float = STEP_DELTA_MAX
35
+ NEW_ENDPOINT_BONUS: float = NEW_ENDPOINT_BONUS
36
+ NEW_VULNERABILITY_BONUS: float = NEW_VULNERABILITY_BONUS
37
+ REPEATED_ACTION_PENALTY: float = REPEATED_ACTION_PENALTY
38
+ TERMINAL_BONUS: float = TERMINAL_BONUS
39
  EARLY_TERMINATION_REWARD: float = EARLY_TERMINATION_REWARD
 
statestrike_env/environment.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import os
5
+ import subprocess
6
+ import time
7
+ from contextlib import asynccontextmanager
8
+ from typing import Any
9
+
10
+ import httpx
11
+ from fastapi import Body, FastAPI
12
+
13
+ from statestrike_env.constants import RewardConstants
14
+ from statestrike_env.grader import compute_task_reward, compute_task_score
15
+ from statestrike_env.models import (
16
+ EndpointChoice,
17
+ PayloadStrategy,
18
+ StateStrikeAction,
19
+ StateStrikeObservation,
20
+ StateStrikeState,
21
+ StepResult,
22
+ )
23
+ from statestrike_env.session import StateStrikeSession
24
+ from statestrike_env.tasks import TASK_REGISTRY
25
+
26
+
27
+ class StateStrikeEnv:
28
+ """Unified OpenEnv-compatible runtime for StateStrike."""
29
+
30
+ def __init__(
31
+ self,
32
+ honeypot_url: str | None = None,
33
+ constants: RewardConstants | None = None,
34
+ ) -> None:
35
+ self.honeypot_url = (honeypot_url or os.getenv("HONEYPOT_URL", "http://localhost:8000")).rstrip("/")
36
+ self.constants = constants or RewardConstants()
37
+ self.session = StateStrikeSession.new_session("endpoint_discovery")
38
+ self._managed_container_id: str | None = None
39
+
40
+ @classmethod
41
+ async def from_docker_image(cls, image_name: str) -> StateStrikeEnv:
42
+ env = cls(honeypot_url=os.getenv("HONEYPOT_URL", "http://localhost:8000"))
43
+ proc = await asyncio.create_subprocess_exec(
44
+ "docker",
45
+ "run",
46
+ "-d",
47
+ "-p",
48
+ "8000:8000",
49
+ image_name,
50
+ stdout=asyncio.subprocess.PIPE,
51
+ stderr=asyncio.subprocess.PIPE,
52
+ )
53
+ stdout, stderr = await proc.communicate()
54
+ if proc.returncode != 0:
55
+ raise RuntimeError(f"Failed to start docker image: {stderr.decode().strip()}")
56
+ env._managed_container_id = stdout.decode().strip()
57
+
58
+ for _ in range(30):
59
+ try:
60
+ async with httpx.AsyncClient(timeout=2.0) as client:
61
+ response = await client.get(f"{env.honeypot_url}/health")
62
+ if response.status_code == 200:
63
+ return env
64
+ except Exception:
65
+ pass
66
+ await asyncio.sleep(1)
67
+
68
+ await env.close()
69
+ raise RuntimeError("Timed out waiting for honeypot container to become ready")
70
+
71
+ async def close(self) -> None:
72
+ if self._managed_container_id:
73
+ process = await asyncio.create_subprocess_exec(
74
+ "docker",
75
+ "rm",
76
+ "-f",
77
+ self._managed_container_id,
78
+ stdout=asyncio.subprocess.PIPE,
79
+ stderr=asyncio.subprocess.PIPE,
80
+ )
81
+ await process.communicate()
82
+ self._managed_container_id = None
83
+
84
+ async def reset(self, task_name: str = "endpoint_discovery") -> StepResult:
85
+ if task_name not in TASK_REGISTRY:
86
+ task_name = "endpoint_discovery"
87
+
88
+ status, latency_ms, _ = await self._request_honeypot("GET", "/health")
89
+ baseline = latency_ms if latency_ms > 0 else self.constants.DEFAULT_BASELINE_LATENCY_MS
90
+ self.session.reset(task_name=task_name, baseline_latency=baseline)
91
+
92
+ observation = StateStrikeObservation(
93
+ step=0,
94
+ endpoint_called=EndpointChoice.HEALTH.value,
95
+ http_status=status,
96
+ latency_ms=latency_ms,
97
+ response_body={"status": "reset"},
98
+ session_order_count=0,
99
+ endpoints_discovered=[],
100
+ vulnerabilities_found=[],
101
+ task_progress=0.0,
102
+ )
103
+ return StepResult(observation=observation, reward=0.0, done=False, info={"task": task_name})
104
+
105
+ async def step(self, action: StateStrikeAction) -> StepResult:
106
+ method, path, params, payload = self._translate_action(action)
107
+ status, latency_ms, body = await self._request_honeypot(method, path, params=params, payload=payload)
108
+
109
+ endpoint_value = action.endpoint if isinstance(action.endpoint, str) else action.endpoint.value
110
+ strategy_value = (
111
+ action.payload_strategy if isinstance(action.payload_strategy, str) else action.payload_strategy.value
112
+ )
113
+
114
+ self.session.step_count += 1
115
+ if endpoint_value == EndpointChoice.POST_ORDERS.value and status in (200, 201):
116
+ self.session.order_count += 1
117
+
118
+ endpoint_name = endpoint_value
119
+ new_endpoint = False
120
+ if status > 0 and endpoint_name not in self.session.endpoints_discovered:
121
+ self.session.endpoints_discovered.add(endpoint_name)
122
+ new_endpoint = True
123
+
124
+ signature = f"{endpoint_value}|{strategy_value}|{action.target_user_id}"
125
+ repeated_action = signature == self.session.last_action_signature
126
+ self.session.last_action_signature = signature
127
+
128
+ new_vulnerability = False
129
+
130
+ if (
131
+ not self.session.redos_bounty_awarded
132
+ and endpoint_value == EndpointChoice.POST_USERS.value
133
+ and strategy_value == PayloadStrategy.REDOS_ATTACK.value
134
+ and status == 400
135
+ and latency_ms > self.constants.REDOS_LATENCY_THRESHOLD_MS
136
+ ):
137
+ self.session.redos_bounty_awarded = True
138
+ self.session.vulnerabilities_found.add("redos")
139
+ new_vulnerability = True
140
+
141
+ chain_cooldown_ready = (
142
+ self.session.step_count - self.session.last_chain_bonus_step
143
+ ) >= self.constants.CHAIN_COOLDOWN_STEPS
144
+ chain_progressed = self.session.order_count > self.session.post_count_at_last_chain
145
+ if (
146
+ not self.session.db_degradation_bounty_awarded
147
+ and endpoint_value == EndpointChoice.GET_ORDERS.value
148
+ and self.session.order_count >= self.constants.CHAIN_REQUIRED_ORDERS
149
+ and latency_ms > self.constants.DB_TIMEOUT_THRESHOLD_MS
150
+ and chain_cooldown_ready
151
+ and chain_progressed
152
+ ):
153
+ self.session.db_degradation_bounty_awarded = True
154
+ self.session.vulnerabilities_found.add("db_degradation")
155
+ self.session.last_chain_bonus_step = self.session.step_count
156
+ self.session.post_count_at_last_chain = self.session.order_count
157
+ new_vulnerability = True
158
+
159
+ if (
160
+ endpoint_value == EndpointChoice.POST_USERS.value
161
+ and strategy_value == PayloadStrategy.VALID.value
162
+ and status in (200, 201)
163
+ ):
164
+ self.session.user_created = True
165
+
166
+ self.session.steps_history.append(
167
+ {
168
+ "endpoint": endpoint_value,
169
+ "payload_strategy": strategy_value,
170
+ "target_user_id": action.target_user_id,
171
+ "http_status": status,
172
+ "latency_ms": latency_ms,
173
+ }
174
+ )
175
+ if len(self.session.steps_history) > 200:
176
+ self.session.steps_history.pop(0)
177
+
178
+ self.session.task_specific_state["new_endpoint_discovered"] = new_endpoint
179
+ self.session.task_specific_state["new_vulnerability_found"] = new_vulnerability
180
+ self.session.task_specific_state["repeated_action"] = repeated_action
181
+
182
+ task_score = compute_task_score(self.session, self.session.task_name)
183
+ observation = StateStrikeObservation(
184
+ step=self.session.step_count,
185
+ endpoint_called=endpoint_value,
186
+ http_status=status,
187
+ latency_ms=latency_ms,
188
+ response_body=body,
189
+ session_order_count=self.session.order_count,
190
+ endpoints_discovered=sorted(self.session.endpoints_discovered),
191
+ vulnerabilities_found=sorted(self.session.vulnerabilities_found),
192
+ task_progress=task_score,
193
+ )
194
+
195
+ reward, breakdown = compute_task_reward(
196
+ observation,
197
+ self.session,
198
+ self.session.task_name,
199
+ self.constants,
200
+ )
201
+ self.session.cumulative_reward += reward
202
+
203
+ task_cfg, _ = TASK_REGISTRY[self.session.task_name]
204
+ done = self.session.step_count >= task_cfg.max_steps or task_score >= task_cfg.success_threshold
205
+
206
+ return StepResult(
207
+ observation=observation,
208
+ reward=reward,
209
+ done=done,
210
+ info={
211
+ "reward_breakdown": breakdown,
212
+ "task": self.session.task_name,
213
+ },
214
+ )
215
+
216
+ async def state(self) -> StateStrikeState:
217
+ return self.session.as_state()
218
+
219
+ def reset_sync(self, task_name: str = "endpoint_discovery") -> StepResult:
220
+ return asyncio.run(self.reset(task_name=task_name))
221
+
222
+ def step_sync(self, action: StateStrikeAction) -> StepResult:
223
+ return asyncio.run(self.step(action))
224
+
225
+ def state_sync(self) -> StateStrikeState:
226
+ return asyncio.run(self.state())
227
+
228
+ def _translate_action(
229
+ self,
230
+ action: StateStrikeAction,
231
+ ) -> tuple[str, str, dict[str, Any] | None, dict[str, Any] | None]:
232
+ endpoint_value = action.endpoint if isinstance(action.endpoint, str) else action.endpoint.value
233
+ strategy_value = (
234
+ action.payload_strategy if isinstance(action.payload_strategy, str) else action.payload_strategy.value
235
+ )
236
+ target_user_id = action.target_user_id or 1
237
+
238
+ if endpoint_value == EndpointChoice.POST_USERS.value:
239
+ return "POST", "/users", None, {"email": self._email_for_strategy(strategy_value)}
240
+ if endpoint_value == EndpointChoice.GET_USER.value:
241
+ return "GET", f"/users/{target_user_id}", None, None
242
+ if endpoint_value == EndpointChoice.POST_ORDERS.value:
243
+ return "POST", "/orders", None, {
244
+ "user_id": target_user_id,
245
+ "item": self._item_for_strategy(strategy_value),
246
+ }
247
+ if endpoint_value == EndpointChoice.GET_ORDERS.value:
248
+ return "GET", "/orders", {"user_id": target_user_id}, None
249
+ return "GET", "/health", None, None
250
+
251
+ @staticmethod
252
+ def _email_for_strategy(strategy: str) -> str:
253
+ if strategy == PayloadStrategy.REDOS_ATTACK.value:
254
+ return "a" * 39 + "!"
255
+ if strategy == PayloadStrategy.OVERSIZED.value:
256
+ return "A" * 4096
257
+ if strategy == PayloadStrategy.MALFORMED.value:
258
+ return "@@@"
259
+ return "validuser123"
260
+
261
+ @staticmethod
262
+ def _item_for_strategy(strategy: str) -> str:
263
+ if strategy == PayloadStrategy.OVERSIZED.value:
264
+ return "item_" + ("X" * 2048)
265
+ if strategy == PayloadStrategy.MALFORMED.value:
266
+ return ""
267
+ return "standard_item"
268
+
269
+ async def _request_honeypot(
270
+ self,
271
+ method: str,
272
+ path: str,
273
+ *,
274
+ params: dict[str, Any] | None = None,
275
+ payload: dict[str, Any] | None = None,
276
+ ) -> tuple[int, float, dict[str, Any]]:
277
+ url = f"{self.honeypot_url}{path}"
278
+ started = time.perf_counter()
279
+ try:
280
+ async with httpx.AsyncClient(timeout=self.constants.ACTION_TIMEOUT_SECONDS) as client:
281
+ response = await client.request(method, url, params=params, json=payload)
282
+ elapsed_ms = (time.perf_counter() - started) * 1000.0
283
+ header_latency = response.headers.get("X-Process-Time-Ms")
284
+ latency_ms = float(header_latency) if header_latency else elapsed_ms
285
+ body = response.json() if response.content else {}
286
+ return response.status_code, latency_ms, body
287
+ except Exception as exc:
288
+ return 0, 0.0, {"error": str(exc), "synthetic": True}
289
+
290
+
291
+ @asynccontextmanager
292
+ async def lifespan(_: FastAPI):
293
+ yield
294
+
295
+
296
+ app = FastAPI(title="StateStrike", lifespan=lifespan)
297
+ _global_env = StateStrikeEnv()
298
+
299
+
300
+ @app.post("/reset")
301
+ async def reset_endpoint(body: dict = Body(default={})):
302
+ task = body.get("task", "endpoint_discovery")
303
+ result = await _global_env.reset(task_name=task)
304
+ return result.model_dump()
305
+
306
+
307
+ @app.post("/step")
308
+ async def step_endpoint(action: StateStrikeAction):
309
+ result = await _global_env.step(action)
310
+ return result.model_dump()
311
+
312
+
313
+ @app.get("/state")
314
+ async def state_endpoint():
315
+ return (await _global_env.state()).model_dump()
316
+
317
+
318
+ @app.get("/health")
319
+ async def health():
320
+ return {"status": "ok"}
321
+
322
+
323
+ def main() -> None:
324
+ import uvicorn
325
+
326
+ host = os.getenv("STATESTRIKE_ENV_HOST", "0.0.0.0")
327
+ port = int(os.getenv("STATESTRIKE_ENV_PORT", "7860"))
328
+ uvicorn.run("statestrike_env.environment:app", host=host, port=port, reload=False)
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()
statestrike_env/grader.py CHANGED
@@ -1,208 +1,99 @@
1
  from __future__ import annotations
2
 
3
- """Reward grading logic for StateStrike.
4
-
5
- Theory:
6
- The reward function follows standard MDP shaping principles from Sutton &
7
- Barto (2018): combine dense shaping signals (latency ratio), sparse goal
8
- rewards (exploit bounty), and penalties (invalid spam suppression). It also
9
- borrows stateful-sequence ideas from RESTler (Atlidakis et al., ICSE 2019)
10
- while rewarding infrastructure effects (e.g., ReDoS latency spikes) inspired
11
- by Davis et al. (USENIX Security 2018).
12
- """
13
-
14
  import logging
15
- import math
16
- from typing import TYPE_CHECKING
17
 
18
- if TYPE_CHECKING:
19
- from statestrike_env.constants import RewardConstants
20
- from statestrike_env.models import StateStrikeObservation
21
- from statestrike_env.session import StateStrikeSession
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
- def compute_reward(
27
- obs: "StateStrikeObservation",
28
- session: "StateStrikeSession",
29
- constants: "RewardConstants",
30
- ) -> tuple[float, dict[str, float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
- Compute R_t = α·log(L_t/L_base) + β·S_t + γ·E_t δ·P_t
33
-
34
- Theory (Sutton & Barto, 2018, Ch. 3 — Finite MDPs):
35
- The reward signal must be designed so the ONLY way to maximize cumulative
36
- reward is to achieve the TRUE objective. Each term is chosen to prevent a
37
- specific reward-hacking strategy:
38
-
39
- TERM 1 α·log(L_t/L_base): Logarithmic latency reward.
40
- Why log? Linear reward incentivizes the agent to find ONE massive spike
41
- and repeat it. Logarithmic reward gives diminishing returns per repeated
42
- exploitation, pushing the agent to discover NEW vulnerabilities.
43
- Why ratio? Prevents baseline-anchoring attacks where agent engineers a
44
- low baseline then makes normal requests look like spikes.
45
- Anti-hack: baseline ONLY updates from successful (latency>0) steps.
46
-
47
- TERM 2 — β·S_t: State-chain bonus.
48
- Fires at most once per CHAIN_COOLDOWN_STEPS steps, and only if
49
- order_count has increased since the last award. This prevents the
50
- POST→GET farming loop that would yield +5 reward/step for free.
51
- Anti-hack: last_chain_bonus_step and post_count_at_last_chain guards.
52
-
53
- TERM 3 — γ·E_t: Exploitation bounty.
54
- Fires EXACTLY ONCE per vulnerability type per episode (one-time flag).
55
- Without this, an agent discovering db_degradation would spam GET /orders
56
- for +500/step indefinitely. The one-time award correctly signals
57
- "you found it" without incentivizing repeated triggering.
58
- Anti-hack: redos_bounty_awarded and db_degradation_bounty_awarded flags.
59
-
60
- TERM 4 — δ·P_t: Fuzzing penalty.
61
- Applied only to genuinely fast 400s (latency < 100ms), not to slow 400s
62
- (which may indicate actual CPU burn from ReDoS parsing).
63
- Threshold tightened from 200ms to 100ms to avoid penalizing legitimate
64
- slow-failing payloads.
65
- Anti-hack: latency threshold ensures ReDoS probes are not penalized.
66
-
67
- Reference:
68
- - Sutton & Barto (2018): reward shaping and sparse reward design
69
- - Atlidakis et al. (ICSE 2019): stateful API exploration objectives
70
- - Davis et al. (USENIX 2018): ReDoS computational complexity
71
-
72
- Args:
73
- obs: The observation from the current step.
74
- session: The mutable session state (modified in-place for flags).
75
- constants: Reward weight constants from constants.py.
76
-
77
- Returns:
78
- Tuple of (scalar_reward, breakdown_dict) where breakdown_dict
79
- contains each term's contribution for telemetry and dashboard display.
80
  """
81
 
82
- reward = 0.0
83
- breakdown: dict[str, float] = {
84
- "latency_reward": 0.0,
85
- "chain_bonus": 0.0,
86
- "exploit_bounty": 0.0,
87
- "fuzz_penalty": 0.0,
 
88
  "total": 0.0,
89
  }
90
 
91
- # Guard: connection failure -> neutral observation, no reward signal.
92
  if obs.http_status == 0 or obs.latency_ms == 0.0:
93
- breakdown["error"] = 1.0
94
- logger.debug("Step %d: connection failure, returning zero reward", obs.step)
95
  return 0.0, breakdown
96
 
97
- # Update rolling baseline only from successful steps.
98
- _update_baseline(session, obs.latency_ms, constants.BASELINE_WINDOW)
99
-
100
- # TERM 1: Logarithmic latency reward.
101
- if session.baseline_latency > 0:
102
- latency_ratio = obs.latency_ms / session.baseline_latency
103
- latency_ratio = max(0.01, min(latency_ratio, 100.0))
104
- latency_reward = constants.ALPHA * math.log(latency_ratio)
105
- reward += latency_reward
106
- breakdown["latency_reward"] = round(latency_reward, 4)
107
- logger.debug(
108
- "Step %d: latency=%.1fms baseline=%.1fms ratio=%.2f reward=%.3f",
109
- obs.step,
110
- obs.latency_ms,
111
- session.baseline_latency,
112
- latency_ratio,
113
- latency_reward,
114
- )
115
-
116
- # TERM 2: State-chain bonus (anti-farming guards).
117
- chain_bonus = 0.0
118
- if _should_award_chain_bonus(obs, session, constants):
119
- chain_bonus = constants.BETA
120
- session.last_chain_bonus_step = obs.step
121
- session.post_count_at_last_chain = session.order_count
122
- logger.info(
123
- "Step %d: Chain bonus awarded (+%.1f). order_count=%d",
124
- obs.step,
125
- chain_bonus,
126
- session.order_count,
127
- )
128
- reward += chain_bonus
129
- breakdown["chain_bonus"] = chain_bonus
130
-
131
- # TERM 3: Exploitation bounties (one-time per episode).
132
- exploit_bounty = 0.0
133
-
134
- if (
135
- not session.db_degradation_bounty_awarded
136
- and (obs.http_status >= 500 or obs.latency_ms > constants.DB_TIMEOUT_THRESHOLD)
137
- and obs.action_taken.action_type.value == "get_orders"
138
- ):
139
- exploit_bounty += constants.GAMMA
140
- session.db_degradation_bounty_awarded = True
141
- session.triggered_vulns.add("db_degradation")
142
- logger.info(
143
- "Step %d: DB_DEGRADATION bounty awarded (+%.1f). latency=%.1fms",
144
- obs.step,
145
- constants.GAMMA,
146
- obs.latency_ms,
147
- )
148
-
149
- if (
150
- not session.redos_bounty_awarded
151
- and obs.latency_ms > constants.REDOS_LATENCY_THRESHOLD
152
- and obs.http_status == 400
153
- and obs.action_taken.action_type.value == "post_user"
154
- ):
155
- redos_bounty = constants.GAMMA * 0.8
156
- exploit_bounty += redos_bounty
157
- session.redos_bounty_awarded = True
158
- session.triggered_vulns.add("redos")
159
- logger.info(
160
- "Step %d: REDOS bounty awarded (+%.1f). latency=%.1fms",
161
- obs.step,
162
- redos_bounty,
163
- obs.latency_ms,
164
- )
165
-
166
- reward += exploit_bounty
167
- breakdown["exploit_bounty"] = round(exploit_bounty, 4)
168
-
169
- # TERM 4: Fuzzing penalty (only genuine fast-fail syntax errors).
170
- fuzz_penalty = 0.0
171
- if obs.http_status == 400 and obs.latency_ms < 100.0:
172
- fuzz_penalty = -constants.DELTA
173
- logger.debug("Step %d: Fuzz penalty applied (fast 400, %.1fms)", obs.step, obs.latency_ms)
174
- reward += fuzz_penalty
175
- breakdown["fuzz_penalty"] = round(fuzz_penalty, 4)
176
 
177
- breakdown["total"] = round(reward, 4)
178
- return reward, breakdown
 
 
 
179
 
 
180
 
181
- def _update_baseline(session: "StateStrikeSession", latency_ms: float, window: int) -> None:
182
- """Update rolling baseline latency using exponential moving average."""
 
183
 
184
- alpha_ema = 2.0 / (window + 1)
185
- if session.baseline_sample_count == 0:
186
- session.baseline_latency = latency_ms
187
- else:
188
- session.baseline_latency = alpha_ema * latency_ms + (1 - alpha_ema) * session.baseline_latency
189
- session.baseline_sample_count += 1
190
 
 
 
 
191
 
192
- def _should_award_chain_bonus(
193
- obs: "StateStrikeObservation",
194
- session: "StateStrikeSession",
195
- constants: "RewardConstants",
196
- ) -> bool:
197
- """Determine if the state-chain bonus should be awarded this step."""
198
-
199
- if obs.action_taken.action_type.value != "get_orders":
200
- return False
201
- if session.order_count < constants.CHAIN_REQUIRED_ORDERS:
202
- return False
203
- steps_since_last = obs.step - session.last_chain_bonus_step
204
- if steps_since_last < constants.CHAIN_COOLDOWN_STEPS:
205
- return False
206
- if session.order_count <= session.post_count_at_last_chain:
207
- return False
208
- return True
 
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
+ from typing import Any
 
5
 
6
+ from statestrike_env.constants import RewardConstants
7
+ from statestrike_env.models import StateStrikeObservation
8
+ from statestrike_env.session import StateStrikeSession
9
+ from statestrike_env.tasks import TASK_REGISTRY
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def compute_task_score(session: StateStrikeSession, task_name: str) -> float:
15
+ task_config, grader = TASK_REGISTRY[task_name]
16
+ del task_config
17
+ return float(grader.score(session.as_grader_state()))
18
+
19
+
20
+ def _update_baseline_ema(session: StateStrikeSession, latency_ms: float, window: int) -> None:
21
+ alpha_ema = 2.0 / (window + 1)
22
+ if session.baseline_sample_count == 0:
23
+ session.baseline_latency = latency_ms
24
+ else:
25
+ session.baseline_latency = alpha_ema * latency_ms + (1 - alpha_ema) * session.baseline_latency
26
+ session.baseline_sample_count += 1
27
+
28
+
29
+ def compute_task_reward(
30
+ obs: StateStrikeObservation,
31
+ session: StateStrikeSession,
32
+ task_name: str,
33
+ constants: RewardConstants,
34
+ ) -> tuple[float, dict[str, Any]]:
35
  """
36
+ Compute step reward in [0.0, 1.0] based on task progress delta.
37
+
38
+ Theory (reward shaping, Ng et al. 1999):
39
+ R_shaped(s, a, s') = R(s, a, s') + gamma*Phi(s') - Phi(s)
40
+ where Phi(s) = task_score(s) is the potential function.
41
+
42
+ The terminal bonus (+0.20) is a sparse goal reward layered on top of
43
+ the shaped reward, following the hybrid approach in Sutton & Barto (2018).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
 
46
+ task_config, _ = TASK_REGISTRY[task_name]
47
+ breakdown: dict[str, Any] = {
48
+ "score_delta": 0.0,
49
+ "new_endpoint_bonus": 0.0,
50
+ "new_vulnerability_bonus": 0.0,
51
+ "repeat_penalty": 0.0,
52
+ "terminal_bonus": 0.0,
53
  "total": 0.0,
54
  }
55
 
 
56
  if obs.http_status == 0 or obs.latency_ms == 0.0:
57
+ breakdown["error"] = "connection_failed"
 
58
  return 0.0, breakdown
59
 
60
+ _update_baseline_ema(session, obs.latency_ms, constants.BASELINE_WINDOW)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ current_task_score = compute_task_score(session, task_name)
63
+ previous_task_score = session.previous_task_score
64
+ score_delta = max(0.0, current_task_score - previous_task_score)
65
+ score_delta = min(score_delta, constants.STEP_DELTA_MAX)
66
+ breakdown["score_delta"] = round(score_delta, 4)
67
 
68
+ reward = score_delta
69
 
70
+ if bool(session.task_specific_state.get("new_endpoint_discovered", False)):
71
+ reward += constants.NEW_ENDPOINT_BONUS
72
+ breakdown["new_endpoint_bonus"] = constants.NEW_ENDPOINT_BONUS
73
 
74
+ if bool(session.task_specific_state.get("new_vulnerability_found", False)):
75
+ reward += constants.NEW_VULNERABILITY_BONUS
76
+ breakdown["new_vulnerability_bonus"] = constants.NEW_VULNERABILITY_BONUS
 
 
 
77
 
78
+ if bool(session.task_specific_state.get("repeated_action", False)):
79
+ reward -= constants.REPEATED_ACTION_PENALTY
80
+ breakdown["repeat_penalty"] = -constants.REPEATED_ACTION_PENALTY
81
 
82
+ terminal = session.step_count >= task_config.max_steps or current_task_score >= task_config.success_threshold
83
+ if terminal and current_task_score >= task_config.success_threshold:
84
+ reward += constants.TERMINAL_BONUS
85
+ breakdown["terminal_bonus"] = constants.TERMINAL_BONUS
86
+
87
+ reward = max(0.0, min(1.0, reward))
88
+ breakdown["total"] = round(reward, 4)
89
+ breakdown["task_score"] = round(current_task_score, 4)
90
+
91
+ session.previous_task_score = max(previous_task_score, current_task_score)
92
+ logger.debug(
93
+ "task=%s step=%d score=%.3f reward=%.3f",
94
+ task_name,
95
+ obs.step,
96
+ current_task_score,
97
+ reward,
98
+ )
99
+ return reward, breakdown
statestrike_env/models.py CHANGED
@@ -1,31 +1,20 @@
1
  from __future__ import annotations
2
 
3
- """Typed action, observation, and state models for StateStrike.
4
-
5
- Theory:
6
- Explicit state/action schemas reduce ambiguity in RL interfaces and improve
7
- reproducibility when evaluating policies across different backends.
8
- """
9
-
10
  from enum import Enum
11
  from typing import Any, Optional
12
 
13
  from pydantic import BaseModel, Field
14
 
15
 
16
- class ActionType(str, Enum):
17
- """Discrete actions available to the StateStrike agent."""
18
-
19
- POST_USER = "post_user"
20
- GET_USER = "get_user"
21
- POST_ORDER = "post_order"
22
- GET_ORDERS = "get_orders"
23
- HEALTH_CHECK = "health_check"
24
 
25
 
26
  class PayloadStrategy(str, Enum):
27
- """Payload generation strategies used by the fuzzing policy."""
28
-
29
  VALID = "valid"
30
  REDOS_ATTACK = "redos"
31
  OVERSIZED = "oversized"
@@ -33,66 +22,48 @@ class PayloadStrategy(str, Enum):
33
 
34
 
35
  class StateStrikeAction(BaseModel):
36
- """Action frame sent by the RL agent.
37
-
38
- Args:
39
- action_type: Target endpoint operation.
40
- payload_strategy: Payload mutation strategy.
41
- target_user_id: Optional user identifier override.
42
- """
43
 
44
- action_type: ActionType
45
- payload_strategy: PayloadStrategy
46
  target_user_id: Optional[int] = None
47
 
 
 
 
48
 
49
  class StateStrikeObservation(BaseModel):
50
- """Step-level feedback returned by the environment.
51
-
52
- Args:
53
- step: Current step index within the episode.
54
- action_taken: Action executed during the step.
55
- http_status: HTTP status code from honeypot response.
56
- latency_ms: End-to-end processing latency in milliseconds.
57
- reward: Scalar reward at this step.
58
- cumulative_reward: Running reward sum for the episode.
59
- baseline_latency_ms: Rolling latency baseline used for normalization.
60
- order_count: Number of POST /orders calls in this episode.
61
- triggered_vulns: Vulnerability labels discovered so far.
62
- done: Terminal signal for episode completion.
63
- info: Arbitrary metadata, including reward breakdown.
64
- """
65
 
66
  step: int
67
- action_taken: StateStrikeAction
68
  http_status: int
69
  latency_ms: float
 
 
 
 
 
 
 
 
 
 
 
70
  reward: float
71
- cumulative_reward: float
72
- baseline_latency_ms: float
73
- order_count: int
74
- triggered_vulns: list[str]
75
  done: bool
76
  info: dict[str, Any] = Field(default_factory=dict)
77
 
78
 
79
  class StateStrikeState(BaseModel):
80
- """Persistent session state exposed by state().
81
-
82
- Args:
83
- session_id: Unique identifier for current environment episode.
84
- step_count: Number of actions executed in current session.
85
- cumulative_reward: Running reward sum for current session.
86
- order_count: Number of POST /orders calls in session.
87
- baseline_latency_ms: Rolling baseline latency in milliseconds.
88
- action_history: Most recent action history window.
89
- triggered_vulns: Vulnerabilities discovered in this session.
90
- """
91
 
92
  session_id: str
 
93
  step_count: int
94
  cumulative_reward: float
95
  order_count: int
96
  baseline_latency_ms: float
97
- action_history: list[StateStrikeAction]
98
- triggered_vulns: list[str]
 
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
 
3
  from enum import Enum
4
  from typing import Any, Optional
5
 
6
  from pydantic import BaseModel, Field
7
 
8
 
9
+ class EndpointChoice(str, Enum):
10
+ POST_USERS = "POST /users"
11
+ GET_USER = "GET /users/{id}"
12
+ POST_ORDERS = "POST /orders"
13
+ GET_ORDERS = "GET /orders"
14
+ HEALTH = "GET /health"
 
 
15
 
16
 
17
  class PayloadStrategy(str, Enum):
 
 
18
  VALID = "valid"
19
  REDOS_ATTACK = "redos"
20
  OVERSIZED = "oversized"
 
22
 
23
 
24
  class StateStrikeAction(BaseModel):
25
+ """Action space for StateStrike environment."""
 
 
 
 
 
 
26
 
27
+ endpoint: EndpointChoice
28
+ payload_strategy: PayloadStrategy = PayloadStrategy.VALID
29
  target_user_id: Optional[int] = None
30
 
31
+ class Config:
32
+ use_enum_values = True
33
+
34
 
35
  class StateStrikeObservation(BaseModel):
36
+ """Observation returned after each step."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  step: int
39
+ endpoint_called: str
40
  http_status: int
41
  latency_ms: float
42
+ response_body: dict[str, Any] = Field(default_factory=dict)
43
+ session_order_count: int = 0
44
+ endpoints_discovered: list[str] = Field(default_factory=list)
45
+ vulnerabilities_found: list[str] = Field(default_factory=list)
46
+ task_progress: float = 0.0
47
+
48
+
49
+ class StepResult(BaseModel):
50
+ """Top-level return from step()."""
51
+
52
+ observation: StateStrikeObservation
53
  reward: float
 
 
 
 
54
  done: bool
55
  info: dict[str, Any] = Field(default_factory=dict)
56
 
57
 
58
  class StateStrikeState(BaseModel):
59
+ """Full session state, returned by state()."""
 
 
 
 
 
 
 
 
 
 
60
 
61
  session_id: str
62
+ task_name: str
63
  step_count: int
64
  cumulative_reward: float
65
  order_count: int
66
  baseline_latency_ms: float
67
+ endpoints_discovered: list[str]
68
+ vulnerabilities_found: list[str]
69
+ task_specific_state: dict[str, Any] = Field(default_factory=dict)
statestrike_env/server.py CHANGED
@@ -1,400 +1,7 @@
1
  from __future__ import annotations
2
 
3
- """OpenEnv-style WebSocket environment server for StateStrike."""
4
 
5
- import asyncio
6
- import json
7
- import logging
8
- import os
9
- import time
10
- from contextlib import asynccontextmanager
11
- from typing import Any
12
 
13
- import httpx
14
- from dotenv import load_dotenv
15
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
16
- from fastapi.responses import JSONResponse
17
-
18
- try:
19
- import openenv_core # noqa: F401
20
- except ImportError: # pragma: no cover - optional import for compatibility signaling.
21
- openenv_core = None
22
-
23
- from statestrike_env.constants import (
24
- ACTION_TIMEOUT_SECONDS,
25
- DEFAULT_BASELINE_LATENCY_MS,
26
- EPISODE_LENGTH,
27
- RewardConstants,
28
- )
29
- from statestrike_env.grader import compute_reward
30
- from statestrike_env.models import ActionType, PayloadStrategy, StateStrikeAction, StateStrikeObservation, StateStrikeState
31
- from statestrike_env.session import StateStrikeSession
32
-
33
- load_dotenv()
34
-
35
- logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
36
- LOGGER = logging.getLogger(__name__)
37
-
38
- HONEYPOT_URL = os.getenv("HONEYPOT_URL", "http://localhost:8000")
39
- HOST = os.getenv("STATESTRIKE_ENV_HOST", "0.0.0.0")
40
- PORT = int(os.getenv("STATESTRIKE_ENV_PORT", "8001"))
41
-
42
-
43
- async def wait_for_honeypot(url: str, max_wait: int = 30) -> None:
44
- """Block until honeypot is reachable or raise RuntimeError.
45
-
46
- Args:
47
- url: Honeypot base URL.
48
- max_wait: Maximum wait time in seconds.
49
-
50
- Raises:
51
- RuntimeError: If honeypot is not reachable before timeout.
52
- """
53
-
54
- deadline = asyncio.get_event_loop().time() + max_wait
55
- delay = 1.0
56
- async with httpx.AsyncClient() as client:
57
- while asyncio.get_event_loop().time() < deadline:
58
- try:
59
- response = await client.get(f"{url}/health", timeout=3.0)
60
- if response.status_code == 200:
61
- LOGGER.info("Honeypot is ready at %s", url)
62
- return
63
- LOGGER.warning(
64
- "Honeypot health returned status=%s, retrying in %.1fs...",
65
- response.status_code,
66
- delay,
67
- )
68
- except Exception as exc: # noqa: BLE001
69
- LOGGER.warning("Honeypot not ready (%s), retrying in %.1fs...", exc, delay)
70
-
71
- await asyncio.sleep(delay)
72
- delay = min(delay * 1.5, 5.0)
73
-
74
- raise RuntimeError(f"Honeypot at {url} did not become ready within {max_wait}s")
75
-
76
-
77
- class StateStrikeEnvironment:
78
- """Core reset/step/state implementation.
79
-
80
- Theory:
81
- OpenEnv training loops benefit from persistent transport: WebSocket-based
82
- sessions amortize handshake overhead and preserve episode-local state,
83
- which aligns with OpenEnv architecture guidance (Burtenshaw, 2025).
84
- """
85
-
86
- def __init__(self, honeypot_url: str, constants: RewardConstants | None = None) -> None:
87
- """Initialize environment service.
88
-
89
- Args:
90
- honeypot_url: Base URL for vulnerable honeypot API.
91
- constants: Optional reward constants override.
92
- """
93
-
94
- self.honeypot_url = honeypot_url.rstrip("/")
95
- self.constants = constants or RewardConstants()
96
-
97
- async def reset(self, session: StateStrikeSession) -> StateStrikeObservation:
98
- """Reset session and return initial observation.
99
-
100
- Args:
101
- session: Session object tied to one client connection.
102
-
103
- Returns:
104
- Initial observation with zero reward.
105
- """
106
-
107
- status, latency_ms, _ = await self._request_honeypot("GET", "/health")
108
- baseline = latency_ms if latency_ms > 0 else DEFAULT_BASELINE_LATENCY_MS
109
- session.reset(baseline_latency=baseline)
110
-
111
- action = StateStrikeAction(action_type=ActionType.HEALTH_CHECK, payload_strategy=PayloadStrategy.VALID)
112
- obs = StateStrikeObservation(
113
- step=0,
114
- action_taken=action,
115
- http_status=status,
116
- latency_ms=latency_ms,
117
- reward=0.0,
118
- cumulative_reward=0.0,
119
- baseline_latency_ms=session.baseline_latency,
120
- order_count=0,
121
- triggered_vulns=[],
122
- done=False,
123
- info={"event": "reset"},
124
- )
125
- return obs
126
-
127
- async def step(self, session: StateStrikeSession, action: StateStrikeAction) -> StateStrikeObservation:
128
- """Execute one environment transition.
129
-
130
- Args:
131
- session: Session object tied to one client connection.
132
- action: Agent action.
133
-
134
- Returns:
135
- Updated observation with reward and terminal signal.
136
- """
137
-
138
- request_method, request_path, params, payload = self._translate_action(action, session)
139
- status, latency_ms, body = await self._request_honeypot(request_method, request_path, params=params, payload=payload)
140
-
141
- session.step_count += 1
142
- if action.action_type == ActionType.POST_ORDER:
143
- session.order_count += 1
144
- session.append_action(action)
145
-
146
- provisional = StateStrikeObservation(
147
- step=session.step_count,
148
- action_taken=action,
149
- http_status=status,
150
- latency_ms=latency_ms,
151
- reward=0.0,
152
- cumulative_reward=session.cumulative_reward,
153
- baseline_latency_ms=session.baseline_latency,
154
- order_count=session.order_count,
155
- triggered_vulns=sorted(session.triggered_vulns),
156
- done=False,
157
- info={"response": body},
158
- )
159
-
160
- reward, breakdown = compute_reward(provisional, session, self.constants)
161
- session.cumulative_reward += reward
162
-
163
- done = (
164
- session.step_count >= EPISODE_LENGTH
165
- or session.cumulative_reward < self.constants.EARLY_TERMINATION_REWARD
166
- )
167
- obs = StateStrikeObservation(
168
- step=session.step_count,
169
- action_taken=action,
170
- http_status=status,
171
- latency_ms=latency_ms,
172
- reward=reward,
173
- cumulative_reward=session.cumulative_reward,
174
- baseline_latency_ms=session.baseline_latency,
175
- order_count=session.order_count,
176
- triggered_vulns=sorted(session.triggered_vulns),
177
- done=done,
178
- info={"reward_breakdown": breakdown, "response": body},
179
- )
180
- return obs
181
-
182
- async def state(self, session: StateStrikeSession) -> StateStrikeState:
183
- """Return serializable state snapshot.
184
-
185
- Args:
186
- session: Session object tied to one client connection.
187
-
188
- Returns:
189
- Current state model.
190
- """
191
-
192
- return session.as_state()
193
-
194
- def _translate_action(
195
- self,
196
- action: StateStrikeAction,
197
- session: StateStrikeSession,
198
- ) -> tuple[str, str, dict[str, Any] | None, dict[str, Any] | None]:
199
- """Translate action schema into honeypot HTTP request details.
200
-
201
- Args:
202
- action: Agent action.
203
- session: Session used for contextual defaults.
204
-
205
- Returns:
206
- Tuple of method, path, query params, and JSON payload.
207
- """
208
-
209
- target_user_id = action.target_user_id or 1
210
-
211
- if action.action_type == ActionType.POST_USER:
212
- email = self._payload_email(action.payload_strategy)
213
- return "POST", "/users", None, {"email": email}
214
- if action.action_type == ActionType.GET_USER:
215
- return "GET", f"/users/{target_user_id}", None, None
216
- if action.action_type == ActionType.POST_ORDER:
217
- item = self._payload_item(action.payload_strategy)
218
- return "POST", "/orders", None, {"user_id": target_user_id, "item": item}
219
- if action.action_type == ActionType.GET_ORDERS:
220
- return "GET", "/orders", {"user_id": target_user_id}, None
221
- return "GET", "/health", None, None
222
-
223
- @staticmethod
224
- def _payload_email(strategy: PayloadStrategy) -> str:
225
- """Build email-like payload for POST /users action.
226
-
227
- Args:
228
- strategy: Payload strategy enum.
229
-
230
- Returns:
231
- Strategy-specific string payload.
232
- """
233
-
234
- if strategy == PayloadStrategy.REDOS_ATTACK:
235
- return "a" * 39 + "!"
236
- if strategy == PayloadStrategy.OVERSIZED:
237
- return "A" * 4096
238
- if strategy == PayloadStrategy.MALFORMED:
239
- return "@@@"
240
- return "validuser123"
241
-
242
- @staticmethod
243
- def _payload_item(strategy: PayloadStrategy) -> str:
244
- """Build order item payload.
245
-
246
- Args:
247
- strategy: Payload strategy enum.
248
-
249
- Returns:
250
- Strategy-specific order item string.
251
- """
252
-
253
- if strategy == PayloadStrategy.OVERSIZED:
254
- return "item_" + ("X" * 2048)
255
- if strategy == PayloadStrategy.MALFORMED:
256
- return ""
257
- return "standard_item"
258
-
259
- async def _request_honeypot(
260
- self,
261
- method: str,
262
- path: str,
263
- *,
264
- params: dict[str, Any] | None = None,
265
- payload: dict[str, Any] | None = None,
266
- ) -> tuple[int, float, dict[str, Any]]:
267
- """Execute honeypot request and normalize response metadata.
268
-
269
- Args:
270
- method: HTTP method.
271
- path: Relative path.
272
- params: Optional query parameters.
273
- payload: Optional JSON body.
274
-
275
- Returns:
276
- Tuple of status code, latency milliseconds, and parsed response body.
277
- """
278
-
279
- url = f"{self.honeypot_url}{path}"
280
- started = time.perf_counter()
281
- try:
282
- async with httpx.AsyncClient(timeout=ACTION_TIMEOUT_SECONDS) as client:
283
- response = await client.request(method, url, params=params, json=payload)
284
- elapsed_ms = (time.perf_counter() - started) * 1000.0
285
- header_latency = response.headers.get("X-Process-Time-Ms")
286
- latency_ms = float(header_latency) if header_latency else elapsed_ms
287
- body = response.json() if response.content else {}
288
- return response.status_code, latency_ms, body
289
- except (httpx.RequestError, ValueError) as exc:
290
- LOGGER.warning("Honeypot request failed method=%s path=%s error=%s", method, path, exc)
291
- return 0, 0.0, {"error": str(exc), "synthetic": True}
292
-
293
-
294
- @asynccontextmanager
295
- async def lifespan(_: FastAPI):
296
- """Block API startup until honeypot health endpoint is reachable."""
297
-
298
- await wait_for_honeypot(HONEYPOT_URL, max_wait=30)
299
- yield
300
-
301
-
302
- app = FastAPI(title="StateStrike OpenEnv Server", version="1.0.0", lifespan=lifespan)
303
- env_service = StateStrikeEnvironment(HONEYPOT_URL)
304
- http_debug_session = StateStrikeSession.new_session()
305
-
306
-
307
- # OpenEnv uses WebSocket (/ws) for persistent sessions rather than
308
- # stateless HTTP. Each step() is a lightweight frame over an existing
309
- # connection (~0.1ms overhead vs ~10-50ms TCP handshake per HTTP call).
310
- # Reference: openenv-course module-5, burtenshaw/openenv-scaling
311
- # This architecture enables high-frequency RL training loops.
312
- @app.websocket("/ws")
313
- async def websocket_env(websocket: WebSocket) -> None:
314
- """Run one isolated environment loop per WebSocket client.
315
-
316
- Args:
317
- websocket: Connected client transport.
318
- """
319
-
320
- await websocket.accept()
321
- session = StateStrikeSession.new_session()
322
- LOGGER.info("WebSocket session started session_id=%s", session.session_id)
323
-
324
- try:
325
- while True:
326
- frame = await websocket.receive_text()
327
- request = json.loads(frame)
328
- method = request.get("method")
329
-
330
- if method == "reset":
331
- obs = await env_service.reset(session)
332
- await websocket.send_json({"ok": True, "observation": obs.model_dump()})
333
- continue
334
-
335
- if method == "step":
336
- action_payload = request.get("action", {})
337
- action = StateStrikeAction.model_validate(action_payload)
338
- obs = await env_service.step(session, action)
339
- await websocket.send_json({"ok": True, "observation": obs.model_dump()})
340
- continue
341
-
342
- if method == "state":
343
- state = await env_service.state(session)
344
- await websocket.send_json({"ok": True, "state": state.model_dump()})
345
- continue
346
-
347
- await websocket.send_json({"ok": False, "error": f"Unknown method: {method}"})
348
- except (WebSocketDisconnect, json.JSONDecodeError):
349
- LOGGER.info("WebSocket session ended session_id=%s", session.session_id)
350
-
351
-
352
- @app.get("/reset")
353
- async def reset_http() -> JSONResponse:
354
- """HTTP debug endpoint for reset semantics.
355
-
356
- Returns:
357
- JSON response containing reset observation.
358
- """
359
-
360
- obs = await env_service.reset(http_debug_session)
361
- return JSONResponse(obs.model_dump())
362
-
363
-
364
- @app.post("/step")
365
- async def step_http(action: StateStrikeAction) -> JSONResponse:
366
- """HTTP debug endpoint for step semantics.
367
-
368
- Args:
369
- action: Action payload.
370
-
371
- Returns:
372
- JSON response containing post-step observation.
373
- """
374
-
375
- obs = await env_service.step(http_debug_session, action)
376
- return JSONResponse(obs.model_dump())
377
-
378
-
379
- @app.get("/state")
380
- async def state_http() -> JSONResponse:
381
- """HTTP debug endpoint for state semantics.
382
-
383
- Returns:
384
- JSON response containing current session state.
385
- """
386
-
387
- state = await env_service.state(http_debug_session)
388
- return JSONResponse(state.model_dump())
389
-
390
-
391
- def main() -> None:
392
- """Entrypoint for running environment server via python -m."""
393
-
394
- import uvicorn
395
-
396
- uvicorn.run("statestrike_env.server:app", host=HOST, port=PORT, reload=False)
397
-
398
-
399
- if __name__ == "__main__":
400
- main()
 
1
  from __future__ import annotations
2
 
3
+ from statestrike_env.environment import StateStrikeEnv, app, main
4
 
5
+ StateStrikeEnvironment = StateStrikeEnv
 
 
 
 
 
 
6
 
7
+ __all__ = ["StateStrikeEnvironment", "StateStrikeEnv", "app", "main"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
statestrike_env/session.py CHANGED
@@ -1,122 +1,86 @@
1
  from __future__ import annotations
2
 
3
- """Session state manager for per-agent environment isolation."""
4
-
5
  from dataclasses import dataclass, field
 
6
  from uuid import uuid4
7
 
8
- from statestrike_env.constants import DEFAULT_BASELINE_LATENCY_MS, MAX_ACTION_HISTORY
9
- from statestrike_env.models import StateStrikeAction, StateStrikeState
10
 
11
 
12
  @dataclass
13
  class StateStrikeSession:
14
- """Mutable per-WebSocket environment session.
15
-
16
- Attributes:
17
- session_id: Current episode UUID.
18
- step_count: Number of steps taken in current episode.
19
- cumulative_reward: Running reward total.
20
- order_count: Number of POST /orders actions issued.
21
- baseline_latency: Rolling average latency used in reward normalization.
22
- action_history: Most recent action history window.
23
- triggered_vulns: Vulnerabilities discovered in current episode.
24
- redos_bounty_awarded: One-time ReDoS bounty guard.
25
- db_degradation_bounty_awarded: One-time DB degradation bounty guard.
26
- last_chain_bonus_step: Last step where chain bonus was awarded.
27
- post_count_at_last_chain: Order count snapshot at last chain award.
28
- baseline_sample_count: Number of successful baseline samples seen.
29
- """
30
-
31
  session_id: str
 
32
  step_count: int = 0
33
  cumulative_reward: float = 0.0
34
  order_count: int = 0
35
  baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS
36
- action_history: list[StateStrikeAction] = field(default_factory=list)
37
- triggered_vulns: set[str] = field(default_factory=set)
38
- # Anti-hacking: one-time flags so each bounty fires exactly once per episode.
 
 
 
 
 
 
 
39
  redos_bounty_awarded: bool = False
40
  db_degradation_bounty_awarded: bool = False
41
- # Anti-hacking: chain bonus can only fire once between meaningful progress windows.
42
  last_chain_bonus_step: int = -10
43
  post_count_at_last_chain: int = 0
44
- # Baseline integrity: updated only on successful (non-zero latency) steps.
45
  baseline_sample_count: int = 0
46
 
47
  @classmethod
48
- def new_session(cls) -> StateStrikeSession:
49
- """Create a new initialized session.
50
-
51
- Returns:
52
- Newly initialized StateStrikeSession instance.
53
- """
54
-
55
- return cls(session_id=str(uuid4()))
56
-
57
- def reset(self, baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS) -> None:
58
- """Reset session in-place for a new episode.
59
-
60
- Args:
61
- baseline_latency: Fresh baseline latency in milliseconds.
62
- """
63
-
64
  self.session_id = str(uuid4())
 
65
  self.step_count = 0
66
  self.cumulative_reward = 0.0
67
  self.order_count = 0
68
  self.baseline_latency = baseline_latency
69
- self.action_history.clear()
70
- self.triggered_vulns.clear()
 
 
 
 
 
 
 
 
71
  self.redos_bounty_awarded = False
72
  self.db_degradation_bounty_awarded = False
73
  self.last_chain_bonus_step = -10
74
  self.post_count_at_last_chain = 0
75
  self.baseline_sample_count = 1 if baseline_latency > 0 else 0
76
 
77
- def record_latency(self, latency_ms: float) -> float:
78
- """Update baseline latency using EMA from successful samples.
79
-
80
- Args:
81
- latency_ms: Observed latency for the current step.
82
-
83
- Returns:
84
- Updated baseline latency.
85
- """
86
-
87
- sample = max(latency_ms, 1.0)
88
- alpha_ema = 2.0 / (10 + 1)
89
- if self.baseline_sample_count == 0:
90
- self.baseline_latency = sample
91
- else:
92
- self.baseline_latency = alpha_ema * sample + (1 - alpha_ema) * self.baseline_latency
93
- self.baseline_sample_count += 1
94
- return self.baseline_latency
95
-
96
- def append_action(self, action: StateStrikeAction) -> None:
97
- """Append action while enforcing history length constraints.
98
-
99
- Args:
100
- action: Action to append.
101
- """
102
-
103
- self.action_history.append(action)
104
- if len(self.action_history) > MAX_ACTION_HISTORY:
105
- self.action_history.pop(0)
106
-
107
  def as_state(self) -> StateStrikeState:
108
- """Convert mutable session internals to external state model.
109
-
110
- Returns:
111
- Immutable API-safe state representation.
112
- """
113
-
114
  return StateStrikeState(
115
  session_id=self.session_id,
 
116
  step_count=self.step_count,
117
  cumulative_reward=self.cumulative_reward,
118
  order_count=self.order_count,
119
  baseline_latency_ms=self.baseline_latency,
120
- action_history=list(self.action_history),
121
- triggered_vulns=sorted(self.triggered_vulns),
 
122
  )
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
 
3
  from dataclasses import dataclass, field
4
+ from typing import Any
5
  from uuid import uuid4
6
 
7
+ from statestrike_env.constants import DEFAULT_BASELINE_LATENCY_MS
8
+ from statestrike_env.models import StateStrikeState
9
 
10
 
11
  @dataclass
12
  class StateStrikeSession:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  session_id: str
14
+ task_name: str = "endpoint_discovery"
15
  step_count: int = 0
16
  cumulative_reward: float = 0.0
17
  order_count: int = 0
18
  baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS
19
+
20
+ endpoints_discovered: set[str] = field(default_factory=set)
21
+ vulnerabilities_found: set[str] = field(default_factory=set)
22
+ task_specific_state: dict[str, Any] = field(default_factory=dict)
23
+ steps_history: list[dict[str, Any]] = field(default_factory=list)
24
+
25
+ user_created: bool = False
26
+ previous_task_score: float = 0.0
27
+ last_action_signature: str | None = None
28
+
29
  redos_bounty_awarded: bool = False
30
  db_degradation_bounty_awarded: bool = False
 
31
  last_chain_bonus_step: int = -10
32
  post_count_at_last_chain: int = 0
 
33
  baseline_sample_count: int = 0
34
 
35
  @classmethod
36
+ def new_session(cls, task_name: str = "endpoint_discovery") -> StateStrikeSession:
37
+ return cls(session_id=str(uuid4()), task_name=task_name)
38
+
39
+ def reset(
40
+ self,
41
+ task_name: str,
42
+ baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS,
43
+ ) -> None:
 
 
 
 
 
 
 
 
44
  self.session_id = str(uuid4())
45
+ self.task_name = task_name
46
  self.step_count = 0
47
  self.cumulative_reward = 0.0
48
  self.order_count = 0
49
  self.baseline_latency = baseline_latency
50
+
51
+ self.endpoints_discovered.clear()
52
+ self.vulnerabilities_found.clear()
53
+ self.task_specific_state = {}
54
+ self.steps_history.clear()
55
+
56
+ self.user_created = False
57
+ self.previous_task_score = 0.0
58
+ self.last_action_signature = None
59
+
60
  self.redos_bounty_awarded = False
61
  self.db_degradation_bounty_awarded = False
62
  self.last_chain_bonus_step = -10
63
  self.post_count_at_last_chain = 0
64
  self.baseline_sample_count = 1 if baseline_latency > 0 else 0
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def as_state(self) -> StateStrikeState:
 
 
 
 
 
 
67
  return StateStrikeState(
68
  session_id=self.session_id,
69
+ task_name=self.task_name,
70
  step_count=self.step_count,
71
  cumulative_reward=self.cumulative_reward,
72
  order_count=self.order_count,
73
  baseline_latency_ms=self.baseline_latency,
74
+ endpoints_discovered=sorted(self.endpoints_discovered),
75
+ vulnerabilities_found=sorted(self.vulnerabilities_found),
76
+ task_specific_state=dict(self.task_specific_state),
77
  )
78
+
79
+ def as_grader_state(self) -> dict[str, Any]:
80
+ return {
81
+ "endpoints_discovered": sorted(self.endpoints_discovered),
82
+ "vulnerabilities_found": sorted(self.vulnerabilities_found),
83
+ "steps_history": list(self.steps_history),
84
+ "order_count": self.order_count,
85
+ "user_created": self.user_created,
86
+ }
statestrike_env/tasks.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+
7
+ @dataclass
8
+ class TaskConfig:
9
+ task_id: str
10
+ description: str
11
+ max_steps: int
12
+ difficulty: str
13
+ success_threshold: float
14
+
15
+
16
+ TASK_ENDPOINT_DISCOVERY = TaskConfig(
17
+ task_id="endpoint_discovery",
18
+ description="Discover all 5 reachable API endpoints by probing them.",
19
+ max_steps=20,
20
+ difficulty="easy",
21
+ success_threshold=0.6,
22
+ )
23
+
24
+ ALL_ENDPOINTS = {
25
+ "POST /users",
26
+ "GET /users/{id}",
27
+ "POST /orders",
28
+ "GET /orders",
29
+ "GET /health",
30
+ }
31
+
32
+
33
+ class EndpointDiscoveryGrader:
34
+ def score(self, session_state: dict[str, Any]) -> float:
35
+ found = set(session_state.get("endpoints_discovered", []))
36
+ return round(len(found & ALL_ENDPOINTS) / len(ALL_ENDPOINTS), 4)
37
+
38
+
39
+ TASK_VULNERABILITY_PROBE = TaskConfig(
40
+ task_id="vulnerability_probe",
41
+ description=(
42
+ "Identify and correctly classify at least one vulnerability. "
43
+ "Vulnerabilities: 'redos' (POST /users with attack payload, latency > 1500ms), "
44
+ "'db_degradation' (GET /orders after 20+ orders, latency > 800ms)."
45
+ ),
46
+ max_steps=30,
47
+ difficulty="medium",
48
+ success_threshold=0.5,
49
+ )
50
+
51
+
52
+ class VulnerabilityProbeGrader:
53
+ REDOS_LATENCY_THRESHOLD_MS = 1500.0
54
+ DB_LATENCY_THRESHOLD_MS = 800.0
55
+
56
+ def score(self, session_state: dict[str, Any]) -> float:
57
+ vulns = set(session_state.get("vulnerabilities_found", []))
58
+ steps_history = session_state.get("steps_history", [])
59
+
60
+ score = 0.0
61
+
62
+ if any(s.get("endpoint") != "GET /health" for s in steps_history):
63
+ score = max(score, 0.2)
64
+
65
+ if any(s.get("payload_strategy") != "valid" for s in steps_history):
66
+ score = max(score, 0.4)
67
+
68
+ if "redos" in vulns:
69
+ score = max(score, 0.6)
70
+
71
+ if "db_degradation" in vulns:
72
+ score = max(score, 0.8)
73
+
74
+ if "redos" in vulns and "db_degradation" in vulns:
75
+ score = 1.0
76
+
77
+ return round(score, 4)
78
+
79
+
80
+ TASK_EXPLOIT_CHAIN = TaskConfig(
81
+ task_id="exploit_chain",
82
+ description=(
83
+ "Execute the full stateful exploit chain in order: "
84
+ "1. Create a user (POST /users, valid payload). "
85
+ "2. Build order history (POST /orders x 21 for same user). "
86
+ "3. Trigger DB degradation (GET /orders with 21 orders, latency > 800ms). "
87
+ "4. Trigger ReDoS (POST /users, redos payload, latency > 1500ms). "
88
+ "All four sub-objectives must be completed for full score."
89
+ ),
90
+ max_steps=60,
91
+ difficulty="hard",
92
+ success_threshold=0.75,
93
+ )
94
+
95
+
96
+ class ExploitChainGrader:
97
+ def score(self, session_state: dict[str, Any]) -> float:
98
+ score = 0.0
99
+
100
+ if session_state.get("user_created", False):
101
+ score += 0.25
102
+ if session_state.get("order_count", 0) >= 21:
103
+ score += 0.25
104
+ if "db_degradation" in session_state.get("vulnerabilities_found", []):
105
+ score += 0.25
106
+ if "redos" in session_state.get("vulnerabilities_found", []):
107
+ score += 0.25
108
+
109
+ return round(score, 4)
110
+
111
+
112
+ TASK_REGISTRY = {
113
+ "endpoint_discovery": (TASK_ENDPOINT_DISCOVERY, EndpointDiscoveryGrader()),
114
+ "vulnerability_probe": (TASK_VULNERABILITY_PROBE, VulnerabilityProbeGrader()),
115
+ "exploit_chain": (TASK_EXPLOIT_CHAIN, ExploitChainGrader()),
116
+ }