Spaces:
Running
Running
File size: 5,947 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Core typed models."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
from app.common.normalization import clamp_reward
class StrictBase(BaseModel):
model_config = ConfigDict(extra="forbid")
class Medication(StrictBase):
drug: str
dose_bucket: DoseBucket = DoseBucket.MEDIUM
indication: Optional[str] = None
class_name: Optional[str] = None
requires_taper: bool = False
class LabSummary(StrictBase):
egfr: Optional[float] = None
ast: Optional[float] = None
alt: Optional[float] = None
inr: Optional[float] = None
glucose: Optional[float] = None
class PatientProfile(StrictBase):
patient_id: str
age: int
sex: str
comorbidities: list[str] = Field(default_factory=list)
medications: list[Medication] = Field(default_factory=list)
labs: LabSummary = Field(default_factory=LabSummary)
vitals: dict[str, float] = Field(default_factory=dict)
specialist_conflicts: list[str] = Field(default_factory=list)
prior_ade_history: list[str] = Field(default_factory=list)
frailty_score: float = 0.3
adherence_estimate: float = 0.8
latent_confounders: dict[str, float] = Field(default_factory=dict)
monitoring_gaps: list[str] = Field(default_factory=list)
class CandidateAction(StrictBase):
candidate_id: str
mode: DecisionMode
action_type: ActionType
target_drug: Optional[str] = None
replacement_drug: Optional[str] = None
dose_bucket: DoseBucket = DoseBucket.NA
taper_days: Optional[int] = None
monitoring_plan: Optional[str] = None
evidence_query: Optional[str] = None
new_drug_name: Optional[str] = None
candidate_components: list[str] = Field(default_factory=list)
estimated_safety_delta: float = 0.0
burden_delta: float = 0.0
disease_stability_estimate: float = 0.0
uncertainty_score: float = 0.5
rationale_tags: list[str] = Field(default_factory=list)
required_monitoring: list[str] = Field(default_factory=list)
legality_precheck: bool = True
class PolyGuardAction(StrictBase):
mode: DecisionMode
action_type: ActionType
target_drug: Optional[str] = None
replacement_drug: Optional[str] = None
dose_bucket: DoseBucket = DoseBucket.NA
taper_days: Optional[int] = None
monitoring_plan: Optional[str] = None
evidence_query: Optional[str] = None
new_drug_name: Optional[str] = None
candidate_components: list[str] = Field(default_factory=list)
candidate_id: str
confidence: float
rationale_brief: str
@field_validator("confidence")
@classmethod
def _valid_confidence(cls, value: float) -> float:
return clamp_reward(value)
class RewardBreakdown(StrictBase):
format_compliance_score: float
candidate_alignment_score: float
legality_score: float
safety_delta_score: float
burden_improvement_score: float
disease_stability_score: float
dosing_quality_score: float
abstention_quality_score: float
efficiency_score: float
process_fidelity_score: float
explanation_grounding_score: float
anti_cheat_score: float
uncertainty_calibration_score: float
primary_safety_legality: float = 0.5
primary_clinical_improvement: float = 0.5
primary_dosing_quality: float = 0.5
primary_process_integrity: float = 0.5
total_reward: float
class SafetyReport(StrictBase):
legal: bool
violations: list[str] = Field(default_factory=list)
severity: str = "none"
recommended_fallback: Optional[ActionType] = None
uncertainty_notes: list[str] = Field(default_factory=list)
class UncertaintyReport(StrictBase):
overall_uncertainty: float = 0.5
missing_data_flags: list[str] = Field(default_factory=list)
abstention_recommended: bool = False
class PolyGuardState(StrictBase):
episode_id: str
seed: int
scenario_id: Optional[str] = None
difficulty: Difficulty
sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK
step_count: int
max_steps: int
patient: PatientProfile
active_mode: DecisionMode = DecisionMode.REGIMEN_OPT
cumulative_reward: float = 0.0
unresolved_conflicts: list[str] = Field(default_factory=list)
risk_summary: dict[str, float] = Field(default_factory=dict)
burden_score: float = 0.5
precision_dosing_flags: list[str] = Field(default_factory=list)
action_history: list[dict[str, Any]] = Field(default_factory=list)
done: bool = False
created_at: datetime = Field(default_factory=datetime.utcnow)
class PolyGuardObservation(StrictBase):
patient_summary: dict[str, Any]
medication_table: list[dict[str, Any]]
comorbidity_summary: list[str]
organ_function_summary: dict[str, Any]
labs_vitals_summary: dict[str, Any]
graph_safety_summary: dict[str, Any]
burden_score_summary: dict[str, Any]
precision_dosing_flags: list[str]
unresolved_conflicts: list[str]
candidate_action_set: list[CandidateAction]
step_budget_remaining: int
action_history: list[dict[str, Any]]
warning_summary: list[str]
abstention_indicators: dict[str, Any]
sub_environment: SubEnvironment
deterministic_contract: dict[str, Any] = Field(default_factory=dict)
class StepTrace(StrictBase):
step: int
observation_snapshot: PolyGuardObservation
selected_action: Optional[PolyGuardAction] = None
critic_output: dict[str, Any] = Field(default_factory=dict)
reward_components: dict[str, float] = Field(default_factory=dict)
transition_delta: dict[str, Any] = Field(default_factory=dict)
uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport)
failure_reasons: list[str] = Field(default_factory=list)
timeout: bool = False
|