""" Static Lint v3 — Layer 2: Deterministic pre-flight checks. Validates operator ARITY (argument count) to catch 'Invalid number of inputs' errors. """ import re from pathlib import Path from ..schemas import LintResult from ..data.brain_fields import FIELD_INDEX # Operator → required number of arguments (minimum) OPERATOR_ARITY: dict[str, int] = { # 1-argument operators "rank": 1, "zscore": 1, "quantile": 1, "abs": 1, "log": 1, "sign": 1, "sqrt": 1, "sigmoid": 1, "tanh": 1, "relu": 1, "pasteurize": 1, "truncate": 1, "fraction": 1, "vec_avg": 1, "vec_sum": 1, "vec_norm": 1, "vec_count": 1, # 2-argument operators "ts_mean": 2, "ts_std": 2, "ts_sum": 2, "ts_min": 2, "ts_max": 2, "ts_rank": 2, "ts_zscore": 2, "ts_delta": 2, "ts_delay": 2, "ts_decay_linear": 2, "ts_argmax": 2, "ts_argmin": 2, "ts_skewness": 2, "ts_kurtosis": 2, "ts_entropy": 2, "ts_product": 2, "ts_moment": 2, "ts_av_diff": 2, "ts_hump": 2, "ts_scale": 2, "ts_step": 2, "ts_decay_exp_window": 2, "ts_backfill": 2, "group_neutralize": 2, "group_rank": 2, "group_zscore": 2, "group_mean": 2, "group_sum": 2, "group_median": 2, "group_max": 2, "group_min": 2, "group_count": 2, "indneutralize": 2, "market_neutralize": 2, "power": 2, "max": 2, "min": 2, "winsorize": 2, "less": 1, "greater": 1, "equal": 1, "bucket": 2, "tail": 2, "mask": 2, "filter": 2, # 3-argument operators "ts_correlation": 3, "ts_covariance": 3, "if_else": 3, "trade_when": 3, # Variable (3+) "ts_regression": 3, } def _load_operators(path: Path = Path("data/operators.csv")) -> set[str]: """Load valid operator names from operators.csv.""" if not path.exists(): return set(OPERATOR_ARITY.keys()) ops = set() with open(path) as f: for line in f: name = line.strip().split(",")[0].lower() if name and name != "name": ops.add(name) return ops ALLOWED_OPS: set[str] | None = None LOOKAHEAD_PATTERNS = [ r"ts_delay\([^,]+,\s*-\d", r"\bfuture_", r"\bforward_return", r"ts_backfill\([^,]+,\s*\d{3,}", ] UNIT_SAFE_WRAPPERS = {"zscore", "rank", "quantile", "group_zscore", "group_rank", "group_neutralize", "indneutralize"} def _count_args(expression: str, start_pos: int) -> int: """Count arguments of a function call starting at the opening paren.""" depth = 0 arg_count = 1 # At least 1 if there's any content i = start_pos has_content = False while i < len(expression): ch = expression[i] if ch == '(': depth += 1 elif ch == ')': depth -= 1 if depth == 0: return arg_count if has_content else 0 elif ch == ',' and depth == 1: arg_count += 1 elif ch not in ' \t\n' and depth == 1: has_content = True i += 1 return arg_count if has_content else 0 def lint(expression: str, operators_path: Path = Path("data/operators.csv")) -> LintResult: """Run all deterministic pre-flight checks on a BRAIN expression.""" global ALLOWED_OPS if ALLOWED_OPS is None: ALLOWED_OPS = _load_operators(operators_path) errors: list[str] = [] warnings: list[str] = [] # CHECK 1: Empty or trivially short if not expression or len(expression.strip()) < 5: errors.append("Expression is empty or trivially short") return LintResult(passed=False, errors=errors, warnings=warnings) # CHECK 2: Operator validity found_ops = re.findall(r"\b([a-z_]+)\s*\(", expression.lower()) for op in found_ops: if op not in ALLOWED_OPS and op not in {"if", "and", "or", "not"}: errors.append(f"Unknown operator: '{op}' — not in catalog") # CHECK 3: Look-ahead detection for pattern in LOOKAHEAD_PATTERNS: match = re.search(pattern, expression, re.IGNORECASE) if match: errors.append(f"Look-ahead detected: '{match.group()}'") # CHECK 4: Balanced parentheses depth = 0 for ch in expression: if ch == "(": depth += 1 elif ch == ")": depth -= 1 if depth < 0: errors.append("Unbalanced parentheses: extra closing ')'") break if depth > 0: errors.append(f"Unbalanced parentheses: {depth} unclosed '('") # CHECK 5: Operator ARITY validation (catches "Invalid number of inputs") for match in re.finditer(r"\b([a-z_]+)\s*\(", expression.lower()): op_name = match.group(1) if op_name in OPERATOR_ARITY: expected_min = OPERATOR_ARITY[op_name] paren_start = match.end() - 1 # position of '(' actual_args = _count_args(expression.lower(), paren_start) if actual_args < expected_min: errors.append( f"Arity error: '{op_name}' requires {expected_min} args, got {actual_args}. " f"Example: {op_name}(field, days)" if expected_min == 2 else f"Example: {op_name}(x, y, z)" ) # CHECK 6: Unit-safety for additive expressions additive_parts = re.split(r"\s*\+\s*", expression.strip()) if len(additive_parts) > 1: for part in additive_parts: part_clean = part.strip() first_func = re.match(r"[\d.\-]*\s*\*?\s*([a-z_]+)\s*\(", part_clean.lower()) if first_func: func_name = first_func.group(1) if func_name not in UNIT_SAFE_WRAPPERS: warnings.append(f"Additive operand not unit-safe: '{part_clean[:60]}...' — wrap in zscore/rank") # CHECK 7: Field coverage validation all_words = re.findall(r"\b([a-z][a-z0-9_]+)\b", expression.lower()) for fid in all_words: if fid in FIELD_INDEX and FIELD_INDEX[fid].coverage < 0.60: warnings.append(f"Field '{fid}' has low coverage ({FIELD_INDEX[fid].coverage:.0%}) — use ts_backfill({fid}, 30)") # CHECK 8: Decay sanity decay_match = re.search(r"ts_decay_linear\([^,]+,\s*(\d+)", expression) if decay_match and int(decay_match.group(1)) > 20: warnings.append(f"Decay of {decay_match.group(1)} days is high — typical range 3-15") # CHECK 9: Very long lookback for match in re.finditer(r"ts_\w+\([^,]+,\s*(\d+)", expression): window = int(match.group(1)) if window > 252: warnings.append(f"Window of {window} days > 1 year — may reduce coverage") # CHECK 10: Empty arguments empty_args = re.findall(r"\b[a-z_]+\(\s*,", expression.lower()) if empty_args: errors.append(f"Empty first argument detected: {empty_args}") empty_func = re.findall(r"\b[a-z_]+\(\s*\)", expression.lower()) if empty_func: errors.append(f"Function called with no arguments: {empty_func}") # CHECK 11: Turnover guard if "ts_decay_linear" not in expression.lower(): warnings.append("No ts_decay_linear found — turnover may exceed 70%") return LintResult( passed=len(errors) == 0, errors=errors, warnings=warnings, ) def quick_dedup_hash(expression: str, neutralization: str, decay: int) -> str: import hashlib key = f"{expression.strip()}|{neutralization}|{decay}" return hashlib.sha256(key.encode()).hexdigest()[:16] def validate_field_exists(field_id: str) -> bool: return field_id in FIELD_INDEX