drugenv / server /tasks /procedural_generator.py
anugrahteesdollar's picture
initial: drugenv FastAPI + gradio demo
77e1e28 verified
"""Procedural drug-target-validation scenario generator.
Composes coherent ``Scenario`` objects by sampling from a pool of real
cancer targets and disease contexts and bundling them with an internally
consistent ``TargetProfile`` (viable vs non-viable bundles).
"""
from __future__ import annotations
import logging
from typing import List, Optional
import numpy as np
from models import ValidationTaskSpec
from server.simulator.latent_state import (
DataQualityState,
TargetProfile,
)
from .scenarios import Scenario
logger = logging.getLogger(__name__)
_TARGET_POOL: List[str] = [
"BRAF", "MET", "FGFR1", "PIK3CA", "AKT1", "CDK4", "MDM2", "BCL2",
"PARP1", "IDH1", "IDH2", "FLT3", "JAK2", "BTK", "MTOR", "ALK",
"ROS1", "KIT", "ERBB2", "ABL1",
]
_DISEASE_POOL: List[str] = [
"non-small cell lung cancer",
"colorectal cancer",
"melanoma",
"acute myeloid leukemia",
"chronic myeloid leukemia",
"glioblastoma",
"breast cancer",
"ovarian cancer",
]
_DIFFICULTY_PARAMS = {
"easy": {
"noise_level": (0.05, 0.10),
"false_positive_rate": (0.02, 0.05),
"false_negative_rate": (0.02, 0.05),
"database_coverage": (0.90, 1.0),
"credits_limit": (45, 60),
"viable_prob": 0.65,
"n_key_evidence": (1, 2),
"misleading_prob": 0.0,
},
"medium": {
"noise_level": (0.08, 0.15),
"false_positive_rate": (0.04, 0.08),
"false_negative_rate": (0.04, 0.08),
"database_coverage": (0.80, 0.95),
"credits_limit": (40, 55),
"viable_prob": 0.50,
"n_key_evidence": (2, 3),
"misleading_prob": 0.20,
},
"hard": {
"noise_level": (0.12, 0.22),
"false_positive_rate": (0.06, 0.12),
"false_negative_rate": (0.06, 0.12),
"database_coverage": (0.65, 0.90),
"credits_limit": (35, 50),
"viable_prob": 0.45,
"n_key_evidence": (3, 4),
"misleading_prob": 0.50,
},
}
def _build_viable_target(rng: np.random.Generator) -> TargetProfile:
return TargetProfile(
expression_level=str(rng.choice(["high_specific", "moderate"])),
tissue_specificity=float(rng.uniform(0.55, 0.90)),
disease_overexpression=float(rng.uniform(2.0, 5.0)),
druggability_score=float(rng.uniform(0.55, 0.90)),
binding_pocket_quality=str(rng.choice(["excellent", "good"])),
has_known_ligands=True,
allosteric_site_available=bool(rng.choice([True, False])),
selectivity_ratio=float(rng.uniform(5.0, 20.0)),
off_target_count=int(rng.integers(0, 4)),
off_target_genes=[],
toxicity_profile=str(rng.choice(["clean", "mild", "moderate"])),
toxicity_tissues=[],
clinical_precedent=str(rng.choice(["positive", "mixed"])),
clinical_stage_reached=str(rng.choice(["phase1", "phase2", "phase3"])),
competitor_programs=[],
requires_patient_stratification=bool(rng.choice([True, False])),
responder_biomarker=None,
in_vitro_ic50_nM=float(rng.uniform(2.0, 100.0)),
in_vivo_efficacy=str(rng.choice(["strong", "moderate"])),
crispr_essentiality=float(rng.uniform(-1.5, -0.5)),
true_viability_score=float(rng.uniform(0.65, 0.90)),
correct_decision="go",
)
def _build_nonviable_target(rng: np.random.Generator) -> TargetProfile:
return TargetProfile(
expression_level=str(rng.choice(["high_nonspecific", "low", "moderate"])),
tissue_specificity=float(rng.uniform(0.10, 0.45)),
disease_overexpression=float(rng.uniform(0.5, 1.8)),
druggability_score=float(rng.uniform(0.05, 0.40)),
binding_pocket_quality=str(rng.choice(["poor", "undruggable"])),
has_known_ligands=False,
allosteric_site_available=False,
selectivity_ratio=float(rng.uniform(0.5, 3.0)),
off_target_count=int(rng.integers(5, 12)),
off_target_genes=[f"OFF_{i}" for i in range(int(rng.integers(2, 6)))],
toxicity_profile=str(rng.choice(["moderate", "severe"])),
toxicity_tissues=[
str(rng.choice(["liver", "kidney", "cardiac", "CNS", "GI"]))
],
clinical_precedent=str(rng.choice(["negative", "none", "mixed"])),
clinical_stage_reached=None,
competitor_programs=[],
requires_patient_stratification=False,
responder_biomarker=None,
in_vitro_ic50_nM=float(rng.uniform(500.0, 10_000.0)),
in_vivo_efficacy=str(rng.choice(["weak", "none"])),
crispr_essentiality=float(rng.uniform(-0.3, 0.3)),
true_viability_score=float(rng.uniform(0.05, 0.35)),
correct_decision="no_go",
)
_DIMENSION_POOL: List[str] = [
"expression",
"druggability",
"off_target",
"toxicity",
"clinical",
"literature",
"in_vitro",
"in_vivo",
"patient_stratification",
]
def generate_scenario(
seed: int,
difficulty: str = "medium",
) -> Scenario:
"""Generate a single procedural scenario with complete latent state."""
rng = np.random.default_rng(seed)
params = _DIFFICULTY_PARAMS[difficulty]
target_gene = str(rng.choice(_TARGET_POOL))
disease = str(rng.choice(_DISEASE_POOL))
if rng.random() < params["viable_prob"]:
target = _build_viable_target(rng)
else:
target = _build_nonviable_target(rng)
n_key = int(rng.integers(*params["n_key_evidence"]))
target.key_evidence_dimensions = list(
rng.choice(_DIMENSION_POOL, size=min(n_key, len(_DIMENSION_POOL)),
replace=False)
)
if rng.random() < params["misleading_prob"]:
target.misleading_signals = [
"high_expression_looks_positive"
if target.correct_decision == "no_go"
else "historical_undruggability"
]
data_quality = DataQualityState(
noise_level=round(float(rng.uniform(*params["noise_level"])), 3),
false_positive_rate=round(
float(rng.uniform(*params["false_positive_rate"])), 3
),
false_negative_rate=round(
float(rng.uniform(*params["false_negative_rate"])), 3
),
database_coverage=round(
float(rng.uniform(*params["database_coverage"])), 3
),
)
credits_limit = int(rng.integers(*params["credits_limit"]))
task = ValidationTaskSpec(
problem_statement=(
f"Validate {target_gene} as a drug target in {disease}."
),
target_gene=target_gene,
disease_context=disease,
indication=f"{target_gene}-driven {disease}",
credits_limit=credits_limit,
success_criteria=[
f"Investigate the key evidence for {target_gene}",
"Submit a calibrated go / no_go validation report",
],
)
name = f"proc_{target_gene}_{difficulty}_{seed}"
tags = [difficulty, target_gene, disease.replace(" ", "_")]
return Scenario(
name=name,
task=task,
target=target,
data_quality=data_quality,
difficulty=difficulty,
tags=tags,
)
def generate_procedural_scenarios(
n: int = 20,
seed: int = 42,
) -> List[Scenario]:
"""Pre-generate a pool of procedural scenarios across difficulties."""
rng = np.random.default_rng(seed)
scenarios: List[Scenario] = []
difficulties = ["easy", "medium", "hard"]
for i in range(n):
diff = difficulties[i % len(difficulties)]
child_seed = int(rng.integers(0, 2**31))
scenarios.append(generate_scenario(seed=child_seed, difficulty=diff))
logger.info("Generated %d procedural scenarios.", len(scenarios))
return scenarios