self-healing-training / tests /test_core.py
ScottzillaSystems's picture
Upload tests/test_core.py
880bd2d verified
#!/usr/bin/env python3
"""
Unit tests for Self-Healing Training System.
Run: pytest tests/ -v
"""
import pytest
import torch
import math
from dataclasses import asdict
from unittest.mock import MagicMock, patch
# Import the system (these don't need GPU)
import sys
sys.path.insert(0, "..")
from self_healing.core import (
HealingConfig,
HealingActions,
SelfHealingCallback,
SelfHealingTrainer,
ZClip,
FailureType,
FAILURE_RECIPES,
)
class TestHealingConfig:
"""Tests for HealingConfig."""
def test_default_values(self):
config = HealingConfig()
assert config.nan_patience == 3
assert config.loss_spike_factor == 5.0
assert config.zclip_enabled is True
assert config.max_recovery_attempts == 5
def test_serialization_roundtrip(self):
config = HealingConfig(nan_patience=10, zclip_z_threshold=2.5)
d = config.to_dict()
config2 = HealingConfig.from_dict(d)
assert config2.nan_patience == 10
assert config2.zclip_z_threshold == 2.5
def test_aggressive_preset(self):
config = HealingConfig.aggressive()
assert config.nan_patience == 1
assert config.loss_spike_factor == 3.0
assert config.max_recovery_attempts == 10
def test_conservative_preset(self):
config = HealingConfig.conservative()
assert config.nan_patience == 10
assert config.max_recovery_attempts == 2
class TestZClip:
"""Tests for ZClip adaptive gradient clipping."""
def test_initial_state(self):
zclip = ZClip(z_threshold=3.0, ema_decay=0.99)
assert zclip.mean is None
assert zclip.std is None
assert zclip.clip_count == 0
def test_first_update(self):
zclip = ZClip()
result = zclip.update_and_clip(5.0)
assert result == 5.0
assert zclip.mean == 5.0
assert zclip.std == 0.0
def test_no_clip_within_threshold(self):
zclip = ZClip(z_threshold=3.0, ema_decay=0.5)
# Stabilize at 5.0
for _ in range(20):
zclip.update_and_clip(5.0)
# Small perturbation
result = zclip.update_and_clip(6.0)
assert result == 6.0 # No clip
assert zclip.clip_count == 0
def test_clip_on_spike(self):
zclip = ZClip(z_threshold=2.0, ema_decay=0.9)
# Stabilize
for _ in range(50):
zclip.update_and_clip(5.0)
# Massive spike
result = zclip.update_and_clip(100.0)
assert result < 100.0 # Was clipped
assert zclip.clip_count == 1
def test_state_serialization(self):
zclip = ZClip()
zclip.update_and_clip(5.0)
zclip.update_and_clip(10.0)
state = zclip.state_dict()
assert "mean" in state
assert "std" in state
assert "clip_count" in state
zclip2 = ZClip()
zclip2.load_state_dict(state)
assert zclip2.mean == zclip.mean
assert zclip2.clip_count == zclip.clip_count
class TestFailureTaxonomy:
"""Tests for failure taxonomy."""
def test_all_failures_have_recipes(self):
for failure in FailureType:
assert failure in FAILURE_RECIPES
recipe = FAILURE_RECIPES[failure]
assert "diagnosis" in recipe
assert "actions" in recipe
assert "severity" in recipe
assert recipe["severity"] in ("error", "warn")
def test_nan_loss_actions(self):
recipe = FAILURE_RECIPES[FailureType.NAN_LOSS]
assert "rollback_checkpoint" in recipe["actions"]
assert "halve_learning_rate" in recipe["actions"]
def test_oom_actions(self):
recipe = FAILURE_RECIPES[FailureType.OOM]
assert "halve_batch_size" in recipe["actions"]
assert "enable_gradient_checkpointing" in recipe["actions"]
assert "clear_cache" in recipe["actions"]
class TestSelfHealingCallback:
"""Tests for SelfHealingCallback detection logic."""
def setup_method(self):
self.config = HealingConfig(
nan_patience=3,
loss_spike_factor=5.0,
divergence_patience=10,
zclip_enabled=False, # Disable for simpler tests
)
def test_initial_state(self):
cb = SelfHealingCallback(self.config)
assert cb.nan_count == 0
assert cb.recovery_attempts == 0
assert cb.lr_reductions == 0
assert len(cb.loss_history) == 0
def test_callbacks_have_required_methods(self):
"""All TrainerCallback methods should be present."""
cb = SelfHealingCallback(self.config)
for method in [
"on_train_begin", "on_step_end", "on_log",
"on_evaluate", "on_exception", "on_train_end",
]:
assert hasattr(cb, method)
def test_state_serialization(self):
cb = SelfHealingCallback(self.config)
cb.nan_count = 5
cb.increasing_loss_count = 20
cb.recovery_attempts = 2
state = cb.get_state()
assert state["nan_count"] == 5
assert state["recovery_attempts"] == 2
cb2 = SelfHealingCallback(self.config)
cb2.load_state(state)
assert cb2.nan_count == 5
assert cb2.recovery_attempts == 2
class TestHealingActions:
"""Tests for HealingActions recovery logic."""
def setup_method(self):
self.config = HealingConfig(
lr_reduce_factor=0.5,
batch_reduce_factor=0.5,
max_lr_reductions=4,
max_batch_reductions=3,
)
def test_halve_learning_rate(self):
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="/tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
)
cb = SelfHealingCallback(self.config)
actions = HealingActions(self.config, cb)
result = actions._apply_single("halve_learning_rate", args, {})
assert args.learning_rate == 5e-5
assert cb.lr_reductions == 1
assert "5.00e-05" in result
def test_lr_reduction_limit(self):
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="/tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
)
cb = SelfHealingCallback(self.config)
cb.lr_reductions = 4 # Already at max
actions = HealingActions(self.config, cb)
result = actions._apply_single("halve_learning_rate", args, {})
assert "MAX" in result
def test_halve_batch_size_preserves_effective(self):
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="/tmp",
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
learning_rate=1e-4,
)
cb = SelfHealingCallback(self.config)
actions = HealingActions(self.config, cb)
result = actions._apply_single("halve_batch_size", args, {})
assert args.per_device_train_batch_size == 4
assert args.gradient_accumulation_steps == 2 # Effective batch preserved
def test_enable_gradient_checkpointing(self):
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="/tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
)
args.gradient_checkpointing = False
cb = SelfHealingCallback(self.config)
actions = HealingActions(self.config, cb)
result = actions._apply_single("enable_gradient_checkpointing", args, {})
assert args.gradient_checkpointing is True
assert "Enabled" in result
def test_exponential_backoff(self):
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="/tmp",
learning_rate=1e-4,
per_device_train_batch_size=4,
)
self.config.api_retry_base_delay = 0.01 # Fast for tests
cb = SelfHealingCallback(self.config)
cb.recovery_attempts = 1
actions = HealingActions(self.config, cb)
result = actions._apply_single("exponential_backoff", args, {})
assert "Waited" in result
if __name__ == "__main__":
pytest.main([__file__, "-v"])