Spaces:
Running
Running
| """ | |
| RFT (Rejection-sampling Fine-Tuning) polish pass for the trained Sentinel LoRA. | |
| Pipeline: | |
| 1. Load the 200-step GRPO LoRA from $LORA_PATH on top of Qwen3-4B-bnb-4bit. | |
| 2. Generate N rollouts per Sentinel task with the trained policy. | |
| 3. Score each rollout with the real env reward + count false positives | |
| from the audit trail. | |
| 4. Keep ONLY the rollouts with `score >= MIN_SCORE` AND `fp <= MAX_FP`. | |
| 5. SFT (UnslothTrainer) for `EPOCHS` epochs on those high-quality rollouts. | |
| 6. Save the polished LoRA to $RFT_OUTPUT_DIR/final. | |
| 7. Optionally upload to the HuggingFace Hub. | |
| This is the technique competing teams use to push reward 0.30 -> 0.55+. | |
| ENV VARS: | |
| LORA_PATH existing GRPO LoRA (default /data/sentinel_outputs/final) | |
| MODEL_NAME base model (default unsloth/Qwen3-4B-bnb-4bit) | |
| RFT_OUTPUT_DIR where to save (default /data/sentinel_outputs_rft) | |
| NUM_ROLLOUTS_PER_TASK per-task generation count (default 20) | |
| MAX_NEW_TOKENS cap on each rollout (default 512) | |
| GEN_TEMPERATURE sampling temp (default 0.7) | |
| GEN_TOP_P nucleus p (default 0.9) | |
| MIN_SCORE keep filter (>=) (default 0.55) | |
| MAX_FP keep filter (<=) (default 3) | |
| EPOCHS SFT epochs (default 2) | |
| SFT_LR SFT learning rate (default 5e-6) | |
| HF_TOKEN HF write token (optional) | |
| HF_REPO HF repo id (optional) | |
| Output: | |
| $RFT_OUTPUT_DIR/final/ polished LoRA adapter | |
| $RFT_OUTPUT_DIR/rollouts.jsonl all rollouts with scores | |
| $RFT_OUTPUT_DIR/sft_dataset.jsonl filtered (kept) rollouts | |
| $RFT_OUTPUT_DIR/rft_summary.json run summary statistics | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from collections import Counter, defaultdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| # Make sure repo root is on sys.path | |
| ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| import torch | |
| from datasets import Dataset | |
| from peft import PeftModel, prepare_model_for_kbit_training | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments | |
| from training.episodes import run_episode_with_completion | |
| from training.prompts import build_prompt_record | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger("rft_polish") | |
| # --------------------------------------------------------------------------- | |
| # Config (env-driven so judges can rerun) | |
| # --------------------------------------------------------------------------- | |
| LORA_PATH = os.environ.get("LORA_PATH", "/data/sentinel_outputs/final") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen3-4B-bnb-4bit") | |
| RFT_OUTPUT_DIR = os.environ.get("RFT_OUTPUT_DIR", "/data/sentinel_outputs_rft") | |
| NUM_ROLLOUTS_PER_TASK = int(os.environ.get("NUM_ROLLOUTS_PER_TASK", "20")) | |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "512")) | |
| GEN_TEMPERATURE = float(os.environ.get("GEN_TEMPERATURE", "0.7")) | |
| GEN_TOP_P = float(os.environ.get("GEN_TOP_P", "0.9")) | |
| MIN_SCORE = float(os.environ.get("MIN_SCORE", "0.55")) | |
| MAX_FP = int(os.environ.get("MAX_FP", "3")) | |
| EPOCHS = int(os.environ.get("EPOCHS", "2")) | |
| SFT_LR = float(os.environ.get("SFT_LR", "5e-6")) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| HF_REPO = os.environ.get("HF_REPO", "") | |
| MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096")) | |
| REUSE_RFT_ROLLOUTS = os.environ.get("REUSE_RFT_ROLLOUTS", "").lower() in {"1", "true", "yes", "y"} | |
| RFT_BACKEND = os.environ.get("RFT_BACKEND", "standard" if REUSE_RFT_ROLLOUTS else "unsloth").lower() | |
| SENTINEL_TASKS = [ | |
| "basic_oversight", | |
| "fleet_monitoring_conflict", | |
| "adversarial_worker", | |
| "multi_crisis_command", | |
| ] | |
| def banner(msg: str) -> None: | |
| bar = "=" * 70 | |
| logger.info(bar) | |
| logger.info(" %s", msg) | |
| logger.info(bar) | |
| def count_false_positives(history: List[Dict[str, Any]]) -> int: | |
| """Count audit entries where Sentinel blocked a SAFE action.""" | |
| fp = 0 | |
| for entry in history: | |
| audit = entry.get("audit") or {} | |
| decision = audit.get("sentinel_decision") | |
| was_misb = audit.get("was_misbehavior") | |
| if (decision is not None and decision != "APPROVE") and not was_misb: | |
| fp += 1 | |
| return fp | |
| def load_existing_rollouts(path: Path) -> List[Dict[str, Any]]: | |
| """Load saved rollout JSONL and recompute keep/drop with current thresholds.""" | |
| rows: List[Dict[str, Any]] = [] | |
| if not path.exists(): | |
| return rows | |
| for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| row = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| if not isinstance(row, dict): | |
| continue | |
| score = float(row.get("score") or 0.0) | |
| fp = int(row.get("fp") or 0) | |
| row["score"] = score | |
| row["fp"] = fp | |
| row["kept"] = score >= MIN_SCORE and fp <= MAX_FP | |
| rows.append(row) | |
| return rows | |
| def resolve_tokenizer_eos(tokenizer) -> str | None: | |
| """Resolve an EOS token that actually exists in the tokenizer vocab.""" | |
| candidates = [ | |
| getattr(tokenizer, "eos_token", None), | |
| "<|im_end|>", | |
| "<|endoftext|>", | |
| ] | |
| unk_id = getattr(tokenizer, "unk_token_id", None) | |
| for token in candidates: | |
| if not token: | |
| continue | |
| try: | |
| token_id = tokenizer.convert_tokens_to_ids(token) | |
| except Exception: | |
| token_id = None | |
| if token_id is not None and token_id != unk_id: | |
| return token | |
| eos_id = getattr(tokenizer, "eos_token_id", None) | |
| if eos_id is not None: | |
| try: | |
| return tokenizer.convert_ids_to_tokens(eos_id) | |
| except Exception: | |
| return None | |
| return None | |
| def build_causal_lm_dataset(tokenizer, dataset: Dataset) -> Dataset: | |
| """Tokenize text rows for plain HF Trainer causal-LM fine-tuning.""" | |
| eos_token = resolve_tokenizer_eos(tokenizer) | |
| if eos_token: | |
| tokenizer.eos_token = eos_token | |
| if tokenizer.pad_token_id is None and eos_token: | |
| tokenizer.pad_token = eos_token | |
| logger.info("Using eos token as pad token for RFT SFT: %s", eos_token) | |
| def tokenize_batch(batch): | |
| encoded = tokenizer( | |
| batch["text"], | |
| truncation=True, | |
| max_length=MAX_SEQ_LENGTH, | |
| padding=False, | |
| ) | |
| encoded["labels"] = [ids.copy() for ids in encoded["input_ids"]] | |
| return encoded | |
| return dataset.map(tokenize_batch, batched=True, remove_columns=dataset.column_names) | |
| def build_causal_lm_collator(tokenizer): | |
| """Pad inputs and mask padded labels for causal-LM SFT.""" | |
| pad_id = tokenizer.pad_token_id | |
| if pad_id is None: | |
| pad_id = tokenizer.eos_token_id | |
| if pad_id is None: | |
| pad_id = 0 | |
| def collate(features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| max_len = min(MAX_SEQ_LENGTH, max(len(feature["input_ids"]) for feature in features)) | |
| batch = {"input_ids": [], "attention_mask": [], "labels": []} | |
| for feature in features: | |
| input_ids = list(feature["input_ids"][:max_len]) | |
| attention_mask = list(feature.get("attention_mask", [1] * len(input_ids))[:max_len]) | |
| labels = list(feature["labels"][:max_len]) | |
| pad_len = max_len - len(input_ids) | |
| if pad_len > 0: | |
| input_ids.extend([pad_id] * pad_len) | |
| attention_mask.extend([0] * pad_len) | |
| labels.extend([-100] * pad_len) | |
| batch["input_ids"].append(input_ids) | |
| batch["attention_mask"].append(attention_mask) | |
| batch["labels"].append(labels) | |
| return {key: torch.tensor(value, dtype=torch.long) for key, value in batch.items()} | |
| return collate | |
| def disable_gradient_checkpointing(model) -> None: | |
| """Disable checkpointing paths that can mismatch across Unsloth/Transformers versions.""" | |
| try: | |
| model.gradient_checkpointing_disable() | |
| except Exception: | |
| pass | |
| for module in model.modules(): | |
| if hasattr(module, "gradient_checkpointing"): | |
| try: | |
| module.gradient_checkpointing = False | |
| except Exception: | |
| pass | |
| config = getattr(module, "config", None) | |
| if config is not None and hasattr(config, "gradient_checkpointing"): | |
| try: | |
| config.gradient_checkpointing = False | |
| except Exception: | |
| pass | |
| config = getattr(model, "config", None) | |
| if config is not None: | |
| if hasattr(config, "gradient_checkpointing"): | |
| config.gradient_checkpointing = False | |
| if hasattr(config, "use_cache"): | |
| config.use_cache = False | |
| logger.info("Gradient checkpointing disabled for RFT SFT compatibility") | |
| def build_sft_trainer(model, tokenizer, dataset: Dataset, output_dir: Path) -> Trainer: | |
| """Create a plain HF Trainer to avoid TRL EOS-token version bugs.""" | |
| eos_token = resolve_tokenizer_eos(tokenizer) | |
| if eos_token: | |
| tokenizer.eos_token = eos_token | |
| logger.info("Preparing plain HF Trainer with tokenizer eos_token=%s", eos_token) | |
| tokenized = build_causal_lm_dataset(tokenizer, dataset) | |
| training_args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=EPOCHS, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=2, | |
| learning_rate=SFT_LR, | |
| logging_steps=1, | |
| save_strategy="no", | |
| report_to=[], | |
| bf16=False, | |
| fp16=False, | |
| optim="adamw_torch", | |
| gradient_checkpointing=False, | |
| warmup_ratio=0.1, | |
| lr_scheduler_type="cosine", | |
| remove_unused_columns=False, | |
| seed=42, | |
| ) | |
| return Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized, | |
| data_collator=build_causal_lm_collator(tokenizer), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # 1. Load base model + existing LoRA in fp16 for inference | |
| # --------------------------------------------------------------------------- | |
| def load_unsloth_policy(): | |
| banner("Loading base model + GRPO LoRA with Unsloth") | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = MODEL_NAME, | |
| max_seq_length = MAX_SEQ_LENGTH, | |
| dtype = torch.float16, | |
| load_in_4bit = True, | |
| ) | |
| if Path(LORA_PATH).exists(): | |
| logger.info("Loading LoRA adapter from %s", LORA_PATH) | |
| model = PeftModel.from_pretrained(model, LORA_PATH, is_trainable=True) | |
| # Coerce LoRA to fp16 to match bnb-4bit compute dtype (avoids matmul errors) | |
| for name, p in model.named_parameters(): | |
| if "lora_" in name and p.dtype != torch.float16: | |
| p.data = p.data.to(torch.float16) | |
| else: | |
| logger.warning("LORA_PATH %s does not exist, using base model only", LORA_PATH) | |
| FastLanguageModel.for_inference(model) | |
| return model, tokenizer | |
| def load_standard_policy(): | |
| """Load with standard Transformers/PEFT to avoid Unsloth/xFormers training kernels.""" | |
| banner("Loading base model + GRPO LoRA with standard Transformers") | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| eos_token = resolve_tokenizer_eos(tokenizer) | |
| if eos_token: | |
| tokenizer.eos_token = eos_token | |
| if tokenizer.pad_token_id is None and eos_token: | |
| tokenizer.pad_token = eos_token | |
| model_kwargs = { | |
| "quantization_config": quant_config, | |
| "device_map": "auto", | |
| "torch_dtype": torch.float16, | |
| "trust_remote_code": True, | |
| } | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| attn_implementation="eager", | |
| **model_kwargs, | |
| ) | |
| logger.info("Loaded standard model with eager attention") | |
| except TypeError: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_kwargs) | |
| logger.info("Loaded standard model without explicit attention override") | |
| if hasattr(model.config, "use_cache"): | |
| model.config.use_cache = False | |
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) | |
| if Path(LORA_PATH).exists(): | |
| logger.info("Loading LoRA adapter from %s", LORA_PATH) | |
| model = PeftModel.from_pretrained(model, LORA_PATH, is_trainable=True) | |
| for name, p in model.named_parameters(): | |
| if "lora_" in name and p.dtype != torch.float16: | |
| p.data = p.data.to(torch.float16) | |
| else: | |
| logger.warning("LORA_PATH %s does not exist, using base model only", LORA_PATH) | |
| model.train() | |
| return model, tokenizer | |
| def load_policy(): | |
| if RFT_BACKEND == "standard": | |
| return load_standard_policy() | |
| if RFT_BACKEND != "unsloth": | |
| logger.warning("Unknown RFT_BACKEND=%s; falling back to standard", RFT_BACKEND) | |
| return load_standard_policy() | |
| return load_unsloth_policy() | |
| # --------------------------------------------------------------------------- | |
| # 2. Generate rollouts and 3. Score them | |
| # --------------------------------------------------------------------------- | |
| def generate_and_score(model, tokenizer) -> List[Dict[str, Any]]: | |
| banner(f"Generating {NUM_ROLLOUTS_PER_TASK} rollouts x {len(SENTINEL_TASKS)} tasks") | |
| all_rollouts: List[Dict[str, Any]] = [] | |
| for task_id in SENTINEL_TASKS: | |
| for variant_seed in range(NUM_ROLLOUTS_PER_TASK): | |
| try: | |
| record = build_prompt_record( | |
| task_id=task_id, | |
| sentinel_task_ids=SENTINEL_TASKS, | |
| variant_seed=variant_seed % 5, # 5 variants cycled | |
| memory_context="", | |
| ) | |
| except Exception as exc: | |
| logger.warning("prompt build failed for %s seed %d: %s", | |
| task_id, variant_seed, exc) | |
| continue | |
| prompt = record["prompt"] | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, | |
| max_length=MAX_SEQ_LENGTH - MAX_NEW_TOKENS).to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens = MAX_NEW_TOKENS, | |
| temperature = GEN_TEMPERATURE, | |
| top_p = GEN_TOP_P, | |
| do_sample = True, | |
| pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| completion = tokenizer.decode( | |
| out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True | |
| ) | |
| try: | |
| score, history = run_episode_with_completion( | |
| completion, task_id, variant_seed % 5, SENTINEL_TASKS, | |
| model_steps_limit=3, | |
| ) | |
| except Exception as exc: | |
| logger.warning("scoring failed for %s seed %d: %s", | |
| task_id, variant_seed, exc) | |
| score, history = 0.0, [] | |
| fp = count_false_positives(history) | |
| rollout = { | |
| "task_id": task_id, | |
| "variant_seed": variant_seed % 5, | |
| "prompt": prompt, | |
| "completion": completion, | |
| "score": float(score), | |
| "fp": int(fp), | |
| "kept": (score >= MIN_SCORE and fp <= MAX_FP), | |
| } | |
| all_rollouts.append(rollout) | |
| logger.info( | |
| "[%s seed=%d] score=%.3f fp=%d %s", | |
| task_id, variant_seed % 5, score, fp, | |
| "KEEP" if rollout["kept"] else "drop", | |
| ) | |
| return all_rollouts | |
| # --------------------------------------------------------------------------- | |
| # 4. Filter and 5. SFT | |
| # --------------------------------------------------------------------------- | |
| def filter_and_sft(model, tokenizer, all_rollouts: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| kept = [r for r in all_rollouts if r["kept"]] | |
| banner( | |
| f"Filtered: {len(kept)} kept / {len(all_rollouts)} total " | |
| f"(score >= {MIN_SCORE}, fp <= {MAX_FP})" | |
| ) | |
| if len(kept) < 4: | |
| logger.error( | |
| "Only %d rollouts passed the filter; need at least 4 for stable SFT. " | |
| "Aborting RFT to avoid producing a worse model.", len(kept) | |
| ) | |
| return {"status": "skipped_too_few_rollouts", "kept": len(kept), "total": len(all_rollouts)} | |
| # Build chat-style training texts: prompt + completion | |
| rows = [] | |
| for r in kept: | |
| full_text = r["prompt"] + r["completion"] + tokenizer.eos_token | |
| rows.append({"text": full_text}) | |
| ds = Dataset.from_list(rows) | |
| # Switch model back to training mode (Unsloth toggles this on for_inference) | |
| if RFT_BACKEND == "unsloth": | |
| from unsloth import FastLanguageModel | |
| FastLanguageModel.for_training(model) | |
| else: | |
| model.train() | |
| disable_gradient_checkpointing(model) | |
| sft_output = Path(RFT_OUTPUT_DIR) / "sft_run" | |
| sft_output.mkdir(parents=True, exist_ok=True) | |
| trainer = build_sft_trainer(model, tokenizer, ds, sft_output) | |
| banner(f"Starting SFT on {len(kept)} kept rollouts for {EPOCHS} epochs (lr={SFT_LR})") | |
| trainer.train() | |
| # Save final polished LoRA | |
| final_dir = Path(RFT_OUTPUT_DIR) / "final" | |
| final_dir.mkdir(parents=True, exist_ok=True) | |
| trainer.model.save_pretrained(str(final_dir)) | |
| tokenizer.save_pretrained(str(final_dir)) | |
| logger.info("Saved RFT-polished LoRA to %s", final_dir) | |
| return { | |
| "status": "ok", | |
| "kept": len(kept), | |
| "total": len(all_rollouts), | |
| "epochs": EPOCHS, | |
| "lr": SFT_LR, | |
| "saved_to": str(final_dir), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # 6. Optional HF Hub push | |
| # --------------------------------------------------------------------------- | |
| def maybe_push_to_hub() -> None: | |
| final_dir = Path(RFT_OUTPUT_DIR) / "final" | |
| if not (HF_TOKEN and HF_REPO and final_dir.exists()): | |
| logger.info("Skipping HF Hub push (missing HF_TOKEN/HF_REPO or no final/ dir)") | |
| return | |
| banner(f"Uploading {final_dir} -> https://huggingface.co/{HF_REPO}") | |
| from huggingface_hub import HfApi, create_repo | |
| create_repo(HF_REPO, token=HF_TOKEN, exist_ok=True, private=False) | |
| HfApi().upload_folder( | |
| folder_path = str(final_dir), | |
| repo_id = HF_REPO, | |
| token = HF_TOKEN, | |
| commit_message = "Upload RFT-polished LoRA (rejection-sampling fine-tune)", | |
| ) | |
| logger.info("Upload complete: https://huggingface.co/%s", HF_REPO) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| banner("RFT Polish — config") | |
| for k, v in { | |
| "LORA_PATH": LORA_PATH, | |
| "MODEL_NAME": MODEL_NAME, | |
| "RFT_OUTPUT_DIR": RFT_OUTPUT_DIR, | |
| "NUM_ROLLOUTS_PER_TASK": NUM_ROLLOUTS_PER_TASK, | |
| "MAX_NEW_TOKENS": MAX_NEW_TOKENS, | |
| "GEN_TEMPERATURE": GEN_TEMPERATURE, | |
| "GEN_TOP_P": GEN_TOP_P, | |
| "MIN_SCORE": MIN_SCORE, | |
| "MAX_FP": MAX_FP, | |
| "EPOCHS": EPOCHS, | |
| "SFT_LR": SFT_LR, | |
| "HF_REPO": HF_REPO or "(skip)", | |
| "REUSE_RFT_ROLLOUTS": REUSE_RFT_ROLLOUTS, | |
| "RFT_BACKEND": RFT_BACKEND, | |
| }.items(): | |
| logger.info(" %-22s = %s", k, v) | |
| Path(RFT_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | |
| model, tokenizer = load_policy() | |
| # Persist all rollouts (for proof pack) | |
| rollouts_file = Path(RFT_OUTPUT_DIR) / "rollouts.jsonl" | |
| if REUSE_RFT_ROLLOUTS and rollouts_file.exists(): | |
| all_rollouts = load_existing_rollouts(rollouts_file) | |
| logger.info("Reusing %d saved rollouts from %s", len(all_rollouts), rollouts_file) | |
| else: | |
| all_rollouts = generate_and_score(model, tokenizer) | |
| with rollouts_file.open("w") as fh: | |
| for r in all_rollouts: | |
| fh.write(json.dumps(r) + "\n") | |
| logger.info("Wrote %d rollouts to %s", len(all_rollouts), rollouts_file) | |
| # Per-task summary BEFORE filtering | |
| by_task = defaultdict(list) | |
| for r in all_rollouts: | |
| by_task[r["task_id"]].append(r) | |
| banner("Per-task generation stats") | |
| for task_id, rs in by_task.items(): | |
| scores = [r["score"] for r in rs] | |
| fps = [r["fp"] for r in rs] | |
| kept = sum(1 for r in rs if r["kept"]) | |
| logger.info( | |
| " %-30s n=%2d mean_score=%.3f mean_fp=%.1f kept=%d", | |
| task_id, len(rs), sum(scores)/max(1, len(rs)), sum(fps)/max(1, len(rs)), kept, | |
| ) | |
| # SFT on the kept rollouts | |
| sft_summary = filter_and_sft(model, tokenizer, all_rollouts) | |
| # Persist filtered SFT dataset for transparency | |
| kept_file = Path(RFT_OUTPUT_DIR) / "sft_dataset.jsonl" | |
| with kept_file.open("w") as fh: | |
| for r in all_rollouts: | |
| if r["kept"]: | |
| fh.write(json.dumps(r) + "\n") | |
| logger.info("Wrote %d kept samples to %s", sum(1 for r in all_rollouts if r["kept"]), kept_file) | |
| # Final summary | |
| summary = { | |
| "config": { | |
| "LORA_PATH": LORA_PATH, | |
| "MODEL_NAME": MODEL_NAME, | |
| "NUM_ROLLOUTS_PER_TASK": NUM_ROLLOUTS_PER_TASK, | |
| "MIN_SCORE": MIN_SCORE, | |
| "MAX_FP": MAX_FP, | |
| "EPOCHS": EPOCHS, | |
| "SFT_LR": SFT_LR, | |
| }, | |
| "rollout_stats": { | |
| "total": len(all_rollouts), | |
| "kept": sum(1 for r in all_rollouts if r["kept"]), | |
| "mean_score_total": sum(r["score"] for r in all_rollouts) / max(1, len(all_rollouts)), | |
| "mean_fp_total": sum(r["fp"] for r in all_rollouts) / max(1, len(all_rollouts)), | |
| "mean_score_kept": ( | |
| sum(r["score"] for r in all_rollouts if r["kept"]) / | |
| max(1, sum(1 for r in all_rollouts if r["kept"])) | |
| ), | |
| "mean_fp_kept": ( | |
| sum(r["fp"] for r in all_rollouts if r["kept"]) / | |
| max(1, sum(1 for r in all_rollouts if r["kept"])) | |
| ), | |
| "task_breakdown": { | |
| t: { | |
| "n": len(rs), | |
| "mean_score": sum(r["score"] for r in rs) / max(1, len(rs)), | |
| "mean_fp": sum(r["fp"] for r in rs) / max(1, len(rs)), | |
| "kept": sum(1 for r in rs if r["kept"]), | |
| } | |
| for t, rs in by_task.items() | |
| }, | |
| }, | |
| "sft": sft_summary, | |
| } | |
| summary_file = Path(RFT_OUTPUT_DIR) / "rft_summary.json" | |
| summary_file.write_text(json.dumps(summary, indent=2)) | |
| logger.info("Wrote summary to %s", summary_file) | |
| maybe_push_to_hub() | |
| banner("RFT polish complete") | |
| logger.info("Final LoRA: %s/final", RFT_OUTPUT_DIR) | |
| logger.info("Summary: %s", summary_file) | |
| if HF_REPO: | |
| logger.info("HF Hub: https://huggingface.co/%s", HF_REPO) | |
| if __name__ == "__main__": | |
| main() | |