#!/usr/bin/env python3 """ Demo: Self-Healing SFT Training =============================== Loads Qwen2.5-0.5B, trains on Capybara dataset with full self-healing. Usage: python demo_sft_self_healing.py Requirements: pip install transformers trl datasets torch self-healing-training """ import os, sys, json, time import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from self_healing import SelfHealingTrainer, HealingConfig def main(): print("\n" + "=" * 60) print(" SELF-HEALING SFT TRAINING DEMO") print("=" * 60 + "\n") # Model model_id = "Qwen/Qwen2.5-0.5B" print(f"[1/4] Loading model: {model_id}") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Dataset print("[2/4] Loading dataset: trl-lib/Capybara") dataset = load_dataset("trl-lib/Capybara", split="train[:2000]") # Config from trl import SFTConfig, SFTTrainer training_args = SFTConfig( output_dir="./sft-output", per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, max_steps=200, logging_steps=10, logging_strategy="steps", logging_first_step=True, save_steps=500, bf16=True, report_to="none", # Set to "trackio" for live monitoring run_name="selfheal-sft-demo", disable_tqdm=True, ) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer, ) # Self-healing wrapper print("[3/4] Wrapping with SelfHealingTrainer...") healing_config = HealingConfig( nan_patience=3, loss_spike_factor=5.0, divergence_patience=50, max_recovery_attempts=5, max_lr_reductions=3, max_batch_reductions=2, zclip_enabled=True, zclip_z_threshold=3.0, postmortem_path="./sft-postmortem.json", ) sh_trainer = SelfHealingTrainer(trainer, healing_config) # Dry-run validation try: sh_trainer.dry_run(num_steps=2) print(" ✓ Dry-run passed!\n") except Exception as e: print(f" ✗ Dry-run failed: {e}") sys.exit(1) # Train print("[4/4] Training with self-healing...\n") result = sh_trainer.train() # Report print("\n" + "=" * 60) print(" DEMO COMPLETE") print("=" * 60) report = sh_trainer.get_report() print(f" Converged: {report['converged']}") print(f" Attempts: {report['attempts']}") print(f" Recoveries: {report['total_recoveries']}") print(f" ZClip clips: {report['zclip_total_clips']}") print(f" NaN count: {report['nan_count']}") print(f" LR reductions: {report['lr_reductions']}") if report["recovery_history"]: print("\n Recovery log:") for i, rec in enumerate(report["recovery_history"]): print(f" [{i+1}] {rec['failure']}: {rec['actions']}") # Save postmortem if os.path.exists(healing_config.postmortem_path): with open(healing_config.postmortem_path) as f: pm = json.load(f) print(f"\n Postmortem: {pm.get('exit_reason', 'unknown')} " f"at step {pm.get('last_step', '?')}") if __name__ == "__main__": main()