drugenv-trainer / server /simulator /output_generator.py
anugrahteesdollar's picture
initial: drugenv trainer control panel
e681925 verified
"""Generate simulated drug-target-validation outputs from latent state."""
from __future__ import annotations
from typing import Any, Dict, List
from models import (
ActionType,
DrugTargetAction,
IntermediateOutput,
OutputType,
)
from .latent_state import FullLatentState, TargetProfile
from .noise import NoiseModel
# Pool of plausible adverse-event tissues used to inject realistic
# false-positive toxicity hits.
_NOISE_TISSUES: List[str] = [
"liver", "kidney", "GI", "skin", "cardiac", "CNS", "lung",
]
class OutputGenerator:
"""Creates structured ``IntermediateOutput`` objects from the hidden
``TargetProfile`` plus a stochastic noise model.
Every action has a dedicated handler that:
- reads relevant fields from the ``TargetProfile``
- applies ``DataQualityState``-driven noise (false positive / false
negative / database coverage)
- returns a typed ``IntermediateOutput`` whose ``data`` dict is the
evidence the agent reasons over.
"""
def __init__(self, noise: NoiseModel):
self.noise = noise
def generate(
self,
action: DrugTargetAction,
state: FullLatentState,
step_index: int,
) -> IntermediateOutput:
handler = _HANDLERS.get(action.action_type, self._default)
out = handler(self, action, state, step_index)
# Database coverage globally reduces quality_score for under-curated
# targets.
coverage = state.data_quality.database_coverage
if coverage < 1.0:
out.quality_score = float(
max(0.0, out.quality_score * (0.5 + 0.5 * coverage))
)
return out
# ── Expression & omics ──────────────────────────────────────────────
def _query_expression(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
flipped = self.noise.coin_flip(s.data_quality.false_positive_rate)
observed_specificity = float(
max(0.0, min(1.0, t.tissue_specificity
+ self.noise.rng.normal(0, s.data_quality.noise_level)))
)
observed_overexpr = float(
max(0.1, t.disease_overexpression
+ self.noise.rng.normal(0, 0.4 * s.data_quality.noise_level))
)
specificity_concern = (t.expression_level == "high_nonspecific")
# Soft summary that *can* mislead when expression is high but
# non-specific.
if t.expression_level in {"high_specific", "high_nonspecific"}:
summary = (
f"{action.parameters.get('database', 'GTEx')}: "
f"{t.expression_level} expression "
f"({observed_overexpr:.2f}Γ— over normal)"
)
else:
summary = (
f"{action.parameters.get('database', 'GTEx')}: "
f"{t.expression_level} expression"
)
return IntermediateOutput(
output_type=OutputType.EXPRESSION_RESULT,
step_index=idx,
quality_score=0.85 if not flipped else 0.55,
summary=summary,
data={
"expression_level": t.expression_level,
"tissue_specificity": round(observed_specificity, 3),
"disease_overexpression": round(observed_overexpr, 2),
"specificity_concern": specificity_concern,
"database": action.parameters.get("database", "GTEx"),
},
uncertainty=0.10 + 0.5 * s.data_quality.noise_level,
artifacts_available=["expression_table"],
)
def _differential_expression(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
log2fc = float(self.noise.rng.normal(
0.0 if t.disease_overexpression < 1.0
else max(0.5, 1.5 * (t.disease_overexpression - 1.0)),
0.4 + s.data_quality.noise_level,
))
n_de_genes = self.noise.sample_count(40 + int(20 * t.disease_overexpression))
return IntermediateOutput(
output_type=OutputType.DE_RESULT,
step_index=idx,
quality_score=0.80,
summary=(
f"DE in {action.parameters.get('cohort', 'TCGA')}: "
f"{t.target if hasattr(t, 'target') else ''} log2FCβ‰ˆ{log2fc:.2f}, "
f"{n_de_genes} co-regulated genes"
),
data={
"target_log2fc": round(log2fc, 3),
"n_de_genes": n_de_genes,
"cohort": action.parameters.get("cohort", "TCGA"),
},
uncertainty=0.15 + s.data_quality.noise_level,
artifacts_available=["de_table"],
)
def _pathway_enrichment(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
# Pathway calls are largely driven by indication-level priors.
pathways = [
{"pathway": "MAPK_signalling", "score": round(0.6 + self.noise.rng.normal(0, 0.1), 3)},
{"pathway": "Cell_cycle", "score": round(0.55 + self.noise.rng.normal(0, 0.1), 3)},
{"pathway": "Apoptosis", "score": round(0.45 + self.noise.rng.normal(0, 0.1), 3)},
{"pathway": "DNA_damage_response", "score": round(0.40 + self.noise.rng.normal(0, 0.1), 3)},
]
return IntermediateOutput(
output_type=OutputType.PATHWAY_RESULT,
step_index=idx,
quality_score=0.70,
summary=f"Pathway enrichment: {len(pathways)} top pathways",
data={"top_pathways": pathways},
uncertainty=0.20,
artifacts_available=["enrichment_table"],
)
def _coexpression_network(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
partners = list(s.target.off_target_genes[:5]) + [
f"PARTNER_{i}" for i in range(2)
]
return IntermediateOutput(
output_type=OutputType.COEXPRESSION_RESULT,
step_index=idx,
quality_score=0.65,
summary=f"{len(partners)} top coexpression partners identified",
data={"partners": partners},
uncertainty=0.25,
artifacts_available=["coexpression_table"],
)
# ── Protein & structure ─────────────────────────────────────────────
def _protein_structure_lookup(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
method = action.parameters.get("method", "AlphaFold")
plddt = float(self.noise.sample_qc_metric(0.78, 0.08, 0.30, 1.0))
return IntermediateOutput(
output_type=OutputType.STRUCTURE_RESULT,
step_index=idx,
quality_score=plddt,
summary=f"{method} structure resolved (pLDDT={plddt:.2f})",
data={
"method": method,
"pLDDT": round(plddt, 3),
"n_residues": int(self.noise.sample_count(420)),
},
uncertainty=1.0 - plddt,
artifacts_available=["pdb_structure"],
)
def _binding_site_analysis(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
include_allosteric = bool(action.parameters.get("include_allosteric", False))
classic_score = {
"excellent": 0.92,
"good": 0.70,
"poor": 0.32,
"undruggable": 0.10,
}[t.binding_pocket_quality]
classic_score = float(self.noise.sample_qc_metric(
classic_score, 0.05, 0.0, 1.0
))
allo_detected = bool(include_allosteric and t.allosteric_site_available)
allo_score = (
float(self.noise.sample_qc_metric(0.65, 0.08, 0.0, 1.0))
if allo_detected else 0.0
)
return IntermediateOutput(
output_type=OutputType.BINDING_SITE_RESULT,
step_index=idx,
quality_score=max(classic_score, allo_score),
summary=(
f"Binding-site analysis: classic_score={classic_score:.2f}"
+ (f", allosteric_site_score={allo_score:.2f}" if allo_detected else "")
),
data={
"binding_pocket_quality": t.binding_pocket_quality,
"classic_score": round(classic_score, 3),
"allosteric_site_detected": allo_detected,
"allosteric_site_score": round(allo_score, 3) if allo_detected else None,
"include_allosteric": include_allosteric,
},
uncertainty=0.12,
artifacts_available=["pocket_table"],
)
def _protein_interaction_network(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
partners = list(s.target.off_target_genes[:6])
return IntermediateOutput(
output_type=OutputType.INTERACTION_RESULT,
step_index=idx,
quality_score=0.70,
summary=f"{len(partners)} high-confidence interactors",
data={
"partners": partners,
"source": action.parameters.get("source", "STRING"),
},
uncertainty=0.20,
artifacts_available=["ppi_network"],
)
def _druggability_screen(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
observed_score = float(self.noise.sample_qc_metric(
t.druggability_score, 0.06, 0.0, 1.0
))
return IntermediateOutput(
output_type=OutputType.DRUGGABILITY_RESULT,
step_index=idx,
quality_score=0.85,
summary=(
f"Druggability score={observed_score:.2f}, "
f"pocket={t.binding_pocket_quality}, "
f"known_ligands={t.has_known_ligands}"
),
data={
"druggability_score": round(observed_score, 3),
"binding_pocket_quality": t.binding_pocket_quality,
"has_known_ligands": t.has_known_ligands,
"n_known_ligands": int(self.noise.sample_count(
20 if t.has_known_ligands else 1
)),
},
uncertainty=0.15,
artifacts_available=["druggability_report"],
)
# ── Clinical & safety ───────────────────────────────────────────────
def _clinical_trial_lookup(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
positive_signals: List[str] = []
negative_signals: List[str] = []
if t.clinical_precedent in {"positive", "mixed"}:
positive_signals.append(
f"Reached {t.clinical_stage_reached or 'preclinical'} with at "
f"least one program"
)
if t.clinical_precedent in {"mixed", "negative"}:
negative_signals.append("Prior failures or withdrawals on record")
if t.clinical_precedent == "negative":
negative_signals.append("No active programs progressing")
return IntermediateOutput(
output_type=OutputType.CLINICAL_RESULT,
step_index=idx,
quality_score=0.85,
summary=(
f"Clinical precedent: {t.clinical_precedent} "
f"(stage={t.clinical_stage_reached})"
),
data={
"clinical_precedent": t.clinical_precedent,
"clinical_stage_reached": t.clinical_stage_reached,
"positive_signals": positive_signals,
"negative_signals": negative_signals,
"competitor_programs": list(t.competitor_programs),
},
uncertainty=0.10,
artifacts_available=["trial_table"],
)
def _toxicity_panel(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
# Higher uncertainty if the agent jumps to toxicity before expression
prereq_met = s.progress.expression_queried
unc = 0.15 if prereq_met else 0.45
toxicity_tissues = list(t.toxicity_tissues)
# False-positive tissue noise
if self.noise.coin_flip(s.data_quality.false_positive_rate):
toxicity_tissues = list(toxicity_tissues) + [
str(self.noise.rng.choice(_NOISE_TISSUES))
]
return IntermediateOutput(
output_type=OutputType.TOXICITY_RESULT,
step_index=idx,
quality_score=0.80 if prereq_met else 0.55,
summary=(
f"Toxicity profile: {t.toxicity_profile}, "
f"flagged tissues: {toxicity_tissues}"
),
data={
"toxicity_profile": t.toxicity_profile,
"toxicity_tissues": toxicity_tissues,
"prerequisite_expression_done": prereq_met,
},
uncertainty=unc,
warnings=[] if prereq_met else [
"Toxicity called without prior expression context β€” "
"interpret with caution"
],
artifacts_available=["toxicity_panel_report"],
)
def _off_target_screen(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
observed_count = max(0, int(self.noise.sample_count(t.off_target_count or 1)))
observed_genes = list(t.off_target_genes[:max(1, observed_count)])
observed_ratio = float(self.noise.sample_qc_metric(
t.selectivity_ratio, 0.5, 0.0, 100.0
))
return IntermediateOutput(
output_type=OutputType.OFF_TARGET_RESULT,
step_index=idx,
quality_score=0.80,
summary=(
f"Off-target screen: selectivity ratio={observed_ratio:.2f}, "
f"{len(observed_genes)} hits"
),
data={
"selectivity_ratio": round(observed_ratio, 3),
"off_target_count": observed_count,
"off_target_genes": observed_genes,
},
uncertainty=0.15,
artifacts_available=["off_target_table"],
)
def _patient_stratification(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
return IntermediateOutput(
output_type=OutputType.PATIENT_STRATIFICATION_RESULT,
step_index=idx,
quality_score=0.78,
summary=(
f"Patient stratification: required={t.requires_patient_stratification}, "
f"biomarker={t.responder_biomarker}"
),
data={
"requires_stratification": t.requires_patient_stratification,
"responder_biomarker": t.responder_biomarker,
"estimated_responder_fraction": round(float(
self.noise.sample_qc_metric(
0.30 if t.requires_patient_stratification else 0.65,
0.10, 0.0, 1.0,
)
), 3),
},
uncertainty=0.20,
artifacts_available=["stratification_report"],
)
# ── Literature & evidence ───────────────────────────────────────────
def _literature_search(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
n_abstracts = int(self.noise.sample_count(4)) + 3
abstracts: List[Dict[str, Any]] = []
for i in range(min(5, n_abstracts)):
abstracts.append({
"title": (
f"Recent perspective on {action.parameters.get('query', 'target')} "
f"({2020 + i % 6})"
),
"snippet": "...findings consistent with a viable program...",
})
# Scenario-specific recent precedent: surface a precedent-changing
# abstract when the current target has positive recent clinical
# precedent reached at least phase 2.
if (
t.clinical_precedent in {"positive", "mixed"}
and t.clinical_stage_reached in {"phase2", "phase3"}
):
abstracts.insert(0, {
"title": (
"Clinical activity of recent inhibitors against this "
"target supports renewed interest"
),
"snippet": (
"...recent programs have demonstrated clinical activity, "
"overturning prior assumptions of undruggability..."
),
})
return IntermediateOutput(
output_type=OutputType.LITERATURE_RESULT,
step_index=idx,
quality_score=0.70,
summary=f"{len(abstracts)} relevant abstracts retrieved",
data={
"abstracts": abstracts,
"query": action.parameters.get("query", ""),
},
uncertainty=0.18,
artifacts_available=["abstract_list"],
)
def _evidence_synthesis(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
# Quality grows with the number of evidence dimensions already covered.
flags = s.progress.model_dump()
covered = sum(1 for k, v in flags.items() if isinstance(v, bool) and v)
quality = float(min(0.85, 0.20 + 0.06 * covered))
return IntermediateOutput(
output_type=OutputType.EVIDENCE_SYNTHESIS_RESULT,
step_index=idx,
quality_score=quality,
summary=f"Evidence synthesis (coverage signal={covered})",
data={
"evidence_signal_count": covered,
"notes": (
"Synthesis is more reliable once multiple evidence "
"dimensions have been investigated."
),
},
uncertainty=max(0.20, 0.80 - 0.06 * covered),
artifacts_available=["synthesis_report"],
)
def _competitor_landscape(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
return IntermediateOutput(
output_type=OutputType.COMPETITOR_LANDSCAPE_RESULT,
step_index=idx,
quality_score=0.75,
summary=f"{len(t.competitor_programs)} competitor programs identified",
data={
"competitor_programs": list(t.competitor_programs),
"clinical_precedent": t.clinical_precedent,
},
uncertainty=0.15,
artifacts_available=["competitor_report"],
)
# ── Experimental ───────────────────────────────────────────────────
def _in_vitro_assay(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
ic50 = float(self.noise.sample_qc_metric(
t.in_vitro_ic50_nM, 0.2 * t.in_vitro_ic50_nM, 0.5, 100_000.0
))
sel_window = float(self.noise.sample_qc_metric(
t.selectivity_ratio, 0.4, 0.0, 100.0
))
viability_drop = float(self.noise.sample_qc_metric(
0.5 if t.in_vivo_efficacy in {"strong", "moderate"} else 0.2,
0.1, 0.0, 1.0,
))
return IntermediateOutput(
output_type=OutputType.IN_VITRO_RESULT,
step_index=idx,
quality_score=0.85,
summary=(
f"In-vitro: IC50={ic50:.1f} nM, selectivity_window={sel_window:.2f}, "
f"viability_drop={viability_drop:.2f}"
),
data={
"IC50_nM": round(ic50, 2),
"selectivity_window": round(sel_window, 3),
"viability_drop": round(viability_drop, 3),
},
uncertainty=0.18,
artifacts_available=["in_vitro_report"],
)
def _in_vivo_model(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
efficacy_score = {
"strong": 0.85, "moderate": 0.55, "weak": 0.25, "none": 0.05,
}.get(t.in_vivo_efficacy, 0.5)
efficacy = float(self.noise.sample_qc_metric(efficacy_score, 0.08, 0.0, 1.0))
tolerability = float(self.noise.sample_qc_metric(
{"clean": 0.9, "mild": 0.75, "moderate": 0.5, "severe": 0.25}
.get(t.toxicity_profile, 0.6),
0.08, 0.0, 1.0,
))
return IntermediateOutput(
output_type=OutputType.IN_VIVO_RESULT,
step_index=idx,
quality_score=0.85,
summary=(
f"In-vivo: efficacy={efficacy:.2f}, tolerability={tolerability:.2f}"
),
data={
"efficacy_endpoint": round(efficacy, 3),
"tolerability": round(tolerability, 3),
"PK_PD_summary": {
"halflife_hours": round(float(
self.noise.sample_qc_metric(8.0, 2.0, 0.5, 48.0)
), 2),
"Cmax_nM": round(float(
self.noise.sample_qc_metric(500.0, 150.0, 1.0, 5000.0)
), 2),
},
},
uncertainty=0.20,
artifacts_available=["in_vivo_report"],
)
def _crispr_knockout(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
ess = float(self.noise.sample_qc_metric(
t.crispr_essentiality, 0.15, -3.0, 1.0
))
synthetic_lethal = list(t.off_target_genes[:3])
return IntermediateOutput(
output_type=OutputType.CRISPR_RESULT,
step_index=idx,
quality_score=0.80,
summary=(
f"CRISPR essentiality score={ess:.2f}; "
f"{len(synthetic_lethal)} synthetic-lethal candidates"
),
data={
"essentiality_score": round(ess, 3),
"synthetic_lethal_partners": synthetic_lethal,
},
uncertainty=0.18,
artifacts_available=["crispr_report"],
)
def _biomarker_correlation(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
t = s.target
corr = float(self.noise.sample_qc_metric(
0.6 if t.responder_biomarker else 0.2, 0.12, -1.0, 1.0,
))
return IntermediateOutput(
output_type=OutputType.BIOMARKER_RESULT,
step_index=idx,
quality_score=0.78,
summary=(
f"Biomarker correlation r={corr:.2f} "
f"({t.responder_biomarker or 'no_biomarker'})"
),
data={
"biomarker": t.responder_biomarker,
"correlation": round(corr, 3),
},
uncertainty=0.22,
artifacts_available=["biomarker_report"],
)
# ── Meta ────────────────────────────────────────────────────────────
def _flag_red_flag(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
note = str(action.parameters.get("note", "(no detail)"))
return IntermediateOutput(
output_type=OutputType.RED_FLAG_NOTE,
step_index=idx,
quality_score=1.0,
summary=f"Red flag recorded: {note[:80]}",
data={"note": note},
uncertainty=0.0,
artifacts_available=["dossier_red_flag"],
)
def _request_expert_review(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
flags = s.progress.model_dump()
covered = sum(1 for k, v in flags.items() if isinstance(v, bool) and v)
quality = float(min(0.75, 0.20 + 0.05 * covered))
return IntermediateOutput(
output_type=OutputType.EXPERT_REVIEW,
step_index=idx,
quality_score=quality,
summary=(
f"Expert review (coverage signal={covered})"
),
data={
"evidence_signal_count": covered,
"review": (
"Review more meaningful when more evidence dimensions "
"have been opened."
),
},
uncertainty=max(0.25, 0.80 - 0.05 * covered),
artifacts_available=["expert_review_note"],
)
def _submit_validation_report(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
decision = action.final_decision or "no_decision"
confidence = float(action.confidence) if action.confidence is not None else 0.0
return IntermediateOutput(
output_type=OutputType.VALIDATION_REPORT,
step_index=idx,
quality_score=1.0,
summary=(
f"Validation report submitted: decision={decision}, "
f"confidence={confidence:.2f}"
),
data={
"decision": decision,
"confidence": confidence,
"reasoning": action.reasoning or "",
},
uncertainty=0.0,
artifacts_available=["validation_report"],
)
# ── Default ────────────────────────────────────────────────────────
def _default(
self, action: DrugTargetAction, s: FullLatentState, idx: int
) -> IntermediateOutput:
return IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=idx,
success=False,
summary=f"Unhandled action type: {action.action_type}",
data={},
)
_HANDLERS = {
ActionType.QUERY_EXPRESSION: OutputGenerator._query_expression,
ActionType.DIFFERENTIAL_EXPRESSION: OutputGenerator._differential_expression,
ActionType.PATHWAY_ENRICHMENT: OutputGenerator._pathway_enrichment,
ActionType.COEXPRESSION_NETWORK: OutputGenerator._coexpression_network,
ActionType.PROTEIN_STRUCTURE_LOOKUP: OutputGenerator._protein_structure_lookup,
ActionType.BINDING_SITE_ANALYSIS: OutputGenerator._binding_site_analysis,
ActionType.PROTEIN_INTERACTION_NETWORK: OutputGenerator._protein_interaction_network,
ActionType.DRUGGABILITY_SCREEN: OutputGenerator._druggability_screen,
ActionType.CLINICAL_TRIAL_LOOKUP: OutputGenerator._clinical_trial_lookup,
ActionType.TOXICITY_PANEL: OutputGenerator._toxicity_panel,
ActionType.OFF_TARGET_SCREEN: OutputGenerator._off_target_screen,
ActionType.PATIENT_STRATIFICATION: OutputGenerator._patient_stratification,
ActionType.LITERATURE_SEARCH: OutputGenerator._literature_search,
ActionType.EVIDENCE_SYNTHESIS: OutputGenerator._evidence_synthesis,
ActionType.COMPETITOR_LANDSCAPE: OutputGenerator._competitor_landscape,
ActionType.IN_VITRO_ASSAY: OutputGenerator._in_vitro_assay,
ActionType.IN_VIVO_MODEL: OutputGenerator._in_vivo_model,
ActionType.CRISPR_KNOCKOUT: OutputGenerator._crispr_knockout,
ActionType.BIOMARKER_CORRELATION: OutputGenerator._biomarker_correlation,
ActionType.FLAG_RED_FLAG: OutputGenerator._flag_red_flag,
ActionType.REQUEST_EXPERT_REVIEW: OutputGenerator._request_expert_review,
ActionType.SUBMIT_VALIDATION_REPORT: OutputGenerator._submit_validation_report,
}