| """ |
| Model Manager β Unified interface for Ollama (local) + HuggingFace Inference API (cloud). |
| Auto-detects available models from both sources. |
| User selects which to use via interactive menu or config. |
| """ |
| import asyncio |
| import aiohttp |
| import os |
| import logging |
| from dataclasses import dataclass, field |
| from typing import Optional |
| from enum import Enum |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModelProvider(str, Enum): |
| OLLAMA = "ollama" |
| HUGGINGFACE = "huggingface" |
|
|
|
|
| @dataclass |
| class ModelInfo: |
| """Metadata about an available model.""" |
| name: str |
| provider: ModelProvider |
| size_gb: Optional[float] = None |
| quantization: Optional[str] = None |
| context_length: Optional[int] = None |
| is_default: bool = False |
|
|
| def display_name(self) -> str: |
| size_str = f" ({self.size_gb:.1f}GB)" if self.size_gb else "" |
| quant_str = f" [{self.quantization}]" if self.quantization else "" |
| return f"[{self.provider.value}] {self.name}{size_str}{quant_str}" |
|
|
|
|
| |
| DEFAULTS = { |
| "microfish": ModelInfo( |
| name="qwen2.5:1.5b", provider=ModelProvider.OLLAMA, |
| size_gb=1.0, context_length=32768, is_default=True, |
| ), |
| "tinyfish": ModelInfo( |
| name="qwen2.5:3b", provider=ModelProvider.OLLAMA, |
| size_gb=2.0, context_length=32768, is_default=True, |
| ), |
| "mediumfish": ModelInfo( |
| name="qwen2.5:7b", provider=ModelProvider.OLLAMA, |
| size_gb=4.7, context_length=32768, is_default=True, |
| ), |
| "bigfish": ModelInfo( |
| name="qwen2.5:14b", provider=ModelProvider.OLLAMA, |
| size_gb=9.0, context_length=32768, is_default=True, |
| ), |
| } |
|
|
| |
| HF_RECOMMENDED = [ |
| "Qwen/Qwen2.5-72B-Instruct", |
| "Qwen/Qwen2.5-32B-Instruct", |
| "Qwen/Qwen2.5-14B-Instruct", |
| "Qwen/Qwen2.5-7B-Instruct", |
| "Qwen/Qwen2.5-Coder-7B-Instruct", |
| "deepseek-ai/DeepSeek-V3", |
| "deepseek-ai/DeepSeek-R1", |
| "meta-llama/Llama-4-Scout-17B-16E-Instruct", |
| "meta-llama/Llama-4-Maverick-17B-128E-Instruct", |
| "mistralai/Mistral-Small-24B-Instruct-2501", |
| "microsoft/phi-4", |
| ] |
|
|
|
|
| def _add_hf_fallbacks(target_list: list[ModelInfo]): |
| """Add all HF recommended models as fallbacks.""" |
| for model_id in HF_RECOMMENDED: |
| target_list.append(ModelInfo( |
| name=model_id, |
| provider=ModelProvider.HUGGINGFACE, |
| )) |
|
|
|
|
| class ModelManager: |
| """ |
| Detects and manages models from Ollama (local) and HuggingFace (cloud). |
| Provides unified interface for the pipeline to request models. |
| """ |
|
|
| def __init__( |
| self, |
| ollama_url: str = "http://localhost:11434", |
| hf_token: Optional[str] = None, |
| ): |
| self.ollama_url = ollama_url |
| self.hf_token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") |
| self.hf_api_url = "https://router.huggingface.co/v1" |
|
|
| |
| self.ollama_models: list[ModelInfo] = [] |
| self.hf_models: list[ModelInfo] = [] |
|
|
| |
| self.selected: dict[str, ModelInfo] = {} |
|
|
| async def discover_all(self): |
| """Discover all available models from both providers.""" |
| await asyncio.gather( |
| self._discover_ollama(), |
| self._discover_hf(), |
| ) |
|
|
| async def _discover_ollama(self): |
| """Detect locally installed Ollama models.""" |
| self.ollama_models = [] |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.ollama_url}/api/tags", |
| timeout=aiohttp.ClientTimeout(total=5) |
| ) as resp: |
| if resp.status == 200: |
| data = await resp.json() |
| for model in data.get("models", []): |
| name = model.get("name", "") |
| size_bytes = model.get("size", 0) |
| size_gb = size_bytes / (1024**3) if size_bytes else None |
|
|
| |
| quant = None |
| for q in ["q4_0", "q4_k_m", "q5_k_m", "q8_0", "fp16"]: |
| if q in name.lower(): |
| quant = q |
| break |
|
|
| self.ollama_models.append(ModelInfo( |
| name=name, |
| provider=ModelProvider.OLLAMA, |
| size_gb=round(size_gb, 1) if size_gb else None, |
| quantization=quant, |
| )) |
| logger.info(f"Discovered {len(self.ollama_models)} Ollama models") |
| else: |
| logger.warning(f"Ollama returned status {resp.status}") |
| except asyncio.TimeoutError: |
| logger.warning("Ollama discovery timed out (5s). Is Ollama running?") |
| except aiohttp.ClientError as e: |
| logger.warning(f"Ollama not reachable: {e}") |
|
|
| async def _discover_hf(self): |
| """Check which HuggingFace models are available via Inference API.""" |
| self.hf_models = [] |
| if not self.hf_token: |
| logger.info("No HF_TOKEN set β using recommended model list without validation") |
| _add_hf_fallbacks(self.hf_models) |
| return |
|
|
| |
| headers = {"Authorization": f"Bearer {self.hf_token}"} |
| async with aiohttp.ClientSession() as session: |
| for model_id in HF_RECOMMENDED: |
| try: |
| async with session.get( |
| f"https://huggingface.co/api/models/{model_id}", |
| headers=headers, |
| timeout=aiohttp.ClientTimeout(total=5), |
| ) as resp: |
| if resp.status == 200: |
| data = await resp.json() |
| self.hf_models.append(ModelInfo( |
| name=model_id, |
| provider=ModelProvider.HUGGINGFACE, |
| context_length=data.get("config", {}).get("max_position_embeddings"), |
| )) |
| logger.debug(f"HF model validated: {model_id}") |
| elif resp.status == 401: |
| logger.warning(f"HF token invalid for {model_id}") |
| self.hf_models.append(ModelInfo( |
| name=model_id, provider=ModelProvider.HUGGINGFACE |
| )) |
| else: |
| logger.debug(f"HF model {model_id} status {resp.status}") |
| self.hf_models.append(ModelInfo( |
| name=model_id, provider=ModelProvider.HUGGINGFACE |
| )) |
| except asyncio.TimeoutError: |
| logger.debug(f"HF model {model_id} discovery timed out") |
| self.hf_models.append(ModelInfo( |
| name=model_id, provider=ModelProvider.HUGGINGFACE |
| )) |
| except aiohttp.ClientError as e: |
| logger.debug(f"HF model {model_id} discovery error: {e}") |
| self.hf_models.append(ModelInfo( |
| name=model_id, provider=ModelProvider.HUGGINGFACE |
| )) |
|
|
| def get_all_models(self) -> list[ModelInfo]: |
| """Get all discovered models (local + cloud).""" |
| return self.ollama_models + self.hf_models |
|
|
| def get_local_models(self) -> list[ModelInfo]: |
| """Get only locally installed models.""" |
| return self.ollama_models |
|
|
| def get_cloud_models(self) -> list[ModelInfo]: |
| """Get HuggingFace cloud models.""" |
| return self.hf_models |
|
|
| def select_model(self, tier: str, model: ModelInfo): |
| """Select a model for a specific tier (microfish/tinyfish/mediumfish/bigfish).""" |
| self.selected[tier] = model |
|
|
| def get_selected(self, tier: str) -> ModelInfo: |
| """Get the selected model for a tier, or return default.""" |
| return self.selected.get(tier, DEFAULTS.get(tier, DEFAULTS["mediumfish"])) |
|
|
| def get_endpoint(self, tier: str) -> tuple[str, str, dict]: |
| """ |
| Get the API endpoint info for the selected model. |
| Returns: (base_url, model_name, headers) |
| """ |
| model = self.get_selected(tier) |
|
|
| if model.provider == ModelProvider.OLLAMA: |
| return ( |
| f"{self.ollama_url}/v1", |
| model.name, |
| {}, |
| ) |
| else: |
| |
| headers = {} |
| if self.hf_token: |
| headers["Authorization"] = f"Bearer {self.hf_token}" |
| return ( |
| self.hf_api_url, |
| model.name, |
| headers, |
| ) |
|
|
| def auto_assign_defaults(self): |
| """ |
| Automatically assign best available models to each tier. |
| Prefers local (Ollama) over cloud (HF) for speed + privacy. |
| """ |
| local_names = {m.name.lower(): m for m in self.ollama_models} |
|
|
| for tier, default in DEFAULTS.items(): |
| |
| if default.name.lower() in local_names: |
| self.selected[tier] = local_names[default.name.lower()] |
| elif self.ollama_models: |
| |
| sorted_local = sorted(self.ollama_models, key=lambda m: m.size_gb or 0) |
| if tier == "microfish" and sorted_local: |
| self.selected[tier] = sorted_local[0] |
| elif tier == "bigfish" and sorted_local: |
| self.selected[tier] = sorted_local[-1] |
| elif sorted_local: |
| mid = len(sorted_local) // 2 |
| self.selected[tier] = sorted_local[mid] |
| elif self.hf_models: |
| |
| hf_tier_map = { |
| "microfish": "Qwen/Qwen2.5-7B-Instruct", |
| "tinyfish": "Qwen/Qwen2.5-7B-Instruct", |
| "mediumfish": "Qwen/Qwen2.5-14B-Instruct", |
| "bigfish": "Qwen/Qwen2.5-72B-Instruct", |
| } |
| target = hf_tier_map.get(tier, "Qwen/Qwen2.5-7B-Instruct") |
| matched = [m for m in self.hf_models if m.name == target] |
| self.selected[tier] = matched[0] if matched else self.hf_models[0] |
| else: |
| |
| self.selected[tier] = default |
|
|
| def print_status(self): |
| """Print current model configuration.""" |
| from rich.console import Console |
| from rich.table import Table |
|
|
| console = Console() |
|
|
| |
| console.print(f"\n[bold]π Model Discovery[/]") |
| console.print(f" Ollama (local): {len(self.ollama_models)} models") |
| console.print(f" HuggingFace (cloud): {len(self.hf_models)} models") |
| if not self.hf_token: |
| console.print(f" [yellow]β No HF_TOKEN set β cloud models may have rate limits[/]") |
|
|
| |
| if self.ollama_models: |
| table = Table(title="Local Models (Ollama)") |
| table.add_column("#", width=3) |
| table.add_column("Model", style="cyan") |
| table.add_column("Size", style="green") |
| table.add_column("Quant", style="yellow") |
| for i, m in enumerate(self.ollama_models, 1): |
| table.add_row( |
| str(i), m.name, |
| f"{m.size_gb:.1f}GB" if m.size_gb else "?", |
| m.quantization or "-", |
| ) |
| console.print(table) |
|
|
| |
| table2 = Table(title="Selected Models (Pipeline)") |
| table2.add_column("Tier", style="bold") |
| table2.add_column("Model", style="cyan") |
| table2.add_column("Provider", style="magenta") |
| table2.add_column("Use", style="dim") |
|
|
| tier_uses = { |
| "microfish": "Hypothesis generation (bulk)", |
| "tinyfish": "Expression compilation", |
| "mediumfish": "Crowd scout + surgeon", |
| "bigfish": "Gatekeeper (final memo)", |
| } |
|
|
| for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]: |
| model = self.get_selected(tier) |
| table2.add_row( |
| tier, model.name, model.provider.value, |
| tier_uses.get(tier, ""), |
| ) |
| console.print(table2) |
|
|
|
|
| def interactive_model_select(manager: ModelManager) -> dict[str, ModelInfo]: |
| """ |
| Interactive CLI menu for model selection. |
| Shows all available models and lets user pick for each tier. |
| """ |
| from rich.console import Console |
| from rich.prompt import Prompt, IntPrompt |
|
|
| console = Console() |
| all_models = manager.get_all_models() |
|
|
| if not all_models: |
| console.print("[red]No models found! Install Ollama models or set HF_TOKEN.[/]") |
| console.print(" ollama pull qwen2.5:1.5b") |
| console.print(" ollama pull qwen2.5:7b") |
| console.print(" export HF_TOKEN=hf_your_token") |
| return {} |
|
|
| console.print("\n[bold]π Available Models:[/]") |
| for i, m in enumerate(all_models, 1): |
| console.print(f" {i:2d}. {m.display_name()}") |
|
|
| selections = {} |
| for tier in ["microfish", "tinyfish", "mediumfish", "bigfish"]: |
| default = DEFAULTS[tier] |
| console.print(f"\n[bold]Select model for [{tier}][/] (default: {default.name}):") |
| tier_desc = {"microfish": "bulk generation", "tinyfish": "compilation", "mediumfish": "critique", "bigfish": "final gate"} |
| console.print(f" Use: {tier_desc[tier]}") |
|
|
| choice = Prompt.ask( |
| f" Enter number (1-{len(all_models)}) or press Enter for default", |
| default="", |
| ) |
|
|
| if choice and choice.isdigit(): |
| idx = int(choice) - 1 |
| if 0 <= idx < len(all_models): |
| selections[tier] = all_models[idx] |
| console.print(f" β Selected: {all_models[idx].display_name()}") |
| else: |
| selections[tier] = default |
| console.print(f" β Using default: {default.name}") |
| else: |
| selections[tier] = default |
| console.print(f" β Using default: {default.name}") |
|
|
| return selections |
|
|