| |
| """ |
| Demo: Self-Healing SFT Training |
| =============================== |
| Loads Qwen2.5-0.5B, trains on Capybara dataset with full self-healing. |
| |
| Usage: |
| python demo_sft_self_healing.py |
| |
| Requirements: |
| pip install transformers trl datasets torch self-healing-training |
| """ |
| import os, sys, json, time |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
|
|
| from self_healing import SelfHealingTrainer, HealingConfig |
|
|
| def main(): |
| print("\n" + "=" * 60) |
| print(" SELF-HEALING SFT TRAINING DEMO") |
| print("=" * 60 + "\n") |
| |
| |
| model_id = "Qwen/Qwen2.5-0.5B" |
| print(f"[1/4] Loading model: {model_id}") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| print("[2/4] Loading dataset: trl-lib/Capybara") |
| dataset = load_dataset("trl-lib/Capybara", split="train[:2000]") |
| |
| |
| from trl import SFTConfig, SFTTrainer |
| |
| training_args = SFTConfig( |
| output_dir="./sft-output", |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=4, |
| learning_rate=2e-5, |
| max_steps=200, |
| logging_steps=10, |
| logging_strategy="steps", |
| logging_first_step=True, |
| save_steps=500, |
| bf16=True, |
| report_to="none", |
| run_name="selfheal-sft-demo", |
| disable_tqdm=True, |
| ) |
| |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset, |
| tokenizer=tokenizer, |
| ) |
| |
| |
| print("[3/4] Wrapping with SelfHealingTrainer...") |
| healing_config = HealingConfig( |
| nan_patience=3, |
| loss_spike_factor=5.0, |
| divergence_patience=50, |
| max_recovery_attempts=5, |
| max_lr_reductions=3, |
| max_batch_reductions=2, |
| zclip_enabled=True, |
| zclip_z_threshold=3.0, |
| postmortem_path="./sft-postmortem.json", |
| ) |
| |
| sh_trainer = SelfHealingTrainer(trainer, healing_config) |
| |
| |
| try: |
| sh_trainer.dry_run(num_steps=2) |
| print(" ✓ Dry-run passed!\n") |
| except Exception as e: |
| print(f" ✗ Dry-run failed: {e}") |
| sys.exit(1) |
| |
| |
| print("[4/4] Training with self-healing...\n") |
| result = sh_trainer.train() |
| |
| |
| print("\n" + "=" * 60) |
| print(" DEMO COMPLETE") |
| print("=" * 60) |
| report = sh_trainer.get_report() |
| print(f" Converged: {report['converged']}") |
| print(f" Attempts: {report['attempts']}") |
| print(f" Recoveries: {report['total_recoveries']}") |
| print(f" ZClip clips: {report['zclip_total_clips']}") |
| print(f" NaN count: {report['nan_count']}") |
| print(f" LR reductions: {report['lr_reductions']}") |
| |
| if report["recovery_history"]: |
| print("\n Recovery log:") |
| for i, rec in enumerate(report["recovery_history"]): |
| print(f" [{i+1}] {rec['failure']}: {rec['actions']}") |
| |
| |
| if os.path.exists(healing_config.postmortem_path): |
| with open(healing_config.postmortem_path) as f: |
| pm = json.load(f) |
| print(f"\n Postmortem: {pm.get('exit_reason', 'unknown')} " |
| f"at step {pm.get('last_step', '?')}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |