"""Task-specific graders for the SupportDesk environment.""" from __future__ import annotations from dataclasses import dataclass from typing import Optional from models import SupportCaseProgress, SupportDeskObservation def _as_case(obj) -> SupportCaseProgress: """Normalize observation/state/case to SupportCaseProgress.""" if isinstance(obj, SupportCaseProgress): return obj if hasattr(obj, "case"): return obj.case # type: ignore[attr-defined] raise TypeError(f"Unsupported object for grading: {type(obj)}") @dataclass class GradeBreakdown: score: float message: str penalties: dict[str, float] completed_milestones: list[str] = None @property def total_score(self) -> float: return self.score def __post_init__(self): if self.completed_milestones is None: object.__setattr__(self, "completed_milestones", []) def _clamp(v: float) -> float: return max(0.01, min(0.99, v)) def grade_task_id(task_id: str, observation: SupportDeskObservation | SupportCaseProgress) -> GradeBreakdown: case = _as_case(observation) if task_id == "billing_refund_easy": return BillingRefundEasyGrader().score(case) if task_id == "account_takeover_medium": return AccountTakeoverMediumGrader().score(case) if task_id == "api_incident_hard": return ApiIncidentHardGrader().score(case) if task_id == "regulated_export_exception_hard": return RegulatedExportExceptionHardGrader().score(case) return GradeBreakdown(0.01, "Unknown task", {"unknown_task": 1.0}) def grade_case(task_or_id, observation) -> GradeBreakdown: """Return a GradeBreakdown for the given task and case/observation.""" task_id = task_or_id.task_id if hasattr(task_or_id, "task_id") else str(task_or_id) case = _as_case(observation) return grade_task_id(task_id, case) class BillingRefundEasyGrader: def score(self, case: SupportCaseProgress) -> GradeBreakdown: penalties: dict[str, float] = {} score = 1.0 reply = (case.reply or "").lower() note = (case.internal_note or "").lower() if reply: if "refund" not in reply: penalties["missing_refund"] = 0.25 else: penalties["no_reply"] = 0.55 if note: if "duplicate" not in note: penalties["note_missing_duplicate"] = 0.2 else: penalties["no_note"] = 0.2 if case.status != "resolved": penalties["status_not_resolved"] = 0.1 score -= sum(penalties.values()) score = round(score, 2) return GradeBreakdown(_clamp(score), "Billing refund evaluation", penalties) def grade(self, case: SupportCaseProgress) -> float: return self.score(case).score class AccountTakeoverMediumGrader: def score(self, case: SupportCaseProgress) -> GradeBreakdown: penalties: dict[str, float] = {} score = 0.2 reply = (case.reply or "").lower() if reply: if "lock" not in reply: penalties["missing_lock"] = 0.2 if "verify" not in reply: penalties["missing_verify"] = 0.2 if "ownership" not in reply: penalties["missing_ownership"] = 0.2 else: penalties["no_reply"] = 0.4 if case.status not in ("escalated", "waiting_on_customer"): penalties["wrong_status"] = 0.2 score -= sum(penalties.values()) score = round(score, 2) return GradeBreakdown(_clamp(score), "Account takeover evaluation", penalties) def grade(self, case: SupportCaseProgress) -> float: return self.score(case).score class ApiIncidentHardGrader: def score(self, case: SupportCaseProgress) -> GradeBreakdown: penalties: dict[str, float] = {} score = 0.2 reply = (case.reply or "").lower() if reply: if "status" not in reply: penalties["missing_status_page"] = 0.15 if "request" not in reply or "id" not in reply: penalties["missing_request_ids"] = 0.2 if "escalat" not in reply: penalties["missing_escalation"] = 0.2 else: penalties["no_reply"] = 0.4 if case.queue != "platform_engineering": penalties["wrong_queue"] = 0.15 score -= sum(penalties.values()) score = round(score, 2) return GradeBreakdown(_clamp(score), "API incident evaluation", penalties) def grade(self, case: SupportCaseProgress) -> float: return self.score(case).score class RegulatedExportExceptionHardGrader: def score(self, case: SupportCaseProgress) -> GradeBreakdown: penalties: dict[str, float] = {} score = 0.2 reply = (case.reply or "").lower() if reply: if "compliance" not in reply: penalties["missing_compliance"] = 0.2 if "cannot promise" not in reply and "not promise" not in reply: penalties["missing_no_promise"] = 0.2 if "recipient" not in reply or "identity" not in reply: penalties["missing_recipient"] = 0.15 else: penalties["no_reply"] = 0.4 if case.status != "waiting_on_customer": penalties["wrong_status"] = 0.15 score -= sum(penalties.values()) score = round(score, 2) return GradeBreakdown(_clamp(score), "Regulated export evaluation", penalties) def grade(self, case: SupportCaseProgress) -> float: return self.score(case).score