gaurv007 commited on
Commit
238274e
·
verified ·
1 Parent(s): 9dc18cd

Upload alpha_factory/infra/model_manager.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/model_manager.py +61 -37
alpha_factory/infra/model_manager.py CHANGED
@@ -6,10 +6,13 @@ User selects which to use via interactive menu or config.
6
  import asyncio
7
  import aiohttp
8
  import os
 
9
  from dataclasses import dataclass, field
10
  from typing import Optional
11
  from enum import Enum
12
 
 
 
13
 
14
  class ModelProvider(str, Enum):
15
  OLLAMA = "ollama"
@@ -68,6 +71,15 @@ HF_RECOMMENDED = [
68
  ]
69
 
70
 
 
 
 
 
 
 
 
 
 
71
  class ModelManager:
72
  """
73
  Detects and manages models from Ollama (local) and HuggingFace (cloud).
@@ -102,7 +114,10 @@ class ModelManager:
102
  self.ollama_models = []
103
  try:
104
  async with aiohttp.ClientSession() as session:
105
- async with session.get(f"{self.ollama_url}/api/tags", timeout=aiohttp.ClientTimeout(total=5)) as resp:
 
 
 
106
  if resp.status == 200:
107
  data = await resp.json()
108
  for model in data.get("models", []):
@@ -123,51 +138,60 @@ class ModelManager:
123
  size_gb=round(size_gb, 1) if size_gb else None,
124
  quantization=quant,
125
  ))
126
- except (aiohttp.ClientError, asyncio.TimeoutError):
127
- pass # Ollama not running — that's fine
 
 
 
 
 
128
 
129
  async def _discover_hf(self):
130
  """Check which HuggingFace models are available via Inference API."""
131
  self.hf_models = []
132
  if not self.hf_token:
133
- # Still list recommended models (user can add token later)
134
- for model_id in HF_RECOMMENDED:
135
- self.hf_models.append(ModelInfo(
136
- name=model_id,
137
- provider=ModelProvider.HUGGINGFACE,
138
- ))
139
  return
140
 
141
  # With token, check which models are actually accessible
142
- try:
143
- async with aiohttp.ClientSession() as session:
144
- headers = {"Authorization": f"Bearer {self.hf_token}"}
145
- for model_id in HF_RECOMMENDED:
146
- try:
147
- async with session.get(
148
- f"https://huggingface.co/api/models/{model_id}",
149
- headers=headers,
150
- timeout=aiohttp.ClientTimeout(total=5),
151
- ) as resp:
152
- if resp.status == 200:
153
- data = await resp.json()
154
- self.hf_models.append(ModelInfo(
155
- name=model_id,
156
- provider=ModelProvider.HUGGINGFACE,
157
- context_length=data.get("config", {}).get("max_position_embeddings"),
158
- ))
159
- except:
160
- # Still list it — might work
161
- self.hf_models.append(ModelInfo(
162
- name=model_id,
163
- provider=ModelProvider.HUGGINGFACE,
164
- ))
165
- except:
166
  for model_id in HF_RECOMMENDED:
167
- self.hf_models.append(ModelInfo(
168
- name=model_id,
169
- provider=ModelProvider.HUGGINGFACE,
170
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def get_all_models(self) -> list[ModelInfo]:
173
  """Get all discovered models (local + cloud)."""
 
6
  import asyncio
7
  import aiohttp
8
  import os
9
+ import logging
10
  from dataclasses import dataclass, field
11
  from typing import Optional
12
  from enum import Enum
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
 
17
  class ModelProvider(str, Enum):
18
  OLLAMA = "ollama"
 
71
  ]
72
 
73
 
74
+ def _add_hf_fallbacks(target_list: list[ModelInfo]):
75
+ """Add all HF recommended models as fallbacks."""
76
+ for model_id in HF_RECOMMENDED:
77
+ target_list.append(ModelInfo(
78
+ name=model_id,
79
+ provider=ModelProvider.HUGGINGFACE,
80
+ ))
81
+
82
+
83
  class ModelManager:
84
  """
85
  Detects and manages models from Ollama (local) and HuggingFace (cloud).
 
114
  self.ollama_models = []
115
  try:
116
  async with aiohttp.ClientSession() as session:
117
+ async with session.get(
118
+ f"{self.ollama_url}/api/tags",
119
+ timeout=aiohttp.ClientTimeout(total=5)
120
+ ) as resp:
121
  if resp.status == 200:
122
  data = await resp.json()
123
  for model in data.get("models", []):
 
138
  size_gb=round(size_gb, 1) if size_gb else None,
139
  quantization=quant,
140
  ))
141
+ logger.info(f"Discovered {len(self.ollama_models)} Ollama models")
142
+ else:
143
+ logger.warning(f"Ollama returned status {resp.status}")
144
+ except asyncio.TimeoutError:
145
+ logger.warning("Ollama discovery timed out (5s). Is Ollama running?")
146
+ except aiohttp.ClientError as e:
147
+ logger.warning(f"Ollama not reachable: {e}")
148
 
149
  async def _discover_hf(self):
150
  """Check which HuggingFace models are available via Inference API."""
151
  self.hf_models = []
152
  if not self.hf_token:
153
+ logger.info("No HF_TOKEN set using recommended model list without validation")
154
+ _add_hf_fallbacks(self.hf_models)
 
 
 
 
155
  return
156
 
157
  # With token, check which models are actually accessible
158
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
159
+ async with aiohttp.ClientSession() as session:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  for model_id in HF_RECOMMENDED:
161
+ try:
162
+ async with session.get(
163
+ f"https://huggingface.co/api/models/{model_id}",
164
+ headers=headers,
165
+ timeout=aiohttp.ClientTimeout(total=5),
166
+ ) as resp:
167
+ if resp.status == 200:
168
+ data = await resp.json()
169
+ self.hf_models.append(ModelInfo(
170
+ name=model_id,
171
+ provider=ModelProvider.HUGGINGFACE,
172
+ context_length=data.get("config", {}).get("max_position_embeddings"),
173
+ ))
174
+ logger.debug(f"HF model validated: {model_id}")
175
+ elif resp.status == 401:
176
+ logger.warning(f"HF token invalid for {model_id}")
177
+ self.hf_models.append(ModelInfo(
178
+ name=model_id, provider=ModelProvider.HUGGINGFACE
179
+ ))
180
+ else:
181
+ logger.debug(f"HF model {model_id} status {resp.status}")
182
+ self.hf_models.append(ModelInfo(
183
+ name=model_id, provider=ModelProvider.HUGGINGFACE
184
+ ))
185
+ except asyncio.TimeoutError:
186
+ logger.debug(f"HF model {model_id} discovery timed out")
187
+ self.hf_models.append(ModelInfo(
188
+ name=model_id, provider=ModelProvider.HUGGINGFACE
189
+ ))
190
+ except aiohttp.ClientError as e:
191
+ logger.debug(f"HF model {model_id} discovery error: {e}")
192
+ self.hf_models.append(ModelInfo(
193
+ name=model_id, provider=ModelProvider.HUGGINGFACE
194
+ ))
195
 
196
  def get_all_models(self) -> list[ModelInfo]:
197
  """Get all discovered models (local + cloud)."""