alpha-factory / tests /test_pipeline.py
gaurv007's picture
Upload tests/test_pipeline.py
344a111 verified
"""
Integration tests for the AlphaFactory pipeline.
Tests both proven template mode and LLM mode (with mocked LLM).
"""
import pytest
import asyncio
from pathlib import Path
import tempfile
import shutil
from alpha_factory.config import Config, load_config
from alpha_factory.schemas import Blueprint, Component, Neutralization, AnomalyTag, BrainMetrics, Verdict
from alpha_factory.data.brain_fields import FIELD_INDEX, BrainField, SignConvention, DatasetTier
from alpha_factory.deterministic.proven_templates import generate_batch_from_proven_templates, generate_alpha15_variant
from alpha_factory.deterministic.expression_mutator import generate_mutations, mutate_decay, mutate_horizon, mutate_neutralization
from alpha_factory.deterministic.lint import lint, quick_dedup_hash
from alpha_factory.deterministic.fitness import compute_fitness, would_pass_brain
from alpha_factory.infra.winner_memory import WinnerMemory
from alpha_factory.data.brain_groups import PRODUCTION_GROUPS, get_group_for_expression
class TestProvenTemplates:
def test_alpha15_generates_valid_expression(self):
field = BrainField(
"test_field", "test", 1.0, 0, "Test field", "Test",
SignConvention.LONG_HIGH, DatasetTier.TIER1
)
expr = generate_alpha15_variant(field, group_key="subindustry", decay=5)
assert "ts_decay_linear(" in expr
assert "group_neutralize(" in expr
assert "ts_rank(" in expr
assert "test_field" in expr
assert expr.startswith("ts_decay_linear(")
def test_alpha15_long_low_sign(self):
field = BrainField(
"test_field", "test", 1.0, 0, "Test field", "Test",
SignConvention.LONG_LOW, DatasetTier.TIER1
)
expr = generate_alpha15_variant(field, group_key="subindustry", decay=5)
# LONG_LOW should prefix the field with minus
assert "-zscore" in expr
def test_alpha15_custom_decay(self):
field = BrainField(
"test_field", "test", 1.0, 0, "Test field", "Test",
SignConvention.LONG_HIGH, DatasetTier.TIER1
)
expr = generate_alpha15_variant(field, group_key="subindustry", decay=10)
assert "ts_decay_linear(" in expr
assert ", 10)" in expr
assert ", 5)" not in expr
def test_batch_generation(self):
batch = generate_batch_from_proven_templates(count=3)
assert len(batch) <= 3
assert all("expression" in b for b in batch)
assert all("field_id" in b for b in batch)
def test_all_generated_expressions_pass_lint(self):
batch = generate_batch_from_proven_templates(count=10)
for alpha in batch:
result = lint(alpha["expression"])
assert result.passed, f"Lint failed for {alpha['template']}: {result.errors}"
def test_batch_no_duplicate_fields(self):
batch = generate_batch_from_proven_templates(count=20)
field_ids = [b["field_id"] for b in batch]
assert len(field_ids) == len(set(field_ids)), "Duplicate fields in batch"
class TestExpressionMutator:
def test_mutate_decay(self):
expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
variants = mutate_decay(expr, 5)
assert len(variants) > 0
assert all("ts_decay_linear(" in v["expression"] for v in variants)
assert all(v["decay"] != 5 for v in variants)
def test_mutate_horizon(self):
expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
variants = mutate_horizon(expr)
assert len(variants) > 0
def test_mutate_neutralization(self):
expr = "ts_decay_linear(group_neutralize(rank(close), subindustry), 5)"
variants = mutate_neutralization(expr)
if variants:
assert all("subindustry" not in v["expression"] for v in variants)
def test_generate_mutations_comprehensive(self):
expr = "ts_decay_linear(group_neutralize(zscore(ts_rank(volume, 252)), subindustry), 5)"
variants = generate_mutations(expr, decay=5)
assert len(variants) >= 3 # Should get decay, horizon, neutralization at minimum
class TestLint:
def test_alpha15_expression_passes_lint(self):
field = BrainField(
"standardized_unexpected_earnings_2", "model77", 0.92, 0,
"SUE", "Model", SignConvention.LONG_HIGH, DatasetTier.TIER1
)
expr = generate_alpha15_variant(field)
result = lint(expr)
assert result.passed, f"Lint errors: {result.errors}"
assert len(result.warnings) <= 2
def test_pv13_field_id_in_expression(self):
"""Test that corrected pv13_customer field IDs are accepted."""
from alpha_factory.data.brain_fields import FIELD_INDEX
assert "pv13_customergraphrank_auth_rank" in FIELD_INDEX
assert "pv13_customergraphrank_page_rank" in FIELD_INDEX
def test_mutate_vol_scale_with_uppercase_fields(self):
"""Test vol_scale mutation handles uppercase field IDs (e.g., mdl77_2GlobalDev...)."""
# Use a realistic expression with uppercase in the field name
expr = "ts_decay_linear(group_neutralize(rank(ts_rank(mdl77_2GlobalDevField, 252)), subindustry), 5)"
from alpha_factory.deterministic.expression_mutator import mutate_vol_scale
variants = mutate_vol_scale(expr)
# Should produce at least one variant since the regex now handles [a-zA-Z]
assert len(variants) > 0, "vol_scale should match uppercase field IDs"
assert "ts_std(" in variants[0]["expression"], "vol_scale should wrap field with ts_std"
def test_no_typos_in_field_registry(self):
"""Verify no 'ustomer' typos remain (missing 'c' in customer)."""
for fid in FIELD_INDEX:
assert "_ustomergraphrank" not in fid, f"Typo found in field: {fid}"
def test_valid_brain_fields_not_in_fake_list(self):
"""Verify common BRAIN fields are not in FAKE_FIELDS."""
from alpha_factory.cleanup import FAKE_FIELDS
valid_fields = {"close", "high", "low", "volume", "vwap", "open",
"returns", "ts_returns", "bid_ask_spread", "volatility"}
for f in valid_fields:
assert f not in FAKE_FIELDS, f"Valid BRAIN field '{f}' is in FAKE_FIELDS"
class TestConfig:
def test_config_loads(self):
config = load_config()
assert config.batch_size >= 1
assert config.kill.daily_llm_token_budget > 0
def test_config_paths_created(self):
config = load_config()
assert config.paths.data.exists() or config.paths.data.parent.exists()
assert config.paths.factor_store.exists() or config.paths.factor_store.parent.exists()
def test_config_proven_templates(self):
config = load_config()
config.use_proven_templates = True
assert config.use_proven_templates
class TestWinnerMemory:
def test_winner_memory_basic(self):
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
wm = WinnerMemory(db_path)
wm.record_winner("test_field", "alpha15", "subindustry", 5, 1.5, "momentum")
winners = wm.get_winning_fields(min_sharpe=1.0)
assert "test_field" in winners
# Need 3+ failures with no wins for a field to be in failed_fields
wm.record_failure("bad_field", "alpha6", "low_sharpe", "hash1")
wm.record_failure("bad_field", "alpha6", "high_turnover", "hash2")
wm.record_failure("bad_field", "alpha6", "flat_line", "hash3")
failed = wm.get_failed_fields()
assert "bad_field" in failed
wm.close()
def test_iteration_queue(self):
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
wm = WinnerMemory(db_path)
wm.queue_for_iteration("alpha_1", "rank(close)", 1.0, 0.3, "increase_decay")
queue = wm.get_iteration_queue(limit=10)
assert len(queue) == 1
assert queue[0]["alpha_id"] == "alpha_1"
wm.mark_iterated("alpha_1")
queue = wm.get_iteration_queue(limit=10)
assert len(queue) == 0
wm.close()
class TestBrainGroups:
def test_novel_groups_exist(self):
assert len(PRODUCTION_GROUPS) > 0
def test_group_coverage(self):
for g in PRODUCTION_GROUPS:
assert g.coverage >= 0.90, f"Group {g.id} coverage too low: {g.coverage}"
assert g.alpha_count <= 30, f"Group {g.id} AC too high: {g.alpha_count}"
def test_get_group_returns_string(self):
group = get_group_for_expression(prefer_novel=True)
assert isinstance(group, str)
assert len(group) > 0
class TestFitness:
def test_fitness_penalizes_high_correlation(self):
metrics = BrainMetrics(
alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
)
low_corr = compute_fitness(metrics, max_corr_to_library=0.1, theme_novelty_score=0.5)
high_corr = compute_fitness(metrics, max_corr_to_library=0.8, theme_novelty_score=0.5)
assert low_corr > high_corr, "Lower correlation should give higher fitness"
def test_would_pass_brain_comprehensive(self):
good = BrainMetrics(
alpha_id="test", sharpe_full=1.5, sharpe_is=1.6, sharpe_os=1.4,
fitness=1.2, turnover=0.25, returns=0.1, max_drawdown=0.04,
yearly_sharpe=[1.2, 1.5, 1.3, 1.4, 1.6], yearly_returns=[0.02]*5,
)
result = would_pass_brain(good)
assert result["overall_pass"] == True
bad = BrainMetrics(
alpha_id="test", sharpe_full=0.5, sharpe_is=0.6, sharpe_os=0.4,
fitness=0.3, turnover=0.8, returns=-0.05, max_drawdown=0.15,
yearly_sharpe=[-0.5, 0.1, -0.3, 0.2, 0.0], yearly_returns=[-0.01]*5,
)
result = would_pass_brain(bad)
assert result["overall_pass"] == False
class TestFieldRegistry:
def test_goldmine_fields_have_zero_ac(self):
from alpha_factory.data.brain_fields import GOLDMINE_FIELDS
for f in GOLDMINE_FIELDS:
assert f.alpha_count == 0, f"Goldmine field {f.id} has AC={f.alpha_count}"
def test_all_fields_unique(self):
from alpha_factory.data.brain_fields import ALL_FIELDS
ids = [f.id for f in ALL_FIELDS]
assert len(ids) == len(set(ids)), "Duplicate field IDs"
def test_field_index_complete(self):
from alpha_factory.data.brain_fields import ALL_FIELDS, FIELD_INDEX
assert len(FIELD_INDEX) == len(ALL_FIELDS)
def test_coverage_in_range(self):
from alpha_factory.data.brain_fields import ALL_FIELDS
for f in ALL_FIELDS:
assert 0.0 <= f.coverage <= 1.0, f"Field {f.id} coverage out of range"
class TestOperatorsCSV:
def test_operator_arity_consistency(self):
"""Verify operators.csv contains all operators in lint.py OPERATOR_ARITY."""
from alpha_factory.deterministic.lint import OPERATOR_ARITY
import csv
from pathlib import Path
csv_path = Path("data/operators.csv")
if not csv_path.exists():
pytest.skip("operators.csv not found")
csv_ops = set()
with open(csv_path) as f:
reader = csv.DictReader(f)
for row in reader:
op_name = row["name"].strip().lower()
csv_ops.add(op_name)
# Every operator in lint.py should exist in the CSV
for op_name in OPERATOR_ARITY:
assert op_name in csv_ops, f"Operator '{op_name}' in lint.py but not in operators.csv"
class TestAsyncPipeline:
def test_proven_template_pipeline(self):
"""Test running a batch in proven template mode."""
import asyncio
from alpha_factory.orchestration.pipeline import AlphaPipeline
config = load_config()
config.batch_size = 3
config.use_proven_templates = True
config.enable_brain_client = False
pipeline = AlphaPipeline(config)
async def _run():
return await pipeline.run_batch(3)
result = asyncio.run(_run())
# Pipeline now has local sim + checklist gates — results may be killed
assert "promoted" in result or "iterated" in result or "killed" in result
total = result.get("promoted", 0) + result.get("iterated", 0) + result.get("killed", 0)
assert total == 3, f"Expected 3 results, got {total}"
pipeline.close()
class TestModelManager:
def test_model_manager_defaults(self):
from alpha_factory.infra.model_manager import ModelManager, DEFAULTS
mm = ModelManager()
assert "microfish" in DEFAULTS
assert "mediumfish" in DEFAULTS
def test_model_manager_get_endpoint(self):
from alpha_factory.infra.model_manager import ModelManager, ModelProvider
mm = ModelManager()
mm.selected["mediumfish"] = mm.get_selected("mediumfish")
url, name, headers = mm.get_endpoint("mediumfish")
assert url.endswith("/v1")
assert len(name) > 0
class TestThemeSampler:
def test_pick_theme_returns_string(self):
from alpha_factory.deterministic.theme_sampler import pick_theme
theme = pick_theme([], [], [])
assert isinstance(theme, str)
assert len(theme) > 0
def test_pick_theme_penalizes_existing(self):
from alpha_factory.deterministic.theme_sampler import pick_theme
existing = ["momentum", "value", "quality"]
theme = pick_theme(existing, [], [])
# Theme should be from the defined THEME_FIELDS
assert theme in [
"earnings_surprise_momentum", "earnings_quality_signaling",
"asset_growth_anomaly", "forward_value_composite",
"liquidity_risk_premium", "multi_factor_momentum",
"news_reaction_drift", "analyst_guidance_revision",
"options_sentiment_pcr", "supply_chain_network",
"social_contrarian", "geographic_exposure",
]
def test_pick_theme_all_dead_returns_alive(self):
"""If all themes are dead, pick_theme must still return a valid theme."""
from alpha_factory.deterministic.theme_sampler import pick_theme, THEME_FIELDS
all_themes = list(THEME_FIELDS.keys())
theme = pick_theme([], [], dead_themes=all_themes)
assert theme in THEME_FIELDS, "Should return a valid theme even when all are dead"
class TestWQClient:
def test_brain_client_error_hierarchy(self):
from alpha_factory.infra.wq_client import BrainClientError, BrainAuthError, BrainRateLimitError
assert issubclass(BrainAuthError, BrainClientError)
assert issubclass(BrainRateLimitError, BrainClientError)
def test_alpha_hash_deterministic(self):
from alpha_factory.infra.wq_client import BrainClient
h1 = BrainClient.alpha_hash("rank(close)", "sector", 5)
h2 = BrainClient.alpha_hash("rank(close)", "sector", 5)
h3 = BrainClient.alpha_hash("rank(close)", "industry", 5)
assert h1 == h2
assert h1 != h3
assert len(h1) == 16 # hex digest truncated
class TestBrainSim:
def test_rank_normalize_basic(self):
import numpy as np
from alpha_factory.local.brain_sim import _rank_normalize
arr = np.array([3.0, 1.0, 2.0])
result = _rank_normalize(arr)
assert len(result) == 3
assert np.all(result >= -1) and np.all(result <= 1)
# Highest value should be at index 2 (3.0), lowest at index 1 (1.0)
assert result[2] > result[1]
def test_simulate_alpha_local_with_random(self):
import numpy as np
from alpha_factory.local.brain_sim import simulate_alpha_local
np.random.seed(42)
signal = np.random.randn(100, 50)
returns = np.random.randn(100, 50) * 0.02
result = simulate_alpha_local(signal, returns, min_sharpe=0.0, min_fitness=0.0)
assert isinstance(result.sharpe, float)
assert isinstance(result.turnover, float)
assert 0 <= result.turnover <= 2.0 # Turnover should be in reasonable range
class TestCleanup:
def test_cleanup_orphans_skips_common_words(self):
"""Verify cleanup_orphans doesn't delete alphas with common words like 'backfill'."""
from alpha_factory.cleanup import cleanup_orphans
# This is a smoke test — actual DB test would need setup
# The key assertion is that 'backfill' is in the skip set
import re
skip = {
"subindustry", "industry", "sector", "market",
"close", "high", "low", "open", "volume", "vwap",
"backfill", "neutralize", "expression",
}
expr = "ts_decay_linear(group_neutralize(rank(ts_backfill(close, 30)), subindustry), 5)"
tokens = re.findall(r"\b([a-z][a-z0-9_]{10,})\b", expr.lower())
for t in tokens:
assert t in skip or t.startswith("ts_") or t.startswith("group_"), f"Unexpected token: {t}"
class TestDedupHash:
def test_same_expression_same_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(close)", "sector", 5)
assert h1 == h2
def test_different_neutralization_different_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(close)", "industry", 5)
assert h1 != h2
def test_different_decay_different_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(close)", "sector", 10)
assert h1 != h2