teszenofficial commited on
Commit
e332d47
·
verified ·
1 Parent(s): 56f9c09

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +823 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import json
5
+ import time
6
+ import gc
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.responses import HTMLResponse, StreamingResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+ from huggingface_hub import snapshot_download
12
+ import uvicorn
13
+ import math
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import sentencepiece as spm
17
+
18
+ # ======================
19
+ # CONFIGURACIÓN DE DISPOSITIVO
20
+ # ======================
21
+ if torch.cuda.is_available():
22
+ DEVICE = "cuda"
23
+ print("✅ GPU NVIDIA detectada. Usando CUDA.")
24
+ else:
25
+ DEVICE = "cpu"
26
+ print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
27
+
28
+ if DEVICE == "cpu":
29
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
30
+
31
+ torch.set_grad_enabled(False)
32
+
33
+ MODEL_REPO = "TeszenAI/MTP-1.1"
34
+
35
+ # ======================
36
+ # DEFINIR ARQUITECTURA DEL MODELO (MTP-1.1)
37
+ # ======================
38
+ class LayerNorm(nn.Module):
39
+ def __init__(self, d_model: int, eps: float = 1e-5):
40
+ super().__init__()
41
+ self.weight = nn.Parameter(torch.ones(d_model))
42
+ self.bias = nn.Parameter(torch.zeros(d_model))
43
+ self.eps = eps
44
+
45
+ def forward(self, x):
46
+ mean = x.mean(-1, keepdim=True)
47
+ std = x.std(-1, keepdim=True)
48
+ return self.weight * (x - mean) / (std + self.eps) + self.bias
49
+
50
+ class MultiHeadAttention(nn.Module):
51
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
52
+ super().__init__()
53
+ assert d_model % n_heads == 0
54
+ self.d_model = d_model
55
+ self.n_heads = n_heads
56
+ self.d_k = d_model // n_heads
57
+ self.w_q = nn.Linear(d_model, d_model)
58
+ self.w_k = nn.Linear(d_model, d_model)
59
+ self.w_v = nn.Linear(d_model, d_model)
60
+ self.w_o = nn.Linear(d_model, d_model)
61
+ self.dropout = nn.Dropout(dropout)
62
+ self.scale = math.sqrt(self.d_k)
63
+
64
+ def forward(self, x, mask=None):
65
+ batch_size, seq_len, _ = x.shape
66
+ Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
67
+ K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
68
+ V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
69
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
70
+ if mask is not None:
71
+ scores = scores.masked_fill(mask == 0, float('-inf'))
72
+ attn_weights = F.softmax(scores, dim=-1)
73
+ attn_weights = self.dropout(attn_weights)
74
+ attn_output = torch.matmul(attn_weights, V)
75
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
76
+ return self.w_o(attn_output)
77
+
78
+ class FeedForward(nn.Module):
79
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
80
+ super().__init__()
81
+ self.linear1 = nn.Linear(d_model, d_ff)
82
+ self.linear2 = nn.Linear(d_ff, d_model)
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ def forward(self, x):
86
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
87
+
88
+ class TransformerBlock(nn.Module):
89
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
90
+ super().__init__()
91
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
92
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
93
+ self.norm1 = LayerNorm(d_model)
94
+ self.norm2 = LayerNorm(d_model)
95
+ self.dropout1 = nn.Dropout(dropout)
96
+ self.dropout2 = nn.Dropout(dropout)
97
+
98
+ def forward(self, x, mask=None):
99
+ attn_output = self.attention(x, mask)
100
+ x = x + self.dropout1(attn_output)
101
+ x = self.norm1(x)
102
+ ff_output = self.feed_forward(x)
103
+ x = x + self.dropout2(ff_output)
104
+ x = self.norm2(x)
105
+ return x
106
+
107
+ class PositionalEncoding(nn.Module):
108
+ def __init__(self, d_model: int, max_len: int = 5000):
109
+ super().__init__()
110
+ pe = torch.zeros(max_len, d_model)
111
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
112
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
113
+ pe[:, 0::2] = torch.sin(position * div_term)
114
+ pe[:, 1::2] = torch.cos(position * div_term)
115
+ self.register_buffer('pe', pe.unsqueeze(0))
116
+
117
+ def forward(self, x):
118
+ return x + self.pe[:, :x.size(1), :]
119
+
120
+ class MTPModel(nn.Module):
121
+ def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4,
122
+ n_layers: int = 4, d_ff: int = 512, dropout: float = 0.1, max_len: int = 256):
123
+ super().__init__()
124
+ self.vocab_size = vocab_size
125
+ self.d_model = d_model
126
+ self.max_len = max_len
127
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
128
+ self.pos_encoding = PositionalEncoding(d_model, max_len)
129
+ self.blocks = nn.ModuleList([
130
+ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
131
+ ])
132
+ self.norm = LayerNorm(d_model)
133
+ self.lm_head = nn.Linear(d_model, vocab_size)
134
+
135
+ def forward(self, x, mask=None):
136
+ if mask is None:
137
+ mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
138
+ x = self.token_embedding(x) * math.sqrt(self.d_model)
139
+ x = self.pos_encoding(x)
140
+ for block in self.blocks:
141
+ x = block(x, mask)
142
+ x = self.norm(x)
143
+ logits = self.lm_head(x)
144
+ return logits
145
+
146
+ def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
147
+ """Método de generación compatible con la interfaz"""
148
+ generated = input_ids
149
+
150
+ for _ in range(max_new_tokens):
151
+ # Obtener logits para el último token
152
+ with torch.no_grad():
153
+ logits = self(generated)
154
+ next_logits = logits[0, -1, :] / temperature
155
+
156
+ # Aplicar repetition penalty
157
+ if repetition_penalty != 1.0:
158
+ for token_id in set(generated[0].tolist()):
159
+ next_logits[token_id] /= repetition_penalty
160
+
161
+ # Top-k filtering
162
+ if top_k > 0:
163
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
164
+ next_logits[indices_to_remove] = float('-inf')
165
+
166
+ # Top-p filtering
167
+ if top_p < 1.0:
168
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
169
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
170
+ sorted_indices_to_remove = cumulative_probs > top_p
171
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
172
+ sorted_indices_to_remove[..., 0] = 0
173
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
174
+ next_logits[indices_to_remove] = float('-inf')
175
+
176
+ # Sampling
177
+ probs = F.softmax(next_logits, dim=-1)
178
+ next_token = torch.multinomial(probs, num_samples=1).item()
179
+
180
+ # Parar en EOS
181
+ if next_token == 3: # EOS ID para SentencePiece
182
+ break
183
+
184
+ generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
185
+
186
+ return generated
187
+
188
+ # ======================
189
+ # DESCARGA Y CARGA DEL MODELO
190
+ # ======================
191
+ print(f"📦 Descargando modelo desde {MODEL_REPO}...")
192
+ repo_path = snapshot_download(
193
+ repo_id=MODEL_REPO,
194
+ repo_type="model",
195
+ local_dir="mtp_repo"
196
+ )
197
+
198
+ # Cargar configuración
199
+ config_path = os.path.join(repo_path, "config.json")
200
+ if os.path.exists(config_path):
201
+ with open(config_path, "r") as f:
202
+ config = json.load(f)
203
+ else:
204
+ config = {
205
+ "vocab_size": 5000,
206
+ "d_model": 128,
207
+ "n_heads": 4,
208
+ "n_layers": 4,
209
+ "d_ff": 512,
210
+ "dropout": 0.1,
211
+ "max_len": 256
212
+ }
213
+
214
+ # Cargar tokenizador
215
+ tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
216
+ sp = spm.SentencePieceProcessor()
217
+ sp.load(tokenizer_path)
218
+ VOCAB_SIZE = sp.get_piece_size()
219
+
220
+ # Actualizar vocab_size en config
221
+ config["vocab_size"] = VOCAB_SIZE
222
+
223
+ print(f"🧠 Inicializando modelo MTP-1.1...")
224
+ print(f" → Vocabulario: {VOCAB_SIZE}")
225
+ print(f" → Dimensión: {config['d_model']}")
226
+ print(f" → Capas: {config['n_layers']}")
227
+ print(f" → Heads: {config['n_heads']}")
228
+
229
+ model = MTPModel(**config)
230
+ model.to(DEVICE)
231
+
232
+ # Cargar pesos del modelo
233
+ model_path = os.path.join(repo_path, "mtp_model.pt")
234
+ if os.path.exists(model_path):
235
+ state_dict = torch.load(model_path, map_location=DEVICE)
236
+ model.load_state_dict(state_dict)
237
+ print("✅ Pesos del modelo cargados")
238
+ else:
239
+ print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
240
+
241
+ model.eval()
242
+
243
+ # Cuantización para CPU
244
+ if DEVICE == "cpu":
245
+ print("⚡ Aplicando cuantización dinámica para CPU...")
246
+ model = torch.quantization.quantize_dynamic(
247
+ model,
248
+ {nn.Linear},
249
+ dtype=torch.qint8
250
+ )
251
+
252
+ param_count = sum(p.numel() for p in model.parameters())
253
+ print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
254
+
255
+ # ======================
256
+ # API CONFIG
257
+ # ======================
258
+ app = FastAPI(
259
+ title="MTP-1.1 API",
260
+ description="API para modelo de lenguaje MTP-1.1",
261
+ version="1.1"
262
+ )
263
+
264
+ app.add_middleware(
265
+ CORSMiddleware,
266
+ allow_origins=["*"],
267
+ allow_methods=["*"],
268
+ allow_headers=["*"],
269
+ )
270
+
271
+ class PromptRequest(BaseModel):
272
+ text: str = Field(..., max_length=2000, description="Texto de entrada")
273
+ max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
274
+ temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
275
+ top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
276
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
277
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
278
+
279
+ def build_prompt(user_input: str) -> str:
280
+ """Construye el prompt en el formato del modelo"""
281
+ return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
282
+
283
+ # ======================
284
+ # GESTIÓN DE CARGA
285
+ # ======================
286
+ ACTIVE_REQUESTS = 0
287
+
288
+ class MTPTokenizer:
289
+ """Wrapper para el tokenizador de SentencePiece"""
290
+ def __init__(self, sp_model):
291
+ self.sp = sp_model
292
+
293
+ def encode(self, text):
294
+ return self.sp.encode(text)
295
+
296
+ def decode(self, tokens):
297
+ return self.sp.decode(tokens)
298
+
299
+ def bos_id(self):
300
+ return self.sp.bos_id()
301
+
302
+ def eos_id(self):
303
+ return self.sp.eos_id()
304
+
305
+ tokenizer_wrapper = MTPTokenizer(sp)
306
+
307
+ @app.post("/generate")
308
+ async def generate(req: PromptRequest):
309
+ """Endpoint principal de generación de texto"""
310
+ global ACTIVE_REQUESTS
311
+ ACTIVE_REQUESTS += 1
312
+
313
+ dyn_max_tokens = req.max_tokens
314
+ dyn_temperature = req.temperature
315
+
316
+ if ACTIVE_REQUESTS > 2:
317
+ print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
318
+ dyn_max_tokens = min(dyn_max_tokens, 120)
319
+ dyn_temperature = max(0.5, dyn_temperature * 0.9)
320
+
321
+ user_input = req.text.strip()
322
+ if not user_input:
323
+ ACTIVE_REQUESTS -= 1
324
+ return {"reply": "", "tokens_generated": 0}
325
+
326
+ full_prompt = build_prompt(user_input)
327
+ tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
328
+ input_ids = torch.tensor([tokens], device=DEVICE)
329
+
330
+ try:
331
+ with torch.no_grad():
332
+ output_ids = model.generate(
333
+ input_ids,
334
+ max_new_tokens=dyn_max_tokens,
335
+ temperature=dyn_temperature,
336
+ top_k=req.top_k,
337
+ top_p=req.top_p,
338
+ repetition_penalty=req.repetition_penalty
339
+ )
340
+
341
+ gen_tokens = output_ids[0, len(tokens):].tolist()
342
+
343
+ safe_tokens = [
344
+ t for t in gen_tokens
345
+ if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
346
+ ]
347
+
348
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
349
+
350
+ if "###" in response:
351
+ response = response.split("###")[0].strip()
352
+
353
+ return {
354
+ "reply": response,
355
+ "tokens_generated": len(safe_tokens),
356
+ "model": "MTP-1.1"
357
+ }
358
+
359
+ except Exception as e:
360
+ print(f"❌ Error durante generación: {e}")
361
+ return {
362
+ "reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
363
+ "error": str(e)
364
+ }
365
+
366
+ finally:
367
+ ACTIVE_REQUESTS -= 1
368
+ if DEVICE == "cuda":
369
+ torch.cuda.empty_cache()
370
+ gc.collect()
371
+
372
+ # ======================
373
+ # ENDPOINTS DE INFORMACIÓN
374
+ # ======================
375
+ @app.get("/health")
376
+ def health_check():
377
+ return {
378
+ "status": "healthy",
379
+ "model": "MTP-1.1",
380
+ "device": DEVICE,
381
+ "active_requests": ACTIVE_REQUESTS,
382
+ "vocab_size": VOCAB_SIZE
383
+ }
384
+
385
+ @app.get("/info")
386
+ def model_info():
387
+ return {
388
+ "model_name": "MTP-1.1",
389
+ "version": "1.1",
390
+ "architecture": config,
391
+ "parameters": sum(p.numel() for p in model.parameters()),
392
+ "device": DEVICE
393
+ }
394
+
395
+ # ======================
396
+ # INTERFAZ WEB (MODERNA DE MTP-3)
397
+ # ======================
398
+ @app.get("/", response_class=HTMLResponse)
399
+ def chat_ui():
400
+ return """
401
+ <!DOCTYPE html>
402
+ <html lang="es">
403
+ <head>
404
+ <meta charset="UTF-8">
405
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
406
+ <title>MTP 1.1</title>
407
+ <link rel="preconnect" href="https://fonts.googleapis.com">
408
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
409
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
410
+ <style>
411
+ :root {
412
+ --bg-color: #131314;
413
+ --surface-color: #1E1F20;
414
+ --accent-color: #4a9eff;
415
+ --text-primary: #e3e3e3;
416
+ --text-secondary: #9aa0a6;
417
+ --user-bubble: #282a2c;
418
+ --bot-actions-color: #c4c7c5;
419
+ --logo-url: url('https://i.postimg.cc/yxS54PF3/IMG-3082.jpg');
420
+ }
421
+ * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
422
+ body {
423
+ margin: 0;
424
+ background-color: var(--bg-color);
425
+ font-family: 'Inter', sans-serif;
426
+ color: var(--text-primary);
427
+ height: 100dvh;
428
+ display: flex;
429
+ flex-direction: column;
430
+ overflow: hidden;
431
+ }
432
+ header {
433
+ padding: 12px 20px;
434
+ display: flex;
435
+ align-items: center;
436
+ justify-content: space-between;
437
+ background: rgba(19, 19, 20, 0.85);
438
+ backdrop-filter: blur(12px);
439
+ position: fixed;
440
+ top: 0;
441
+ width: 100%;
442
+ z-index: 50;
443
+ border-bottom: 1px solid rgba(255,255,255,0.05);
444
+ }
445
+ .brand-wrapper {
446
+ display: flex;
447
+ align-items: center;
448
+ gap: 12px;
449
+ cursor: pointer;
450
+ }
451
+ .brand-logo {
452
+ width: 32px;
453
+ height: 32px;
454
+ border-radius: 50%;
455
+ background-image: var(--logo-url);
456
+ background-size: cover;
457
+ background-position: center;
458
+ border: 1px solid rgba(255,255,255,0.1);
459
+ }
460
+ .brand-text {
461
+ font-weight: 500;
462
+ font-size: 1.05rem;
463
+ display: flex;
464
+ align-items: center;
465
+ gap: 8px;
466
+ }
467
+ .version-badge {
468
+ font-size: 0.75rem;
469
+ background: rgba(74, 158, 255, 0.15);
470
+ color: #8ab4f8;
471
+ padding: 2px 8px;
472
+ border-radius: 12px;
473
+ font-weight: 600;
474
+ }
475
+ .chat-scroll {
476
+ flex: 1;
477
+ overflow-y: auto;
478
+ padding: 80px 20px 40px 20px;
479
+ display: flex;
480
+ flex-direction: column;
481
+ gap: 30px;
482
+ max-width: 850px;
483
+ margin: 0 auto;
484
+ width: 100%;
485
+ scroll-behavior: smooth;
486
+ }
487
+ .msg-row {
488
+ display: flex;
489
+ gap: 16px;
490
+ width: 100%;
491
+ opacity: 0;
492
+ transform: translateY(10px);
493
+ animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
494
+ }
495
+ .msg-row.user { justify-content: flex-end; }
496
+ .msg-row.bot { justify-content: flex-start; align-items: flex-start; }
497
+ .msg-content {
498
+ line-height: 1.6;
499
+ font-size: 1rem;
500
+ word-wrap: break-word;
501
+ max-width: 85%;
502
+ }
503
+ .user .msg-content {
504
+ background-color: var(--user-bubble);
505
+ padding: 10px 18px;
506
+ border-radius: 18px;
507
+ border-top-right-radius: 4px;
508
+ color: #fff;
509
+ }
510
+ .bot .msg-content-wrapper {
511
+ display: flex;
512
+ flex-direction: column;
513
+ gap: 8px;
514
+ width: 100%;
515
+ }
516
+ .bot .msg-text {
517
+ padding-top: 6px;
518
+ color: var(--text-primary);
519
+ }
520
+ .bot-avatar {
521
+ width: 34px;
522
+ height: 34px;
523
+ min-width: 34px;
524
+ border-radius: 50%;
525
+ background-image: var(--logo-url);
526
+ background-size: cover;
527
+ box-shadow: 0 2px 6px rgba(0,0,0,0.2);
528
+ }
529
+ .bot-actions {
530
+ display: flex;
531
+ gap: 10px;
532
+ opacity: 0;
533
+ transition: opacity 0.3s;
534
+ margin-top: 5px;
535
+ }
536
+ .action-btn {
537
+ background: transparent;
538
+ border: none;
539
+ color: var(--text-secondary);
540
+ cursor: pointer;
541
+ padding: 4px;
542
+ border-radius: 4px;
543
+ display: flex;
544
+ align-items: center;
545
+ transition: color 0.2s, background 0.2s;
546
+ }
547
+ .action-btn:hover {
548
+ color: var(--text-primary);
549
+ background: rgba(255,255,255,0.08);
550
+ }
551
+ .action-btn svg { width: 16px; height: 16px; fill: currentColor; }
552
+ .typing-cursor::after {
553
+ content: '';
554
+ display: inline-block;
555
+ width: 10px;
556
+ height: 10px;
557
+ background: var(--accent-color);
558
+ border-radius: 50%;
559
+ margin-left: 5px;
560
+ vertical-align: middle;
561
+ animation: blink 1s infinite;
562
+ }
563
+ .footer-container {
564
+ padding: 0 20px 20px 20px;
565
+ background: linear-gradient(to top, var(--bg-color) 85%, transparent);
566
+ position: relative;
567
+ z-index: 60;
568
+ }
569
+ .input-box {
570
+ max-width: 850px;
571
+ margin: 0 auto;
572
+ background: var(--surface-color);
573
+ border-radius: 28px;
574
+ padding: 8px 10px 8px 20px;
575
+ display: flex;
576
+ align-items: center;
577
+ border: 1px solid rgba(255,255,255,0.1);
578
+ transition: border-color 0.2s, box-shadow 0.2s;
579
+ }
580
+ .input-box:focus-within {
581
+ border-color: rgba(74, 158, 255, 0.5);
582
+ box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
583
+ }
584
+ #userInput {
585
+ flex: 1;
586
+ background: transparent;
587
+ border: none;
588
+ color: white;
589
+ font-size: 1rem;
590
+ font-family: inherit;
591
+ padding: 10px 0;
592
+ }
593
+ #mainBtn {
594
+ background: white;
595
+ color: black;
596
+ border: none;
597
+ width: 36px;
598
+ height: 36px;
599
+ border-radius: 50%;
600
+ display: flex;
601
+ align-items: center;
602
+ justify-content: center;
603
+ cursor: pointer;
604
+ margin-left: 8px;
605
+ transition: transform 0.2s;
606
+ }
607
+ #mainBtn:hover { transform: scale(1.05); }
608
+ .disclaimer {
609
+ text-align: center;
610
+ font-size: 0.75rem;
611
+ color: #666;
612
+ margin-top: 12px;
613
+ }
614
+ @keyframes slideUpFade {
615
+ from { opacity: 0; transform: translateY(15px); }
616
+ to { opacity: 1; transform: translateY(0); }
617
+ }
618
+ @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
619
+ @keyframes pulseAvatar {
620
+ 0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
621
+ 70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
622
+ 100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
623
+ }
624
+ .pulsing { animation: pulseAvatar 1.5s infinite; }
625
+ ::-webkit-scrollbar { width: 8px; }
626
+ ::-webkit-scrollbar-track { background: transparent; }
627
+ ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
628
+ </style>
629
+ </head>
630
+ <body>
631
+ <header>
632
+ <div class="brand-wrapper" onclick="location.reload()">
633
+ <div class="brand-logo"></div>
634
+ <div class="brand-text">
635
+ MTP <span class="version-badge">1.1</span>
636
+ </div>
637
+ </div>
638
+ </header>
639
+ <div id="chatScroll" class="chat-scroll">
640
+ <div class="msg-row bot" style="animation-delay: 0.1s;">
641
+ <div class="bot-avatar"></div>
642
+ <div class="msg-content-wrapper">
643
+ <div class="msg-text">
644
+ ¡Hola! Soy MTP 1.1. ¿En qué puedo ayudarte hoy?
645
+ </div>
646
+ </div>
647
+ </div>
648
+ </div>
649
+ <div class="footer-container">
650
+ <div class="input-box">
651
+ <input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
652
+ <button id="mainBtn" onclick="handleBtnClick()"></button>
653
+ </div>
654
+ <div class="disclaimer">
655
+ MTP puede cometer errores. Considera verificar la información importante.
656
+ </div>
657
+ </div>
658
+ <script>
659
+ const chatScroll = document.getElementById('chatScroll');
660
+ const userInput = document.getElementById('userInput');
661
+ const mainBtn = document.getElementById('mainBtn');
662
+ let isGenerating = false;
663
+ let abortController = null;
664
+ let typingTimeout = null;
665
+ let lastUserPrompt = "";
666
+ const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
667
+ const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
668
+ mainBtn.innerHTML = ICON_SEND;
669
+ function scrollToBottom() {
670
+ chatScroll.scrollTop = chatScroll.scrollHeight;
671
+ }
672
+ function setBtnState(state) {
673
+ if (state === 'sending') {
674
+ mainBtn.innerHTML = ICON_STOP;
675
+ isGenerating = true;
676
+ } else {
677
+ mainBtn.innerHTML = ICON_SEND;
678
+ isGenerating = false;
679
+ abortController = null;
680
+ }
681
+ }
682
+ function handleBtnClick() {
683
+ if (isGenerating) {
684
+ stopGeneration();
685
+ } else {
686
+ sendMessage();
687
+ }
688
+ }
689
+ function stopGeneration() {
690
+ if (abortController) abortController.abort();
691
+ if (typingTimeout) clearTimeout(typingTimeout);
692
+ const activeCursor = document.querySelector('.typing-cursor');
693
+ if (activeCursor) activeCursor.classList.remove('typing-cursor');
694
+ const activeAvatar = document.querySelector('.pulsing');
695
+ if (activeAvatar) activeAvatar.classList.remove('pulsing');
696
+ setBtnState('idle');
697
+ userInput.focus();
698
+ }
699
+ async function sendMessage(textOverride = null) {
700
+ const text = textOverride || userInput.value.trim();
701
+ if (!text) return;
702
+ lastUserPrompt = text;
703
+ if (!textOverride) {
704
+ userInput.value = '';
705
+ addMessage(text, 'user');
706
+ }
707
+ setBtnState('sending');
708
+ abortController = new AbortController();
709
+ const botRow = document.createElement('div');
710
+ botRow.className = 'msg-row bot';
711
+ const avatar = document.createElement('div');
712
+ avatar.className = 'bot-avatar pulsing';
713
+ const wrapper = document.createElement('div');
714
+ wrapper.className = 'msg-content-wrapper';
715
+ const msgText = document.createElement('div');
716
+ msgText.className = 'msg-text';
717
+ wrapper.appendChild(msgText);
718
+ botRow.appendChild(avatar);
719
+ botRow.appendChild(wrapper);
720
+ chatScroll.appendChild(botRow);
721
+ scrollToBottom();
722
+ try {
723
+ const response = await fetch('/generate', {
724
+ method: 'POST',
725
+ headers: { 'Content-Type': 'application/json' },
726
+ body: JSON.stringify({ text: text }),
727
+ signal: abortController.signal
728
+ });
729
+ const data = await response.json();
730
+ if (!isGenerating) return;
731
+ avatar.classList.remove('pulsing');
732
+ const reply = data.reply || "No entendí eso.";
733
+ await typeWriter(msgText, reply);
734
+ if (isGenerating) {
735
+ addActions(wrapper, reply);
736
+ setBtnState('idle');
737
+ }
738
+ } catch (error) {
739
+ if (error.name === 'AbortError') {
740
+ msgText.textContent += " [Detenido]";
741
+ } else {
742
+ avatar.classList.remove('pulsing');
743
+ msgText.textContent = "Error de conexión.";
744
+ msgText.style.color = "#ff8b8b";
745
+ setBtnState('idle');
746
+ }
747
+ }
748
+ }
749
+ function addMessage(text, sender) {
750
+ const row = document.createElement('div');
751
+ row.className = `msg-row ${sender}`;
752
+ const content = document.createElement('div');
753
+ content.className = 'msg-content';
754
+ content.textContent = text;
755
+ row.appendChild(content);
756
+ chatScroll.appendChild(row);
757
+ scrollToBottom();
758
+ }
759
+ function typeWriter(element, text, speed = 12) {
760
+ return new Promise(resolve => {
761
+ let i = 0;
762
+ element.classList.add('typing-cursor');
763
+ function type() {
764
+ if (!isGenerating) {
765
+ element.classList.remove('typing-cursor');
766
+ resolve();
767
+ return;
768
+ }
769
+ if (i < text.length) {
770
+ element.textContent += text.charAt(i);
771
+ i++;
772
+ scrollToBottom();
773
+ typingTimeout = setTimeout(type, speed + Math.random() * 5);
774
+ } else {
775
+ element.classList.remove('typing-cursor');
776
+ resolve();
777
+ }
778
+ }
779
+ type();
780
+ });
781
+ }
782
+ function addActions(wrapperElement, textToCopy) {
783
+ const actionsDiv = document.createElement('div');
784
+ actionsDiv.className = 'bot-actions';
785
+ const copyBtn = document.createElement('button');
786
+ copyBtn.className = 'action-btn';
787
+ copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
788
+ copyBtn.onclick = () => {
789
+ navigator.clipboard.writeText(textToCopy);
790
+ };
791
+ const regenBtn = document.createElement('button');
792
+ regenBtn.className = 'action-btn';
793
+ regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
794
+ regenBtn.onclick = () => {
795
+ sendMessage(lastUserPrompt);
796
+ };
797
+ actionsDiv.appendChild(copyBtn);
798
+ actionsDiv.appendChild(regenBtn);
799
+ wrapperElement.appendChild(actionsDiv);
800
+ requestAnimationFrame(() => actionsDiv.style.opacity = "1");
801
+ scrollToBottom();
802
+ }
803
+ userInput.addEventListener('keydown', (e) => {
804
+ if (e.key === 'Enter') handleBtnClick();
805
+ });
806
+ window.onload = () => userInput.focus();
807
+ </script>
808
+ </body>
809
+ </html>
810
+ """
811
+
812
+ if __name__ == "__main__":
813
+ port = int(os.environ.get("PORT", 7860))
814
+ print(f"\n🚀 Iniciando servidor MTP-1.1 en puerto {port}...")
815
+ print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
816
+ print(f"📡 API docs: http://0.0.0.0:{port}/docs")
817
+
818
+ uvicorn.run(
819
+ app,
820
+ host="0.0.0.0",
821
+ port=port,
822
+ log_level="info"
823
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ sentencepiece>=0.1.99
3
+ fastapi>=0.100.0
4
+ uvicorn>=0.23.0
5
+ numpy>=1.24.0
6
+ huggingface_hub>=0.20.0