qa1145 commited on
Commit
c4c3b2e
·
verified ·
1 Parent(s): e6c065b

Upload 9 files

Browse files
Files changed (2) hide show
  1. app.py +36 -54
  2. src/model_tester.py +77 -0
app.py CHANGED
@@ -66,6 +66,31 @@ async def list_models():
66
  }
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @fastapi_app.post("/v1/chat/completions")
70
  async def chat_completions(request: ChatCompletionRequest):
71
  prompt = request.messages[-1].content if request.messages else ""
@@ -73,7 +98,7 @@ async def chat_completions(request: ChatCompletionRequest):
73
 
74
  if request.stream:
75
  return StreamingResponse(
76
- stream_chat(request.model, prompt, request.messages),
77
  media_type="text/event-stream"
78
  )
79
 
@@ -83,10 +108,7 @@ async def chat_completions(request: ChatCompletionRequest):
83
  raise HTTPException(status_code=400, detail=result.get("error", "Request failed"))
84
 
85
  response_data = result.get("response", {})
86
-
87
- content = ""
88
- if "choices" in response_data and response_data["choices"]:
89
- content = response_data["choices"][0].get("message", {}).get("content", "")
90
 
91
  return {
92
  "id": response_data.get("id", f"chatcmpl-{random.randint(100000, 999999)}"),
@@ -111,59 +133,19 @@ async def chat_completions(request: ChatCompletionRequest):
111
  }
112
 
113
 
114
- async def stream_chat(model_hint: Optional[str], prompt: str, messages: list):
115
- model_hint = model_hint or ""
116
-
117
- result = await model_tester.chat_completion(prompt, model_hint)
118
 
119
- if not result.get("success"):
120
- yield f'data: {{"error": "{result.get("error", "Request failed")}"}}\n\n'
121
  yield "data: [DONE]\n\n"
122
  return
123
 
124
- response_data = result.get("response", {})
125
- content = ""
126
- if "choices" in response_data and response_data["choices"]:
127
- content = response_data["choices"][0].get("message", {}).get("content", "")
128
-
129
- model_id = result.get("model", model_hint or "unknown")
130
- completion_id = f"chatcmpl-{random.randint(100000, 999999)}"
131
- created = int(datetime.now().timestamp())
132
-
133
- # 流式输出每个字
134
- for i, char in enumerate(content):
135
- chunk = {
136
- "id": completion_id,
137
- "object": "chat.completion.chunk",
138
- "created": created,
139
- "model": model_id,
140
- "choices": [
141
- {
142
- "index": 0,
143
- "delta": {
144
- "content": char
145
- },
146
- "finish_reason": None
147
- }
148
- ]
149
- }
150
- yield f"data: {json.dumps(chunk)}\n\n"
151
-
152
- # 发送完成信号
153
- final_chunk = {
154
- "id": completion_id,
155
- "object": "chat.completion.chunk",
156
- "created": created,
157
- "model": model_id,
158
- "choices": [
159
- {
160
- "index": 0,
161
- "delta": {},
162
- "finish_reason": "stop"
163
- }
164
- ]
165
- }
166
- yield f"data: {json.dumps(final_chunk)}\n\n"
167
  yield "data: [DONE]\n\n"
168
 
169
 
 
66
  }
67
 
68
 
69
+ def parse_openrouter_response(response_data: dict) -> str:
70
+ """从OpenRouter响应中提取内容"""
71
+ content = ""
72
+
73
+ # 标准OpenAI格式
74
+ if "choices" in response_data and response_data["choices"]:
75
+ choices = response_data["choices"]
76
+ if choices:
77
+ msg = choices[0].get("message", {})
78
+ content = msg.get("content", "")
79
+ if not content:
80
+ # 可能是delta格式
81
+ delta = choices[0].get("delta", {})
82
+ content = delta.get("content", "")
83
+
84
+ # 直接返回的情况
85
+ if not content and "message" in response_data:
86
+ content = response_data.get("message", {}).get("content", "")
87
+
88
+ if not content and "content" in response_data:
89
+ content = response_data.get("content", "")
90
+
91
+ return content
92
+
93
+
94
  @fastapi_app.post("/v1/chat/completions")
95
  async def chat_completions(request: ChatCompletionRequest):
96
  prompt = request.messages[-1].content if request.messages else ""
 
98
 
99
  if request.stream:
100
  return StreamingResponse(
101
+ stream_chat(request.model, [{"role": m.role, "content": m.content} for m in request.messages]),
102
  media_type="text/event-stream"
103
  )
104
 
 
108
  raise HTTPException(status_code=400, detail=result.get("error", "Request failed"))
109
 
110
  response_data = result.get("response", {})
111
+ content = parse_openrouter_response(response_data)
 
 
 
112
 
113
  return {
114
  "id": response_data.get("id", f"chatcmpl-{random.randint(100000, 999999)}"),
 
133
  }
134
 
135
 
136
+ async def stream_chat(model_hint: Optional[str], messages: list):
137
+ # 直接代理OpenRouter的流式响应
138
+ stream, used_model = await model_tester.chat_completion_stream(model_hint, messages)
 
139
 
140
+ if not stream:
141
+ yield f'data: {{"error": "No available model found"}}\n\n'
142
  yield "data: [DONE]\n\n"
143
  return
144
 
145
+ # 直接转发流式数据
146
+ async for chunk in stream:
147
+ yield chunk.decode() if isinstance(chunk, bytes) else chunk
148
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  yield "data: [DONE]\n\n"
150
 
151
 
src/model_tester.py CHANGED
@@ -139,6 +139,47 @@ class ModelTester:
139
  """Get all free models from API list (not tested)"""
140
  return self._free_models
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  async def try_model_direct(
143
  self,
144
  session: aiohttp.ClientSession,
@@ -356,3 +397,39 @@ class ModelTester:
356
  def test_all_models(self) -> Dict[str, Any]:
357
  """Legacy sync method - use scan_all_models instead"""
358
  return self.scan_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  """Get all free models from API list (not tested)"""
140
  return self._free_models
141
 
142
+ async def try_model_direct_stream(
143
+ self,
144
+ session: aiohttp.ClientSession,
145
+ model_id: str,
146
+ api_key: str,
147
+ messages: List[Dict[str, str]]
148
+ ) -> Optional[Dict[str, Any]]:
149
+ """发送流式请求到OpenRouter"""
150
+ url = "https://openrouter.ai/api/v1/chat/completions"
151
+ payload = {
152
+ "model": model_id,
153
+ "messages": messages,
154
+ "max_tokens": 2048,
155
+ "stream": True
156
+ }
157
+ headers = {
158
+ "Authorization": f"Bearer {api_key}",
159
+ "Content-Type": "application/json"
160
+ }
161
+
162
+ try:
163
+ timeout = aiohttp.ClientTimeout(total=config.get_request_timeout())
164
+ async with session.post(url, json=payload, headers=headers, timeout=timeout) as response:
165
+ if response.status == 200:
166
+ return {
167
+ "success": True,
168
+ "model": model_id,
169
+ "stream": response.content,
170
+ "method": "direct"
171
+ }
172
+ else:
173
+ body = await response.text()
174
+ return {
175
+ "success": False,
176
+ "model": model_id,
177
+ "error": f"HTTP {response.status}: {body[:100]}",
178
+ "method": "direct"
179
+ }
180
+ except Exception as e:
181
+ return {"success": False, "model": model_id, "error": str(e), "method": "direct"}
182
+
183
  async def try_model_direct(
184
  self,
185
  session: aiohttp.ClientSession,
 
397
  def test_all_models(self) -> Dict[str, Any]:
398
  """Legacy sync method - use scan_all_models instead"""
399
  return self.scan_all_models()
400
+
401
+ async def chat_completion_stream(self, model_hint: Optional[str], messages: List[Dict[str, str]]):
402
+ """流式聊天 - 返回流式响应对象"""
403
+ api_keys = config.get_api_keys()
404
+ api_key = random.choice(api_keys)
405
+
406
+ # 方案1:尝试用户指定的模型
407
+ if model_hint:
408
+ full_model_id = self.find_model_in_list(model_hint)
409
+ if full_model_id:
410
+ async with aiohttp.ClientSession() as session:
411
+ result = await self.try_model_direct_stream(session, full_model_id, api_key, messages)
412
+ if result and result.get("success"):
413
+ return result.get("stream"), result.get("model")
414
+
415
+ # 方案2:从列表中找到可用模型
416
+ self.refresh_model_list()
417
+ available_free = self.get_all_free_models()
418
+
419
+ candidates = []
420
+ if model_hint and available_free:
421
+ for m in available_free:
422
+ model_name = m.replace(":free", "").split("/")[-1]
423
+ if model_hint.lower() in model_name.lower():
424
+ candidates.append(m)
425
+
426
+ if not candidates and available_free:
427
+ candidates = available_free[:10]
428
+
429
+ async with aiohttp.ClientSession() as session:
430
+ for model_id in candidates:
431
+ result = await self.try_model_direct_stream(session, model_id, api_key, messages)
432
+ if result and result.get("success"):
433
+ return result.get("stream"), result.get("model")
434
+
435
+ return None, None