openenv / training /monitoring.py
sentinel-space-publisher
space: publish latest Sentinel app snapshot
c452421
# -*- coding: utf-8 -*-
"""Training monitoring: TrainingMonitor, GRPOStabilityCallback, RolloutAuditSampler.
Extracted from train.py to keep the training pipeline modular.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
try:
from transformers import TrainerCallback
except ModuleNotFoundError:
class TrainerCallback: # type: ignore[no-redef]
"""Lightweight fallback so monitoring helpers are importable in lean CI."""
pass
from training.metrics import safe_ratio, aggregate_batch_metrics, summarize_sentinel_history
# ---------------------------------------------------------------------------
# TrainingMonitor
# ---------------------------------------------------------------------------
class TrainingMonitor:
"""Write structured per-batch training metrics for proof-pack and judge review."""
def __init__(self, output_dir: str) -> None:
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.metrics_path = self.output_dir / "training_metrics.jsonl"
self.stability_path = self.output_dir / "training_stability.jsonl"
self.summary_path = self.output_dir / "latest_summary.json"
self.stack_path = self.output_dir / "training_stack_versions.json"
self.batch_index = 0
self.running_reward_total = 0.0
self.running_batch_count = 0
self.best_reward_mean = float("-inf")
self.latest_batch_metrics: Dict[str, Any] = {}
self.latest_trainer_metrics: Dict[str, Any] = {}
self.latest_guardrail: Dict[str, Any] = {}
def write_stack_versions(self, stack_versions: Dict[str, Any]) -> None:
self.stack_path.write_text(
json.dumps(stack_versions, indent=2, sort_keys=True),
encoding="utf-8",
)
def _write_latest_summary(self) -> None:
payload = dict(self.latest_batch_metrics)
if self.latest_trainer_metrics:
payload.update(self.latest_trainer_metrics)
payload["trainer_metrics"] = dict(self.latest_trainer_metrics)
if self.latest_guardrail:
payload["kl_guardrail"] = dict(self.latest_guardrail)
payload["adaptive_beta"] = self.latest_guardrail.get("current_beta")
if not payload:
return
self.summary_path.write_text(
json.dumps(payload, indent=2, sort_keys=True),
encoding="utf-8",
)
def log_batch(
self,
rewards: List[float],
histories: List[List[Dict[str, Any]]],
task_ids: List[str],
variant_seeds: List[int],
sentinel_task_ids: Optional[List[str]] = None,
completions: Optional[List[str]] = None,
prompts: Optional[List[str]] = None,
adversarial_cases: Optional[List[str]] = None,
curriculum_summary: Optional[Dict[str, Any]] = None,
prompt_refreshes: int = 0,
reward_schedule: Optional[Dict[str, Any]] = None,
memory_summary: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
self.batch_index += 1
if sentinel_task_ids is None:
sentinel_task_ids = ["basic_oversight", "fleet_monitoring_conflict", "adversarial_worker", "multi_crisis_command"]
metrics = aggregate_batch_metrics(
rewards=rewards,
histories=histories,
task_ids=task_ids,
variant_seeds=variant_seeds,
sentinel_task_ids=sentinel_task_ids,
completions=completions,
prompts=prompts,
adversarial_cases=adversarial_cases,
curriculum_summary=curriculum_summary,
prompt_refreshes=prompt_refreshes,
)
metrics["batch_index"] = self.batch_index
metrics["monitoring_mode"] = (
"sentinel"
if any(task_id in sentinel_task_ids for task_id in task_ids)
else "irt"
)
if reward_schedule:
metrics["reward_schedule"] = reward_schedule
if memory_summary:
metrics["memory"] = memory_summary
self.running_batch_count += 1
self.running_reward_total += metrics["reward_mean"]
self.best_reward_mean = max(self.best_reward_mean, metrics["reward_mean"])
metrics["running_reward_mean"] = round(
safe_ratio(self.running_reward_total, self.running_batch_count),
4,
)
metrics["best_reward_mean"] = round(self.best_reward_mean, 4)
with self.metrics_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(metrics, sort_keys=True))
handle.write("\n")
self.latest_batch_metrics = dict(metrics)
self._write_latest_summary()
return metrics
def log_twin_replay(
self,
histories: List[List[Dict[str, Any]]],
task_ids: List[str],
variant_seeds: List[int],
rewards: List[float],
) -> Optional[Dict[str, Any]]:
"""Run digital twin counterfactual replay and log metrics."""
try:
from sentinel.twin_replay import compute_batch_twin_metrics
twin_metrics = compute_batch_twin_metrics(histories, task_ids, variant_seeds, rewards)
if twin_metrics.get("twin_replays", 0) > 0:
twin_path = self.output_dir / "twin_replay_metrics.jsonl"
twin_metrics["batch_index"] = self.batch_index
with twin_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(twin_metrics, sort_keys=True))
handle.write("\n")
# Merge into latest batch metrics for dashboard
self.latest_batch_metrics.update(twin_metrics)
self._write_latest_summary()
return twin_metrics
except Exception as exc:
import logging
logging.getLogger(__name__).debug("Twin replay skipped: %s", exc)
return None
def log_reputation_update(
self,
histories: List[List[Dict[str, Any]]],
) -> Optional[Dict[str, Dict[str, Any]]]:
"""Update cross-episode worker reputation profiles."""
try:
from sentinel.reputation import WorkerReputationTracker
tracker = WorkerReputationTracker(str(self.output_dir / "worker_reputation.json"))
all_updated = {}
for history in histories:
if history:
updated = tracker.update_from_episode(history)
all_updated.update(updated)
return all_updated
except Exception as exc:
import logging
logging.getLogger(__name__).debug("Reputation update skipped: %s", exc)
return None
def log_trainer_metrics(
self,
*,
global_step: int,
trainer_metrics: Dict[str, Any],
guardrail: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
payload = {
"global_step": int(global_step),
**trainer_metrics,
}
if guardrail:
payload["kl_guardrail"] = guardrail
with self.stability_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(payload, sort_keys=True))
handle.write("\n")
self.latest_trainer_metrics = dict(trainer_metrics)
if guardrail:
self.latest_guardrail = dict(guardrail)
self._write_latest_summary()
return payload
# ---------------------------------------------------------------------------
# GRPOStabilityCallback
# ---------------------------------------------------------------------------
class GRPOStabilityCallback(TrainerCallback):
"""Hook TRL trainer logs to persist KL/entropy metrics and adapt beta conservatively."""
def __init__(
self,
training_monitor: TrainingMonitor,
*,
initial_beta: float,
target_kl: float,
adaptive: bool,
low_factor: float,
high_factor: float,
beta_up_mult: float,
beta_down_mult: float,
min_beta: float,
max_beta: float,
hard_stop_enabled: bool,
hard_stop_mult: float,
) -> None:
self.training_monitor = training_monitor
self.current_beta = float(initial_beta)
self.target_kl = float(target_kl)
self.adaptive = bool(adaptive)
self.low_factor = float(low_factor)
self.high_factor = float(high_factor)
self.beta_up_mult = float(beta_up_mult)
self.beta_down_mult = float(beta_down_mult)
self.min_beta = float(min_beta)
self.max_beta = float(max_beta)
self.hard_stop_enabled = bool(hard_stop_enabled)
self.hard_stop_mult = float(hard_stop_mult)
self.trainer = None
def bind_trainer(self, trainer) -> None:
self.trainer = trainer
self.current_beta = float(getattr(trainer, "beta", self.current_beta) or self.current_beta)
@staticmethod
def _first_float(logs: Dict[str, Any], keys: List[str]) -> Optional[float]:
for key in keys:
value = logs.get(key)
if value is None:
continue
try:
return float(value)
except (TypeError, ValueError):
continue
return None
def _apply_beta(self, value: float) -> None:
if self.trainer is None:
self.current_beta = float(value)
return
self.current_beta = float(value)
setattr(self.trainer, "beta", self.current_beta)
if hasattr(self.trainer, "args"):
if hasattr(self.trainer.args, "beta"):
setattr(self.trainer.args, "beta", self.current_beta)
if hasattr(self.trainer.args, "kl_coef"):
setattr(self.trainer.args, "kl_coef", self.current_beta)
def _guardrail_update(self, approx_kl: Optional[float]):
low_threshold = self.target_kl / max(self.low_factor, 1.0)
high_threshold = self.target_kl * max(self.high_factor, 1.0)
guardrail = {
"enabled": self.adaptive,
"target_kl": round(self.target_kl, 4),
"low_threshold": round(low_threshold, 4),
"high_threshold": round(high_threshold, 4),
"previous_beta": round(self.current_beta, 6),
"current_beta": round(self.current_beta, 6),
"action": "hold",
"hard_stop_triggered": False,
}
if approx_kl is None:
return guardrail
new_beta = self.current_beta
if self.adaptive and approx_kl > high_threshold:
new_beta = min(self.max_beta, self.current_beta * self.beta_up_mult)
guardrail["action"] = "increase_beta"
elif self.adaptive and approx_kl < low_threshold:
new_beta = max(self.min_beta, self.current_beta * self.beta_down_mult)
guardrail["action"] = "decrease_beta"
if abs(new_beta - self.current_beta) > 1e-12:
self._apply_beta(new_beta)
guardrail["current_beta"] = round(self.current_beta, 6)
if self.hard_stop_enabled and approx_kl > self.target_kl * max(self.hard_stop_mult, 1.0):
guardrail["hard_stop_triggered"] = True
guardrail["action"] = "hard_stop"
return guardrail
def on_log(self, args, state, control, logs=None, **kwargs):
logs = logs or {}
if any(str(key).startswith("eval_") for key in logs):
return control
approx_kl = self._first_float(logs, ["kl", "objective/kl"])
policy_entropy = self._first_float(logs, ["entropy", "policy/entropy"])
clip_ratio = self._first_float(logs, ["clip_ratio/region_mean", "clip_ratio", "objective/clip_ratio"])
if approx_kl is None and policy_entropy is None and clip_ratio is None:
return control
guardrail = self._guardrail_update(approx_kl)
trainer_metrics = {
"approx_kl": round(float(approx_kl), 6) if approx_kl is not None else None,
"policy_entropy": round(float(policy_entropy), 6) if policy_entropy is not None else None,
"clip_ratio": round(float(clip_ratio), 6) if clip_ratio is not None else None,
}
self.training_monitor.log_trainer_metrics(
global_step=int(getattr(state, "global_step", 0) or 0),
trainer_metrics={key: value for key, value in trainer_metrics.items() if value is not None},
guardrail=guardrail,
)
if guardrail.get("hard_stop_triggered"):
control.should_training_stop = True
return control
# ---------------------------------------------------------------------------
# RolloutAuditSampler
# ---------------------------------------------------------------------------
def _truncate_text(text: str, limit: int = 700) -> str:
clean = (text or "").strip()
if len(clean) <= limit:
return clean
return clean[: max(0, limit - 3)].rstrip() + "..."
def _audit_priority(task_id: str, reward: float, history: List[Dict[str, Any]], sentinel_task_ids: List[str]) -> float:
priority = max(0.0, 1.0 - float(reward))
if task_id in sentinel_task_ids:
rollup = summarize_sentinel_history(history)
priority += rollup["false_negatives"] * 2.0
priority += rollup["false_positives"] * 1.5
priority += (1.0 - rollup["risk_reduction_rate"]) * 0.8
priority += rollup["revision_attempts"] * 0.25
else:
priority += len(history) * 0.05
return round(priority, 4)
class RolloutAuditSampler:
"""Persist a periodic sample of rollout traces for human audit during training."""
def __init__(self, output_dir: str, every: int, sample_limit: int) -> None:
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.every = max(0, every)
self.sample_limit = max(0, sample_limit)
self.latest_markdown_path = self.output_dir / "latest.md"
def record_batch(
self,
*,
batch_index: int,
prompts: List[str],
completions: List[str],
rewards: List[float],
histories: List[List[Dict[str, Any]]],
task_ids: List[str],
variant_seeds: List[int],
sentinel_task_ids: Optional[List[str]] = None,
active_task_ids: Optional[List[str]] = None,
monitor_summary: Dict[str, Any],
reward_schedule: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
if sentinel_task_ids is None:
sentinel_task_ids = ["basic_oversight", "fleet_monitoring_conflict", "adversarial_worker", "multi_crisis_command"]
if active_task_ids is None:
active_task_ids = task_ids
if self.every <= 0 or self.sample_limit <= 0:
return None
if batch_index % self.every != 0:
return None
candidates: List[Dict[str, Any]] = []
for index, reward in enumerate(rewards):
task_id = str(task_ids[index]) if index < len(task_ids) else active_task_ids[0]
variant_seed = int(variant_seeds[index]) if index < len(variant_seeds) else 0
history = histories[index] if index < len(histories) else []
history_summary = (
summarize_sentinel_history(history)
if task_id in sentinel_task_ids
else {"steps": float(len(history))}
)
candidates.append(
{
"task_id": task_id,
"variant_seed": variant_seed,
"reward": round(float(reward), 4),
"priority": _audit_priority(task_id, reward, history, sentinel_task_ids),
"prompt": prompts[index] if index < len(prompts) else "",
"completion": completions[index] if index < len(completions) else "",
"history_summary": history_summary,
"history": history,
}
)
top_samples = sorted(
candidates,
key=lambda item: (item["priority"], item["reward"]),
reverse=True,
)[: self.sample_limit]
payload = {
"batch_index": batch_index,
"reward_schedule": reward_schedule or {},
"monitor_summary": monitor_summary,
"samples": top_samples,
}
json_path = self.output_dir / f"batch_{batch_index:04d}.json"
json_path.write_text(
json.dumps(payload, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
lines = [
f"# Rollout Audit Batch {batch_index}",
"",
f"- Samples: {len(top_samples)}",
f"- Reward mean: {monitor_summary.get('reward_mean', 0.0):.4f}",
f"- Running reward mean: {monitor_summary.get('running_reward_mean', 0.0):.4f}",
]
if "approx_kl" in monitor_summary:
lines.append(f"- Approx KL: {monitor_summary.get('approx_kl', 0.0):.6f}")
if "adaptive_beta" in monitor_summary:
lines.append(f"- Adaptive beta: {monitor_summary.get('adaptive_beta', 0.0):.6f}")
if "policy_entropy" in monitor_summary:
lines.append(f"- Policy entropy: {monitor_summary.get('policy_entropy', 0.0):.6f}")
if "decision_entropy" in monitor_summary:
lines.append(f"- Decision entropy: {monitor_summary.get('decision_entropy', 0.0):.4f}")
if "unique_completion_ratio" in monitor_summary:
lines.append(f"- Unique completion ratio: {monitor_summary.get('unique_completion_ratio', 0.0):.4f}")
if "effective_prompt_ratio" in monitor_summary:
lines.append(f"- Effective prompt ratio: {monitor_summary.get('effective_prompt_ratio', 0.0):.4f}")
if "frontier_hit_rate" in monitor_summary:
lines.append(f"- Frontier hit rate: {monitor_summary.get('frontier_hit_rate', 0.0):.4f}")
if "task_diversity_ratio" in monitor_summary:
lines.append(f"- Task diversity ratio: {monitor_summary.get('task_diversity_ratio', 0.0):.4f}")
if "zero_gradient_group_fraction" in monitor_summary:
lines.append(f"- Zero-gradient group fraction: {monitor_summary.get('zero_gradient_group_fraction', 0.0):.4f}")
if "adversarial_case_fraction" in monitor_summary:
lines.append(f"- Adversarial case fraction: {monitor_summary.get('adversarial_case_fraction', 0.0):.4f}")
if "twin_damage_reduction_rate" in monitor_summary:
lines.append(f"- Twin damage reduction rate: {monitor_summary.get('twin_damage_reduction_rate', 0.0):.4f}")
if "coaching_quality" in monitor_summary:
lines.append(f"- Coaching quality: {monitor_summary.get('coaching_quality', 0.0):.4f}")
if monitor_summary.get("memory"):
mem = monitor_summary["memory"]
lines.append(
f"- Memory: enabled={mem.get('agent_memory_enabled')} cards={mem.get('mistake_cards_stored', 0)}"
)
if reward_schedule:
lines.append(
f"- Reward schedule: {reward_schedule.get('stage', 'unknown')} ({reward_schedule.get('mode', 'unknown')})"
)
lines.append("")
for sample_index, sample in enumerate(top_samples, start=1):
history_summary = sample.get("history_summary") or {}
lines.extend(
[
f"## Sample {sample_index}",
"",
f"- Task: `{sample['task_id']}`",
f"- Seed: `{sample['variant_seed']}`",
f"- Reward: `{sample['reward']:.4f}`",
f"- Audit priority: `{sample['priority']:.4f}`",
]
)
if "detection_rate" in history_summary:
lines.extend(
[
f"- Detection rate: `{history_summary.get('detection_rate', 0.0):.4f}`",
f"- False positive rate: `{history_summary.get('false_positive_rate', 0.0):.4f}`",
f"- Risk reduction rate: `{history_summary.get('risk_reduction_rate', 0.0):.4f}`",
f"- Twin without SENTINEL damage: `{history_summary.get('twin_without_sentinel_damage_total', 0.0):.4f}`",
f"- Twin with SENTINEL damage: `{history_summary.get('twin_with_sentinel_damage_total', 0.0):.4f}`",
f"- Rehabilitation rate: `{history_summary.get('worker_rehabilitation_rate', 0.0):.4f}`",
f"- Coaching quality: `{history_summary.get('coaching_quality', 0.0):.4f}`",
]
)
lines.extend(
[
"",
"### Prompt",
"",
"```text",
_truncate_text(str(sample.get("prompt", ""))),
"```",
"",
"### Completion",
"",
"```json",
_truncate_text(str(sample.get("completion", ""))),
"```",
"",
]
)
self.latest_markdown_path.write_text("\n".join(lines), encoding="utf-8")
return str(json_path)