Spaces:
Sleeping
Sleeping
| """ | |
| tests.py — Unit tests for the Email Triage environment. | |
| Run with: python tests.py | |
| """ | |
| import sys | |
| from environment import ( | |
| EmailTriageEnv, | |
| Action, | |
| grade_task1, | |
| grade_task2, | |
| InboxState, | |
| Email, | |
| TASK1_GROUND_TRUTH, | |
| TASK1_EMAILS | |
| ) | |
| def run_test(name: str, fn): | |
| try: | |
| fn() | |
| print(f" ✅ {name}") | |
| return True | |
| except AssertionError as e: | |
| print(f" ❌ {name}: {e}") | |
| return False | |
| except Exception as e: | |
| print(f" 💥 {name}: {type(e).__name__}: {e}") | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Task 1 tests | |
| # --------------------------------------------------------------------------- | |
| def test_task1_reset(): | |
| env = EmailTriageEnv(task=1) | |
| obs = env.reset() | |
| assert obs.status == "ok" | |
| assert obs.data["inbox_size"] == 5 | |
| def test_task1_list(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| result = env.step(Action(action="list_inbox")) | |
| assert result.observation.status == "ok" | |
| assert len(result.observation.data["emails"]) == 5 | |
| def test_task1_read(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| result = env.step(Action(action="read", email_id="t1_001")) | |
| assert result.observation.status == "ok" | |
| assert len(result.observation.data["subject"]) > 0 | |
| def test_task1_label_correct(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| gt = TASK1_GROUND_TRUTH["t1_001"] | |
| result = env.step(Action(action="label", email_id="t1_001", priority=gt)) | |
| assert result.reward == 0.2, f"Expected 0.2, got {result.reward}" | |
| def test_task1_label_wrong(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| gt = TASK1_GROUND_TRUTH["t1_001"] | |
| wrong = "low" if gt in ("urgent", "normal") else "urgent" | |
| result = env.step(Action(action="label", email_id="t1_001", priority=wrong)) | |
| assert result.reward == 0.0 | |
| def test_task1_full_score(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| for eid, priority in TASK1_GROUND_TRUTH.items(): | |
| env.step(Action(action="label", email_id=eid, priority=priority)) | |
| assert env.score() == 1.0, f"Expected 1.0, got {env.score()}" | |
| def test_task1_partial_score(): | |
| env = EmailTriageEnv(task=1) | |
| env.reset() | |
| eids = list(TASK1_GROUND_TRUTH.keys()) | |
| env.step(Action(action="label", email_id=eids[0], priority=TASK1_GROUND_TRUTH[eids[0]])) | |
| env.step(Action(action="label", email_id=eids[1], priority=TASK1_GROUND_TRUTH[eids[1]])) | |
| score = env.score() | |
| assert score == 0.4, f"Expected 0.4, got {score}" | |
| # --------------------------------------------------------------------------- | |
| # Task 2 tests | |
| # --------------------------------------------------------------------------- | |
| def test_task2_reset(): | |
| env = EmailTriageEnv(task=2) | |
| obs = env.reset() | |
| assert obs.data["inbox_size"] == 1 | |
| def test_task2_no_reply_zero(): | |
| env = EmailTriageEnv(task=2) | |
| env.reset() | |
| assert env.score() == 0.0 | |
| def test_task2_good_reply(): | |
| env = EmailTriageEnv(task=2) | |
| env.reset() | |
| env.step(Action( | |
| action="draft_reply", | |
| email_id="t2_001", | |
| body=( | |
| "Dear Jamie,\n\nThank you for reaching out. We sincerely apologize for the " | |
| "experience you have had with order #48291. We understand how frustrating " | |
| "this must be.\n\nWe are urgently investigating the status of your delivery " | |
| "and will provide an update within 2 hours. If we cannot confirm delivery " | |
| "within 48 hours we will process a full refund immediately. We will also " | |
| "review the service failures you experienced and follow up regarding " | |
| "compensation.\n\nWe truly value your business and are committed to " | |
| "making this right.\n\nSincerely,\nCustomer Support Team" | |
| ), | |
| )) | |
| score = env.score() | |
| assert score > 0.5, f"Expected score > 0.5, got {score}" | |
| def test_task2_short_reply_penalised(): | |
| env = EmailTriageEnv(task=2) | |
| env.reset() | |
| result = env.step(Action(action="draft_reply", email_id="t2_001", body="ok")) | |
| assert result.observation.status == "error" | |
| # --------------------------------------------------------------------------- | |
| # Task 3 tests | |
| # --------------------------------------------------------------------------- | |
| def test_task3_reset(): | |
| env = EmailTriageEnv(task=3) | |
| obs = env.reset() | |
| assert obs.data["inbox_size"] == 10 | |
| def test_task3_archive_spam_no_penalty(): | |
| env = EmailTriageEnv(task=3) | |
| env.reset() | |
| # Label spam as low first (so archiving doesn't trigger urgent penalty) | |
| env.step(Action(action="label", email_id="t3_002", priority="low")) | |
| result = env.step(Action(action="archive", email_id="t3_002")) | |
| assert result.observation.status == "ok" | |
| def test_task3_archive_urgent_penalty(): | |
| env = EmailTriageEnv(task=3) | |
| env.reset() | |
| env.step(Action(action="label", email_id="t3_001", priority="urgent")) | |
| result = env.step(Action(action="archive", email_id="t3_001")) | |
| assert result.reward == -0.1 | |
| assert result.observation.status == "warning" | |
| def test_task3_flag(): | |
| env = EmailTriageEnv(task=3) | |
| env.reset() | |
| result = env.step(Action(action="flag", email_id="t3_009", reason="Missing context — need sender identity")) | |
| assert result.observation.status == "ok" | |
| def test_task3_loop_detection(): | |
| env = EmailTriageEnv(task=3) | |
| env.reset() | |
| for _ in range(3): | |
| env.step(Action(action="label", email_id="t3_006", priority="normal")) | |
| assert env._penalties["loop_actions"] >= 1 | |
| def test_task3_not_found(): | |
| env = EmailTriageEnv(task=3) | |
| env.reset() | |
| result = env.step(Action(action="read", email_id="nonexistent")) | |
| assert result.observation.status == "error" | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| tests = [ | |
| # Task 1 | |
| ("Task1 reset", test_task1_reset), | |
| ("Task1 list inbox", test_task1_list), | |
| ("Task1 read email", test_task1_read), | |
| ("Task1 correct label reward", test_task1_label_correct), | |
| ("Task1 wrong label no reward", test_task1_label_wrong), | |
| ("Task1 full score 1.0", test_task1_full_score), | |
| ("Task1 partial score 0.4", test_task1_partial_score), | |
| # Task 2 | |
| ("Task2 reset", test_task2_reset), | |
| ("Task2 no reply = 0.0", test_task2_no_reply_zero), | |
| ("Task2 good reply > 0.5", test_task2_good_reply), | |
| ("Task2 short reply error", test_task2_short_reply_penalised), | |
| # Task 3 | |
| ("Task3 reset", test_task3_reset), | |
| ("Task3 archive spam no penalty", test_task3_archive_spam_no_penalty), | |
| ("Task3 archive urgent = penalty", test_task3_archive_urgent_penalty), | |
| ("Task3 flag ambiguous", test_task3_flag), | |
| ("Task3 loop detection", test_task3_loop_detection), | |
| ("Task3 not found error", test_task3_not_found), | |
| ] | |
| print("\nRunning Email Triage Environment Tests") | |
| print("=" * 45) | |
| passed = sum(run_test(name, fn) for name, fn in tests) | |
| total = len(tests) | |
| print(f"\n{passed}/{total} tests passed") | |
| sys.exit(0 if passed == total else 1) | |