| |
| """ |
| Unit tests for Self-Healing Training System. |
| |
| Run: pytest tests/ -v |
| """ |
| import pytest |
| import torch |
| import math |
| from dataclasses import asdict |
| from unittest.mock import MagicMock, patch |
|
|
| |
| import sys |
| sys.path.insert(0, "..") |
| from self_healing.core import ( |
| HealingConfig, |
| HealingActions, |
| SelfHealingCallback, |
| SelfHealingTrainer, |
| ZClip, |
| FailureType, |
| FAILURE_RECIPES, |
| ) |
|
|
|
|
| class TestHealingConfig: |
| """Tests for HealingConfig.""" |
| |
| def test_default_values(self): |
| config = HealingConfig() |
| assert config.nan_patience == 3 |
| assert config.loss_spike_factor == 5.0 |
| assert config.zclip_enabled is True |
| assert config.max_recovery_attempts == 5 |
| |
| def test_serialization_roundtrip(self): |
| config = HealingConfig(nan_patience=10, zclip_z_threshold=2.5) |
| d = config.to_dict() |
| config2 = HealingConfig.from_dict(d) |
| assert config2.nan_patience == 10 |
| assert config2.zclip_z_threshold == 2.5 |
| |
| def test_aggressive_preset(self): |
| config = HealingConfig.aggressive() |
| assert config.nan_patience == 1 |
| assert config.loss_spike_factor == 3.0 |
| assert config.max_recovery_attempts == 10 |
| |
| def test_conservative_preset(self): |
| config = HealingConfig.conservative() |
| assert config.nan_patience == 10 |
| assert config.max_recovery_attempts == 2 |
|
|
|
|
| class TestZClip: |
| """Tests for ZClip adaptive gradient clipping.""" |
| |
| def test_initial_state(self): |
| zclip = ZClip(z_threshold=3.0, ema_decay=0.99) |
| assert zclip.mean is None |
| assert zclip.std is None |
| assert zclip.clip_count == 0 |
| |
| def test_first_update(self): |
| zclip = ZClip() |
| result = zclip.update_and_clip(5.0) |
| assert result == 5.0 |
| assert zclip.mean == 5.0 |
| assert zclip.std == 0.0 |
| |
| def test_no_clip_within_threshold(self): |
| zclip = ZClip(z_threshold=3.0, ema_decay=0.5) |
| |
| for _ in range(20): |
| zclip.update_and_clip(5.0) |
| |
| result = zclip.update_and_clip(6.0) |
| assert result == 6.0 |
| assert zclip.clip_count == 0 |
| |
| def test_clip_on_spike(self): |
| zclip = ZClip(z_threshold=2.0, ema_decay=0.9) |
| |
| for _ in range(50): |
| zclip.update_and_clip(5.0) |
| |
| result = zclip.update_and_clip(100.0) |
| assert result < 100.0 |
| assert zclip.clip_count == 1 |
| |
| def test_state_serialization(self): |
| zclip = ZClip() |
| zclip.update_and_clip(5.0) |
| zclip.update_and_clip(10.0) |
| state = zclip.state_dict() |
| assert "mean" in state |
| assert "std" in state |
| assert "clip_count" in state |
| |
| zclip2 = ZClip() |
| zclip2.load_state_dict(state) |
| assert zclip2.mean == zclip.mean |
| assert zclip2.clip_count == zclip.clip_count |
|
|
|
|
| class TestFailureTaxonomy: |
| """Tests for failure taxonomy.""" |
| |
| def test_all_failures_have_recipes(self): |
| for failure in FailureType: |
| assert failure in FAILURE_RECIPES |
| recipe = FAILURE_RECIPES[failure] |
| assert "diagnosis" in recipe |
| assert "actions" in recipe |
| assert "severity" in recipe |
| assert recipe["severity"] in ("error", "warn") |
| |
| def test_nan_loss_actions(self): |
| recipe = FAILURE_RECIPES[FailureType.NAN_LOSS] |
| assert "rollback_checkpoint" in recipe["actions"] |
| assert "halve_learning_rate" in recipe["actions"] |
| |
| def test_oom_actions(self): |
| recipe = FAILURE_RECIPES[FailureType.OOM] |
| assert "halve_batch_size" in recipe["actions"] |
| assert "enable_gradient_checkpointing" in recipe["actions"] |
| assert "clear_cache" in recipe["actions"] |
|
|
|
|
| class TestSelfHealingCallback: |
| """Tests for SelfHealingCallback detection logic.""" |
| |
| def setup_method(self): |
| self.config = HealingConfig( |
| nan_patience=3, |
| loss_spike_factor=5.0, |
| divergence_patience=10, |
| zclip_enabled=False, |
| ) |
| |
| def test_initial_state(self): |
| cb = SelfHealingCallback(self.config) |
| assert cb.nan_count == 0 |
| assert cb.recovery_attempts == 0 |
| assert cb.lr_reductions == 0 |
| assert len(cb.loss_history) == 0 |
| |
| def test_callbacks_have_required_methods(self): |
| """All TrainerCallback methods should be present.""" |
| cb = SelfHealingCallback(self.config) |
| for method in [ |
| "on_train_begin", "on_step_end", "on_log", |
| "on_evaluate", "on_exception", "on_train_end", |
| ]: |
| assert hasattr(cb, method) |
| |
| def test_state_serialization(self): |
| cb = SelfHealingCallback(self.config) |
| cb.nan_count = 5 |
| cb.increasing_loss_count = 20 |
| cb.recovery_attempts = 2 |
| state = cb.get_state() |
| assert state["nan_count"] == 5 |
| assert state["recovery_attempts"] == 2 |
| |
| cb2 = SelfHealingCallback(self.config) |
| cb2.load_state(state) |
| assert cb2.nan_count == 5 |
| assert cb2.recovery_attempts == 2 |
|
|
|
|
| class TestHealingActions: |
| """Tests for HealingActions recovery logic.""" |
| |
| def setup_method(self): |
| self.config = HealingConfig( |
| lr_reduce_factor=0.5, |
| batch_reduce_factor=0.5, |
| max_lr_reductions=4, |
| max_batch_reductions=3, |
| ) |
| |
| def test_halve_learning_rate(self): |
| from transformers import TrainingArguments |
| args = TrainingArguments( |
| output_dir="/tmp", |
| learning_rate=1e-4, |
| per_device_train_batch_size=4, |
| ) |
| cb = SelfHealingCallback(self.config) |
| actions = HealingActions(self.config, cb) |
| result = actions._apply_single("halve_learning_rate", args, {}) |
| assert args.learning_rate == 5e-5 |
| assert cb.lr_reductions == 1 |
| assert "5.00e-05" in result |
| |
| def test_lr_reduction_limit(self): |
| from transformers import TrainingArguments |
| args = TrainingArguments( |
| output_dir="/tmp", |
| learning_rate=1e-4, |
| per_device_train_batch_size=4, |
| ) |
| cb = SelfHealingCallback(self.config) |
| cb.lr_reductions = 4 |
| actions = HealingActions(self.config, cb) |
| result = actions._apply_single("halve_learning_rate", args, {}) |
| assert "MAX" in result |
| |
| def test_halve_batch_size_preserves_effective(self): |
| from transformers import TrainingArguments |
| args = TrainingArguments( |
| output_dir="/tmp", |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=1, |
| learning_rate=1e-4, |
| ) |
| cb = SelfHealingCallback(self.config) |
| actions = HealingActions(self.config, cb) |
| result = actions._apply_single("halve_batch_size", args, {}) |
| assert args.per_device_train_batch_size == 4 |
| assert args.gradient_accumulation_steps == 2 |
| |
| def test_enable_gradient_checkpointing(self): |
| from transformers import TrainingArguments |
| args = TrainingArguments( |
| output_dir="/tmp", |
| learning_rate=1e-4, |
| per_device_train_batch_size=4, |
| ) |
| args.gradient_checkpointing = False |
| cb = SelfHealingCallback(self.config) |
| actions = HealingActions(self.config, cb) |
| result = actions._apply_single("enable_gradient_checkpointing", args, {}) |
| assert args.gradient_checkpointing is True |
| assert "Enabled" in result |
| |
| def test_exponential_backoff(self): |
| from transformers import TrainingArguments |
| args = TrainingArguments( |
| output_dir="/tmp", |
| learning_rate=1e-4, |
| per_device_train_batch_size=4, |
| ) |
| self.config.api_retry_base_delay = 0.01 |
| cb = SelfHealingCallback(self.config) |
| cb.recovery_attempts = 1 |
| actions = HealingActions(self.config, cb) |
| result = actions._apply_single("exponential_backoff", args, {}) |
| assert "Waited" in result |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |