"""Step-level reward computation for the IRT environment. Provides dense reward signal over the full trajectory: - Positive for relevant investigations, correct classifications, accurate diagnoses, and appropriate remediations. - Negative for irrelevant actions, wrong classifications, destructive remediations, and wasted steps. - Temporal degradation penalty for delayed response. """ from __future__ import annotations from typing import Any, Dict, List, Optional from src.models import ( Action, ActionType, IncidentSeverity, Reward, ) from src.scenarios import Scenario def _normalize(value: float) -> float: """Clamp reward to [-1.0, 1.0].""" return max(-1.0, min(1.0, value)) def compute_step_reward( action: Action, scenario: Scenario, step_number: int, already_investigated: List[str], already_classified: Optional[IncidentSeverity], already_diagnosed: Optional[str], already_remediated: List[str], already_escalated: List[str], already_communicated: List[str], actions_history: List[Dict[str, Any]], ) -> Reward: """Compute the reward for a single step.""" components: Dict[str, float] = {} total = 0.0 # -- Temporal degradation ----------------------------------------------- degradation = -scenario.degradation_per_step * step_number components["temporal_degradation"] = degradation total += degradation # -- Action-specific rewards -------------------------------------------- if action.action_type == ActionType.INVESTIGATE: target = (action.target or "").strip() if target in already_investigated: components["duplicate_investigation"] = -0.03 total -= 0.03 elif target in scenario.relevant_services: components["relevant_investigation"] = 0.06 total += 0.06 elif target in scenario.available_services: components["irrelevant_investigation"] = -0.02 total -= 0.02 else: components["invalid_target"] = -0.05 total -= 0.05 elif action.action_type == ActionType.CLASSIFY: severity_str = action.parameters.get("severity", "") if already_classified is not None: components["duplicate_classify"] = -0.03 total -= 0.03 else: try: given = IncidentSeverity(severity_str) if given == scenario.correct_severity: components["correct_classification"] = 0.15 total += 0.15 else: diff = abs( list(IncidentSeverity).index(given) - list(IncidentSeverity).index(scenario.correct_severity) ) penalty = -0.05 * diff components["wrong_classification"] = penalty total += penalty except ValueError: components["invalid_severity"] = -0.08 total -= 0.08 elif action.action_type == ActionType.DIAGNOSE: if already_diagnosed is not None: components["duplicate_diagnosis"] = -0.03 total -= 0.03 else: root_cause_text = action.parameters.get("root_cause", "").lower() target_svc = (action.target or "").lower() # Check service match if target_svc == scenario.correct_root_cause_service.lower(): components["correct_service"] = 0.10 total += 0.10 elif target_svc: components["wrong_service"] = -0.05 total -= 0.05 # Check root cause keywords matched = any( kw.lower() in root_cause_text for kw in scenario.correct_root_cause_keywords ) if matched: components["correct_root_cause"] = 0.15 total += 0.15 elif root_cause_text: components["wrong_root_cause"] = -0.05 total -= 0.05 elif action.action_type == ActionType.REMEDIATE: rem_action = action.parameters.get("action", "") rem_service = (action.target or "").strip() rem_key = f"{rem_action}:{rem_service}" if rem_key in already_remediated: components["duplicate_remediation"] = -0.03 total -= 0.03 else: valid = any( va.get("action") == rem_action and va.get("service") == rem_service for va in scenario.valid_remediation_actions ) if valid: components["correct_remediation"] = 0.12 total += 0.12 else: components["wrong_remediation"] = -0.08 total -= 0.08 elif action.action_type == ActionType.ESCALATE: team = (action.target or "").strip().lower() if team in [t.lower() for t in already_escalated]: components["duplicate_escalation"] = -0.02 total -= 0.02 elif team in [t.lower() for t in scenario.expected_escalation_teams]: components["correct_escalation"] = 0.05 total += 0.05 else: components["unnecessary_escalation"] = -0.02 total -= 0.02 elif action.action_type == ActionType.COMMUNICATE: message = action.parameters.get("message", "") if len(message) < 10: components["low_quality_communication"] = -0.02 total -= 0.02 elif already_communicated and len(already_communicated) > 3: components["excessive_communication"] = -0.01 total -= 0.01 else: components["status_communication"] = 0.04 total += 0.04 # -- Reasoning bonus (content-aware: credit for mentioning relevant services) - if action.reasoning and len(action.reasoning) > 20: reasoning_lower = action.reasoning.lower() # Check if reasoning references any relevant service or root-cause keyword mentions_relevant = any( svc.lower() in reasoning_lower for svc in scenario.relevant_services ) or any( kw.lower() in reasoning_lower for kw in scenario.correct_root_cause_keywords ) if mentions_relevant: components["reasoning_relevant"] = 0.02 total += 0.02 else: components["reasoning_provided"] = 0.005 total += 0.005 total = _normalize(total) message_parts = [f"{k}: {v:+.3f}" for k, v in components.items()] return Reward( value=round(total, 4), components={k: round(v, 4) for k, v in components.items()}, message="; ".join(message_parts), )