Spaces:
Runtime error
Runtime error
| """Train a drug-target-validation planner with TRL GRPO and OpenEnv rewards.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import re | |
| from numbers import Real | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| from client import DrugTargetEnv | |
| from models import ( | |
| ActionType, | |
| DrugTargetAction, | |
| ValidationObservation, | |
| build_agent_observation_context, | |
| build_agent_system_prompt, | |
| ) | |
| from server.hackathon_environment import DrugTargetEnvironment | |
| from server.tasks.scenarios import SCENARIO_LIBRARY | |
| DEFAULT_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" | |
| DEFAULT_OUTPUT_DIR = "training/grpo-output" | |
| DEFAULT_BASE_URL = "http://localhost:8000" | |
| DEFAULT_COMPLETION_TOKEN_BUDGET = 160 | |
| INVALID_ACTION_PENALTY = -2.0 | |
| ENVIRONMENT_ERROR_PENALTY = -4.0 | |
| SYSTEM_PROMPT = build_agent_system_prompt() | |
| ACTION_TYPES = [action.value for action in ActionType] | |
| ACTION_TYPE_ALIASES: Dict[str, str] = { | |
| "expression_lookup": ActionType.QUERY_EXPRESSION.value, | |
| "tissue_expression": ActionType.QUERY_EXPRESSION.value, | |
| "gtex_query": ActionType.QUERY_EXPRESSION.value, | |
| "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, | |
| "differential_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, | |
| "pathway": ActionType.PATHWAY_ENRICHMENT.value, | |
| "coexpression": ActionType.COEXPRESSION_NETWORK.value, | |
| "structure_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value, | |
| "alphafold_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value, | |
| "binding_site": ActionType.BINDING_SITE_ANALYSIS.value, | |
| "ppi": ActionType.PROTEIN_INTERACTION_NETWORK.value, | |
| "interaction_network": ActionType.PROTEIN_INTERACTION_NETWORK.value, | |
| "druggability": ActionType.DRUGGABILITY_SCREEN.value, | |
| "trial_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value, | |
| "clinical_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value, | |
| "toxicity": ActionType.TOXICITY_PANEL.value, | |
| "tox_panel": ActionType.TOXICITY_PANEL.value, | |
| "off_target": ActionType.OFF_TARGET_SCREEN.value, | |
| "selectivity": ActionType.OFF_TARGET_SCREEN.value, | |
| "stratification": ActionType.PATIENT_STRATIFICATION.value, | |
| "patient_subset": ActionType.PATIENT_STRATIFICATION.value, | |
| "literature": ActionType.LITERATURE_SEARCH.value, | |
| "pubmed": ActionType.LITERATURE_SEARCH.value, | |
| "synthesis": ActionType.EVIDENCE_SYNTHESIS.value, | |
| "competitor": ActionType.COMPETITOR_LANDSCAPE.value, | |
| "in_vitro": ActionType.IN_VITRO_ASSAY.value, | |
| "cell_assay": ActionType.IN_VITRO_ASSAY.value, | |
| "in_vivo": ActionType.IN_VIVO_MODEL.value, | |
| "mouse_model": ActionType.IN_VIVO_MODEL.value, | |
| "crispr": ActionType.CRISPR_KNOCKOUT.value, | |
| "biomarker": ActionType.BIOMARKER_CORRELATION.value, | |
| "red_flag": ActionType.FLAG_RED_FLAG.value, | |
| "expert_review": ActionType.REQUEST_EXPERT_REVIEW.value, | |
| "submit_report": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| "final_report": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| "submit": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| } | |
| # Heuristic teacher policy used to seed GRPO prompts. Mirrors the default | |
| # pipeline used by the inference script in ``run_agent.py``. | |
| HEURISTIC_SEQUENCE: List[ActionType] = [ | |
| ActionType.QUERY_EXPRESSION, | |
| ActionType.DRUGGABILITY_SCREEN, | |
| ActionType.OFF_TARGET_SCREEN, | |
| ActionType.TOXICITY_PANEL, | |
| ActionType.CLINICAL_TRIAL_LOOKUP, | |
| ActionType.LITERATURE_SEARCH, | |
| ActionType.SUBMIT_VALIDATION_REPORT, | |
| ] | |
| VALID_ACTION_TYPES = set(ACTION_TYPES) | |
| def compact_preview(value: Any, max_chars: int = 160) -> str: | |
| try: | |
| text = json.dumps(value, ensure_ascii=True, sort_keys=True) | |
| except TypeError: | |
| text = str(value) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if len(text) <= max_chars: | |
| return text | |
| return text[: max_chars - 3] + "..." | |
| def _edit_distance(a: str, b: str) -> int: | |
| if len(a) < len(b): | |
| return _edit_distance(b, a) | |
| if not b: | |
| return len(a) | |
| prev = list(range(len(b) + 1)) | |
| for i, ca in enumerate(a): | |
| curr = [i + 1] | |
| for j, cb in enumerate(b): | |
| curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb))) | |
| prev = curr | |
| return prev[-1] | |
| def get_payload_value(payload: Dict[str, Any], *names: str) -> Any: | |
| for name in names: | |
| if name in payload: | |
| return payload[name] | |
| lowered = {str(key).lower(): value for key, value in payload.items()} | |
| for name in names: | |
| if name.lower() in lowered: | |
| return lowered[name.lower()] | |
| for key, value in lowered.items(): | |
| for name in names: | |
| threshold = max(2, len(name) // 3) | |
| if _edit_distance(key, name.lower()) <= threshold: | |
| return value | |
| return None | |
| def build_argument_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Train a GRPO policy against the OpenEnv drug-target-validation " | |
| "environment." | |
| ) | |
| ) | |
| parser.add_argument("--model-id", default=DEFAULT_MODEL_ID) | |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) | |
| parser.add_argument("--dataset-episodes", type=int, default=8) | |
| parser.add_argument("--rollout-steps", type=int, default=6) | |
| parser.add_argument( | |
| "--collection-policy", | |
| choices=["random", "heuristic"], | |
| default="heuristic", | |
| help="Policy used to build prompt states for GRPO training.", | |
| ) | |
| parser.add_argument( | |
| "--reward-backend", | |
| choices=["local", "remote"], | |
| default="local", | |
| help="Use local in-process scoring or a live OpenEnv server.", | |
| ) | |
| parser.add_argument( | |
| "--base-url", | |
| default=DEFAULT_BASE_URL, | |
| help="Base URL for the OpenEnv server when reward-backend=remote.", | |
| ) | |
| parser.add_argument( | |
| "--scenario-name", | |
| action="append", | |
| default=None, | |
| help="Repeatable scenario selector. Defaults to all curated scenarios.", | |
| ) | |
| parser.add_argument( | |
| "--domain-randomise", | |
| action="store_true", | |
| help="Enable domain randomisation while building prompts and local rewards.", | |
| ) | |
| parser.add_argument("--num-generations", type=int, default=2) | |
| parser.add_argument( | |
| "--max-completion-length", | |
| type=int, | |
| default=DEFAULT_COMPLETION_TOKEN_BUDGET, | |
| ) | |
| parser.add_argument("--max-prompt-length", type=int, default=768) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=2) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=1) | |
| parser.add_argument("--learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--num-train-epochs", type=float, default=1.0) | |
| parser.add_argument("--logging-steps", type=int, default=1) | |
| parser.add_argument("--save-steps", type=int, default=50) | |
| parser.add_argument( | |
| "--plot-metric-key", | |
| default=None, | |
| help="Optional extra metric key from trainer log history to plot.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument( | |
| "--load-model-only", | |
| action="store_true", | |
| help="Download and load the selected model and tokenizer, then exit.", | |
| ) | |
| parser.add_argument( | |
| "--trust-remote-code", | |
| action="store_true", | |
| help="Pass trust_remote_code=True to model/tokenizer loading.", | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Build the prompt dataset and smoke-test the reward function without training.", | |
| ) | |
| parser.add_argument( | |
| "--push-to-hub", | |
| type=str, | |
| default=None, | |
| help="HuggingFace Hub repo id to push the trained model to (e.g. 'myuser/my-model').", | |
| ) | |
| # ── Live evidence hooks (drug-domain ``LiveTrainingCallback``). | |
| # All three knobs are best-effort: a missing ``transformers`` / | |
| # read-only ``evidence/`` dir reduces the callback to a no-op so | |
| # training itself still proceeds. | |
| parser.add_argument( | |
| "--evidence-dir", | |
| default="evidence", | |
| help="Directory for training_log.csv + checkpoint_evals.csv.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-eval-steps", | |
| type=int, | |
| default=50, | |
| help="Run a held-out heuristic eval every N GRPO updates.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-eval-episodes", | |
| type=int, | |
| default=4, | |
| help="Number of held-out episodes per mid-training eval.", | |
| ) | |
| return parser | |
| def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: | |
| return build_argument_parser().parse_args(argv) | |
| def make_training_args(**overrides: Any) -> argparse.Namespace: | |
| """Build an argparse-style namespace for notebooks and scripts.""" | |
| parser = build_argument_parser() | |
| defaults = vars(parser.parse_args([])) | |
| unknown = sorted(set(overrides) - set(defaults)) | |
| if unknown: | |
| raise ValueError(f"Unknown training args: {', '.join(unknown)}") | |
| defaults.update(overrides) | |
| return argparse.Namespace(**defaults) | |
| def format_observation(obs: ValidationObservation) -> str: | |
| parts = [ | |
| f"TARGET: {obs.target_gene} | INDICATION: {obs.indication}", | |
| f"DISEASE CONTEXT: {obs.disease_context}", | |
| f"Step: {obs.step_index} | Credits: " | |
| f"{obs.credits_remaining}/{obs.credits_total}", | |
| ] | |
| context = build_agent_observation_context(obs, max_tools=5) | |
| if context: | |
| parts.append(context) | |
| if obs.pipeline_history: | |
| last5 = obs.pipeline_history[-5:] | |
| parts.append("Recent history:") | |
| for step in last5: | |
| tag = "OK" if step.get("success") else "FAIL" | |
| line = ( | |
| f" [{tag}] {step.get('action_type')}: " | |
| f"{str(step.get('output_summary', ''))[:80]}" | |
| ) | |
| parts.append(line) | |
| completed = { | |
| step.get("action_type") for step in obs.pipeline_history if step.get("success") | |
| } | |
| if completed: | |
| parts.append( | |
| "Completed actions (try not to repeat): " | |
| + ", ".join(sorted(map(str, completed))) | |
| ) | |
| remaining = [ | |
| action.value for action in HEURISTIC_SEQUENCE | |
| if action.value not in completed | |
| ] | |
| if remaining: | |
| parts.append(f"Suggested next actions (pick one): {', '.join(remaining)}") | |
| dossier = obs.dossier | |
| if dossier.expression_findings: | |
| parts.append( | |
| f"Expression findings: {compact_preview(dossier.expression_findings, 200)}" | |
| ) | |
| if dossier.protein_findings: | |
| parts.append( | |
| f"Protein findings: {compact_preview(dossier.protein_findings, 200)}" | |
| ) | |
| if dossier.safety_findings: | |
| parts.append( | |
| f"Safety findings: {compact_preview(dossier.safety_findings, 200)}" | |
| ) | |
| if dossier.clinical_findings: | |
| parts.append( | |
| f"Clinical findings: {compact_preview(dossier.clinical_findings, 200)}" | |
| ) | |
| if dossier.experimental_results: | |
| parts.append( | |
| f"Experimental results: {compact_preview(dossier.experimental_results, 200)}" | |
| ) | |
| if obs.latest_output and obs.latest_output.data: | |
| parts.append( | |
| f"Latest data: {compact_preview(obs.latest_output.data, 200)}" | |
| ) | |
| if obs.rule_violations: | |
| parts.append(f"VIOLATIONS: {obs.rule_violations}") | |
| parts.append( | |
| 'Output ONLY a single JSON object with these exact keys, no ' | |
| 'comments, no extra text:\n' | |
| '{"action_type": "<one of the available actions>", ' | |
| '"parameters": {}, "reasoning": "<why>", ' | |
| '"final_decision": null, "confidence": null}\n' | |
| 'When picking submit_validation_report, set "final_decision" to ' | |
| '"go" or "no_go" and "confidence" to a number in [0, 1].' | |
| ) | |
| return "\n".join(parts) | |
| def build_training_prompt(obs: ValidationObservation) -> str: | |
| return f"{SYSTEM_PROMPT}\n\n{format_observation(obs)}" | |
| def heuristic_next_action( | |
| history: Sequence[ActionType], | |
| step_index: int, | |
| ) -> ActionType: | |
| seen = set(history) | |
| for action in HEURISTIC_SEQUENCE: | |
| if action not in seen: | |
| return action | |
| return ActionType.SUBMIT_VALIDATION_REPORT | |
| def pick_action( | |
| policy: str, | |
| step_index: int, | |
| history: Sequence[ActionType], | |
| ) -> ActionType: | |
| if policy == "random": | |
| return random.choice(list(ActionType)) | |
| return heuristic_next_action(history, step_index) | |
| def _heuristic_decision(obs: ValidationObservation) -> Tuple[str, float]: | |
| """Best-effort go / no-go from the running dossier. | |
| Looks at the most-recent druggability_score and clinical_precedent in | |
| the dossier, falls back to ``no_go`` with mid confidence when the | |
| evidence is empty. | |
| """ | |
| decision = "no_go" | |
| confidence = 0.55 | |
| drug = (obs.dossier.protein_findings or {}).get( | |
| ActionType.DRUGGABILITY_SCREEN.value, {} | |
| ) | |
| score = drug.get("druggability_score") if isinstance(drug, dict) else None | |
| if isinstance(score, (int, float)): | |
| if score >= 0.55: | |
| decision = "go" | |
| confidence = float(min(0.9, 0.55 + score * 0.4)) | |
| else: | |
| decision = "no_go" | |
| confidence = float(min(0.9, 0.55 + (1.0 - score) * 0.4)) | |
| clinical = (obs.dossier.clinical_findings or {}).get( | |
| ActionType.CLINICAL_TRIAL_LOOKUP.value, {} | |
| ) | |
| precedent = clinical.get("clinical_precedent") if isinstance(clinical, dict) else None | |
| if precedent == "negative": | |
| decision = "no_go" | |
| confidence = max(confidence, 0.7) | |
| elif precedent == "positive": | |
| decision = "go" | |
| confidence = max(confidence, 0.7) | |
| return decision, round(confidence, 3) | |
| def build_drug_target_action( | |
| action_type: ActionType, | |
| obs: ValidationObservation, | |
| ) -> DrugTargetAction: | |
| """Construct a heuristically-reasonable action for the given type.""" | |
| parameters: Dict[str, Any] = {} | |
| reasoning = f"Advance the investigation with {action_type.value}." | |
| final_decision: Optional[str] = None | |
| confidence: Optional[float] = None | |
| if action_type == ActionType.QUERY_EXPRESSION: | |
| parameters = {"database": "GTEx"} | |
| reasoning = "Establish baseline tissue expression for the target." | |
| elif action_type == ActionType.DIFFERENTIAL_EXPRESSION: | |
| parameters = {"cohort": "TCGA"} | |
| reasoning = "Confirm disease-driven dysregulation of the target." | |
| elif action_type == ActionType.PATHWAY_ENRICHMENT: | |
| parameters = {"library": "Reactome"} | |
| reasoning = "Map the target into known pathway context." | |
| elif action_type == ActionType.COEXPRESSION_NETWORK: | |
| parameters = {"source": "ARCHS4"} | |
| reasoning = "Find functionally-related genes for mechanism reasoning." | |
| elif action_type == ActionType.PROTEIN_STRUCTURE_LOOKUP: | |
| parameters = {"method": "AlphaFold"} | |
| reasoning = "Pull the target's predicted 3D structure." | |
| elif action_type == ActionType.BINDING_SITE_ANALYSIS: | |
| parameters = {"include_allosteric": True} | |
| reasoning = "Detect classic and allosteric pockets." | |
| elif action_type == ActionType.PROTEIN_INTERACTION_NETWORK: | |
| parameters = {"source": "STRING"} | |
| reasoning = "Map first-degree interactors for off-target reasoning." | |
| elif action_type == ActionType.DRUGGABILITY_SCREEN: | |
| parameters = {"source": "ChEMBL"} | |
| reasoning = "Score overall druggability and known ligand chemistry." | |
| elif action_type == ActionType.CLINICAL_TRIAL_LOOKUP: | |
| parameters = {"source": "ClinicalTrials_gov"} | |
| reasoning = "Check clinical precedent for this target / indication." | |
| elif action_type == ActionType.TOXICITY_PANEL: | |
| parameters = {"source": "ToxCast"} | |
| reasoning = "Probe target-mediated toxicity in light of expression." | |
| elif action_type == ActionType.OFF_TARGET_SCREEN: | |
| parameters = {"source": "SafetyPanel"} | |
| reasoning = "Quantify selectivity and paralog liabilities." | |
| elif action_type == ActionType.PATIENT_STRATIFICATION: | |
| parameters = {"source": "ClinVar"} | |
| reasoning = "Identify responder subpopulations." | |
| elif action_type == ActionType.LITERATURE_SEARCH: | |
| parameters = { | |
| "query": f"{obs.target_gene} {obs.indication}".strip() | |
| } | |
| reasoning = "Scan recent literature for precedent that updates priors." | |
| elif action_type == ActionType.EVIDENCE_SYNTHESIS: | |
| reasoning = "Aggregate prior dossier into a coherent picture." | |
| elif action_type == ActionType.COMPETITOR_LANDSCAPE: | |
| parameters = {"source": "DrugBank"} | |
| reasoning = "Survey competing programs against the same target." | |
| elif action_type == ActionType.IN_VITRO_ASSAY: | |
| parameters = {"panel": "InVitroPanel"} | |
| reasoning = "Confirm computational evidence with cell-line activity." | |
| elif action_type == ActionType.IN_VIVO_MODEL: | |
| parameters = {"model": "MouseModel"} | |
| reasoning = "Confirm in-vitro signal with disease-relevant in-vivo data." | |
| elif action_type == ActionType.CRISPR_KNOCKOUT: | |
| parameters = {"panel": "DepMap"} | |
| reasoning = "Test functional dependency of the target." | |
| elif action_type == ActionType.BIOMARKER_CORRELATION: | |
| parameters = {"source": "BiomarkerPanel"} | |
| reasoning = "Correlate target activity with patient biomarkers." | |
| elif action_type == ActionType.FLAG_RED_FLAG: | |
| parameters = {"note": "potential concern noted by the agent"} | |
| reasoning = "Record a concern without spending credits." | |
| elif action_type == ActionType.REQUEST_EXPERT_REVIEW: | |
| parameters = {"focus": "current_dossier"} | |
| reasoning = "Get a lightweight expert critique of the dossier." | |
| elif action_type == ActionType.SUBMIT_VALIDATION_REPORT: | |
| decision, conf = _heuristic_decision(obs) | |
| final_decision = decision | |
| confidence = conf | |
| reasoning = ( | |
| f"Submit {decision} with confidence {conf:.2f} based on the " | |
| "evidence gathered so far." | |
| ) | |
| return DrugTargetAction( | |
| action_type=action_type, | |
| parameters=parameters, | |
| reasoning=reasoning, | |
| final_decision=final_decision, | |
| confidence=confidence, | |
| ) | |
| def selected_scenarios(requested: Optional[Sequence[str]]) -> List[str]: | |
| from server.tasks.procedural_generator import generate_procedural_scenarios | |
| all_scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios( | |
| n=20, seed=42 | |
| ) | |
| available = [scenario.name for scenario in all_scenarios] | |
| if not requested: | |
| return available | |
| unknown = sorted(set(requested) - set(available)) | |
| if unknown: | |
| raise ValueError(f"Unknown scenarios requested: {', '.join(unknown)}") | |
| return list(requested) | |
| def action_completion_json(action: DrugTargetAction) -> str: | |
| payload = { | |
| "action_type": action.action_type.value, | |
| "parameters": action.parameters, | |
| "reasoning": action.reasoning, | |
| "final_decision": action.final_decision, | |
| "confidence": action.confidence, | |
| } | |
| return json.dumps(payload) | |
| def build_prompt_examples( | |
| *, | |
| dataset_episodes: int, | |
| rollout_steps: int, | |
| collection_policy: str, | |
| scenario_names: Sequence[str], | |
| seed: int, | |
| domain_randomise: bool, | |
| ) -> List[Dict[str, str]]: | |
| rng = random.Random(seed) | |
| examples: List[Dict[str, str]] = [] | |
| scenario_cycle = list(scenario_names) | |
| rng.shuffle(scenario_cycle) | |
| for episode_idx in range(dataset_episodes): | |
| scenario_name = scenario_cycle[episode_idx % len(scenario_cycle)] | |
| env = DrugTargetEnvironment( | |
| scenario_name=scenario_name, | |
| domain_randomise=domain_randomise, | |
| ) | |
| obs = env.reset() | |
| history_actions: List[DrugTargetAction] = [] | |
| for step_idx in range(rollout_steps): | |
| if obs.done: | |
| break | |
| next_action = build_drug_target_action( | |
| action_type=pick_action( | |
| collection_policy, | |
| step_idx, | |
| [action.action_type for action in history_actions], | |
| ), | |
| obs=obs, | |
| ) | |
| examples.append({ | |
| "prompt": build_training_prompt(obs), | |
| "scenario_name": scenario_name, | |
| "history_actions": json.dumps( | |
| [action.model_dump() for action in history_actions] | |
| ), | |
| "rng_seed": str(env._latent.rng_seed if env._latent else 0), | |
| "reference_action": action_completion_json(next_action), | |
| "problem_statement": ( | |
| f"Validate {obs.target_gene} in {obs.indication}" | |
| ), | |
| }) | |
| history_actions.append(next_action) | |
| obs = env.step(next_action) | |
| return examples | |
| def completion_to_text(completion: Any) -> str: | |
| if isinstance(completion, str): | |
| return completion.strip() | |
| if isinstance(completion, dict): | |
| return content_to_text(completion.get("content", "")) | |
| if isinstance(completion, list): | |
| for item in reversed(completion): | |
| if isinstance(item, dict) and "content" in item: | |
| text = content_to_text(item["content"]) | |
| if text: | |
| return text | |
| if isinstance(item, str) and item.strip(): | |
| return item.strip() | |
| return str(completion).strip() | |
| def content_to_text(content: Any) -> str: | |
| if isinstance(content, str): | |
| return content.strip() | |
| if isinstance(content, list): | |
| parts: List[str] = [] | |
| for part in content: | |
| if isinstance(part, str): | |
| parts.append(part) | |
| elif isinstance(part, dict): | |
| if isinstance(part.get("text"), str): | |
| parts.append(part["text"]) | |
| elif isinstance(part.get("content"), str): | |
| parts.append(part["content"]) | |
| return "".join(parts).strip() | |
| return str(content).strip() | |
| def _repair_truncated_json(text: str) -> Optional[str]: | |
| """Try to repair JSON truncated mid-value (common with small LLMs).""" | |
| s = text.strip() | |
| if not s.startswith("{"): | |
| return None | |
| s = re.sub(r',\s*"[^"\n]*$', '', s) | |
| s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s) | |
| in_string = False | |
| escape = False | |
| for ch in s: | |
| if escape: | |
| escape = False | |
| continue | |
| if ch == "\\": | |
| escape = True | |
| continue | |
| if ch == '"': | |
| in_string = not in_string | |
| if in_string: | |
| s += '"' | |
| open_braces = s.count("{") - s.count("}") | |
| open_brackets = s.count("[") - s.count("]") | |
| s += "]" * max(0, open_brackets) | |
| s += "}" * max(0, open_braces) | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return s | |
| except json.JSONDecodeError: | |
| pass | |
| s = re.sub(r',\s*([}\]])', r'\1', s) | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return s | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def _normalize_jsonish_text(text: str) -> str: | |
| text = _strip_js_comments(text) | |
| text = re.sub(r'(?<=:\s)\bNone\b', 'null', text) | |
| text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text) | |
| text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text) | |
| text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text) | |
| return text | |
| def _strip_js_comments(text: str) -> str: | |
| text = re.sub(r'//[^\n]*', '', text) | |
| text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) | |
| return text | |
| def extract_json_object(text: str) -> Optional[Dict[str, Any]]: | |
| stripped = _normalize_jsonish_text(text).strip() | |
| fence_prefix = "```" | |
| if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix): | |
| lines = stripped.splitlines() | |
| if len(lines) >= 3: | |
| stripped = "\n".join(lines[1:-1]).strip() | |
| candidates: List[str] = [stripped] | |
| start = stripped.find("{") | |
| while start != -1: | |
| depth = 0 | |
| for idx in range(start, len(stripped)): | |
| char = stripped[idx] | |
| if char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| candidates.append(stripped[start:idx + 1]) | |
| break | |
| start = stripped.find("{", start + 1) | |
| first_brace = stripped.find("{") | |
| if first_brace != -1: | |
| repaired = _repair_truncated_json(stripped[first_brace:]) | |
| if repaired is not None: | |
| candidates.append(repaired) | |
| candidates.sort(key=len, reverse=True) | |
| for candidate in candidates: | |
| try: | |
| parsed = json.loads(candidate) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(parsed, dict): | |
| return parsed | |
| return None | |
| def normalize_optional_string(value: Any) -> Optional[str]: | |
| if value is None or isinstance(value, bool): | |
| return None | |
| if isinstance(value, str): | |
| value = value.strip() | |
| return value or None | |
| if isinstance(value, (int, float)): | |
| return str(value) | |
| return compact_preview(value, 80) | |
| def normalize_action_type(raw_action_type: Any) -> Optional[str]: | |
| if not isinstance(raw_action_type, str): | |
| return None | |
| candidate = raw_action_type.strip().lower() | |
| if candidate in ACTION_TYPES: | |
| return candidate | |
| if candidate in ACTION_TYPE_ALIASES: | |
| return ACTION_TYPE_ALIASES[candidate] | |
| candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_") | |
| if candidate in ACTION_TYPES: | |
| return candidate | |
| if candidate in ACTION_TYPE_ALIASES: | |
| return ACTION_TYPE_ALIASES[candidate] | |
| heuristics = [ | |
| (("expression",), ActionType.QUERY_EXPRESSION.value), | |
| (("differential",), ActionType.DIFFERENTIAL_EXPRESSION.value), | |
| (("pathway",), ActionType.PATHWAY_ENRICHMENT.value), | |
| (("coexpression",), ActionType.COEXPRESSION_NETWORK.value), | |
| (("structure",), ActionType.PROTEIN_STRUCTURE_LOOKUP.value), | |
| (("binding", "site"), ActionType.BINDING_SITE_ANALYSIS.value), | |
| (("interaction",), ActionType.PROTEIN_INTERACTION_NETWORK.value), | |
| (("druggab",), ActionType.DRUGGABILITY_SCREEN.value), | |
| (("clinical",), ActionType.CLINICAL_TRIAL_LOOKUP.value), | |
| (("toxic",), ActionType.TOXICITY_PANEL.value), | |
| (("off",), ActionType.OFF_TARGET_SCREEN.value), | |
| (("strat",), ActionType.PATIENT_STRATIFICATION.value), | |
| (("literature",), ActionType.LITERATURE_SEARCH.value), | |
| (("synthes",), ActionType.EVIDENCE_SYNTHESIS.value), | |
| (("competitor",), ActionType.COMPETITOR_LANDSCAPE.value), | |
| (("vitro",), ActionType.IN_VITRO_ASSAY.value), | |
| (("vivo",), ActionType.IN_VIVO_MODEL.value), | |
| (("crispr",), ActionType.CRISPR_KNOCKOUT.value), | |
| (("biomarker",), ActionType.BIOMARKER_CORRELATION.value), | |
| (("red", "flag"), ActionType.FLAG_RED_FLAG.value), | |
| (("review",), ActionType.REQUEST_EXPERT_REVIEW.value), | |
| (("submit",), ActionType.SUBMIT_VALIDATION_REPORT.value), | |
| (("report",), ActionType.SUBMIT_VALIDATION_REPORT.value), | |
| ] | |
| for fragments, normalized in heuristics: | |
| if all(fragment in candidate for fragment in fragments): | |
| return normalized | |
| return None | |
| def ensure_terminal_payload( | |
| obs: Optional[ValidationObservation], | |
| action: DrugTargetAction, | |
| ) -> DrugTargetAction: | |
| """Make sure a SUBMIT_VALIDATION_REPORT carries a decision + confidence. | |
| If the model omits either field, fill them via ``_heuristic_decision`` | |
| when an observation is available, otherwise fall back to a neutral | |
| no_go @ 0.5. | |
| """ | |
| if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT: | |
| return action | |
| final_decision = action.final_decision | |
| confidence = action.confidence | |
| if final_decision is None or confidence is None: | |
| if obs is not None: | |
| heur_decision, heur_conf = _heuristic_decision(obs) | |
| else: | |
| heur_decision, heur_conf = "no_go", 0.5 | |
| if final_decision is None: | |
| final_decision = heur_decision | |
| if confidence is None: | |
| confidence = heur_conf | |
| return action.model_copy(update={ | |
| "final_decision": final_decision, | |
| "confidence": float(max(0.0, min(1.0, float(confidence)))), | |
| }) | |
| def parse_action_completion(text: str) -> Optional[DrugTargetAction]: | |
| payload = extract_json_object(text) | |
| if payload is not None: | |
| action_type = normalize_action_type(get_payload_value(payload, "action_type")) | |
| if action_type is None: | |
| return None | |
| parameters = get_payload_value(payload, "parameters", "params") or {} | |
| if not isinstance(parameters, dict): | |
| parameters = {} | |
| raw_conf = get_payload_value(payload, "confidence") | |
| if raw_conf is None: | |
| confidence: Optional[float] = None | |
| else: | |
| try: | |
| confidence = float(raw_conf) | |
| confidence = max(0.0, min(1.0, confidence)) | |
| except (TypeError, ValueError): | |
| confidence = None | |
| reasoning = get_payload_value( | |
| payload, "reasoning", "rationale", "justification", "reason" | |
| ) | |
| if reasoning is not None and not isinstance(reasoning, str): | |
| reasoning = compact_preview(reasoning, 200) | |
| reasoning = reasoning or "" | |
| final_decision = normalize_optional_string( | |
| get_payload_value(payload, "final_decision", "decision", "go_no_go") | |
| ) | |
| if final_decision is not None: | |
| final_decision = final_decision.lower().replace("-", "_") | |
| if final_decision not in {"go", "no_go"}: | |
| final_decision = None | |
| return DrugTargetAction( | |
| action_type=ActionType(action_type), | |
| parameters=parameters, | |
| reasoning=reasoning, | |
| final_decision=final_decision, | |
| confidence=confidence, | |
| ) | |
| action_match = re.search( | |
| r'["\']action_type["\']\s*:\s*["\']([^"\']+)', | |
| text, | |
| re.IGNORECASE, | |
| ) | |
| if not action_match: | |
| return None | |
| action_type = normalize_action_type(action_match.group(1)) | |
| if action_type is None: | |
| return None | |
| confidence_match = re.search( | |
| r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)', | |
| text, | |
| re.IGNORECASE, | |
| ) | |
| reasoning_match = re.search( | |
| r'["\'](?:reasoning|justif\w*|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)', | |
| text, | |
| re.DOTALL | re.IGNORECASE, | |
| ) | |
| decision_match = re.search( | |
| r'["\']final_decision["\']\s*:\s*["\']?(go|no[_-]?go)["\']?', | |
| text, | |
| re.IGNORECASE, | |
| ) | |
| confidence = None | |
| if confidence_match: | |
| try: | |
| confidence = max(0.0, min(1.0, float(confidence_match.group(1)))) | |
| except ValueError: | |
| confidence = None | |
| reasoning: Optional[str] = None | |
| if reasoning_match: | |
| try: | |
| reasoning = json.loads(f'"{reasoning_match.group(1)}"') | |
| except json.JSONDecodeError: | |
| reasoning = reasoning_match.group(1) | |
| final_decision: Optional[str] = None | |
| if decision_match: | |
| final_decision = ( | |
| decision_match.group(1).lower().replace("-", "_").replace(" ", "_") | |
| ) | |
| if final_decision not in {"go", "no_go"}: | |
| final_decision = None | |
| return DrugTargetAction( | |
| action_type=ActionType(action_type), | |
| parameters={}, | |
| reasoning=reasoning or "", | |
| final_decision=final_decision, | |
| confidence=confidence, | |
| ) | |
| def decode_history_actions(history_actions: Optional[str]) -> List[DrugTargetAction]: | |
| if not history_actions: | |
| return [] | |
| raw_actions = json.loads(history_actions) | |
| return [ | |
| DrugTargetAction(**action_payload) | |
| for action_payload in raw_actions | |
| if isinstance(action_payload, dict) | |
| ] | |
| def normalise_column(values: Any, length: int) -> List[Any]: | |
| if values is None: | |
| return [None] * length | |
| if isinstance(values, list): | |
| if len(values) == length: | |
| return values | |
| if len(values) == 1: | |
| return values * length | |
| return values[:length] + [None] * max(0, length - len(values)) | |
| return [values] * length | |
| class OpenEnvReward: | |
| """Reward function compatible with TRL GRPOTrainer.""" | |
| def __init__( | |
| self, | |
| *, | |
| reward_backend: str, | |
| base_url: str, | |
| invalid_action_penalty: float = INVALID_ACTION_PENALTY, | |
| environment_error_penalty: float = ENVIRONMENT_ERROR_PENALTY, | |
| domain_randomise: bool = False, | |
| ) -> None: | |
| self.__name__ = "openenv_reward" | |
| self.reward_backend = reward_backend | |
| self.base_url = base_url | |
| self.invalid_action_penalty = invalid_action_penalty | |
| self.environment_error_penalty = environment_error_penalty | |
| self.domain_randomise = domain_randomise | |
| def __call__( | |
| self, | |
| completions: List[Any], | |
| scenario_name: Optional[List[str]] = None, | |
| history_actions: Optional[List[str]] = None, | |
| rng_seed: Optional[List[str]] = None, | |
| **_: Any, | |
| ) -> List[float]: | |
| scenario_names = normalise_column(scenario_name, len(completions)) | |
| history_columns = normalise_column(history_actions, len(completions)) | |
| seed_columns = normalise_column(rng_seed, len(completions)) | |
| rewards: List[float] = [] | |
| for completion, current_scenario, current_history, current_seed in zip( | |
| completions, | |
| scenario_names, | |
| history_columns, | |
| seed_columns, | |
| ): | |
| action = parse_action_completion(completion_to_text(completion)) | |
| if action is None: | |
| rewards.append(self.invalid_action_penalty) | |
| continue | |
| try: | |
| if self.reward_backend == "remote": | |
| reward = self._score_remote(action, current_scenario, current_history) | |
| else: | |
| reward = self._score_local( | |
| action, current_scenario, current_history, current_seed | |
| ) | |
| except Exception: | |
| reward = self.environment_error_penalty | |
| rewards.append(float(reward)) | |
| return rewards | |
| def _score_local( | |
| self, | |
| action: DrugTargetAction, | |
| scenario_name: Optional[str], | |
| history_actions: Optional[str], | |
| rng_seed: Optional[str] = None, | |
| ) -> float: | |
| env = DrugTargetEnvironment( | |
| scenario_name=scenario_name, | |
| domain_randomise=self.domain_randomise, | |
| ) | |
| seed = int(rng_seed) if rng_seed else None | |
| obs = env.reset(seed=seed) | |
| for previous_action in decode_history_actions(history_actions): | |
| obs = env.step(previous_action) | |
| if obs.done: | |
| return float(obs.reward) | |
| action = ensure_terminal_payload(obs, action) | |
| obs = env.step(action) | |
| return float(obs.reward) | |
| def _score_remote( | |
| self, | |
| action: DrugTargetAction, | |
| scenario_name: Optional[str], | |
| history_actions: Optional[str], | |
| ) -> float: | |
| with DrugTargetEnv(base_url=self.base_url) as env: | |
| # NOTE: scenario_name is accepted for API parity with _score_local | |
| # but the OpenEnv HTTP protocol does not yet support passing it | |
| # through reset(). The server uses its configured default. | |
| result = env.reset() | |
| obs = result.observation | |
| for previous_action in decode_history_actions(history_actions): | |
| result = env.step(previous_action) | |
| obs = result.observation | |
| if result.done: | |
| return float(result.reward or 0.0) | |
| action = ensure_terminal_payload(obs, action) | |
| result = env.step(action) | |
| if result.reward is not None: | |
| return float(result.reward) | |
| return float(result.observation.reward) | |
| def is_numeric_log_value(value: Any) -> bool: | |
| return isinstance(value, Real) and not isinstance(value, bool) | |
| def available_numeric_log_keys(log_history: Sequence[Dict[str, Any]]) -> List[str]: | |
| keys = { | |
| key | |
| for entry in log_history | |
| if isinstance(entry, dict) | |
| for key, value in entry.items() | |
| if key != "step" and is_numeric_log_value(value) | |
| } | |
| return sorted(keys) | |
| def extract_log_series( | |
| log_history: Sequence[Dict[str, Any]], | |
| key: Optional[str], | |
| ) -> List[Tuple[float, float]]: | |
| if not key: | |
| return [] | |
| series: List[Tuple[float, float]] = [] | |
| synthetic_step = 0 | |
| for entry in log_history: | |
| if not isinstance(entry, dict) or key not in entry: | |
| continue | |
| value = entry.get(key) | |
| if not is_numeric_log_value(value): | |
| continue | |
| raw_step = entry.get("step") | |
| if is_numeric_log_value(raw_step): | |
| step = float(raw_step) | |
| else: | |
| synthetic_step += 1 | |
| step = float(synthetic_step) | |
| series.append((step, float(value))) | |
| return series | |
| def select_reward_key(log_history: Sequence[Dict[str, Any]]) -> Optional[str]: | |
| numeric_keys = available_numeric_log_keys(log_history) | |
| reward_keys = [key for key in numeric_keys if "reward" in key.lower()] | |
| if not reward_keys: | |
| return None | |
| preferred = [ | |
| "reward", | |
| "mean_reward", | |
| "reward_mean", | |
| "rewards/open_env_reward", | |
| ] | |
| lowered = {key.lower(): key for key in reward_keys} | |
| for key in preferred: | |
| if key in lowered: | |
| return lowered[key] | |
| reward_keys.sort(key=lambda key: ("/" in key, len(key), key)) | |
| return reward_keys[0] | |
| def select_metric_key( | |
| log_history: Sequence[Dict[str, Any]], | |
| *, | |
| reward_key: Optional[str], | |
| requested_key: Optional[str] = None, | |
| ) -> Optional[str]: | |
| numeric_keys = available_numeric_log_keys(log_history) | |
| if requested_key: | |
| if requested_key not in numeric_keys: | |
| available = ", ".join(numeric_keys) or "none" | |
| raise ValueError( | |
| f"Requested plot metric '{requested_key}' was not logged. " | |
| f"Available numeric keys: {available}" | |
| ) | |
| return requested_key | |
| excluded = { | |
| "epoch", | |
| "loss", | |
| "learning_rate", | |
| "step", | |
| "total_flos", | |
| "train_loss", | |
| "train_runtime", | |
| "train_samples_per_second", | |
| "train_steps_per_second", | |
| } | |
| if reward_key: | |
| excluded.add(reward_key) | |
| preferred = [ | |
| "kl", | |
| "objective/kl", | |
| "completion_length", | |
| "mean_completion_length", | |
| "grad_norm", | |
| "entropy", | |
| "accuracy", | |
| "learning_rate", | |
| "epoch", | |
| ] | |
| numeric_set = set(numeric_keys) | |
| for key in preferred: | |
| if key in numeric_set and key not in excluded: | |
| return key | |
| candidates = [ | |
| key for key in numeric_keys | |
| if key not in excluded and "reward" not in key.lower() | |
| ] | |
| if candidates: | |
| return candidates[0] | |
| for fallback in ("learning_rate", "epoch"): | |
| if fallback in numeric_set: | |
| return fallback | |
| return None | |
| def save_plot( | |
| path: Path, | |
| *, | |
| series: Sequence[Tuple[float, float]], | |
| title: str, | |
| ylabel: str, | |
| ) -> None: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| if series: | |
| x_values, y_values = zip(*series) | |
| ax.plot(x_values, y_values, marker="o", linewidth=1.8) | |
| else: | |
| ax.text( | |
| 0.5, | |
| 0.5, | |
| "No logged data available", | |
| ha="center", | |
| va="center", | |
| transform=ax.transAxes, | |
| ) | |
| ax.set_title(title) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel(ylabel) | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| def save_training_plots( | |
| log_history: Sequence[Dict[str, Any]], | |
| output_dir: str | Path, | |
| metric_key: Optional[str] = None, | |
| ) -> Dict[str, str]: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| reward_key = select_reward_key(log_history) | |
| selected_metric_key = select_metric_key( | |
| log_history, | |
| reward_key=reward_key, | |
| requested_key=metric_key, | |
| ) | |
| loss_series = extract_log_series(log_history, "loss") | |
| reward_series = extract_log_series(log_history, reward_key) | |
| metric_series = extract_log_series(log_history, selected_metric_key) | |
| loss_path = output_path / "training_loss.png" | |
| reward_path = output_path / "training_reward.png" | |
| metric_path = output_path / "training_metric.png" | |
| dashboard_path = output_path / "training_dashboard.png" | |
| manifest_path = output_path / "training_plot_manifest.json" | |
| save_plot(loss_path, series=loss_series, title="Training Loss", ylabel="Loss") | |
| save_plot( | |
| reward_path, | |
| series=reward_series, | |
| title=f"Training Reward ({reward_key or 'not logged'})", | |
| ylabel="Reward", | |
| ) | |
| save_plot( | |
| metric_path, | |
| series=metric_series, | |
| title=f"Training Metric ({selected_metric_key or 'not logged'})", | |
| ylabel=selected_metric_key or "Metric", | |
| ) | |
| fig, axes = plt.subplots(3, 1, figsize=(10, 12)) | |
| plot_specs = [ | |
| (axes[0], loss_series, "Training Loss", "Loss"), | |
| (axes[1], reward_series, f"Training Reward ({reward_key or 'not logged'})", "Reward"), | |
| ( | |
| axes[2], | |
| metric_series, | |
| f"Training Metric ({selected_metric_key or 'not logged'})", | |
| selected_metric_key or "Metric", | |
| ), | |
| ] | |
| for axis, series, title, ylabel in plot_specs: | |
| if series: | |
| x_values, y_values = zip(*series) | |
| axis.plot(x_values, y_values, marker="o", linewidth=1.8) | |
| else: | |
| axis.text( | |
| 0.5, | |
| 0.5, | |
| "No logged data available", | |
| ha="center", | |
| va="center", | |
| transform=axis.transAxes, | |
| ) | |
| axis.set_title(title) | |
| axis.set_xlabel("Step") | |
| axis.set_ylabel(ylabel) | |
| axis.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(dashboard_path, dpi=150) | |
| plt.close(fig) | |
| manifest = { | |
| "available_numeric_keys": available_numeric_log_keys(log_history), | |
| "reward_key": reward_key, | |
| "metric_key": selected_metric_key, | |
| "plots": { | |
| "loss": str(loss_path), | |
| "reward": str(reward_path), | |
| "metric": str(metric_path), | |
| "dashboard": str(dashboard_path), | |
| }, | |
| } | |
| manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8") | |
| return manifest["plots"] | |
| def run_dry_run_preview( | |
| examples: Sequence[Dict[str, str]], | |
| reward_fn: OpenEnvReward, | |
| output_dir: str, | |
| ) -> None: | |
| if not examples: | |
| raise ValueError("No training prompts were generated for the dry run.") | |
| sample = examples[0] | |
| sample_reward = reward_fn( | |
| completions=[[{"role": "assistant", "content": sample["reference_action"]}]], | |
| scenario_name=[sample["scenario_name"]], | |
| history_actions=[sample["history_actions"]], | |
| )[0] | |
| print(f"Built {len(examples)} prompt states.") | |
| print(f"Output directory: {Path(output_dir)}") | |
| print(f"Sample scenario: {sample['scenario_name']}") | |
| print(f"Sample reward for reference action: {sample_reward:+.3f}") | |
| print("\nSample prompt:\n") | |
| print(sample["prompt"]) | |
| def resolve_torch_runtime() -> Dict[str, Any]: | |
| import torch | |
| use_cuda = torch.cuda.is_available() | |
| bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False | |
| dtype = torch.bfloat16 if bf16 else ( | |
| torch.float16 if use_cuda else torch.float32 | |
| ) | |
| return { | |
| "use_cuda": use_cuda, | |
| "device": "cuda:0" if use_cuda else "cpu", | |
| "dtype": dtype, | |
| "bf16": bf16, | |
| "fp16": use_cuda and not bf16, | |
| "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu", | |
| } | |
| def _guard_invalid_torchao_version() -> None: | |
| """Treat malformed torchao installs as unavailable for HF imports.""" | |
| import functools | |
| import importlib.metadata as importlib_metadata | |
| import sys | |
| from packaging.version import InvalidVersion, Version | |
| if getattr(importlib_metadata, "_openenv_torchao_guard_installed", False): | |
| metadata_guard_installed = True | |
| else: | |
| original_version = importlib_metadata.version | |
| def guarded_version(distribution_name: str) -> str: | |
| version = original_version(distribution_name) | |
| if distribution_name.lower() == "torchao": | |
| try: | |
| Version(version) | |
| except InvalidVersion as exc: | |
| raise importlib_metadata.PackageNotFoundError( | |
| f"Malformed torchao version metadata: {version!r}" | |
| ) from exc | |
| return version | |
| importlib_metadata.version = guarded_version | |
| importlib_metadata._openenv_torchao_guard_installed = True | |
| metadata_guard_installed = False | |
| import_utils = sys.modules.get("transformers.utils.import_utils") | |
| if import_utils is not None and not getattr( | |
| import_utils, "_openenv_torchao_guard_installed", False | |
| ): | |
| original_is_package_available = import_utils._is_package_available | |
| def guarded_is_package_available( | |
| pkg_name: str, | |
| return_version: bool = False, | |
| ): | |
| if pkg_name != "torchao": | |
| return original_is_package_available(pkg_name, return_version=return_version) | |
| is_available, package_version = original_is_package_available( | |
| pkg_name, | |
| return_version=True, | |
| ) | |
| if not is_available: | |
| return (False, package_version) if return_version else (False, None) | |
| try: | |
| Version(package_version) | |
| except InvalidVersion: | |
| return (False, "0") if return_version else (False, None) | |
| return (True, package_version) if return_version else (True, None) | |
| min_version = getattr(import_utils, "TORCHAO_MIN_VERSION", "0") | |
| def guarded_is_torchao_available(min_version_override: str = min_version) -> bool: | |
| is_available, package_version = guarded_is_package_available( | |
| "torchao", | |
| return_version=True, | |
| ) | |
| if not is_available: | |
| return False | |
| try: | |
| return Version(package_version) >= Version(min_version_override) | |
| except InvalidVersion: | |
| return False | |
| if hasattr(import_utils.is_torchao_available, "cache_clear"): | |
| import_utils.is_torchao_available.cache_clear() | |
| import_utils._is_package_available = guarded_is_package_available | |
| import_utils.is_torchao_available = guarded_is_torchao_available | |
| import_utils._openenv_torchao_guard_installed = True | |
| transformers_utils = sys.modules.get("transformers.utils") | |
| if transformers_utils is not None: | |
| transformers_utils.is_torchao_available = guarded_is_torchao_available | |
| if metadata_guard_installed and import_utils is None: | |
| return | |
| def _guard_partial_vllm_install() -> None: | |
| """Treat partial vLLM installs as unavailable for TRL imports.""" | |
| import functools | |
| import importlib | |
| try: | |
| import trl.import_utils as trl_import_utils | |
| except Exception: | |
| return | |
| if getattr(trl_import_utils, "_openenv_vllm_guard_installed", False): | |
| return | |
| def _has_usable_vllm() -> bool: | |
| try: | |
| importlib.import_module("vllm") | |
| importlib.import_module("vllm.distributed.device_communicators.pynccl") | |
| importlib.import_module("vllm.distributed.utils") | |
| except Exception: | |
| return False | |
| return True | |
| def guarded_is_vllm_available(*args: Any, **kwargs: Any) -> bool: | |
| return _has_usable_vllm() | |
| if hasattr(trl_import_utils.is_vllm_available, "cache_clear"): | |
| trl_import_utils.is_vllm_available.cache_clear() | |
| trl_import_utils.is_vllm_available = guarded_is_vllm_available | |
| trl_import_utils._openenv_vllm_guard_installed = True | |
| def load_model_artifacts( | |
| model_id: str, | |
| *, | |
| trust_remote_code: bool, | |
| ): | |
| _guard_invalid_torchao_version() | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| runtime = resolve_torch_runtime() | |
| print(f"Loading tokenizer for {model_id} ...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f"Loading model for {model_id} ...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=trust_remote_code, | |
| torch_dtype=runtime["dtype"], | |
| ) | |
| if runtime["use_cuda"]: | |
| model = model.to(runtime["device"]) | |
| else: | |
| model = model.to("cpu") | |
| return tokenizer, model | |
| def build_openenv_reward(args: argparse.Namespace) -> OpenEnvReward: | |
| """Return the OpenEnv-compatible reward callable used by GRPO.""" | |
| return OpenEnvReward( | |
| reward_backend=args.reward_backend, | |
| base_url=args.base_url, | |
| domain_randomise=args.domain_randomise, | |
| ) | |
| def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]: | |
| """Build the OpenEnv rollout states that seed GRPO prompts.""" | |
| scenario_names = selected_scenarios(args.scenario_name) | |
| examples = build_prompt_examples( | |
| dataset_episodes=args.dataset_episodes, | |
| rollout_steps=args.rollout_steps, | |
| collection_policy=args.collection_policy, | |
| scenario_names=scenario_names, | |
| seed=args.seed, | |
| domain_randomise=args.domain_randomise, | |
| ) | |
| return { | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| } | |
| def build_grpo_config( | |
| args: argparse.Namespace, | |
| runtime: Dict[str, Any], | |
| ): | |
| import inspect | |
| _guard_invalid_torchao_version() | |
| _guard_partial_vllm_install() | |
| from trl import GRPOConfig | |
| config_kwargs = { | |
| "output_dir": args.output_dir, | |
| "learning_rate": args.learning_rate, | |
| "per_device_train_batch_size": args.per_device_train_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| "num_generations": args.num_generations, | |
| "max_completion_length": args.max_completion_length, | |
| "max_prompt_length": args.max_prompt_length, | |
| "num_train_epochs": args.num_train_epochs, | |
| "logging_steps": args.logging_steps, | |
| "save_steps": args.save_steps, | |
| "bf16": runtime["bf16"], | |
| "fp16": runtime["fp16"], | |
| "report_to": "none", | |
| "remove_unused_columns": False, | |
| } | |
| supported_params = set(inspect.signature(GRPOConfig.__init__).parameters) | |
| if ( | |
| "max_length" in supported_params | |
| and "max_prompt_length" not in supported_params | |
| and "max_completion_length" not in supported_params | |
| ): | |
| config_kwargs["max_length"] = ( | |
| args.max_prompt_length + args.max_completion_length | |
| ) | |
| filtered_kwargs = { | |
| key: value | |
| for key, value in config_kwargs.items() | |
| if key in supported_params | |
| } | |
| skipped = sorted(set(config_kwargs) - set(filtered_kwargs)) | |
| if skipped: | |
| print( | |
| "GRPOConfig compatibility: skipping unsupported fields " | |
| f"{', '.join(skipped)}" | |
| ) | |
| return GRPOConfig(**filtered_kwargs) | |
| def build_grpo_trainer( | |
| *, | |
| model: Any, | |
| tokenizer: Any, | |
| reward_func: Any, | |
| train_dataset: Any, | |
| args: argparse.Namespace, | |
| runtime: Dict[str, Any], | |
| ): | |
| _guard_invalid_torchao_version() | |
| _guard_partial_vllm_install() | |
| from trl import GRPOTrainer | |
| config = build_grpo_config(args, runtime) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_func, | |
| args=config, | |
| train_dataset=train_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| # Bolt on the live evidence callback so the trainer Space dashboard | |
| # gets per-step CSV updates while the run is in flight. The callback | |
| # only writes through DrugEnv's existing public API, so it cannot | |
| # impact training correctness — failures are swallowed inside the | |
| # callback itself. | |
| try: | |
| from training.live_callback import LiveTrainingCallback | |
| trainer.add_callback(LiveTrainingCallback( | |
| evidence_dir=getattr(args, "evidence_dir", "evidence"), | |
| checkpoint_eval_steps=getattr(args, "checkpoint_eval_steps", 50), | |
| checkpoint_eval_episodes=getattr(args, "checkpoint_eval_episodes", 4), | |
| )) | |
| except Exception as exc: | |
| print(f"[live-callback] not installed: {exc}") | |
| return trainer | |
| def generate_action_with_model( | |
| model: Any, | |
| tokenizer: Any, | |
| prompt_or_observation: str | ValidationObservation, | |
| *, | |
| max_new_tokens: int = DEFAULT_COMPLETION_TOKEN_BUDGET, | |
| temperature: float = 0.2, | |
| top_p: float = 0.9, | |
| do_sample: bool = True, | |
| ) -> Dict[str, Any]: | |
| import torch | |
| if isinstance(prompt_or_observation, ValidationObservation): | |
| prompt = build_training_prompt(prompt_or_observation) | |
| obs_for_terminal: Optional[ValidationObservation] = prompt_or_observation | |
| else: | |
| prompt = str(prompt_or_observation) | |
| obs_for_terminal = None | |
| model_device = getattr(model, "device", None) | |
| if model_device is None: | |
| model_device = resolve_torch_runtime()["device"] | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {key: value.to(model_device) for key, value in inputs.items()} | |
| prompt_tokens = inputs["input_ids"].shape[1] | |
| generation_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| } | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **generation_kwargs) | |
| new_tokens = output_ids[0][prompt_tokens:] | |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| action = parse_action_completion(response_text) | |
| if action is not None: | |
| action = ensure_terminal_payload(obs_for_terminal, action) | |
| return { | |
| "prompt": prompt, | |
| "response_text": response_text, | |
| "action": action, | |
| } | |
| def run_training(args: argparse.Namespace) -> Dict[str, Any]: | |
| random.seed(args.seed) | |
| runtime = resolve_torch_runtime() | |
| if args.load_model_only: | |
| tokenizer, model = load_model_artifacts( | |
| args.model_id, | |
| trust_remote_code=args.trust_remote_code, | |
| ) | |
| device = getattr(model, "device", "unknown") | |
| print(f"Model ready: {args.model_id}") | |
| print(f"Tokenizer vocab size: {len(tokenizer)}") | |
| print(f"Model device: {device}") | |
| print(f"Runtime device name: {runtime['device_name']}") | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| } | |
| prompt_data = prepare_prompt_examples(args) | |
| scenario_names = prompt_data["scenario_names"] | |
| examples = prompt_data["examples"] | |
| reward_fn = build_openenv_reward(args) | |
| if args.dry_run: | |
| run_dry_run_preview(examples, reward_fn, args.output_dir) | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| "reward_fn": reward_fn, | |
| } | |
| from datasets import Dataset | |
| train_dataset = Dataset.from_list(examples) | |
| tokenizer, model = load_model_artifacts( | |
| args.model_id, | |
| trust_remote_code=args.trust_remote_code, | |
| ) | |
| print( | |
| f"Training runtime: device={runtime['device']} " | |
| f"name={runtime['device_name']} " | |
| f"dtype={runtime['dtype']}" | |
| ) | |
| print( | |
| "OpenEnv reward: " | |
| f"backend={args.reward_backend} scenarios={len(scenario_names)} " | |
| f"examples={len(examples)}" | |
| ) | |
| trainer = build_grpo_trainer( | |
| model=model, | |
| train_dataset=train_dataset, | |
| tokenizer=tokenizer, | |
| reward_func=reward_fn, | |
| args=args, | |
| runtime=runtime, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| if args.push_to_hub: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True) | |
| print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}") | |
| api.upload_folder( | |
| folder_path=args.output_dir, | |
| repo_id=args.push_to_hub, | |
| repo_type="model", | |
| create_pr=False, | |
| ) | |
| print(f"Model pushed to https://huggingface.co/{args.push_to_hub}") | |
| plot_paths = save_training_plots( | |
| trainer.state.log_history, | |
| args.output_dir, | |
| metric_key=args.plot_metric_key, | |
| ) | |
| print("Saved training plots:") | |
| for plot_name, plot_path in plot_paths.items(): | |
| print(f" - {plot_name}: {plot_path}") | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| "reward_fn": reward_fn, | |
| "train_dataset": train_dataset, | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| "trainer": trainer, | |
| "plot_paths": plot_paths, | |
| } | |
| def main() -> None: | |
| run_training(parse_args()) | |
| if __name__ == "__main__": | |
| main() | |