Upload alpha_factory/personas/expression_compiler.py
Browse files
alpha_factory/personas/expression_compiler.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
-
Expression Compiler
|
| 3 |
correct weight semantics (weights inside rank don't linearly combine),
|
| 4 |
-
removed double-negation bugs.
|
| 5 |
"""
|
| 6 |
import re
|
| 7 |
from jinja2 import Environment, BaseLoader
|
|
@@ -30,7 +30,7 @@ ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }
|
|
| 30 |
""",
|
| 31 |
"vol_scaled_shock": """
|
| 32 |
{%- set c = bp.components[0] -%}
|
| 33 |
-
ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001)), {{ group_key }}), {{ bp.decay }})
|
| 34 |
""",
|
| 35 |
"intraday_mr_decay": """
|
| 36 |
{%- set c = bp.components[0] -%}
|
|
@@ -148,34 +148,6 @@ def _get_backfill(blueprint: Blueprint) -> int:
|
|
| 148 |
return 10
|
| 149 |
|
| 150 |
|
| 151 |
-
def _apply_component_signs(expression: str, blueprint: Blueprint) -> str:
|
| 152 |
-
"""
|
| 153 |
-
Apply per-component sign corrections based on field sign conventions.
|
| 154 |
-
|
| 155 |
-
Unlike the old code which blindly negated the entire expression,
|
| 156 |
-
this only adjusts individual field references based on their
|
| 157 |
-
registered sign_direction mismatch with the component's declared sign.
|
| 158 |
-
"""
|
| 159 |
-
# For now, the templates handle sign_direction correctly via Jinja conditionals.
|
| 160 |
-
# This function is a hook for future per-component sign sweeps.
|
| 161 |
-
return expression
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
|
| 165 |
-
"""Mandatory post-compilation transforms."""
|
| 166 |
-
expr = expression.strip()
|
| 167 |
-
|
| 168 |
-
# Apply component-level sign corrections (now handled in templates)
|
| 169 |
-
expr = _apply_component_signs(expr, blueprint)
|
| 170 |
-
|
| 171 |
-
# ALWAYS wrap in ts_decay_linear (min 5, max 10)
|
| 172 |
-
decay = min(max(blueprint.decay, 5), 10)
|
| 173 |
-
if not expr.startswith("ts_decay_linear"):
|
| 174 |
-
expr = f"ts_decay_linear({expr}, {decay})"
|
| 175 |
-
|
| 176 |
-
return expr
|
| 177 |
-
|
| 178 |
-
|
| 179 |
def _validate_expression(expr: str) -> list[str]:
|
| 180 |
"""Validate expression has no obvious errors. Returns list of issues."""
|
| 181 |
issues = []
|
|
@@ -196,6 +168,26 @@ def _validate_expression(expr: str) -> list[str]:
|
|
| 196 |
return issues
|
| 197 |
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression:
|
| 200 |
"""Convert a Blueprint to a BRAIN expression."""
|
| 201 |
|
|
@@ -263,6 +255,13 @@ Output ONLY the expression."""
|
|
| 263 |
if "ts_decay_linear" not in result.operators_used:
|
| 264 |
result.operators_used.append("ts_decay_linear")
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
return result
|
| 267 |
|
| 268 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Expression Compiler v6 — Fixed per-component sign logic,
|
| 3 |
correct weight semantics (weights inside rank don't linearly combine),
|
| 4 |
+
removed double-negation bugs, WIRED field validation.
|
| 5 |
"""
|
| 6 |
import re
|
| 7 |
from jinja2 import Environment, BaseLoader
|
|
|
|
| 30 |
""",
|
| 31 |
"vol_scaled_shock": """
|
| 32 |
{%- set c = bp.components[0] -%}
|
| 33 |
+
ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))), {{ group_key }}), {{ bp.decay }})
|
| 34 |
""",
|
| 35 |
"intraday_mr_decay": """
|
| 36 |
{%- set c = bp.components[0] -%}
|
|
|
|
| 148 |
return 10
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def _validate_expression(expr: str) -> list[str]:
|
| 152 |
"""Validate expression has no obvious errors. Returns list of issues."""
|
| 153 |
issues = []
|
|
|
|
| 168 |
return issues
|
| 169 |
|
| 170 |
|
| 171 |
+
def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
|
| 172 |
+
"""Mandatory post-compilation transforms."""
|
| 173 |
+
expr = expression.strip()
|
| 174 |
+
|
| 175 |
+
# ALWAYS wrap in ts_decay_linear (min 5, max 10)
|
| 176 |
+
decay = min(max(blueprint.decay, 5), 10)
|
| 177 |
+
if not expr.startswith("ts_decay_linear"):
|
| 178 |
+
expr = f"ts_decay_linear({expr}, {decay})"
|
| 179 |
+
|
| 180 |
+
# Validate after all transforms
|
| 181 |
+
issues = _validate_expression(expr)
|
| 182 |
+
if issues:
|
| 183 |
+
# Log but don't crash — some fields may be legitimate but not in our registry yet
|
| 184 |
+
import logging
|
| 185 |
+
logger = logging.getLogger(__name__)
|
| 186 |
+
logger.debug(f"Expression validation issues: {issues}")
|
| 187 |
+
|
| 188 |
+
return expr
|
| 189 |
+
|
| 190 |
+
|
| 191 |
async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression:
|
| 192 |
"""Convert a Blueprint to a BRAIN expression."""
|
| 193 |
|
|
|
|
| 255 |
if "ts_decay_linear" not in result.operators_used:
|
| 256 |
result.operators_used.append("ts_decay_linear")
|
| 257 |
|
| 258 |
+
# Run field validation on LLM output and surface warnings
|
| 259 |
+
issues = _validate_expression(result.expression)
|
| 260 |
+
if issues:
|
| 261 |
+
import logging
|
| 262 |
+
logger = logging.getLogger(__name__)
|
| 263 |
+
logger.warning(f"LLM expression has issues: {issues}")
|
| 264 |
+
|
| 265 |
return result
|
| 266 |
|
| 267 |
|