email-triage / models.py
rishi38's picture
Email Triage OpenEnv - hackathon submission
63e46ae
"""
OpenEnv-compliant typed Pydantic models for the Email Triage environment.
Three core model types required by the OpenEnv spec:
- EmailAction : what the agent sends each step
- EmailObservation : what the agent receives each step
- EmailState : episode-level metadata
Plus helpers:
- EmailReward : detailed reward breakdown per step
- StepResult : full result returned by step()
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
# ─────────────────────────────────────────────────────────────────────────────
# ACTION β€” what the agent sends on every step
# ─────────────────────────────────────────────────────────────────────────────
VALID_LABELS = {"spam", "inbox", "urgent", "archive", "delete"}
VALID_PRIORITIES = {"high", "medium", "low"}
VALID_CATEGORIES = {"spam", "work", "personal", "newsletter", "notification", "social"}
class EmailAction(BaseModel):
"""
The action an agent takes after reading one email.
Task 1 (easy) β†’ only `label` is scored (spam | inbox)
Task 2 (medium) β†’ `label` + `priority` are scored
Task 3 (hard) β†’ all three fields are scored
"""
label: str = Field(
...,
description=(
"How to file this email. "
"spam=junk/malicious | inbox=needs attention | "
"urgent=act immediately | archive=save but no action | delete=discard"
),
examples=["spam", "inbox", "urgent", "archive", "delete"],
)
priority: str = Field(
default="medium",
description="How time-sensitive is this email. high=act today | medium=act this week | low=whenever",
examples=["high", "medium", "low"],
)
category: str = Field(
default="work",
description=(
"What type of email this is. "
"spam | work | personal | newsletter | notification | social"
),
examples=["work", "personal", "newsletter", "notification", "social", "spam"],
)
@field_validator("label")
@classmethod
def validate_label(cls, v: str) -> str:
v = v.lower().strip()
if v not in VALID_LABELS:
raise ValueError(f"label must be one of {VALID_LABELS}, got '{v}'")
return v
@field_validator("priority")
@classmethod
def validate_priority(cls, v: str) -> str:
v = v.lower().strip()
if v not in VALID_PRIORITIES:
raise ValueError(f"priority must be one of {VALID_PRIORITIES}, got '{v}'")
return v
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
v = v.lower().strip()
if v not in VALID_CATEGORIES:
raise ValueError(f"category must be one of {VALID_CATEGORIES}, got '{v}'")
return v
# ─────────────────────────────────────────────────────────────────────────────
# OBSERVATION β€” what the agent receives each step
# ─────────────────────────────────────────────────────────────────────────────
class EmailObservation(BaseModel):
"""
Everything the agent can see after a reset() or step() call.
Contains the current email to classify plus episode progress info.
"""
# Current email content
email_id: str = Field(..., description="Unique identifier for this email")
subject: str = Field(..., description="Email subject line")
sender: str = Field(..., description="Sender email address")
body: str = Field(..., description="Full email body text")
timestamp: str = Field(..., description="When the email arrived (YYYY-MM-DD HH:MM:SS)")
# Episode progress
step: int = Field(..., description="Current step number (0-indexed)")
total_emails: int = Field(..., description="Total number of emails in this episode")
emails_remaining: int = Field(..., description="Emails still to be processed (including current)")
# Reward signal
reward: float = Field(0.0, description="Reward for the LAST action (0.0 on first obs)")
cumulative_reward: float = Field(0.0, description="Total reward accumulated this episode")
# Terminal flag
done: bool = Field(False, description="True when all emails have been processed")
# Optional extra info
metadata: Dict[str, Any] = Field(default_factory=dict, description="Extra debugging info")
# ─────────────────────────────────────────────────────────────────────────────
# REWARD β€” detailed breakdown of how a step was scored
# ─────────────────────────────────────────────────────────────────────────────
class EmailReward(BaseModel):
"""
Detailed reward breakdown for a single step.
Returned in StepResult.info β€” gives the agent rich learning signal.
"""
value: float = Field(..., description="Total reward for this step [0.0, 1.0]")
label_score: float = Field(0.0, description="Score for the label field [0.0, 1.0]")
priority_score: float = Field(0.0, description="Score for the priority field [0.0, 1.0]")
category_score: float = Field(0.0, description="Score for the category field [0.0, 1.0]")
feedback: str = Field("", description="Human-readable explanation of scoring")
penalties: List[str] = Field(default_factory=list, description="Any penalties applied and why")
# ─────────────────────────────────────────────────────────────────────────────
# STATE β€” episode-level metadata
# ─────────────────────────────────────────────────────────────────────────────
class EmailState(BaseModel):
"""
Episode metadata returned by state().
Gives an overview of current episode progress without email content.
"""
episode_id: str = Field(..., description="Unique ID for this episode (UUID)")
task_id: int = Field(..., description="Which task is running (1, 2, or 3)")
task_name: str = Field(..., description="Human-readable task name")
task_difficulty: str = Field(..., description="easy | medium | hard")
step_count: int = Field(..., description="Steps taken so far")
total_emails: int = Field(..., description="Total emails in this episode")
cumulative_reward: float = Field(..., description="Total reward accumulated so far")
score: float = Field(..., description="Normalized score so far [0.0, 1.0]")
done: bool = Field(..., description="True when episode is complete")
# ─────────────────────────────────────────────────────────────────────────────
# STEP RESULT β€” full response from step()
# ─────────────────────────────────────────────────────────────────────────────
class StepResult(BaseModel):
"""
Complete result returned by the /step endpoint.
Matches the OpenEnv spec: (observation, reward, done, info).
"""
observation: EmailObservation
reward: float
done: bool
info: Dict[str, Any] = Field(default_factory=dict)
# ─────────────────────────────────────────────────────────────────────────────
# RESET REQUEST
# ─────────────────────────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
"""Request body for the /reset endpoint."""
task_id: int = Field(
default=1,
ge=1,
le=3,
description="Which task to run. 1=Spam Detection (easy), 2=Priority Triage (medium), 3=Full Triage (hard)",
)