File size: 3,017 Bytes
31715b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Example: talk to the OpenSleuth env via the upstream OpenEnv client.

This script connects to the deployed Space using the canonical OpenEnv
``GenericEnvClient`` (HTTP+WebSocket) and runs one episode end-to-end.

Usage::

    pip install openenv-core==0.2.3
    python example_client.py                 # hits the deployed Space
    python example_client.py http://localhost:7860  # against a local server

We hit the ``/openenv`` sub-app rather than the legacy bare routes, because
the OpenEnv client requires an OpenEnv-conformant ``/ws`` WebSocket. The
legacy ``/reset`` and ``/step`` endpoints used by the in-flight trainer are
preserved unchanged at the root.
"""

from __future__ import annotations

import asyncio
import sys

DEFAULT_BASE = (
    "https://anugrah55-opensleuth-env-gemini-cli.hf.space/openenv"
)


async def main(base_url: str) -> None:
    from openenv import GenericEnvClient, GenericAction

    print(f"Connecting to {base_url} ...")
    async with GenericEnvClient(base_url=base_url) as env:
        # Reset with the default ('fibonacci') target. Pass any of the legacy
        # OpenSleuth reset kwargs as extra fields; OpenEnv ResetRequest has
        # extra='allow', so target_name / target_code / max_steps / etc. all
        # flow through.
        result = await env.reset(target_name="fibonacci", max_steps=8, seed=42)
        obs = result.observation
        print("\n[reset]")
        print(f"  episode_id = {obs['episode_id']}")
        print(f"  target = {obs['target_function_name']} ({obs['difficulty']})")

        # Probe a few inputs.
        for repr_input in ("1", "5", "10", "-1", "'oops'"):
            result = await env.step(
                GenericAction(action_type="probe", input_repr=repr_input)
            )
            last = result.observation["probe_history"][-1]
            print(
                f"[probe {repr_input!s:>8}] -> output={last['output_repr']!r:>30} "
                f"reward={result.reward:+.2f} done={result.done}"
            )

        # Submit a perfect implementation.
        code = (
            "def fibonacci(n):\n"
            "    if not isinstance(n, int) or isinstance(n, bool) or n <= 0 or n > 90:\n"
            "        raise ValueError('bad')\n"
            "    a, b = 0, 1\n"
            "    for _ in range(n - 1):\n"
            "        a, b = b, a + b\n"
            "    return b\n"
        )
        result = await env.step(GenericAction(action_type="submit", code=code))
        info = result.observation.get("info", {})
        print("\n[submit reference impl]")
        print(f"  reward = {result.reward:.2f}")
        print(f"  done = {result.done}")
        print(f"  info = {info}")

        # State endpoint sanity check.
        state = await env.state()
        print(f"\n[state] {state}")


if __name__ == "__main__":
    base = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_BASE
    if not base.rstrip("/").endswith("/openenv"):
        base = base.rstrip("/") + "/openenv"
    asyncio.run(main(base))