ScottzillaSystems's picture
Upload README.md
8760c42 verified
# Self-Healing Training System (SHTS)
> **Fully autonomous debugging and error recovery for Hugging Face TRL trainers. Add one callback, wrap with `SelfHealingTrainer`, and cut debugging costs to near zero.**
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![HF Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)](https://huggingface.co/ScottzillaSystems/self-healing-training)
---
## The Problem
ML training fails constantly:
- **CUDA OOM** kills jobs at step 847/1000 β€” restart from scratch
- **NaN loss** silently corrupts models β€” discovered hours later
- **Loss spikes** cascade into divergence β€” manual intervention required
- **DPO plateau** at 0.693 loss (= random chance) β€” wasted GPU hours
- **No postmortem** β€” "what step did it die on?"
Each failure costs **developer time + GPU credits + schedule delay**. At scale, this is millions in wasted compute.
## The Solution
SHTS wraps any Hugging Face TRL trainer with four autonomous layers:
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ LAYER 4: ORCHESTRATION β”‚
β”‚ SelfHealingTrainer retry loop β”‚
β”‚ while not converged: try β†’ recover β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ LAYER 3: RECOVERY β”‚
β”‚ HealingActions: rollback, halve LR, β”‚
β”‚ halve batch, reclip, clear cache β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ LAYER 2: DIAGNOSIS β”‚
β”‚ Root-cause classifier: NaN/divergence/ β”‚
β”‚ OOM/data/API β€” with literature refs β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ LAYER 1: DETECTION β”‚
β”‚ SelfHealingCallback: loss, gradients, β”‚
β”‚ memory, ZClip adaptive clipping β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
## Quick Start
```bash
pip install git+https://huggingface.co/ScottzillaSystems/self-healing-training
```
```python
from self_healing import SelfHealingTrainer, HealingConfig
from trl import SFTTrainer, SFTConfig
# Your normal training setup
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
),
train_dataset=dataset,
tokenizer=tokenizer,
)
# Wrap with self-healing β€” that's it!
sh = SelfHealingTrainer(
trainer,
HealingConfig(
max_recovery_attempts=5,
zclip_enabled=True,
),
)
# Optional: dry-run to catch config errors before full training
sh.dry_run(num_steps=2)
# Train with full autonomy
result = sh.train()
```
## What Handles What
| Failure | Detection | Recovery | Paper |
|---------|-----------|----------|-------|
| **NaN loss** | `math.isnan(loss)` after each step | Rollback β†’ halve LR β†’ enable grad clip | ZClip arxiv:2504.02507 |
| **CUDA OOM** | `on_exception` catches `OutOfMemoryError` | Halve batch (preserve effective via GA) β†’ gradient checkpointing β†’ clear cache | Unicron arxiv:2401.00134 |
| **Loss spike** | Loss > 5Γ— running mean over window | ZClip adaptive gradient clipping β†’ emergency checkpoint | ZClip arxiv:2504.02507 |
| **Divergence** | Loss increasing for N consecutive steps | Rollback β†’ halve LR | Pioneer Agent arxiv:2604.09791 |
| **Gradient explosion** | `grad_norm > 100` | ZClip β†’ enable max_grad_norm=1.0 | AdaGC arxiv:2502.11034 |
| **DPO plateau** | `loss β‰ˆ 0.693` (random chance) | Increase LR 2-5Γ— β†’ check data quality | Rafailov et al. (2023) |
| **Overfitting** | `eval_loss - train_loss > 2.0` | Alert with actionable recommendation | Standard practice |
| **API errors** | Exception with "api/network/timeout" | Exponential backoff (30s β†’ 60s β†’ 120s β†’ ...) | Standard pattern |
| **Data errors** | Exception with "shape/dimension/index" | Skip batch β†’ log bad sample | Deep Researcher arxiv:2604.05854 |
| **Crash postmortem** | Always | `postmortem.json` with exit reason, last step, metrics, recovery history | PTT pattern |
## Crash Postmortem
Every training interruption produces a `postmortem.json`:
```json
{
"exit_reason": "exception",
"exception_type": "OutOfMemoryError",
"last_step": 847,
"timestamp": "2026-04-30T15:26:04Z",
"final_metrics": {"loss": 2.15, "grad_norm": 42.3},
"recovery_actions": [
{
"failure": "oom",
"diagnosis": "CUDA Out of Memory. Batch size exceeds GPU capacity.",
"actions": ["halve_batch_size", "enable_gradient_checkpointing", "clear_cache"]
}
],
"running_time_seconds": 1847.3
}
```
## Trackio Integration
Set `report_to="trackio"` in your training args. SHTS emits:
- **Alerts** at every decision point (INFO/WARN/ERROR)
- **Metrics**: `healing/recovery_attempts`, `healing/nan_count`, `healing/loss_spike_ratio`, `healing/eval_gap`
- **ZClip metrics**: `zclip/raw_grad_norm`, `zclip/clipped_grad_norm`, `zclip/z_score`, `zclip/total_clips`
Dashboard URL: `https://huggingface.co/spaces/<username>/<trackio-space>`
## HealingConfig Presets
```python
# Aggressive β€” for unstable training, low tolerance
config = HealingConfig.aggressive()
# nan_patience=1, zclip_z_threshold=2.0, max_recovery_attempts=10
# Conservative β€” only intervene on clear failures
config = HealingConfig.conservative()
# nan_patience=10, loss_spike_factor=10.0, zclip_z_threshold=4.0, max_recovery_attempts=2
# Custom
config = HealingConfig(
nan_patience=5,
loss_spike_factor=8.0,
divergence_patience=100,
max_recovery_attempts=3,
zclip_enabled=True,
zclip_z_threshold=3.0,
)
```
## Compatibility
| Trainer | Status | Notes |
|---------|--------|-------|
| `SFTTrainer` (TRL) | βœ… Full | All metrics captured |
| `DPOTrainer` (TRL) | βœ… Full | DPO plateau detection (lossβ‰ˆ0.693) |
| `GRPOTrainer` (TRL) | βœ… Full | Group reward monitoring |
| `PPOTrainer` (TRL) | βœ… Full | KL divergence tracking |
| `ORPOTrainer` (TRL) | βœ… Full | Odds ratio monitoring |
| `KTOTrainer` (TRL) | βœ… Full | Desirable/undesirable logps |
| `CPOTrainer` (TRL) | βœ… Full | Contrastive preference |
| `Trainer` (Transformers) | βœ… Full | Standard ML training |
## Architecture
```
SelfHealingTrainer.train()
β”‚
β”œβ”€β”€ dry_run() ← Validate setup first
β”‚
└── while not converged:
β”‚
β”œβ”€β”€ trainer.train() ← Run training
β”‚ β”‚
β”‚ β”œβ”€β”€ on_step_end ← Detect NaN, spikes, divergence
β”‚ β”œβ”€β”€ on_log ← Monitor gradients (ZClip)
β”‚ β”œβ”€β”€ on_evaluate ← Check overfitting
β”‚ └── on_exception ← Catch OOM, API, data errors
β”‚
β”œβ”€β”€ [recovery needed?]
β”‚ β”œβ”€β”€ diagnose ← Classify failure type
β”‚ β”œβ”€β”€ heal ← Apply recovery actions
β”‚ └── retry ← resume_from_checkpoint=True
β”‚
└── [converged] ← Done!
```
## References
| Paper | ID | Contribution |
|-------|-----|-------------|
| Unicron | arxiv:2401.00134 | Cost-aware self-healing at cluster scale, error taxonomy (4 types), elastic scaling |
| ZClip | arxiv:2504.02507 | Z-score adaptive gradient clipping, eliminates catastrophic loss spikes |
| AdaGC | arxiv:2502.11034 | Per-tensor adaptive gradient clipping, optimizer-agnostic |
| Pioneer Agent | arxiv:2604.09791 | Structured decision tree by score buckets for autonomous iteration |
| Deep Researcher | arxiv:2604.05854 | Dry-run validation, zero-cost monitoring, constant-size memory |
| CheckFree | arxiv:2506.15461 | Pipeline-parallel recovery via neighbor averaging |
| DPO | Rafailov et al. (2023) | DPO plateau at 0.693 = random chance (Section 4.2) |
| PTT | [post-training-toolkit](https://github.com/microsoft/post-training-toolkit) | DiagnosticsCallback + postmortem pattern |
## License
MIT β€” use freely, attribution appreciated.
---
Built autonomously by ML Intern. Questions? Open an issue on the Hub.