"""Task generator — produces (ValidationTaskSpec, FullLatentState) pairs for drug-target-validation episodes. Supports two modes: 1. Select from the curated ``SCENARIO_LIBRARY``. 2. Add procedurally-generated scenarios on top. """ from __future__ import annotations from typing import List, Optional, Tuple import numpy as np from models import ActionType, ValidationTaskSpec from server.simulator.latent_state import ( CreditState, DataQualityState, FullLatentState, TargetProfile, ValidationProgress, ) from .scenarios import SCENARIO_LIBRARY, Scenario from .procedural_generator import generate_procedural_scenarios class TaskGenerator: """Generates task + latent-state pairs for environment episodes.""" def __init__( self, scenarios: Optional[List[Scenario]] = None, domain_randomise: bool = True, ): if scenarios is not None: self.scenarios = scenarios else: self.scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios( n=20, seed=42, ) self.domain_randomise = domain_randomise def generate( self, *, seed: Optional[int] = None, scenario_name: Optional[str] = None, ) -> Tuple[ValidationTaskSpec, FullLatentState]: rng = np.random.default_rng(seed) if scenario_name: scenario = self._find_scenario(scenario_name) else: idx = int(rng.integers(0, len(self.scenarios))) scenario = self.scenarios[idx] task = scenario.task.model_copy(deep=True) target = scenario.target.model_copy(deep=True) data_quality = scenario.data_quality.model_copy(deep=True) if self.domain_randomise: self._randomise(rng, task, target, data_quality) if not task.available_actions: task.available_actions = [a.value for a in ActionType] latent = FullLatentState( target=target, data_quality=data_quality, progress=ValidationProgress(), credits=CreditState(credits_total=task.credits_limit), rng_seed=seed or 0, ) return task, latent def list_scenarios(self) -> List[str]: return [s.name for s in self.scenarios] # ── internals ─────────────────────────────────────────────────────── def _find_scenario(self, name: str) -> Scenario: for s in self.scenarios: if s.name == name: return s available = ", ".join(self.list_scenarios()) raise ValueError(f"Unknown scenario '{name}'. Available: {available}") @staticmethod def _randomise( rng: np.random.Generator, task: ValidationTaskSpec, target: TargetProfile, data_quality: DataQualityState, ) -> None: """Light domain randomisation that nudges noise / numerics without flipping ``correct_decision`` or ``key_evidence_dimensions``.""" # Credit budget jitter task.credits_limit = int( max(15, round(task.credits_limit * float(rng.uniform(0.9, 1.1)))) ) # Data-quality jitter data_quality.noise_level = float(np.clip( data_quality.noise_level + rng.normal(0, 0.02), 0.02, 0.4 )) data_quality.false_positive_rate = float(np.clip( data_quality.false_positive_rate + rng.normal(0, 0.01), 0.0, 0.3 )) data_quality.false_negative_rate = float(np.clip( data_quality.false_negative_rate + rng.normal(0, 0.01), 0.0, 0.3 )) data_quality.database_coverage = float(np.clip( data_quality.database_coverage + rng.normal(0, 0.03), 0.5, 1.0 )) # Target profile numerics — keep categorical fields fixed. target.tissue_specificity = float(np.clip( target.tissue_specificity * float(rng.uniform(0.9, 1.1)), 0.0, 1.0 )) target.disease_overexpression = float(max( 0.1, target.disease_overexpression * float(rng.uniform(0.85, 1.15)) )) target.druggability_score = float(np.clip( target.druggability_score * float(rng.uniform(0.9, 1.1)), 0.0, 1.0 )) target.selectivity_ratio = float(max( 0.0, target.selectivity_ratio * float(rng.uniform(0.85, 1.15)) )) target.in_vitro_ic50_nM = float(max( 0.5, target.in_vitro_ic50_nM * float(rng.uniform(0.7, 1.3)) ))