Overhaul env: per-episode state, ast.literal_eval probe parsing, sandboxed verifier with timeouts, 9 black-box functions, slim Dockerfile
Browse files- .gitignore +9 -0
- Dockerfile +12 -13
- README.md +45 -4
- opensleuth_env/__init__.py +28 -1
- opensleuth_env/black_box.py +249 -21
- opensleuth_env/env.py +195 -83
- opensleuth_env/models.py +65 -15
- opensleuth_env/verifier.py +220 -52
- requirements.txt +3 -3
- server.py +73 -24
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.DS_Store
|
| 4 |
+
.env
|
| 5 |
+
.venv/
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
.cache/
|
| 8 |
+
verifier_log.txt
|
| 9 |
+
*.log
|
Dockerfile
CHANGED
|
@@ -1,19 +1,18 @@
|
|
| 1 |
-
|
| 2 |
-
FROM python:3.9-slim
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
COPY ./opensleuth_env /app/opensleuth_env
|
| 9 |
-
COPY ./server.py /app/
|
| 10 |
-
COPY ./requirements.txt /app/
|
| 11 |
|
| 12 |
-
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
|
|
|
| 2 |
|
| 3 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 4 |
+
PIP_NO_CACHE_DIR=1 \
|
| 5 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 6 |
|
| 7 |
+
WORKDIR /app
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
COPY requirements.txt /app/
|
| 10 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
|
| 12 |
+
COPY opensleuth_env /app/opensleuth_env
|
| 13 |
+
COPY server.py /app/
|
| 14 |
+
|
| 15 |
+
EXPOSE 7860
|
| 16 |
|
| 17 |
+
# HF Spaces require listening on $PORT (default 7860). uvicorn binds 0.0.0.0.
|
| 18 |
+
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,51 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: OpenSleuth Env
|
| 3 |
+
emoji: 🕵️
|
| 4 |
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
suggested_hardware: cpu-basic
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# OpenSleuth — Environment
|
| 13 |
+
|
| 14 |
+
FastAPI service that exposes an OpenEnv-style `/reset` + `/step` API for the
|
| 15 |
+
**Algorithmic Detective** task. An agent has to figure out an unknown Python
|
| 16 |
+
function by probing it, then submit Python source that replicates it.
|
| 17 |
+
|
| 18 |
+
## Endpoints
|
| 19 |
+
|
| 20 |
+
| Method | Path | Body | Notes |
|
| 21 |
+
|-------:|---------------|----------------------------------------|----------------------------------------|
|
| 22 |
+
| GET | `/health` | — | Liveness probe. |
|
| 23 |
+
| GET | `/functions` | — | Catalogue of available black-boxes. |
|
| 24 |
+
| POST | `/reset` | `{"target_name": "fibonacci", "seed": 0}` | Starts a new episode, returns initial obs + `episode_id`. |
|
| 25 |
+
| POST | `/step` | `{"episode_id": "...", "action": {...}}` | One agent action. |
|
| 26 |
+
| GET | `/state/{eid}`| — | Inspect the live state of an episode (debug). |
|
| 27 |
+
|
| 28 |
+
### Action shapes
|
| 29 |
+
|
| 30 |
+
```json
|
| 31 |
+
{"action_type": "probe", "input_repr": "5"} // input_repr is parsed via ast.literal_eval
|
| 32 |
+
{"action_type": "submit", "code": "def fibonacci(n):..."}
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Reward
|
| 36 |
+
|
| 37 |
+
* **Probe:** `-1` step cost, plus `+2` per newly-seen output and `+5` per
|
| 38 |
+
newly-seen exception type, encouraging exploration of edge cases.
|
| 39 |
+
* **Submit (terminal):** `100 * matches/fuzz_count` minus a logarithmic
|
| 40 |
+
cyclomatic-complexity penalty. A perfect submission gets a `+50` bonus.
|
| 41 |
+
|
| 42 |
+
## Hardware
|
| 43 |
+
|
| 44 |
+
CPU-only — `cpu-basic` is plenty. Do **not** assign GPU to this Space.
|
| 45 |
+
|
| 46 |
+
## Running locally
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
pip install -r requirements.txt
|
| 50 |
+
uvicorn server:app --port 7860 --reload
|
| 51 |
+
```
|
opensleuth_env/__init__.py
CHANGED
|
@@ -1 +1,28 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenSleuth environment library."""
|
| 2 |
+
|
| 3 |
+
from .env import OpenSleuthEnv
|
| 4 |
+
from .models import (
|
| 5 |
+
Action,
|
| 6 |
+
ProbeAction,
|
| 7 |
+
SubmitAction,
|
| 8 |
+
Observation,
|
| 9 |
+
State,
|
| 10 |
+
StepResponse,
|
| 11 |
+
ResetRequest,
|
| 12 |
+
StepRequest,
|
| 13 |
+
)
|
| 14 |
+
from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"OpenSleuthEnv",
|
| 18 |
+
"Action",
|
| 19 |
+
"ProbeAction",
|
| 20 |
+
"SubmitAction",
|
| 21 |
+
"Observation",
|
| 22 |
+
"State",
|
| 23 |
+
"StepResponse",
|
| 24 |
+
"ResetRequest",
|
| 25 |
+
"StepRequest",
|
| 26 |
+
"BLACK_BOX_FUNCTIONS",
|
| 27 |
+
"FunctionSpec",
|
| 28 |
+
]
|
opensleuth_env/black_box.py
CHANGED
|
@@ -1,31 +1,259 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
a, b = 0, 1
|
| 12 |
for _ in range(n - 1):
|
| 13 |
a, b = b, a + b
|
| 14 |
-
return b
|
| 15 |
|
| 16 |
-
# --- Add more black-box functions for later stages ---
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
Reverses a string.
|
| 21 |
-
- Raises TypeError for non-string inputs.
|
| 22 |
-
"""
|
| 23 |
if not isinstance(s, str):
|
| 24 |
raise TypeError("Input must be a string.")
|
| 25 |
return s[::-1]
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
}
|
|
|
|
| 1 |
+
"""Catalogue of hidden 'black-box' Python functions the agent must reproduce.
|
| 2 |
+
|
| 3 |
+
Each entry pairs the reference implementation with a *typed input domain*
|
| 4 |
+
generator so the verifier can fuzz it, plus a public signature/docstring shown
|
| 5 |
+
to the agent in the prompt. The reference implementation itself is never
|
| 6 |
+
shown to the agent.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
import string
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any, Callable, Dict, List
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ----- Reference implementations --------------------------------------------
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _fibonacci(n: int) -> int:
|
| 21 |
+
if not isinstance(n, int) or isinstance(n, bool) or n <= 0 or n > 90:
|
| 22 |
+
raise ValueError("Input must be a positive integer <= 90.")
|
| 23 |
a, b = 0, 1
|
| 24 |
for _ in range(n - 1):
|
| 25 |
a, b = b, a + b
|
| 26 |
+
return b if n > 0 else a
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
+
def _reverse_string(s: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if not isinstance(s, str):
|
| 31 |
raise TypeError("Input must be a string.")
|
| 32 |
return s[::-1]
|
| 33 |
|
| 34 |
+
|
| 35 |
+
def _is_palindrome(s: str) -> bool:
|
| 36 |
+
if not isinstance(s, str):
|
| 37 |
+
raise TypeError("Input must be a string.")
|
| 38 |
+
cleaned = "".join(ch.lower() for ch in s if ch.isalnum())
|
| 39 |
+
return cleaned == cleaned[::-1]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _digit_sum(n: int) -> int:
|
| 43 |
+
if not isinstance(n, int) or isinstance(n, bool):
|
| 44 |
+
raise TypeError("Input must be int.")
|
| 45 |
+
if n < 0:
|
| 46 |
+
raise ValueError("Input must be non-negative.")
|
| 47 |
+
return sum(int(c) for c in str(n))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _count_vowels(s: str) -> int:
|
| 51 |
+
if not isinstance(s, str):
|
| 52 |
+
raise TypeError("Input must be a string.")
|
| 53 |
+
return sum(1 for c in s.lower() if c in "aeiou")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _gcd(pair) -> int:
|
| 57 |
+
"""Greatest common divisor of two non-negative ints, given as a 2-tuple
|
| 58 |
+
or 2-list. Hidden trick: tuple/list both accepted, ints only."""
|
| 59 |
+
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
| 60 |
+
raise TypeError("Input must be a 2-element list or tuple.")
|
| 61 |
+
a, b = pair
|
| 62 |
+
if not all(isinstance(x, int) and not isinstance(x, bool) for x in (a, b)):
|
| 63 |
+
raise TypeError("Both elements must be int.")
|
| 64 |
+
if a < 0 or b < 0:
|
| 65 |
+
raise ValueError("Both elements must be non-negative.")
|
| 66 |
+
while b:
|
| 67 |
+
a, b = b, a % b
|
| 68 |
+
return a
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _sort_unique(xs) -> list:
|
| 72 |
+
"""Return sorted unique elements from a list of ints."""
|
| 73 |
+
if not isinstance(xs, list):
|
| 74 |
+
raise TypeError("Input must be a list.")
|
| 75 |
+
if not all(isinstance(x, int) and not isinstance(x, bool) for x in xs):
|
| 76 |
+
raise TypeError("All elements must be int.")
|
| 77 |
+
return sorted(set(xs))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _caesar_cipher(s: str) -> str:
|
| 81 |
+
"""Caesar shift by +3 on lowercase letters; everything else unchanged."""
|
| 82 |
+
if not isinstance(s, str):
|
| 83 |
+
raise TypeError("Input must be a string.")
|
| 84 |
+
out = []
|
| 85 |
+
for ch in s:
|
| 86 |
+
if "a" <= ch <= "z":
|
| 87 |
+
out.append(chr((ord(ch) - ord("a") + 3) % 26 + ord("a")))
|
| 88 |
+
else:
|
| 89 |
+
out.append(ch)
|
| 90 |
+
return "".join(out)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _is_prime(n: int) -> bool:
|
| 94 |
+
if not isinstance(n, int) or isinstance(n, bool):
|
| 95 |
+
raise TypeError("Input must be int.")
|
| 96 |
+
if n < 2:
|
| 97 |
+
return False
|
| 98 |
+
if n < 4:
|
| 99 |
+
return True
|
| 100 |
+
if n % 2 == 0:
|
| 101 |
+
return False
|
| 102 |
+
i = 3
|
| 103 |
+
while i * i <= n:
|
| 104 |
+
if n % i == 0:
|
| 105 |
+
return False
|
| 106 |
+
i += 2
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ----- Fuzz input generators ------------------------------------------------
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _fuzz_small_pos_int(rng: random.Random, n: int) -> List[int]:
|
| 114 |
+
return [rng.randint(1, 30) for _ in range(n)]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _fuzz_fib_int(rng: random.Random, n: int) -> List[int]:
|
| 118 |
+
# Mix common values, edges, and random.
|
| 119 |
+
pool = [1, 2, 3, 10, 20, 30, 50, 89, 90]
|
| 120 |
+
return [rng.choice(pool) if rng.random() < 0.3 else rng.randint(1, 90) for _ in range(n)]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _fuzz_short_string(rng: random.Random, n: int) -> List[str]:
|
| 124 |
+
alpha = string.ascii_letters + string.digits
|
| 125 |
+
return ["".join(rng.choices(alpha, k=rng.randint(0, 12))) for _ in range(n)]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _fuzz_palindrome_string(rng: random.Random, n: int) -> List[str]:
|
| 129 |
+
out = []
|
| 130 |
+
for _ in range(n):
|
| 131 |
+
if rng.random() < 0.4:
|
| 132 |
+
base = "".join(rng.choices(string.ascii_lowercase, k=rng.randint(0, 6)))
|
| 133 |
+
out.append(base + base[::-1])
|
| 134 |
+
else:
|
| 135 |
+
out.append("".join(rng.choices(string.ascii_letters + " ", k=rng.randint(0, 12))))
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _fuzz_nonneg_int(rng: random.Random, n: int) -> List[int]:
|
| 140 |
+
return [rng.randint(0, 10_000) for _ in range(n)]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _fuzz_int_pair(rng: random.Random, n: int):
|
| 144 |
+
return [(rng.randint(0, 1000), rng.randint(0, 1000)) for _ in range(n)]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _fuzz_int_list(rng: random.Random, n: int):
|
| 148 |
+
return [
|
| 149 |
+
[rng.randint(-50, 50) for _ in range(rng.randint(0, 8))] for _ in range(n)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _fuzz_lower_string(rng: random.Random, n: int) -> List[str]:
|
| 154 |
+
return [
|
| 155 |
+
"".join(rng.choices(string.ascii_lowercase + " ,!", k=rng.randint(0, 16)))
|
| 156 |
+
for _ in range(n)
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _fuzz_prime_int(rng: random.Random, n: int) -> List[int]:
|
| 161 |
+
# Mix in known primes and composites to cover both branches.
|
| 162 |
+
seeded = [0, 1, 2, 3, 4, 9, 11, 15, 17, 25, 29, 97, 100]
|
| 163 |
+
return [rng.choice(seeded) if rng.random() < 0.3 else rng.randint(0, 200) for _ in range(n)]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ----- Spec ----------------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@dataclass(frozen=True)
|
| 170 |
+
class FunctionSpec:
|
| 171 |
+
name: str
|
| 172 |
+
fn: Callable[[Any], Any]
|
| 173 |
+
signature: str
|
| 174 |
+
description: str
|
| 175 |
+
fuzzer: Callable[[random.Random, int], list]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
BLACK_BOX_FUNCTIONS: Dict[str, FunctionSpec] = {
|
| 179 |
+
spec.name: spec
|
| 180 |
+
for spec in [
|
| 181 |
+
FunctionSpec(
|
| 182 |
+
name="fibonacci",
|
| 183 |
+
fn=_fibonacci,
|
| 184 |
+
signature="fibonacci(n: int) -> int",
|
| 185 |
+
description=(
|
| 186 |
+
"Returns the n-th Fibonacci number. Raises ValueError for "
|
| 187 |
+
"invalid n (n must be a positive int <= 90)."
|
| 188 |
+
),
|
| 189 |
+
fuzzer=_fuzz_fib_int,
|
| 190 |
+
),
|
| 191 |
+
FunctionSpec(
|
| 192 |
+
name="reverse_string",
|
| 193 |
+
fn=_reverse_string,
|
| 194 |
+
signature="reverse_string(s: str) -> str",
|
| 195 |
+
description="Returns the reversed string. Raises TypeError for non-str.",
|
| 196 |
+
fuzzer=_fuzz_short_string,
|
| 197 |
+
),
|
| 198 |
+
FunctionSpec(
|
| 199 |
+
name="is_palindrome",
|
| 200 |
+
fn=_is_palindrome,
|
| 201 |
+
signature="is_palindrome(s: str) -> bool",
|
| 202 |
+
description=(
|
| 203 |
+
"Case-insensitive palindrome check, ignoring non-alphanumeric "
|
| 204 |
+
"characters. Raises TypeError for non-str."
|
| 205 |
+
),
|
| 206 |
+
fuzzer=_fuzz_palindrome_string,
|
| 207 |
+
),
|
| 208 |
+
FunctionSpec(
|
| 209 |
+
name="digit_sum",
|
| 210 |
+
fn=_digit_sum,
|
| 211 |
+
signature="digit_sum(n: int) -> int",
|
| 212 |
+
description=(
|
| 213 |
+
"Sum of the decimal digits of n. n must be a non-negative int."
|
| 214 |
+
),
|
| 215 |
+
fuzzer=_fuzz_nonneg_int,
|
| 216 |
+
),
|
| 217 |
+
FunctionSpec(
|
| 218 |
+
name="count_vowels",
|
| 219 |
+
fn=_count_vowels,
|
| 220 |
+
signature="count_vowels(s: str) -> int",
|
| 221 |
+
description="Count of vowels (a/e/i/o/u, case-insensitive) in s.",
|
| 222 |
+
fuzzer=_fuzz_lower_string,
|
| 223 |
+
),
|
| 224 |
+
FunctionSpec(
|
| 225 |
+
name="gcd",
|
| 226 |
+
fn=_gcd,
|
| 227 |
+
signature="gcd(pair: tuple[int, int] | list[int]) -> int",
|
| 228 |
+
description=(
|
| 229 |
+
"Greatest common divisor of a 2-element tuple/list of "
|
| 230 |
+
"non-negative ints."
|
| 231 |
+
),
|
| 232 |
+
fuzzer=_fuzz_int_pair,
|
| 233 |
+
),
|
| 234 |
+
FunctionSpec(
|
| 235 |
+
name="sort_unique",
|
| 236 |
+
fn=_sort_unique,
|
| 237 |
+
signature="sort_unique(xs: list[int]) -> list[int]",
|
| 238 |
+
description="Sorted, deduplicated list of ints from xs.",
|
| 239 |
+
fuzzer=_fuzz_int_list,
|
| 240 |
+
),
|
| 241 |
+
FunctionSpec(
|
| 242 |
+
name="caesar_cipher",
|
| 243 |
+
fn=_caesar_cipher,
|
| 244 |
+
signature="caesar_cipher(s: str) -> str",
|
| 245 |
+
description=(
|
| 246 |
+
"Caesar shift by +3 on lowercase letters; non-lowercase chars "
|
| 247 |
+
"are unchanged."
|
| 248 |
+
),
|
| 249 |
+
fuzzer=_fuzz_lower_string,
|
| 250 |
+
),
|
| 251 |
+
FunctionSpec(
|
| 252 |
+
name="is_prime",
|
| 253 |
+
fn=_is_prime,
|
| 254 |
+
signature="is_prime(n: int) -> bool",
|
| 255 |
+
description="True iff n is a prime int. n must be int.",
|
| 256 |
+
fuzzer=_fuzz_prime_int,
|
| 257 |
+
),
|
| 258 |
+
]
|
| 259 |
}
|
opensleuth_env/env.py
CHANGED
|
@@ -1,93 +1,205 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class OpenSleuthEnv:
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
Resets the environment to a new episode.
|
| 15 |
-
Selects a black-box function and clears the history.
|
| 16 |
-
"""
|
| 17 |
-
if target_name not in BLACK_BOX_FUNCTIONS:
|
| 18 |
-
raise ValueError(f"Unknown target function: {target_name}")
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
target_function_name=target_name,
|
| 22 |
-
|
| 23 |
-
seen_outputs=set(),
|
| 24 |
-
seen_error_types=set(),
|
| 25 |
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
except Exception as e:
|
| 47 |
-
obs = Observation(probe_history=self.state.probe_history, last_error=f"Invalid action format: {e}")
|
| 48 |
-
return obs, -20.0, True
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if action.action_type == "probe":
|
| 52 |
-
return self._handle_probe(action)
|
| 53 |
-
elif action.action_type == "submit":
|
| 54 |
-
return self._handle_submit(action)
|
| 55 |
else:
|
| 56 |
-
obs =
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
error_type = type(e).__name__
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core OpenSleuth episodic environment.
|
| 2 |
+
|
| 3 |
+
A single OpenSleuthEnv holds a *registry of episodes* keyed by episode_id, so
|
| 4 |
+
multiple training rollouts can hit the same FastAPI server in parallel without
|
| 5 |
+
stepping on each other's state.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import ast
|
| 11 |
+
import logging
|
| 12 |
+
import uuid
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
|
| 16 |
+
from .models import (
|
| 17 |
+
Action,
|
| 18 |
+
Observation,
|
| 19 |
+
ProbeAction,
|
| 20 |
+
ProbeRecord,
|
| 21 |
+
State,
|
| 22 |
+
StepResponse,
|
| 23 |
+
SubmitAction,
|
| 24 |
+
)
|
| 25 |
+
from .verifier import generate_fuzz_inputs, verify_submission
|
| 26 |
+
|
| 27 |
+
log = logging.getLogger("opensleuth.env")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Reward shaping knobs (kept here so they're easy to tune).
|
| 31 |
+
PROBE_STEP_COST = -1.0
|
| 32 |
+
NEW_OUTPUT_BONUS = 2.0
|
| 33 |
+
NEW_ERROR_TYPE_BONUS = 5.0
|
| 34 |
+
PERFECT_SUBMISSION_BONUS = 50.0
|
| 35 |
+
MAX_PROBE_HISTORY_IN_OBS = 25
|
| 36 |
+
|
| 37 |
|
| 38 |
class OpenSleuthEnv:
|
| 39 |
+
"""Multi-episode environment registry."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, fuzz_count: int = 100) -> None:
|
| 42 |
+
self._states: dict[str, State] = {}
|
| 43 |
+
self._configs: dict[str, dict] = {}
|
| 44 |
+
self.fuzz_count = fuzz_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# --- Lifecycle ---------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Observation:
|
| 49 |
+
if target_name not in BLACK_BOX_FUNCTIONS:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"Unknown target function: {target_name!r}. "
|
| 52 |
+
f"Available: {sorted(BLACK_BOX_FUNCTIONS)}"
|
| 53 |
+
)
|
| 54 |
+
spec = BLACK_BOX_FUNCTIONS[target_name]
|
| 55 |
+
episode_id = uuid.uuid4().hex
|
| 56 |
+
self._states[episode_id] = State(
|
| 57 |
+
episode_id=episode_id,
|
| 58 |
target_function_name=target_name,
|
| 59 |
+
seed=seed,
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
+
self._configs[episode_id] = {"max_steps": max_steps}
|
| 62 |
+
return self._build_observation(episode_id, spec, last_error="")
|
| 63 |
+
|
| 64 |
+
def step(self, episode_id: str, action: Action) -> StepResponse:
|
| 65 |
+
state = self._states.get(episode_id)
|
| 66 |
+
if state is None:
|
| 67 |
+
raise KeyError(f"Unknown episode_id {episode_id!r}. Did you /reset first?")
|
| 68 |
+
if state.done:
|
| 69 |
+
spec = BLACK_BOX_FUNCTIONS[state.target_function_name]
|
| 70 |
+
obs = self._build_observation(episode_id, spec, last_error="Episode already terminated.")
|
| 71 |
+
return StepResponse(observation=obs, reward=0.0, done=True, info={"reason": "already_done"})
|
| 72 |
+
|
| 73 |
+
spec = BLACK_BOX_FUNCTIONS[state.target_function_name]
|
| 74 |
+
state.steps_taken += 1
|
| 75 |
+
max_steps = self._configs[episode_id]["max_steps"]
|
| 76 |
+
|
| 77 |
+
if isinstance(action, ProbeAction):
|
| 78 |
+
obs, reward, done, info = self._handle_probe(state, spec, action)
|
| 79 |
+
elif isinstance(action, SubmitAction):
|
| 80 |
+
obs, reward, done, info = self._handle_submit(state, spec, action)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
+
obs = self._build_observation(
|
| 83 |
+
episode_id, spec, last_error=f"Invalid action type: {type(action).__name__}"
|
| 84 |
+
)
|
| 85 |
+
reward, done, info = -20.0, True, {"reason": "invalid_action"}
|
| 86 |
|
| 87 |
+
# Step-budget exhaustion ends the episode with no extra reward.
|
| 88 |
+
if not done and state.steps_taken >= max_steps:
|
| 89 |
+
done = True
|
| 90 |
+
info = {**info, "reason": info.get("reason", "step_limit")}
|
| 91 |
|
| 92 |
+
if done:
|
| 93 |
+
state.done = True
|
| 94 |
+
return StepResponse(observation=obs, reward=reward, done=done, info=info)
|
| 95 |
+
|
| 96 |
+
# --- Action handlers ---------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _handle_probe(
|
| 99 |
+
self, state: State, spec: FunctionSpec, action: ProbeAction
|
| 100 |
+
) -> Tuple[Observation, float, bool, dict]:
|
| 101 |
+
# Parse the agent's input from a Python literal repr.
|
| 102 |
+
try:
|
| 103 |
+
parsed = ast.literal_eval(action.input_repr)
|
| 104 |
+
except (ValueError, SyntaxError) as e:
|
| 105 |
+
err = f"Could not parse input_repr as a Python literal: {e}"
|
| 106 |
+
state.probe_history.append(
|
| 107 |
+
ProbeRecord(
|
| 108 |
+
input_repr=action.input_repr,
|
| 109 |
+
output_repr=err,
|
| 110 |
+
is_error=True,
|
| 111 |
+
error_type="ParseError",
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
obs = self._build_observation(state.episode_id, spec, last_error=err)
|
| 115 |
+
return obs, PROBE_STEP_COST, False, {"reason": "parse_error"}
|
| 116 |
+
|
| 117 |
+
intrinsic = 0.0
|
| 118 |
+
last_error = ""
|
| 119 |
try:
|
| 120 |
+
output = spec.fn(parsed)
|
| 121 |
+
output_repr = repr(output)
|
| 122 |
+
state.probe_history.append(
|
| 123 |
+
ProbeRecord(input_repr=repr(parsed), output_repr=output_repr, is_error=False)
|
| 124 |
+
)
|
| 125 |
+
if output_repr not in state.seen_outputs:
|
| 126 |
+
intrinsic += NEW_OUTPUT_BONUS
|
| 127 |
+
state.seen_outputs.add(output_repr)
|
| 128 |
+
except Exception as e: # noqa: BLE001
|
| 129 |
error_type = type(e).__name__
|
| 130 |
+
err_repr = f"{error_type}: {e}"
|
| 131 |
+
state.probe_history.append(
|
| 132 |
+
ProbeRecord(
|
| 133 |
+
input_repr=repr(parsed),
|
| 134 |
+
output_repr=err_repr,
|
| 135 |
+
is_error=True,
|
| 136 |
+
error_type=error_type,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
last_error = err_repr
|
| 140 |
+
if error_type not in state.seen_error_types:
|
| 141 |
+
intrinsic += NEW_ERROR_TYPE_BONUS
|
| 142 |
+
state.seen_error_types.add(error_type)
|
| 143 |
+
|
| 144 |
+
reward = intrinsic + PROBE_STEP_COST
|
| 145 |
+
obs = self._build_observation(state.episode_id, spec, last_error=last_error)
|
| 146 |
+
return obs, reward, False, {"intrinsic": intrinsic}
|
| 147 |
+
|
| 148 |
+
def _handle_submit(
|
| 149 |
+
self, state: State, spec: FunctionSpec, action: SubmitAction
|
| 150 |
+
) -> Tuple[Observation, float, bool, dict]:
|
| 151 |
+
fuzz_inputs = generate_fuzz_inputs(spec, count=self.fuzz_count, seed=state.seed)
|
| 152 |
+
result = verify_submission(action.code, spec.fn, fuzz_inputs, target_name=spec.name)
|
| 153 |
+
|
| 154 |
+
total = result.execution_reward - result.complexity_penalty
|
| 155 |
+
if result.execution_reward >= 99.999:
|
| 156 |
+
total += PERFECT_SUBMISSION_BONUS
|
| 157 |
+
|
| 158 |
+
obs = self._build_observation(
|
| 159 |
+
state.episode_id,
|
| 160 |
+
spec,
|
| 161 |
+
last_error=result.define_error or "",
|
| 162 |
+
)
|
| 163 |
+
info = {
|
| 164 |
+
"execution_reward": result.execution_reward,
|
| 165 |
+
"complexity_penalty": result.complexity_penalty,
|
| 166 |
+
"matches": result.matches,
|
| 167 |
+
"fuzz_count": result.fuzz_count,
|
| 168 |
+
"define_error": result.define_error,
|
| 169 |
+
"reason": "submission",
|
| 170 |
+
}
|
| 171 |
+
return obs, total, True, info
|
| 172 |
+
|
| 173 |
+
# --- Helpers -----------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def _build_observation(
|
| 176 |
+
self, episode_id: str, spec: FunctionSpec, last_error: str
|
| 177 |
+
) -> Observation:
|
| 178 |
+
state = self._states[episode_id]
|
| 179 |
+
max_steps = self._configs[episode_id]["max_steps"]
|
| 180 |
+
history = state.probe_history[-MAX_PROBE_HISTORY_IN_OBS:]
|
| 181 |
+
return Observation(
|
| 182 |
+
episode_id=episode_id,
|
| 183 |
+
target_function_name=state.target_function_name,
|
| 184 |
+
target_function_signature=f"{spec.signature}\n\n{spec.description}",
|
| 185 |
+
probe_history=history,
|
| 186 |
+
last_error=last_error,
|
| 187 |
+
steps_taken=state.steps_taken,
|
| 188 |
+
max_steps=max_steps,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# --- Introspection -----------------------------------------------------
|
| 192 |
+
|
| 193 |
+
def get_state(self, episode_id: str) -> dict:
|
| 194 |
+
s = self._states.get(episode_id)
|
| 195 |
+
if s is None:
|
| 196 |
+
return {}
|
| 197 |
+
return {
|
| 198 |
+
"episode_id": s.episode_id,
|
| 199 |
+
"target_function_name": s.target_function_name,
|
| 200 |
+
"steps_taken": s.steps_taken,
|
| 201 |
+
"done": s.done,
|
| 202 |
+
"seen_outputs": sorted(s.seen_outputs),
|
| 203 |
+
"seen_error_types": sorted(s.seen_error_types),
|
| 204 |
+
"probe_history": [r.model_dump() for r in s.probe_history],
|
| 205 |
+
}
|
opensleuth_env/models.py
CHANGED
|
@@ -1,29 +1,79 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class ProbeAction(BaseModel):
|
| 5 |
action_type: Literal["probe"] = "probe"
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class SubmitAction(BaseModel):
|
| 9 |
action_type: Literal["submit"] = "submit"
|
| 10 |
-
code: str
|
|
|
|
| 11 |
|
| 12 |
Action = Union[ProbeAction, SubmitAction]
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class Observation(BaseModel):
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
last_error: str = Field(
|
| 20 |
-
"",
|
| 21 |
-
description="The error message from the last action, if any."
|
| 22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class State(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
target_function_name: str
|
| 26 |
-
probe_history: List[
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for the OpenSleuth API and core state."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, List, Literal, Optional, Tuple, Union
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
+
|
| 8 |
|
| 9 |
class ProbeAction(BaseModel):
|
| 10 |
action_type: Literal["probe"] = "probe"
|
| 11 |
+
# The agent submits inputs as a Python literal string (e.g. "5", "'abc'",
|
| 12 |
+
# "[1, 2, 3]"). We parse it server-side with ast.literal_eval. Keeping it
|
| 13 |
+
# as a string avoids a class of FastAPI auto-coercion bugs and matches
|
| 14 |
+
# what an LLM naturally emits.
|
| 15 |
+
input_repr: str = Field(..., description="Python literal repr of the probe input")
|
| 16 |
+
|
| 17 |
|
| 18 |
class SubmitAction(BaseModel):
|
| 19 |
action_type: Literal["submit"] = "submit"
|
| 20 |
+
code: str = Field(..., description="Python source defining the target function")
|
| 21 |
+
|
| 22 |
|
| 23 |
Action = Union[ProbeAction, SubmitAction]
|
| 24 |
|
| 25 |
+
|
| 26 |
+
class ProbeRecord(BaseModel):
|
| 27 |
+
"""One entry in the probe history. Output is either the function's return
|
| 28 |
+
value (Pythonic repr) or, if it raised, an error string."""
|
| 29 |
+
|
| 30 |
+
input_repr: str
|
| 31 |
+
output_repr: str
|
| 32 |
+
is_error: bool = False
|
| 33 |
+
error_type: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class Observation(BaseModel):
|
| 37 |
+
episode_id: str
|
| 38 |
+
target_function_name: str
|
| 39 |
+
target_function_signature: str = Field(
|
| 40 |
+
"", description="Human readable signature + docstring shown to the agent"
|
|
|
|
|
|
|
|
|
|
| 41 |
)
|
| 42 |
+
probe_history: List[ProbeRecord] = Field(default_factory=list)
|
| 43 |
+
last_error: str = ""
|
| 44 |
+
steps_taken: int = 0
|
| 45 |
+
max_steps: int = 25
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StepResponse(BaseModel):
|
| 49 |
+
observation: Observation
|
| 50 |
+
reward: float
|
| 51 |
+
done: bool
|
| 52 |
+
info: dict = Field(default_factory=dict)
|
| 53 |
+
|
| 54 |
|
| 55 |
class State(BaseModel):
|
| 56 |
+
"""Internal mutable state for one episode. Not exposed in /step responses
|
| 57 |
+
in full, but available via /state/{eid} for debugging."""
|
| 58 |
+
|
| 59 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 60 |
+
|
| 61 |
+
episode_id: str
|
| 62 |
target_function_name: str
|
| 63 |
+
probe_history: List[ProbeRecord] = Field(default_factory=list)
|
| 64 |
+
seen_outputs: set = Field(default_factory=set)
|
| 65 |
+
seen_error_types: set = Field(default_factory=set)
|
| 66 |
+
steps_taken: int = 0
|
| 67 |
+
done: bool = False
|
| 68 |
+
seed: int = 0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ResetRequest(BaseModel):
|
| 72 |
+
target_name: str = "fibonacci"
|
| 73 |
+
seed: int = 0
|
| 74 |
+
max_steps: int = 25
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class StepRequest(BaseModel):
|
| 78 |
+
episode_id: str
|
| 79 |
+
action: Action
|
opensleuth_env/verifier.py
CHANGED
|
@@ -1,68 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import ast
|
| 2 |
-
import random
|
| 3 |
-
import string
|
| 4 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
def __init__(self):
|
| 8 |
-
self.
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
self.
|
| 12 |
-
def visit_For(self, node):
|
| 13 |
-
self.complexity += 1
|
| 14 |
-
self.generic_visit(node)
|
| 15 |
-
def visit_While(self, node):
|
| 16 |
-
self.complexity += 1
|
| 17 |
-
self.generic_visit(node)
|
| 18 |
-
def visit_And(self, node):
|
| 19 |
-
self.complexity += 1
|
| 20 |
-
self.generic_visit(node)
|
| 21 |
-
def visit_Or(self, node):
|
| 22 |
-
self.complexity += 1
|
| 23 |
self.generic_visit(node)
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.generic_visit(node)
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
try:
|
| 30 |
tree = ast.parse(code)
|
| 31 |
-
visitor = ComplexityVisitor()
|
| 32 |
-
visitor.visit(tree)
|
| 33 |
-
return math.log(visitor.complexity)
|
| 34 |
except SyntaxError:
|
| 35 |
-
return 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
def _generate_fuzz_inputs(target_func, count=100):
|
| 38 |
-
inputs = []
|
| 39 |
-
if target_func.__name__ == "fibonacci":
|
| 40 |
-
inputs = [random.randint(1, 90) for _ in range(count)]
|
| 41 |
-
elif target_func.__name__ == "reverse_string":
|
| 42 |
-
inputs = [''.join(random.choices(string.ascii_letters + string.digits, k=random.randint(1, 20))) for _ in range(count)]
|
| 43 |
-
return inputs
|
| 44 |
|
| 45 |
-
def
|
|
|
|
| 46 |
try:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
matches = 0
|
| 57 |
for inp in fuzz_inputs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
try:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Verifier: scores a submitted Python source against a hidden reference
|
| 2 |
+
function by domain-aware fuzzing, with sandboxed execution and a complexity
|
| 3 |
+
penalty.
|
| 4 |
+
|
| 5 |
+
Reward design:
|
| 6 |
+
execution_reward in [0, 100] = 100 * matches/fuzz_count
|
| 7 |
+
complexity_penalty in [0, 50] = log(cyclomatic) clipped, else 50 on syntax error
|
| 8 |
+
exec_failure_penalty = 25 if def-time exec raised, before fuzzing
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
import ast
|
|
|
|
|
|
|
| 14 |
import math
|
| 15 |
+
import multiprocessing as mp
|
| 16 |
+
import random
|
| 17 |
+
import signal
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, List, Optional
|
| 20 |
+
|
| 21 |
|
| 22 |
+
# ----- AST complexity ------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _CCVisitor(ast.NodeVisitor):
|
| 26 |
def __init__(self):
|
| 27 |
+
self.cc = 1
|
| 28 |
+
|
| 29 |
+
def _bump(self, node):
|
| 30 |
+
self.cc += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.generic_visit(node)
|
| 32 |
+
|
| 33 |
+
visit_If = _bump
|
| 34 |
+
visit_For = _bump
|
| 35 |
+
visit_While = _bump
|
| 36 |
+
visit_AsyncFor = _bump
|
| 37 |
+
visit_ExceptHandler = _bump
|
| 38 |
+
visit_With = _bump
|
| 39 |
+
visit_IfExp = _bump
|
| 40 |
+
|
| 41 |
+
def visit_BoolOp(self, node):
|
| 42 |
+
self.cc += max(0, len(node.values) - 1)
|
| 43 |
self.generic_visit(node)
|
| 44 |
|
| 45 |
+
|
| 46 |
+
def calculate_complexity_penalty(code: str) -> float:
|
| 47 |
+
"""Bounded log-scaled cyclomatic complexity, or 50 if code won't parse."""
|
| 48 |
try:
|
| 49 |
tree = ast.parse(code)
|
|
|
|
|
|
|
|
|
|
| 50 |
except SyntaxError:
|
| 51 |
+
return 50.0
|
| 52 |
+
v = _CCVisitor()
|
| 53 |
+
v.visit(tree)
|
| 54 |
+
# log2 keeps small functions at ~0..2 and aggressive 100-branch lookups
|
| 55 |
+
# up around log2(100) ≈ 6.6, then we clamp.
|
| 56 |
+
return min(50.0, math.log2(v.cc))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ----- Sandboxed execution -------------------------------------------------
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
def _exec_target_in_sandbox(code: str, target_name: str, queue: mp.Queue) -> None:
|
| 63 |
+
"""Run inside a child process so we can hard-kill on timeout."""
|
| 64 |
try:
|
| 65 |
+
# Restricted but still useful builtins. Submitted code rarely needs
|
| 66 |
+
# imports beyond math/string/itertools/functools, so we whitelist.
|
| 67 |
+
allowed_modules = {"math", "string", "itertools", "functools", "collections", "re"}
|
| 68 |
+
safe_globals = {
|
| 69 |
+
"__builtins__": __builtins__,
|
| 70 |
+
"__name__": "__sandbox__",
|
| 71 |
+
}
|
| 72 |
+
# Pre-import the whitelisted modules so the agent can use them
|
| 73 |
+
# without needing import statements (and we keep everything inproc).
|
| 74 |
+
for mod_name in allowed_modules:
|
| 75 |
+
safe_globals[mod_name] = __import__(mod_name)
|
| 76 |
+
|
| 77 |
+
local_scope: dict = {}
|
| 78 |
+
exec(code, safe_globals, local_scope)
|
| 79 |
+
fn = local_scope.get(target_name) or safe_globals.get(target_name)
|
| 80 |
+
if not callable(fn):
|
| 81 |
+
queue.put(("err", f"No callable named {target_name!r} defined."))
|
| 82 |
+
return
|
| 83 |
+
queue.put(("ok", None))
|
| 84 |
+
except Exception as e: # noqa: BLE001
|
| 85 |
+
queue.put(("err", f"{type(e).__name__}: {e}"))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _can_define(code: str, target_name: str, timeout_s: float) -> Optional[str]:
|
| 89 |
+
"""Return None if the submitted code defines the target callable, else an
|
| 90 |
+
error string. Uses a child process with a wall-clock timeout."""
|
| 91 |
+
ctx = mp.get_context("fork") if mp.get_start_method(allow_none=True) != "spawn" else mp.get_context("spawn")
|
| 92 |
+
q: mp.Queue = ctx.Queue()
|
| 93 |
+
p = ctx.Process(target=_exec_target_in_sandbox, args=(code, target_name, q))
|
| 94 |
+
p.start()
|
| 95 |
+
p.join(timeout=timeout_s)
|
| 96 |
+
if p.is_alive():
|
| 97 |
+
p.terminate()
|
| 98 |
+
p.join(1.0)
|
| 99 |
+
if p.is_alive():
|
| 100 |
+
p.kill()
|
| 101 |
+
return f"Definition timed out after {timeout_s}s."
|
| 102 |
+
if q.empty():
|
| 103 |
+
return "Sandbox produced no result."
|
| 104 |
+
status, payload = q.get()
|
| 105 |
+
return None if status == "ok" else payload
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Per-call (per-input) sandboxing is too slow for 100 fuzz inputs, so we
|
| 109 |
+
# accept the trade-off of running the submitted callable in-process for
|
| 110 |
+
# fuzzing, but we wrap each call in a SIGALRM-based timeout and we already
|
| 111 |
+
# proved at definition-time that the import didn't blow up.
|
| 112 |
+
|
| 113 |
+
class _CallTimeout(Exception):
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _call_with_timeout(fn: Callable, arg: Any, timeout_s: float):
|
| 118 |
+
def _handler(signum, frame): # noqa: ARG001
|
| 119 |
+
raise _CallTimeout()
|
| 120 |
+
|
| 121 |
+
old = signal.signal(signal.SIGALRM, _handler)
|
| 122 |
+
signal.setitimer(signal.ITIMER_REAL, timeout_s)
|
| 123 |
+
try:
|
| 124 |
+
return fn(arg)
|
| 125 |
+
finally:
|
| 126 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 127 |
+
signal.signal(signal.SIGALRM, old)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _safe_call(fn: Callable, arg: Any, timeout_s: float):
|
| 131 |
+
"""Returns (kind, value): kind in {'val', 'err', 'timeout'}."""
|
| 132 |
+
try:
|
| 133 |
+
return ("val", _call_with_timeout(fn, arg, timeout_s))
|
| 134 |
+
except _CallTimeout:
|
| 135 |
+
return ("timeout", f"timed out after {timeout_s}s")
|
| 136 |
+
except Exception as e: # noqa: BLE001
|
| 137 |
+
return ("err", f"{type(e).__name__}: {e}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ----- Public scoring ------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass
|
| 144 |
+
class VerificationResult:
|
| 145 |
+
execution_reward: float
|
| 146 |
+
complexity_penalty: float
|
| 147 |
+
define_error: Optional[str]
|
| 148 |
+
matches: int
|
| 149 |
+
fuzz_count: int
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def verify_submission(
|
| 153 |
+
submitted_code: str,
|
| 154 |
+
target_function: Callable[[Any], Any],
|
| 155 |
+
fuzz_inputs: List[Any],
|
| 156 |
+
*,
|
| 157 |
+
target_name: Optional[str] = None,
|
| 158 |
+
define_timeout_s: float = 5.0,
|
| 159 |
+
call_timeout_s: float = 1.0,
|
| 160 |
+
) -> VerificationResult:
|
| 161 |
+
"""Score `submitted_code` against `target_function` over the supplied
|
| 162 |
+
`fuzz_inputs`. The agent is expected to define a top-level function with
|
| 163 |
+
the same name as `target_function` (overridable via `target_name`)."""
|
| 164 |
+
name = target_name or target_function.__name__
|
| 165 |
+
|
| 166 |
+
define_err = _can_define(submitted_code, name, define_timeout_s)
|
| 167 |
+
complexity = calculate_complexity_penalty(submitted_code)
|
| 168 |
+
if define_err is not None:
|
| 169 |
+
return VerificationResult(
|
| 170 |
+
execution_reward=0.0,
|
| 171 |
+
complexity_penalty=complexity,
|
| 172 |
+
define_error=define_err,
|
| 173 |
+
matches=0,
|
| 174 |
+
fuzz_count=len(fuzz_inputs),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Re-define in-process for fast fuzzing. We just confirmed it won't blow
|
| 178 |
+
# up at import-time; we still time-bound each call.
|
| 179 |
+
safe_globals: dict = {
|
| 180 |
+
"__builtins__": __builtins__,
|
| 181 |
+
"__name__": "__opensleuth_submission__",
|
| 182 |
+
"math": __import__("math"),
|
| 183 |
+
"string": __import__("string"),
|
| 184 |
+
"itertools": __import__("itertools"),
|
| 185 |
+
"functools": __import__("functools"),
|
| 186 |
+
"collections": __import__("collections"),
|
| 187 |
+
"re": __import__("re"),
|
| 188 |
+
}
|
| 189 |
+
local_scope: dict = {}
|
| 190 |
+
exec(submitted_code, safe_globals, local_scope)
|
| 191 |
+
submitted_fn = local_scope.get(name) or safe_globals.get(name)
|
| 192 |
+
|
| 193 |
matches = 0
|
| 194 |
for inp in fuzz_inputs:
|
| 195 |
+
ref = _safe_call(target_function, inp, call_timeout_s)
|
| 196 |
+
sub = _safe_call(submitted_fn, inp, call_timeout_s)
|
| 197 |
+
if _outputs_equivalent(ref, sub):
|
| 198 |
+
matches += 1
|
| 199 |
+
|
| 200 |
+
fuzz_count = len(fuzz_inputs) or 1
|
| 201 |
+
exec_reward = 100.0 * (matches / fuzz_count)
|
| 202 |
+
return VerificationResult(
|
| 203 |
+
execution_reward=exec_reward,
|
| 204 |
+
complexity_penalty=complexity,
|
| 205 |
+
define_error=None,
|
| 206 |
+
matches=matches,
|
| 207 |
+
fuzz_count=fuzz_count,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _outputs_equivalent(ref, sub) -> bool:
|
| 212 |
+
"""Ref and sub are (kind, value) tuples from `_safe_call`. They count as
|
| 213 |
+
equivalent if both raised the same exception type, or both returned values
|
| 214 |
+
that are == equal."""
|
| 215 |
+
rkind, rval = ref
|
| 216 |
+
skind, sval = sub
|
| 217 |
+
if rkind == "val" and skind == "val":
|
| 218 |
try:
|
| 219 |
+
return rval == sval
|
| 220 |
+
except Exception: # noqa: BLE001
|
| 221 |
+
return False
|
| 222 |
+
if rkind == "err" and skind == "err":
|
| 223 |
+
# Match on exception class name.
|
| 224 |
+
return rval.split(":", 1)[0] == sval.split(":", 1)[0]
|
| 225 |
+
if rkind == "timeout" and skind == "timeout":
|
| 226 |
+
return True
|
| 227 |
+
return False
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def generate_fuzz_inputs(
|
| 231 |
+
spec, count: int = 100, seed: Optional[int] = None
|
| 232 |
+
) -> List[Any]:
|
| 233 |
+
"""Public helper: pull `count` fuzz inputs from a FunctionSpec, optionally
|
| 234 |
+
seeded for reproducibility."""
|
| 235 |
+
rng = random.Random(seed)
|
| 236 |
+
return spec.fuzzer(rng, count)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn
|
| 3 |
-
pydantic
|
|
|
|
| 1 |
+
fastapi==0.115.6
|
| 2 |
+
uvicorn[standard]==0.32.1
|
| 3 |
+
pydantic==2.10.3
|
server.py
CHANGED
|
@@ -1,27 +1,76 @@
|
|
| 1 |
-
|
| 2 |
-
from pydantic import BaseModel
|
| 3 |
-
from opensleuth_env.env import OpenSleuthEnv
|
| 4 |
-
from opensleuth_env.models import Action, Observation
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
env = OpenSleuthEnv()
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server exposing the OpenSleuth environment over HTTP."""
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, HTTPException
|
| 8 |
+
|
| 9 |
+
from opensleuth_env import (
|
| 10 |
+
BLACK_BOX_FUNCTIONS,
|
| 11 |
+
OpenSleuthEnv,
|
| 12 |
+
ProbeAction,
|
| 13 |
+
ResetRequest,
|
| 14 |
+
StepRequest,
|
| 15 |
+
StepResponse,
|
| 16 |
+
SubmitAction,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
| 20 |
+
log = logging.getLogger("opensleuth.server")
|
| 21 |
+
|
| 22 |
+
app = FastAPI(title="OpenSleuth Env", version="0.2.0")
|
| 23 |
env = OpenSleuthEnv()
|
| 24 |
|
| 25 |
+
|
| 26 |
+
@app.get("/health")
|
| 27 |
+
def health():
|
| 28 |
+
return {"status": "ok", "episodes_tracked": len(env._states)} # noqa: SLF001
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.get("/functions")
|
| 32 |
+
def list_functions():
|
| 33 |
+
return {
|
| 34 |
+
"functions": [
|
| 35 |
+
{
|
| 36 |
+
"name": s.name,
|
| 37 |
+
"signature": s.signature,
|
| 38 |
+
"description": s.description,
|
| 39 |
+
}
|
| 40 |
+
for s in BLACK_BOX_FUNCTIONS.values()
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.post("/reset")
|
| 46 |
+
def reset(req: ResetRequest):
|
| 47 |
+
try:
|
| 48 |
+
obs = env.reset(target_name=req.target_name, seed=req.seed, max_steps=req.max_steps)
|
| 49 |
+
except ValueError as e:
|
| 50 |
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
| 51 |
+
return obs
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@app.post("/step", response_model=StepResponse)
|
| 55 |
+
def step(req: StepRequest):
|
| 56 |
+
try:
|
| 57 |
+
return env.step(req.episode_id, req.action)
|
| 58 |
+
except KeyError as e:
|
| 59 |
+
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@app.get("/state/{episode_id}")
|
| 63 |
+
def get_state(episode_id: str):
|
| 64 |
+
state = env.get_state(episode_id)
|
| 65 |
+
if not state:
|
| 66 |
+
raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
|
| 67 |
+
return state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Convenience: a flat /step that does reset+step in one call is occasionally
|
| 71 |
+
# useful for shell-style debugging.
|
| 72 |
+
@app.post("/probe_once")
|
| 73 |
+
def probe_once(target_name: str, input_repr: str):
|
| 74 |
+
obs = env.reset(target_name=target_name)
|
| 75 |
+
resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
|
| 76 |
+
return resp
|