File size: 7,589 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
graders.py β€” Gov Workflow OpenEnv: Deterministic Episode Graders

Rules:
  - All graders read ONLY from EpisodeStateModel flat fields.
  - No access to env internals, EpisodeMetrics, or reward breakdown proxies.
  - GraderResult uses the aligned schema (score, grader_name, named metric fields).
  - grade_episode() dispatches by task_id.

Grader weights:
  Easy   β€” completion(0.45) + SLA(0.35) + idle_efficiency(0.20)          = 1.00
  Medium β€” completion(0.35) + SLA(0.30) + doc_rework(0.20) + urgent(0.15) = 1.00
  Hard   β€” completion(0.28) + SLA(0.24) + doc_rework(0.16)
           + fairness(0.16) + escalation_discipline(0.16)                 = 1.00
"""
from __future__ import annotations
from app.models import EpisodeStateModel, GraderResult


# ─────────────────────────────────────────────────────────────────────────────
# INTERNAL HELPERS
# ─────────────────────────────────────────────────────────────────────────────

def _safe_ratio(num: float, den: float, default: float = 1.0) -> float:
    """Safe division, clamped to [0.0, 1.0]. Returns `default` when den ≀ 0."""
    if den <= 0:
        return max(0.0, min(1.0, default))
    return max(0.0, min(1.0, num / den))


def _b(value: float) -> float:
    """Clamp any float to [0.0, 1.0]."""
    return max(0.0, min(1.0, float(value)))


def _extract(state: EpisodeStateModel) -> dict[str, float]:
    """
    Extract all grader input metrics from EpisodeStateModel flat fields.

    Design note:
      - total_arrived   : populated by env.state() from metrics.total_arrived
      - fairness_gap    : computed by completion_fairness_gap() in env.state()
      - All other fields are direct EpisodeStateModel attributes.
    """
    total_arrived      = max(1, state.total_arrived)
    total_completed    = float(state.total_completed)
    total_breaches     = float(state.total_sla_breaches)
    total_docs_req     = float(state.total_docs_requested)
    total_docs_cleared = float(state.total_docs_cleared)
    total_urgent_arr   = float(state.total_urgent_arrived)
    total_urgent_comp  = float(state.total_urgent_completed)
    total_idle         = float(state.total_idle_officer_days)
    total_capacity     = float(state.total_capacity_days)
    total_escused      = float(state.total_escalations_used)
    total_wasted_esc   = float(state.total_wasted_escalations)
    fairness_gap       = float(state.fairness_gap)

    return {
        "completion_rate":         _b(_safe_ratio(total_completed, total_arrived, 0.0)),
        "sla_compliance":          _b(1.0 - _safe_ratio(total_breaches, total_arrived, 0.0)),
        "document_rework_quality": _b(_safe_ratio(total_docs_cleared, total_docs_req, 1.0)),
        "urgent_served_rate":      _b(_safe_ratio(total_urgent_comp, total_urgent_arr, 1.0)),
        "fairness_score":          _b(1.0 - fairness_gap),
        "escalation_discipline":   _b(1.0 - _safe_ratio(total_wasted_esc, max(1.0, total_escused), 0.0)),
        "idle_efficiency":         _b(1.0 - _safe_ratio(total_idle, max(1.0, total_capacity), 0.0)),
        "fairness_gap":            round(fairness_gap, 4),
    }


def _build_result(
    state: EpisodeStateModel,
    score: float,
    grader_name: str,
    m: dict[str, float],
) -> GraderResult:
    """Assemble a fully-populated GraderResult from metric dict and state."""
    total_arrived = max(0, state.total_arrived)
    avg_wait = state.avg_waiting_days

    return GraderResult(
        task_id=state.task_id,
        episode_id=state.episode_id,
        grader_name=grader_name,
        score=_b(score),
        completion_rate=m["completion_rate"],
        sla_compliance_rate=m["sla_compliance"],
        idle_efficiency=m["idle_efficiency"],
        document_rework_quality=m["document_rework_quality"],
        urgent_served_rate=m["urgent_served_rate"],
        fairness_score=m["fairness_score"],
        escalation_discipline=m["escalation_discipline"],
        fairness_gap=m["fairness_gap"],
        total_cases_arrived=total_arrived,
        total_completed=state.total_completed,
        total_sla_breached=state.total_sla_breaches,
        total_rejected=state.total_rejected,
        avg_waiting_days=avg_wait,
    )


# ─────────────────────────────────────────────────────────────────────────────
# TASK GRADERS
# ─────────────────────────────────────────────────────────────────────────────

def grade_easy(state: EpisodeStateModel) -> GraderResult:
    """
    district_backlog_easy grader.
    Focus: raw throughput and SLA hygiene under simple single-service load.

    Weights: completion(0.45) + SLA(0.35) + idle_efficiency(0.20)
    """
    m = _extract(state)
    score = (
        0.45 * m["completion_rate"]
      + 0.35 * m["sla_compliance"]
      + 0.20 * m["idle_efficiency"]
    )
    return _build_result(state, score, "easy", m)


def grade_medium(state: EpisodeStateModel) -> GraderResult:
    """
    mixed_urgency_medium grader.
    Focus: throughput + SLA + document quality + prioritizing urgent cases.

    Weights: completion(0.35) + SLA(0.30) + doc_rework(0.20) + urgent(0.15)
    """
    m = _extract(state)
    score = (
        0.35 * m["completion_rate"]
      + 0.30 * m["sla_compliance"]
      + 0.20 * m["document_rework_quality"]
      + 0.15 * m["urgent_served_rate"]
    )
    return _build_result(state, score, "medium", m)


def grade_hard(state: EpisodeStateModel) -> GraderResult:
    """
    cross_department_hard grader.
    Focus: all-round excellence including cross-service fairness and
    restrained escalation use under crisis conditions.

    Weights: completion(0.28) + SLA(0.24) + doc_rework(0.16)
             + fairness(0.16) + escalation_discipline(0.16)
    """
    m = _extract(state)
    score = (
        0.28 * m["completion_rate"]
      + 0.24 * m["sla_compliance"]
      + 0.16 * m["document_rework_quality"]
      + 0.16 * m["fairness_score"]
      + 0.16 * m["escalation_discipline"]
    )
    return _build_result(state, score, "hard", m)


# ─────────────────────────────────────────────────────────────────────────────
# DISPATCHER
# ─────────────────────────────────────────────────────────────────────────────

_GRADER_MAP = {
    "district_backlog_easy":          grade_easy,
    "district_backlog_easy_extreme":  grade_easy,
    "mixed_urgency_medium":           grade_medium,
    "cross_department_hard":          grade_hard,
}


def grade_episode(state: EpisodeStateModel) -> GraderResult:
    """
    Dispatch to the correct task grader.
    Falls back to grade_hard for unknown task IDs (safe default for new tasks).
    """
    grader_fn = _GRADER_MAP.get(state.task_id, grade_hard)
    return grader_fn(state)