gaurv007's picture
Upload alpha_factory/deterministic/lint.py with huggingface_hub
5d773d2 verified
"""
Static Lint v3 — Layer 2: Deterministic pre-flight checks.
Validates operator ARITY (argument count) to catch 'Invalid number of inputs' errors.
"""
import re
from pathlib import Path
from ..schemas import LintResult
from ..data.brain_fields import FIELD_INDEX
# Operator → required number of arguments (minimum)
OPERATOR_ARITY: dict[str, int] = {
# 1-argument operators
"rank": 1, "zscore": 1, "quantile": 1, "abs": 1, "log": 1,
"sign": 1, "sqrt": 1, "sigmoid": 1, "tanh": 1, "relu": 1,
"pasteurize": 1, "truncate": 1, "fraction": 1,
"vec_avg": 1, "vec_sum": 1, "vec_norm": 1, "vec_count": 1,
# 2-argument operators
"ts_mean": 2, "ts_std": 2, "ts_sum": 2, "ts_min": 2, "ts_max": 2,
"ts_rank": 2, "ts_zscore": 2, "ts_delta": 2, "ts_delay": 2,
"ts_decay_linear": 2, "ts_argmax": 2, "ts_argmin": 2,
"ts_skewness": 2, "ts_kurtosis": 2, "ts_entropy": 2,
"ts_product": 2, "ts_moment": 2, "ts_av_diff": 2,
"ts_hump": 2, "ts_scale": 2, "ts_step": 2,
"ts_decay_exp_window": 2, "ts_backfill": 2,
"group_neutralize": 2, "group_rank": 2, "group_zscore": 2,
"group_mean": 2, "group_sum": 2, "group_median": 2,
"group_max": 2, "group_min": 2, "group_count": 2,
"indneutralize": 2, "market_neutralize": 2,
"power": 2, "max": 2, "min": 2, "winsorize": 2,
"less": 1, "greater": 1, "equal": 1,
"bucket": 2, "tail": 2, "mask": 2, "filter": 2,
# 3-argument operators
"ts_correlation": 3, "ts_covariance": 3,
"if_else": 3, "trade_when": 3,
# Variable (3+)
"ts_regression": 3,
}
def _load_operators(path: Path = Path("data/operators.csv")) -> set[str]:
"""Load valid operator names from operators.csv."""
if not path.exists():
return set(OPERATOR_ARITY.keys())
ops = set()
with open(path) as f:
for line in f:
name = line.strip().split(",")[0].lower()
if name and name != "name":
ops.add(name)
return ops
ALLOWED_OPS: set[str] | None = None
LOOKAHEAD_PATTERNS = [
r"ts_delay\([^,]+,\s*-\d",
r"\bfuture_",
r"\bforward_return",
r"ts_backfill\([^,]+,\s*\d{3,}",
]
UNIT_SAFE_WRAPPERS = {"zscore", "rank", "quantile", "group_zscore", "group_rank", "group_neutralize", "indneutralize"}
def _count_args(expression: str, start_pos: int) -> int:
"""Count arguments of a function call starting at the opening paren."""
depth = 0
arg_count = 1 # At least 1 if there's any content
i = start_pos
has_content = False
while i < len(expression):
ch = expression[i]
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
if depth == 0:
return arg_count if has_content else 0
elif ch == ',' and depth == 1:
arg_count += 1
elif ch not in ' \t\n' and depth == 1:
has_content = True
i += 1
return arg_count if has_content else 0
def lint(expression: str, operators_path: Path = Path("data/operators.csv")) -> LintResult:
"""Run all deterministic pre-flight checks on a BRAIN expression."""
global ALLOWED_OPS
if ALLOWED_OPS is None:
ALLOWED_OPS = _load_operators(operators_path)
errors: list[str] = []
warnings: list[str] = []
# CHECK 1: Empty or trivially short
if not expression or len(expression.strip()) < 5:
errors.append("Expression is empty or trivially short")
return LintResult(passed=False, errors=errors, warnings=warnings)
# CHECK 2: Operator validity
found_ops = re.findall(r"\b([a-z_]+)\s*\(", expression.lower())
for op in found_ops:
if op not in ALLOWED_OPS and op not in {"if", "and", "or", "not"}:
errors.append(f"Unknown operator: '{op}' — not in catalog")
# CHECK 3: Look-ahead detection
for pattern in LOOKAHEAD_PATTERNS:
match = re.search(pattern, expression, re.IGNORECASE)
if match:
errors.append(f"Look-ahead detected: '{match.group()}'")
# CHECK 4: Balanced parentheses
depth = 0
for ch in expression:
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth < 0:
errors.append("Unbalanced parentheses: extra closing ')'")
break
if depth > 0:
errors.append(f"Unbalanced parentheses: {depth} unclosed '('")
# CHECK 5: Operator ARITY validation (catches "Invalid number of inputs")
for match in re.finditer(r"\b([a-z_]+)\s*\(", expression.lower()):
op_name = match.group(1)
if op_name in OPERATOR_ARITY:
expected_min = OPERATOR_ARITY[op_name]
paren_start = match.end() - 1 # position of '('
actual_args = _count_args(expression.lower(), paren_start)
if actual_args < expected_min:
errors.append(
f"Arity error: '{op_name}' requires {expected_min} args, got {actual_args}. "
f"Example: {op_name}(field, days)" if expected_min == 2 else f"Example: {op_name}(x, y, z)"
)
# CHECK 6: Unit-safety for additive expressions
additive_parts = re.split(r"\s*\+\s*", expression.strip())
if len(additive_parts) > 1:
for part in additive_parts:
part_clean = part.strip()
first_func = re.match(r"[\d.\-]*\s*\*?\s*([a-z_]+)\s*\(", part_clean.lower())
if first_func:
func_name = first_func.group(1)
if func_name not in UNIT_SAFE_WRAPPERS:
warnings.append(f"Additive operand not unit-safe: '{part_clean[:60]}...' — wrap in zscore/rank")
# CHECK 7: Field coverage validation
all_words = re.findall(r"\b([a-z][a-z0-9_]+)\b", expression.lower())
for fid in all_words:
if fid in FIELD_INDEX and FIELD_INDEX[fid].coverage < 0.60:
warnings.append(f"Field '{fid}' has low coverage ({FIELD_INDEX[fid].coverage:.0%}) — use ts_backfill({fid}, 30)")
# CHECK 8: Decay sanity
decay_match = re.search(r"ts_decay_linear\([^,]+,\s*(\d+)", expression)
if decay_match and int(decay_match.group(1)) > 20:
warnings.append(f"Decay of {decay_match.group(1)} days is high — typical range 3-15")
# CHECK 9: Very long lookback
for match in re.finditer(r"ts_\w+\([^,]+,\s*(\d+)", expression):
window = int(match.group(1))
if window > 252:
warnings.append(f"Window of {window} days > 1 year — may reduce coverage")
# CHECK 10: Empty arguments
empty_args = re.findall(r"\b[a-z_]+\(\s*,", expression.lower())
if empty_args:
errors.append(f"Empty first argument detected: {empty_args}")
empty_func = re.findall(r"\b[a-z_]+\(\s*\)", expression.lower())
if empty_func:
errors.append(f"Function called with no arguments: {empty_func}")
# CHECK 11: Turnover guard
if "ts_decay_linear" not in expression.lower():
warnings.append("No ts_decay_linear found — turnover may exceed 70%")
return LintResult(
passed=len(errors) == 0,
errors=errors,
warnings=warnings,
)
def quick_dedup_hash(expression: str, neutralization: str, decay: int) -> str:
import hashlib
key = f"{expression.strip()}|{neutralization}|{decay}"
return hashlib.sha256(key.encode()).hexdigest()[:16]
def validate_field_exists(field_id: str) -> bool:
return field_id in FIELD_INDEX