Spaces:
Paused
Paused
File size: 4,495 Bytes
5ab5338 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, Observation, State
import random
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from uuid import uuid4
from overview_env.models import OverviewObservation, OverviewAction, AnalysisResult, TaskType
from overview_env.tasks.definitions import (
get_all_tasks,
get_task_by_id,
OverviewTask,
OverviewTaskEvaluator,
GradingResult,
)
@dataclass
class _EpisodeState:
task: OverviewTask
episode_id: str
current_step: int = 0
cumulative_reward: float = 0.0
submitted_analysis: Dict[str, Any] = field(default_factory=dict)
episode_complete: bool = False
DIFFICULTY_WEIGHTS = {"easy": 0.15, "medium": 0.12, "hard": 0.08}
class OverviewEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, task_id: Optional[str] = None, seed: Optional[int] = None, max_steps: int = 10):
self._task_id = task_id
self._seed = seed
self._max_steps = max_steps
self._ep: Optional[_EpisodeState] = None
self._state = State(episode_id=str(uuid4()), step_count=0)
if seed is not None:
random.seed(seed)
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs: Any) -> Observation:
if seed is not None:
random.seed(seed)
target_task_id = task_id or self._task_id
if target_task_id:
task = get_task_by_id(target_task_id)
else:
task = random.choice(get_all_tasks())
eid = episode_id or str(uuid4())
self._ep = _EpisodeState(task=task, episode_id=eid, current_step=0, cumulative_reward=0.0, submitted_analysis={}, episode_complete=False)
self._state = State(episode_id=eid, step_count=0)
obs = self._build_observation()
return Observation(done=False, reward=0.0, metadata=obs.model_dump())
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
if self._ep is None:
return Observation(done=True, reward=0.0, metadata={"error": "Environment not reset. Call reset() first."})
if self._ep.episode_complete:
return Observation(done=True, reward=0.0, metadata={"error": "Episode already finished."})
action_data = {}
if hasattr(action, "data") and isinstance(action.data, dict):
action_data = action.data
elif isinstance(action, dict):
action_data = action
elif hasattr(action, "__dict__"):
action_data = vars(action)
try:
env_action = OverviewAction.model_validate(action_data)
except Exception:
env_action = OverviewAction(analysis=AnalysisResult())
self._ep.current_step += 1
self._state.step_count = self._ep.current_step
analysis_dict = env_action.analysis.model_dump()
self._ep.submitted_analysis = analysis_dict
evaluator = OverviewTaskEvaluator(self._ep.task)
result = evaluator.grade(analysis_dict)
reward_value = result.score * DIFFICULTY_WEIGHTS.get(self._ep.task.difficulty, 0.1)
terminal = env_action.submit or self._ep.current_step >= self._ep.task.max_steps
if terminal:
final_score = result.score
if self._ep.current_step <= self._ep.task.max_steps * 0.5:
final_score += 0.1
reward_value = min(1.0, final_score)
self._ep.episode_complete = True
self._ep.cumulative_reward += reward_value
obs = self._build_observation()
step_result = {"observation": obs.model_dump(), "reward": {"value": reward_value, "total": result.score, "feedback": result.feedback}, "done": self._ep.episode_complete, "info": {"step": self._ep.current_step}}
return Observation(done=self._ep.episode_complete, reward=reward_value, metadata=step_result)
@property
def state(self) -> State:
return self._state
def _build_observation(self) -> OverviewObservation:
ep = self._ep
return OverviewObservation(task_id=ep.task.task_id, task_type=TaskType(ep.task.task_type), task_name=ep.task.name, task_description=ep.task.description, difficulty=ep.task.difficulty, input_text=ep.task.input_text, question=ep.task.question, expected_output=ep.task.expected_output, max_steps=ep.task.max_steps) |