import os import sys import torch import json import time import gc import re from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn import math import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE MODEL_REPO = "TeszenAI/MTP-3.2.1" # ====================== # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD # ====================== def truncate_greeting_response(text: str) -> str: """ Para respuestas de saludo, trunca SOLO en el primer PUNTO (.) No usa signos de exclamación o interrogación. """ if not text: return text # Buscar el primer PUNTO (.) end_match = re.search(r'\.', text) if end_match: # Cortar justo después del punto end_pos = end_match.end() truncated = text[:end_pos].strip() return truncated # Si no hay punto, devolver solo primeras 80 caracteres if len(text) > 80: return text[:80] + "..." return text def clean_response(text: str, user_input: str = "") -> str: """Limpia la respuesta del modelo""" if not text: return "" # Eliminar repeticiones excesivas words = text.split() cleaned_words = [] last_word = "" repeat_count = 0 for word in words: if word == last_word: repeat_count += 1 if repeat_count > 2: continue else: last_word = word repeat_count = 0 cleaned_words.append(word) text = " ".join(cleaned_words) # Eliminar caracteres raros text = re.sub(r'(.)\1{4,}', r'\1\1', text) # Detectar si es un saludo is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"] if is_greeting and text: # Para saludos, truncar SOLO en el primer PUNTO (.) punct_match = re.search(r'\.', text) if punct_match: text = text[:punct_match.end()].strip() else: # Si no hay punto, tomar solo la primera oración o 60 caracteres first_sentence = text.split('.')[0].strip() if len(first_sentence) > 5: text = first_sentence elif len(text) > 60: text = text[:60] # Si la respuesta es muy corta o vacía if len(text.strip()) < 5: if is_greeting: return "¡Hola! ¿En qué puedo ayudarte?" return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?" # Eliminar espacios múltiples text = re.sub(r'\s+', ' ', text).strip() return text # ====================== # DEFINIR ARQUITECTURA DEL MODELO (MTP) # ====================== class LayerNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.weight * (x - mean) / (std + self.eps) + self.bias class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, V) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) return self.w_o(attn_output) class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.linear2(self.dropout(F.gelu(self.linear1(x)))) class TransformerBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads, dropout) self.feed_forward = FeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output = self.attention(x, mask) x = x + self.dropout1(attn_output) x = self.norm1(x) ff_output = self.feed_forward(x) x = x + self.dropout2(ff_output) x = self.norm2(x) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1), :] class MTPModel(nn.Module): def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8, n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_len = max_len self.token_embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len) self.blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) self.norm = LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): if mask is None: mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device) x = self.token_embedding(x) * math.sqrt(self.d_model) x = self.pos_encoding(x) for block in self.blocks: x = block(x, mask) x = self.norm(x) logits = self.lm_head(x) return logits def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1): """Genera texto token por token""" generated = input_ids for step in range(max_new_tokens): with torch.no_grad(): logits = self(generated) next_logits = logits[0, -1, :] / temperature if repetition_penalty != 1.0: for token_id in set(generated[0].tolist()): next_logits[token_id] /= repetition_penalty if top_k > 0: indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None] next_logits[indices_to_remove] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() # EOS ID común para SentencePiece if next_token == 2 or next_token == 3: break generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1) return generated # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo" ) # Cargar configuración config_path = os.path.join(repo_path, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: config = { "vocab_size": 5000, "d_model": 256, "n_heads": 8, "n_layers": 6, "d_ff": 1024, "dropout": 0.1, "max_len": 512 } # Cargar tokenizador tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model") if not os.path.exists(tokenizer_path): print(f"❌ Tokenizador no encontrado en {tokenizer_path}") sys.exit(1) sp = spm.SentencePieceProcessor() sp.load(tokenizer_path) VOCAB_SIZE = sp.get_piece_size() # Actualizar vocab_size en config config["vocab_size"] = VOCAB_SIZE print(f"🧠 Inicializando modelo MTP...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → Dimensión: {config['d_model']}") print(f" → Capas: {config['n_layers']}") print(f" → Heads: {config['n_heads']}") model = MTPModel(**config) model.to(DEVICE) # Cargar pesos del modelo model_path = os.path.join(repo_path, "mtp_model.pt") if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict, strict=False) print("✅ Pesos del modelo cargados") else: print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios") model.eval() param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP API", description="API para modelo de lenguaje MTP", version="1.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=150, ge=10, le=250, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición") def build_prompt(user_input: str) -> str: """Construye el prompt en el formato del modelo""" return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" # ====================== # GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 class MTPTokenizer: """Wrapper para el tokenizador de SentencePiece""" def __init__(self, sp_model): self.sp = sp_model def encode(self, text): return self.sp.encode(text) def decode(self, tokens): return self.sp.decode(tokens) def bos_id(self): return self.sp.bos_id() def eos_id(self): return self.sp.eos_id() def pad_id(self): return self.sp.pad_id() tokenizer_wrapper = MTPTokenizer(sp) @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal de generación de texto""" global ACTIVE_REQUESTS ACTIVE_REQUESTS += 1 user_input = req.text.strip() if not user_input: ACTIVE_REQUESTS -= 1 return {"reply": "", "tokens_generated": 0} # Detectar si es un saludo is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"] # Si es saludo, usar menos tokens max_tokens = 30 if is_greeting else req.max_tokens full_prompt = build_prompt(user_input) tokens = tokenizer_wrapper.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) try: with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=req.temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty ) gen_tokens = output_ids[0, len(tokens):].tolist() # Filtrar tokens inválidos safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE] if safe_tokens: response = tokenizer_wrapper.decode(safe_tokens).strip() else: response = "" # Limpiar respuesta response = clean_response(response, user_input) # Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto if len(response) < 3: if is_greeting: response = "¡Hola! ¿En qué puedo ayudarte?" else: response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?" return { "reply": response, "tokens_generated": len(safe_tokens), "model": "MTP" } except Exception as e: print(f"❌ Error durante generación: {e}") if is_greeting: fallback = "¡Hola! ¿En qué puedo ayudarte?" else: fallback = "Lo siento, ocurrió un error al procesar tu solicitud." return { "reply": fallback, "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): return { "status": "healthy", "model": "MTP", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "vocab_size": VOCAB_SIZE } @app.get("/info") def model_info(): return { "model_name": "MTP", "version": "1.0", "architecture": config, "parameters": sum(p.numel() for p in model.parameters()), "device": DEVICE } # ====================== # INTERFAZ WEB # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """