File size: 8,404 Bytes
8760c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Self-Healing Training System (SHTS)

> **Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with `SelfHealingTrainer`, and cut debugging costs to near zero.**

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![HF Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)](https://huggingface.co/ScottzillaSystems/self-healing-training)

---

## The Problem

ML training fails constantly:
- **CUDA OOM** kills jobs at step 847/1000 β€” restart from scratch
- **NaN loss** silently corrupts models β€” discovered hours later
- **Loss spikes** cascade into divergence β€” manual intervention required
- **DPO plateau** at 0.693 loss (= random chance) β€” wasted GPU hours
- **No postmortem** β€” "what step did it die on?"

Each failure costs **developer time + GPU credits + schedule delay**. At scale, this is millions in wasted compute.

## The Solution

SHTS wraps any Hugging Face TRL trainer with four autonomous layers:

```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  LAYER 4: ORCHESTRATION                 β”‚
β”‚  SelfHealingTrainer retry loop          β”‚
β”‚  while not converged: try β†’ recover     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 3: RECOVERY                      β”‚
β”‚  HealingActions: rollback, halve LR,    β”‚
β”‚  halve batch, reclip, clear cache       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 2: DIAGNOSIS                     β”‚
β”‚  Root-cause classifier: NaN/divergence/ β”‚
β”‚  OOM/data/API β€” with literature refs    β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  LAYER 1: DETECTION                     β”‚
β”‚  SelfHealingCallback: loss, gradients,  β”‚
β”‚  memory, ZClip adaptive clipping        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```

## Quick Start

```bash
pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
```

```python
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig

# Your normal training setup
trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        output_dir="./output",
        learning_rate=2e-5,
        per_device_train_batch_size=4,
    ),
    train_dataset=dataset,
    tokenizer=tokenizer,
)

# Wrap with self-healing β€” that's it!
sh = SelfHealingTrainer(
    trainer,
    HealingConfig(
        max_recovery_attempts=5,
        zclip_enabled=True,
    ),
)

# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)

# Train with full autonomy
result = sh.train()
```

## What Handles What

| Failure | Detection | Recovery | Paper |
|---------|-----------|----------|-------|
| **NaN loss** | `math.isnan(loss)` after each step | Rollback β†’ halve LR β†’ enable grad clip | ZClip arxiv:2504.02507 |
| **CUDA OOM** | `on_exception` catches `OutOfMemoryError` | Halve batch (preserve effective via GA) β†’ gradient checkpointing β†’ clear cache | Unicron arxiv:2401.00134 |
| **Loss spike** | Loss > 5Γ— running mean over window | ZClip adaptive gradient clipping β†’ emergency checkpoint | ZClip arxiv:2504.02507 |
| **Divergence** | Loss increasing for N consecutive steps | Rollback β†’ halve LR | Pioneer Agent arxiv:2604.09791 |
| **Gradient explosion** | `grad_norm > 100` | ZClip β†’ enable max_grad_norm=1.0 | AdaGC arxiv:2502.11034 |
| **DPO plateau** | `loss β‰ˆ 0.693` (random chance) | Increase LR 2-5Γ— β†’ check data quality | Rafailov et al. (2023) |
| **Overfitting** | `eval_loss - train_loss > 2.0` | Alert with actionable recommendation | Standard practice |
| **API errors** | Exception with "api/network/timeout" | Exponential backoff (30s β†’ 60s β†’ 120s β†’ ...) | Standard pattern |
| **Data errors** | Exception with "shape/dimension/index" | Skip batch β†’ log bad sample | Deep Researcher arxiv:2604.05854 |
| **Crash postmortem** | Always | `postmortem.json` with exit reason, last step, metrics, recovery history | PTT pattern |

## Crash Postmortem

Every training interruption produces a `postmortem.json`:

```json
{
  "exit_reason": "exception",
  "exception_type": "OutOfMemoryError",
  "last_step": 847,
  "timestamp": "2026-04-30T15:26:04Z",
  "final_metrics": {"loss": 2.15, "grad_norm": 42.3},
  "recovery_actions": [
    {
      "failure": "oom",
      "diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
      "actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
    }
  ],
  "running_time_seconds": 1847.3
}
```

## Trackio Integration

Set `report_to="trackio"` in your training args. SHTS emits:

- **Alerts** at every decision point (INFO/WARN/ERROR)
- **Metrics**: `healing/recovery_attempts`, `healing/nan_count`, `healing/loss_spike_ratio`, `healing/eval_gap`
- **ZClip metrics**: `zclip/raw_grad_norm`, `zclip/clipped_grad_norm`, `zclip/z_score`, `zclip/total_clips`

Dashboard URL: `https://huggingface.co/spaces/<username>/<trackio-space>`

## HealingConfig Presets

```python
# Aggressive β€” for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10

# Conservative β€” only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2

# Custom
config = HealingConfig(
    nan_patience=5,
    loss_spike_factor=8.0,
    divergence_patience=100,
    max_recovery_attempts=3,
    zclip_enabled=True,
    zclip_z_threshold=3.0,
)
```

## Compatibility

| Trainer | Status | Notes |
|---------|--------|-------|
| `SFTTrainer` (TRL) | βœ… Full | All metrics captured |
| `DPOTrainer` (TRL) | βœ… Full | DPO plateau detection (lossβ‰ˆ0.693) |
| `GRPOTrainer` (TRL) | βœ… Full | Group reward monitoring |
| `PPOTrainer` (TRL) | βœ… Full | KL divergence tracking |
| `ORPOTrainer` (TRL) | βœ… Full | Odds ratio monitoring |
| `KTOTrainer` (TRL) | βœ… Full | Desirable/undesirable logps |
| `CPOTrainer` (TRL) | βœ… Full | Contrastive preference |
| `Trainer` (Transformers) | βœ… Full | Standard ML training |

## Architecture

```
SelfHealingTrainer.train()
  β”‚
  β”œβ”€β”€ dry_run()              ← Validate setup first
  β”‚
  └── while not converged:
      β”‚
      β”œβ”€β”€ trainer.train()    ← Run training
      β”‚     β”‚
      β”‚     β”œβ”€β”€ on_step_end  ← Detect NaN, spikes, divergence
      β”‚     β”œβ”€β”€ on_log       ← Monitor gradients (ZClip)
      β”‚     β”œβ”€β”€ on_evaluate  ← Check overfitting
      β”‚     └── on_exception ← Catch OOM, API, data errors
      β”‚
      β”œβ”€β”€ [recovery needed?]
      β”‚     β”œβ”€β”€ diagnose     ← Classify failure type
      β”‚     β”œβ”€β”€ heal         ← Apply recovery actions
      β”‚     └── retry        ← resume_from_checkpoint=True
      β”‚
      └── [converged]        ← Done!
```

## References

| Paper | ID | Contribution |
|-------|-----|-------------|
| Unicron | arxiv:2401.00134 | Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling |
| ZClip | arxiv:2504.02507 | Z-score adaptive gradient clipping, eliminates catastrophic loss spikes |
| AdaGC | arxiv:2502.11034 | Per-tensor adaptive gradient clipping, optimizer-agnostic |
| Pioneer Agent | arxiv:2604.09791 | Structured decision tree by score buckets for autonomous iteration |
| Deep Researcher | arxiv:2604.05854 | Dry-run validation, zero-cost monitoring, constant-size memory |
| CheckFree | arxiv:2506.15461 | Pipeline-parallel recovery via neighbor averaging |
| DPO | Rafailov et al. (2023) | DPO plateau at 0.693 = random chance (Section 4.2) |
| PTT | [post-training-toolkit](https://github.com/microsoft/post-training-toolkit) | DiagnosticsCallback + postmortem pattern |

## License

MIT β€” use freely, attribution appreciated.

---

Built autonomously by ML Intern. Questions? Open an issue on the Hub.