Spaces:
Sleeping
Sleeping
| """ | |
| Task Graders for OpenEnv Email Triage | |
| Implements agent graders for three difficulty levels (easy, medium, hard) | |
| with scoring from 0.0 to 1.0 based on criteria like accuracy and critical safety. | |
| """ | |
| import numpy as np | |
| from typing import Dict, Any, List, Tuple | |
| from dataclasses import dataclass | |
| class GradingCriteria: | |
| name: str | |
| weight: float | |
| score: float = 0.0 | |
| class TaskGrader: | |
| def __init__(self, config: Dict[str, Any]): | |
| self.config = config | |
| self.criteria = [] | |
| self.episode_data = { | |
| 'steps': 0, | |
| 'correct_actions': 0, | |
| 'incorrect_actions': 0, | |
| 'critical_failures': 0, | |
| } | |
| self._initialize_criteria() | |
| def _initialize_criteria(self) -> None: | |
| for criterion_config in self.config.get('criteria', []): | |
| criterion = GradingCriteria( | |
| name=criterion_config['name'], | |
| weight=criterion_config['weight'] | |
| ) | |
| self.criteria.append(criterion) | |
| def reset(self) -> None: | |
| self.episode_data = { | |
| 'steps': 0, | |
| 'correct_actions': 0, | |
| 'incorrect_actions': 0, | |
| 'critical_failures': 0, | |
| } | |
| for criterion in self.criteria: | |
| criterion.score = 0.0 | |
| def update(self, **kwargs) -> None: | |
| for key, value in kwargs.items(): | |
| if key in self.episode_data: | |
| self.episode_data[key] = value | |
| def compute_scores(self) -> Dict[str, float]: | |
| scores = {} | |
| # Accuracy Criterion | |
| acc_criterion = next((c for c in self.criteria if c.name == 'accuracy'), None) | |
| if acc_criterion: | |
| total_actions = self.episode_data['correct_actions'] + self.episode_data['incorrect_actions'] | |
| if total_actions > 0: | |
| acc_criterion.score = float(self.episode_data['correct_actions']) / total_actions | |
| else: | |
| acc_criterion.score = 0.0 | |
| scores['accuracy'] = acc_criterion.score if acc_criterion else 0.0 | |
| # Critical Safety Criterion | |
| safety_criterion = next((c for c in self.criteria if c.name == 'critical_safety'), None) | |
| if safety_criterion: | |
| failures = self.episode_data['critical_failures'] | |
| if failures == 0: | |
| safety_criterion.score = 1.0 | |
| else: | |
| safety_criterion.score = max(0.0, 1.0 - (failures * 0.3)) # Penalty per failure | |
| scores['critical_safety'] = safety_criterion.score if safety_criterion else 0.0 | |
| return scores | |
| def get_final_score(self) -> float: | |
| self.compute_scores() | |
| total_weight = sum(c.weight for c in self.criteria) | |
| weighted_sum = sum(c.score * c.weight for c in self.criteria) | |
| if total_weight > 0: | |
| final_score = weighted_sum / total_weight | |
| else: | |
| final_score = 0.0 | |
| return np.clip(final_score, 0.0, 1.0) | |
| def get_grade_report(self) -> Dict[str, Any]: | |
| scores = self.compute_scores() | |
| final_score = self.get_final_score() | |
| threshold = self.config.get('success_threshold', 0.7) | |
| return { | |
| 'final_score': final_score, | |
| 'success_threshold': threshold, | |
| 'passed': final_score >= threshold, | |
| 'criteria_scores': {c.name: c.score for c in self.criteria}, | |
| 'episode_data': self.episode_data.copy(), | |
| 'feedback': self._generate_feedback(scores), | |
| } | |
| def _generate_feedback(self, scores: Dict[str, float]) -> str: | |
| feedback = [] | |
| if scores.get('accuracy', 0) < 0.7: | |
| feedback.append("Triage Accuracy needs improvement.") | |
| else: | |
| feedback.append("Good triage accuracy.") | |
| if scores.get('critical_safety', 1.0) < 1.0: | |
| feedback.append("Critical safety failures occurred (e.g. ignored urgent email).") | |
| return " | ".join(feedback) | |
| class EasyGrader(TaskGrader): | |
| pass | |
| class MediumGrader(TaskGrader): | |
| pass | |
| class HardGrader(TaskGrader): | |
| pass | |
| def create_grader(task_level: str, config: Dict[str, Any]) -> TaskGrader: | |
| graders = { | |
| 'easy': EasyGrader, | |
| 'medium': MediumGrader, | |
| 'hard': HardGrader, | |
| } | |
| if task_level not in graders: | |
| raise ValueError(f"Unknown task level: {task_level}") | |
| return graders[task_level](config) | |