gaurv007 commited on
Commit
b59d9a6
·
verified ·
1 Parent(s): 68817c3

Upload alpha_factory/infra/wq_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/wq_client.py +110 -24
alpha_factory/infra/wq_client.py CHANGED
@@ -1,19 +1,47 @@
1
  """
2
- WorldQuant BRAIN API Client — async wrapper with rate limiting.
3
- Handles submission, polling, and result harvesting.
 
 
 
 
4
  """
5
  import asyncio
6
  import aiohttp
7
  import hashlib
 
8
  from typing import Any, Optional
9
  from ..config import BrainConfig
10
  from ..schemas import BrainMetrics
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class BrainClient:
14
  """
15
  Async client for WorldQuant BRAIN API.
16
  Rate-limited, retry-on-429, circuit-breaker-protected.
 
 
 
17
  """
18
 
19
  def __init__(self, session: aiohttp.ClientSession, config: BrainConfig):
@@ -23,6 +51,46 @@ class BrainClient:
23
  self._submissions_today = 0
24
  self._consecutive_failures = 0
25
  self._circuit_open = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  async def submit_alpha(
28
  self,
@@ -35,7 +103,9 @@ class BrainClient:
35
  Returns the simulation response or raises on failure.
36
  """
37
  if self._circuit_open:
38
- raise RuntimeError("Circuit breaker OPEN — too many consecutive failures")
 
 
39
 
40
  async with self.semaphore:
41
  # Rate limiting
@@ -61,56 +131,72 @@ class BrainClient:
61
  async with self.session.post(
62
  f"{self.config.api_url}/simulations",
63
  json=payload,
 
64
  ) as resp:
65
  if resp.status == 429:
66
- # Rate limited — back off
67
- await asyncio.sleep(30)
 
68
  return await self.submit_alpha(expression, neutralization, decay)
69
- elif resp.status >= 500:
70
- self._consecutive_failures += 1
71
- if self._consecutive_failures >= 5:
72
- self._circuit_open = True
73
- raise aiohttp.ClientResponseError(
74
- resp.request_info, resp.history, status=resp.status
75
- )
76
 
77
  resp.raise_for_status()
78
  self._consecutive_failures = 0
79
  self._submissions_today += 1
80
  return await resp.json()
81
 
82
- except aiohttp.ClientError as e:
83
  self._consecutive_failures += 1
84
  if self._consecutive_failures >= 5:
85
  self._circuit_open = True
86
- raise
87
 
88
  async def poll_simulation(self, sim_id: str, max_wait: int = 300) -> dict[str, Any]:
89
  """Poll a simulation until completion or timeout."""
90
  elapsed = 0
91
  interval = 5
 
92
 
93
  while elapsed < max_wait:
94
- async with self.session.get(
95
- f"{self.config.api_url}/simulations/{sim_id}"
96
- ) as resp:
97
- if resp.status == 200:
98
- data = await resp.json()
99
- if data.get("status") == "DONE":
100
- return data
101
- elif data.get("status") == "ERROR":
102
- raise RuntimeError(f"Simulation failed: {data.get('error', 'unknown')}")
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  await asyncio.sleep(interval)
105
  elapsed += interval
106
 
107
- raise TimeoutError(f"Simulation {sim_id} timed out after {max_wait}s")
 
 
108
 
109
  def parse_metrics(self, sim_result: dict, alpha_id: str) -> BrainMetrics:
110
  """Extract structured metrics from BRAIN simulation result."""
111
  stats = sim_result.get("stats", {})
112
  yearly = sim_result.get("yearly", [])
113
 
 
 
 
 
 
114
  return BrainMetrics(
115
  alpha_id=alpha_id,
116
  sharpe_full=stats.get("sharpe", 0.0),
 
1
  """
2
+ WorldQuant BRAIN API Client v3 — async wrapper with proper session auth,
3
+ rate limiting, circuit breaker, and comprehensive exception handling.
4
+
5
+ NOTE: BRAIN uses session-based auth (cookies), not API keys.
6
+ Session tokens expire quickly. The client reads BRAIN_SESSION_TOKEN
7
+ from env and includes it as a Cookie header.
8
  """
9
  import asyncio
10
  import aiohttp
11
  import hashlib
12
+ import os
13
  from typing import Any, Optional
14
  from ..config import BrainConfig
15
  from ..schemas import BrainMetrics
16
 
17
 
18
+ class BrainClientError(Exception):
19
+ """Base exception for BRAIN client errors."""
20
+ pass
21
+
22
+
23
+ class BrainAuthError(BrainClientError):
24
+ """Raised when BRAIN authentication fails (401)."""
25
+ pass
26
+
27
+
28
+ class BrainRateLimitError(BrainClientError):
29
+ """Raised when BRAIN rate limits the request (429)."""
30
+ pass
31
+
32
+
33
+ class BrainServerError(BrainClientError):
34
+ """Raised when BRAIN returns a 5xx error."""
35
+ pass
36
+
37
+
38
  class BrainClient:
39
  """
40
  Async client for WorldQuant BRAIN API.
41
  Rate-limited, retry-on-429, circuit-breaker-protected.
42
+
43
+ Auth: reads BRAIN_SESSION_TOKEN from environment.
44
+ This is a browser session cookie that must be refreshed manually.
45
  """
46
 
47
  def __init__(self, session: aiohttp.ClientSession, config: BrainConfig):
 
51
  self._submissions_today = 0
52
  self._consecutive_failures = 0
53
  self._circuit_open = False
54
+
55
+ # Auth: session token from env
56
+ self._session_token = os.getenv("BRAIN_SESSION_TOKEN", "")
57
+ if not self._session_token:
58
+ # Also try legacy name
59
+ self._session_token = os.getenv("BRAIN_AUTH", "")
60
+
61
+ self._headers = {
62
+ "Content-Type": "application/json",
63
+ "Accept": "application/json",
64
+ }
65
+ if self._session_token:
66
+ # BRAIN uses session cookies, NOT Bearer tokens
67
+ self._headers["Cookie"] = f"session={self._session_token}"
68
+
69
+ def _check_auth(self):
70
+ """Verify we have auth credentials."""
71
+ if not self._session_token:
72
+ raise BrainAuthError(
73
+ "No BRAIN_SESSION_TOKEN set. "
74
+ "Get it from browser devtools after logging into brain.worldquant.com "
75
+ "(Network tab → any API request → copy Cookie header). "
76
+ "Set: export BRAIN_SESSION_TOKEN=your_token"
77
+ )
78
+
79
+ def _handle_response_error(self, resp: aiohttp.ClientResponse) -> None:
80
+ """Handle HTTP error responses with proper exception types."""
81
+ if resp.status == 401:
82
+ self._consecutive_failures += 1
83
+ raise BrainAuthError(
84
+ "BRAIN auth failed (401). Session token expired. "
85
+ "Get a new token from browser and set BRAIN_SESSION_TOKEN."
86
+ )
87
+ elif resp.status == 429:
88
+ raise BrainRateLimitError("BRAIN rate limit hit (429). Backing off.")
89
+ elif resp.status >= 500:
90
+ self._consecutive_failures += 1
91
+ if self._consecutive_failures >= 5:
92
+ self._circuit_open = True
93
+ raise BrainServerError(f"BRAIN server error: {resp.status}")
94
 
95
  async def submit_alpha(
96
  self,
 
103
  Returns the simulation response or raises on failure.
104
  """
105
  if self._circuit_open:
106
+ raise BrainClientError("Circuit breaker OPEN — too many consecutive failures")
107
+
108
+ self._check_auth()
109
 
110
  async with self.semaphore:
111
  # Rate limiting
 
131
  async with self.session.post(
132
  f"{self.config.api_url}/simulations",
133
  json=payload,
134
+ headers=self._headers,
135
  ) as resp:
136
  if resp.status == 429:
137
+ # Rate limited — back off with jitter
138
+ wait = 30 + (self._consecutive_failures * 10)
139
+ await asyncio.sleep(wait)
140
  return await self.submit_alpha(expression, neutralization, decay)
141
+ elif resp.status >= 400:
142
+ self._handle_response_error(resp)
 
 
 
 
 
143
 
144
  resp.raise_for_status()
145
  self._consecutive_failures = 0
146
  self._submissions_today += 1
147
  return await resp.json()
148
 
149
+ except (aiohttp.ClientError, asyncio.TimeoutError, ConnectionError) as e:
150
  self._consecutive_failures += 1
151
  if self._consecutive_failures >= 5:
152
  self._circuit_open = True
153
+ raise BrainClientError(f"BRAIN API request failed: {e}") from e
154
 
155
  async def poll_simulation(self, sim_id: str, max_wait: int = 300) -> dict[str, Any]:
156
  """Poll a simulation until completion or timeout."""
157
  elapsed = 0
158
  interval = 5
159
+ last_error = None
160
 
161
  while elapsed < max_wait:
162
+ try:
163
+ async with self.session.get(
164
+ f"{self.config.api_url}/simulations/{sim_id}",
165
+ headers=self._headers,
166
+ ) as resp:
167
+ if resp.status == 200:
168
+ data = await resp.json()
169
+ if data.get("status") == "DONE":
170
+ return data
171
+ elif data.get("status") == "ERROR":
172
+ raise BrainClientError(
173
+ f"Simulation failed: {data.get('error', 'unknown')}"
174
+ )
175
+ elif resp.status == 401:
176
+ raise BrainAuthError("Auth expired during polling. Refresh BRAIN_SESSION_TOKEN.")
177
+ elif resp.status >= 500:
178
+ last_error = f"Server error {resp.status}"
179
+ # Non-fatal errors: log and retry
180
+ except (aiohttp.ClientError, asyncio.TimeoutError, ConnectionError) as e:
181
+ last_error = str(e)
182
 
183
  await asyncio.sleep(interval)
184
  elapsed += interval
185
 
186
+ raise BrainClientError(
187
+ f"Simulation {sim_id} timed out after {max_wait}s. Last error: {last_error}"
188
+ )
189
 
190
  def parse_metrics(self, sim_result: dict, alpha_id: str) -> BrainMetrics:
191
  """Extract structured metrics from BRAIN simulation result."""
192
  stats = sim_result.get("stats", {})
193
  yearly = sim_result.get("yearly", [])
194
 
195
+ # BRAIN sometimes nests stats differently
196
+ if not stats and "result" in sim_result:
197
+ stats = sim_result["result"].get("stats", {})
198
+ yearly = sim_result["result"].get("yearly", [])
199
+
200
  return BrainMetrics(
201
  alpha_id=alpha_id,
202
  sharpe_full=stats.get("sharpe", 0.0),