| from pydantic import BaseModel |
| from typing import List, Optional, Dict, Any |
| import time |
| import re |
|
|
| |
| class ChatChoice(BaseModel): |
| index: int |
| message: Dict[str, str] |
| finish_reason: str |
|
|
|
|
| class ChatUsage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
|
|
|
|
| class ChatResponse(BaseModel): |
| id: str |
| object: str |
| created: int |
| model: str |
| choices: List[ChatChoice] |
| usage: ChatUsage |
|
|
|
|
| def convert_json_format(input_data): |
| """转换 pipeline 输出格式""" |
| output_generations = [] |
| for item in input_data: |
| generated_text_list = item.get('generated_text', []) |
| |
| assistant_content = "" |
| for message in generated_text_list: |
| if message.get('role') == 'assistant': |
| assistant_content = message.get('content', '') |
| break |
|
|
| |
| clean_content = re.sub(r'\s*', '', assistant_content, flags=re.DOTALL).strip() |
|
|
| output_generations.append([ |
| { |
| "text": clean_content, |
| "generationInfo": { |
| "finish_reason": "stop" |
| } |
| } |
| ]) |
| |
| return {"generations": output_generations} |
|
|
|
|
| def create_chat_response(request: Any, pipe=None, tokenizer=None) -> ChatResponse: |
| """ |
| 创建聊天响应 - 使用 pipeline 生成实际响应 |
| """ |
| if pipe is None: |
| |
| response_message = { |
| "role": "assistant", |
| "content": "模型正在初始化中,请稍后重试..." |
| } |
| completion_text = response_message["content"] |
| else: |
| |
| messages = request.messages |
| |
| |
| |
| max_new_tokens = request.max_tokens if request.max_tokens is not None else None |
| |
| |
| result = pipe(messages, max_new_tokens=max_new_tokens) |
| |
| |
| |
| converted_result = convert_json_format(result) |
| |
| |
| completion_text = converted_result["generations"][0][0]["text"] |
| |
| response_message = { |
| "role": "assistant", |
| "content": completion_text |
| } |
| |
| |
| if tokenizer: |
| prompt_tokens = sum(len(tokenizer.encode(msg.get("content", ""))) for msg in request.messages) |
| completion_tokens = len(tokenizer.encode(completion_text)) |
| else: |
| |
| prompt_tokens = sum(len(msg.get("content", "")) for msg in request.messages) // 4 |
| completion_tokens = len(completion_text) // 4 |
| |
| return ChatResponse( |
| id=f"chatcmpl-{int(time.time())}", |
| object="chat.completion", |
| created=int(time.time()), |
| model=request.model, |
| choices=[ |
| ChatChoice( |
| index=0, |
| message=response_message, |
| finish_reason="stop" |
| ) |
| ], |
| usage=ChatUsage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens |
| ) |
| ) |
|
|