Spaces:
Running
Running
| """Core typed models.""" | |
| from __future__ import annotations | |
| from datetime import datetime | |
| from typing import Any, Optional | |
| from pydantic import BaseModel, ConfigDict, Field, field_validator | |
| from app.common.enums import ActionType, DecisionMode, Difficulty, DoseBucket, SubEnvironment | |
| from app.common.normalization import clamp_reward | |
| class StrictBase(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| class Medication(StrictBase): | |
| drug: str | |
| dose_bucket: DoseBucket = DoseBucket.MEDIUM | |
| indication: Optional[str] = None | |
| class_name: Optional[str] = None | |
| requires_taper: bool = False | |
| class LabSummary(StrictBase): | |
| egfr: Optional[float] = None | |
| ast: Optional[float] = None | |
| alt: Optional[float] = None | |
| inr: Optional[float] = None | |
| glucose: Optional[float] = None | |
| class PatientProfile(StrictBase): | |
| patient_id: str | |
| age: int | |
| sex: str | |
| comorbidities: list[str] = Field(default_factory=list) | |
| medications: list[Medication] = Field(default_factory=list) | |
| labs: LabSummary = Field(default_factory=LabSummary) | |
| vitals: dict[str, float] = Field(default_factory=dict) | |
| specialist_conflicts: list[str] = Field(default_factory=list) | |
| prior_ade_history: list[str] = Field(default_factory=list) | |
| frailty_score: float = 0.3 | |
| adherence_estimate: float = 0.8 | |
| latent_confounders: dict[str, float] = Field(default_factory=dict) | |
| monitoring_gaps: list[str] = Field(default_factory=list) | |
| class CandidateAction(StrictBase): | |
| candidate_id: str | |
| mode: DecisionMode | |
| action_type: ActionType | |
| target_drug: Optional[str] = None | |
| replacement_drug: Optional[str] = None | |
| dose_bucket: DoseBucket = DoseBucket.NA | |
| taper_days: Optional[int] = None | |
| monitoring_plan: Optional[str] = None | |
| evidence_query: Optional[str] = None | |
| new_drug_name: Optional[str] = None | |
| candidate_components: list[str] = Field(default_factory=list) | |
| estimated_safety_delta: float = 0.0 | |
| burden_delta: float = 0.0 | |
| disease_stability_estimate: float = 0.0 | |
| uncertainty_score: float = 0.5 | |
| rationale_tags: list[str] = Field(default_factory=list) | |
| required_monitoring: list[str] = Field(default_factory=list) | |
| legality_precheck: bool = True | |
| class PolyGuardAction(StrictBase): | |
| mode: DecisionMode | |
| action_type: ActionType | |
| target_drug: Optional[str] = None | |
| replacement_drug: Optional[str] = None | |
| dose_bucket: DoseBucket = DoseBucket.NA | |
| taper_days: Optional[int] = None | |
| monitoring_plan: Optional[str] = None | |
| evidence_query: Optional[str] = None | |
| new_drug_name: Optional[str] = None | |
| candidate_components: list[str] = Field(default_factory=list) | |
| candidate_id: str | |
| confidence: float | |
| rationale_brief: str | |
| def _valid_confidence(cls, value: float) -> float: | |
| return clamp_reward(value) | |
| class RewardBreakdown(StrictBase): | |
| format_compliance_score: float | |
| candidate_alignment_score: float | |
| legality_score: float | |
| safety_delta_score: float | |
| burden_improvement_score: float | |
| disease_stability_score: float | |
| dosing_quality_score: float | |
| abstention_quality_score: float | |
| efficiency_score: float | |
| process_fidelity_score: float | |
| explanation_grounding_score: float | |
| anti_cheat_score: float | |
| uncertainty_calibration_score: float | |
| primary_safety_legality: float = 0.5 | |
| primary_clinical_improvement: float = 0.5 | |
| primary_dosing_quality: float = 0.5 | |
| primary_process_integrity: float = 0.5 | |
| total_reward: float | |
| class SafetyReport(StrictBase): | |
| legal: bool | |
| violations: list[str] = Field(default_factory=list) | |
| severity: str = "none" | |
| recommended_fallback: Optional[ActionType] = None | |
| uncertainty_notes: list[str] = Field(default_factory=list) | |
| class UncertaintyReport(StrictBase): | |
| overall_uncertainty: float = 0.5 | |
| missing_data_flags: list[str] = Field(default_factory=list) | |
| abstention_recommended: bool = False | |
| class PolyGuardState(StrictBase): | |
| episode_id: str | |
| seed: int | |
| scenario_id: Optional[str] = None | |
| difficulty: Difficulty | |
| sub_environment: SubEnvironment = SubEnvironment.REGIMEN_RISK | |
| step_count: int | |
| max_steps: int | |
| patient: PatientProfile | |
| active_mode: DecisionMode = DecisionMode.REGIMEN_OPT | |
| cumulative_reward: float = 0.0 | |
| unresolved_conflicts: list[str] = Field(default_factory=list) | |
| risk_summary: dict[str, float] = Field(default_factory=dict) | |
| burden_score: float = 0.5 | |
| precision_dosing_flags: list[str] = Field(default_factory=list) | |
| action_history: list[dict[str, Any]] = Field(default_factory=list) | |
| done: bool = False | |
| created_at: datetime = Field(default_factory=datetime.utcnow) | |
| class PolyGuardObservation(StrictBase): | |
| patient_summary: dict[str, Any] | |
| medication_table: list[dict[str, Any]] | |
| comorbidity_summary: list[str] | |
| organ_function_summary: dict[str, Any] | |
| labs_vitals_summary: dict[str, Any] | |
| graph_safety_summary: dict[str, Any] | |
| burden_score_summary: dict[str, Any] | |
| precision_dosing_flags: list[str] | |
| unresolved_conflicts: list[str] | |
| candidate_action_set: list[CandidateAction] | |
| step_budget_remaining: int | |
| action_history: list[dict[str, Any]] | |
| warning_summary: list[str] | |
| abstention_indicators: dict[str, Any] | |
| sub_environment: SubEnvironment | |
| deterministic_contract: dict[str, Any] = Field(default_factory=dict) | |
| class StepTrace(StrictBase): | |
| step: int | |
| observation_snapshot: PolyGuardObservation | |
| selected_action: Optional[PolyGuardAction] = None | |
| critic_output: dict[str, Any] = Field(default_factory=dict) | |
| reward_components: dict[str, float] = Field(default_factory=dict) | |
| transition_delta: dict[str, Any] = Field(default_factory=dict) | |
| uncertainty_report: UncertaintyReport = Field(default_factory=UncertaintyReport) | |
| failure_reasons: list[str] = Field(default_factory=list) | |
| timeout: bool = False | |