#!/usr/bin/env python3 """ Demo: Self-Healing DPO Training =============================== Loads a pretrained model, does DPO with full self-healing. DPO-specific: detects plateau at loss≈0.693 (random chance). Usage: python demo_dpo_self_healing.py """ import os, sys, json, time import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from self_healing import SelfHealingTrainer, HealingConfig def main(): print("\n" + "=" * 60) print(" SELF-HEALING DPO TRAINING DEMO") print("=" * 60 + "\n") model_id = "Qwen/Qwen2.5-0.5B" print(f"[1/4] Loading model: {model_id}") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # DPO dataset: needs "prompt", "chosen", "rejected" print("[2/4] Loading DPO dataset: trl-lib/ultrafeedback_binarized") dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:2000]") from trl import DPOConfig, DPOTrainer training_args = DPOConfig( output_dir="./dpo-output", per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=5e-6, max_steps=200, logging_steps=10, logging_strategy="steps", logging_first_step=True, save_steps=500, bf16=True, beta=0.1, # DPO temperature report_to="none", run_name="selfheal-dpo-demo", disable_tqdm=True, ) trainer = DPOTrainer( model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer, ) # Self-healing: more aggressive for DPO (plateau detection is critical) print("[3/4] Wrapping with SelfHealingTrainer...") healing_config = HealingConfig( nan_patience=3, loss_spike_factor=5.0, divergence_patience=50, max_recovery_attempts=5, max_lr_reductions=3, max_batch_reductions=2, zclip_enabled=True, zclip_z_threshold=3.0, postmortem_path="./dpo-postmortem.json", ) sh_trainer = SelfHealingTrainer(trainer, healing_config) # Dry-run try: sh_trainer.dry_run(num_steps=2) print(" ✓ Dry-run passed!\n") except Exception as e: print(f" ✗ Dry-run failed: {e}") sys.exit(1) # Train print("[4/4] Training DPO with self-healing...\n") result = sh_trainer.train() # Report print("\n" + "=" * 60) print(" DPO DEMO COMPLETE") print("=" * 60) report = sh_trainer.get_report() print(f" Converged: {report['converged']}") print(f" Attempts: {report['attempts']}") print(f" Recoveries: {report['total_recoveries']}") if report["recovery_history"]: print("\n Recovery log:") for i, rec in enumerate(report["recovery_history"]): print(f" [{i+1}] {rec['failure']}: {rec['actions']}") if os.path.exists(healing_config.postmortem_path): with open(healing_config.postmortem_path) as f: pm = json.load(f) print(f"\n Postmortem: {pm.get('exit_reason', 'unknown')} " f"at step {pm.get('last_step', '?')}") if __name__ == "__main__": main()