File size: 3,960 Bytes
1da48e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""
Production Job Runner: Self-Healing Training
============================================
Complete self-healing training job with Trackio monitoring.

Pre-flight:
  ✓ Dataset: trl-lib/Capybara (messages format, SFT-compatible)
  ✓ Model: Qwen2.5-0.5B
  ✓ Trackio monitoring
  ✓ Timeout: 2h for 0.5B model on a10g-large
"""
import os, sys, json, time, math, gc
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from self_healing import SelfHealingTrainer, HealingConfig

def main():
    print("\n" + "=" * 70)
    print("  SELF-HEALING TRAINING JOB")
    print("=" * 70 + "\n")
    
    model_id = "Qwen/Qwen2.5-0.5B"
    print(f"[1/4] Loading {model_id}")
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    
    print("[2/4] Loading trl-lib/Capybara")
    dataset = load_dataset("trl-lib/Capybara", split="train[:5000]")
    print(f"  {len(dataset)} samples")
    
    from trl import SFTConfig, SFTTrainer
    training_args = SFTConfig(
        output_dir="./output",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        max_steps=500,
        logging_steps=10,
        logging_strategy="steps",
        logging_first_step=True,
        save_steps=200, save_total_limit=3,
        bf16=True, gradient_checkpointing=True,
        warmup_steps=50, lr_scheduler_type="cosine",
        report_to="trackio",
        run_name="selfheal-qwen0.5b-sft",
        project="self-healing-system",
        trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", ""),
        push_to_hub=False, disable_tqdm=True,
    )
    
    trainer = SFTTrainer(
        model=model, args=training_args,
        train_dataset=dataset, tokenizer=tokenizer)
    
    print("[3/4] Configuring self-healing")
    healing_config = HealingConfig(
        nan_patience=3, loss_spike_factor=5.0, divergence_patience=50,
        grad_explosion_threshold=100.0, zclip_enabled=True,
        zclip_z_threshold=3.0, max_recovery_attempts=5,
        max_lr_reductions=3, max_batch_reductions=2,
        postmortem_path="./postmortem.json",
    )
    sh_trainer = SelfHealingTrainer(trainer, healing_config)
    
    print("[4/4] Dry-run validation")
    try:
        sh_trainer.dry_run(num_steps=2)
        print("  ✓ Dry-run passed!\n")
    except Exception as e:
        print(f"  ✗ Dry-run FAILED: {e}")
        sys.exit(1)
    
    print("Training with autonomous self-healing...\n")
    start = time.time()
    result = sh_trainer.train()
    elapsed = time.time() - start
    
    print("\n" + "=" * 70)
    print("  TRAINING COMPLETE")
    print("=" * 70)
    report = sh_trainer.get_report()
    print(f"  Converged: {sh_trainer.converged} | Attempts: {sh_trainer.attempt}")
    print(f"  Recoveries: {report['total_recoveries']} | ZClip: {report['zclip_total_clips']}")
    print(f"  NaN: {report['nan_count']} | LR cuts: {report['lr_reductions']}")
    print(f"  Elapsed: {elapsed:.0f}s ({elapsed/60:.1f}min)")
    
    if report["recovery_history"]:
        print(f"\n  Recovery log ({len(report['recovery_history'])}):")
        for i, rec in enumerate(report["recovery_history"]):
            print(f"    [{i+1}] {rec['failure']}: {rec['actions']}")
    
    if os.path.exists(healing_config.postmortem_path):
        with open(healing_config.postmortem_path) as f:
            pm = json.load(f)
        print(f"\n  Postmortem: {pm.get('exit_reason')} at step {pm.get('last_step')}")
    
    space = os.environ.get("TRACKIO_SPACE_ID", "mlintern/selfheal-demo")
    print(f"\n  Dashboard: https://huggingface.co/spaces/{space}")

if __name__ == "__main__":
    main()