anugrah55 commited on
Commit
d3cd20c
·
verified ·
1 Parent(s): 63bb50c

Overhaul env: per-episode state, ast.literal_eval probe parsing, sandboxed verifier with timeouts, 9 black-box functions, slim Dockerfile

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .DS_Store
4
+ .env
5
+ .venv/
6
+ .pytest_cache/
7
+ .cache/
8
+ verifier_log.txt
9
+ *.log
Dockerfile CHANGED
@@ -1,19 +1,18 @@
1
- # Use a standard Python 3.9 image
2
- FROM python:3.9-slim
3
 
4
- # Set the working directory
5
- WORKDIR /app
 
6
 
7
- # Copy the environment files into the container
8
- COPY ./opensleuth_env /app/opensleuth_env
9
- COPY ./server.py /app/
10
- COPY ./requirements.txt /app/
11
 
12
- # Install dependencies
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
- # Expose the port the app runs on
16
- EXPOSE 8000
 
 
17
 
18
- # Run the application
19
- CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ FROM python:3.11-slim
 
2
 
3
+ ENV PYTHONUNBUFFERED=1 \
4
+ PIP_NO_CACHE_DIR=1 \
5
+ PIP_DISABLE_PIP_VERSION_CHECK=1
6
 
7
+ WORKDIR /app
 
 
 
8
 
9
+ COPY requirements.txt /app/
10
  RUN pip install --no-cache-dir -r requirements.txt
11
 
12
+ COPY opensleuth_env /app/opensleuth_env
13
+ COPY server.py /app/
14
+
15
+ EXPOSE 7860
16
 
17
+ # HF Spaces require listening on $PORT (default 7860). uvicorn binds 0.0.0.0.
18
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,51 @@
1
  ---
2
- title: Opensleuth Env Gemini Cli
3
- emoji: 📊
4
  colorFrom: indigo
5
- colorTo: blue
6
  sdk: docker
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: OpenSleuth Env
3
+ emoji: 🕵️
4
  colorFrom: indigo
5
+ colorTo: pink
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ suggested_hardware: cpu-basic
10
  ---
11
 
12
+ # OpenSleuth Environment
13
+
14
+ FastAPI service that exposes an OpenEnv-style `/reset` + `/step` API for the
15
+ **Algorithmic Detective** task. An agent has to figure out an unknown Python
16
+ function by probing it, then submit Python source that replicates it.
17
+
18
+ ## Endpoints
19
+
20
+ | Method | Path | Body | Notes |
21
+ |-------:|---------------|----------------------------------------|----------------------------------------|
22
+ | GET | `/health` | — | Liveness probe. |
23
+ | GET | `/functions` | — | Catalogue of available black-boxes. |
24
+ | POST | `/reset` | `{"target_name": "fibonacci", "seed": 0}` | Starts a new episode, returns initial obs + `episode_id`. |
25
+ | POST | `/step` | `{"episode_id": "...", "action": {...}}` | One agent action. |
26
+ | GET | `/state/{eid}`| — | Inspect the live state of an episode (debug). |
27
+
28
+ ### Action shapes
29
+
30
+ ```json
31
+ {"action_type": "probe", "input_repr": "5"} // input_repr is parsed via ast.literal_eval
32
+ {"action_type": "submit", "code": "def fibonacci(n):..."}
33
+ ```
34
+
35
+ ### Reward
36
+
37
+ * **Probe:** `-1` step cost, plus `+2` per newly-seen output and `+5` per
38
+ newly-seen exception type, encouraging exploration of edge cases.
39
+ * **Submit (terminal):** `100 * matches/fuzz_count` minus a logarithmic
40
+ cyclomatic-complexity penalty. A perfect submission gets a `+50` bonus.
41
+
42
+ ## Hardware
43
+
44
+ CPU-only — `cpu-basic` is plenty. Do **not** assign GPU to this Space.
45
+
46
+ ## Running locally
47
+
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ uvicorn server:app --port 7860 --reload
51
+ ```
opensleuth_env/__init__.py CHANGED
@@ -1 +1,28 @@
1
- # This file makes the 'opensleuth_env' directory a Python package.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenSleuth environment library."""
2
+
3
+ from .env import OpenSleuthEnv
4
+ from .models import (
5
+ Action,
6
+ ProbeAction,
7
+ SubmitAction,
8
+ Observation,
9
+ State,
10
+ StepResponse,
11
+ ResetRequest,
12
+ StepRequest,
13
+ )
14
+ from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
15
+
16
+ __all__ = [
17
+ "OpenSleuthEnv",
18
+ "Action",
19
+ "ProbeAction",
20
+ "SubmitAction",
21
+ "Observation",
22
+ "State",
23
+ "StepResponse",
24
+ "ResetRequest",
25
+ "StepRequest",
26
+ "BLACK_BOX_FUNCTIONS",
27
+ "FunctionSpec",
28
+ ]
opensleuth_env/black_box.py CHANGED
@@ -1,31 +1,259 @@
1
- def fibonacci(n: int) -> int:
2
- """
3
- Calculates the nth Fibonacci number.
4
- - Handles positive integers up to 90 to avoid large numbers.
5
- - Raises ValueError for non-positive inputs or large inputs.
6
- """
7
- if not isinstance(n, int) or n <= 0 or n > 90:
8
- raise ValueError("Input must be a positive integer less than or equal to 90.")
9
- if n == 1:
10
- return 1
 
 
 
 
 
 
 
 
 
 
 
 
11
  a, b = 0, 1
12
  for _ in range(n - 1):
13
  a, b = b, a + b
14
- return b
15
 
16
- # --- Add more black-box functions for later stages ---
17
 
18
- def reverse_string(s: str) -> str:
19
- """
20
- Reverses a string.
21
- - Raises TypeError for non-string inputs.
22
- """
23
  if not isinstance(s, str):
24
  raise TypeError("Input must be a string.")
25
  return s[::-1]
26
 
27
- # --- Dictionary to hold all available black-box functions ---
28
- BLACK_BOX_FUNCTIONS = {
29
- "fibonacci": fibonacci,
30
- "reverse_string": reverse_string,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
 
1
+ """Catalogue of hidden 'black-box' Python functions the agent must reproduce.
2
+
3
+ Each entry pairs the reference implementation with a *typed input domain*
4
+ generator so the verifier can fuzz it, plus a public signature/docstring shown
5
+ to the agent in the prompt. The reference implementation itself is never
6
+ shown to the agent.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ import string
13
+ from dataclasses import dataclass
14
+ from typing import Any, Callable, Dict, List
15
+
16
+
17
+ # ----- Reference implementations --------------------------------------------
18
+
19
+
20
+ def _fibonacci(n: int) -> int:
21
+ if not isinstance(n, int) or isinstance(n, bool) or n <= 0 or n > 90:
22
+ raise ValueError("Input must be a positive integer <= 90.")
23
  a, b = 0, 1
24
  for _ in range(n - 1):
25
  a, b = b, a + b
26
+ return b if n > 0 else a
27
 
 
28
 
29
+ def _reverse_string(s: str) -> str:
 
 
 
 
30
  if not isinstance(s, str):
31
  raise TypeError("Input must be a string.")
32
  return s[::-1]
33
 
34
+
35
+ def _is_palindrome(s: str) -> bool:
36
+ if not isinstance(s, str):
37
+ raise TypeError("Input must be a string.")
38
+ cleaned = "".join(ch.lower() for ch in s if ch.isalnum())
39
+ return cleaned == cleaned[::-1]
40
+
41
+
42
+ def _digit_sum(n: int) -> int:
43
+ if not isinstance(n, int) or isinstance(n, bool):
44
+ raise TypeError("Input must be int.")
45
+ if n < 0:
46
+ raise ValueError("Input must be non-negative.")
47
+ return sum(int(c) for c in str(n))
48
+
49
+
50
+ def _count_vowels(s: str) -> int:
51
+ if not isinstance(s, str):
52
+ raise TypeError("Input must be a string.")
53
+ return sum(1 for c in s.lower() if c in "aeiou")
54
+
55
+
56
+ def _gcd(pair) -> int:
57
+ """Greatest common divisor of two non-negative ints, given as a 2-tuple
58
+ or 2-list. Hidden trick: tuple/list both accepted, ints only."""
59
+ if not isinstance(pair, (list, tuple)) or len(pair) != 2:
60
+ raise TypeError("Input must be a 2-element list or tuple.")
61
+ a, b = pair
62
+ if not all(isinstance(x, int) and not isinstance(x, bool) for x in (a, b)):
63
+ raise TypeError("Both elements must be int.")
64
+ if a < 0 or b < 0:
65
+ raise ValueError("Both elements must be non-negative.")
66
+ while b:
67
+ a, b = b, a % b
68
+ return a
69
+
70
+
71
+ def _sort_unique(xs) -> list:
72
+ """Return sorted unique elements from a list of ints."""
73
+ if not isinstance(xs, list):
74
+ raise TypeError("Input must be a list.")
75
+ if not all(isinstance(x, int) and not isinstance(x, bool) for x in xs):
76
+ raise TypeError("All elements must be int.")
77
+ return sorted(set(xs))
78
+
79
+
80
+ def _caesar_cipher(s: str) -> str:
81
+ """Caesar shift by +3 on lowercase letters; everything else unchanged."""
82
+ if not isinstance(s, str):
83
+ raise TypeError("Input must be a string.")
84
+ out = []
85
+ for ch in s:
86
+ if "a" <= ch <= "z":
87
+ out.append(chr((ord(ch) - ord("a") + 3) % 26 + ord("a")))
88
+ else:
89
+ out.append(ch)
90
+ return "".join(out)
91
+
92
+
93
+ def _is_prime(n: int) -> bool:
94
+ if not isinstance(n, int) or isinstance(n, bool):
95
+ raise TypeError("Input must be int.")
96
+ if n < 2:
97
+ return False
98
+ if n < 4:
99
+ return True
100
+ if n % 2 == 0:
101
+ return False
102
+ i = 3
103
+ while i * i <= n:
104
+ if n % i == 0:
105
+ return False
106
+ i += 2
107
+ return True
108
+
109
+
110
+ # ----- Fuzz input generators ------------------------------------------------
111
+
112
+
113
+ def _fuzz_small_pos_int(rng: random.Random, n: int) -> List[int]:
114
+ return [rng.randint(1, 30) for _ in range(n)]
115
+
116
+
117
+ def _fuzz_fib_int(rng: random.Random, n: int) -> List[int]:
118
+ # Mix common values, edges, and random.
119
+ pool = [1, 2, 3, 10, 20, 30, 50, 89, 90]
120
+ return [rng.choice(pool) if rng.random() < 0.3 else rng.randint(1, 90) for _ in range(n)]
121
+
122
+
123
+ def _fuzz_short_string(rng: random.Random, n: int) -> List[str]:
124
+ alpha = string.ascii_letters + string.digits
125
+ return ["".join(rng.choices(alpha, k=rng.randint(0, 12))) for _ in range(n)]
126
+
127
+
128
+ def _fuzz_palindrome_string(rng: random.Random, n: int) -> List[str]:
129
+ out = []
130
+ for _ in range(n):
131
+ if rng.random() < 0.4:
132
+ base = "".join(rng.choices(string.ascii_lowercase, k=rng.randint(0, 6)))
133
+ out.append(base + base[::-1])
134
+ else:
135
+ out.append("".join(rng.choices(string.ascii_letters + " ", k=rng.randint(0, 12))))
136
+ return out
137
+
138
+
139
+ def _fuzz_nonneg_int(rng: random.Random, n: int) -> List[int]:
140
+ return [rng.randint(0, 10_000) for _ in range(n)]
141
+
142
+
143
+ def _fuzz_int_pair(rng: random.Random, n: int):
144
+ return [(rng.randint(0, 1000), rng.randint(0, 1000)) for _ in range(n)]
145
+
146
+
147
+ def _fuzz_int_list(rng: random.Random, n: int):
148
+ return [
149
+ [rng.randint(-50, 50) for _ in range(rng.randint(0, 8))] for _ in range(n)
150
+ ]
151
+
152
+
153
+ def _fuzz_lower_string(rng: random.Random, n: int) -> List[str]:
154
+ return [
155
+ "".join(rng.choices(string.ascii_lowercase + " ,!", k=rng.randint(0, 16)))
156
+ for _ in range(n)
157
+ ]
158
+
159
+
160
+ def _fuzz_prime_int(rng: random.Random, n: int) -> List[int]:
161
+ # Mix in known primes and composites to cover both branches.
162
+ seeded = [0, 1, 2, 3, 4, 9, 11, 15, 17, 25, 29, 97, 100]
163
+ return [rng.choice(seeded) if rng.random() < 0.3 else rng.randint(0, 200) for _ in range(n)]
164
+
165
+
166
+ # ----- Spec ----------------------------------------------------------------
167
+
168
+
169
+ @dataclass(frozen=True)
170
+ class FunctionSpec:
171
+ name: str
172
+ fn: Callable[[Any], Any]
173
+ signature: str
174
+ description: str
175
+ fuzzer: Callable[[random.Random, int], list]
176
+
177
+
178
+ BLACK_BOX_FUNCTIONS: Dict[str, FunctionSpec] = {
179
+ spec.name: spec
180
+ for spec in [
181
+ FunctionSpec(
182
+ name="fibonacci",
183
+ fn=_fibonacci,
184
+ signature="fibonacci(n: int) -> int",
185
+ description=(
186
+ "Returns the n-th Fibonacci number. Raises ValueError for "
187
+ "invalid n (n must be a positive int <= 90)."
188
+ ),
189
+ fuzzer=_fuzz_fib_int,
190
+ ),
191
+ FunctionSpec(
192
+ name="reverse_string",
193
+ fn=_reverse_string,
194
+ signature="reverse_string(s: str) -> str",
195
+ description="Returns the reversed string. Raises TypeError for non-str.",
196
+ fuzzer=_fuzz_short_string,
197
+ ),
198
+ FunctionSpec(
199
+ name="is_palindrome",
200
+ fn=_is_palindrome,
201
+ signature="is_palindrome(s: str) -> bool",
202
+ description=(
203
+ "Case-insensitive palindrome check, ignoring non-alphanumeric "
204
+ "characters. Raises TypeError for non-str."
205
+ ),
206
+ fuzzer=_fuzz_palindrome_string,
207
+ ),
208
+ FunctionSpec(
209
+ name="digit_sum",
210
+ fn=_digit_sum,
211
+ signature="digit_sum(n: int) -> int",
212
+ description=(
213
+ "Sum of the decimal digits of n. n must be a non-negative int."
214
+ ),
215
+ fuzzer=_fuzz_nonneg_int,
216
+ ),
217
+ FunctionSpec(
218
+ name="count_vowels",
219
+ fn=_count_vowels,
220
+ signature="count_vowels(s: str) -> int",
221
+ description="Count of vowels (a/e/i/o/u, case-insensitive) in s.",
222
+ fuzzer=_fuzz_lower_string,
223
+ ),
224
+ FunctionSpec(
225
+ name="gcd",
226
+ fn=_gcd,
227
+ signature="gcd(pair: tuple[int, int] | list[int]) -> int",
228
+ description=(
229
+ "Greatest common divisor of a 2-element tuple/list of "
230
+ "non-negative ints."
231
+ ),
232
+ fuzzer=_fuzz_int_pair,
233
+ ),
234
+ FunctionSpec(
235
+ name="sort_unique",
236
+ fn=_sort_unique,
237
+ signature="sort_unique(xs: list[int]) -> list[int]",
238
+ description="Sorted, deduplicated list of ints from xs.",
239
+ fuzzer=_fuzz_int_list,
240
+ ),
241
+ FunctionSpec(
242
+ name="caesar_cipher",
243
+ fn=_caesar_cipher,
244
+ signature="caesar_cipher(s: str) -> str",
245
+ description=(
246
+ "Caesar shift by +3 on lowercase letters; non-lowercase chars "
247
+ "are unchanged."
248
+ ),
249
+ fuzzer=_fuzz_lower_string,
250
+ ),
251
+ FunctionSpec(
252
+ name="is_prime",
253
+ fn=_is_prime,
254
+ signature="is_prime(n: int) -> bool",
255
+ description="True iff n is a prime int. n must be int.",
256
+ fuzzer=_fuzz_prime_int,
257
+ ),
258
+ ]
259
  }
opensleuth_env/env.py CHANGED
@@ -1,93 +1,205 @@
1
- from opensleuth_env.models import Action, Observation, State, ProbeAction, SubmitAction
2
- from opensleuth_env.black_box import BLACK_BOX_FUNCTIONS
3
- from opensleuth_env.verifier import verify_submission
4
- import random
5
- import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class OpenSleuthEnv:
8
- def __init__(self):
9
- self.state = None
10
- # The verifier is now a static function, so no need to init it
11
-
12
- def reset(self, target_name: str = "fibonacci") -> Observation:
13
- """
14
- Resets the environment to a new episode.
15
- Selects a black-box function and clears the history.
16
- """
17
- if target_name not in BLACK_BOX_FUNCTIONS:
18
- raise ValueError(f"Unknown target function: {target_name}")
19
 
20
- self.state = State(
 
 
 
 
 
 
 
 
 
 
 
21
  target_function_name=target_name,
22
- probe_history=[],
23
- seen_outputs=set(),
24
- seen_error_types=set(),
25
  )
26
- return Observation(probe_history=[], last_error="")
27
-
28
- def step(self, action: Action) -> tuple[Observation, float, bool]:
29
- """
30
- Takes a step in the environment.
31
- """
32
- if self.state is None:
33
- # If reset() was not called, do it now.
34
- self.reset()
35
-
36
- # The Pydantic model binding in FastAPI should handle the conversion.
37
- # This check is for robustness.
38
- if not isinstance(action, (ProbeAction, SubmitAction)):
39
- try:
40
- if action.get("action_type") == "probe":
41
- action = ProbeAction(**action)
42
- elif action.get("action_type") == "submit":
43
- action = SubmitAction(**action)
44
- else:
45
- raise ValueError("Invalid action_type")
46
- except Exception as e:
47
- obs = Observation(probe_history=self.state.probe_history, last_error=f"Invalid action format: {e}")
48
- return obs, -20.0, True
49
-
50
-
51
- if action.action_type == "probe":
52
- return self._handle_probe(action)
53
- elif action.action_type == "submit":
54
- return self._handle_submit(action)
55
  else:
56
- obs = Observation(probe_history=self.state.probe_history, last_error=f"Invalid action type: {action.action_type}")
57
- return obs, -20.0, True
 
 
58
 
59
- def _handle_probe(self, action: ProbeAction) -> tuple[Observation, float, bool]:
60
- target_func = BLACK_BOX_FUNCTIONS[self.state.target_function_name]
61
- intrinsic_reward = 0.0
62
- last_error = ""
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
- eval_input = action.input
66
- output = target_func(eval_input)
67
- self.state.probe_history.append((eval_input, output))
68
- if str(output) not in self.state.seen_outputs:
69
- intrinsic_reward += 2.0
70
- self.state.seen_outputs.add(str(output))
71
-
72
- except Exception as e:
 
73
  error_type = type(e).__name__
74
- error_str = traceback.format_exc()
75
- self.state.probe_history.append((action.input, error_str))
76
- last_error = error_str
77
- if error_type not in self.state.seen_error_types:
78
- intrinsic_reward += 5.0
79
- self.state.seen_error_types.add(error_type)
80
-
81
- reward = intrinsic_reward - 1.0
82
- obs = Observation(probe_history=self.state.probe_history, last_error=last_error)
83
- return obs, reward, False
84
-
85
- def _handle_submit(self, action: SubmitAction) -> tuple[Observation, float, bool]:
86
- target_func = BLACK_BOX_FUNCTIONS[self.state.target_function_name]
87
- execution_reward, complexity_penalty = verify_submission(action.code, target_func)
88
- total_reward = execution_reward - complexity_penalty
89
- if execution_reward == 100.0:
90
- total_reward += 50.0
91
-
92
- obs = Observation(probe_history=self.state.probe_history, last_error="")
93
- return obs, total_reward, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core OpenSleuth episodic environment.
2
+
3
+ A single OpenSleuthEnv holds a *registry of episodes* keyed by episode_id, so
4
+ multiple training rollouts can hit the same FastAPI server in parallel without
5
+ stepping on each other's state.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import ast
11
+ import logging
12
+ import uuid
13
+ from typing import Tuple
14
+
15
+ from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
16
+ from .models import (
17
+ Action,
18
+ Observation,
19
+ ProbeAction,
20
+ ProbeRecord,
21
+ State,
22
+ StepResponse,
23
+ SubmitAction,
24
+ )
25
+ from .verifier import generate_fuzz_inputs, verify_submission
26
+
27
+ log = logging.getLogger("opensleuth.env")
28
+
29
+
30
+ # Reward shaping knobs (kept here so they're easy to tune).
31
+ PROBE_STEP_COST = -1.0
32
+ NEW_OUTPUT_BONUS = 2.0
33
+ NEW_ERROR_TYPE_BONUS = 5.0
34
+ PERFECT_SUBMISSION_BONUS = 50.0
35
+ MAX_PROBE_HISTORY_IN_OBS = 25
36
+
37
 
38
  class OpenSleuthEnv:
39
+ """Multi-episode environment registry."""
40
+
41
+ def __init__(self, fuzz_count: int = 100) -> None:
42
+ self._states: dict[str, State] = {}
43
+ self._configs: dict[str, dict] = {}
44
+ self.fuzz_count = fuzz_count
 
 
 
 
 
45
 
46
+ # --- Lifecycle ---------------------------------------------------------
47
+
48
+ def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Observation:
49
+ if target_name not in BLACK_BOX_FUNCTIONS:
50
+ raise ValueError(
51
+ f"Unknown target function: {target_name!r}. "
52
+ f"Available: {sorted(BLACK_BOX_FUNCTIONS)}"
53
+ )
54
+ spec = BLACK_BOX_FUNCTIONS[target_name]
55
+ episode_id = uuid.uuid4().hex
56
+ self._states[episode_id] = State(
57
+ episode_id=episode_id,
58
  target_function_name=target_name,
59
+ seed=seed,
 
 
60
  )
61
+ self._configs[episode_id] = {"max_steps": max_steps}
62
+ return self._build_observation(episode_id, spec, last_error="")
63
+
64
+ def step(self, episode_id: str, action: Action) -> StepResponse:
65
+ state = self._states.get(episode_id)
66
+ if state is None:
67
+ raise KeyError(f"Unknown episode_id {episode_id!r}. Did you /reset first?")
68
+ if state.done:
69
+ spec = BLACK_BOX_FUNCTIONS[state.target_function_name]
70
+ obs = self._build_observation(episode_id, spec, last_error="Episode already terminated.")
71
+ return StepResponse(observation=obs, reward=0.0, done=True, info={"reason": "already_done"})
72
+
73
+ spec = BLACK_BOX_FUNCTIONS[state.target_function_name]
74
+ state.steps_taken += 1
75
+ max_steps = self._configs[episode_id]["max_steps"]
76
+
77
+ if isinstance(action, ProbeAction):
78
+ obs, reward, done, info = self._handle_probe(state, spec, action)
79
+ elif isinstance(action, SubmitAction):
80
+ obs, reward, done, info = self._handle_submit(state, spec, action)
 
 
 
 
 
 
 
 
 
81
  else:
82
+ obs = self._build_observation(
83
+ episode_id, spec, last_error=f"Invalid action type: {type(action).__name__}"
84
+ )
85
+ reward, done, info = -20.0, True, {"reason": "invalid_action"}
86
 
87
+ # Step-budget exhaustion ends the episode with no extra reward.
88
+ if not done and state.steps_taken >= max_steps:
89
+ done = True
90
+ info = {**info, "reason": info.get("reason", "step_limit")}
91
 
92
+ if done:
93
+ state.done = True
94
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
95
+
96
+ # --- Action handlers ---------------------------------------------------
97
+
98
+ def _handle_probe(
99
+ self, state: State, spec: FunctionSpec, action: ProbeAction
100
+ ) -> Tuple[Observation, float, bool, dict]:
101
+ # Parse the agent's input from a Python literal repr.
102
+ try:
103
+ parsed = ast.literal_eval(action.input_repr)
104
+ except (ValueError, SyntaxError) as e:
105
+ err = f"Could not parse input_repr as a Python literal: {e}"
106
+ state.probe_history.append(
107
+ ProbeRecord(
108
+ input_repr=action.input_repr,
109
+ output_repr=err,
110
+ is_error=True,
111
+ error_type="ParseError",
112
+ )
113
+ )
114
+ obs = self._build_observation(state.episode_id, spec, last_error=err)
115
+ return obs, PROBE_STEP_COST, False, {"reason": "parse_error"}
116
+
117
+ intrinsic = 0.0
118
+ last_error = ""
119
  try:
120
+ output = spec.fn(parsed)
121
+ output_repr = repr(output)
122
+ state.probe_history.append(
123
+ ProbeRecord(input_repr=repr(parsed), output_repr=output_repr, is_error=False)
124
+ )
125
+ if output_repr not in state.seen_outputs:
126
+ intrinsic += NEW_OUTPUT_BONUS
127
+ state.seen_outputs.add(output_repr)
128
+ except Exception as e: # noqa: BLE001
129
  error_type = type(e).__name__
130
+ err_repr = f"{error_type}: {e}"
131
+ state.probe_history.append(
132
+ ProbeRecord(
133
+ input_repr=repr(parsed),
134
+ output_repr=err_repr,
135
+ is_error=True,
136
+ error_type=error_type,
137
+ )
138
+ )
139
+ last_error = err_repr
140
+ if error_type not in state.seen_error_types:
141
+ intrinsic += NEW_ERROR_TYPE_BONUS
142
+ state.seen_error_types.add(error_type)
143
+
144
+ reward = intrinsic + PROBE_STEP_COST
145
+ obs = self._build_observation(state.episode_id, spec, last_error=last_error)
146
+ return obs, reward, False, {"intrinsic": intrinsic}
147
+
148
+ def _handle_submit(
149
+ self, state: State, spec: FunctionSpec, action: SubmitAction
150
+ ) -> Tuple[Observation, float, bool, dict]:
151
+ fuzz_inputs = generate_fuzz_inputs(spec, count=self.fuzz_count, seed=state.seed)
152
+ result = verify_submission(action.code, spec.fn, fuzz_inputs, target_name=spec.name)
153
+
154
+ total = result.execution_reward - result.complexity_penalty
155
+ if result.execution_reward >= 99.999:
156
+ total += PERFECT_SUBMISSION_BONUS
157
+
158
+ obs = self._build_observation(
159
+ state.episode_id,
160
+ spec,
161
+ last_error=result.define_error or "",
162
+ )
163
+ info = {
164
+ "execution_reward": result.execution_reward,
165
+ "complexity_penalty": result.complexity_penalty,
166
+ "matches": result.matches,
167
+ "fuzz_count": result.fuzz_count,
168
+ "define_error": result.define_error,
169
+ "reason": "submission",
170
+ }
171
+ return obs, total, True, info
172
+
173
+ # --- Helpers -----------------------------------------------------------
174
+
175
+ def _build_observation(
176
+ self, episode_id: str, spec: FunctionSpec, last_error: str
177
+ ) -> Observation:
178
+ state = self._states[episode_id]
179
+ max_steps = self._configs[episode_id]["max_steps"]
180
+ history = state.probe_history[-MAX_PROBE_HISTORY_IN_OBS:]
181
+ return Observation(
182
+ episode_id=episode_id,
183
+ target_function_name=state.target_function_name,
184
+ target_function_signature=f"{spec.signature}\n\n{spec.description}",
185
+ probe_history=history,
186
+ last_error=last_error,
187
+ steps_taken=state.steps_taken,
188
+ max_steps=max_steps,
189
+ )
190
+
191
+ # --- Introspection -----------------------------------------------------
192
+
193
+ def get_state(self, episode_id: str) -> dict:
194
+ s = self._states.get(episode_id)
195
+ if s is None:
196
+ return {}
197
+ return {
198
+ "episode_id": s.episode_id,
199
+ "target_function_name": s.target_function_name,
200
+ "steps_taken": s.steps_taken,
201
+ "done": s.done,
202
+ "seen_outputs": sorted(s.seen_outputs),
203
+ "seen_error_types": sorted(s.seen_error_types),
204
+ "probe_history": [r.model_dump() for r in s.probe_history],
205
+ }
opensleuth_env/models.py CHANGED
@@ -1,29 +1,79 @@
1
- from typing import Union, List, Tuple, Any, Literal
2
- from pydantic import BaseModel, Field
 
 
 
 
 
3
 
4
  class ProbeAction(BaseModel):
5
  action_type: Literal["probe"] = "probe"
6
- input: Any
 
 
 
 
 
7
 
8
  class SubmitAction(BaseModel):
9
  action_type: Literal["submit"] = "submit"
10
- code: str
 
11
 
12
  Action = Union[ProbeAction, SubmitAction]
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Observation(BaseModel):
15
- probe_history: List[Tuple[Any, Any]] = Field(
16
- ...,
17
- description="A list of (input, output) pairs from previous probes. Output can be a value or an error string."
18
- )
19
- last_error: str = Field(
20
- "",
21
- description="The error message from the last action, if any."
22
  )
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class State(BaseModel):
 
 
 
 
 
 
25
  target_function_name: str
26
- probe_history: List[Tuple[Any, Any]]
27
- # Store unique outputs and error types to calculate intrinsic reward
28
- seen_outputs: set
29
- seen_error_types: set
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for the OpenSleuth API and core state."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, List, Literal, Optional, Tuple, Union
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
 
9
  class ProbeAction(BaseModel):
10
  action_type: Literal["probe"] = "probe"
11
+ # The agent submits inputs as a Python literal string (e.g. "5", "'abc'",
12
+ # "[1, 2, 3]"). We parse it server-side with ast.literal_eval. Keeping it
13
+ # as a string avoids a class of FastAPI auto-coercion bugs and matches
14
+ # what an LLM naturally emits.
15
+ input_repr: str = Field(..., description="Python literal repr of the probe input")
16
+
17
 
18
  class SubmitAction(BaseModel):
19
  action_type: Literal["submit"] = "submit"
20
+ code: str = Field(..., description="Python source defining the target function")
21
+
22
 
23
  Action = Union[ProbeAction, SubmitAction]
24
 
25
+
26
+ class ProbeRecord(BaseModel):
27
+ """One entry in the probe history. Output is either the function's return
28
+ value (Pythonic repr) or, if it raised, an error string."""
29
+
30
+ input_repr: str
31
+ output_repr: str
32
+ is_error: bool = False
33
+ error_type: Optional[str] = None
34
+
35
+
36
  class Observation(BaseModel):
37
+ episode_id: str
38
+ target_function_name: str
39
+ target_function_signature: str = Field(
40
+ "", description="Human readable signature + docstring shown to the agent"
 
 
 
41
  )
42
+ probe_history: List[ProbeRecord] = Field(default_factory=list)
43
+ last_error: str = ""
44
+ steps_taken: int = 0
45
+ max_steps: int = 25
46
+
47
+
48
+ class StepResponse(BaseModel):
49
+ observation: Observation
50
+ reward: float
51
+ done: bool
52
+ info: dict = Field(default_factory=dict)
53
+
54
 
55
  class State(BaseModel):
56
+ """Internal mutable state for one episode. Not exposed in /step responses
57
+ in full, but available via /state/{eid} for debugging."""
58
+
59
+ model_config = ConfigDict(arbitrary_types_allowed=True)
60
+
61
+ episode_id: str
62
  target_function_name: str
63
+ probe_history: List[ProbeRecord] = Field(default_factory=list)
64
+ seen_outputs: set = Field(default_factory=set)
65
+ seen_error_types: set = Field(default_factory=set)
66
+ steps_taken: int = 0
67
+ done: bool = False
68
+ seed: int = 0
69
+
70
+
71
+ class ResetRequest(BaseModel):
72
+ target_name: str = "fibonacci"
73
+ seed: int = 0
74
+ max_steps: int = 25
75
+
76
+
77
+ class StepRequest(BaseModel):
78
+ episode_id: str
79
+ action: Action
opensleuth_env/verifier.py CHANGED
@@ -1,68 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import ast
2
- import random
3
- import string
4
  import math
 
 
 
 
 
 
5
 
6
- class ComplexityVisitor(ast.NodeVisitor):
 
 
 
7
  def __init__(self):
8
- self.complexity = 1
9
- def visit_If(self, node):
10
- self.complexity += 1
11
- self.generic_visit(node)
12
- def visit_For(self, node):
13
- self.complexity += 1
14
- self.generic_visit(node)
15
- def visit_While(self, node):
16
- self.complexity += 1
17
- self.generic_visit(node)
18
- def visit_And(self, node):
19
- self.complexity += 1
20
- self.generic_visit(node)
21
- def visit_Or(self, node):
22
- self.complexity += 1
23
  self.generic_visit(node)
24
- def visit_ExceptHandler(self, node):
25
- self.complexity += 1
 
 
 
 
 
 
 
 
 
26
  self.generic_visit(node)
27
 
28
- def _calculate_cyclomatic_complexity(code: str) -> int:
 
 
29
  try:
30
  tree = ast.parse(code)
31
- visitor = ComplexityVisitor()
32
- visitor.visit(tree)
33
- return math.log(visitor.complexity)
34
  except SyntaxError:
35
- return 50
 
 
 
 
 
 
 
 
36
 
37
- def _generate_fuzz_inputs(target_func, count=100):
38
- inputs = []
39
- if target_func.__name__ == "fibonacci":
40
- inputs = [random.randint(1, 90) for _ in range(count)]
41
- elif target_func.__name__ == "reverse_string":
42
- inputs = [''.join(random.choices(string.ascii_letters + string.digits, k=random.randint(1, 20))) for _ in range(count)]
43
- return inputs
44
 
45
- def verify_submission(submitted_code: str, target_function: callable, fuzz_count: int = 100) -> tuple[float, float]:
 
46
  try:
47
- local_scope = {}
48
- exec(submitted_code, {}, local_scope)
49
- submitted_func = local_scope.get(target_function.__name__)
50
- if not callable(submitted_func):
51
- return 0.0, 50.0
52
- except Exception:
53
- return 0.0, 50.0
54
-
55
- fuzz_inputs = _generate_fuzz_inputs(target_function, fuzz_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  matches = 0
57
  for inp in fuzz_inputs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
- target_output = target_function(inp)
60
- submitted_output = submitted_func(inp)
61
- if target_output == submitted_output:
62
- matches += 1
63
- except Exception:
64
- continue
65
-
66
- execution_reward = 100.0 * (matches / fuzz_count)
67
- complexity_penalty = _calculate_cyclomatic_complexity(submitted_code)
68
- return execution_reward, complexity_penalty
 
 
 
 
 
 
 
 
 
1
+ """Verifier: scores a submitted Python source against a hidden reference
2
+ function by domain-aware fuzzing, with sandboxed execution and a complexity
3
+ penalty.
4
+
5
+ Reward design:
6
+ execution_reward in [0, 100] = 100 * matches/fuzz_count
7
+ complexity_penalty in [0, 50] = log(cyclomatic) clipped, else 50 on syntax error
8
+ exec_failure_penalty = 25 if def-time exec raised, before fuzzing
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
  import ast
 
 
14
  import math
15
+ import multiprocessing as mp
16
+ import random
17
+ import signal
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, List, Optional
20
+
21
 
22
+ # ----- AST complexity ------------------------------------------------------
23
+
24
+
25
+ class _CCVisitor(ast.NodeVisitor):
26
  def __init__(self):
27
+ self.cc = 1
28
+
29
+ def _bump(self, node):
30
+ self.cc += 1
 
 
 
 
 
 
 
 
 
 
 
31
  self.generic_visit(node)
32
+
33
+ visit_If = _bump
34
+ visit_For = _bump
35
+ visit_While = _bump
36
+ visit_AsyncFor = _bump
37
+ visit_ExceptHandler = _bump
38
+ visit_With = _bump
39
+ visit_IfExp = _bump
40
+
41
+ def visit_BoolOp(self, node):
42
+ self.cc += max(0, len(node.values) - 1)
43
  self.generic_visit(node)
44
 
45
+
46
+ def calculate_complexity_penalty(code: str) -> float:
47
+ """Bounded log-scaled cyclomatic complexity, or 50 if code won't parse."""
48
  try:
49
  tree = ast.parse(code)
 
 
 
50
  except SyntaxError:
51
+ return 50.0
52
+ v = _CCVisitor()
53
+ v.visit(tree)
54
+ # log2 keeps small functions at ~0..2 and aggressive 100-branch lookups
55
+ # up around log2(100) ≈ 6.6, then we clamp.
56
+ return min(50.0, math.log2(v.cc))
57
+
58
+
59
+ # ----- Sandboxed execution -------------------------------------------------
60
 
 
 
 
 
 
 
 
61
 
62
+ def _exec_target_in_sandbox(code: str, target_name: str, queue: mp.Queue) -> None:
63
+ """Run inside a child process so we can hard-kill on timeout."""
64
  try:
65
+ # Restricted but still useful builtins. Submitted code rarely needs
66
+ # imports beyond math/string/itertools/functools, so we whitelist.
67
+ allowed_modules = {"math", "string", "itertools", "functools", "collections", "re"}
68
+ safe_globals = {
69
+ "__builtins__": __builtins__,
70
+ "__name__": "__sandbox__",
71
+ }
72
+ # Pre-import the whitelisted modules so the agent can use them
73
+ # without needing import statements (and we keep everything inproc).
74
+ for mod_name in allowed_modules:
75
+ safe_globals[mod_name] = __import__(mod_name)
76
+
77
+ local_scope: dict = {}
78
+ exec(code, safe_globals, local_scope)
79
+ fn = local_scope.get(target_name) or safe_globals.get(target_name)
80
+ if not callable(fn):
81
+ queue.put(("err", f"No callable named {target_name!r} defined."))
82
+ return
83
+ queue.put(("ok", None))
84
+ except Exception as e: # noqa: BLE001
85
+ queue.put(("err", f"{type(e).__name__}: {e}"))
86
+
87
+
88
+ def _can_define(code: str, target_name: str, timeout_s: float) -> Optional[str]:
89
+ """Return None if the submitted code defines the target callable, else an
90
+ error string. Uses a child process with a wall-clock timeout."""
91
+ ctx = mp.get_context("fork") if mp.get_start_method(allow_none=True) != "spawn" else mp.get_context("spawn")
92
+ q: mp.Queue = ctx.Queue()
93
+ p = ctx.Process(target=_exec_target_in_sandbox, args=(code, target_name, q))
94
+ p.start()
95
+ p.join(timeout=timeout_s)
96
+ if p.is_alive():
97
+ p.terminate()
98
+ p.join(1.0)
99
+ if p.is_alive():
100
+ p.kill()
101
+ return f"Definition timed out after {timeout_s}s."
102
+ if q.empty():
103
+ return "Sandbox produced no result."
104
+ status, payload = q.get()
105
+ return None if status == "ok" else payload
106
+
107
+
108
+ # Per-call (per-input) sandboxing is too slow for 100 fuzz inputs, so we
109
+ # accept the trade-off of running the submitted callable in-process for
110
+ # fuzzing, but we wrap each call in a SIGALRM-based timeout and we already
111
+ # proved at definition-time that the import didn't blow up.
112
+
113
+ class _CallTimeout(Exception):
114
+ pass
115
+
116
+
117
+ def _call_with_timeout(fn: Callable, arg: Any, timeout_s: float):
118
+ def _handler(signum, frame): # noqa: ARG001
119
+ raise _CallTimeout()
120
+
121
+ old = signal.signal(signal.SIGALRM, _handler)
122
+ signal.setitimer(signal.ITIMER_REAL, timeout_s)
123
+ try:
124
+ return fn(arg)
125
+ finally:
126
+ signal.setitimer(signal.ITIMER_REAL, 0)
127
+ signal.signal(signal.SIGALRM, old)
128
+
129
+
130
+ def _safe_call(fn: Callable, arg: Any, timeout_s: float):
131
+ """Returns (kind, value): kind in {'val', 'err', 'timeout'}."""
132
+ try:
133
+ return ("val", _call_with_timeout(fn, arg, timeout_s))
134
+ except _CallTimeout:
135
+ return ("timeout", f"timed out after {timeout_s}s")
136
+ except Exception as e: # noqa: BLE001
137
+ return ("err", f"{type(e).__name__}: {e}")
138
+
139
+
140
+ # ----- Public scoring ------------------------------------------------------
141
+
142
+
143
+ @dataclass
144
+ class VerificationResult:
145
+ execution_reward: float
146
+ complexity_penalty: float
147
+ define_error: Optional[str]
148
+ matches: int
149
+ fuzz_count: int
150
+
151
+
152
+ def verify_submission(
153
+ submitted_code: str,
154
+ target_function: Callable[[Any], Any],
155
+ fuzz_inputs: List[Any],
156
+ *,
157
+ target_name: Optional[str] = None,
158
+ define_timeout_s: float = 5.0,
159
+ call_timeout_s: float = 1.0,
160
+ ) -> VerificationResult:
161
+ """Score `submitted_code` against `target_function` over the supplied
162
+ `fuzz_inputs`. The agent is expected to define a top-level function with
163
+ the same name as `target_function` (overridable via `target_name`)."""
164
+ name = target_name or target_function.__name__
165
+
166
+ define_err = _can_define(submitted_code, name, define_timeout_s)
167
+ complexity = calculate_complexity_penalty(submitted_code)
168
+ if define_err is not None:
169
+ return VerificationResult(
170
+ execution_reward=0.0,
171
+ complexity_penalty=complexity,
172
+ define_error=define_err,
173
+ matches=0,
174
+ fuzz_count=len(fuzz_inputs),
175
+ )
176
+
177
+ # Re-define in-process for fast fuzzing. We just confirmed it won't blow
178
+ # up at import-time; we still time-bound each call.
179
+ safe_globals: dict = {
180
+ "__builtins__": __builtins__,
181
+ "__name__": "__opensleuth_submission__",
182
+ "math": __import__("math"),
183
+ "string": __import__("string"),
184
+ "itertools": __import__("itertools"),
185
+ "functools": __import__("functools"),
186
+ "collections": __import__("collections"),
187
+ "re": __import__("re"),
188
+ }
189
+ local_scope: dict = {}
190
+ exec(submitted_code, safe_globals, local_scope)
191
+ submitted_fn = local_scope.get(name) or safe_globals.get(name)
192
+
193
  matches = 0
194
  for inp in fuzz_inputs:
195
+ ref = _safe_call(target_function, inp, call_timeout_s)
196
+ sub = _safe_call(submitted_fn, inp, call_timeout_s)
197
+ if _outputs_equivalent(ref, sub):
198
+ matches += 1
199
+
200
+ fuzz_count = len(fuzz_inputs) or 1
201
+ exec_reward = 100.0 * (matches / fuzz_count)
202
+ return VerificationResult(
203
+ execution_reward=exec_reward,
204
+ complexity_penalty=complexity,
205
+ define_error=None,
206
+ matches=matches,
207
+ fuzz_count=fuzz_count,
208
+ )
209
+
210
+
211
+ def _outputs_equivalent(ref, sub) -> bool:
212
+ """Ref and sub are (kind, value) tuples from `_safe_call`. They count as
213
+ equivalent if both raised the same exception type, or both returned values
214
+ that are == equal."""
215
+ rkind, rval = ref
216
+ skind, sval = sub
217
+ if rkind == "val" and skind == "val":
218
  try:
219
+ return rval == sval
220
+ except Exception: # noqa: BLE001
221
+ return False
222
+ if rkind == "err" and skind == "err":
223
+ # Match on exception class name.
224
+ return rval.split(":", 1)[0] == sval.split(":", 1)[0]
225
+ if rkind == "timeout" and skind == "timeout":
226
+ return True
227
+ return False
228
+
229
+
230
+ def generate_fuzz_inputs(
231
+ spec, count: int = 100, seed: Optional[int] = None
232
+ ) -> List[Any]:
233
+ """Public helper: pull `count` fuzz inputs from a FunctionSpec, optionally
234
+ seeded for reproducibility."""
235
+ rng = random.Random(seed)
236
+ return spec.fuzzer(rng, count)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- fastapi
2
- uvicorn
3
- pydantic
 
1
+ fastapi==0.115.6
2
+ uvicorn[standard]==0.32.1
3
+ pydantic==2.10.3
server.py CHANGED
@@ -1,27 +1,76 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from opensleuth_env.env import OpenSleuthEnv
4
- from opensleuth_env.models import Action, Observation
5
 
6
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  env = OpenSleuthEnv()
8
 
9
- class ResetBody(BaseModel):
10
- target_name: str = "fibonacci"
11
-
12
- @app.post("/reset", response_model=Observation)
13
- def reset_env(body: ResetBody):
14
- # Ensure the environment is reset for a new session
15
- return env.reset(target_name=body.target_name)
16
-
17
- @app.post("/step")
18
- def step_env(action: Action):
19
- # The environment now handles the case where it's not reset
20
- obs, reward, done = env.step(action)
21
- return {"observation": obs, "reward": reward, "done": done}
22
-
23
- @app.get("/state")
24
- def get_state():
25
- if env.state is None:
26
- return {}
27
- return env.get_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server exposing the OpenSleuth environment over HTTP."""
 
 
 
2
 
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+
9
+ from opensleuth_env import (
10
+ BLACK_BOX_FUNCTIONS,
11
+ OpenSleuthEnv,
12
+ ProbeAction,
13
+ ResetRequest,
14
+ StepRequest,
15
+ StepResponse,
16
+ SubmitAction,
17
+ )
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
20
+ log = logging.getLogger("opensleuth.server")
21
+
22
+ app = FastAPI(title="OpenSleuth Env", version="0.2.0")
23
  env = OpenSleuthEnv()
24
 
25
+
26
+ @app.get("/health")
27
+ def health():
28
+ return {"status": "ok", "episodes_tracked": len(env._states)} # noqa: SLF001
29
+
30
+
31
+ @app.get("/functions")
32
+ def list_functions():
33
+ return {
34
+ "functions": [
35
+ {
36
+ "name": s.name,
37
+ "signature": s.signature,
38
+ "description": s.description,
39
+ }
40
+ for s in BLACK_BOX_FUNCTIONS.values()
41
+ ]
42
+ }
43
+
44
+
45
+ @app.post("/reset")
46
+ def reset(req: ResetRequest):
47
+ try:
48
+ obs = env.reset(target_name=req.target_name, seed=req.seed, max_steps=req.max_steps)
49
+ except ValueError as e:
50
+ raise HTTPException(status_code=400, detail=str(e)) from e
51
+ return obs
52
+
53
+
54
+ @app.post("/step", response_model=StepResponse)
55
+ def step(req: StepRequest):
56
+ try:
57
+ return env.step(req.episode_id, req.action)
58
+ except KeyError as e:
59
+ raise HTTPException(status_code=404, detail=str(e)) from e
60
+
61
+
62
+ @app.get("/state/{episode_id}")
63
+ def get_state(episode_id: str):
64
+ state = env.get_state(episode_id)
65
+ if not state:
66
+ raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}")
67
+ return state
68
+
69
+
70
+ # Convenience: a flat /step that does reset+step in one call is occasionally
71
+ # useful for shell-style debugging.
72
+ @app.post("/probe_once")
73
+ def probe_once(target_name: str, input_repr: str):
74
+ obs = env.reset(target_name=target_name)
75
+ resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
76
+ return resp