"""Deterministic graders and reward helpers for SupportDesk.""" from __future__ import annotations import re from dataclasses import dataclass from models import SupportCaseProgress from tasks import SupportTaskSpec, get_task STRICT_SCORE_EPSILON = 0.01 @dataclass(frozen=True) class GradeBreakdown: """A scored view of how close a case is to the gold solution.""" total_score: float queue_score: float priority_score: float issue_type_score: float requested_fields_score: float reply_score: float note_score: float status_score: float resolution_score: float completed_milestones: tuple[str, ...] def _normalize(text: str | None) -> str: if not text: return "" normalized = text.lower().replace("-", " ") return re.sub(r"[^a-z0-9\s]", " ", normalized) def _marker_group_score(text: str | None, marker_groups: tuple[tuple[str, ...], ...]) -> float: if not marker_groups: return 1.0 normalized = _normalize(text) if not normalized: return 0.0 matches = 0 for group in marker_groups: if any(_normalize(marker) in normalized for marker in group): matches += 1 return matches / len(marker_groups) def _requested_fields_score(case: SupportCaseProgress, task: SupportTaskSpec) -> float: required = set(task.required_requested_fields) requested = set(case.requested_fields) if not required: return 1.0 if not requested else 0.0 if not requested: return 0.0 matched = len(required.intersection(requested)) extras = len(requested.difference(required)) raw = matched / len(required) penalty = min(0.25, extras * 0.05) return max(0.0, raw - penalty) def _reply_penalty(case: SupportCaseProgress, task: SupportTaskSpec) -> float: text = _normalize(case.reply) if not text: return 0.0 return 0.0 if not any(_normalize(marker) in text for marker in task.forbidden_reply_markers) else 0.5 def _strict_open_unit_interval(score: float) -> float: """Keep final task scores strictly within (0, 1) for evaluator compatibility.""" return min(1.0 - STRICT_SCORE_EPSILON, max(STRICT_SCORE_EPSILON, score)) def grade_case(task: SupportTaskSpec, case: SupportCaseProgress) -> GradeBreakdown: """Score a case deterministically with total_score strictly inside (0, 1).""" queue_score = 1.0 if case.queue == task.gold_queue else 0.0 priority_score = 1.0 if case.priority == task.gold_priority else 0.0 issue_type_score = 1.0 if case.issue_type == task.gold_issue_type else 0.0 requested_fields_score = _requested_fields_score(case, task) reply_score = max(0.0, _marker_group_score(case.reply, task.required_reply_markers) - _reply_penalty(case, task)) note_score = _marker_group_score(case.internal_note, task.required_note_markers) status_score = 1.0 if case.status == task.gold_status else 0.0 resolution_score = 1.0 if case.resolution_code == task.gold_resolution_code else 0.0 weighted_total = ( queue_score * 0.15 + priority_score * 0.10 + issue_type_score * 0.10 + requested_fields_score * 0.15 + reply_score * 0.25 + note_score * 0.10 + status_score * 0.10 + resolution_score * 0.05 ) milestones: list[str] = [] if queue_score: milestones.append("queue") if priority_score: milestones.append("priority") if issue_type_score: milestones.append("issue_type") if requested_fields_score >= 0.99: milestones.append("requested_fields") if reply_score >= 0.99: milestones.append("reply") if note_score >= 0.99: milestones.append("internal_note") if status_score: milestones.append("status") if resolution_score: milestones.append("resolution_code") return GradeBreakdown( total_score=round(_strict_open_unit_interval(weighted_total), 4), queue_score=queue_score, priority_score=priority_score, issue_type_score=issue_type_score, requested_fields_score=round(requested_fields_score, 4), reply_score=round(reply_score, 4), note_score=round(note_score, 4), status_score=status_score, resolution_score=resolution_score, completed_milestones=tuple(milestones), ) def grade_task_id(task_id: str, case: SupportCaseProgress) -> GradeBreakdown: """Convenience wrapper used by tests and evaluation scripts.""" return grade_case(get_task(task_id), case) class _TaskSpecificGrader: """Importable task-specific grader wrapper for validator task discovery.""" task_id: str = "" def grade(self, case: SupportCaseProgress) -> float: return grade_task_id(self.task_id, case).total_score def __call__(self, case: SupportCaseProgress) -> float: return self.grade(case) class BillingRefundEasyGrader(_TaskSpecificGrader): task_id = "billing_refund_easy" class AccountTakeoverMediumGrader(_TaskSpecificGrader): task_id = "account_takeover_medium" class ApiIncidentHardGrader(_TaskSpecificGrader): task_id = "api_incident_hard" class RegulatedExportExceptionHardGrader(_TaskSpecificGrader): task_id = "regulated_export_exception_hard" __all__ = [ "AccountTakeoverMediumGrader", "ApiIncidentHardGrader", "BillingRefundEasyGrader", "GradeBreakdown", "RegulatedExportExceptionHardGrader", "grade_case", "grade_task_id", ]