gaurv007 commited on
Commit
344a111
·
verified ·
1 Parent(s): f212716

Upload tests/test_pipeline.py

Browse files
Files changed (1) hide show
  1. tests/test_pipeline.py +45 -6
tests/test_pipeline.py CHANGED
@@ -25,7 +25,7 @@ class TestProvenTemplates:
25
  "test_field", "test", 1.0, 0, "Test field", "Test",
26
  SignConvention.LONG_HIGH, DatasetTier.TIER1
27
  )
28
- expr = generate_alpha15_variant(field, group_key="subindustry")
29
  assert "ts_decay_linear(" in expr
30
  assert "group_neutralize(" in expr
31
  assert "ts_rank(" in expr
@@ -37,10 +37,20 @@ class TestProvenTemplates:
37
  "test_field", "test", 1.0, 0, "Test field", "Test",
38
  SignConvention.LONG_LOW, DatasetTier.TIER1
39
  )
40
- expr = generate_alpha15_variant(field, group_key="subindustry")
41
  # LONG_LOW should prefix the field with minus
42
  assert "-zscore" in expr
43
 
 
 
 
 
 
 
 
 
 
 
44
  def test_batch_generation(self):
45
  batch = generate_batch_from_proven_templates(count=3)
46
  assert len(batch) <= 3
@@ -101,6 +111,16 @@ class TestLint:
101
  assert "pv13_customergraphrank_auth_rank" in FIELD_INDEX
102
  assert "pv13_customergraphrank_page_rank" in FIELD_INDEX
103
 
 
 
 
 
 
 
 
 
 
 
104
  def test_no_typos_in_field_registry(self):
105
  """Verify no 'ustomer' typos remain (missing 'c' in customer)."""
106
  for fid in FIELD_INDEX:
@@ -256,10 +276,10 @@ class TestOperatorsCSV:
256
  assert op_name in csv_ops, f"Operator '{op_name}' in lint.py but not in operators.csv"
257
 
258
 
259
- @pytest.mark.asyncio
260
  class TestAsyncPipeline:
261
- async def test_proven_template_pipeline(self):
262
  """Test running a batch in proven template mode."""
 
263
  from alpha_factory.orchestration.pipeline import AlphaPipeline
264
  config = load_config()
265
  config.batch_size = 3
@@ -267,7 +287,11 @@ class TestAsyncPipeline:
267
  config.enable_brain_client = False
268
 
269
  pipeline = AlphaPipeline(config)
270
- result = await pipeline.run_batch(3)
 
 
 
 
271
 
272
  # Pipeline now has local sim + checklist gates — results may be killed
273
  assert "promoted" in result or "iterated" in result or "killed" in result
@@ -304,7 +328,22 @@ class TestThemeSampler:
304
  from alpha_factory.deterministic.theme_sampler import pick_theme
305
  existing = ["momentum", "value", "quality"]
306
  theme = pick_theme(existing, [], [])
307
- assert theme not in existing or len(existing) >= 20 # May repeat if exhausted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
 
310
  class TestWQClient:
 
25
  "test_field", "test", 1.0, 0, "Test field", "Test",
26
  SignConvention.LONG_HIGH, DatasetTier.TIER1
27
  )
28
+ expr = generate_alpha15_variant(field, group_key="subindustry", decay=5)
29
  assert "ts_decay_linear(" in expr
30
  assert "group_neutralize(" in expr
31
  assert "ts_rank(" in expr
 
37
  "test_field", "test", 1.0, 0, "Test field", "Test",
38
  SignConvention.LONG_LOW, DatasetTier.TIER1
39
  )
40
+ expr = generate_alpha15_variant(field, group_key="subindustry", decay=5)
41
  # LONG_LOW should prefix the field with minus
42
  assert "-zscore" in expr
43
 
44
+ def test_alpha15_custom_decay(self):
45
+ field = BrainField(
46
+ "test_field", "test", 1.0, 0, "Test field", "Test",
47
+ SignConvention.LONG_HIGH, DatasetTier.TIER1
48
+ )
49
+ expr = generate_alpha15_variant(field, group_key="subindustry", decay=10)
50
+ assert "ts_decay_linear(" in expr
51
+ assert ", 10)" in expr
52
+ assert ", 5)" not in expr
53
+
54
  def test_batch_generation(self):
55
  batch = generate_batch_from_proven_templates(count=3)
56
  assert len(batch) <= 3
 
111
  assert "pv13_customergraphrank_auth_rank" in FIELD_INDEX
112
  assert "pv13_customergraphrank_page_rank" in FIELD_INDEX
113
 
114
+ def test_mutate_vol_scale_with_uppercase_fields(self):
115
+ """Test vol_scale mutation handles uppercase field IDs (e.g., mdl77_2GlobalDev...)."""
116
+ # Use a realistic expression with uppercase in the field name
117
+ expr = "ts_decay_linear(group_neutralize(rank(ts_rank(mdl77_2GlobalDevField, 252)), subindustry), 5)"
118
+ from alpha_factory.deterministic.expression_mutator import mutate_vol_scale
119
+ variants = mutate_vol_scale(expr)
120
+ # Should produce at least one variant since the regex now handles [a-zA-Z]
121
+ assert len(variants) > 0, "vol_scale should match uppercase field IDs"
122
+ assert "ts_std(" in variants[0]["expression"], "vol_scale should wrap field with ts_std"
123
+
124
  def test_no_typos_in_field_registry(self):
125
  """Verify no 'ustomer' typos remain (missing 'c' in customer)."""
126
  for fid in FIELD_INDEX:
 
276
  assert op_name in csv_ops, f"Operator '{op_name}' in lint.py but not in operators.csv"
277
 
278
 
 
279
  class TestAsyncPipeline:
280
+ def test_proven_template_pipeline(self):
281
  """Test running a batch in proven template mode."""
282
+ import asyncio
283
  from alpha_factory.orchestration.pipeline import AlphaPipeline
284
  config = load_config()
285
  config.batch_size = 3
 
287
  config.enable_brain_client = False
288
 
289
  pipeline = AlphaPipeline(config)
290
+
291
+ async def _run():
292
+ return await pipeline.run_batch(3)
293
+
294
+ result = asyncio.run(_run())
295
 
296
  # Pipeline now has local sim + checklist gates — results may be killed
297
  assert "promoted" in result or "iterated" in result or "killed" in result
 
328
  from alpha_factory.deterministic.theme_sampler import pick_theme
329
  existing = ["momentum", "value", "quality"]
330
  theme = pick_theme(existing, [], [])
331
+ # Theme should be from the defined THEME_FIELDS
332
+ assert theme in [
333
+ "earnings_surprise_momentum", "earnings_quality_signaling",
334
+ "asset_growth_anomaly", "forward_value_composite",
335
+ "liquidity_risk_premium", "multi_factor_momentum",
336
+ "news_reaction_drift", "analyst_guidance_revision",
337
+ "options_sentiment_pcr", "supply_chain_network",
338
+ "social_contrarian", "geographic_exposure",
339
+ ]
340
+
341
+ def test_pick_theme_all_dead_returns_alive(self):
342
+ """If all themes are dead, pick_theme must still return a valid theme."""
343
+ from alpha_factory.deterministic.theme_sampler import pick_theme, THEME_FIELDS
344
+ all_themes = list(THEME_FIELDS.keys())
345
+ theme = pick_theme([], [], dead_themes=all_themes)
346
+ assert theme in THEME_FIELDS, "Should return a valid theme even when all are dead"
347
 
348
 
349
  class TestWQClient: