| """ |
| Expression Compiler v6 — Fixed per-component sign logic, |
| correct weight semantics (weights inside rank don't linearly combine), |
| removed double-negation bugs, WIRED field validation. |
| """ |
| import re |
| from jinja2 import Environment, BaseLoader |
| from ..infra.llm_client import LLMClient |
| from ..schemas import Blueprint, Expression |
| from ..data.brain_fields import FIELD_INDEX, get_backfill_days, get_sign_multiplier |
| from ..data.brain_groups import get_group_for_expression, PRODUCTION_GROUPS |
|
|
|
|
| TEMPLATES = { |
| "value_quality_blend": """ |
| {%- set comps = [] -%} |
| {%- for c in bp.components -%} |
| {%- set _ = comps.append("group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%} |
| {%- endfor -%} |
| {%- if comps | length > 1 -%} |
| ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }}) |
| {%- else -%} |
| ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }}) |
| {%- endif -%} |
| """, |
| "multi_horizon_mr": """ |
| {%- set main = bp.components[0] -%} |
| {%- set sign = "-" if main.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})){%- for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})){%- endfor %}, {{ group_key }}), {{ bp.decay }}) |
| """, |
| "vol_scaled_shock": """ |
| {%- set c = bp.components[0] -%} |
| ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "intraday_mr_decay": """ |
| {%- set c = bp.components[0] -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "pead_revisions": """ |
| {%- set c = bp.components[0] -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "fundamental_yield_composite": """ |
| {%- set comps = [] -%} |
| {%- for c in bp.components -%} |
| {%- set _ = comps.append("zscore(rank(" ~ c.fields[0] ~ "))") -%} |
| {%- endfor -%} |
| {%- if comps | length > 1 -%} |
| ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }}) |
| {%- else -%} |
| ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }}) |
| {%- endif -%} |
| """, |
| "sue_drift": """ |
| {%- set c = bp.components[0] -%} |
| {%- set bf = backfill_days -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "supply_chain_lead_lag": """ |
| {%- set c = bp.components[0] -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "analyst_guidance_yield": """ |
| {%- set c = bp.components[0] -%} |
| {%- set bf = backfill_days -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "pcr_contrarian": """ |
| {%- set c = bp.components[0] -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "model_score_momentum": """ |
| {%- set c = bp.components[0] -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }}) |
| """, |
| "alpha15_hybrid": """ |
| {%- set c = bp.components[0] -%} |
| {%- set bf = backfill_days -%} |
| {%- set sign = "-" if c.sign_direction == "long_low" else "" -%} |
| {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank((high + low) / 2 - close, 252)) + zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }}) |
| """, |
| } |
|
|
| _env = Environment(loader=BaseLoader()) |
|
|
|
|
| COMPILER_SYSTEM_PROMPT = """You are a BRAIN expression compiler. Output ONLY the expression string. |
| |
| OPERATOR ARITY (MUST have exact number of args): |
| 1-arg: rank(x), zscore(x), abs(x), log(x), sign(x), sqrt(x), vec_avg(x), vec_sum(x) |
| 2-arg: ts_mean(field,days), ts_std(field,days), ts_sum(field,days), ts_rank(field,days), |
| ts_zscore(field,days), ts_delta(field,days), ts_delay(field,days), |
| ts_backfill(field,days), ts_decay_linear(field,days), ts_argmax(field,days), |
| group_neutralize(score,group), group_rank(score,group), group_zscore(score,group), |
| winsorize(field,std), power(x,n), max(x,y), min(x,y) |
| 3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit) |
| |
| RULES: |
| 1. Do NOT include ts_decay_linear — added automatically by the pipeline. |
| 2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252) |
| 3. Every additive operand MUST be wrapped in zscore() or rank(). |
| 4. Use group_neutralize(score, group_key) with 2 args always. |
| 5. For long_low sign_direction, prefix the field access with a minus sign: -field_name |
| |
| WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x)) |
| RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry) |
| """ |
|
|
|
|
| def _sanitize_expression(expr: str) -> str: |
| """Post-compile sanitization — fix all known LLM mistakes.""" |
| |
| expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr) |
| |
| |
| expr = expr.replace("```", "").strip() |
| |
| |
| lines = [l.strip() for l in expr.split("\n") if l.strip()] |
| expr = " ".join(lines) |
| |
| |
| expr = re.sub(r"-\(-\(", "(", expr) |
| |
| return expr |
|
|
|
|
| def _get_novel_group() -> str: |
| """Pick a novel group key for this expression.""" |
| if PRODUCTION_GROUPS: |
| return get_group_for_expression(prefer_novel=True) |
| return "subindustry" |
|
|
|
|
| def _get_backfill(blueprint: Blueprint) -> int: |
| """Determine backfill days from first component's field.""" |
| for c in blueprint.components: |
| for fid in c.fields: |
| if fid in FIELD_INDEX: |
| return get_backfill_days(FIELD_INDEX[fid]) |
| return 10 |
|
|
|
|
| def _validate_expression(expr: str) -> list[str]: |
| """Validate expression has no obvious errors. Returns list of issues.""" |
| issues = [] |
| |
| if re.search(r"['\"][a-z_]+['\"]", expr): |
| issues.append("Contains quoted field names") |
| |
| found_fields = re.findall(r"\b([a-z][a-z0-9_]{10,})\b", expr) |
| for f in found_fields: |
| if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"): |
| continue |
| if f in ("subindustry", "industry", "sector", "market", "close", "high", "low", "open", "volume", "vwap"): |
| continue |
| if f.startswith("pv13_") or f.startswith("mdl") or f.startswith("snt") or f.startswith("scl"): |
| continue |
| if f not in FIELD_INDEX: |
| issues.append(f"Unknown field: '{f}' (not in FIELD_INDEX)") |
| return issues |
|
|
|
|
| def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str: |
| """Mandatory post-compilation transforms.""" |
| expr = expression.strip() |
|
|
| |
| decay = min(max(blueprint.decay, 5), 10) |
| if not expr.startswith("ts_decay_linear"): |
| expr = f"ts_decay_linear({expr}, {decay})" |
|
|
| |
| issues = _validate_expression(expr) |
| if issues: |
| |
| import logging |
| logger = logging.getLogger(__name__) |
| logger.debug(f"Expression validation issues: {issues}") |
|
|
| return expr |
|
|
|
|
| async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression: |
| """Convert a Blueprint to a BRAIN expression.""" |
| |
| |
| group_key = _get_novel_group() |
| backfill_days = _get_backfill(blueprint) |
| |
| if blueprint.archetype in TEMPLATES: |
| template_str = TEMPLATES[blueprint.archetype] |
| template = _env.from_string(template_str) |
| expr_text = template.render( |
| bp=blueprint, |
| group_key=group_key, |
| backfill_days=backfill_days, |
| ).strip() |
|
|
| |
| expr_text = _sanitize_expression(expr_text) |
| expr_text = _apply_post_compile_rules(expr_text, blueprint) |
|
|
| fields_used = [] |
| ops_used = [] |
| for c in blueprint.components: |
| fields_used.extend(c.fields) |
| ops_used.extend(c.operators) |
|
|
| return Expression( |
| expression=expr_text, |
| fields_used=list(set(fields_used)), |
| operators_used=list(set(ops_used)) + ["ts_decay_linear", "group_neutralize"], |
| archetype_used=blueprint.archetype, |
| ) |
|
|
| |
| fields_list = [] |
| for c in blueprint.components: |
| fields_list.extend(c.fields) |
| |
| user_prompt = f"""Convert to BRAIN expression: |
| Theme: {blueprint.theme} |
| Components: |
| {_format_components(blueprint)} |
| Neutralization group: {group_key} |
| |
| REMEMBER: |
| - Do NOT quote field names. Write field_name not 'field_name'. |
| - ts_* operators need exactly 2 args: ts_rank(field, days) |
| - group_neutralize needs 2 args: group_neutralize(score, {group_key}) |
| - Do NOT include ts_decay_linear (added automatically). |
| - For long_low sign_direction, use -field_name. |
| |
| Output ONLY the expression.""" |
|
|
| result = await llm.generate_json( |
| prompt=user_prompt, |
| schema=Expression, |
| tier="tinyfish", |
| temperature=0.1, |
| system_prompt=COMPILER_SYSTEM_PROMPT, |
| ) |
|
|
| |
| result.expression = _sanitize_expression(result.expression) |
| result.expression = _apply_post_compile_rules(result.expression, blueprint) |
| if "ts_decay_linear" not in result.operators_used: |
| result.operators_used.append("ts_decay_linear") |
|
|
| |
| issues = _validate_expression(result.expression) |
| if issues: |
| import logging |
| logger = logging.getLogger(__name__) |
| logger.warning(f"LLM expression has issues: {issues}") |
|
|
| return result |
|
|
|
|
| def _format_components(bp: Blueprint) -> str: |
| lines = [] |
| for i, c in enumerate(bp.components): |
| lines.append( |
| f" {i+1}. {c.name}: fields={c.fields}, " |
| f"horizon={c.horizon_days}d, weight={c.weight}, sign={c.sign_direction}" |
| ) |
| return "\n".join(lines) |
|
|