""" RFT (Rejection-sampling Fine-Tuning) polish pass for the trained Sentinel LoRA. Pipeline: 1. Load the 200-step GRPO LoRA from $LORA_PATH on top of Qwen3-4B-bnb-4bit. 2. Generate N rollouts per Sentinel task with the trained policy. 3. Score each rollout with the real env reward + count false positives from the audit trail. 4. Keep ONLY the rollouts with `score >= MIN_SCORE` AND `fp <= MAX_FP`. 5. SFT (UnslothTrainer) for `EPOCHS` epochs on those high-quality rollouts. 6. Save the polished LoRA to $RFT_OUTPUT_DIR/final. 7. Optionally upload to the HuggingFace Hub. This is the technique competing teams use to push reward 0.30 -> 0.55+. ENV VARS: LORA_PATH existing GRPO LoRA (default /data/sentinel_outputs/final) MODEL_NAME base model (default unsloth/Qwen3-4B-bnb-4bit) RFT_OUTPUT_DIR where to save (default /data/sentinel_outputs_rft) NUM_ROLLOUTS_PER_TASK per-task generation count (default 20) MAX_NEW_TOKENS cap on each rollout (default 512) GEN_TEMPERATURE sampling temp (default 0.7) GEN_TOP_P nucleus p (default 0.9) MIN_SCORE keep filter (>=) (default 0.55) MAX_FP keep filter (<=) (default 3) EPOCHS SFT epochs (default 2) SFT_LR SFT learning rate (default 5e-6) HF_TOKEN HF write token (optional) HF_REPO HF repo id (optional) Output: $RFT_OUTPUT_DIR/final/ polished LoRA adapter $RFT_OUTPUT_DIR/rollouts.jsonl all rollouts with scores $RFT_OUTPUT_DIR/sft_dataset.jsonl filtered (kept) rollouts $RFT_OUTPUT_DIR/rft_summary.json run summary statistics """ from __future__ import annotations import json import logging import os import sys from collections import Counter, defaultdict from pathlib import Path from typing import Any, Dict, List # Make sure repo root is on sys.path ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) import torch from datasets import Dataset from peft import PeftModel, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments from training.episodes import run_episode_with_completion from training.prompts import build_prompt_record logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger("rft_polish") # --------------------------------------------------------------------------- # Config (env-driven so judges can rerun) # --------------------------------------------------------------------------- LORA_PATH = os.environ.get("LORA_PATH", "/data/sentinel_outputs/final") MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen3-4B-bnb-4bit") RFT_OUTPUT_DIR = os.environ.get("RFT_OUTPUT_DIR", "/data/sentinel_outputs_rft") NUM_ROLLOUTS_PER_TASK = int(os.environ.get("NUM_ROLLOUTS_PER_TASK", "20")) MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "512")) GEN_TEMPERATURE = float(os.environ.get("GEN_TEMPERATURE", "0.7")) GEN_TOP_P = float(os.environ.get("GEN_TOP_P", "0.9")) MIN_SCORE = float(os.environ.get("MIN_SCORE", "0.55")) MAX_FP = int(os.environ.get("MAX_FP", "3")) EPOCHS = int(os.environ.get("EPOCHS", "2")) SFT_LR = float(os.environ.get("SFT_LR", "5e-6")) HF_TOKEN = os.environ.get("HF_TOKEN", "") HF_REPO = os.environ.get("HF_REPO", "") MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096")) REUSE_RFT_ROLLOUTS = os.environ.get("REUSE_RFT_ROLLOUTS", "").lower() in {"1", "true", "yes", "y"} RFT_BACKEND = os.environ.get("RFT_BACKEND", "standard" if REUSE_RFT_ROLLOUTS else "unsloth").lower() SENTINEL_TASKS = [ "basic_oversight", "fleet_monitoring_conflict", "adversarial_worker", "multi_crisis_command", ] def banner(msg: str) -> None: bar = "=" * 70 logger.info(bar) logger.info(" %s", msg) logger.info(bar) def count_false_positives(history: List[Dict[str, Any]]) -> int: """Count audit entries where Sentinel blocked a SAFE action.""" fp = 0 for entry in history: audit = entry.get("audit") or {} decision = audit.get("sentinel_decision") was_misb = audit.get("was_misbehavior") if (decision is not None and decision != "APPROVE") and not was_misb: fp += 1 return fp def load_existing_rollouts(path: Path) -> List[Dict[str, Any]]: """Load saved rollout JSONL and recompute keep/drop with current thresholds.""" rows: List[Dict[str, Any]] = [] if not path.exists(): return rows for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): line = line.strip() if not line: continue try: row = json.loads(line) except json.JSONDecodeError: continue if not isinstance(row, dict): continue score = float(row.get("score") or 0.0) fp = int(row.get("fp") or 0) row["score"] = score row["fp"] = fp row["kept"] = score >= MIN_SCORE and fp <= MAX_FP rows.append(row) return rows def resolve_tokenizer_eos(tokenizer) -> str | None: """Resolve an EOS token that actually exists in the tokenizer vocab.""" candidates = [ getattr(tokenizer, "eos_token", None), "<|im_end|>", "<|endoftext|>", ] unk_id = getattr(tokenizer, "unk_token_id", None) for token in candidates: if not token: continue try: token_id = tokenizer.convert_tokens_to_ids(token) except Exception: token_id = None if token_id is not None and token_id != unk_id: return token eos_id = getattr(tokenizer, "eos_token_id", None) if eos_id is not None: try: return tokenizer.convert_ids_to_tokens(eos_id) except Exception: return None return None def build_causal_lm_dataset(tokenizer, dataset: Dataset) -> Dataset: """Tokenize text rows for plain HF Trainer causal-LM fine-tuning.""" eos_token = resolve_tokenizer_eos(tokenizer) if eos_token: tokenizer.eos_token = eos_token if tokenizer.pad_token_id is None and eos_token: tokenizer.pad_token = eos_token logger.info("Using eos token as pad token for RFT SFT: %s", eos_token) def tokenize_batch(batch): encoded = tokenizer( batch["text"], truncation=True, max_length=MAX_SEQ_LENGTH, padding=False, ) encoded["labels"] = [ids.copy() for ids in encoded["input_ids"]] return encoded return dataset.map(tokenize_batch, batched=True, remove_columns=dataset.column_names) def build_causal_lm_collator(tokenizer): """Pad inputs and mask padded labels for causal-LM SFT.""" pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = tokenizer.eos_token_id if pad_id is None: pad_id = 0 def collate(features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: max_len = min(MAX_SEQ_LENGTH, max(len(feature["input_ids"]) for feature in features)) batch = {"input_ids": [], "attention_mask": [], "labels": []} for feature in features: input_ids = list(feature["input_ids"][:max_len]) attention_mask = list(feature.get("attention_mask", [1] * len(input_ids))[:max_len]) labels = list(feature["labels"][:max_len]) pad_len = max_len - len(input_ids) if pad_len > 0: input_ids.extend([pad_id] * pad_len) attention_mask.extend([0] * pad_len) labels.extend([-100] * pad_len) batch["input_ids"].append(input_ids) batch["attention_mask"].append(attention_mask) batch["labels"].append(labels) return {key: torch.tensor(value, dtype=torch.long) for key, value in batch.items()} return collate def disable_gradient_checkpointing(model) -> None: """Disable checkpointing paths that can mismatch across Unsloth/Transformers versions.""" try: model.gradient_checkpointing_disable() except Exception: pass for module in model.modules(): if hasattr(module, "gradient_checkpointing"): try: module.gradient_checkpointing = False except Exception: pass config = getattr(module, "config", None) if config is not None and hasattr(config, "gradient_checkpointing"): try: config.gradient_checkpointing = False except Exception: pass config = getattr(model, "config", None) if config is not None: if hasattr(config, "gradient_checkpointing"): config.gradient_checkpointing = False if hasattr(config, "use_cache"): config.use_cache = False logger.info("Gradient checkpointing disabled for RFT SFT compatibility") def build_sft_trainer(model, tokenizer, dataset: Dataset, output_dir: Path) -> Trainer: """Create a plain HF Trainer to avoid TRL EOS-token version bugs.""" eos_token = resolve_tokenizer_eos(tokenizer) if eos_token: tokenizer.eos_token = eos_token logger.info("Preparing plain HF Trainer with tokenizer eos_token=%s", eos_token) tokenized = build_causal_lm_dataset(tokenizer, dataset) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=EPOCHS, per_device_train_batch_size=2, gradient_accumulation_steps=2, learning_rate=SFT_LR, logging_steps=1, save_strategy="no", report_to=[], bf16=False, fp16=False, optim="adamw_torch", gradient_checkpointing=False, warmup_ratio=0.1, lr_scheduler_type="cosine", remove_unused_columns=False, seed=42, ) return Trainer( model=model, args=training_args, train_dataset=tokenized, data_collator=build_causal_lm_collator(tokenizer), ) # --------------------------------------------------------------------------- # 1. Load base model + existing LoRA in fp16 for inference # --------------------------------------------------------------------------- def load_unsloth_policy(): banner("Loading base model + GRPO LoRA with Unsloth") from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name = MODEL_NAME, max_seq_length = MAX_SEQ_LENGTH, dtype = torch.float16, load_in_4bit = True, ) if Path(LORA_PATH).exists(): logger.info("Loading LoRA adapter from %s", LORA_PATH) model = PeftModel.from_pretrained(model, LORA_PATH, is_trainable=True) # Coerce LoRA to fp16 to match bnb-4bit compute dtype (avoids matmul errors) for name, p in model.named_parameters(): if "lora_" in name and p.dtype != torch.float16: p.data = p.data.to(torch.float16) else: logger.warning("LORA_PATH %s does not exist, using base model only", LORA_PATH) FastLanguageModel.for_inference(model) return model, tokenizer def load_standard_policy(): """Load with standard Transformers/PEFT to avoid Unsloth/xFormers training kernels.""" banner("Loading base model + GRPO LoRA with standard Transformers") quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) eos_token = resolve_tokenizer_eos(tokenizer) if eos_token: tokenizer.eos_token = eos_token if tokenizer.pad_token_id is None and eos_token: tokenizer.pad_token = eos_token model_kwargs = { "quantization_config": quant_config, "device_map": "auto", "torch_dtype": torch.float16, "trust_remote_code": True, } try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, attn_implementation="eager", **model_kwargs, ) logger.info("Loaded standard model with eager attention") except TypeError: model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_kwargs) logger.info("Loaded standard model without explicit attention override") if hasattr(model.config, "use_cache"): model.config.use_cache = False model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) if Path(LORA_PATH).exists(): logger.info("Loading LoRA adapter from %s", LORA_PATH) model = PeftModel.from_pretrained(model, LORA_PATH, is_trainable=True) for name, p in model.named_parameters(): if "lora_" in name and p.dtype != torch.float16: p.data = p.data.to(torch.float16) else: logger.warning("LORA_PATH %s does not exist, using base model only", LORA_PATH) model.train() return model, tokenizer def load_policy(): if RFT_BACKEND == "standard": return load_standard_policy() if RFT_BACKEND != "unsloth": logger.warning("Unknown RFT_BACKEND=%s; falling back to standard", RFT_BACKEND) return load_standard_policy() return load_unsloth_policy() # --------------------------------------------------------------------------- # 2. Generate rollouts and 3. Score them # --------------------------------------------------------------------------- def generate_and_score(model, tokenizer) -> List[Dict[str, Any]]: banner(f"Generating {NUM_ROLLOUTS_PER_TASK} rollouts x {len(SENTINEL_TASKS)} tasks") all_rollouts: List[Dict[str, Any]] = [] for task_id in SENTINEL_TASKS: for variant_seed in range(NUM_ROLLOUTS_PER_TASK): try: record = build_prompt_record( task_id=task_id, sentinel_task_ids=SENTINEL_TASKS, variant_seed=variant_seed % 5, # 5 variants cycled memory_context="", ) except Exception as exc: logger.warning("prompt build failed for %s seed %d: %s", task_id, variant_seed, exc) continue prompt = record["prompt"] inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH - MAX_NEW_TOKENS).to(model.device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens = MAX_NEW_TOKENS, temperature = GEN_TEMPERATURE, top_p = GEN_TOP_P, do_sample = True, pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id, ) completion = tokenizer.decode( out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True ) try: score, history = run_episode_with_completion( completion, task_id, variant_seed % 5, SENTINEL_TASKS, model_steps_limit=3, ) except Exception as exc: logger.warning("scoring failed for %s seed %d: %s", task_id, variant_seed, exc) score, history = 0.0, [] fp = count_false_positives(history) rollout = { "task_id": task_id, "variant_seed": variant_seed % 5, "prompt": prompt, "completion": completion, "score": float(score), "fp": int(fp), "kept": (score >= MIN_SCORE and fp <= MAX_FP), } all_rollouts.append(rollout) logger.info( "[%s seed=%d] score=%.3f fp=%d %s", task_id, variant_seed % 5, score, fp, "KEEP" if rollout["kept"] else "drop", ) return all_rollouts # --------------------------------------------------------------------------- # 4. Filter and 5. SFT # --------------------------------------------------------------------------- def filter_and_sft(model, tokenizer, all_rollouts: List[Dict[str, Any]]) -> Dict[str, Any]: kept = [r for r in all_rollouts if r["kept"]] banner( f"Filtered: {len(kept)} kept / {len(all_rollouts)} total " f"(score >= {MIN_SCORE}, fp <= {MAX_FP})" ) if len(kept) < 4: logger.error( "Only %d rollouts passed the filter; need at least 4 for stable SFT. " "Aborting RFT to avoid producing a worse model.", len(kept) ) return {"status": "skipped_too_few_rollouts", "kept": len(kept), "total": len(all_rollouts)} # Build chat-style training texts: prompt + completion rows = [] for r in kept: full_text = r["prompt"] + r["completion"] + tokenizer.eos_token rows.append({"text": full_text}) ds = Dataset.from_list(rows) # Switch model back to training mode (Unsloth toggles this on for_inference) if RFT_BACKEND == "unsloth": from unsloth import FastLanguageModel FastLanguageModel.for_training(model) else: model.train() disable_gradient_checkpointing(model) sft_output = Path(RFT_OUTPUT_DIR) / "sft_run" sft_output.mkdir(parents=True, exist_ok=True) trainer = build_sft_trainer(model, tokenizer, ds, sft_output) banner(f"Starting SFT on {len(kept)} kept rollouts for {EPOCHS} epochs (lr={SFT_LR})") trainer.train() # Save final polished LoRA final_dir = Path(RFT_OUTPUT_DIR) / "final" final_dir.mkdir(parents=True, exist_ok=True) trainer.model.save_pretrained(str(final_dir)) tokenizer.save_pretrained(str(final_dir)) logger.info("Saved RFT-polished LoRA to %s", final_dir) return { "status": "ok", "kept": len(kept), "total": len(all_rollouts), "epochs": EPOCHS, "lr": SFT_LR, "saved_to": str(final_dir), } # --------------------------------------------------------------------------- # 6. Optional HF Hub push # --------------------------------------------------------------------------- def maybe_push_to_hub() -> None: final_dir = Path(RFT_OUTPUT_DIR) / "final" if not (HF_TOKEN and HF_REPO and final_dir.exists()): logger.info("Skipping HF Hub push (missing HF_TOKEN/HF_REPO or no final/ dir)") return banner(f"Uploading {final_dir} -> https://huggingface.co/{HF_REPO}") from huggingface_hub import HfApi, create_repo create_repo(HF_REPO, token=HF_TOKEN, exist_ok=True, private=False) HfApi().upload_folder( folder_path = str(final_dir), repo_id = HF_REPO, token = HF_TOKEN, commit_message = "Upload RFT-polished LoRA (rejection-sampling fine-tune)", ) logger.info("Upload complete: https://huggingface.co/%s", HF_REPO) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: banner("RFT Polish — config") for k, v in { "LORA_PATH": LORA_PATH, "MODEL_NAME": MODEL_NAME, "RFT_OUTPUT_DIR": RFT_OUTPUT_DIR, "NUM_ROLLOUTS_PER_TASK": NUM_ROLLOUTS_PER_TASK, "MAX_NEW_TOKENS": MAX_NEW_TOKENS, "GEN_TEMPERATURE": GEN_TEMPERATURE, "GEN_TOP_P": GEN_TOP_P, "MIN_SCORE": MIN_SCORE, "MAX_FP": MAX_FP, "EPOCHS": EPOCHS, "SFT_LR": SFT_LR, "HF_REPO": HF_REPO or "(skip)", "REUSE_RFT_ROLLOUTS": REUSE_RFT_ROLLOUTS, "RFT_BACKEND": RFT_BACKEND, }.items(): logger.info(" %-22s = %s", k, v) Path(RFT_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) model, tokenizer = load_policy() # Persist all rollouts (for proof pack) rollouts_file = Path(RFT_OUTPUT_DIR) / "rollouts.jsonl" if REUSE_RFT_ROLLOUTS and rollouts_file.exists(): all_rollouts = load_existing_rollouts(rollouts_file) logger.info("Reusing %d saved rollouts from %s", len(all_rollouts), rollouts_file) else: all_rollouts = generate_and_score(model, tokenizer) with rollouts_file.open("w") as fh: for r in all_rollouts: fh.write(json.dumps(r) + "\n") logger.info("Wrote %d rollouts to %s", len(all_rollouts), rollouts_file) # Per-task summary BEFORE filtering by_task = defaultdict(list) for r in all_rollouts: by_task[r["task_id"]].append(r) banner("Per-task generation stats") for task_id, rs in by_task.items(): scores = [r["score"] for r in rs] fps = [r["fp"] for r in rs] kept = sum(1 for r in rs if r["kept"]) logger.info( " %-30s n=%2d mean_score=%.3f mean_fp=%.1f kept=%d", task_id, len(rs), sum(scores)/max(1, len(rs)), sum(fps)/max(1, len(rs)), kept, ) # SFT on the kept rollouts sft_summary = filter_and_sft(model, tokenizer, all_rollouts) # Persist filtered SFT dataset for transparency kept_file = Path(RFT_OUTPUT_DIR) / "sft_dataset.jsonl" with kept_file.open("w") as fh: for r in all_rollouts: if r["kept"]: fh.write(json.dumps(r) + "\n") logger.info("Wrote %d kept samples to %s", sum(1 for r in all_rollouts if r["kept"]), kept_file) # Final summary summary = { "config": { "LORA_PATH": LORA_PATH, "MODEL_NAME": MODEL_NAME, "NUM_ROLLOUTS_PER_TASK": NUM_ROLLOUTS_PER_TASK, "MIN_SCORE": MIN_SCORE, "MAX_FP": MAX_FP, "EPOCHS": EPOCHS, "SFT_LR": SFT_LR, }, "rollout_stats": { "total": len(all_rollouts), "kept": sum(1 for r in all_rollouts if r["kept"]), "mean_score_total": sum(r["score"] for r in all_rollouts) / max(1, len(all_rollouts)), "mean_fp_total": sum(r["fp"] for r in all_rollouts) / max(1, len(all_rollouts)), "mean_score_kept": ( sum(r["score"] for r in all_rollouts if r["kept"]) / max(1, sum(1 for r in all_rollouts if r["kept"])) ), "mean_fp_kept": ( sum(r["fp"] for r in all_rollouts if r["kept"]) / max(1, sum(1 for r in all_rollouts if r["kept"])) ), "task_breakdown": { t: { "n": len(rs), "mean_score": sum(r["score"] for r in rs) / max(1, len(rs)), "mean_fp": sum(r["fp"] for r in rs) / max(1, len(rs)), "kept": sum(1 for r in rs if r["kept"]), } for t, rs in by_task.items() }, }, "sft": sft_summary, } summary_file = Path(RFT_OUTPUT_DIR) / "rft_summary.json" summary_file.write_text(json.dumps(summary, indent=2)) logger.info("Wrote summary to %s", summary_file) maybe_push_to_hub() banner("RFT polish complete") logger.info("Final LoRA: %s/final", RFT_OUTPUT_DIR) logger.info("Summary: %s", summary_file) if HF_REPO: logger.info("HF Hub: https://huggingface.co/%s", HF_REPO) if __name__ == "__main__": main()