gaurv007 commited on
Commit
1c0e301
·
verified ·
1 Parent(s): 4ddefc8

Fix LLM client: handle HF API response format limitations + better error handling"

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/llm_client.py +95 -38
alpha_factory/infra/llm_client.py CHANGED
@@ -5,6 +5,7 @@ All outputs are schema-constrained via guided JSON generation.
5
  """
6
  import asyncio
7
  import json
 
8
  from typing import TypeVar
9
  from pydantic import BaseModel
10
  from openai import AsyncOpenAI
@@ -19,8 +20,6 @@ class LLMClient:
19
  - Ollama (local, http://localhost:11434/v1)
20
  - HuggingFace Inference API (cloud, https://router.huggingface.co/v1)
21
  - vLLM (local/remote, any OpenAI-compatible endpoint)
22
-
23
- All outputs are JSON-schema-constrained for reliability.
24
  """
25
 
26
  def __init__(self, config: LLMConfig, model_manager=None):
@@ -29,24 +28,19 @@ class LLMClient:
29
  self._clients: dict[str, AsyncOpenAI] = {}
30
  self._token_count = 0
31
 
32
- def _get_client(self, base_url: str, api_key: str = "dummy", **headers) -> AsyncOpenAI:
33
  """Get or create an AsyncOpenAI client for the given endpoint."""
34
  key = f"{base_url}|{api_key}"
35
  if key not in self._clients:
36
  self._clients[key] = AsyncOpenAI(
37
  base_url=base_url,
38
  api_key=api_key,
39
- default_headers=headers if headers else None,
40
  )
41
  return self._clients[key]
42
 
43
  def _resolve_model(self, tier: str = "mediumfish", model_override: str | None = None) -> tuple[AsyncOpenAI, str]:
44
- """
45
- Resolve which client + model to use for a given tier.
46
- Priority: model_override > ModelManager selection > config default
47
- """
48
  if model_override:
49
- # Direct model name — use default endpoint
50
  client = self._get_client(self.config.base_url, self.config.api_key)
51
  return client, model_override
52
 
@@ -56,7 +50,6 @@ class LLMClient:
56
  client = self._get_client(base_url, api_key)
57
  return client, model_name
58
 
59
- # Fallback: use config defaults
60
  tier_to_model = {
61
  "microfish": self.config.microfish_model,
62
  "tinyfish": self.config.tinyfish_model,
@@ -78,26 +71,30 @@ class LLMClient:
78
  ) -> T:
79
  """
80
  Generate a structured JSON response conforming to the given Pydantic schema.
81
- Uses guided decoding via response_format.
82
-
83
- Args:
84
- prompt: The user prompt
85
- schema: Pydantic model class for output validation
86
- tier: Model tier (microfish/tinyfish/mediumfish/bigfish)
87
- model: Override model name (optional)
88
- temperature: Override temperature (optional)
89
- system_prompt: System message
90
  """
91
  client, model_name = self._resolve_model(tier, model)
92
  temp = temperature or self.config.temperature_generation
93
  json_schema = schema.model_json_schema()
94
 
 
 
 
 
 
 
 
 
 
95
  try:
96
  response = await client.chat.completions.create(
97
  model=model_name,
98
  messages=[
99
  {"role": "system", "content": system_prompt},
100
- {"role": "user", "content": prompt},
101
  ],
102
  temperature=temp,
103
  max_tokens=self.config.max_tokens,
@@ -109,32 +106,92 @@ class LLMClient:
109
  },
110
  },
111
  )
112
- except Exception:
113
- # Fallback: some providers don't support json_schema format
114
- # Try with json_object format + schema instruction in prompt
115
- schema_str = json.dumps(json_schema, indent=2)
116
- augmented_prompt = (
117
- f"{prompt}\n\n"
118
- f"IMPORTANT: Output ONLY valid JSON matching this schema:\n"
119
- f"```json\n{schema_str}\n```\n"
120
- f"No other text. Just the JSON."
121
- )
 
122
  response = await client.chat.completions.create(
123
  model=model_name,
124
  messages=[
125
- {"role": "system", "content": system_prompt},
126
- {"role": "user", "content": augmented_prompt},
127
  ],
128
  temperature=temp,
129
  max_tokens=self.config.max_tokens,
130
  response_format={"type": "json_object"},
131
  )
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
133
  content = response.choices[0].message.content
134
  self._token_count += response.usage.total_tokens if response.usage else 0
135
 
136
- # Parse and validate
137
- data = json.loads(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return schema.model_validate(data)
139
 
140
  async def generate_text(
@@ -146,7 +203,7 @@ class LLMClient:
146
  system_prompt: str = "You are a quantitative finance expert.",
147
  max_tokens: int = 2048,
148
  ) -> str:
149
- """Generate free-text response (for memos/reports only, never for expressions)."""
150
  client, model_name = self._resolve_model(tier, model)
151
  temp = temperature or self.config.temperature_critique
152
 
@@ -162,7 +219,7 @@ class LLMClient:
162
 
163
  content = response.choices[0].message.content
164
  self._token_count += response.usage.total_tokens if response.usage else 0
165
- return content
166
 
167
  async def health_check(self, tier: str = "mediumfish") -> bool:
168
  """Check if the model endpoint is reachable."""
@@ -170,10 +227,10 @@ class LLMClient:
170
  client, model_name = self._resolve_model(tier)
171
  response = await client.chat.completions.create(
172
  model=model_name,
173
- messages=[{"role": "user", "content": "Say 'ok'"}],
174
  max_tokens=5,
175
  )
176
- return True
177
  except Exception:
178
  return False
179
 
 
5
  """
6
  import asyncio
7
  import json
8
+ import re
9
  from typing import TypeVar
10
  from pydantic import BaseModel
11
  from openai import AsyncOpenAI
 
20
  - Ollama (local, http://localhost:11434/v1)
21
  - HuggingFace Inference API (cloud, https://router.huggingface.co/v1)
22
  - vLLM (local/remote, any OpenAI-compatible endpoint)
 
 
23
  """
24
 
25
  def __init__(self, config: LLMConfig, model_manager=None):
 
28
  self._clients: dict[str, AsyncOpenAI] = {}
29
  self._token_count = 0
30
 
31
+ def _get_client(self, base_url: str, api_key: str = "dummy", **kwargs) -> AsyncOpenAI:
32
  """Get or create an AsyncOpenAI client for the given endpoint."""
33
  key = f"{base_url}|{api_key}"
34
  if key not in self._clients:
35
  self._clients[key] = AsyncOpenAI(
36
  base_url=base_url,
37
  api_key=api_key,
 
38
  )
39
  return self._clients[key]
40
 
41
  def _resolve_model(self, tier: str = "mediumfish", model_override: str | None = None) -> tuple[AsyncOpenAI, str]:
42
+ """Resolve which client + model to use for a given tier."""
 
 
 
43
  if model_override:
 
44
  client = self._get_client(self.config.base_url, self.config.api_key)
45
  return client, model_override
46
 
 
50
  client = self._get_client(base_url, api_key)
51
  return client, model_name
52
 
 
53
  tier_to_model = {
54
  "microfish": self.config.microfish_model,
55
  "tinyfish": self.config.tinyfish_model,
 
71
  ) -> T:
72
  """
73
  Generate a structured JSON response conforming to the given Pydantic schema.
74
+ Tries multiple strategies for JSON output:
75
+ 1. response_format: json_schema (vLLM/Ollama)
76
+ 2. response_format: json_object (some providers)
77
+ 3. Plain text with JSON extraction (fallback for HF)
 
 
 
 
 
78
  """
79
  client, model_name = self._resolve_model(tier, model)
80
  temp = temperature or self.config.temperature_generation
81
  json_schema = schema.model_json_schema()
82
 
83
+ # Build the schema instruction to embed in prompt
84
+ schema_str = json.dumps(json_schema, indent=2)
85
+ json_instruction = (
86
+ f"\n\nYou MUST output ONLY valid JSON matching this exact schema. "
87
+ f"No markdown, no explanation, no ```json blocks — ONLY the raw JSON object.\n"
88
+ f"Schema:\n{schema_str}"
89
+ )
90
+
91
+ # Strategy 1: Try json_schema format (works with vLLM, newer Ollama)
92
  try:
93
  response = await client.chat.completions.create(
94
  model=model_name,
95
  messages=[
96
  {"role": "system", "content": system_prompt},
97
+ {"role": "user", "content": prompt + json_instruction},
98
  ],
99
  temperature=temp,
100
  max_tokens=self.config.max_tokens,
 
106
  },
107
  },
108
  )
109
+ content = response.choices[0].message.content
110
+ if content and content.strip():
111
+ self._token_count += response.usage.total_tokens if response.usage else 0
112
+ return self._parse_json_response(content, schema)
113
+ except Exception as e:
114
+ if "json_schema" not in str(e).lower() and "format" not in str(e).lower() and "unsupported" not in str(e).lower():
115
+ # Not a format issue try json_object
116
+ pass
117
+
118
+ # Strategy 2: Try json_object format
119
+ try:
120
  response = await client.chat.completions.create(
121
  model=model_name,
122
  messages=[
123
+ {"role": "system", "content": system_prompt + "\nAlways respond in valid JSON."},
124
+ {"role": "user", "content": prompt + json_instruction},
125
  ],
126
  temperature=temp,
127
  max_tokens=self.config.max_tokens,
128
  response_format={"type": "json_object"},
129
  )
130
+ content = response.choices[0].message.content
131
+ if content and content.strip():
132
+ self._token_count += response.usage.total_tokens if response.usage else 0
133
+ return self._parse_json_response(content, schema)
134
+ except Exception:
135
+ pass
136
 
137
+ # Strategy 3: Plain text with JSON extraction (works everywhere including HF)
138
+ response = await client.chat.completions.create(
139
+ model=model_name,
140
+ messages=[
141
+ {"role": "system", "content": system_prompt + "\nYou always respond with valid JSON only. No other text."},
142
+ {"role": "user", "content": prompt + json_instruction},
143
+ ],
144
+ temperature=temp,
145
+ max_tokens=self.config.max_tokens,
146
+ )
147
  content = response.choices[0].message.content
148
  self._token_count += response.usage.total_tokens if response.usage else 0
149
 
150
+ if not content or not content.strip():
151
+ raise ValueError(f"Empty response from model {model_name}")
152
+
153
+ return self._parse_json_response(content, schema)
154
+
155
+ def _parse_json_response(self, content: str, schema: type[T]) -> T:
156
+ """
157
+ Parse JSON from LLM response, handling common issues:
158
+ - Markdown code blocks (```json ... ```)
159
+ - Leading/trailing text
160
+ - Thinking tags (<think>...</think>)
161
+ """
162
+ text = content.strip()
163
+
164
+ # Remove thinking tags (Qwen/DeepSeek R1 style)
165
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
166
+
167
+ # Remove markdown code blocks
168
+ if "```json" in text:
169
+ match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
170
+ if match:
171
+ text = match.group(1)
172
+ elif "```" in text:
173
+ match = re.search(r'```\s*(.*?)\s*```', text, re.DOTALL)
174
+ if match:
175
+ text = match.group(1)
176
+
177
+ # Try to find JSON object in the text
178
+ if not text.startswith('{'):
179
+ # Look for first { and last }
180
+ start = text.find('{')
181
+ end = text.rfind('}')
182
+ if start != -1 and end != -1 and end > start:
183
+ text = text[start:end + 1]
184
+
185
+ # Parse
186
+ try:
187
+ data = json.loads(text)
188
+ except json.JSONDecodeError as e:
189
+ raise ValueError(
190
+ f"Failed to parse JSON from model response.\n"
191
+ f"Error: {e}\n"
192
+ f"Response (first 500 chars): {content[:500]}"
193
+ )
194
+
195
  return schema.model_validate(data)
196
 
197
  async def generate_text(
 
203
  system_prompt: str = "You are a quantitative finance expert.",
204
  max_tokens: int = 2048,
205
  ) -> str:
206
+ """Generate free-text response (for memos/reports only)."""
207
  client, model_name = self._resolve_model(tier, model)
208
  temp = temperature or self.config.temperature_critique
209
 
 
219
 
220
  content = response.choices[0].message.content
221
  self._token_count += response.usage.total_tokens if response.usage else 0
222
+ return content or ""
223
 
224
  async def health_check(self, tier: str = "mediumfish") -> bool:
225
  """Check if the model endpoint is reachable."""
 
227
  client, model_name = self._resolve_model(tier)
228
  response = await client.chat.completions.create(
229
  model=model_name,
230
+ messages=[{"role": "user", "content": "Say ok"}],
231
  max_tokens=5,
232
  )
233
+ return bool(response.choices[0].message.content)
234
  except Exception:
235
  return False
236