drugenv / server /tasks /generator.py
anugrahteesdollar's picture
initial: drugenv FastAPI + gradio demo
77e1e28 verified
"""Task generator β€” produces (ValidationTaskSpec, FullLatentState) pairs
for drug-target-validation episodes.
Supports two modes:
1. Select from the curated ``SCENARIO_LIBRARY``.
2. Add procedurally-generated scenarios on top.
"""
from __future__ import annotations
from typing import List, Optional, Tuple
import numpy as np
from models import ActionType, ValidationTaskSpec
from server.simulator.latent_state import (
CreditState,
DataQualityState,
FullLatentState,
TargetProfile,
ValidationProgress,
)
from .scenarios import SCENARIO_LIBRARY, Scenario
from .procedural_generator import generate_procedural_scenarios
class TaskGenerator:
"""Generates task + latent-state pairs for environment episodes."""
def __init__(
self,
scenarios: Optional[List[Scenario]] = None,
domain_randomise: bool = True,
):
if scenarios is not None:
self.scenarios = scenarios
else:
self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(
n=20, seed=42,
)
self.domain_randomise = domain_randomise
def generate(
self,
*,
seed: Optional[int] = None,
scenario_name: Optional[str] = None,
) -> Tuple[ValidationTaskSpec, FullLatentState]:
rng = np.random.default_rng(seed)
if scenario_name:
scenario = self._find_scenario(scenario_name)
else:
idx = int(rng.integers(0, len(self.scenarios)))
scenario = self.scenarios[idx]
task = scenario.task.model_copy(deep=True)
target = scenario.target.model_copy(deep=True)
data_quality = scenario.data_quality.model_copy(deep=True)
if self.domain_randomise:
self._randomise(rng, task, target, data_quality)
if not task.available_actions:
task.available_actions = [a.value for a in ActionType]
latent = FullLatentState(
target=target,
data_quality=data_quality,
progress=ValidationProgress(),
credits=CreditState(credits_total=task.credits_limit),
rng_seed=seed or 0,
)
return task, latent
def list_scenarios(self) -> List[str]:
return [s.name for s in self.scenarios]
# ── internals ───────────────────────────────────────────────────────
def _find_scenario(self, name: str) -> Scenario:
for s in self.scenarios:
if s.name == name:
return s
available = ", ".join(self.list_scenarios())
raise ValueError(f"Unknown scenario '{name}'. Available: {available}")
@staticmethod
def _randomise(
rng: np.random.Generator,
task: ValidationTaskSpec,
target: TargetProfile,
data_quality: DataQualityState,
) -> None:
"""Light domain randomisation that nudges noise / numerics without
flipping ``correct_decision`` or ``key_evidence_dimensions``."""
# Credit budget jitter
task.credits_limit = int(
max(15, round(task.credits_limit * float(rng.uniform(0.9, 1.1))))
)
# Data-quality jitter
data_quality.noise_level = float(np.clip(
data_quality.noise_level + rng.normal(0, 0.02), 0.02, 0.4
))
data_quality.false_positive_rate = float(np.clip(
data_quality.false_positive_rate + rng.normal(0, 0.01), 0.0, 0.3
))
data_quality.false_negative_rate = float(np.clip(
data_quality.false_negative_rate + rng.normal(0, 0.01), 0.0, 0.3
))
data_quality.database_coverage = float(np.clip(
data_quality.database_coverage + rng.normal(0, 0.03), 0.5, 1.0
))
# Target profile numerics β€” keep categorical fields fixed.
target.tissue_specificity = float(np.clip(
target.tissue_specificity * float(rng.uniform(0.9, 1.1)), 0.0, 1.0
))
target.disease_overexpression = float(max(
0.1, target.disease_overexpression * float(rng.uniform(0.85, 1.15))
))
target.druggability_score = float(np.clip(
target.druggability_score * float(rng.uniform(0.9, 1.1)), 0.0, 1.0
))
target.selectivity_ratio = float(max(
0.0, target.selectivity_ratio * float(rng.uniform(0.85, 1.15))
))
target.in_vitro_ic50_nM = float(max(
0.5, target.in_vitro_ic50_nM * float(rng.uniform(0.7, 1.3))
))