gaurv007 commited on
Commit
82fe9b4
·
verified ·
1 Parent(s): 494e9ca

Upload tests/test_pipeline.py

Browse files
Files changed (1) hide show
  1. tests/test_pipeline.py +294 -0
tests/test_pipeline.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for the AlphaFactory pipeline.
3
+ Tests both proven template mode and LLM mode (with mocked LLM).
4
+ """
5
+ import pytest
6
+ import asyncio
7
+ from pathlib import Path
8
+ import tempfile
9
+ import shutil
10
+
11
+ from alpha_factory.config import Config, load_config
12
+ from alpha_factory.schemas import Blueprint, Component, Neutralization, AnomalyTag, BrainMetrics, Verdict
13
+ from alpha_factory.data.brain_fields import FIELD_INDEX, BrainField, SignConvention, DatasetTier
14
+ from alpha_factory.deterministic.proven_templates import generate_batch_from_proven_templates, generate_alpha15_variant
15
+ from alpha_factory.deterministic.expression_mutator import generate_mutations, mutate_decay, mutate_horizon, mutate_neutralization
16
+ from alpha_factory.deterministic.lint import lint, quick_dedup_hash
17
+ from alpha_factory.deterministic.fitness import compute_fitness, would_pass_brain
18
+ from alpha_factory.infra.winner_memory import WinnerMemory
19
+ from alpha_factory.data.brain_groups import PRODUCTION_GROUPS, get_group_for_expression
20
+
21
+
22
+ class TestProvenTemplates:
23
+ def test_alpha15_generates_valid_expression(self):
24
+ field = BrainField(
25
+ "test_field", "test", 1.0, 0, "Test field", "Test",
26
+ SignConvention.LONG_HIGH, DatasetTier.TIER1
27
+ )
28
+ expr = generate_alpha15_variant(field, group_key="subindustry")
29
+ assert "ts_decay_linear(" in expr
30
+ assert "group_neutralize(" in expr
31
+ assert "ts_rank(" in expr
32
+ assert "test_field" in expr
33
+ assert expr.startswith("ts_decay_linear(")
34
+
35
+ def test_alpha15_long_low_sign(self):
36
+ field = BrainField(
37
+ "test_field", "test", 1.0, 0, "Test field", "Test",
38
+ SignConvention.LONG_LOW, DatasetTier.TIER1
39
+ )
40
+ expr = generate_alpha15_variant(field, group_key="subindustry")
41
+ # LONG_LOW should prefix the field with minus
42
+ assert "-zscore" in expr
43
+
44
+ def test_batch_generation(self):
45
+ batch = generate_batch_from_proven_templates(count=3)
46
+ assert len(batch) <= 3
47
+ assert all("expression" in b for b in batch)
48
+ assert all("field_id" in b for b in batch)
49
+
50
+ def test_all_generated_expressions_pass_lint(self):
51
+ batch = generate_batch_from_proven_templates(count=10)
52
+ for alpha in batch:
53
+ result = lint(alpha["expression"])
54
+ assert result.passed, f"Lint failed for {alpha['template']}: {result.errors}"
55
+
56
+ def test_batch_no_duplicate_fields(self):
57
+ batch = generate_batch_from_proven_templates(count=20)
58
+ field_ids = [b["field_id"] for b in batch]
59
+ assert len(field_ids) == len(set(field_ids)), "Duplicate fields in batch"
60
+
61
+
62
+ class TestExpressionMutator:
63
+ def test_mutate_decay(self):
64
+ expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
65
+ variants = mutate_decay(expr, 5)
66
+ assert len(variants) > 0
67
+ assert all("ts_decay_linear(" in v["expression"] for v in variants)
68
+ assert all(v["decay"] != 5 for v in variants)
69
+
70
+ def test_mutate_horizon(self):
71
+ expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
72
+ variants = mutate_horizon(expr)
73
+ assert len(variants) > 0
74
+
75
+ def test_mutate_neutralization(self):
76
+ expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
77
+ variants = mutate_neutralization(expr)
78
+ if variants:
79
+ assert all("subindustry" not in v["expression"] for v in variants)
80
+
81
+ def test_generate_mutations_comprehensive(self):
82
+ expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
83
+ variants = generate_mutations(expr, decay=5)
84
+ assert len(variants) >= 3 # Should get decay, horizon, neutralization at minimum
85
+
86
+
87
+ class TestLint:
88
+ def test_alpha15_expression_passes_lint(self):
89
+ field = BrainField(
90
+ "standardized_unexpected_earnings_2", "model77", 0.92, 0,
91
+ "SUE", "Model", SignConvention.LONG_HIGH, DatasetTier.TIER1
92
+ )
93
+ expr = generate_alpha15_variant(field)
94
+ result = lint(expr)
95
+ assert result.passed, f"Lint errors: {result.errors}"
96
+ assert len(result.warnings) <= 2
97
+
98
+ def test_pv13_field_id_in_expression(self):
99
+ """Test that corrected pv13_customer field IDs are accepted."""
100
+ from alpha_factory.data.brain_fields import FIELD_INDEX
101
+ assert "pv13_customergraphrank_auth_rank" in FIELD_INDEX
102
+ assert "pv13_customergraphrank_page_rank" in FIELD_INDEX
103
+
104
+ def test_no_typos_in_field_registry(self):
105
+ """Verify no 'ustomer' typos remain."""
106
+ for fid in FIELD_INDEX:
107
+ assert "_ustomergraphrank" not in fid, f"Typo found in field: {fid}"
108
+
109
+ def test_valid_brain_fields_not_in_fake_list(self):
110
+ """Verify common BRAIN fields are not in FAKE_FIELDS."""
111
+ from alpha_factory.cleanup import FAKE_FIELDS
112
+ valid_fields = {"close", "high", "low", "volume", "vwap", "open",
113
+ "returns", "ts_returns", "bid_ask_spread", "volatility"}
114
+ for f in valid_fields:
115
+ assert f not in FAKE_FIELDS, f"Valid BRAIN field '{f}' is in FAKE_FIELDS"
116
+
117
+
118
+ class TestConfig:
119
+ def test_config_loads(self):
120
+ config = load_config()
121
+ assert config.batch_size >= 1
122
+ assert config.kill.daily_llm_token_budget > 0
123
+
124
+ def test_config_paths_created(self):
125
+ config = load_config()
126
+ assert config.paths.data.exists() or config.paths.data.parent.exists()
127
+ assert config.paths.factor_store.exists() or config.paths.factor_store.parent.exists()
128
+
129
+ def test_config_proven_templates(self):
130
+ config = load_config()
131
+ config.use_proven_templates = True
132
+ assert config.use_proven_templates
133
+
134
+
135
+ class TestWinnerMemory:
136
+ def test_winner_memory_basic(self):
137
+ with tempfile.TemporaryDirectory() as tmpdir:
138
+ db_path = Path(tmpdir) / "test.duckdb"
139
+ wm = WinnerMemory(db_path)
140
+
141
+ wm.record_winner("test_field", "alpha15", "subindustry", 5, 1.5, "momentum")
142
+ winners = wm.get_winning_fields(min_sharpe=1.0)
143
+ assert "test_field" in winners
144
+
145
+ # Need 3+ failures with no wins for a field to be in failed_fields
146
+ wm.record_failure("bad_field", "alpha6", "low_sharpe", "hash1")
147
+ wm.record_failure("bad_field", "alpha6", "high_turnover", "hash2")
148
+ wm.record_failure("bad_field", "alpha6", "flat_line", "hash3")
149
+ failed = wm.get_failed_fields()
150
+ assert "bad_field" in failed
151
+
152
+ wm.close()
153
+
154
+ def test_iteration_queue(self):
155
+ with tempfile.TemporaryDirectory() as tmpdir:
156
+ db_path = Path(tmpdir) / "test.duckdb"
157
+ wm = WinnerMemory(db_path)
158
+
159
+ wm.queue_for_iteration("alpha_1", "rank(close)", 1.0, 0.3, "increase_decay")
160
+ queue = wm.get_iteration_queue(limit=10)
161
+ assert len(queue) == 1
162
+ assert queue[0]["alpha_id"] == "alpha_1"
163
+
164
+ wm.mark_iterated("alpha_1")
165
+ queue = wm.get_iteration_queue(limit=10)
166
+ assert len(queue) == 0
167
+
168
+ wm.close()
169
+
170
+
171
+ class TestBrainGroups:
172
+ def test_novel_groups_exist(self):
173
+ assert len(PRODUCTION_GROUPS) > 0
174
+
175
+ def test_group_coverage(self):
176
+ for g in PRODUCTION_GROUPS:
177
+ assert g.coverage >= 0.90, f"Group {g.id} coverage too low: {g.coverage}"
178
+ assert g.alpha_count <= 30, f"Group {g.id} AC too high: {g.alpha_count}"
179
+
180
+ def test_get_group_returns_string(self):
181
+ group = get_group_for_expression(prefer_novel=True)
182
+ assert isinstance(group, str)
183
+ assert len(group) > 0
184
+
185
+
186
+ class TestFitness:
187
+ def test_fitness_penalizes_high_correlation(self):
188
+ metrics = BrainMetrics(
189
+ alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
190
+ fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
191
+ yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
192
+ )
193
+ low_corr = compute_fitness(metrics, max_corr_to_library=0.1, theme_novelty_score=0.5)
194
+ high_corr = compute_fitness(metrics, max_corr_to_library=0.8, theme_novelty_score=0.5)
195
+ assert low_corr > high_corr, "Lower correlation should give higher fitness"
196
+
197
+ def test_would_pass_brain_comprehensive(self):
198
+ good = BrainMetrics(
199
+ alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
200
+ fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
201
+ yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
202
+ )
203
+ result = would_pass_brain(good)
204
+ assert result["overall_pass"] == True
205
+
206
+ bad = BrainMetrics(
207
+ alpha_id="test", sharpe_full=0.5, sharpe_is=0.6, sharpe_os=0.4,
208
+ fitness=0.3, turnover=0.8, returns=-0.05, max_drawdown=0.15,
209
+ yearly_sharpe=[-0.5, 0.1, -0.3, 0.2, 0.0], yearly_returns=[-0.01]*5,
210
+ )
211
+ result = would_pass_brain(bad)
212
+ assert result["overall_pass"] == False
213
+
214
+
215
+ class TestFieldRegistry:
216
+ def test_goldmine_fields_have_zero_ac(self):
217
+ from alpha_factory.data.brain_fields import GOLDMINE_FIELDS
218
+ for f in GOLDMINE_FIELDS:
219
+ assert f.alpha_count == 0, f"Goldmine field {f.id} has AC={f.alpha_count}"
220
+
221
+ def test_all_fields_unique(self):
222
+ from alpha_factory.data.brain_fields import ALL_FIELDS
223
+ ids = [f.id for f in ALL_FIELDS]
224
+ assert len(ids) == len(set(ids)), "Duplicate field IDs"
225
+
226
+ def test_field_index_complete(self):
227
+ from alpha_factory.data.brain_fields import ALL_FIELDS, FIELD_INDEX
228
+ assert len(FIELD_INDEX) == len(ALL_FIELDS)
229
+
230
+ def test_coverage_in_range(self):
231
+ from alpha_factory.data.brain_fields import ALL_FIELDS
232
+ for f in ALL_FIELDS:
233
+ assert 0.0 <= f.coverage <= 1.0, f"Field {f.id} coverage out of range"
234
+
235
+
236
+ class TestOperatorsCSV:
237
+ def test_operator_arity_consistency(self):
238
+ """Verify operators.csv arity matches lint.py OPERATOR_ARITY."""
239
+ from alpha_factory.deterministic.lint import OPERATOR_ARITY
240
+ import csv
241
+ from pathlib import Path
242
+
243
+ csv_path = Path("data/operators.csv")
244
+ if not csv_path.exists():
245
+ pytest.skip("operators.csv not found")
246
+
247
+ with open(csv_path) as f:
248
+ reader = csv.DictReader(f)
249
+ for row in reader:
250
+ op_name = row["name"].strip().lower()
251
+ csv_arity = int(row["level"])
252
+
253
+ if op_name in OPERATOR_ARITY:
254
+ lint_arity = OPERATOR_ARITY[op_name]
255
+ assert csv_arity == lint_arity, (
256
+ f"Operator {op_name}: CSV arity={csv_arity} != lint arity={lint_arity}"
257
+ )
258
+
259
+
260
+ @pytest.mark.asyncio
261
+ class TestAsyncPipeline:
262
+ async def test_proven_template_pipeline(self):
263
+ """Test running a batch in proven template mode."""
264
+ from alpha_factory.orchestration.pipeline import AlphaPipeline
265
+ config = load_config()
266
+ config.batch_size = 3
267
+ config.use_proven_templates = True
268
+ config.enable_brain_client = False
269
+
270
+ pipeline = AlphaPipeline(config)
271
+ result = await pipeline.run_batch(3)
272
+
273
+ assert "promoted" in result or "iterated" in result or "killed" in result
274
+ total = result.get("promoted", 0) + result.get("iterated", 0) + result.get("killed", 0)
275
+ assert total > 0, "Pipeline produced no results"
276
+
277
+ pipeline.close()
278
+
279
+
280
+ class TestDedupHash:
281
+ def test_same_expression_same_hash(self):
282
+ h1 = quick_dedup_hash("rank(close)", "sector", 5)
283
+ h2 = quick_dedup_hash("rank(close)", "sector", 5)
284
+ assert h1 == h2
285
+
286
+ def test_different_neutralization_different_hash(self):
287
+ h1 = quick_dedup_hash("rank(close)", "sector", 5)
288
+ h2 = quick_dedup_hash("rank(close)", "industry", 5)
289
+ assert h1 != h2
290
+
291
+ def test_different_decay_different_hash(self):
292
+ h1 = quick_dedup_hash("rank(close)", "sector", 5)
293
+ h2 = quick_dedup_hash("rank(close)", "sector", 10)
294
+ assert h1 != h2