qa1145 commited on
Commit
9188527
·
verified ·
1 Parent(s): 98263a0

Upload 9 files

Browse files
Files changed (1) hide show
  1. src/model_tester.py +64 -40
src/model_tester.py CHANGED
@@ -198,12 +198,18 @@ class ModelTester:
198
  available_free = self.get_all_free_models()
199
  print(f"[try_best] Found {len(available_free)} free models")
200
 
201
- # 第二步:用关键词匹配模型
202
  candidates = []
203
 
204
  if keyword and available_free:
205
- matched = [m for m in available_free if keyword.lower() in m.lower()]
206
- print(f"[try_best] Keyword '{keyword}' matched: {matched}")
 
 
 
 
 
 
207
  if matched:
208
  candidates.extend([(m, "matched") for m in matched[:10]])
209
 
@@ -264,6 +270,23 @@ class ModelTester:
264
  "method": "list_fallback"
265
  }
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  async def chat_completion(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]:
268
  api_keys = config.get_api_keys()
269
  api_key = random.choice(api_keys)
@@ -271,54 +294,55 @@ class ModelTester:
271
  async with aiohttp.ClientSession() as session:
272
  tasks = []
273
 
 
274
  if model_hint:
275
- full_model = f"{model_hint}:free" if ":free" not in model_hint else model_hint
276
- tasks.append(asyncio.create_task(
277
- self.try_model_direct(session, full_model, api_key)
278
- ))
 
 
 
 
 
 
 
 
 
 
279
 
280
  tasks.append(asyncio.create_task(
281
  self.try_best_available_model(session, model_hint or "", api_key)
282
  ))
283
 
284
- done, pending = await asyncio.wait(
285
- tasks,
286
- return_when=asyncio.FIRST_COMPLETED
287
- )
288
 
289
- for task in pending:
290
- task.cancel()
 
 
 
 
 
 
 
291
 
292
- if tasks[0] in done:
293
- result = tasks[0].result()
294
- if result and result.get("success"):
295
- return {
296
- "success": True,
297
- "response": result.get("response"),
298
- "method": result.get("method"),
299
- "model": result.get("model")
300
- }
301
-
302
- if tasks[1] in done:
303
- result = tasks[1].result()
304
- if result and result.get("success"):
305
- return {
306
- "success": True,
307
- "response": result.get("response"),
308
- "method": result.get("method"),
309
- "model": result.get("model")
310
- }
311
- else:
312
- return {
313
- "success": False,
314
- "error": result.get("error", "Unknown error"),
315
- "method": result.get("method")
316
- }
317
 
 
318
  return {
319
  "success": False,
320
- "error": "Both methods failed",
321
- "method": "both_failed"
322
  }
323
 
324
  def chat_completion_sync(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]:
 
198
  available_free = self.get_all_free_models()
199
  print(f"[try_best] Found {len(available_free)} free models")
200
 
201
+ # 第二步:用关键词匹配模型(避免匹配到不完整的ID)
202
  candidates = []
203
 
204
  if keyword and available_free:
205
+ # 只匹配模型名部分,不匹配作者前缀
206
+ matched = []
207
+ for m in available_free:
208
+ model_name = m.replace(":free", "").split("/")[-1]
209
+ if keyword.lower() in model_name.lower():
210
+ matched.append(m)
211
+
212
+ print(f"[try_best] Keyword '{keyword}' matched: {matched[:5]}")
213
  if matched:
214
  candidates.extend([(m, "matched") for m in matched[:10]])
215
 
 
270
  "method": "list_fallback"
271
  }
272
 
273
+ def find_model_in_list(self, keyword: str) -> Optional[str]:
274
+ """Find full model ID from keyword"""
275
+ available_free = self.get_all_free_models()
276
+
277
+ # 先精确匹配
278
+ for model in available_free:
279
+ model_name = model.replace(":free", "").split("/")[-1]
280
+ if model_name.lower() == keyword.lower():
281
+ return model
282
+
283
+ # 然后模糊匹配
284
+ for model in available_free:
285
+ if keyword.lower() in model.lower():
286
+ return model
287
+
288
+ return None
289
+
290
  async def chat_completion(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]:
291
  api_keys = config.get_api_keys()
292
  api_key = random.choice(api_keys)
 
294
  async with aiohttp.ClientSession() as session:
295
  tasks = []
296
 
297
+ # 方案1:用户指定模型,需要先找到完整的模型ID
298
  if model_hint:
299
+ # 尝试在模型列表中找到匹配的完整模型ID
300
+ full_model_id = self.find_model_in_list(model_hint)
301
+
302
+ if full_model_id:
303
+ # 找到完整ID,直接使用
304
+ tasks.append(asyncio.create_task(
305
+ self.try_model_direct(session, full_model_id, api_key)
306
+ ))
307
+ else:
308
+ # 没找到,尝试用原始输入(可能是完整ID)
309
+ full_model = f"{model_hint}:free" if ":free" not in model_hint else model_hint
310
+ tasks.append(asyncio.create_task(
311
+ self.try_model_direct(session, full_model, api_key)
312
+ ))
313
 
314
  tasks.append(asyncio.create_task(
315
  self.try_best_available_model(session, model_hint or "", api_key)
316
  ))
317
 
318
+ # 等待所有任务完成
319
+ results = await asyncio.gather(*tasks, return_exceptions=True)
 
 
320
 
321
+ # 先检查方案1
322
+ result1 = results[0]
323
+ if isinstance(result1, dict) and result1.get("success"):
324
+ return {
325
+ "success": True,
326
+ "response": result1.get("response"),
327
+ "method": result1.get("method"),
328
+ "model": result1.get("model")
329
+ }
330
 
331
+ # 方案1失败,检查方案2
332
+ result2 = results[1]
333
+ if isinstance(result2, dict) and result2.get("success"):
334
+ return {
335
+ "success": True,
336
+ "response": result2.get("response"),
337
+ "method": result2.get("method"),
338
+ "model": result2.get("model")
339
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ # 都失败了,返回方案2的错误(更详细)
342
  return {
343
  "success": False,
344
+ "error": result2.get("error", "Unknown error") if isinstance(result2, dict) else "Request failed",
345
+ "method": result2.get("method", "both_failed") if isinstance(result2, dict) else "both_failed"
346
  }
347
 
348
  def chat_completion_sync(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]: