ScottzillaSystems commited on
Commit
fa60c5e
·
verified ·
1 Parent(s): cf520c9

Upload examples/demo_sft_self_healing.py

Browse files
Files changed (1) hide show
  1. examples/demo_sft_self_healing.py +121 -0
examples/demo_sft_self_healing.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo: Self-Healing SFT Training
4
+ ===============================
5
+ Loads Qwen2.5-0.5B, trains on Capybara dataset with full self-healing.
6
+
7
+ Usage:
8
+ python demo_sft_self_healing.py
9
+
10
+ Requirements:
11
+ pip install transformers trl datasets torch self-healing-training
12
+ """
13
+ import os, sys, json, time
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ from datasets import load_dataset
17
+
18
+ from self_healing import SelfHealingTrainer, HealingConfig
19
+
20
+ def main():
21
+ print("\n" + "=" * 60)
22
+ print(" SELF-HEALING SFT TRAINING DEMO")
23
+ print("=" * 60 + "\n")
24
+
25
+ # Model
26
+ model_id = "Qwen/Qwen2.5-0.5B"
27
+ print(f"[1/4] Loading model: {model_id}")
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto",
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
34
+ if tokenizer.pad_token is None:
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+
37
+ # Dataset
38
+ print("[2/4] Loading dataset: trl-lib/Capybara")
39
+ dataset = load_dataset("trl-lib/Capybara", split="train[:2000]")
40
+
41
+ # Config
42
+ from trl import SFTConfig, SFTTrainer
43
+
44
+ training_args = SFTConfig(
45
+ output_dir="./sft-output",
46
+ per_device_train_batch_size=2,
47
+ gradient_accumulation_steps=4,
48
+ learning_rate=2e-5,
49
+ max_steps=200,
50
+ logging_steps=10,
51
+ logging_strategy="steps",
52
+ logging_first_step=True,
53
+ save_steps=500,
54
+ bf16=True,
55
+ report_to="none", # Set to "trackio" for live monitoring
56
+ run_name="selfheal-sft-demo",
57
+ disable_tqdm=True,
58
+ )
59
+
60
+ trainer = SFTTrainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=dataset,
64
+ tokenizer=tokenizer,
65
+ )
66
+
67
+ # Self-healing wrapper
68
+ print("[3/4] Wrapping with SelfHealingTrainer...")
69
+ healing_config = HealingConfig(
70
+ nan_patience=3,
71
+ loss_spike_factor=5.0,
72
+ divergence_patience=50,
73
+ max_recovery_attempts=5,
74
+ max_lr_reductions=3,
75
+ max_batch_reductions=2,
76
+ zclip_enabled=True,
77
+ zclip_z_threshold=3.0,
78
+ postmortem_path="./sft-postmortem.json",
79
+ )
80
+
81
+ sh_trainer = SelfHealingTrainer(trainer, healing_config)
82
+
83
+ # Dry-run validation
84
+ try:
85
+ sh_trainer.dry_run(num_steps=2)
86
+ print(" ✓ Dry-run passed!\n")
87
+ except Exception as e:
88
+ print(f" ✗ Dry-run failed: {e}")
89
+ sys.exit(1)
90
+
91
+ # Train
92
+ print("[4/4] Training with self-healing...\n")
93
+ result = sh_trainer.train()
94
+
95
+ # Report
96
+ print("\n" + "=" * 60)
97
+ print(" DEMO COMPLETE")
98
+ print("=" * 60)
99
+ report = sh_trainer.get_report()
100
+ print(f" Converged: {report['converged']}")
101
+ print(f" Attempts: {report['attempts']}")
102
+ print(f" Recoveries: {report['total_recoveries']}")
103
+ print(f" ZClip clips: {report['zclip_total_clips']}")
104
+ print(f" NaN count: {report['nan_count']}")
105
+ print(f" LR reductions: {report['lr_reductions']}")
106
+
107
+ if report["recovery_history"]:
108
+ print("\n Recovery log:")
109
+ for i, rec in enumerate(report["recovery_history"]):
110
+ print(f" [{i+1}] {rec['failure']}: {rec['actions']}")
111
+
112
+ # Save postmortem
113
+ if os.path.exists(healing_config.postmortem_path):
114
+ with open(healing_config.postmortem_path) as f:
115
+ pm = json.load(f)
116
+ print(f"\n Postmortem: {pm.get('exit_reason', 'unknown')} "
117
+ f"at step {pm.get('last_step', '?')}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()