alpha-factory / tests /test_lint.py
gaurv007's picture
Upload tests/test_lint.py
6a64df3 verified
"""
Tests for the static lint module.
"""
import pytest
from alpha_factory.deterministic.lint import lint, quick_dedup_hash
class TestLint:
def test_valid_simple_expression(self):
expr = "rank(ts_mean(close, 20))"
result = lint(expr)
assert result.passed
assert result.errors == []
def test_valid_composite_expression(self):
expr = "0.5 * zscore(ts_mean(close, 20)) + 0.5 * zscore(rank(volume))"
result = lint(expr)
assert result.passed
def test_unknown_operator(self):
expr = "fake_operator(close, 20)"
result = lint(expr)
assert not result.passed
assert any("Unknown operator" in e for e in result.errors)
def test_lookahead_negative_delay(self):
expr = "ts_delay(close, -1)"
result = lint(expr)
assert not result.passed
assert any("Look-ahead" in e for e in result.errors)
def test_lookahead_future_field(self):
expr = "rank(future_returns)"
result = lint(expr)
assert not result.passed
assert any("Look-ahead" in e for e in result.errors)
def test_unbalanced_parens_extra_close(self):
expr = "rank(close))"
result = lint(expr)
assert not result.passed
assert any("Unbalanced" in e for e in result.errors)
def test_unbalanced_parens_unclosed(self):
expr = "rank(ts_mean(close, 20)"
result = lint(expr)
assert not result.passed
assert any("Unbalanced" in e for e in result.errors)
def test_empty_expression(self):
result = lint("")
assert not result.passed
def test_unit_safety_warning(self):
expr = "0.5 * ts_mean(close, 20) + 0.3 * volume"
result = lint(expr)
# Should have warnings about unit safety
assert len(result.warnings) > 0
def test_high_decay_warning(self):
expr = "ts_decay_linear(rank(close), 50)"
result = lint(expr)
assert result.passed # warning, not error
assert any("is high" in w for w in result.warnings), f"Warnings: {result.warnings}"
class TestDedup:
def test_same_input_same_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(close)", "sector", 5)
assert h1 == h2
def test_different_expression_different_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(volume)", "sector", 5)
assert h1 != h2
def test_different_neutralization_different_hash(self):
h1 = quick_dedup_hash("rank(close)", "sector", 5)
h2 = quick_dedup_hash("rank(close)", "industry", 5)
assert h1 != h2