alpha-factory / alpha_factory /personas /expression_compiler.py
gaurv007's picture
Upload alpha_factory/personas/expression_compiler.py
8cf7a55 verified
"""
Expression Compiler v6 — Fixed per-component sign logic,
correct weight semantics (weights inside rank don't linearly combine),
removed double-negation bugs, WIRED field validation.
"""
import re
from jinja2 import Environment, BaseLoader
from ..infra.llm_client import LLMClient
from ..schemas import Blueprint, Expression
from ..data.brain_fields import FIELD_INDEX, get_backfill_days, get_sign_multiplier
from ..data.brain_groups import get_group_for_expression, PRODUCTION_GROUPS
TEMPLATES = {
"value_quality_blend": """
{%- set comps = [] -%}
{%- for c in bp.components -%}
{%- set _ = comps.append("group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%}
{%- endfor -%}
{%- if comps | length > 1 -%}
ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
{%- else -%}
ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
{%- endif -%}
""",
"multi_horizon_mr": """
{%- set main = bp.components[0] -%}
{%- set sign = "-" if main.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})){%- for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})){%- endfor %}, {{ group_key }}), {{ bp.decay }})
""",
"vol_scaled_shock": """
{%- set c = bp.components[0] -%}
ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))), {{ group_key }}), {{ bp.decay }})
""",
"intraday_mr_decay": """
{%- set c = bp.components[0] -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }}), {{ bp.decay }})
""",
"pead_revisions": """
{%- set c = bp.components[0] -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
""",
"fundamental_yield_composite": """
{%- set comps = [] -%}
{%- for c in bp.components -%}
{%- set _ = comps.append("zscore(rank(" ~ c.fields[0] ~ "))") -%}
{%- endfor -%}
{%- if comps | length > 1 -%}
ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
{%- else -%}
ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
{%- endif -%}
""",
"sue_drift": """
{%- set c = bp.components[0] -%}
{%- set bf = backfill_days -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }}), {{ bp.decay }})
""",
"supply_chain_lead_lag": """
{%- set c = bp.components[0] -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
""",
"analyst_guidance_yield": """
{%- set c = bp.components[0] -%}
{%- set bf = backfill_days -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
""",
"pcr_contrarian": """
{%- set c = bp.components[0] -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
""",
"model_score_momentum": """
{%- set c = bp.components[0] -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
""",
"alpha15_hybrid": """
{%- set c = bp.components[0] -%}
{%- set bf = backfill_days -%}
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank((high + low) / 2 - close, 252)) + zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
""",
}
_env = Environment(loader=BaseLoader())
COMPILER_SYSTEM_PROMPT = """You are a BRAIN expression compiler. Output ONLY the expression string.
OPERATOR ARITY (MUST have exact number of args):
1-arg: rank(x), zscore(x), abs(x), log(x), sign(x), sqrt(x), vec_avg(x), vec_sum(x)
2-arg: ts_mean(field,days), ts_std(field,days), ts_sum(field,days), ts_rank(field,days),
ts_zscore(field,days), ts_delta(field,days), ts_delay(field,days),
ts_backfill(field,days), ts_decay_linear(field,days), ts_argmax(field,days),
group_neutralize(score,group), group_rank(score,group), group_zscore(score,group),
winsorize(field,std), power(x,n), max(x,y), min(x,y)
3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit)
RULES:
1. Do NOT include ts_decay_linear — added automatically by the pipeline.
2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252)
3. Every additive operand MUST be wrapped in zscore() or rank().
4. Use group_neutralize(score, group_key) with 2 args always.
5. For long_low sign_direction, prefix the field access with a minus sign: -field_name
WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x))
RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry)
"""
def _sanitize_expression(expr: str) -> str:
"""Post-compile sanitization — fix all known LLM mistakes."""
# 1. Strip quotes around field names
expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr)
# 2. Remove any markdown artifacts
expr = expr.replace("```", "").strip()
# 3. Remove any leading/trailing whitespace per line and join
lines = [l.strip() for l in expr.split("\n") if l.strip()]
expr = " ".join(lines)
# 4. Fix double negation: -(-(...)) -> (...)
expr = re.sub(r"-\(-\(", "(", expr)
return expr
def _get_novel_group() -> str:
"""Pick a novel group key for this expression."""
if PRODUCTION_GROUPS:
return get_group_for_expression(prefer_novel=True)
return "subindustry"
def _get_backfill(blueprint: Blueprint) -> int:
"""Determine backfill days from first component's field."""
for c in blueprint.components:
for fid in c.fields:
if fid in FIELD_INDEX:
return get_backfill_days(FIELD_INDEX[fid])
return 10
def _validate_expression(expr: str) -> list[str]:
"""Validate expression has no obvious errors. Returns list of issues."""
issues = []
# Check for quoted fields
if re.search(r"['\"][a-z_]+['\"]", expr):
issues.append("Contains quoted field names")
# Check for fields not in registry (sample check)
found_fields = re.findall(r"\b([a-z][a-z0-9_]{10,})\b", expr)
for f in found_fields:
if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"):
continue
if f in ("subindustry", "industry", "sector", "market", "close", "high", "low", "open", "volume", "vwap"):
continue
if f.startswith("pv13_") or f.startswith("mdl") or f.startswith("snt") or f.startswith("scl"):
continue
if f not in FIELD_INDEX:
issues.append(f"Unknown field: '{f}' (not in FIELD_INDEX)")
return issues
def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
"""Mandatory post-compilation transforms."""
expr = expression.strip()
# ALWAYS wrap in ts_decay_linear (min 5, max 10)
decay = min(max(blueprint.decay, 5), 10)
if not expr.startswith("ts_decay_linear"):
expr = f"ts_decay_linear({expr}, {decay})"
# Validate after all transforms
issues = _validate_expression(expr)
if issues:
# Log but don't crash — some fields may be legitimate but not in our registry yet
import logging
logger = logging.getLogger(__name__)
logger.debug(f"Expression validation issues: {issues}")
return expr
async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression:
"""Convert a Blueprint to a BRAIN expression."""
# Determine novel group key
group_key = _get_novel_group()
backfill_days = _get_backfill(blueprint)
if blueprint.archetype in TEMPLATES:
template_str = TEMPLATES[blueprint.archetype]
template = _env.from_string(template_str)
expr_text = template.render(
bp=blueprint,
group_key=group_key,
backfill_days=backfill_days,
).strip()
# Sanitize
expr_text = _sanitize_expression(expr_text)
expr_text = _apply_post_compile_rules(expr_text, blueprint)
fields_used = []
ops_used = []
for c in blueprint.components:
fields_used.extend(c.fields)
ops_used.extend(c.operators)
return Expression(
expression=expr_text,
fields_used=list(set(fields_used)),
operators_used=list(set(ops_used)) + ["ts_decay_linear", "group_neutralize"],
archetype_used=blueprint.archetype,
)
# LLM fallback for unknown archetypes
fields_list = []
for c in blueprint.components:
fields_list.extend(c.fields)
user_prompt = f"""Convert to BRAIN expression:
Theme: {blueprint.theme}
Components:
{_format_components(blueprint)}
Neutralization group: {group_key}
REMEMBER:
- Do NOT quote field names. Write field_name not 'field_name'.
- ts_* operators need exactly 2 args: ts_rank(field, days)
- group_neutralize needs 2 args: group_neutralize(score, {group_key})
- Do NOT include ts_decay_linear (added automatically).
- For long_low sign_direction, use -field_name.
Output ONLY the expression."""
result = await llm.generate_json(
prompt=user_prompt,
schema=Expression,
tier="tinyfish",
temperature=0.1,
system_prompt=COMPILER_SYSTEM_PROMPT,
)
# Sanitize LLM output
result.expression = _sanitize_expression(result.expression)
result.expression = _apply_post_compile_rules(result.expression, blueprint)
if "ts_decay_linear" not in result.operators_used:
result.operators_used.append("ts_decay_linear")
# Run field validation on LLM output and surface warnings
issues = _validate_expression(result.expression)
if issues:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"LLM expression has issues: {issues}")
return result
def _format_components(bp: Blueprint) -> str:
lines = []
for i, c in enumerate(bp.components):
lines.append(
f" {i+1}. {c.name}: fields={c.fields}, "
f"horizon={c.horizon_days}d, weight={c.weight}, sign={c.sign_direction}"
)
return "\n".join(lines)