ScottzillaSystems commited on
Commit
880bd2d
·
verified ·
1 Parent(s): 5e55ab0

Upload tests/test_core.py

Browse files
Files changed (1) hide show
  1. tests/test_core.py +255 -0
tests/test_core.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unit tests for Self-Healing Training System.
4
+
5
+ Run: pytest tests/ -v
6
+ """
7
+ import pytest
8
+ import torch
9
+ import math
10
+ from dataclasses import asdict
11
+ from unittest.mock import MagicMock, patch
12
+
13
+ # Import the system (these don't need GPU)
14
+ import sys
15
+ sys.path.insert(0, "..")
16
+ from self_healing.core import (
17
+ HealingConfig,
18
+ HealingActions,
19
+ SelfHealingCallback,
20
+ SelfHealingTrainer,
21
+ ZClip,
22
+ FailureType,
23
+ FAILURE_RECIPES,
24
+ )
25
+
26
+
27
+ class TestHealingConfig:
28
+ """Tests for HealingConfig."""
29
+
30
+ def test_default_values(self):
31
+ config = HealingConfig()
32
+ assert config.nan_patience == 3
33
+ assert config.loss_spike_factor == 5.0
34
+ assert config.zclip_enabled is True
35
+ assert config.max_recovery_attempts == 5
36
+
37
+ def test_serialization_roundtrip(self):
38
+ config = HealingConfig(nan_patience=10, zclip_z_threshold=2.5)
39
+ d = config.to_dict()
40
+ config2 = HealingConfig.from_dict(d)
41
+ assert config2.nan_patience == 10
42
+ assert config2.zclip_z_threshold == 2.5
43
+
44
+ def test_aggressive_preset(self):
45
+ config = HealingConfig.aggressive()
46
+ assert config.nan_patience == 1
47
+ assert config.loss_spike_factor == 3.0
48
+ assert config.max_recovery_attempts == 10
49
+
50
+ def test_conservative_preset(self):
51
+ config = HealingConfig.conservative()
52
+ assert config.nan_patience == 10
53
+ assert config.max_recovery_attempts == 2
54
+
55
+
56
+ class TestZClip:
57
+ """Tests for ZClip adaptive gradient clipping."""
58
+
59
+ def test_initial_state(self):
60
+ zclip = ZClip(z_threshold=3.0, ema_decay=0.99)
61
+ assert zclip.mean is None
62
+ assert zclip.std is None
63
+ assert zclip.clip_count == 0
64
+
65
+ def test_first_update(self):
66
+ zclip = ZClip()
67
+ result = zclip.update_and_clip(5.0)
68
+ assert result == 5.0
69
+ assert zclip.mean == 5.0
70
+ assert zclip.std == 0.0
71
+
72
+ def test_no_clip_within_threshold(self):
73
+ zclip = ZClip(z_threshold=3.0, ema_decay=0.5)
74
+ # Stabilize at 5.0
75
+ for _ in range(20):
76
+ zclip.update_and_clip(5.0)
77
+ # Small perturbation
78
+ result = zclip.update_and_clip(6.0)
79
+ assert result == 6.0 # No clip
80
+ assert zclip.clip_count == 0
81
+
82
+ def test_clip_on_spike(self):
83
+ zclip = ZClip(z_threshold=2.0, ema_decay=0.9)
84
+ # Stabilize
85
+ for _ in range(50):
86
+ zclip.update_and_clip(5.0)
87
+ # Massive spike
88
+ result = zclip.update_and_clip(100.0)
89
+ assert result < 100.0 # Was clipped
90
+ assert zclip.clip_count == 1
91
+
92
+ def test_state_serialization(self):
93
+ zclip = ZClip()
94
+ zclip.update_and_clip(5.0)
95
+ zclip.update_and_clip(10.0)
96
+ state = zclip.state_dict()
97
+ assert "mean" in state
98
+ assert "std" in state
99
+ assert "clip_count" in state
100
+
101
+ zclip2 = ZClip()
102
+ zclip2.load_state_dict(state)
103
+ assert zclip2.mean == zclip.mean
104
+ assert zclip2.clip_count == zclip.clip_count
105
+
106
+
107
+ class TestFailureTaxonomy:
108
+ """Tests for failure taxonomy."""
109
+
110
+ def test_all_failures_have_recipes(self):
111
+ for failure in FailureType:
112
+ assert failure in FAILURE_RECIPES
113
+ recipe = FAILURE_RECIPES[failure]
114
+ assert "diagnosis" in recipe
115
+ assert "actions" in recipe
116
+ assert "severity" in recipe
117
+ assert recipe["severity"] in ("error", "warn")
118
+
119
+ def test_nan_loss_actions(self):
120
+ recipe = FAILURE_RECIPES[FailureType.NAN_LOSS]
121
+ assert "rollback_checkpoint" in recipe["actions"]
122
+ assert "halve_learning_rate" in recipe["actions"]
123
+
124
+ def test_oom_actions(self):
125
+ recipe = FAILURE_RECIPES[FailureType.OOM]
126
+ assert "halve_batch_size" in recipe["actions"]
127
+ assert "enable_gradient_checkpointing" in recipe["actions"]
128
+ assert "clear_cache" in recipe["actions"]
129
+
130
+
131
+ class TestSelfHealingCallback:
132
+ """Tests for SelfHealingCallback detection logic."""
133
+
134
+ def setup_method(self):
135
+ self.config = HealingConfig(
136
+ nan_patience=3,
137
+ loss_spike_factor=5.0,
138
+ divergence_patience=10,
139
+ zclip_enabled=False, # Disable for simpler tests
140
+ )
141
+
142
+ def test_initial_state(self):
143
+ cb = SelfHealingCallback(self.config)
144
+ assert cb.nan_count == 0
145
+ assert cb.recovery_attempts == 0
146
+ assert cb.lr_reductions == 0
147
+ assert len(cb.loss_history) == 0
148
+
149
+ def test_callbacks_have_required_methods(self):
150
+ """All TrainerCallback methods should be present."""
151
+ cb = SelfHealingCallback(self.config)
152
+ for method in [
153
+ "on_train_begin", "on_step_end", "on_log",
154
+ "on_evaluate", "on_exception", "on_train_end",
155
+ ]:
156
+ assert hasattr(cb, method)
157
+
158
+ def test_state_serialization(self):
159
+ cb = SelfHealingCallback(self.config)
160
+ cb.nan_count = 5
161
+ cb.increasing_loss_count = 20
162
+ cb.recovery_attempts = 2
163
+ state = cb.get_state()
164
+ assert state["nan_count"] == 5
165
+ assert state["recovery_attempts"] == 2
166
+
167
+ cb2 = SelfHealingCallback(self.config)
168
+ cb2.load_state(state)
169
+ assert cb2.nan_count == 5
170
+ assert cb2.recovery_attempts == 2
171
+
172
+
173
+ class TestHealingActions:
174
+ """Tests for HealingActions recovery logic."""
175
+
176
+ def setup_method(self):
177
+ self.config = HealingConfig(
178
+ lr_reduce_factor=0.5,
179
+ batch_reduce_factor=0.5,
180
+ max_lr_reductions=4,
181
+ max_batch_reductions=3,
182
+ )
183
+
184
+ def test_halve_learning_rate(self):
185
+ from transformers import TrainingArguments
186
+ args = TrainingArguments(
187
+ output_dir="/tmp",
188
+ learning_rate=1e-4,
189
+ per_device_train_batch_size=4,
190
+ )
191
+ cb = SelfHealingCallback(self.config)
192
+ actions = HealingActions(self.config, cb)
193
+ result = actions._apply_single("halve_learning_rate", args, {})
194
+ assert args.learning_rate == 5e-5
195
+ assert cb.lr_reductions == 1
196
+ assert "5.00e-05" in result
197
+
198
+ def test_lr_reduction_limit(self):
199
+ from transformers import TrainingArguments
200
+ args = TrainingArguments(
201
+ output_dir="/tmp",
202
+ learning_rate=1e-4,
203
+ per_device_train_batch_size=4,
204
+ )
205
+ cb = SelfHealingCallback(self.config)
206
+ cb.lr_reductions = 4 # Already at max
207
+ actions = HealingActions(self.config, cb)
208
+ result = actions._apply_single("halve_learning_rate", args, {})
209
+ assert "MAX" in result
210
+
211
+ def test_halve_batch_size_preserves_effective(self):
212
+ from transformers import TrainingArguments
213
+ args = TrainingArguments(
214
+ output_dir="/tmp",
215
+ per_device_train_batch_size=8,
216
+ gradient_accumulation_steps=1,
217
+ learning_rate=1e-4,
218
+ )
219
+ cb = SelfHealingCallback(self.config)
220
+ actions = HealingActions(self.config, cb)
221
+ result = actions._apply_single("halve_batch_size", args, {})
222
+ assert args.per_device_train_batch_size == 4
223
+ assert args.gradient_accumulation_steps == 2 # Effective batch preserved
224
+
225
+ def test_enable_gradient_checkpointing(self):
226
+ from transformers import TrainingArguments
227
+ args = TrainingArguments(
228
+ output_dir="/tmp",
229
+ learning_rate=1e-4,
230
+ per_device_train_batch_size=4,
231
+ )
232
+ args.gradient_checkpointing = False
233
+ cb = SelfHealingCallback(self.config)
234
+ actions = HealingActions(self.config, cb)
235
+ result = actions._apply_single("enable_gradient_checkpointing", args, {})
236
+ assert args.gradient_checkpointing is True
237
+ assert "Enabled" in result
238
+
239
+ def test_exponential_backoff(self):
240
+ from transformers import TrainingArguments
241
+ args = TrainingArguments(
242
+ output_dir="/tmp",
243
+ learning_rate=1e-4,
244
+ per_device_train_batch_size=4,
245
+ )
246
+ self.config.api_retry_base_delay = 0.01 # Fast for tests
247
+ cb = SelfHealingCallback(self.config)
248
+ cb.recovery_attempts = 1
249
+ actions = HealingActions(self.config, cb)
250
+ result = actions._apply_single("exponential_backoff", args, {})
251
+ assert "Waited" in result
252
+
253
+
254
+ if __name__ == "__main__":
255
+ pytest.main([__file__, "-v"])