alpha-factory / alpha_factory /infra /model_manager.py
gaurv007's picture
Upload alpha_factory/infra/model_manager.py with huggingface_hub
238274e verified
"""
Model Manager β€” Unified interface for Ollama (local) + HuggingFace Inference API (cloud).
Auto-detects available models from both sources.
User selects which to use via interactive menu or config.
"""
import asyncio
import aiohttp
import os
import logging
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum
logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OLLAMA = "ollama"
HUGGINGFACE = "huggingface"
@dataclass
class ModelInfo:
"""Metadata about an available model."""
name: str
provider: ModelProvider
size_gb: Optional[float] = None
quantization: Optional[str] = None
context_length: Optional[int] = None
is_default: bool = False
def display_name(self) -> str:
size_str = f" ({self.size_gb:.1f}GB)" if self.size_gb else ""
quant_str = f" [{self.quantization}]" if self.quantization else ""
return f"[{self.provider.value}] {self.name}{size_str}{quant_str}"
# ─── Default model recommendations ─────────────────────────────────────────
DEFAULTS = {
"microfish": ModelInfo(
name="qwen2.5:1.5b", provider=ModelProvider.OLLAMA,
size_gb=1.0, context_length=32768, is_default=True,
),
"tinyfish": ModelInfo(
name="qwen2.5:3b", provider=ModelProvider.OLLAMA,
size_gb=2.0, context_length=32768, is_default=True,
),
"mediumfish": ModelInfo(
name="qwen2.5:7b", provider=ModelProvider.OLLAMA,
size_gb=4.7, context_length=32768, is_default=True,
),
"bigfish": ModelInfo(
name="qwen2.5:14b", provider=ModelProvider.OLLAMA,
size_gb=9.0, context_length=32768, is_default=True,
),
}
# HuggingFace models that work well for this pipeline
HF_RECOMMENDED = [
"Qwen/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
"deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-R1",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"meta-llama/Llama-4-Maverick-17B-128E-Instruct",
"mistralai/Mistral-Small-24B-Instruct-2501",
"microsoft/phi-4",
]
def _add_hf_fallbacks(target_list: list[ModelInfo]):
"""Add all HF recommended models as fallbacks."""
for model_id in HF_RECOMMENDED:
target_list.append(ModelInfo(
name=model_id,
provider=ModelProvider.HUGGINGFACE,
))
class ModelManager:
"""
Detects and manages models from Ollama (local) and HuggingFace (cloud).
Provides unified interface for the pipeline to request models.
"""
def __init__(
self,
ollama_url: str = "http://localhost:11434",
hf_token: Optional[str] = None,
):
self.ollama_url = ollama_url
self.hf_token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
self.hf_api_url = "https://router.huggingface.co/v1"
# Discovered models
self.ollama_models: list[ModelInfo] = []
self.hf_models: list[ModelInfo] = []
# Selected models for each tier
self.selected: dict[str, ModelInfo] = {}
async def discover_all(self):
"""Discover all available models from both providers."""
await asyncio.gather(
self._discover_ollama(),
self._discover_hf(),
)
async def _discover_ollama(self):
"""Detect locally installed Ollama models."""
self.ollama_models = []
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.ollama_url}/api/tags",
timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
data = await resp.json()
for model in data.get("models", []):
name = model.get("name", "")
size_bytes = model.get("size", 0)
size_gb = size_bytes / (1024**3) if size_bytes else None
# Extract quantization from name
quant = None
for q in ["q4_0", "q4_k_m", "q5_k_m", "q8_0", "fp16"]:
if q in name.lower():
quant = q
break
self.ollama_models.append(ModelInfo(
name=name,
provider=ModelProvider.OLLAMA,
size_gb=round(size_gb, 1) if size_gb else None,
quantization=quant,
))
logger.info(f"Discovered {len(self.ollama_models)} Ollama models")
else:
logger.warning(f"Ollama returned status {resp.status}")
except asyncio.TimeoutError:
logger.warning("Ollama discovery timed out (5s). Is Ollama running?")
except aiohttp.ClientError as e:
logger.warning(f"Ollama not reachable: {e}")
async def _discover_hf(self):
"""Check which HuggingFace models are available via Inference API."""
self.hf_models = []
if not self.hf_token:
logger.info("No HF_TOKEN set β€” using recommended model list without validation")
_add_hf_fallbacks(self.hf_models)
return
# With token, check which models are actually accessible
headers = {"Authorization": f"Bearer {self.hf_token}"}
async with aiohttp.ClientSession() as session:
for model_id in HF_RECOMMENDED:
try:
async with session.get(
f"https://huggingface.co/api/models/{model_id}",
headers=headers,
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status == 200:
data = await resp.json()
self.hf_models.append(ModelInfo(
name=model_id,
provider=ModelProvider.HUGGINGFACE,
context_length=data.get("config", {}).get("max_position_embeddings"),
))
logger.debug(f"HF model validated: {model_id}")
elif resp.status == 401:
logger.warning(f"HF token invalid for {model_id}")
self.hf_models.append(ModelInfo(
name=model_id, provider=ModelProvider.HUGGINGFACE
))
else:
logger.debug(f"HF model {model_id} status {resp.status}")
self.hf_models.append(ModelInfo(
name=model_id, provider=ModelProvider.HUGGINGFACE
))
except asyncio.TimeoutError:
logger.debug(f"HF model {model_id} discovery timed out")
self.hf_models.append(ModelInfo(
name=model_id, provider=ModelProvider.HUGGINGFACE
))
except aiohttp.ClientError as e:
logger.debug(f"HF model {model_id} discovery error: {e}")
self.hf_models.append(ModelInfo(
name=model_id, provider=ModelProvider.HUGGINGFACE
))
def get_all_models(self) -> list[ModelInfo]:
"""Get all discovered models (local + cloud)."""
return self.ollama_models + self.hf_models
def get_local_models(self) -> list[ModelInfo]:
"""Get only locally installed models."""
return self.ollama_models
def get_cloud_models(self) -> list[ModelInfo]:
"""Get HuggingFace cloud models."""
return self.hf_models
def select_model(self, tier: str, model: ModelInfo):
"""Select a model for a specific tier (microfish/tinyfish/mediumfish/bigfish)."""
self.selected[tier] = model
def get_selected(self, tier: str) -> ModelInfo:
"""Get the selected model for a tier, or return default."""
return self.selected.get(tier, DEFAULTS.get(tier, DEFAULTS["mediumfish"]))
def get_endpoint(self, tier: str) -> tuple[str, str, dict]:
"""
Get the API endpoint info for the selected model.
Returns: (base_url, model_name, headers)
"""
model = self.get_selected(tier)
if model.provider == ModelProvider.OLLAMA:
return (
f"{self.ollama_url}/v1",
model.name,
{},
)
else:
# HuggingFace Inference API
headers = {}
if self.hf_token:
headers["Authorization"] = f"Bearer {self.hf_token}"
return (
self.hf_api_url,
model.name,
headers,
)
def auto_assign_defaults(self):
"""
Automatically assign best available models to each tier.
Prefers local (Ollama) over cloud (HF) for speed + privacy.
"""
local_names = {m.name.lower(): m for m in self.ollama_models}
for tier, default in DEFAULTS.items():
# Try to find the default model locally
if default.name.lower() in local_names:
self.selected[tier] = local_names[default.name.lower()]
elif self.ollama_models:
# Use the best available local model for this tier
sorted_local = sorted(self.ollama_models, key=lambda m: m.size_gb or 0)
if tier == "microfish" and sorted_local:
self.selected[tier] = sorted_local[0] # smallest
elif tier == "bigfish" and sorted_local:
self.selected[tier] = sorted_local[-1] # largest
elif sorted_local:
mid = len(sorted_local) // 2
self.selected[tier] = sorted_local[mid] # middle
elif self.hf_models:
# Fallback to HuggingFace cloud β€” pick size-appropriate model
hf_tier_map = {
"microfish": "Qwen/Qwen2.5-7B-Instruct",
"tinyfish": "Qwen/Qwen2.5-7B-Instruct",
"mediumfish": "Qwen/Qwen2.5-14B-Instruct",
"bigfish": "Qwen/Qwen2.5-72B-Instruct",
}
target = hf_tier_map.get(tier, "Qwen/Qwen2.5-7B-Instruct")
matched = [m for m in self.hf_models if m.name == target]
self.selected[tier] = matched[0] if matched else self.hf_models[0]
else:
# Use defaults (will fail at runtime if nothing available)
self.selected[tier] = default
def print_status(self):
"""Print current model configuration."""
from rich.console import Console
from rich.table import Table
console = Console()
# Discovery summary
console.print(f"\n[bold]πŸ” Model Discovery[/]")
console.print(f" Ollama (local): {len(self.ollama_models)} models")
console.print(f" HuggingFace (cloud): {len(self.hf_models)} models")
if not self.hf_token:
console.print(f" [yellow]⚠ No HF_TOKEN set β€” cloud models may have rate limits[/]")
# Available models table
if self.ollama_models:
table = Table(title="Local Models (Ollama)")
table.add_column("#", width=3)
table.add_column("Model", style="cyan")
table.add_column("Size", style="green")
table.add_column("Quant", style="yellow")
for i, m in enumerate(self.ollama_models, 1):
table.add_row(
str(i), m.name,
f"{m.size_gb:.1f}GB" if m.size_gb else "?",
m.quantization or "-",
)
console.print(table)
# Selected models
table2 = Table(title="Selected Models (Pipeline)")
table2.add_column("Tier", style="bold")
table2.add_column("Model", style="cyan")
table2.add_column("Provider", style="magenta")
table2.add_column("Use", style="dim")
tier_uses = {
"microfish": "Hypothesis generation (bulk)",
"tinyfish": "Expression compilation",
"mediumfish": "Crowd scout + surgeon",
"bigfish": "Gatekeeper (final memo)",
}
for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
model = self.get_selected(tier)
table2.add_row(
tier, model.name, model.provider.value,
tier_uses.get(tier, ""),
)
console.print(table2)
def interactive_model_select(manager: ModelManager) -> dict[str, ModelInfo]:
"""
Interactive CLI menu for model selection.
Shows all available models and lets user pick for each tier.
"""
from rich.console import Console
from rich.prompt import Prompt, IntPrompt
console = Console()
all_models = manager.get_all_models()
if not all_models:
console.print("[red]No models found! Install Ollama models or set HF_TOKEN.[/]")
console.print(" ollama pull qwen2.5:1.5b")
console.print(" ollama pull qwen2.5:7b")
console.print(" export HF_TOKEN=hf_your_token")
return {}
console.print("\n[bold]πŸ“‹ Available Models:[/]")
for i, m in enumerate(all_models, 1):
console.print(f" {i:2d}. {m.display_name()}")
selections = {}
for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]:
default = DEFAULTS[tier]
console.print(f"\n[bold]Select model for [{tier}][/] (default: {default.name}):")
tier_desc = {"microfish": "bulk generation", "tinyfish": "compilation", "mediumfish": "critique", "bigfish": "final gate"}
console.print(f" Use: {tier_desc[tier]}")
choice = Prompt.ask(
f" Enter number (1-{len(all_models)}) or press Enter for default",
default="",
)
if choice and choice.isdigit():
idx = int(choice) - 1
if 0 <= idx < len(all_models):
selections[tier] = all_models[idx]
console.print(f" β†’ Selected: {all_models[idx].display_name()}")
else:
selections[tier] = default
console.print(f" β†’ Using default: {default.name}")
else:
selections[tier] = default
console.print(f" β†’ Using default: {default.name}")
return selections