Spaces:
Sleeping
Sleeping
| """Central Weights & Biases integration for Qubit-Medic. | |
| Design goals | |
| ------------ | |
| 1. **Single source of truth** for the W&B project name, default tags, and | |
| the ``config`` dump that every run logs. Trainers, eval scripts, and | |
| notebooks all funnel through :func:`init_run` so dashboards always | |
| line up. | |
| 2. **Safe to import without wandb installed.** The package's training | |
| deps (``wandb``) live in ``requirements-train.txt`` and are absent on | |
| the lean Spaces image. Anything in this module degrades gracefully | |
| when the import fails - the rest of the project keeps working. | |
| 3. **Disable-by-env-var.** Set ``WANDB_DISABLED=1`` (or | |
| ``QUBIT_MEDIC_WANDB=0``) and every helper here becomes a no-op, | |
| regardless of whether the package is installed. CI runs and offline | |
| testing rely on this. | |
| 4. **Rich first-class logging.** We expose dedicated helpers for the | |
| things this project cares about: | |
| * Per-reward component scalars (5 lines per step, not just total) | |
| * Curriculum-level moving averages (one line per level) | |
| * Parse success / partial / failure rates | |
| * Generation sample tables (prompt / completion / per-reward) | |
| * Eval summary tables (one row per (policy, level)) | |
| The trainers and eval script only have to call these helpers; the | |
| Pythonic context manager handles init, summary, and finish. | |
| """ | |
| from __future__ import annotations | |
| import contextlib | |
| import dataclasses | |
| import os | |
| import socket | |
| import sys | |
| import time | |
| from typing import Any, Iterable, Mapping, Optional, Sequence | |
| from qubit_medic.config import ( | |
| CURRICULUM, GRPO_GEN_PER_PROMPT, GRPO_KL_COEF, GRPO_LR, GRPO_MAX_COMPLETION_LEN, | |
| GRPO_MAX_PROMPT_LEN, GRPO_STEPS, LORA_ALPHA, LORA_R, LORA_TARGET_MODULES, | |
| MODEL_ID, PRIMARY_SEED, REWARD_WEIGHTS, SFT_BATCH_SIZE, SFT_EPOCHS, | |
| SFT_GRAD_ACCUM, SFT_LR, SFT_MAX_SEQ_LEN, WANDB_DEFAULT_TAGS, WANDB_ENTITY, | |
| WANDB_LOG_GENERATIONS_EVERY, WANDB_PROJECT, WANDB_SAMPLE_GENERATIONS, | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Lazy import + on/off toggle # | |
| # --------------------------------------------------------------------------- # | |
| _WANDB_MODULE = None | |
| _RUN: Any = None | |
| def _import_wandb(): | |
| """Import wandb on first use. Returns ``None`` if it isn't installed.""" | |
| global _WANDB_MODULE | |
| if _WANDB_MODULE is None: | |
| try: | |
| import wandb # type: ignore[import-not-found] | |
| _WANDB_MODULE = wandb | |
| except ImportError: | |
| _WANDB_MODULE = False # sentinel: "tried, failed" | |
| return _WANDB_MODULE if _WANDB_MODULE is not False else None | |
| def is_disabled() -> bool: | |
| """Honours ``WANDB_DISABLED`` and ``QUBIT_MEDIC_WANDB=0``.""" | |
| if os.environ.get("WANDB_DISABLED", "").lower() in {"1", "true", "yes"}: | |
| return True | |
| if os.environ.get("QUBIT_MEDIC_WANDB", "1").lower() in {"0", "false", "no"}: | |
| return True | |
| return False | |
| def is_available() -> bool: | |
| """``True`` iff wandb is importable AND not disabled by env var.""" | |
| return _import_wandb() is not None and not is_disabled() | |
| def get_run(): | |
| """Return the active W&B run object, or ``None`` if not initialised.""" | |
| return _RUN | |
| # --------------------------------------------------------------------------- # | |
| # Init / finish # | |
| # --------------------------------------------------------------------------- # | |
| def _system_metadata() -> dict: | |
| """Static metadata that's helpful on the dashboard but isn't a hyperparam.""" | |
| info = { | |
| "python_version": sys.version.split()[0], | |
| "hostname": socket.gethostname(), | |
| "argv": " ".join(sys.argv), | |
| "pid": os.getpid(), | |
| } | |
| try: | |
| import torch | |
| info["torch_version"] = torch.__version__ | |
| info["cuda_available"] = bool(torch.cuda.is_available()) | |
| if torch.cuda.is_available(): | |
| info["cuda_device"] = torch.cuda.get_device_name(0) | |
| except Exception: | |
| pass | |
| try: | |
| import stim | |
| info["stim_version"] = stim.__version__ | |
| except Exception: | |
| pass | |
| try: | |
| import pymatching | |
| info["pymatching_version"] = pymatching.__version__ | |
| except Exception: | |
| pass | |
| try: | |
| import trl, transformers, peft | |
| info["trl_version"] = trl.__version__ | |
| info["transformers_version"] = transformers.__version__ | |
| info["peft_version"] = peft.__version__ | |
| except Exception: | |
| pass | |
| return info | |
| def _build_default_config(extra: Optional[Mapping[str, Any]] = None) -> dict: | |
| """The config every run logs - hyperparameters + reward weights + curriculum.""" | |
| cfg: dict[str, Any] = { | |
| "model_id": MODEL_ID, | |
| "primary_seed": PRIMARY_SEED, | |
| "lora_r": LORA_R, | |
| "lora_alpha": LORA_ALPHA, | |
| "lora_target_modules": list(LORA_TARGET_MODULES), | |
| "sft": { | |
| "epochs": SFT_EPOCHS, | |
| "batch_size": SFT_BATCH_SIZE, | |
| "grad_accum": SFT_GRAD_ACCUM, | |
| "lr": SFT_LR, | |
| "max_seq_len": SFT_MAX_SEQ_LEN, | |
| }, | |
| "grpo": { | |
| "steps": GRPO_STEPS, | |
| "gen_per_prompt": GRPO_GEN_PER_PROMPT, | |
| "lr": GRPO_LR, | |
| "kl_coef": GRPO_KL_COEF, | |
| "max_prompt_len": GRPO_MAX_PROMPT_LEN, | |
| "max_completion_len": GRPO_MAX_COMPLETION_LEN, | |
| }, | |
| "reward_weights": dict(REWARD_WEIGHTS), | |
| "curriculum": [ | |
| { | |
| "name": lvl.name, "distance": lvl.distance, "rounds": lvl.rounds, | |
| "p": lvl.p, "promotion_threshold": lvl.promotion_threshold, | |
| } | |
| for lvl in CURRICULUM | |
| ], | |
| "system": _system_metadata(), | |
| } | |
| if extra: | |
| cfg.update(extra) | |
| return cfg | |
| def init_run( | |
| run_name: str, | |
| job_type: str, | |
| *, | |
| tags: Optional[Sequence[str]] = None, | |
| extra_config: Optional[Mapping[str, Any]] = None, | |
| notes: Optional[str] = None, | |
| group: Optional[str] = None, | |
| ): | |
| """Initialise (or no-op) a W&B run. | |
| Parameters | |
| ---------- | |
| run_name: | |
| Human-readable run name (e.g. ``"sft-warmup-2026-04-25"``). | |
| job_type: | |
| One of ``"sft"``, ``"grpo"``, ``"eval"``, ``"format-test"``, | |
| ``"baseline"``. Used to group runs on the dashboard. | |
| tags: | |
| Extra tags appended to :data:`qubit_medic.config.WANDB_DEFAULT_TAGS`. | |
| extra_config: | |
| Hyperparameters specific to this run (e.g. SFT epochs override). | |
| notes: | |
| Free-text notes shown on the dashboard. | |
| group: | |
| Optional W&B group (used to bundle SFT + GRPO + eval runs of the | |
| same experiment). | |
| Returns | |
| ------- | |
| The wandb Run object, or ``None`` if W&B is unavailable / disabled. | |
| """ | |
| global _RUN | |
| wandb = _import_wandb() | |
| if wandb is None or is_disabled(): | |
| if wandb is None: | |
| print("[wandb] not installed; skipping init " | |
| "(install with `pip install wandb` to enable logging)", | |
| file=sys.stderr) | |
| else: | |
| print("[wandb] disabled by env var; skipping init", file=sys.stderr) | |
| _RUN = None | |
| return None | |
| all_tags = list(WANDB_DEFAULT_TAGS) + list(tags or []) | |
| cfg = _build_default_config(extra=extra_config) | |
| cfg["job_type"] = job_type | |
| _RUN = wandb.init( | |
| project=WANDB_PROJECT, | |
| entity=WANDB_ENTITY, | |
| name=run_name, | |
| job_type=job_type, | |
| tags=all_tags, | |
| config=cfg, | |
| notes=notes, | |
| group=group, | |
| reinit=True, | |
| ) | |
| print(f"[wandb] run live at {_RUN.url}", file=sys.stderr) | |
| return _RUN | |
| def finish_run() -> None: | |
| """Cleanly close the current W&B run, if any.""" | |
| global _RUN | |
| wandb = _import_wandb() | |
| if wandb is None or _RUN is None: | |
| _RUN = None | |
| return | |
| try: | |
| wandb.finish() | |
| finally: | |
| _RUN = None | |
| def run_context(run_name: str, job_type: str, **kwargs): | |
| """Context-manager wrapper around :func:`init_run` / :func:`finish_run`.""" | |
| init_run(run_name, job_type, **kwargs) | |
| try: | |
| yield get_run() | |
| finally: | |
| finish_run() | |
| # --------------------------------------------------------------------------- # | |
| # Generic logging helpers # | |
| # --------------------------------------------------------------------------- # | |
| def log(metrics: Mapping[str, Any], *, step: Optional[int] = None, | |
| commit: bool = True) -> None: | |
| """No-op-safe ``wandb.log`` wrapper. | |
| We store training-step alignment as an explicit scalar | |
| ``train/global_step`` instead of passing W&B's reserved ``step=`` value. | |
| HuggingFace/TRL may advance W&B's internal step before our callback logs, | |
| which otherwise produces "Tried to log to step N that is less than the | |
| current step N+1" and drops eval metrics. | |
| """ | |
| wandb = _import_wandb() | |
| if wandb is None or _RUN is None: | |
| return | |
| try: | |
| payload = dict(metrics) | |
| if step is not None and "train/global_step" not in payload: | |
| payload["train/global_step"] = int(step) | |
| wandb.log(payload, commit=commit) | |
| except Exception as exc: # pragma: no cover - defensive | |
| print(f"[wandb] log failed: {exc}", file=sys.stderr) | |
| def update_summary(values: Mapping[str, Any]) -> None: | |
| """Write to ``run.summary`` (the run's headline numbers).""" | |
| if _RUN is None: | |
| return | |
| try: | |
| for k, v in values.items(): | |
| _RUN.summary[k] = v | |
| except Exception as exc: # pragma: no cover | |
| print(f"[wandb] summary update failed: {exc}", file=sys.stderr) | |
| # --------------------------------------------------------------------------- # | |
| # Project-specific helpers # | |
| # --------------------------------------------------------------------------- # | |
| _REWARD_KEYS = ( | |
| "logical_correction", | |
| "syndrome_consistency", | |
| "hamming_overlap", | |
| "format_compliance", | |
| "pymatching_beat", | |
| "total", | |
| ) | |
| def log_reward_breakdown( | |
| breakdowns: Sequence[Mapping[str, float]], | |
| *, | |
| step: Optional[int] = None, | |
| prefix: str = "rl", | |
| ) -> None: | |
| """Log mean / min / max for each of the five reward components. | |
| ``breakdowns`` is a list of dicts, one per generation in the most-recent | |
| GRPO step (length = ``GRPO_GEN_PER_PROMPT * batch``). We log mean and | |
| standard deviation so the dashboard has both signal and noise. | |
| """ | |
| if not breakdowns or _RUN is None: | |
| return | |
| out: dict[str, float] = {} | |
| for k in _REWARD_KEYS: | |
| vals = [float(b.get(k, 0.0)) for b in breakdowns] | |
| n = max(1, len(vals)) | |
| mean = sum(vals) / n | |
| var = sum((v - mean) ** 2 for v in vals) / n | |
| out[f"{prefix}/reward/{k}_mean"] = mean | |
| out[f"{prefix}/reward/{k}_std"] = var ** 0.5 | |
| out[f"{prefix}/reward/{k}_max"] = max(vals) | |
| out[f"{prefix}/reward/{k}_min"] = min(vals) | |
| log(out, step=step) | |
| def log_parse_stats( | |
| parse_results: Iterable, # iterable of qubit_medic.prompts.ParseResult | |
| *, | |
| step: Optional[int] = None, | |
| prefix: str = "rl", | |
| ) -> None: | |
| """Log parse-success / partial / failure rates for the most-recent batch.""" | |
| if _RUN is None: | |
| return | |
| parse_results = list(parse_results) | |
| n = max(1, len(parse_results)) | |
| success = sum(1 for r in parse_results if getattr(r, "parse_success", False)) | |
| partial = sum(1 for r in parse_results | |
| if not getattr(r, "parse_success", False) | |
| and getattr(r, "parse_partial", False)) | |
| log({ | |
| f"{prefix}/parse/success_rate": success / n, | |
| f"{prefix}/parse/partial_rate": partial / n, | |
| f"{prefix}/parse/failure_rate": (n - success - partial) / n, | |
| f"{prefix}/parse/sample_count": n, | |
| }, step=step) | |
| def log_curriculum( | |
| curriculum_stats: Mapping[str, Mapping[str, float]], | |
| *, | |
| step: Optional[int] = None, | |
| prefix: str = "rl", | |
| ) -> None: | |
| """Log the per-level moving-average from the env health endpoint. | |
| ``curriculum_stats`` is what | |
| :meth:`qubit_medic.server.curriculum.CurriculumScheduler.snapshot` | |
| returns; one inner dict per level with keys ``moving_mean`` / ``samples``. | |
| """ | |
| if _RUN is None or not curriculum_stats: | |
| return | |
| out: dict[str, float] = {} | |
| for level_name, stats in curriculum_stats.items(): | |
| out[f"{prefix}/curriculum/{level_name}_mean"] = float(stats.get("moving_mean", 0.0)) | |
| out[f"{prefix}/curriculum/{level_name}_samples"] = float(stats.get("samples", 0.0)) | |
| log(out, step=step) | |
| def log_generation_table( | |
| rows: Sequence[Mapping[str, Any]], | |
| *, | |
| step: Optional[int], | |
| table_name: str = "rl/generations", | |
| columns: Optional[Sequence[str]] = None, | |
| ) -> None: | |
| """Log a wandb.Table of (prompt, completion, reward, ...) rows. | |
| Each row is a flat dict; the column set is the union of all keys (or | |
| the explicit ``columns`` arg). Used to surface qualitative samples | |
| in addition to the scalar curves. | |
| """ | |
| wandb = _import_wandb() | |
| if wandb is None or _RUN is None or not rows: | |
| return | |
| cols = list(columns) if columns is not None else sorted( | |
| {k for row in rows for k in row.keys()} | |
| ) | |
| try: | |
| table = wandb.Table(columns=cols) | |
| for row in rows: | |
| table.add_data(*[row.get(c, None) for c in cols]) | |
| log({table_name: table}, step=step) | |
| except Exception as exc: # pragma: no cover | |
| print(f"[wandb] table log failed: {exc}", file=sys.stderr) | |
| def log_eval_summary( | |
| summary: Mapping[str, Any], | |
| *, | |
| step: Optional[int] = None, | |
| prefix: str = "eval", | |
| ) -> None: | |
| """Log the dict produced by ``scripts/eval._summary`` as scalars.""" | |
| if _RUN is None: | |
| return | |
| out: dict[str, Any] = {} | |
| for k, v in summary.items(): | |
| if isinstance(v, (int, float)): | |
| out[f"{prefix}/{k}"] = v | |
| log(out, step=step) | |
| update_summary({f"{prefix}/{k}": v for k, v in summary.items() | |
| if isinstance(v, (int, float, str, bool))}) | |
| def log_artifact( | |
| path: str, *, name: str, artifact_type: str, | |
| description: Optional[str] = None, | |
| ) -> None: | |
| """Save a file or directory as a W&B artifact.""" | |
| wandb = _import_wandb() | |
| if wandb is None or _RUN is None: | |
| return | |
| try: | |
| art = wandb.Artifact(name, type=artifact_type, description=description) | |
| if os.path.isdir(path): | |
| art.add_dir(path) | |
| else: | |
| art.add_file(path) | |
| _RUN.log_artifact(art) | |
| except Exception as exc: # pragma: no cover | |
| print(f"[wandb] artifact log failed: {exc}", file=sys.stderr) | |
| # --------------------------------------------------------------------------- # | |
| # CLI integration helpers # | |
| # --------------------------------------------------------------------------- # | |
| def derive_report_to(report_to: str) -> str: | |
| """Translate the user-facing ``--report-to`` flag. | |
| If the user passes ``"wandb"`` but wandb is unavailable, fall back to | |
| ``"none"`` rather than crashing the trainer. Lets the same script run | |
| on Colab (with wandb) and CI (without). | |
| """ | |
| if report_to == "wandb" and not is_available(): | |
| print("[wandb] requested via --report-to but unavailable; falling back to 'none'", | |
| file=sys.stderr) | |
| return "none" | |
| return report_to | |
| def make_run_name(prefix: str, suffix: Optional[str] = None) -> str: | |
| """Build a default run name like ``sft-warmup-20260425-2105``.""" | |
| stamp = time.strftime("%Y%m%d-%H%M%S") | |
| bits = [prefix, stamp] | |
| if suffix: | |
| bits.append(suffix) | |
| return "-".join(bits) | |
| __all__ = [ | |
| "derive_report_to", | |
| "finish_run", | |
| "get_run", | |
| "init_run", | |
| "is_available", | |
| "is_disabled", | |
| "log", | |
| "log_artifact", | |
| "log_curriculum", | |
| "log_eval_summary", | |
| "log_generation_table", | |
| "log_parse_stats", | |
| "log_reward_breakdown", | |
| "make_run_name", | |
| "run_context", | |
| "update_summary", | |
| "WANDB_LOG_GENERATIONS_EVERY", | |
| "WANDB_SAMPLE_GENERATIONS", | |
| ] | |