File size: 2,709 Bytes
d1612a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a64df3
d1612a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Tests for the static lint module.
"""
import pytest
from alpha_factory.deterministic.lint import lint, quick_dedup_hash


class TestLint:
    def test_valid_simple_expression(self):
        expr = "rank(ts_mean(close, 20))"
        result = lint(expr)
        assert result.passed
        assert result.errors == []

    def test_valid_composite_expression(self):
        expr = "0.5 * zscore(ts_mean(close, 20)) + 0.5 * zscore(rank(volume))"
        result = lint(expr)
        assert result.passed

    def test_unknown_operator(self):
        expr = "fake_operator(close, 20)"
        result = lint(expr)
        assert not result.passed
        assert any("Unknown operator" in e for e in result.errors)

    def test_lookahead_negative_delay(self):
        expr = "ts_delay(close, -1)"
        result = lint(expr)
        assert not result.passed
        assert any("Look-ahead" in e for e in result.errors)

    def test_lookahead_future_field(self):
        expr = "rank(future_returns)"
        result = lint(expr)
        assert not result.passed
        assert any("Look-ahead" in e for e in result.errors)

    def test_unbalanced_parens_extra_close(self):
        expr = "rank(close))"
        result = lint(expr)
        assert not result.passed
        assert any("Unbalanced" in e for e in result.errors)

    def test_unbalanced_parens_unclosed(self):
        expr = "rank(ts_mean(close, 20)"
        result = lint(expr)
        assert not result.passed
        assert any("Unbalanced" in e for e in result.errors)

    def test_empty_expression(self):
        result = lint("")
        assert not result.passed

    def test_unit_safety_warning(self):
        expr = "0.5 * ts_mean(close, 20) + 0.3 * volume"
        result = lint(expr)
        # Should have warnings about unit safety
        assert len(result.warnings) > 0

    def test_high_decay_warning(self):
        expr = "ts_decay_linear(rank(close), 50)"
        result = lint(expr)
        assert result.passed  # warning, not error
        assert any("is high" in w for w in result.warnings), f"Warnings: {result.warnings}"


class TestDedup:
    def test_same_input_same_hash(self):
        h1 = quick_dedup_hash("rank(close)", "sector", 5)
        h2 = quick_dedup_hash("rank(close)", "sector", 5)
        assert h1 == h2

    def test_different_expression_different_hash(self):
        h1 = quick_dedup_hash("rank(close)", "sector", 5)
        h2 = quick_dedup_hash("rank(volume)", "sector", 5)
        assert h1 != h2

    def test_different_neutralization_different_hash(self):
        h1 = quick_dedup_hash("rank(close)", "sector", 5)
        h2 = quick_dedup_hash("rank(close)", "industry", 5)
        assert h1 != h2