hirann commited on
Commit
5ab5338
·
verified ·
1 Parent(s): 961d0eb

Upload server/overview_environment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/overview_environment.py +99 -0
server/overview_environment.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openenv.core.env_server.interfaces import Environment
2
+ from openenv.core.env_server.types import Action, Observation, State
3
+
4
+ import random
5
+ from typing import Any, Dict, Optional
6
+ from dataclasses import dataclass, field
7
+ from uuid import uuid4
8
+
9
+ from overview_env.models import OverviewObservation, OverviewAction, AnalysisResult, TaskType
10
+ from overview_env.tasks.definitions import (
11
+ get_all_tasks,
12
+ get_task_by_id,
13
+ OverviewTask,
14
+ OverviewTaskEvaluator,
15
+ GradingResult,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class _EpisodeState:
21
+ task: OverviewTask
22
+ episode_id: str
23
+ current_step: int = 0
24
+ cumulative_reward: float = 0.0
25
+ submitted_analysis: Dict[str, Any] = field(default_factory=dict)
26
+ episode_complete: bool = False
27
+
28
+
29
+ DIFFICULTY_WEIGHTS = {"easy": 0.15, "medium": 0.12, "hard": 0.08}
30
+
31
+
32
+ class OverviewEnvironment(Environment):
33
+ SUPPORTS_CONCURRENT_SESSIONS = True
34
+
35
+ def __init__(self, task_id: Optional[str] = None, seed: Optional[int] = None, max_steps: int = 10):
36
+ self._task_id = task_id
37
+ self._seed = seed
38
+ self._max_steps = max_steps
39
+ self._ep: Optional[_EpisodeState] = None
40
+ self._state = State(episode_id=str(uuid4()), step_count=0)
41
+ if seed is not None:
42
+ random.seed(seed)
43
+
44
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs: Any) -> Observation:
45
+ if seed is not None:
46
+ random.seed(seed)
47
+ target_task_id = task_id or self._task_id
48
+ if target_task_id:
49
+ task = get_task_by_id(target_task_id)
50
+ else:
51
+ task = random.choice(get_all_tasks())
52
+ eid = episode_id or str(uuid4())
53
+ self._ep = _EpisodeState(task=task, episode_id=eid, current_step=0, cumulative_reward=0.0, submitted_analysis={}, episode_complete=False)
54
+ self._state = State(episode_id=eid, step_count=0)
55
+ obs = self._build_observation()
56
+ return Observation(done=False, reward=0.0, metadata=obs.model_dump())
57
+
58
+ def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
59
+ if self._ep is None:
60
+ return Observation(done=True, reward=0.0, metadata={"error": "Environment not reset. Call reset() first."})
61
+ if self._ep.episode_complete:
62
+ return Observation(done=True, reward=0.0, metadata={"error": "Episode already finished."})
63
+ action_data = {}
64
+ if hasattr(action, "data") and isinstance(action.data, dict):
65
+ action_data = action.data
66
+ elif isinstance(action, dict):
67
+ action_data = action
68
+ elif hasattr(action, "__dict__"):
69
+ action_data = vars(action)
70
+ try:
71
+ env_action = OverviewAction.model_validate(action_data)
72
+ except Exception:
73
+ env_action = OverviewAction(analysis=AnalysisResult())
74
+ self._ep.current_step += 1
75
+ self._state.step_count = self._ep.current_step
76
+ analysis_dict = env_action.analysis.model_dump()
77
+ self._ep.submitted_analysis = analysis_dict
78
+ evaluator = OverviewTaskEvaluator(self._ep.task)
79
+ result = evaluator.grade(analysis_dict)
80
+ reward_value = result.score * DIFFICULTY_WEIGHTS.get(self._ep.task.difficulty, 0.1)
81
+ terminal = env_action.submit or self._ep.current_step >= self._ep.task.max_steps
82
+ if terminal:
83
+ final_score = result.score
84
+ if self._ep.current_step <= self._ep.task.max_steps * 0.5:
85
+ final_score += 0.1
86
+ reward_value = min(1.0, final_score)
87
+ self._ep.episode_complete = True
88
+ self._ep.cumulative_reward += reward_value
89
+ obs = self._build_observation()
90
+ step_result = {"observation": obs.model_dump(), "reward": {"value": reward_value, "total": result.score, "feedback": result.feedback}, "done": self._ep.episode_complete, "info": {"step": self._ep.current_step}}
91
+ return Observation(done=self._ep.episode_complete, reward=reward_value, metadata=step_result)
92
+
93
+ @property
94
+ def state(self) -> State:
95
+ return self._state
96
+
97
+ def _build_observation(self) -> OverviewObservation:
98
+ ep = self._ep
99
+ return OverviewObservation(task_id=ep.task.task_id, task_type=TaskType(ep.task.task_type), task_name=ep.task.name, task_description=ep.task.description, difficulty=ep.task.difficulty, input_text=ep.task.input_text, question=ep.task.question, expected_output=ep.task.expected_output, max_steps=ep.task.max_steps)