File size: 7,427 Bytes
46ac2bf
b042e9d
a6a7d9d
46ac2bf
 
 
 
fa79bf1
46ac2bf
 
b042e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d773d2
b042e9d
 
 
 
 
 
 
 
 
46ac2bf
b042e9d
46ac2bf
b042e9d
46ac2bf
 
 
 
 
 
 
 
 
 
 
 
fa79bf1
 
 
 
46ac2bf
 
 
 
b042e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa79bf1
46ac2bf
 
fa79bf1
46ac2bf
 
 
 
 
 
 
fa79bf1
46ac2bf
 
 
 
b042e9d
46ac2bf
 
 
a6a7d9d
46ac2bf
fa79bf1
46ac2bf
 
 
b042e9d
46ac2bf
fa79bf1
46ac2bf
 
 
 
 
 
 
 
 
 
 
 
b042e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
46ac2bf
 
 
 
 
 
 
 
b042e9d
46ac2bf
b042e9d
fa79bf1
b042e9d
 
 
 
 
46ac2bf
 
fa79bf1
46ac2bf
b042e9d
46ac2bf
 
 
b042e9d
46ac2bf
b042e9d
46ac2bf
 
 
 
 
 
 
b042e9d
fa79bf1
b042e9d
fa79bf1
46ac2bf
 
 
 
 
 
 
 
 
 
 
fa79bf1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Static Lint v3 — Layer 2: Deterministic pre-flight checks.
Validates operator ARITY (argument count) to catch 'Invalid number of inputs' errors.
"""
import re
from pathlib import Path
from ..schemas import LintResult
from ..data.brain_fields import FIELD_INDEX


# Operator → required number of arguments (minimum)
OPERATOR_ARITY: dict[str, int] = {
    # 1-argument operators
    "rank": 1, "zscore": 1, "quantile": 1, "abs": 1, "log": 1,
    "sign": 1, "sqrt": 1, "sigmoid": 1, "tanh": 1, "relu": 1,
    "pasteurize": 1, "truncate": 1, "fraction": 1,
    "vec_avg": 1, "vec_sum": 1, "vec_norm": 1, "vec_count": 1,
    # 2-argument operators
    "ts_mean": 2, "ts_std": 2, "ts_sum": 2, "ts_min": 2, "ts_max": 2,
    "ts_rank": 2, "ts_zscore": 2, "ts_delta": 2, "ts_delay": 2,
    "ts_decay_linear": 2, "ts_argmax": 2, "ts_argmin": 2,
    "ts_skewness": 2, "ts_kurtosis": 2, "ts_entropy": 2,
    "ts_product": 2, "ts_moment": 2, "ts_av_diff": 2,
    "ts_hump": 2, "ts_scale": 2, "ts_step": 2,
    "ts_decay_exp_window": 2, "ts_backfill": 2,
    "group_neutralize": 2, "group_rank": 2, "group_zscore": 2,
    "group_mean": 2, "group_sum": 2, "group_median": 2,
    "group_max": 2, "group_min": 2, "group_count": 2,
    "indneutralize": 2, "market_neutralize": 2,
    "power": 2, "max": 2, "min": 2, "winsorize": 2,
    "less": 1, "greater": 1, "equal": 1,
    "bucket": 2, "tail": 2, "mask": 2, "filter": 2,
    # 3-argument operators
    "ts_correlation": 3, "ts_covariance": 3,
    "if_else": 3, "trade_when": 3,
    # Variable (3+)
    "ts_regression": 3,
}


def _load_operators(path: Path = Path("data/operators.csv")) -> set[str]:
    """Load valid operator names from operators.csv."""
    if not path.exists():
        return set(OPERATOR_ARITY.keys())
    ops = set()
    with open(path) as f:
        for line in f:
            name = line.strip().split(",")[0].lower()
            if name and name != "name":
                ops.add(name)
    return ops


ALLOWED_OPS: set[str] | None = None

LOOKAHEAD_PATTERNS = [
    r"ts_delay\([^,]+,\s*-\d",
    r"\bfuture_",
    r"\bforward_return",
    r"ts_backfill\([^,]+,\s*\d{3,}",
]

UNIT_SAFE_WRAPPERS = {"zscore", "rank", "quantile", "group_zscore", "group_rank", "group_neutralize", "indneutralize"}


def _count_args(expression: str, start_pos: int) -> int:
    """Count arguments of a function call starting at the opening paren."""
    depth = 0
    arg_count = 1  # At least 1 if there's any content
    i = start_pos
    has_content = False
    while i < len(expression):
        ch = expression[i]
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth == 0:
                return arg_count if has_content else 0
        elif ch == ',' and depth == 1:
            arg_count += 1
        elif ch not in ' \t\n' and depth == 1:
            has_content = True
        i += 1
    return arg_count if has_content else 0


def lint(expression: str, operators_path: Path = Path("data/operators.csv")) -> LintResult:
    """Run all deterministic pre-flight checks on a BRAIN expression."""
    global ALLOWED_OPS
    if ALLOWED_OPS is None:
        ALLOWED_OPS = _load_operators(operators_path)

    errors: list[str] = []
    warnings: list[str] = []

    # CHECK 1: Empty or trivially short
    if not expression or len(expression.strip()) < 5:
        errors.append("Expression is empty or trivially short")
        return LintResult(passed=False, errors=errors, warnings=warnings)

    # CHECK 2: Operator validity
    found_ops = re.findall(r"\b([a-z_]+)\s*\(", expression.lower())
    for op in found_ops:
        if op not in ALLOWED_OPS and op not in {"if", "and", "or", "not"}:
            errors.append(f"Unknown operator: '{op}' — not in catalog")

    # CHECK 3: Look-ahead detection
    for pattern in LOOKAHEAD_PATTERNS:
        match = re.search(pattern, expression, re.IGNORECASE)
        if match:
            errors.append(f"Look-ahead detected: '{match.group()}'")

    # CHECK 4: Balanced parentheses
    depth = 0
    for ch in expression:
        if ch == "(":
            depth += 1
        elif ch == ")":
            depth -= 1
        if depth < 0:
            errors.append("Unbalanced parentheses: extra closing ')'")
            break
    if depth > 0:
        errors.append(f"Unbalanced parentheses: {depth} unclosed '('")

    # CHECK 5: Operator ARITY validation (catches "Invalid number of inputs")
    for match in re.finditer(r"\b([a-z_]+)\s*\(", expression.lower()):
        op_name = match.group(1)
        if op_name in OPERATOR_ARITY:
            expected_min = OPERATOR_ARITY[op_name]
            paren_start = match.end() - 1  # position of '('
            actual_args = _count_args(expression.lower(), paren_start)
            if actual_args < expected_min:
                errors.append(
                    f"Arity error: '{op_name}' requires {expected_min} args, got {actual_args}. "
                    f"Example: {op_name}(field, days)" if expected_min == 2 else f"Example: {op_name}(x, y, z)"
                )

    # CHECK 6: Unit-safety for additive expressions
    additive_parts = re.split(r"\s*\+\s*", expression.strip())
    if len(additive_parts) > 1:
        for part in additive_parts:
            part_clean = part.strip()
            first_func = re.match(r"[\d.\-]*\s*\*?\s*([a-z_]+)\s*\(", part_clean.lower())
            if first_func:
                func_name = first_func.group(1)
                if func_name not in UNIT_SAFE_WRAPPERS:
                    warnings.append(f"Additive operand not unit-safe: '{part_clean[:60]}...' — wrap in zscore/rank")

    # CHECK 7: Field coverage validation
    all_words = re.findall(r"\b([a-z][a-z0-9_]+)\b", expression.lower())
    for fid in all_words:
        if fid in FIELD_INDEX and FIELD_INDEX[fid].coverage < 0.60:
            warnings.append(f"Field '{fid}' has low coverage ({FIELD_INDEX[fid].coverage:.0%}) — use ts_backfill({fid}, 30)")

    # CHECK 8: Decay sanity
    decay_match = re.search(r"ts_decay_linear\([^,]+,\s*(\d+)", expression)
    if decay_match and int(decay_match.group(1)) > 20:
        warnings.append(f"Decay of {decay_match.group(1)} days is high — typical range 3-15")

    # CHECK 9: Very long lookback
    for match in re.finditer(r"ts_\w+\([^,]+,\s*(\d+)", expression):
        window = int(match.group(1))
        if window > 252:
            warnings.append(f"Window of {window} days > 1 year — may reduce coverage")

    # CHECK 10: Empty arguments
    empty_args = re.findall(r"\b[a-z_]+\(\s*,", expression.lower())
    if empty_args:
        errors.append(f"Empty first argument detected: {empty_args}")
    empty_func = re.findall(r"\b[a-z_]+\(\s*\)", expression.lower())
    if empty_func:
        errors.append(f"Function called with no arguments: {empty_func}")

    # CHECK 11: Turnover guard
    if "ts_decay_linear" not in expression.lower():
        warnings.append("No ts_decay_linear found — turnover may exceed 70%")

    return LintResult(
        passed=len(errors) == 0,
        errors=errors,
        warnings=warnings,
    )


def quick_dedup_hash(expression: str, neutralization: str, decay: int) -> str:
    import hashlib
    key = f"{expression.strip()}|{neutralization}|{decay}"
    return hashlib.sha256(key.encode()).hexdigest()[:16]


def validate_field_exists(field_id: str) -> bool:
    return field_id in FIELD_INDEX