Upload alpha_factory/personas/expression_compiler.py with huggingface_hub
Browse files
alpha_factory/personas/expression_compiler.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
-
Expression Compiler
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
3. Uses novel GROUP keys (pv13_rcsed_6l etc) instead of always sector
|
| 6 |
-
4. Explicit arity in LLM prompt
|
| 7 |
-
5. Post-compile sanitization
|
| 8 |
"""
|
| 9 |
import re
|
| 10 |
from jinja2 import Environment, BaseLoader
|
|
@@ -18,61 +15,76 @@ TEMPLATES = {
|
|
| 18 |
"value_quality_blend": """
|
| 19 |
{%- set comps = [] -%}
|
| 20 |
{%- for c in bp.components -%}
|
| 21 |
-
{%- set _ = comps.append(
|
| 22 |
{%- endfor -%}
|
| 23 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
""",
|
| 25 |
"multi_horizon_mr": """
|
| 26 |
{%- set main = bp.components[0] -%}
|
| 27 |
-
{%- set sign = "-" if
|
| 28 |
-
{{ sign }}(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }}))
|
| 29 |
""",
|
| 30 |
"vol_scaled_shock": """
|
| 31 |
{%- set c = bp.components[0] -%}
|
| 32 |
-
zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))
|
| 33 |
""",
|
| 34 |
"intraday_mr_decay": """
|
| 35 |
{%- set c = bp.components[0] -%}
|
| 36 |
-
|
|
|
|
| 37 |
""",
|
| 38 |
"pead_revisions": """
|
| 39 |
{%- set c = bp.components[0] -%}
|
| 40 |
-
|
|
|
|
| 41 |
""",
|
| 42 |
"fundamental_yield_composite": """
|
| 43 |
{%- set comps = [] -%}
|
| 44 |
{%- for c in bp.components -%}
|
| 45 |
-
{%- set _ = comps.append(
|
| 46 |
{%- endfor -%}
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
""",
|
| 49 |
"sue_drift": """
|
| 50 |
{%- set c = bp.components[0] -%}
|
| 51 |
{%- set bf = backfill_days -%}
|
| 52 |
-
|
|
|
|
| 53 |
""",
|
| 54 |
"supply_chain_lead_lag": """
|
| 55 |
{%- set c = bp.components[0] -%}
|
| 56 |
-
|
|
|
|
| 57 |
""",
|
| 58 |
"analyst_guidance_yield": """
|
| 59 |
{%- set c = bp.components[0] -%}
|
| 60 |
{%- set bf = backfill_days -%}
|
| 61 |
-
|
|
|
|
| 62 |
""",
|
| 63 |
"pcr_contrarian": """
|
| 64 |
{%- set c = bp.components[0] -%}
|
| 65 |
-
|
|
|
|
| 66 |
""",
|
| 67 |
"model_score_momentum": """
|
| 68 |
{%- set c = bp.components[0] -%}
|
| 69 |
-
|
|
|
|
| 70 |
""",
|
| 71 |
"alpha15_hybrid": """
|
| 72 |
{%- set c = bp.components[0] -%}
|
| 73 |
{%- set bf = backfill_days -%}
|
| 74 |
-
{%- set
|
| 75 |
-
|
| 76 |
""",
|
| 77 |
}
|
| 78 |
|
|
@@ -91,10 +103,11 @@ OPERATOR ARITY (MUST have exact number of args):
|
|
| 91 |
3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit)
|
| 92 |
|
| 93 |
RULES:
|
| 94 |
-
1. Do NOT include ts_decay_linear — added automatically.
|
| 95 |
2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252)
|
| 96 |
3. Every additive operand MUST be wrapped in zscore() or rank().
|
| 97 |
4. Use group_neutralize(score, group_key) with 2 args always.
|
|
|
|
| 98 |
|
| 99 |
WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x))
|
| 100 |
RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry)
|
|
@@ -103,7 +116,7 @@ RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindus
|
|
| 103 |
|
| 104 |
def _sanitize_expression(expr: str) -> str:
|
| 105 |
"""Post-compile sanitization — fix all known LLM mistakes."""
|
| 106 |
-
# 1. Strip quotes around field names
|
| 107 |
expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr)
|
| 108 |
|
| 109 |
# 2. Remove any markdown artifacts
|
|
@@ -135,30 +148,29 @@ def _get_backfill(blueprint: Blueprint) -> int:
|
|
| 135 |
return 10
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
|
| 139 |
"""Mandatory post-compilation transforms."""
|
| 140 |
expr = expression.strip()
|
| 141 |
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
if any(f in ["returns", "close", "ts_returns"] for f in c.fields):
|
| 145 |
-
if c.horizon_days <= 20 and c.sign_direction == "long_high":
|
| 146 |
-
expr = f"-({expr})"
|
| 147 |
-
break
|
| 148 |
-
|
| 149 |
-
# Apply field-level sign from registry
|
| 150 |
-
for c in blueprint.components:
|
| 151 |
-
for fid in c.fields:
|
| 152 |
-
if fid in FIELD_INDEX:
|
| 153 |
-
sign = get_sign_multiplier(FIELD_INDEX[fid])
|
| 154 |
-
if sign == -1 and c.sign_direction != "long_low":
|
| 155 |
-
if not expr.startswith("-"):
|
| 156 |
-
expr = f"-({expr})"
|
| 157 |
-
break
|
| 158 |
|
| 159 |
# ALWAYS wrap in ts_decay_linear (min 5, max 10)
|
| 160 |
decay = min(max(blueprint.decay, 5), 10)
|
| 161 |
-
if
|
| 162 |
expr = f"ts_decay_linear({expr}, {decay})"
|
| 163 |
|
| 164 |
return expr
|
|
@@ -175,12 +187,12 @@ def _validate_expression(expr: str) -> list[str]:
|
|
| 175 |
for f in found_fields:
|
| 176 |
if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"):
|
| 177 |
continue
|
| 178 |
-
if f in ("subindustry", "industry", "sector", "market"):
|
| 179 |
continue
|
| 180 |
-
if f.startswith("pv13_"):
|
| 181 |
continue
|
| 182 |
-
|
| 183 |
-
|
| 184 |
return issues
|
| 185 |
|
| 186 |
|
|
@@ -217,7 +229,7 @@ async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression
|
|
| 217 |
archetype_used=blueprint.archetype,
|
| 218 |
)
|
| 219 |
|
| 220 |
-
# LLM fallback
|
| 221 |
fields_list = []
|
| 222 |
for c in blueprint.components:
|
| 223 |
fields_list.extend(c.fields)
|
|
@@ -232,7 +244,8 @@ REMEMBER:
|
|
| 232 |
- Do NOT quote field names. Write field_name not 'field_name'.
|
| 233 |
- ts_* operators need exactly 2 args: ts_rank(field, days)
|
| 234 |
- group_neutralize needs 2 args: group_neutralize(score, {group_key})
|
| 235 |
-
- Do NOT include ts_decay_linear.
|
|
|
|
| 236 |
|
| 237 |
Output ONLY the expression."""
|
| 238 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Expression Compiler v5 — Fixed per-component sign logic,
|
| 3 |
+
correct weight semantics (weights inside rank don't linearly combine),
|
| 4 |
+
removed double-negation bugs.
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
import re
|
| 7 |
from jinja2 import Environment, BaseLoader
|
|
|
|
| 15 |
"value_quality_blend": """
|
| 16 |
{%- set comps = [] -%}
|
| 17 |
{%- for c in bp.components -%}
|
| 18 |
+
{%- set _ = comps.append("group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%}
|
| 19 |
{%- endfor -%}
|
| 20 |
+
{%- if comps | length > 1 -%}
|
| 21 |
+
ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
|
| 22 |
+
{%- else -%}
|
| 23 |
+
ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
|
| 24 |
+
{%- endif -%}
|
| 25 |
""",
|
| 26 |
"multi_horizon_mr": """
|
| 27 |
{%- set main = bp.components[0] -%}
|
| 28 |
+
{%- set sign = "-" if main.sign_direction == "long_low" else "" -%}
|
| 29 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})){%- for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})){%- endfor %}, {{ group_key }}), {{ bp.decay }})
|
| 30 |
""",
|
| 31 |
"vol_scaled_shock": """
|
| 32 |
{%- set c = bp.components[0] -%}
|
| 33 |
+
ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001)), {{ group_key }}), {{ bp.decay }})
|
| 34 |
""",
|
| 35 |
"intraday_mr_decay": """
|
| 36 |
{%- set c = bp.components[0] -%}
|
| 37 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 38 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }}), {{ bp.decay }})
|
| 39 |
""",
|
| 40 |
"pead_revisions": """
|
| 41 |
{%- set c = bp.components[0] -%}
|
| 42 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 43 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
|
| 44 |
""",
|
| 45 |
"fundamental_yield_composite": """
|
| 46 |
{%- set comps = [] -%}
|
| 47 |
{%- for c in bp.components -%}
|
| 48 |
+
{%- set _ = comps.append("zscore(rank(" ~ c.fields[0] ~ "))") -%}
|
| 49 |
{%- endfor -%}
|
| 50 |
+
{%- if comps | length > 1 -%}
|
| 51 |
+
ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
|
| 52 |
+
{%- else -%}
|
| 53 |
+
ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
|
| 54 |
+
{%- endif -%}
|
| 55 |
""",
|
| 56 |
"sue_drift": """
|
| 57 |
{%- set c = bp.components[0] -%}
|
| 58 |
{%- set bf = backfill_days -%}
|
| 59 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 60 |
+
{{ sign }}ts_decay_linear(group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }}), {{ bp.decay }})
|
| 61 |
""",
|
| 62 |
"supply_chain_lead_lag": """
|
| 63 |
{%- set c = bp.components[0] -%}
|
| 64 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 65 |
+
{{ sign }}ts_decay_linear(group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
|
| 66 |
""",
|
| 67 |
"analyst_guidance_yield": """
|
| 68 |
{%- set c = bp.components[0] -%}
|
| 69 |
{%- set bf = backfill_days -%}
|
| 70 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 71 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
|
| 72 |
""",
|
| 73 |
"pcr_contrarian": """
|
| 74 |
{%- set c = bp.components[0] -%}
|
| 75 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 76 |
+
{{ sign }}ts_decay_linear(group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
|
| 77 |
""",
|
| 78 |
"model_score_momentum": """
|
| 79 |
{%- set c = bp.components[0] -%}
|
| 80 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 81 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
|
| 82 |
""",
|
| 83 |
"alpha15_hybrid": """
|
| 84 |
{%- set c = bp.components[0] -%}
|
| 85 |
{%- set bf = backfill_days -%}
|
| 86 |
+
{%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
|
| 87 |
+
{{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank((high + low) / 2 - close, 252)) + zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
|
| 88 |
""",
|
| 89 |
}
|
| 90 |
|
|
|
|
| 103 |
3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit)
|
| 104 |
|
| 105 |
RULES:
|
| 106 |
+
1. Do NOT include ts_decay_linear — added automatically by the pipeline.
|
| 107 |
2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252)
|
| 108 |
3. Every additive operand MUST be wrapped in zscore() or rank().
|
| 109 |
4. Use group_neutralize(score, group_key) with 2 args always.
|
| 110 |
+
5. For long_low sign_direction, prefix the field access with a minus sign: -field_name
|
| 111 |
|
| 112 |
WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x))
|
| 113 |
RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry)
|
|
|
|
| 116 |
|
| 117 |
def _sanitize_expression(expr: str) -> str:
|
| 118 |
"""Post-compile sanitization — fix all known LLM mistakes."""
|
| 119 |
+
# 1. Strip quotes around field names
|
| 120 |
expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr)
|
| 121 |
|
| 122 |
# 2. Remove any markdown artifacts
|
|
|
|
| 148 |
return 10
|
| 149 |
|
| 150 |
|
| 151 |
+
def _apply_component_signs(expression: str, blueprint: Blueprint) -> str:
|
| 152 |
+
"""
|
| 153 |
+
Apply per-component sign corrections based on field sign conventions.
|
| 154 |
+
|
| 155 |
+
Unlike the old code which blindly negated the entire expression,
|
| 156 |
+
this only adjusts individual field references based on their
|
| 157 |
+
registered sign_direction mismatch with the component's declared sign.
|
| 158 |
+
"""
|
| 159 |
+
# For now, the templates handle sign_direction correctly via Jinja conditionals.
|
| 160 |
+
# This function is a hook for future per-component sign sweeps.
|
| 161 |
+
return expression
|
| 162 |
+
|
| 163 |
+
|
| 164 |
def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
|
| 165 |
"""Mandatory post-compilation transforms."""
|
| 166 |
expr = expression.strip()
|
| 167 |
|
| 168 |
+
# Apply component-level sign corrections (now handled in templates)
|
| 169 |
+
expr = _apply_component_signs(expr, blueprint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# ALWAYS wrap in ts_decay_linear (min 5, max 10)
|
| 172 |
decay = min(max(blueprint.decay, 5), 10)
|
| 173 |
+
if not expr.startswith("ts_decay_linear"):
|
| 174 |
expr = f"ts_decay_linear({expr}, {decay})"
|
| 175 |
|
| 176 |
return expr
|
|
|
|
| 187 |
for f in found_fields:
|
| 188 |
if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"):
|
| 189 |
continue
|
| 190 |
+
if f in ("subindustry", "industry", "sector", "market", "close", "high", "low", "open", "volume", "vwap"):
|
| 191 |
continue
|
| 192 |
+
if f.startswith("pv13_") or f.startswith("mdl") or f.startswith("snt") or f.startswith("scl"):
|
| 193 |
continue
|
| 194 |
+
if f not in FIELD_INDEX:
|
| 195 |
+
issues.append(f"Unknown field: '{f}' (not in FIELD_INDEX)")
|
| 196 |
return issues
|
| 197 |
|
| 198 |
|
|
|
|
| 229 |
archetype_used=blueprint.archetype,
|
| 230 |
)
|
| 231 |
|
| 232 |
+
# LLM fallback for unknown archetypes
|
| 233 |
fields_list = []
|
| 234 |
for c in blueprint.components:
|
| 235 |
fields_list.extend(c.fields)
|
|
|
|
| 244 |
- Do NOT quote field names. Write field_name not 'field_name'.
|
| 245 |
- ts_* operators need exactly 2 args: ts_rank(field, days)
|
| 246 |
- group_neutralize needs 2 args: group_neutralize(score, {group_key})
|
| 247 |
+
- Do NOT include ts_decay_linear (added automatically).
|
| 248 |
+
- For long_low sign_direction, use -field_name.
|
| 249 |
|
| 250 |
Output ONLY the expression."""
|
| 251 |
|