gaurv007 commited on
Commit
7be83bf
·
verified ·
1 Parent(s): e8b127a

Upload alpha_factory/infra/model_manager.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/model_manager.py +350 -0
alpha_factory/infra/model_manager.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Manager — Unified interface for Ollama (local) + HuggingFace Inference API (cloud).
3
+ Auto-detects available models from both sources.
4
+ User selects which to use via interactive menu or config.
5
+ """
6
+ import asyncio
7
+ import aiohttp
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+ from enum import Enum
12
+
13
+
14
+ class ModelProvider(str, Enum):
15
+ OLLAMA = "ollama"
16
+ HUGGINGFACE = "huggingface"
17
+
18
+
19
+ @dataclass
20
+ class ModelInfo:
21
+ """Metadata about an available model."""
22
+ name: str
23
+ provider: ModelProvider
24
+ size_gb: Optional[float] = None
25
+ quantization: Optional[str] = None
26
+ context_length: Optional[int] = None
27
+ is_default: bool = False
28
+
29
+ def display_name(self) -> str:
30
+ size_str = f" ({self.size_gb:.1f}GB)" if self.size_gb else ""
31
+ quant_str = f" [{self.quantization}]" if self.quantization else ""
32
+ return f"[{self.provider.value}] {self.name}{size_str}{quant_str}"
33
+
34
+
35
+ # ─── Default model recommendations ─────────────────────────────────────────
36
+ DEFAULTS = {
37
+ "microfish": ModelInfo(
38
+ name="qwen2.5:1.5b", provider=ModelProvider.OLLAMA,
39
+ size_gb=1.0, context_length=32768, is_default=True,
40
+ ),
41
+ "tinyfish": ModelInfo(
42
+ name="qwen2.5:3b", provider=ModelProvider.OLLAMA,
43
+ size_gb=2.0, context_length=32768, is_default=True,
44
+ ),
45
+ "mediumfish": ModelInfo(
46
+ name="qwen2.5:7b", provider=ModelProvider.OLLAMA,
47
+ size_gb=4.7, context_length=32768, is_default=True,
48
+ ),
49
+ "bigfish": ModelInfo(
50
+ name="qwen2.5:14b", provider=ModelProvider.OLLAMA,
51
+ size_gb=9.0, context_length=32768, is_default=True,
52
+ ),
53
+ }
54
+
55
+ # HuggingFace models that work well for this pipeline
56
+ HF_RECOMMENDED = [
57
+ "Qwen/Qwen2.5-72B-Instruct",
58
+ "Qwen/Qwen2.5-32B-Instruct",
59
+ "Qwen/Qwen2.5-14B-Instruct",
60
+ "Qwen/Qwen2.5-7B-Instruct",
61
+ "Qwen/Qwen2.5-Coder-7B-Instruct",
62
+ "deepseek-ai/DeepSeek-V3",
63
+ "deepseek-ai/DeepSeek-R1",
64
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
65
+ "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
66
+ "mistralai/Mistral-Small-24B-Instruct-2501",
67
+ "microsoft/phi-4",
68
+ ]
69
+
70
+
71
+ class ModelManager:
72
+ """
73
+ Detects and manages models from Ollama (local) and HuggingFace (cloud).
74
+ Provides unified interface for the pipeline to request models.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ ollama_url: str = "http://localhost:11434",
80
+ hf_token: Optional[str] = None,
81
+ ):
82
+ self.ollama_url = ollama_url
83
+ self.hf_token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
84
+ self.hf_api_url = "https://router.huggingface.co/v1"
85
+
86
+ # Discovered models
87
+ self.ollama_models: list[ModelInfo] = []
88
+ self.hf_models: list[ModelInfo] = []
89
+
90
+ # Selected models for each tier
91
+ self.selected: dict[str, ModelInfo] = {}
92
+
93
+ async def discover_all(self):
94
+ """Discover all available models from both providers."""
95
+ await asyncio.gather(
96
+ self._discover_ollama(),
97
+ self._discover_hf(),
98
+ )
99
+
100
+ async def _discover_ollama(self):
101
+ """Detect locally installed Ollama models."""
102
+ self.ollama_models = []
103
+ try:
104
+ async with aiohttp.ClientSession() as session:
105
+ async with session.get(f"{self.ollama_url}/api/tags", timeout=aiohttp.ClientTimeout(total=5)) as resp:
106
+ if resp.status == 200:
107
+ data = await resp.json()
108
+ for model in data.get("models", []):
109
+ name = model.get("name", "")
110
+ size_bytes = model.get("size", 0)
111
+ size_gb = size_bytes / (1024**3) if size_bytes else None
112
+
113
+ # Extract quantization from name
114
+ quant = None
115
+ for q in ["q4_0", "q4_k_m", "q5_k_m", "q8_0", "fp16"]:
116
+ if q in name.lower():
117
+ quant = q
118
+ break
119
+
120
+ self.ollama_models.append(ModelInfo(
121
+ name=name,
122
+ provider=ModelProvider.OLLAMA,
123
+ size_gb=round(size_gb, 1) if size_gb else None,
124
+ quantization=quant,
125
+ ))
126
+ except (aiohttp.ClientError, asyncio.TimeoutError):
127
+ pass # Ollama not running — that's fine
128
+
129
+ async def _discover_hf(self):
130
+ """Check which HuggingFace models are available via Inference API."""
131
+ self.hf_models = []
132
+ if not self.hf_token:
133
+ # Still list recommended models (user can add token later)
134
+ for model_id in HF_RECOMMENDED:
135
+ self.hf_models.append(ModelInfo(
136
+ name=model_id,
137
+ provider=ModelProvider.HUGGINGFACE,
138
+ ))
139
+ return
140
+
141
+ # With token, check which models are actually accessible
142
+ try:
143
+ async with aiohttp.ClientSession() as session:
144
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
145
+ for model_id in HF_RECOMMENDED:
146
+ try:
147
+ async with session.get(
148
+ f"https://huggingface.co/api/models/{model_id}",
149
+ headers=headers,
150
+ timeout=aiohttp.ClientTimeout(total=5),
151
+ ) as resp:
152
+ if resp.status == 200:
153
+ data = await resp.json()
154
+ self.hf_models.append(ModelInfo(
155
+ name=model_id,
156
+ provider=ModelProvider.HUGGINGFACE,
157
+ context_length=data.get("config", {}).get("max_position_embeddings"),
158
+ ))
159
+ except:
160
+ # Still list it — might work
161
+ self.hf_models.append(ModelInfo(
162
+ name=model_id,
163
+ provider=ModelProvider.HUGGINGFACE,
164
+ ))
165
+ except:
166
+ for model_id in HF_RECOMMENDED:
167
+ self.hf_models.append(ModelInfo(
168
+ name=model_id,
169
+ provider=ModelProvider.HUGGINGFACE,
170
+ ))
171
+
172
+ def get_all_models(self) -> list[ModelInfo]:
173
+ """Get all discovered models (local + cloud)."""
174
+ return self.ollama_models + self.hf_models
175
+
176
+ def get_local_models(self) -> list[ModelInfo]:
177
+ """Get only locally installed models."""
178
+ return self.ollama_models
179
+
180
+ def get_cloud_models(self) -> list[ModelInfo]:
181
+ """Get HuggingFace cloud models."""
182
+ return self.hf_models
183
+
184
+ def select_model(self, tier: str, model: ModelInfo):
185
+ """Select a model for a specific tier (microfish/tinyfish/mediumfish/bigfish)."""
186
+ self.selected[tier] = model
187
+
188
+ def get_selected(self, tier: str) -> ModelInfo:
189
+ """Get the selected model for a tier, or return default."""
190
+ return self.selected.get(tier, DEFAULTS.get(tier, DEFAULTS["mediumfish"]))
191
+
192
+ def get_endpoint(self, tier: str) -> tuple[str, str, dict]:
193
+ """
194
+ Get the API endpoint info for the selected model.
195
+ Returns: (base_url, model_name, headers)
196
+ """
197
+ model = self.get_selected(tier)
198
+
199
+ if model.provider == ModelProvider.OLLAMA:
200
+ return (
201
+ f"{self.ollama_url}/v1",
202
+ model.name,
203
+ {},
204
+ )
205
+ else:
206
+ # HuggingFace Inference API
207
+ headers = {}
208
+ if self.hf_token:
209
+ headers["Authorization"] = f"Bearer {self.hf_token}"
210
+ return (
211
+ self.hf_api_url,
212
+ model.name,
213
+ headers,
214
+ )
215
+
216
+ def auto_assign_defaults(self):
217
+ """
218
+ Automatically assign best available models to each tier.
219
+ Prefers local (Ollama) over cloud (HF) for speed + privacy.
220
+ """
221
+ local_names = {m.name.lower(): m for m in self.ollama_models}
222
+
223
+ for tier, default in DEFAULTS.items():
224
+ # Try to find the default model locally
225
+ if default.name.lower() in local_names:
226
+ self.selected[tier] = local_names[default.name.lower()]
227
+ elif self.ollama_models:
228
+ # Use the best available local model for this tier
229
+ sorted_local = sorted(self.ollama_models, key=lambda m: m.size_gb or 0)
230
+ if tier == "microfish" and sorted_local:
231
+ self.selected[tier] = sorted_local[0] # smallest
232
+ elif tier == "bigfish" and sorted_local:
233
+ self.selected[tier] = sorted_local[-1] # largest
234
+ elif sorted_local:
235
+ mid = len(sorted_local) // 2
236
+ self.selected[tier] = sorted_local[mid] # middle
237
+ elif self.hf_models:
238
+ # Fallback to HuggingFace cloud — pick size-appropriate model
239
+ hf_tier_map = {
240
+ "microfish": "Qwen/Qwen2.5-7B-Instruct",
241
+ "tinyfish": "Qwen/Qwen2.5-7B-Instruct",
242
+ "mediumfish": "Qwen/Qwen2.5-14B-Instruct",
243
+ "bigfish": "Qwen/Qwen2.5-72B-Instruct",
244
+ }
245
+ target = hf_tier_map.get(tier, "Qwen/Qwen2.5-7B-Instruct")
246
+ matched = [m for m in self.hf_models if m.name == target]
247
+ self.selected[tier] = matched[0] if matched else self.hf_models[0]
248
+ else:
249
+ # Use defaults (will fail at runtime if nothing available)
250
+ self.selected[tier] = default
251
+
252
+ def print_status(self):
253
+ """Print current model configuration."""
254
+ from rich.console import Console
255
+ from rich.table import Table
256
+
257
+ console = Console()
258
+
259
+ # Discovery summary
260
+ console.print(f"\n[bold]🔍 Model Discovery[/]")
261
+ console.print(f" Ollama (local): {len(self.ollama_models)} models")
262
+ console.print(f" HuggingFace (cloud): {len(self.hf_models)} models")
263
+ if not self.hf_token:
264
+ console.print(f" [yellow]⚠ No HF_TOKEN set — cloud models may have rate limits[/]")
265
+
266
+ # Available models table
267
+ if self.ollama_models:
268
+ table = Table(title="Local Models (Ollama)")
269
+ table.add_column("#", width=3)
270
+ table.add_column("Model", style="cyan")
271
+ table.add_column("Size", style="green")
272
+ table.add_column("Quant", style="yellow")
273
+ for i, m in enumerate(self.ollama_models, 1):
274
+ table.add_row(
275
+ str(i), m.name,
276
+ f"{m.size_gb:.1f}GB" if m.size_gb else "?",
277
+ m.quantization or "-",
278
+ )
279
+ console.print(table)
280
+
281
+ # Selected models
282
+ table2 = Table(title="Selected Models (Pipeline)")
283
+ table2.add_column("Tier", style="bold")
284
+ table2.add_column("Model", style="cyan")
285
+ table2.add_column("Provider", style="magenta")
286
+ table2.add_column("Use", style="dim")
287
+
288
+ tier_uses = {
289
+ "microfish": "Hypothesis generation (bulk)",
290
+ "tinyfish": "Expression compilation",
291
+ "mediumfish": "Crowd scout + surgeon",
292
+ "bigfish": "Gatekeeper (final memo)",
293
+ }
294
+
295
+ for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
296
+ model = self.get_selected(tier)
297
+ table2.add_row(
298
+ tier, model.name, model.provider.value,
299
+ tier_uses.get(tier, ""),
300
+ )
301
+ console.print(table2)
302
+
303
+
304
+ def interactive_model_select(manager: ModelManager) -> dict[str, ModelInfo]:
305
+ """
306
+ Interactive CLI menu for model selection.
307
+ Shows all available models and lets user pick for each tier.
308
+ """
309
+ from rich.console import Console
310
+ from rich.prompt import Prompt, IntPrompt
311
+
312
+ console = Console()
313
+ all_models = manager.get_all_models()
314
+
315
+ if not all_models:
316
+ console.print("[red]No models found! Install Ollama models or set HF_TOKEN.[/]")
317
+ console.print(" ollama pull qwen2.5:1.5b")
318
+ console.print(" ollama pull qwen2.5:7b")
319
+ console.print(" export HF_TOKEN=hf_your_token")
320
+ return {}
321
+
322
+ console.print("\n[bold]📋 Available Models:[/]")
323
+ for i, m in enumerate(all_models, 1):
324
+ console.print(f" {i:2d}. {m.display_name()}")
325
+
326
+ selections = {}
327
+ for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
328
+ default = DEFAULTS[tier]
329
+ console.print(f"\n[bold]Select model for [{tier}][/] (default: {default.name}):")
330
+ tier_desc = {"microfish": "bulk generation", "tinyfish": "compilation", "mediumfish": "critique", "bigfish": "final gate"}
331
+ console.print(f" Use: {tier_desc[tier]}")
332
+
333
+ choice = Prompt.ask(
334
+ f" Enter number (1-{len(all_models)}) or press Enter for default",
335
+ default="",
336
+ )
337
+
338
+ if choice and choice.isdigit():
339
+ idx = int(choice) - 1
340
+ if 0 <= idx < len(all_models):
341
+ selections[tier] = all_models[idx]
342
+ console.print(f" → Selected: {all_models[idx].display_name()}")
343
+ else:
344
+ selections[tier] = default
345
+ console.print(f" → Using default: {default.name}")
346
+ else:
347
+ selections[tier] = default
348
+ console.print(f" → Using default: {default.name}")
349
+
350
+ return selections