File size: 2,804 Bytes
b77d3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# salespath_env/client.py

from typing import Any, Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from .models import (
    SalesPathAction,
    SalesPathObservation,
    SalesPathState,
)


class SalesPathEnv(EnvClient[SalesPathAction, SalesPathObservation, SalesPathState]):

    # ------------------------------------------------------------------ #
    # Abstract method implementations required by EnvClient               #
    # ------------------------------------------------------------------ #

    def _step_payload(self, action: SalesPathAction) -> Dict[str, Any]:
        """Serialise action → JSON dict for the WebSocket server.
        WSStepMessage.data IS the action dict directly (no wrapper key).
        """
        return action.model_dump(exclude={"metadata"})

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SalesPathObservation]:
        """Deserialise server JSON → StepResult[SalesPathObservation]."""
        # Server may nest obs under an 'observation' key
        obs_data = payload.get("observation", payload)
        obs = SalesPathObservation(**obs_data)
        return StepResult(
            observation=obs,
            reward=payload.get("reward", obs.reward),
            done=payload.get("done", obs.done),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> SalesPathState:
        """Deserialise server JSON → SalesPathState."""
        state_data = payload.get("state", payload)
        return SalesPathState(**state_data)

    # ------------------------------------------------------------------ #
    # Convenience wrappers that return the unwrapped observation directly  #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _with_step_fields(
        result: StepResult[SalesPathObservation],
    ) -> SalesPathObservation:
        """
        Keep observation fields in sync with StepResult wrapper fields.
        Some server payloads provide reward/done only at top-level.
        """
        return result.observation.model_copy(
            update={
                "reward": result.reward,
                "done": result.done,
            }
        )

    async def reset(
        self,
        difficulty: int = 1,
    ) -> SalesPathObservation:
        result = await super().reset(difficulty=difficulty)
        return self._with_step_fields(result)

    async def step(
        self,
        action_type: str,
        content: str,
        target: str = "",
    ) -> SalesPathObservation:
        action = SalesPathAction(
            action_type=action_type,
            content=content,
            target=target,
        )
        result = await super().step(action)
        return self._with_step_fields(result)