File size: 2,323 Bytes
e721fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Data models for the OpenEnv email triage environment."""

from typing import Literal

from pydantic import BaseModel, field_validator

OPEN_INTERVAL_EPSILON = 1e-2


def _strict_open_unit_interval(raw_value: float) -> float:
    """Clamp numeric values to the strict open interval (0, 1)."""
    numeric_value = float(raw_value)
    if numeric_value <= 0.0:
        return OPEN_INTERVAL_EPSILON
    if numeric_value >= 1.0:
        return 1.0 - OPEN_INTERVAL_EPSILON
    return numeric_value


class EmailObservation(BaseModel):
    """Represents the email context visible to the agent at each step."""

    email_id: str
    subject: str
    body: str
    sender: str
    timestamp: str
    thread_history: list[str]
    task_id: str
    step_number: int
    total_emails: int


class TriageAction(BaseModel):
    """Represents the action chosen by the agent for an email."""

    label: Literal["urgent", "normal", "spam", "archive"]
    summary: str
    route_to: str


class RewardResult(BaseModel):
    """Represents deterministic grading output before reward shaping."""

    score: float
    breakdown: dict[str, float]
    feedback: str

    @field_validator("score")
    @classmethod
    def _validate_score(cls, value: float) -> float:
        return _strict_open_unit_interval(value)


class EnvironmentState(BaseModel):
    """Represents full internal environment state for debugging and evaluation."""

    task_id: str
    current_step: int
    total_steps: int
    done: bool
    action_history: list[TriageAction]
    reward_history: list[float]

    @field_validator("reward_history")
    @classmethod
    def _validate_reward_history(cls, values: list[float]) -> list[float]:
        return [_strict_open_unit_interval(value) for value in values]


class StepResult(BaseModel):
    """Represents the standardized output of environment step calls."""

    observation: EmailObservation
    reward: float
    done: bool
    info: dict[str, str | int | float | bool]

    @field_validator("reward")
    @classmethod
    def _validate_reward(cls, value: float) -> float:
        return _strict_open_unit_interval(value)


class ResetResult(BaseModel):
    """Represents the standardized output of environment reset calls."""

    observation: EmailObservation
    info: dict[str, str | int | float | bool]