drugenv / server /rules /engine.py
anugrahteesdollar's picture
initial: drugenv FastAPI + gradio demo
77e1e28 verified
"""Pharma rule engine β€” hard and soft constraint checking.
Hard violations block action execution entirely (the action still
deducts no credits and the simulator returns a ``FailureReport``).
Soft violations allow execution but degrade output quality and incur
penalties.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Iterable, List, Optional
from models import ActionType, DrugTargetAction
from server.simulator.latent_state import FullLatentState
class Severity(str, Enum):
HARD = "hard"
SOFT = "soft"
@dataclass
class RuleViolation:
rule_id: str
severity: Severity
message: str
class RuleEngine:
"""Evaluates drug-target-validation constraints against the current
latent state before each action is applied.
"""
def check(
self,
action: DrugTargetAction,
state: FullLatentState,
*,
evidence_dimensions_covered: Optional[Iterable[str]] = None,
) -> List[RuleViolation]:
violations: List[RuleViolation] = []
violations.extend(self._check_resource_constraints(action, state))
violations.extend(self._check_submission(
action, state, evidence_dimensions_covered or [],
))
violations.extend(self._check_redundancy(action, state))
violations.extend(self._check_ordering(action, state))
return violations
@staticmethod
def hard_violations(violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.HARD]
@staticmethod
def soft_violations(violations: List[RuleViolation]) -> List[str]:
return [v.message for v in violations if v.severity == Severity.SOFT]
# ── resource / credit constraints ───────────────────────────────────
def _check_resource_constraints(
self, action: DrugTargetAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
from server.simulator.transition import compute_action_cost
cost = compute_action_cost(action)
if s.credits.exhausted and action.action_type != ActionType.SUBMIT_VALIDATION_REPORT:
vs.append(RuleViolation(
rule_id="credits_exhausted",
severity=Severity.HARD,
message="Credits exhausted - submit validation report or end episode",
))
elif cost > s.credits.credits_remaining and cost > 0:
vs.append(RuleViolation(
rule_id="credits_insufficient",
severity=Severity.HARD,
message=(
f"Action costs {cost} credits but only "
f"{s.credits.credits_remaining} remain"
),
))
return vs
# ── submission validation ───────────────────────────────────────────
def _check_submission(
self,
action: DrugTargetAction,
s: FullLatentState,
evidence_dimensions_covered: Iterable[str],
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT:
return vs
# Hard: report with no evidence at all.
if not list(evidence_dimensions_covered):
vs.append(RuleViolation(
rule_id="report_without_evidence",
severity=Severity.HARD,
message=(
"Cannot submit validation report without gathering "
"any evidence"
),
))
# Hard: report missing decision or confidence.
if action.final_decision is None:
vs.append(RuleViolation(
rule_id="report_missing_decision",
severity=Severity.HARD,
message=(
"Submitting validation report without a final_decision "
"is not allowed"
),
))
elif action.final_decision.lower() not in {"go", "no_go"}:
vs.append(RuleViolation(
rule_id="report_invalid_decision",
severity=Severity.HARD,
message=(
f"final_decision must be 'go' or 'no_go', got "
f"{action.final_decision!r}"
),
))
if action.confidence is None:
vs.append(RuleViolation(
rule_id="report_missing_confidence",
severity=Severity.HARD,
message=(
"Submitting validation report without a confidence "
"score is not allowed"
),
))
elif action.confidence < 0.30:
vs.append(RuleViolation(
rule_id="report_low_confidence",
severity=Severity.SOFT,
message=(
f"Submitting with very low confidence "
f"({action.confidence:.2f}) β€” the agent appears "
f"poorly calibrated"
),
))
return vs
# ── redundancy checks ───────────────────────────────────────────────
def _check_redundancy(
self, action: DrugTargetAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
if action.action_type == ActionType.FLAG_RED_FLAG:
return vs
if action.action_type == ActionType.SUBMIT_VALIDATION_REPORT:
if s.progress.report_submitted:
vs.append(RuleViolation(
rule_id="duplicate_report",
severity=Severity.HARD,
message="Validation report has already been submitted",
))
return vs
count = s.action_call_counts.get(action.action_type.value, 0)
if count >= 2:
vs.append(RuleViolation(
rule_id=f"redundant_{action.action_type.value}",
severity=Severity.SOFT,
message=(
f"Action '{action.action_type.value}' has already been "
f"executed {count} time(s); further repeats are "
f"redundant"
),
))
return vs
# ── ordering checks ─────────────────────────────────────────────────
def _check_ordering(
self, action: DrugTargetAction, s: FullLatentState
) -> List[RuleViolation]:
vs: List[RuleViolation] = []
p = s.progress
if action.action_type == ActionType.IN_VIVO_MODEL and not p.in_vitro_done:
vs.append(RuleViolation(
rule_id="in_vivo_before_in_vitro",
severity=Severity.SOFT,
message=(
"Running in_vivo_model before in_vitro_assay is "
"scientifically backwards"
),
))
if (
action.action_type == ActionType.TOXICITY_PANEL
and not p.expression_queried
):
vs.append(RuleViolation(
rule_id="toxicity_before_expression",
severity=Severity.SOFT,
message=(
"Toxicity panel before any expression query β€” "
"tissue-specific toxicity will be hard to interpret"
),
))
return vs