Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """Training monitoring: TrainingMonitor, GRPOStabilityCallback, RolloutAuditSampler. | |
| Extracted from train.py to keep the training pipeline modular. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| try: | |
| from transformers import TrainerCallback | |
| except ModuleNotFoundError: | |
| class TrainerCallback: # type: ignore[no-redef] | |
| """Lightweight fallback so monitoring helpers are importable in lean CI.""" | |
| pass | |
| from training.metrics import safe_ratio, aggregate_batch_metrics, summarize_sentinel_history | |
| # --------------------------------------------------------------------------- | |
| # TrainingMonitor | |
| # --------------------------------------------------------------------------- | |
| class TrainingMonitor: | |
| """Write structured per-batch training metrics for proof-pack and judge review.""" | |
| def __init__(self, output_dir: str) -> None: | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.metrics_path = self.output_dir / "training_metrics.jsonl" | |
| self.stability_path = self.output_dir / "training_stability.jsonl" | |
| self.summary_path = self.output_dir / "latest_summary.json" | |
| self.stack_path = self.output_dir / "training_stack_versions.json" | |
| self.batch_index = 0 | |
| self.running_reward_total = 0.0 | |
| self.running_batch_count = 0 | |
| self.best_reward_mean = float("-inf") | |
| self.latest_batch_metrics: Dict[str, Any] = {} | |
| self.latest_trainer_metrics: Dict[str, Any] = {} | |
| self.latest_guardrail: Dict[str, Any] = {} | |
| def write_stack_versions(self, stack_versions: Dict[str, Any]) -> None: | |
| self.stack_path.write_text( | |
| json.dumps(stack_versions, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| def _write_latest_summary(self) -> None: | |
| payload = dict(self.latest_batch_metrics) | |
| if self.latest_trainer_metrics: | |
| payload.update(self.latest_trainer_metrics) | |
| payload["trainer_metrics"] = dict(self.latest_trainer_metrics) | |
| if self.latest_guardrail: | |
| payload["kl_guardrail"] = dict(self.latest_guardrail) | |
| payload["adaptive_beta"] = self.latest_guardrail.get("current_beta") | |
| if not payload: | |
| return | |
| self.summary_path.write_text( | |
| json.dumps(payload, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| def log_batch( | |
| self, | |
| rewards: List[float], | |
| histories: List[List[Dict[str, Any]]], | |
| task_ids: List[str], | |
| variant_seeds: List[int], | |
| sentinel_task_ids: Optional[List[str]] = None, | |
| completions: Optional[List[str]] = None, | |
| prompts: Optional[List[str]] = None, | |
| adversarial_cases: Optional[List[str]] = None, | |
| curriculum_summary: Optional[Dict[str, Any]] = None, | |
| prompt_refreshes: int = 0, | |
| reward_schedule: Optional[Dict[str, Any]] = None, | |
| memory_summary: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| self.batch_index += 1 | |
| if sentinel_task_ids is None: | |
| sentinel_task_ids = ["basic_oversight", "fleet_monitoring_conflict", "adversarial_worker", "multi_crisis_command"] | |
| metrics = aggregate_batch_metrics( | |
| rewards=rewards, | |
| histories=histories, | |
| task_ids=task_ids, | |
| variant_seeds=variant_seeds, | |
| sentinel_task_ids=sentinel_task_ids, | |
| completions=completions, | |
| prompts=prompts, | |
| adversarial_cases=adversarial_cases, | |
| curriculum_summary=curriculum_summary, | |
| prompt_refreshes=prompt_refreshes, | |
| ) | |
| metrics["batch_index"] = self.batch_index | |
| metrics["monitoring_mode"] = ( | |
| "sentinel" | |
| if any(task_id in sentinel_task_ids for task_id in task_ids) | |
| else "irt" | |
| ) | |
| if reward_schedule: | |
| metrics["reward_schedule"] = reward_schedule | |
| if memory_summary: | |
| metrics["memory"] = memory_summary | |
| self.running_batch_count += 1 | |
| self.running_reward_total += metrics["reward_mean"] | |
| self.best_reward_mean = max(self.best_reward_mean, metrics["reward_mean"]) | |
| metrics["running_reward_mean"] = round( | |
| safe_ratio(self.running_reward_total, self.running_batch_count), | |
| 4, | |
| ) | |
| metrics["best_reward_mean"] = round(self.best_reward_mean, 4) | |
| with self.metrics_path.open("a", encoding="utf-8") as handle: | |
| handle.write(json.dumps(metrics, sort_keys=True)) | |
| handle.write("\n") | |
| self.latest_batch_metrics = dict(metrics) | |
| self._write_latest_summary() | |
| return metrics | |
| def log_twin_replay( | |
| self, | |
| histories: List[List[Dict[str, Any]]], | |
| task_ids: List[str], | |
| variant_seeds: List[int], | |
| rewards: List[float], | |
| ) -> Optional[Dict[str, Any]]: | |
| """Run digital twin counterfactual replay and log metrics.""" | |
| try: | |
| from sentinel.twin_replay import compute_batch_twin_metrics | |
| twin_metrics = compute_batch_twin_metrics(histories, task_ids, variant_seeds, rewards) | |
| if twin_metrics.get("twin_replays", 0) > 0: | |
| twin_path = self.output_dir / "twin_replay_metrics.jsonl" | |
| twin_metrics["batch_index"] = self.batch_index | |
| with twin_path.open("a", encoding="utf-8") as handle: | |
| handle.write(json.dumps(twin_metrics, sort_keys=True)) | |
| handle.write("\n") | |
| # Merge into latest batch metrics for dashboard | |
| self.latest_batch_metrics.update(twin_metrics) | |
| self._write_latest_summary() | |
| return twin_metrics | |
| except Exception as exc: | |
| import logging | |
| logging.getLogger(__name__).debug("Twin replay skipped: %s", exc) | |
| return None | |
| def log_reputation_update( | |
| self, | |
| histories: List[List[Dict[str, Any]]], | |
| ) -> Optional[Dict[str, Dict[str, Any]]]: | |
| """Update cross-episode worker reputation profiles.""" | |
| try: | |
| from sentinel.reputation import WorkerReputationTracker | |
| tracker = WorkerReputationTracker(str(self.output_dir / "worker_reputation.json")) | |
| all_updated = {} | |
| for history in histories: | |
| if history: | |
| updated = tracker.update_from_episode(history) | |
| all_updated.update(updated) | |
| return all_updated | |
| except Exception as exc: | |
| import logging | |
| logging.getLogger(__name__).debug("Reputation update skipped: %s", exc) | |
| return None | |
| def log_trainer_metrics( | |
| self, | |
| *, | |
| global_step: int, | |
| trainer_metrics: Dict[str, Any], | |
| guardrail: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| payload = { | |
| "global_step": int(global_step), | |
| **trainer_metrics, | |
| } | |
| if guardrail: | |
| payload["kl_guardrail"] = guardrail | |
| with self.stability_path.open("a", encoding="utf-8") as handle: | |
| handle.write(json.dumps(payload, sort_keys=True)) | |
| handle.write("\n") | |
| self.latest_trainer_metrics = dict(trainer_metrics) | |
| if guardrail: | |
| self.latest_guardrail = dict(guardrail) | |
| self._write_latest_summary() | |
| return payload | |
| # --------------------------------------------------------------------------- | |
| # GRPOStabilityCallback | |
| # --------------------------------------------------------------------------- | |
| class GRPOStabilityCallback(TrainerCallback): | |
| """Hook TRL trainer logs to persist KL/entropy metrics and adapt beta conservatively.""" | |
| def __init__( | |
| self, | |
| training_monitor: TrainingMonitor, | |
| *, | |
| initial_beta: float, | |
| target_kl: float, | |
| adaptive: bool, | |
| low_factor: float, | |
| high_factor: float, | |
| beta_up_mult: float, | |
| beta_down_mult: float, | |
| min_beta: float, | |
| max_beta: float, | |
| hard_stop_enabled: bool, | |
| hard_stop_mult: float, | |
| ) -> None: | |
| self.training_monitor = training_monitor | |
| self.current_beta = float(initial_beta) | |
| self.target_kl = float(target_kl) | |
| self.adaptive = bool(adaptive) | |
| self.low_factor = float(low_factor) | |
| self.high_factor = float(high_factor) | |
| self.beta_up_mult = float(beta_up_mult) | |
| self.beta_down_mult = float(beta_down_mult) | |
| self.min_beta = float(min_beta) | |
| self.max_beta = float(max_beta) | |
| self.hard_stop_enabled = bool(hard_stop_enabled) | |
| self.hard_stop_mult = float(hard_stop_mult) | |
| self.trainer = None | |
| def bind_trainer(self, trainer) -> None: | |
| self.trainer = trainer | |
| self.current_beta = float(getattr(trainer, "beta", self.current_beta) or self.current_beta) | |
| def _first_float(logs: Dict[str, Any], keys: List[str]) -> Optional[float]: | |
| for key in keys: | |
| value = logs.get(key) | |
| if value is None: | |
| continue | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| continue | |
| return None | |
| def _apply_beta(self, value: float) -> None: | |
| if self.trainer is None: | |
| self.current_beta = float(value) | |
| return | |
| self.current_beta = float(value) | |
| setattr(self.trainer, "beta", self.current_beta) | |
| if hasattr(self.trainer, "args"): | |
| if hasattr(self.trainer.args, "beta"): | |
| setattr(self.trainer.args, "beta", self.current_beta) | |
| if hasattr(self.trainer.args, "kl_coef"): | |
| setattr(self.trainer.args, "kl_coef", self.current_beta) | |
| def _guardrail_update(self, approx_kl: Optional[float]): | |
| low_threshold = self.target_kl / max(self.low_factor, 1.0) | |
| high_threshold = self.target_kl * max(self.high_factor, 1.0) | |
| guardrail = { | |
| "enabled": self.adaptive, | |
| "target_kl": round(self.target_kl, 4), | |
| "low_threshold": round(low_threshold, 4), | |
| "high_threshold": round(high_threshold, 4), | |
| "previous_beta": round(self.current_beta, 6), | |
| "current_beta": round(self.current_beta, 6), | |
| "action": "hold", | |
| "hard_stop_triggered": False, | |
| } | |
| if approx_kl is None: | |
| return guardrail | |
| new_beta = self.current_beta | |
| if self.adaptive and approx_kl > high_threshold: | |
| new_beta = min(self.max_beta, self.current_beta * self.beta_up_mult) | |
| guardrail["action"] = "increase_beta" | |
| elif self.adaptive and approx_kl < low_threshold: | |
| new_beta = max(self.min_beta, self.current_beta * self.beta_down_mult) | |
| guardrail["action"] = "decrease_beta" | |
| if abs(new_beta - self.current_beta) > 1e-12: | |
| self._apply_beta(new_beta) | |
| guardrail["current_beta"] = round(self.current_beta, 6) | |
| if self.hard_stop_enabled and approx_kl > self.target_kl * max(self.hard_stop_mult, 1.0): | |
| guardrail["hard_stop_triggered"] = True | |
| guardrail["action"] = "hard_stop" | |
| return guardrail | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| logs = logs or {} | |
| if any(str(key).startswith("eval_") for key in logs): | |
| return control | |
| approx_kl = self._first_float(logs, ["kl", "objective/kl"]) | |
| policy_entropy = self._first_float(logs, ["entropy", "policy/entropy"]) | |
| clip_ratio = self._first_float(logs, ["clip_ratio/region_mean", "clip_ratio", "objective/clip_ratio"]) | |
| if approx_kl is None and policy_entropy is None and clip_ratio is None: | |
| return control | |
| guardrail = self._guardrail_update(approx_kl) | |
| trainer_metrics = { | |
| "approx_kl": round(float(approx_kl), 6) if approx_kl is not None else None, | |
| "policy_entropy": round(float(policy_entropy), 6) if policy_entropy is not None else None, | |
| "clip_ratio": round(float(clip_ratio), 6) if clip_ratio is not None else None, | |
| } | |
| self.training_monitor.log_trainer_metrics( | |
| global_step=int(getattr(state, "global_step", 0) or 0), | |
| trainer_metrics={key: value for key, value in trainer_metrics.items() if value is not None}, | |
| guardrail=guardrail, | |
| ) | |
| if guardrail.get("hard_stop_triggered"): | |
| control.should_training_stop = True | |
| return control | |
| # --------------------------------------------------------------------------- | |
| # RolloutAuditSampler | |
| # --------------------------------------------------------------------------- | |
| def _truncate_text(text: str, limit: int = 700) -> str: | |
| clean = (text or "").strip() | |
| if len(clean) <= limit: | |
| return clean | |
| return clean[: max(0, limit - 3)].rstrip() + "..." | |
| def _audit_priority(task_id: str, reward: float, history: List[Dict[str, Any]], sentinel_task_ids: List[str]) -> float: | |
| priority = max(0.0, 1.0 - float(reward)) | |
| if task_id in sentinel_task_ids: | |
| rollup = summarize_sentinel_history(history) | |
| priority += rollup["false_negatives"] * 2.0 | |
| priority += rollup["false_positives"] * 1.5 | |
| priority += (1.0 - rollup["risk_reduction_rate"]) * 0.8 | |
| priority += rollup["revision_attempts"] * 0.25 | |
| else: | |
| priority += len(history) * 0.05 | |
| return round(priority, 4) | |
| class RolloutAuditSampler: | |
| """Persist a periodic sample of rollout traces for human audit during training.""" | |
| def __init__(self, output_dir: str, every: int, sample_limit: int) -> None: | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.every = max(0, every) | |
| self.sample_limit = max(0, sample_limit) | |
| self.latest_markdown_path = self.output_dir / "latest.md" | |
| def record_batch( | |
| self, | |
| *, | |
| batch_index: int, | |
| prompts: List[str], | |
| completions: List[str], | |
| rewards: List[float], | |
| histories: List[List[Dict[str, Any]]], | |
| task_ids: List[str], | |
| variant_seeds: List[int], | |
| sentinel_task_ids: Optional[List[str]] = None, | |
| active_task_ids: Optional[List[str]] = None, | |
| monitor_summary: Dict[str, Any], | |
| reward_schedule: Optional[Dict[str, Any]] = None, | |
| ) -> Optional[str]: | |
| if sentinel_task_ids is None: | |
| sentinel_task_ids = ["basic_oversight", "fleet_monitoring_conflict", "adversarial_worker", "multi_crisis_command"] | |
| if active_task_ids is None: | |
| active_task_ids = task_ids | |
| if self.every <= 0 or self.sample_limit <= 0: | |
| return None | |
| if batch_index % self.every != 0: | |
| return None | |
| candidates: List[Dict[str, Any]] = [] | |
| for index, reward in enumerate(rewards): | |
| task_id = str(task_ids[index]) if index < len(task_ids) else active_task_ids[0] | |
| variant_seed = int(variant_seeds[index]) if index < len(variant_seeds) else 0 | |
| history = histories[index] if index < len(histories) else [] | |
| history_summary = ( | |
| summarize_sentinel_history(history) | |
| if task_id in sentinel_task_ids | |
| else {"steps": float(len(history))} | |
| ) | |
| candidates.append( | |
| { | |
| "task_id": task_id, | |
| "variant_seed": variant_seed, | |
| "reward": round(float(reward), 4), | |
| "priority": _audit_priority(task_id, reward, history, sentinel_task_ids), | |
| "prompt": prompts[index] if index < len(prompts) else "", | |
| "completion": completions[index] if index < len(completions) else "", | |
| "history_summary": history_summary, | |
| "history": history, | |
| } | |
| ) | |
| top_samples = sorted( | |
| candidates, | |
| key=lambda item: (item["priority"], item["reward"]), | |
| reverse=True, | |
| )[: self.sample_limit] | |
| payload = { | |
| "batch_index": batch_index, | |
| "reward_schedule": reward_schedule or {}, | |
| "monitor_summary": monitor_summary, | |
| "samples": top_samples, | |
| } | |
| json_path = self.output_dir / f"batch_{batch_index:04d}.json" | |
| json_path.write_text( | |
| json.dumps(payload, indent=2, sort_keys=True, default=str), | |
| encoding="utf-8", | |
| ) | |
| lines = [ | |
| f"# Rollout Audit Batch {batch_index}", | |
| "", | |
| f"- Samples: {len(top_samples)}", | |
| f"- Reward mean: {monitor_summary.get('reward_mean', 0.0):.4f}", | |
| f"- Running reward mean: {monitor_summary.get('running_reward_mean', 0.0):.4f}", | |
| ] | |
| if "approx_kl" in monitor_summary: | |
| lines.append(f"- Approx KL: {monitor_summary.get('approx_kl', 0.0):.6f}") | |
| if "adaptive_beta" in monitor_summary: | |
| lines.append(f"- Adaptive beta: {monitor_summary.get('adaptive_beta', 0.0):.6f}") | |
| if "policy_entropy" in monitor_summary: | |
| lines.append(f"- Policy entropy: {monitor_summary.get('policy_entropy', 0.0):.6f}") | |
| if "decision_entropy" in monitor_summary: | |
| lines.append(f"- Decision entropy: {monitor_summary.get('decision_entropy', 0.0):.4f}") | |
| if "unique_completion_ratio" in monitor_summary: | |
| lines.append(f"- Unique completion ratio: {monitor_summary.get('unique_completion_ratio', 0.0):.4f}") | |
| if "effective_prompt_ratio" in monitor_summary: | |
| lines.append(f"- Effective prompt ratio: {monitor_summary.get('effective_prompt_ratio', 0.0):.4f}") | |
| if "frontier_hit_rate" in monitor_summary: | |
| lines.append(f"- Frontier hit rate: {monitor_summary.get('frontier_hit_rate', 0.0):.4f}") | |
| if "task_diversity_ratio" in monitor_summary: | |
| lines.append(f"- Task diversity ratio: {monitor_summary.get('task_diversity_ratio', 0.0):.4f}") | |
| if "zero_gradient_group_fraction" in monitor_summary: | |
| lines.append(f"- Zero-gradient group fraction: {monitor_summary.get('zero_gradient_group_fraction', 0.0):.4f}") | |
| if "adversarial_case_fraction" in monitor_summary: | |
| lines.append(f"- Adversarial case fraction: {monitor_summary.get('adversarial_case_fraction', 0.0):.4f}") | |
| if "twin_damage_reduction_rate" in monitor_summary: | |
| lines.append(f"- Twin damage reduction rate: {monitor_summary.get('twin_damage_reduction_rate', 0.0):.4f}") | |
| if "coaching_quality" in monitor_summary: | |
| lines.append(f"- Coaching quality: {monitor_summary.get('coaching_quality', 0.0):.4f}") | |
| if monitor_summary.get("memory"): | |
| mem = monitor_summary["memory"] | |
| lines.append( | |
| f"- Memory: enabled={mem.get('agent_memory_enabled')} cards={mem.get('mistake_cards_stored', 0)}" | |
| ) | |
| if reward_schedule: | |
| lines.append( | |
| f"- Reward schedule: {reward_schedule.get('stage', 'unknown')} ({reward_schedule.get('mode', 'unknown')})" | |
| ) | |
| lines.append("") | |
| for sample_index, sample in enumerate(top_samples, start=1): | |
| history_summary = sample.get("history_summary") or {} | |
| lines.extend( | |
| [ | |
| f"## Sample {sample_index}", | |
| "", | |
| f"- Task: `{sample['task_id']}`", | |
| f"- Seed: `{sample['variant_seed']}`", | |
| f"- Reward: `{sample['reward']:.4f}`", | |
| f"- Audit priority: `{sample['priority']:.4f}`", | |
| ] | |
| ) | |
| if "detection_rate" in history_summary: | |
| lines.extend( | |
| [ | |
| f"- Detection rate: `{history_summary.get('detection_rate', 0.0):.4f}`", | |
| f"- False positive rate: `{history_summary.get('false_positive_rate', 0.0):.4f}`", | |
| f"- Risk reduction rate: `{history_summary.get('risk_reduction_rate', 0.0):.4f}`", | |
| f"- Twin without SENTINEL damage: `{history_summary.get('twin_without_sentinel_damage_total', 0.0):.4f}`", | |
| f"- Twin with SENTINEL damage: `{history_summary.get('twin_with_sentinel_damage_total', 0.0):.4f}`", | |
| f"- Rehabilitation rate: `{history_summary.get('worker_rehabilitation_rate', 0.0):.4f}`", | |
| f"- Coaching quality: `{history_summary.get('coaching_quality', 0.0):.4f}`", | |
| ] | |
| ) | |
| lines.extend( | |
| [ | |
| "", | |
| "### Prompt", | |
| "", | |
| "```text", | |
| _truncate_text(str(sample.get("prompt", ""))), | |
| "```", | |
| "", | |
| "### Completion", | |
| "", | |
| "```json", | |
| _truncate_text(str(sample.get("completion", ""))), | |
| "```", | |
| "", | |
| ] | |
| ) | |
| self.latest_markdown_path.write_text("\n".join(lines), encoding="utf-8") | |
| return str(json_path) | |