| import os |
| import re |
| import gc |
| import sys |
| import time |
| import json |
| import queue |
| import random |
| import asyncio |
| import threading |
| import requests |
| import collections |
| import torch |
| import numpy as np |
| from typing import List, Optional, Dict, Any, Literal, Union |
| from pydantic import BaseModel, Field, model_validator |
| from pydantic_settings import BaseSettings |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.gzip import GZipMiddleware |
| from huggingface_hub import hf_hub_download |
| from snowflake import SnowflakeGenerator |
|
|
| if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": |
| from modelscope import patch_hub |
| patch_hub() |
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" |
| os.environ["RWKV_V7_ON"] = "1" |
| os.environ["RWKV_JIT_ON"] = "1" |
| os.environ["RWKV_CUDA_ON"] = "1" |
|
|
| GPU_LOCK = asyncio.Lock() |
|
|
| class ChatMessage(BaseModel): |
| role: str = Field() |
| content: str = Field() |
| name: Optional[str] = Field(None) |
| tool_call_id: Optional[str] = Field(None) |
|
|
| class Logprob(BaseModel): |
| token: str |
| logprob: float |
| top_logprobs: Optional[List[Dict[str, Any]]] = None |
|
|
| class LogprobsContent(BaseModel): |
| content: Optional[List[Logprob]] = None |
| refusal: Optional[List[Logprob]] = None |
|
|
| class ChatCompletionMessage(BaseModel): |
| role: Optional[str] = Field(None) |
| content: Optional[str] = Field(None) |
| reasoning_content: Optional[str] = Field(None) |
| tool_calls: Optional[List[Dict[str, Any]]] = Field(None) |
|
|
| class PromptTokensDetails(BaseModel): |
| cached_tokens: int |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
| prompt_tokens_details: Optional[PromptTokensDetails] = None |
|
|
| class ChatCompletionChoice(BaseModel): |
| index: int |
| message: Optional[ChatCompletionMessage] = None |
| delta: Optional[ChatCompletionMessage] = None |
| logprobs: Optional[LogprobsContent] = None |
| finish_reason: Optional[str] = Field(...) |
|
|
| class ChatCompletionChunk(BaseModel): |
| id: str = Field(...) |
| object: Literal["chat.completion.chunk"] = "chat.completion.chunk" |
| created: int = Field(...) |
| model: str |
| choices: List[ChatCompletionChoice] |
| usage: Optional[Usage] = None |
|
|
| class ToolFunction(BaseModel): |
| name: str |
| description: str |
| parameters: Dict[str, Any] |
|
|
| class Tool(BaseModel): |
| type: Literal["function"] = "function" |
| function: ToolFunction |
|
|
| def remove_nested_think_tags_stack(text): |
| stack = [] |
| result = "" |
| i = 0 |
| while i < len(text): |
| if text[i : i + 7] == "<think>": |
| stack.append("<think>") |
| i += 7 |
| elif text[i : i + 8] == "</think>": |
| if stack and stack[-1] == "<think>": |
| stack.pop() |
| i += 8 |
| else: |
| result += text[i : i + 8] |
| i += 8 |
| elif not stack: |
| result += text[i] |
| i += 1 |
| else: |
| i += 1 |
| return result |
|
|
| def cleanMessages(messages: List[ChatMessage], removeThinkingContent: bool = False): |
| promptStrList = [] |
| |
| if not messages: |
| return "" |
| |
| for message in messages: |
| content = message.content.strip() |
| content = re.sub(r"\n+", "\n", content) |
| role_str = message.role.strip().lower().capitalize() |
| if role_str == 'Assistant' and removeThinkingContent: |
| content = remove_nested_think_tags_stack(content) |
| |
| if message.role == "tool": |
| promptStrList.append(f"Tool Output ({message.name}): {content}") |
| elif message.role == "system": |
| promptStrList.append(f"System: {content}") |
| elif message.role == "user": |
| promptStrList.append(f"User: {content}") |
| elif message.role == "assistant": |
| promptStrList.append(f"Assistant: {content}") |
| else: |
| promptStrList.append(f"{role_str}: {content}") |
| return "\n\n".join(promptStrList) |
|
|
| class SamplerConfig(BaseModel): |
| max_tokens: int = 4096 |
| temperature: float = 1.0 |
| top_p: float = 0.3 |
| presence_penalty: float = 0.5 |
| count_penalty: float = 0.5 |
| penalty_decay: float = 0.996 |
| stop: List[str] = ["\n\n"] |
| stop_tokens: List[int] = [0] |
|
|
| class ModelConfig(BaseModel): |
| SERVICE_NAME: str |
| DOWNLOAD_MODEL_FILE_NAME: str |
| DOWNLOAD_MODEL_REPO_ID: str |
| DOWNLOAD_MODEL_DIR: str = "models" |
| MODEL_FILE_PATH: Optional[str] = None |
| DEFAULT_CHAT: bool = False |
| DEFAULT_REASONING: bool = False |
| REASONING: bool = False |
| VOCAB: str = "rwkv_vocab_v20230424" |
| CTX_LEN: int = 4096 |
| DEFAULT_SAMPLER: SamplerConfig = Field(default_factory=SamplerConfig) |
|
|
| class Config(BaseSettings): |
| HOST: str = "0.0.0.0" |
| PORT: int = 7860 |
| STRATEGY: str = "cuda fp16" |
| RWKV_CUDA_ON: bool = True |
| CHUNK_LEN: int = 256 |
| MODELS: List[ModelConfig] = [ |
| ModelConfig( |
| SERVICE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192", |
| DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a4-2.9b-20251118-ctx8192.pth", |
| DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1", |
| REASONING=True, |
| CTX_LEN=8192 |
| ), |
| ModelConfig( |
| SERVICE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192", |
| DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a3-1.5b-20251015-ctx8192.pth", |
| DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1", |
| REASONING=True, |
| CTX_LEN=8192 |
| ), |
| ModelConfig( |
| SERVICE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096", |
| DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a-0.4b-20250905-ctx4096.pth", |
| DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1", |
| REASONING=True, |
| CTX_LEN=4096 |
| ), |
| ModelConfig( |
| SERVICE_NAME="rwkv7-g1a-0.1b-20250728-ctx4096", |
| DOWNLOAD_MODEL_FILE_NAME="rwkv7-g1a-0.1b-20250728-ctx4096.pth", |
| DOWNLOAD_MODEL_REPO_ID="BlinkDL/rwkv7-g1", |
| REASONING=True, |
| DEFAULT_CHAT=True, |
| DEFAULT_REASONING=True, |
| CTX_LEN=4096 |
| ), |
| ] |
|
|
| CONFIG = Config() |
|
|
| try: |
| from duckduckgo_search import DDGS |
| HAS_DDG = True |
| except ImportError: |
| HAS_DDG = False |
|
|
| try: |
| from faker import Faker |
| fake = Faker() |
| HAS_FAKER = True |
| except ImportError: |
| HAS_FAKER = False |
|
|
| CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) |
|
|
| if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): |
| CONFIG.STRATEGY = "cpu fp16" |
| CONFIG.RWKV_CUDA_ON = False |
|
|
| if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower(): |
| from pynvml import * |
| nvmlInit() |
| os.environ["RWKV_CUDA_ON"] = "1" |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| else: |
| os.environ["RWKV_CUDA_ON"] = "0" |
|
|
| from rwkv.model import RWKV |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS |
|
|
| class ModelStorage: |
| MODEL_CONFIG: Optional[ModelConfig] = None |
| model: Optional[RWKV] = None |
| pipeline: Optional[PIPELINE] = None |
|
|
| MODEL_STORAGE: Dict[str, ModelStorage] = {} |
| DEFALUT_MODEL_NAME = None |
| DEFAULT_REASONING_MODEL_NAME = None |
|
|
| for model_config in CONFIG.MODELS: |
| if model_config.MODEL_FILE_PATH is None: |
| model_config.MODEL_FILE_PATH = hf_hub_download( |
| repo_id=model_config.DOWNLOAD_MODEL_REPO_ID, |
| filename=model_config.DOWNLOAD_MODEL_FILE_NAME, |
| local_dir=model_config.DOWNLOAD_MODEL_DIR, |
| ) |
| if model_config.DEFAULT_CHAT: |
| DEFALUT_MODEL_NAME = model_config.SERVICE_NAME |
| if model_config.DEFAULT_REASONING: |
| DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME |
| |
| MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() |
| MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config |
| MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( |
| model=model_config.MODEL_FILE_PATH.replace(".pth", ""), |
| strategy=CONFIG.STRATEGY, |
| ) |
| MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( |
| MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB |
| ) |
| if "cuda" in CONFIG.STRATEGY: |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = Field(default="rwkv-latest") |
| messages: Optional[List[ChatMessage]] = Field(default=None) |
| prompt: Optional[str] = Field(default=None) |
| max_tokens: Optional[int] = Field(default=None) |
| temperature: Optional[float] = Field(default=None) |
| top_p: Optional[float] = Field(default=None) |
| presence_penalty: Optional[float] = Field(default=None) |
| count_penalty: Optional[float] = Field(default=None) |
| penalty_decay: Optional[float] = Field(default=None) |
| stream: Optional[bool] = Field(default=False) |
| stop: Optional[List[str]] = Field(["\n\n"]) |
| stop_tokens: Optional[List[int]] = Field([0]) |
| tools: Optional[List[Tool]] = Field(default=None) |
| tool_choice: Optional[Union[str, Dict]] = Field(default="auto") |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def validate_mutual_exclusivity(cls, data: Any) -> Any: |
| if not isinstance(data, dict): return data |
| if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]: |
| raise ValueError("messages and prompt cannot coexist.") |
| return data |
|
|
| class ToolEngine: |
| TOOL_SYSTEM_PROMPT = """ |
| CAPABILITY: You have access to real-time tools. |
| INSTRUCTION: To use a tool, output exactly: <call>tool_name("argument")</call> |
| Do not describe the tool, just call it. After the System provides the result, synthesize the answer. |
| |
| AVAILABLE TOOLS: |
| 1. google_search(query): Searches Google and DuckDuckGo for real-time information. |
| 2. visit_page(url): Accesses a specific link, reads the text, and finds sub-links. |
| """.strip() |
|
|
| @staticmethod |
| def google_search_request(query: str) -> str: |
| try: |
| headers = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"} |
| resp = requests.get("https://www.google.com/search", params={"q": query, "gl": "us", "hl": "en"}, headers=headers, timeout=6) |
| |
| if resp.status_code != 200: raise Exception("Google blocked request") |
| |
| clean_text = re.sub(r'<script.*?>.*?</script>', '', resp.text, flags=re.DOTALL) |
| clean_text = re.sub(r'<style.*?>.*?</style>', '', clean_text, flags=re.DOTALL) |
| |
| headings = re.findall(r'<h3.*?>(.*?)</h3>', clean_text) |
| links = re.findall(r'<a href="/url\?q=(.*?)&', clean_text) |
| |
| limit = min(len(headings), len(links), 5) |
| output = "Google Results:\n" |
| for i in range(limit): |
| output += f"{i+1}. {re.sub(r'<.*?>', '', headings[i])} - Link: {links[i]}\n" |
| |
| if not headings: |
| return ToolEngine.duckduckgo_fallback(query) |
| |
| return output |
| except: |
| return ToolEngine.duckduckgo_fallback(query) |
|
|
| @staticmethod |
| def duckduckgo_fallback(query: str) -> str: |
| try: |
| if HAS_DDG: |
| res = DDGS().text(query, max_results=5) |
| return "\n".join([f"- {r['title']}: {r['body']} ({r['href']})" for r in res]) |
| |
| resp = requests.get("https://html.duckduckgo.com/html/", params={"q": query}, headers={"User-Agent": "Mozilla/5.0"}, timeout=5) |
| titles = re.findall(r'<a class="result__a"[^>]*>(.*?)</a>', resp.text) |
| snippets = re.findall(r'<a class="result__snippet"[^>]*>(.*?)</a>', resp.text) |
| |
| limit = min(len(titles), len(snippets), 4) |
| out = "DuckDuckGo HTML Results:\n" |
| for i in range(limit): |
| t = re.sub(r'<.*?>', '', titles[i]).strip() |
| s = re.sub(r'<.*?>', '', snippets[i]).strip() |
| out += f"{i+1}. {t}: {s}\n" |
| return out |
| except Exception as e: |
| return f"Search failed: {str(e)}" |
|
|
| @staticmethod |
| def visit_page(url: str) -> str: |
| try: |
| headers = {"User-Agent": "Mozilla/5.0 (compatible; RWKV-Bot/1.0)"} |
| resp = requests.get(url, headers=headers, timeout=8) |
| resp.encoding = resp.apparent_encoding |
| |
| text = re.sub(r'<head.*?>.*?</head>', '', resp.text, flags=re.DOTALL) |
| text = re.sub(r'<script.*?>.*?</script>', '', text, flags=re.DOTALL) |
| text = re.sub(r'<style.*?>.*?</style>', '', text, flags=re.DOTALL) |
| text = re.sub(r'<!--.*?-->', '', text, flags=re.DOTALL) |
| text = re.sub(r'<[^>]+>', ' ', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| |
| links = re.findall(r'href=["\'](http[s]?://[^"\']+)["\']', resp.text) |
| unique_links = list(set(links))[:5] |
| |
| content_preview = text[:3000] + ("..." if len(text) > 3000 else "") |
| |
| return f"PAGE CONTENT ({url}):\n{content_preview}\n\nFOUND SUB-LINKS:\n" + "\n".join(unique_links) |
| except Exception as e: |
| return f"Error visiting page: {str(e)}" |
|
|
| @staticmethod |
| def execute(call_str: str) -> str: |
| try: |
| match = re.match(r'(\w+)\(["\'](.*?)["\']\)', call_str) |
| if not match: return "Invalid tool call syntax." |
| |
| func, arg = match.groups() |
| |
| if func == "google_search": |
| return ToolEngine.google_search_request(arg) |
| elif func == "visit_page": |
| return ToolEngine.visit_page(arg) |
| else: |
| return f"Unknown tool: {func}" |
| except Exception as e: |
| return f"Tool execution error: {e}" |
|
|
| app = FastAPI(title="RWKV Ultimate Agent Server") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) |
|
|
| @app.middleware("http") |
| async def privacy_middleware(request: Request, call_next): |
| if HAS_FAKER: |
| request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80) |
| return await call_next(request) |
|
|
| def prune_context(messages: List[ChatMessage], model_name: str, max_gen_tokens: int): |
| storage = MODEL_STORAGE[model_name] |
| limit = storage.MODEL_CONFIG.CTX_LEN |
| pipeline = storage.pipeline |
| |
| current_text = cleanMessages(messages) |
| tokens = pipeline.encode(current_text) |
| |
| if len(tokens) + max_gen_tokens < limit: |
| return messages |
| |
| system_msgs = [m for m in messages if m.role == "System"] |
| other_msgs = [m for m in messages if m.role != "System"] |
| |
| while len(other_msgs) > 1: |
| candidate_text = cleanMessages(system_msgs + other_msgs) |
| if len(pipeline.encode(candidate_text)) + max_gen_tokens < limit: |
| break |
| other_msgs.pop(0) |
| |
| return system_msgs + other_msgs |
|
|
| async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state): |
| ctx = ctx.replace("\r\n", "\n") |
| tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx) |
| model_tokens.extend([int(x) for x in tokens]) |
| while len(tokens) > 0: |
| out, model_state = MODEL_STORAGE[request.model].model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) |
| tokens = tokens[CONFIG.CHUNK_LEN :] |
| await asyncio.sleep(0) |
| return out, model_tokens, model_state |
|
|
| def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048): |
| args = PIPELINE_ARGS( |
| temperature=request.temperature, |
| top_p=request.top_p, |
| alpha_frequency=request.count_penalty, |
| alpha_presence=request.presence_penalty, |
| token_ban=[], token_stop=[0] |
| ) |
| occurrence = {} |
| out_tokens = [] |
| out_last = 0 |
| cache_word_list = [] |
| |
| stop_sequences = request.stop if request.stop else [] |
| stop_sequences.append("<call>") |
|
|
| for i in range(max_tokens): |
| for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency |
| token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) |
| |
| if token == 0: |
| yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state} |
| del out; gc.collect(); return |
|
|
| out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state) |
| model_tokens.append(token) |
| out_tokens.append(token) |
| |
| for xxx in occurrence: occurrence[xxx] *= request.penalty_decay |
| occurrence[token] = 1 + (occurrence.get(token, 0)) |
| |
| tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:]) |
| if "\ufffd" in tmp: continue |
| cache_word_list.append(tmp) |
| out_last = i + 1 |
| |
| current_buffer = "".join(cache_word_list) |
| |
| if "<call>" in current_buffer: |
| pre_call = current_buffer.split("<call>")[0] |
| yield {"content": pre_call, "finish_reason": "tool_start", "state": model_state} |
| del out; gc.collect(); return |
|
|
| for s in stop_sequences: |
| if s in current_buffer and s != "<call>": |
| final_content = current_buffer.split(s)[0] |
| yield {"content": final_content, "finish_reason": "stop", "state": model_state} |
| del out; gc.collect(); return |
|
|
| if len(cache_word_list) > 2: |
| yield {"content": cache_word_list.pop(0), "finish_reason": None} |
| |
| yield {"content": "".join(cache_word_list), "finish_reason": "length"} |
|
|
| async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool): |
| current_messages = request.messages |
| |
| for step in range(4): |
| clean_msg = cleanMessages(current_messages, enableReasoning) |
| prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}" |
| |
| tool_call_mode = False |
| |
| async with GPU_LOCK: |
| try: |
| out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) |
| |
| if step == 0: |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n" |
| |
| for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096): |
| content = chunk.get("content", "") |
| finish = chunk.get("finish_reason", None) |
| |
| if finish == "tool_start": |
| tool_call_mode = True |
| if content: |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n" |
| break |
| |
| if content: |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n" |
| |
| if finish: |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=''), finish_reason=finish)]).model_dump_json()}\n\n" |
| return |
|
|
| finally: |
| pass |
| |
| if tool_call_mode: |
| full_tool_call = "" |
| |
| async with GPU_LOCK: |
| try: |
| tool_out, tool_tokens, tool_state = await runPrefill(request, "", [0], model_state) |
| current_gen = "" |
| |
| for i in range(200): |
| tool_token = MODEL_STORAGE[request.model].pipeline.sample_logits(tool_out, temperature=0.1, top_p=0.1) |
| tool_out, tool_state = MODEL_STORAGE[request.model].model.forward([tool_token], tool_state) |
| |
| char = MODEL_STORAGE[request.model].pipeline.decode([tool_token]) |
| current_gen += char |
| |
| if "</call>" in current_gen: |
| full_tool_call = current_gen.split("</call>")[0] |
| break |
| finally: |
| pass |
| |
| if full_tool_call: |
| result = ToolEngine.execute(full_tool_call) |
| current_messages.append(ChatMessage(role="assistant", content=f"<call>{full_tool_call}</call>")) |
| current_messages.append(ChatMessage(role="tool", content=result, name="system")) |
| else: |
| break |
| else: |
| break |
| |
| yield "data: [DONE]\n\n" |
|
|
| @app.post("/v1/chat/completions") |
| @app.post("/v1/chat/") |
| @app.post("/v1/completions") |
| @app.post("/v1/responses") |
| @app.post("/responses") |
| @app.post("/api/generate") |
| @app.post("/api/v1/chat/completions") |
| async def chat_completions(request: ChatCompletionRequest): |
| completionId = str(next(CompletionIdGenerator)) |
| raw_model = request.model |
| model_key = request.model.split(":")[0].replace(":online", "") |
| is_reasoning = ":thinking" in request.model |
| target_model = model_key |
| if "rwkv-latest" in model_key: |
| if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME |
| elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME |
| |
| if target_model not in MODEL_STORAGE: |
| raise HTTPException(404, f"Model {target_model} not loaded.") |
| request.model = target_model |
|
|
| default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER |
| req_data = request.model_dump() |
| for k, v in default_sampler.model_dump().items(): |
| if req_data.get(k) is None: req_data[k] = v |
| realRequest = ChatCompletionRequest(**req_data) |
|
|
| |
| if realRequest.messages is None: |
| if realRequest.prompt: |
| realRequest.messages = [ChatMessage(role="user", content=realRequest.prompt)] |
| else: |
| |
| realRequest.messages = [] |
|
|
| enable_tools = ":online" in raw_model or realRequest.tools is not None |
| |
| if enable_tools: |
| sys_msg = ChatMessage(role="System", content=ToolEngine.TOOL_SYSTEM_PROMPT) |
| if realRequest.messages: |
| if realRequest.messages[0].role == "System": |
| realRequest.messages[0].content += f"\n\n{ToolEngine.TOOL_SYSTEM_PROMPT}" |
| else: |
| realRequest.messages.insert(0, sys_msg) |
| else: |
| realRequest.messages.append(sys_msg) |
|
|
| realRequest.messages = prune_context(realRequest.messages, target_model, realRequest.max_tokens or 1024) |
|
|
| return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream") |
|
|
| @app.get("/api/v1/models") |
| @app.get("/v1/models") |
| @app.get("/models") |
| async def list_models(): |
| models_list = [] |
| ts = int(time.time()) |
| for model_id in MODEL_STORAGE.keys(): |
| models_list.append({"id": model_id, "object": "model", "created": ts, "owned_by": "rwkv-server"}) |
| models_list.append({"id": f"{model_id}:online", "object": "model", "created": ts, "owned_by": "rwkv-server"}) |
| if DEFALUT_MODEL_NAME: |
| models_list.append({"id": "rwkv-latest", "object": "model", "created": ts, "owned_by": "rwkv-system"}) |
| models_list.append({"id": "rwkv-latest:online", "object": "model", "created": ts, "owned_by": "rwkv-system"}) |
| if DEFAULT_REASONING_MODEL_NAME: |
| models_list.append({"id": "rwkv-latest:thinking", "object": "model", "created": ts, "owned_by": "rwkv-system"}) |
| models_list.append({"id": "rwkv-latest:thinking:online", "object": "model", "created": ts, "owned_by": "rwkv-system"}) |
| return {"object": "list", "data": models_list} |
|
|
| app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT) |