hirann commited on
Commit
77eebd8
·
verified ·
1 Parent(s): 97c03e5

Upload client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. client.py +42 -0
client.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ try:
4
+ from openenv.core.env_client import EnvClient
5
+ from openenv.core.client_types import StepResult
6
+ except ImportError:
7
+ from openenv.core.env_client import EnvClient
8
+ from openenv.core.client_types import StepResult
9
+
10
+ from overview_env.models import OverviewObservation, OverviewAction
11
+
12
+
13
+ class OverviewEnv(EnvClient[OverviewAction, OverviewObservation, Dict[str, Any]]):
14
+ def _step_payload(self, action: OverviewAction) -> Dict[str, Any]:
15
+ return action.model_dump()
16
+
17
+ def _parse_result(self, payload: Dict[str, Any]) -> "StepResult[OverviewObservation]":
18
+ metadata = payload.get("metadata", payload)
19
+ obs_data = metadata.get("observation", metadata)
20
+ reward_value = payload.get("reward", 0.0)
21
+ done = payload.get("done", False)
22
+
23
+ try:
24
+ observation = OverviewObservation.model_validate(obs_data)
25
+ except Exception:
26
+ observation = OverviewObservation(
27
+ task_id=obs_data.get("task_id", "unknown"),
28
+ task_type=obs_data.get("task_type", "summarization"),
29
+ task_name=obs_data.get("task_name", "unknown"),
30
+ task_description=obs_data.get("task_description", ""),
31
+ input_text=obs_data.get("input_text", ""),
32
+ )
33
+
34
+ return StepResult(
35
+ observation=observation,
36
+ reward=reward_value,
37
+ done=done,
38
+ info=metadata.get("info", {}),
39
+ )
40
+
41
+ def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
42
+ return payload