"""Pharma rule engine — hard and soft constraint checking. Hard violations block action execution entirely (the action still deducts no credits and the simulator returns a ``FailureReport``). Soft violations allow execution but degrade output quality and incur penalties. """ from __future__ import annotations from dataclasses import dataclass from enum import Enum from typing import Iterable, List, Optional from models import ActionType, DrugTargetAction from server.simulator.latent_state import FullLatentState class Severity(str, Enum): HARD = "hard" SOFT = "soft" @dataclass class RuleViolation: rule_id: str severity: Severity message: str class RuleEngine: """Evaluates drug-target-validation constraints against the current latent state before each action is applied. """ def check( self, action: DrugTargetAction, state: FullLatentState, *, evidence_dimensions_covered: Optional[Iterable[str]] = None, ) -> List[RuleViolation]: violations: List[RuleViolation] = [] violations.extend(self._check_resource_constraints(action, state)) violations.extend(self._check_submission( action, state, evidence_dimensions_covered or [], )) violations.extend(self._check_redundancy(action, state)) violations.extend(self._check_ordering(action, state)) return violations @staticmethod def hard_violations(violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.HARD] @staticmethod def soft_violations(violations: List[RuleViolation]) -> List[str]: return [v.message for v in violations if v.severity == Severity.SOFT] # ── resource / credit constraints ─────────────────────────────────── def _check_resource_constraints( self, action: DrugTargetAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] from server.simulator.transition import compute_action_cost cost = compute_action_cost(action) if s.credits.exhausted and action.action_type != ActionType.SUBMIT_VALIDATION_REPORT: vs.append(RuleViolation( rule_id="credits_exhausted", severity=Severity.HARD, message="Credits exhausted - submit validation report or end episode", )) elif cost > s.credits.credits_remaining and cost > 0: vs.append(RuleViolation( rule_id="credits_insufficient", severity=Severity.HARD, message=( f"Action costs {cost} credits but only " f"{s.credits.credits_remaining} remain" ), )) return vs # ── submission validation ─────────────────────────────────────────── def _check_submission( self, action: DrugTargetAction, s: FullLatentState, evidence_dimensions_covered: Iterable[str], ) -> List[RuleViolation]: vs: List[RuleViolation] = [] if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT: return vs # Hard: report with no evidence at all. if not list(evidence_dimensions_covered): vs.append(RuleViolation( rule_id="report_without_evidence", severity=Severity.HARD, message=( "Cannot submit validation report without gathering " "any evidence" ), )) # Hard: report missing decision or confidence. if action.final_decision is None: vs.append(RuleViolation( rule_id="report_missing_decision", severity=Severity.HARD, message=( "Submitting validation report without a final_decision " "is not allowed" ), )) elif action.final_decision.lower() not in {"go", "no_go"}: vs.append(RuleViolation( rule_id="report_invalid_decision", severity=Severity.HARD, message=( f"final_decision must be 'go' or 'no_go', got " f"{action.final_decision!r}" ), )) if action.confidence is None: vs.append(RuleViolation( rule_id="report_missing_confidence", severity=Severity.HARD, message=( "Submitting validation report without a confidence " "score is not allowed" ), )) elif action.confidence < 0.30: vs.append(RuleViolation( rule_id="report_low_confidence", severity=Severity.SOFT, message=( f"Submitting with very low confidence " f"({action.confidence:.2f}) — the agent appears " f"poorly calibrated" ), )) return vs # ── redundancy checks ─────────────────────────────────────────────── def _check_redundancy( self, action: DrugTargetAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] if action.action_type == ActionType.FLAG_RED_FLAG: return vs if action.action_type == ActionType.SUBMIT_VALIDATION_REPORT: if s.progress.report_submitted: vs.append(RuleViolation( rule_id="duplicate_report", severity=Severity.HARD, message="Validation report has already been submitted", )) return vs count = s.action_call_counts.get(action.action_type.value, 0) if count >= 2: vs.append(RuleViolation( rule_id=f"redundant_{action.action_type.value}", severity=Severity.SOFT, message=( f"Action '{action.action_type.value}' has already been " f"executed {count} time(s); further repeats are " f"redundant" ), )) return vs # ── ordering checks ───────────────────────────────────────────────── def _check_ordering( self, action: DrugTargetAction, s: FullLatentState ) -> List[RuleViolation]: vs: List[RuleViolation] = [] p = s.progress if action.action_type == ActionType.IN_VIVO_MODEL and not p.in_vitro_done: vs.append(RuleViolation( rule_id="in_vivo_before_in_vitro", severity=Severity.SOFT, message=( "Running in_vivo_model before in_vitro_assay is " "scientifically backwards" ), )) if ( action.action_type == ActionType.TOXICITY_PANEL and not p.expression_queried ): vs.append(RuleViolation( rule_id="toxicity_before_expression", severity=Severity.SOFT, message=( "Toxicity panel before any expression query — " "tissue-specific toxicity will be hard to interpret" ), )) return vs