Spaces:
Sleeping
Sleeping
| """ | |
| graders.py β Reward shaping logic for MediRoute OpenEnv. | |
| Each action is evaluated against the ground-truth task expectations. | |
| Rewards are incremental per-step values; the environment accumulates and | |
| clamps the episode total to [0.0, 1.0]. | |
| Reward table | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Correct severity classification (analyze_symptoms) +0.30 | |
| Correct specialist recommendation +0.30 | |
| Correct hospital selection +0.20 | |
| Successful appointment booking (non-emergency) +0.20 | |
| Correct emergency escalation (call_ambulance) +0.50 | |
| Wrong department / specialist -0.20 | |
| Unnecessary loop / duplicate action -0.30 | |
| Calling ambulance on non-emergency -0.30 | |
| Booking appointment in emergency case -0.30 | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List | |
| from models import Action | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Internal helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _is_duplicate(action: Action, previous_actions: List[str]) -> bool: | |
| return action.as_key() in previous_actions | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Public API | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def grade_step( | |
| task: Dict[str, Any], | |
| action: Action, | |
| previous_actions: List[str], | |
| ) -> float: | |
| """ | |
| Compute the incremental reward for a single action taken in *task*. | |
| Args: | |
| task: The full task dict as returned by tasks.get_task(). | |
| action: The Action the agent wants to execute. | |
| previous_actions: Actions already taken this episode (as 'type:target' strings). | |
| Returns: | |
| A float reward value (can be negative; clamping is done in the environment). | |
| """ | |
| # ββ Duplicate penalty ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if _is_duplicate(action, previous_actions): | |
| return -0.30 | |
| action_type = action.action_type | |
| target = (action.target or "").strip() | |
| # ββ analyze_symptoms βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if action_type == "analyze_symptoms": | |
| if target.lower() == task["expected_severity"].lower(): | |
| return 0.30 | |
| else: | |
| return -0.10 # Incorrect severity assessment | |
| # ββ request_more_info ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "request_more_info": | |
| # Neutral in most cases; mild reward only if no prior analysis done | |
| analyzed = any(a.startswith("analyze_symptoms") for a in previous_actions) | |
| return 0.05 if not analyzed else -0.05 | |
| # ββ recommend_specialist βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "recommend_specialist": | |
| if target == task["expected_specialist"]: | |
| return 0.30 | |
| else: | |
| return -0.20 # Wrong department | |
| # ββ select_hospital ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "select_hospital": | |
| if target == task["expected_hospital"]: | |
| return 0.20 | |
| elif target in task["nearby_hospitals"]: | |
| return 0.05 # Nearby but not optimal | |
| else: | |
| return -0.10 # Unknown / unreachable hospital | |
| # ββ book_appointment βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "book_appointment": | |
| if task["requires_ambulance"]: | |
| # Trying to book appointment in a life-threatening emergency is wrong | |
| return -0.30 | |
| return 0.20 | |
| # ββ call_ambulance βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "call_ambulance": | |
| if task["requires_ambulance"]: | |
| return 0.50 # Correct emergency escalation | |
| else: | |
| return -0.30 # Unnecessary ambulance dispatch | |
| # ββ provide_temp_guidance βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| elif action_type == "provide_temp_guidance": | |
| # Acceptable as a closing action for non-emergencies | |
| if not task["requires_ambulance"]: | |
| return 0.10 | |
| else: | |
| return -0.10 # Not enough for a critical patient | |
| # ββ Unknown action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| return -0.10 | |
| def grade_episode( | |
| task: Dict[str, Any], | |
| all_actions: List[str], | |
| final_total_reward: float, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Produce a final episode summary / score report. | |
| Args: | |
| task: Task dict. | |
| all_actions: Full list of action keys taken during the episode. | |
| final_total_reward: Accumulated clamped reward from the environment. | |
| Returns: | |
| A dict with score, pass/fail, and diagnostic breakdown. | |
| """ | |
| score = round(final_total_reward, 4) | |
| passed = score >= 0.5 | |
| breakdown = { | |
| "severity_classified": any( | |
| a.startswith(f"analyze_symptoms:{task['expected_severity']}") | |
| for a in all_actions | |
| ), | |
| "correct_specialist": any( | |
| a.startswith(f"recommend_specialist:{task['expected_specialist']}") | |
| for a in all_actions | |
| ), | |
| "correct_hospital": any( | |
| a.startswith(f"select_hospital:{task['expected_hospital']}") | |
| for a in all_actions | |
| ), | |
| "ambulance_called": any(a.startswith("call_ambulance") for a in all_actions), | |
| "appointment_booked": any(a.startswith("book_appointment") for a in all_actions), | |
| } | |
| return { | |
| "score": score, | |
| "passed": passed, | |
| "difficulty": task["difficulty"], | |
| "total_steps": len(all_actions), | |
| "breakdown": breakdown, | |
| } | |