Spaces:
Running
Running
| """ | |
| train.py - GRPO Fine-tuning for OpenEnv (IRT / SENTINEL) | |
| ============================================================== | |
| Runnable training script. Uses TRL GRPOTrainer + Unsloth (optional) + curriculum. | |
| HOW TO RUN: | |
| # Minimum (T4 / A10G, no Unsloth): | |
| python train.py | |
| # With Unsloth (A100 / H100, 2x faster): | |
| USE_UNSLOTH=1 python train.py | |
| # Override model and steps: | |
| MODEL_NAME=unsloth/Qwen3-30B-A3B-bnb-4bit TRAIN_STEPS=200 python train.py | |
| # Resume from checkpoint: | |
| RESUME_FROM=outputs/checkpoints/checkpoint-100 python train.py | |
| ENV VARS: | |
| MODEL_NAME HuggingFace model ID (default: unsloth/Qwen3-30B-A3B-bnb-4bit) | |
| HF_TOKEN HuggingFace token (for gated models) | |
| GROQ_API_KEY Groq API key (for LLM judge panel, optional) | |
| WANDB_PROJECT W&B project name (optional, set to "" to disable) | |
| TRAIN_STEPS Number of GRPO training steps (default: 200) | |
| NUM_GENERATIONS G rollouts per prompt (default: 4) | |
| USE_UNSLOTH Set to "1" to use Unsloth (requires unsloth installed) | |
| RESUME_FROM Path to checkpoint to resume from | |
| OUTPUT_DIR Where to save checkpoints (default: outputs/checkpoints) | |
| LR Learning rate (default: 5e-6) | |
| KL_COEF KL penalty coefficient (default: 0.04) | |
| LORA_R LoRA rank (default: 16) | |
| TRAIN_MONITOR_DIR Structured metrics output dir (default: outputs/monitoring) | |
| WARM_START_STEPS Optional small warm-start steps before GRPO (default: 0) | |
| WARM_START_LR Learning rate for warm-start stage (default: 2e-5) | |
| WARM_START_ONLY Set to "1" to stop after warm-start | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import platform | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from importlib import metadata as importlib_metadata | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| try: | |
| import torch | |
| from torch.utils.data import Dataset as TorchDataset | |
| except ModuleNotFoundError: | |
| torch = None | |
| class TorchDataset: # type: ignore[no-redef] | |
| """Fallback base for tests that import train.py without training deps.""" | |
| pass | |
| # bnb-4bit pre-quantized models have compute_dtype=float16 baked in, so LoRA | |
| # adapter parameters and their gradients are FP16. PyTorch 2.10 added a strict | |
| # check in GradScaler._unscale_grads_ that rejects FP16 gradients (intended for | |
| # full-precision training where FP16 grads indicate a misconfiguration). For | |
| # bnb-4bit + LoRA this check is a false positive — patch it out. | |
| if torch is not None: | |
| import torch.amp.grad_scaler as _gs | |
| _orig_unscale_grads = _gs.GradScaler._unscale_grads_ | |
| def _allow_fp16_unscale(self, optimizer, inv_scale, found_inf, allow_fp16): | |
| return _orig_unscale_grads(self, optimizer, inv_scale, found_inf, True) | |
| _gs.GradScaler._unscale_grads_ = _allow_fp16_unscale | |
| # Re-export from extracted modules for backward compatibility | |
| from training.metrics import ( | |
| safe_ratio as _safe_ratio, | |
| _increment_counter, | |
| _normalize_completion_text, | |
| _extract_completion_choice, | |
| _shannon_entropy_from_labels, | |
| summarize_sentinel_history as _summarize_sentinel_history, | |
| aggregate_batch_metrics as _aggregate_batch_metrics, | |
| completion_diversity_metrics as _completion_diversity_metrics, | |
| productive_signal_metrics as _productive_signal_metrics, | |
| training_coverage_metrics as _training_coverage_metrics, | |
| zero_gradient_group_metrics as _zero_gradient_group_metrics, | |
| frontier_scenario_keys as _frontier_scenario_keys, | |
| set_thresholds as _set_metric_thresholds, | |
| ) | |
| from training.monitoring import ( | |
| TrainingMonitor, | |
| GRPOStabilityCallback, | |
| RolloutAuditSampler, | |
| _truncate_text, | |
| _audit_priority, | |
| ) | |
| from training.prompts import ( | |
| build_system_prompt, | |
| scenario_to_prompt, | |
| sentinel_obs_to_prompt, | |
| sentinel_adversarial_case_to_prompt, | |
| build_prompt_record as _build_prompt_record_impl, | |
| memory_context_for_task as _memory_context_for_task, | |
| load_or_create_sentinel_adversarial_cases as _load_or_create_sentinel_adversarial_cases_impl, | |
| AdaptivePromptState as _AdaptivePromptStateBase, | |
| AdaptivePromptDataset, | |
| WarmStartDataset, | |
| build_grpo_dataset as _build_grpo_dataset_impl, | |
| ) | |
| from training.episodes import ( | |
| parse_action as _parse_action, | |
| greedy_fallback_action as _greedy_fallback_action, | |
| greedy_fallback_sentinel_decision as _greedy_fallback_sentinel_decision, | |
| run_episode_with_completion as _run_episode_with_completion_impl, | |
| _run_irt_episode, | |
| _run_sentinel_episode, | |
| run_sentinel_adversarial_case as _run_sentinel_adversarial_case, | |
| grpo_reward_fn as _grpo_reward_fn_impl, | |
| trajectory_summary_from_history as _trajectory_summary_from_history, | |
| mistakes_from_history as _mistakes_from_history, | |
| mistake_cards_from_history as _mistake_cards_from_history, | |
| successes_from_history as _successes_from_history, | |
| ) | |
| from training.curriculum import CURRICULUM_FRONTIER_FAILURE_RATE | |
| MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen3-30B-A3B-bnb-4bit") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "") | |
| WANDB_PROJECT = os.getenv("WANDB_PROJECT", "").strip() | |
| TRAIN_STEPS = int(os.getenv("TRAIN_STEPS", "100")) | |
| NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "2")) | |
| USE_UNSLOTH = os.getenv("USE_UNSLOTH", "1") == "1" | |
| RESUME_FROM = os.getenv("RESUME_FROM", "") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/checkpoints") | |
| LR = float(os.getenv("LR", "5e-6")) | |
| KL_COEF = float(os.getenv("KL_COEF", "0.04")) | |
| LORA_R = int(os.getenv("LORA_R", "16")) | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512")) | |
| PROMPT_DATASET_SIZE = int(os.getenv("PROMPT_DATASET_SIZE", str(max(512, TRAIN_STEPS * 8)))) | |
| USE_LLM_PANEL = bool(GROQ_API_KEY) # auto-enable if key available | |
| USE_CURRICULUM = os.getenv("USE_CURRICULUM", "1") == "1" | |
| GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.7")) | |
| GEN_TOP_P = float(os.getenv("GEN_TOP_P", "1.0")) | |
| USE_SENTINEL = os.getenv("USE_SENTINEL", "0") == "1" # Enable SENTINEL training | |
| USE_AGENT_MEMORY = os.getenv("USE_AGENT_MEMORY", "1") == "1" | |
| USE_FEEDBACK_MEMORY = os.getenv("USE_FEEDBACK_MEMORY", "1") == "1" and USE_AGENT_MEMORY | |
| USE_SENTINEL_ADVERSARIAL = os.getenv("USE_SENTINEL_ADVERSARIAL", "1") == "1" | |
| SENTINEL_ADVERSARIAL_PATH = os.getenv( | |
| "SENTINEL_ADVERSARIAL_PATH", | |
| "outputs/sentinel_adversarial_cases.json", | |
| ) | |
| SENTINEL_FEEDBACK_MEMORY_PATH = os.getenv( | |
| "SENTINEL_FEEDBACK_MEMORY_PATH", | |
| "outputs/sentinel_feedback_memory.json", | |
| ) | |
| TRAIN_MONITOR_DIR = os.getenv("TRAIN_MONITOR_DIR", "outputs/monitoring") | |
| WARM_START_STEPS = int(os.getenv("WARM_START_STEPS", "0")) | |
| WARM_START_LR = float(os.getenv("WARM_START_LR", "2e-5")) | |
| WARM_START_DATASET_SIZE = int(os.getenv("WARM_START_DATASET_SIZE", "24")) | |
| WARM_START_OUTPUT_DIR = os.getenv("WARM_START_OUTPUT_DIR", "outputs/warm_start") | |
| WARM_START_ONLY = os.getenv("WARM_START_ONLY", "0") == "1" | |
| ROLLOUT_AUDIT_DIR = os.getenv("ROLLOUT_AUDIT_DIR", os.path.join(TRAIN_MONITOR_DIR, "rollout_audits")) | |
| ROLLOUT_AUDIT_EVERY = int(os.getenv("ROLLOUT_AUDIT_EVERY", "10")) | |
| ROLLOUT_AUDIT_SAMPLES = int(os.getenv("ROLLOUT_AUDIT_SAMPLES", "2")) | |
| REWARD_SCHEDULE_MODE = os.getenv("REWARD_SCHEDULE_MODE", os.getenv("REWARD_PROFILE", "dynamic")) | |
| MODEL_STEPS_LIMIT = int(os.getenv("MODEL_STEPS_LIMIT", "1")) | |
| KL_TARGET = float(os.getenv("KL_TARGET", "0.08")) | |
| KL_ADAPTIVE = os.getenv("KL_ADAPTIVE", "1") == "1" | |
| KL_LOW_FACTOR = float(os.getenv("KL_LOW_FACTOR", "1.5")) | |
| KL_HIGH_FACTOR = float(os.getenv("KL_HIGH_FACTOR", "1.5")) | |
| KL_BETA_UP_MULT = float(os.getenv("KL_BETA_UP_MULT", "2.0")) | |
| KL_BETA_DOWN_MULT = float(os.getenv("KL_BETA_DOWN_MULT", "0.5")) | |
| KL_MIN_BETA = float(os.getenv("KL_MIN_BETA", "0.005")) | |
| KL_MAX_BETA = float(os.getenv("KL_MAX_BETA", "0.5")) | |
| KL_HARD_STOP_ENABLED = os.getenv("KL_HARD_STOP_ENABLED", "0") == "1" | |
| KL_HARD_STOP_MULT = float(os.getenv("KL_HARD_STOP_MULT", "3.0")) | |
| ZERO_SIGNAL_REWARD_THRESHOLD = float(os.getenv("ZERO_SIGNAL_REWARD_THRESHOLD", "0.05")) | |
| TRIVIAL_REWARD_THRESHOLD = float(os.getenv("TRIVIAL_REWARD_THRESHOLD", "0.95")) | |
| TASK_IDS = [ | |
| "severity_classification", | |
| "root_cause_analysis", | |
| "full_incident_management", | |
| ] | |
| SENTINEL_TASK_IDS = [ | |
| "basic_oversight", | |
| "fleet_monitoring_conflict", | |
| "adversarial_worker", | |
| "multi_crisis_command", | |
| ] | |
| def _parse_task_filter(env_name: str, allowed: List[str]) -> List[str]: | |
| raw = os.getenv(env_name, "").strip() | |
| if not raw: | |
| return list(allowed) | |
| selected = [part.strip() for part in raw.split(",") if part.strip()] | |
| unknown = [task_id for task_id in selected if task_id not in allowed] | |
| if unknown: | |
| raise ValueError( | |
| f"{env_name} contains unknown task id(s): {unknown}. " | |
| f"Allowed: {allowed}" | |
| ) | |
| return selected or list(allowed) | |
| TASK_IDS = _parse_task_filter("IRT_TASKS", TASK_IDS) | |
| SENTINEL_TASK_IDS = _parse_task_filter("SENTINEL_TASKS", SENTINEL_TASK_IDS) | |
| # Select task set based on USE_SENTINEL flag | |
| ACTIVE_TASK_IDS = SENTINEL_TASK_IDS if USE_SENTINEL else TASK_IDS | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs("outputs/reward_curves", exist_ok=True) | |
| os.makedirs(TRAIN_MONITOR_DIR, exist_ok=True) | |
| logging.basicConfig( | |
| level = logging.INFO, | |
| format = "%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler(os.path.join(OUTPUT_DIR, "train.log")), | |
| ], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def _package_version(name: str) -> str: | |
| try: | |
| return importlib_metadata.version(name) | |
| except importlib_metadata.PackageNotFoundError: | |
| return "missing" | |
| def collect_training_stack_versions() -> Dict[str, Any]: | |
| cuda_available = bool(torch is not None and torch.cuda.is_available()) | |
| return { | |
| "python": platform.python_version(), | |
| "platform": platform.platform(), | |
| "model_name": MODEL_NAME, | |
| "use_unsloth": USE_UNSLOTH, | |
| "cuda_available": cuda_available, | |
| "bf16_available": bool(cuda_available and torch.cuda.is_bf16_supported()), | |
| "train_steps": TRAIN_STEPS, | |
| "warm_start_steps": WARM_START_STEPS, | |
| "reward_schedule_mode": REWARD_SCHEDULE_MODE, | |
| "memory": { | |
| "agent_memory_enabled": USE_AGENT_MEMORY, | |
| "feedback_memory_enabled": USE_FEEDBACK_MEMORY, | |
| }, | |
| "productive_signal_thresholds": { | |
| "zero_signal_reward_threshold": ZERO_SIGNAL_REWARD_THRESHOLD, | |
| "trivial_reward_threshold": TRIVIAL_REWARD_THRESHOLD, | |
| }, | |
| "adaptive_curriculum": { | |
| "frontier_failure_rate": CURRICULUM_FRONTIER_FAILURE_RATE, | |
| }, | |
| "kl_control": { | |
| "initial_beta": KL_COEF, | |
| "target": KL_TARGET, | |
| "adaptive": KL_ADAPTIVE, | |
| "low_factor": KL_LOW_FACTOR, | |
| "high_factor": KL_HIGH_FACTOR, | |
| "beta_up_mult": KL_BETA_UP_MULT, | |
| "beta_down_mult": KL_BETA_DOWN_MULT, | |
| "min_beta": KL_MIN_BETA, | |
| "max_beta": KL_MAX_BETA, | |
| "hard_stop_enabled": KL_HARD_STOP_ENABLED, | |
| "hard_stop_mult": KL_HARD_STOP_MULT, | |
| }, | |
| "packages": { | |
| "torch": getattr(torch, "__version__", "missing") if torch is not None else "missing", | |
| "bitsandbytes": _package_version("bitsandbytes"), | |
| "transformers": _package_version("transformers"), | |
| "peft": _package_version("peft"), | |
| "trl": _package_version("trl"), | |
| "datasets": _package_version("datasets"), | |
| "matplotlib": _package_version("matplotlib"), | |
| "wandb": _package_version("wandb"), | |
| "openenv-core": _package_version("openenv-core"), | |
| "unsloth": _package_version("unsloth"), | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # W&B setup (optional) | |
| # --------------------------------------------------------------------------- | |
| wandb_enabled = bool(WANDB_PROJECT) and WANDB_PROJECT.lower() not in {"0", "false", "none", "disabled"} | |
| if wandb_enabled: | |
| try: | |
| import wandb | |
| wandb.init(project=WANDB_PROJECT, config={ | |
| "model": MODEL_NAME, | |
| "train_steps": TRAIN_STEPS, | |
| "num_generations": NUM_GENERATIONS, | |
| "lr": LR, | |
| "kl_coef": KL_COEF, | |
| "lora_r": LORA_R, | |
| "use_llm_panel": USE_LLM_PANEL, | |
| }) | |
| logger.info("W&B enabled: project=%s", WANDB_PROJECT) | |
| except ImportError: | |
| wandb_enabled = False | |
| logger.warning("wandb not installed -- logging disabled") | |
| except Exception as exc: | |
| wandb_enabled = False | |
| logger.warning("wandb init skipped: %s", exc) | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def load_model_and_tokenizer(): | |
| """Load model + tokenizer. Uses Unsloth if USE_UNSLOTH=1, else standard HF. | |
| When Unsloth is enabled: | |
| - 12x faster MoE training via Triton kernels (torch._grouped_mm) | |
| - 3x faster inference via fused attention (FastLanguageModel.for_inference) | |
| - >35% less VRAM via 4-bit quantization + gradient checkpointing | |
| """ | |
| if torch is None: | |
| raise ImportError( | |
| "Training requires torch. Install the training extras before running train.py." | |
| ) | |
| if USE_UNSLOTH: | |
| logger.info("Loading model with Unsloth: %s", MODEL_NAME) | |
| from unsloth import FastLanguageModel | |
| # IMPORTANT: keep dtype=float16 for bnb-4bit. The pre-quantized | |
| # unsloth/*-bnb-4bit models have compute_dtype=float16 baked into their | |
| # quantization config. Unsloth's fast_lora kernels use X.dtype as the | |
| # target dtype for LoRA ops; if X is BF16 but bnb dequant output is FP16 | |
| # the addmm_ inside matmul_lora crashes with "same dtype" error. | |
| _unsloth_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = MODEL_NAME, | |
| max_seq_length = 4096, | |
| dtype = _unsloth_dtype, | |
| load_in_4bit = True, | |
| token = HF_TOKEN or None, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r = LORA_R, | |
| target_modules = ["q_proj","k_proj","v_proj","o_proj", | |
| "gate_proj","up_proj","down_proj"], | |
| lora_alpha = LORA_R, | |
| lora_dropout = 0, | |
| bias = "none", | |
| use_gradient_checkpointing = "unsloth", | |
| random_state = 42, | |
| ) | |
| # Enable Unsloth fast inference (2-3x speedup for generation) | |
| # GRPOTrainer internally handles train/eval mode toggling, but | |
| # setting this up front ensures optimized attention kernels are | |
| # compiled and ready for the first rollout batch. | |
| try: | |
| FastLanguageModel.for_inference(model) | |
| logger.info("Unsloth fast inference enabled (fused attention kernels)") | |
| except Exception as exc: | |
| logger.warning("Unsloth fast inference setup skipped: %s", exc) | |
| else: | |
| logger.info("Loading model with standard HF: %s", MODEL_NAME) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig, get_peft_model | |
| cuda_available = torch.cuda.is_available() | |
| bf16_available = cuda_available and torch.cuda.is_bf16_supported() | |
| load_kwargs: Dict[str, Any] = { | |
| "torch_dtype": torch.bfloat16 if bf16_available else (torch.float16 if cuda_available else torch.float32), | |
| "device_map" : "auto" if cuda_available else None, | |
| } | |
| if "bnb-4bit" in MODEL_NAME or "4bit" in MODEL_NAME: | |
| from transformers import BitsAndBytesConfig | |
| load_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit = True, | |
| bnb_4bit_use_double_quant = True, | |
| bnb_4bit_quant_type = "nf4", | |
| bnb_4bit_compute_dtype = torch.bfloat16, | |
| ) | |
| load_kwargs.pop("torch_dtype", None) | |
| if HF_TOKEN: | |
| load_kwargs["token"] = HF_TOKEN | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN or None) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **load_kwargs) | |
| lora_config = LoraConfig( | |
| r = LORA_R, | |
| lora_alpha = LORA_R, | |
| target_modules = ["q_proj","k_proj","v_proj","o_proj"], | |
| lora_dropout = 0.05, | |
| bias = "none", | |
| task_type = "CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| if RESUME_FROM: | |
| logger.info("Resuming from checkpoint: %s", RESUME_FROM) | |
| from peft import PeftModel | |
| model = PeftModel.from_pretrained(model, RESUME_FROM) | |
| return model, tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Dataset construction | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # Backward-compatible re-exports for tests | |
| AdaptivePromptState = _AdaptivePromptStateBase | |
| build_prompt_record = _build_prompt_record_impl | |
| build_grpo_dataset = _build_grpo_dataset_impl | |
| _load_or_create_sentinel_adversarial_cases = _load_or_create_sentinel_adversarial_cases_impl | |
| _aggregate_batch_metrics = _aggregate_batch_metrics | |
| _sentinel_history_entry = None # re-exported below | |
| def _sentinel_history_entry_fn(decision, result): | |
| from training.episodes import _sentinel_history_entry as _she | |
| return _she(decision, result) | |
| _sentinel_history_entry = _sentinel_history_entry_fn | |
| # Thin wrappers delegating to extracted modules | |
| # --------------------------------------------------------------------------- | |
| # Prompt construction | |
| def _build_system_prompt(task_id, memory_context=""): | |
| return build_system_prompt(task_id, SENTINEL_TASK_IDS, memory_context) | |
| def _scenario_to_prompt(scenario, task_id, memory_context=""): | |
| return scenario_to_prompt(scenario, task_id, SENTINEL_TASK_IDS, memory_context) | |
| def _sentinel_obs_to_prompt(obs, task_id, memory_context=""): | |
| return sentinel_obs_to_prompt(obs, task_id, SENTINEL_TASK_IDS, memory_context) | |
| # Episode execution | |
| def run_episode_with_completion(completion_text, task_id, variant_seed): | |
| return _run_episode_with_completion_impl( | |
| completion_text, task_id, variant_seed, SENTINEL_TASK_IDS, | |
| model_steps_limit=MODEL_STEPS_LIMIT, | |
| ) | |
| def grpo_reward_fn(prompts, completions, **kwargs): | |
| return _grpo_reward_fn_impl( | |
| prompts, completions, | |
| sentinel_task_ids=SENTINEL_TASK_IDS, | |
| active_task_ids=list(ACTIVE_TASK_IDS), | |
| use_llm_panel=USE_LLM_PANEL, | |
| groq_api_key=GROQ_API_KEY, | |
| wandb_enabled=wandb_enabled, | |
| model_steps_limit=MODEL_STEPS_LIMIT, | |
| **kwargs, | |
| ) | |
| # Warm-start helpers | |
| def _warm_start_counterfactual(obs, decision): | |
| constitution = obs.constitutional_assessment or {} | |
| violations = list((constitution.get("constitutional_violations") or {}).keys()) | |
| if decision.get("decision") == "APPROVE": | |
| return "If approved, the proposal stays within the current safe operating envelope." | |
| if violations: | |
| return "If approved unchanged, this proposal would likely violate " + ", ".join(violations) + " and increase operational risk." | |
| return "If approved unchanged, this proposal could bypass oversight without sufficient justification." | |
| def _warm_start_sentinel_decision_for_observation(obs): | |
| decision = dict(_greedy_fallback_sentinel_decision(obs, [])) | |
| violations = sorted((obs.constitutional_assessment.get("constitutional_violations") or {}).keys()) | |
| decision.setdefault("worker_message", "Approved. Execute carefully and preserve evidence.") | |
| for key in ["required_evidence", "redirect_action", "reassign_to", "flag_severity", "suggested_action_type", "suggested_target"]: | |
| decision.setdefault(key, [] if key == "required_evidence" else None) | |
| decision.setdefault("suggested_parameters", {}) | |
| decision["constitutional_violations"] = violations | |
| decision["counterfactual"] = _warm_start_counterfactual(obs, decision) | |
| return decision | |
| def _build_warm_start_examples(task_ids, memory_context="", memory=None, feedback_memory=None, max_examples=None, max_seeds=3): | |
| if max_examples is None: max_examples = WARM_START_DATASET_SIZE | |
| records = [] | |
| for task_id in task_ids: | |
| for seed in range(max_seeds): | |
| task_memory = _memory_context_for_task(memory, feedback_memory, task_id, memory_context) | |
| if task_id in SENTINEL_TASK_IDS: | |
| from sentinel.environment import SentinelEnv | |
| env = SentinelEnv() | |
| obs = env.reset(task_id=task_id, variant_seed=seed) | |
| prompt = _sentinel_obs_to_prompt(obs, task_id, task_memory) | |
| response = _warm_start_sentinel_decision_for_observation(obs) | |
| else: | |
| from src.environment import IncidentResponseEnv | |
| env = IncidentResponseEnv() | |
| obs = env.reset(task_id=task_id, variant_seed=seed) | |
| prompt = _scenario_to_prompt(env._scenario, task_id, task_memory) | |
| response = _greedy_fallback_action(env, obs, []) | |
| records.append({"task_id": task_id, "variant_seed": seed, "text": prompt + json.dumps(response, sort_keys=True)}) | |
| if len(records) >= max_examples: return records | |
| if records and len(records) < max_examples: | |
| cycled = [] | |
| idx = 0 | |
| while len(records) + len(cycled) < max_examples: | |
| cycled.append(dict(records[idx % len(records)])) | |
| idx += 1 | |
| records.extend(cycled) | |
| return records[:max_examples] | |
| def _run_small_warm_start(model, tokenizer, prompt_state): | |
| from transformers import Trainer, TrainingArguments | |
| output_dir = Path(WARM_START_OUTPUT_DIR) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| examples = _build_warm_start_examples(task_ids=list(ACTIVE_TASK_IDS), memory_context=prompt_state.memory_context, memory=prompt_state.memory, feedback_memory=prompt_state.feedback_memory, max_examples=max(1, WARM_START_DATASET_SIZE)) | |
| if not examples: raise RuntimeError("Warm-start requested, but no warm-start examples could be built.") | |
| preview = [{"task_id": r["task_id"], "variant_seed": r["variant_seed"], "text_preview": str(r["text"])[:240]} for r in examples[:5]] | |
| (output_dir / "dataset_preview.json").write_text(json.dumps(preview, indent=2), encoding="utf-8") | |
| dataset = WarmStartDataset([r["text"] for r in examples], tokenizer) | |
| args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| learning_rate=WARM_START_LR, | |
| max_steps=max(1, WARM_START_STEPS), | |
| num_train_epochs=1, | |
| logging_steps=1, | |
| save_strategy="no", | |
| remove_unused_columns=False, | |
| bf16=False, | |
| fp16=torch.cuda.is_available(), | |
| report_to="wandb" if wandb_enabled else "none", | |
| ) | |
| trainer = Trainer(model=model, args=args, train_dataset=dataset) | |
| trainer.train() | |
| final_dir = output_dir / "final" | |
| trainer.save_model(str(final_dir)) | |
| tokenizer.save_pretrained(str(final_dir)) | |
| summary = {"enabled": True, "steps": max(1, WARM_START_STEPS), "learning_rate": WARM_START_LR, "dataset_size": len(examples), "output_dir": str(output_dir), "saved_model_path": str(final_dir), "task_ids": list(ACTIVE_TASK_IDS)} | |
| (output_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| logger.info("Warm-start complete: steps=%d dataset=%d saved=%s", summary["steps"], summary["dataset_size"], final_dir) | |
| return summary | |
| def train(): | |
| logger.info("=" * 60) | |
| logger.info("OpenEnv GRPO Training") | |
| logger.info("Model: %s", MODEL_NAME) | |
| logger.info("Steps: %d", TRAIN_STEPS) | |
| logger.info("G: %d rollouts/prompt", NUM_GENERATIONS) | |
| logger.info("LR: %g", LR) | |
| logger.info("KL coef: %g", KL_COEF) | |
| logger.info("LoRA r: %d", LORA_R) | |
| logger.info("LLM panel: %s", USE_LLM_PANEL) | |
| logger.info("Curriculum: %s", USE_CURRICULUM) | |
| logger.info("Sampling: temperature=%.2f top_p=%.2f", GEN_TEMPERATURE, GEN_TOP_P) | |
| logger.info("Episode: MODEL_STEPS_LIMIT=%d MAX_NEW_TOKENS=%d", MODEL_STEPS_LIMIT, MAX_NEW_TOKENS) | |
| logger.info("EvalMinDif: %s", os.getenv("EVAL_MIN_DIFFICULTY", "0.0")) | |
| logger.info("Warm start: %s", WARM_START_STEPS if WARM_START_STEPS > 0 else "disabled") | |
| logger.info("Reward schedule: %s", REWARD_SCHEDULE_MODE if USE_SENTINEL else "n/a") | |
| logger.info( | |
| "KL control: target=%s adaptive=%s beta=%s [%s, %s]", | |
| KL_TARGET, | |
| KL_ADAPTIVE, | |
| KL_COEF, | |
| KL_MIN_BETA, | |
| KL_MAX_BETA, | |
| ) | |
| logger.info( | |
| "Rollout audit: every %s batch(es), %s sample(s)", | |
| ROLLOUT_AUDIT_EVERY if ROLLOUT_AUDIT_EVERY > 0 else "disabled", | |
| ROLLOUT_AUDIT_SAMPLES, | |
| ) | |
| logger.info("Output: %s", OUTPUT_DIR) | |
| logger.info("=" * 60) | |
| # Load model | |
| model, tokenizer = load_model_and_tokenizer() | |
| # Load curriculum and agent memory | |
| from training.curriculum import get_curriculum | |
| from training.memory import ( | |
| load_agent_memory, build_memory_context, maybe_consolidate_memory, | |
| record_episode as mem_record_episode, save_agent_memory, | |
| memory_summary as summarize_agent_memory, new_agent_memory, | |
| ) | |
| from sentinel.feedback import ( | |
| load_feedback_memory, | |
| empty_feedback_memory, | |
| record_episode_feedback, | |
| save_feedback_memory, | |
| ) | |
| from sentinel.rewards import reset_reward_weights, scheduled_reward_weights, set_reward_weights | |
| curriculum = get_curriculum(active_task_ids=ACTIVE_TASK_IDS) if USE_CURRICULUM else None | |
| memory = load_agent_memory() if USE_AGENT_MEMORY else new_agent_memory() | |
| feedback_memory = ( | |
| load_feedback_memory(SENTINEL_FEEDBACK_MEMORY_PATH) | |
| if USE_FEEDBACK_MEMORY | |
| else empty_feedback_memory() | |
| ) | |
| memory_ctx = build_memory_context(memory) if USE_AGENT_MEMORY else "" | |
| prompt_state = _AdaptivePromptStateBase( | |
| task_ids=list(ACTIVE_TASK_IDS), | |
| sentinel_task_ids=list(SENTINEL_TASK_IDS), | |
| curriculum=curriculum, | |
| memory=memory, | |
| feedback_memory=feedback_memory, | |
| memory_context=memory_ctx, | |
| memory_enabled=USE_AGENT_MEMORY, | |
| max_seeds=5, | |
| use_sentinel=USE_SENTINEL, | |
| use_feedback_memory=USE_FEEDBACK_MEMORY, | |
| use_llm_panel=USE_LLM_PANEL, | |
| groq_api_key=GROQ_API_KEY, | |
| sentinel_adversarial_path=SENTINEL_ADVERSARIAL_PATH, | |
| sentinel_feedback_memory_path=SENTINEL_FEEDBACK_MEMORY_PATH, | |
| use_sentinel_adversarial=USE_SENTINEL_ADVERSARIAL, | |
| ) | |
| if USE_SENTINEL and USE_SENTINEL_ADVERSARIAL: | |
| prompt_state.refresh_adversarial_cases() | |
| train_dataset = AdaptivePromptDataset( | |
| state=prompt_state, | |
| total_samples=PROMPT_DATASET_SIZE, | |
| ) | |
| training_monitor = TrainingMonitor(TRAIN_MONITOR_DIR) | |
| training_monitor.write_stack_versions(collect_training_stack_versions()) | |
| rollout_auditor = RolloutAuditSampler( | |
| output_dir=ROLLOUT_AUDIT_DIR, | |
| every=ROLLOUT_AUDIT_EVERY, | |
| sample_limit=ROLLOUT_AUDIT_SAMPLES, | |
| ) | |
| warm_start_summary: Optional[Dict[str, Any]] = None | |
| warm_start_path = os.path.join(WARM_START_OUTPUT_DIR, "final") | |
| if WARM_START_STEPS > 0 and os.path.isdir(warm_start_path): | |
| logger.info("Warm-start checkpoint found at %s — SKIPPING (saves ~20 min)", warm_start_path) | |
| # Reload the warm-start LoRA weights | |
| try: | |
| from peft import PeftModel | |
| if not hasattr(model, "peft_config"): | |
| model = PeftModel.from_pretrained(model, warm_start_path) | |
| # Coerce LoRA adapter dtype to fp16 to match the bnb-4bit base | |
| # compute dtype. bnb-4bit base weights are unaffected by .to(); | |
| # only the (small) LoRA adapters get cast. Prevents the | |
| # "self and mat2 must have the same dtype" crash inside Unsloth's | |
| # fast_lora kernels (which derive target dtype from X.dtype = fp16). | |
| if torch.cuda.is_available(): | |
| for name, param in model.named_parameters(): | |
| if "lora_" in name and param.dtype != torch.float16: | |
| param.data = param.data.to(torch.float16) | |
| logger.info("Loaded warm-start LoRA from %s", warm_start_path) | |
| except Exception as exc: | |
| logger.warning("Could not reload warm-start LoRA: %s (continuing with base model)", exc) | |
| warm_start_summary = {"saved_model_path": warm_start_path, "skipped": True} | |
| elif WARM_START_STEPS > 0: | |
| warm_start_summary = _run_small_warm_start(model, tokenizer, prompt_state) | |
| if WARM_START_ONLY: | |
| return warm_start_summary["saved_model_path"] | |
| # GRPO config | |
| from trl import GRPOConfig, GRPOTrainer | |
| grpo_config = GRPOConfig( | |
| output_dir = OUTPUT_DIR, | |
| num_train_epochs = 1, | |
| per_device_train_batch_size = NUM_GENERATIONS, | |
| gradient_accumulation_steps = 1, | |
| num_generations = NUM_GENERATIONS, | |
| max_completion_length = MAX_NEW_TOKENS, | |
| learning_rate = LR, | |
| beta = KL_COEF, | |
| temperature = GEN_TEMPERATURE, | |
| top_p = GEN_TOP_P, | |
| logging_steps = 1, | |
| save_steps = 25, | |
| save_total_limit = 4, | |
| dataloader_num_workers = 0, | |
| bf16 = False, | |
| fp16 = torch.cuda.is_available(), | |
| report_to = "wandb" if wandb_enabled else "none", | |
| max_steps = TRAIN_STEPS, | |
| ) | |
| # Wrap reward fn to inject curriculum-selected task_ids and seeds | |
| def reward_fn_with_curriculum(prompts, completions, **kwargs): | |
| # Extract task_id and variant_seed from dataset columns if available | |
| t_ids = kwargs.get("task_id", [ACTIVE_TASK_IDS[0]] * len(prompts)) | |
| v_seeds = kwargs.get("variant_seed", [0] * len(prompts)) | |
| adv_cases = kwargs.get("adversarial_case", [""] * len(prompts)) | |
| curriculum_snapshot = curriculum.summary() if curriculum else None | |
| reward_schedule: Optional[Dict[str, Any]] = None | |
| if USE_SENTINEL: | |
| current_batch_index = training_monitor.batch_index + 1 | |
| progress = min(1.0, current_batch_index / max(1, TRAIN_STEPS)) | |
| reward_schedule = scheduled_reward_weights( | |
| progress=progress, | |
| mode=REWARD_SCHEDULE_MODE, | |
| ) | |
| set_reward_weights(reward_schedule["weights"]) | |
| rewards, histories = grpo_reward_fn( | |
| prompts = prompts, | |
| completions = completions, | |
| task_id = t_ids, | |
| variant_seed = v_seeds, | |
| adversarial_case = adv_cases, | |
| return_histories = True, | |
| **{k: v for k, v in kwargs.items() if k not in ("task_id", "variant_seed", "adversarial_case")}, | |
| ) | |
| for i, r in enumerate(rewards): | |
| t_id = t_ids[i] if i < len(t_ids) else ACTIVE_TASK_IDS[0] | |
| seed = v_seeds[i] if i < len(v_seeds) else 0 | |
| history = histories[i] if i < len(histories) else [] | |
| prompt_state.update_after_episode( | |
| task_id=t_id, | |
| variant_seed=seed, | |
| reward=r, | |
| history=history, | |
| mem_record_episode=mem_record_episode, | |
| record_episode_feedback=record_episode_feedback, | |
| save_agent_memory=save_agent_memory, | |
| save_feedback_memory=save_feedback_memory, | |
| maybe_consolidate_memory=maybe_consolidate_memory, | |
| ) | |
| nonlocal memory | |
| memory = prompt_state.memory | |
| nonlocal feedback_memory | |
| feedback_memory = prompt_state.feedback_memory | |
| monitor_summary = training_monitor.log_batch( | |
| sentinel_task_ids=list(SENTINEL_TASK_IDS), | |
| rewards=rewards, | |
| histories=histories, | |
| task_ids=[str(task_id) for task_id in t_ids], | |
| variant_seeds=[int(seed) for seed in v_seeds], | |
| completions=[str(completion) for completion in completions], | |
| prompts=[str(prompt) for prompt in prompts], | |
| adversarial_cases=[str(case) for case in adv_cases], | |
| curriculum_summary=curriculum_snapshot, | |
| prompt_refreshes=prompt_state.prompt_refreshes, | |
| reward_schedule=reward_schedule, | |
| memory_summary={ | |
| "agent_memory_enabled": USE_AGENT_MEMORY, | |
| "feedback_memory_enabled": USE_FEEDBACK_MEMORY, | |
| **summarize_agent_memory(memory), | |
| }, | |
| ) | |
| audit_path = rollout_auditor.record_batch( | |
| sentinel_task_ids=list(SENTINEL_TASK_IDS), | |
| active_task_ids=list(ACTIVE_TASK_IDS), | |
| batch_index=training_monitor.batch_index, | |
| prompts=[str(prompt) for prompt in prompts], | |
| completions=[str(completion) for completion in completions], | |
| rewards=rewards, | |
| histories=histories, | |
| task_ids=[str(task_id) for task_id in t_ids], | |
| variant_seeds=[int(seed) for seed in v_seeds], | |
| monitor_summary=monitor_summary, | |
| reward_schedule=reward_schedule, | |
| ) | |
| if curriculum and curriculum.should_use_adversarial(): | |
| logger.info( | |
| "Adversarial trigger: tier=%d mean=%.2f", | |
| curriculum.tier_index, | |
| curriculum.summary()["recent_mean_score"], | |
| ) | |
| try: | |
| weak_spots = curriculum.weak_spots(top_n=2) | |
| if USE_SENTINEL and USE_SENTINEL_ADVERSARIAL: | |
| from training.adversarial import ( | |
| generate_sentinel_adversarial_cases, | |
| save_sentinel_adversarial_cases, | |
| ) | |
| cases = generate_sentinel_adversarial_cases(weak_spots, n=4) | |
| save_sentinel_adversarial_cases(cases, SENTINEL_ADVERSARIAL_PATH) | |
| prompt_state.sentinel_adversarial_cases = cases | |
| logger.info("Generated %d SENTINEL adversarial worker cases", len(cases)) | |
| elif GROQ_API_KEY: | |
| from training.adversarial import AdversarialDesigner | |
| designer = AdversarialDesigner(api_key=GROQ_API_KEY) | |
| new_scenarios = designer.generate(weak_spots, n=3) | |
| designer.save_generated("outputs/adversarial_scenarios.json") | |
| logger.info("Generated %d adversarial scenarios", len(new_scenarios)) | |
| except Exception as e: | |
| logger.debug("Adversarial generation failed: %s", e) | |
| if wandb_enabled: | |
| import wandb | |
| wandb_payload = { | |
| "monitor/reward_mean": monitor_summary["reward_mean"], | |
| "monitor/avg_steps": monitor_summary["avg_steps"], | |
| "monitor/running_reward_mean": monitor_summary["running_reward_mean"], | |
| "monitor/best_reward_mean": monitor_summary["best_reward_mean"], | |
| "monitor/unique_completion_ratio": monitor_summary.get("unique_completion_ratio", 0.0), | |
| "monitor/decision_entropy": monitor_summary.get("decision_entropy", 0.0), | |
| "monitor/decision_variety": monitor_summary.get("decision_variety", 0), | |
| "monitor/zero_reward_fraction": monitor_summary.get("zero_reward_fraction", 0.0), | |
| "monitor/trivially_solved_fraction": monitor_summary.get("trivially_solved_fraction", 0.0), | |
| "monitor/productive_fraction": monitor_summary.get("productive_fraction", 0.0), | |
| "monitor/effective_prompt_ratio": monitor_summary.get("effective_prompt_ratio", 0.0), | |
| "monitor/frontier_hit_rate": monitor_summary.get("frontier_hit_rate", 0.0), | |
| "monitor/task_diversity_ratio": monitor_summary.get("task_diversity_ratio", 0.0), | |
| "monitor/zero_gradient_group_fraction": monitor_summary.get("zero_gradient_group_fraction", 0.0), | |
| "monitor/adversarial_case_fraction": monitor_summary.get("adversarial_case_fraction", 0.0), | |
| } | |
| if monitor_summary.get("memory"): | |
| wandb_payload["monitor/memory_total_episodes"] = monitor_summary["memory"].get("total_episodes", 0) | |
| wandb_payload["monitor/memory_mistake_cards"] = monitor_summary["memory"].get("mistake_cards_stored", 0) | |
| if USE_SENTINEL: | |
| wandb_payload.update( | |
| { | |
| "monitor/detection_rate": monitor_summary.get("detection_rate", 0.0), | |
| "monitor/false_positive_rate": monitor_summary.get("false_positive_rate", 0.0), | |
| "monitor/risk_reduction_rate": monitor_summary.get("risk_reduction_rate", 0.0), | |
| "monitor/twin_damage_reduction_rate": monitor_summary.get("twin_damage_reduction_rate", 0.0), | |
| "monitor/twin_without_sentinel_damage_total": monitor_summary.get("twin_without_sentinel_damage_total", 0.0), | |
| "monitor/twin_with_sentinel_damage_total": monitor_summary.get("twin_with_sentinel_damage_total", 0.0), | |
| "monitor/worker_rehabilitation_rate": monitor_summary.get("worker_rehabilitation_rate", 0.0), | |
| "monitor/coaching_quality": monitor_summary.get("coaching_quality", 0.0), | |
| } | |
| ) | |
| if reward_schedule: | |
| wandb_payload.update( | |
| { | |
| "monitor/reward_schedule_progress": reward_schedule.get("progress", 0.0), | |
| "monitor/reward_schedule_stage": reward_schedule.get("stage", "unknown"), | |
| } | |
| ) | |
| if audit_path: | |
| wandb_payload["monitor/rollout_audit_saved"] = 1 | |
| wandb.log(wandb_payload) | |
| return rewards | |
| # Create trainer | |
| trainer = GRPOTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| args = grpo_config, | |
| train_dataset = train_dataset, | |
| reward_funcs = [reward_fn_with_curriculum], | |
| ) | |
| stability_callback = GRPOStabilityCallback( | |
| training_monitor=training_monitor, | |
| initial_beta=KL_COEF, | |
| target_kl=KL_TARGET, | |
| adaptive=KL_ADAPTIVE, | |
| low_factor=KL_LOW_FACTOR, | |
| high_factor=KL_HIGH_FACTOR, | |
| beta_up_mult=KL_BETA_UP_MULT, | |
| beta_down_mult=KL_BETA_DOWN_MULT, | |
| min_beta=KL_MIN_BETA, | |
| max_beta=KL_MAX_BETA, | |
| hard_stop_enabled=KL_HARD_STOP_ENABLED, | |
| hard_stop_mult=KL_HARD_STOP_MULT, | |
| ) | |
| trainer.add_callback(stability_callback) | |
| stability_callback.bind_trainer(trainer) | |
| # Train | |
| logger.info("Starting training...") | |
| start_time = time.time() | |
| trainer.train() | |
| elapsed = time.time() - start_time | |
| logger.info("Training complete in %.1f minutes", elapsed / 60) | |
| # Save final model | |
| final_path = os.path.join(OUTPUT_DIR, "final") | |
| trainer.save_model(final_path) | |
| tokenizer.save_pretrained(final_path) | |
| logger.info("Saved final model to %s", final_path) | |
| # Save curriculum state | |
| if curriculum: | |
| logger.info("Curriculum summary: %s", curriculum.summary()) | |
| if USE_AGENT_MEMORY: | |
| save_agent_memory(memory) | |
| if USE_SENTINEL and USE_FEEDBACK_MEMORY: | |
| save_feedback_memory(feedback_memory, SENTINEL_FEEDBACK_MEMORY_PATH) | |
| if warm_start_summary: | |
| logger.info("Warm-start summary: %s", warm_start_summary) | |
| if USE_SENTINEL: | |
| reset_reward_weights() | |
| # Plot reward curve | |
| _plot_reward_curve() | |
| try: | |
| from scripts.render_training_dashboard import render_dashboard | |
| render_dashboard( | |
| monitor_dir=TRAIN_MONITOR_DIR, | |
| output_dir="outputs/reward_curves", | |
| ) | |
| except Exception as exc: | |
| logger.warning("Training dashboard render skipped: %s", exc) | |
| # Push to Hub (if HF_TOKEN set) | |
| hf_repo = os.getenv("HF_REPO") | |
| if hf_repo and HF_TOKEN: | |
| logger.info("Pushing to HuggingFace Hub: %s", hf_repo) | |
| trainer.model.push_to_hub(hf_repo, token=HF_TOKEN) | |
| tokenizer.push_to_hub(hf_repo, token=HF_TOKEN) | |
| logger.info("Done! Update openenv.yaml model: %s", hf_repo) | |
| if wandb_enabled: | |
| import wandb | |
| wandb.finish() | |
| return final_path | |
| # --------------------------------------------------------------------------- | |
| # Reward curve plot | |
| # --------------------------------------------------------------------------- | |
| def _plot_reward_curve(): | |
| """Plot reward/mean over steps from wandb run or log file.""" | |
| try: | |
| import matplotlib.pyplot as plt | |
| steps, rewards = [], [] | |
| monitor_path = Path(TRAIN_MONITOR_DIR) / "training_metrics.jsonl" | |
| if monitor_path.exists(): | |
| with monitor_path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| payload = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| steps.append(int(payload.get("batch_index", len(steps) + 1))) | |
| rewards.append(float(payload.get("reward_mean", 0.0))) | |
| else: | |
| log_path = os.path.join(OUTPUT_DIR, "train.log") | |
| if not os.path.exists(log_path): | |
| return | |
| with open(log_path, encoding="utf-8", errors="ignore") as f: | |
| for line in f: | |
| if "Batch rewards: mean=" in line: | |
| try: | |
| mean_str = line.split("mean=")[1].split(" ")[0] | |
| steps.append(len(steps) + 1) | |
| rewards.append(float(mean_str)) | |
| except Exception: | |
| pass | |
| if not steps: | |
| return | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(steps, rewards, linewidth=2, color="royalblue") | |
| plt.xlabel("Training Step") | |
| plt.ylabel("Mean Reward") | |
| plt.title("GRPO Training Reward Curve") | |
| plt.grid(True, alpha=0.3) | |
| # Smoothed line | |
| if len(rewards) > 10: | |
| window = min(10, len(rewards) // 5) | |
| smoothed = np.convolve(rewards, np.ones(window)/window, mode="valid") | |
| smooth_steps = steps[:len(smoothed)] | |
| plt.plot(smooth_steps, smoothed, linewidth=2, color="red", | |
| linestyle="--", label=f"Smoothed (w={window})") | |
| plt.legend() | |
| plot_path = "outputs/reward_curves/training_curve.png" | |
| plt.savefig(plot_path, dpi=120, bbox_inches="tight") | |
| plt.close() | |
| logger.info("Saved reward curve to %s", plot_path) | |
| except ImportError: | |
| logger.info("matplotlib not installed - skipping reward plot") | |
| except Exception as e: | |
| logger.warning("Could not plot reward curve: %s", e) | |
| # --------------------------------------------------------------------------- | |
| # CLI entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="GRPO training for OpenEnv") | |
| parser.add_argument("--steps", type=int, default=TRAIN_STEPS, help="Training steps") | |
| parser.add_argument("--model", type=str, default=MODEL_NAME, help="Model name/path") | |
| parser.add_argument("--lr", type=float, default=LR, help="Learning rate") | |
| parser.add_argument("--output", type=str, default=OUTPUT_DIR, help="Output directory") | |
| parser.add_argument("--resume", type=str, default=RESUME_FROM, help="Checkpoint to resume from") | |
| parser.add_argument("--warm-start-steps", type=int, default=WARM_START_STEPS, help="Optional small SFT-style warm-start steps before GRPO") | |
| parser.add_argument("--warm-start-only", action="store_true", help="Run only the warm-start stage and stop before GRPO") | |
| parser.add_argument("--dry-run", action="store_true", help="Validate setup without training") | |
| args = parser.parse_args() | |
| # Override from CLI | |
| TRAIN_STEPS = args.steps | |
| MODEL_NAME = args.model | |
| LR = args.lr | |
| OUTPUT_DIR = args.output | |
| RESUME_FROM = args.resume | |
| WARM_START_STEPS = args.warm_start_steps | |
| WARM_START_ONLY = args.warm_start_only or WARM_START_ONLY | |
| if args.dry_run: | |
| logger.info("DRY RUN: Validating environment and reward function...") | |
| if USE_SENTINEL: | |
| from sentinel.environment import SentinelEnv | |
| env = SentinelEnv() | |
| for task_id in SENTINEL_TASK_IDS: | |
| obs = env.reset(task_id=task_id, variant_seed=0) | |
| grade = env.grade() | |
| score = float(grade.score) if hasattr(grade, "score") else float(grade.get("score", 0.0)) | |
| logger.info(" task=%s initial_grade=%.3f", task_id, score) | |
| else: | |
| from src.environment import IncidentResponseEnv | |
| env = IncidentResponseEnv() | |
| for task_id in TASK_IDS: | |
| obs = env.reset(task_id=task_id, variant_seed=0) | |
| grade = env.grade() | |
| score = float(grade.score) if hasattr(grade, "score") else float(grade.get("score", 0.0)) | |
| logger.info(" task=%s initial_grade=%.3f", task_id, score) | |
| if WARM_START_STEPS > 0: | |
| from training.memory import load_agent_memory | |
| from sentinel.feedback import load_feedback_memory | |
| warm_start_records = _build_warm_start_examples( | |
| task_ids=list(ACTIVE_TASK_IDS), | |
| memory=load_agent_memory(), | |
| feedback_memory=load_feedback_memory(SENTINEL_FEEDBACK_MEMORY_PATH), | |
| max_examples=max(1, min(WARM_START_DATASET_SIZE, 8)), | |
| ) | |
| logger.info(" warm_start_examples=%d", len(warm_start_records)) | |
| logger.info("DRY RUN PASSED. Environment is working.") | |
| sys.exit(0) | |
| final_path = train() | |
| logger.info("Training finished. Final model: %s", final_path) | |
| logger.info("Next steps:") | |
| logger.info(" 1. python validate.py") | |
| logger.info(" 2. Update openenv.yaml: model: <HF_REPO>") | |
| logger.info(" 3. Submit!") | |