alpha-factory / alpha_factory /deterministic /expression_mutator.py
gaurv007's picture
Upload alpha_factory/deterministic/expression_mutator.py
8b456b8 verified
"""
Expression Mutator — Evolutionary alpha improvement.
Takes a base alpha expression and generates 5 structural variants.
Mutations:
1. Decay adjustment (±3 days)
2. Horizon change (different lookback window)
3. Neutralization swap (novel group key)
4. Vol-scaling wrapper
5. Sign flip (test opposite direction)
"""
import re
from ..data.brain_groups import ALT_GROUPS, PRODUCTION_GROUPS
def mutate_decay(expression: str, current_decay: int) -> list[dict]:
"""Generate decay variants."""
variants = []
for new_decay in [3, 5, 7, 10]:
if new_decay == current_decay:
continue
# Replace the decay value in ts_decay_linear
new_expr = re.sub(
r"ts_decay_linear\((.+),\s*\d+\)",
f"ts_decay_linear(\\1, {new_decay})",
expression
)
if new_expr != expression:
variants.append({
"expression": new_expr,
"mutation": f"decay_{current_decay}_to_{new_decay}",
"decay": new_decay,
})
return variants
def mutate_horizon(expression: str) -> list[dict]:
"""Change lookback windows in ts_rank, ts_mean, ts_delta."""
variants = []
horizons = [21, 42, 63, 126, 252]
# Find current horizon
match = re.search(r"ts_(?:rank|mean|delta)\([^,]+,\s*(\d+)\)", expression)
if not match:
return []
current = int(match.group(1))
for h in horizons:
if h == current:
continue
new_expr = re.sub(
r"(ts_(?:rank|mean|delta)\([^,]+,\s*)\d+(\))",
f"\\g<1>{h}\\2",
expression,
count=1 # Only change first occurrence
)
if new_expr != expression:
variants.append({
"expression": new_expr,
"mutation": f"horizon_{current}_to_{h}",
"decay": None, # Keep original
})
return variants[:2] # Max 2 horizon variants
def mutate_neutralization(expression: str) -> list[dict]:
"""Swap neutralization group key."""
variants = []
# Find current group key
match = re.search(r"group_(?:neutralize|zscore|rank)\([^,]+,\s*([a-z0-9_]+)\)", expression)
if not match:
return []
current_group = match.group(1)
# Try novel groups
novel_groups = [g for g in PRODUCTION_GROUPS if g.id != current_group][:3]
for g in novel_groups:
new_expr = expression.replace(current_group, g.id)
if new_expr != expression:
variants.append({
"expression": new_expr,
"mutation": f"group_{current_group}_to_{g.id}",
"decay": None,
})
return variants[:2]
def mutate_vol_scale(expression: str) -> list[dict]:
"""Add volatility scaling to the signal."""
# If expression has a raw field, wrap it in vol-normalization
# Pattern: find the innermost field reference and divide by ts_std
match = re.search(r"(ts_(?:rank|mean)\()([a-z][a-z0-9_]+)(,\s*\d+\))", expression)
if not match:
return []
prefix, field, suffix = match.group(1), match.group(2), match.group(3)
# Get the window from suffix
window_match = re.search(r"(\d+)", suffix)
if not window_match:
return []
window = int(window_match.group(1))
vol_window = min(window * 2, 252)
# Replace field with field / (ts_std(field, vol_window) + 0.001)
vol_scaled = f"{field} / (ts_std({field}, {vol_window}) + 0.001)"
new_expr = expression.replace(
f"{prefix}{field}{suffix}",
f"{prefix}{vol_scaled}{suffix}"
)
if new_expr != expression:
return [{"expression": new_expr, "mutation": "vol_scaled", "decay": None}]
return []
def mutate_sign_flip(expression: str) -> list[dict]:
"""Flip the sign of the entire expression (test opposite direction)."""
# Extract inner from ts_decay_linear wrapper
match = re.match(r"ts_decay_linear\((.+),\s*(\d+)\)$", expression)
if match:
inner_expr = match.group(1).strip()
decay = match.group(2)
if inner_expr.startswith("-(" ) and inner_expr.endswith(")"):
# Remove negation
flipped = inner_expr[2:-1]
else:
# Add negation
flipped = f"-({inner_expr})"
new_expr = f"ts_decay_linear({flipped}, {decay})"
return [{"expression": new_expr, "mutation": "sign_flip", "decay": None}]
return []
def generate_mutations(expression: str, decay: int = 5) -> list[dict]:
"""
Generate all mutation variants from a base expression.
Returns list of dicts with 'expression', 'mutation' type, and optional 'decay'.
"""
all_variants = []
# 1. Decay mutations
all_variants.extend(mutate_decay(expression, decay))
# 2. Horizon mutations
all_variants.extend(mutate_horizon(expression))
# 3. Neutralization mutations
all_variants.extend(mutate_neutralization(expression))
# 4. Vol-scaling
all_variants.extend(mutate_vol_scale(expression))
# 5. Sign flip
all_variants.extend(mutate_sign_flip(expression))
return all_variants