Spaces:
Runtime error
Runtime error
| """Live (mid-training) evidence callback for the trainer Space dashboard. | |
| Bolt-on `transformers.TrainerCallback` that streams two CSVs while a | |
| GRPO run is in flight: | |
| * ``evidence/training_log.csv`` β one row per logging step, with all | |
| numeric values surfaced by TRL's log-history dict (``loss``, | |
| ``reward``, ``reward_std``, ``kl``, ``grad_norm``, ``learning_rate``, | |
| ``epoch``, ``frac_reward_zero_std``, ``completions/mean_length``, β¦). | |
| The columns are discovered dynamically from ``state.log_history`` β | |
| we never hard-code TRL's exact key set because it has shifted across | |
| releases. | |
| * ``evidence/checkpoint_evals.csv`` β every ``checkpoint_eval_steps`` | |
| GRPO updates we run a short held-out rollout against | |
| ``DrugTargetEnvironment`` (using the same heuristic policy | |
| ``training/training_script.py::heuristic_next_action`` would pick if | |
| GRPO were silent), score it through DrugEnv's existing | |
| ``RewardComputer`` / ``RuleEngine`` / ``TransitionEngine`` so the | |
| reward column names match ``RewardBreakdown.to_dict()`` 1-to-1, and | |
| append a row. | |
| On train-end, the existing post-hoc ``save_training_plots`` from | |
| ``training/training_script.py`` is invoked again so the plots | |
| ``training_dashboard.png`` / ``training_loss.png`` / | |
| ``training_reward.png`` / ``training_metric.png`` are refreshed for the | |
| final state too. We deliberately do NOT replace that post-hoc plotter | |
| β both paths coexist. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import logging | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence | |
| from models import ActionType, DrugTargetAction, ValidationObservation | |
| logger = logging.getLogger(__name__) | |
| # Reward component column names mirror ``server.rewards.reward | |
| # .RewardBreakdown.to_dict()`` so a reviewer can pivot the CSV against | |
| # the dashboard cards without renaming columns. ``term_*`` prefixed | |
| # duplicates are emitted by ``DrugTargetEnvironment.step`` to | |
| # disambiguate per-step vs terminal contributions; we keep both shapes. | |
| _REWARD_COLUMNS = ( | |
| "evidence_coverage", | |
| "decision_accuracy", | |
| "credit_efficiency", | |
| "reasoning_coherence", | |
| "novelty", | |
| "penalty", | |
| "shaping", | |
| "terminal", | |
| "total", | |
| ) | |
| def _ensure_dir(p: Path) -> None: | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| class _AppendOnlyCsv: | |
| """Minimal append-only CSV with a thread-safe header rewrite when | |
| new fields show up after the file was first opened. | |
| GRPO's ``state.log_history`` adds keys as different code paths fire | |
| (e.g. ``frac_reward_zero_std`` only appears once a saturated batch | |
| is logged), so we cannot fix the header at construction time the | |
| way ``csv.DictWriter`` would prefer. Instead we accept the union | |
| of seen-so-far keys, rewrite the header in place when it grows, | |
| and pad missing fields with empty strings. | |
| """ | |
| def __init__(self, path: Path, base_fields: Sequence[str]) -> None: | |
| self.path = Path(path) | |
| _ensure_dir(self.path) | |
| self._lock = threading.Lock() | |
| self._fields: List[str] = list(base_fields) | |
| if not self.path.exists(): | |
| with open(self.path, "w", newline="") as f: | |
| csv.DictWriter(f, fieldnames=self._fields).writeheader() | |
| def append(self, row: Dict[str, Any]) -> None: | |
| with self._lock: | |
| new_keys = [k for k in row if k not in self._fields] | |
| if new_keys: | |
| # Header grew β rewrite the whole file in place. | |
| self._fields.extend(new_keys) | |
| existing: List[Dict[str, str]] = [] | |
| if self.path.exists(): | |
| with open(self.path, newline="") as f: | |
| existing = list(csv.DictReader(f)) | |
| with open(self.path, "w", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=self._fields) | |
| w.writeheader() | |
| for r in existing: | |
| w.writerow({k: r.get(k, "") for k in self._fields}) | |
| with open(self.path, "a", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=self._fields) | |
| w.writerow({k: row.get(k, "") for k in self._fields}) | |
| def _try_callback_class(): | |
| try: | |
| from transformers import TrainerCallback # type: ignore | |
| return TrainerCallback | |
| except Exception as exc: # pragma: no cover - tested only inside Spaces | |
| logger.warning( | |
| "transformers.TrainerCallback unavailable (%s); LiveTrainingCallback " | |
| "will degrade to a no-op.", | |
| exc, | |
| ) | |
| return object # plain stand-in; trainer will simply not call our hooks | |
| _TrainerCallback = _try_callback_class() | |
| class LiveTrainingCallback(_TrainerCallback): # type: ignore[misc] | |
| """Streams ``training_log.csv`` + ``checkpoint_evals.csv`` during GRPO. | |
| Wire it into a TRL ``GRPOTrainer`` via ``callbacks=[LiveTrainingCallback(...)]`` | |
| after building the trainer; see ``training/training_script.py`` for the | |
| canonical wiring. The callback only writes through DrugEnv's existing | |
| public API (``DrugTargetEnvironment``, ``RewardBreakdown.to_dict``, | |
| ``training/training_script.py``'s heuristic policy) so future tweaks | |
| to the env or reward shape do not require touching this file. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| evidence_dir: str = "evidence", | |
| checkpoint_eval_steps: int = 50, | |
| checkpoint_eval_episodes: int = 4, | |
| difficulty: Optional[str] = None, | |
| scenario_name: Optional[str] = None, | |
| ) -> None: | |
| self.evidence_dir = Path(evidence_dir) | |
| self.checkpoint_eval_steps = max(1, int(checkpoint_eval_steps)) | |
| self.checkpoint_eval_episodes = max(1, int(checkpoint_eval_episodes)) | |
| self.difficulty = difficulty | |
| self.scenario_name = scenario_name | |
| self.evidence_dir.mkdir(parents=True, exist_ok=True) | |
| self._training_log = _AppendOnlyCsv( | |
| self.evidence_dir / "training_log.csv", | |
| base_fields=("step", "wall_time_s"), | |
| ) | |
| self._checkpoint_log = _AppendOnlyCsv( | |
| self.evidence_dir / "checkpoint_evals.csv", | |
| base_fields=( | |
| "step", "fraction_done", "episodes", | |
| "mean_reward", "success_rate", | |
| "decision_accuracy_rate", "evidence_coverage_rate", | |
| "report_submitted_rate", | |
| ), | |
| ) | |
| self._t0 = time.time() | |
| self._last_eval_step = -1 | |
| self._latest_log_keys: List[str] = [] | |
| # ββ transformers TrainerCallback hooks ββββββββββββββββββββββββββ | |
| def on_log(self, _args, state, control, logs=None, **kw): | |
| logs = logs or {} | |
| row: Dict[str, Any] = { | |
| "step": getattr(state, "global_step", None), | |
| "wall_time_s": round(time.time() - self._t0, 2), | |
| } | |
| for k, v in logs.items(): | |
| if isinstance(v, (int, float)) and not isinstance(v, bool): | |
| # CSV-friendly: replace '/' in TRL keys (e.g. 'rewards/mean') | |
| # with a flat dotted name so spreadsheet tools don't choke. | |
| row[k.replace("/", ".")] = v | |
| self._latest_log_keys = list(row.keys()) | |
| self._training_log.append(row) | |
| return control | |
| def on_step_end(self, _args, state, control, **kw): | |
| step = getattr(state, "global_step", 0) | |
| if step <= 0 or step == self._last_eval_step: | |
| return control | |
| if step % self.checkpoint_eval_steps != 0: | |
| return control | |
| self._last_eval_step = step | |
| try: | |
| self._run_checkpoint_eval(step, state) | |
| except Exception as exc: | |
| logger.warning("checkpoint eval failed at step %d: %s", step, exc) | |
| return control | |
| def on_train_end(self, _args, state, control, **kw): | |
| # Best-effort: ask DrugEnv's existing post-hoc plotter to refresh | |
| # the static PNG dashboard. Imported lazily so import failures do | |
| # not break the in-flight run. | |
| try: | |
| from training.training_script import save_training_plots | |
| save_training_plots( | |
| getattr(state, "log_history", []) or [], | |
| str(self.evidence_dir), | |
| ) | |
| except Exception as exc: # pragma: no cover - plotting is optional | |
| logger.warning("post-hoc plot refresh failed: %s", exc) | |
| return control | |
| # ββ checkpoint-eval rollout βββββββββββββββββββββββββββββββββββββ | |
| def _run_checkpoint_eval(self, step: int, state) -> None: | |
| """Run a tiny held-out rollout using the heuristic policy and | |
| score the final episode through DrugEnv's RewardBreakdown.""" | |
| from server.hackathon_environment import DrugTargetEnvironment, MAX_STEPS | |
| from training.training_script import ( | |
| HEURISTIC_SEQUENCE, | |
| build_drug_target_action, | |
| heuristic_next_action, | |
| ) | |
| episodes_summary: List[Dict[str, Any]] = [] | |
| held_out_seed_base = 900_000 | |
| for i in range(self.checkpoint_eval_episodes): | |
| env = DrugTargetEnvironment( | |
| scenario_name=self.scenario_name, | |
| domain_randomise=False, | |
| ) | |
| obs: ValidationObservation = env.reset(seed=held_out_seed_base + i) | |
| history: List[ActionType] = [] | |
| cumulative = 0.0 | |
| steps = 0 | |
| while not obs.done and steps < MAX_STEPS: | |
| next_at = heuristic_next_action(history, steps) | |
| action: DrugTargetAction = build_drug_target_action(next_at, obs) | |
| obs = env.step(action) | |
| history.append(action.action_type) | |
| cumulative += float(obs.reward or 0.0) | |
| steps += 1 | |
| # Pull ground-truth flags off the latent β the heuristic | |
| # only "succeeds" if it ended in a submitted report whose | |
| # decision matched the hidden ``correct_decision``. | |
| latent = env._latent | |
| correct = ( | |
| latent is not None | |
| and latent.progress.report_submitted | |
| and any( | |
| h == ActionType.SUBMIT_VALIDATION_REPORT for h in history | |
| ) | |
| ) | |
| terminal_breakdown = (obs.metadata or {}).get("reward_breakdown", {}) or {} | |
| evidence_cov = float( | |
| terminal_breakdown.get("term_evidence_coverage") | |
| or terminal_breakdown.get("evidence_coverage") | |
| or 0.0 | |
| ) | |
| decision_acc = float( | |
| terminal_breakdown.get("term_decision_accuracy") | |
| or terminal_breakdown.get("decision_accuracy") | |
| or 0.0 | |
| ) | |
| episodes_summary.append({ | |
| "reward": cumulative, | |
| "decision_accuracy": decision_acc, | |
| "evidence_coverage": evidence_cov, | |
| "submitted": bool(latent and latent.progress.report_submitted), | |
| "correct": bool(correct and decision_acc >= 0.5), | |
| }) | |
| n = len(episodes_summary) | |
| if n == 0: | |
| return | |
| mean_reward = sum(e["reward"] for e in episodes_summary) / n | |
| success_rate = sum(1 for e in episodes_summary if e["correct"]) / n | |
| decision_rate = sum(e["decision_accuracy"] for e in episodes_summary) / n | |
| coverage_rate = sum(e["evidence_coverage"] for e in episodes_summary) / n | |
| submitted_rate = sum(1 for e in episodes_summary if e["submitted"]) / n | |
| max_steps = getattr(state, "max_steps", None) or step | |
| self._checkpoint_log.append({ | |
| "step": step, | |
| "fraction_done": round(step / max(max_steps, 1), 4), | |
| "episodes": n, | |
| "mean_reward": round(mean_reward, 4), | |
| "success_rate": round(success_rate, 4), | |
| "decision_accuracy_rate": round(decision_rate, 4), | |
| "evidence_coverage_rate": round(coverage_rate, 4), | |
| "report_submitted_rate": round(submitted_rate, 4), | |
| }) | |
| logger.info( | |
| "[checkpoint-eval step=%d] reward=%.3f success=%.2f decision=%.2f " | |
| "coverage=%.2f submitted=%.2f (heuristic policy, n=%d)", | |
| step, mean_reward, success_rate, decision_rate, | |
| coverage_rate, submitted_rate, n, | |
| ) | |
| __all__ = ["LiveTrainingCallback"] | |