File size: 4,481 Bytes
3eb9552
4b77608
3eb9552
 
4b77608
3eb9552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b77608
 
 
3eb9552
 
 
 
 
 
 
 
 
 
 
 
 
 
4b77608
 
 
3eb9552
 
 
 
 
 
 
4b77608
3eb9552
 
4b77608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb9552
4b77608
3eb9552
 
4b77608
3eb9552
 
 
 
 
 
 
 
4b77608
3eb9552
 
 
 
4b77608
3eb9552
4b77608
3eb9552
4b77608
 
3eb9552
 
 
 
 
 
4b77608
 
 
 
 
 
 
 
 
3eb9552
 
4b77608
3eb9552
 
4b77608
3eb9552
 
4b77608
3eb9552
 
 
 
 
 
 
 
4b77608
3eb9552
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Task Graders for OpenEnv Email Triage

Implements agent graders for three difficulty levels (easy, medium, hard)
with scoring from 0.0 to 1.0 based on criteria like accuracy and critical safety.
"""

import numpy as np
from typing import Dict, Any, List, Tuple
from dataclasses import dataclass


@dataclass
class GradingCriteria:
    name: str
    weight: float
    score: float = 0.0


class TaskGrader:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.criteria = []
        self.episode_data = {
            'steps': 0,
            'correct_actions': 0,
            'incorrect_actions': 0,
            'critical_failures': 0,
        }
        self._initialize_criteria()
    
    def _initialize_criteria(self) -> None:
        for criterion_config in self.config.get('criteria', []):
            criterion = GradingCriteria(
                name=criterion_config['name'],
                weight=criterion_config['weight']
            )
            self.criteria.append(criterion)
    
    def reset(self) -> None:
        self.episode_data = {
            'steps': 0,
            'correct_actions': 0,
            'incorrect_actions': 0,
            'critical_failures': 0,
        }
        for criterion in self.criteria:
            criterion.score = 0.0
    
    def update(self, **kwargs) -> None:
        for key, value in kwargs.items():
            if key in self.episode_data:
                self.episode_data[key] = value
    
    def compute_scores(self) -> Dict[str, float]:
        scores = {}
        
        # Accuracy Criterion
        acc_criterion = next((c for c in self.criteria if c.name == 'accuracy'), None)
        if acc_criterion:
            total_actions = self.episode_data['correct_actions'] + self.episode_data['incorrect_actions']
            if total_actions > 0:
                acc_criterion.score = float(self.episode_data['correct_actions']) / total_actions
            else:
                acc_criterion.score = 0.0
        scores['accuracy'] = acc_criterion.score if acc_criterion else 0.0
        
        # Critical Safety Criterion
        safety_criterion = next((c for c in self.criteria if c.name == 'critical_safety'), None)
        if safety_criterion:
            failures = self.episode_data['critical_failures']
            if failures == 0:
                safety_criterion.score = 1.0
            else:
                safety_criterion.score = max(0.0, 1.0 - (failures * 0.3)) # Penalty per failure
        scores['critical_safety'] = safety_criterion.score if safety_criterion else 0.0
        
        return scores
    
    def get_final_score(self) -> float:
        self.compute_scores()
        total_weight = sum(c.weight for c in self.criteria)
        weighted_sum = sum(c.score * c.weight for c in self.criteria)
        
        if total_weight > 0:
            final_score = weighted_sum / total_weight
        else:
            final_score = 0.0
        
        return np.clip(final_score, 0.0, 1.0)
    
    def get_grade_report(self) -> Dict[str, Any]:
        scores = self.compute_scores()
        final_score = self.get_final_score()
        threshold = self.config.get('success_threshold', 0.7)
        
        return {
            'final_score': final_score,
            'success_threshold': threshold,
            'passed': final_score >= threshold,
            'criteria_scores': {c.name: c.score for c in self.criteria},
            'episode_data': self.episode_data.copy(),
            'feedback': self._generate_feedback(scores),
        }
    
    def _generate_feedback(self, scores: Dict[str, float]) -> str:
        feedback = []
        if scores.get('accuracy', 0) < 0.7:
            feedback.append("Triage Accuracy needs improvement.")
        else:
            feedback.append("Good triage accuracy.")
            
        if scores.get('critical_safety', 1.0) < 1.0:
            feedback.append("Critical safety failures occurred (e.g. ignored urgent email).")
        return " | ".join(feedback)

class EasyGrader(TaskGrader):
    pass

class MediumGrader(TaskGrader):
    pass

class HardGrader(TaskGrader):
    pass

def create_grader(task_level: str, config: Dict[str, Any]) -> TaskGrader:
    graders = {
        'easy': EasyGrader,
        'medium': MediumGrader,
        'hard': HardGrader,
    }
    if task_level not in graders:
        raise ValueError(f"Unknown task level: {task_level}")
    return graders[task_level](config)