Spaces:
Sleeping
Sleeping
| """Reward shaping and regimen-risk computation.""" | |
| from __future__ import annotations | |
| from itertools import combinations | |
| from typing import Dict, List, Optional, Tuple | |
| from .config import ( | |
| INTERVENTION_COST, | |
| INVALID_ACTION_PENALTY, | |
| MODERATE_DDI_DISCOVERY_BONUS, | |
| QUERY_COST, | |
| SEVERE_DDI_DISCOVERY_BONUS, | |
| TIMEOUT_PENALTY, | |
| ) | |
| from .data_loader import BeersCriterion, DDIRule, DrugMeta | |
| def compute_regimen_risk( | |
| current_drug_ids: List[str], | |
| patient_conditions: List[str], | |
| ddi_rules: Dict[Tuple[str, str], DDIRule], | |
| beers_criteria: List[BeersCriterion], | |
| drug_metadata: Dict[str, DrugMeta], | |
| ) -> float: | |
| """Compute an aggregate risk score for the current medication regimen. | |
| Returns a float clipped to [0.0, 1.0]. | |
| """ | |
| if not current_drug_ids: | |
| return 0.0 | |
| risk = 0.0 | |
| drug_set = set(current_drug_ids) | |
| # 1. DDI pairwise risk | |
| for a, b in combinations(sorted(drug_set), 2): | |
| key = (a, b) if a < b else (b, a) | |
| rule = ddi_rules.get(key) | |
| if rule is not None: | |
| risk += rule.base_risk_score | |
| # 2. Beers violations (weights reflect clinical severity) | |
| beers_weight = {"avoid": 0.30, "caution": 0.12, "dose_adjust": 0.10, "avoid_in_condition": 0.25} | |
| for bc in beers_criteria: | |
| if bc.drug_id not in drug_set: | |
| continue | |
| if bc.condition is None: | |
| risk += beers_weight.get(bc.criterion_type, 0.05) | |
| elif bc.condition in patient_conditions: | |
| risk += beers_weight.get(bc.criterion_type, 0.05) | |
| # 3. High-risk elderly drugs | |
| for did in drug_set: | |
| dm = drug_metadata.get(did) | |
| if dm and dm.is_high_risk_elderly: | |
| risk += 0.05 | |
| # Normalise by regimen size to keep score comparable across difficulties | |
| risk /= max(len(drug_set), 1) | |
| return min(max(risk, 0.0), 1.0) | |
| def compute_shaped_reward( | |
| previous_risk: float, | |
| new_risk: float, | |
| action_type: str, | |
| *, | |
| is_invalid: bool = False, | |
| is_timeout: bool = False, | |
| discovered_severe: bool = False, | |
| discovered_moderate: bool = False, | |
| ) -> float: | |
| """Compute the step-level shaped reward.""" | |
| reward = 0.0 | |
| if is_invalid: | |
| reward = -INVALID_ACTION_PENALTY | |
| elif is_timeout: | |
| reward = -TIMEOUT_PENALTY | |
| elif action_type == "query_ddi": | |
| reward -= QUERY_COST | |
| if discovered_severe: | |
| reward += SEVERE_DDI_DISCOVERY_BONUS | |
| elif discovered_moderate: | |
| reward += MODERATE_DDI_DISCOVERY_BONUS | |
| elif action_type == "propose_intervention": | |
| reward += (previous_risk - new_risk) | |
| reward -= INTERVENTION_COST | |
| # finish_review terminal bonus is added by the caller after grading | |
| # Clamp all rewards to strict (0.001, 0.999) range | |
| return max(0.001, min(0.999, reward)) | |