gaurv007 commited on
Commit
3bf54ea
·
verified ·
1 Parent(s): dfbfd81

Upload alpha_factory/infra/model_manager.py

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/model_manager.py +132 -47
alpha_factory/infra/model_manager.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- Model Manager — Unified interface for Ollama (local) + HuggingFace Inference API (cloud).
3
- Auto-detects available models from both sources.
4
- User selects which to use via interactive menu or config.
5
  """
6
  import asyncio
7
  import aiohttp
@@ -28,11 +28,34 @@ class ModelInfo:
28
  quantization: Optional[str] = None
29
  context_length: Optional[int] = None
30
  is_default: bool = False
 
31
 
32
  def display_name(self) -> str:
33
  size_str = f" ({self.size_gb:.1f}GB)" if self.size_gb else ""
34
  quant_str = f" [{self.quantization}]" if self.quantization else ""
35
- return f"[{self.provider.value}] {self.name}{size_str}{quant_str}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  # ─── Default model recommendations ─────────────────────────────────────────
@@ -71,6 +94,20 @@ HF_RECOMMENDED = [
71
  ]
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def _add_hf_fallbacks(target_list: list[ModelInfo]):
75
  """Add all HF recommended models as fallbacks."""
76
  for model_id in HF_RECOMMENDED:
@@ -82,7 +119,9 @@ def _add_hf_fallbacks(target_list: list[ModelInfo]):
82
 
83
  class ModelManager:
84
  """
85
- Detects and manages models from Ollama (local) and HuggingFace (cloud).
 
 
86
  Provides unified interface for the pipeline to request models.
87
  """
88
 
@@ -110,8 +149,11 @@ class ModelManager:
110
  )
111
 
112
  async def _discover_ollama(self):
113
- """Detect locally installed Ollama models."""
114
  self.ollama_models = []
 
 
 
115
  try:
116
  async with aiohttp.ClientSession() as session:
117
  async with session.get(
@@ -137,8 +179,10 @@ class ModelManager:
137
  provider=ModelProvider.OLLAMA,
138
  size_gb=round(size_gb, 1) if size_gb else None,
139
  quantization=quant,
 
140
  ))
141
- logger.info(f"Discovered {len(self.ollama_models)} Ollama models")
 
142
  else:
143
  logger.warning(f"Ollama returned status {resp.status}")
144
  except asyncio.TimeoutError:
@@ -146,6 +190,20 @@ class ModelManager:
146
  except aiohttp.ClientError as e:
147
  logger.warning(f"Ollama not reachable: {e}")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  async def _discover_hf(self):
150
  """Check which HuggingFace models are available via Inference API."""
151
  self.hf_models = []
@@ -154,7 +212,6 @@ class ModelManager:
154
  _add_hf_fallbacks(self.hf_models)
155
  return
156
 
157
- # With token, check which models are actually accessible
158
  headers = {"Authorization": f"Bearer {self.hf_token}"}
159
  async with aiohttp.ClientSession() as session:
160
  for model_id in HF_RECOMMENDED:
@@ -194,13 +251,17 @@ class ModelManager:
194
  ))
195
 
196
  def get_all_models(self) -> list[ModelInfo]:
197
- """Get all discovered models (local + cloud)."""
198
  return self.ollama_models + self.hf_models
199
 
200
  def get_local_models(self) -> list[ModelInfo]:
201
- """Get only locally installed models."""
202
  return self.ollama_models
203
 
 
 
 
 
204
  def get_cloud_models(self) -> list[ModelInfo]:
205
  """Get HuggingFace cloud models."""
206
  return self.hf_models
@@ -227,7 +288,6 @@ class ModelManager:
227
  {},
228
  )
229
  else:
230
- # HuggingFace Inference API
231
  headers = {}
232
  if self.hf_token:
233
  headers["Authorization"] = f"Bearer {self.hf_token}"
@@ -240,26 +300,33 @@ class ModelManager:
240
  def auto_assign_defaults(self):
241
  """
242
  Automatically assign best available models to each tier.
243
- Prefers local (Ollama) over cloud (HF) for speed + privacy.
244
  """
245
- local_names = {m.name.lower(): m for m in self.ollama_models}
 
246
 
247
  for tier, default in DEFAULTS.items():
248
- # Try to find the default model locally
249
- if default.name.lower() in local_names:
250
- self.selected[tier] = local_names[default.name.lower()]
251
- elif self.ollama_models:
252
- # Use the best available local model for this tier
253
- sorted_local = sorted(self.ollama_models, key=lambda m: m.size_gb or 0)
254
- if tier == "microfish" and sorted_local:
255
- self.selected[tier] = sorted_local[0] # smallest
256
- elif tier == "bigfish" and sorted_local:
257
- self.selected[tier] = sorted_local[-1] # largest
258
- elif sorted_local:
259
- mid = len(sorted_local) // 2
260
- self.selected[tier] = sorted_local[mid] # middle
 
 
 
 
 
 
 
261
  elif self.hf_models:
262
- # Fallback to HuggingFace cloud — pick size-appropriate model
263
  hf_tier_map = {
264
  "microfish": "Qwen/Qwen2.5-7B-Instruct",
265
  "tinyfish": "Qwen/Qwen2.5-7B-Instruct",
@@ -270,7 +337,6 @@ class ModelManager:
270
  matched = [m for m in self.hf_models if m.name == target]
271
  self.selected[tier] = matched[0] if matched else self.hf_models[0]
272
  else:
273
- # Use defaults (will fail at runtime if nothing available)
274
  self.selected[tier] = default
275
 
276
  def print_status(self):
@@ -284,22 +350,24 @@ class ModelManager:
284
  console = None
285
  has_rich = False
286
 
287
- # Discovery summary
 
 
288
  if has_rich:
289
  console.print(f"\n[bold]🔍 Model Discovery[/]")
290
- console.print(f" Ollama (local): {len(self.ollama_models)} models")
 
291
  console.print(f" HuggingFace (cloud): {len(self.hf_models)} models")
292
  if not self.hf_token:
293
  console.print(f" [yellow]⚠ No HF_TOKEN set — cloud models may have rate limits[/]")
294
 
295
- # Available models table
296
- if self.ollama_models:
297
- table = Table(title="Local Models (Ollama)")
298
  table.add_column("#", width=3)
299
  table.add_column("Model", style="cyan")
300
  table.add_column("Size", style="green")
301
  table.add_column("Quant", style="yellow")
302
- for i, m in enumerate(self.ollama_models, 1):
303
  table.add_row(
304
  str(i), m.name,
305
  f"{m.size_gb:.1f}GB" if m.size_gb else "?",
@@ -307,11 +375,24 @@ class ModelManager:
307
  )
308
  console.print(table)
309
 
310
- # Selected models
 
 
 
 
 
 
 
 
 
 
 
 
311
  table2 = Table(title="Selected Models (Pipeline)")
312
  table2.add_column("Tier", style="bold")
313
  table2.add_column("Model", style="cyan")
314
  table2.add_column("Provider", style="magenta")
 
315
  table2.add_column("Use", style="dim")
316
 
317
  tier_uses = {
@@ -323,18 +404,17 @@ class ModelManager:
323
 
324
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
325
  model = self.get_selected(tier)
 
326
  table2.add_row(
327
- tier, model.name, model.provider.value,
328
  tier_uses.get(tier, ""),
329
  )
330
  console.print(table2)
331
  else:
332
- # Plain text fallback
333
  print(f"\nModel Discovery")
334
- print(f" Ollama (local): {len(self.ollama_models)} models")
 
335
  print(f" HuggingFace (cloud): {len(self.hf_models)} models")
336
- if not self.hf_token:
337
- print(f" ! No HF_TOKEN set — cloud models may have rate limits")
338
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
339
  model = self.get_selected(tier)
340
  print(f" {tier}: {model.name} ({model.provider.value})")
@@ -358,7 +438,9 @@ def interactive_model_select(manager: ModelManager) -> dict[str, ModelInfo]:
358
  all_models = manager.get_all_models()
359
 
360
  if not all_models:
361
- msg = "No models found! Install Ollama models or set HF_TOKEN.\n ollama pull qwen2.5:1.5b\n ollama pull qwen2.5:7b\n export HF_TOKEN=hf_your_token"
 
 
362
  if has_rich:
363
  console.print(f"[red]{msg}[/]")
364
  else:
@@ -372,23 +454,26 @@ def interactive_model_select(manager: ModelManager) -> dict[str, ModelInfo]:
372
  else:
373
  print("\nAvailable Models:")
374
  for i, m in enumerate(all_models, 1):
375
- print(f" {i:2d}. [{m.provider.value}] {m.name}")
376
 
377
  selections = {}
378
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
379
  default = DEFAULTS[tier]
380
- tier_desc = {"microfish": "bulk generation", "tinyfish": "compilation", "mediumfish": "critique", "bigfish": "final gate"}
 
 
 
 
 
381
 
382
  if has_rich:
383
- console.print(f"\n[bold]Select model for [{tier}][/] (default: {default.name}):")
384
- console.print(f" Use: {tier_desc[tier]}")
385
  choice = Prompt.ask(
386
  f" Enter number (1-{len(all_models)}) or press Enter for default",
387
  default="",
388
  )
389
  else:
390
- print(f"\nSelect model for [{tier}] (default: {default.name}):")
391
- print(f" Use: {tier_desc[tier]}")
392
  choice = input(f" Enter number (1-{len(all_models)}) or press Enter for default: ")
393
 
394
  if choice and choice.isdigit():
 
1
  """
2
+ Model Manager — Unified interface for Ollama (local + pullable) + HuggingFace Inference API (cloud).
3
+ Auto-detects installed Ollama models AND shows recommended models available to pull.
4
+ User selects which to use via interactive menu, CLI flags, or Gradio dropdowns.
5
  """
6
  import asyncio
7
  import aiohttp
 
28
  quantization: Optional[str] = None
29
  context_length: Optional[int] = None
30
  is_default: bool = False
31
+ is_installed: bool = True # False = recommended but not yet pulled (Ollama only)
32
 
33
  def display_name(self) -> str:
34
  size_str = f" ({self.size_gb:.1f}GB)" if self.size_gb else ""
35
  quant_str = f" [{self.quantization}]" if self.quantization else ""
36
+ pullable = " [PULLABLE — ollama pull " + self.name + "]" if not self.is_installed else ""
37
+ return f"[{self.provider.value}] {self.name}{size_str}{quant_str}{pullable}"
38
+
39
+
40
+ # ─── Ollama models known to work well for this pipeline ────────────────────
41
+ # Includes a range of sizes so every tier has good options.
42
+ OLLAMA_RECOMMENDED = [
43
+ # Qwen 2.5 family (excellent for structured JSON / codegen)
44
+ "qwen2.5:0.5b", "qwen2.5:1.5b", "qwen2.5:3b", "qwen2.5:7b",
45
+ "qwen2.5:14b", "qwen2.5:32b", "qwen2.5:72b",
46
+ "qwen2.5-coder:1.5b", "qwen2.5-coder:7b", "qwen2.5-coder:14b",
47
+ # DeepSeek R1 (reasoning-heavy, good for gatekeeper)
48
+ "deepseek-r1:1.5b", "deepseek-r1:7b", "deepseek-r1:14b",
49
+ "deepseek-r1:32b", "deepseek-r1:70b",
50
+ # Llama family
51
+ "llama3.2:1b", "llama3.2:3b", "llama3.3:70b",
52
+ # Mistral family
53
+ "mistral:7b", "mixtral:8x7b", "mixtral:8x22b",
54
+ # Microsoft Phi
55
+ "phi4:14b", "phi3:3.8b", "phi3:medium",
56
+ # Google Gemma
57
+ "gemma2:2b", "gemma2:9b", "gemma2:27b",
58
+ ]
59
 
60
 
61
  # ─── Default model recommendations ─────────────────────────────────────────
 
94
  ]
95
 
96
 
97
+ # Approximate size mapping for Ollama models (to help tier selection)
98
+ OLLAMA_SIZE_GUESS: dict[str, float] = {
99
+ "qwen2.5:0.5b": 0.5, "qwen2.5:1.5b": 1.0, "qwen2.5:3b": 2.0,
100
+ "qwen2.5:7b": 4.7, "qwen2.5:14b": 9.0, "qwen2.5:32b": 20.0, "qwen2.5:72b": 47.0,
101
+ "qwen2.5-coder:1.5b": 1.0, "qwen2.5-coder:7b": 4.7, "qwen2.5-coder:14b": 9.0,
102
+ "deepseek-r1:1.5b": 1.0, "deepseek-r1:7b": 4.7, "deepseek-r1:14b": 9.0,
103
+ "deepseek-r1:32b": 20.0, "deepseek-r1:70b": 43.0,
104
+ "llama3.2:1b": 0.7, "llama3.2:3b": 2.0, "llama3.3:70b": 43.0,
105
+ "mistral:7b": 4.7, "mixtral:8x7b": 26.0, "mixtral:8x22b": 80.0,
106
+ "phi4:14b": 9.0, "phi3:3.8b": 2.5, "phi3:medium": 4.0,
107
+ "gemma2:2b": 1.6, "gemma2:9b": 5.5, "gemma2:27b": 18.0,
108
+ }
109
+
110
+
111
  def _add_hf_fallbacks(target_list: list[ModelInfo]):
112
  """Add all HF recommended models as fallbacks."""
113
  for model_id in HF_RECOMMENDED:
 
119
 
120
  class ModelManager:
121
  """
122
+ Detects and manages models from:
123
+ - Ollama (local, installed + recommended-to-pull)
124
+ - HuggingFace Inference API (cloud)
125
  Provides unified interface for the pipeline to request models.
126
  """
127
 
 
149
  )
150
 
151
  async def _discover_ollama(self):
152
+ """Detect locally installed Ollama models AND show recommended models to pull."""
153
  self.ollama_models = []
154
+ installed_names: set[str] = set()
155
+
156
+ # 1. Query Ollama for already-pulled models
157
  try:
158
  async with aiohttp.ClientSession() as session:
159
  async with session.get(
 
179
  provider=ModelProvider.OLLAMA,
180
  size_gb=round(size_gb, 1) if size_gb else None,
181
  quantization=quant,
182
+ is_installed=True,
183
  ))
184
+ installed_names.add(name)
185
+ logger.info(f"Discovered {len(self.ollama_models)} installed Ollama models")
186
  else:
187
  logger.warning(f"Ollama returned status {resp.status}")
188
  except asyncio.TimeoutError:
 
190
  except aiohttp.ClientError as e:
191
  logger.warning(f"Ollama not reachable: {e}")
192
 
193
+ # 2. Add recommended models that are NOT installed (pullable)
194
+ for tag in OLLAMA_RECOMMENDED:
195
+ if tag not in installed_names:
196
+ self.ollama_models.append(ModelInfo(
197
+ name=tag,
198
+ provider=ModelProvider.OLLAMA,
199
+ size_gb=OLLAMA_SIZE_GUESS.get(tag),
200
+ is_installed=False,
201
+ ))
202
+
203
+ logger.info(f"Total Ollama choices: {len(self.ollama_models)} "
204
+ f"({len(installed_names)} installed + "
205
+ f"{len(self.ollama_models) - len(installed_names)} pullable)")
206
+
207
  async def _discover_hf(self):
208
  """Check which HuggingFace models are available via Inference API."""
209
  self.hf_models = []
 
212
  _add_hf_fallbacks(self.hf_models)
213
  return
214
 
 
215
  headers = {"Authorization": f"Bearer {self.hf_token}"}
216
  async with aiohttp.ClientSession() as session:
217
  for model_id in HF_RECOMMENDED:
 
251
  ))
252
 
253
  def get_all_models(self) -> list[ModelInfo]:
254
+ """Get all discovered models (local installed + local pullable + cloud)."""
255
  return self.ollama_models + self.hf_models
256
 
257
  def get_local_models(self) -> list[ModelInfo]:
258
+ """Get Ollama models (installed + pullable)."""
259
  return self.ollama_models
260
 
261
+ def get_installed_models(self) -> list[ModelInfo]:
262
+ """Get only installed Ollama models."""
263
+ return [m for m in self.ollama_models if m.is_installed]
264
+
265
  def get_cloud_models(self) -> list[ModelInfo]:
266
  """Get HuggingFace cloud models."""
267
  return self.hf_models
 
288
  {},
289
  )
290
  else:
 
291
  headers = {}
292
  if self.hf_token:
293
  headers["Authorization"] = f"Bearer {self.hf_token}"
 
300
  def auto_assign_defaults(self):
301
  """
302
  Automatically assign best available models to each tier.
303
+ Prefers local installed (Ollama) over pullable over cloud (HF).
304
  """
305
+ installed_names = {m.name.lower(): m for m in self.ollama_models if m.is_installed}
306
+ all_ollama_names = {m.name.lower(): m for m in self.ollama_models}
307
 
308
  for tier, default in DEFAULTS.items():
309
+ # 1. Try exact match among installed
310
+ if default.name.lower() in installed_names:
311
+ self.selected[tier] = installed_names[default.name.lower()]
312
+ # 2. Any installed model, size-appropriate
313
+ elif installed_names:
314
+ sorted_installed = sorted(
315
+ installed_names.values(),
316
+ key=lambda m: m.size_gb or 0
317
+ )
318
+ if tier == "microfish" and sorted_installed:
319
+ self.selected[tier] = sorted_installed[0]
320
+ elif tier == "bigfish" and sorted_installed:
321
+ self.selected[tier] = sorted_installed[-1]
322
+ elif sorted_installed:
323
+ mid = len(sorted_installed) // 2
324
+ self.selected[tier] = sorted_installed[mid]
325
+ # 3. Fallback to pullable Ollama (same defaults)
326
+ elif default.name.lower() in all_ollama_names:
327
+ self.selected[tier] = all_ollama_names[default.name.lower()]
328
+ # 4. Fallback to HF cloud
329
  elif self.hf_models:
 
330
  hf_tier_map = {
331
  "microfish": "Qwen/Qwen2.5-7B-Instruct",
332
  "tinyfish": "Qwen/Qwen2.5-7B-Instruct",
 
337
  matched = [m for m in self.hf_models if m.name == target]
338
  self.selected[tier] = matched[0] if matched else self.hf_models[0]
339
  else:
 
340
  self.selected[tier] = default
341
 
342
  def print_status(self):
 
350
  console = None
351
  has_rich = False
352
 
353
+ installed = [m for m in self.ollama_models if m.is_installed]
354
+ pullable = [m for m in self.ollama_models if not m.is_installed]
355
+
356
  if has_rich:
357
  console.print(f"\n[bold]🔍 Model Discovery[/]")
358
+ console.print(f" Ollama (installed): {len(installed)} models")
359
+ console.print(f" Ollama (pullable): {len(pullable)} models")
360
  console.print(f" HuggingFace (cloud): {len(self.hf_models)} models")
361
  if not self.hf_token:
362
  console.print(f" [yellow]⚠ No HF_TOKEN set — cloud models may have rate limits[/]")
363
 
364
+ if installed:
365
+ table = Table(title="Installed Ollama Models")
 
366
  table.add_column("#", width=3)
367
  table.add_column("Model", style="cyan")
368
  table.add_column("Size", style="green")
369
  table.add_column("Quant", style="yellow")
370
+ for i, m in enumerate(installed, 1):
371
  table.add_row(
372
  str(i), m.name,
373
  f"{m.size_gb:.1f}GB" if m.size_gb else "?",
 
375
  )
376
  console.print(table)
377
 
378
+ if pullable:
379
+ table = Table(title="Available to Pull (Ollama)")
380
+ table.add_column("Tag", style="cyan")
381
+ table.add_column("Est. Size", style="dim")
382
+ for m in pullable[:15]: # Limit to avoid wall of text
383
+ table.add_row(
384
+ m.name,
385
+ f"~{m.size_gb:.1f}GB" if m.size_gb else "?",
386
+ )
387
+ if len(pullable) > 15:
388
+ table.add_row(f"... and {len(pullable) - 15} more", "")
389
+ console.print(table)
390
+
391
  table2 = Table(title="Selected Models (Pipeline)")
392
  table2.add_column("Tier", style="bold")
393
  table2.add_column("Model", style="cyan")
394
  table2.add_column("Provider", style="magenta")
395
+ table2.add_column("Status", style="dim")
396
  table2.add_column("Use", style="dim")
397
 
398
  tier_uses = {
 
404
 
405
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
406
  model = self.get_selected(tier)
407
+ status = "installed" if model.is_installed else ("pullable" if model.provider == ModelProvider.OLLAMA else "cloud")
408
  table2.add_row(
409
+ tier, model.name, model.provider.value, status,
410
  tier_uses.get(tier, ""),
411
  )
412
  console.print(table2)
413
  else:
 
414
  print(f"\nModel Discovery")
415
+ print(f" Ollama (installed): {len(installed)} models")
416
+ print(f" Ollama (pullable): {len(pullable)} models")
417
  print(f" HuggingFace (cloud): {len(self.hf_models)} models")
 
 
418
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
419
  model = self.get_selected(tier)
420
  print(f" {tier}: {model.name} ({model.provider.value})")
 
438
  all_models = manager.get_all_models()
439
 
440
  if not all_models:
441
+ msg = ("No models found!\n"
442
+ " Install Ollama models: ollama pull qwen2.5:1.5b\n"
443
+ " Or set HF_TOKEN for cloud models: export HF_TOKEN=hf_your_token")
444
  if has_rich:
445
  console.print(f"[red]{msg}[/]")
446
  else:
 
454
  else:
455
  print("\nAvailable Models:")
456
  for i, m in enumerate(all_models, 1):
457
+ print(f" {i:2d}. {m.display_name()}")
458
 
459
  selections = {}
460
  for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
461
  default = DEFAULTS[tier]
462
+ tier_desc = {
463
+ "microfish": "bulk generation",
464
+ "tinyfish": "compilation",
465
+ "mediumfish": "critique",
466
+ "bigfish": "final gate",
467
+ }
468
 
469
  if has_rich:
470
+ console.print(f"\n[bold]Select model for [{tier}][/] (default: {default.name}) — {tier_desc[tier]}:")
 
471
  choice = Prompt.ask(
472
  f" Enter number (1-{len(all_models)}) or press Enter for default",
473
  default="",
474
  )
475
  else:
476
+ print(f"\nSelect model for [{tier}] (default: {default.name}) — {tier_desc[tier]}:")
 
477
  choice = input(f" Enter number (1-{len(all_models)}) or press Enter for default: ")
478
 
479
  if choice and choice.isdigit():