anugrah55 commited on
Commit
9030acd
·
verified ·
1 Parent(s): 77e65fb

env: add GET /tasks/{name}/sample_inputs for trainer-side fuzz delegation

Browse files
Files changed (2) hide show
  1. server.py +38 -1
  2. tests/test_open_env.py +500 -0
server.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import logging
 
6
  from typing import Optional
7
 
8
  from fastapi import FastAPI, HTTPException, Query
@@ -17,11 +18,12 @@ from opensleuth_env import (
17
  SubmitAction,
18
  TaskCatalog,
19
  )
 
20
 
21
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
22
  log = logging.getLogger("opensleuth.server")
23
 
24
- app = FastAPI(title="OpenSleuth Env", version="0.4.0")
25
  env = OpenSleuthEnv()
26
 
27
 
@@ -141,3 +143,38 @@ def probe_once(target_name: str, input_repr: str):
141
  obs = env.reset(target_name=target_name)
142
  resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
143
  return resp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
  import logging
6
+ import random
7
  from typing import Optional
8
 
9
  from fastapi import FastAPI, HTTPException, Query
 
18
  SubmitAction,
19
  TaskCatalog,
20
  )
21
+ from opensleuth_env.task_catalog import TaskResolutionError
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
24
  log = logging.getLogger("opensleuth.server")
25
 
26
+ app = FastAPI(title="OpenSleuth Env", version="0.4.1")
27
  env = OpenSleuthEnv()
28
 
29
 
 
143
  obs = env.reset(target_name=target_name)
144
  resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr))
145
  return resp
146
+
147
+
148
+ @app.get("/tasks/{name}/sample_inputs")
149
+ def sample_inputs(
150
+ name: str,
151
+ n: int = Query(8, ge=1, le=64, description="How many inputs to draw."),
152
+ seed: int = Query(0, description="Deterministic seed for the fuzzer."),
153
+ ):
154
+ """Return ``n`` Python-literal `repr` strings drawn from the task's
155
+ auto-fuzzer (or hand-written fuzzer for builtins).
156
+
157
+ Used by the trainer to build in-context probe pools without having to
158
+ duplicate the auto-fuzzer logic on the trainer side. Each returned
159
+ string is `ast.literal_eval`-safe and can be POSTed straight back to
160
+ `/step` as a `ProbeAction.input_repr`.
161
+ """
162
+ try:
163
+ spec = env.catalog.resolve(target_name=name)
164
+ except TaskResolutionError as e:
165
+ raise HTTPException(status_code=404, detail=str(e)) from e
166
+ rng = random.Random(seed)
167
+ try:
168
+ raw_inputs = spec.fuzzer(rng, n)
169
+ except Exception as e: # noqa: BLE001
170
+ raise HTTPException(
171
+ status_code=500,
172
+ detail=f"fuzzer for {name!r} failed: {type(e).__name__}: {e}",
173
+ ) from e
174
+ return {
175
+ "name": name,
176
+ "n": n,
177
+ "seed": seed,
178
+ "unpack_args": bool(getattr(spec, "unpack_args", False)),
179
+ "inputs": [repr(x) for x in raw_inputs],
180
+ }
tests/test_open_env.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for OpenSleuth Level 2: auto-fuzzer + TaskCatalog + open /reset.
2
+
3
+ These tests do *not* require Hub network access. The Hub-availability test
4
+ is opportunistic: it asserts ``>=15`` total tasks if the dataset loads, but
5
+ silently passes (with a marker) if the Hub is offline / the env is sandboxed.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import random
12
+ import typing
13
+ from typing import Optional, Literal
14
+
15
+ import pytest
16
+ from fastapi.testclient import TestClient
17
+
18
+ from opensleuth_env import (
19
+ BLACK_BOX_FUNCTIONS,
20
+ OpenSleuthEnv,
21
+ ProbeAction,
22
+ SubmitAction,
23
+ TaskCatalog,
24
+ TaskResolutionError,
25
+ auto_fuzz,
26
+ )
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Auto-fuzzer
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ class TestAutoFuzzerTypes:
35
+ def _rng(self, seed: int = 0) -> random.Random:
36
+ return random.Random(seed)
37
+
38
+ def test_int_inputs_are_ints(self):
39
+ def f(n: int) -> int:
40
+ return n
41
+
42
+ rng = self._rng()
43
+ outs = auto_fuzz(f, 50, rng)
44
+ assert len(outs) == 50
45
+ assert all(isinstance(t, tuple) and len(t) == 1 for t in outs)
46
+ assert all(isinstance(t[0], int) and not isinstance(t[0], bool) for t in outs)
47
+
48
+ def test_str_inputs_are_strs(self):
49
+ def f(s: str) -> int:
50
+ return len(s)
51
+
52
+ outs = auto_fuzz(f, 30, self._rng())
53
+ assert all(isinstance(t[0], str) for t in outs)
54
+
55
+ def test_list_int_inputs_are_lists_of_ints(self):
56
+ def f(xs: list[int]) -> int:
57
+ return sum(xs)
58
+
59
+ outs = auto_fuzz(f, 30, self._rng())
60
+ for (xs,) in outs:
61
+ assert isinstance(xs, list)
62
+ assert all(isinstance(x, int) for x in xs)
63
+
64
+ def test_homogeneous_tuple_inputs(self):
65
+ def f(xs: tuple[int, ...]) -> int:
66
+ return sum(xs)
67
+
68
+ outs = auto_fuzz(f, 30, self._rng())
69
+ for (xs,) in outs:
70
+ assert isinstance(xs, tuple)
71
+ assert all(isinstance(x, int) for x in xs)
72
+
73
+ def test_heterogeneous_tuple_inputs(self):
74
+ def f(t: tuple[int, str]) -> int:
75
+ return len(t[1])
76
+
77
+ outs = auto_fuzz(f, 30, self._rng())
78
+ for (t,) in outs:
79
+ assert isinstance(t, tuple) and len(t) == 2
80
+ assert isinstance(t[0], int)
81
+ assert isinstance(t[1], str)
82
+
83
+ def test_optional_inputs_sometimes_None(self):
84
+ def f(x: Optional[int]) -> int:
85
+ return 0
86
+
87
+ outs = auto_fuzz(f, 200, self._rng(seed=42))
88
+ seen_none = any(t[0] is None for t in outs)
89
+ seen_int = any(isinstance(t[0], int) and not isinstance(t[0], bool) for t in outs)
90
+ assert seen_none, "Optional[int] should occasionally yield None"
91
+ assert seen_int, "Optional[int] should also yield ints"
92
+
93
+ def test_literal_inputs_only_pick_listed_values(self):
94
+ def f(mode: Literal["a", "b", "c"]) -> int:
95
+ return 0
96
+
97
+ outs = auto_fuzz(f, 50, self._rng())
98
+ for (m,) in outs:
99
+ assert m in ("a", "b", "c")
100
+
101
+ def test_dict_str_int_inputs(self):
102
+ def f(d: dict[str, int]) -> int:
103
+ return len(d)
104
+
105
+ outs = auto_fuzz(f, 20, self._rng())
106
+ for (d,) in outs:
107
+ assert isinstance(d, dict)
108
+ for k, v in d.items():
109
+ assert isinstance(k, str)
110
+ assert isinstance(v, int)
111
+
112
+ def test_multi_arg_returns_full_tuples(self):
113
+ def f(a: int, b: str) -> int:
114
+ return 0
115
+
116
+ outs = auto_fuzz(f, 20, self._rng())
117
+ for t in outs:
118
+ assert isinstance(t, tuple)
119
+ assert len(t) == 2
120
+ assert isinstance(t[0], int)
121
+ assert isinstance(t[1], str)
122
+
123
+ def test_unannotated_param_falls_back_to_int(self):
124
+ def f(x): # no annotation
125
+ return x
126
+
127
+ outs = auto_fuzz(f, 30, self._rng())
128
+ for (x,) in outs:
129
+ assert isinstance(x, int)
130
+
131
+
132
+ class TestAutoFuzzerSpecOverride:
133
+ def test_int_min_max_overrides_default_range(self):
134
+ def f(n: int) -> int:
135
+ return n
136
+
137
+ outs = auto_fuzz(f, 100, random.Random(0), fuzz_spec={"n": {"type": "int", "min": 1, "max": 5}})
138
+ for (n,) in outs:
139
+ assert 1 <= n <= 5, f"expected n in [1, 5], got {n}"
140
+
141
+ def test_str_alphabet_override(self):
142
+ def f(s: str) -> int:
143
+ return len(s)
144
+
145
+ outs = auto_fuzz(
146
+ f, 100, random.Random(0),
147
+ fuzz_spec={"s": {"type": "str", "alphabet": "ab", "max_len": 4}},
148
+ )
149
+ for (s,) in outs:
150
+ assert len(s) <= 4
151
+ for ch in s:
152
+ assert ch in "ab", f"unexpected char {ch!r} in {s!r}"
153
+
154
+ def test_list_elem_override(self):
155
+ def f(xs: list[int]) -> int:
156
+ return sum(xs)
157
+
158
+ outs = auto_fuzz(
159
+ f, 80, random.Random(0),
160
+ fuzz_spec={"xs": {"type": "list", "elem": {"type": "int", "min": 0, "max": 3}, "max_len": 4}},
161
+ )
162
+ for (xs,) in outs:
163
+ assert len(xs) <= 4
164
+ for v in xs:
165
+ assert 0 <= v <= 3
166
+
167
+ def test_tuple_elems_override(self):
168
+ def f(t):
169
+ return t
170
+
171
+ outs = auto_fuzz(
172
+ f, 30, random.Random(0),
173
+ fuzz_spec={"t": {"type": "tuple", "elems": [
174
+ {"type": "int", "min": 0, "max": 1},
175
+ {"type": "str", "alphabet": "x", "max_len": 2},
176
+ ]}},
177
+ )
178
+ for (t,) in outs:
179
+ assert isinstance(t, tuple) and len(t) == 2
180
+ assert 0 <= t[0] <= 1
181
+ for ch in t[1]:
182
+ assert ch == "x"
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # TaskCatalog
187
+ # ---------------------------------------------------------------------------
188
+
189
+
190
+ class TestTaskCatalog:
191
+ def test_resolves_builtin_by_name(self):
192
+ cat = TaskCatalog(enable_hub=False)
193
+ spec = cat.resolve(target_name="fibonacci")
194
+ assert spec.name == "fibonacci"
195
+ assert spec is BLACK_BOX_FUNCTIONS["fibonacci"]
196
+ assert spec.unpack_args is False
197
+ assert spec.source == "builtin"
198
+
199
+ def test_resolves_caller_supplied_target_code(self):
200
+ cat = TaskCatalog(enable_hub=False)
201
+ code = "def add(a: int, b: int) -> int:\n return a + b\n"
202
+ spec = cat.resolve(target_code=code, target_function_name="add")
203
+ assert spec.name == "add"
204
+ assert spec.unpack_args is True # 2-arg
205
+ assert spec.source == "user"
206
+ # The wrapped fuzzer must produce calls that succeed end-to-end.
207
+ rng = random.Random(0)
208
+ inputs = spec.fuzzer(rng, 10)
209
+ for args in inputs:
210
+ assert isinstance(args, tuple) and len(args) == 2
211
+ assert spec.fn(*args) == args[0] + args[1]
212
+
213
+ def test_caller_supplied_unary_uses_unwrapped_call(self):
214
+ cat = TaskCatalog(enable_hub=False)
215
+ code = "def square(n: int) -> int:\n return n * n\n"
216
+ spec = cat.resolve(target_code=code, target_function_name="square")
217
+ assert spec.unpack_args is False
218
+ rng = random.Random(0)
219
+ inputs = spec.fuzzer(rng, 5)
220
+ for x in inputs:
221
+ assert isinstance(x, int)
222
+ assert spec.fn(x) == x * x
223
+
224
+ def test_resolve_with_no_source_raises(self):
225
+ cat = TaskCatalog(enable_hub=False)
226
+ with pytest.raises(TaskResolutionError):
227
+ cat.resolve()
228
+
229
+ def test_resolve_unknown_name_raises(self):
230
+ cat = TaskCatalog(enable_hub=False)
231
+ with pytest.raises(TaskResolutionError):
232
+ cat.resolve(target_name="this_does_not_exist")
233
+
234
+ def test_target_code_without_function_name_raises(self):
235
+ cat = TaskCatalog(enable_hub=False)
236
+ with pytest.raises(TaskResolutionError):
237
+ cat.resolve(target_code="def foo(): return 1\n")
238
+
239
+ def test_rejects_oracle_import(self):
240
+ cat = TaskCatalog(enable_hub=False)
241
+ bad = (
242
+ "import opensleuth_env\n"
243
+ "def f(x): return x\n"
244
+ )
245
+ with pytest.raises(TaskResolutionError):
246
+ cat.resolve(target_code=bad, target_function_name="f")
247
+
248
+ bad2 = (
249
+ "from opensleuth_env.black_box import _fibonacci\n"
250
+ "def f(x): return _fibonacci(x)\n"
251
+ )
252
+ with pytest.raises(TaskResolutionError):
253
+ cat.resolve(target_code=bad2, target_function_name="f")
254
+
255
+ def test_target_code_using_open_is_blocked_at_call_time(self):
256
+ """`open` is not in the safe-builtins whitelist. The catalog will
257
+ compile the function (since `open` is only resolved at call-time
258
+ via NameError), but invoking it must fail safely."""
259
+ cat = TaskCatalog(enable_hub=False)
260
+ code = (
261
+ "def f(x):\n"
262
+ " open('/tmp/x', 'w')\n"
263
+ " return 0\n"
264
+ )
265
+ spec = cat.resolve(target_code=code, target_function_name="f")
266
+ with pytest.raises(NameError):
267
+ spec.fn(0)
268
+
269
+ def test_caller_supplied_edge_cases_are_parsed(self):
270
+ cat = TaskCatalog(enable_hub=False)
271
+ spec = cat.resolve(
272
+ target_code="def neg(n: int) -> int:\n return -n\n",
273
+ target_function_name="neg",
274
+ edge_cases=["0", "1", "-1", "100"],
275
+ )
276
+ assert spec.edge_cases == [0, 1, -1, 100]
277
+
278
+ def test_caller_supplied_fuzz_spec_is_used(self):
279
+ cat = TaskCatalog(enable_hub=False)
280
+ spec = cat.resolve(
281
+ target_code="def f(n: int) -> int:\n return n\n",
282
+ target_function_name="f",
283
+ fuzz_spec={"n": {"type": "int", "min": 7, "max": 9}},
284
+ )
285
+ rng = random.Random(0)
286
+ inputs = spec.fuzzer(rng, 50)
287
+ for x in inputs:
288
+ assert 7 <= x <= 9
289
+
290
+ def test_list_builtin_returns_nine_entries(self):
291
+ cat = TaskCatalog(enable_hub=False)
292
+ builtins_list = cat.list_builtin()
293
+ assert len(builtins_list) == 9
294
+ for entry in builtins_list:
295
+ assert entry["source"] == "builtin"
296
+ assert "name" in entry
297
+ assert "signature" in entry
298
+ assert "difficulty" in entry
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # End-to-end via OpenSleuthEnv
303
+ # ---------------------------------------------------------------------------
304
+
305
+
306
+ class TestEnvOpenEnded:
307
+ def test_legacy_reset_by_target_name_unchanged(self):
308
+ env = OpenSleuthEnv(fuzz_count=10)
309
+ obs = env.reset(target_name="fibonacci")
310
+ assert obs.target_function_name == "fibonacci"
311
+ assert obs.difficulty == "easy"
312
+ assert obs.steps_taken == 0
313
+
314
+ # Probe via the same path as before.
315
+ resp = env.step(obs.episode_id, ProbeAction(input_repr="10"))
316
+ assert resp.observation.probe_history[-1].output_repr == "55"
317
+
318
+ def test_env_caller_supplied_unary_full_loop(self):
319
+ env = OpenSleuthEnv(fuzz_count=10)
320
+ obs = env.reset(
321
+ target_code="def square(n: int) -> int:\n return n * n\n",
322
+ target_function_name="square",
323
+ )
324
+ assert obs.target_function_name == "square"
325
+
326
+ # Probe.
327
+ resp = env.step(obs.episode_id, ProbeAction(input_repr="5"))
328
+ assert resp.observation.probe_history[-1].output_repr == "25"
329
+
330
+ # Submit a perfect implementation.
331
+ code = "def square(n):\n return n * n\n"
332
+ resp = env.step(obs.episode_id, SubmitAction(code=code))
333
+ assert resp.done is True
334
+ assert resp.info["execution_reward"] == pytest.approx(100.0)
335
+ assert resp.reward > 140.0
336
+
337
+ def test_env_caller_supplied_multi_arg_full_loop(self):
338
+ env = OpenSleuthEnv(fuzz_count=10)
339
+ obs = env.reset(
340
+ target_code="def add(a: int, b: int) -> int:\n return a + b\n",
341
+ target_function_name="add",
342
+ edge_cases=["(0, 0)", "(1, -1)", "(100, 0)"],
343
+ )
344
+ assert obs.target_function_name == "add"
345
+
346
+ # Probe with a 2-tuple.
347
+ resp = env.step(obs.episode_id, ProbeAction(input_repr="(2, 3)"))
348
+ assert resp.observation.probe_history[-1].output_repr == "5"
349
+
350
+ # Submit a perfect implementation.
351
+ code = "def add(a, b):\n return a + b\n"
352
+ resp = env.step(obs.episode_id, SubmitAction(code=code))
353
+ assert resp.done is True
354
+ assert resp.info["execution_reward"] == pytest.approx(100.0)
355
+ assert resp.reward > 140.0
356
+
357
+ def test_env_caller_supplied_buggy_submission_scored_negative(self):
358
+ env = OpenSleuthEnv(fuzz_count=10)
359
+ obs = env.reset(
360
+ target_code="def add(a: int, b: int) -> int:\n return a + b\n",
361
+ target_function_name="add",
362
+ )
363
+ bad = "def add(a, b):\n return a - b\n"
364
+ resp = env.step(obs.episode_id, SubmitAction(code=bad))
365
+ assert resp.done is True
366
+ assert resp.info["execution_reward"] < 50.0
367
+ assert resp.reward < 0.0
368
+
369
+ def test_env_caller_supplied_oracle_import_rejected(self):
370
+ env = OpenSleuthEnv()
371
+ with pytest.raises(ValueError):
372
+ env.reset(
373
+ target_code="import opensleuth_env\ndef f(x): return x\n",
374
+ target_function_name="f",
375
+ )
376
+
377
+
378
+ # ---------------------------------------------------------------------------
379
+ # HTTP layer
380
+ # ---------------------------------------------------------------------------
381
+
382
+
383
+ @pytest.fixture(scope="module")
384
+ def http_client():
385
+ from server import app
386
+
387
+ with TestClient(app) as client:
388
+ yield client
389
+
390
+
391
+ class TestHttpLayer:
392
+ def test_tasks_endpoint_lists_at_least_nine_builtin(self, http_client):
393
+ r = http_client.get("/tasks?source=builtin")
394
+ assert r.status_code == 200
395
+ body = r.json()
396
+ assert body["count"] >= 9
397
+ names = [t["name"] for t in body["tasks"]]
398
+ for name in BLACK_BOX_FUNCTIONS:
399
+ assert name in names
400
+
401
+ def test_tasks_all_includes_at_least_builtins(self, http_client):
402
+ r = http_client.get("/tasks?source=all")
403
+ assert r.status_code == 200
404
+ body = r.json()
405
+ # The builtins are always present. If the Hub is reachable we'd
406
+ # expect 15+, but the test must pass even if Hub is unavailable
407
+ # (e.g. CI sandboxes block egress).
408
+ assert body["count"] >= 9
409
+ if not body["hub"].get("enabled", False) or body["hub"].get("error"):
410
+ pytest.skip(f"hub not reachable: {body['hub']}")
411
+ # Hub reachable -> dataset should have 15+ rows after bootstrap.
412
+ assert body["count"] >= 15
413
+
414
+ def test_sample_inputs_returns_n_repr_strings_for_builtin(self, http_client):
415
+ r = http_client.get("/tasks/fibonacci/sample_inputs?n=5&seed=7")
416
+ assert r.status_code == 200, r.text
417
+ body = r.json()
418
+ assert body["name"] == "fibonacci"
419
+ assert body["n"] == 5
420
+ assert body["seed"] == 7
421
+ assert isinstance(body["inputs"], list)
422
+ assert len(body["inputs"]) == 5
423
+ # Every returned string must be ast.literal_eval-safe so the trainer
424
+ # can post it straight back to /step as a probe input_repr.
425
+ import ast
426
+ for s in body["inputs"]:
427
+ assert isinstance(s, str)
428
+ ast.literal_eval(s)
429
+ # Determinism: same seed -> identical inputs.
430
+ r2 = http_client.get("/tasks/fibonacci/sample_inputs?n=5&seed=7")
431
+ assert r2.json()["inputs"] == body["inputs"]
432
+
433
+ def test_sample_inputs_unknown_target_404s(self, http_client):
434
+ r = http_client.get("/tasks/__nope__/sample_inputs?n=2&seed=0")
435
+ assert r.status_code == 404
436
+
437
+ def test_reset_legacy_target_name_still_works(self, http_client):
438
+ r = http_client.post("/reset", json={
439
+ "target_name": "fibonacci", "seed": 0, "max_steps": 10,
440
+ })
441
+ assert r.status_code == 200
442
+ body = r.json()
443
+ assert body["target_function_name"] == "fibonacci"
444
+ assert "fibonacci" in body["target_function_signature"]
445
+
446
+ def test_reset_caller_supplied_target_code(self, http_client):
447
+ payload = {
448
+ "target_code": "def add(a: int, b: int) -> int:\n return a + b\n",
449
+ "target_function_name": "add",
450
+ "edge_cases": ["(0, 0)", "(1, -1)"],
451
+ "max_steps": 5,
452
+ }
453
+ r = http_client.post("/reset", json=payload)
454
+ assert r.status_code == 200, r.text
455
+ body = r.json()
456
+ assert body["target_function_name"] == "add"
457
+ eid = body["episode_id"]
458
+
459
+ # Probe -> verify wrapping.
460
+ r = http_client.post("/step", json={
461
+ "episode_id": eid,
462
+ "action": {"action_type": "probe", "input_repr": "(7, 8)"},
463
+ })
464
+ assert r.status_code == 200, r.text
465
+ body = r.json()
466
+ assert body["observation"]["probe_history"][-1]["output_repr"] == "15"
467
+
468
+ # Submit perfect.
469
+ r = http_client.post("/step", json={
470
+ "episode_id": eid,
471
+ "action": {"action_type": "submit", "code": "def add(a, b):\n return a + b\n"},
472
+ })
473
+ assert r.status_code == 200, r.text
474
+ body = r.json()
475
+ assert body["done"] is True
476
+ assert body["info"]["execution_reward"] == pytest.approx(100.0)
477
+ assert body["reward"] > 140.0
478
+
479
+ def test_reset_with_neither_target_returns_400(self, http_client):
480
+ r = http_client.post("/reset", json={"seed": 0})
481
+ assert r.status_code == 400
482
+
483
+ def test_reset_with_target_code_only_no_function_name_returns_400(self, http_client):
484
+ r = http_client.post("/reset", json={
485
+ "target_code": "def f(): return 1\n",
486
+ })
487
+ assert r.status_code == 400
488
+
489
+ def test_functions_endpoint_unchanged_for_trainer(self, http_client):
490
+ r = http_client.get("/functions")
491
+ assert r.status_code == 200
492
+ body = r.json()
493
+ assert "functions" in body
494
+ names = [f["name"] for f in body["functions"]]
495
+ for name in BLACK_BOX_FUNCTIONS:
496
+ assert name in names
497
+ # The original v0.3 fields must all be present.
498
+ for entry in body["functions"]:
499
+ for k in ("name", "signature", "description", "difficulty", "edge_case_count"):
500
+ assert k in entry