repl / local.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
b74674a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Local in-process REPL helper.
This module is intentionally separate from `client.py` so the remote client
module does not import anything from `server/`.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Optional
try:
from openenv.core.client_types import StepResult
from .models import REPLAction, REPLObservation, REPLState
from .rubrics import REPLRubric
from .server.repl_environment import REPLEnvironment
except ImportError:
from models import REPLAction, REPLObservation, REPLState
from openenv.core.client_types import StepResult
from rubrics import REPLRubric
from server.repl_environment import REPLEnvironment
class LocalREPLEnv:
"""Explicit in-process REPL helper for local experimentation."""
def __init__(
self,
*,
llm_query_fn: Optional[Callable[[str], str]] = None,
llm_batch_fn: Optional[Callable[[list[str]], list[str]]] = None,
subcall_fn: Optional[Callable[[str, Optional[str]], str]] = None,
subcall_batch_fn: Optional[
Callable[[list[str], Optional[str]], list[str]]
] = None,
max_output_length: int = 8192,
context_preview_length: int = 500,
rubric: Optional[REPLRubric] = None,
rlm_max_depth: int = 1,
rlm_max_iterations: int | None = None,
):
self._env = REPLEnvironment(
max_output_length=max_output_length,
context_preview_length=context_preview_length,
rubric=rubric,
llm_query_fn=llm_query_fn,
llm_batch_fn=llm_batch_fn,
subcall_fn=subcall_fn,
subcall_batch_fn=subcall_batch_fn,
rlm_max_depth=rlm_max_depth,
rlm_max_iterations=rlm_max_iterations,
)
def reset(
self,
*,
context: str = "",
task_prompt: str = "",
max_iterations: int = 30,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
hf_token: Optional[str] = None,
llm_model: Optional[str] = None,
expected_answer: Optional[str] = None,
rlm_max_depth: Optional[int] = None,
rlm_max_iterations: Optional[int] = None,
) -> StepResult[REPLObservation]:
self._env.max_iterations = max_iterations
reset_kwargs = {}
if rlm_max_depth is not None:
reset_kwargs["rlm_max_depth"] = rlm_max_depth
if rlm_max_iterations is not None:
reset_kwargs["rlm_max_iterations"] = rlm_max_iterations
if expected_answer is not None:
reset_kwargs["expected_answer"] = expected_answer
obs = self._env.reset(
seed=seed,
episode_id=episode_id,
context=context,
task_prompt=task_prompt,
hf_token=hf_token,
llm_model=llm_model,
**reset_kwargs,
)
return self._wrap_observation(obs)
def step(self, action: REPLAction) -> StepResult[REPLObservation]:
return self._wrap_observation(self._env.step(action))
def execute(self, code: str) -> StepResult[REPLObservation]:
return self.step(REPLAction(code=code))
def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]:
return self.step(REPLAction(code="", is_final=True, final_answer=answer))
def get_variable(self, name: str) -> StepResult[REPLObservation]:
return self.execute(f"print(repr({name}))")
def state(self) -> REPLState:
return self._env.state
def list_variables(self) -> list[str]:
return self.state().namespace_keys
def close(self) -> None:
self._env.close()
def __enter__(self) -> "LocalREPLEnv":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
@staticmethod
def _wrap_observation(obs: REPLObservation) -> StepResult[REPLObservation]:
return StepResult(
observation=obs,
reward=obs.reward,
done=obs.done,
)