gaurv007 commited on
Commit
6c7bc78
·
verified ·
1 Parent(s): 5453fed

Upload alpha_factory/infra/wq_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/wq_client.py +143 -0
alpha_factory/infra/wq_client.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WorldQuant BRAIN API Client — async wrapper with rate limiting.
3
+ Handles submission, polling, and result harvesting.
4
+ """
5
+ import asyncio
6
+ import aiohttp
7
+ import hashlib
8
+ from typing import Any, Optional
9
+ from ..config import BrainConfig
10
+ from ..schemas import BrainMetrics
11
+
12
+
13
+ class BrainClient:
14
+ """
15
+ Async client for WorldQuant BRAIN API.
16
+ Rate-limited, retry-on-429, circuit-breaker-protected.
17
+ """
18
+
19
+ def __init__(self, session: aiohttp.ClientSession, config: BrainConfig):
20
+ self.session = session
21
+ self.config = config
22
+ self.semaphore = asyncio.Semaphore(config.max_concurrent)
23
+ self._submissions_today = 0
24
+ self._consecutive_failures = 0
25
+ self._circuit_open = False
26
+
27
+ async def submit_alpha(
28
+ self,
29
+ expression: str,
30
+ neutralization: str = "sector",
31
+ decay: int = 5,
32
+ ) -> dict[str, Any]:
33
+ """
34
+ Submit an alpha expression to BRAIN for simulation.
35
+ Returns the simulation response or raises on failure.
36
+ """
37
+ if self._circuit_open:
38
+ raise RuntimeError("Circuit breaker OPEN — too many consecutive failures")
39
+
40
+ async with self.semaphore:
41
+ # Rate limiting
42
+ await asyncio.sleep(self.config.submit_interval_sec)
43
+
44
+ payload = {
45
+ "type": "REGULAR",
46
+ "settings": {
47
+ "instrumentType": "EQUITY",
48
+ "region": self.config.region,
49
+ "universe": self.config.universe,
50
+ "delay": self.config.delay,
51
+ "neutralization": neutralization,
52
+ "decay": decay,
53
+ "truncation": self.config.truncation,
54
+ "pasteurization": self.config.pasteurization,
55
+ "nanHandling": self.config.nan_handling,
56
+ },
57
+ "regular": expression,
58
+ }
59
+
60
+ try:
61
+ async with self.session.post(
62
+ f"{self.config.api_url}/simulations",
63
+ json=payload,
64
+ ) as resp:
65
+ if resp.status == 429:
66
+ # Rate limited — back off
67
+ await asyncio.sleep(30)
68
+ return await self.submit_alpha(expression, neutralization, decay)
69
+ elif resp.status >= 500:
70
+ self._consecutive_failures += 1
71
+ if self._consecutive_failures >= 5:
72
+ self._circuit_open = True
73
+ raise aiohttp.ClientResponseError(
74
+ resp.request_info, resp.history, status=resp.status
75
+ )
76
+
77
+ resp.raise_for_status()
78
+ self._consecutive_failures = 0
79
+ self._submissions_today += 1
80
+ return await resp.json()
81
+
82
+ except aiohttp.ClientError as e:
83
+ self._consecutive_failures += 1
84
+ if self._consecutive_failures >= 5:
85
+ self._circuit_open = True
86
+ raise
87
+
88
+ async def poll_simulation(self, sim_id: str, max_wait: int = 300) -> dict[str, Any]:
89
+ """Poll a simulation until completion or timeout."""
90
+ elapsed = 0
91
+ interval = 5
92
+
93
+ while elapsed < max_wait:
94
+ async with self.session.get(
95
+ f"{self.config.api_url}/simulations/{sim_id}"
96
+ ) as resp:
97
+ if resp.status == 200:
98
+ data = await resp.json()
99
+ if data.get("status") == "DONE":
100
+ return data
101
+ elif data.get("status") == "ERROR":
102
+ raise RuntimeError(f"Simulation failed: {data.get('error', 'unknown')}")
103
+
104
+ await asyncio.sleep(interval)
105
+ elapsed += interval
106
+
107
+ raise TimeoutError(f"Simulation {sim_id} timed out after {max_wait}s")
108
+
109
+ def parse_metrics(self, sim_result: dict, alpha_id: str) -> BrainMetrics:
110
+ """Extract structured metrics from BRAIN simulation result."""
111
+ stats = sim_result.get("stats", {})
112
+ yearly = sim_result.get("yearly", [])
113
+
114
+ return BrainMetrics(
115
+ alpha_id=alpha_id,
116
+ sharpe_full=stats.get("sharpe", 0.0),
117
+ sharpe_is=stats.get("sharpe_is", stats.get("sharpe", 0.0)),
118
+ sharpe_os=stats.get("sharpe_os", stats.get("sharpe", 0.0)),
119
+ fitness=stats.get("fitness", 0.0),
120
+ turnover=stats.get("turnover", 0.0),
121
+ returns=stats.get("returns", 0.0),
122
+ max_drawdown=stats.get("max_drawdown", 0.0),
123
+ yearly_sharpe=[y.get("sharpe", 0.0) for y in yearly],
124
+ yearly_returns=[y.get("returns", 0.0) for y in yearly],
125
+ margin_pct=stats.get("margin", None),
126
+ long_count=stats.get("long_count", None),
127
+ short_count=stats.get("short_count", None),
128
+ )
129
+
130
+ @property
131
+ def submissions_today(self) -> int:
132
+ return self._submissions_today
133
+
134
+ def reset_circuit(self):
135
+ """Manually reset the circuit breaker."""
136
+ self._circuit_open = False
137
+ self._consecutive_failures = 0
138
+
139
+ @staticmethod
140
+ def alpha_hash(expression: str, neutralization: str, decay: int) -> str:
141
+ """Deterministic hash for dedup."""
142
+ key = f"{expression.strip()}|{neutralization}|{decay}"
143
+ return hashlib.sha256(key.encode()).hexdigest()[:16]