drugenv-trainer / training /live_callback.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Live (mid-training) evidence callback for the trainer Space dashboard.
Bolt-on `transformers.TrainerCallback` that streams two CSVs while a
GRPO run is in flight:
* ``evidence/training_log.csv`` β€” one row per logging step, with all
numeric values surfaced by TRL's log-history dict (``loss``,
``reward``, ``reward_std``, ``kl``, ``grad_norm``, ``learning_rate``,
``epoch``, ``frac_reward_zero_std``, ``completions/mean_length``, …).
The columns are discovered dynamically from ``state.log_history`` β€”
we never hard-code TRL's exact key set because it has shifted across
releases.
* ``evidence/checkpoint_evals.csv`` β€” every ``checkpoint_eval_steps``
GRPO updates we run a short held-out rollout against
``DrugTargetEnvironment`` (using the same heuristic policy
``training/training_script.py::heuristic_next_action`` would pick if
GRPO were silent), score it through DrugEnv's existing
``RewardComputer`` / ``RuleEngine`` / ``TransitionEngine`` so the
reward column names match ``RewardBreakdown.to_dict()`` 1-to-1, and
append a row.
On train-end, the existing post-hoc ``save_training_plots`` from
``training/training_script.py`` is invoked again so the plots
``training_dashboard.png`` / ``training_loss.png`` /
``training_reward.png`` / ``training_metric.png`` are refreshed for the
final state too. We deliberately do NOT replace that post-hoc plotter
β€” both paths coexist.
"""
from __future__ import annotations
import csv
import logging
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
from models import ActionType, DrugTargetAction, ValidationObservation
logger = logging.getLogger(__name__)
# Reward component column names mirror ``server.rewards.reward
# .RewardBreakdown.to_dict()`` so a reviewer can pivot the CSV against
# the dashboard cards without renaming columns. ``term_*`` prefixed
# duplicates are emitted by ``DrugTargetEnvironment.step`` to
# disambiguate per-step vs terminal contributions; we keep both shapes.
_REWARD_COLUMNS = (
"evidence_coverage",
"decision_accuracy",
"credit_efficiency",
"reasoning_coherence",
"novelty",
"penalty",
"shaping",
"terminal",
"total",
)
def _ensure_dir(p: Path) -> None:
p.parent.mkdir(parents=True, exist_ok=True)
class _AppendOnlyCsv:
"""Minimal append-only CSV with a thread-safe header rewrite when
new fields show up after the file was first opened.
GRPO's ``state.log_history`` adds keys as different code paths fire
(e.g. ``frac_reward_zero_std`` only appears once a saturated batch
is logged), so we cannot fix the header at construction time the
way ``csv.DictWriter`` would prefer. Instead we accept the union
of seen-so-far keys, rewrite the header in place when it grows,
and pad missing fields with empty strings.
"""
def __init__(self, path: Path, base_fields: Sequence[str]) -> None:
self.path = Path(path)
_ensure_dir(self.path)
self._lock = threading.Lock()
self._fields: List[str] = list(base_fields)
if not self.path.exists():
with open(self.path, "w", newline="") as f:
csv.DictWriter(f, fieldnames=self._fields).writeheader()
def append(self, row: Dict[str, Any]) -> None:
with self._lock:
new_keys = [k for k in row if k not in self._fields]
if new_keys:
# Header grew β€” rewrite the whole file in place.
self._fields.extend(new_keys)
existing: List[Dict[str, str]] = []
if self.path.exists():
with open(self.path, newline="") as f:
existing = list(csv.DictReader(f))
with open(self.path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=self._fields)
w.writeheader()
for r in existing:
w.writerow({k: r.get(k, "") for k in self._fields})
with open(self.path, "a", newline="") as f:
w = csv.DictWriter(f, fieldnames=self._fields)
w.writerow({k: row.get(k, "") for k in self._fields})
def _try_callback_class():
try:
from transformers import TrainerCallback # type: ignore
return TrainerCallback
except Exception as exc: # pragma: no cover - tested only inside Spaces
logger.warning(
"transformers.TrainerCallback unavailable (%s); LiveTrainingCallback "
"will degrade to a no-op.",
exc,
)
return object # plain stand-in; trainer will simply not call our hooks
_TrainerCallback = _try_callback_class()
class LiveTrainingCallback(_TrainerCallback): # type: ignore[misc]
"""Streams ``training_log.csv`` + ``checkpoint_evals.csv`` during GRPO.
Wire it into a TRL ``GRPOTrainer`` via ``callbacks=[LiveTrainingCallback(...)]``
after building the trainer; see ``training/training_script.py`` for the
canonical wiring. The callback only writes through DrugEnv's existing
public API (``DrugTargetEnvironment``, ``RewardBreakdown.to_dict``,
``training/training_script.py``'s heuristic policy) so future tweaks
to the env or reward shape do not require touching this file.
"""
def __init__(
self,
*,
evidence_dir: str = "evidence",
checkpoint_eval_steps: int = 50,
checkpoint_eval_episodes: int = 4,
difficulty: Optional[str] = None,
scenario_name: Optional[str] = None,
) -> None:
self.evidence_dir = Path(evidence_dir)
self.checkpoint_eval_steps = max(1, int(checkpoint_eval_steps))
self.checkpoint_eval_episodes = max(1, int(checkpoint_eval_episodes))
self.difficulty = difficulty
self.scenario_name = scenario_name
self.evidence_dir.mkdir(parents=True, exist_ok=True)
self._training_log = _AppendOnlyCsv(
self.evidence_dir / "training_log.csv",
base_fields=("step", "wall_time_s"),
)
self._checkpoint_log = _AppendOnlyCsv(
self.evidence_dir / "checkpoint_evals.csv",
base_fields=(
"step", "fraction_done", "episodes",
"mean_reward", "success_rate",
"decision_accuracy_rate", "evidence_coverage_rate",
"report_submitted_rate",
),
)
self._t0 = time.time()
self._last_eval_step = -1
self._latest_log_keys: List[str] = []
# ── transformers TrainerCallback hooks ──────────────────────────
def on_log(self, _args, state, control, logs=None, **kw):
logs = logs or {}
row: Dict[str, Any] = {
"step": getattr(state, "global_step", None),
"wall_time_s": round(time.time() - self._t0, 2),
}
for k, v in logs.items():
if isinstance(v, (int, float)) and not isinstance(v, bool):
# CSV-friendly: replace '/' in TRL keys (e.g. 'rewards/mean')
# with a flat dotted name so spreadsheet tools don't choke.
row[k.replace("/", ".")] = v
self._latest_log_keys = list(row.keys())
self._training_log.append(row)
return control
def on_step_end(self, _args, state, control, **kw):
step = getattr(state, "global_step", 0)
if step <= 0 or step == self._last_eval_step:
return control
if step % self.checkpoint_eval_steps != 0:
return control
self._last_eval_step = step
try:
self._run_checkpoint_eval(step, state)
except Exception as exc:
logger.warning("checkpoint eval failed at step %d: %s", step, exc)
return control
def on_train_end(self, _args, state, control, **kw):
# Best-effort: ask DrugEnv's existing post-hoc plotter to refresh
# the static PNG dashboard. Imported lazily so import failures do
# not break the in-flight run.
try:
from training.training_script import save_training_plots
save_training_plots(
getattr(state, "log_history", []) or [],
str(self.evidence_dir),
)
except Exception as exc: # pragma: no cover - plotting is optional
logger.warning("post-hoc plot refresh failed: %s", exc)
return control
# ── checkpoint-eval rollout ─────────────────────────────────────
def _run_checkpoint_eval(self, step: int, state) -> None:
"""Run a tiny held-out rollout using the heuristic policy and
score the final episode through DrugEnv's RewardBreakdown."""
from server.hackathon_environment import DrugTargetEnvironment, MAX_STEPS
from training.training_script import (
HEURISTIC_SEQUENCE,
build_drug_target_action,
heuristic_next_action,
)
episodes_summary: List[Dict[str, Any]] = []
held_out_seed_base = 900_000
for i in range(self.checkpoint_eval_episodes):
env = DrugTargetEnvironment(
scenario_name=self.scenario_name,
domain_randomise=False,
)
obs: ValidationObservation = env.reset(seed=held_out_seed_base + i)
history: List[ActionType] = []
cumulative = 0.0
steps = 0
while not obs.done and steps < MAX_STEPS:
next_at = heuristic_next_action(history, steps)
action: DrugTargetAction = build_drug_target_action(next_at, obs)
obs = env.step(action)
history.append(action.action_type)
cumulative += float(obs.reward or 0.0)
steps += 1
# Pull ground-truth flags off the latent β€” the heuristic
# only "succeeds" if it ended in a submitted report whose
# decision matched the hidden ``correct_decision``.
latent = env._latent
correct = (
latent is not None
and latent.progress.report_submitted
and any(
h == ActionType.SUBMIT_VALIDATION_REPORT for h in history
)
)
terminal_breakdown = (obs.metadata or {}).get("reward_breakdown", {}) or {}
evidence_cov = float(
terminal_breakdown.get("term_evidence_coverage")
or terminal_breakdown.get("evidence_coverage")
or 0.0
)
decision_acc = float(
terminal_breakdown.get("term_decision_accuracy")
or terminal_breakdown.get("decision_accuracy")
or 0.0
)
episodes_summary.append({
"reward": cumulative,
"decision_accuracy": decision_acc,
"evidence_coverage": evidence_cov,
"submitted": bool(latent and latent.progress.report_submitted),
"correct": bool(correct and decision_acc >= 0.5),
})
n = len(episodes_summary)
if n == 0:
return
mean_reward = sum(e["reward"] for e in episodes_summary) / n
success_rate = sum(1 for e in episodes_summary if e["correct"]) / n
decision_rate = sum(e["decision_accuracy"] for e in episodes_summary) / n
coverage_rate = sum(e["evidence_coverage"] for e in episodes_summary) / n
submitted_rate = sum(1 for e in episodes_summary if e["submitted"]) / n
max_steps = getattr(state, "max_steps", None) or step
self._checkpoint_log.append({
"step": step,
"fraction_done": round(step / max(max_steps, 1), 4),
"episodes": n,
"mean_reward": round(mean_reward, 4),
"success_rate": round(success_rate, 4),
"decision_accuracy_rate": round(decision_rate, 4),
"evidence_coverage_rate": round(coverage_rate, 4),
"report_submitted_rate": round(submitted_rate, 4),
})
logger.info(
"[checkpoint-eval step=%d] reward=%.3f success=%.2f decision=%.2f "
"coverage=%.2f submitted=%.2f (heuristic policy, n=%d)",
step, mean_reward, success_rate, decision_rate,
coverage_rate, submitted_rate, n,
)
__all__ = ["LiveTrainingCallback"]