Meta-Hackathon / tests.py
Dev Shah
feat: initial commit for email triage agent
96a5caf
"""
tests.py — Unit tests for the Email Triage environment.
Run with: python tests.py
"""
import sys
from environment import (
EmailTriageEnv,
Action,
grade_task1,
grade_task2,
InboxState,
Email,
TASK1_GROUND_TRUTH,
TASK1_EMAILS
)
def run_test(name: str, fn):
try:
fn()
print(f" ✅ {name}")
return True
except AssertionError as e:
print(f" ❌ {name}: {e}")
return False
except Exception as e:
print(f" 💥 {name}: {type(e).__name__}: {e}")
return False
# ---------------------------------------------------------------------------
# Task 1 tests
# ---------------------------------------------------------------------------
def test_task1_reset():
env = EmailTriageEnv(task=1)
obs = env.reset()
assert obs.status == "ok"
assert obs.data["inbox_size"] == 5
def test_task1_list():
env = EmailTriageEnv(task=1)
env.reset()
result = env.step(Action(action="list_inbox"))
assert result.observation.status == "ok"
assert len(result.observation.data["emails"]) == 5
def test_task1_read():
env = EmailTriageEnv(task=1)
env.reset()
result = env.step(Action(action="read", email_id="t1_001"))
assert result.observation.status == "ok"
assert len(result.observation.data["subject"]) > 0
def test_task1_label_correct():
env = EmailTriageEnv(task=1)
env.reset()
gt = TASK1_GROUND_TRUTH["t1_001"]
result = env.step(Action(action="label", email_id="t1_001", priority=gt))
assert result.reward == 0.2, f"Expected 0.2, got {result.reward}"
def test_task1_label_wrong():
env = EmailTriageEnv(task=1)
env.reset()
gt = TASK1_GROUND_TRUTH["t1_001"]
wrong = "low" if gt in ("urgent", "normal") else "urgent"
result = env.step(Action(action="label", email_id="t1_001", priority=wrong))
assert result.reward == 0.0
def test_task1_full_score():
env = EmailTriageEnv(task=1)
env.reset()
for eid, priority in TASK1_GROUND_TRUTH.items():
env.step(Action(action="label", email_id=eid, priority=priority))
assert env.score() == 1.0, f"Expected 1.0, got {env.score()}"
def test_task1_partial_score():
env = EmailTriageEnv(task=1)
env.reset()
eids = list(TASK1_GROUND_TRUTH.keys())
env.step(Action(action="label", email_id=eids[0], priority=TASK1_GROUND_TRUTH[eids[0]]))
env.step(Action(action="label", email_id=eids[1], priority=TASK1_GROUND_TRUTH[eids[1]]))
score = env.score()
assert score == 0.4, f"Expected 0.4, got {score}"
# ---------------------------------------------------------------------------
# Task 2 tests
# ---------------------------------------------------------------------------
def test_task2_reset():
env = EmailTriageEnv(task=2)
obs = env.reset()
assert obs.data["inbox_size"] == 1
def test_task2_no_reply_zero():
env = EmailTriageEnv(task=2)
env.reset()
assert env.score() == 0.0
def test_task2_good_reply():
env = EmailTriageEnv(task=2)
env.reset()
env.step(Action(
action="draft_reply",
email_id="t2_001",
body=(
"Dear Jamie,\n\nThank you for reaching out. We sincerely apologize for the "
"experience you have had with order #48291. We understand how frustrating "
"this must be.\n\nWe are urgently investigating the status of your delivery "
"and will provide an update within 2 hours. If we cannot confirm delivery "
"within 48 hours we will process a full refund immediately. We will also "
"review the service failures you experienced and follow up regarding "
"compensation.\n\nWe truly value your business and are committed to "
"making this right.\n\nSincerely,\nCustomer Support Team"
),
))
score = env.score()
assert score > 0.5, f"Expected score > 0.5, got {score}"
def test_task2_short_reply_penalised():
env = EmailTriageEnv(task=2)
env.reset()
result = env.step(Action(action="draft_reply", email_id="t2_001", body="ok"))
assert result.observation.status == "error"
# ---------------------------------------------------------------------------
# Task 3 tests
# ---------------------------------------------------------------------------
def test_task3_reset():
env = EmailTriageEnv(task=3)
obs = env.reset()
assert obs.data["inbox_size"] == 10
def test_task3_archive_spam_no_penalty():
env = EmailTriageEnv(task=3)
env.reset()
# Label spam as low first (so archiving doesn't trigger urgent penalty)
env.step(Action(action="label", email_id="t3_002", priority="low"))
result = env.step(Action(action="archive", email_id="t3_002"))
assert result.observation.status == "ok"
def test_task3_archive_urgent_penalty():
env = EmailTriageEnv(task=3)
env.reset()
env.step(Action(action="label", email_id="t3_001", priority="urgent"))
result = env.step(Action(action="archive", email_id="t3_001"))
assert result.reward == -0.1
assert result.observation.status == "warning"
def test_task3_flag():
env = EmailTriageEnv(task=3)
env.reset()
result = env.step(Action(action="flag", email_id="t3_009", reason="Missing context — need sender identity"))
assert result.observation.status == "ok"
def test_task3_loop_detection():
env = EmailTriageEnv(task=3)
env.reset()
for _ in range(3):
env.step(Action(action="label", email_id="t3_006", priority="normal"))
assert env._penalties["loop_actions"] >= 1
def test_task3_not_found():
env = EmailTriageEnv(task=3)
env.reset()
result = env.step(Action(action="read", email_id="nonexistent"))
assert result.observation.status == "error"
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
if __name__ == "__main__":
tests = [
# Task 1
("Task1 reset", test_task1_reset),
("Task1 list inbox", test_task1_list),
("Task1 read email", test_task1_read),
("Task1 correct label reward", test_task1_label_correct),
("Task1 wrong label no reward", test_task1_label_wrong),
("Task1 full score 1.0", test_task1_full_score),
("Task1 partial score 0.4", test_task1_partial_score),
# Task 2
("Task2 reset", test_task2_reset),
("Task2 no reply = 0.0", test_task2_no_reply_zero),
("Task2 good reply > 0.5", test_task2_good_reply),
("Task2 short reply error", test_task2_short_reply_penalised),
# Task 3
("Task3 reset", test_task3_reset),
("Task3 archive spam no penalty", test_task3_archive_spam_no_penalty),
("Task3 archive urgent = penalty", test_task3_archive_urgent_penalty),
("Task3 flag ambiguous", test_task3_flag),
("Task3 loop detection", test_task3_loop_detection),
("Task3 not found error", test_task3_not_found),
]
print("\nRunning Email Triage Environment Tests")
print("=" * 45)
passed = sum(run_test(name, fn) for name, fn in tests)
total = len(tests)
print(f"\n{passed}/{total} tests passed")
sys.exit(0 if passed == total else 1)