Spaces:
Sleeping
Sleeping
feat: unified single-container deployment with all 3 tasks + inference.py
Browse files- Dockerfile +19 -4
- honeypot/README.md +9 -0
- honeypot/__init__.py +3 -0
- honeypot/app.py +197 -0
- honeypot/database.py +57 -0
- honeypot/middleware.py +81 -0
- honeypot/models.py +44 -0
- inference.py +213 -0
- openenv.yaml +41 -0
- requirements.txt +2 -2
- scripts/run_demo.sh +28 -0
- scripts/setup_env.sh +22 -0
- scripts/start.sh +21 -0
- statestrike_env/README.md +1 -0
- statestrike_env/__init__.py +19 -137
- statestrike_env/__pycache__/__init__.cpython-311.pyc +0 -0
- statestrike_env/__pycache__/constants.cpython-313.pyc +0 -0
- statestrike_env/__pycache__/models.cpython-311.pyc +0 -0
- statestrike_env/__pycache__/server.cpython-313.pyc +0 -0
- statestrike_env/constants.py +24 -41
- statestrike_env/environment.py +332 -0
- statestrike_env/grader.py +76 -185
- statestrike_env/models.py +30 -59
- statestrike_env/server.py +3 -396
- statestrike_env/session.py +46 -82
- statestrike_env/tasks.py +116 -0
Dockerfile
CHANGED
|
@@ -1,7 +1,22 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
WORKDIR /app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
COPY . .
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
EXPOSE 7860
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
COPY requirements.txt .
|
| 8 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
+
|
| 10 |
COPY . .
|
| 11 |
+
|
| 12 |
+
RUN touch /app/statestrike.db
|
| 13 |
+
|
| 14 |
+
COPY scripts/start.sh /start.sh
|
| 15 |
+
RUN chmod +x /start.sh
|
| 16 |
+
|
| 17 |
EXPOSE 7860
|
| 18 |
+
|
| 19 |
+
HEALTHCHECK --interval=10s --timeout=5s --start-period=20s --retries=5 \
|
| 20 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 21 |
+
|
| 22 |
+
CMD ["/start.sh"]
|
honeypot/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: StateStrike Environment
|
| 3 |
+
emoji: 🎯
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: true
|
| 8 |
+
tags: ["openenv", "hackathon", "security", "rl"]
|
| 9 |
+
---
|
honeypot/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
"""StateStrike vulnerable honeypot API package."""
|
honeypot/app.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
"""FastAPI honeypot target for stateful API fuzzing experiments."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
from datetime import datetime, timezone
|
| 9 |
+
|
| 10 |
+
from fastapi import Depends, FastAPI, HTTPException, Query
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
from sqlalchemy.orm import Session
|
| 13 |
+
|
| 14 |
+
from honeypot.database import get_db, init_db
|
| 15 |
+
from honeypot.middleware import TelemetryMiddleware, create_telemetry_router
|
| 16 |
+
from honeypot.models import Order, User
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 19 |
+
LOGGER = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
app = FastAPI(title="StateStrike Honeypot", version="1.0.0")
|
| 22 |
+
app.add_middleware(TelemetryMiddleware)
|
| 23 |
+
app.include_router(create_telemetry_router())
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class UserCreate(BaseModel):
|
| 27 |
+
"""Payload for POST /users."""
|
| 28 |
+
|
| 29 |
+
email: str = Field(min_length=1, max_length=256)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class OrderCreate(BaseModel):
|
| 33 |
+
"""Payload for POST /orders."""
|
| 34 |
+
|
| 35 |
+
user_id: int
|
| 36 |
+
item: str = Field(min_length=1, max_length=256)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@app.on_event("startup")
|
| 40 |
+
async def on_startup() -> None:
|
| 41 |
+
"""Initialize database tables at service startup."""
|
| 42 |
+
|
| 43 |
+
init_db()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@app.get("/health")
|
| 47 |
+
def health_check() -> dict[str, object]:
|
| 48 |
+
"""Return liveness and timestamp information.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dictionary with service status and UNIX timestamp.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
return {"status": "ok", "ts": int(time.time())}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.post("/users")
|
| 58 |
+
def create_user(payload: UserCreate, db: Session = Depends(get_db)) -> dict[str, object]:
|
| 59 |
+
"""Create a user with intentionally vulnerable regex validation.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
payload: User creation body.
|
| 63 |
+
db: SQLAlchemy session.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Created user dictionary.
|
| 67 |
+
|
| 68 |
+
Raises:
|
| 69 |
+
HTTPException: If email validation fails.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
pattern = r"^([a-zA-Z0-9]+\s?)*[a-zA-Z0-9]+$"
|
| 73 |
+
|
| 74 |
+
# VULNERABILITY: ReDoS via catastrophic backtracking
|
| 75 |
+
# Reference: Davis et al., "ReDoS in the Wild" (USENIX Security 2018)
|
| 76 |
+
# This pattern exhibits O(2^n) backtracking on input "aaa...a!"
|
| 77 |
+
# A production-hardened alternative would use: re2 or a finite automaton
|
| 78 |
+
if not re.fullmatch(pattern, payload.email, flags=re.DOTALL):
|
| 79 |
+
raise HTTPException(status_code=400, detail="Invalid email format")
|
| 80 |
+
|
| 81 |
+
user = User(email=payload.email)
|
| 82 |
+
db.add(user)
|
| 83 |
+
db.commit()
|
| 84 |
+
db.refresh(user)
|
| 85 |
+
|
| 86 |
+
return {"id": user.id, "email": user.email, "created_at": user.created_at.isoformat()}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.get("/users/{user_id}")
|
| 90 |
+
def get_user(user_id: int, db: Session = Depends(get_db)) -> dict[str, object]:
|
| 91 |
+
"""Fetch user by identifier.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
user_id: User identifier.
|
| 95 |
+
db: SQLAlchemy session.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
User dictionary.
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
HTTPException: If user does not exist.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
user = db.query(User).filter(User.id == user_id).first()
|
| 105 |
+
if user is None:
|
| 106 |
+
raise HTTPException(status_code=404, detail="User not found")
|
| 107 |
+
return {"id": user.id, "email": user.email, "created_at": user.created_at.isoformat()}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.post("/orders")
|
| 111 |
+
def create_order(payload: OrderCreate, db: Session = Depends(get_db)) -> dict[str, object]:
|
| 112 |
+
"""Create an order for an existing user.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
payload: Order creation body.
|
| 116 |
+
db: SQLAlchemy session.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Created order dictionary.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
HTTPException: If user does not exist.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
user = db.query(User).filter(User.id == payload.user_id).first()
|
| 126 |
+
if user is None:
|
| 127 |
+
raise HTTPException(status_code=404, detail="User not found")
|
| 128 |
+
|
| 129 |
+
order = Order(user_id=payload.user_id, item=payload.item)
|
| 130 |
+
db.add(order)
|
| 131 |
+
db.commit()
|
| 132 |
+
db.refresh(order)
|
| 133 |
+
|
| 134 |
+
return {
|
| 135 |
+
"id": order.id,
|
| 136 |
+
"user_id": order.user_id,
|
| 137 |
+
"item": order.item,
|
| 138 |
+
"created_at": order.created_at.isoformat(),
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.get("/orders")
|
| 143 |
+
def list_orders(
|
| 144 |
+
user_id: int | None = Query(default=None),
|
| 145 |
+
db: Session = Depends(get_db),
|
| 146 |
+
) -> dict[str, object]:
|
| 147 |
+
"""List orders and expose intentional stateful degradation path.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
user_id: Optional user filter and degradation trigger key.
|
| 151 |
+
db: SQLAlchemy session.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Order list payload with count metadata.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
query = db.query(Order)
|
| 158 |
+
if user_id is not None:
|
| 159 |
+
query = query.filter(Order.user_id == user_id)
|
| 160 |
+
|
| 161 |
+
orders = query.all()
|
| 162 |
+
|
| 163 |
+
if user_id is not None and len(orders) > 20:
|
| 164 |
+
# VULNERABILITY: Unindexed aggregate query degradation
|
| 165 |
+
# Only reachable after stateful chain: 21x POST /orders -> GET /orders
|
| 166 |
+
# An RL agent can discover this; a stateless fuzzer cannot.
|
| 167 |
+
# Reference: RESTler (Atlidakis et al., ICSE 2019) pioneered stateful
|
| 168 |
+
# REST fuzzing but used grammar-based, not RL-based exploration.
|
| 169 |
+
all_orders = db.query(Order).all()
|
| 170 |
+
expensive_aggregate: dict[int, int] = {}
|
| 171 |
+
for left in all_orders:
|
| 172 |
+
total = 0
|
| 173 |
+
for right in all_orders:
|
| 174 |
+
if left.user_id == right.user_id:
|
| 175 |
+
total += 1
|
| 176 |
+
expensive_aggregate[left.user_id] = total
|
| 177 |
+
LOGGER.info(
|
| 178 |
+
"Triggered synthetic O(n^2) aggregate for user_id=%s with %s total rows",
|
| 179 |
+
user_id,
|
| 180 |
+
len(all_orders),
|
| 181 |
+
)
|
| 182 |
+
time.sleep(0.8)
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"count": len(orders),
|
| 186 |
+
"orders": [
|
| 187 |
+
{
|
| 188 |
+
"id": order.id,
|
| 189 |
+
"user_id": order.user_id,
|
| 190 |
+
"item": order.item,
|
| 191 |
+
"created_at": order.created_at.isoformat()
|
| 192 |
+
if isinstance(order.created_at, datetime)
|
| 193 |
+
else datetime.now(timezone.utc).isoformat(),
|
| 194 |
+
}
|
| 195 |
+
for order in orders
|
| 196 |
+
],
|
| 197 |
+
}
|
honeypot/database.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
"""Database setup for the StateStrike honeypot.
|
| 4 |
+
|
| 5 |
+
Theory:
|
| 6 |
+
A local SQLite backend keeps the demo deterministic and lightweight while
|
| 7 |
+
preserving enough statefulness for multi-step fuzzing trajectories.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from collections.abc import Generator
|
| 13 |
+
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
from sqlalchemy import create_engine
|
| 16 |
+
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
| 17 |
+
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
LOGGER = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
DATABASE_FILE = os.getenv("DATABASE_FILE", "statestrike.db")
|
| 23 |
+
DATABASE_URL = f"sqlite:///{DATABASE_FILE}"
|
| 24 |
+
|
| 25 |
+
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
| 26 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 27 |
+
Base = declarative_base()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_db() -> Generator[Session, None, None]:
|
| 31 |
+
"""Yield a SQLAlchemy session for request-scoped DB access.
|
| 32 |
+
|
| 33 |
+
Yields:
|
| 34 |
+
An open SQLAlchemy Session object.
|
| 35 |
+
|
| 36 |
+
Raises:
|
| 37 |
+
RuntimeError: If session creation fails.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
db = SessionLocal()
|
| 41 |
+
try:
|
| 42 |
+
yield db
|
| 43 |
+
finally:
|
| 44 |
+
db.close()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def init_db() -> None:
|
| 48 |
+
"""Create database schema if tables do not yet exist.
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
Exception: Propagates SQLAlchemy creation errors.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
from honeypot import models # Local import avoids circular import at module load.
|
| 55 |
+
|
| 56 |
+
Base.metadata.create_all(bind=engine)
|
| 57 |
+
LOGGER.info("Initialized SQLite schema at %s", DATABASE_URL)
|
honeypot/middleware.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
"""Telemetry middleware and stream endpoint for honeypot observations."""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from collections import deque
|
| 8 |
+
from collections.abc import AsyncIterator
|
| 9 |
+
from datetime import datetime, timezone
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter
|
| 12 |
+
from fastapi.responses import StreamingResponse
|
| 13 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 14 |
+
from starlette.requests import Request
|
| 15 |
+
from starlette.responses import Response
|
| 16 |
+
|
| 17 |
+
TELEMETRY_BUFFER: deque[dict[str, object]] = deque(maxlen=500)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TelemetryMiddleware(BaseHTTPMiddleware):
|
| 21 |
+
"""Capture request latency and expose response timing metadata.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
app: Wrapped ASGI application.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
None. Middleware mutates response headers and side effects telemetry buffer.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
async def dispatch(self, request: Request, call_next) -> Response:
|
| 31 |
+
"""Process an incoming request and append telemetry entry.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
request: Starlette request object.
|
| 35 |
+
call_next: Next middleware/app callable.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
The downstream response with X-Process-Time-Ms header attached.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
start = time.perf_counter()
|
| 42 |
+
response = await call_next(request)
|
| 43 |
+
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
| 44 |
+
response.headers["X-Process-Time-Ms"] = f"{elapsed_ms:.3f}"
|
| 45 |
+
|
| 46 |
+
TELEMETRY_BUFFER.append(
|
| 47 |
+
{
|
| 48 |
+
"ts": datetime.now(timezone.utc).isoformat(),
|
| 49 |
+
"path": request.url.path,
|
| 50 |
+
"method": request.method,
|
| 51 |
+
"status": response.status_code,
|
| 52 |
+
"latency_ms": round(elapsed_ms, 3),
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
return response
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def create_telemetry_router() -> APIRouter:
|
| 59 |
+
"""Create telemetry SSE routes.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
A FastAPI router exposing telemetry streaming endpoint.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
|
| 66 |
+
|
| 67 |
+
@router.get("/stream")
|
| 68 |
+
async def stream_recent_entries() -> StreamingResponse:
|
| 69 |
+
"""Emit the latest telemetry entries over Server-Sent Events.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
StreamingResponse configured with text/event-stream media type.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
async def event_source() -> AsyncIterator[str]:
|
| 76 |
+
payload = json.dumps(list(TELEMETRY_BUFFER)[-100:])
|
| 77 |
+
yield f"data: {payload}\n\n"
|
| 78 |
+
|
| 79 |
+
return StreamingResponse(event_source(), media_type="text/event-stream")
|
| 80 |
+
|
| 81 |
+
return router
|
honeypot/models.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
"""ORM models used by the honeypot API."""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime, timezone
|
| 6 |
+
|
| 7 |
+
from sqlalchemy import DateTime, ForeignKey, Integer, String
|
| 8 |
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
| 9 |
+
|
| 10 |
+
from honeypot.database import Base
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class User(Base):
|
| 14 |
+
"""User entity created by POST /users."""
|
| 15 |
+
|
| 16 |
+
__tablename__ = "users"
|
| 17 |
+
|
| 18 |
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
| 19 |
+
email: Mapped[str] = mapped_column(String(256), unique=False, nullable=False)
|
| 20 |
+
created_at: Mapped[datetime] = mapped_column(
|
| 21 |
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
orders: Mapped[list[Order]] = relationship("Order", back_populates="user")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Order(Base):
|
| 28 |
+
"""Order entity created by POST /orders.
|
| 29 |
+
|
| 30 |
+
Note:
|
| 31 |
+
The `user_id` field intentionally has no explicit DB index to preserve the
|
| 32 |
+
degradation path used by the challenge.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
__tablename__ = "orders"
|
| 36 |
+
|
| 37 |
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
| 38 |
+
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
| 39 |
+
item: Mapped[str] = mapped_column(String(256), nullable=False)
|
| 40 |
+
created_at: Mapped[datetime] = mapped_column(
|
| 41 |
+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
user: Mapped[User] = relationship("User", back_populates="orders")
|
inference.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
StateStrike Inference Script
|
| 3 |
+
============================
|
| 4 |
+
Runs an LLM agent against all 3 StateStrike tasks and emits
|
| 5 |
+
structured [START]/[STEP]/[END] logs for automated scoring.
|
| 6 |
+
|
| 7 |
+
Environment variables:
|
| 8 |
+
API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1)
|
| 9 |
+
MODEL_NAME Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
|
| 10 |
+
HF_TOKEN Hugging Face API token (required, no default)
|
| 11 |
+
LOCAL_IMAGE_NAME Docker image name if using from_docker_image()
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import textwrap
|
| 20 |
+
from typing import List, Optional
|
| 21 |
+
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
|
| 24 |
+
from statestrike_env.environment import StateStrikeEnv
|
| 25 |
+
from statestrike_env.models import EndpointChoice, PayloadStrategy, StateStrikeAction
|
| 26 |
+
|
| 27 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 28 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 29 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 30 |
+
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 31 |
+
BENCHMARK = "statestrike"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 35 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 39 |
+
error_val = error if error else "null"
|
| 40 |
+
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 44 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 45 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
TASK_CONFIGS = {
|
| 49 |
+
"endpoint_discovery": {"max_steps": 10, "success_threshold": 0.6},
|
| 50 |
+
"vulnerability_probe": {"max_steps": 15, "success_threshold": 0.5},
|
| 51 |
+
"exploit_chain": {"max_steps": 30, "success_threshold": 0.75},
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
SYSTEM_PROMPTS = {
|
| 56 |
+
"endpoint_discovery": textwrap.dedent(
|
| 57 |
+
"""
|
| 58 |
+
You are a security agent probing an API to discover all available endpoints.
|
| 59 |
+
Known endpoints to try: POST /users, GET /users/{id}, POST /orders, GET /orders, GET /health
|
| 60 |
+
At each step, choose ONE endpoint to call and a payload strategy.
|
| 61 |
+
Reply with a JSON object: {"endpoint": "POST /users", "payload_strategy": "valid", "target_user_id": null}
|
| 62 |
+
Try each endpoint at least once. Vary payload strategies to confirm reachability.
|
| 63 |
+
"""
|
| 64 |
+
).strip(),
|
| 65 |
+
"vulnerability_probe": textwrap.dedent(
|
| 66 |
+
"""
|
| 67 |
+
You are a security agent looking for vulnerabilities in an API.
|
| 68 |
+
Two vulnerabilities exist:
|
| 69 |
+
1. ReDoS: POST /users with payload_strategy "redos" causes CPU spike (latency > 1500ms, returns 400)
|
| 70 |
+
2. DB Degradation: GET /orders after 20+ orders causes slow query (latency > 800ms)
|
| 71 |
+
Reply with JSON: {"endpoint": "POST /users", "payload_strategy": "redos", "target_user_id": null}
|
| 72 |
+
Valid endpoint values: "POST /users", "GET /users/{id}", "POST /orders", "GET /orders", "GET /health"
|
| 73 |
+
Valid strategies: "valid", "redos", "oversized", "malformed"
|
| 74 |
+
"""
|
| 75 |
+
).strip(),
|
| 76 |
+
"exploit_chain": textwrap.dedent(
|
| 77 |
+
"""
|
| 78 |
+
You are a security agent executing a full exploit chain. Complete these in order:
|
| 79 |
+
1. Create a user: POST /users with payload_strategy "valid"
|
| 80 |
+
2. Create 21 orders: POST /orders with payload_strategy "valid", target_user_id=<user_id from step 1>
|
| 81 |
+
3. Trigger slow query: GET /orders with target_user_id set
|
| 82 |
+
4. Trigger ReDoS: POST /users with payload_strategy "redos"
|
| 83 |
+
Track your progress. Each step builds on the previous.
|
| 84 |
+
Reply with JSON: {"endpoint": "POST /orders", "payload_strategy": "valid", "target_user_id": 1}
|
| 85 |
+
"""
|
| 86 |
+
).strip(),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_agent_action(
|
| 91 |
+
client: OpenAI,
|
| 92 |
+
task_name: str,
|
| 93 |
+
step: int,
|
| 94 |
+
last_obs: dict,
|
| 95 |
+
history: List[str],
|
| 96 |
+
) -> StateStrikeAction:
|
| 97 |
+
system = SYSTEM_PROMPTS[task_name]
|
| 98 |
+
history_block = "\n".join(history[-5:]) if history else "None"
|
| 99 |
+
user_msg = textwrap.dedent(
|
| 100 |
+
f"""
|
| 101 |
+
Step: {step}
|
| 102 |
+
Last observation: {json.dumps(last_obs, indent=2)}
|
| 103 |
+
Recent history:
|
| 104 |
+
{history_block}
|
| 105 |
+
What is your next action? Reply with JSON only.
|
| 106 |
+
"""
|
| 107 |
+
).strip()
|
| 108 |
+
|
| 109 |
+
fallback = StateStrikeAction(
|
| 110 |
+
endpoint=EndpointChoice.HEALTH,
|
| 111 |
+
payload_strategy=PayloadStrategy.VALID,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
completion = client.chat.completions.create(
|
| 116 |
+
model=MODEL_NAME,
|
| 117 |
+
messages=[
|
| 118 |
+
{"role": "system", "content": system},
|
| 119 |
+
{"role": "user", "content": user_msg},
|
| 120 |
+
],
|
| 121 |
+
temperature=0.7,
|
| 122 |
+
max_tokens=100,
|
| 123 |
+
)
|
| 124 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 125 |
+
text = text.removeprefix("```json").removesuffix("```").strip()
|
| 126 |
+
data = json.loads(text)
|
| 127 |
+
return StateStrikeAction(**data)
|
| 128 |
+
except Exception as exc:
|
| 129 |
+
print(f"[DEBUG] Action parse failed: {exc}", flush=True)
|
| 130 |
+
return fallback
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
async def run_task(
|
| 134 |
+
env: StateStrikeEnv,
|
| 135 |
+
client: OpenAI,
|
| 136 |
+
task_name: str,
|
| 137 |
+
) -> float:
|
| 138 |
+
config = TASK_CONFIGS[task_name]
|
| 139 |
+
max_steps = config["max_steps"]
|
| 140 |
+
success_threshold = config["success_threshold"]
|
| 141 |
+
|
| 142 |
+
rewards: List[float] = []
|
| 143 |
+
steps_taken = 0
|
| 144 |
+
score = 0.0
|
| 145 |
+
success = False
|
| 146 |
+
history: List[str] = []
|
| 147 |
+
|
| 148 |
+
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
result = await env.reset(task_name=task_name)
|
| 152 |
+
obs = result.observation
|
| 153 |
+
last_obs_dict = obs.model_dump()
|
| 154 |
+
|
| 155 |
+
for step in range(1, max_steps + 1):
|
| 156 |
+
if result.done:
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
action = get_agent_action(client, task_name, step, last_obs_dict, history)
|
| 160 |
+
action_str = f"{action.endpoint}+{action.payload_strategy}"
|
| 161 |
+
|
| 162 |
+
result = await env.step(action)
|
| 163 |
+
obs = result.observation
|
| 164 |
+
reward = result.reward or 0.0
|
| 165 |
+
done = result.done
|
| 166 |
+
error = result.info.get("error") if isinstance(result.info, dict) else None
|
| 167 |
+
|
| 168 |
+
rewards.append(reward)
|
| 169 |
+
steps_taken = step
|
| 170 |
+
last_obs_dict = obs.model_dump()
|
| 171 |
+
|
| 172 |
+
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
|
| 173 |
+
history.append(
|
| 174 |
+
f"Step {step}: {action_str} -> status={obs.http_status} "
|
| 175 |
+
f"latency={obs.latency_ms:.0f}ms reward={reward:.2f}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if done:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
score = min(max(obs.task_progress, 0.0), 1.0)
|
| 182 |
+
success = score >= success_threshold
|
| 183 |
+
|
| 184 |
+
except Exception as exc:
|
| 185 |
+
print(f"[DEBUG] Task {task_name} failed: {exc}", flush=True)
|
| 186 |
+
finally:
|
| 187 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 188 |
+
|
| 189 |
+
return score
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
async def main() -> None:
|
| 193 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 194 |
+
|
| 195 |
+
if LOCAL_IMAGE_NAME:
|
| 196 |
+
env = await StateStrikeEnv.from_docker_image(LOCAL_IMAGE_NAME)
|
| 197 |
+
else:
|
| 198 |
+
env = StateStrikeEnv()
|
| 199 |
+
|
| 200 |
+
scores = {}
|
| 201 |
+
for task_name in ["endpoint_discovery", "vulnerability_probe", "exploit_chain"]:
|
| 202 |
+
score = await run_task(env, client, task_name)
|
| 203 |
+
scores[task_name] = score
|
| 204 |
+
|
| 205 |
+
await env.close()
|
| 206 |
+
|
| 207 |
+
print(f"\n[DEBUG] Final scores: {scores}", flush=True)
|
| 208 |
+
avg = sum(scores.values()) / len(scores)
|
| 209 |
+
print(f"[DEBUG] Average score: {avg:.3f}", flush=True)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
asyncio.run(main())
|
openenv.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: statestrike
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A stateful API security audit environment where an agent learns to discover
|
| 5 |
+
real infrastructure vulnerabilities through systematic endpoint exploration
|
| 6 |
+
and stateful exploit chaining.
|
| 7 |
+
author: StateStrike Team
|
| 8 |
+
license: MIT
|
| 9 |
+
tags:
|
| 10 |
+
- security
|
| 11 |
+
- api-testing
|
| 12 |
+
- stateful
|
| 13 |
+
- openenv
|
| 14 |
+
tasks:
|
| 15 |
+
- id: endpoint_discovery
|
| 16 |
+
description: >
|
| 17 |
+
Identify all reachable API endpoints. Agent receives the base URL
|
| 18 |
+
and must probe systematically to discover which endpoints exist.
|
| 19 |
+
difficulty: easy
|
| 20 |
+
max_steps: 20
|
| 21 |
+
reward_range: [0.0, 1.0]
|
| 22 |
+
- id: vulnerability_probe
|
| 23 |
+
description: >
|
| 24 |
+
Identify and correctly classify at least one vulnerability in the
|
| 25 |
+
target API (redos or db_degradation).
|
| 26 |
+
difficulty: medium
|
| 27 |
+
max_steps: 30
|
| 28 |
+
reward_range: [0.0, 1.0]
|
| 29 |
+
- id: exploit_chain
|
| 30 |
+
description: >
|
| 31 |
+
Execute the full stateful exploit chain: create user, build order
|
| 32 |
+
history, trigger DB degradation slow path, and trigger ReDoS.
|
| 33 |
+
difficulty: hard
|
| 34 |
+
max_steps: 60
|
| 35 |
+
reward_range: [0.0, 1.0]
|
| 36 |
+
observation_space:
|
| 37 |
+
type: object
|
| 38 |
+
description: HTTP response details including status, latency, body, and session state
|
| 39 |
+
action_space:
|
| 40 |
+
type: object
|
| 41 |
+
description: HTTP action with endpoint choice, method, and payload strategy
|
requirements.txt
CHANGED
|
@@ -10,5 +10,5 @@ python-dotenv==1.0.1
|
|
| 10 |
pytest==8.2.0
|
| 11 |
pytest-asyncio==0.23.7
|
| 12 |
rich==13.7.1
|
| 13 |
-
websockets=
|
| 14 |
-
|
|
|
|
| 10 |
pytest==8.2.0
|
| 11 |
pytest-asyncio==0.23.7
|
| 12 |
rich==13.7.1
|
| 13 |
+
websockets>=15.0
|
| 14 |
+
openai>=1.0.0
|
scripts/run_demo.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# StateStrike Demo Launch Script
|
| 5 |
+
# Starts all services and runs 200-step agent demo
|
| 6 |
+
|
| 7 |
+
echo "🎯 StateStrike — OpenEnv Hackathon Demo"
|
| 8 |
+
echo "Starting honeypot API..."
|
| 9 |
+
uvicorn honeypot.app:app --port 8000 &
|
| 10 |
+
HONEY_PID=$!
|
| 11 |
+
|
| 12 |
+
echo "Starting OpenEnv environment server..."
|
| 13 |
+
python -m statestrike_env.server &
|
| 14 |
+
ENV_PID=$!
|
| 15 |
+
|
| 16 |
+
echo "Starting dashboard..."
|
| 17 |
+
streamlit run dashboard/app.py --server.port 8501 &
|
| 18 |
+
DASH_PID=$!
|
| 19 |
+
|
| 20 |
+
cleanup() {
|
| 21 |
+
kill "$HONEY_PID" "$ENV_PID" "$DASH_PID" 2>/dev/null || true
|
| 22 |
+
}
|
| 23 |
+
trap cleanup EXIT
|
| 24 |
+
|
| 25 |
+
sleep 3
|
| 26 |
+
echo "All services up. Running agent..."
|
| 27 |
+
echo "Dashboard: http://localhost:8501"
|
| 28 |
+
python -m agent.runner --steps 200
|
scripts/setup_env.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# StateStrike bootstrap script
|
| 5 |
+
# Creates local env file and installs pinned dependencies.
|
| 6 |
+
|
| 7 |
+
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 8 |
+
cd "$ROOT_DIR"
|
| 9 |
+
|
| 10 |
+
if [ ! -f .env ]; then
|
| 11 |
+
cp .env.example .env
|
| 12 |
+
echo "Created .env from .env.example"
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
python -m pip install --upgrade pip
|
| 16 |
+
python -m pip install -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Ensure Docker bind-mount file targets exist as files.
|
| 19 |
+
touch statestrike.db
|
| 20 |
+
touch telemetry.json
|
| 21 |
+
|
| 22 |
+
echo "StateStrike environment setup complete."
|
scripts/start.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
echo "[StateStrike] Starting honeypot on port 8000..."
|
| 5 |
+
uvicorn honeypot.app:app --host 0.0.0.0 --port 8000 &
|
| 6 |
+
HONEYPOT_PID=$!
|
| 7 |
+
|
| 8 |
+
echo "[StateStrike] Waiting for honeypot..."
|
| 9 |
+
for i in $(seq 1 30); do
|
| 10 |
+
if curl -sf http://localhost:8000/health > /dev/null 2>&1; then
|
| 11 |
+
echo "[StateStrike] Honeypot ready."
|
| 12 |
+
break
|
| 13 |
+
fi
|
| 14 |
+
sleep 1
|
| 15 |
+
done
|
| 16 |
+
|
| 17 |
+
echo "[StateStrike] Starting environment server on port 7860..."
|
| 18 |
+
export HONEYPOT_URL="http://localhost:8000"
|
| 19 |
+
uvicorn statestrike_env.environment:app --host 0.0.0.0 --port 7860
|
| 20 |
+
|
| 21 |
+
wait $HONEYPOT_PID
|
statestrike_env/README.md
CHANGED
|
@@ -5,4 +5,5 @@ colorFrom: red
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: true
|
|
|
|
| 8 |
---
|
|
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: true
|
| 8 |
+
tags: ["openenv", "hackathon", "security", "rl"]
|
| 9 |
---
|
statestrike_env/__init__.py
CHANGED
|
@@ -1,139 +1,21 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
normalized = base_url.rstrip("/")
|
| 25 |
-
self.base_url = normalized if normalized.endswith("/ws") else f"{normalized}/ws"
|
| 26 |
-
self._conn: ClientConnection | None = None
|
| 27 |
-
|
| 28 |
-
def __enter__(self) -> "_SyncStateStrikeClient":
|
| 29 |
-
"""Open WebSocket connection for environment operations.
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
Connected client instance.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
self._conn = connect(self.base_url)
|
| 36 |
-
return self
|
| 37 |
-
|
| 38 |
-
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
| 39 |
-
"""Close WebSocket connection.
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
exc_type: Exception type if raised in context block.
|
| 43 |
-
exc: Exception value if raised in context block.
|
| 44 |
-
tb: Traceback object if raised in context block.
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
if self._conn is not None:
|
| 48 |
-
self._conn.close()
|
| 49 |
-
self._conn = None
|
| 50 |
-
|
| 51 |
-
def reset(self) -> StateStrikeObservation:
|
| 52 |
-
"""Request environment reset.
|
| 53 |
-
|
| 54 |
-
Returns:
|
| 55 |
-
Initial observation.
|
| 56 |
-
|
| 57 |
-
Raises:
|
| 58 |
-
RuntimeError: If the server response is malformed or unsuccessful.
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
frame = self._request({"method": "reset"})
|
| 62 |
-
return StateStrikeObservation.model_validate(frame["observation"])
|
| 63 |
-
|
| 64 |
-
def step(self, action: StateStrikeAction) -> StateStrikeObservation:
|
| 65 |
-
"""Execute one environment step.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
action: Action payload.
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
Updated observation.
|
| 72 |
-
|
| 73 |
-
Raises:
|
| 74 |
-
RuntimeError: If the server response is malformed or unsuccessful.
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
-
frame = self._request({"method": "step", "action": action.model_dump()})
|
| 78 |
-
return StateStrikeObservation.model_validate(frame["observation"])
|
| 79 |
-
|
| 80 |
-
def state(self) -> StateStrikeState:
|
| 81 |
-
"""Retrieve current environment state.
|
| 82 |
-
|
| 83 |
-
Returns:
|
| 84 |
-
Current state model.
|
| 85 |
-
|
| 86 |
-
Raises:
|
| 87 |
-
RuntimeError: If the server response is malformed or unsuccessful.
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
-
frame = self._request({"method": "state"})
|
| 91 |
-
return StateStrikeState.model_validate(frame["state"])
|
| 92 |
-
|
| 93 |
-
def _request(self, payload: dict[str, Any]) -> dict[str, Any]:
|
| 94 |
-
"""Send request frame and parse server response.
|
| 95 |
-
|
| 96 |
-
Args:
|
| 97 |
-
payload: JSON-serializable request payload.
|
| 98 |
-
|
| 99 |
-
Returns:
|
| 100 |
-
Parsed response object.
|
| 101 |
-
|
| 102 |
-
Raises:
|
| 103 |
-
RuntimeError: If connection is closed or server reports failure.
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
if self._conn is None:
|
| 107 |
-
raise RuntimeError("WebSocket connection is not open")
|
| 108 |
-
|
| 109 |
-
self._conn.send(json.dumps(payload))
|
| 110 |
-
raw = self._conn.recv()
|
| 111 |
-
frame = json.loads(raw)
|
| 112 |
-
if not frame.get("ok"):
|
| 113 |
-
raise RuntimeError(frame.get("error", "Unknown server error"))
|
| 114 |
-
return frame
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
class StateStrikeEnv:
|
| 118 |
-
"""Environment client namespace matching OpenEnv SDK usage patterns."""
|
| 119 |
-
|
| 120 |
-
def __init__(self, base_url: str = "ws://localhost:8001/ws") -> None:
|
| 121 |
-
"""Store base URL for later sync client creation.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
base_url: Environment WebSocket endpoint.
|
| 125 |
-
"""
|
| 126 |
-
|
| 127 |
-
self.base_url = base_url
|
| 128 |
-
|
| 129 |
-
def sync(self) -> _SyncStateStrikeClient:
|
| 130 |
-
"""Create synchronous context-managed client.
|
| 131 |
-
|
| 132 |
-
Returns:
|
| 133 |
-
A synchronous environment client implementing reset/step/state.
|
| 134 |
-
"""
|
| 135 |
-
|
| 136 |
-
return _SyncStateStrikeClient(self.base_url)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
__all__ = ["StateStrikeEnv", "StateStrikeAction", "StateStrikeObservation", "StateStrikeState"]
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from statestrike_env.environment import StateStrikeEnv
|
| 4 |
+
from statestrike_env.models import (
|
| 5 |
+
EndpointChoice,
|
| 6 |
+
PayloadStrategy,
|
| 7 |
+
StateStrikeAction,
|
| 8 |
+
StateStrikeObservation,
|
| 9 |
+
StateStrikeState,
|
| 10 |
+
StepResult,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"StateStrikeEnv",
|
| 15 |
+
"EndpointChoice",
|
| 16 |
+
"PayloadStrategy",
|
| 17 |
+
"StateStrikeAction",
|
| 18 |
+
"StateStrikeObservation",
|
| 19 |
+
"StateStrikeState",
|
| 20 |
+
"StepResult",
|
| 21 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
statestrike_env/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (6.4 kB)
|
|
|
statestrike_env/__pycache__/constants.cpython-313.pyc
DELETED
|
Binary file (1.85 kB)
|
|
|
statestrike_env/__pycache__/models.cpython-311.pyc
DELETED
|
Binary file (4.42 kB)
|
|
|
statestrike_env/__pycache__/server.cpython-313.pyc
DELETED
|
Binary file (24.3 kB)
|
|
|
statestrike_env/constants.py
CHANGED
|
@@ -1,56 +1,39 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
"""Centralized constants for StateStrike environment and reward grading.
|
| 4 |
-
|
| 5 |
-
Theory:
|
| 6 |
-
Consolidating reward and episode hyperparameters avoids hidden magic numbers,
|
| 7 |
-
supports reproducibility, and aligns with RL experiment hygiene guidance from
|
| 8 |
-
Sutton & Barto (2018).
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
from dataclasses import dataclass
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
BETA = 10.0
|
| 16 |
-
GAMMA = 500.0
|
| 17 |
-
DELTA = 1.0
|
| 18 |
BASELINE_WINDOW = 10
|
| 19 |
-
EPISODE_LENGTH = 200
|
| 20 |
-
REDOS_LATENCY_THRESHOLD = 1500.0
|
| 21 |
-
DB_TIMEOUT_THRESHOLD = 3000.0
|
| 22 |
-
CHAIN_REQUIRED_ORDERS = 20
|
| 23 |
-
CHAIN_COOLDOWN_STEPS = 10
|
| 24 |
-
MAX_ACTION_HISTORY = 20
|
| 25 |
ACTION_TIMEOUT_SECONDS = 8.0
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
EARLY_TERMINATION_REWARD = -200.0
|
| 28 |
|
| 29 |
|
| 30 |
@dataclass(frozen=True)
|
| 31 |
class RewardConstants:
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
GAMMA: Exploitation bounty for severe degradation/failure.
|
| 38 |
-
DELTA: Penalty magnitude for low-value fuzzing requests.
|
| 39 |
-
REDOS_LATENCY_THRESHOLD: Latency threshold used to infer ReDoS impact.
|
| 40 |
-
DB_TIMEOUT_THRESHOLD: Latency threshold used for DB timeout exploitation.
|
| 41 |
-
CHAIN_REQUIRED_ORDERS: Minimum order count before GET /orders chain bonus.
|
| 42 |
-
CHAIN_COOLDOWN_STEPS: Minimum steps between chain bonus awards.
|
| 43 |
-
EARLY_TERMINATION_REWARD: Episode early-stop reward floor.
|
| 44 |
-
BASELINE_WINDOW: EMA window used for baseline latency updates.
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
ALPHA: float = ALPHA
|
| 48 |
-
BETA: float = BETA
|
| 49 |
-
GAMMA: float = GAMMA
|
| 50 |
-
DELTA: float = DELTA
|
| 51 |
-
REDOS_LATENCY_THRESHOLD: float = REDOS_LATENCY_THRESHOLD
|
| 52 |
-
DB_TIMEOUT_THRESHOLD: float = DB_TIMEOUT_THRESHOLD
|
| 53 |
CHAIN_REQUIRED_ORDERS: int = CHAIN_REQUIRED_ORDERS
|
| 54 |
CHAIN_COOLDOWN_STEPS: int = CHAIN_COOLDOWN_STEPS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
EARLY_TERMINATION_REWARD: float = EARLY_TERMINATION_REWARD
|
| 56 |
-
BASELINE_WINDOW: int = BASELINE_WINDOW
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
| 5 |
|
| 6 |
+
DEFAULT_BASELINE_LATENCY_MS = 50.0
|
|
|
|
|
|
|
|
|
|
| 7 |
BASELINE_WINDOW = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
ACTION_TIMEOUT_SECONDS = 8.0
|
| 9 |
+
|
| 10 |
+
REDOS_LATENCY_THRESHOLD_MS = 1500.0
|
| 11 |
+
DB_TIMEOUT_THRESHOLD_MS = 800.0
|
| 12 |
+
|
| 13 |
+
CHAIN_REQUIRED_ORDERS = 21
|
| 14 |
+
CHAIN_COOLDOWN_STEPS = 10
|
| 15 |
+
|
| 16 |
+
STEP_DELTA_MAX = 0.30
|
| 17 |
+
NEW_ENDPOINT_BONUS = 0.05
|
| 18 |
+
NEW_VULNERABILITY_BONUS = 0.10
|
| 19 |
+
REPEATED_ACTION_PENALTY = 0.02
|
| 20 |
+
TERMINAL_BONUS = 0.20
|
| 21 |
+
|
| 22 |
EARLY_TERMINATION_REWARD = -200.0
|
| 23 |
|
| 24 |
|
| 25 |
@dataclass(frozen=True)
|
| 26 |
class RewardConstants:
|
| 27 |
+
DEFAULT_BASELINE_LATENCY_MS: float = DEFAULT_BASELINE_LATENCY_MS
|
| 28 |
+
BASELINE_WINDOW: int = BASELINE_WINDOW
|
| 29 |
+
ACTION_TIMEOUT_SECONDS: float = ACTION_TIMEOUT_SECONDS
|
| 30 |
+
REDOS_LATENCY_THRESHOLD_MS: float = REDOS_LATENCY_THRESHOLD_MS
|
| 31 |
+
DB_TIMEOUT_THRESHOLD_MS: float = DB_TIMEOUT_THRESHOLD_MS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
CHAIN_REQUIRED_ORDERS: int = CHAIN_REQUIRED_ORDERS
|
| 33 |
CHAIN_COOLDOWN_STEPS: int = CHAIN_COOLDOWN_STEPS
|
| 34 |
+
STEP_DELTA_MAX: float = STEP_DELTA_MAX
|
| 35 |
+
NEW_ENDPOINT_BONUS: float = NEW_ENDPOINT_BONUS
|
| 36 |
+
NEW_VULNERABILITY_BONUS: float = NEW_VULNERABILITY_BONUS
|
| 37 |
+
REPEATED_ACTION_PENALTY: float = REPEATED_ACTION_PENALTY
|
| 38 |
+
TERMINAL_BONUS: float = TERMINAL_BONUS
|
| 39 |
EARLY_TERMINATION_REWARD: float = EARLY_TERMINATION_REWARD
|
|
|
statestrike_env/environment.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import time
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
from fastapi import Body, FastAPI
|
| 12 |
+
|
| 13 |
+
from statestrike_env.constants import RewardConstants
|
| 14 |
+
from statestrike_env.grader import compute_task_reward, compute_task_score
|
| 15 |
+
from statestrike_env.models import (
|
| 16 |
+
EndpointChoice,
|
| 17 |
+
PayloadStrategy,
|
| 18 |
+
StateStrikeAction,
|
| 19 |
+
StateStrikeObservation,
|
| 20 |
+
StateStrikeState,
|
| 21 |
+
StepResult,
|
| 22 |
+
)
|
| 23 |
+
from statestrike_env.session import StateStrikeSession
|
| 24 |
+
from statestrike_env.tasks import TASK_REGISTRY
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class StateStrikeEnv:
|
| 28 |
+
"""Unified OpenEnv-compatible runtime for StateStrike."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
honeypot_url: str | None = None,
|
| 33 |
+
constants: RewardConstants | None = None,
|
| 34 |
+
) -> None:
|
| 35 |
+
self.honeypot_url = (honeypot_url or os.getenv("HONEYPOT_URL", "http://localhost:8000")).rstrip("/")
|
| 36 |
+
self.constants = constants or RewardConstants()
|
| 37 |
+
self.session = StateStrikeSession.new_session("endpoint_discovery")
|
| 38 |
+
self._managed_container_id: str | None = None
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
async def from_docker_image(cls, image_name: str) -> StateStrikeEnv:
|
| 42 |
+
env = cls(honeypot_url=os.getenv("HONEYPOT_URL", "http://localhost:8000"))
|
| 43 |
+
proc = await asyncio.create_subprocess_exec(
|
| 44 |
+
"docker",
|
| 45 |
+
"run",
|
| 46 |
+
"-d",
|
| 47 |
+
"-p",
|
| 48 |
+
"8000:8000",
|
| 49 |
+
image_name,
|
| 50 |
+
stdout=asyncio.subprocess.PIPE,
|
| 51 |
+
stderr=asyncio.subprocess.PIPE,
|
| 52 |
+
)
|
| 53 |
+
stdout, stderr = await proc.communicate()
|
| 54 |
+
if proc.returncode != 0:
|
| 55 |
+
raise RuntimeError(f"Failed to start docker image: {stderr.decode().strip()}")
|
| 56 |
+
env._managed_container_id = stdout.decode().strip()
|
| 57 |
+
|
| 58 |
+
for _ in range(30):
|
| 59 |
+
try:
|
| 60 |
+
async with httpx.AsyncClient(timeout=2.0) as client:
|
| 61 |
+
response = await client.get(f"{env.honeypot_url}/health")
|
| 62 |
+
if response.status_code == 200:
|
| 63 |
+
return env
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
await asyncio.sleep(1)
|
| 67 |
+
|
| 68 |
+
await env.close()
|
| 69 |
+
raise RuntimeError("Timed out waiting for honeypot container to become ready")
|
| 70 |
+
|
| 71 |
+
async def close(self) -> None:
|
| 72 |
+
if self._managed_container_id:
|
| 73 |
+
process = await asyncio.create_subprocess_exec(
|
| 74 |
+
"docker",
|
| 75 |
+
"rm",
|
| 76 |
+
"-f",
|
| 77 |
+
self._managed_container_id,
|
| 78 |
+
stdout=asyncio.subprocess.PIPE,
|
| 79 |
+
stderr=asyncio.subprocess.PIPE,
|
| 80 |
+
)
|
| 81 |
+
await process.communicate()
|
| 82 |
+
self._managed_container_id = None
|
| 83 |
+
|
| 84 |
+
async def reset(self, task_name: str = "endpoint_discovery") -> StepResult:
|
| 85 |
+
if task_name not in TASK_REGISTRY:
|
| 86 |
+
task_name = "endpoint_discovery"
|
| 87 |
+
|
| 88 |
+
status, latency_ms, _ = await self._request_honeypot("GET", "/health")
|
| 89 |
+
baseline = latency_ms if latency_ms > 0 else self.constants.DEFAULT_BASELINE_LATENCY_MS
|
| 90 |
+
self.session.reset(task_name=task_name, baseline_latency=baseline)
|
| 91 |
+
|
| 92 |
+
observation = StateStrikeObservation(
|
| 93 |
+
step=0,
|
| 94 |
+
endpoint_called=EndpointChoice.HEALTH.value,
|
| 95 |
+
http_status=status,
|
| 96 |
+
latency_ms=latency_ms,
|
| 97 |
+
response_body={"status": "reset"},
|
| 98 |
+
session_order_count=0,
|
| 99 |
+
endpoints_discovered=[],
|
| 100 |
+
vulnerabilities_found=[],
|
| 101 |
+
task_progress=0.0,
|
| 102 |
+
)
|
| 103 |
+
return StepResult(observation=observation, reward=0.0, done=False, info={"task": task_name})
|
| 104 |
+
|
| 105 |
+
async def step(self, action: StateStrikeAction) -> StepResult:
|
| 106 |
+
method, path, params, payload = self._translate_action(action)
|
| 107 |
+
status, latency_ms, body = await self._request_honeypot(method, path, params=params, payload=payload)
|
| 108 |
+
|
| 109 |
+
endpoint_value = action.endpoint if isinstance(action.endpoint, str) else action.endpoint.value
|
| 110 |
+
strategy_value = (
|
| 111 |
+
action.payload_strategy if isinstance(action.payload_strategy, str) else action.payload_strategy.value
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.session.step_count += 1
|
| 115 |
+
if endpoint_value == EndpointChoice.POST_ORDERS.value and status in (200, 201):
|
| 116 |
+
self.session.order_count += 1
|
| 117 |
+
|
| 118 |
+
endpoint_name = endpoint_value
|
| 119 |
+
new_endpoint = False
|
| 120 |
+
if status > 0 and endpoint_name not in self.session.endpoints_discovered:
|
| 121 |
+
self.session.endpoints_discovered.add(endpoint_name)
|
| 122 |
+
new_endpoint = True
|
| 123 |
+
|
| 124 |
+
signature = f"{endpoint_value}|{strategy_value}|{action.target_user_id}"
|
| 125 |
+
repeated_action = signature == self.session.last_action_signature
|
| 126 |
+
self.session.last_action_signature = signature
|
| 127 |
+
|
| 128 |
+
new_vulnerability = False
|
| 129 |
+
|
| 130 |
+
if (
|
| 131 |
+
not self.session.redos_bounty_awarded
|
| 132 |
+
and endpoint_value == EndpointChoice.POST_USERS.value
|
| 133 |
+
and strategy_value == PayloadStrategy.REDOS_ATTACK.value
|
| 134 |
+
and status == 400
|
| 135 |
+
and latency_ms > self.constants.REDOS_LATENCY_THRESHOLD_MS
|
| 136 |
+
):
|
| 137 |
+
self.session.redos_bounty_awarded = True
|
| 138 |
+
self.session.vulnerabilities_found.add("redos")
|
| 139 |
+
new_vulnerability = True
|
| 140 |
+
|
| 141 |
+
chain_cooldown_ready = (
|
| 142 |
+
self.session.step_count - self.session.last_chain_bonus_step
|
| 143 |
+
) >= self.constants.CHAIN_COOLDOWN_STEPS
|
| 144 |
+
chain_progressed = self.session.order_count > self.session.post_count_at_last_chain
|
| 145 |
+
if (
|
| 146 |
+
not self.session.db_degradation_bounty_awarded
|
| 147 |
+
and endpoint_value == EndpointChoice.GET_ORDERS.value
|
| 148 |
+
and self.session.order_count >= self.constants.CHAIN_REQUIRED_ORDERS
|
| 149 |
+
and latency_ms > self.constants.DB_TIMEOUT_THRESHOLD_MS
|
| 150 |
+
and chain_cooldown_ready
|
| 151 |
+
and chain_progressed
|
| 152 |
+
):
|
| 153 |
+
self.session.db_degradation_bounty_awarded = True
|
| 154 |
+
self.session.vulnerabilities_found.add("db_degradation")
|
| 155 |
+
self.session.last_chain_bonus_step = self.session.step_count
|
| 156 |
+
self.session.post_count_at_last_chain = self.session.order_count
|
| 157 |
+
new_vulnerability = True
|
| 158 |
+
|
| 159 |
+
if (
|
| 160 |
+
endpoint_value == EndpointChoice.POST_USERS.value
|
| 161 |
+
and strategy_value == PayloadStrategy.VALID.value
|
| 162 |
+
and status in (200, 201)
|
| 163 |
+
):
|
| 164 |
+
self.session.user_created = True
|
| 165 |
+
|
| 166 |
+
self.session.steps_history.append(
|
| 167 |
+
{
|
| 168 |
+
"endpoint": endpoint_value,
|
| 169 |
+
"payload_strategy": strategy_value,
|
| 170 |
+
"target_user_id": action.target_user_id,
|
| 171 |
+
"http_status": status,
|
| 172 |
+
"latency_ms": latency_ms,
|
| 173 |
+
}
|
| 174 |
+
)
|
| 175 |
+
if len(self.session.steps_history) > 200:
|
| 176 |
+
self.session.steps_history.pop(0)
|
| 177 |
+
|
| 178 |
+
self.session.task_specific_state["new_endpoint_discovered"] = new_endpoint
|
| 179 |
+
self.session.task_specific_state["new_vulnerability_found"] = new_vulnerability
|
| 180 |
+
self.session.task_specific_state["repeated_action"] = repeated_action
|
| 181 |
+
|
| 182 |
+
task_score = compute_task_score(self.session, self.session.task_name)
|
| 183 |
+
observation = StateStrikeObservation(
|
| 184 |
+
step=self.session.step_count,
|
| 185 |
+
endpoint_called=endpoint_value,
|
| 186 |
+
http_status=status,
|
| 187 |
+
latency_ms=latency_ms,
|
| 188 |
+
response_body=body,
|
| 189 |
+
session_order_count=self.session.order_count,
|
| 190 |
+
endpoints_discovered=sorted(self.session.endpoints_discovered),
|
| 191 |
+
vulnerabilities_found=sorted(self.session.vulnerabilities_found),
|
| 192 |
+
task_progress=task_score,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
reward, breakdown = compute_task_reward(
|
| 196 |
+
observation,
|
| 197 |
+
self.session,
|
| 198 |
+
self.session.task_name,
|
| 199 |
+
self.constants,
|
| 200 |
+
)
|
| 201 |
+
self.session.cumulative_reward += reward
|
| 202 |
+
|
| 203 |
+
task_cfg, _ = TASK_REGISTRY[self.session.task_name]
|
| 204 |
+
done = self.session.step_count >= task_cfg.max_steps or task_score >= task_cfg.success_threshold
|
| 205 |
+
|
| 206 |
+
return StepResult(
|
| 207 |
+
observation=observation,
|
| 208 |
+
reward=reward,
|
| 209 |
+
done=done,
|
| 210 |
+
info={
|
| 211 |
+
"reward_breakdown": breakdown,
|
| 212 |
+
"task": self.session.task_name,
|
| 213 |
+
},
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
async def state(self) -> StateStrikeState:
|
| 217 |
+
return self.session.as_state()
|
| 218 |
+
|
| 219 |
+
def reset_sync(self, task_name: str = "endpoint_discovery") -> StepResult:
|
| 220 |
+
return asyncio.run(self.reset(task_name=task_name))
|
| 221 |
+
|
| 222 |
+
def step_sync(self, action: StateStrikeAction) -> StepResult:
|
| 223 |
+
return asyncio.run(self.step(action))
|
| 224 |
+
|
| 225 |
+
def state_sync(self) -> StateStrikeState:
|
| 226 |
+
return asyncio.run(self.state())
|
| 227 |
+
|
| 228 |
+
def _translate_action(
|
| 229 |
+
self,
|
| 230 |
+
action: StateStrikeAction,
|
| 231 |
+
) -> tuple[str, str, dict[str, Any] | None, dict[str, Any] | None]:
|
| 232 |
+
endpoint_value = action.endpoint if isinstance(action.endpoint, str) else action.endpoint.value
|
| 233 |
+
strategy_value = (
|
| 234 |
+
action.payload_strategy if isinstance(action.payload_strategy, str) else action.payload_strategy.value
|
| 235 |
+
)
|
| 236 |
+
target_user_id = action.target_user_id or 1
|
| 237 |
+
|
| 238 |
+
if endpoint_value == EndpointChoice.POST_USERS.value:
|
| 239 |
+
return "POST", "/users", None, {"email": self._email_for_strategy(strategy_value)}
|
| 240 |
+
if endpoint_value == EndpointChoice.GET_USER.value:
|
| 241 |
+
return "GET", f"/users/{target_user_id}", None, None
|
| 242 |
+
if endpoint_value == EndpointChoice.POST_ORDERS.value:
|
| 243 |
+
return "POST", "/orders", None, {
|
| 244 |
+
"user_id": target_user_id,
|
| 245 |
+
"item": self._item_for_strategy(strategy_value),
|
| 246 |
+
}
|
| 247 |
+
if endpoint_value == EndpointChoice.GET_ORDERS.value:
|
| 248 |
+
return "GET", "/orders", {"user_id": target_user_id}, None
|
| 249 |
+
return "GET", "/health", None, None
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _email_for_strategy(strategy: str) -> str:
|
| 253 |
+
if strategy == PayloadStrategy.REDOS_ATTACK.value:
|
| 254 |
+
return "a" * 39 + "!"
|
| 255 |
+
if strategy == PayloadStrategy.OVERSIZED.value:
|
| 256 |
+
return "A" * 4096
|
| 257 |
+
if strategy == PayloadStrategy.MALFORMED.value:
|
| 258 |
+
return "@@@"
|
| 259 |
+
return "validuser123"
|
| 260 |
+
|
| 261 |
+
@staticmethod
|
| 262 |
+
def _item_for_strategy(strategy: str) -> str:
|
| 263 |
+
if strategy == PayloadStrategy.OVERSIZED.value:
|
| 264 |
+
return "item_" + ("X" * 2048)
|
| 265 |
+
if strategy == PayloadStrategy.MALFORMED.value:
|
| 266 |
+
return ""
|
| 267 |
+
return "standard_item"
|
| 268 |
+
|
| 269 |
+
async def _request_honeypot(
|
| 270 |
+
self,
|
| 271 |
+
method: str,
|
| 272 |
+
path: str,
|
| 273 |
+
*,
|
| 274 |
+
params: dict[str, Any] | None = None,
|
| 275 |
+
payload: dict[str, Any] | None = None,
|
| 276 |
+
) -> tuple[int, float, dict[str, Any]]:
|
| 277 |
+
url = f"{self.honeypot_url}{path}"
|
| 278 |
+
started = time.perf_counter()
|
| 279 |
+
try:
|
| 280 |
+
async with httpx.AsyncClient(timeout=self.constants.ACTION_TIMEOUT_SECONDS) as client:
|
| 281 |
+
response = await client.request(method, url, params=params, json=payload)
|
| 282 |
+
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
| 283 |
+
header_latency = response.headers.get("X-Process-Time-Ms")
|
| 284 |
+
latency_ms = float(header_latency) if header_latency else elapsed_ms
|
| 285 |
+
body = response.json() if response.content else {}
|
| 286 |
+
return response.status_code, latency_ms, body
|
| 287 |
+
except Exception as exc:
|
| 288 |
+
return 0, 0.0, {"error": str(exc), "synthetic": True}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@asynccontextmanager
|
| 292 |
+
async def lifespan(_: FastAPI):
|
| 293 |
+
yield
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
app = FastAPI(title="StateStrike", lifespan=lifespan)
|
| 297 |
+
_global_env = StateStrikeEnv()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@app.post("/reset")
|
| 301 |
+
async def reset_endpoint(body: dict = Body(default={})):
|
| 302 |
+
task = body.get("task", "endpoint_discovery")
|
| 303 |
+
result = await _global_env.reset(task_name=task)
|
| 304 |
+
return result.model_dump()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@app.post("/step")
|
| 308 |
+
async def step_endpoint(action: StateStrikeAction):
|
| 309 |
+
result = await _global_env.step(action)
|
| 310 |
+
return result.model_dump()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@app.get("/state")
|
| 314 |
+
async def state_endpoint():
|
| 315 |
+
return (await _global_env.state()).model_dump()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@app.get("/health")
|
| 319 |
+
async def health():
|
| 320 |
+
return {"status": "ok"}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def main() -> None:
|
| 324 |
+
import uvicorn
|
| 325 |
+
|
| 326 |
+
host = os.getenv("STATESTRIKE_ENV_HOST", "0.0.0.0")
|
| 327 |
+
port = int(os.getenv("STATESTRIKE_ENV_PORT", "7860"))
|
| 328 |
+
uvicorn.run("statestrike_env.environment:app", host=host, port=port, reload=False)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
main()
|
statestrike_env/grader.py
CHANGED
|
@@ -1,208 +1,99 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
"""Reward grading logic for StateStrike.
|
| 4 |
-
|
| 5 |
-
Theory:
|
| 6 |
-
The reward function follows standard MDP shaping principles from Sutton &
|
| 7 |
-
Barto (2018): combine dense shaping signals (latency ratio), sparse goal
|
| 8 |
-
rewards (exploit bounty), and penalties (invalid spam suppression). It also
|
| 9 |
-
borrows stateful-sequence ideas from RESTler (Atlidakis et al., ICSE 2019)
|
| 10 |
-
while rewarding infrastructure effects (e.g., ReDoS latency spikes) inspired
|
| 11 |
-
by Davis et al. (USENIX Security 2018).
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
import logging
|
| 15 |
-
import
|
| 16 |
-
from typing import TYPE_CHECKING
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
-
Compute
|
| 33 |
-
|
| 34 |
-
Theory (
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
Why log? Linear reward incentivizes the agent to find ONE massive spike
|
| 41 |
-
and repeat it. Logarithmic reward gives diminishing returns per repeated
|
| 42 |
-
exploitation, pushing the agent to discover NEW vulnerabilities.
|
| 43 |
-
Why ratio? Prevents baseline-anchoring attacks where agent engineers a
|
| 44 |
-
low baseline then makes normal requests look like spikes.
|
| 45 |
-
Anti-hack: baseline ONLY updates from successful (latency>0) steps.
|
| 46 |
-
|
| 47 |
-
TERM 2 — β·S_t: State-chain bonus.
|
| 48 |
-
Fires at most once per CHAIN_COOLDOWN_STEPS steps, and only if
|
| 49 |
-
order_count has increased since the last award. This prevents the
|
| 50 |
-
POST→GET farming loop that would yield +5 reward/step for free.
|
| 51 |
-
Anti-hack: last_chain_bonus_step and post_count_at_last_chain guards.
|
| 52 |
-
|
| 53 |
-
TERM 3 — γ·E_t: Exploitation bounty.
|
| 54 |
-
Fires EXACTLY ONCE per vulnerability type per episode (one-time flag).
|
| 55 |
-
Without this, an agent discovering db_degradation would spam GET /orders
|
| 56 |
-
for +500/step indefinitely. The one-time award correctly signals
|
| 57 |
-
"you found it" without incentivizing repeated triggering.
|
| 58 |
-
Anti-hack: redos_bounty_awarded and db_degradation_bounty_awarded flags.
|
| 59 |
-
|
| 60 |
-
TERM 4 — δ·P_t: Fuzzing penalty.
|
| 61 |
-
Applied only to genuinely fast 400s (latency < 100ms), not to slow 400s
|
| 62 |
-
(which may indicate actual CPU burn from ReDoS parsing).
|
| 63 |
-
Threshold tightened from 200ms to 100ms to avoid penalizing legitimate
|
| 64 |
-
slow-failing payloads.
|
| 65 |
-
Anti-hack: latency threshold ensures ReDoS probes are not penalized.
|
| 66 |
-
|
| 67 |
-
Reference:
|
| 68 |
-
- Sutton & Barto (2018): reward shaping and sparse reward design
|
| 69 |
-
- Atlidakis et al. (ICSE 2019): stateful API exploration objectives
|
| 70 |
-
- Davis et al. (USENIX 2018): ReDoS computational complexity
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
obs: The observation from the current step.
|
| 74 |
-
session: The mutable session state (modified in-place for flags).
|
| 75 |
-
constants: Reward weight constants from constants.py.
|
| 76 |
-
|
| 77 |
-
Returns:
|
| 78 |
-
Tuple of (scalar_reward, breakdown_dict) where breakdown_dict
|
| 79 |
-
contains each term's contribution for telemetry and dashboard display.
|
| 80 |
"""
|
| 81 |
|
| 82 |
-
|
| 83 |
-
breakdown: dict[str,
|
| 84 |
-
"
|
| 85 |
-
"
|
| 86 |
-
"
|
| 87 |
-
"
|
|
|
|
| 88 |
"total": 0.0,
|
| 89 |
}
|
| 90 |
|
| 91 |
-
# Guard: connection failure -> neutral observation, no reward signal.
|
| 92 |
if obs.http_status == 0 or obs.latency_ms == 0.0:
|
| 93 |
-
breakdown["error"] =
|
| 94 |
-
logger.debug("Step %d: connection failure, returning zero reward", obs.step)
|
| 95 |
return 0.0, breakdown
|
| 96 |
|
| 97 |
-
|
| 98 |
-
_update_baseline(session, obs.latency_ms, constants.BASELINE_WINDOW)
|
| 99 |
-
|
| 100 |
-
# TERM 1: Logarithmic latency reward.
|
| 101 |
-
if session.baseline_latency > 0:
|
| 102 |
-
latency_ratio = obs.latency_ms / session.baseline_latency
|
| 103 |
-
latency_ratio = max(0.01, min(latency_ratio, 100.0))
|
| 104 |
-
latency_reward = constants.ALPHA * math.log(latency_ratio)
|
| 105 |
-
reward += latency_reward
|
| 106 |
-
breakdown["latency_reward"] = round(latency_reward, 4)
|
| 107 |
-
logger.debug(
|
| 108 |
-
"Step %d: latency=%.1fms baseline=%.1fms ratio=%.2f reward=%.3f",
|
| 109 |
-
obs.step,
|
| 110 |
-
obs.latency_ms,
|
| 111 |
-
session.baseline_latency,
|
| 112 |
-
latency_ratio,
|
| 113 |
-
latency_reward,
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
# TERM 2: State-chain bonus (anti-farming guards).
|
| 117 |
-
chain_bonus = 0.0
|
| 118 |
-
if _should_award_chain_bonus(obs, session, constants):
|
| 119 |
-
chain_bonus = constants.BETA
|
| 120 |
-
session.last_chain_bonus_step = obs.step
|
| 121 |
-
session.post_count_at_last_chain = session.order_count
|
| 122 |
-
logger.info(
|
| 123 |
-
"Step %d: Chain bonus awarded (+%.1f). order_count=%d",
|
| 124 |
-
obs.step,
|
| 125 |
-
chain_bonus,
|
| 126 |
-
session.order_count,
|
| 127 |
-
)
|
| 128 |
-
reward += chain_bonus
|
| 129 |
-
breakdown["chain_bonus"] = chain_bonus
|
| 130 |
-
|
| 131 |
-
# TERM 3: Exploitation bounties (one-time per episode).
|
| 132 |
-
exploit_bounty = 0.0
|
| 133 |
-
|
| 134 |
-
if (
|
| 135 |
-
not session.db_degradation_bounty_awarded
|
| 136 |
-
and (obs.http_status >= 500 or obs.latency_ms > constants.DB_TIMEOUT_THRESHOLD)
|
| 137 |
-
and obs.action_taken.action_type.value == "get_orders"
|
| 138 |
-
):
|
| 139 |
-
exploit_bounty += constants.GAMMA
|
| 140 |
-
session.db_degradation_bounty_awarded = True
|
| 141 |
-
session.triggered_vulns.add("db_degradation")
|
| 142 |
-
logger.info(
|
| 143 |
-
"Step %d: DB_DEGRADATION bounty awarded (+%.1f). latency=%.1fms",
|
| 144 |
-
obs.step,
|
| 145 |
-
constants.GAMMA,
|
| 146 |
-
obs.latency_ms,
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
if (
|
| 150 |
-
not session.redos_bounty_awarded
|
| 151 |
-
and obs.latency_ms > constants.REDOS_LATENCY_THRESHOLD
|
| 152 |
-
and obs.http_status == 400
|
| 153 |
-
and obs.action_taken.action_type.value == "post_user"
|
| 154 |
-
):
|
| 155 |
-
redos_bounty = constants.GAMMA * 0.8
|
| 156 |
-
exploit_bounty += redos_bounty
|
| 157 |
-
session.redos_bounty_awarded = True
|
| 158 |
-
session.triggered_vulns.add("redos")
|
| 159 |
-
logger.info(
|
| 160 |
-
"Step %d: REDOS bounty awarded (+%.1f). latency=%.1fms",
|
| 161 |
-
obs.step,
|
| 162 |
-
redos_bounty,
|
| 163 |
-
obs.latency_ms,
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
reward += exploit_bounty
|
| 167 |
-
breakdown["exploit_bounty"] = round(exploit_bounty, 4)
|
| 168 |
-
|
| 169 |
-
# TERM 4: Fuzzing penalty (only genuine fast-fail syntax errors).
|
| 170 |
-
fuzz_penalty = 0.0
|
| 171 |
-
if obs.http_status == 400 and obs.latency_ms < 100.0:
|
| 172 |
-
fuzz_penalty = -constants.DELTA
|
| 173 |
-
logger.debug("Step %d: Fuzz penalty applied (fast 400, %.1fms)", obs.step, obs.latency_ms)
|
| 174 |
-
reward += fuzz_penalty
|
| 175 |
-
breakdown["fuzz_penalty"] = round(fuzz_penalty, 4)
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
| 179 |
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
else:
|
| 188 |
-
session.baseline_latency = alpha_ema * latency_ms + (1 - alpha_ema) * session.baseline_latency
|
| 189 |
-
session.baseline_sample_count += 1
|
| 190 |
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
+
from typing import Any
|
|
|
|
| 5 |
|
| 6 |
+
from statestrike_env.constants import RewardConstants
|
| 7 |
+
from statestrike_env.models import StateStrikeObservation
|
| 8 |
+
from statestrike_env.session import StateStrikeSession
|
| 9 |
+
from statestrike_env.tasks import TASK_REGISTRY
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
def compute_task_score(session: StateStrikeSession, task_name: str) -> float:
|
| 15 |
+
task_config, grader = TASK_REGISTRY[task_name]
|
| 16 |
+
del task_config
|
| 17 |
+
return float(grader.score(session.as_grader_state()))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _update_baseline_ema(session: StateStrikeSession, latency_ms: float, window: int) -> None:
|
| 21 |
+
alpha_ema = 2.0 / (window + 1)
|
| 22 |
+
if session.baseline_sample_count == 0:
|
| 23 |
+
session.baseline_latency = latency_ms
|
| 24 |
+
else:
|
| 25 |
+
session.baseline_latency = alpha_ema * latency_ms + (1 - alpha_ema) * session.baseline_latency
|
| 26 |
+
session.baseline_sample_count += 1
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def compute_task_reward(
|
| 30 |
+
obs: StateStrikeObservation,
|
| 31 |
+
session: StateStrikeSession,
|
| 32 |
+
task_name: str,
|
| 33 |
+
constants: RewardConstants,
|
| 34 |
+
) -> tuple[float, dict[str, Any]]:
|
| 35 |
"""
|
| 36 |
+
Compute step reward in [0.0, 1.0] based on task progress delta.
|
| 37 |
+
|
| 38 |
+
Theory (reward shaping, Ng et al. 1999):
|
| 39 |
+
R_shaped(s, a, s') = R(s, a, s') + gamma*Phi(s') - Phi(s)
|
| 40 |
+
where Phi(s) = task_score(s) is the potential function.
|
| 41 |
+
|
| 42 |
+
The terminal bonus (+0.20) is a sparse goal reward layered on top of
|
| 43 |
+
the shaped reward, following the hybrid approach in Sutton & Barto (2018).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
|
| 46 |
+
task_config, _ = TASK_REGISTRY[task_name]
|
| 47 |
+
breakdown: dict[str, Any] = {
|
| 48 |
+
"score_delta": 0.0,
|
| 49 |
+
"new_endpoint_bonus": 0.0,
|
| 50 |
+
"new_vulnerability_bonus": 0.0,
|
| 51 |
+
"repeat_penalty": 0.0,
|
| 52 |
+
"terminal_bonus": 0.0,
|
| 53 |
"total": 0.0,
|
| 54 |
}
|
| 55 |
|
|
|
|
| 56 |
if obs.http_status == 0 or obs.latency_ms == 0.0:
|
| 57 |
+
breakdown["error"] = "connection_failed"
|
|
|
|
| 58 |
return 0.0, breakdown
|
| 59 |
|
| 60 |
+
_update_baseline_ema(session, obs.latency_ms, constants.BASELINE_WINDOW)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
current_task_score = compute_task_score(session, task_name)
|
| 63 |
+
previous_task_score = session.previous_task_score
|
| 64 |
+
score_delta = max(0.0, current_task_score - previous_task_score)
|
| 65 |
+
score_delta = min(score_delta, constants.STEP_DELTA_MAX)
|
| 66 |
+
breakdown["score_delta"] = round(score_delta, 4)
|
| 67 |
|
| 68 |
+
reward = score_delta
|
| 69 |
|
| 70 |
+
if bool(session.task_specific_state.get("new_endpoint_discovered", False)):
|
| 71 |
+
reward += constants.NEW_ENDPOINT_BONUS
|
| 72 |
+
breakdown["new_endpoint_bonus"] = constants.NEW_ENDPOINT_BONUS
|
| 73 |
|
| 74 |
+
if bool(session.task_specific_state.get("new_vulnerability_found", False)):
|
| 75 |
+
reward += constants.NEW_VULNERABILITY_BONUS
|
| 76 |
+
breakdown["new_vulnerability_bonus"] = constants.NEW_VULNERABILITY_BONUS
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
if bool(session.task_specific_state.get("repeated_action", False)):
|
| 79 |
+
reward -= constants.REPEATED_ACTION_PENALTY
|
| 80 |
+
breakdown["repeat_penalty"] = -constants.REPEATED_ACTION_PENALTY
|
| 81 |
|
| 82 |
+
terminal = session.step_count >= task_config.max_steps or current_task_score >= task_config.success_threshold
|
| 83 |
+
if terminal and current_task_score >= task_config.success_threshold:
|
| 84 |
+
reward += constants.TERMINAL_BONUS
|
| 85 |
+
breakdown["terminal_bonus"] = constants.TERMINAL_BONUS
|
| 86 |
+
|
| 87 |
+
reward = max(0.0, min(1.0, reward))
|
| 88 |
+
breakdown["total"] = round(reward, 4)
|
| 89 |
+
breakdown["task_score"] = round(current_task_score, 4)
|
| 90 |
+
|
| 91 |
+
session.previous_task_score = max(previous_task_score, current_task_score)
|
| 92 |
+
logger.debug(
|
| 93 |
+
"task=%s step=%d score=%.3f reward=%.3f",
|
| 94 |
+
task_name,
|
| 95 |
+
obs.step,
|
| 96 |
+
current_task_score,
|
| 97 |
+
reward,
|
| 98 |
+
)
|
| 99 |
+
return reward, breakdown
|
statestrike_env/models.py
CHANGED
|
@@ -1,31 +1,20 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
"""Typed action, observation, and state models for StateStrike.
|
| 4 |
-
|
| 5 |
-
Theory:
|
| 6 |
-
Explicit state/action schemas reduce ambiguity in RL interfaces and improve
|
| 7 |
-
reproducibility when evaluating policies across different backends.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
from enum import Enum
|
| 11 |
from typing import Any, Optional
|
| 12 |
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
|
| 15 |
|
| 16 |
-
class
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
GET_ORDERS = "get_orders"
|
| 23 |
-
HEALTH_CHECK = "health_check"
|
| 24 |
|
| 25 |
|
| 26 |
class PayloadStrategy(str, Enum):
|
| 27 |
-
"""Payload generation strategies used by the fuzzing policy."""
|
| 28 |
-
|
| 29 |
VALID = "valid"
|
| 30 |
REDOS_ATTACK = "redos"
|
| 31 |
OVERSIZED = "oversized"
|
|
@@ -33,66 +22,48 @@ class PayloadStrategy(str, Enum):
|
|
| 33 |
|
| 34 |
|
| 35 |
class StateStrikeAction(BaseModel):
|
| 36 |
-
"""Action
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
action_type: Target endpoint operation.
|
| 40 |
-
payload_strategy: Payload mutation strategy.
|
| 41 |
-
target_user_id: Optional user identifier override.
|
| 42 |
-
"""
|
| 43 |
|
| 44 |
-
|
| 45 |
-
payload_strategy: PayloadStrategy
|
| 46 |
target_user_id: Optional[int] = None
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
class StateStrikeObservation(BaseModel):
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
step: Current step index within the episode.
|
| 54 |
-
action_taken: Action executed during the step.
|
| 55 |
-
http_status: HTTP status code from honeypot response.
|
| 56 |
-
latency_ms: End-to-end processing latency in milliseconds.
|
| 57 |
-
reward: Scalar reward at this step.
|
| 58 |
-
cumulative_reward: Running reward sum for the episode.
|
| 59 |
-
baseline_latency_ms: Rolling latency baseline used for normalization.
|
| 60 |
-
order_count: Number of POST /orders calls in this episode.
|
| 61 |
-
triggered_vulns: Vulnerability labels discovered so far.
|
| 62 |
-
done: Terminal signal for episode completion.
|
| 63 |
-
info: Arbitrary metadata, including reward breakdown.
|
| 64 |
-
"""
|
| 65 |
|
| 66 |
step: int
|
| 67 |
-
|
| 68 |
http_status: int
|
| 69 |
latency_ms: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
reward: float
|
| 71 |
-
cumulative_reward: float
|
| 72 |
-
baseline_latency_ms: float
|
| 73 |
-
order_count: int
|
| 74 |
-
triggered_vulns: list[str]
|
| 75 |
done: bool
|
| 76 |
info: dict[str, Any] = Field(default_factory=dict)
|
| 77 |
|
| 78 |
|
| 79 |
class StateStrikeState(BaseModel):
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
Args:
|
| 83 |
-
session_id: Unique identifier for current environment episode.
|
| 84 |
-
step_count: Number of actions executed in current session.
|
| 85 |
-
cumulative_reward: Running reward sum for current session.
|
| 86 |
-
order_count: Number of POST /orders calls in session.
|
| 87 |
-
baseline_latency_ms: Rolling baseline latency in milliseconds.
|
| 88 |
-
action_history: Most recent action history window.
|
| 89 |
-
triggered_vulns: Vulnerabilities discovered in this session.
|
| 90 |
-
"""
|
| 91 |
|
| 92 |
session_id: str
|
|
|
|
| 93 |
step_count: int
|
| 94 |
cumulative_reward: float
|
| 95 |
order_count: int
|
| 96 |
baseline_latency_ms: float
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any, Optional
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
|
| 9 |
+
class EndpointChoice(str, Enum):
|
| 10 |
+
POST_USERS = "POST /users"
|
| 11 |
+
GET_USER = "GET /users/{id}"
|
| 12 |
+
POST_ORDERS = "POST /orders"
|
| 13 |
+
GET_ORDERS = "GET /orders"
|
| 14 |
+
HEALTH = "GET /health"
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class PayloadStrategy(str, Enum):
|
|
|
|
|
|
|
| 18 |
VALID = "valid"
|
| 19 |
REDOS_ATTACK = "redos"
|
| 20 |
OVERSIZED = "oversized"
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class StateStrikeAction(BaseModel):
|
| 25 |
+
"""Action space for StateStrike environment."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
endpoint: EndpointChoice
|
| 28 |
+
payload_strategy: PayloadStrategy = PayloadStrategy.VALID
|
| 29 |
target_user_id: Optional[int] = None
|
| 30 |
|
| 31 |
+
class Config:
|
| 32 |
+
use_enum_values = True
|
| 33 |
+
|
| 34 |
|
| 35 |
class StateStrikeObservation(BaseModel):
|
| 36 |
+
"""Observation returned after each step."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
step: int
|
| 39 |
+
endpoint_called: str
|
| 40 |
http_status: int
|
| 41 |
latency_ms: float
|
| 42 |
+
response_body: dict[str, Any] = Field(default_factory=dict)
|
| 43 |
+
session_order_count: int = 0
|
| 44 |
+
endpoints_discovered: list[str] = Field(default_factory=list)
|
| 45 |
+
vulnerabilities_found: list[str] = Field(default_factory=list)
|
| 46 |
+
task_progress: float = 0.0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class StepResult(BaseModel):
|
| 50 |
+
"""Top-level return from step()."""
|
| 51 |
+
|
| 52 |
+
observation: StateStrikeObservation
|
| 53 |
reward: float
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
done: bool
|
| 55 |
info: dict[str, Any] = Field(default_factory=dict)
|
| 56 |
|
| 57 |
|
| 58 |
class StateStrikeState(BaseModel):
|
| 59 |
+
"""Full session state, returned by state()."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
session_id: str
|
| 62 |
+
task_name: str
|
| 63 |
step_count: int
|
| 64 |
cumulative_reward: float
|
| 65 |
order_count: int
|
| 66 |
baseline_latency_ms: float
|
| 67 |
+
endpoints_discovered: list[str]
|
| 68 |
+
vulnerabilities_found: list[str]
|
| 69 |
+
task_specific_state: dict[str, Any] = Field(default_factory=dict)
|
statestrike_env/server.py
CHANGED
|
@@ -1,400 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
import json
|
| 7 |
-
import logging
|
| 8 |
-
import os
|
| 9 |
-
import time
|
| 10 |
-
from contextlib import asynccontextmanager
|
| 11 |
-
from typing import Any
|
| 12 |
|
| 13 |
-
|
| 14 |
-
from dotenv import load_dotenv
|
| 15 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 16 |
-
from fastapi.responses import JSONResponse
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
import openenv_core # noqa: F401
|
| 20 |
-
except ImportError: # pragma: no cover - optional import for compatibility signaling.
|
| 21 |
-
openenv_core = None
|
| 22 |
-
|
| 23 |
-
from statestrike_env.constants import (
|
| 24 |
-
ACTION_TIMEOUT_SECONDS,
|
| 25 |
-
DEFAULT_BASELINE_LATENCY_MS,
|
| 26 |
-
EPISODE_LENGTH,
|
| 27 |
-
RewardConstants,
|
| 28 |
-
)
|
| 29 |
-
from statestrike_env.grader import compute_reward
|
| 30 |
-
from statestrike_env.models import ActionType, PayloadStrategy, StateStrikeAction, StateStrikeObservation, StateStrikeState
|
| 31 |
-
from statestrike_env.session import StateStrikeSession
|
| 32 |
-
|
| 33 |
-
load_dotenv()
|
| 34 |
-
|
| 35 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 36 |
-
LOGGER = logging.getLogger(__name__)
|
| 37 |
-
|
| 38 |
-
HONEYPOT_URL = os.getenv("HONEYPOT_URL", "http://localhost:8000")
|
| 39 |
-
HOST = os.getenv("STATESTRIKE_ENV_HOST", "0.0.0.0")
|
| 40 |
-
PORT = int(os.getenv("STATESTRIKE_ENV_PORT", "8001"))
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
async def wait_for_honeypot(url: str, max_wait: int = 30) -> None:
|
| 44 |
-
"""Block until honeypot is reachable or raise RuntimeError.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
url: Honeypot base URL.
|
| 48 |
-
max_wait: Maximum wait time in seconds.
|
| 49 |
-
|
| 50 |
-
Raises:
|
| 51 |
-
RuntimeError: If honeypot is not reachable before timeout.
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
deadline = asyncio.get_event_loop().time() + max_wait
|
| 55 |
-
delay = 1.0
|
| 56 |
-
async with httpx.AsyncClient() as client:
|
| 57 |
-
while asyncio.get_event_loop().time() < deadline:
|
| 58 |
-
try:
|
| 59 |
-
response = await client.get(f"{url}/health", timeout=3.0)
|
| 60 |
-
if response.status_code == 200:
|
| 61 |
-
LOGGER.info("Honeypot is ready at %s", url)
|
| 62 |
-
return
|
| 63 |
-
LOGGER.warning(
|
| 64 |
-
"Honeypot health returned status=%s, retrying in %.1fs...",
|
| 65 |
-
response.status_code,
|
| 66 |
-
delay,
|
| 67 |
-
)
|
| 68 |
-
except Exception as exc: # noqa: BLE001
|
| 69 |
-
LOGGER.warning("Honeypot not ready (%s), retrying in %.1fs...", exc, delay)
|
| 70 |
-
|
| 71 |
-
await asyncio.sleep(delay)
|
| 72 |
-
delay = min(delay * 1.5, 5.0)
|
| 73 |
-
|
| 74 |
-
raise RuntimeError(f"Honeypot at {url} did not become ready within {max_wait}s")
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class StateStrikeEnvironment:
|
| 78 |
-
"""Core reset/step/state implementation.
|
| 79 |
-
|
| 80 |
-
Theory:
|
| 81 |
-
OpenEnv training loops benefit from persistent transport: WebSocket-based
|
| 82 |
-
sessions amortize handshake overhead and preserve episode-local state,
|
| 83 |
-
which aligns with OpenEnv architecture guidance (Burtenshaw, 2025).
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
def __init__(self, honeypot_url: str, constants: RewardConstants | None = None) -> None:
|
| 87 |
-
"""Initialize environment service.
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
honeypot_url: Base URL for vulnerable honeypot API.
|
| 91 |
-
constants: Optional reward constants override.
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
self.honeypot_url = honeypot_url.rstrip("/")
|
| 95 |
-
self.constants = constants or RewardConstants()
|
| 96 |
-
|
| 97 |
-
async def reset(self, session: StateStrikeSession) -> StateStrikeObservation:
|
| 98 |
-
"""Reset session and return initial observation.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
session: Session object tied to one client connection.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Initial observation with zero reward.
|
| 105 |
-
"""
|
| 106 |
-
|
| 107 |
-
status, latency_ms, _ = await self._request_honeypot("GET", "/health")
|
| 108 |
-
baseline = latency_ms if latency_ms > 0 else DEFAULT_BASELINE_LATENCY_MS
|
| 109 |
-
session.reset(baseline_latency=baseline)
|
| 110 |
-
|
| 111 |
-
action = StateStrikeAction(action_type=ActionType.HEALTH_CHECK, payload_strategy=PayloadStrategy.VALID)
|
| 112 |
-
obs = StateStrikeObservation(
|
| 113 |
-
step=0,
|
| 114 |
-
action_taken=action,
|
| 115 |
-
http_status=status,
|
| 116 |
-
latency_ms=latency_ms,
|
| 117 |
-
reward=0.0,
|
| 118 |
-
cumulative_reward=0.0,
|
| 119 |
-
baseline_latency_ms=session.baseline_latency,
|
| 120 |
-
order_count=0,
|
| 121 |
-
triggered_vulns=[],
|
| 122 |
-
done=False,
|
| 123 |
-
info={"event": "reset"},
|
| 124 |
-
)
|
| 125 |
-
return obs
|
| 126 |
-
|
| 127 |
-
async def step(self, session: StateStrikeSession, action: StateStrikeAction) -> StateStrikeObservation:
|
| 128 |
-
"""Execute one environment transition.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
session: Session object tied to one client connection.
|
| 132 |
-
action: Agent action.
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
Updated observation with reward and terminal signal.
|
| 136 |
-
"""
|
| 137 |
-
|
| 138 |
-
request_method, request_path, params, payload = self._translate_action(action, session)
|
| 139 |
-
status, latency_ms, body = await self._request_honeypot(request_method, request_path, params=params, payload=payload)
|
| 140 |
-
|
| 141 |
-
session.step_count += 1
|
| 142 |
-
if action.action_type == ActionType.POST_ORDER:
|
| 143 |
-
session.order_count += 1
|
| 144 |
-
session.append_action(action)
|
| 145 |
-
|
| 146 |
-
provisional = StateStrikeObservation(
|
| 147 |
-
step=session.step_count,
|
| 148 |
-
action_taken=action,
|
| 149 |
-
http_status=status,
|
| 150 |
-
latency_ms=latency_ms,
|
| 151 |
-
reward=0.0,
|
| 152 |
-
cumulative_reward=session.cumulative_reward,
|
| 153 |
-
baseline_latency_ms=session.baseline_latency,
|
| 154 |
-
order_count=session.order_count,
|
| 155 |
-
triggered_vulns=sorted(session.triggered_vulns),
|
| 156 |
-
done=False,
|
| 157 |
-
info={"response": body},
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
reward, breakdown = compute_reward(provisional, session, self.constants)
|
| 161 |
-
session.cumulative_reward += reward
|
| 162 |
-
|
| 163 |
-
done = (
|
| 164 |
-
session.step_count >= EPISODE_LENGTH
|
| 165 |
-
or session.cumulative_reward < self.constants.EARLY_TERMINATION_REWARD
|
| 166 |
-
)
|
| 167 |
-
obs = StateStrikeObservation(
|
| 168 |
-
step=session.step_count,
|
| 169 |
-
action_taken=action,
|
| 170 |
-
http_status=status,
|
| 171 |
-
latency_ms=latency_ms,
|
| 172 |
-
reward=reward,
|
| 173 |
-
cumulative_reward=session.cumulative_reward,
|
| 174 |
-
baseline_latency_ms=session.baseline_latency,
|
| 175 |
-
order_count=session.order_count,
|
| 176 |
-
triggered_vulns=sorted(session.triggered_vulns),
|
| 177 |
-
done=done,
|
| 178 |
-
info={"reward_breakdown": breakdown, "response": body},
|
| 179 |
-
)
|
| 180 |
-
return obs
|
| 181 |
-
|
| 182 |
-
async def state(self, session: StateStrikeSession) -> StateStrikeState:
|
| 183 |
-
"""Return serializable state snapshot.
|
| 184 |
-
|
| 185 |
-
Args:
|
| 186 |
-
session: Session object tied to one client connection.
|
| 187 |
-
|
| 188 |
-
Returns:
|
| 189 |
-
Current state model.
|
| 190 |
-
"""
|
| 191 |
-
|
| 192 |
-
return session.as_state()
|
| 193 |
-
|
| 194 |
-
def _translate_action(
|
| 195 |
-
self,
|
| 196 |
-
action: StateStrikeAction,
|
| 197 |
-
session: StateStrikeSession,
|
| 198 |
-
) -> tuple[str, str, dict[str, Any] | None, dict[str, Any] | None]:
|
| 199 |
-
"""Translate action schema into honeypot HTTP request details.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
action: Agent action.
|
| 203 |
-
session: Session used for contextual defaults.
|
| 204 |
-
|
| 205 |
-
Returns:
|
| 206 |
-
Tuple of method, path, query params, and JSON payload.
|
| 207 |
-
"""
|
| 208 |
-
|
| 209 |
-
target_user_id = action.target_user_id or 1
|
| 210 |
-
|
| 211 |
-
if action.action_type == ActionType.POST_USER:
|
| 212 |
-
email = self._payload_email(action.payload_strategy)
|
| 213 |
-
return "POST", "/users", None, {"email": email}
|
| 214 |
-
if action.action_type == ActionType.GET_USER:
|
| 215 |
-
return "GET", f"/users/{target_user_id}", None, None
|
| 216 |
-
if action.action_type == ActionType.POST_ORDER:
|
| 217 |
-
item = self._payload_item(action.payload_strategy)
|
| 218 |
-
return "POST", "/orders", None, {"user_id": target_user_id, "item": item}
|
| 219 |
-
if action.action_type == ActionType.GET_ORDERS:
|
| 220 |
-
return "GET", "/orders", {"user_id": target_user_id}, None
|
| 221 |
-
return "GET", "/health", None, None
|
| 222 |
-
|
| 223 |
-
@staticmethod
|
| 224 |
-
def _payload_email(strategy: PayloadStrategy) -> str:
|
| 225 |
-
"""Build email-like payload for POST /users action.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
strategy: Payload strategy enum.
|
| 229 |
-
|
| 230 |
-
Returns:
|
| 231 |
-
Strategy-specific string payload.
|
| 232 |
-
"""
|
| 233 |
-
|
| 234 |
-
if strategy == PayloadStrategy.REDOS_ATTACK:
|
| 235 |
-
return "a" * 39 + "!"
|
| 236 |
-
if strategy == PayloadStrategy.OVERSIZED:
|
| 237 |
-
return "A" * 4096
|
| 238 |
-
if strategy == PayloadStrategy.MALFORMED:
|
| 239 |
-
return "@@@"
|
| 240 |
-
return "validuser123"
|
| 241 |
-
|
| 242 |
-
@staticmethod
|
| 243 |
-
def _payload_item(strategy: PayloadStrategy) -> str:
|
| 244 |
-
"""Build order item payload.
|
| 245 |
-
|
| 246 |
-
Args:
|
| 247 |
-
strategy: Payload strategy enum.
|
| 248 |
-
|
| 249 |
-
Returns:
|
| 250 |
-
Strategy-specific order item string.
|
| 251 |
-
"""
|
| 252 |
-
|
| 253 |
-
if strategy == PayloadStrategy.OVERSIZED:
|
| 254 |
-
return "item_" + ("X" * 2048)
|
| 255 |
-
if strategy == PayloadStrategy.MALFORMED:
|
| 256 |
-
return ""
|
| 257 |
-
return "standard_item"
|
| 258 |
-
|
| 259 |
-
async def _request_honeypot(
|
| 260 |
-
self,
|
| 261 |
-
method: str,
|
| 262 |
-
path: str,
|
| 263 |
-
*,
|
| 264 |
-
params: dict[str, Any] | None = None,
|
| 265 |
-
payload: dict[str, Any] | None = None,
|
| 266 |
-
) -> tuple[int, float, dict[str, Any]]:
|
| 267 |
-
"""Execute honeypot request and normalize response metadata.
|
| 268 |
-
|
| 269 |
-
Args:
|
| 270 |
-
method: HTTP method.
|
| 271 |
-
path: Relative path.
|
| 272 |
-
params: Optional query parameters.
|
| 273 |
-
payload: Optional JSON body.
|
| 274 |
-
|
| 275 |
-
Returns:
|
| 276 |
-
Tuple of status code, latency milliseconds, and parsed response body.
|
| 277 |
-
"""
|
| 278 |
-
|
| 279 |
-
url = f"{self.honeypot_url}{path}"
|
| 280 |
-
started = time.perf_counter()
|
| 281 |
-
try:
|
| 282 |
-
async with httpx.AsyncClient(timeout=ACTION_TIMEOUT_SECONDS) as client:
|
| 283 |
-
response = await client.request(method, url, params=params, json=payload)
|
| 284 |
-
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
| 285 |
-
header_latency = response.headers.get("X-Process-Time-Ms")
|
| 286 |
-
latency_ms = float(header_latency) if header_latency else elapsed_ms
|
| 287 |
-
body = response.json() if response.content else {}
|
| 288 |
-
return response.status_code, latency_ms, body
|
| 289 |
-
except (httpx.RequestError, ValueError) as exc:
|
| 290 |
-
LOGGER.warning("Honeypot request failed method=%s path=%s error=%s", method, path, exc)
|
| 291 |
-
return 0, 0.0, {"error": str(exc), "synthetic": True}
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
@asynccontextmanager
|
| 295 |
-
async def lifespan(_: FastAPI):
|
| 296 |
-
"""Block API startup until honeypot health endpoint is reachable."""
|
| 297 |
-
|
| 298 |
-
await wait_for_honeypot(HONEYPOT_URL, max_wait=30)
|
| 299 |
-
yield
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
app = FastAPI(title="StateStrike OpenEnv Server", version="1.0.0", lifespan=lifespan)
|
| 303 |
-
env_service = StateStrikeEnvironment(HONEYPOT_URL)
|
| 304 |
-
http_debug_session = StateStrikeSession.new_session()
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# OpenEnv uses WebSocket (/ws) for persistent sessions rather than
|
| 308 |
-
# stateless HTTP. Each step() is a lightweight frame over an existing
|
| 309 |
-
# connection (~0.1ms overhead vs ~10-50ms TCP handshake per HTTP call).
|
| 310 |
-
# Reference: openenv-course module-5, burtenshaw/openenv-scaling
|
| 311 |
-
# This architecture enables high-frequency RL training loops.
|
| 312 |
-
@app.websocket("/ws")
|
| 313 |
-
async def websocket_env(websocket: WebSocket) -> None:
|
| 314 |
-
"""Run one isolated environment loop per WebSocket client.
|
| 315 |
-
|
| 316 |
-
Args:
|
| 317 |
-
websocket: Connected client transport.
|
| 318 |
-
"""
|
| 319 |
-
|
| 320 |
-
await websocket.accept()
|
| 321 |
-
session = StateStrikeSession.new_session()
|
| 322 |
-
LOGGER.info("WebSocket session started session_id=%s", session.session_id)
|
| 323 |
-
|
| 324 |
-
try:
|
| 325 |
-
while True:
|
| 326 |
-
frame = await websocket.receive_text()
|
| 327 |
-
request = json.loads(frame)
|
| 328 |
-
method = request.get("method")
|
| 329 |
-
|
| 330 |
-
if method == "reset":
|
| 331 |
-
obs = await env_service.reset(session)
|
| 332 |
-
await websocket.send_json({"ok": True, "observation": obs.model_dump()})
|
| 333 |
-
continue
|
| 334 |
-
|
| 335 |
-
if method == "step":
|
| 336 |
-
action_payload = request.get("action", {})
|
| 337 |
-
action = StateStrikeAction.model_validate(action_payload)
|
| 338 |
-
obs = await env_service.step(session, action)
|
| 339 |
-
await websocket.send_json({"ok": True, "observation": obs.model_dump()})
|
| 340 |
-
continue
|
| 341 |
-
|
| 342 |
-
if method == "state":
|
| 343 |
-
state = await env_service.state(session)
|
| 344 |
-
await websocket.send_json({"ok": True, "state": state.model_dump()})
|
| 345 |
-
continue
|
| 346 |
-
|
| 347 |
-
await websocket.send_json({"ok": False, "error": f"Unknown method: {method}"})
|
| 348 |
-
except (WebSocketDisconnect, json.JSONDecodeError):
|
| 349 |
-
LOGGER.info("WebSocket session ended session_id=%s", session.session_id)
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
@app.get("/reset")
|
| 353 |
-
async def reset_http() -> JSONResponse:
|
| 354 |
-
"""HTTP debug endpoint for reset semantics.
|
| 355 |
-
|
| 356 |
-
Returns:
|
| 357 |
-
JSON response containing reset observation.
|
| 358 |
-
"""
|
| 359 |
-
|
| 360 |
-
obs = await env_service.reset(http_debug_session)
|
| 361 |
-
return JSONResponse(obs.model_dump())
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
@app.post("/step")
|
| 365 |
-
async def step_http(action: StateStrikeAction) -> JSONResponse:
|
| 366 |
-
"""HTTP debug endpoint for step semantics.
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
action: Action payload.
|
| 370 |
-
|
| 371 |
-
Returns:
|
| 372 |
-
JSON response containing post-step observation.
|
| 373 |
-
"""
|
| 374 |
-
|
| 375 |
-
obs = await env_service.step(http_debug_session, action)
|
| 376 |
-
return JSONResponse(obs.model_dump())
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
@app.get("/state")
|
| 380 |
-
async def state_http() -> JSONResponse:
|
| 381 |
-
"""HTTP debug endpoint for state semantics.
|
| 382 |
-
|
| 383 |
-
Returns:
|
| 384 |
-
JSON response containing current session state.
|
| 385 |
-
"""
|
| 386 |
-
|
| 387 |
-
state = await env_service.state(http_debug_session)
|
| 388 |
-
return JSONResponse(state.model_dump())
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
def main() -> None:
|
| 392 |
-
"""Entrypoint for running environment server via python -m."""
|
| 393 |
-
|
| 394 |
-
import uvicorn
|
| 395 |
-
|
| 396 |
-
uvicorn.run("statestrike_env.server:app", host=HOST, port=PORT, reload=False)
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
if __name__ == "__main__":
|
| 400 |
-
main()
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from statestrike_env.environment import StateStrikeEnv, app, main
|
| 4 |
|
| 5 |
+
StateStrikeEnvironment = StateStrikeEnv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
__all__ = ["StateStrikeEnvironment", "StateStrikeEnv", "app", "main"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
statestrike_env/session.py
CHANGED
|
@@ -1,122 +1,86 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
"""Session state manager for per-agent environment isolation."""
|
| 4 |
-
|
| 5 |
from dataclasses import dataclass, field
|
|
|
|
| 6 |
from uuid import uuid4
|
| 7 |
|
| 8 |
-
from statestrike_env.constants import DEFAULT_BASELINE_LATENCY_MS
|
| 9 |
-
from statestrike_env.models import
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class StateStrikeSession:
|
| 14 |
-
"""Mutable per-WebSocket environment session.
|
| 15 |
-
|
| 16 |
-
Attributes:
|
| 17 |
-
session_id: Current episode UUID.
|
| 18 |
-
step_count: Number of steps taken in current episode.
|
| 19 |
-
cumulative_reward: Running reward total.
|
| 20 |
-
order_count: Number of POST /orders actions issued.
|
| 21 |
-
baseline_latency: Rolling average latency used in reward normalization.
|
| 22 |
-
action_history: Most recent action history window.
|
| 23 |
-
triggered_vulns: Vulnerabilities discovered in current episode.
|
| 24 |
-
redos_bounty_awarded: One-time ReDoS bounty guard.
|
| 25 |
-
db_degradation_bounty_awarded: One-time DB degradation bounty guard.
|
| 26 |
-
last_chain_bonus_step: Last step where chain bonus was awarded.
|
| 27 |
-
post_count_at_last_chain: Order count snapshot at last chain award.
|
| 28 |
-
baseline_sample_count: Number of successful baseline samples seen.
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
session_id: str
|
|
|
|
| 32 |
step_count: int = 0
|
| 33 |
cumulative_reward: float = 0.0
|
| 34 |
order_count: int = 0
|
| 35 |
baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
redos_bounty_awarded: bool = False
|
| 40 |
db_degradation_bounty_awarded: bool = False
|
| 41 |
-
# Anti-hacking: chain bonus can only fire once between meaningful progress windows.
|
| 42 |
last_chain_bonus_step: int = -10
|
| 43 |
post_count_at_last_chain: int = 0
|
| 44 |
-
# Baseline integrity: updated only on successful (non-zero latency) steps.
|
| 45 |
baseline_sample_count: int = 0
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
-
def new_session(cls) -> StateStrikeSession:
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def reset(self, baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS) -> None:
|
| 58 |
-
"""Reset session in-place for a new episode.
|
| 59 |
-
|
| 60 |
-
Args:
|
| 61 |
-
baseline_latency: Fresh baseline latency in milliseconds.
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
self.session_id = str(uuid4())
|
|
|
|
| 65 |
self.step_count = 0
|
| 66 |
self.cumulative_reward = 0.0
|
| 67 |
self.order_count = 0
|
| 68 |
self.baseline_latency = baseline_latency
|
| 69 |
-
|
| 70 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.redos_bounty_awarded = False
|
| 72 |
self.db_degradation_bounty_awarded = False
|
| 73 |
self.last_chain_bonus_step = -10
|
| 74 |
self.post_count_at_last_chain = 0
|
| 75 |
self.baseline_sample_count = 1 if baseline_latency > 0 else 0
|
| 76 |
|
| 77 |
-
def record_latency(self, latency_ms: float) -> float:
|
| 78 |
-
"""Update baseline latency using EMA from successful samples.
|
| 79 |
-
|
| 80 |
-
Args:
|
| 81 |
-
latency_ms: Observed latency for the current step.
|
| 82 |
-
|
| 83 |
-
Returns:
|
| 84 |
-
Updated baseline latency.
|
| 85 |
-
"""
|
| 86 |
-
|
| 87 |
-
sample = max(latency_ms, 1.0)
|
| 88 |
-
alpha_ema = 2.0 / (10 + 1)
|
| 89 |
-
if self.baseline_sample_count == 0:
|
| 90 |
-
self.baseline_latency = sample
|
| 91 |
-
else:
|
| 92 |
-
self.baseline_latency = alpha_ema * sample + (1 - alpha_ema) * self.baseline_latency
|
| 93 |
-
self.baseline_sample_count += 1
|
| 94 |
-
return self.baseline_latency
|
| 95 |
-
|
| 96 |
-
def append_action(self, action: StateStrikeAction) -> None:
|
| 97 |
-
"""Append action while enforcing history length constraints.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
action: Action to append.
|
| 101 |
-
"""
|
| 102 |
-
|
| 103 |
-
self.action_history.append(action)
|
| 104 |
-
if len(self.action_history) > MAX_ACTION_HISTORY:
|
| 105 |
-
self.action_history.pop(0)
|
| 106 |
-
|
| 107 |
def as_state(self) -> StateStrikeState:
|
| 108 |
-
"""Convert mutable session internals to external state model.
|
| 109 |
-
|
| 110 |
-
Returns:
|
| 111 |
-
Immutable API-safe state representation.
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
return StateStrikeState(
|
| 115 |
session_id=self.session_id,
|
|
|
|
| 116 |
step_count=self.step_count,
|
| 117 |
cumulative_reward=self.cumulative_reward,
|
| 118 |
order_count=self.order_count,
|
| 119 |
baseline_latency_ms=self.baseline_latency,
|
| 120 |
-
|
| 121 |
-
|
|
|
|
| 122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any
|
| 5 |
from uuid import uuid4
|
| 6 |
|
| 7 |
+
from statestrike_env.constants import DEFAULT_BASELINE_LATENCY_MS
|
| 8 |
+
from statestrike_env.models import StateStrikeState
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class StateStrikeSession:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
session_id: str
|
| 14 |
+
task_name: str = "endpoint_discovery"
|
| 15 |
step_count: int = 0
|
| 16 |
cumulative_reward: float = 0.0
|
| 17 |
order_count: int = 0
|
| 18 |
baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS
|
| 19 |
+
|
| 20 |
+
endpoints_discovered: set[str] = field(default_factory=set)
|
| 21 |
+
vulnerabilities_found: set[str] = field(default_factory=set)
|
| 22 |
+
task_specific_state: dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
steps_history: list[dict[str, Any]] = field(default_factory=list)
|
| 24 |
+
|
| 25 |
+
user_created: bool = False
|
| 26 |
+
previous_task_score: float = 0.0
|
| 27 |
+
last_action_signature: str | None = None
|
| 28 |
+
|
| 29 |
redos_bounty_awarded: bool = False
|
| 30 |
db_degradation_bounty_awarded: bool = False
|
|
|
|
| 31 |
last_chain_bonus_step: int = -10
|
| 32 |
post_count_at_last_chain: int = 0
|
|
|
|
| 33 |
baseline_sample_count: int = 0
|
| 34 |
|
| 35 |
@classmethod
|
| 36 |
+
def new_session(cls, task_name: str = "endpoint_discovery") -> StateStrikeSession:
|
| 37 |
+
return cls(session_id=str(uuid4()), task_name=task_name)
|
| 38 |
+
|
| 39 |
+
def reset(
|
| 40 |
+
self,
|
| 41 |
+
task_name: str,
|
| 42 |
+
baseline_latency: float = DEFAULT_BASELINE_LATENCY_MS,
|
| 43 |
+
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
self.session_id = str(uuid4())
|
| 45 |
+
self.task_name = task_name
|
| 46 |
self.step_count = 0
|
| 47 |
self.cumulative_reward = 0.0
|
| 48 |
self.order_count = 0
|
| 49 |
self.baseline_latency = baseline_latency
|
| 50 |
+
|
| 51 |
+
self.endpoints_discovered.clear()
|
| 52 |
+
self.vulnerabilities_found.clear()
|
| 53 |
+
self.task_specific_state = {}
|
| 54 |
+
self.steps_history.clear()
|
| 55 |
+
|
| 56 |
+
self.user_created = False
|
| 57 |
+
self.previous_task_score = 0.0
|
| 58 |
+
self.last_action_signature = None
|
| 59 |
+
|
| 60 |
self.redos_bounty_awarded = False
|
| 61 |
self.db_degradation_bounty_awarded = False
|
| 62 |
self.last_chain_bonus_step = -10
|
| 63 |
self.post_count_at_last_chain = 0
|
| 64 |
self.baseline_sample_count = 1 if baseline_latency > 0 else 0
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def as_state(self) -> StateStrikeState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return StateStrikeState(
|
| 68 |
session_id=self.session_id,
|
| 69 |
+
task_name=self.task_name,
|
| 70 |
step_count=self.step_count,
|
| 71 |
cumulative_reward=self.cumulative_reward,
|
| 72 |
order_count=self.order_count,
|
| 73 |
baseline_latency_ms=self.baseline_latency,
|
| 74 |
+
endpoints_discovered=sorted(self.endpoints_discovered),
|
| 75 |
+
vulnerabilities_found=sorted(self.vulnerabilities_found),
|
| 76 |
+
task_specific_state=dict(self.task_specific_state),
|
| 77 |
)
|
| 78 |
+
|
| 79 |
+
def as_grader_state(self) -> dict[str, Any]:
|
| 80 |
+
return {
|
| 81 |
+
"endpoints_discovered": sorted(self.endpoints_discovered),
|
| 82 |
+
"vulnerabilities_found": sorted(self.vulnerabilities_found),
|
| 83 |
+
"steps_history": list(self.steps_history),
|
| 84 |
+
"order_count": self.order_count,
|
| 85 |
+
"user_created": self.user_created,
|
| 86 |
+
}
|
statestrike_env/tasks.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class TaskConfig:
|
| 9 |
+
task_id: str
|
| 10 |
+
description: str
|
| 11 |
+
max_steps: int
|
| 12 |
+
difficulty: str
|
| 13 |
+
success_threshold: float
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
TASK_ENDPOINT_DISCOVERY = TaskConfig(
|
| 17 |
+
task_id="endpoint_discovery",
|
| 18 |
+
description="Discover all 5 reachable API endpoints by probing them.",
|
| 19 |
+
max_steps=20,
|
| 20 |
+
difficulty="easy",
|
| 21 |
+
success_threshold=0.6,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
ALL_ENDPOINTS = {
|
| 25 |
+
"POST /users",
|
| 26 |
+
"GET /users/{id}",
|
| 27 |
+
"POST /orders",
|
| 28 |
+
"GET /orders",
|
| 29 |
+
"GET /health",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class EndpointDiscoveryGrader:
|
| 34 |
+
def score(self, session_state: dict[str, Any]) -> float:
|
| 35 |
+
found = set(session_state.get("endpoints_discovered", []))
|
| 36 |
+
return round(len(found & ALL_ENDPOINTS) / len(ALL_ENDPOINTS), 4)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
TASK_VULNERABILITY_PROBE = TaskConfig(
|
| 40 |
+
task_id="vulnerability_probe",
|
| 41 |
+
description=(
|
| 42 |
+
"Identify and correctly classify at least one vulnerability. "
|
| 43 |
+
"Vulnerabilities: 'redos' (POST /users with attack payload, latency > 1500ms), "
|
| 44 |
+
"'db_degradation' (GET /orders after 20+ orders, latency > 800ms)."
|
| 45 |
+
),
|
| 46 |
+
max_steps=30,
|
| 47 |
+
difficulty="medium",
|
| 48 |
+
success_threshold=0.5,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class VulnerabilityProbeGrader:
|
| 53 |
+
REDOS_LATENCY_THRESHOLD_MS = 1500.0
|
| 54 |
+
DB_LATENCY_THRESHOLD_MS = 800.0
|
| 55 |
+
|
| 56 |
+
def score(self, session_state: dict[str, Any]) -> float:
|
| 57 |
+
vulns = set(session_state.get("vulnerabilities_found", []))
|
| 58 |
+
steps_history = session_state.get("steps_history", [])
|
| 59 |
+
|
| 60 |
+
score = 0.0
|
| 61 |
+
|
| 62 |
+
if any(s.get("endpoint") != "GET /health" for s in steps_history):
|
| 63 |
+
score = max(score, 0.2)
|
| 64 |
+
|
| 65 |
+
if any(s.get("payload_strategy") != "valid" for s in steps_history):
|
| 66 |
+
score = max(score, 0.4)
|
| 67 |
+
|
| 68 |
+
if "redos" in vulns:
|
| 69 |
+
score = max(score, 0.6)
|
| 70 |
+
|
| 71 |
+
if "db_degradation" in vulns:
|
| 72 |
+
score = max(score, 0.8)
|
| 73 |
+
|
| 74 |
+
if "redos" in vulns and "db_degradation" in vulns:
|
| 75 |
+
score = 1.0
|
| 76 |
+
|
| 77 |
+
return round(score, 4)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
TASK_EXPLOIT_CHAIN = TaskConfig(
|
| 81 |
+
task_id="exploit_chain",
|
| 82 |
+
description=(
|
| 83 |
+
"Execute the full stateful exploit chain in order: "
|
| 84 |
+
"1. Create a user (POST /users, valid payload). "
|
| 85 |
+
"2. Build order history (POST /orders x 21 for same user). "
|
| 86 |
+
"3. Trigger DB degradation (GET /orders with 21 orders, latency > 800ms). "
|
| 87 |
+
"4. Trigger ReDoS (POST /users, redos payload, latency > 1500ms). "
|
| 88 |
+
"All four sub-objectives must be completed for full score."
|
| 89 |
+
),
|
| 90 |
+
max_steps=60,
|
| 91 |
+
difficulty="hard",
|
| 92 |
+
success_threshold=0.75,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ExploitChainGrader:
|
| 97 |
+
def score(self, session_state: dict[str, Any]) -> float:
|
| 98 |
+
score = 0.0
|
| 99 |
+
|
| 100 |
+
if session_state.get("user_created", False):
|
| 101 |
+
score += 0.25
|
| 102 |
+
if session_state.get("order_count", 0) >= 21:
|
| 103 |
+
score += 0.25
|
| 104 |
+
if "db_degradation" in session_state.get("vulnerabilities_found", []):
|
| 105 |
+
score += 0.25
|
| 106 |
+
if "redos" in session_state.get("vulnerabilities_found", []):
|
| 107 |
+
score += 0.25
|
| 108 |
+
|
| 109 |
+
return round(score, 4)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
TASK_REGISTRY = {
|
| 113 |
+
"endpoint_discovery": (TASK_ENDPOINT_DISCOVERY, EndpointDiscoveryGrader()),
|
| 114 |
+
"vulnerability_probe": (TASK_VULNERABILITY_PROBE, VulnerabilityProbeGrader()),
|
| 115 |
+
"exploit_chain": (TASK_EXPLOIT_CHAIN, ExploitChainGrader()),
|
| 116 |
+
}
|