File size: 8,443 Bytes
880bd2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#!/usr/bin/env python3
"""
Unit tests for Self-Healing Training System.

Run: pytest tests/ -v
"""
import pytest
import torch
import math
from dataclasses import asdict
from unittest.mock import MagicMock, patch

# Import the system (these don't need GPU)
import sys
sys.path.insert(0, "..")
from self_healing.core import (
    HealingConfig,
    HealingActions,
    SelfHealingCallback,
    SelfHealingTrainer,
    ZClip,
    FailureType,
    FAILURE_RECIPES,
)


class TestHealingConfig:
    """Tests for HealingConfig."""
    
    def test_default_values(self):
        config = HealingConfig()
        assert config.nan_patience == 3
        assert config.loss_spike_factor == 5.0
        assert config.zclip_enabled is True
        assert config.max_recovery_attempts == 5
    
    def test_serialization_roundtrip(self):
        config = HealingConfig(nan_patience=10, zclip_z_threshold=2.5)
        d = config.to_dict()
        config2 = HealingConfig.from_dict(d)
        assert config2.nan_patience == 10
        assert config2.zclip_z_threshold == 2.5
    
    def test_aggressive_preset(self):
        config = HealingConfig.aggressive()
        assert config.nan_patience == 1
        assert config.loss_spike_factor == 3.0
        assert config.max_recovery_attempts == 10
    
    def test_conservative_preset(self):
        config = HealingConfig.conservative()
        assert config.nan_patience == 10
        assert config.max_recovery_attempts == 2


class TestZClip:
    """Tests for ZClip adaptive gradient clipping."""
    
    def test_initial_state(self):
        zclip = ZClip(z_threshold=3.0, ema_decay=0.99)
        assert zclip.mean is None
        assert zclip.std is None
        assert zclip.clip_count == 0
    
    def test_first_update(self):
        zclip = ZClip()
        result = zclip.update_and_clip(5.0)
        assert result == 5.0
        assert zclip.mean == 5.0
        assert zclip.std == 0.0
    
    def test_no_clip_within_threshold(self):
        zclip = ZClip(z_threshold=3.0, ema_decay=0.5)
        # Stabilize at 5.0
        for _ in range(20):
            zclip.update_and_clip(5.0)
        # Small perturbation
        result = zclip.update_and_clip(6.0)
        assert result == 6.0  # No clip
        assert zclip.clip_count == 0
    
    def test_clip_on_spike(self):
        zclip = ZClip(z_threshold=2.0, ema_decay=0.9)
        # Stabilize
        for _ in range(50):
            zclip.update_and_clip(5.0)
        # Massive spike
        result = zclip.update_and_clip(100.0)
        assert result < 100.0  # Was clipped
        assert zclip.clip_count == 1
    
    def test_state_serialization(self):
        zclip = ZClip()
        zclip.update_and_clip(5.0)
        zclip.update_and_clip(10.0)
        state = zclip.state_dict()
        assert "mean" in state
        assert "std" in state
        assert "clip_count" in state
        
        zclip2 = ZClip()
        zclip2.load_state_dict(state)
        assert zclip2.mean == zclip.mean
        assert zclip2.clip_count == zclip.clip_count


class TestFailureTaxonomy:
    """Tests for failure taxonomy."""
    
    def test_all_failures_have_recipes(self):
        for failure in FailureType:
            assert failure in FAILURE_RECIPES
            recipe = FAILURE_RECIPES[failure]
            assert "diagnosis" in recipe
            assert "actions" in recipe
            assert "severity" in recipe
            assert recipe["severity"] in ("error", "warn")
    
    def test_nan_loss_actions(self):
        recipe = FAILURE_RECIPES[FailureType.NAN_LOSS]
        assert "rollback_checkpoint" in recipe["actions"]
        assert "halve_learning_rate" in recipe["actions"]
    
    def test_oom_actions(self):
        recipe = FAILURE_RECIPES[FailureType.OOM]
        assert "halve_batch_size" in recipe["actions"]
        assert "enable_gradient_checkpointing" in recipe["actions"]
        assert "clear_cache" in recipe["actions"]


class TestSelfHealingCallback:
    """Tests for SelfHealingCallback detection logic."""
    
    def setup_method(self):
        self.config = HealingConfig(
            nan_patience=3,
            loss_spike_factor=5.0,
            divergence_patience=10,
            zclip_enabled=False,  # Disable for simpler tests
        )
    
    def test_initial_state(self):
        cb = SelfHealingCallback(self.config)
        assert cb.nan_count == 0
        assert cb.recovery_attempts == 0
        assert cb.lr_reductions == 0
        assert len(cb.loss_history) == 0
    
    def test_callbacks_have_required_methods(self):
        """All TrainerCallback methods should be present."""
        cb = SelfHealingCallback(self.config)
        for method in [
            "on_train_begin", "on_step_end", "on_log",
            "on_evaluate", "on_exception", "on_train_end",
        ]:
            assert hasattr(cb, method)
    
    def test_state_serialization(self):
        cb = SelfHealingCallback(self.config)
        cb.nan_count = 5
        cb.increasing_loss_count = 20
        cb.recovery_attempts = 2
        state = cb.get_state()
        assert state["nan_count"] == 5
        assert state["recovery_attempts"] == 2
        
        cb2 = SelfHealingCallback(self.config)
        cb2.load_state(state)
        assert cb2.nan_count == 5
        assert cb2.recovery_attempts == 2


class TestHealingActions:
    """Tests for HealingActions recovery logic."""
    
    def setup_method(self):
        self.config = HealingConfig(
            lr_reduce_factor=0.5,
            batch_reduce_factor=0.5,
            max_lr_reductions=4,
            max_batch_reductions=3,
        )
    
    def test_halve_learning_rate(self):
        from transformers import TrainingArguments
        args = TrainingArguments(
            output_dir="/tmp",
            learning_rate=1e-4,
            per_device_train_batch_size=4,
        )
        cb = SelfHealingCallback(self.config)
        actions = HealingActions(self.config, cb)
        result = actions._apply_single("halve_learning_rate", args, {})
        assert args.learning_rate == 5e-5
        assert cb.lr_reductions == 1
        assert "5.00e-05" in result
    
    def test_lr_reduction_limit(self):
        from transformers import TrainingArguments
        args = TrainingArguments(
            output_dir="/tmp",
            learning_rate=1e-4,
            per_device_train_batch_size=4,
        )
        cb = SelfHealingCallback(self.config)
        cb.lr_reductions = 4  # Already at max
        actions = HealingActions(self.config, cb)
        result = actions._apply_single("halve_learning_rate", args, {})
        assert "MAX" in result
    
    def test_halve_batch_size_preserves_effective(self):
        from transformers import TrainingArguments
        args = TrainingArguments(
            output_dir="/tmp",
            per_device_train_batch_size=8,
            gradient_accumulation_steps=1,
            learning_rate=1e-4,
        )
        cb = SelfHealingCallback(self.config)
        actions = HealingActions(self.config, cb)
        result = actions._apply_single("halve_batch_size", args, {})
        assert args.per_device_train_batch_size == 4
        assert args.gradient_accumulation_steps == 2  # Effective batch preserved
        
    def test_enable_gradient_checkpointing(self):
        from transformers import TrainingArguments
        args = TrainingArguments(
            output_dir="/tmp",
            learning_rate=1e-4,
            per_device_train_batch_size=4,
        )
        args.gradient_checkpointing = False
        cb = SelfHealingCallback(self.config)
        actions = HealingActions(self.config, cb)
        result = actions._apply_single("enable_gradient_checkpointing", args, {})
        assert args.gradient_checkpointing is True
        assert "Enabled" in result
    
    def test_exponential_backoff(self):
        from transformers import TrainingArguments
        args = TrainingArguments(
            output_dir="/tmp",
            learning_rate=1e-4,
            per_device_train_batch_size=4,
        )
        self.config.api_retry_base_delay = 0.01  # Fast for tests
        cb = SelfHealingCallback(self.config)
        cb.recovery_attempts = 1
        actions = HealingActions(self.config, cb)
        result = actions._apply_single("exponential_backoff", args, {})
        assert "Waited" in result


if __name__ == "__main__":
    pytest.main([__file__, "-v"])