File size: 4,287 Bytes
b74674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Local in-process REPL helper.

This module is intentionally separate from `client.py` so the remote client
module does not import anything from `server/`.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Optional

try:
    from openenv.core.client_types import StepResult

    from .models import REPLAction, REPLObservation, REPLState
    from .rubrics import REPLRubric
    from .server.repl_environment import REPLEnvironment
except ImportError:
    from models import REPLAction, REPLObservation, REPLState
    from openenv.core.client_types import StepResult
    from rubrics import REPLRubric
    from server.repl_environment import REPLEnvironment


class LocalREPLEnv:
    """Explicit in-process REPL helper for local experimentation."""

    def __init__(
        self,
        *,
        llm_query_fn: Optional[Callable[[str], str]] = None,
        llm_batch_fn: Optional[Callable[[list[str]], list[str]]] = None,
        subcall_fn: Optional[Callable[[str, Optional[str]], str]] = None,
        subcall_batch_fn: Optional[
            Callable[[list[str], Optional[str]], list[str]]
        ] = None,
        max_output_length: int = 8192,
        context_preview_length: int = 500,
        rubric: Optional[REPLRubric] = None,
        rlm_max_depth: int = 1,
        rlm_max_iterations: int | None = None,
    ):
        self._env = REPLEnvironment(
            max_output_length=max_output_length,
            context_preview_length=context_preview_length,
            rubric=rubric,
            llm_query_fn=llm_query_fn,
            llm_batch_fn=llm_batch_fn,
            subcall_fn=subcall_fn,
            subcall_batch_fn=subcall_batch_fn,
            rlm_max_depth=rlm_max_depth,
            rlm_max_iterations=rlm_max_iterations,
        )

    def reset(
        self,
        *,
        context: str = "",
        task_prompt: str = "",
        max_iterations: int = 30,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        hf_token: Optional[str] = None,
        llm_model: Optional[str] = None,
        expected_answer: Optional[str] = None,
        rlm_max_depth: Optional[int] = None,
        rlm_max_iterations: Optional[int] = None,
    ) -> StepResult[REPLObservation]:
        self._env.max_iterations = max_iterations
        reset_kwargs = {}
        if rlm_max_depth is not None:
            reset_kwargs["rlm_max_depth"] = rlm_max_depth
        if rlm_max_iterations is not None:
            reset_kwargs["rlm_max_iterations"] = rlm_max_iterations
        if expected_answer is not None:
            reset_kwargs["expected_answer"] = expected_answer
        obs = self._env.reset(
            seed=seed,
            episode_id=episode_id,
            context=context,
            task_prompt=task_prompt,
            hf_token=hf_token,
            llm_model=llm_model,
            **reset_kwargs,
        )
        return self._wrap_observation(obs)

    def step(self, action: REPLAction) -> StepResult[REPLObservation]:
        return self._wrap_observation(self._env.step(action))

    def execute(self, code: str) -> StepResult[REPLObservation]:
        return self.step(REPLAction(code=code))

    def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]:
        return self.step(REPLAction(code="", is_final=True, final_answer=answer))

    def get_variable(self, name: str) -> StepResult[REPLObservation]:
        return self.execute(f"print(repr({name}))")

    def state(self) -> REPLState:
        return self._env.state

    def list_variables(self) -> list[str]:
        return self.state().namespace_keys

    def close(self) -> None:
        self._env.close()

    def __enter__(self) -> "LocalREPLEnv":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close()

    @staticmethod
    def _wrap_observation(obs: REPLObservation) -> StepResult[REPLObservation]:
        return StepResult(
            observation=obs,
            reward=obs.reward,
            done=obs.done,
        )