TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Core typed models."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment
from app.common.normalization import clamp_reward
class StrictBase(BaseModel):
model_config = ConfigDict(extra="forbid")
class Medication(StrictBase):
drug: str
dose_bucket: DoseBucket = DoseBucket.MEDIUM
indication: Optional[str] = None
class_name: Optional[str] = None
requires_taper: bool = False
class LabSummary(StrictBase):
egfr: Optional[float] = None
ast: Optional[float] = None
alt: Optional[float] = None
inr: Optional[float] = None
glucose: Optional[float] = None
class PatientProfile(StrictBase):
patient_id: str
age: int
sex: str
comorbidities: list[str] = Field(default_factory=list)
medications: list[Medication] = Field(default_factory=list)
labs: LabSummary = Field(default_factory=LabSummary)
vitals: dict[str, float] = Field(default_factory=dict)
specialist_conflicts: list[str] = Field(default_factory=list)
prior_ade_history: list[str] = Field(default_factory=list)
frailty_score: float = 0.3
adherence_estimate: float = 0.8
latent_confounders: dict[str, float] = Field(default_factory=dict)
monitoring_gaps: list[str] = Field(default_factory=list)
class CandidateAction(StrictBase):
candidate_id: str
mode: DecisionMode
action_type: ActionType
target_drug: Optional[str] = None
replacement_drug: Optional[str] = None
dose_bucket: DoseBucket = DoseBucket.NA
taper_days: Optional[int] = None
monitoring_plan: Optional[str] = None
evidence_query: Optional[str] = None
new_drug_name: Optional[str] = None
candidate_components: list[str] = Field(default_factory=list)
estimated_safety_delta: float = 0.0
burden_delta: float = 0.0
disease_stability_estimate: float = 0.0
uncertainty_score: float = 0.5
rationale_tags: list[str] = Field(default_factory=list)
required_monitoring: list[str] = Field(default_factory=list)
legality_precheck: bool = True
class PolyGuardAction(StrictBase):
mode: DecisionMode
action_type: ActionType
target_drug: Optional[str] = None
replacement_drug: Optional[str] = None
dose_bucket: DoseBucket = DoseBucket.NA
taper_days: Optional[int] = None
monitoring_plan: Optional[str] = None
evidence_query: Optional[str] = None
new_drug_name: Optional[str] = None
candidate_components: list[str] = Field(default_factory=list)
candidate_id: str
confidence: float
rationale_brief: str
@field_validator("confidence")
@classmethod
def _valid_confidence(cls, value: float) -> float:
return clamp_reward(value)
class RewardBreakdown(StrictBase):
format_compliance_score: float
candidate_alignment_score: float
legality_score: float
safety_delta_score: float
burden_improvement_score: float
disease_stability_score: float
dosing_quality_score: float
abstention_quality_score: float
efficiency_score: float
process_fidelity_score: float
explanation_grounding_score: float
anti_cheat_score: float
uncertainty_calibration_score: float
primary_safety_legality: float = 0.5
primary_clinical_improvement: float = 0.5
primary_dosing_quality: float = 0.5
primary_process_integrity: float = 0.5
total_reward: float
class SafetyReport(StrictBase):
legal: bool
violations: list[str] = Field(default_factory=list)
severity: str = "none"
recommended_fallback: Optional[ActionType] = None
uncertainty_notes: list[str] = Field(default_factory=list)
class UncertaintyReport(StrictBase):
overall_uncertainty: float = 0.5
missing_data_flags: list[str] = Field(default_factory=list)
abstention_recommended: bool = False
class PolyGuardState(StrictBase):
episode_id: str
seed: int
scenario_id: Optional[str] = None
difficulty: Difficulty
sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK
step_count: int
max_steps: int
patient: PatientProfile
active_mode: DecisionMode = DecisionMode.REGIMEN_OPT
cumulative_reward: float = 0.0
unresolved_conflicts: list[str] = Field(default_factory=list)
risk_summary: dict[str, float] = Field(default_factory=dict)
burden_score: float = 0.5
precision_dosing_flags: list[str] = Field(default_factory=list)
action_history: list[dict[str, Any]] = Field(default_factory=list)
done: bool = False
created_at: datetime = Field(default_factory=datetime.utcnow)
class PolyGuardObservation(StrictBase):
patient_summary: dict[str, Any]
medication_table: list[dict[str, Any]]
comorbidity_summary: list[str]
organ_function_summary: dict[str, Any]
labs_vitals_summary: dict[str, Any]
graph_safety_summary: dict[str, Any]
burden_score_summary: dict[str, Any]
precision_dosing_flags: list[str]
unresolved_conflicts: list[str]
candidate_action_set: list[CandidateAction]
step_budget_remaining: int
action_history: list[dict[str, Any]]
warning_summary: list[str]
abstention_indicators: dict[str, Any]
sub_environment: SubEnvironment
deterministic_contract: dict[str, Any] = Field(default_factory=dict)
class StepTrace(StrictBase):
step: int
observation_snapshot: PolyGuardObservation
selected_action: Optional[PolyGuardAction] = None
critic_output: dict[str, Any] = Field(default_factory=dict)
reward_components: dict[str, float] = Field(default_factory=dict)
transition_delta: dict[str, Any] = Field(default_factory=dict)
uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport)
failure_reasons: list[str] = Field(default_factory=list)
timeout: bool = False