gaurv007 commited on
Commit
10c2948
·
verified ·
1 Parent(s): 7be83bf

Upload alpha_factory/infra/llm_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/llm_client.py +114 -31
alpha_factory/infra/llm_client.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- LLM Client — unified interface to vLLM / Ollama with guided JSON generation.
3
- All outputs are schema-constrained. No free-text alpha generation.
 
4
  """
5
  import asyncio
6
  import json
@@ -14,52 +15,120 @@ T = TypeVar("T", bound=BaseModel)
14
 
15
  class LLMClient:
16
  """
17
- Async LLM client with structured JSON output.
18
- Connects to vLLM or Ollama (both expose OpenAI-compatible API).
 
 
 
 
19
  """
20
 
21
- def __init__(self, config: LLMConfig):
22
  self.config = config
23
- self.client = AsyncOpenAI(
24
- base_url=config.base_url,
25
- api_key=config.api_key,
26
- )
27
  self._token_count = 0
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  async def generate_json(
30
  self,
31
  prompt: str,
32
  schema: type[T],
 
33
  model: str | None = None,
34
  temperature: float | None = None,
35
  system_prompt: str = "You are a quantitative finance expert.",
36
  ) -> T:
37
  """
38
  Generate a structured JSON response conforming to the given Pydantic schema.
39
- Uses guided decoding via response_format (vLLM supports this natively).
 
 
 
 
 
 
 
 
40
  """
41
- model = model or self.config.mediumfish_model
42
  temp = temperature or self.config.temperature_generation
43
-
44
- # Build JSON schema for guided generation
45
  json_schema = schema.model_json_schema()
46
 
47
- response = await self.client.chat.completions.create(
48
- model=model,
49
- messages=[
50
- {"role": "system", "content": system_prompt},
51
- {"role": "user", "content": prompt},
52
- ],
53
- temperature=temp,
54
- max_tokens=self.config.max_tokens,
55
- response_format={
56
- "type": "json_schema",
57
- "json_schema": {
58
- "name": schema.__name__,
59
- "schema": json_schema,
 
 
60
  },
61
- },
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  content = response.choices[0].message.content
65
  self._token_count += response.usage.total_tokens if response.usage else 0
@@ -71,17 +140,18 @@ class LLMClient:
71
  async def generate_text(
72
  self,
73
  prompt: str,
 
74
  model: str | None = None,
75
  temperature: float | None = None,
76
  system_prompt: str = "You are a quantitative finance expert.",
77
  max_tokens: int = 2048,
78
  ) -> str:
79
  """Generate free-text response (for memos/reports only, never for expressions)."""
80
- model = model or self.config.mediumfish_model
81
  temp = temperature or self.config.temperature_critique
82
 
83
- response = await self.client.chat.completions.create(
84
- model=model,
85
  messages=[
86
  {"role": "system", "content": system_prompt},
87
  {"role": "user", "content": prompt},
@@ -94,6 +164,19 @@ class LLMClient:
94
  self._token_count += response.usage.total_tokens if response.usage else 0
95
  return content
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  @property
98
  def tokens_used(self) -> int:
99
  return self._token_count
 
1
  """
2
+ LLM Client — unified interface supporting both Ollama (local) and HuggingFace (cloud).
3
+ Auto-switches between providers based on ModelManager selection.
4
+ All outputs are schema-constrained via guided JSON generation.
5
  """
6
  import asyncio
7
  import json
 
15
 
16
  class LLMClient:
17
  """
18
+ Async LLM client supporting:
19
+ - Ollama (local, http://localhost:11434/v1)
20
+ - HuggingFace Inference API (cloud, https://router.huggingface.co/v1)
21
+ - vLLM (local/remote, any OpenAI-compatible endpoint)
22
+
23
+ All outputs are JSON-schema-constrained for reliability.
24
  """
25
 
26
+ def __init__(self, config: LLMConfig, model_manager=None):
27
  self.config = config
28
+ self.model_manager = model_manager
29
+ self._clients: dict[str, AsyncOpenAI] = {}
 
 
30
  self._token_count = 0
31
 
32
+ def _get_client(self, base_url: str, api_key: str = "dummy", **headers) -> AsyncOpenAI:
33
+ """Get or create an AsyncOpenAI client for the given endpoint."""
34
+ key = f"{base_url}|{api_key}"
35
+ if key not in self._clients:
36
+ self._clients[key] = AsyncOpenAI(
37
+ base_url=base_url,
38
+ api_key=api_key,
39
+ default_headers=headers if headers else None,
40
+ )
41
+ return self._clients[key]
42
+
43
+ def _resolve_model(self, tier: str = "mediumfish", model_override: str | None = None) -> tuple[AsyncOpenAI, str]:
44
+ """
45
+ Resolve which client + model to use for a given tier.
46
+ Priority: model_override > ModelManager selection > config default
47
+ """
48
+ if model_override:
49
+ # Direct model name — use default endpoint
50
+ client = self._get_client(self.config.base_url, self.config.api_key)
51
+ return client, model_override
52
+
53
+ if self.model_manager:
54
+ base_url, model_name, headers = self.model_manager.get_endpoint(tier)
55
+ api_key = headers.get("Authorization", "").replace("Bearer ", "") or "dummy"
56
+ client = self._get_client(base_url, api_key)
57
+ return client, model_name
58
+
59
+ # Fallback: use config defaults
60
+ tier_to_model = {
61
+ "microfish": self.config.microfish_model,
62
+ "tinyfish": self.config.tinyfish_model,
63
+ "mediumfish": self.config.mediumfish_model,
64
+ "bigfish": self.config.bigfish_model,
65
+ }
66
+ model = tier_to_model.get(tier, self.config.mediumfish_model)
67
+ client = self._get_client(self.config.base_url, self.config.api_key)
68
+ return client, model
69
+
70
  async def generate_json(
71
  self,
72
  prompt: str,
73
  schema: type[T],
74
+ tier: str = "mediumfish",
75
  model: str | None = None,
76
  temperature: float | None = None,
77
  system_prompt: str = "You are a quantitative finance expert.",
78
  ) -> T:
79
  """
80
  Generate a structured JSON response conforming to the given Pydantic schema.
81
+ Uses guided decoding via response_format.
82
+
83
+ Args:
84
+ prompt: The user prompt
85
+ schema: Pydantic model class for output validation
86
+ tier: Model tier (microfish/tinyfish/mediumfish/bigfish)
87
+ model: Override model name (optional)
88
+ temperature: Override temperature (optional)
89
+ system_prompt: System message
90
  """
91
+ client, model_name = self._resolve_model(tier, model)
92
  temp = temperature or self.config.temperature_generation
 
 
93
  json_schema = schema.model_json_schema()
94
 
95
+ try:
96
+ response = await client.chat.completions.create(
97
+ model=model_name,
98
+ messages=[
99
+ {"role": "system", "content": system_prompt},
100
+ {"role": "user", "content": prompt},
101
+ ],
102
+ temperature=temp,
103
+ max_tokens=self.config.max_tokens,
104
+ response_format={
105
+ "type": "json_schema",
106
+ "json_schema": {
107
+ "name": schema.__name__,
108
+ "schema": json_schema,
109
+ },
110
  },
111
+ )
112
+ except Exception:
113
+ # Fallback: some providers don't support json_schema format
114
+ # Try with json_object format + schema instruction in prompt
115
+ schema_str = json.dumps(json_schema, indent=2)
116
+ augmented_prompt = (
117
+ f"{prompt}\n\n"
118
+ f"IMPORTANT: Output ONLY valid JSON matching this schema:\n"
119
+ f"```json\n{schema_str}\n```\n"
120
+ f"No other text. Just the JSON."
121
+ )
122
+ response = await client.chat.completions.create(
123
+ model=model_name,
124
+ messages=[
125
+ {"role": "system", "content": system_prompt},
126
+ {"role": "user", "content": augmented_prompt},
127
+ ],
128
+ temperature=temp,
129
+ max_tokens=self.config.max_tokens,
130
+ response_format={"type": "json_object"},
131
+ )
132
 
133
  content = response.choices[0].message.content
134
  self._token_count += response.usage.total_tokens if response.usage else 0
 
140
  async def generate_text(
141
  self,
142
  prompt: str,
143
+ tier: str = "mediumfish",
144
  model: str | None = None,
145
  temperature: float | None = None,
146
  system_prompt: str = "You are a quantitative finance expert.",
147
  max_tokens: int = 2048,
148
  ) -> str:
149
  """Generate free-text response (for memos/reports only, never for expressions)."""
150
+ client, model_name = self._resolve_model(tier, model)
151
  temp = temperature or self.config.temperature_critique
152
 
153
+ response = await client.chat.completions.create(
154
+ model=model_name,
155
  messages=[
156
  {"role": "system", "content": system_prompt},
157
  {"role": "user", "content": prompt},
 
164
  self._token_count += response.usage.total_tokens if response.usage else 0
165
  return content
166
 
167
+ async def health_check(self, tier: str = "mediumfish") -> bool:
168
+ """Check if the model endpoint is reachable."""
169
+ try:
170
+ client, model_name = self._resolve_model(tier)
171
+ response = await client.chat.completions.create(
172
+ model=model_name,
173
+ messages=[{"role": "user", "content": "Say 'ok'"}],
174
+ max_tokens=5,
175
+ )
176
+ return True
177
+ except Exception:
178
+ return False
179
+
180
  @property
181
  def tokens_used(self) -> int:
182
  return self._token_count