"""Reward shaping and regimen-risk computation.""" from __future__ import annotations from itertools import combinations from typing import Dict, List, Optional, Tuple from .config import ( INTERVENTION_COST, INVALID_ACTION_PENALTY, QUERY_COST, SEVERE_DDI_DISCOVERY_BONUS, TIMEOUT_PENALTY, ) from .data_loader import BeersCriterion, DDIRule, DrugMeta def compute_regimen_risk( current_drug_ids: List[str], patient_conditions: List[str], ddi_rules: Dict[Tuple[str, str], DDIRule], beers_criteria: List[BeersCriterion], drug_metadata: Dict[str, DrugMeta], ) -> float: """Compute an aggregate risk score for the current medication regimen. Returns a float clipped to [0.0, 1.0]. """ if not current_drug_ids: return 0.0 risk = 0.0 drug_set = set(current_drug_ids) # 1. DDI pairwise risk for a, b in combinations(sorted(drug_set), 2): key = (a, b) if a < b else (b, a) rule = ddi_rules.get(key) if rule is not None: risk += rule.base_risk_score # 2. Beers violations beers_weight = {"avoid": 0.25, "caution": 0.10, "dose_adjust": 0.08, "avoid_in_condition": 0.20} for bc in beers_criteria: if bc.drug_id not in drug_set: continue if bc.condition is None: risk += beers_weight.get(bc.criterion_type, 0.05) elif bc.condition in patient_conditions: risk += beers_weight.get(bc.criterion_type, 0.05) # 3. High-risk elderly drugs for did in drug_set: dm = drug_metadata.get(did) if dm and dm.is_high_risk_elderly: risk += 0.05 # Normalise by regimen size to keep score comparable across difficulties risk /= max(len(drug_set), 1) return min(max(risk, 0.0), 1.0) def compute_shaped_reward( previous_risk: float, new_risk: float, action_type: str, *, is_invalid: bool = False, is_timeout: bool = False, discovered_severe: bool = False, ) -> float: """Compute the step-level shaped reward.""" reward = 0.0 if is_invalid: return -INVALID_ACTION_PENALTY if is_timeout: return -TIMEOUT_PENALTY if action_type == "query_ddi": reward -= QUERY_COST if discovered_severe: reward += SEVERE_DDI_DISCOVERY_BONUS elif action_type == "propose_intervention": reward += (previous_risk - new_risk) reward -= INTERVENTION_COST # finish_review terminal bonus is added by the caller after grading return reward