File size: 4,495 Bytes
5ab5338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, Observation, State

import random
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from uuid import uuid4

from overview_env.models import OverviewObservation, OverviewAction, AnalysisResult, TaskType
from overview_env.tasks.definitions import (
    get_all_tasks,
    get_task_by_id,
    OverviewTask,
    OverviewTaskEvaluator,
    GradingResult,
)


@dataclass
class _EpisodeState:
    task: OverviewTask
    episode_id: str
    current_step: int = 0
    cumulative_reward: float = 0.0
    submitted_analysis: Dict[str, Any] = field(default_factory=dict)
    episode_complete: bool = False


DIFFICULTY_WEIGHTS = {"easy": 0.15, "medium": 0.12, "hard": 0.08}


class OverviewEnvironment(Environment):
    SUPPORTS_CONCURRENT_SESSIONS = True

    def __init__(self, task_id: Optional[str] = None, seed: Optional[int] = None, max_steps: int = 10):
        self._task_id = task_id
        self._seed = seed
        self._max_steps = max_steps
        self._ep: Optional[_EpisodeState] = None
        self._state = State(episode_id=str(uuid4()), step_count=0)
        if seed is not None:
            random.seed(seed)

    def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs: Any) -> Observation:
        if seed is not None:
            random.seed(seed)
        target_task_id = task_id or self._task_id
        if target_task_id:
            task = get_task_by_id(target_task_id)
        else:
            task = random.choice(get_all_tasks())
        eid = episode_id or str(uuid4())
        self._ep = _EpisodeState(task=task, episode_id=eid, current_step=0, cumulative_reward=0.0, submitted_analysis={}, episode_complete=False)
        self._state = State(episode_id=eid, step_count=0)
        obs = self._build_observation()
        return Observation(done=False, reward=0.0, metadata=obs.model_dump())

    def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs: Any) -> Observation:
        if self._ep is None:
            return Observation(done=True, reward=0.0, metadata={"error": "Environment not reset. Call reset() first."})
        if self._ep.episode_complete:
            return Observation(done=True, reward=0.0, metadata={"error": "Episode already finished."})
        action_data = {}
        if hasattr(action, "data") and isinstance(action.data, dict):
            action_data = action.data
        elif isinstance(action, dict):
            action_data = action
        elif hasattr(action, "__dict__"):
            action_data = vars(action)
        try:
            env_action = OverviewAction.model_validate(action_data)
        except Exception:
            env_action = OverviewAction(analysis=AnalysisResult())
        self._ep.current_step += 1
        self._state.step_count = self._ep.current_step
        analysis_dict = env_action.analysis.model_dump()
        self._ep.submitted_analysis = analysis_dict
        evaluator = OverviewTaskEvaluator(self._ep.task)
        result = evaluator.grade(analysis_dict)
        reward_value = result.score * DIFFICULTY_WEIGHTS.get(self._ep.task.difficulty, 0.1)
        terminal = env_action.submit or self._ep.current_step >= self._ep.task.max_steps
        if terminal:
            final_score = result.score
            if self._ep.current_step <= self._ep.task.max_steps * 0.5:
                final_score += 0.1
            reward_value = min(1.0, final_score)
            self._ep.episode_complete = True
        self._ep.cumulative_reward += reward_value
        obs = self._build_observation()
        step_result = {"observation": obs.model_dump(), "reward": {"value": reward_value, "total": result.score, "feedback": result.feedback}, "done": self._ep.episode_complete, "info": {"step": self._ep.current_step}}
        return Observation(done=self._ep.episode_complete, reward=reward_value, metadata=step_result)

    @property
    def state(self) -> State:
        return self._state

    def _build_observation(self) -> OverviewObservation:
        ep = self._ep
        return OverviewObservation(task_id=ep.task.task_id, task_type=TaskType(ep.task.task_type), task_name=ep.task.name, task_description=ep.task.description, difficulty=ep.task.difficulty, input_text=ep.task.input_text, question=ep.task.question, expected_output=ep.task.expected_output, max_steps=ep.task.max_steps)