"""Live (mid-training) evidence callback for the trainer Space dashboard. Bolt-on `transformers.TrainerCallback` that streams two CSVs while a GRPO run is in flight: * ``evidence/training_log.csv`` — one row per logging step, with all numeric values surfaced by TRL's log-history dict (``loss``, ``reward``, ``reward_std``, ``kl``, ``grad_norm``, ``learning_rate``, ``epoch``, ``frac_reward_zero_std``, ``completions/mean_length``, …). The columns are discovered dynamically from ``state.log_history`` — we never hard-code TRL's exact key set because it has shifted across releases. * ``evidence/checkpoint_evals.csv`` — every ``checkpoint_eval_steps`` GRPO updates we run a short held-out rollout against ``DrugTargetEnvironment`` (using the same heuristic policy ``training/training_script.py::heuristic_next_action`` would pick if GRPO were silent), score it through DrugEnv's existing ``RewardComputer`` / ``RuleEngine`` / ``TransitionEngine`` so the reward column names match ``RewardBreakdown.to_dict()`` 1-to-1, and append a row. On train-end, the existing post-hoc ``save_training_plots`` from ``training/training_script.py`` is invoked again so the plots ``training_dashboard.png`` / ``training_loss.png`` / ``training_reward.png`` / ``training_metric.png`` are refreshed for the final state too. We deliberately do NOT replace that post-hoc plotter — both paths coexist. """ from __future__ import annotations import csv import logging import threading import time from pathlib import Path from typing import Any, Dict, List, Optional, Sequence from models import ActionType, DrugTargetAction, ValidationObservation logger = logging.getLogger(__name__) # Reward component column names mirror ``server.rewards.reward # .RewardBreakdown.to_dict()`` so a reviewer can pivot the CSV against # the dashboard cards without renaming columns. ``term_*`` prefixed # duplicates are emitted by ``DrugTargetEnvironment.step`` to # disambiguate per-step vs terminal contributions; we keep both shapes. _REWARD_COLUMNS = ( "evidence_coverage", "decision_accuracy", "credit_efficiency", "reasoning_coherence", "novelty", "penalty", "shaping", "terminal", "total", ) def _ensure_dir(p: Path) -> None: p.parent.mkdir(parents=True, exist_ok=True) class _AppendOnlyCsv: """Minimal append-only CSV with a thread-safe header rewrite when new fields show up after the file was first opened. GRPO's ``state.log_history`` adds keys as different code paths fire (e.g. ``frac_reward_zero_std`` only appears once a saturated batch is logged), so we cannot fix the header at construction time the way ``csv.DictWriter`` would prefer. Instead we accept the union of seen-so-far keys, rewrite the header in place when it grows, and pad missing fields with empty strings. """ def __init__(self, path: Path, base_fields: Sequence[str]) -> None: self.path = Path(path) _ensure_dir(self.path) self._lock = threading.Lock() self._fields: List[str] = list(base_fields) if not self.path.exists(): with open(self.path, "w", newline="") as f: csv.DictWriter(f, fieldnames=self._fields).writeheader() def append(self, row: Dict[str, Any]) -> None: with self._lock: new_keys = [k for k in row if k not in self._fields] if new_keys: # Header grew — rewrite the whole file in place. self._fields.extend(new_keys) existing: List[Dict[str, str]] = [] if self.path.exists(): with open(self.path, newline="") as f: existing = list(csv.DictReader(f)) with open(self.path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=self._fields) w.writeheader() for r in existing: w.writerow({k: r.get(k, "") for k in self._fields}) with open(self.path, "a", newline="") as f: w = csv.DictWriter(f, fieldnames=self._fields) w.writerow({k: row.get(k, "") for k in self._fields}) def _try_callback_class(): try: from transformers import TrainerCallback # type: ignore return TrainerCallback except Exception as exc: # pragma: no cover - tested only inside Spaces logger.warning( "transformers.TrainerCallback unavailable (%s); LiveTrainingCallback " "will degrade to a no-op.", exc, ) return object # plain stand-in; trainer will simply not call our hooks _TrainerCallback = _try_callback_class() class LiveTrainingCallback(_TrainerCallback): # type: ignore[misc] """Streams ``training_log.csv`` + ``checkpoint_evals.csv`` during GRPO. Wire it into a TRL ``GRPOTrainer`` via ``callbacks=[LiveTrainingCallback(...)]`` after building the trainer; see ``training/training_script.py`` for the canonical wiring. The callback only writes through DrugEnv's existing public API (``DrugTargetEnvironment``, ``RewardBreakdown.to_dict``, ``training/training_script.py``'s heuristic policy) so future tweaks to the env or reward shape do not require touching this file. """ def __init__( self, *, evidence_dir: str = "evidence", checkpoint_eval_steps: int = 50, checkpoint_eval_episodes: int = 4, difficulty: Optional[str] = None, scenario_name: Optional[str] = None, ) -> None: self.evidence_dir = Path(evidence_dir) self.checkpoint_eval_steps = max(1, int(checkpoint_eval_steps)) self.checkpoint_eval_episodes = max(1, int(checkpoint_eval_episodes)) self.difficulty = difficulty self.scenario_name = scenario_name self.evidence_dir.mkdir(parents=True, exist_ok=True) self._training_log = _AppendOnlyCsv( self.evidence_dir / "training_log.csv", base_fields=("step", "wall_time_s"), ) self._checkpoint_log = _AppendOnlyCsv( self.evidence_dir / "checkpoint_evals.csv", base_fields=( "step", "fraction_done", "episodes", "mean_reward", "success_rate", "decision_accuracy_rate", "evidence_coverage_rate", "report_submitted_rate", ), ) self._t0 = time.time() self._last_eval_step = -1 self._latest_log_keys: List[str] = [] # ── transformers TrainerCallback hooks ────────────────────────── def on_log(self, _args, state, control, logs=None, **kw): logs = logs or {} row: Dict[str, Any] = { "step": getattr(state, "global_step", None), "wall_time_s": round(time.time() - self._t0, 2), } for k, v in logs.items(): if isinstance(v, (int, float)) and not isinstance(v, bool): # CSV-friendly: replace '/' in TRL keys (e.g. 'rewards/mean') # with a flat dotted name so spreadsheet tools don't choke. row[k.replace("/", ".")] = v self._latest_log_keys = list(row.keys()) self._training_log.append(row) return control def on_step_end(self, _args, state, control, **kw): step = getattr(state, "global_step", 0) if step <= 0 or step == self._last_eval_step: return control if step % self.checkpoint_eval_steps != 0: return control self._last_eval_step = step try: self._run_checkpoint_eval(step, state) except Exception as exc: logger.warning("checkpoint eval failed at step %d: %s", step, exc) return control def on_train_end(self, _args, state, control, **kw): # Best-effort: ask DrugEnv's existing post-hoc plotter to refresh # the static PNG dashboard. Imported lazily so import failures do # not break the in-flight run. try: from training.training_script import save_training_plots save_training_plots( getattr(state, "log_history", []) or [], str(self.evidence_dir), ) except Exception as exc: # pragma: no cover - plotting is optional logger.warning("post-hoc plot refresh failed: %s", exc) return control # ── checkpoint-eval rollout ───────────────────────────────────── def _run_checkpoint_eval(self, step: int, state) -> None: """Run a tiny held-out rollout using the heuristic policy and score the final episode through DrugEnv's RewardBreakdown.""" from server.hackathon_environment import DrugTargetEnvironment, MAX_STEPS from training.training_script import ( HEURISTIC_SEQUENCE, build_drug_target_action, heuristic_next_action, ) episodes_summary: List[Dict[str, Any]] = [] held_out_seed_base = 900_000 for i in range(self.checkpoint_eval_episodes): env = DrugTargetEnvironment( scenario_name=self.scenario_name, domain_randomise=False, ) obs: ValidationObservation = env.reset(seed=held_out_seed_base + i) history: List[ActionType] = [] cumulative = 0.0 steps = 0 while not obs.done and steps < MAX_STEPS: next_at = heuristic_next_action(history, steps) action: DrugTargetAction = build_drug_target_action(next_at, obs) obs = env.step(action) history.append(action.action_type) cumulative += float(obs.reward or 0.0) steps += 1 # Pull ground-truth flags off the latent — the heuristic # only "succeeds" if it ended in a submitted report whose # decision matched the hidden ``correct_decision``. latent = env._latent correct = ( latent is not None and latent.progress.report_submitted and any( h == ActionType.SUBMIT_VALIDATION_REPORT for h in history ) ) terminal_breakdown = (obs.metadata or {}).get("reward_breakdown", {}) or {} evidence_cov = float( terminal_breakdown.get("term_evidence_coverage") or terminal_breakdown.get("evidence_coverage") or 0.0 ) decision_acc = float( terminal_breakdown.get("term_decision_accuracy") or terminal_breakdown.get("decision_accuracy") or 0.0 ) episodes_summary.append({ "reward": cumulative, "decision_accuracy": decision_acc, "evidence_coverage": evidence_cov, "submitted": bool(latent and latent.progress.report_submitted), "correct": bool(correct and decision_acc >= 0.5), }) n = len(episodes_summary) if n == 0: return mean_reward = sum(e["reward"] for e in episodes_summary) / n success_rate = sum(1 for e in episodes_summary if e["correct"]) / n decision_rate = sum(e["decision_accuracy"] for e in episodes_summary) / n coverage_rate = sum(e["evidence_coverage"] for e in episodes_summary) / n submitted_rate = sum(1 for e in episodes_summary if e["submitted"]) / n max_steps = getattr(state, "max_steps", None) or step self._checkpoint_log.append({ "step": step, "fraction_done": round(step / max(max_steps, 1), 4), "episodes": n, "mean_reward": round(mean_reward, 4), "success_rate": round(success_rate, 4), "decision_accuracy_rate": round(decision_rate, 4), "evidence_coverage_rate": round(coverage_rate, 4), "report_submitted_rate": round(submitted_rate, 4), }) logger.info( "[checkpoint-eval step=%d] reward=%.3f success=%.2f decision=%.2f " "coverage=%.2f submitted=%.2f (heuristic policy, n=%d)", step, mean_reward, success_rate, decision_rate, coverage_rate, submitted_rate, n, ) __all__ = ["LiveTrainingCallback"]