"""Core typed models.""" from __future__ import annotations from datetime import datetime from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field, field_validator from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment from app.common.normalization import clamp_reward class StrictBase(BaseModel): model_config = ConfigDict(extra="forbid") class Medication(StrictBase): drug: str dose_bucket: DoseBucket = DoseBucket.MEDIUM indication: Optional[str] = None class_name: Optional[str] = None requires_taper: bool = False class LabSummary(StrictBase): egfr: Optional[float] = None ast: Optional[float] = None alt: Optional[float] = None inr: Optional[float] = None glucose: Optional[float] = None class PatientProfile(StrictBase): patient_id: str age: int sex: str comorbidities: list[str] = Field(default_factory=list) medications: list[Medication] = Field(default_factory=list) labs: LabSummary = Field(default_factory=LabSummary) vitals: dict[str, float] = Field(default_factory=dict) specialist_conflicts: list[str] = Field(default_factory=list) prior_ade_history: list[str] = Field(default_factory=list) frailty_score: float = 0.3 adherence_estimate: float = 0.8 latent_confounders: dict[str, float] = Field(default_factory=dict) monitoring_gaps: list[str] = Field(default_factory=list) class CandidateAction(StrictBase): candidate_id: str mode: DecisionMode action_type: ActionType target_drug: Optional[str] = None replacement_drug: Optional[str] = None dose_bucket: DoseBucket = DoseBucket.NA taper_days: Optional[int] = None monitoring_plan: Optional[str] = None evidence_query: Optional[str] = None new_drug_name: Optional[str] = None candidate_components: list[str] = Field(default_factory=list) estimated_safety_delta: float = 0.0 burden_delta: float = 0.0 disease_stability_estimate: float = 0.0 uncertainty_score: float = 0.5 rationale_tags: list[str] = Field(default_factory=list) required_monitoring: list[str] = Field(default_factory=list) legality_precheck: bool = True class PolyGuardAction(StrictBase): mode: DecisionMode action_type: ActionType target_drug: Optional[str] = None replacement_drug: Optional[str] = None dose_bucket: DoseBucket = DoseBucket.NA taper_days: Optional[int] = None monitoring_plan: Optional[str] = None evidence_query: Optional[str] = None new_drug_name: Optional[str] = None candidate_components: list[str] = Field(default_factory=list) candidate_id: str confidence: float rationale_brief: str @field_validator("confidence") @classmethod def _valid_confidence(cls, value: float) -> float: return clamp_reward(value) class RewardBreakdown(StrictBase): format_compliance_score: float candidate_alignment_score: float legality_score: float safety_delta_score: float burden_improvement_score: float disease_stability_score: float dosing_quality_score: float abstention_quality_score: float efficiency_score: float process_fidelity_score: float explanation_grounding_score: float anti_cheat_score: float uncertainty_calibration_score: float primary_safety_legality: float = 0.5 primary_clinical_improvement: float = 0.5 primary_dosing_quality: float = 0.5 primary_process_integrity: float = 0.5 total_reward: float class SafetyReport(StrictBase): legal: bool violations: list[str] = Field(default_factory=list) severity: str = "none" recommended_fallback: Optional[ActionType] = None uncertainty_notes: list[str] = Field(default_factory=list) class UncertaintyReport(StrictBase): overall_uncertainty: float = 0.5 missing_data_flags: list[str] = Field(default_factory=list) abstention_recommended: bool = False class PolyGuardState(StrictBase): episode_id: str seed: int scenario_id: Optional[str] = None difficulty: Difficulty sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK step_count: int max_steps: int patient: PatientProfile active_mode: DecisionMode = DecisionMode.REGIMEN_OPT cumulative_reward: float = 0.0 unresolved_conflicts: list[str] = Field(default_factory=list) risk_summary: dict[str, float] = Field(default_factory=dict) burden_score: float = 0.5 precision_dosing_flags: list[str] = Field(default_factory=list) action_history: list[dict[str, Any]] = Field(default_factory=list) done: bool = False created_at: datetime = Field(default_factory=datetime.utcnow) class PolyGuardObservation(StrictBase): patient_summary: dict[str, Any] medication_table: list[dict[str, Any]] comorbidity_summary: list[str] organ_function_summary: dict[str, Any] labs_vitals_summary: dict[str, Any] graph_safety_summary: dict[str, Any] burden_score_summary: dict[str, Any] precision_dosing_flags: list[str] unresolved_conflicts: list[str] candidate_action_set: list[CandidateAction] step_budget_remaining: int action_history: list[dict[str, Any]] warning_summary: list[str] abstention_indicators: dict[str, Any] sub_environment: SubEnvironment deterministic_contract: dict[str, Any] = Field(default_factory=dict) class StepTrace(StrictBase): step: int observation_snapshot: PolyGuardObservation selected_action: Optional[PolyGuardAction] = None critic_output: dict[str, Any] = Field(default_factory=dict) reward_components: dict[str, float] = Field(default_factory=dict) transition_delta: dict[str, Any] = Field(default_factory=dict) uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport) failure_reasons: list[str] = Field(default_factory=list) timeout: bool = False