Spaces:
Running
Running
| """Step-level reward computation for the IRT environment. | |
| Provides dense reward signal over the full trajectory: | |
| - Positive for relevant investigations, correct classifications, | |
| accurate diagnoses, and appropriate remediations. | |
| - Negative for irrelevant actions, wrong classifications, | |
| destructive remediations, and wasted steps. | |
| - Temporal degradation penalty for delayed response. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from src.models import ( | |
| Action, | |
| ActionType, | |
| IncidentSeverity, | |
| Reward, | |
| ) | |
| from src.scenarios import Scenario | |
| def _normalize(value: float) -> float: | |
| """Clamp reward to [-1.0, 1.0].""" | |
| return max(-1.0, min(1.0, value)) | |
| def compute_step_reward( | |
| action: Action, | |
| scenario: Scenario, | |
| step_number: int, | |
| already_investigated: List[str], | |
| already_classified: Optional[IncidentSeverity], | |
| already_diagnosed: Optional[str], | |
| already_remediated: List[str], | |
| already_escalated: List[str], | |
| already_communicated: List[str], | |
| actions_history: List[Dict[str, Any]], | |
| ) -> Reward: | |
| """Compute the reward for a single step.""" | |
| components: Dict[str, float] = {} | |
| total = 0.0 | |
| # -- Temporal degradation ----------------------------------------------- | |
| degradation = -scenario.degradation_per_step * step_number | |
| components["temporal_degradation"] = degradation | |
| total += degradation | |
| # -- Action-specific rewards -------------------------------------------- | |
| if action.action_type == ActionType.INVESTIGATE: | |
| target = (action.target or "").strip() | |
| if target in already_investigated: | |
| components["duplicate_investigation"] = -0.03 | |
| total -= 0.03 | |
| elif target in scenario.relevant_services: | |
| components["relevant_investigation"] = 0.06 | |
| total += 0.06 | |
| elif target in scenario.available_services: | |
| components["irrelevant_investigation"] = -0.02 | |
| total -= 0.02 | |
| else: | |
| components["invalid_target"] = -0.05 | |
| total -= 0.05 | |
| elif action.action_type == ActionType.CLASSIFY: | |
| severity_str = action.parameters.get("severity", "") | |
| if already_classified is not None: | |
| components["duplicate_classify"] = -0.03 | |
| total -= 0.03 | |
| else: | |
| try: | |
| given = IncidentSeverity(severity_str) | |
| if given == scenario.correct_severity: | |
| components["correct_classification"] = 0.15 | |
| total += 0.15 | |
| else: | |
| diff = abs( | |
| list(IncidentSeverity).index(given) | |
| - list(IncidentSeverity).index(scenario.correct_severity) | |
| ) | |
| penalty = -0.05 * diff | |
| components["wrong_classification"] = penalty | |
| total += penalty | |
| except ValueError: | |
| components["invalid_severity"] = -0.08 | |
| total -= 0.08 | |
| elif action.action_type == ActionType.DIAGNOSE: | |
| if already_diagnosed is not None: | |
| components["duplicate_diagnosis"] = -0.03 | |
| total -= 0.03 | |
| else: | |
| root_cause_text = action.parameters.get("root_cause", "").lower() | |
| target_svc = (action.target or "").lower() | |
| # Check service match | |
| if target_svc == scenario.correct_root_cause_service.lower(): | |
| components["correct_service"] = 0.10 | |
| total += 0.10 | |
| elif target_svc: | |
| components["wrong_service"] = -0.05 | |
| total -= 0.05 | |
| # Check root cause keywords | |
| matched = any( | |
| kw.lower() in root_cause_text | |
| for kw in scenario.correct_root_cause_keywords | |
| ) | |
| if matched: | |
| components["correct_root_cause"] = 0.15 | |
| total += 0.15 | |
| elif root_cause_text: | |
| components["wrong_root_cause"] = -0.05 | |
| total -= 0.05 | |
| elif action.action_type == ActionType.REMEDIATE: | |
| rem_action = action.parameters.get("action", "") | |
| rem_service = (action.target or "").strip() | |
| rem_key = f"{rem_action}:{rem_service}" | |
| if rem_key in already_remediated: | |
| components["duplicate_remediation"] = -0.03 | |
| total -= 0.03 | |
| else: | |
| valid = any( | |
| va.get("action") == rem_action and va.get("service") == rem_service | |
| for va in scenario.valid_remediation_actions | |
| ) | |
| if valid: | |
| components["correct_remediation"] = 0.12 | |
| total += 0.12 | |
| else: | |
| components["wrong_remediation"] = -0.08 | |
| total -= 0.08 | |
| elif action.action_type == ActionType.ESCALATE: | |
| team = (action.target or "").strip().lower() | |
| if team in [t.lower() for t in already_escalated]: | |
| components["duplicate_escalation"] = -0.02 | |
| total -= 0.02 | |
| elif team in [t.lower() for t in scenario.expected_escalation_teams]: | |
| components["correct_escalation"] = 0.05 | |
| total += 0.05 | |
| else: | |
| components["unnecessary_escalation"] = -0.02 | |
| total -= 0.02 | |
| elif action.action_type == ActionType.COMMUNICATE: | |
| message = action.parameters.get("message", "") | |
| if len(message) < 10: | |
| components["low_quality_communication"] = -0.02 | |
| total -= 0.02 | |
| elif already_communicated and len(already_communicated) > 3: | |
| components["excessive_communication"] = -0.01 | |
| total -= 0.01 | |
| else: | |
| components["status_communication"] = 0.04 | |
| total += 0.04 | |
| # -- Reasoning bonus (content-aware: credit for mentioning relevant services) - | |
| if action.reasoning and len(action.reasoning) > 20: | |
| reasoning_lower = action.reasoning.lower() | |
| # Check if reasoning references any relevant service or root-cause keyword | |
| mentions_relevant = any( | |
| svc.lower() in reasoning_lower for svc in scenario.relevant_services | |
| ) or any( | |
| kw.lower() in reasoning_lower for kw in scenario.correct_root_cause_keywords | |
| ) | |
| if mentions_relevant: | |
| components["reasoning_relevant"] = 0.02 | |
| total += 0.02 | |
| else: | |
| components["reasoning_provided"] = 0.005 | |
| total += 0.005 | |
| total = _normalize(total) | |
| message_parts = [f"{k}: {v:+.3f}" for k, v in components.items()] | |
| return Reward( | |
| value=round(total, 4), | |
| components={k: round(v, 4) for k, v in components.items()}, | |
| message="; ".join(message_parts), | |
| ) | |