Spaces:
Runtime error
Runtime error
| """Transition dynamics engine for the drug-target-validation simulator. | |
| Orchestrates latent-state updates, output generation, credit accounting, | |
| and constraint propagation for every agent action. | |
| """ | |
| from __future__ import annotations | |
| from copy import deepcopy | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Tuple | |
| from models import ( | |
| ActionType, | |
| DrugTargetAction, | |
| IntermediateOutput, | |
| OutputType, | |
| ) | |
| from .latent_state import FullLatentState | |
| from .noise import NoiseModel | |
| from .output_generator import OutputGenerator | |
| # Credit costs per ActionType. | |
| _BASE_ACTION_COSTS: Dict[ActionType, int] = { | |
| ActionType.QUERY_EXPRESSION: 2, | |
| ActionType.DIFFERENTIAL_EXPRESSION: 2, | |
| ActionType.PATHWAY_ENRICHMENT: 2, | |
| ActionType.COEXPRESSION_NETWORK: 2, | |
| ActionType.PROTEIN_STRUCTURE_LOOKUP: 3, | |
| ActionType.BINDING_SITE_ANALYSIS: 3, | |
| ActionType.PROTEIN_INTERACTION_NETWORK: 2, | |
| ActionType.DRUGGABILITY_SCREEN: 3, | |
| ActionType.CLINICAL_TRIAL_LOOKUP: 3, | |
| ActionType.TOXICITY_PANEL: 3, | |
| ActionType.OFF_TARGET_SCREEN: 3, | |
| ActionType.PATIENT_STRATIFICATION: 3, | |
| ActionType.LITERATURE_SEARCH: 1, | |
| ActionType.EVIDENCE_SYNTHESIS: 1, | |
| ActionType.COMPETITOR_LANDSCAPE: 1, | |
| ActionType.IN_VITRO_ASSAY: 5, | |
| ActionType.IN_VIVO_MODEL: 8, | |
| ActionType.CRISPR_KNOCKOUT: 4, | |
| ActionType.BIOMARKER_CORRELATION: 3, | |
| ActionType.FLAG_RED_FLAG: 0, | |
| ActionType.REQUEST_EXPERT_REVIEW: 1, | |
| ActionType.SUBMIT_VALIDATION_REPORT: 0, | |
| } | |
| # Public alias kept for callers that historically imported ACTION_COSTS. | |
| ACTION_COSTS = _BASE_ACTION_COSTS | |
| def compute_action_cost(action: DrugTargetAction) -> int: | |
| """Return the credit cost for a single action.""" | |
| return _BASE_ACTION_COSTS.get(action.action_type, 0) | |
| # Map action type → progress flag that should be set when it succeeds. | |
| _PROGRESS_MAP: Dict[ActionType, str] = { | |
| ActionType.QUERY_EXPRESSION: "expression_queried", | |
| ActionType.DIFFERENTIAL_EXPRESSION: "expression_queried", | |
| ActionType.PATHWAY_ENRICHMENT: "pathway_analysed", | |
| ActionType.COEXPRESSION_NETWORK: "interactions_mapped", | |
| ActionType.PROTEIN_STRUCTURE_LOOKUP: "structure_resolved", | |
| ActionType.BINDING_SITE_ANALYSIS: "druggability_assessed", | |
| ActionType.PROTEIN_INTERACTION_NETWORK: "interactions_mapped", | |
| ActionType.DRUGGABILITY_SCREEN: "druggability_assessed", | |
| ActionType.CLINICAL_TRIAL_LOOKUP: "clinical_checked", | |
| ActionType.TOXICITY_PANEL: "toxicity_assessed", | |
| ActionType.OFF_TARGET_SCREEN: "selectivity_checked", | |
| ActionType.PATIENT_STRATIFICATION: "patient_stratification_done", | |
| ActionType.LITERATURE_SEARCH: "literature_reviewed", | |
| ActionType.EVIDENCE_SYNTHESIS: "evidence_synthesised", | |
| ActionType.COMPETITOR_LANDSCAPE: "literature_reviewed", | |
| ActionType.IN_VITRO_ASSAY: "in_vitro_done", | |
| ActionType.IN_VIVO_MODEL: "in_vivo_done", | |
| ActionType.CRISPR_KNOCKOUT: "crispr_done", | |
| ActionType.BIOMARKER_CORRELATION: "biomarker_correlated", | |
| ActionType.REQUEST_EXPERT_REVIEW: "expert_reviewed", | |
| ActionType.SUBMIT_VALIDATION_REPORT: "report_submitted", | |
| } | |
| class TransitionResult: | |
| """Bundle returned by the transition engine after one step.""" | |
| next_state: FullLatentState | |
| output: IntermediateOutput | |
| reward_components: Dict[str, float] = field(default_factory=dict) | |
| hard_violations: List[str] = field(default_factory=list) | |
| soft_violations: List[str] = field(default_factory=list) | |
| done: bool = False | |
| class TransitionEngine: | |
| """Applies one action to the latent state, producing the next state and | |
| a simulated intermediate output. Delegates output generation to | |
| ``OutputGenerator``. | |
| """ | |
| def __init__(self, noise: NoiseModel): | |
| self.noise = noise | |
| self.output_gen = OutputGenerator(noise) | |
| def step( | |
| self, | |
| state: FullLatentState, | |
| action: DrugTargetAction, | |
| *, | |
| hard_violations: Optional[List[str]] = None, | |
| soft_violations: Optional[List[str]] = None, | |
| ) -> TransitionResult: | |
| s = deepcopy(state) | |
| step_idx = sum(s.action_call_counts.values()) + 1 | |
| hard_v = hard_violations or [] | |
| soft_v = soft_violations or [] | |
| if hard_v: | |
| output = IntermediateOutput( | |
| output_type=OutputType.FAILURE_REPORT, | |
| step_index=step_idx, | |
| success=False, | |
| summary=f"Action blocked: {'; '.join(hard_v)}", | |
| ) | |
| done = action.action_type == ActionType.SUBMIT_VALIDATION_REPORT | |
| return TransitionResult( | |
| next_state=s, | |
| output=output, | |
| hard_violations=hard_v, | |
| soft_violations=soft_v, | |
| done=done, | |
| ) | |
| # Track call counts before deduction so the rule engine can use | |
| # them when reasoning about redundancy on the next step. | |
| key = action.action_type.value | |
| s.action_call_counts[key] = s.action_call_counts.get(key, 0) + 1 | |
| # Deduct credits. | |
| cost = compute_action_cost(action) | |
| s.credits.credits_used += cost | |
| # If credits exhausted *and* this isn't a terminal report, the | |
| # episode ends with a failure-style output (the caller still | |
| # records the action). | |
| credits_exhausted_after = s.credits.exhausted | |
| # Generate the simulated output. | |
| output = self.output_gen.generate(action, s, step_idx) | |
| if soft_v: | |
| output.quality_score = float(max(0.0, output.quality_score * 0.7)) | |
| output.warnings = list(output.warnings) + list(soft_v) | |
| # Update progress flags for successful actions. | |
| flag = _PROGRESS_MAP.get(action.action_type) | |
| if flag and output.success: | |
| setattr(s.progress, flag, True) | |
| # Determine episode termination. | |
| done = ( | |
| action.action_type == ActionType.SUBMIT_VALIDATION_REPORT | |
| or credits_exhausted_after | |
| ) | |
| return TransitionResult( | |
| next_state=s, | |
| output=output, | |
| soft_violations=soft_v, | |
| done=done, | |
| ) | |
| def covered_evidence_dimensions(s: FullLatentState) -> List[str]: | |
| """Return the set of *evidence dimensions* the agent has touched. | |
| Mirrors the keys used in ``TargetProfile.key_evidence_dimensions`` | |
| so the reward computer can compute coverage directly. | |
| """ | |
| p = s.progress | |
| flags: List[Tuple[str, bool]] = [ | |
| ("expression", p.expression_queried), | |
| ("druggability", p.druggability_assessed), | |
| ("off_target", p.selectivity_checked), | |
| ("toxicity", p.toxicity_assessed), | |
| ("clinical", p.clinical_checked), | |
| ("literature", p.literature_reviewed), | |
| ("in_vitro", p.in_vitro_done), | |
| ("in_vivo", p.in_vivo_done), | |
| ("patient_stratification", p.patient_stratification_done), | |
| ("pathway", p.pathway_analysed), | |
| ("structure", p.structure_resolved), | |
| ("interactions", p.interactions_mapped), | |
| ("crispr", p.crispr_done), | |
| ("biomarker", p.biomarker_correlated), | |
| ] | |
| return [name for name, hit in flags if hit] | |