Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv-compliant typed Pydantic models for the Email Triage environment. | |
| Three core model types required by the OpenEnv spec: | |
| - EmailAction : what the agent sends each step | |
| - EmailObservation : what the agent receives each step | |
| - EmailState : episode-level metadata | |
| Plus helpers: | |
| - EmailReward : detailed reward breakdown per step | |
| - StepResult : full result returned by step() | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field, field_validator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ACTION β what the agent sends on every step | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| VALID_LABELS = {"spam", "inbox", "urgent", "archive", "delete"} | |
| VALID_PRIORITIES = {"high", "medium", "low"} | |
| VALID_CATEGORIES = {"spam", "work", "personal", "newsletter", "notification", "social"} | |
| class EmailAction(BaseModel): | |
| """ | |
| The action an agent takes after reading one email. | |
| Task 1 (easy) β only `label` is scored (spam | inbox) | |
| Task 2 (medium) β `label` + `priority` are scored | |
| Task 3 (hard) β all three fields are scored | |
| """ | |
| label: str = Field( | |
| ..., | |
| description=( | |
| "How to file this email. " | |
| "spam=junk/malicious | inbox=needs attention | " | |
| "urgent=act immediately | archive=save but no action | delete=discard" | |
| ), | |
| examples=["spam", "inbox", "urgent", "archive", "delete"], | |
| ) | |
| priority: str = Field( | |
| default="medium", | |
| description="How time-sensitive is this email. high=act today | medium=act this week | low=whenever", | |
| examples=["high", "medium", "low"], | |
| ) | |
| category: str = Field( | |
| default="work", | |
| description=( | |
| "What type of email this is. " | |
| "spam | work | personal | newsletter | notification | social" | |
| ), | |
| examples=["work", "personal", "newsletter", "notification", "social", "spam"], | |
| ) | |
| def validate_label(cls, v: str) -> str: | |
| v = v.lower().strip() | |
| if v not in VALID_LABELS: | |
| raise ValueError(f"label must be one of {VALID_LABELS}, got '{v}'") | |
| return v | |
| def validate_priority(cls, v: str) -> str: | |
| v = v.lower().strip() | |
| if v not in VALID_PRIORITIES: | |
| raise ValueError(f"priority must be one of {VALID_PRIORITIES}, got '{v}'") | |
| return v | |
| def validate_category(cls, v: str) -> str: | |
| v = v.lower().strip() | |
| if v not in VALID_CATEGORIES: | |
| raise ValueError(f"category must be one of {VALID_CATEGORIES}, got '{v}'") | |
| return v | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # OBSERVATION β what the agent receives each step | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EmailObservation(BaseModel): | |
| """ | |
| Everything the agent can see after a reset() or step() call. | |
| Contains the current email to classify plus episode progress info. | |
| """ | |
| # Current email content | |
| email_id: str = Field(..., description="Unique identifier for this email") | |
| subject: str = Field(..., description="Email subject line") | |
| sender: str = Field(..., description="Sender email address") | |
| body: str = Field(..., description="Full email body text") | |
| timestamp: str = Field(..., description="When the email arrived (YYYY-MM-DD HH:MM:SS)") | |
| # Episode progress | |
| step: int = Field(..., description="Current step number (0-indexed)") | |
| total_emails: int = Field(..., description="Total number of emails in this episode") | |
| emails_remaining: int = Field(..., description="Emails still to be processed (including current)") | |
| # Reward signal | |
| reward: float = Field(0.0, description="Reward for the LAST action (0.0 on first obs)") | |
| cumulative_reward: float = Field(0.0, description="Total reward accumulated this episode") | |
| # Terminal flag | |
| done: bool = Field(False, description="True when all emails have been processed") | |
| # Optional extra info | |
| metadata: Dict[str, Any] = Field(default_factory=dict, description="Extra debugging info") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # REWARD β detailed breakdown of how a step was scored | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EmailReward(BaseModel): | |
| """ | |
| Detailed reward breakdown for a single step. | |
| Returned in StepResult.info β gives the agent rich learning signal. | |
| """ | |
| value: float = Field(..., description="Total reward for this step [0.0, 1.0]") | |
| label_score: float = Field(0.0, description="Score for the label field [0.0, 1.0]") | |
| priority_score: float = Field(0.0, description="Score for the priority field [0.0, 1.0]") | |
| category_score: float = Field(0.0, description="Score for the category field [0.0, 1.0]") | |
| feedback: str = Field("", description="Human-readable explanation of scoring") | |
| penalties: List[str] = Field(default_factory=list, description="Any penalties applied and why") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STATE β episode-level metadata | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EmailState(BaseModel): | |
| """ | |
| Episode metadata returned by state(). | |
| Gives an overview of current episode progress without email content. | |
| """ | |
| episode_id: str = Field(..., description="Unique ID for this episode (UUID)") | |
| task_id: int = Field(..., description="Which task is running (1, 2, or 3)") | |
| task_name: str = Field(..., description="Human-readable task name") | |
| task_difficulty: str = Field(..., description="easy | medium | hard") | |
| step_count: int = Field(..., description="Steps taken so far") | |
| total_emails: int = Field(..., description="Total emails in this episode") | |
| cumulative_reward: float = Field(..., description="Total reward accumulated so far") | |
| score: float = Field(..., description="Normalized score so far [0.0, 1.0]") | |
| done: bool = Field(..., description="True when episode is complete") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STEP RESULT β full response from step() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class StepResult(BaseModel): | |
| """ | |
| Complete result returned by the /step endpoint. | |
| Matches the OpenEnv spec: (observation, reward, done, info). | |
| """ | |
| observation: EmailObservation | |
| reward: float | |
| done: bool | |
| info: Dict[str, Any] = Field(default_factory=dict) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RESET REQUEST | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| """Request body for the /reset endpoint.""" | |
| task_id: int = Field( | |
| default=1, | |
| ge=1, | |
| le=3, | |
| description="Which task to run. 1=Spam Detection (easy), 2=Priority Triage (medium), 3=Full Triage (hard)", | |
| ) | |