| import os |
| #os.system("pip install faker duckduckgo_search") |
| import copy |
| import types |
| import gc |
| import sys |
| import re |
| import time |
| import collections |
| import asyncio |
| import random |
| from typing import List, Optional, Union, Any, Dict |
|
|
| # --- CONFIGURACI脫N DE ENTORNO --- |
| if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": |
| from modelscope import patch_hub |
| patch_hub() |
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" |
| os.environ["RWKV_V7_ON"] = "1" |
| os.environ["RWKV_JIT_ON"] = "1" |
|
|
| # --- IMPORTS --- |
| from config import CONFIG, ModelConfig |
| from utils import ( |
| cleanMessages, |
| parse_think_response, |
| remove_nested_think_tags_stack, |
| format_bytes, |
| log, |
| ) |
| from huggingface_hub import hf_hub_download |
| from loguru import logger |
| from snowflake import SnowflakeGenerator |
| import numpy as np |
| import torch |
| import requests |
|
|
| # Dependencias Opcionales |
| try: |
| from duckduckgo_search import DDGS |
| HAS_DDG = True |
| except ImportError: |
| HAS_DDG = False |
|
|
| try: |
| from faker import Faker |
| fake = Faker() |
| HAS_FAKER = True |
| except ImportError: |
| HAS_FAKER = False |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.gzip import GZipMiddleware |
| from pydantic import BaseModel, Field, model_validator |
|
|
| # --- SETUP INICIAL --- |
| CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) |
|
|
| if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): |
| CONFIG.STRATEGY = "cpu fp16" |
|
|
| if "cuda" in CONFIG.STRATEGY.lower(): |
| from pynvml import * |
| nvmlInit() |
| gpu_h = nvmlDeviceGetHandleByIndex(0) |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| os.environ["RWKV_CUDA_ON"] = "1" if CONFIG.RWKV_CUDA_ON else "0" |
| else: |
| os.environ["RWKV_CUDA_ON"] = "0" |
|
|
| from rwkv.model import RWKV |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS |
| from api_types import ( |
| ChatMessage, ChatCompletion, ChatCompletionChunk, Usage, |
| ChatCompletionChoice, ChatCompletionMessage |
| ) |
|
|
| # --- ALMACENAMIENTO DE MODELOS --- |
| class ModelStorage: |
| MODEL_CONFIG: Optional[ModelConfig] = None |
| model: Optional[RWKV] = None |
| pipeline: Optional[PIPELINE] = None |
| |
| MODEL_STORAGE: Dict[str, ModelStorage] = {} |
| DEFALUT_MODEL_NAME = None |
| DEFAULT_REASONING_MODEL_NAME = None |
| |
| for model_config in CONFIG.MODELS: |
| if model_config.MODEL_FILE_PATH is None: |
| model_config.MODEL_FILE_PATH = hf_hub_download( |
| repo_id=model_config.DOWNLOAD_MODEL_REPO_ID, |
| filename=model_config.DOWNLOAD_MODEL_FILE_NAME, |
| local_dir=model_config.DOWNLOAD_MODEL_DIR, |
| ) |
| if model_config.DEFAULT_CHAT: DEFALUT_MODEL_NAME = model_config.SERVICE_NAME |
| if model_config.DEFAULT_REASONING: DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME |
| |
| MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() |
| MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config |
| MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( |
| model=model_config.MODEL_FILE_PATH.replace(".pth", ""), |
| strategy=CONFIG.STRATEGY, |
| ) |
| MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( |
| MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB |
| ) |
| if "cuda" in CONFIG.STRATEGY: |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| # |
| class ChatCompletionRequest(BaseModel): |
| model: str = Field(default="rwkv-latest") |
| messages: Optional[List[ChatMessage]] = Field(default=None) |
| prompt: Optional[str] = Field(default=None) |
| max_tokens: Optional[int] = Field(default=None) |
| temperature: Optional[float] = Field(default=None) |
| top_p: Optional[float] = Field(default=None) |
| presence_penalty: Optional[float] = Field(default=None) |
| count_penalty: Optional[float] = Field(default=None) |
| penalty_decay: Optional[float] = Field(default=None) |
| stream: Optional[bool] = Field(default=False) |
| stop: Optional[list[str]] = Field(["\n\n"]) |
| stop_tokens: Optional[list[int]] = Field([0]) |
| |
| @model_validator(mode="before") |
| @classmethod |
| def validate_mutual_exclusivity(cls, data: Any) -> Any: |
| if not isinstance(data, dict): return data |
| if "messages" in data and "prompt" in data and data["messages"] and data["prompt"]: |
| raise ValueError("messages and prompt cannot coexist.") |
| return data |
| |
| # |
| class TruthAndFlowProtocol: |
| """ |
| Gestiona la coherencia factual y evita la repetici贸n rob贸tica. |
| """ |
| |
| SYSTEM_INSTRUCTION = """ |
| PROTOCOL: FACTUAL_AND_CONCISE |
| 1. TRUTH: Say ONLY what is verified in the context or internal knowledge. |
| 2. NO REPETITION: Do not repeat facts. Do not repeat sentence structures. |
| 3. CONCISENESS: Get to the point directly. |
| 4. LABELS: Use [VERIFICADO] for confirmed data, [INCIERTO] for contradictions. |
| 5. NO FILLER: Avoid "As an AI", "I think", "Basically". |
| """.strip() |
| |
| @staticmethod |
| def optimize_params(request: ChatCompletionRequest): |
| """ |
| Calibraci贸n fina para evitar bucles sin perder la factualidad. |
| """ |
| # Temperatura baja (0.15) pero no cero. |
| # Si es 0.0, entra en bucle seguro. 0.15 da el m铆nimo margen para variar palabras. |
| request.temperature = 0.15 |
| |
| # Top P estricto (0.1) |
| # Solo permite palabras l贸gicas. |
| request.top_p = 0.1 |
| |
| # |
| |
| # Frequency Penalty (1.2): |
| # Castigo ALTO si usas la MISMA palabra exacta muchas veces. |
| # Evita: "y y y y" o "es es es". |
| request.count_penalty = 1.2 |
| |
| # Presence Penalty (0.7): |
| # Castigo MEDIO si repites el mismo concepto. |
| # Evita decir lo mismo con otras palabras inmediatamente. |
| request.presence_penalty = 0.7 |
| |
| # Penalty Decay (0.996): |
| # "Perdona" el uso de palabras despu茅s de un rato. |
| # Necesario para que pueda volver a usar "el", "de", "que" sin bloquearse. |
| request.penalty_decay = 0.996 |
| |
| @staticmethod |
| def search_verify(query: str) -> str: |
| """B煤squeda y corroboraci贸n web.""" |
| if not HAS_DDG: return "" |
| try: |
| # B煤squeda normal |
| ddgs = DDGS() |
| results = ddgs.text(query, max_results=3) |
| |
| # B煤squeda de fact-check si es necesario |
| is_suspicious = any(w in query.lower() for w in ["verdad", "fake", "bulo", "cierto"]) |
| if is_suspicious: |
| check_res = ddgs.text(f"{query} fact check", max_results=2) |
| if check_res: results.extend(check_res) |
| |
| if not results: return "" |
| |
| context = "VERIFIED CONTEXT (Use strict labels [VERIFICADO]/[INCIERTO]):\n" |
| for r in results: |
| context += f"- {r['body']} (Source: {r['title']})\n" |
| |
| return context |
| except Exception: |
| return "" |
| |
| # |
| app = FastAPI(title="RWKV High-Fidelity Server") |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) |
| |
| @app.middleware("http") |
| async def privacy_middleware(request: Request, call_next): |
| if HAS_FAKER: |
| request.scope["client"] = (fake.ipv4(), request.client.port if request.client else 80) |
| return await call_next(request) |
| |
| # |
| search_cache = collections.OrderedDict() |
| |
| def get_context(query: str) -> str: |
| if query in search_cache: return search_cache[query] |
| ctx = TruthAndFlowProtocol.search_verify(query) |
| if len(search_cache) > 50: search_cache.popitem(last=False) |
| search_cache[query] = ctx |
| return ctx |
| |
| def needs_search(msg: str, model: str) -> bool: |
| if ":online" in model: return True |
| return any(k in msg.lower() for k in ["quien", "cuando", "donde", "precio", "es verdad", "dato"]) |
| |
| # |
| async def runPrefill(request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state): |
| ctx = ctx.replace("\r\n", "\n") |
| tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx) |
| model_tokens.extend([int(x) for x in tokens]) |
| while len(tokens) > 0: |
| out, model_state = MODEL_STORAGE[request.model].model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) |
| tokens = tokens[CONFIG.CHUNK_LEN :] |
| await asyncio.sleep(0) |
| return out, model_tokens, model_state |
| |
| def generate(request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048): |
| # Asignaci贸n correcta de penalizaciones a PIPELINE_ARGS |
| # Nota: alpha_frequency suele mapearse a count_penalty en la API de OpenAI |
| args = PIPELINE_ARGS( |
| temperature=request.temperature, |
| top_p=request.top_p, |
| alpha_frequency=request.count_penalty, # Penalizaci贸n por repetici贸n exacta |
| alpha_presence=request.presence_penalty, # Penalizaci贸n por presencia de concepto |
| token_ban=[], |
| token_stop=[0] |
| ) |
| |
| occurrence = {} |
| out_tokens = [] |
| out_last = 0 |
| cache_word_list = [] |
| |
| for i in range(max_tokens): |
| # Aplicaci贸n manual de penalizaciones al vector de logits 'out' |
| for n in occurrence: |
| out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency |
| |
| token = MODEL_STORAGE[request.model].pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) |
| |
| if token == 0: |
| yield {"content": "".join(cache_word_list), "finish_reason": "stop", "state": model_state} |
| del out; gc.collect(); return |
| |
| out, model_state = MODEL_STORAGE[request.model].model.forward([token], model_state) |
| model_tokens.append(token) |
| out_tokens.append(token) |
| |
| # Decay: La memoria de repetici贸n se desvanece lentamente |
| for xxx in occurrence: occurrence[xxx] *= request.penalty_decay |
| occurrence[token] = 1 + (occurrence.get(token, 0)) |
| |
| tmp = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:]) |
| if "\ufffd" in tmp: continue |
| cache_word_list.append(tmp) |
| out_last = i + 1 |
| |
| if len(cache_word_list) > 1: |
| yield {"content": cache_word_list.pop(0), "finish_reason": None} |
| |
| yield {"content": "".join(cache_word_list), "finish_reason": "length"} |
| |
| # |
| async def chatResponseStream(request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool): |
| clean_msg = cleanMessages(request.messages, enableReasoning) |
| prompt = f"{clean_msg}\n\nAssistant:{' <think' if enableReasoning else ''}" |
| |
| out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) |
| |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(role='Assistant', content=''), finish_reason=None)]).model_dump_json()}\n\n" |
| |
| for chunk in generate(request, out, model_tokens, model_state, max_tokens=request.max_tokens or 4096): |
| content = chunk["content"] |
| if content: |
| yield f"data: {ChatCompletionChunk(id=completionId, created=int(time.time()), model=request.model, choices=[ChatCompletionChoice(index=0, delta=ChatCompletionMessage(content=content), finish_reason=None)]).model_dump_json()}\n\n" |
| if chunk.get("finish_reason"): break |
| await asyncio.sleep(0) |
| |
| yield "data: [DONE]\n\n" |
| |
| @app.post("/v1/chat/completions") |
| @app.post("/api/v1/chat/completions") |
| async def chat_completions(request: ChatCompletionRequest): |
| completionId = str(next(CompletionIdGenerator)) |
| |
| raw_model = request.model |
| model_key = request.model.split(":")[0].replace(":online", "") |
| is_reasoning = ":thinking" in request.model |
| |
| target_model = model_key |
| if "rwkv-latest" in model_key: |
| if is_reasoning and DEFAULT_REASONING_MODEL_NAME: target_model = DEFAULT_REASONING_MODEL_NAME |
| elif DEFALUT_MODEL_NAME: target_model = DEFALUT_MODEL_NAME |
| |
| if target_model not in MODEL_STORAGE: raise HTTPException(404, "Model not found") |
| request.model = target_model |
| |
| default_sampler = MODEL_STORAGE[target_model].MODEL_CONFIG.DEFAULT_SAMPLER |
| req_data = request.model_dump() |
| for k, v in default_sampler.model_dump().items(): |
| if req_data.get(k) is None: req_data[k] = v |
| realRequest = ChatCompletionRequest(**req_data) |
| |
| # |
| |
| # 1. System Prompt Anti-Repetici贸n |
| sys_msg = ChatMessage(role="System", content=TruthAndFlowProtocol.SYSTEM_INSTRUCTION) |
| if realRequest.messages: |
| if realRequest.messages[0].role == "System": |
| realRequest.messages[0].content = f"{TruthAndFlowProtocol.SYSTEM_INSTRUCTION}\n\n{realRequest.messages[0].content}" |
| else: |
| realRequest.messages.insert(0, sys_msg) |
| |
| # 2. Inyecci贸n de Contexto (si aplica) |
| last_msg = realRequest.messages[-1] |
| if last_msg.role == "user" and needs_search(last_msg.content, raw_model): |
| ctx = get_context(last_msg.content) |
| if ctx: realRequest.messages.insert(-1, ChatMessage(role="System", content=ctx)) |
| |
| # 3. Ajuste Fino de Par谩metros (El n煤cleo anti-repetici贸n) |
| TruthAndFlowProtocol.optimize_params(realRequest) |
| |
| logger.info(f"[REQ] {completionId} | Params: T={realRequest.temperature} Freq={realRequest.count_penalty} Pres={realRequest.presence_penalty}") |
| |
| return StreamingResponse(chatResponseStream(realRequest, None, completionId, is_reasoning), media_type="text/event-stream") |
| |
| @app.get("/api/v1/models") |
| @app.get("/v1/models") |
| async def list_models(): |
| return {"object": "list", "data": [{"id": "rwkv-latest", "object": "model"}]} |
| |
| app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static") |
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT) |