| """Episode termination logic.""" |
|
|
| from __future__ import annotations |
|
|
| from app.common.types import PolyGuardAction, PolyGuardState |
|
|
|
|
| def check_termination(state: PolyGuardState, action: PolyGuardAction, exploit_detected: bool = False) -> tuple[bool, str]: |
| if exploit_detected: |
| return True, "exploit_detection" |
|
|
| if state.step_count >= state.max_steps: |
| return True, "step_budget_exhausted" |
|
|
| invalid_recent = [h for h in state.action_history[-3:] if h.get("applied") is False] |
| if len(invalid_recent) >= 3: |
| return True, "repeated_invalid_actions" |
|
|
| if state.risk_summary.get("severe_pair_count", 0.0) >= 2.0 and state.step_count >= max(2, state.max_steps // 2): |
| return True, "safety_veto_threshold" |
|
|
| if state.risk_summary.get("burden_score", 1.0) > 0.92 and state.step_count >= 2: |
| return True, "patient_destabilization" |
|
|
| if state.burden_score < 0.25 and not state.unresolved_conflicts: |
| return True, "safe_resolution" |
|
|
| return False, "ongoing" |
|
|
|
|
| def check_termination_with_timeout( |
| state: PolyGuardState, |
| action: PolyGuardAction, |
| exploit_detected: bool = False, |
| elapsed_seconds: float | None = None, |
| wall_clock_limit_seconds: float | None = None, |
| ) -> tuple[bool, str]: |
| done, reason = check_termination(state=state, action=action, exploit_detected=exploit_detected) |
| if done: |
| return done, reason |
| if elapsed_seconds is not None and wall_clock_limit_seconds is not None: |
| if elapsed_seconds >= max(0.1, wall_clock_limit_seconds): |
| return True, "wall_clock_timeout" |
| return False, "ongoing" |
|
|