adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
"""Episode termination logic."""
from __future__ import annotations
from app.common.enums import ActionType
from app.common.types import PolyGuardAction, PolyGuardState
def check_termination(state: PolyGuardState, action: PolyGuardAction, exploit_detected: bool = False) -> tuple[bool, str]:
if exploit_detected:
return True, "exploit_detection"
if state.step_count >= state.max_steps:
return True, "step_budget_exhausted"
invalid_recent = [h for h in state.action_history[-3:] if h.get("applied") is False]
if len(invalid_recent) >= 3:
return True, "repeated_invalid_actions"
if action.action_type in {
ActionType.REQUEST_SPECIALIST_REVIEW,
ActionType.REQUEST_PHARMACIST_REVIEW,
}:
return True, "justified_review_escalation"
if state.risk_summary.get("severe_pair_count", 0.0) >= 2.0 and state.step_count >= max(2, state.max_steps // 2):
return True, "safety_veto_threshold"
if state.risk_summary.get("burden_score", 1.0) > 0.92 and state.step_count >= 2:
return True, "patient_destabilization"
if state.burden_score < 0.25 and not state.unresolved_conflicts:
return True, "safe_resolution"
return False, "ongoing"
def check_termination_with_timeout(
state: PolyGuardState,
action: PolyGuardAction,
exploit_detected: bool = False,
elapsed_seconds: float | None = None,
wall_clock_limit_seconds: float | None = None,
) -> tuple[bool, str]:
done, reason = check_termination(state=state, action=action, exploit_detected=exploit_detected)
if done:
return done, reason
if elapsed_seconds is not None and wall_clock_limit_seconds is not None:
if elapsed_seconds >= max(0.1, wall_clock_limit_seconds):
return True, "wall_clock_timeout"
return False, "ongoing"