| """Anti reward-hacking guards.""" |
|
|
| from __future__ import annotations |
|
|
| from collections import Counter |
| from dataclasses import dataclass |
| from typing import Iterable |
|
|
| from app.common.constants import MAX_KEEP_REGIMEN_RATIO, MAX_REPEATED_ACTIONS, MAX_REVIEW_RATIO |
| from app.common.enums import ActionType |
| from app.common.types import PolyGuardAction, PolyGuardState |
|
|
|
|
| @dataclass(slots=True) |
| class AntiCheatResult: |
| exploit_detected: bool |
| reasons: list[str] |
|
|
|
|
| def detect_repeated_action_loop(actions: Iterable[PolyGuardAction], threshold: int = 3) -> bool: |
| ids = [a.candidate_id for a in actions] |
| if len(ids) < threshold: |
| return False |
| return len(set(ids[-threshold:])) == 1 |
|
|
|
|
| def evaluate_anti_cheat( |
| state: PolyGuardState, |
| action: PolyGuardAction, |
| legal_candidate_ids: set[str] | None = None, |
| ) -> AntiCheatResult: |
| reasons: list[str] = [] |
| history = [ |
| PolyGuardAction.model_validate(item["action"]) if isinstance(item.get("action"), dict) else None |
| for item in state.action_history |
| ] |
| history = [x for x in history if x is not None] |
| if detect_repeated_action_loop(history + [action], threshold=MAX_REPEATED_ACTIONS): |
| reasons.append("repeated_action_loop") |
|
|
| action_types = [a.action_type for a in history] |
| type_count = Counter(action_types) |
| keep_count = type_count.get(ActionType.KEEP_REGIMEN, 0) + (1 if action.action_type == ActionType.KEEP_REGIMEN else 0) |
| total_count = len(history) + 1 |
| if total_count >= 3 and (keep_count / total_count) > MAX_KEEP_REGIMEN_RATIO: |
| reasons.append("keep_regimen_abuse") |
|
|
| review_actions = { |
| ActionType.REQUEST_SPECIALIST_REVIEW, |
| ActionType.REQUEST_PHARMACIST_REVIEW, |
| } |
| review_count = sum(1 for t in action_types if t in review_actions) + (1 if action.action_type in review_actions else 0) |
| if total_count >= 3 and (review_count / total_count) > MAX_REVIEW_RATIO: |
| reasons.append("review_abuse") |
|
|
| if not action.candidate_id.startswith("cand_"): |
| reasons.append("candidate_id_mismatch") |
| if legal_candidate_ids is not None and action.candidate_id not in legal_candidate_ids: |
| reasons.append("candidate_not_in_legal_set") |
|
|
| |
| risky_pair_key = {"warfarin_like", "nsaid_like"} |
| current_drugs = {m.drug for m in state.patient.medications} |
| prior_holdout_keep = any(a.action_type == ActionType.KEEP_REGIMEN for a in history) |
| if risky_pair_key.issubset(current_drugs) and action.action_type == ActionType.KEEP_REGIMEN and prior_holdout_keep: |
| reasons.append("holdout_ddi_not_addressed") |
|
|
| if "<" in action.rationale_brief or "{" in action.rationale_brief: |
| reasons.append("parser_exploit_pattern") |
|
|
| if state.action_history: |
| last = state.action_history[-1] |
| last_action = last.get("action", {}) |
| if ( |
| isinstance(last_action, dict) |
| and last_action.get("candidate_id") == action.candidate_id |
| and last_action.get("action_type") == action.action_type.value |
| and last.get("applied") is False |
| ): |
| reasons.append("no_op_retry_loop") |
|
|
| return AntiCheatResult(exploit_detected=bool(reasons), reasons=reasons) |
|
|