File size: 8,404 Bytes
8760c42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | # Self-Healing Training System (SHTS)
> **Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with `SelfHealingTrainer`, and cut debugging costs to near zero.**
[](https://opensource.org/licenses/MIT)
[](https://huggingface.co/ScottzillaSystems/self-healing-training)
---
## The Problem
ML training fails constantly:
- **CUDA OOM** kills jobs at step 847/1000 β restart from scratch
- **NaN loss** silently corrupts models β discovered hours later
- **Loss spikes** cascade into divergence β manual intervention required
- **DPO plateau** at 0.693 loss (= random chance) β wasted GPU hours
- **No postmortem** β "what step did it die on?"
Each failure costs **developer time + GPU credits + schedule delay**. At scale, this is millions in wasted compute.
## The Solution
SHTS wraps any Hugging Face TRL trainer with four autonomous layers:
```
βββββββββββββββββββββββββββββββββββββββββββ
β LAYER 4: ORCHESTRATION β
β SelfHealingTrainer retry loop β
β while not converged: try β recover β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 3: RECOVERY β
β HealingActions: rollback, halve LR, β
β halve batch, reclip, clear cache β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 2: DIAGNOSIS β
β Root-cause classifier: NaN/divergence/ β
β OOM/data/API β with literature refs β
βββββββββββββββββββββββββββββββββββββββββββ€
β LAYER 1: DETECTION β
β SelfHealingCallback: loss, gradients, β
β memory, ZClip adaptive clipping β
βββββββββββββββββββββββββββββββββββββββββββ
```
## Quick Start
```bash
pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
```
```python
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig
# Your normal training setup
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
),
train_dataset=dataset,
tokenizer=tokenizer,
)
# Wrap with self-healing β that's it!
sh = SelfHealingTrainer(
trainer,
HealingConfig(
max_recovery_attempts=5,
zclip_enabled=True,
),
)
# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)
# Train with full autonomy
result = sh.train()
```
## What Handles What
| Failure | Detection | Recovery | Paper |
|---------|-----------|----------|-------|
| **NaN loss** | `math.isnan(loss)` after each step | Rollback β halve LR β enable grad clip | ZClip arxiv:2504.02507 |
| **CUDA OOM** | `on_exception` catches `OutOfMemoryError` | Halve batch (preserve effective via GA) β gradient checkpointing β clear cache | Unicron arxiv:2401.00134 |
| **Loss spike** | Loss > 5Γ running mean over window | ZClip adaptive gradient clipping β emergency checkpoint | ZClip arxiv:2504.02507 |
| **Divergence** | Loss increasing for N consecutive steps | Rollback β halve LR | Pioneer Agent arxiv:2604.09791 |
| **Gradient explosion** | `grad_norm > 100` | ZClip β enable max_grad_norm=1.0 | AdaGC arxiv:2502.11034 |
| **DPO plateau** | `loss β 0.693` (random chance) | Increase LR 2-5Γ β check data quality | Rafailov et al. (2023) |
| **Overfitting** | `eval_loss - train_loss > 2.0` | Alert with actionable recommendation | Standard practice |
| **API errors** | Exception with "api/network/timeout" | Exponential backoff (30s β 60s β 120s β ...) | Standard pattern |
| **Data errors** | Exception with "shape/dimension/index" | Skip batch β log bad sample | Deep Researcher arxiv:2604.05854 |
| **Crash postmortem** | Always | `postmortem.json` with exit reason, last step, metrics, recovery history | PTT pattern |
## Crash Postmortem
Every training interruption produces a `postmortem.json`:
```json
{
"exit_reason": "exception",
"exception_type": "OutOfMemoryError",
"last_step": 847,
"timestamp": "2026-04-30T15:26:04Z",
"final_metrics": {"loss": 2.15, "grad_norm": 42.3},
"recovery_actions": [
{
"failure": "oom",
"diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
"actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
}
],
"running_time_seconds": 1847.3
}
```
## Trackio Integration
Set `report_to="trackio"` in your training args. SHTS emits:
- **Alerts** at every decision point (INFO/WARN/ERROR)
- **Metrics**: `healing/recovery_attempts`, `healing/nan_count`, `healing/loss_spike_ratio`, `healing/eval_gap`
- **ZClip metrics**: `zclip/raw_grad_norm`, `zclip/clipped_grad_norm`, `zclip/z_score`, `zclip/total_clips`
Dashboard URL: `https://huggingface.co/spaces/<username>/<trackio-space>`
## HealingConfig Presets
```python
# Aggressive β for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10
# Conservative β only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2
# Custom
config = HealingConfig(
nan_patience=5,
loss_spike_factor=8.0,
divergence_patience=100,
max_recovery_attempts=3,
zclip_enabled=True,
zclip_z_threshold=3.0,
)
```
## Compatibility
| Trainer | Status | Notes |
|---------|--------|-------|
| `SFTTrainer` (TRL) | β
Full | All metrics captured |
| `DPOTrainer` (TRL) | β
Full | DPO plateau detection (lossβ0.693) |
| `GRPOTrainer` (TRL) | β
Full | Group reward monitoring |
| `PPOTrainer` (TRL) | β
Full | KL divergence tracking |
| `ORPOTrainer` (TRL) | β
Full | Odds ratio monitoring |
| `KTOTrainer` (TRL) | β
Full | Desirable/undesirable logps |
| `CPOTrainer` (TRL) | β
Full | Contrastive preference |
| `Trainer` (Transformers) | β
Full | Standard ML training |
## Architecture
```
SelfHealingTrainer.train()
β
βββ dry_run() β Validate setup first
β
βββ while not converged:
β
βββ trainer.train() β Run training
β β
β βββ on_step_end β Detect NaN, spikes, divergence
β βββ on_log β Monitor gradients (ZClip)
β βββ on_evaluate β Check overfitting
β βββ on_exception β Catch OOM, API, data errors
β
βββ [recovery needed?]
β βββ diagnose β Classify failure type
β βββ heal β Apply recovery actions
β βββ retry β resume_from_checkpoint=True
β
βββ [converged] β Done!
```
## References
| Paper | ID | Contribution |
|-------|-----|-------------|
| Unicron | arxiv:2401.00134 | Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling |
| ZClip | arxiv:2504.02507 | Z-score adaptive gradient clipping, eliminates catastrophic loss spikes |
| AdaGC | arxiv:2502.11034 | Per-tensor adaptive gradient clipping, optimizer-agnostic |
| Pioneer Agent | arxiv:2604.09791 | Structured decision tree by score buckets for autonomous iteration |
| Deep Researcher | arxiv:2604.05854 | Dry-run validation, zero-cost monitoring, constant-size memory |
| CheckFree | arxiv:2506.15461 | Pipeline-parallel recovery via neighbor averaging |
| DPO | Rafailov et al. (2023) | DPO plateau at 0.693 = random chance (Section 4.2) |
| PTT | [post-training-toolkit](https://github.com/microsoft/post-training-toolkit) | DiagnosticsCallback + postmortem pattern |
## License
MIT β use freely, attribution appreciated.
---
Built autonomously by ML Intern. Questions? Open an issue on the Hub. |