File size: 8,713 Bytes
354e067 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | #!/usr/bin/env python3
"""
Stress-test: Catastrophic Failure Injection
===========================================
Intentionally triggers failures to verify self-healing recovery.
Failures injected:
1. NaN injection in loss β should trigger rollback + halve LR
2. Simulated OOM β should trigger batch halving + grad checkpointing
3. API error β should trigger exponential backoff
This requires a GPU. Run with:
python tests/stress_test_recovery.py
"""
import os, sys, json, time, math, gc
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments,
TrainerCallback, TrainerControl, TrainerState,
)
from datasets import Dataset
from self_healing import (
SelfHealingTrainer, HealingConfig, SelfHealingCallback,
HealingActions, FailureType, FAILURE_RECIPES,
)
class NaNInjectorCallback(TrainerCallback):
"""Intentionally inject NaN into loss at a specific step."""
def __init__(self, inject_at_step: int = 10):
self.inject_at_step = inject_at_step
self.original_forward = None
def on_step_begin(self, args, state, control, **kwargs):
if state.global_step == self.inject_at_step and not hasattr(self, '_injected'):
self._injected = True
print(f"\n [INJECT] Forcing NaN at step {state.global_step}\n")
# Override the model's forward to return NaN
model = kwargs.get("model")
if model is not None:
self.original_forward = model.forward
def nan_forward(*a, **kw):
result = self.original_forward(*a, **kw)
result.loss = torch.tensor(float('nan'))
return result
model.forward = nan_forward
def test_nan_recovery():
"""
Test: Inject NaN β verify SelfHealingTrainer detects and recovers.
"""
print("\n" + "=" * 60)
print(" STRESS TEST 1: NaN Recovery")
print("=" * 60)
# Tiny model
model_id = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32, # float32 for NaN safety
device_map="auto" if torch.cuda.is_available() else None,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Create dummy dataset
texts = ["The quick brown fox jumps over the lazy dog."] * 100
ds = Dataset.from_dict({
"text": texts,
"input_ids": [tokenizer.encode(t, truncation=True, max_length=32) for t in texts],
"attention_mask": [[1]*len(tokenizer.encode(t, truncation=True, max_length=32)) for t in texts],
})
training_args = TrainingArguments(
output_dir="./stress-nan-output",
per_device_train_batch_size=2,
learning_rate=1e-4,
max_steps=30,
logging_steps=1,
logging_strategy="steps",
logging_first_step=True,
save_steps=100,
report_to="none",
disable_tqdm=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=tokenizer,
callbacks=[NaNInjectorCallback(inject_at_step=10)],
)
healing_config = HealingConfig(
nan_patience=1, # React immediately
max_recovery_attempts=3,
max_lr_reductions=3,
zclip_enabled=False,
postmortem_path="./stress-nan-postmortem.json",
)
sh = SelfHealingTrainer(trainer, healing_config)
print("Training with NaN injection at step 10...")
result = sh.train()
print(f"\nResults:")
print(f" Converged: {sh.converged}")
print(f" Attempts: {sh.attempt}")
print(f" Recoveries: {len(sh.recovery_history)}")
if sh.recovery_history:
for rec in sh.recovery_history:
print(f" β {rec['failure']}: {rec['actions']}")
# Verify: should have at least one recovery for NaN
assert len(sh.recovery_history) >= 1, "Expected NaN recovery!"
assert any(r["failure"] == "nan_loss" for r in sh.recovery_history), \
"Expected nan_loss failure type!"
# Verify LR was reduced
assert sh.healing_callback.lr_reductions >= 1, \
"Expected LR to be reduced!"
print(" β NaN recovery test PASSED")
if os.path.exists(healing_config.postmortem_path):
with open(healing_config.postmortem_path) as f:
pm = json.load(f)
print(f" Postmortem: {pm.get('exit_reason')} at step {pm.get('last_step')}")
def test_zclip_spike_detection():
"""
Test: Feed spike values to ZClip β verify clipping.
"""
print("\n" + "=" * 60)
print(" STRESS TEST 2: ZClip Spike Detection")
print("=" * 60)
from self_healing import ZClip
zclip = ZClip(z_threshold=2.5, ema_decay=0.9)
# Stabilize at norm=10.0
for _ in range(100):
zclip.update_and_clip(10.0)
# Inject spike
clipped = zclip.update_and_clip(500.0)
print(f" Raw: 500.0, Clipped: {clipped:.1f}, Clips: {zclip.clip_count}")
assert clipped < 500.0, "Expected spike to be clipped!"
assert zclip.clip_count >= 1, "Expected clip counter to increment!"
print(" β ZClip spike detection PASSED")
def test_healing_config_limits():
"""
Test: Verify that max reduction limits are enforced.
"""
print("\n" + "=" * 60)
print(" STRESS TEST 3: Recovery Limits")
print("=" * 60)
from transformers import TrainingArguments
from self_healing import HealingActions, SelfHealingCallback, HealingConfig
config = HealingConfig(
max_lr_reductions=2,
max_batch_reductions=2,
)
# Test LR limit
args = TrainingArguments(
output_dir="/tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
)
cb = SelfHealingCallback(config)
actions = HealingActions(config, cb)
# Reduce twice
actions._apply_single("halve_learning_rate", args, {})
actions._apply_single("halve_learning_rate", args, {})
assert cb.lr_reductions == 2
# Third reduction should hit limit
result = actions._apply_single("halve_learning_rate", args, {})
assert "MAX" in result
assert cb.lr_reductions == 2 # Should not increment
print(f" LR after 2 reductions: {args.learning_rate:.2e}")
print(f" Third attempt: {result}")
print(" β Recovery limits test PASSED")
def test_postmortem_written():
"""
Test: Verify postmortem.json is written on crash.
"""
print("\n" + "=" * 60)
print(" STRESS TEST 4: Postmortem Generation")
print("=" * 60)
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
config = HealingConfig(
postmortem_path=os.path.join(tmpdir, "postmortem.json"),
)
cb = SelfHealingCallback(config)
# Simulate exception
cb.on_exception(
MagicMock(), # args
MagicMock(global_step=42, log_history=[{"loss": 1.5}]), # state
MagicMock(), # control
torch.cuda.OutOfMemoryError("CUDA out of memory. Tried to allocate 2.00 GiB"), # exception
)
# Check postmortem exists
assert os.path.exists(config.postmortem_path)
with open(config.postmortem_path) as f:
pm = json.load(f)
assert pm["exception_type"] == "OutOfMemoryError"
assert pm["last_step"] == 42
assert "loss" in pm["final_metrics"]
assert pm["final_metrics"]["loss"] == 1.5
print(f" Postmortem path: {config.postmortem_path}")
print(f" Content: {json.dumps(pm, indent=2)}")
print(" β Postmortem generation PASSED")
if __name__ == "__main__":
# Import mock for test 4
from unittest.mock import MagicMock
print("β" + "β" * 58 + "β")
print("β SELF-HEALING TRAINING SYSTEM β STRESS TEST SUITE β")
print("β" + "β" * 58 + "β")
# Run tests (order matters: ZClip first, no GPU needed)
test_zclip_spike_detection()
test_healing_config_limits()
test_postmortem_written()
# NaN recovery test (needs model loading)
if torch.cuda.is_available():
test_nan_recovery()
else:
print("\n" + "=" * 60)
print(" STRESS TEST 1: NaN Recovery")
print("=" * 60)
print(" β Skipped: No GPU available")
print("\n" + "=" * 60)
print(" ALL STRESS TESTS PASSED β")
print("=" * 60) |