#!/usr/bin/env python3 """ Unit tests for Self-Healing Training System. Run: pytest tests/ -v """ import pytest import torch import math from dataclasses import asdict from unittest.mock import MagicMock, patch # Import the system (these don't need GPU) import sys sys.path.insert(0, "..") from self_healing.core import ( HealingConfig, HealingActions, SelfHealingCallback, SelfHealingTrainer, ZClip, FailureType, FAILURE_RECIPES, ) class TestHealingConfig: """Tests for HealingConfig.""" def test_default_values(self): config = HealingConfig() assert config.nan_patience == 3 assert config.loss_spike_factor == 5.0 assert config.zclip_enabled is True assert config.max_recovery_attempts == 5 def test_serialization_roundtrip(self): config = HealingConfig(nan_patience=10, zclip_z_threshold=2.5) d = config.to_dict() config2 = HealingConfig.from_dict(d) assert config2.nan_patience == 10 assert config2.zclip_z_threshold == 2.5 def test_aggressive_preset(self): config = HealingConfig.aggressive() assert config.nan_patience == 1 assert config.loss_spike_factor == 3.0 assert config.max_recovery_attempts == 10 def test_conservative_preset(self): config = HealingConfig.conservative() assert config.nan_patience == 10 assert config.max_recovery_attempts == 2 class TestZClip: """Tests for ZClip adaptive gradient clipping.""" def test_initial_state(self): zclip = ZClip(z_threshold=3.0, ema_decay=0.99) assert zclip.mean is None assert zclip.std is None assert zclip.clip_count == 0 def test_first_update(self): zclip = ZClip() result = zclip.update_and_clip(5.0) assert result == 5.0 assert zclip.mean == 5.0 assert zclip.std == 0.0 def test_no_clip_within_threshold(self): zclip = ZClip(z_threshold=3.0, ema_decay=0.5) # Stabilize at 5.0 for _ in range(20): zclip.update_and_clip(5.0) # Small perturbation result = zclip.update_and_clip(6.0) assert result == 6.0 # No clip assert zclip.clip_count == 0 def test_clip_on_spike(self): zclip = ZClip(z_threshold=2.0, ema_decay=0.9) # Stabilize for _ in range(50): zclip.update_and_clip(5.0) # Massive spike result = zclip.update_and_clip(100.0) assert result < 100.0 # Was clipped assert zclip.clip_count == 1 def test_state_serialization(self): zclip = ZClip() zclip.update_and_clip(5.0) zclip.update_and_clip(10.0) state = zclip.state_dict() assert "mean" in state assert "std" in state assert "clip_count" in state zclip2 = ZClip() zclip2.load_state_dict(state) assert zclip2.mean == zclip.mean assert zclip2.clip_count == zclip.clip_count class TestFailureTaxonomy: """Tests for failure taxonomy.""" def test_all_failures_have_recipes(self): for failure in FailureType: assert failure in FAILURE_RECIPES recipe = FAILURE_RECIPES[failure] assert "diagnosis" in recipe assert "actions" in recipe assert "severity" in recipe assert recipe["severity"] in ("error", "warn") def test_nan_loss_actions(self): recipe = FAILURE_RECIPES[FailureType.NAN_LOSS] assert "rollback_checkpoint" in recipe["actions"] assert "halve_learning_rate" in recipe["actions"] def test_oom_actions(self): recipe = FAILURE_RECIPES[FailureType.OOM] assert "halve_batch_size" in recipe["actions"] assert "enable_gradient_checkpointing" in recipe["actions"] assert "clear_cache" in recipe["actions"] class TestSelfHealingCallback: """Tests for SelfHealingCallback detection logic.""" def setup_method(self): self.config = HealingConfig( nan_patience=3, loss_spike_factor=5.0, divergence_patience=10, zclip_enabled=False, # Disable for simpler tests ) def test_initial_state(self): cb = SelfHealingCallback(self.config) assert cb.nan_count == 0 assert cb.recovery_attempts == 0 assert cb.lr_reductions == 0 assert len(cb.loss_history) == 0 def test_callbacks_have_required_methods(self): """All TrainerCallback methods should be present.""" cb = SelfHealingCallback(self.config) for method in [ "on_train_begin", "on_step_end", "on_log", "on_evaluate", "on_exception", "on_train_end", ]: assert hasattr(cb, method) def test_state_serialization(self): cb = SelfHealingCallback(self.config) cb.nan_count = 5 cb.increasing_loss_count = 20 cb.recovery_attempts = 2 state = cb.get_state() assert state["nan_count"] == 5 assert state["recovery_attempts"] == 2 cb2 = SelfHealingCallback(self.config) cb2.load_state(state) assert cb2.nan_count == 5 assert cb2.recovery_attempts == 2 class TestHealingActions: """Tests for HealingActions recovery logic.""" def setup_method(self): self.config = HealingConfig( lr_reduce_factor=0.5, batch_reduce_factor=0.5, max_lr_reductions=4, max_batch_reductions=3, ) def test_halve_learning_rate(self): from transformers import TrainingArguments args = TrainingArguments( output_dir="/tmp", learning_rate=1e-4, per_device_train_batch_size=4, ) cb = SelfHealingCallback(self.config) actions = HealingActions(self.config, cb) result = actions._apply_single("halve_learning_rate", args, {}) assert args.learning_rate == 5e-5 assert cb.lr_reductions == 1 assert "5.00e-05" in result def test_lr_reduction_limit(self): from transformers import TrainingArguments args = TrainingArguments( output_dir="/tmp", learning_rate=1e-4, per_device_train_batch_size=4, ) cb = SelfHealingCallback(self.config) cb.lr_reductions = 4 # Already at max actions = HealingActions(self.config, cb) result = actions._apply_single("halve_learning_rate", args, {}) assert "MAX" in result def test_halve_batch_size_preserves_effective(self): from transformers import TrainingArguments args = TrainingArguments( output_dir="/tmp", per_device_train_batch_size=8, gradient_accumulation_steps=1, learning_rate=1e-4, ) cb = SelfHealingCallback(self.config) actions = HealingActions(self.config, cb) result = actions._apply_single("halve_batch_size", args, {}) assert args.per_device_train_batch_size == 4 assert args.gradient_accumulation_steps == 2 # Effective batch preserved def test_enable_gradient_checkpointing(self): from transformers import TrainingArguments args = TrainingArguments( output_dir="/tmp", learning_rate=1e-4, per_device_train_batch_size=4, ) args.gradient_checkpointing = False cb = SelfHealingCallback(self.config) actions = HealingActions(self.config, cb) result = actions._apply_single("enable_gradient_checkpointing", args, {}) assert args.gradient_checkpointing is True assert "Enabled" in result def test_exponential_backoff(self): from transformers import TrainingArguments args = TrainingArguments( output_dir="/tmp", learning_rate=1e-4, per_device_train_batch_size=4, ) self.config.api_retry_base_delay = 0.01 # Fast for tests cb = SelfHealingCallback(self.config) cb.recovery_attempts = 1 actions = HealingActions(self.config, cb) result = actions._apply_single("exponential_backoff", args, {}) assert "Waited" in result if __name__ == "__main__": pytest.main([__file__, "-v"])