File size: 4,518 Bytes
d597642
 
 
 
 
 
 
78575eb
d597642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78575eb
 
 
 
 
 
 
 
 
 
 
 
 
 
d597642
 
 
 
 
 
78575eb
d597642
 
 
 
78575eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d597642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Thin HTTP client for the OpenSleuth env Space."""

from __future__ import annotations

import logging
import os
import time
from typing import Any, Dict, List, Optional

import requests

log = logging.getLogger("opensleuth.client")


class EnvClient:
    def __init__(self, base_url: str | None = None, timeout: float = 30.0, retries: int = 3):
        self.base_url = (base_url or os.environ.get("ENV_URL", "http://127.0.0.1:7860")).rstrip("/")
        self.timeout = timeout
        self.retries = retries

    def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
        last_exc: Exception | None = None
        for attempt in range(self.retries):
            try:
                r = requests.post(f"{self.base_url}{path}", json=payload, timeout=self.timeout)
                r.raise_for_status()
                return r.json()
            except (requests.RequestException, ValueError) as e:  # noqa: PERF203
                last_exc = e
                wait = 0.5 * (2 ** attempt)
                log.warning("env POST %s failed (%s); retrying in %.1fs", path, e, wait)
                time.sleep(wait)
        raise RuntimeError(f"env POST {path} failed after {self.retries} retries: {last_exc}")

    def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        last_exc: Exception | None = None
        for attempt in range(self.retries):
            try:
                r = requests.get(f"{self.base_url}{path}", params=params, timeout=self.timeout)
                r.raise_for_status()
                return r.json()
            except (requests.RequestException, ValueError) as e:  # noqa: PERF203
                last_exc = e
                wait = 0.5 * (2 ** attempt)
                log.warning("env GET %s failed (%s); retrying in %.1fs", path, e, wait)
                time.sleep(wait)
        raise RuntimeError(f"env GET {path} failed after {self.retries} retries: {last_exc}")

    def health(self) -> Dict[str, Any]:
        r = requests.get(f"{self.base_url}/health", timeout=self.timeout)
        r.raise_for_status()
        return r.json()

    def list_functions(self) -> list[Dict[str, str]]:
        """Legacy v0.3 endpoint -- only the 9 builtin functions."""
        r = requests.get(f"{self.base_url}/functions", timeout=self.timeout)
        r.raise_for_status()
        return r.json()["functions"]

    def list_tasks(
        self,
        source: str = "all",
        difficulty: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """v0.4 catalog endpoint -- builtins + Hub-driven tasks.

        Each item carries: ``name``, ``signature``, ``description``,
        ``difficulty`` (``easy|medium|hard|None``), ``edge_case_count``,
        ``source`` (``builtin|hub``).
        """
        params: Dict[str, Any] = {"source": source}
        if difficulty:
            params["difficulty"] = difficulty
        return self._get("/tasks", params=params)["tasks"]

    def sample_inputs(self, target_name: str, n: int = 8, seed: int = 0) -> List[str]:
        """Pull ``n`` ready-to-probe input_repr strings from the env's own
        auto-fuzzer. Encapsulates the fuzz logic on the env side so the
        trainer doesn't have to keep its own per-task input pools in sync."""
        resp = self._get(
            f"/tasks/{target_name}/sample_inputs",
            params={"n": n, "seed": seed},
        )
        return list(resp["inputs"])

    def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:
        return self._post("/reset", {"target_name": target_name, "seed": seed, "max_steps": max_steps})

    def step(self, episode_id: str, action: Dict[str, Any]) -> Dict[str, Any]:
        return self._post("/step", {"episode_id": episode_id, "action": action})

    # --- High-level helpers used by the reward function --------------------

    def submit(self, episode_id: str, code: str) -> Dict[str, Any]:
        return self.step(episode_id, {"action_type": "submit", "code": code})

    def probe(self, episode_id: str, input_repr: str) -> Dict[str, Any]:
        return self.step(episode_id, {"action_type": "probe", "input_repr": input_repr})

    def score_submission(self, target_name: str, code: str, seed: int = 0) -> float:
        """One-shot: open an episode, submit the code, return total reward."""
        ep = self.reset(target_name=target_name, seed=seed, max_steps=2)
        resp = self.submit(ep["episode_id"], code)
        return float(resp["reward"])