Spaces:
Sleeping
Sleeping
| """OpenEnv environment wrapper for Flatmate RL.""" | |
| from __future__ import annotations | |
| import os | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| try: | |
| from ..models import FlatmateRlAction, FlatmateRlObservation, FlatmateRlState | |
| from .episode import FlatmateEpisode | |
| except ImportError: | |
| from models import FlatmateRlAction, FlatmateRlObservation, FlatmateRlState | |
| from server.episode import FlatmateEpisode | |
| class FlatmateRlEnvironment(Environment): | |
| """Deterministic RL environment for Flatmate visit scheduling.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| strict_eval_mode = os.getenv("STRICT_EVAL_MODE", "").lower() in {"1", "true", "yes", "on"} | |
| self._state = FlatmateRlState(episode_id=str(uuid4()), step_count=0) | |
| self._episode = FlatmateEpisode(strict_eval_mode=strict_eval_mode) | |
| def reset(self, scenario_id: str | None = None, seed: int | None = None) -> FlatmateRlObservation: | |
| self._state = FlatmateRlState(episode_id=str(uuid4()), step_count=0) | |
| observation = self._episode.reset(scenario_id=scenario_id, seed=seed) | |
| self._sync_state(observation) | |
| return observation | |
| def step(self, action: FlatmateRlAction) -> FlatmateRlObservation: # type: ignore[override] | |
| self._state.step_count += 1 | |
| observation = self._episode.step(action) | |
| self._sync_state(observation) | |
| return observation | |
| def state(self) -> FlatmateRlState: | |
| return self._state | |
| def _sync_state(self, observation: FlatmateRlObservation) -> None: | |
| self._state.scenario_id = observation.scenario_id | |
| self._state.phase = observation.phase | |
| self._state.status = observation.status | |
| self._state.gathered_fields = list(observation.gathered_fields) | |
| self._state.selected_posts = list(observation.selected_posts) | |
| self._state.booked_visits = list(observation.booked_visits) | |
| self._state.buyer_profile_stored = observation.buyer_profile_stored | |
| self._state.seller_profile_stored = observation.seller_profile_stored | |
| self._state.tool_trace = list(observation.tool_trace) | |
| self._state.total_reward = observation.total_reward | |
| self._state.done = observation.done | |