drugenv-trainer / server /simulator /transition.py
anugrahteesdollar's picture
initial: drugenv trainer control panel
e681925 verified
"""Transition dynamics engine for the drug-target-validation simulator.
Orchestrates latent-state updates, output generation, credit accounting,
and constraint propagation for every agent action.
"""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from models import (
ActionType,
DrugTargetAction,
IntermediateOutput,
OutputType,
)
from .latent_state import FullLatentState
from .noise import NoiseModel
from .output_generator import OutputGenerator
# Credit costs per ActionType.
_BASE_ACTION_COSTS: Dict[ActionType, int] = {
ActionType.QUERY_EXPRESSION: 2,
ActionType.DIFFERENTIAL_EXPRESSION: 2,
ActionType.PATHWAY_ENRICHMENT: 2,
ActionType.COEXPRESSION_NETWORK: 2,
ActionType.PROTEIN_STRUCTURE_LOOKUP: 3,
ActionType.BINDING_SITE_ANALYSIS: 3,
ActionType.PROTEIN_INTERACTION_NETWORK: 2,
ActionType.DRUGGABILITY_SCREEN: 3,
ActionType.CLINICAL_TRIAL_LOOKUP: 3,
ActionType.TOXICITY_PANEL: 3,
ActionType.OFF_TARGET_SCREEN: 3,
ActionType.PATIENT_STRATIFICATION: 3,
ActionType.LITERATURE_SEARCH: 1,
ActionType.EVIDENCE_SYNTHESIS: 1,
ActionType.COMPETITOR_LANDSCAPE: 1,
ActionType.IN_VITRO_ASSAY: 5,
ActionType.IN_VIVO_MODEL: 8,
ActionType.CRISPR_KNOCKOUT: 4,
ActionType.BIOMARKER_CORRELATION: 3,
ActionType.FLAG_RED_FLAG: 0,
ActionType.REQUEST_EXPERT_REVIEW: 1,
ActionType.SUBMIT_VALIDATION_REPORT: 0,
}
# Public alias kept for callers that historically imported ACTION_COSTS.
ACTION_COSTS = _BASE_ACTION_COSTS
def compute_action_cost(action: DrugTargetAction) -> int:
"""Return the credit cost for a single action."""
return _BASE_ACTION_COSTS.get(action.action_type, 0)
# Map action type → progress flag that should be set when it succeeds.
_PROGRESS_MAP: Dict[ActionType, str] = {
ActionType.QUERY_EXPRESSION: "expression_queried",
ActionType.DIFFERENTIAL_EXPRESSION: "expression_queried",
ActionType.PATHWAY_ENRICHMENT: "pathway_analysed",
ActionType.COEXPRESSION_NETWORK: "interactions_mapped",
ActionType.PROTEIN_STRUCTURE_LOOKUP: "structure_resolved",
ActionType.BINDING_SITE_ANALYSIS: "druggability_assessed",
ActionType.PROTEIN_INTERACTION_NETWORK: "interactions_mapped",
ActionType.DRUGGABILITY_SCREEN: "druggability_assessed",
ActionType.CLINICAL_TRIAL_LOOKUP: "clinical_checked",
ActionType.TOXICITY_PANEL: "toxicity_assessed",
ActionType.OFF_TARGET_SCREEN: "selectivity_checked",
ActionType.PATIENT_STRATIFICATION: "patient_stratification_done",
ActionType.LITERATURE_SEARCH: "literature_reviewed",
ActionType.EVIDENCE_SYNTHESIS: "evidence_synthesised",
ActionType.COMPETITOR_LANDSCAPE: "literature_reviewed",
ActionType.IN_VITRO_ASSAY: "in_vitro_done",
ActionType.IN_VIVO_MODEL: "in_vivo_done",
ActionType.CRISPR_KNOCKOUT: "crispr_done",
ActionType.BIOMARKER_CORRELATION: "biomarker_correlated",
ActionType.REQUEST_EXPERT_REVIEW: "expert_reviewed",
ActionType.SUBMIT_VALIDATION_REPORT: "report_submitted",
}
@dataclass
class TransitionResult:
"""Bundle returned by the transition engine after one step."""
next_state: FullLatentState
output: IntermediateOutput
reward_components: Dict[str, float] = field(default_factory=dict)
hard_violations: List[str] = field(default_factory=list)
soft_violations: List[str] = field(default_factory=list)
done: bool = False
class TransitionEngine:
"""Applies one action to the latent state, producing the next state and
a simulated intermediate output. Delegates output generation to
``OutputGenerator``.
"""
def __init__(self, noise: NoiseModel):
self.noise = noise
self.output_gen = OutputGenerator(noise)
def step(
self,
state: FullLatentState,
action: DrugTargetAction,
*,
hard_violations: Optional[List[str]] = None,
soft_violations: Optional[List[str]] = None,
) -> TransitionResult:
s = deepcopy(state)
step_idx = sum(s.action_call_counts.values()) + 1
hard_v = hard_violations or []
soft_v = soft_violations or []
if hard_v:
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=step_idx,
success=False,
summary=f"Action blocked: {'; '.join(hard_v)}",
)
done = action.action_type == ActionType.SUBMIT_VALIDATION_REPORT
return TransitionResult(
next_state=s,
output=output,
hard_violations=hard_v,
soft_violations=soft_v,
done=done,
)
# Track call counts before deduction so the rule engine can use
# them when reasoning about redundancy on the next step.
key = action.action_type.value
s.action_call_counts[key] = s.action_call_counts.get(key, 0) + 1
# Deduct credits.
cost = compute_action_cost(action)
s.credits.credits_used += cost
# If credits exhausted *and* this isn't a terminal report, the
# episode ends with a failure-style output (the caller still
# records the action).
credits_exhausted_after = s.credits.exhausted
# Generate the simulated output.
output = self.output_gen.generate(action, s, step_idx)
if soft_v:
output.quality_score = float(max(0.0, output.quality_score * 0.7))
output.warnings = list(output.warnings) + list(soft_v)
# Update progress flags for successful actions.
flag = _PROGRESS_MAP.get(action.action_type)
if flag and output.success:
setattr(s.progress, flag, True)
# Determine episode termination.
done = (
action.action_type == ActionType.SUBMIT_VALIDATION_REPORT
or credits_exhausted_after
)
return TransitionResult(
next_state=s,
output=output,
soft_violations=soft_v,
done=done,
)
@staticmethod
def covered_evidence_dimensions(s: FullLatentState) -> List[str]:
"""Return the set of *evidence dimensions* the agent has touched.
Mirrors the keys used in ``TargetProfile.key_evidence_dimensions``
so the reward computer can compute coverage directly.
"""
p = s.progress
flags: List[Tuple[str, bool]] = [
("expression", p.expression_queried),
("druggability", p.druggability_assessed),
("off_target", p.selectivity_checked),
("toxicity", p.toxicity_assessed),
("clinical", p.clinical_checked),
("literature", p.literature_reviewed),
("in_vitro", p.in_vitro_done),
("in_vivo", p.in_vivo_done),
("patient_stratification", p.patient_stratification_done),
("pathway", p.pathway_analysed),
("structure", p.structure_resolved),
("interactions", p.interactions_mapped),
("crispr", p.crispr_done),
("biomarker", p.biomarker_correlated),
]
return [name for name, hit in flags if hit]