MTP3.2 / app.py
teszenofficial's picture
Update app.py
65869a7 verified
import os
import sys
import torch
import json
import time
import gc
import re
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
# ======================
# CONFIGURACIÓN DE DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
if DEVICE == "cpu":
torch.set_num_threads(max(1, os.cpu_count() // 2))
torch.set_grad_enabled(False)
# CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
MODEL_REPO = "TeszenAI/MTP-3.2.1"
# ======================
# FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
# ======================
def truncate_greeting_response(text: str) -> str:
"""
Para respuestas de saludo, trunca SOLO en el primer PUNTO (.)
No usa signos de exclamación o interrogación.
"""
if not text:
return text
# Buscar el primer PUNTO (.)
end_match = re.search(r'\.', text)
if end_match:
# Cortar justo después del punto
end_pos = end_match.end()
truncated = text[:end_pos].strip()
return truncated
# Si no hay punto, devolver solo primeras 80 caracteres
if len(text) > 80:
return text[:80] + "..."
return text
def clean_response(text: str, user_input: str = "") -> str:
"""Limpia la respuesta del modelo"""
if not text:
return ""
# Eliminar repeticiones excesivas
words = text.split()
cleaned_words = []
last_word = ""
repeat_count = 0
for word in words:
if word == last_word:
repeat_count += 1
if repeat_count > 2:
continue
else:
last_word = word
repeat_count = 0
cleaned_words.append(word)
text = " ".join(cleaned_words)
# Eliminar caracteres raros
text = re.sub(r'(.)\1{4,}', r'\1\1', text)
# Detectar si es un saludo
is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
if is_greeting and text:
# Para saludos, truncar SOLO en el primer PUNTO (.)
punct_match = re.search(r'\.', text)
if punct_match:
text = text[:punct_match.end()].strip()
else:
# Si no hay punto, tomar solo la primera oración o 60 caracteres
first_sentence = text.split('.')[0].strip()
if len(first_sentence) > 5:
text = first_sentence
elif len(text) > 60:
text = text[:60]
# Si la respuesta es muy corta o vacía
if len(text.strip()) < 5:
if is_greeting:
return "¡Hola! ¿En qué puedo ayudarte?"
return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
# Eliminar espacios múltiples
text = re.sub(r'\s+', ' ', text).strip()
return text
# ======================
# DEFINIR ARQUITECTURA DEL MODELO (MTP)
# ======================
class LayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.weight * (x - mean) / (std + self.eps) + self.bias
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.w_o(attn_output)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output = self.attention(x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
ff_output = self.feed_forward(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
class MTPModel(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_len = max_len
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
if mask is None:
mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
x = self.token_embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)
logits = self.lm_head(x)
return logits
def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
"""Genera texto token por token"""
generated = input_ids
for step in range(max_new_tokens):
with torch.no_grad():
logits = self(generated)
next_logits = logits[0, -1, :] / temperature
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
next_logits[token_id] /= repetition_penalty
if top_k > 0:
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
next_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
# EOS ID común para SentencePiece
if next_token == 2 or next_token == 3:
break
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
return generated
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
print(f"📦 Descargando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtp_repo"
)
# Cargar configuración
config_path = os.path.join(repo_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
config = {
"vocab_size": 5000,
"d_model": 256,
"n_heads": 8,
"n_layers": 6,
"d_ff": 1024,
"dropout": 0.1,
"max_len": 512
}
# Cargar tokenizador
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
if not os.path.exists(tokenizer_path):
print(f"❌ Tokenizador no encontrado en {tokenizer_path}")
sys.exit(1)
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_path)
VOCAB_SIZE = sp.get_piece_size()
# Actualizar vocab_size en config
config["vocab_size"] = VOCAB_SIZE
print(f"🧠 Inicializando modelo MTP...")
print(f" → Vocabulario: {VOCAB_SIZE}")
print(f" → Dimensión: {config['d_model']}")
print(f" → Capas: {config['n_layers']}")
print(f" → Heads: {config['n_heads']}")
model = MTPModel(**config)
model.to(DEVICE)
# Cargar pesos del modelo
model_path = os.path.join(repo_path, "mtp_model.pt")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
print("✅ Pesos del modelo cargados")
else:
print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios")
model.eval()
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
# ======================
# API CONFIG
# ======================
app = FastAPI(
title="MTP API",
description="API para modelo de lenguaje MTP",
version="1.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000, description="Texto de entrada")
max_tokens: int = Field(default=150, ge=10, le=250, description="Tokens máximos a generar")
temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
def build_prompt(user_input: str) -> str:
"""Construye el prompt en el formato del modelo"""
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
# ======================
# GESTIÓN DE CARGA
# ======================
ACTIVE_REQUESTS = 0
class MTPTokenizer:
"""Wrapper para el tokenizador de SentencePiece"""
def __init__(self, sp_model):
self.sp = sp_model
def encode(self, text):
return self.sp.encode(text)
def decode(self, tokens):
return self.sp.decode(tokens)
def bos_id(self):
return self.sp.bos_id()
def eos_id(self):
return self.sp.eos_id()
def pad_id(self):
return self.sp.pad_id()
tokenizer_wrapper = MTPTokenizer(sp)
@app.post("/generate")
async def generate(req: PromptRequest):
"""Endpoint principal de generación de texto"""
global ACTIVE_REQUESTS
ACTIVE_REQUESTS += 1
user_input = req.text.strip()
if not user_input:
ACTIVE_REQUESTS -= 1
return {"reply": "", "tokens_generated": 0}
# Detectar si es un saludo
is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
# Si es saludo, usar menos tokens
max_tokens = 30 if is_greeting else req.max_tokens
full_prompt = build_prompt(user_input)
tokens = tokenizer_wrapper.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
try:
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=req.temperature,
top_k=req.top_k,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty
)
gen_tokens = output_ids[0, len(tokens):].tolist()
# Filtrar tokens inválidos
safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE]
if safe_tokens:
response = tokenizer_wrapper.decode(safe_tokens).strip()
else:
response = ""
# Limpiar respuesta
response = clean_response(response, user_input)
# Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
if len(response) < 3:
if is_greeting:
response = "¡Hola! ¿En qué puedo ayudarte?"
else:
response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?"
return {
"reply": response,
"tokens_generated": len(safe_tokens),
"model": "MTP"
}
except Exception as e:
print(f"❌ Error durante generación: {e}")
if is_greeting:
fallback = "¡Hola! ¿En qué puedo ayudarte?"
else:
fallback = "Lo siento, ocurrió un error al procesar tu solicitud."
return {
"reply": fallback,
"error": str(e)
}
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
# ======================
# ENDPOINTS DE INFORMACIÓN
# ======================
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model": "MTP",
"device": DEVICE,
"active_requests": ACTIVE_REQUESTS,
"vocab_size": VOCAB_SIZE
}
@app.get("/info")
def model_info():
return {
"model_name": "MTP",
"version": "1.0",
"architecture": config,
"parameters": sum(p.numel() for p in model.parameters()),
"device": DEVICE
}
# ======================
# INTERFAZ WEB
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MTP - Asistente IA</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
background: #131314;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
height: 100vh;
display: flex;
flex-direction: column;
}
.chat-header {
padding: 16px 20px;
background: #1E1F20;
border-bottom: 1px solid #2a2b2e;
}
.chat-header h1 {
color: white;
font-size: 1.2rem;
font-weight: 500;
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 20px;
display: flex;
flex-direction: column;
gap: 16px;
}
.message {
display: flex;
gap: 12px;
max-width: 80%;
}
.message.user {
align-self: flex-end;
flex-direction: row-reverse;
}
.message-content {
padding: 10px 16px;
border-radius: 18px;
font-size: 0.95rem;
line-height: 1.4;
}
.user .message-content {
background: #4a9eff;
color: white;
border-radius: 18px 4px 18px 18px;
}
.bot .message-content {
background: #1E1F20;
color: #e3e3e3;
border-radius: 4px 18px 18px 18px;
}
.chat-input-container {
padding: 16px 20px;
background: #1E1F20;
border-top: 1px solid #2a2b2e;
}
.input-wrapper {
display: flex;
gap: 12px;
max-width: 800px;
margin: 0 auto;
}
#messageInput {
flex: 1;
padding: 12px 16px;
background: #2a2b2e;
border: none;
border-radius: 24px;
color: white;
font-size: 0.95rem;
outline: none;
}
#messageInput::placeholder {
color: #888;
}
#sendBtn {
padding: 12px 24px;
background: #4a9eff;
border: none;
border-radius: 24px;
color: white;
font-weight: 500;
cursor: pointer;
transition: opacity 0.2s;
}
#sendBtn:hover { opacity: 0.9; }
#sendBtn:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.typing {
display: flex;
gap: 4px;
padding: 10px 16px;
}
.typing span {
width: 8px;
height: 8px;
background: #888;
border-radius: 50%;
animation: bounce 1.4s infinite ease-in-out;
}
.typing span:nth-child(1) { animation-delay: -0.32s; }
.typing span:nth-child(2) { animation-delay: -0.16s; }
@keyframes bounce {
0%, 80%, 100% { transform: scale(0); }
40% { transform: scale(1); }
}
</style>
</head>
<body>
<div class="chat-header">
<h1>🤖 MTP - Asistente IA</h1>
</div>
<div class="chat-messages" id="chatMessages">
<div class="message bot">
<div class="message-content">¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?</div>
</div>
</div>
<div class="chat-input-container">
<div class="input-wrapper">
<input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
<button id="sendBtn">Enviar</button>
</div>
</div>
<script>
const chatMessages = document.getElementById('chatMessages');
const messageInput = document.getElementById('messageInput');
const sendBtn = document.getElementById('sendBtn');
let isLoading = false;
function addMessage(text, isUser) {
const div = document.createElement('div');
div.className = `message ${isUser ? 'user' : 'bot'}`;
div.innerHTML = `<div class="message-content">${text}</div>`;
chatMessages.appendChild(div);
chatMessages.scrollTop = chatMessages.scrollHeight;
return div;
}
function addTypingIndicator() {
const div = document.createElement('div');
div.className = 'message bot';
div.id = 'typingIndicator';
div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
chatMessages.appendChild(div);
chatMessages.scrollTop = chatMessages.scrollHeight;
}
function removeTypingIndicator() {
const indicator = document.getElementById('typingIndicator');
if (indicator) indicator.remove();
}
async function sendMessage() {
const text = messageInput.value.trim();
if (!text || isLoading) return;
messageInput.value = '';
addMessage(text, true);
isLoading = true;
sendBtn.disabled = true;
addTypingIndicator();
try {
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text })
});
const data = await response.json();
removeTypingIndicator();
addMessage(data.reply, false);
} catch (error) {
removeTypingIndicator();
addMessage('Error de conexión. Intenta de nuevo.', false);
} finally {
isLoading = false;
sendBtn.disabled = false;
messageInput.focus();
}
}
messageInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter') sendMessage();
});
sendBtn.addEventListener('click', sendMessage);
messageInput.focus();
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)