alpha-factory / alpha_factory /infra /winner_memory.py
gaurv007's picture
feat: winner_memory.py — stores BRAIN results, identifies winning patterns, feeds back to generation"
ff03a17 verified
"""
Winner Memory — Feedback loop component.
Tracks what works and what doesn't, feeds patterns back to generation.
Key principle: The system LEARNS from BRAIN results.
- Winning fields get higher priority in future generation
- Failing patterns get blacklisted
- Near-misses get auto-iterated
"""
import duckdb
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
@dataclass
class WinnerPattern:
"""A pattern extracted from successful alphas."""
field_id: str
archetype: str
group_key: str
decay: int
sharpe: float
theme: str
@dataclass
class FailurePattern:
"""A pattern extracted from failed alphas."""
field_id: str
archetype: str
reason: str # "low_sharpe", "high_turnover", "flat_line", "unknown_op"
class WinnerMemory:
"""
Stores and retrieves patterns from BRAIN simulation results.
Used by hypothesis hunter and template generator to prioritize what works.
"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn = duckdb.connect(str(db_path))
self._ensure_tables()
def _ensure_tables(self):
self.conn.execute("""
CREATE TABLE IF NOT EXISTS winner_patterns (
field_id VARCHAR,
archetype VARCHAR,
group_key VARCHAR,
decay INTEGER,
sharpe DOUBLE,
theme VARCHAR,
discovered_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.conn.execute("""
CREATE TABLE IF NOT EXISTS failure_patterns (
field_id VARCHAR,
archetype VARCHAR,
reason VARCHAR,
expression_hash VARCHAR,
recorded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.conn.execute("""
CREATE TABLE IF NOT EXISTS iteration_queue (
alpha_id VARCHAR,
expression VARCHAR,
sharpe DOUBLE,
turnover DOUBLE,
suggestion VARCHAR,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
iterated BOOLEAN DEFAULT FALSE
)
""")
def record_winner(self, field_id: str, archetype: str, group_key: str,
decay: int, sharpe: float, theme: str):
"""Record a winning pattern (Sharpe > 1.25)."""
self.conn.execute("""
INSERT INTO winner_patterns (field_id, archetype, group_key, decay, sharpe, theme)
VALUES (?, ?, ?, ?, ?, ?)
""", [field_id, archetype, group_key, decay, sharpe, theme])
def record_failure(self, field_id: str, archetype: str, reason: str, expr_hash: str):
"""Record a failure pattern."""
self.conn.execute("""
INSERT INTO failure_patterns (field_id, archetype, reason, expression_hash)
VALUES (?, ?, ?, ?)
""", [field_id, archetype, reason, expr_hash])
def queue_for_iteration(self, alpha_id: str, expression: str,
sharpe: float, turnover: float, suggestion: str):
"""Queue a near-miss alpha for auto-iteration."""
self.conn.execute("""
INSERT INTO iteration_queue (alpha_id, expression, sharpe, turnover, suggestion)
VALUES (?, ?, ?, ?, ?)
""", [alpha_id, expression, sharpe, turnover, suggestion])
def get_winning_fields(self, min_sharpe: float = 1.5) -> list[str]:
"""Get field IDs that have produced winners."""
rows = self.conn.execute("""
SELECT DISTINCT field_id FROM winner_patterns
WHERE sharpe >= ? ORDER BY sharpe DESC
""", [min_sharpe]).fetchall()
return [r[0] for r in rows]
def get_winning_archetypes(self) -> list[tuple[str, float]]:
"""Get archetypes ranked by average Sharpe."""
rows = self.conn.execute("""
SELECT archetype, AVG(sharpe) as avg_sharpe, COUNT(*) as cnt
FROM winner_patterns
GROUP BY archetype
ORDER BY avg_sharpe DESC
""").fetchall()
return [(r[0], r[1]) for r in rows]
def get_failed_fields(self) -> set[str]:
"""Get fields that consistently fail (3+ failures, no wins)."""
fail_rows = self.conn.execute("""
SELECT field_id, COUNT(*) as fail_count
FROM failure_patterns
GROUP BY field_id
HAVING fail_count >= 3
""").fetchall()
failed = set()
for field_id, _ in fail_rows:
# Check if this field has any wins
win = self.conn.execute("""
SELECT 1 FROM winner_patterns WHERE field_id = ? LIMIT 1
""", [field_id]).fetchone()
if not win:
failed.add(field_id)
return failed
def get_iteration_queue(self, limit: int = 10) -> list[dict]:
"""Get alphas queued for iteration."""
rows = self.conn.execute("""
SELECT alpha_id, expression, sharpe, turnover, suggestion
FROM iteration_queue
WHERE iterated = FALSE
ORDER BY sharpe DESC
LIMIT ?
""", [limit]).fetchall()
return [
{"alpha_id": r[0], "expression": r[1], "sharpe": r[2],
"turnover": r[3], "suggestion": r[4]}
for r in rows
]
def mark_iterated(self, alpha_id: str):
"""Mark an iteration queue item as processed."""
self.conn.execute("""
UPDATE iteration_queue SET iterated = TRUE WHERE alpha_id = ?
""", [alpha_id])
def get_best_config(self) -> Optional[dict]:
"""Get the best-performing configuration from winners."""
row = self.conn.execute("""
SELECT field_id, archetype, group_key, decay, sharpe
FROM winner_patterns
ORDER BY sharpe DESC LIMIT 1
""").fetchone()
if row:
return {
"field_id": row[0], "archetype": row[1],
"group_key": row[2], "decay": row[3], "sharpe": row[4],
}
return None
def suggest_iteration(self, sharpe: float, turnover: float) -> str:
"""Suggest what to change for a near-miss alpha."""
if turnover > 0.70:
return "increase_decay"
elif sharpe < 0.5:
return "sign_flip"
elif sharpe < 1.0:
return "change_horizon"
else: # 1.0 <= sharpe < 1.25
return "change_neutralization"
def close(self):
self.conn.close()