drugenv-trainer / training /training_script.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Train a drug-target-validation planner with TRL GRPO and OpenEnv rewards."""
from __future__ import annotations
import argparse
import json
import random
import re
from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from client import DrugTargetEnv
from models import (
ActionType,
DrugTargetAction,
ValidationObservation,
build_agent_observation_context,
build_agent_system_prompt,
)
from server.hackathon_environment import DrugTargetEnvironment
from server.tasks.scenarios import SCENARIO_LIBRARY
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
DEFAULT_OUTPUT_DIR = "training/grpo-output"
DEFAULT_BASE_URL = "http://localhost:8000"
DEFAULT_COMPLETION_TOKEN_BUDGET = 160
INVALID_ACTION_PENALTY = -2.0
ENVIRONMENT_ERROR_PENALTY = -4.0
SYSTEM_PROMPT = build_agent_system_prompt()
ACTION_TYPES = [action.value for action in ActionType]
ACTION_TYPE_ALIASES: Dict[str, str] = {
"expression_lookup": ActionType.QUERY_EXPRESSION.value,
"tissue_expression": ActionType.QUERY_EXPRESSION.value,
"gtex_query": ActionType.QUERY_EXPRESSION.value,
"de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
"differential_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
"pathway": ActionType.PATHWAY_ENRICHMENT.value,
"coexpression": ActionType.COEXPRESSION_NETWORK.value,
"structure_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value,
"alphafold_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value,
"binding_site": ActionType.BINDING_SITE_ANALYSIS.value,
"ppi": ActionType.PROTEIN_INTERACTION_NETWORK.value,
"interaction_network": ActionType.PROTEIN_INTERACTION_NETWORK.value,
"druggability": ActionType.DRUGGABILITY_SCREEN.value,
"trial_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value,
"clinical_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value,
"toxicity": ActionType.TOXICITY_PANEL.value,
"tox_panel": ActionType.TOXICITY_PANEL.value,
"off_target": ActionType.OFF_TARGET_SCREEN.value,
"selectivity": ActionType.OFF_TARGET_SCREEN.value,
"stratification": ActionType.PATIENT_STRATIFICATION.value,
"patient_subset": ActionType.PATIENT_STRATIFICATION.value,
"literature": ActionType.LITERATURE_SEARCH.value,
"pubmed": ActionType.LITERATURE_SEARCH.value,
"synthesis": ActionType.EVIDENCE_SYNTHESIS.value,
"competitor": ActionType.COMPETITOR_LANDSCAPE.value,
"in_vitro": ActionType.IN_VITRO_ASSAY.value,
"cell_assay": ActionType.IN_VITRO_ASSAY.value,
"in_vivo": ActionType.IN_VIVO_MODEL.value,
"mouse_model": ActionType.IN_VIVO_MODEL.value,
"crispr": ActionType.CRISPR_KNOCKOUT.value,
"biomarker": ActionType.BIOMARKER_CORRELATION.value,
"red_flag": ActionType.FLAG_RED_FLAG.value,
"expert_review": ActionType.REQUEST_EXPERT_REVIEW.value,
"submit_report": ActionType.SUBMIT_VALIDATION_REPORT.value,
"final_report": ActionType.SUBMIT_VALIDATION_REPORT.value,
"submit": ActionType.SUBMIT_VALIDATION_REPORT.value,
}
# Heuristic teacher policy used to seed GRPO prompts. Mirrors the default
# pipeline used by the inference script in ``run_agent.py``.
HEURISTIC_SEQUENCE: List[ActionType] = [
ActionType.QUERY_EXPRESSION,
ActionType.DRUGGABILITY_SCREEN,
ActionType.OFF_TARGET_SCREEN,
ActionType.TOXICITY_PANEL,
ActionType.CLINICAL_TRIAL_LOOKUP,
ActionType.LITERATURE_SEARCH,
ActionType.SUBMIT_VALIDATION_REPORT,
]
VALID_ACTION_TYPES = set(ACTION_TYPES)
def compact_preview(value: Any, max_chars: int = 160) -> str:
try:
text = json.dumps(value, ensure_ascii=True, sort_keys=True)
except TypeError:
text = str(value)
text = re.sub(r"\s+", " ", text).strip()
if len(text) <= max_chars:
return text
return text[: max_chars - 3] + "..."
def _edit_distance(a: str, b: str) -> int:
if len(a) < len(b):
return _edit_distance(b, a)
if not b:
return len(a)
prev = list(range(len(b) + 1))
for i, ca in enumerate(a):
curr = [i + 1]
for j, cb in enumerate(b):
curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
prev = curr
return prev[-1]
def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
for name in names:
if name in payload:
return payload[name]
lowered = {str(key).lower(): value for key, value in payload.items()}
for name in names:
if name.lower() in lowered:
return lowered[name.lower()]
for key, value in lowered.items():
for name in names:
threshold = max(2, len(name) // 3)
if _edit_distance(key, name.lower()) <= threshold:
return value
return None
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Train a GRPO policy against the OpenEnv drug-target-validation "
"environment."
)
)
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
parser.add_argument("--dataset-episodes", type=int, default=8)
parser.add_argument("--rollout-steps", type=int, default=6)
parser.add_argument(
"--collection-policy",
choices=["random", "heuristic"],
default="heuristic",
help="Policy used to build prompt states for GRPO training.",
)
parser.add_argument(
"--reward-backend",
choices=["local", "remote"],
default="local",
help="Use local in-process scoring or a live OpenEnv server.",
)
parser.add_argument(
"--base-url",
default=DEFAULT_BASE_URL,
help="Base URL for the OpenEnv server when reward-backend=remote.",
)
parser.add_argument(
"--scenario-name",
action="append",
default=None,
help="Repeatable scenario selector. Defaults to all curated scenarios.",
)
parser.add_argument(
"--domain-randomise",
action="store_true",
help="Enable domain randomisation while building prompts and local rewards.",
)
parser.add_argument("--num-generations", type=int, default=2)
parser.add_argument(
"--max-completion-length",
type=int,
default=DEFAULT_COMPLETION_TOKEN_BUDGET,
)
parser.add_argument("--max-prompt-length", type=int, default=768)
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--logging-steps", type=int, default=1)
parser.add_argument("--save-steps", type=int, default=50)
parser.add_argument(
"--plot-metric-key",
default=None,
help="Optional extra metric key from trainer log history to plot.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--load-model-only",
action="store_true",
help="Download and load the selected model and tokenizer, then exit.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Pass trust_remote_code=True to model/tokenizer loading.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Build the prompt dataset and smoke-test the reward function without training.",
)
parser.add_argument(
"--push-to-hub",
type=str,
default=None,
help="HuggingFace Hub repo id to push the trained model to (e.g. 'myuser/my-model').",
)
# ── Live evidence hooks (drug-domain ``LiveTrainingCallback``).
# All three knobs are best-effort: a missing ``transformers`` /
# read-only ``evidence/`` dir reduces the callback to a no-op so
# training itself still proceeds.
parser.add_argument(
"--evidence-dir",
default="evidence",
help="Directory for training_log.csv + checkpoint_evals.csv.",
)
parser.add_argument(
"--checkpoint-eval-steps",
type=int,
default=50,
help="Run a held-out heuristic eval every N GRPO updates.",
)
parser.add_argument(
"--checkpoint-eval-episodes",
type=int,
default=4,
help="Number of held-out episodes per mid-training eval.",
)
return parser
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
return build_argument_parser().parse_args(argv)
def make_training_args(**overrides: Any) -> argparse.Namespace:
"""Build an argparse-style namespace for notebooks and scripts."""
parser = build_argument_parser()
defaults = vars(parser.parse_args([]))
unknown = sorted(set(overrides) - set(defaults))
if unknown:
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
defaults.update(overrides)
return argparse.Namespace(**defaults)
def format_observation(obs: ValidationObservation) -> str:
parts = [
f"TARGET: {obs.target_gene} | INDICATION: {obs.indication}",
f"DISEASE CONTEXT: {obs.disease_context}",
f"Step: {obs.step_index} | Credits: "
f"{obs.credits_remaining}/{obs.credits_total}",
]
context = build_agent_observation_context(obs, max_tools=5)
if context:
parts.append(context)
if obs.pipeline_history:
last5 = obs.pipeline_history[-5:]
parts.append("Recent history:")
for step in last5:
tag = "OK" if step.get("success") else "FAIL"
line = (
f" [{tag}] {step.get('action_type')}: "
f"{str(step.get('output_summary', ''))[:80]}"
)
parts.append(line)
completed = {
step.get("action_type") for step in obs.pipeline_history if step.get("success")
}
if completed:
parts.append(
"Completed actions (try not to repeat): "
+ ", ".join(sorted(map(str, completed)))
)
remaining = [
action.value for action in HEURISTIC_SEQUENCE
if action.value not in completed
]
if remaining:
parts.append(f"Suggested next actions (pick one): {', '.join(remaining)}")
dossier = obs.dossier
if dossier.expression_findings:
parts.append(
f"Expression findings: {compact_preview(dossier.expression_findings, 200)}"
)
if dossier.protein_findings:
parts.append(
f"Protein findings: {compact_preview(dossier.protein_findings, 200)}"
)
if dossier.safety_findings:
parts.append(
f"Safety findings: {compact_preview(dossier.safety_findings, 200)}"
)
if dossier.clinical_findings:
parts.append(
f"Clinical findings: {compact_preview(dossier.clinical_findings, 200)}"
)
if dossier.experimental_results:
parts.append(
f"Experimental results: {compact_preview(dossier.experimental_results, 200)}"
)
if obs.latest_output and obs.latest_output.data:
parts.append(
f"Latest data: {compact_preview(obs.latest_output.data, 200)}"
)
if obs.rule_violations:
parts.append(f"VIOLATIONS: {obs.rule_violations}")
parts.append(
'Output ONLY a single JSON object with these exact keys, no '
'comments, no extra text:\n'
'{"action_type": "<one of the available actions>", '
'"parameters": {}, "reasoning": "<why>", '
'"final_decision": null, "confidence": null}\n'
'When picking submit_validation_report, set "final_decision" to '
'"go" or "no_go" and "confidence" to a number in [0, 1].'
)
return "\n".join(parts)
def build_training_prompt(obs: ValidationObservation) -> str:
return f"{SYSTEM_PROMPT}\n\n{format_observation(obs)}"
def heuristic_next_action(
history: Sequence[ActionType],
step_index: int,
) -> ActionType:
seen = set(history)
for action in HEURISTIC_SEQUENCE:
if action not in seen:
return action
return ActionType.SUBMIT_VALIDATION_REPORT
def pick_action(
policy: str,
step_index: int,
history: Sequence[ActionType],
) -> ActionType:
if policy == "random":
return random.choice(list(ActionType))
return heuristic_next_action(history, step_index)
def _heuristic_decision(obs: ValidationObservation) -> Tuple[str, float]:
"""Best-effort go / no-go from the running dossier.
Looks at the most-recent druggability_score and clinical_precedent in
the dossier, falls back to ``no_go`` with mid confidence when the
evidence is empty.
"""
decision = "no_go"
confidence = 0.55
drug = (obs.dossier.protein_findings or {}).get(
ActionType.DRUGGABILITY_SCREEN.value, {}
)
score = drug.get("druggability_score") if isinstance(drug, dict) else None
if isinstance(score, (int, float)):
if score >= 0.55:
decision = "go"
confidence = float(min(0.9, 0.55 + score * 0.4))
else:
decision = "no_go"
confidence = float(min(0.9, 0.55 + (1.0 - score) * 0.4))
clinical = (obs.dossier.clinical_findings or {}).get(
ActionType.CLINICAL_TRIAL_LOOKUP.value, {}
)
precedent = clinical.get("clinical_precedent") if isinstance(clinical, dict) else None
if precedent == "negative":
decision = "no_go"
confidence = max(confidence, 0.7)
elif precedent == "positive":
decision = "go"
confidence = max(confidence, 0.7)
return decision, round(confidence, 3)
def build_drug_target_action(
action_type: ActionType,
obs: ValidationObservation,
) -> DrugTargetAction:
"""Construct a heuristically-reasonable action for the given type."""
parameters: Dict[str, Any] = {}
reasoning = f"Advance the investigation with {action_type.value}."
final_decision: Optional[str] = None
confidence: Optional[float] = None
if action_type == ActionType.QUERY_EXPRESSION:
parameters = {"database": "GTEx"}
reasoning = "Establish baseline tissue expression for the target."
elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
parameters = {"cohort": "TCGA"}
reasoning = "Confirm disease-driven dysregulation of the target."
elif action_type == ActionType.PATHWAY_ENRICHMENT:
parameters = {"library": "Reactome"}
reasoning = "Map the target into known pathway context."
elif action_type == ActionType.COEXPRESSION_NETWORK:
parameters = {"source": "ARCHS4"}
reasoning = "Find functionally-related genes for mechanism reasoning."
elif action_type == ActionType.PROTEIN_STRUCTURE_LOOKUP:
parameters = {"method": "AlphaFold"}
reasoning = "Pull the target's predicted 3D structure."
elif action_type == ActionType.BINDING_SITE_ANALYSIS:
parameters = {"include_allosteric": True}
reasoning = "Detect classic and allosteric pockets."
elif action_type == ActionType.PROTEIN_INTERACTION_NETWORK:
parameters = {"source": "STRING"}
reasoning = "Map first-degree interactors for off-target reasoning."
elif action_type == ActionType.DRUGGABILITY_SCREEN:
parameters = {"source": "ChEMBL"}
reasoning = "Score overall druggability and known ligand chemistry."
elif action_type == ActionType.CLINICAL_TRIAL_LOOKUP:
parameters = {"source": "ClinicalTrials_gov"}
reasoning = "Check clinical precedent for this target / indication."
elif action_type == ActionType.TOXICITY_PANEL:
parameters = {"source": "ToxCast"}
reasoning = "Probe target-mediated toxicity in light of expression."
elif action_type == ActionType.OFF_TARGET_SCREEN:
parameters = {"source": "SafetyPanel"}
reasoning = "Quantify selectivity and paralog liabilities."
elif action_type == ActionType.PATIENT_STRATIFICATION:
parameters = {"source": "ClinVar"}
reasoning = "Identify responder subpopulations."
elif action_type == ActionType.LITERATURE_SEARCH:
parameters = {
"query": f"{obs.target_gene} {obs.indication}".strip()
}
reasoning = "Scan recent literature for precedent that updates priors."
elif action_type == ActionType.EVIDENCE_SYNTHESIS:
reasoning = "Aggregate prior dossier into a coherent picture."
elif action_type == ActionType.COMPETITOR_LANDSCAPE:
parameters = {"source": "DrugBank"}
reasoning = "Survey competing programs against the same target."
elif action_type == ActionType.IN_VITRO_ASSAY:
parameters = {"panel": "InVitroPanel"}
reasoning = "Confirm computational evidence with cell-line activity."
elif action_type == ActionType.IN_VIVO_MODEL:
parameters = {"model": "MouseModel"}
reasoning = "Confirm in-vitro signal with disease-relevant in-vivo data."
elif action_type == ActionType.CRISPR_KNOCKOUT:
parameters = {"panel": "DepMap"}
reasoning = "Test functional dependency of the target."
elif action_type == ActionType.BIOMARKER_CORRELATION:
parameters = {"source": "BiomarkerPanel"}
reasoning = "Correlate target activity with patient biomarkers."
elif action_type == ActionType.FLAG_RED_FLAG:
parameters = {"note": "potential concern noted by the agent"}
reasoning = "Record a concern without spending credits."
elif action_type == ActionType.REQUEST_EXPERT_REVIEW:
parameters = {"focus": "current_dossier"}
reasoning = "Get a lightweight expert critique of the dossier."
elif action_type == ActionType.SUBMIT_VALIDATION_REPORT:
decision, conf = _heuristic_decision(obs)
final_decision = decision
confidence = conf
reasoning = (
f"Submit {decision} with confidence {conf:.2f} based on the "
"evidence gathered so far."
)
return DrugTargetAction(
action_type=action_type,
parameters=parameters,
reasoning=reasoning,
final_decision=final_decision,
confidence=confidence,
)
def selected_scenarios(requested: Optional[Sequence[str]]) -> List[str]:
from server.tasks.procedural_generator import generate_procedural_scenarios
all_scenarios = list(SCENARIO_LIBRARY) + generate_procedural_scenarios(
n=20, seed=42
)
available = [scenario.name for scenario in all_scenarios]
if not requested:
return available
unknown = sorted(set(requested) - set(available))
if unknown:
raise ValueError(f"Unknown scenarios requested: {', '.join(unknown)}")
return list(requested)
def action_completion_json(action: DrugTargetAction) -> str:
payload = {
"action_type": action.action_type.value,
"parameters": action.parameters,
"reasoning": action.reasoning,
"final_decision": action.final_decision,
"confidence": action.confidence,
}
return json.dumps(payload)
def build_prompt_examples(
*,
dataset_episodes: int,
rollout_steps: int,
collection_policy: str,
scenario_names: Sequence[str],
seed: int,
domain_randomise: bool,
) -> List[Dict[str, str]]:
rng = random.Random(seed)
examples: List[Dict[str, str]] = []
scenario_cycle = list(scenario_names)
rng.shuffle(scenario_cycle)
for episode_idx in range(dataset_episodes):
scenario_name = scenario_cycle[episode_idx % len(scenario_cycle)]
env = DrugTargetEnvironment(
scenario_name=scenario_name,
domain_randomise=domain_randomise,
)
obs = env.reset()
history_actions: List[DrugTargetAction] = []
for step_idx in range(rollout_steps):
if obs.done:
break
next_action = build_drug_target_action(
action_type=pick_action(
collection_policy,
step_idx,
[action.action_type for action in history_actions],
),
obs=obs,
)
examples.append({
"prompt": build_training_prompt(obs),
"scenario_name": scenario_name,
"history_actions": json.dumps(
[action.model_dump() for action in history_actions]
),
"rng_seed": str(env._latent.rng_seed if env._latent else 0),
"reference_action": action_completion_json(next_action),
"problem_statement": (
f"Validate {obs.target_gene} in {obs.indication}"
),
})
history_actions.append(next_action)
obs = env.step(next_action)
return examples
def completion_to_text(completion: Any) -> str:
if isinstance(completion, str):
return completion.strip()
if isinstance(completion, dict):
return content_to_text(completion.get("content", ""))
if isinstance(completion, list):
for item in reversed(completion):
if isinstance(item, dict) and "content" in item:
text = content_to_text(item["content"])
if text:
return text
if isinstance(item, str) and item.strip():
return item.strip()
return str(completion).strip()
def content_to_text(content: Any) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict):
if isinstance(part.get("text"), str):
parts.append(part["text"])
elif isinstance(part.get("content"), str):
parts.append(part["content"])
return "".join(parts).strip()
return str(content).strip()
def _repair_truncated_json(text: str) -> Optional[str]:
"""Try to repair JSON truncated mid-value (common with small LLMs)."""
s = text.strip()
if not s.startswith("{"):
return None
s = re.sub(r',\s*"[^"\n]*$', '', s)
s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
in_string = False
escape = False
for ch in s:
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
if ch == '"':
in_string = not in_string
if in_string:
s += '"'
open_braces = s.count("{") - s.count("}")
open_brackets = s.count("[") - s.count("]")
s += "]" * max(0, open_brackets)
s += "}" * max(0, open_braces)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return s
except json.JSONDecodeError:
pass
s = re.sub(r',\s*([}\]])', r'\1', s)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return s
except json.JSONDecodeError:
pass
return None
def _normalize_jsonish_text(text: str) -> str:
text = _strip_js_comments(text)
text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
return text
def _strip_js_comments(text: str) -> str:
text = re.sub(r'//[^\n]*', '', text)
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
return text
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
stripped = _normalize_jsonish_text(text).strip()
fence_prefix = "```"
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
lines = stripped.splitlines()
if len(lines) >= 3:
stripped = "\n".join(lines[1:-1]).strip()
candidates: List[str] = [stripped]
start = stripped.find("{")
while start != -1:
depth = 0
for idx in range(start, len(stripped)):
char = stripped[idx]
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
candidates.append(stripped[start:idx + 1])
break
start = stripped.find("{", start + 1)
first_brace = stripped.find("{")
if first_brace != -1:
repaired = _repair_truncated_json(stripped[first_brace:])
if repaired is not None:
candidates.append(repaired)
candidates.sort(key=len, reverse=True)
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def normalize_optional_string(value: Any) -> Optional[str]:
if value is None or isinstance(value, bool):
return None
if isinstance(value, str):
value = value.strip()
return value or None
if isinstance(value, (int, float)):
return str(value)
return compact_preview(value, 80)
def normalize_action_type(raw_action_type: Any) -> Optional[str]:
if not isinstance(raw_action_type, str):
return None
candidate = raw_action_type.strip().lower()
if candidate in ACTION_TYPES:
return candidate
if candidate in ACTION_TYPE_ALIASES:
return ACTION_TYPE_ALIASES[candidate]
candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
if candidate in ACTION_TYPES:
return candidate
if candidate in ACTION_TYPE_ALIASES:
return ACTION_TYPE_ALIASES[candidate]
heuristics = [
(("expression",), ActionType.QUERY_EXPRESSION.value),
(("differential",), ActionType.DIFFERENTIAL_EXPRESSION.value),
(("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
(("coexpression",), ActionType.COEXPRESSION_NETWORK.value),
(("structure",), ActionType.PROTEIN_STRUCTURE_LOOKUP.value),
(("binding", "site"), ActionType.BINDING_SITE_ANALYSIS.value),
(("interaction",), ActionType.PROTEIN_INTERACTION_NETWORK.value),
(("druggab",), ActionType.DRUGGABILITY_SCREEN.value),
(("clinical",), ActionType.CLINICAL_TRIAL_LOOKUP.value),
(("toxic",), ActionType.TOXICITY_PANEL.value),
(("off",), ActionType.OFF_TARGET_SCREEN.value),
(("strat",), ActionType.PATIENT_STRATIFICATION.value),
(("literature",), ActionType.LITERATURE_SEARCH.value),
(("synthes",), ActionType.EVIDENCE_SYNTHESIS.value),
(("competitor",), ActionType.COMPETITOR_LANDSCAPE.value),
(("vitro",), ActionType.IN_VITRO_ASSAY.value),
(("vivo",), ActionType.IN_VIVO_MODEL.value),
(("crispr",), ActionType.CRISPR_KNOCKOUT.value),
(("biomarker",), ActionType.BIOMARKER_CORRELATION.value),
(("red", "flag"), ActionType.FLAG_RED_FLAG.value),
(("review",), ActionType.REQUEST_EXPERT_REVIEW.value),
(("submit",), ActionType.SUBMIT_VALIDATION_REPORT.value),
(("report",), ActionType.SUBMIT_VALIDATION_REPORT.value),
]
for fragments, normalized in heuristics:
if all(fragment in candidate for fragment in fragments):
return normalized
return None
def ensure_terminal_payload(
obs: Optional[ValidationObservation],
action: DrugTargetAction,
) -> DrugTargetAction:
"""Make sure a SUBMIT_VALIDATION_REPORT carries a decision + confidence.
If the model omits either field, fill them via ``_heuristic_decision``
when an observation is available, otherwise fall back to a neutral
no_go @ 0.5.
"""
if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT:
return action
final_decision = action.final_decision
confidence = action.confidence
if final_decision is None or confidence is None:
if obs is not None:
heur_decision, heur_conf = _heuristic_decision(obs)
else:
heur_decision, heur_conf = "no_go", 0.5
if final_decision is None:
final_decision = heur_decision
if confidence is None:
confidence = heur_conf
return action.model_copy(update={
"final_decision": final_decision,
"confidence": float(max(0.0, min(1.0, float(confidence)))),
})
def parse_action_completion(text: str) -> Optional[DrugTargetAction]:
payload = extract_json_object(text)
if payload is not None:
action_type = normalize_action_type(get_payload_value(payload, "action_type"))
if action_type is None:
return None
parameters = get_payload_value(payload, "parameters", "params") or {}
if not isinstance(parameters, dict):
parameters = {}
raw_conf = get_payload_value(payload, "confidence")
if raw_conf is None:
confidence: Optional[float] = None
else:
try:
confidence = float(raw_conf)
confidence = max(0.0, min(1.0, confidence))
except (TypeError, ValueError):
confidence = None
reasoning = get_payload_value(
payload, "reasoning", "rationale", "justification", "reason"
)
if reasoning is not None and not isinstance(reasoning, str):
reasoning = compact_preview(reasoning, 200)
reasoning = reasoning or ""
final_decision = normalize_optional_string(
get_payload_value(payload, "final_decision", "decision", "go_no_go")
)
if final_decision is not None:
final_decision = final_decision.lower().replace("-", "_")
if final_decision not in {"go", "no_go"}:
final_decision = None
return DrugTargetAction(
action_type=ActionType(action_type),
parameters=parameters,
reasoning=reasoning,
final_decision=final_decision,
confidence=confidence,
)
action_match = re.search(
r'["\']action_type["\']\s*:\s*["\']([^"\']+)',
text,
re.IGNORECASE,
)
if not action_match:
return None
action_type = normalize_action_type(action_match.group(1))
if action_type is None:
return None
confidence_match = re.search(
r'["\']confidence["\']\s*:\s*([0-9]*\.?[0-9]+)',
text,
re.IGNORECASE,
)
reasoning_match = re.search(
r'["\'](?:reasoning|justif\w*|rationale|reason)["\']\s*:\s*"((?:[^"\\]|\\.)*)',
text,
re.DOTALL | re.IGNORECASE,
)
decision_match = re.search(
r'["\']final_decision["\']\s*:\s*["\']?(go|no[_-]?go)["\']?',
text,
re.IGNORECASE,
)
confidence = None
if confidence_match:
try:
confidence = max(0.0, min(1.0, float(confidence_match.group(1))))
except ValueError:
confidence = None
reasoning: Optional[str] = None
if reasoning_match:
try:
reasoning = json.loads(f'"{reasoning_match.group(1)}"')
except json.JSONDecodeError:
reasoning = reasoning_match.group(1)
final_decision: Optional[str] = None
if decision_match:
final_decision = (
decision_match.group(1).lower().replace("-", "_").replace(" ", "_")
)
if final_decision not in {"go", "no_go"}:
final_decision = None
return DrugTargetAction(
action_type=ActionType(action_type),
parameters={},
reasoning=reasoning or "",
final_decision=final_decision,
confidence=confidence,
)
def decode_history_actions(history_actions: Optional[str]) -> List[DrugTargetAction]:
if not history_actions:
return []
raw_actions = json.loads(history_actions)
return [
DrugTargetAction(**action_payload)
for action_payload in raw_actions
if isinstance(action_payload, dict)
]
def normalise_column(values: Any, length: int) -> List[Any]:
if values is None:
return [None] * length
if isinstance(values, list):
if len(values) == length:
return values
if len(values) == 1:
return values * length
return values[:length] + [None] * max(0, length - len(values))
return [values] * length
class OpenEnvReward:
"""Reward function compatible with TRL GRPOTrainer."""
def __init__(
self,
*,
reward_backend: str,
base_url: str,
invalid_action_penalty: float = INVALID_ACTION_PENALTY,
environment_error_penalty: float = ENVIRONMENT_ERROR_PENALTY,
domain_randomise: bool = False,
) -> None:
self.__name__ = "openenv_reward"
self.reward_backend = reward_backend
self.base_url = base_url
self.invalid_action_penalty = invalid_action_penalty
self.environment_error_penalty = environment_error_penalty
self.domain_randomise = domain_randomise
def __call__(
self,
completions: List[Any],
scenario_name: Optional[List[str]] = None,
history_actions: Optional[List[str]] = None,
rng_seed: Optional[List[str]] = None,
**_: Any,
) -> List[float]:
scenario_names = normalise_column(scenario_name, len(completions))
history_columns = normalise_column(history_actions, len(completions))
seed_columns = normalise_column(rng_seed, len(completions))
rewards: List[float] = []
for completion, current_scenario, current_history, current_seed in zip(
completions,
scenario_names,
history_columns,
seed_columns,
):
action = parse_action_completion(completion_to_text(completion))
if action is None:
rewards.append(self.invalid_action_penalty)
continue
try:
if self.reward_backend == "remote":
reward = self._score_remote(action, current_scenario, current_history)
else:
reward = self._score_local(
action, current_scenario, current_history, current_seed
)
except Exception:
reward = self.environment_error_penalty
rewards.append(float(reward))
return rewards
def _score_local(
self,
action: DrugTargetAction,
scenario_name: Optional[str],
history_actions: Optional[str],
rng_seed: Optional[str] = None,
) -> float:
env = DrugTargetEnvironment(
scenario_name=scenario_name,
domain_randomise=self.domain_randomise,
)
seed = int(rng_seed) if rng_seed else None
obs = env.reset(seed=seed)
for previous_action in decode_history_actions(history_actions):
obs = env.step(previous_action)
if obs.done:
return float(obs.reward)
action = ensure_terminal_payload(obs, action)
obs = env.step(action)
return float(obs.reward)
def _score_remote(
self,
action: DrugTargetAction,
scenario_name: Optional[str],
history_actions: Optional[str],
) -> float:
with DrugTargetEnv(base_url=self.base_url) as env:
# NOTE: scenario_name is accepted for API parity with _score_local
# but the OpenEnv HTTP protocol does not yet support passing it
# through reset(). The server uses its configured default.
result = env.reset()
obs = result.observation
for previous_action in decode_history_actions(history_actions):
result = env.step(previous_action)
obs = result.observation
if result.done:
return float(result.reward or 0.0)
action = ensure_terminal_payload(obs, action)
result = env.step(action)
if result.reward is not None:
return float(result.reward)
return float(result.observation.reward)
def is_numeric_log_value(value: Any) -> bool:
return isinstance(value, Real) and not isinstance(value, bool)
def available_numeric_log_keys(log_history: Sequence[Dict[str, Any]]) -> List[str]:
keys = {
key
for entry in log_history
if isinstance(entry, dict)
for key, value in entry.items()
if key != "step" and is_numeric_log_value(value)
}
return sorted(keys)
def extract_log_series(
log_history: Sequence[Dict[str, Any]],
key: Optional[str],
) -> List[Tuple[float, float]]:
if not key:
return []
series: List[Tuple[float, float]] = []
synthetic_step = 0
for entry in log_history:
if not isinstance(entry, dict) or key not in entry:
continue
value = entry.get(key)
if not is_numeric_log_value(value):
continue
raw_step = entry.get("step")
if is_numeric_log_value(raw_step):
step = float(raw_step)
else:
synthetic_step += 1
step = float(synthetic_step)
series.append((step, float(value)))
return series
def select_reward_key(log_history: Sequence[Dict[str, Any]]) -> Optional[str]:
numeric_keys = available_numeric_log_keys(log_history)
reward_keys = [key for key in numeric_keys if "reward" in key.lower()]
if not reward_keys:
return None
preferred = [
"reward",
"mean_reward",
"reward_mean",
"rewards/open_env_reward",
]
lowered = {key.lower(): key for key in reward_keys}
for key in preferred:
if key in lowered:
return lowered[key]
reward_keys.sort(key=lambda key: ("/" in key, len(key), key))
return reward_keys[0]
def select_metric_key(
log_history: Sequence[Dict[str, Any]],
*,
reward_key: Optional[str],
requested_key: Optional[str] = None,
) -> Optional[str]:
numeric_keys = available_numeric_log_keys(log_history)
if requested_key:
if requested_key not in numeric_keys:
available = ", ".join(numeric_keys) or "none"
raise ValueError(
f"Requested plot metric '{requested_key}' was not logged. "
f"Available numeric keys: {available}"
)
return requested_key
excluded = {
"epoch",
"loss",
"learning_rate",
"step",
"total_flos",
"train_loss",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
}
if reward_key:
excluded.add(reward_key)
preferred = [
"kl",
"objective/kl",
"completion_length",
"mean_completion_length",
"grad_norm",
"entropy",
"accuracy",
"learning_rate",
"epoch",
]
numeric_set = set(numeric_keys)
for key in preferred:
if key in numeric_set and key not in excluded:
return key
candidates = [
key for key in numeric_keys
if key not in excluded and "reward" not in key.lower()
]
if candidates:
return candidates[0]
for fallback in ("learning_rate", "epoch"):
if fallback in numeric_set:
return fallback
return None
def save_plot(
path: Path,
*,
series: Sequence[Tuple[float, float]],
title: str,
ylabel: str,
) -> None:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 4.5))
if series:
x_values, y_values = zip(*series)
ax.plot(x_values, y_values, marker="o", linewidth=1.8)
else:
ax.text(
0.5,
0.5,
"No logged data available",
ha="center",
va="center",
transform=ax.transAxes,
)
ax.set_title(title)
ax.set_xlabel("Step")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(path, dpi=150)
plt.close(fig)
def save_training_plots(
log_history: Sequence[Dict[str, Any]],
output_dir: str | Path,
metric_key: Optional[str] = None,
) -> Dict[str, str]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
reward_key = select_reward_key(log_history)
selected_metric_key = select_metric_key(
log_history,
reward_key=reward_key,
requested_key=metric_key,
)
loss_series = extract_log_series(log_history, "loss")
reward_series = extract_log_series(log_history, reward_key)
metric_series = extract_log_series(log_history, selected_metric_key)
loss_path = output_path / "training_loss.png"
reward_path = output_path / "training_reward.png"
metric_path = output_path / "training_metric.png"
dashboard_path = output_path / "training_dashboard.png"
manifest_path = output_path / "training_plot_manifest.json"
save_plot(loss_path, series=loss_series, title="Training Loss", ylabel="Loss")
save_plot(
reward_path,
series=reward_series,
title=f"Training Reward ({reward_key or 'not logged'})",
ylabel="Reward",
)
save_plot(
metric_path,
series=metric_series,
title=f"Training Metric ({selected_metric_key or 'not logged'})",
ylabel=selected_metric_key or "Metric",
)
fig, axes = plt.subplots(3, 1, figsize=(10, 12))
plot_specs = [
(axes[0], loss_series, "Training Loss", "Loss"),
(axes[1], reward_series, f"Training Reward ({reward_key or 'not logged'})", "Reward"),
(
axes[2],
metric_series,
f"Training Metric ({selected_metric_key or 'not logged'})",
selected_metric_key or "Metric",
),
]
for axis, series, title, ylabel in plot_specs:
if series:
x_values, y_values = zip(*series)
axis.plot(x_values, y_values, marker="o", linewidth=1.8)
else:
axis.text(
0.5,
0.5,
"No logged data available",
ha="center",
va="center",
transform=axis.transAxes,
)
axis.set_title(title)
axis.set_xlabel("Step")
axis.set_ylabel(ylabel)
axis.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(dashboard_path, dpi=150)
plt.close(fig)
manifest = {
"available_numeric_keys": available_numeric_log_keys(log_history),
"reward_key": reward_key,
"metric_key": selected_metric_key,
"plots": {
"loss": str(loss_path),
"reward": str(reward_path),
"metric": str(metric_path),
"dashboard": str(dashboard_path),
},
}
manifest_path.write_text(json.dumps(manifest, indent=2), encoding="utf-8")
return manifest["plots"]
def run_dry_run_preview(
examples: Sequence[Dict[str, str]],
reward_fn: OpenEnvReward,
output_dir: str,
) -> None:
if not examples:
raise ValueError("No training prompts were generated for the dry run.")
sample = examples[0]
sample_reward = reward_fn(
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
scenario_name=[sample["scenario_name"]],
history_actions=[sample["history_actions"]],
)[0]
print(f"Built {len(examples)} prompt states.")
print(f"Output directory: {Path(output_dir)}")
print(f"Sample scenario: {sample['scenario_name']}")
print(f"Sample reward for reference action: {sample_reward:+.3f}")
print("\nSample prompt:\n")
print(sample["prompt"])
def resolve_torch_runtime() -> Dict[str, Any]:
import torch
use_cuda = torch.cuda.is_available()
bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
dtype = torch.bfloat16 if bf16 else (
torch.float16 if use_cuda else torch.float32
)
return {
"use_cuda": use_cuda,
"device": "cuda:0" if use_cuda else "cpu",
"dtype": dtype,
"bf16": bf16,
"fp16": use_cuda and not bf16,
"device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
}
def _guard_invalid_torchao_version() -> None:
"""Treat malformed torchao installs as unavailable for HF imports."""
import functools
import importlib.metadata as importlib_metadata
import sys
from packaging.version import InvalidVersion, Version
if getattr(importlib_metadata, "_openenv_torchao_guard_installed", False):
metadata_guard_installed = True
else:
original_version = importlib_metadata.version
def guarded_version(distribution_name: str) -> str:
version = original_version(distribution_name)
if distribution_name.lower() == "torchao":
try:
Version(version)
except InvalidVersion as exc:
raise importlib_metadata.PackageNotFoundError(
f"Malformed torchao version metadata: {version!r}"
) from exc
return version
importlib_metadata.version = guarded_version
importlib_metadata._openenv_torchao_guard_installed = True
metadata_guard_installed = False
import_utils = sys.modules.get("transformers.utils.import_utils")
if import_utils is not None and not getattr(
import_utils, "_openenv_torchao_guard_installed", False
):
original_is_package_available = import_utils._is_package_available
def guarded_is_package_available(
pkg_name: str,
return_version: bool = False,
):
if pkg_name != "torchao":
return original_is_package_available(pkg_name, return_version=return_version)
is_available, package_version = original_is_package_available(
pkg_name,
return_version=True,
)
if not is_available:
return (False, package_version) if return_version else (False, None)
try:
Version(package_version)
except InvalidVersion:
return (False, "0") if return_version else (False, None)
return (True, package_version) if return_version else (True, None)
min_version = getattr(import_utils, "TORCHAO_MIN_VERSION", "0")
@functools.lru_cache
def guarded_is_torchao_available(min_version_override: str = min_version) -> bool:
is_available, package_version = guarded_is_package_available(
"torchao",
return_version=True,
)
if not is_available:
return False
try:
return Version(package_version) >= Version(min_version_override)
except InvalidVersion:
return False
if hasattr(import_utils.is_torchao_available, "cache_clear"):
import_utils.is_torchao_available.cache_clear()
import_utils._is_package_available = guarded_is_package_available
import_utils.is_torchao_available = guarded_is_torchao_available
import_utils._openenv_torchao_guard_installed = True
transformers_utils = sys.modules.get("transformers.utils")
if transformers_utils is not None:
transformers_utils.is_torchao_available = guarded_is_torchao_available
if metadata_guard_installed and import_utils is None:
return
def _guard_partial_vllm_install() -> None:
"""Treat partial vLLM installs as unavailable for TRL imports."""
import functools
import importlib
try:
import trl.import_utils as trl_import_utils
except Exception:
return
if getattr(trl_import_utils, "_openenv_vllm_guard_installed", False):
return
def _has_usable_vllm() -> bool:
try:
importlib.import_module("vllm")
importlib.import_module("vllm.distributed.device_communicators.pynccl")
importlib.import_module("vllm.distributed.utils")
except Exception:
return False
return True
@functools.lru_cache
def guarded_is_vllm_available(*args: Any, **kwargs: Any) -> bool:
return _has_usable_vllm()
if hasattr(trl_import_utils.is_vllm_available, "cache_clear"):
trl_import_utils.is_vllm_available.cache_clear()
trl_import_utils.is_vllm_available = guarded_is_vllm_available
trl_import_utils._openenv_vllm_guard_installed = True
def load_model_artifacts(
model_id: str,
*,
trust_remote_code: bool,
):
_guard_invalid_torchao_version()
from transformers import AutoModelForCausalLM, AutoTokenizer
runtime = resolve_torch_runtime()
print(f"Loading tokenizer for {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for {model_id} ...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
torch_dtype=runtime["dtype"],
)
if runtime["use_cuda"]:
model = model.to(runtime["device"])
else:
model = model.to("cpu")
return tokenizer, model
def build_openenv_reward(args: argparse.Namespace) -> OpenEnvReward:
"""Return the OpenEnv-compatible reward callable used by GRPO."""
return OpenEnvReward(
reward_backend=args.reward_backend,
base_url=args.base_url,
domain_randomise=args.domain_randomise,
)
def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
"""Build the OpenEnv rollout states that seed GRPO prompts."""
scenario_names = selected_scenarios(args.scenario_name)
examples = build_prompt_examples(
dataset_episodes=args.dataset_episodes,
rollout_steps=args.rollout_steps,
collection_policy=args.collection_policy,
scenario_names=scenario_names,
seed=args.seed,
domain_randomise=args.domain_randomise,
)
return {
"scenario_names": scenario_names,
"examples": examples,
}
def build_grpo_config(
args: argparse.Namespace,
runtime: Dict[str, Any],
):
import inspect
_guard_invalid_torchao_version()
_guard_partial_vllm_install()
from trl import GRPOConfig
config_kwargs = {
"output_dir": args.output_dir,
"learning_rate": args.learning_rate,
"per_device_train_batch_size": args.per_device_train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"num_generations": args.num_generations,
"max_completion_length": args.max_completion_length,
"max_prompt_length": args.max_prompt_length,
"num_train_epochs": args.num_train_epochs,
"logging_steps": args.logging_steps,
"save_steps": args.save_steps,
"bf16": runtime["bf16"],
"fp16": runtime["fp16"],
"report_to": "none",
"remove_unused_columns": False,
}
supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
if (
"max_length" in supported_params
and "max_prompt_length" not in supported_params
and "max_completion_length" not in supported_params
):
config_kwargs["max_length"] = (
args.max_prompt_length + args.max_completion_length
)
filtered_kwargs = {
key: value
for key, value in config_kwargs.items()
if key in supported_params
}
skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
if skipped:
print(
"GRPOConfig compatibility: skipping unsupported fields "
f"{', '.join(skipped)}"
)
return GRPOConfig(**filtered_kwargs)
def build_grpo_trainer(
*,
model: Any,
tokenizer: Any,
reward_func: Any,
train_dataset: Any,
args: argparse.Namespace,
runtime: Dict[str, Any],
):
_guard_invalid_torchao_version()
_guard_partial_vllm_install()
from trl import GRPOTrainer
config = build_grpo_config(args, runtime)
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_func,
args=config,
train_dataset=train_dataset,
processing_class=tokenizer,
)
# Bolt on the live evidence callback so the trainer Space dashboard
# gets per-step CSV updates while the run is in flight. The callback
# only writes through DrugEnv's existing public API, so it cannot
# impact training correctness — failures are swallowed inside the
# callback itself.
try:
from training.live_callback import LiveTrainingCallback
trainer.add_callback(LiveTrainingCallback(
evidence_dir=getattr(args, "evidence_dir", "evidence"),
checkpoint_eval_steps=getattr(args, "checkpoint_eval_steps", 50),
checkpoint_eval_episodes=getattr(args, "checkpoint_eval_episodes", 4),
))
except Exception as exc:
print(f"[live-callback] not installed: {exc}")
return trainer
def generate_action_with_model(
model: Any,
tokenizer: Any,
prompt_or_observation: str | ValidationObservation,
*,
max_new_tokens: int = DEFAULT_COMPLETION_TOKEN_BUDGET,
temperature: float = 0.2,
top_p: float = 0.9,
do_sample: bool = True,
) -> Dict[str, Any]:
import torch
if isinstance(prompt_or_observation, ValidationObservation):
prompt = build_training_prompt(prompt_or_observation)
obs_for_terminal: Optional[ValidationObservation] = prompt_or_observation
else:
prompt = str(prompt_or_observation)
obs_for_terminal = None
model_device = getattr(model, "device", None)
if model_device is None:
model_device = resolve_torch_runtime()["device"]
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(model_device) for key, value in inputs.items()}
prompt_tokens = inputs["input_ids"].shape[1]
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"pad_token_id": tokenizer.pad_token_id,
}
with torch.no_grad():
output_ids = model.generate(**inputs, **generation_kwargs)
new_tokens = output_ids[0][prompt_tokens:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
action = parse_action_completion(response_text)
if action is not None:
action = ensure_terminal_payload(obs_for_terminal, action)
return {
"prompt": prompt,
"response_text": response_text,
"action": action,
}
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
random.seed(args.seed)
runtime = resolve_torch_runtime()
if args.load_model_only:
tokenizer, model = load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
)
device = getattr(model, "device", "unknown")
print(f"Model ready: {args.model_id}")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model device: {device}")
print(f"Runtime device name: {runtime['device_name']}")
return {
"args": args,
"runtime": runtime,
"tokenizer": tokenizer,
"model": model,
}
prompt_data = prepare_prompt_examples(args)
scenario_names = prompt_data["scenario_names"]
examples = prompt_data["examples"]
reward_fn = build_openenv_reward(args)
if args.dry_run:
run_dry_run_preview(examples, reward_fn, args.output_dir)
return {
"args": args,
"runtime": runtime,
"scenario_names": scenario_names,
"examples": examples,
"reward_fn": reward_fn,
}
from datasets import Dataset
train_dataset = Dataset.from_list(examples)
tokenizer, model = load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
)
print(
f"Training runtime: device={runtime['device']} "
f"name={runtime['device_name']} "
f"dtype={runtime['dtype']}"
)
print(
"OpenEnv reward: "
f"backend={args.reward_backend} scenarios={len(scenario_names)} "
f"examples={len(examples)}"
)
trainer = build_grpo_trainer(
model=model,
train_dataset=train_dataset,
tokenizer=tokenizer,
reward_func=reward_fn,
args=args,
runtime=runtime,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
api.upload_folder(
folder_path=args.output_dir,
repo_id=args.push_to_hub,
repo_type="model",
create_pr=False,
)
print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
plot_paths = save_training_plots(
trainer.state.log_history,
args.output_dir,
metric_key=args.plot_metric_key,
)
print("Saved training plots:")
for plot_name, plot_path in plot_paths.items():
print(f" - {plot_name}: {plot_path}")
return {
"args": args,
"runtime": runtime,
"scenario_names": scenario_names,
"examples": examples,
"reward_fn": reward_fn,
"train_dataset": train_dataset,
"tokenizer": tokenizer,
"model": model,
"trainer": trainer,
"plot_paths": plot_paths,
}
def main() -> None:
run_training(parse_args())
if __name__ == "__main__":
main()