gaurv007 commited on
Commit
8cf7a55
·
verified ·
1 Parent(s): 8267e98

Upload alpha_factory/personas/expression_compiler.py

Browse files
alpha_factory/personas/expression_compiler.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- Expression Compiler v5 — Fixed per-component sign logic,
3
  correct weight semantics (weights inside rank don't linearly combine),
4
- removed double-negation bugs.
5
  """
6
  import re
7
  from jinja2 import Environment, BaseLoader
@@ -30,7 +30,7 @@ ts_decay_linear(group_neutralize({{ comps[0] }}, {{ group_key }}), {{ bp.decay }
30
  """,
31
  "vol_scaled_shock": """
32
  {%- set c = bp.components[0] -%}
33
- ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001)), {{ group_key }}), {{ bp.decay }})
34
  """,
35
  "intraday_mr_decay": """
36
  {%- set c = bp.components[0] -%}
@@ -148,34 +148,6 @@ def _get_backfill(blueprint: Blueprint) -> int:
148
  return 10
149
 
150
 
151
- def _apply_component_signs(expression: str, blueprint: Blueprint) -> str:
152
- """
153
- Apply per-component sign corrections based on field sign conventions.
154
-
155
- Unlike the old code which blindly negated the entire expression,
156
- this only adjusts individual field references based on their
157
- registered sign_direction mismatch with the component's declared sign.
158
- """
159
- # For now, the templates handle sign_direction correctly via Jinja conditionals.
160
- # This function is a hook for future per-component sign sweeps.
161
- return expression
162
-
163
-
164
- def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
165
- """Mandatory post-compilation transforms."""
166
- expr = expression.strip()
167
-
168
- # Apply component-level sign corrections (now handled in templates)
169
- expr = _apply_component_signs(expr, blueprint)
170
-
171
- # ALWAYS wrap in ts_decay_linear (min 5, max 10)
172
- decay = min(max(blueprint.decay, 5), 10)
173
- if not expr.startswith("ts_decay_linear"):
174
- expr = f"ts_decay_linear({expr}, {decay})"
175
-
176
- return expr
177
-
178
-
179
  def _validate_expression(expr: str) -> list[str]:
180
  """Validate expression has no obvious errors. Returns list of issues."""
181
  issues = []
@@ -196,6 +168,26 @@ def _validate_expression(expr: str) -> list[str]:
196
  return issues
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression:
200
  """Convert a Blueprint to a BRAIN expression."""
201
 
@@ -263,6 +255,13 @@ Output ONLY the expression."""
263
  if "ts_decay_linear" not in result.operators_used:
264
  result.operators_used.append("ts_decay_linear")
265
 
 
 
 
 
 
 
 
266
  return result
267
 
268
 
 
1
  """
2
+ Expression Compiler v6 — Fixed per-component sign logic,
3
  correct weight semantics (weights inside rank don't linearly combine),
4
+ removed double-negation bugs, WIRED field validation.
5
  """
6
  import re
7
  from jinja2 import Environment, BaseLoader
 
30
  """,
31
  "vol_scaled_shock": """
32
  {%- set c = bp.components[0] -%}
33
+ ts_decay_linear(group_neutralize(zscore(ts_delta({{ c.fields[0] }}, {{ c.horizon_days }}) / (ts_std({{ c.fields[0] }}, {{ c.horizon_days * 4 }}) + 0.001))), {{ group_key }}), {{ bp.decay }})
34
  """,
35
  "intraday_mr_decay": """
36
  {%- set c = bp.components[0] -%}
 
148
  return 10
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def _validate_expression(expr: str) -> list[str]:
152
  """Validate expression has no obvious errors. Returns list of issues."""
153
  issues = []
 
168
  return issues
169
 
170
 
171
+ def _apply_post_compile_rules(expression: str, blueprint: Blueprint) -> str:
172
+ """Mandatory post-compilation transforms."""
173
+ expr = expression.strip()
174
+
175
+ # ALWAYS wrap in ts_decay_linear (min 5, max 10)
176
+ decay = min(max(blueprint.decay, 5), 10)
177
+ if not expr.startswith("ts_decay_linear"):
178
+ expr = f"ts_decay_linear({expr}, {decay})"
179
+
180
+ # Validate after all transforms
181
+ issues = _validate_expression(expr)
182
+ if issues:
183
+ # Log but don't crash — some fields may be legitimate but not in our registry yet
184
+ import logging
185
+ logger = logging.getLogger(__name__)
186
+ logger.debug(f"Expression validation issues: {issues}")
187
+
188
+ return expr
189
+
190
+
191
  async def compile_expression(blueprint: Blueprint, llm: LLMClient) -> Expression:
192
  """Convert a Blueprint to a BRAIN expression."""
193
 
 
255
  if "ts_decay_linear" not in result.operators_used:
256
  result.operators_used.append("ts_decay_linear")
257
 
258
+ # Run field validation on LLM output and surface warnings
259
+ issues = _validate_expression(result.expression)
260
+ if issues:
261
+ import logging
262
+ logger = logging.getLogger(__name__)
263
+ logger.warning(f"LLM expression has issues: {issues}")
264
+
265
  return result
266
 
267