Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Local in-process REPL helper. | |
| This module is intentionally separate from `client.py` so the remote client | |
| module does not import anything from `server/`. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Callable | |
| from typing import Optional | |
| try: | |
| from openenv.core.client_types import StepResult | |
| from .models import REPLAction, REPLObservation, REPLState | |
| from .rubrics import REPLRubric | |
| from .server.repl_environment import REPLEnvironment | |
| except ImportError: | |
| from models import REPLAction, REPLObservation, REPLState | |
| from openenv.core.client_types import StepResult | |
| from rubrics import REPLRubric | |
| from server.repl_environment import REPLEnvironment | |
| class LocalREPLEnv: | |
| """Explicit in-process REPL helper for local experimentation.""" | |
| def __init__( | |
| self, | |
| *, | |
| llm_query_fn: Optional[Callable[[str], str]] = None, | |
| llm_batch_fn: Optional[Callable[[list[str]], list[str]]] = None, | |
| subcall_fn: Optional[Callable[[str, Optional[str]], str]] = None, | |
| subcall_batch_fn: Optional[ | |
| Callable[[list[str], Optional[str]], list[str]] | |
| ] = None, | |
| max_output_length: int = 8192, | |
| context_preview_length: int = 500, | |
| rubric: Optional[REPLRubric] = None, | |
| rlm_max_depth: int = 1, | |
| rlm_max_iterations: int | None = None, | |
| ): | |
| self._env = REPLEnvironment( | |
| max_output_length=max_output_length, | |
| context_preview_length=context_preview_length, | |
| rubric=rubric, | |
| llm_query_fn=llm_query_fn, | |
| llm_batch_fn=llm_batch_fn, | |
| subcall_fn=subcall_fn, | |
| subcall_batch_fn=subcall_batch_fn, | |
| rlm_max_depth=rlm_max_depth, | |
| rlm_max_iterations=rlm_max_iterations, | |
| ) | |
| def reset( | |
| self, | |
| *, | |
| context: str = "", | |
| task_prompt: str = "", | |
| max_iterations: int = 30, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| llm_model: Optional[str] = None, | |
| expected_answer: Optional[str] = None, | |
| rlm_max_depth: Optional[int] = None, | |
| rlm_max_iterations: Optional[int] = None, | |
| ) -> StepResult[REPLObservation]: | |
| self._env.max_iterations = max_iterations | |
| reset_kwargs = {} | |
| if rlm_max_depth is not None: | |
| reset_kwargs["rlm_max_depth"] = rlm_max_depth | |
| if rlm_max_iterations is not None: | |
| reset_kwargs["rlm_max_iterations"] = rlm_max_iterations | |
| if expected_answer is not None: | |
| reset_kwargs["expected_answer"] = expected_answer | |
| obs = self._env.reset( | |
| seed=seed, | |
| episode_id=episode_id, | |
| context=context, | |
| task_prompt=task_prompt, | |
| hf_token=hf_token, | |
| llm_model=llm_model, | |
| **reset_kwargs, | |
| ) | |
| return self._wrap_observation(obs) | |
| def step(self, action: REPLAction) -> StepResult[REPLObservation]: | |
| return self._wrap_observation(self._env.step(action)) | |
| def execute(self, code: str) -> StepResult[REPLObservation]: | |
| return self.step(REPLAction(code=code)) | |
| def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]: | |
| return self.step(REPLAction(code="", is_final=True, final_answer=answer)) | |
| def get_variable(self, name: str) -> StepResult[REPLObservation]: | |
| return self.execute(f"print(repr({name}))") | |
| def state(self) -> REPLState: | |
| return self._env.state | |
| def list_variables(self) -> list[str]: | |
| return self.state().namespace_keys | |
| def close(self) -> None: | |
| self._env.close() | |
| def __enter__(self) -> "LocalREPLEnv": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.close() | |
| def _wrap_observation(obs: REPLObservation) -> StepResult[REPLObservation]: | |
| return StepResult( | |
| observation=obs, | |
| reward=obs.reward, | |
| done=obs.done, | |
| ) | |