| """ |
| Static Lint v3 — Layer 2: Deterministic pre-flight checks. |
| Validates operator ARITY (argument count) to catch 'Invalid number of inputs' errors. |
| """ |
| import re |
| from pathlib import Path |
| from ..schemas import LintResult |
| from ..data.brain_fields import FIELD_INDEX |
|
|
|
|
| |
| OPERATOR_ARITY: dict[str, int] = { |
| |
| "rank": 1, "zscore": 1, "quantile": 1, "abs": 1, "log": 1, |
| "sign": 1, "sqrt": 1, "sigmoid": 1, "tanh": 1, "relu": 1, |
| "pasteurize": 1, "truncate": 1, "fraction": 1, |
| "vec_avg": 1, "vec_sum": 1, "vec_norm": 1, "vec_count": 1, |
| |
| "ts_mean": 2, "ts_std": 2, "ts_sum": 2, "ts_min": 2, "ts_max": 2, |
| "ts_rank": 2, "ts_zscore": 2, "ts_delta": 2, "ts_delay": 2, |
| "ts_decay_linear": 2, "ts_argmax": 2, "ts_argmin": 2, |
| "ts_skewness": 2, "ts_kurtosis": 2, "ts_entropy": 2, |
| "ts_product": 2, "ts_moment": 2, "ts_av_diff": 2, |
| "ts_hump": 2, "ts_scale": 2, "ts_step": 2, |
| "ts_decay_exp_window": 2, "ts_backfill": 2, |
| "group_neutralize": 2, "group_rank": 2, "group_zscore": 2, |
| "group_mean": 2, "group_sum": 2, "group_median": 2, |
| "group_max": 2, "group_min": 2, "group_count": 2, |
| "indneutralize": 2, "market_neutralize": 2, |
| "power": 2, "max": 2, "min": 2, "winsorize": 2, |
| "less": 1, "greater": 1, "equal": 1, |
| "bucket": 2, "tail": 2, "mask": 2, "filter": 2, |
| |
| "ts_correlation": 3, "ts_covariance": 3, |
| "if_else": 3, "trade_when": 3, |
| |
| "ts_regression": 3, |
| } |
|
|
|
|
| def _load_operators(path: Path = Path("data/operators.csv")) -> set[str]: |
| """Load valid operator names from operators.csv.""" |
| if not path.exists(): |
| return set(OPERATOR_ARITY.keys()) |
| ops = set() |
| with open(path) as f: |
| for line in f: |
| name = line.strip().split(",")[0].lower() |
| if name and name != "name": |
| ops.add(name) |
| return ops |
|
|
|
|
| ALLOWED_OPS: set[str] | None = None |
|
|
| LOOKAHEAD_PATTERNS = [ |
| r"ts_delay\([^,]+,\s*-\d", |
| r"\bfuture_", |
| r"\bforward_return", |
| r"ts_backfill\([^,]+,\s*\d{3,}", |
| ] |
|
|
| UNIT_SAFE_WRAPPERS = {"zscore", "rank", "quantile", "group_zscore", "group_rank", "group_neutralize", "indneutralize"} |
|
|
|
|
| def _count_args(expression: str, start_pos: int) -> int: |
| """Count arguments of a function call starting at the opening paren.""" |
| depth = 0 |
| arg_count = 1 |
| i = start_pos |
| has_content = False |
| while i < len(expression): |
| ch = expression[i] |
| if ch == '(': |
| depth += 1 |
| elif ch == ')': |
| depth -= 1 |
| if depth == 0: |
| return arg_count if has_content else 0 |
| elif ch == ',' and depth == 1: |
| arg_count += 1 |
| elif ch not in ' \t\n' and depth == 1: |
| has_content = True |
| i += 1 |
| return arg_count if has_content else 0 |
|
|
|
|
| def lint(expression: str, operators_path: Path = Path("data/operators.csv")) -> LintResult: |
| """Run all deterministic pre-flight checks on a BRAIN expression.""" |
| global ALLOWED_OPS |
| if ALLOWED_OPS is None: |
| ALLOWED_OPS = _load_operators(operators_path) |
|
|
| errors: list[str] = [] |
| warnings: list[str] = [] |
|
|
| |
| if not expression or len(expression.strip()) < 5: |
| errors.append("Expression is empty or trivially short") |
| return LintResult(passed=False, errors=errors, warnings=warnings) |
|
|
| |
| found_ops = re.findall(r"\b([a-z_]+)\s*\(", expression.lower()) |
| for op in found_ops: |
| if op not in ALLOWED_OPS and op not in {"if", "and", "or", "not"}: |
| errors.append(f"Unknown operator: '{op}' — not in catalog") |
|
|
| |
| for pattern in LOOKAHEAD_PATTERNS: |
| match = re.search(pattern, expression, re.IGNORECASE) |
| if match: |
| errors.append(f"Look-ahead detected: '{match.group()}'") |
|
|
| |
| depth = 0 |
| for ch in expression: |
| if ch == "(": |
| depth += 1 |
| elif ch == ")": |
| depth -= 1 |
| if depth < 0: |
| errors.append("Unbalanced parentheses: extra closing ')'") |
| break |
| if depth > 0: |
| errors.append(f"Unbalanced parentheses: {depth} unclosed '('") |
|
|
| |
| for match in re.finditer(r"\b([a-z_]+)\s*\(", expression.lower()): |
| op_name = match.group(1) |
| if op_name in OPERATOR_ARITY: |
| expected_min = OPERATOR_ARITY[op_name] |
| paren_start = match.end() - 1 |
| actual_args = _count_args(expression.lower(), paren_start) |
| if actual_args < expected_min: |
| errors.append( |
| f"Arity error: '{op_name}' requires {expected_min} args, got {actual_args}. " |
| f"Example: {op_name}(field, days)" if expected_min == 2 else f"Example: {op_name}(x, y, z)" |
| ) |
|
|
| |
| additive_parts = re.split(r"\s*\+\s*", expression.strip()) |
| if len(additive_parts) > 1: |
| for part in additive_parts: |
| part_clean = part.strip() |
| first_func = re.match(r"[\d.\-]*\s*\*?\s*([a-z_]+)\s*\(", part_clean.lower()) |
| if first_func: |
| func_name = first_func.group(1) |
| if func_name not in UNIT_SAFE_WRAPPERS: |
| warnings.append(f"Additive operand not unit-safe: '{part_clean[:60]}...' — wrap in zscore/rank") |
|
|
| |
| all_words = re.findall(r"\b([a-z][a-z0-9_]+)\b", expression.lower()) |
| for fid in all_words: |
| if fid in FIELD_INDEX and FIELD_INDEX[fid].coverage < 0.60: |
| warnings.append(f"Field '{fid}' has low coverage ({FIELD_INDEX[fid].coverage:.0%}) — use ts_backfill({fid}, 30)") |
|
|
| |
| decay_match = re.search(r"ts_decay_linear\([^,]+,\s*(\d+)", expression) |
| if decay_match and int(decay_match.group(1)) > 20: |
| warnings.append(f"Decay of {decay_match.group(1)} days is high — typical range 3-15") |
|
|
| |
| for match in re.finditer(r"ts_\w+\([^,]+,\s*(\d+)", expression): |
| window = int(match.group(1)) |
| if window > 252: |
| warnings.append(f"Window of {window} days > 1 year — may reduce coverage") |
|
|
| |
| empty_args = re.findall(r"\b[a-z_]+\(\s*,", expression.lower()) |
| if empty_args: |
| errors.append(f"Empty first argument detected: {empty_args}") |
| empty_func = re.findall(r"\b[a-z_]+\(\s*\)", expression.lower()) |
| if empty_func: |
| errors.append(f"Function called with no arguments: {empty_func}") |
|
|
| |
| if "ts_decay_linear" not in expression.lower(): |
| warnings.append("No ts_decay_linear found — turnover may exceed 70%") |
|
|
| return LintResult( |
| passed=len(errors) == 0, |
| errors=errors, |
| warnings=warnings, |
| ) |
|
|
|
|
| def quick_dedup_hash(expression: str, neutralization: str, decay: int) -> str: |
| import hashlib |
| key = f"{expression.strip()}|{neutralization}|{decay}" |
| return hashlib.sha256(key.encode()).hexdigest()[:16] |
|
|
|
|
| def validate_field_exists(field_id: str) -> bool: |
| return field_id in FIELD_INDEX |
|
|