"""Transition dynamics engine for the drug-target-validation simulator. Orchestrates latent-state updates, output generation, credit accounting, and constraint propagation for every agent action. """ from __future__ import annotations from copy import deepcopy from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple from models import ( ActionType, DrugTargetAction, IntermediateOutput, OutputType, ) from .latent_state import FullLatentState from .noise import NoiseModel from .output_generator import OutputGenerator # Credit costs per ActionType. _BASE_ACTION_COSTS: Dict[ActionType, int] = { ActionType.QUERY_EXPRESSION: 2, ActionType.DIFFERENTIAL_EXPRESSION: 2, ActionType.PATHWAY_ENRICHMENT: 2, ActionType.COEXPRESSION_NETWORK: 2, ActionType.PROTEIN_STRUCTURE_LOOKUP: 3, ActionType.BINDING_SITE_ANALYSIS: 3, ActionType.PROTEIN_INTERACTION_NETWORK: 2, ActionType.DRUGGABILITY_SCREEN: 3, ActionType.CLINICAL_TRIAL_LOOKUP: 3, ActionType.TOXICITY_PANEL: 3, ActionType.OFF_TARGET_SCREEN: 3, ActionType.PATIENT_STRATIFICATION: 3, ActionType.LITERATURE_SEARCH: 1, ActionType.EVIDENCE_SYNTHESIS: 1, ActionType.COMPETITOR_LANDSCAPE: 1, ActionType.IN_VITRO_ASSAY: 5, ActionType.IN_VIVO_MODEL: 8, ActionType.CRISPR_KNOCKOUT: 4, ActionType.BIOMARKER_CORRELATION: 3, ActionType.FLAG_RED_FLAG: 0, ActionType.REQUEST_EXPERT_REVIEW: 1, ActionType.SUBMIT_VALIDATION_REPORT: 0, } # Public alias kept for callers that historically imported ACTION_COSTS. ACTION_COSTS = _BASE_ACTION_COSTS def compute_action_cost(action: DrugTargetAction) -> int: """Return the credit cost for a single action.""" return _BASE_ACTION_COSTS.get(action.action_type, 0) # Map action type → progress flag that should be set when it succeeds. _PROGRESS_MAP: Dict[ActionType, str] = { ActionType.QUERY_EXPRESSION: "expression_queried", ActionType.DIFFERENTIAL_EXPRESSION: "expression_queried", ActionType.PATHWAY_ENRICHMENT: "pathway_analysed", ActionType.COEXPRESSION_NETWORK: "interactions_mapped", ActionType.PROTEIN_STRUCTURE_LOOKUP: "structure_resolved", ActionType.BINDING_SITE_ANALYSIS: "druggability_assessed", ActionType.PROTEIN_INTERACTION_NETWORK: "interactions_mapped", ActionType.DRUGGABILITY_SCREEN: "druggability_assessed", ActionType.CLINICAL_TRIAL_LOOKUP: "clinical_checked", ActionType.TOXICITY_PANEL: "toxicity_assessed", ActionType.OFF_TARGET_SCREEN: "selectivity_checked", ActionType.PATIENT_STRATIFICATION: "patient_stratification_done", ActionType.LITERATURE_SEARCH: "literature_reviewed", ActionType.EVIDENCE_SYNTHESIS: "evidence_synthesised", ActionType.COMPETITOR_LANDSCAPE: "literature_reviewed", ActionType.IN_VITRO_ASSAY: "in_vitro_done", ActionType.IN_VIVO_MODEL: "in_vivo_done", ActionType.CRISPR_KNOCKOUT: "crispr_done", ActionType.BIOMARKER_CORRELATION: "biomarker_correlated", ActionType.REQUEST_EXPERT_REVIEW: "expert_reviewed", ActionType.SUBMIT_VALIDATION_REPORT: "report_submitted", } @dataclass class TransitionResult: """Bundle returned by the transition engine after one step.""" next_state: FullLatentState output: IntermediateOutput reward_components: Dict[str, float] = field(default_factory=dict) hard_violations: List[str] = field(default_factory=list) soft_violations: List[str] = field(default_factory=list) done: bool = False class TransitionEngine: """Applies one action to the latent state, producing the next state and a simulated intermediate output. Delegates output generation to ``OutputGenerator``. """ def __init__(self, noise: NoiseModel): self.noise = noise self.output_gen = OutputGenerator(noise) def step( self, state: FullLatentState, action: DrugTargetAction, *, hard_violations: Optional[List[str]] = None, soft_violations: Optional[List[str]] = None, ) -> TransitionResult: s = deepcopy(state) step_idx = sum(s.action_call_counts.values()) + 1 hard_v = hard_violations or [] soft_v = soft_violations or [] if hard_v: output = IntermediateOutput( output_type=OutputType.FAILURE_REPORT, step_index=step_idx, success=False, summary=f"Action blocked: {'; '.join(hard_v)}", ) done = action.action_type == ActionType.SUBMIT_VALIDATION_REPORT return TransitionResult( next_state=s, output=output, hard_violations=hard_v, soft_violations=soft_v, done=done, ) # Track call counts before deduction so the rule engine can use # them when reasoning about redundancy on the next step. key = action.action_type.value s.action_call_counts[key] = s.action_call_counts.get(key, 0) + 1 # Deduct credits. cost = compute_action_cost(action) s.credits.credits_used += cost # If credits exhausted *and* this isn't a terminal report, the # episode ends with a failure-style output (the caller still # records the action). credits_exhausted_after = s.credits.exhausted # Generate the simulated output. output = self.output_gen.generate(action, s, step_idx) if soft_v: output.quality_score = float(max(0.0, output.quality_score * 0.7)) output.warnings = list(output.warnings) + list(soft_v) # Update progress flags for successful actions. flag = _PROGRESS_MAP.get(action.action_type) if flag and output.success: setattr(s.progress, flag, True) # Determine episode termination. done = ( action.action_type == ActionType.SUBMIT_VALIDATION_REPORT or credits_exhausted_after ) return TransitionResult( next_state=s, output=output, soft_violations=soft_v, done=done, ) @staticmethod def covered_evidence_dimensions(s: FullLatentState) -> List[str]: """Return the set of *evidence dimensions* the agent has touched. Mirrors the keys used in ``TargetProfile.key_evidence_dimensions`` so the reward computer can compute coverage directly. """ p = s.progress flags: List[Tuple[str, bool]] = [ ("expression", p.expression_queried), ("druggability", p.druggability_assessed), ("off_target", p.selectivity_checked), ("toxicity", p.toxicity_assessed), ("clinical", p.clinical_checked), ("literature", p.literature_reviewed), ("in_vitro", p.in_vitro_done), ("in_vivo", p.in_vivo_done), ("patient_stratification", p.patient_stratification_done), ("pathway", p.pathway_analysed), ("structure", p.structure_resolved), ("interactions", p.interactions_mapped), ("crispr", p.crispr_done), ("biomarker", p.biomarker_correlated), ] return [name for name, hit in flags if hit]