""" Expression Compiler v6 — Fixed per-component sign logic, correct weight semantics (weights inside rank don't linearly combine), removed double-negation bugs, WIRED field validation. """ import re from jinja2 import Environment, BaseLoader from ..infra.llm_client import LLMClient from ..schemas import Blueprint, Expression from ..data.brain_fields import FIELD_INDEX, get_backfill_days, get_sign_multiplier from ..data.brain_groups import get_group_for_expression, PRODUCTION_GROUPS TEMPLATES = { "value_quality_blend": """ {%- set comps = [] -%} {%- for c in bp.components -%} {%- set _ = comps.append("group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%} {%- endfor -%} {%- if comps | length > 1 -%} ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }}) {%- else -%} ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }}) {%- endif -%} """, "multi_horizon_mr": """ {%- set main = bp.components[0] -%} {%- set sign = "-" if main.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})){%- for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})){%- endfor %}, {{ group_key }}), {{ bp.decay }}) """, "vol_scaled_shock": """ {%- set c = bp.components[0] -%} ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))), {{ group_key }}), {{ bp.decay }}) """, "intraday_mr_decay": """ {%- set c = bp.components[0] -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }}), {{ bp.decay }}) """, "pead_revisions": """ {%- set c = bp.components[0] -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) """, "fundamental_yield_composite": """ {%- set comps = [] -%} {%- for c in bp.components -%} {%- set _ = comps.append("zscore(rank(" ~ c.fields[0] ~ "))") -%} {%- endfor -%} {%- if comps | length > 1 -%} ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }}) {%- else -%} ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }}) {%- endif -%} """, "sue_drift": """ {%- set c = bp.components[0] -%} {%- set bf = backfill_days -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }}), {{ bp.decay }}) """, "supply_chain_lead_lag": """ {%- set c = bp.components[0] -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) """, "analyst_guidance_yield": """ {%- set c = bp.components[0] -%} {%- set bf = backfill_days -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }}) """, "pcr_contrarian": """ {%- set c = bp.components[0] -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) """, "model_score_momentum": """ {%- set c = bp.components[0] -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) """, "alpha15_hybrid": """ {%- set c = bp.components[0] -%} {%- set bf = backfill_days -%} {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank((high + low) / 2 - close, 252)) + zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }}) """, } _env = Environment(loader=BaseLoader()) COMPILER_SYSTEM_PROMPT = """You are a BRAIN expression compiler. Output ONLY the expression string. OPERATOR ARITY (MUST have exact number of args): 1-arg: rank(x), zscore(x), abs(x), log(x), sign(x), sqrt(x), vec_avg(x), vec_sum(x) 2-arg: ts_mean(field,days), ts_std(field,days), ts_sum(field,days), ts_rank(field,days), ts_zscore(field,days), ts_delta(field,days), ts_delay(field,days), ts_backfill(field,days), ts_decay_linear(field,days), ts_argmax(field,days), group_neutralize(score,group), group_rank(score,group), group_zscore(score,group), winsorize(field,std), power(x,n), max(x,y), min(x,y) 3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit) RULES: 1. Do NOT include ts_decay_linear — added automatically by the pipeline. 2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252) 3. Every additive operand MUST be wrapped in zscore() or rank(). 4. Use group_neutralize(score, group_key) with 2 args always. 5. For long_low sign_direction, prefix the field access with a minus sign: -field_name WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x)) RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry) """ def _sanitize_expression(expr: str) -> str: """Post-compile sanitization — fix all known LLM mistakes.""" # 1. Strip quotes around field names expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr) # 2. Remove any markdown artifacts expr = expr.replace("```", "").strip() # 3. Remove any leading/trailing whitespace per line and join lines = [l.strip() for l in expr.split("\n") if l.strip()] expr = " ".join(lines) # 4. Fix double negation: -(-(...)) -> (...) expr = re.sub(r"-\(-\(", "(", expr) return expr def _get_novel_group() -> str: """Pick a novel group key for this expression.""" if PRODUCTION_GROUPS: return get_group_for_expression(prefer_novel=True) return "subindustry" def _get_backfill(blueprint: Blueprint) -> int: """Determine backfill days from first component's field.""" for c in blueprint.components: for fid in c.fields: if fid in FIELD_INDEX: return get_backfill_days(FIELD_INDEX[fid]) return 10 def _validate_expression(expr: str) -> list[str]: """Validate expression has no obvious errors. Returns list of issues.""" issues = [] # Check for quoted fields if re.search(r"['\"][a-z_]+['\"]", expr): issues.append("Contains quoted field names") # Check for fields not in registry (sample check) found_fields = re.findall(r"\b([a-z][a-z0-9_]{10,})\b", expr) for f in found_fields: if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"): continue if f in ("subindustry", "industry", "sector", "market", "close", "high", "low", "open", "volume", "vwap"): continue if f.startswith("pv13_") or f.startswith("mdl") or f.startswith("snt") or f.startswith("scl"): continue if f not in FIELD_INDEX: issues.append(f"Unknown field: '{f}' (not in FIELD_INDEX)") return issues def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str: """Mandatory post-compilation transforms.""" expr = expression.strip() # ALWAYS wrap in ts_decay_linear (min 5, max 10) decay = min(max(blueprint.decay, 5), 10) if not expr.startswith("ts_decay_linear"): expr = f"ts_decay_linear({expr}, {decay})" # Validate after all transforms issues = _validate_expression(expr) if issues: # Log but don't crash — some fields may be legitimate but not in our registry yet import logging logger = logging.getLogger(__name__) logger.debug(f"Expression validation issues: {issues}") return expr async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression: """Convert a Blueprint to a BRAIN expression.""" # Determine novel group key group_key = _get_novel_group() backfill_days = _get_backfill(blueprint) if blueprint.archetype in TEMPLATES: template_str = TEMPLATES[blueprint.archetype] template = _env.from_string(template_str) expr_text = template.render( bp=blueprint, group_key=group_key, backfill_days=backfill_days, ).strip() # Sanitize expr_text = _sanitize_expression(expr_text) expr_text = _apply_post_compile_rules(expr_text, blueprint) fields_used = [] ops_used = [] for c in blueprint.components: fields_used.extend(c.fields) ops_used.extend(c.operators) return Expression( expression=expr_text, fields_used=list(set(fields_used)), operators_used=list(set(ops_used)) + ["ts_decay_linear", "group_neutralize"], archetype_used=blueprint.archetype, ) # LLM fallback for unknown archetypes fields_list = [] for c in blueprint.components: fields_list.extend(c.fields) user_prompt = f"""Convert to BRAIN expression: Theme: {blueprint.theme} Components: {_format_components(blueprint)} Neutralization group: {group_key} REMEMBER: - Do NOT quote field names. Write field_name not 'field_name'. - ts_* operators need exactly 2 args: ts_rank(field, days) - group_neutralize needs 2 args: group_neutralize(score, {group_key}) - Do NOT include ts_decay_linear (added automatically). - For long_low sign_direction, use -field_name. Output ONLY the expression.""" result = await llm.generate_json( prompt=user_prompt, schema=Expression, tier="tinyfish", temperature=0.1, system_prompt=COMPILER_SYSTEM_PROMPT, ) # Sanitize LLM output result.expression = _sanitize_expression(result.expression) result.expression = _apply_post_compile_rules(result.expression, blueprint) if "ts_decay_linear" not in result.operators_used: result.operators_used.append("ts_decay_linear") # Run field validation on LLM output and surface warnings issues = _validate_expression(result.expression) if issues: import logging logger = logging.getLogger(__name__) logger.warning(f"LLM expression has issues: {issues}") return result def _format_components(bp: Blueprint) -> str: lines = [] for i, c in enumerate(bp.components): lines.append( f" {i+1}. {c.name}: fields={c.fields}, " f"horizon={c.horizon_days}d, weight={c.weight}, sign={c.sign_direction}" ) return "\n".join(lines)