"""Anti reward-hacking guards.""" from __future__ import annotations from collections import Counter from dataclasses import dataclass from typing import Iterable from app.common.constants import MAX_KEEP_REGIMEN_RATIO, MAX_REPEATED_ACTIONS, MAX_REVIEW_RATIO from app.common.enums import ActionType from app.common.types import PolyGuardAction, PolyGuardState @dataclass(slots=True) class AntiCheatResult: exploit_detected: bool reasons: list[str] def detect_repeated_action_loop(actions: Iterable[PolyGuardAction], threshold: int = 3) -> bool: ids = [a.candidate_id for a in actions] if len(ids) < threshold: return False return len(set(ids[-threshold:])) == 1 def evaluate_anti_cheat( state: PolyGuardState, action: PolyGuardAction, legal_candidate_ids: set[str] | None = None, ) -> AntiCheatResult: reasons: list[str] = [] history = [ PolyGuardAction.model_validate(item["action"]) if isinstance(item.get("action"), dict) else None for item in state.action_history ] history = [x for x in history if x is not None] if detect_repeated_action_loop(history + [action], threshold=MAX_REPEATED_ACTIONS): reasons.append("repeated_action_loop") action_types = [a.action_type for a in history] type_count = Counter(action_types) keep_count = type_count.get(ActionType.KEEP_REGIMEN, 0) + (1 if action.action_type == ActionType.KEEP_REGIMEN else 0) total_count = len(history) + 1 if total_count >= 3 and (keep_count / total_count) > MAX_KEEP_REGIMEN_RATIO: reasons.append("keep_regimen_abuse") review_actions = { ActionType.REQUEST_SPECIALIST_REVIEW, ActionType.REQUEST_PHARMACIST_REVIEW, } review_count = sum(1 for t in action_types if t in review_actions) + (1 if action.action_type in review_actions else 0) if total_count >= 3 and (review_count / total_count) > MAX_REVIEW_RATIO: reasons.append("review_abuse") if not action.candidate_id.startswith("cand_"): reasons.append("candidate_id_mismatch") if legal_candidate_ids is not None and action.candidate_id not in legal_candidate_ids: reasons.append("candidate_not_in_legal_set") # Hidden holdout rule: known high-risk pair should not be repeatedly ignored. risky_pair_key = {"warfarin_like", "nsaid_like"} current_drugs = {m.drug for m in state.patient.medications} if risky_pair_key.issubset(current_drugs) and action.action_type == ActionType.KEEP_REGIMEN: reasons.append("holdout_ddi_not_addressed") if "<" in action.rationale_brief or "{" in action.rationale_brief: reasons.append("parser_exploit_pattern") if state.action_history: last = state.action_history[-1] last_action = last.get("action", {}) if ( isinstance(last_action, dict) and last_action.get("candidate_id") == action.candidate_id and last_action.get("action_type") == action.action_type.value and last.get("applied") is False ): reasons.append("no_op_retry_loop") return AntiCheatResult(exploit_detected=bool(reasons), reasons=reasons)