import os import sys import torch import json import time import gc from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn import math import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) MODEL_REPO = "TeszenAI/MTP-2" # ====================== # DEFINIR ARQUITECTURA DEL MODELO (MTP-1.1) # ====================== class LayerNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.weight * (x - mean) / (std + self.eps) + self.bias class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, V) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) return self.w_o(attn_output) class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.linear2(self.dropout(F.gelu(self.linear1(x)))) class TransformerBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads, dropout) self.feed_forward = FeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output = self.attention(x, mask) x = x + self.dropout1(attn_output) x = self.norm1(x) ff_output = self.feed_forward(x) x = x + self.dropout2(ff_output) x = self.norm2(x) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1), :] class MTPModel(nn.Module): def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 4, d_ff: int = 512, dropout: float = 0.1, max_len: int = 256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_len = max_len self.token_embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len) self.blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) self.norm = LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): if mask is None: mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device) x = self.token_embedding(x) * math.sqrt(self.d_model) x = self.pos_encoding(x) for block in self.blocks: x = block(x, mask) x = self.norm(x) logits = self.lm_head(x) return logits def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1): """Método de generación compatible con la interfaz""" generated = input_ids for _ in range(max_new_tokens): # Obtener logits para el último token with torch.no_grad(): logits = self(generated) next_logits = logits[0, -1, :] / temperature # Aplicar repetition penalty if repetition_penalty != 1.0: for token_id in set(generated[0].tolist()): next_logits[token_id] /= repetition_penalty # Top-k filtering if top_k > 0: indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None] next_logits[indices_to_remove] = float('-inf') # Top-p filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_logits[indices_to_remove] = float('-inf') # Sampling probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() # Parar en EOS if next_token == 3: # EOS ID para SentencePiece break generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1) return generated # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo" ) # Cargar configuración config_path = os.path.join(repo_path, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: config = { "vocab_size": 5000, "d_model": 128, "n_heads": 4, "n_layers": 4, "d_ff": 512, "dropout": 0.1, "max_len": 256 } # Cargar tokenizador tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model") sp = spm.SentencePieceProcessor() sp.load(tokenizer_path) VOCAB_SIZE = sp.get_piece_size() # Actualizar vocab_size en config config["vocab_size"] = VOCAB_SIZE print(f"🧠 Inicializando modelo MTP-1.1...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → Dimensión: {config['d_model']}") print(f" → Capas: {config['n_layers']}") print(f" → Heads: {config['n_heads']}") model = MTPModel(**config) model.to(DEVICE) # Cargar pesos del modelo model_path = os.path.join(repo_path, "mtp_model.pt") if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict) print("✅ Pesos del modelo cargados") else: print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios") model.eval() # Cuantización para CPU if DEVICE == "cpu": print("⚡ Aplicando cuantización dinámica para CPU...") model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP-2 API", description="API para modelo de lenguaje MTP-1.1", version="1.1" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición") def build_prompt(user_input: str) -> str: """Construye el prompt en el formato del modelo""" return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" # ====================== # GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 class MTPTokenizer: """Wrapper para el tokenizador de SentencePiece""" def __init__(self, sp_model): self.sp = sp_model def encode(self, text): return self.sp.encode(text) def decode(self, tokens): return self.sp.decode(tokens) def bos_id(self): return self.sp.bos_id() def eos_id(self): return self.sp.eos_id() tokenizer_wrapper = MTPTokenizer(sp) @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal de generación de texto""" global ACTIVE_REQUESTS ACTIVE_REQUESTS += 1 dyn_max_tokens = req.max_tokens dyn_temperature = req.temperature if ACTIVE_REQUESTS > 2: print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.") dyn_max_tokens = min(dyn_max_tokens, 120) dyn_temperature = max(0.5, dyn_temperature * 0.9) user_input = req.text.strip() if not user_input: ACTIVE_REQUESTS -= 1 return {"reply": "", "tokens_generated": 0} full_prompt = build_prompt(user_input) tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) try: with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=dyn_max_tokens, temperature=dyn_temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty ) gen_tokens = output_ids[0, len(tokens):].tolist() safe_tokens = [ t for t in gen_tokens if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id() ] response = tokenizer_wrapper.decode(safe_tokens).strip() if "###" in response: response = response.split("###")[0].strip() return { "reply": response, "tokens_generated": len(safe_tokens), "model": "MTP-1.1" } except Exception as e: print(f"❌ Error durante generación: {e}") return { "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): return { "status": "healthy", "model": "MTP-1.1", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "vocab_size": VOCAB_SIZE } @app.get("/info") def model_info(): return { "model_name": "MTP-1.1", "version": "1.1", "architecture": config, "parameters": sum(p.numel() for p in model.parameters()), "device": DEVICE } # ====================== # INTERFAZ WEB (MODERNA DE MTP-3) # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """