Meta-Hackathon / environment.py
Dev Shah
fix: use Pydantic model_validate instead of unpacking with aliases to resolve Post Reset failure
4ed3421
"""
Email Triage & Response Environment
OpenEnv-compatible environment for agent evaluation.
"""
import json
from typing import Optional, Literal
from pydantic import BaseModel, Field
class Email(BaseModel):
id: str
from_: str = Field(..., alias="from")
subject: str
body: str
labels: list[str] = []
replied: bool = False
archived: bool = False
flagged: bool = False
flag_reason: Optional[str] = None
reply_body: Optional[str] = None
class Config:
populate_by_name = True
class InboxState(BaseModel):
inbox: list[Email]
sent: list[dict] = []
step_count: int = 0
class Observation(BaseModel):
status: str
message: str
data: Optional[dict] = None
step_count: int = 0
class Action(BaseModel):
action: Literal["label", "draft_reply", "archive", "flag", "read", "list_inbox"]
email_id: Optional[str] = None
priority: Optional[Literal["urgent", "normal", "low"]] = None
body: Optional[str] = None
reason: Optional[str] = None
class StepResult(BaseModel):
observation: Observation
reward: float
done: bool
info: dict = {}
import os
# Load dataset (generated by curate_dataset.py)
DATASET_PATH = os.path.join(os.path.dirname(__file__), "data", "emails.json")
try:
with open(DATASET_PATH, "r", encoding="utf-8") as f:
_dataset = json.load(f)
except FileNotFoundError:
# Fallback to empty if not curated yet, though curate_dataset.py should be run first
_dataset = {"task1": {"emails": [], "ground_truth": {}}, "task2": {"emails": []}, "task3": {"emails": [], "ground_truth": {}, "urgent_ids": [], "archive_ids": [], "flag_ids": []}}
TASK1_EMAILS = _dataset["task1"]["emails"]
TASK1_GROUND_TRUTH = _dataset["task1"].get("ground_truth", {})
TASK2_EMAIL = _dataset["task2"]["emails"][0] if _dataset["task2"]["emails"] else {}
TASK3_EMAILS = _dataset["task3"]["emails"]
TASK3_GROUND_TRUTH = _dataset["task3"].get("ground_truth", {})
TASK3_URGENT_IDS = set(_dataset["task3"].get("urgent_ids", []))
TASK3_ARCHIVE_IDS = set(_dataset["task3"].get("archive_ids", []))
TASK3_FLAG_IDS = set(_dataset["task3"].get("flag_ids", []))
def grade_task1(state: InboxState) -> float:
score = 0.0
for email in state.inbox:
gt = TASK1_GROUND_TRUTH.get(email.id)
if gt and "urgent" in email.labels and gt == "urgent":
score += 0.2
elif gt and "normal" in email.labels and gt == "normal":
score += 0.2
elif gt and "low" in email.labels and gt == "low":
score += 0.2
return round(min(score, 1.0), 2)
def grade_task2(state: InboxState) -> float:
score = 0.0
email = next((e for e in state.inbox if e.id == "t2_001"), None)
if email is None or not email.replied or not email.reply_body:
return 0.0
reply = email.reply_body.lower()
issues_covered = 0
if "order" in reply and ("48291" in reply or "order" in reply):
issues_covered += 1
if any(w in reply for w in ["refund", "deliver", "shipment", "track"]):
issues_covered += 1
if any(w in reply for w in ["compensat", "apologi", "sorry", "inconvenien"]):
issues_covered += 1
score += 0.1 * issues_covered # up to 0.3
# +0.3 professional tone
professional_signals = ["dear", "sincerely", "regards", "thank you", "we apologize",
"we understand", "please", "we will"]
rude_signals = ["whatever", "not our fault", "calm down"]
tone_score = sum(1 for w in professional_signals if w in reply)
rude_penalty = sum(1 for w in rude_signals if w in reply)
score += min(0.3, tone_score * 0.05) - (rude_penalty * 0.1)
# +0.2 correct recipient / subject handling
if email.reply_body and len(email.reply_body) > 50:
score += 0.2
# +0.2 no fabricated facts (heuristic: no invented order dates / amounts)
fabrication_signals = ["$", "€", "refund amount", "exact date", "tracking number is"]
fab_hits = sum(1 for w in fabrication_signals if w in reply)
if fab_hits == 0:
score += 0.2
return round(max(0.0, min(score, 1.0)), 2)
def grade_task3(state: InboxState, penalties: dict) -> float:
score = 0.0
email_map = {e.id: e for e in state.inbox}
# Priority labels (0.2 per correct, 10 emails = max 2.0 → normalise to 0.5 weight)
label_score = 0.0
for eid, gt in TASK3_GROUND_TRUTH.items():
email = email_map.get(eid)
if email and gt in email.labels:
label_score += 0.2
score += min(label_score, 2.0) * 0.25 # normalise to 0.5
# Replies for urgent emails (max 0.4)
reply_scores = []
for eid in TASK3_URGENT_IDS:
email = email_map.get(eid)
if email and email.replied and email.reply_body:
reply_scores.append(min(len(email.reply_body) / 200, 1.0) * 0.1)
score += sum(reply_scores)
# Archive spam (0.05 each, max 0.1)
for eid in TASK3_ARCHIVE_IDS:
email = email_map.get(eid)
if email and email.archived:
score += 0.05
# Flag ambiguous (0.05 each)
for eid in TASK3_FLAG_IDS:
email = email_map.get(eid)
if email and email.flagged:
score += 0.05
# Penalties
score -= penalties.get("destructive_actions", 0) * 0.1
score -= penalties.get("loop_actions", 0) * 0.05
return round(max(0.0, min(score, 1.0)), 2)
# ---------------------------------------------------------------------------
# Environment Class
# ---------------------------------------------------------------------------
class EmailTriageEnv:
"""OpenEnv-compatible Email Triage environment."""
TASKS = {1, 2, 3}
def __init__(self, task: int = 1):
assert task in self.TASKS, f"task must be one of {self.TASKS}"
self.task = task
self._state: Optional[InboxState] = None
self._penalties = {"destructive_actions": 0, "loop_actions": 0}
self._action_history: list[str] = []
self._done = False
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(self) -> Observation:
self._penalties = {"destructive_actions": 0, "loop_actions": 0}
self._action_history = []
self._done = False
if self.task == 1:
emails = [Email.model_validate(e) for e in TASK1_EMAILS]
elif self.task == 2:
emails = [Email.model_validate(TASK2_EMAIL)]
else:
emails = [Email.model_validate(e) for e in TASK3_EMAILS]
self._state = InboxState(inbox=emails)
return Observation(
status="ok",
message=f"Task {self.task} environment reset. Inbox contains {len(emails)} email(s).",
data={"task": self.task, "inbox_size": len(emails)},
step_count=0,
)
def state(self) -> dict:
assert self._state is not None, "Call reset() first."
return json.loads(self._state.model_dump_json(by_alias=True))
def step(self, action: Action) -> StepResult:
assert self._state is not None, "Call reset() first."
if self._done:
return StepResult(
observation=Observation(status="done", message="Episode already finished.", step_count=self._state.step_count),
reward=0.0,
done=True,
)
self._state.step_count += 1
action_key = f"{action.action}:{action.email_id}"
# Loop detection
if self._action_history.count(action_key) >= 2:
self._penalties["loop_actions"] += 1
self._action_history.append(action_key)
obs, reward = self._dispatch(action)
obs.step_count = self._state.step_count
return StepResult(observation=obs, reward=reward, done=self._done)
def score(self) -> float:
"""Return current cumulative score (0-1)."""
assert self._state is not None, "Call reset() first."
if self.task == 1:
return grade_task1(self._state)
elif self.task == 2:
return grade_task2(self._state)
else:
return grade_task3(self._state, self._penalties)
# ------------------------------------------------------------------
# Action dispatch
# ------------------------------------------------------------------
def _dispatch(self, action: Action):
handlers = {
"list_inbox": self._act_list_inbox,
"read": self._act_read,
"label": self._act_label,
"draft_reply":self._act_draft_reply,
"archive": self._act_archive,
"flag": self._act_flag,
}
handler = handlers.get(action.action)
if handler is None:
return Observation(status="error", message=f"Unknown action: {action.action}"), 0.0
return handler(action)
def _act_list_inbox(self, action: Action):
summaries = [
{"id": e.id, "from": e.from_, "subject": e.subject,
"labels": e.labels, "replied": e.replied, "archived": e.archived, "flagged": e.flagged}
for e in self._state.inbox
]
return Observation(status="ok", message="Inbox listed.", data={"emails": summaries}), 0.0
def _act_read(self, action: Action):
email = self._find(action.email_id)
if email is None:
return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
return Observation(
status="ok",
message=f"Read email {action.email_id}.",
data=json.loads(email.model_dump_json(by_alias=True)),
), 0.0
def _act_label(self, action: Action):
email = self._find(action.email_id)
if email is None:
return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
if action.priority not in ("urgent", "normal", "low"):
return Observation(status="error", message="priority must be urgent | normal | low"), 0.0
# Remove existing priority labels then add new
email.labels = [l for l in email.labels if l not in ("urgent", "normal", "low")]
email.labels.append(action.priority)
incremental = self._incremental_label_reward(email.id, action.priority)
return Observation(
status="ok",
message=f"Labelled {action.email_id} as {action.priority}.",
data={"email_id": action.email_id, "priority": action.priority},
), incremental
def _act_draft_reply(self, action: Action):
email = self._find(action.email_id)
if email is None:
return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
if not action.body or len(action.body.strip()) < 10:
return Observation(status="error", message="Reply body too short."), 0.0
email.replied = True
email.reply_body = action.body
self._state.sent.append({"to": email.from_, "subject": f"Re: {email.subject}", "body": action.body})
return Observation(status="ok", message=f"Reply drafted for {action.email_id}."), 0.0
def _act_archive(self, action: Action):
email = self._find(action.email_id)
if email is None:
return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
# Penalty if archiving urgent email
if "urgent" in email.labels:
self._penalties["destructive_actions"] += 1
return Observation(
status="warning",
message=f"Archived urgent email {action.email_id} — penalty applied.",
), -0.1
email.archived = True
return Observation(status="ok", message=f"Email {action.email_id} archived."), 0.0
def _act_flag(self, action: Action):
email = self._find(action.email_id)
if email is None:
return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0
email.flagged = True
email.flag_reason = action.reason or "unspecified"
return Observation(status="ok", message=f"Email {action.email_id} flagged for human review."), 0.0
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find(self, email_id: Optional[str]) -> Optional[Email]:
if email_id is None:
return None
return next((e for e in self._state.inbox if e.id == email_id), None)
def _incremental_label_reward(self, email_id: str, priority: str) -> float:
"""Return +0.2 if label matches ground truth for task 1."""
if self.task == 1:
gt = TASK1_GROUND_TRUTH.get(email_id)
return 0.2 if gt == priority else 0.0
return 0.0