gaurv007 commited on
Commit
b32eb19
·
verified ·
1 Parent(s): 83341a1

Upload alpha_factory/personas/expression_compiler.py with huggingface_hub

Browse files
alpha_factory/personas/expression_compiler.py CHANGED
@@ -1,10 +1,7 @@
1
  """
2
- Expression Compiler v4All known issues fixed:
3
- 1. Strips quotes from field names ('field' -> field)
4
- 2. Validates all field IDs exist in registry
5
- 3. Uses novel GROUP keys (pv13_rcsed_6l etc) instead of always sector
6
- 4. Explicit arity in LLM prompt
7
- 5. Post-compile sanitization
8
  """
9
  import re
10
  from jinja2 import Environment, BaseLoader
@@ -18,61 +15,76 @@ TEMPLATES = {
18
  "value_quality_blend": """
19
  {%- set comps = [] -%}
20
  {%- for c in bp.components -%}
21
- {%- set _ = comps.append(c.weight|string ~ " * group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%}
22
  {%- endfor -%}
23
- {{ comps | join(" + ") }}
 
 
 
 
24
  """,
25
  "multi_horizon_mr": """
26
  {%- set main = bp.components[0] -%}
27
- {%- set sign = "-" if bp.components[0].sign_direction == "long_low" else "" -%}
28
- {{ sign }}(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})) * {{ main.weight }}{% for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})) * {{ c.weight }}{% endfor %})
29
  """,
30
  "vol_scaled_shock": """
31
  {%- set c = bp.components[0] -%}
32
- zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))
33
  """,
34
  "intraday_mr_decay": """
35
  {%- set c = bp.components[0] -%}
36
- group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }})
 
37
  """,
38
  "pead_revisions": """
39
  {%- set c = bp.components[0] -%}
40
- group_zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}), {{ group_key }})
 
41
  """,
42
  "fundamental_yield_composite": """
43
  {%- set comps = [] -%}
44
  {%- for c in bp.components -%}
45
- {%- set _ = comps.append(c.weight|string ~ " * zscore(rank(" ~ c.fields[0] ~ "))") -%}
46
  {%- endfor -%}
47
- group_neutralize({{ comps | join(" + ") }}, {{ group_key }})
 
 
 
 
48
  """,
49
  "sue_drift": """
50
  {%- set c = bp.components[0] -%}
51
  {%- set bf = backfill_days -%}
52
- group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }})
 
53
  """,
54
  "supply_chain_lead_lag": """
55
  {%- set c = bp.components[0] -%}
56
- group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }})
 
57
  """,
58
  "analyst_guidance_yield": """
59
  {%- set c = bp.components[0] -%}
60
  {%- set bf = backfill_days -%}
61
- group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }})
 
62
  """,
63
  "pcr_contrarian": """
64
  {%- set c = bp.components[0] -%}
65
- group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }})
 
66
  """,
67
  "model_score_momentum": """
68
  {%- set c = bp.components[0] -%}
69
- group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }})
 
70
  """,
71
  "alpha15_hybrid": """
72
  {%- set c = bp.components[0] -%}
73
  {%- set bf = backfill_days -%}
74
- {%- set sign_mult = "" if c.sign_direction == "long_high" else "-" -%}
75
- group_neutralize({{ sign_mult }}(0.60 * zscore(ts_rank((high + low) / 2 - close, 252)) + 0.40 * zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252))), {{ group_key }})
76
  """,
77
  }
78
 
@@ -91,10 +103,11 @@ OPERATOR ARITY (MUST have exact number of args):
91
  3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit)
92
 
93
  RULES:
94
- 1. Do NOT include ts_decay_linear — added automatically.
95
  2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252)
96
  3. Every additive operand MUST be wrapped in zscore() or rank().
97
  4. Use group_neutralize(score, group_key) with 2 args always.
 
98
 
99
  WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x))
100
  RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry)
@@ -103,7 +116,7 @@ RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindus
103
 
104
  def _sanitize_expression(expr: str) -> str:
105
  """Post-compile sanitization — fix all known LLM mistakes."""
106
- # 1. Strip quotes around field names: 'field_name' or "field_name" -> field_name
107
  expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr)
108
 
109
  # 2. Remove any markdown artifacts
@@ -135,30 +148,29 @@ def _get_backfill(blueprint: Blueprint) -> int:
135
  return 10
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
139
  """Mandatory post-compilation transforms."""
140
  expr = expression.strip()
141
 
142
- # Sign flip for short-horizon reversal
143
- for c in blueprint.components:
144
- if any(f in ["returns", "close", "ts_returns"] for f in c.fields):
145
- if c.horizon_days <= 20 and c.sign_direction == "long_high":
146
- expr = f"-({expr})"
147
- break
148
-
149
- # Apply field-level sign from registry
150
- for c in blueprint.components:
151
- for fid in c.fields:
152
- if fid in FIELD_INDEX:
153
- sign = get_sign_multiplier(FIELD_INDEX[fid])
154
- if sign == -1 and c.sign_direction != "long_low":
155
- if not expr.startswith("-"):
156
- expr = f"-({expr})"
157
- break
158
 
159
  # ALWAYS wrap in ts_decay_linear (min 5, max 10)
160
  decay = min(max(blueprint.decay, 5), 10)
161
- if "ts_decay_linear" not in expr:
162
  expr = f"ts_decay_linear({expr}, {decay})"
163
 
164
  return expr
@@ -175,12 +187,12 @@ def _validate_expression(expr: str) -> list[str]:
175
  for f in found_fields:
176
  if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"):
177
  continue
178
- if f in ("subindustry", "industry", "sector", "market"):
179
  continue
180
- if f.startswith("pv13_"):
181
  continue
182
- # It's likely a field — check if known
183
- # (skip validation for now since we have partial registry)
184
  return issues
185
 
186
 
@@ -217,7 +229,7 @@ async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression
217
  archetype_used=blueprint.archetype,
218
  )
219
 
220
- # LLM fallback
221
  fields_list = []
222
  for c in blueprint.components:
223
  fields_list.extend(c.fields)
@@ -232,7 +244,8 @@ REMEMBER:
232
  - Do NOT quote field names. Write field_name not 'field_name'.
233
  - ts_* operators need exactly 2 args: ts_rank(field, days)
234
  - group_neutralize needs 2 args: group_neutralize(score, {group_key})
235
- - Do NOT include ts_decay_linear.
 
236
 
237
  Output ONLY the expression."""
238
 
 
1
  """
2
+ Expression Compiler v5Fixed per-component sign logic,
3
+ correct weight semantics (weights inside rank don't linearly combine),
4
+ removed double-negation bugs.
 
 
 
5
  """
6
  import re
7
  from jinja2 import Environment, BaseLoader
 
15
  "value_quality_blend": """
16
  {%- set comps = [] -%}
17
  {%- for c in bp.components -%}
18
+ {%- set _ = comps.append("group_zscore(rank(ts_mean(" ~ c.fields[0] ~ ", " ~ c.horizon_days ~ ")), " ~ group_key ~ ")") -%}
19
  {%- endfor -%}
20
+ {%- if comps | length > 1 -%}
21
+ ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
22
+ {%- else -%}
23
+ ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
24
+ {%- endif -%}
25
  """,
26
  "multi_horizon_mr": """
27
  {%- set main = bp.components[0] -%}
28
+ {%- set sign = "-" if main.sign_direction == "long_low" else "" -%}
29
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank({{ main.fields[0] }}, {{ main.horizon_days }})){%- for c in bp.components[1:] %} + zscore(ts_rank({{ c.fields[0] }}, {{ c.horizon_days }})){%- endfor %}, {{ group_key }}), {{ bp.decay }})
30
  """,
31
  "vol_scaled_shock": """
32
  {%- set c = bp.components[0] -%}
33
+ ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001)), {{ group_key }}), {{ bp.decay }})
34
  """,
35
  "intraday_mr_decay": """
36
  {%- set c = bp.components[0] -%}
37
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
38
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(rank({{ c.fields[0] }})), {{ group_key }}), {{ bp.decay }})
39
  """,
40
  "pead_revisions": """
41
  {%- set c = bp.components[0] -%}
42
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
43
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
44
  """,
45
  "fundamental_yield_composite": """
46
  {%- set comps = [] -%}
47
  {%- for c in bp.components -%}
48
+ {%- set _ = comps.append("zscore(rank(" ~ c.fields[0] ~ "))") -%}
49
  {%- endfor -%}
50
+ {%- if comps | length > 1 -%}
51
+ ts_decay_linear(group_neutralize(rank({{ comps | join(" + ") }}), {{ group_key }}), {{ bp.decay }})
52
+ {%- else -%}
53
+ ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }})
54
+ {%- endif -%}
55
  """,
56
  "sue_drift": """
57
  {%- set c = bp.components[0] -%}
58
  {%- set bf = backfill_days -%}
59
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
60
+ {{ sign }}ts_decay_linear(group_neutralize(rank(ts_backfill({{ c.fields[0] }}, {{ bf }})), {{ group_key }}), {{ bp.decay }})
61
  """,
62
  "supply_chain_lead_lag": """
63
  {%- set c = bp.components[0] -%}
64
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
65
+ {{ sign }}ts_decay_linear(group_neutralize(rank(ts_mean(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
66
  """,
67
  "analyst_guidance_yield": """
68
  {%- set c = bp.components[0] -%}
69
  {%- set bf = backfill_days -%}
70
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
71
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
72
  """,
73
  "pcr_contrarian": """
74
  {%- set c = bp.components[0] -%}
75
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
76
+ {{ sign }}ts_decay_linear(group_neutralize(rank(-ts_delta(ts_backfill({{ c.fields[0] }}, 30), {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
77
  """,
78
  "model_score_momentum": """
79
  {%- set c = bp.components[0] -%}
80
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
81
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }})), {{ group_key }}), {{ bp.decay }})
82
  """,
83
  "alpha15_hybrid": """
84
  {%- set c = bp.components[0] -%}
85
  {%- set bf = backfill_days -%}
86
+ {%- set sign = "-" if c.sign_direction == "long_low" else "" -%}
87
+ {{ sign }}ts_decay_linear(group_neutralize(zscore(ts_rank((high + low) / 2 - close, 252)) + zscore(ts_rank(ts_backfill({{ c.fields[0] }}, {{ bf }}), 252)), {{ group_key }}), {{ bp.decay }})
88
  """,
89
  }
90
 
 
103
  3-arg: ts_correlation(x,y,days), if_else(cond,true,false), trade_when(cond,signal,exit)
104
 
105
  RULES:
106
+ 1. Do NOT include ts_decay_linear — added automatically by the pipeline.
107
  2. Do NOT quote field names. Write: ts_rank(field_name, 252) NOT ts_rank('field_name', 252)
108
  3. Every additive operand MUST be wrapped in zscore() or rank().
109
  4. Use group_neutralize(score, group_key) with 2 args always.
110
+ 5. For long_low sign_direction, prefix the field access with a minus sign: -field_name
111
 
112
  WRONG: ts_mean(field) ts_rank(field) group_neutralize(rank(x))
113
  RIGHT: ts_mean(field, 20) ts_rank(field, 252) group_neutralize(rank(x), subindustry)
 
116
 
117
  def _sanitize_expression(expr: str) -> str:
118
  """Post-compile sanitization — fix all known LLM mistakes."""
119
+ # 1. Strip quotes around field names
120
  expr = re.sub(r"['\"]([a-z][a-z0-9_]+)['\"]", r"\1", expr)
121
 
122
  # 2. Remove any markdown artifacts
 
148
  return 10
149
 
150
 
151
+ def _apply_component_signs(expression: str, blueprint: Blueprint) -> str:
152
+ """
153
+ Apply per-component sign corrections based on field sign conventions.
154
+
155
+ Unlike the old code which blindly negated the entire expression,
156
+ this only adjusts individual field references based on their
157
+ registered sign_direction mismatch with the component's declared sign.
158
+ """
159
+ # For now, the templates handle sign_direction correctly via Jinja conditionals.
160
+ # This function is a hook for future per-component sign sweeps.
161
+ return expression
162
+
163
+
164
  def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
165
  """Mandatory post-compilation transforms."""
166
  expr = expression.strip()
167
 
168
+ # Apply component-level sign corrections (now handled in templates)
169
+ expr = _apply_component_signs(expr, blueprint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # ALWAYS wrap in ts_decay_linear (min 5, max 10)
172
  decay = min(max(blueprint.decay, 5), 10)
173
+ if not expr.startswith("ts_decay_linear"):
174
  expr = f"ts_decay_linear({expr}, {decay})"
175
 
176
  return expr
 
187
  for f in found_fields:
188
  if f.startswith("ts_") or f.startswith("group_") or f.startswith("vec_"):
189
  continue
190
+ if f in ("subindustry", "industry", "sector", "market", "close", "high", "low", "open", "volume", "vwap"):
191
  continue
192
+ if f.startswith("pv13_") or f.startswith("mdl") or f.startswith("snt") or f.startswith("scl"):
193
  continue
194
+ if f not in FIELD_INDEX:
195
+ issues.append(f"Unknown field: '{f}' (not in FIELD_INDEX)")
196
  return issues
197
 
198
 
 
229
  archetype_used=blueprint.archetype,
230
  )
231
 
232
+ # LLM fallback for unknown archetypes
233
  fields_list = []
234
  for c in blueprint.components:
235
  fields_list.extend(c.fields)
 
244
  - Do NOT quote field names. Write field_name not 'field_name'.
245
  - ts_* operators need exactly 2 args: ts_rank(field, days)
246
  - group_neutralize needs 2 args: group_neutralize(score, {group_key})
247
+ - Do NOT include ts_decay_linear (added automatically).
248
+ - For long_low sign_direction, use -field_name.
249
 
250
  Output ONLY the expression."""
251