Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
MTP 4 API - ASISTENTE AVANZADO
|
| 4 |
-
- Modelo
|
| 5 |
-
- Temperatura 0.4
|
| 6 |
-
- Sistema anti-alucinaciones
|
| 7 |
-
- Parada inteligente avanzada
|
| 8 |
"""
|
| 9 |
|
| 10 |
import os
|
|
@@ -42,73 +41,55 @@ else:
|
|
| 42 |
|
| 43 |
torch.set_grad_enabled(False)
|
| 44 |
|
| 45 |
-
# CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
|
| 46 |
MODEL_REPO = "TeszenAI/MTP-4" # Cambia por tu repo
|
| 47 |
|
| 48 |
# ======================
|
| 49 |
-
# SISTEMA ANTI-ALUCINACIONES
|
| 50 |
# ======================
|
| 51 |
class AntiHallucination:
|
| 52 |
-
"""Sistema para prevenir alucinaciones y respuestas incoherentes"""
|
| 53 |
-
|
| 54 |
def __init__(self):
|
| 55 |
self.uncertainty_words = [
|
| 56 |
'no se', 'no lo se', 'no tengo idea', 'no estoy seguro',
|
| 57 |
'no puedo responder', 'no sé', 'desconozco'
|
| 58 |
]
|
| 59 |
-
|
| 60 |
self.empty_patterns = [
|
| 61 |
r'^[.,!?;:]+$', r'^[\s]+$', r'^[0-9]+$', r'^[a-zA-Z]{1,3}$',
|
| 62 |
]
|
| 63 |
-
|
| 64 |
self.repetition_patterns = [
|
| 65 |
r'(\b\w+\b)(?:\s+\1){5,}', r'(.)\1{10,}',
|
| 66 |
]
|
| 67 |
-
|
| 68 |
-
self.max_safe_tokens = 120
|
| 69 |
self.max_safe_chars = 500
|
| 70 |
|
| 71 |
def is_hallucinating(self, text: str) -> Tuple[bool, str]:
|
| 72 |
if not text:
|
| 73 |
return True, "Respuesta vacía"
|
| 74 |
-
|
| 75 |
-
text_lower = text.lower().strip()
|
| 76 |
-
|
| 77 |
if len(text) < 5:
|
| 78 |
return True, "Respuesta demasiado corta"
|
| 79 |
-
|
| 80 |
for pattern in self.empty_patterns:
|
| 81 |
if re.match(pattern, text):
|
| 82 |
return True, "Patrón vacío detectado"
|
| 83 |
-
|
| 84 |
for pattern in self.repetition_patterns:
|
| 85 |
if re.search(pattern, text):
|
| 86 |
return True, "Repetición excesiva"
|
| 87 |
-
|
| 88 |
-
words = text_lower.split()[:5]
|
| 89 |
for uw in self.uncertainty_words:
|
| 90 |
if uw in ' '.join(words):
|
| 91 |
return True, f"Expresa incertidumbre: '{uw}'"
|
| 92 |
-
|
| 93 |
if len(text) > self.max_safe_chars:
|
| 94 |
return True, "Respuesta demasiado larga"
|
| 95 |
-
|
| 96 |
return False, "OK"
|
| 97 |
|
| 98 |
def is_coherent(self, text: str, question: str) -> Tuple[bool, str]:
|
| 99 |
if not text or not question:
|
| 100 |
return True, "Sin datos suficientes"
|
| 101 |
-
|
| 102 |
text_lower = text.lower()
|
| 103 |
question_lower = question.lower()
|
| 104 |
question_words = set(re.findall(r'\b[a-záéíóúüñ]{3,}\b', question_lower))
|
| 105 |
-
|
| 106 |
if question_words:
|
| 107 |
matches = sum(1 for w in question_words if w in text_lower)
|
| 108 |
ratio = matches / len(question_words)
|
| 109 |
if len(question_words) >= 2 and ratio < 0.2:
|
| 110 |
-
return False, f"No responde a la pregunta
|
| 111 |
-
|
| 112 |
return True, "OK"
|
| 113 |
|
| 114 |
# ======================
|
|
@@ -121,15 +102,8 @@ class CompletionState(Enum):
|
|
| 121 |
|
| 122 |
class IntelligentStopper:
|
| 123 |
def __init__(self):
|
| 124 |
-
self.completion_patterns = [
|
| 125 |
-
|
| 126 |
-
]
|
| 127 |
-
|
| 128 |
-
self.continuation_patterns = [
|
| 129 |
-
r'[,;:]\s*$', r' y $', r' o $', r' pero $', r' porque $',
|
| 130 |
-
r' además $', r' también $', r' como $',
|
| 131 |
-
]
|
| 132 |
-
|
| 133 |
self.completion_phrases = [
|
| 134 |
'gracias', 'saludos', 'adios', 'hasta luego',
|
| 135 |
'espero haberte ayudado', 'cualquier otra pregunta',
|
|
@@ -139,33 +113,26 @@ class IntelligentStopper:
|
|
| 139 |
def analyze(self, text: str, min_length: int = 40) -> Tuple[CompletionState, str]:
|
| 140 |
if not text or len(text) < min_length:
|
| 141 |
return CompletionState.INCOMPLETE, "Demasiado corto"
|
| 142 |
-
|
| 143 |
text = text.strip()
|
| 144 |
-
|
| 145 |
for pattern in self.continuation_patterns:
|
| 146 |
if re.search(pattern, text, re.IGNORECASE):
|
| 147 |
return CompletionState.INCOMPLETE, "Indica continuación"
|
| 148 |
-
|
| 149 |
text_lower = text.lower()
|
| 150 |
for phrase in self.completion_phrases:
|
| 151 |
if phrase in text_lower[-80:]:
|
| 152 |
return CompletionState.COMPLETE, "Frase de finalización"
|
| 153 |
-
|
| 154 |
for pattern in self.completion_patterns:
|
| 155 |
if re.search(pattern, text):
|
| 156 |
if len(text) > min_length:
|
| 157 |
return CompletionState.COMPLETE, "Termina naturalmente"
|
| 158 |
-
|
| 159 |
if len(text) > 350:
|
| 160 |
return CompletionState.COMPLETE, "Longitud suficiente"
|
| 161 |
-
|
| 162 |
return CompletionState.INCOMPLETE, "Puede continuar"
|
| 163 |
|
| 164 |
# ======================
|
| 165 |
-
# ARQUITECTURA MTP 4 (
|
| 166 |
# ======================
|
| 167 |
class LayerNorm(nn.Module):
|
| 168 |
-
__slots__ = ('weight', 'bias', 'eps')
|
| 169 |
def __init__(self, d_model, eps=1e-5):
|
| 170 |
super().__init__()
|
| 171 |
self.weight = nn.Parameter(torch.ones(d_model))
|
|
@@ -175,10 +142,10 @@ class LayerNorm(nn.Module):
|
|
| 175 |
return self.weight * (x - x.mean(-1, keepdim=True)) / (x.std(-1, keepdim=True) + self.eps) + self.bias
|
| 176 |
|
| 177 |
class MultiHeadAttention(nn.Module):
|
| 178 |
-
__slots__ = ('n_heads', 'd_k', 'w_q', 'w_k', 'w_v', 'w_o', 'dropout', 'scale')
|
| 179 |
def __init__(self, d_model, n_heads, dropout=0.2):
|
| 180 |
super().__init__()
|
| 181 |
assert d_model % n_heads == 0
|
|
|
|
| 182 |
self.n_heads = n_heads
|
| 183 |
self.d_k = d_model // n_heads
|
| 184 |
self.w_q = nn.Linear(d_model, d_model)
|
|
@@ -200,7 +167,6 @@ class MultiHeadAttention(nn.Module):
|
|
| 200 |
return self.w_o(out)
|
| 201 |
|
| 202 |
class FeedForward(nn.Module):
|
| 203 |
-
__slots__ = ('linear1', 'linear2', 'dropout')
|
| 204 |
def __init__(self, d_model, d_ff, dropout=0.2):
|
| 205 |
super().__init__()
|
| 206 |
self.linear1 = nn.Linear(d_model, d_ff)
|
|
@@ -210,7 +176,6 @@ class FeedForward(nn.Module):
|
|
| 210 |
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
|
| 211 |
|
| 212 |
class TransformerBlock(nn.Module):
|
| 213 |
-
__slots__ = ('attn', 'ff', 'norm1', 'norm2', 'dropout1', 'dropout2')
|
| 214 |
def __init__(self, d_model, n_heads, d_ff, dropout=0.2):
|
| 215 |
super().__init__()
|
| 216 |
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
|
|
@@ -248,6 +213,11 @@ class MTP4Model(nn.Module):
|
|
| 248 |
self.norm = LayerNorm(d_model)
|
| 249 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 250 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
def forward(self, x):
|
| 252 |
seq_len = x.size(1)
|
| 253 |
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)
|
|
@@ -261,7 +231,6 @@ class MTP4Model(nn.Module):
|
|
| 261 |
@torch.no_grad()
|
| 262 |
def generate(self, input_ids, max_new=120, temperature=0.4, top_k=30, top_p=0.85,
|
| 263 |
repetition_penalty=1.3, stopper=None):
|
| 264 |
-
"""Generación optimizada para MTP 4"""
|
| 265 |
generated = input_ids
|
| 266 |
eos_id = 3
|
| 267 |
last_tokens = []
|
|
@@ -363,6 +332,12 @@ config_path = os.path.join(repo_path, "config.json")
|
|
| 363 |
with open(config_path, "r") as f:
|
| 364 |
config = json.load(f)
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
| 367 |
sp = spm.SentencePieceProcessor()
|
| 368 |
sp.load(tokenizer_path)
|
|
@@ -371,16 +346,16 @@ config["vocab_size"] = VOCAB_SIZE
|
|
| 371 |
|
| 372 |
print(f"🧠 Inicializando MTP 4...")
|
| 373 |
print(f" → Vocabulario: {VOCAB_SIZE}")
|
| 374 |
-
print(f" → Dimensiones: {config.get('d_model', 384)}")
|
| 375 |
-
print(f" → Capas: {config.get('n_layers', 6)}")
|
| 376 |
print(f" → Dispositivo: {DEVICE.upper()}")
|
| 377 |
|
|
|
|
| 378 |
model = MTP4Model(**config)
|
| 379 |
model.to(DEVICE)
|
| 380 |
|
| 381 |
model_path = os.path.join(repo_path, "mtp_model.pt")
|
| 382 |
if os.path.exists(model_path):
|
| 383 |
state_dict = torch.load(model_path, map_location=DEVICE)
|
|
|
|
| 384 |
model.load_state_dict(state_dict, strict=False)
|
| 385 |
print("✅ Pesos del modelo cargados")
|
| 386 |
|
|
@@ -399,8 +374,6 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], all
|
|
| 399 |
|
| 400 |
class PromptRequest(BaseModel):
|
| 401 |
text: str = Field(..., max_length=2000)
|
| 402 |
-
max_tokens: int = Field(default=120, ge=20, le=200)
|
| 403 |
-
temperature: float = Field(default=0.4, ge=0.3, le=1.0)
|
| 404 |
|
| 405 |
def build_prompt(user_input: str) -> str:
|
| 406 |
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
|
@@ -430,8 +403,8 @@ async def generate(req: PromptRequest):
|
|
| 430 |
|
| 431 |
output_ids = model.generate(
|
| 432 |
input_ids,
|
| 433 |
-
max_new=
|
| 434 |
-
temperature=
|
| 435 |
top_k=30,
|
| 436 |
top_p=0.85,
|
| 437 |
repetition_penalty=1.3,
|
|
@@ -445,6 +418,7 @@ async def generate(req: PromptRequest):
|
|
| 445 |
|
| 446 |
response = sp.decode(safe_tokens).strip() if safe_tokens else ""
|
| 447 |
|
|
|
|
| 448 |
is_hallucinating, reason = anti_hallucination.is_hallucinating(response)
|
| 449 |
if is_hallucinating:
|
| 450 |
print(f"⚠️ Alucinación detectada: {reason}")
|
|
@@ -455,6 +429,7 @@ async def generate(req: PromptRequest):
|
|
| 455 |
if is_hallucinating:
|
| 456 |
response = ""
|
| 457 |
|
|
|
|
| 458 |
is_coherent, _ = anti_hallucination.is_coherent(response, user_input)
|
| 459 |
if not is_coherent and len(response) > 20:
|
| 460 |
first_sentence = response.split('.')[0] if '.' in response else response[:100]
|
|
@@ -475,6 +450,8 @@ async def generate(req: PromptRequest):
|
|
| 475 |
|
| 476 |
except Exception as e:
|
| 477 |
print(f"Error: {e}")
|
|
|
|
|
|
|
| 478 |
return {"reply": "Lo siento, ocurrió un error."}
|
| 479 |
|
| 480 |
finally:
|
|
@@ -495,13 +472,11 @@ def info():
|
|
| 495 |
"parameters": param_count,
|
| 496 |
"parameters_millions": round(param_count / 1e6, 2),
|
| 497 |
"device": DEVICE,
|
| 498 |
-
"vocab_size": VOCAB_SIZE
|
| 499 |
-
"d_model": config.get('d_model', 384),
|
| 500 |
-
"n_layers": config.get('n_layers', 6)
|
| 501 |
}
|
| 502 |
|
| 503 |
# ======================
|
| 504 |
-
# INTERFAZ WEB
|
| 505 |
# ======================
|
| 506 |
@app.get("/", response_class=HTMLResponse)
|
| 507 |
def chat_ui():
|
|
@@ -511,7 +486,7 @@ def chat_ui():
|
|
| 511 |
<head>
|
| 512 |
<meta charset="UTF-8">
|
| 513 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 514 |
-
<title>MTP 4 - Asistente IA
|
| 515 |
<style>
|
| 516 |
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 517 |
body {
|
|
@@ -527,16 +502,7 @@ def chat_ui():
|
|
| 527 |
backdrop-filter: blur(10px);
|
| 528 |
border-bottom: 1px solid rgba(255,255,255,0.1);
|
| 529 |
}
|
| 530 |
-
.header h1 {
|
| 531 |
-
color: white;
|
| 532 |
-
font-size: 1.2rem;
|
| 533 |
-
}
|
| 534 |
-
.header h1 span {
|
| 535 |
-
background: linear-gradient(135deg, #4a9eff, #ff6b6b);
|
| 536 |
-
-webkit-background-clip: text;
|
| 537 |
-
background-clip: text;
|
| 538 |
-
color: transparent;
|
| 539 |
-
}
|
| 540 |
.header p { color: #888; font-size: 0.7rem; margin-top: 4px; }
|
| 541 |
.messages {
|
| 542 |
flex: 1;
|
|
@@ -661,25 +627,25 @@ def chat_ui():
|
|
| 661 |
</head>
|
| 662 |
<body>
|
| 663 |
<div class="header">
|
| 664 |
-
<h1>🤖
|
| 665 |
-
<p>✨
|
| 666 |
</div>
|
| 667 |
<div class="suggestions">
|
| 668 |
-
<div class="suggestion">Hola
|
| 669 |
-
<div class="suggestion">¿Quién eres?
|
| 670 |
-
<div class="suggestion">¿Qué puedes hacer?
|
| 671 |
-
<div class="suggestion">Explícame la IA
|
| 672 |
-
<div class="suggestion">Háblame de BTS
|
| 673 |
-
<div class="suggestion">¿Qué es un agujero negro?
|
| 674 |
-
<div class="suggestion">Dime un chiste
|
| 675 |
-
<div class="suggestion">Adiós
|
| 676 |
</div>
|
| 677 |
<div class="messages" id="messages">
|
| 678 |
-
<div class="message bot">✨
|
| 679 |
</div>
|
| 680 |
<div class="input-area">
|
| 681 |
<input type="text" id="input" placeholder="Escribe tu pregunta..." autocomplete="off">
|
| 682 |
-
<button id="send">Enviar
|
| 683 |
</div>
|
| 684 |
<div class="badge">⚡ MTP 4 | 🌡️ 0.4 | 🛡️ Anti-alucinaciones</div>
|
| 685 |
<script>
|
|
@@ -757,7 +723,6 @@ if __name__ == "__main__":
|
|
| 757 |
port = int(os.environ.get("PORT", 7860))
|
| 758 |
print("\n" + "=" * 60)
|
| 759 |
print(f"🚀 MTP 4 en http://0.0.0.0:{port}")
|
| 760 |
-
print(f"📊 Parámetros: {param_count:,} ({param_count/1e6:.2f}M)")
|
| 761 |
print(f"🌡️ Temperatura: 0.4 | 🔁 Repetition penalty: 1.3")
|
| 762 |
print("=" * 60)
|
| 763 |
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
MTP 4 API - ASISTENTE AVANZADO
|
| 4 |
+
- Modelo: d_model=384, n_layers=6 (25M parámetros)
|
| 5 |
+
- Temperatura 0.4
|
| 6 |
+
- Sistema anti-alucinaciones
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
|
|
|
| 41 |
|
| 42 |
torch.set_grad_enabled(False)
|
| 43 |
|
|
|
|
| 44 |
MODEL_REPO = "TeszenAI/MTP-4" # Cambia por tu repo
|
| 45 |
|
| 46 |
# ======================
|
| 47 |
+
# SISTEMA ANTI-ALUCINACIONES
|
| 48 |
# ======================
|
| 49 |
class AntiHallucination:
|
|
|
|
|
|
|
| 50 |
def __init__(self):
|
| 51 |
self.uncertainty_words = [
|
| 52 |
'no se', 'no lo se', 'no tengo idea', 'no estoy seguro',
|
| 53 |
'no puedo responder', 'no sé', 'desconozco'
|
| 54 |
]
|
|
|
|
| 55 |
self.empty_patterns = [
|
| 56 |
r'^[.,!?;:]+$', r'^[\s]+$', r'^[0-9]+$', r'^[a-zA-Z]{1,3}$',
|
| 57 |
]
|
|
|
|
| 58 |
self.repetition_patterns = [
|
| 59 |
r'(\b\w+\b)(?:\s+\1){5,}', r'(.)\1{10,}',
|
| 60 |
]
|
|
|
|
|
|
|
| 61 |
self.max_safe_chars = 500
|
| 62 |
|
| 63 |
def is_hallucinating(self, text: str) -> Tuple[bool, str]:
|
| 64 |
if not text:
|
| 65 |
return True, "Respuesta vacía"
|
|
|
|
|
|
|
|
|
|
| 66 |
if len(text) < 5:
|
| 67 |
return True, "Respuesta demasiado corta"
|
|
|
|
| 68 |
for pattern in self.empty_patterns:
|
| 69 |
if re.match(pattern, text):
|
| 70 |
return True, "Patrón vacío detectado"
|
|
|
|
| 71 |
for pattern in self.repetition_patterns:
|
| 72 |
if re.search(pattern, text):
|
| 73 |
return True, "Repetición excesiva"
|
| 74 |
+
words = text.lower().split()[:5]
|
|
|
|
| 75 |
for uw in self.uncertainty_words:
|
| 76 |
if uw in ' '.join(words):
|
| 77 |
return True, f"Expresa incertidumbre: '{uw}'"
|
|
|
|
| 78 |
if len(text) > self.max_safe_chars:
|
| 79 |
return True, "Respuesta demasiado larga"
|
|
|
|
| 80 |
return False, "OK"
|
| 81 |
|
| 82 |
def is_coherent(self, text: str, question: str) -> Tuple[bool, str]:
|
| 83 |
if not text or not question:
|
| 84 |
return True, "Sin datos suficientes"
|
|
|
|
| 85 |
text_lower = text.lower()
|
| 86 |
question_lower = question.lower()
|
| 87 |
question_words = set(re.findall(r'\b[a-záéíóúüñ]{3,}\b', question_lower))
|
|
|
|
| 88 |
if question_words:
|
| 89 |
matches = sum(1 for w in question_words if w in text_lower)
|
| 90 |
ratio = matches / len(question_words)
|
| 91 |
if len(question_words) >= 2 and ratio < 0.2:
|
| 92 |
+
return False, f"No responde a la pregunta"
|
|
|
|
| 93 |
return True, "OK"
|
| 94 |
|
| 95 |
# ======================
|
|
|
|
| 102 |
|
| 103 |
class IntelligentStopper:
|
| 104 |
def __init__(self):
|
| 105 |
+
self.completion_patterns = [r'\.\s*$', r'\!?\s*$', r'\?\s*$', r'\.\.\.\s*$']
|
| 106 |
+
self.continuation_patterns = [r'[,;:]\s*$', r' y $', r' o $', r' pero $', r' porque $']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
self.completion_phrases = [
|
| 108 |
'gracias', 'saludos', 'adios', 'hasta luego',
|
| 109 |
'espero haberte ayudado', 'cualquier otra pregunta',
|
|
|
|
| 113 |
def analyze(self, text: str, min_length: int = 40) -> Tuple[CompletionState, str]:
|
| 114 |
if not text or len(text) < min_length:
|
| 115 |
return CompletionState.INCOMPLETE, "Demasiado corto"
|
|
|
|
| 116 |
text = text.strip()
|
|
|
|
| 117 |
for pattern in self.continuation_patterns:
|
| 118 |
if re.search(pattern, text, re.IGNORECASE):
|
| 119 |
return CompletionState.INCOMPLETE, "Indica continuación"
|
|
|
|
| 120 |
text_lower = text.lower()
|
| 121 |
for phrase in self.completion_phrases:
|
| 122 |
if phrase in text_lower[-80:]:
|
| 123 |
return CompletionState.COMPLETE, "Frase de finalización"
|
|
|
|
| 124 |
for pattern in self.completion_patterns:
|
| 125 |
if re.search(pattern, text):
|
| 126 |
if len(text) > min_length:
|
| 127 |
return CompletionState.COMPLETE, "Termina naturalmente"
|
|
|
|
| 128 |
if len(text) > 350:
|
| 129 |
return CompletionState.COMPLETE, "Longitud suficiente"
|
|
|
|
| 130 |
return CompletionState.INCOMPLETE, "Puede continuar"
|
| 131 |
|
| 132 |
# ======================
|
| 133 |
+
# ARQUITECTURA MTP 4 (IDÉNTICA AL ENTRENADOR)
|
| 134 |
# ======================
|
| 135 |
class LayerNorm(nn.Module):
|
|
|
|
| 136 |
def __init__(self, d_model, eps=1e-5):
|
| 137 |
super().__init__()
|
| 138 |
self.weight = nn.Parameter(torch.ones(d_model))
|
|
|
|
| 142 |
return self.weight * (x - x.mean(-1, keepdim=True)) / (x.std(-1, keepdim=True) + self.eps) + self.bias
|
| 143 |
|
| 144 |
class MultiHeadAttention(nn.Module):
|
|
|
|
| 145 |
def __init__(self, d_model, n_heads, dropout=0.2):
|
| 146 |
super().__init__()
|
| 147 |
assert d_model % n_heads == 0
|
| 148 |
+
self.d_model = d_model
|
| 149 |
self.n_heads = n_heads
|
| 150 |
self.d_k = d_model // n_heads
|
| 151 |
self.w_q = nn.Linear(d_model, d_model)
|
|
|
|
| 167 |
return self.w_o(out)
|
| 168 |
|
| 169 |
class FeedForward(nn.Module):
|
|
|
|
| 170 |
def __init__(self, d_model, d_ff, dropout=0.2):
|
| 171 |
super().__init__()
|
| 172 |
self.linear1 = nn.Linear(d_model, d_ff)
|
|
|
|
| 176 |
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
|
| 177 |
|
| 178 |
class TransformerBlock(nn.Module):
|
|
|
|
| 179 |
def __init__(self, d_model, n_heads, d_ff, dropout=0.2):
|
| 180 |
super().__init__()
|
| 181 |
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
|
|
|
|
| 213 |
self.norm = LayerNorm(d_model)
|
| 214 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 215 |
self.dropout = nn.Dropout(dropout)
|
| 216 |
+
self._init_weights()
|
| 217 |
+
def _init_weights(self):
|
| 218 |
+
for p in self.parameters():
|
| 219 |
+
if p.dim() > 1:
|
| 220 |
+
nn.init.xavier_uniform_(p)
|
| 221 |
def forward(self, x):
|
| 222 |
seq_len = x.size(1)
|
| 223 |
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)
|
|
|
|
| 231 |
@torch.no_grad()
|
| 232 |
def generate(self, input_ids, max_new=120, temperature=0.4, top_k=30, top_p=0.85,
|
| 233 |
repetition_penalty=1.3, stopper=None):
|
|
|
|
| 234 |
generated = input_ids
|
| 235 |
eos_id = 3
|
| 236 |
last_tokens = []
|
|
|
|
| 332 |
with open(config_path, "r") as f:
|
| 333 |
config = json.load(f)
|
| 334 |
|
| 335 |
+
print(f"📋 Configuración encontrada:")
|
| 336 |
+
print(f" → d_model: {config.get('d_model', 'No especificado')}")
|
| 337 |
+
print(f" → n_layers: {config.get('n_layers', 'No especificado')}")
|
| 338 |
+
print(f" → n_heads: {config.get('n_heads', 'No especificado')}")
|
| 339 |
+
print(f" → d_ff: {config.get('d_ff', 'No especificado')}")
|
| 340 |
+
|
| 341 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
| 342 |
sp = spm.SentencePieceProcessor()
|
| 343 |
sp.load(tokenizer_path)
|
|
|
|
| 346 |
|
| 347 |
print(f"🧠 Inicializando MTP 4...")
|
| 348 |
print(f" → Vocabulario: {VOCAB_SIZE}")
|
|
|
|
|
|
|
| 349 |
print(f" → Dispositivo: {DEVICE.upper()}")
|
| 350 |
|
| 351 |
+
# Crear modelo con la configuración EXACTA del archivo
|
| 352 |
model = MTP4Model(**config)
|
| 353 |
model.to(DEVICE)
|
| 354 |
|
| 355 |
model_path = os.path.join(repo_path, "mtp_model.pt")
|
| 356 |
if os.path.exists(model_path):
|
| 357 |
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 358 |
+
# Usar strict=False para permitir pequeñas diferencias
|
| 359 |
model.load_state_dict(state_dict, strict=False)
|
| 360 |
print("✅ Pesos del modelo cargados")
|
| 361 |
|
|
|
|
| 374 |
|
| 375 |
class PromptRequest(BaseModel):
|
| 376 |
text: str = Field(..., max_length=2000)
|
|
|
|
|
|
|
| 377 |
|
| 378 |
def build_prompt(user_input: str) -> str:
|
| 379 |
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
|
|
|
| 403 |
|
| 404 |
output_ids = model.generate(
|
| 405 |
input_ids,
|
| 406 |
+
max_new=100,
|
| 407 |
+
temperature=0.4,
|
| 408 |
top_k=30,
|
| 409 |
top_p=0.85,
|
| 410 |
repetition_penalty=1.3,
|
|
|
|
| 418 |
|
| 419 |
response = sp.decode(safe_tokens).strip() if safe_tokens else ""
|
| 420 |
|
| 421 |
+
# Anti-alucinaciones
|
| 422 |
is_hallucinating, reason = anti_hallucination.is_hallucinating(response)
|
| 423 |
if is_hallucinating:
|
| 424 |
print(f"⚠️ Alucinación detectada: {reason}")
|
|
|
|
| 429 |
if is_hallucinating:
|
| 430 |
response = ""
|
| 431 |
|
| 432 |
+
# Verificar coherencia
|
| 433 |
is_coherent, _ = anti_hallucination.is_coherent(response, user_input)
|
| 434 |
if not is_coherent and len(response) > 20:
|
| 435 |
first_sentence = response.split('.')[0] if '.' in response else response[:100]
|
|
|
|
| 450 |
|
| 451 |
except Exception as e:
|
| 452 |
print(f"Error: {e}")
|
| 453 |
+
import traceback
|
| 454 |
+
traceback.print_exc()
|
| 455 |
return {"reply": "Lo siento, ocurrió un error."}
|
| 456 |
|
| 457 |
finally:
|
|
|
|
| 472 |
"parameters": param_count,
|
| 473 |
"parameters_millions": round(param_count / 1e6, 2),
|
| 474 |
"device": DEVICE,
|
| 475 |
+
"vocab_size": VOCAB_SIZE
|
|
|
|
|
|
|
| 476 |
}
|
| 477 |
|
| 478 |
# ======================
|
| 479 |
+
# INTERFAZ WEB
|
| 480 |
# ======================
|
| 481 |
@app.get("/", response_class=HTMLResponse)
|
| 482 |
def chat_ui():
|
|
|
|
| 486 |
<head>
|
| 487 |
<meta charset="UTF-8">
|
| 488 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 489 |
+
<title>MTP 4 - Asistente IA</title>
|
| 490 |
<style>
|
| 491 |
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 492 |
body {
|
|
|
|
| 502 |
backdrop-filter: blur(10px);
|
| 503 |
border-bottom: 1px solid rgba(255,255,255,0.1);
|
| 504 |
}
|
| 505 |
+
.header h1 { color: white; font-size: 1.2rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
.header p { color: #888; font-size: 0.7rem; margin-top: 4px; }
|
| 507 |
.messages {
|
| 508 |
flex: 1;
|
|
|
|
| 627 |
</head>
|
| 628 |
<body>
|
| 629 |
<div class="header">
|
| 630 |
+
<h1>🤖 MTP 4 - Asistente IA</h1>
|
| 631 |
+
<p>✨ Temperatura 0.4 | Anti-alucinaciones | Respuestas precisas</p>
|
| 632 |
</div>
|
| 633 |
<div class="suggestions">
|
| 634 |
+
<div class="suggestion">Hola</div>
|
| 635 |
+
<div class="suggestion">¿Quién eres?</div>
|
| 636 |
+
<div class="suggestion">¿Qué puedes hacer?</div>
|
| 637 |
+
<div class="suggestion">Explícame la IA</div>
|
| 638 |
+
<div class="suggestion">Háblame de BTS</div>
|
| 639 |
+
<div class="suggestion">¿Qué es un agujero negro?</div>
|
| 640 |
+
<div class="suggestion">Dime un chiste</div>
|
| 641 |
+
<div class="suggestion">Adiós</div>
|
| 642 |
</div>
|
| 643 |
<div class="messages" id="messages">
|
| 644 |
+
<div class="message bot">✨ Hola, soy MTP 4. Estoy optimizado para dar respuestas coherentes y evitar alucinaciones. ¿En qué puedo ayudarte?</div>
|
| 645 |
</div>
|
| 646 |
<div class="input-area">
|
| 647 |
<input type="text" id="input" placeholder="Escribe tu pregunta..." autocomplete="off">
|
| 648 |
+
<button id="send">Enviar</button>
|
| 649 |
</div>
|
| 650 |
<div class="badge">⚡ MTP 4 | 🌡️ 0.4 | 🛡️ Anti-alucinaciones</div>
|
| 651 |
<script>
|
|
|
|
| 723 |
port = int(os.environ.get("PORT", 7860))
|
| 724 |
print("\n" + "=" * 60)
|
| 725 |
print(f"🚀 MTP 4 en http://0.0.0.0:{port}")
|
|
|
|
| 726 |
print(f"🌡️ Temperatura: 0.4 | 🔁 Repetition penalty: 1.3")
|
| 727 |
print("=" * 60)
|
| 728 |
|