File size: 2,337 Bytes
53dbcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0594d27
53dbcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""OpenEnv environment wrapper for Flatmate RL."""

from __future__ import annotations

import os
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment

try:
    from ..models import FlatmateRlAction, FlatmateRlObservation, FlatmateRlState
    from .episode import FlatmateEpisode
except ImportError:
    from models import FlatmateRlAction, FlatmateRlObservation, FlatmateRlState
    from server.episode import FlatmateEpisode


class FlatmateRlEnvironment(Environment):
    """Deterministic RL environment for Flatmate visit scheduling."""

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self) -> None:
        super().__init__()
        strict_eval_mode = os.getenv("STRICT_EVAL_MODE", "").lower() in {"1", "true", "yes", "on"}
        self._state = FlatmateRlState(episode_id=str(uuid4()), step_count=0)
        self._episode = FlatmateEpisode(strict_eval_mode=strict_eval_mode)

    def reset(self, scenario_id: str | None = None, seed: int | None = None) -> FlatmateRlObservation:
        self._state = FlatmateRlState(episode_id=str(uuid4()), step_count=0)
        observation = self._episode.reset(scenario_id=scenario_id, seed=seed)
        self._sync_state(observation)
        return observation

    def step(self, action: FlatmateRlAction) -> FlatmateRlObservation:  # type: ignore[override]
        self._state.step_count += 1
        observation = self._episode.step(action)
        self._sync_state(observation)
        return observation

    @property
    def state(self) -> FlatmateRlState:
        return self._state

    def _sync_state(self, observation: FlatmateRlObservation) -> None:
        self._state.scenario_id = observation.scenario_id
        self._state.phase = observation.phase
        self._state.status = observation.status
        self._state.gathered_fields = list(observation.gathered_fields)
        self._state.selected_posts = list(observation.selected_posts)
        self._state.booked_visits = list(observation.booked_visits)
        self._state.buyer_profile_stored = observation.buyer_profile_stored
        self._state.seller_profile_stored = observation.seller_profile_stored
        self._state.tool_trace = list(observation.tool_trace)
        self._state.total_reward = observation.total_reward
        self._state.done = observation.done