File size: 1,578 Bytes
77eebd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Any, Dict

try:
    from openenv.core.env_client import EnvClient
    from openenv.core.client_types import StepResult
except ImportError:
    from openenv.core.env_client import EnvClient
    from openenv.core.client_types import StepResult

from overview_env.models import OverviewObservation, OverviewAction


class OverviewEnv(EnvClient[OverviewAction, OverviewObservation, Dict[str, Any]]):
    def _step_payload(self, action: OverviewAction) -> Dict[str, Any]:
        return action.model_dump()

    def _parse_result(self, payload: Dict[str, Any]) -> "StepResult[OverviewObservation]":
        metadata = payload.get("metadata", payload)
        obs_data = metadata.get("observation", metadata)
        reward_value = payload.get("reward", 0.0)
        done = payload.get("done", False)

        try:
            observation = OverviewObservation.model_validate(obs_data)
        except Exception:
            observation = OverviewObservation(
                task_id=obs_data.get("task_id", "unknown"),
                task_type=obs_data.get("task_type", "summarization"),
                task_name=obs_data.get("task_name", "unknown"),
                task_description=obs_data.get("task_description", ""),
                input_text=obs_data.get("input_text", ""),
            )

        return StepResult(
            observation=observation,
            reward=reward_value,
            done=done,
            info=metadata.get("info", {}),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        return payload