| """Core typed models.""" |
|
|
| from __future__ import annotations |
|
|
| from datetime import datetime |
| from typing import Any, Optional |
|
|
| from pydantic import BaseModel, ConfigDict, Field, field_validator |
|
|
| from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment |
| from app.common.normalization import clamp_reward |
|
|
|
|
| class StrictBase(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
|
|
|
|
| class Medication(StrictBase): |
| drug: str |
| dose_bucket: DoseBucket = DoseBucket.MEDIUM |
| indication: Optional[str] = None |
| class_name: Optional[str] = None |
| requires_taper: bool = False |
|
|
|
|
| class LabSummary(StrictBase): |
| egfr: Optional[float] = None |
| ast: Optional[float] = None |
| alt: Optional[float] = None |
| inr: Optional[float] = None |
| glucose: Optional[float] = None |
|
|
|
|
| class PatientProfile(StrictBase): |
| patient_id: str |
| age: int |
| sex: str |
| comorbidities: list[str] = Field(default_factory=list) |
| medications: list[Medication] = Field(default_factory=list) |
| labs: LabSummary = Field(default_factory=LabSummary) |
| vitals: dict[str, float] = Field(default_factory=dict) |
| specialist_conflicts: list[str] = Field(default_factory=list) |
| prior_ade_history: list[str] = Field(default_factory=list) |
| frailty_score: float = 0.3 |
| adherence_estimate: float = 0.8 |
| latent_confounders: dict[str, float] = Field(default_factory=dict) |
| monitoring_gaps: list[str] = Field(default_factory=list) |
|
|
|
|
| class CandidateAction(StrictBase): |
| candidate_id: str |
| mode: DecisionMode |
| action_type: ActionType |
| target_drug: Optional[str] = None |
| replacement_drug: Optional[str] = None |
| dose_bucket: DoseBucket = DoseBucket.NA |
| taper_days: Optional[int] = None |
| monitoring_plan: Optional[str] = None |
| evidence_query: Optional[str] = None |
| new_drug_name: Optional[str] = None |
| candidate_components: list[str] = Field(default_factory=list) |
| estimated_safety_delta: float = 0.0 |
| burden_delta: float = 0.0 |
| disease_stability_estimate: float = 0.0 |
| uncertainty_score: float = 0.5 |
| rationale_tags: list[str] = Field(default_factory=list) |
| required_monitoring: list[str] = Field(default_factory=list) |
| legality_precheck: bool = True |
|
|
|
|
| class PolyGuardAction(StrictBase): |
| mode: DecisionMode |
| action_type: ActionType |
| target_drug: Optional[str] = None |
| replacement_drug: Optional[str] = None |
| dose_bucket: DoseBucket = DoseBucket.NA |
| taper_days: Optional[int] = None |
| monitoring_plan: Optional[str] = None |
| evidence_query: Optional[str] = None |
| new_drug_name: Optional[str] = None |
| candidate_components: list[str] = Field(default_factory=list) |
| candidate_id: str |
| confidence: float |
| rationale_brief: str |
|
|
| @field_validator("confidence") |
| @classmethod |
| def _valid_confidence(cls, value: float) -> float: |
| return clamp_reward(value) |
|
|
|
|
| class RewardBreakdown(StrictBase): |
| format_compliance_score: float |
| candidate_alignment_score: float |
| legality_score: float |
| safety_delta_score: float |
| burden_improvement_score: float |
| disease_stability_score: float |
| dosing_quality_score: float |
| abstention_quality_score: float |
| efficiency_score: float |
| process_fidelity_score: float |
| explanation_grounding_score: float |
| anti_cheat_score: float |
| uncertainty_calibration_score: float |
| primary_safety_legality: float = 0.5 |
| primary_clinical_improvement: float = 0.5 |
| primary_dosing_quality: float = 0.5 |
| primary_process_integrity: float = 0.5 |
| total_reward: float |
|
|
|
|
| class SafetyReport(StrictBase): |
| legal: bool |
| violations: list[str] = Field(default_factory=list) |
| severity: str = "none" |
| recommended_fallback: Optional[ActionType] = None |
| uncertainty_notes: list[str] = Field(default_factory=list) |
|
|
|
|
| class UncertaintyReport(StrictBase): |
| overall_uncertainty: float = 0.5 |
| missing_data_flags: list[str] = Field(default_factory=list) |
| abstention_recommended: bool = False |
|
|
|
|
| class PolyGuardState(StrictBase): |
| episode_id: str |
| seed: int |
| scenario_id: Optional[str] = None |
| difficulty: Difficulty |
| sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK |
| step_count: int |
| max_steps: int |
| patient: PatientProfile |
| active_mode: DecisionMode = DecisionMode.REGIMEN_OPT |
| cumulative_reward: float = 0.0 |
| unresolved_conflicts: list[str] = Field(default_factory=list) |
| risk_summary: dict[str, float] = Field(default_factory=dict) |
| burden_score: float = 0.5 |
| precision_dosing_flags: list[str] = Field(default_factory=list) |
| action_history: list[dict[str, Any]] = Field(default_factory=list) |
| done: bool = False |
| created_at: datetime = Field(default_factory=datetime.utcnow) |
|
|
|
|
| class PolyGuardObservation(StrictBase): |
| patient_summary: dict[str, Any] |
| medication_table: list[dict[str, Any]] |
| comorbidity_summary: list[str] |
| organ_function_summary: dict[str, Any] |
| labs_vitals_summary: dict[str, Any] |
| graph_safety_summary: dict[str, Any] |
| burden_score_summary: dict[str, Any] |
| precision_dosing_flags: list[str] |
| unresolved_conflicts: list[str] |
| candidate_action_set: list[CandidateAction] |
| step_budget_remaining: int |
| action_history: list[dict[str, Any]] |
| warning_summary: list[str] |
| abstention_indicators: dict[str, Any] |
| sub_environment: SubEnvironment |
| deterministic_contract: dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
| class StepTrace(StrictBase): |
| step: int |
| observation_snapshot: PolyGuardObservation |
| selected_action: Optional[PolyGuardAction] = None |
| critic_output: dict[str, Any] = Field(default_factory=dict) |
| reward_components: dict[str, float] = Field(default_factory=dict) |
| transition_delta: dict[str, Any] = Field(default_factory=dict) |
| uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport) |
| failure_reasons: list[str] = Field(default_factory=list) |
| timeout: bool = False |
|
|