alpha-factory / alpha_factory /infra /factor_store.py
gaurv007's picture
Upload alpha_factory/infra/factor_store.py
8a49c31 verified
"""
Factor Store — DuckDB persistence for all alphas.
Single source of truth for every alpha ever submitted.
"""
import duckdb
from pathlib import Path
from datetime import datetime
from typing import Optional
from ..schemas import BrainMetrics, Verdict
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS alphas (
alpha_id VARCHAR PRIMARY KEY,
submitted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expression TEXT NOT NULL,
neutralization VARCHAR NOT NULL,
decay INTEGER NOT NULL,
universe VARCHAR DEFAULT 'TOP3000',
region VARCHAR DEFAULT 'USA',
delay_days INTEGER DEFAULT 1,
fields_used VARCHAR[],
operators_used VARCHAR[],
archetype VARCHAR,
theme VARCHAR,
anomaly_tag VARCHAR,
academic_anchor VARCHAR,
sharpe_full DOUBLE,
sharpe_is DOUBLE,
sharpe_os DOUBLE,
fitness_brain DOUBLE,
yearly_sharpe DOUBLE[],
yearly_returns DOUBLE[],
turnover DOUBLE,
max_drawdown DOUBLE,
returns_total DOUBLE,
margin_pct DOUBLE,
fitness_score DOUBLE,
max_corr_to_library DOUBLE,
verdict VARCHAR,
gatekeeper_memo TEXT,
iteration INTEGER DEFAULT 1,
family_id VARCHAR,
created_by VARCHAR DEFAULT 'pipeline'
);
CREATE TABLE IF NOT EXISTS dead_themes (
theme VARCHAR,
universe VARCHAR,
date_killed TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_sharpe DOUBLE,
reason TEXT,
cooldown_until TIMESTAMP
);
"""
class FactorStore:
"""DuckDB-backed factor store for all alpha history."""
def __init__(self, db_path: Path):
self.db_path = db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self.conn = duckdb.connect(str(db_path))
self.conn.execute(SCHEMA_SQL)
def insert_alpha(
self,
alpha_id: str,
expression: str,
neutralization: str,
decay: int,
fields_used: list[str],
operators_used: list[str],
archetype: str,
theme: str,
anomaly_tag: str,
academic_anchor: Optional[str] = None,
family_id: Optional[str] = None,
iteration: int = 1,
):
"""Insert a new alpha candidate (before BRAIN results arrive)."""
# NOTE: All SQL uses parameterized queries (? placeholders) to prevent injection.
# The expression field is user-controlled (from LLM output) but is passed as a param.
self.conn.execute("""
INSERT OR REPLACE INTO alphas (alpha_id, expression, neutralization, decay,
fields_used, operators_used, archetype, theme, anomaly_tag,
academic_anchor, family_id, iteration)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", [alpha_id, expression, neutralization, decay,
fields_used, operators_used, archetype, theme, anomaly_tag,
academic_anchor, family_id, iteration])
def update_metrics(self, alpha_id: str, metrics: BrainMetrics, fitness_score: float):
"""Update alpha with BRAIN simulation results.
Uses DuckDB's transaction context manager for atomicity.
"""
with self.conn:
self.conn.execute("""
UPDATE alphas SET
sharpe_full = ?, sharpe_is = ?, sharpe_os = ?,
fitness_brain = ?, turnover = ?, returns_total = ?,
max_drawdown = ?, yearly_sharpe = ?, yearly_returns = ?,
margin_pct = ?, fitness_score = ?
WHERE alpha_id = ?
""", [metrics.sharpe_full, metrics.sharpe_is, metrics.sharpe_os,
metrics.fitness, metrics.turnover, metrics.returns,
metrics.max_drawdown, metrics.yearly_sharpe, metrics.yearly_returns,
metrics.margin_pct, fitness_score, alpha_id])
def update_verdict(self, alpha_id: str, verdict: Verdict, memo: str = ""):
"""Set the final verdict for an alpha."""
self.conn.execute("""
UPDATE alphas SET verdict = ?, gatekeeper_memo = ? WHERE alpha_id = ?
""", [verdict.value, memo, alpha_id])
def update_correlation(self, alpha_id: str, max_corr: float):
"""Update max correlation to library."""
self.conn.execute("""
UPDATE alphas SET max_corr_to_library = ? WHERE alpha_id = ?
""", [max_corr, alpha_id])
def get_all_themes(self) -> list[str]:
"""Get themes of all alphas in the store."""
result = self.conn.execute("SELECT theme FROM alphas WHERE theme IS NOT NULL").fetchall()
return [r[0] for r in result]
def get_all_anomaly_tags(self) -> list[str]:
"""Get anomaly tags of all alphas."""
result = self.conn.execute("SELECT anomaly_tag FROM alphas WHERE anomaly_tag IS NOT NULL").fetchall()
return [r[0] for r in result]
def get_dead_themes(self) -> list[str]:
"""Get themes that are in cooldown."""
result = self.conn.execute("""
SELECT theme FROM dead_themes WHERE cooldown_until > CURRENT_TIMESTAMP
""").fetchall()
return [r[0] for r in result]
def exists(self, alpha_id: str) -> bool:
"""Check if an alpha already exists (dedup)."""
result = self.conn.execute("SELECT 1 FROM alphas WHERE alpha_id = ?", [alpha_id]).fetchone()
return result is not None
def get_expression_hashes(self) -> set[str]:
"""Get all alpha_ids for dedup."""
result = self.conn.execute("SELECT alpha_id FROM alphas").fetchall()
return {r[0] for r in result}
def count_consecutive_kills(self) -> int:
"""Count consecutive kills from most recent (for kill switch)."""
# Use rowid as tiebreaker for deterministic ordering when timestamps are equal
results = self.conn.execute("""
SELECT verdict FROM alphas
ORDER BY submitted_at DESC, rowid DESC
LIMIT 50
""").fetchall()
count = 0
for r in results:
if r[0] == "kill":
count += 1
else:
break
return count
def kill_theme(self, theme: str, last_sharpe: float, reason: str, cooldown_days: int = 180):
"""Add a theme to the dead list with cooldown."""
self.conn.execute("""
INSERT INTO dead_themes (theme, universe, last_sharpe, reason, cooldown_until)
VALUES (?, 'TOP3000', ?, ?, CURRENT_TIMESTAMP + INTERVAL ? DAY)
""", [theme, last_sharpe, reason, cooldown_days])
def get_library_stats(self) -> dict:
"""Summary statistics for the factor store."""
total = self.conn.execute("SELECT COUNT(*) FROM alphas").fetchone()[0]
promoted = self.conn.execute("SELECT COUNT(*) FROM alphas WHERE verdict = 'promote'").fetchone()[0]
killed = self.conn.execute("SELECT COUNT(*) FROM alphas WHERE verdict = 'kill'").fetchone()[0]
avg_sharpe = self.conn.execute("SELECT AVG(sharpe_os) FROM alphas WHERE sharpe_os IS NOT NULL").fetchone()[0]
return {
"total_alphas": total,
"promoted": promoted,
"killed": killed,
"pending": total - promoted - killed,
"avg_sharpe_os": round(avg_sharpe or 0, 3),
}
def close(self):
self.conn.close()