File size: 5,044 Bytes
c07f15e
 
 
 
 
 
025774a
 
 
cc6473a
025774a
 
 
 
 
 
 
 
 
 
105973d
025774a
105973d
025774a
 
 
 
cc6473a
025774a
 
e74ff96
cc6473a
 
025774a
 
 
 
cc6473a
025774a
 
105973d
 
 
 
 
 
 
025774a
 
105973d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
025774a
 
cc6473a
 
 
 
 
 
 
 
 
025774a
 
 
 
105973d
 
 
 
 
 
 
 
 
025774a
 
 
 
 
 
 
 
 
 
 
 
 
cc6473a
 
 
 
 
 
 
 
 
 
025774a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
RhythmEnv Client.

Provides the WebSocket client for connecting to a RhythmEnv Life Simulator server.
"""

from __future__ import annotations

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

try:
    from .models import RhythmAction, RhythmObservation, RhythmState, StepRecord
except ImportError:
    from models import RhythmAction, RhythmObservation, RhythmState, StepRecord


class RhythmEnv(EnvClient[RhythmAction, RhythmObservation, RhythmState]):
    """
    Client for the RhythmEnv Life Simulator.

    Example:
        >>> async with RhythmEnv(base_url="https://InosLihka-rhythm-env.hf.space") as client:
        ...     result = await client.reset()
        ...     result = await client.step(RhythmAction(action_type=ActionType.DEEP_WORK))
    """

    def _step_payload(self, action: RhythmAction) -> Dict[str, Any]:
        """Serialize RhythmAction to JSON payload."""
        return {"action_type": action.action_type.value}

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[RhythmObservation]:
        """Parse server response into StepResult[RhythmObservation].

        Surfaces ALL observation fields the server returns, including the
        per-meter deltas, anomalies (in step_history), last_action, and the
        full step history. Without these, an external agent connecting to the
        server can't see the meta-RL signals it needs to infer the profile.
        """
        obs_data = payload.get("observation", {})

        # Reconstruct step_history with full StepRecord fidelity
        step_history_raw = obs_data.get("step_history", []) or []
        step_history = [
            StepRecord(
                step=h.get("step", 0),
                action=h.get("action", ""),
                reward=h.get("reward", 0.0),
                vitality_delta=h.get("vitality_delta", 0.0),
                cognition_delta=h.get("cognition_delta", 0.0),
                progress_delta=h.get("progress_delta", 0.0),
                serenity_delta=h.get("serenity_delta", 0.0),
                connection_delta=h.get("connection_delta", 0.0),
                vitality_anomaly=h.get("vitality_anomaly", 0.0),
                cognition_anomaly=h.get("cognition_anomaly", 0.0),
                progress_anomaly=h.get("progress_anomaly", 0.0),
                serenity_anomaly=h.get("serenity_anomaly", 0.0),
                connection_anomaly=h.get("connection_anomaly", 0.0),
            )
            for h in step_history_raw
        ]

        observation = RhythmObservation(
            timestep=obs_data.get("timestep", 0),
            day=obs_data.get("day", 0),
            slot=obs_data.get("slot", 0),
            vitality=obs_data.get("vitality", 0.8),
            cognition=obs_data.get("cognition", 0.7),
            progress=obs_data.get("progress", 0.0),
            serenity=obs_data.get("serenity", 0.7),
            connection=obs_data.get("connection", 0.5),
            active_event=obs_data.get("active_event"),
            remaining_steps=obs_data.get("remaining_steps", 28),
            reward_breakdown=obs_data.get("reward_breakdown", {}),
            done=payload.get("done", False),
            reward=payload.get("reward", 0.0),
            metadata=obs_data.get("metadata", {}),
            # Per-meter deltas from THIS step (was being silently dropped)
            vitality_delta=obs_data.get("vitality_delta", 0.0),
            cognition_delta=obs_data.get("cognition_delta", 0.0),
            progress_delta=obs_data.get("progress_delta", 0.0),
            serenity_delta=obs_data.get("serenity_delta", 0.0),
            connection_delta=obs_data.get("connection_delta", 0.0),
            last_action=obs_data.get("last_action"),
            # Rolling history with anomalies (the meta-RL signal)
            step_history=step_history,
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> RhythmState:
        """Parse server response into RhythmState."""
        return RhythmState(
            episode_id=payload.get("episode_id", ""),
            step_count=payload.get("step_count", 0),
            timestep=payload.get("timestep", 0),
            day=payload.get("day", 0),
            slot=payload.get("slot", 0),
            profile_name=payload.get("profile_name", ""),
            vitality=payload.get("vitality", 0.8),
            cognition=payload.get("cognition", 0.7),
            progress=payload.get("progress", 0.0),
            serenity=payload.get("serenity", 0.7),
            connection=payload.get("connection", 0.5),
            active_event=payload.get("active_event"),
        )