openenv / train.py
sentinel-space-publisher
space: publish latest Sentinel app snapshot
c452421
"""
train.py - GRPO Fine-tuning for OpenEnv (IRT / SENTINEL)
==============================================================
Runnable training script. Uses TRL GRPOTrainer + Unsloth (optional) + curriculum.
HOW TO RUN:
# Minimum (T4 / A10G, no Unsloth):
python train.py
# With Unsloth (A100 / H100, 2x faster):
USE_UNSLOTH=1 python train.py
# Override model and steps:
MODEL_NAME=unsloth/Qwen3-30B-A3B-bnb-4bit TRAIN_STEPS=200 python train.py
# Resume from checkpoint:
RESUME_FROM=outputs/checkpoints/checkpoint-100 python train.py
ENV VARS:
MODEL_NAME HuggingFace model ID (default: unsloth/Qwen3-30B-A3B-bnb-4bit)
HF_TOKEN HuggingFace token (for gated models)
GROQ_API_KEY Groq API key (for LLM judge panel, optional)
WANDB_PROJECT W&B project name (optional, set to "" to disable)
TRAIN_STEPS Number of GRPO training steps (default: 200)
NUM_GENERATIONS G rollouts per prompt (default: 4)
USE_UNSLOTH Set to "1" to use Unsloth (requires unsloth installed)
RESUME_FROM Path to checkpoint to resume from
OUTPUT_DIR Where to save checkpoints (default: outputs/checkpoints)
LR Learning rate (default: 5e-6)
KL_COEF KL penalty coefficient (default: 0.04)
LORA_R LoRA rank (default: 16)
TRAIN_MONITOR_DIR Structured metrics output dir (default: outputs/monitoring)
WARM_START_STEPS Optional small warm-start steps before GRPO (default: 0)
WARM_START_LR Learning rate for warm-start stage (default: 2e-5)
WARM_START_ONLY Set to "1" to stop after warm-start
"""
from __future__ import annotations
import json
import logging
import math
import os
import platform
import sys
import time
from dataclasses import dataclass, field
from importlib import metadata as importlib_metadata
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import torch
from torch.utils.data import Dataset as TorchDataset
except ModuleNotFoundError:
torch = None
class TorchDataset: # type: ignore[no-redef]
"""Fallback base for tests that import train.py without training deps."""
pass
# bnb-4bit pre-quantized models have compute_dtype=float16 baked in, so LoRA
# adapter parameters and their gradients are FP16. PyTorch 2.10 added a strict
# check in GradScaler._unscale_grads_ that rejects FP16 gradients (intended for
# full-precision training where FP16 grads indicate a misconfiguration). For
# bnb-4bit + LoRA this check is a false positive — patch it out.
if torch is not None:
import torch.amp.grad_scaler as _gs
_orig_unscale_grads = _gs.GradScaler._unscale_grads_
def _allow_fp16_unscale(self, optimizer, inv_scale, found_inf, allow_fp16):
return _orig_unscale_grads(self, optimizer, inv_scale, found_inf, True)
_gs.GradScaler._unscale_grads_ = _allow_fp16_unscale
# Re-export from extracted modules for backward compatibility
from training.metrics import (
safe_ratio as _safe_ratio,
_increment_counter,
_normalize_completion_text,
_extract_completion_choice,
_shannon_entropy_from_labels,
summarize_sentinel_history as _summarize_sentinel_history,
aggregate_batch_metrics as _aggregate_batch_metrics,
completion_diversity_metrics as _completion_diversity_metrics,
productive_signal_metrics as _productive_signal_metrics,
training_coverage_metrics as _training_coverage_metrics,
zero_gradient_group_metrics as _zero_gradient_group_metrics,
frontier_scenario_keys as _frontier_scenario_keys,
set_thresholds as _set_metric_thresholds,
)
from training.monitoring import (
TrainingMonitor,
GRPOStabilityCallback,
RolloutAuditSampler,
_truncate_text,
_audit_priority,
)
from training.prompts import (
build_system_prompt,
scenario_to_prompt,
sentinel_obs_to_prompt,
sentinel_adversarial_case_to_prompt,
build_prompt_record as _build_prompt_record_impl,
memory_context_for_task as _memory_context_for_task,
load_or_create_sentinel_adversarial_cases as _load_or_create_sentinel_adversarial_cases_impl,
AdaptivePromptState as _AdaptivePromptStateBase,
AdaptivePromptDataset,
WarmStartDataset,
build_grpo_dataset as _build_grpo_dataset_impl,
)
from training.episodes import (
parse_action as _parse_action,
greedy_fallback_action as _greedy_fallback_action,
greedy_fallback_sentinel_decision as _greedy_fallback_sentinel_decision,
run_episode_with_completion as _run_episode_with_completion_impl,
_run_irt_episode,
_run_sentinel_episode,
run_sentinel_adversarial_case as _run_sentinel_adversarial_case,
grpo_reward_fn as _grpo_reward_fn_impl,
trajectory_summary_from_history as _trajectory_summary_from_history,
mistakes_from_history as _mistakes_from_history,
mistake_cards_from_history as _mistake_cards_from_history,
successes_from_history as _successes_from_history,
)
from training.curriculum import CURRICULUM_FRONTIER_FAILURE_RATE
MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen3-30B-A3B-bnb-4bit")
HF_TOKEN = os.getenv("HF_TOKEN", "")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "").strip()
TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100"))
NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "2"))
USE_UNSLOTH = os.getenv("USE_UNSLOTH", "1") == "1"
RESUME_FROM = os.getenv("RESUME_FROM", "")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/checkpoints")
LR = float(os.getenv("LR", "5e-6"))
KL_COEF = float(os.getenv("KL_COEF", "0.04"))
LORA_R = int(os.getenv("LORA_R", "16"))
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
PROMPT_DATASET_SIZE = int(os.getenv("PROMPT_DATASET_SIZE", str(max(512, TRAIN_STEPS * 8))))
USE_LLM_PANEL = bool(GROQ_API_KEY) # auto-enable if key available
USE_CURRICULUM = os.getenv("USE_CURRICULUM", "1") == "1"
GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.7"))
GEN_TOP_P = float(os.getenv("GEN_TOP_P", "1.0"))
USE_SENTINEL = os.getenv("USE_SENTINEL", "0") == "1" # Enable SENTINEL training
USE_AGENT_MEMORY = os.getenv("USE_AGENT_MEMORY", "1") == "1"
USE_FEEDBACK_MEMORY = os.getenv("USE_FEEDBACK_MEMORY", "1") == "1" and USE_AGENT_MEMORY
USE_SENTINEL_ADVERSARIAL = os.getenv("USE_SENTINEL_ADVERSARIAL", "1") == "1"
SENTINEL_ADVERSARIAL_PATH = os.getenv(
"SENTINEL_ADVERSARIAL_PATH",
"outputs/sentinel_adversarial_cases.json",
)
SENTINEL_FEEDBACK_MEMORY_PATH = os.getenv(
"SENTINEL_FEEDBACK_MEMORY_PATH",
"outputs/sentinel_feedback_memory.json",
)
TRAIN_MONITOR_DIR = os.getenv("TRAIN_MONITOR_DIR", "outputs/monitoring")
WARM_START_STEPS = int(os.getenv("WARM_START_STEPS", "0"))
WARM_START_LR = float(os.getenv("WARM_START_LR", "2e-5"))
WARM_START_DATASET_SIZE = int(os.getenv("WARM_START_DATASET_SIZE", "24"))
WARM_START_OUTPUT_DIR = os.getenv("WARM_START_OUTPUT_DIR", "outputs/warm_start")
WARM_START_ONLY = os.getenv("WARM_START_ONLY", "0") == "1"
ROLLOUT_AUDIT_DIR = os.getenv("ROLLOUT_AUDIT_DIR", os.path.join(TRAIN_MONITOR_DIR, "rollout_audits"))
ROLLOUT_AUDIT_EVERY = int(os.getenv("ROLLOUT_AUDIT_EVERY", "10"))
ROLLOUT_AUDIT_SAMPLES = int(os.getenv("ROLLOUT_AUDIT_SAMPLES", "2"))
REWARD_SCHEDULE_MODE = os.getenv("REWARD_SCHEDULE_MODE", os.getenv("REWARD_PROFILE", "dynamic"))
MODEL_STEPS_LIMIT = int(os.getenv("MODEL_STEPS_LIMIT", "1"))
KL_TARGET = float(os.getenv("KL_TARGET", "0.08"))
KL_ADAPTIVE = os.getenv("KL_ADAPTIVE", "1") == "1"
KL_LOW_FACTOR = float(os.getenv("KL_LOW_FACTOR", "1.5"))
KL_HIGH_FACTOR = float(os.getenv("KL_HIGH_FACTOR", "1.5"))
KL_BETA_UP_MULT = float(os.getenv("KL_BETA_UP_MULT", "2.0"))
KL_BETA_DOWN_MULT = float(os.getenv("KL_BETA_DOWN_MULT", "0.5"))
KL_MIN_BETA = float(os.getenv("KL_MIN_BETA", "0.005"))
KL_MAX_BETA = float(os.getenv("KL_MAX_BETA", "0.5"))
KL_HARD_STOP_ENABLED = os.getenv("KL_HARD_STOP_ENABLED", "0") == "1"
KL_HARD_STOP_MULT = float(os.getenv("KL_HARD_STOP_MULT", "3.0"))
ZERO_SIGNAL_REWARD_THRESHOLD = float(os.getenv("ZERO_SIGNAL_REWARD_THRESHOLD", "0.05"))
TRIVIAL_REWARD_THRESHOLD = float(os.getenv("TRIVIAL_REWARD_THRESHOLD", "0.95"))
TASK_IDS = [
"severity_classification",
"root_cause_analysis",
"full_incident_management",
]
SENTINEL_TASK_IDS = [
"basic_oversight",
"fleet_monitoring_conflict",
"adversarial_worker",
"multi_crisis_command",
]
def _parse_task_filter(env_name: str, allowed: List[str]) -> List[str]:
raw = os.getenv(env_name, "").strip()
if not raw:
return list(allowed)
selected = [part.strip() for part in raw.split(",") if part.strip()]
unknown = [task_id for task_id in selected if task_id not in allowed]
if unknown:
raise ValueError(
f"{env_name} contains unknown task id(s): {unknown}. "
f"Allowed: {allowed}"
)
return selected or list(allowed)
TASK_IDS = _parse_task_filter("IRT_TASKS", TASK_IDS)
SENTINEL_TASK_IDS = _parse_task_filter("SENTINEL_TASKS", SENTINEL_TASK_IDS)
# Select task set based on USE_SENTINEL flag
ACTIVE_TASK_IDS = SENTINEL_TASK_IDS if USE_SENTINEL else TASK_IDS
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs("outputs/reward_curves", exist_ok=True)
os.makedirs(TRAIN_MONITOR_DIR, exist_ok=True)
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)s %(name)s: %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(os.path.join(OUTPUT_DIR, "train.log")),
],
)
logger = logging.getLogger(__name__)
def _package_version(name: str) -> str:
try:
return importlib_metadata.version(name)
except importlib_metadata.PackageNotFoundError:
return "missing"
def collect_training_stack_versions() -> Dict[str, Any]:
cuda_available = bool(torch is not None and torch.cuda.is_available())
return {
"python": platform.python_version(),
"platform": platform.platform(),
"model_name": MODEL_NAME,
"use_unsloth": USE_UNSLOTH,
"cuda_available": cuda_available,
"bf16_available": bool(cuda_available and torch.cuda.is_bf16_supported()),
"train_steps": TRAIN_STEPS,
"warm_start_steps": WARM_START_STEPS,
"reward_schedule_mode": REWARD_SCHEDULE_MODE,
"memory": {
"agent_memory_enabled": USE_AGENT_MEMORY,
"feedback_memory_enabled": USE_FEEDBACK_MEMORY,
},
"productive_signal_thresholds": {
"zero_signal_reward_threshold": ZERO_SIGNAL_REWARD_THRESHOLD,
"trivial_reward_threshold": TRIVIAL_REWARD_THRESHOLD,
},
"adaptive_curriculum": {
"frontier_failure_rate": CURRICULUM_FRONTIER_FAILURE_RATE,
},
"kl_control": {
"initial_beta": KL_COEF,
"target": KL_TARGET,
"adaptive": KL_ADAPTIVE,
"low_factor": KL_LOW_FACTOR,
"high_factor": KL_HIGH_FACTOR,
"beta_up_mult": KL_BETA_UP_MULT,
"beta_down_mult": KL_BETA_DOWN_MULT,
"min_beta": KL_MIN_BETA,
"max_beta": KL_MAX_BETA,
"hard_stop_enabled": KL_HARD_STOP_ENABLED,
"hard_stop_mult": KL_HARD_STOP_MULT,
},
"packages": {
"torch": getattr(torch, "__version__", "missing") if torch is not None else "missing",
"bitsandbytes": _package_version("bitsandbytes"),
"transformers": _package_version("transformers"),
"peft": _package_version("peft"),
"trl": _package_version("trl"),
"datasets": _package_version("datasets"),
"matplotlib": _package_version("matplotlib"),
"wandb": _package_version("wandb"),
"openenv-core": _package_version("openenv-core"),
"unsloth": _package_version("unsloth"),
},
}
# ---------------------------------------------------------------------------
# W&B setup (optional)
# ---------------------------------------------------------------------------
wandb_enabled = bool(WANDB_PROJECT) and WANDB_PROJECT.lower() not in {"0", "false", "none", "disabled"}
if wandb_enabled:
try:
import wandb
wandb.init(project=WANDB_PROJECT, config={
"model": MODEL_NAME,
"train_steps": TRAIN_STEPS,
"num_generations": NUM_GENERATIONS,
"lr": LR,
"kl_coef": KL_COEF,
"lora_r": LORA_R,
"use_llm_panel": USE_LLM_PANEL,
})
logger.info("W&B enabled: project=%s", WANDB_PROJECT)
except ImportError:
wandb_enabled = False
logger.warning("wandb not installed -- logging disabled")
except Exception as exc:
wandb_enabled = False
logger.warning("wandb init skipped: %s", exc)
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model_and_tokenizer():
"""Load model + tokenizer. Uses Unsloth if USE_UNSLOTH=1, else standard HF.
When Unsloth is enabled:
- 12x faster MoE training via Triton kernels (torch._grouped_mm)
- 3x faster inference via fused attention (FastLanguageModel.for_inference)
- >35% less VRAM via 4-bit quantization + gradient checkpointing
"""
if torch is None:
raise ImportError(
"Training requires torch. Install the training extras before running train.py."
)
if USE_UNSLOTH:
logger.info("Loading model with Unsloth: %s", MODEL_NAME)
from unsloth import FastLanguageModel
# IMPORTANT: keep dtype=float16 for bnb-4bit. The pre-quantized
# unsloth/*-bnb-4bit models have compute_dtype=float16 baked into their
# quantization config. Unsloth's fast_lora kernels use X.dtype as the
# target dtype for LoRA ops; if X is BF16 but bnb dequant output is FP16
# the addmm_ inside matmul_lora crashes with "same dtype" error.
_unsloth_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME,
max_seq_length = 4096,
dtype = _unsloth_dtype,
load_in_4bit = True,
token = HF_TOKEN or None,
)
model = FastLanguageModel.get_peft_model(
model,
r = LORA_R,
target_modules = ["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"],
lora_alpha = LORA_R,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 42,
)
# Enable Unsloth fast inference (2-3x speedup for generation)
# GRPOTrainer internally handles train/eval mode toggling, but
# setting this up front ensures optimized attention kernels are
# compiled and ready for the first rollout batch.
try:
FastLanguageModel.for_inference(model)
logger.info("Unsloth fast inference enabled (fused attention kernels)")
except Exception as exc:
logger.warning("Unsloth fast inference setup skipped: %s", exc)
else:
logger.info("Loading model with standard HF: %s", MODEL_NAME)
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
cuda_available = torch.cuda.is_available()
bf16_available = cuda_available and torch.cuda.is_bf16_supported()
load_kwargs: Dict[str, Any] = {
"torch_dtype": torch.bfloat16 if bf16_available else (torch.float16 if cuda_available else torch.float32),
"device_map" : "auto" if cuda_available else None,
}
if "bnb-4bit" in MODEL_NAME or "4bit" in MODEL_NAME:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = torch.bfloat16,
)
load_kwargs.pop("torch_dtype", None)
if HF_TOKEN:
load_kwargs["token"] = HF_TOKEN
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN or None)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **load_kwargs)
lora_config = LoraConfig(
r = LORA_R,
lora_alpha = LORA_R,
target_modules = ["q_proj","k_proj","v_proj","o_proj"],
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if RESUME_FROM:
logger.info("Resuming from checkpoint: %s", RESUME_FROM)
from peft import PeftModel
model = PeftModel.from_pretrained(model, RESUME_FROM)
return model, tokenizer
# ---------------------------------------------------------------------------
# Dataset construction
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Backward-compatible re-exports for tests
AdaptivePromptState = _AdaptivePromptStateBase
build_prompt_record = _build_prompt_record_impl
build_grpo_dataset = _build_grpo_dataset_impl
_load_or_create_sentinel_adversarial_cases = _load_or_create_sentinel_adversarial_cases_impl
_aggregate_batch_metrics = _aggregate_batch_metrics
_sentinel_history_entry = None # re-exported below
def _sentinel_history_entry_fn(decision, result):
from training.episodes import _sentinel_history_entry as _she
return _she(decision, result)
_sentinel_history_entry = _sentinel_history_entry_fn
# Thin wrappers delegating to extracted modules
# ---------------------------------------------------------------------------
# Prompt construction
def _build_system_prompt(task_id, memory_context=""):
return build_system_prompt(task_id, SENTINEL_TASK_IDS, memory_context)
def _scenario_to_prompt(scenario, task_id, memory_context=""):
return scenario_to_prompt(scenario, task_id, SENTINEL_TASK_IDS, memory_context)
def _sentinel_obs_to_prompt(obs, task_id, memory_context=""):
return sentinel_obs_to_prompt(obs, task_id, SENTINEL_TASK_IDS, memory_context)
# Episode execution
def run_episode_with_completion(completion_text, task_id, variant_seed):
return _run_episode_with_completion_impl(
completion_text, task_id, variant_seed, SENTINEL_TASK_IDS,
model_steps_limit=MODEL_STEPS_LIMIT,
)
def grpo_reward_fn(prompts, completions, **kwargs):
return _grpo_reward_fn_impl(
prompts, completions,
sentinel_task_ids=SENTINEL_TASK_IDS,
active_task_ids=list(ACTIVE_TASK_IDS),
use_llm_panel=USE_LLM_PANEL,
groq_api_key=GROQ_API_KEY,
wandb_enabled=wandb_enabled,
model_steps_limit=MODEL_STEPS_LIMIT,
**kwargs,
)
# Warm-start helpers
def _warm_start_counterfactual(obs, decision):
constitution = obs.constitutional_assessment or {}
violations = list((constitution.get("constitutional_violations") or {}).keys())
if decision.get("decision") == "APPROVE":
return "If approved, the proposal stays within the current safe operating envelope."
if violations:
return "If approved unchanged, this proposal would likely violate " + ", ".join(violations) + " and increase operational risk."
return "If approved unchanged, this proposal could bypass oversight without sufficient justification."
def _warm_start_sentinel_decision_for_observation(obs):
decision = dict(_greedy_fallback_sentinel_decision(obs, []))
violations = sorted((obs.constitutional_assessment.get("constitutional_violations") or {}).keys())
decision.setdefault("worker_message", "Approved. Execute carefully and preserve evidence.")
for key in ["required_evidence", "redirect_action", "reassign_to", "flag_severity", "suggested_action_type", "suggested_target"]:
decision.setdefault(key, [] if key == "required_evidence" else None)
decision.setdefault("suggested_parameters", {})
decision["constitutional_violations"] = violations
decision["counterfactual"] = _warm_start_counterfactual(obs, decision)
return decision
def _build_warm_start_examples(task_ids, memory_context="", memory=None, feedback_memory=None, max_examples=None, max_seeds=3):
if max_examples is None: max_examples = WARM_START_DATASET_SIZE
records = []
for task_id in task_ids:
for seed in range(max_seeds):
task_memory = _memory_context_for_task(memory, feedback_memory, task_id, memory_context)
if task_id in SENTINEL_TASK_IDS:
from sentinel.environment import SentinelEnv
env = SentinelEnv()
obs = env.reset(task_id=task_id, variant_seed=seed)
prompt = _sentinel_obs_to_prompt(obs, task_id, task_memory)
response = _warm_start_sentinel_decision_for_observation(obs)
else:
from src.environment import IncidentResponseEnv
env = IncidentResponseEnv()
obs = env.reset(task_id=task_id, variant_seed=seed)
prompt = _scenario_to_prompt(env._scenario, task_id, task_memory)
response = _greedy_fallback_action(env, obs, [])
records.append({"task_id": task_id, "variant_seed": seed, "text": prompt + json.dumps(response, sort_keys=True)})
if len(records) >= max_examples: return records
if records and len(records) < max_examples:
cycled = []
idx = 0
while len(records) + len(cycled) < max_examples:
cycled.append(dict(records[idx % len(records)]))
idx += 1
records.extend(cycled)
return records[:max_examples]
def _run_small_warm_start(model, tokenizer, prompt_state):
from transformers import Trainer, TrainingArguments
output_dir = Path(WARM_START_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
examples = _build_warm_start_examples(task_ids=list(ACTIVE_TASK_IDS), memory_context=prompt_state.memory_context, memory=prompt_state.memory, feedback_memory=prompt_state.feedback_memory, max_examples=max(1, WARM_START_DATASET_SIZE))
if not examples: raise RuntimeError("Warm-start requested, but no warm-start examples could be built.")
preview = [{"task_id": r["task_id"], "variant_seed": r["variant_seed"], "text_preview": str(r["text"])[:240]} for r in examples[:5]]
(output_dir / "dataset_preview.json").write_text(json.dumps(preview, indent=2), encoding="utf-8")
dataset = WarmStartDataset([r["text"] for r in examples], tokenizer)
args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=WARM_START_LR,
max_steps=max(1, WARM_START_STEPS),
num_train_epochs=1,
logging_steps=1,
save_strategy="no",
remove_unused_columns=False,
bf16=False,
fp16=torch.cuda.is_available(),
report_to="wandb" if wandb_enabled else "none",
)
trainer = Trainer(model=model, args=args, train_dataset=dataset)
trainer.train()
final_dir = output_dir / "final"
trainer.save_model(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
summary = {"enabled": True, "steps": max(1, WARM_START_STEPS), "learning_rate": WARM_START_LR, "dataset_size": len(examples), "output_dir": str(output_dir), "saved_model_path": str(final_dir), "task_ids": list(ACTIVE_TASK_IDS)}
(output_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
logger.info("Warm-start complete: steps=%d dataset=%d saved=%s", summary["steps"], summary["dataset_size"], final_dir)
return summary
def train():
logger.info("=" * 60)
logger.info("OpenEnv GRPO Training")
logger.info("Model: %s", MODEL_NAME)
logger.info("Steps: %d", TRAIN_STEPS)
logger.info("G: %d rollouts/prompt", NUM_GENERATIONS)
logger.info("LR: %g", LR)
logger.info("KL coef: %g", KL_COEF)
logger.info("LoRA r: %d", LORA_R)
logger.info("LLM panel: %s", USE_LLM_PANEL)
logger.info("Curriculum: %s", USE_CURRICULUM)
logger.info("Sampling: temperature=%.2f top_p=%.2f", GEN_TEMPERATURE, GEN_TOP_P)
logger.info("Episode: MODEL_STEPS_LIMIT=%d MAX_NEW_TOKENS=%d", MODEL_STEPS_LIMIT, MAX_NEW_TOKENS)
logger.info("EvalMinDif: %s", os.getenv("EVAL_MIN_DIFFICULTY", "0.0"))
logger.info("Warm start: %s", WARM_START_STEPS if WARM_START_STEPS > 0 else "disabled")
logger.info("Reward schedule: %s", REWARD_SCHEDULE_MODE if USE_SENTINEL else "n/a")
logger.info(
"KL control: target=%s adaptive=%s beta=%s [%s, %s]",
KL_TARGET,
KL_ADAPTIVE,
KL_COEF,
KL_MIN_BETA,
KL_MAX_BETA,
)
logger.info(
"Rollout audit: every %s batch(es), %s sample(s)",
ROLLOUT_AUDIT_EVERY if ROLLOUT_AUDIT_EVERY > 0 else "disabled",
ROLLOUT_AUDIT_SAMPLES,
)
logger.info("Output: %s", OUTPUT_DIR)
logger.info("=" * 60)
# Load model
model, tokenizer = load_model_and_tokenizer()
# Load curriculum and agent memory
from training.curriculum import get_curriculum
from training.memory import (
load_agent_memory, build_memory_context, maybe_consolidate_memory,
record_episode as mem_record_episode, save_agent_memory,
memory_summary as summarize_agent_memory, new_agent_memory,
)
from sentinel.feedback import (
load_feedback_memory,
empty_feedback_memory,
record_episode_feedback,
save_feedback_memory,
)
from sentinel.rewards import reset_reward_weights, scheduled_reward_weights, set_reward_weights
curriculum = get_curriculum(active_task_ids=ACTIVE_TASK_IDS) if USE_CURRICULUM else None
memory = load_agent_memory() if USE_AGENT_MEMORY else new_agent_memory()
feedback_memory = (
load_feedback_memory(SENTINEL_FEEDBACK_MEMORY_PATH)
if USE_FEEDBACK_MEMORY
else empty_feedback_memory()
)
memory_ctx = build_memory_context(memory) if USE_AGENT_MEMORY else ""
prompt_state = _AdaptivePromptStateBase(
task_ids=list(ACTIVE_TASK_IDS),
sentinel_task_ids=list(SENTINEL_TASK_IDS),
curriculum=curriculum,
memory=memory,
feedback_memory=feedback_memory,
memory_context=memory_ctx,
memory_enabled=USE_AGENT_MEMORY,
max_seeds=5,
use_sentinel=USE_SENTINEL,
use_feedback_memory=USE_FEEDBACK_MEMORY,
use_llm_panel=USE_LLM_PANEL,
groq_api_key=GROQ_API_KEY,
sentinel_adversarial_path=SENTINEL_ADVERSARIAL_PATH,
sentinel_feedback_memory_path=SENTINEL_FEEDBACK_MEMORY_PATH,
use_sentinel_adversarial=USE_SENTINEL_ADVERSARIAL,
)
if USE_SENTINEL and USE_SENTINEL_ADVERSARIAL:
prompt_state.refresh_adversarial_cases()
train_dataset = AdaptivePromptDataset(
state=prompt_state,
total_samples=PROMPT_DATASET_SIZE,
)
training_monitor = TrainingMonitor(TRAIN_MONITOR_DIR)
training_monitor.write_stack_versions(collect_training_stack_versions())
rollout_auditor = RolloutAuditSampler(
output_dir=ROLLOUT_AUDIT_DIR,
every=ROLLOUT_AUDIT_EVERY,
sample_limit=ROLLOUT_AUDIT_SAMPLES,
)
warm_start_summary: Optional[Dict[str, Any]] = None
warm_start_path = os.path.join(WARM_START_OUTPUT_DIR, "final")
if WARM_START_STEPS > 0 and os.path.isdir(warm_start_path):
logger.info("Warm-start checkpoint found at %s — SKIPPING (saves ~20 min)", warm_start_path)
# Reload the warm-start LoRA weights
try:
from peft import PeftModel
if not hasattr(model, "peft_config"):
model = PeftModel.from_pretrained(model, warm_start_path)
# Coerce LoRA adapter dtype to fp16 to match the bnb-4bit base
# compute dtype. bnb-4bit base weights are unaffected by .to();
# only the (small) LoRA adapters get cast. Prevents the
# "self and mat2 must have the same dtype" crash inside Unsloth's
# fast_lora kernels (which derive target dtype from X.dtype = fp16).
if torch.cuda.is_available():
for name, param in model.named_parameters():
if "lora_" in name and param.dtype != torch.float16:
param.data = param.data.to(torch.float16)
logger.info("Loaded warm-start LoRA from %s", warm_start_path)
except Exception as exc:
logger.warning("Could not reload warm-start LoRA: %s (continuing with base model)", exc)
warm_start_summary = {"saved_model_path": warm_start_path, "skipped": True}
elif WARM_START_STEPS > 0:
warm_start_summary = _run_small_warm_start(model, tokenizer, prompt_state)
if WARM_START_ONLY:
return warm_start_summary["saved_model_path"]
# GRPO config
from trl import GRPOConfig, GRPOTrainer
grpo_config = GRPOConfig(
output_dir = OUTPUT_DIR,
num_train_epochs = 1,
per_device_train_batch_size = NUM_GENERATIONS,
gradient_accumulation_steps = 1,
num_generations = NUM_GENERATIONS,
max_completion_length = MAX_NEW_TOKENS,
learning_rate = LR,
beta = KL_COEF,
temperature = GEN_TEMPERATURE,
top_p = GEN_TOP_P,
logging_steps = 1,
save_steps = 25,
save_total_limit = 4,
dataloader_num_workers = 0,
bf16 = False,
fp16 = torch.cuda.is_available(),
report_to = "wandb" if wandb_enabled else "none",
max_steps = TRAIN_STEPS,
)
# Wrap reward fn to inject curriculum-selected task_ids and seeds
def reward_fn_with_curriculum(prompts, completions, **kwargs):
# Extract task_id and variant_seed from dataset columns if available
t_ids = kwargs.get("task_id", [ACTIVE_TASK_IDS[0]] * len(prompts))
v_seeds = kwargs.get("variant_seed", [0] * len(prompts))
adv_cases = kwargs.get("adversarial_case", [""] * len(prompts))
curriculum_snapshot = curriculum.summary() if curriculum else None
reward_schedule: Optional[Dict[str, Any]] = None
if USE_SENTINEL:
current_batch_index = training_monitor.batch_index + 1
progress = min(1.0, current_batch_index / max(1, TRAIN_STEPS))
reward_schedule = scheduled_reward_weights(
progress=progress,
mode=REWARD_SCHEDULE_MODE,
)
set_reward_weights(reward_schedule["weights"])
rewards, histories = grpo_reward_fn(
prompts = prompts,
completions = completions,
task_id = t_ids,
variant_seed = v_seeds,
adversarial_case = adv_cases,
return_histories = True,
**{k: v for k, v in kwargs.items() if k not in ("task_id", "variant_seed", "adversarial_case")},
)
for i, r in enumerate(rewards):
t_id = t_ids[i] if i < len(t_ids) else ACTIVE_TASK_IDS[0]
seed = v_seeds[i] if i < len(v_seeds) else 0
history = histories[i] if i < len(histories) else []
prompt_state.update_after_episode(
task_id=t_id,
variant_seed=seed,
reward=r,
history=history,
mem_record_episode=mem_record_episode,
record_episode_feedback=record_episode_feedback,
save_agent_memory=save_agent_memory,
save_feedback_memory=save_feedback_memory,
maybe_consolidate_memory=maybe_consolidate_memory,
)
nonlocal memory
memory = prompt_state.memory
nonlocal feedback_memory
feedback_memory = prompt_state.feedback_memory
monitor_summary = training_monitor.log_batch(
sentinel_task_ids=list(SENTINEL_TASK_IDS),
rewards=rewards,
histories=histories,
task_ids=[str(task_id) for task_id in t_ids],
variant_seeds=[int(seed) for seed in v_seeds],
completions=[str(completion) for completion in completions],
prompts=[str(prompt) for prompt in prompts],
adversarial_cases=[str(case) for case in adv_cases],
curriculum_summary=curriculum_snapshot,
prompt_refreshes=prompt_state.prompt_refreshes,
reward_schedule=reward_schedule,
memory_summary={
"agent_memory_enabled": USE_AGENT_MEMORY,
"feedback_memory_enabled": USE_FEEDBACK_MEMORY,
**summarize_agent_memory(memory),
},
)
audit_path = rollout_auditor.record_batch(
sentinel_task_ids=list(SENTINEL_TASK_IDS),
active_task_ids=list(ACTIVE_TASK_IDS),
batch_index=training_monitor.batch_index,
prompts=[str(prompt) for prompt in prompts],
completions=[str(completion) for completion in completions],
rewards=rewards,
histories=histories,
task_ids=[str(task_id) for task_id in t_ids],
variant_seeds=[int(seed) for seed in v_seeds],
monitor_summary=monitor_summary,
reward_schedule=reward_schedule,
)
if curriculum and curriculum.should_use_adversarial():
logger.info(
"Adversarial trigger: tier=%d mean=%.2f",
curriculum.tier_index,
curriculum.summary()["recent_mean_score"],
)
try:
weak_spots = curriculum.weak_spots(top_n=2)
if USE_SENTINEL and USE_SENTINEL_ADVERSARIAL:
from training.adversarial import (
generate_sentinel_adversarial_cases,
save_sentinel_adversarial_cases,
)
cases = generate_sentinel_adversarial_cases(weak_spots, n=4)
save_sentinel_adversarial_cases(cases, SENTINEL_ADVERSARIAL_PATH)
prompt_state.sentinel_adversarial_cases = cases
logger.info("Generated %d SENTINEL adversarial worker cases", len(cases))
elif GROQ_API_KEY:
from training.adversarial import AdversarialDesigner
designer = AdversarialDesigner(api_key=GROQ_API_KEY)
new_scenarios = designer.generate(weak_spots, n=3)
designer.save_generated("outputs/adversarial_scenarios.json")
logger.info("Generated %d adversarial scenarios", len(new_scenarios))
except Exception as e:
logger.debug("Adversarial generation failed: %s", e)
if wandb_enabled:
import wandb
wandb_payload = {
"monitor/reward_mean": monitor_summary["reward_mean"],
"monitor/avg_steps": monitor_summary["avg_steps"],
"monitor/running_reward_mean": monitor_summary["running_reward_mean"],
"monitor/best_reward_mean": monitor_summary["best_reward_mean"],
"monitor/unique_completion_ratio": monitor_summary.get("unique_completion_ratio", 0.0),
"monitor/decision_entropy": monitor_summary.get("decision_entropy", 0.0),
"monitor/decision_variety": monitor_summary.get("decision_variety", 0),
"monitor/zero_reward_fraction": monitor_summary.get("zero_reward_fraction", 0.0),
"monitor/trivially_solved_fraction": monitor_summary.get("trivially_solved_fraction", 0.0),
"monitor/productive_fraction": monitor_summary.get("productive_fraction", 0.0),
"monitor/effective_prompt_ratio": monitor_summary.get("effective_prompt_ratio", 0.0),
"monitor/frontier_hit_rate": monitor_summary.get("frontier_hit_rate", 0.0),
"monitor/task_diversity_ratio": monitor_summary.get("task_diversity_ratio", 0.0),
"monitor/zero_gradient_group_fraction": monitor_summary.get("zero_gradient_group_fraction", 0.0),
"monitor/adversarial_case_fraction": monitor_summary.get("adversarial_case_fraction", 0.0),
}
if monitor_summary.get("memory"):
wandb_payload["monitor/memory_total_episodes"] = monitor_summary["memory"].get("total_episodes", 0)
wandb_payload["monitor/memory_mistake_cards"] = monitor_summary["memory"].get("mistake_cards_stored", 0)
if USE_SENTINEL:
wandb_payload.update(
{
"monitor/detection_rate": monitor_summary.get("detection_rate", 0.0),
"monitor/false_positive_rate": monitor_summary.get("false_positive_rate", 0.0),
"monitor/risk_reduction_rate": monitor_summary.get("risk_reduction_rate", 0.0),
"monitor/twin_damage_reduction_rate": monitor_summary.get("twin_damage_reduction_rate", 0.0),
"monitor/twin_without_sentinel_damage_total": monitor_summary.get("twin_without_sentinel_damage_total", 0.0),
"monitor/twin_with_sentinel_damage_total": monitor_summary.get("twin_with_sentinel_damage_total", 0.0),
"monitor/worker_rehabilitation_rate": monitor_summary.get("worker_rehabilitation_rate", 0.0),
"monitor/coaching_quality": monitor_summary.get("coaching_quality", 0.0),
}
)
if reward_schedule:
wandb_payload.update(
{
"monitor/reward_schedule_progress": reward_schedule.get("progress", 0.0),
"monitor/reward_schedule_stage": reward_schedule.get("stage", "unknown"),
}
)
if audit_path:
wandb_payload["monitor/rollout_audit_saved"] = 1
wandb.log(wandb_payload)
return rewards
# Create trainer
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
args = grpo_config,
train_dataset = train_dataset,
reward_funcs = [reward_fn_with_curriculum],
)
stability_callback = GRPOStabilityCallback(
training_monitor=training_monitor,
initial_beta=KL_COEF,
target_kl=KL_TARGET,
adaptive=KL_ADAPTIVE,
low_factor=KL_LOW_FACTOR,
high_factor=KL_HIGH_FACTOR,
beta_up_mult=KL_BETA_UP_MULT,
beta_down_mult=KL_BETA_DOWN_MULT,
min_beta=KL_MIN_BETA,
max_beta=KL_MAX_BETA,
hard_stop_enabled=KL_HARD_STOP_ENABLED,
hard_stop_mult=KL_HARD_STOP_MULT,
)
trainer.add_callback(stability_callback)
stability_callback.bind_trainer(trainer)
# Train
logger.info("Starting training...")
start_time = time.time()
trainer.train()
elapsed = time.time() - start_time
logger.info("Training complete in %.1f minutes", elapsed / 60)
# Save final model
final_path = os.path.join(OUTPUT_DIR, "final")
trainer.save_model(final_path)
tokenizer.save_pretrained(final_path)
logger.info("Saved final model to %s", final_path)
# Save curriculum state
if curriculum:
logger.info("Curriculum summary: %s", curriculum.summary())
if USE_AGENT_MEMORY:
save_agent_memory(memory)
if USE_SENTINEL and USE_FEEDBACK_MEMORY:
save_feedback_memory(feedback_memory, SENTINEL_FEEDBACK_MEMORY_PATH)
if warm_start_summary:
logger.info("Warm-start summary: %s", warm_start_summary)
if USE_SENTINEL:
reset_reward_weights()
# Plot reward curve
_plot_reward_curve()
try:
from scripts.render_training_dashboard import render_dashboard
render_dashboard(
monitor_dir=TRAIN_MONITOR_DIR,
output_dir="outputs/reward_curves",
)
except Exception as exc:
logger.warning("Training dashboard render skipped: %s", exc)
# Push to Hub (if HF_TOKEN set)
hf_repo = os.getenv("HF_REPO")
if hf_repo and HF_TOKEN:
logger.info("Pushing to HuggingFace Hub: %s", hf_repo)
trainer.model.push_to_hub(hf_repo, token=HF_TOKEN)
tokenizer.push_to_hub(hf_repo, token=HF_TOKEN)
logger.info("Done! Update openenv.yaml model: %s", hf_repo)
if wandb_enabled:
import wandb
wandb.finish()
return final_path
# ---------------------------------------------------------------------------
# Reward curve plot
# ---------------------------------------------------------------------------
def _plot_reward_curve():
"""Plot reward/mean over steps from wandb run or log file."""
try:
import matplotlib.pyplot as plt
steps, rewards = [], []
monitor_path = Path(TRAIN_MONITOR_DIR) / "training_metrics.jsonl"
if monitor_path.exists():
with monitor_path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
steps.append(int(payload.get("batch_index", len(steps) + 1)))
rewards.append(float(payload.get("reward_mean", 0.0)))
else:
log_path = os.path.join(OUTPUT_DIR, "train.log")
if not os.path.exists(log_path):
return
with open(log_path, encoding="utf-8", errors="ignore") as f:
for line in f:
if "Batch rewards: mean=" in line:
try:
mean_str = line.split("mean=")[1].split(" ")[0]
steps.append(len(steps) + 1)
rewards.append(float(mean_str))
except Exception:
pass
if not steps:
return
plt.figure(figsize=(10, 5))
plt.plot(steps, rewards, linewidth=2, color="royalblue")
plt.xlabel("Training Step")
plt.ylabel("Mean Reward")
plt.title("GRPO Training Reward Curve")
plt.grid(True, alpha=0.3)
# Smoothed line
if len(rewards) > 10:
window = min(10, len(rewards) // 5)
smoothed = np.convolve(rewards, np.ones(window)/window, mode="valid")
smooth_steps = steps[:len(smoothed)]
plt.plot(smooth_steps, smoothed, linewidth=2, color="red",
linestyle="--", label=f"Smoothed (w={window})")
plt.legend()
plot_path = "outputs/reward_curves/training_curve.png"
plt.savefig(plot_path, dpi=120, bbox_inches="tight")
plt.close()
logger.info("Saved reward curve to %s", plot_path)
except ImportError:
logger.info("matplotlib not installed - skipping reward plot")
except Exception as e:
logger.warning("Could not plot reward curve: %s", e)
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="GRPO training for OpenEnv")
parser.add_argument("--steps", type=int, default=TRAIN_STEPS, help="Training steps")
parser.add_argument("--model", type=str, default=MODEL_NAME, help="Model name/path")
parser.add_argument("--lr", type=float, default=LR, help="Learning rate")
parser.add_argument("--output", type=str, default=OUTPUT_DIR, help="Output directory")
parser.add_argument("--resume", type=str, default=RESUME_FROM, help="Checkpoint to resume from")
parser.add_argument("--warm-start-steps", type=int, default=WARM_START_STEPS, help="Optional small SFT-style warm-start steps before GRPO")
parser.add_argument("--warm-start-only", action="store_true", help="Run only the warm-start stage and stop before GRPO")
parser.add_argument("--dry-run", action="store_true", help="Validate setup without training")
args = parser.parse_args()
# Override from CLI
TRAIN_STEPS = args.steps
MODEL_NAME = args.model
LR = args.lr
OUTPUT_DIR = args.output
RESUME_FROM = args.resume
WARM_START_STEPS = args.warm_start_steps
WARM_START_ONLY = args.warm_start_only or WARM_START_ONLY
if args.dry_run:
logger.info("DRY RUN: Validating environment and reward function...")
if USE_SENTINEL:
from sentinel.environment import SentinelEnv
env = SentinelEnv()
for task_id in SENTINEL_TASK_IDS:
obs = env.reset(task_id=task_id, variant_seed=0)
grade = env.grade()
score = float(grade.score) if hasattr(grade, "score") else float(grade.get("score", 0.0))
logger.info(" task=%s initial_grade=%.3f", task_id, score)
else:
from src.environment import IncidentResponseEnv
env = IncidentResponseEnv()
for task_id in TASK_IDS:
obs = env.reset(task_id=task_id, variant_seed=0)
grade = env.grade()
score = float(grade.score) if hasattr(grade, "score") else float(grade.get("score", 0.0))
logger.info(" task=%s initial_grade=%.3f", task_id, score)
if WARM_START_STEPS > 0:
from training.memory import load_agent_memory
from sentinel.feedback import load_feedback_memory
warm_start_records = _build_warm_start_examples(
task_ids=list(ACTIVE_TASK_IDS),
memory=load_agent_memory(),
feedback_memory=load_feedback_memory(SENTINEL_FEEDBACK_MEMORY_PATH),
max_examples=max(1, min(WARM_START_DATASET_SIZE, 8)),
)
logger.info(" warm_start_examples=%d", len(warm_start_records))
logger.info("DRY RUN PASSED. Environment is working.")
sys.exit(0)
final_path = train()
logger.info("Training finished. Final model: %s", final_path)
logger.info("Next steps:")
logger.info(" 1. python validate.py")
logger.info(" 2. Update openenv.yaml: model: <HF_REPO>")
logger.info(" 3. Submit!")