File size: 5,212 Bytes
4579313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b456b8
4579313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Expression Mutator — Evolutionary alpha improvement.
Takes a base alpha expression and generates 5 structural variants.

Mutations:
1. Decay adjustment (±3 days)
2. Horizon change (different lookback window)
3. Neutralization swap (novel group key)
4. Vol-scaling wrapper
5. Sign flip (test opposite direction)
"""
import re
from ..data.brain_groups import ALT_GROUPS, PRODUCTION_GROUPS


def mutate_decay(expression: str, current_decay: int) -> list[dict]:
    """Generate decay variants."""
    variants = []
    for new_decay in [3, 5, 7, 10]:
        if new_decay == current_decay:
            continue
        # Replace the decay value in ts_decay_linear
        new_expr = re.sub(
            r"ts_decay_linear\((.+),\s*\d+\)",
            f"ts_decay_linear(\\1, {new_decay})",
            expression
        )
        if new_expr != expression:
            variants.append({
                "expression": new_expr,
                "mutation": f"decay_{current_decay}_to_{new_decay}",
                "decay": new_decay,
            })
    return variants


def mutate_horizon(expression: str) -> list[dict]:
    """Change lookback windows in ts_rank, ts_mean, ts_delta."""
    variants = []
    horizons = [21, 42, 63, 126, 252]
    
    # Find current horizon
    match = re.search(r"ts_(?:rank|mean|delta)\([^,]+,\s*(\d+)\)", expression)
    if not match:
        return []
    
    current = int(match.group(1))
    for h in horizons:
        if h == current:
            continue
        new_expr = re.sub(
            r"(ts_(?:rank|mean|delta)\([^,]+,\s*)\d+(\))",
            f"\\g<1>{h}\\2",
            expression,
            count=1  # Only change first occurrence
        )
        if new_expr != expression:
            variants.append({
                "expression": new_expr,
                "mutation": f"horizon_{current}_to_{h}",
                "decay": None,  # Keep original
            })
    return variants[:2]  # Max 2 horizon variants


def mutate_neutralization(expression: str) -> list[dict]:
    """Swap neutralization group key."""
    variants = []
    # Find current group key
    match = re.search(r"group_(?:neutralize|zscore|rank)\([^,]+,\s*([a-z0-9_]+)\)", expression)
    if not match:
        return []
    
    current_group = match.group(1)
    # Try novel groups
    novel_groups = [g for g in PRODUCTION_GROUPS if g.id != current_group][:3]
    
    for g in novel_groups:
        new_expr = expression.replace(current_group, g.id)
        if new_expr != expression:
            variants.append({
                "expression": new_expr,
                "mutation": f"group_{current_group}_to_{g.id}",
                "decay": None,
            })
    return variants[:2]


def mutate_vol_scale(expression: str) -> list[dict]:
    """Add volatility scaling to the signal."""
    # If expression has a raw field, wrap it in vol-normalization
    # Pattern: find the innermost field reference and divide by ts_std
    match = re.search(r"(ts_(?:rank|mean)\()([a-z][a-z0-9_]+)(,\s*\d+\))", expression)
    if not match:
        return []
    
    prefix, field, suffix = match.group(1), match.group(2), match.group(3)
    # Get the window from suffix
    window_match = re.search(r"(\d+)", suffix)
    if not window_match:
        return []
    window = int(window_match.group(1))
    vol_window = min(window * 2, 252)
    
    # Replace field with field / (ts_std(field, vol_window) + 0.001)
    vol_scaled = f"{field} / (ts_std({field}, {vol_window}) + 0.001)"
    new_expr = expression.replace(
        f"{prefix}{field}{suffix}",
        f"{prefix}{vol_scaled}{suffix}"
    )
    
    if new_expr != expression:
        return [{"expression": new_expr, "mutation": "vol_scaled", "decay": None}]
    return []


def mutate_sign_flip(expression: str) -> list[dict]:
    """Flip the sign of the entire expression (test opposite direction)."""
    # Extract inner from ts_decay_linear wrapper
    match = re.match(r"ts_decay_linear\((.+),\s*(\d+)\)$", expression)
    if match:
        inner_expr = match.group(1).strip()
        decay = match.group(2)
        
        if inner_expr.startswith("-(" ) and inner_expr.endswith(")"):
            # Remove negation
            flipped = inner_expr[2:-1]
        else:
            # Add negation
            flipped = f"-({inner_expr})"
        
        new_expr = f"ts_decay_linear({flipped}, {decay})"
        return [{"expression": new_expr, "mutation": "sign_flip", "decay": None}]
    
    return []


def generate_mutations(expression: str, decay: int = 5) -> list[dict]:
    """
    Generate all mutation variants from a base expression.
    Returns list of dicts with 'expression', 'mutation' type, and optional 'decay'.
    """
    all_variants = []
    
    # 1. Decay mutations
    all_variants.extend(mutate_decay(expression, decay))
    
    # 2. Horizon mutations
    all_variants.extend(mutate_horizon(expression))
    
    # 3. Neutralization mutations
    all_variants.extend(mutate_neutralization(expression))
    
    # 4. Vol-scaling
    all_variants.extend(mutate_vol_scale(expression))
    
    # 5. Sign flip
    all_variants.extend(mutate_sign_flip(expression))
    
    return all_variants