Upload tests/test_pipeline.py
Browse files- tests/test_pipeline.py +294 -0
tests/test_pipeline.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests for the AlphaFactory pipeline.
|
| 3 |
+
Tests both proven template mode and LLM mode (with mocked LLM).
|
| 4 |
+
"""
|
| 5 |
+
import pytest
|
| 6 |
+
import asyncio
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
|
| 11 |
+
from alpha_factory.config import Config, load_config
|
| 12 |
+
from alpha_factory.schemas import Blueprint, Component, Neutralization, AnomalyTag, BrainMetrics, Verdict
|
| 13 |
+
from alpha_factory.data.brain_fields import FIELD_INDEX, BrainField, SignConvention, DatasetTier
|
| 14 |
+
from alpha_factory.deterministic.proven_templates import generate_batch_from_proven_templates, generate_alpha15_variant
|
| 15 |
+
from alpha_factory.deterministic.expression_mutator import generate_mutations, mutate_decay, mutate_horizon, mutate_neutralization
|
| 16 |
+
from alpha_factory.deterministic.lint import lint, quick_dedup_hash
|
| 17 |
+
from alpha_factory.deterministic.fitness import compute_fitness, would_pass_brain
|
| 18 |
+
from alpha_factory.infra.winner_memory import WinnerMemory
|
| 19 |
+
from alpha_factory.data.brain_groups import PRODUCTION_GROUPS, get_group_for_expression
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestProvenTemplates:
|
| 23 |
+
def test_alpha15_generates_valid_expression(self):
|
| 24 |
+
field = BrainField(
|
| 25 |
+
"test_field", "test", 1.0, 0, "Test field", "Test",
|
| 26 |
+
SignConvention.LONG_HIGH, DatasetTier.TIER1
|
| 27 |
+
)
|
| 28 |
+
expr = generate_alpha15_variant(field, group_key="subindustry")
|
| 29 |
+
assert "ts_decay_linear(" in expr
|
| 30 |
+
assert "group_neutralize(" in expr
|
| 31 |
+
assert "ts_rank(" in expr
|
| 32 |
+
assert "test_field" in expr
|
| 33 |
+
assert expr.startswith("ts_decay_linear(")
|
| 34 |
+
|
| 35 |
+
def test_alpha15_long_low_sign(self):
|
| 36 |
+
field = BrainField(
|
| 37 |
+
"test_field", "test", 1.0, 0, "Test field", "Test",
|
| 38 |
+
SignConvention.LONG_LOW, DatasetTier.TIER1
|
| 39 |
+
)
|
| 40 |
+
expr = generate_alpha15_variant(field, group_key="subindustry")
|
| 41 |
+
# LONG_LOW should prefix the field with minus
|
| 42 |
+
assert "-zscore" in expr
|
| 43 |
+
|
| 44 |
+
def test_batch_generation(self):
|
| 45 |
+
batch = generate_batch_from_proven_templates(count=3)
|
| 46 |
+
assert len(batch) <= 3
|
| 47 |
+
assert all("expression" in b for b in batch)
|
| 48 |
+
assert all("field_id" in b for b in batch)
|
| 49 |
+
|
| 50 |
+
def test_all_generated_expressions_pass_lint(self):
|
| 51 |
+
batch = generate_batch_from_proven_templates(count=10)
|
| 52 |
+
for alpha in batch:
|
| 53 |
+
result = lint(alpha["expression"])
|
| 54 |
+
assert result.passed, f"Lint failed for {alpha['template']}: {result.errors}"
|
| 55 |
+
|
| 56 |
+
def test_batch_no_duplicate_fields(self):
|
| 57 |
+
batch = generate_batch_from_proven_templates(count=20)
|
| 58 |
+
field_ids = [b["field_id"] for b in batch]
|
| 59 |
+
assert len(field_ids) == len(set(field_ids)), "Duplicate fields in batch"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TestExpressionMutator:
|
| 63 |
+
def test_mutate_decay(self):
|
| 64 |
+
expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
|
| 65 |
+
variants = mutate_decay(expr, 5)
|
| 66 |
+
assert len(variants) > 0
|
| 67 |
+
assert all("ts_decay_linear(" in v["expression"] for v in variants)
|
| 68 |
+
assert all(v["decay"] != 5 for v in variants)
|
| 69 |
+
|
| 70 |
+
def test_mutate_horizon(self):
|
| 71 |
+
expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
|
| 72 |
+
variants = mutate_horizon(expr)
|
| 73 |
+
assert len(variants) > 0
|
| 74 |
+
|
| 75 |
+
def test_mutate_neutralization(self):
|
| 76 |
+
expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
|
| 77 |
+
variants = mutate_neutralization(expr)
|
| 78 |
+
if variants:
|
| 79 |
+
assert all("subindustry" not in v["expression"] for v in variants)
|
| 80 |
+
|
| 81 |
+
def test_generate_mutations_comprehensive(self):
|
| 82 |
+
expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
|
| 83 |
+
variants = generate_mutations(expr, decay=5)
|
| 84 |
+
assert len(variants) >= 3 # Should get decay, horizon, neutralization at minimum
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TestLint:
|
| 88 |
+
def test_alpha15_expression_passes_lint(self):
|
| 89 |
+
field = BrainField(
|
| 90 |
+
"standardized_unexpected_earnings_2", "model77", 0.92, 0,
|
| 91 |
+
"SUE", "Model", SignConvention.LONG_HIGH, DatasetTier.TIER1
|
| 92 |
+
)
|
| 93 |
+
expr = generate_alpha15_variant(field)
|
| 94 |
+
result = lint(expr)
|
| 95 |
+
assert result.passed, f"Lint errors: {result.errors}"
|
| 96 |
+
assert len(result.warnings) <= 2
|
| 97 |
+
|
| 98 |
+
def test_pv13_field_id_in_expression(self):
|
| 99 |
+
"""Test that corrected pv13_customer field IDs are accepted."""
|
| 100 |
+
from alpha_factory.data.brain_fields import FIELD_INDEX
|
| 101 |
+
assert "pv13_customergraphrank_auth_rank" in FIELD_INDEX
|
| 102 |
+
assert "pv13_customergraphrank_page_rank" in FIELD_INDEX
|
| 103 |
+
|
| 104 |
+
def test_no_typos_in_field_registry(self):
|
| 105 |
+
"""Verify no 'ustomer' typos remain."""
|
| 106 |
+
for fid in FIELD_INDEX:
|
| 107 |
+
assert "_ustomergraphrank" not in fid, f"Typo found in field: {fid}"
|
| 108 |
+
|
| 109 |
+
def test_valid_brain_fields_not_in_fake_list(self):
|
| 110 |
+
"""Verify common BRAIN fields are not in FAKE_FIELDS."""
|
| 111 |
+
from alpha_factory.cleanup import FAKE_FIELDS
|
| 112 |
+
valid_fields = {"close", "high", "low", "volume", "vwap", "open",
|
| 113 |
+
"returns", "ts_returns", "bid_ask_spread", "volatility"}
|
| 114 |
+
for f in valid_fields:
|
| 115 |
+
assert f not in FAKE_FIELDS, f"Valid BRAIN field '{f}' is in FAKE_FIELDS"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TestConfig:
|
| 119 |
+
def test_config_loads(self):
|
| 120 |
+
config = load_config()
|
| 121 |
+
assert config.batch_size >= 1
|
| 122 |
+
assert config.kill.daily_llm_token_budget > 0
|
| 123 |
+
|
| 124 |
+
def test_config_paths_created(self):
|
| 125 |
+
config = load_config()
|
| 126 |
+
assert config.paths.data.exists() or config.paths.data.parent.exists()
|
| 127 |
+
assert config.paths.factor_store.exists() or config.paths.factor_store.parent.exists()
|
| 128 |
+
|
| 129 |
+
def test_config_proven_templates(self):
|
| 130 |
+
config = load_config()
|
| 131 |
+
config.use_proven_templates = True
|
| 132 |
+
assert config.use_proven_templates
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TestWinnerMemory:
|
| 136 |
+
def test_winner_memory_basic(self):
|
| 137 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 138 |
+
db_path = Path(tmpdir) / "test.duckdb"
|
| 139 |
+
wm = WinnerMemory(db_path)
|
| 140 |
+
|
| 141 |
+
wm.record_winner("test_field", "alpha15", "subindustry", 5, 1.5, "momentum")
|
| 142 |
+
winners = wm.get_winning_fields(min_sharpe=1.0)
|
| 143 |
+
assert "test_field" in winners
|
| 144 |
+
|
| 145 |
+
# Need 3+ failures with no wins for a field to be in failed_fields
|
| 146 |
+
wm.record_failure("bad_field", "alpha6", "low_sharpe", "hash1")
|
| 147 |
+
wm.record_failure("bad_field", "alpha6", "high_turnover", "hash2")
|
| 148 |
+
wm.record_failure("bad_field", "alpha6", "flat_line", "hash3")
|
| 149 |
+
failed = wm.get_failed_fields()
|
| 150 |
+
assert "bad_field" in failed
|
| 151 |
+
|
| 152 |
+
wm.close()
|
| 153 |
+
|
| 154 |
+
def test_iteration_queue(self):
|
| 155 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 156 |
+
db_path = Path(tmpdir) / "test.duckdb"
|
| 157 |
+
wm = WinnerMemory(db_path)
|
| 158 |
+
|
| 159 |
+
wm.queue_for_iteration("alpha_1", "rank(close)", 1.0, 0.3, "increase_decay")
|
| 160 |
+
queue = wm.get_iteration_queue(limit=10)
|
| 161 |
+
assert len(queue) == 1
|
| 162 |
+
assert queue[0]["alpha_id"] == "alpha_1"
|
| 163 |
+
|
| 164 |
+
wm.mark_iterated("alpha_1")
|
| 165 |
+
queue = wm.get_iteration_queue(limit=10)
|
| 166 |
+
assert len(queue) == 0
|
| 167 |
+
|
| 168 |
+
wm.close()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class TestBrainGroups:
|
| 172 |
+
def test_novel_groups_exist(self):
|
| 173 |
+
assert len(PRODUCTION_GROUPS) > 0
|
| 174 |
+
|
| 175 |
+
def test_group_coverage(self):
|
| 176 |
+
for g in PRODUCTION_GROUPS:
|
| 177 |
+
assert g.coverage >= 0.90, f"Group {g.id} coverage too low: {g.coverage}"
|
| 178 |
+
assert g.alpha_count <= 30, f"Group {g.id} AC too high: {g.alpha_count}"
|
| 179 |
+
|
| 180 |
+
def test_get_group_returns_string(self):
|
| 181 |
+
group = get_group_for_expression(prefer_novel=True)
|
| 182 |
+
assert isinstance(group, str)
|
| 183 |
+
assert len(group) > 0
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TestFitness:
|
| 187 |
+
def test_fitness_penalizes_high_correlation(self):
|
| 188 |
+
metrics = BrainMetrics(
|
| 189 |
+
alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
|
| 190 |
+
fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
|
| 191 |
+
yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
|
| 192 |
+
)
|
| 193 |
+
low_corr = compute_fitness(metrics, max_corr_to_library=0.1, theme_novelty_score=0.5)
|
| 194 |
+
high_corr = compute_fitness(metrics, max_corr_to_library=0.8, theme_novelty_score=0.5)
|
| 195 |
+
assert low_corr > high_corr, "Lower correlation should give higher fitness"
|
| 196 |
+
|
| 197 |
+
def test_would_pass_brain_comprehensive(self):
|
| 198 |
+
good = BrainMetrics(
|
| 199 |
+
alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
|
| 200 |
+
fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
|
| 201 |
+
yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
|
| 202 |
+
)
|
| 203 |
+
result = would_pass_brain(good)
|
| 204 |
+
assert result["overall_pass"] == True
|
| 205 |
+
|
| 206 |
+
bad = BrainMetrics(
|
| 207 |
+
alpha_id="test", sharpe_full=0.5, sharpe_is=0.6, sharpe_os=0.4,
|
| 208 |
+
fitness=0.3, turnover=0.8, returns=-0.05, max_drawdown=0.15,
|
| 209 |
+
yearly_sharpe=[-0.5, 0.1, -0.3, 0.2, 0.0], yearly_returns=[-0.01]*5,
|
| 210 |
+
)
|
| 211 |
+
result = would_pass_brain(bad)
|
| 212 |
+
assert result["overall_pass"] == False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class TestFieldRegistry:
|
| 216 |
+
def test_goldmine_fields_have_zero_ac(self):
|
| 217 |
+
from alpha_factory.data.brain_fields import GOLDMINE_FIELDS
|
| 218 |
+
for f in GOLDMINE_FIELDS:
|
| 219 |
+
assert f.alpha_count == 0, f"Goldmine field {f.id} has AC={f.alpha_count}"
|
| 220 |
+
|
| 221 |
+
def test_all_fields_unique(self):
|
| 222 |
+
from alpha_factory.data.brain_fields import ALL_FIELDS
|
| 223 |
+
ids = [f.id for f in ALL_FIELDS]
|
| 224 |
+
assert len(ids) == len(set(ids)), "Duplicate field IDs"
|
| 225 |
+
|
| 226 |
+
def test_field_index_complete(self):
|
| 227 |
+
from alpha_factory.data.brain_fields import ALL_FIELDS, FIELD_INDEX
|
| 228 |
+
assert len(FIELD_INDEX) == len(ALL_FIELDS)
|
| 229 |
+
|
| 230 |
+
def test_coverage_in_range(self):
|
| 231 |
+
from alpha_factory.data.brain_fields import ALL_FIELDS
|
| 232 |
+
for f in ALL_FIELDS:
|
| 233 |
+
assert 0.0 <= f.coverage <= 1.0, f"Field {f.id} coverage out of range"
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class TestOperatorsCSV:
|
| 237 |
+
def test_operator_arity_consistency(self):
|
| 238 |
+
"""Verify operators.csv arity matches lint.py OPERATOR_ARITY."""
|
| 239 |
+
from alpha_factory.deterministic.lint import OPERATOR_ARITY
|
| 240 |
+
import csv
|
| 241 |
+
from pathlib import Path
|
| 242 |
+
|
| 243 |
+
csv_path = Path("data/operators.csv")
|
| 244 |
+
if not csv_path.exists():
|
| 245 |
+
pytest.skip("operators.csv not found")
|
| 246 |
+
|
| 247 |
+
with open(csv_path) as f:
|
| 248 |
+
reader = csv.DictReader(f)
|
| 249 |
+
for row in reader:
|
| 250 |
+
op_name = row["name"].strip().lower()
|
| 251 |
+
csv_arity = int(row["level"])
|
| 252 |
+
|
| 253 |
+
if op_name in OPERATOR_ARITY:
|
| 254 |
+
lint_arity = OPERATOR_ARITY[op_name]
|
| 255 |
+
assert csv_arity == lint_arity, (
|
| 256 |
+
f"Operator {op_name}: CSV arity={csv_arity} != lint arity={lint_arity}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@pytest.mark.asyncio
|
| 261 |
+
class TestAsyncPipeline:
|
| 262 |
+
async def test_proven_template_pipeline(self):
|
| 263 |
+
"""Test running a batch in proven template mode."""
|
| 264 |
+
from alpha_factory.orchestration.pipeline import AlphaPipeline
|
| 265 |
+
config = load_config()
|
| 266 |
+
config.batch_size = 3
|
| 267 |
+
config.use_proven_templates = True
|
| 268 |
+
config.enable_brain_client = False
|
| 269 |
+
|
| 270 |
+
pipeline = AlphaPipeline(config)
|
| 271 |
+
result = await pipeline.run_batch(3)
|
| 272 |
+
|
| 273 |
+
assert "promoted" in result or "iterated" in result or "killed" in result
|
| 274 |
+
total = result.get("promoted", 0) + result.get("iterated", 0) + result.get("killed", 0)
|
| 275 |
+
assert total > 0, "Pipeline produced no results"
|
| 276 |
+
|
| 277 |
+
pipeline.close()
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class TestDedupHash:
|
| 281 |
+
def test_same_expression_same_hash(self):
|
| 282 |
+
h1 = quick_dedup_hash("rank(close)", "sector", 5)
|
| 283 |
+
h2 = quick_dedup_hash("rank(close)", "sector", 5)
|
| 284 |
+
assert h1 == h2
|
| 285 |
+
|
| 286 |
+
def test_different_neutralization_different_hash(self):
|
| 287 |
+
h1 = quick_dedup_hash("rank(close)", "sector", 5)
|
| 288 |
+
h2 = quick_dedup_hash("rank(close)", "industry", 5)
|
| 289 |
+
assert h1 != h2
|
| 290 |
+
|
| 291 |
+
def test_different_decay_different_hash(self):
|
| 292 |
+
h1 = quick_dedup_hash("rank(close)", "sector", 5)
|
| 293 |
+
h2 = quick_dedup_hash("rank(close)", "sector", 10)
|
| 294 |
+
assert h1 != h2
|