ScottzillaSystems commited on
Commit
1da48e5
·
verified ·
1 Parent(s): 354e067

Upload run_self_healing_job.py

Browse files
Files changed (1) hide show
  1. run_self_healing_job.py +107 -0
run_self_healing_job.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Production Job Runner: Self-Healing Training
4
+ ============================================
5
+ Complete self-healing training job with Trackio monitoring.
6
+
7
+ Pre-flight:
8
+ ✓ Dataset: trl-lib/Capybara (messages format, SFT-compatible)
9
+ ✓ Model: Qwen2.5-0.5B
10
+ ✓ Trackio monitoring
11
+ ✓ Timeout: 2h for 0.5B model on a10g-large
12
+ """
13
+ import os, sys, json, time, math, gc
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ from datasets import load_dataset
17
+
18
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
19
+ from self_healing import SelfHealingTrainer, HealingConfig
20
+
21
+ def main():
22
+ print("\n" + "=" * 70)
23
+ print(" SELF-HEALING TRAINING JOB")
24
+ print("=" * 70 + "\n")
25
+
26
+ model_id = "Qwen/Qwen2.5-0.5B"
27
+ print(f"[1/4] Loading {model_id}")
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id, torch_dtype=torch.bfloat16, device_map="auto")
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
32
+
33
+ print("[2/4] Loading trl-lib/Capybara")
34
+ dataset = load_dataset("trl-lib/Capybara", split="train[:5000]")
35
+ print(f" {len(dataset)} samples")
36
+
37
+ from trl import SFTConfig, SFTTrainer
38
+ training_args = SFTConfig(
39
+ output_dir="./output",
40
+ per_device_train_batch_size=4,
41
+ gradient_accumulation_steps=4,
42
+ learning_rate=2e-5,
43
+ max_steps=500,
44
+ logging_steps=10,
45
+ logging_strategy="steps",
46
+ logging_first_step=True,
47
+ save_steps=200, save_total_limit=3,
48
+ bf16=True, gradient_checkpointing=True,
49
+ warmup_steps=50, lr_scheduler_type="cosine",
50
+ report_to="trackio",
51
+ run_name="selfheal-qwen0.5b-sft",
52
+ project="self-healing-system",
53
+ trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", ""),
54
+ push_to_hub=False, disable_tqdm=True,
55
+ )
56
+
57
+ trainer = SFTTrainer(
58
+ model=model, args=training_args,
59
+ train_dataset=dataset, tokenizer=tokenizer)
60
+
61
+ print("[3/4] Configuring self-healing")
62
+ healing_config = HealingConfig(
63
+ nan_patience=3, loss_spike_factor=5.0, divergence_patience=50,
64
+ grad_explosion_threshold=100.0, zclip_enabled=True,
65
+ zclip_z_threshold=3.0, max_recovery_attempts=5,
66
+ max_lr_reductions=3, max_batch_reductions=2,
67
+ postmortem_path="./postmortem.json",
68
+ )
69
+ sh_trainer = SelfHealingTrainer(trainer, healing_config)
70
+
71
+ print("[4/4] Dry-run validation")
72
+ try:
73
+ sh_trainer.dry_run(num_steps=2)
74
+ print(" ✓ Dry-run passed!\n")
75
+ except Exception as e:
76
+ print(f" ✗ Dry-run FAILED: {e}")
77
+ sys.exit(1)
78
+
79
+ print("Training with autonomous self-healing...\n")
80
+ start = time.time()
81
+ result = sh_trainer.train()
82
+ elapsed = time.time() - start
83
+
84
+ print("\n" + "=" * 70)
85
+ print(" TRAINING COMPLETE")
86
+ print("=" * 70)
87
+ report = sh_trainer.get_report()
88
+ print(f" Converged: {sh_trainer.converged} | Attempts: {sh_trainer.attempt}")
89
+ print(f" Recoveries: {report['total_recoveries']} | ZClip: {report['zclip_total_clips']}")
90
+ print(f" NaN: {report['nan_count']} | LR cuts: {report['lr_reductions']}")
91
+ print(f" Elapsed: {elapsed:.0f}s ({elapsed/60:.1f}min)")
92
+
93
+ if report["recovery_history"]:
94
+ print(f"\n Recovery log ({len(report['recovery_history'])}):")
95
+ for i, rec in enumerate(report["recovery_history"]):
96
+ print(f" [{i+1}] {rec['failure']}: {rec['actions']}")
97
+
98
+ if os.path.exists(healing_config.postmortem_path):
99
+ with open(healing_config.postmortem_path) as f:
100
+ pm = json.load(f)
101
+ print(f"\n Postmortem: {pm.get('exit_reason')} at step {pm.get('last_step')}")
102
+
103
+ space = os.environ.get("TRACKIO_SPACE_ID", "mlintern/selfheal-demo")
104
+ print(f"\n Dashboard: https://huggingface.co/spaces/{space}")
105
+
106
+ if __name__ == "__main__":
107
+ main()