teszenofficial commited on
Commit
6517f5f
·
verified ·
1 Parent(s): d163150

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -453
app.py CHANGED
@@ -32,99 +32,87 @@ if DEVICE == "cpu":
32
  torch.set_grad_enabled(False)
33
 
34
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
- MODEL_REPO = "TeszenAI/MTP-3.1.1" # <-- CAMBIA A TU REPO
36
 
37
  # ======================
38
  # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
  # ======================
40
 
41
- def clean_response(text: str) -> str:
42
  """
43
- Limpia la respuesta eliminando repeticiones, frases sin sentido y
44
- asegurando que termine correctamente.
45
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if not text:
47
  return ""
48
 
49
- # 1. Eliminar repeticiones excesivas de palabras o frases cortas
50
  words = text.split()
51
  cleaned_words = []
52
- last_phrase = ""
53
  repeat_count = 0
54
 
55
  for word in words:
56
- if word == last_phrase:
57
  repeat_count += 1
58
- if repeat_count > 2: # Si repite más de 2 veces seguidas
59
  continue
60
  else:
61
- last_phrase = word
62
  repeat_count = 0
63
  cleaned_words.append(word)
64
 
65
  text = " ".join(cleaned_words)
66
 
67
- # 2. Eliminar patrones sin sentido (repeticiones de letras, caracteres raros)
68
- text = re.sub(r'(.)\1{4,}', r'\1\1', text) # aaa... -> aa
69
- text = re.sub(r'[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ0-9\s.,;:!?¿¡()\-"]+', '', text)
70
 
71
- # 3. Cortar en la primera frase que parezca final coherente
72
- stop_patterns = [
73
- r'(\.\s*)$', # Punto final
74
- r'[.!?](\s+)?$', # Fin de oración
75
- r'(gracias|hasta luego|adiós|saludos|fin|fin del mensaje)$',
76
- r'(¿algo más\?|¿necesitas algo más\?|¿en qué más puedo ayudarte\?)'
77
- ]
78
 
79
- for pattern in stop_patterns:
80
- match = re.search(pattern, text, re.IGNORECASE)
81
- if match:
82
- # Cortar justo después del patrón de finalización
83
- end_pos = match.end()
84
- text = text[:end_pos]
85
- break
 
 
 
 
 
86
 
87
- # 4. Si la respuesta es muy corta o vacía, devolver mensaje por defecto
88
- if len(text.strip()) < 10:
 
 
89
  return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
90
 
91
- # 5. Eliminar espacios múltiples y saltos de línea excesivos
92
  text = re.sub(r'\s+', ' ', text).strip()
93
 
94
  return text
95
 
96
- def should_stop_generation(generated_text: str, min_length: int = 30, max_length: int = 300) -> bool:
97
- """
98
- Determina si debemos detener la generación basado en el texto generado.
99
- """
100
- # Si ya superamos la longitud máxima
101
- if len(generated_text) > max_length:
102
- return True
103
-
104
- # Si es muy corto y no hay puntuación final
105
- if len(generated_text) < min_length and not re.search(r'[.!?]$', generated_text):
106
- return False
107
-
108
- # Señales de que ya terminó la respuesta
109
- stop_signals = [
110
- r'(gracias por tu pregunta|espero haberte ayudado|¿necesitas algo más\?)',
111
- r'(hasta luego|adiós|quedo atento|saludos cordiales)',
112
- r'(fin del mensaje|fin de la conversación)'
113
- ]
114
-
115
- for signal in stop_signals:
116
- if re.search(signal, generated_text, re.IGNORECASE):
117
- return True
118
-
119
- # Si la última frase parece completa
120
- last_sentence = generated_text.split('.')[-1].strip()
121
- if len(last_sentence) > 5 and re.search(r'[.!?]$', last_sentence):
122
- # Y ya hemos generado suficiente contenido
123
- if len(generated_text) > min_length:
124
- return True
125
-
126
- return False
127
-
128
  # ======================
129
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
130
  # ======================
@@ -237,11 +225,8 @@ class MTPModel(nn.Module):
237
  return logits
238
 
239
  def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
240
- """Método de generación mejorado con detección inteligente de fin"""
241
  generated = input_ids
242
- generated_text = ""
243
- min_response_length = 30
244
- max_response_length = max_new_tokens * 2
245
 
246
  for step in range(max_new_tokens):
247
  with torch.no_grad():
@@ -268,17 +253,11 @@ class MTPModel(nn.Module):
268
  probs = F.softmax(next_logits, dim=-1)
269
  next_token = torch.multinomial(probs, num_samples=1).item()
270
 
271
- if next_token == 3: # EOS ID para SentencePiece
 
272
  break
273
 
274
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
275
-
276
- # Decodificar parcialmente para verificar si debemos parar (solo cada 10 pasos para eficiencia)
277
- if step > 10 and step % 10 == 0:
278
- # Intentar decodificar tokens generados (esto es aproximado, el tokenizador real está fuera)
279
- if len(generated[0]) > 10:
280
- if should_stop_generation(str(generated[0].tolist()), min_response_length, max_response_length):
281
- break
282
 
283
  return generated
284
 
@@ -310,6 +289,10 @@ else:
310
 
311
  # Cargar tokenizador
312
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
 
 
 
 
313
  sp = spm.SentencePieceProcessor()
314
  sp.load(tokenizer_path)
315
  VOCAB_SIZE = sp.get_piece_size()
@@ -330,25 +313,13 @@ model.to(DEVICE)
330
  model_path = os.path.join(repo_path, "mtp_model.pt")
331
  if os.path.exists(model_path):
332
  state_dict = torch.load(model_path, map_location=DEVICE)
333
- model.load_state_dict(state_dict)
334
  print("✅ Pesos del modelo cargados")
335
  else:
336
- print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
337
 
338
  model.eval()
339
 
340
- # Cuantización para CPU
341
- if DEVICE == "cpu":
342
- print("⚡ Aplicando cuantización dinámica para CPU...")
343
- try:
344
- model = torch.quantization.quantize_dynamic(
345
- model,
346
- {nn.Linear},
347
- dtype=torch.qint8
348
- )
349
- except Exception as e:
350
- print(f"⚠️ No se pudo aplicar cuantización: {e}")
351
-
352
  param_count = sum(p.numel() for p in model.parameters())
353
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
354
 
@@ -370,7 +341,7 @@ app.add_middleware(
370
 
371
  class PromptRequest(BaseModel):
372
  text: str = Field(..., max_length=2000, description="Texto de entrada")
373
- max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
374
  temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
375
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
376
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
@@ -413,29 +384,27 @@ async def generate(req: PromptRequest):
413
  global ACTIVE_REQUESTS
414
  ACTIVE_REQUESTS += 1
415
 
416
- dyn_max_tokens = req.max_tokens
417
- dyn_temperature = req.temperature
418
-
419
- if ACTIVE_REQUESTS > 2:
420
- print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
421
- dyn_max_tokens = min(dyn_max_tokens, 120)
422
- dyn_temperature = max(0.5, dyn_temperature * 0.9)
423
-
424
  user_input = req.text.strip()
425
  if not user_input:
426
  ACTIVE_REQUESTS -= 1
427
  return {"reply": "", "tokens_generated": 0}
 
 
 
 
 
 
428
 
429
  full_prompt = build_prompt(user_input)
430
- tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
431
  input_ids = torch.tensor([tokens], device=DEVICE)
432
 
433
  try:
434
  with torch.no_grad():
435
  output_ids = model.generate(
436
  input_ids,
437
- max_new_tokens=dyn_max_tokens,
438
- temperature=dyn_temperature,
439
  top_k=req.top_k,
440
  top_p=req.top_p,
441
  repetition_penalty=req.repetition_penalty
@@ -443,18 +412,23 @@ async def generate(req: PromptRequest):
443
 
444
  gen_tokens = output_ids[0, len(tokens):].tolist()
445
 
446
- safe_tokens = [
447
- t for t in gen_tokens
448
- if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
449
- ]
450
 
451
- response = tokenizer_wrapper.decode(safe_tokens).strip()
 
 
 
452
 
453
- if "###" in response:
454
- response = response.split("###")[0].strip()
455
 
456
- # Aplicar limpieza inteligente a la respuesta
457
- response = clean_response(response)
 
 
 
 
458
 
459
  return {
460
  "reply": response,
@@ -464,8 +438,12 @@ async def generate(req: PromptRequest):
464
 
465
  except Exception as e:
466
  print(f"❌ Error durante generación: {e}")
 
 
 
 
467
  return {
468
- "reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
469
  "error": str(e)
470
  }
471
 
@@ -499,7 +477,7 @@ def model_info():
499
  }
500
 
501
  # ======================
502
- # INTERFAZ WEB (MODERNA CON LOGO INTEGRADO)
503
  # ======================
504
  @app.get("/", response_class=HTMLResponse)
505
  def chat_ui():
@@ -508,410 +486,197 @@ def chat_ui():
508
  <html lang="es">
509
  <head>
510
  <meta charset="UTF-8">
511
- <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
512
  <title>MTP - Asistente IA</title>
513
- <link rel="preconnect" href="https://fonts.googleapis.com">
514
- <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
515
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
516
  <style>
517
- :root {
518
- --bg-color: #131314;
519
- --surface-color: #1E1F20;
520
- --accent-color: #4a9eff;
521
- --text-primary: #e3e3e3;
522
- --text-secondary: #9aa0a6;
523
- --user-bubble: #282a2c;
524
- }
525
- * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
526
  body {
527
- margin: 0;
528
- background-color: var(--bg-color);
529
- font-family: 'Inter', sans-serif;
530
- color: var(--text-primary);
531
- height: 100dvh;
532
  display: flex;
533
  flex-direction: column;
534
- overflow: hidden;
535
  }
536
- header {
537
- padding: 12px 20px;
538
- display: flex;
539
- align-items: center;
540
- justify-content: space-between;
541
- background: rgba(19, 19, 20, 0.85);
542
- backdrop-filter: blur(12px);
543
- position: fixed;
544
- top: 0;
545
- width: 100%;
546
- z-index: 50;
547
- border-bottom: 1px solid rgba(255,255,255,0.05);
548
  }
549
- .brand-wrapper {
550
- display: flex;
551
- align-items: center;
552
- gap: 12px;
553
- cursor: pointer;
554
- }
555
- .brand-logo {
556
- width: 32px;
557
- height: 32px;
558
- border-radius: 50%;
559
- background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
560
- background-size: cover;
561
- background-position: center;
562
- background-repeat: no-repeat;
563
- border: 1px solid rgba(255,255,255,0.1);
564
- }
565
- .brand-text {
566
  font-weight: 500;
567
- font-size: 1.05rem;
568
- display: flex;
569
- align-items: center;
570
- gap: 8px;
571
  }
572
- .version-badge {
573
- font-size: 0.75rem;
574
- background: rgba(74, 158, 255, 0.15);
575
- color: #8ab4f8;
576
- padding: 2px 8px;
577
- border-radius: 12px;
578
- font-weight: 600;
579
- }
580
- .chat-scroll {
581
  flex: 1;
582
  overflow-y: auto;
583
- padding: 80px 20px 40px 20px;
584
  display: flex;
585
  flex-direction: column;
586
- gap: 30px;
587
- max-width: 850px;
588
- margin: 0 auto;
589
- width: 100%;
590
- scroll-behavior: smooth;
591
  }
592
- .msg-row {
593
  display: flex;
594
- gap: 16px;
595
- width: 100%;
596
- opacity: 0;
597
- transform: translateY(10px);
598
- animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
599
  }
600
- .msg-row.user { justify-content: flex-end; }
601
- .msg-row.bot { justify-content: flex-start; align-items: flex-start; }
602
- .msg-content {
603
- line-height: 1.6;
604
- font-size: 1rem;
605
- word-wrap: break-word;
606
- max-width: 85%;
607
  }
608
- .user .msg-content {
609
- background-color: var(--user-bubble);
610
- padding: 10px 18px;
611
  border-radius: 18px;
612
- border-top-right-radius: 4px;
613
- color: #fff;
614
  }
615
- .bot .msg-content-wrapper {
616
- display: flex;
617
- flex-direction: column;
618
- gap: 8px;
619
- width: 100%;
620
  }
621
- .bot .msg-text {
622
- padding-top: 6px;
623
- color: var(--text-primary);
 
624
  }
625
- .bot-avatar {
626
- width: 34px;
627
- height: 34px;
628
- min-width: 34px;
629
- border-radius: 50%;
630
- background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
631
- background-size: cover;
632
- background-position: center;
633
- background-repeat: no-repeat;
634
- box-shadow: 0 2px 6px rgba(0,0,0,0.2);
635
  }
636
- .bot-actions {
637
  display: flex;
638
- gap: 10px;
639
- opacity: 0;
640
- transition: opacity 0.3s;
641
- margin-top: 5px;
642
- }
643
- .action-btn {
644
- background: transparent;
645
- border: none;
646
- color: var(--text-secondary);
647
- cursor: pointer;
648
- padding: 4px;
649
- border-radius: 4px;
650
- display: flex;
651
- align-items: center;
652
- transition: color 0.2s, background 0.2s;
653
- }
654
- .action-btn:hover {
655
- color: var(--text-primary);
656
- background: rgba(255,255,255,0.08);
657
- }
658
- .action-btn svg { width: 16px; height: 16px; fill: currentColor; }
659
- .typing-cursor::after {
660
- content: '▊';
661
- display: inline-block;
662
- margin-left: 2px;
663
- animation: blink 1s infinite;
664
- }
665
- .footer-container {
666
- padding: 0 20px 20px 20px;
667
- background: linear-gradient(to top, var(--bg-color) 85%, transparent);
668
- position: relative;
669
- z-index: 60;
670
- }
671
- .input-box {
672
- max-width: 850px;
673
  margin: 0 auto;
674
- background: var(--surface-color);
675
- border-radius: 28px;
676
- padding: 8px 10px 8px 20px;
677
- display: flex;
678
- align-items: center;
679
- border: 1px solid rgba(255,255,255,0.1);
680
- transition: border-color 0.2s, box-shadow 0.2s;
681
  }
682
- .input-box:focus-within {
683
- border-color: rgba(74, 158, 255, 0.5);
684
- box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
685
- }
686
- #userInput {
687
  flex: 1;
688
- background: transparent;
 
689
  border: none;
 
690
  color: white;
691
- font-size: 1rem;
692
- font-family: inherit;
693
- padding: 10px 0;
 
 
694
  }
695
- #mainBtn {
696
- background: white;
697
- color: black;
698
  border: none;
699
- width: 36px;
700
- height: 36px;
701
- border-radius: 50%;
702
- display: flex;
703
- align-items: center;
704
- justify-content: center;
705
  cursor: pointer;
706
- margin-left: 8px;
707
- transition: transform 0.2s;
708
  }
709
- #mainBtn:hover { transform: scale(1.05); }
710
- .disclaimer {
711
- text-align: center;
712
- font-size: 0.75rem;
713
- color: #666;
714
- margin-top: 12px;
715
  }
716
- @keyframes slideUpFade {
717
- from { opacity: 0; transform: translateY(15px); }
718
- to { opacity: 1; transform: translateY(0); }
 
719
  }
720
- @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
721
- @keyframes pulseAvatar {
722
- 0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
723
- 70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
724
- 100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
 
 
 
 
 
 
 
725
  }
726
- .pulsing { animation: pulseAvatar 1.5s infinite; }
727
- ::-webkit-scrollbar { width: 8px; }
728
- ::-webkit-scrollbar-track { background: transparent; }
729
- ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
730
  </style>
731
  </head>
732
  <body>
733
- <header>
734
- <div class="brand-wrapper" onclick="location.reload()">
735
- <div class="brand-logo"></div>
736
- <div class="brand-text">
737
- MTP <span class="version-badge">v1</span>
738
- </div>
739
- </div>
740
- </header>
741
- <div id="chatScroll" class="chat-scroll">
742
- <div class="msg-row bot" style="animation-delay: 0.1s;">
743
- <div class="bot-avatar"></div>
744
- <div class="msg-content-wrapper">
745
- <div class="msg-text">
746
- ¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
747
- </div>
748
- </div>
749
- </div>
750
  </div>
751
- <div class="footer-container">
752
- <div class="input-box">
753
- <input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
754
- <button id="mainBtn" onclick="handleBtnClick()">➤</button>
755
  </div>
756
- <div class="disclaimer">
757
- MTP puede cometer errores. Considera verificar la información importante.
 
 
 
758
  </div>
759
  </div>
760
  <script>
761
- const chatScroll = document.getElementById('chatScroll');
762
- const userInput = document.getElementById('userInput');
763
- const mainBtn = document.getElementById('mainBtn');
764
- let isGenerating = false;
765
- let abortController = null;
766
- let typingTimeout = null;
767
- let lastUserPrompt = "";
768
-
769
- function scrollToBottom() {
770
- chatScroll.scrollTop = chatScroll.scrollHeight;
771
- }
772
 
773
- function setBtnState(state) {
774
- if (state === 'sending') {
775
- mainBtn.innerHTML = '';
776
- isGenerating = true;
777
- } else {
778
- mainBtn.innerHTML = '➤';
779
- isGenerating = false;
780
- abortController = null;
781
- }
782
  }
783
 
784
- function handleBtnClick() {
785
- if (isGenerating) {
786
- stopGeneration();
787
- } else {
788
- sendMessage();
789
- }
 
790
  }
791
 
792
- function stopGeneration() {
793
- if (abortController) abortController.abort();
794
- if (typingTimeout) clearTimeout(typingTimeout);
795
- const activeCursor = document.querySelector('.typing-cursor');
796
- if (activeCursor) activeCursor.classList.remove('typing-cursor');
797
- const activeAvatar = document.querySelector('.pulsing');
798
- if (activeAvatar) activeAvatar.classList.remove('pulsing');
799
- setBtnState('idle');
800
- userInput.focus();
801
  }
802
 
803
- async function sendMessage(textOverride = null) {
804
- const text = textOverride || userInput.value.trim();
805
- if (!text) return;
806
- lastUserPrompt = text;
807
- if (!textOverride) {
808
- userInput.value = '';
809
- addMessage(text, 'user');
810
- }
811
- setBtnState('sending');
812
- abortController = new AbortController();
813
- const botRow = document.createElement('div');
814
- botRow.className = 'msg-row bot';
815
- const avatar = document.createElement('div');
816
- avatar.className = 'bot-avatar pulsing';
817
- const wrapper = document.createElement('div');
818
- wrapper.className = 'msg-content-wrapper';
819
- const msgText = document.createElement('div');
820
- msgText.className = 'msg-text';
821
- wrapper.appendChild(msgText);
822
- botRow.appendChild(avatar);
823
- botRow.appendChild(wrapper);
824
- chatScroll.appendChild(botRow);
825
- scrollToBottom();
826
  try {
827
  const response = await fetch('/generate', {
828
  method: 'POST',
829
  headers: { 'Content-Type': 'application/json' },
830
- body: JSON.stringify({ text: text }),
831
- signal: abortController.signal
832
  });
833
  const data = await response.json();
834
- if (!isGenerating) return;
835
- avatar.classList.remove('pulsing');
836
- const reply = data.reply || "No entendí eso.";
837
- await typeWriter(msgText, reply);
838
- if (isGenerating) {
839
- addActions(wrapper, reply);
840
- setBtnState('idle');
841
- }
842
  } catch (error) {
843
- if (error.name === 'AbortError') {
844
- msgText.textContent += " [Detenido]";
845
- } else {
846
- avatar.classList.remove('pulsing');
847
- msgText.textContent = "Error de conexión.";
848
- msgText.style.color = "#ff8b8b";
849
- setBtnState('idle');
850
- }
851
  }
852
  }
853
 
854
- function addMessage(text, sender) {
855
- const row = document.createElement('div');
856
- row.className = `msg-row ${sender}`;
857
- const content = document.createElement('div');
858
- content.className = 'msg-content';
859
- content.textContent = text;
860
- row.appendChild(content);
861
- chatScroll.appendChild(row);
862
- scrollToBottom();
863
- }
864
-
865
- function typeWriter(element, text, speed = 12) {
866
- return new Promise(resolve => {
867
- let i = 0;
868
- element.classList.add('typing-cursor');
869
- function type() {
870
- if (!isGenerating) {
871
- element.classList.remove('typing-cursor');
872
- resolve();
873
- return;
874
- }
875
- if (i < text.length) {
876
- element.textContent += text.charAt(i);
877
- i++;
878
- scrollToBottom();
879
- typingTimeout = setTimeout(type, speed + Math.random() * 5);
880
- } else {
881
- element.classList.remove('typing-cursor');
882
- resolve();
883
- }
884
- }
885
- type();
886
- });
887
- }
888
-
889
- function addActions(wrapperElement, textToCopy) {
890
- const actionsDiv = document.createElement('div');
891
- actionsDiv.className = 'bot-actions';
892
- const copyBtn = document.createElement('button');
893
- copyBtn.className = 'action-btn';
894
- copyBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
895
- copyBtn.onclick = () => {
896
- navigator.clipboard.writeText(textToCopy);
897
- };
898
- const regenBtn = document.createElement('button');
899
- regenBtn.className = 'action-btn';
900
- regenBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
901
- regenBtn.onclick = () => {
902
- sendMessage(lastUserPrompt);
903
- };
904
- actionsDiv.appendChild(copyBtn);
905
- actionsDiv.appendChild(regenBtn);
906
- wrapperElement.appendChild(actionsDiv);
907
- requestAnimationFrame(() => actionsDiv.style.opacity = "1");
908
- scrollToBottom();
909
- }
910
-
911
- userInput.addEventListener('keydown', (e) => {
912
- if (e.key === 'Enter') handleBtnClick();
913
  });
914
- window.onload = () => userInput.focus();
 
915
  </script>
916
  </body>
917
  </html>
 
32
  torch.set_grad_enabled(False)
33
 
34
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
+ MODEL_REPO = "TeszenAI/MTP-3.1.1"
36
 
37
  # ======================
38
  # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
  # ======================
40
 
41
+ def truncate_greeting_response(text: str) -> str:
42
  """
43
+ Para respuestas de saludo, trunca SOLO en el primer PUNTO (.)
44
+ No usa signos de exclamación o interrogación.
45
  """
46
+ if not text:
47
+ return text
48
+
49
+ # Buscar el primer PUNTO (.)
50
+ end_match = re.search(r'\.', text)
51
+
52
+ if end_match:
53
+ # Cortar justo después del punto
54
+ end_pos = end_match.end()
55
+ truncated = text[:end_pos].strip()
56
+ return truncated
57
+
58
+ # Si no hay punto, devolver solo primeras 80 caracteres
59
+ if len(text) > 80:
60
+ return text[:80] + "..."
61
+ return text
62
+
63
+ def clean_response(text: str, user_input: str = "") -> str:
64
+ """Limpia la respuesta del modelo"""
65
  if not text:
66
  return ""
67
 
68
+ # Eliminar repeticiones excesivas
69
  words = text.split()
70
  cleaned_words = []
71
+ last_word = ""
72
  repeat_count = 0
73
 
74
  for word in words:
75
+ if word == last_word:
76
  repeat_count += 1
77
+ if repeat_count > 2:
78
  continue
79
  else:
80
+ last_word = word
81
  repeat_count = 0
82
  cleaned_words.append(word)
83
 
84
  text = " ".join(cleaned_words)
85
 
86
+ # Eliminar caracteres raros
87
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text)
 
88
 
89
+ # Detectar si es un saludo
90
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
 
 
 
 
 
91
 
92
+ if is_greeting and text:
93
+ # Para saludos, truncar SOLO en el primer PUNTO (.)
94
+ punct_match = re.search(r'\.', text)
95
+ if punct_match:
96
+ text = text[:punct_match.end()].strip()
97
+ else:
98
+ # Si no hay punto, tomar solo la primera oración o 60 caracteres
99
+ first_sentence = text.split('.')[0].strip()
100
+ if len(first_sentence) > 5:
101
+ text = first_sentence
102
+ elif len(text) > 60:
103
+ text = text[:60]
104
 
105
+ # Si la respuesta es muy corta o vacía
106
+ if len(text.strip()) < 5:
107
+ if is_greeting:
108
+ return "¡Hola! ¿En qué puedo ayudarte?"
109
  return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
110
 
111
+ # Eliminar espacios múltiples
112
  text = re.sub(r'\s+', ' ', text).strip()
113
 
114
  return text
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # ======================
117
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
118
  # ======================
 
225
  return logits
226
 
227
  def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
228
+ """Genera texto token por token"""
229
  generated = input_ids
 
 
 
230
 
231
  for step in range(max_new_tokens):
232
  with torch.no_grad():
 
253
  probs = F.softmax(next_logits, dim=-1)
254
  next_token = torch.multinomial(probs, num_samples=1).item()
255
 
256
+ # EOS ID común para SentencePiece
257
+ if next_token == 2 or next_token == 3:
258
  break
259
 
260
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
 
 
 
 
 
 
 
261
 
262
  return generated
263
 
 
289
 
290
  # Cargar tokenizador
291
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
292
+ if not os.path.exists(tokenizer_path):
293
+ print(f"❌ Tokenizador no encontrado en {tokenizer_path}")
294
+ sys.exit(1)
295
+
296
  sp = spm.SentencePieceProcessor()
297
  sp.load(tokenizer_path)
298
  VOCAB_SIZE = sp.get_piece_size()
 
313
  model_path = os.path.join(repo_path, "mtp_model.pt")
314
  if os.path.exists(model_path):
315
  state_dict = torch.load(model_path, map_location=DEVICE)
316
+ model.load_state_dict(state_dict, strict=False)
317
  print("✅ Pesos del modelo cargados")
318
  else:
319
+ print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios")
320
 
321
  model.eval()
322
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  param_count = sum(p.numel() for p in model.parameters())
324
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
325
 
 
341
 
342
  class PromptRequest(BaseModel):
343
  text: str = Field(..., max_length=2000, description="Texto de entrada")
344
+ max_tokens: int = Field(default=100, ge=10, le=200, description="Tokens máximos a generar")
345
  temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
346
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
347
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
 
384
  global ACTIVE_REQUESTS
385
  ACTIVE_REQUESTS += 1
386
 
 
 
 
 
 
 
 
 
387
  user_input = req.text.strip()
388
  if not user_input:
389
  ACTIVE_REQUESTS -= 1
390
  return {"reply": "", "tokens_generated": 0}
391
+
392
+ # Detectar si es un saludo
393
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
394
+
395
+ # Si es saludo, usar menos tokens
396
+ max_tokens = 30 if is_greeting else req.max_tokens
397
 
398
  full_prompt = build_prompt(user_input)
399
+ tokens = tokenizer_wrapper.encode(full_prompt)
400
  input_ids = torch.tensor([tokens], device=DEVICE)
401
 
402
  try:
403
  with torch.no_grad():
404
  output_ids = model.generate(
405
  input_ids,
406
+ max_new_tokens=max_tokens,
407
+ temperature=req.temperature,
408
  top_k=req.top_k,
409
  top_p=req.top_p,
410
  repetition_penalty=req.repetition_penalty
 
412
 
413
  gen_tokens = output_ids[0, len(tokens):].tolist()
414
 
415
+ # Filtrar tokens inválidos
416
+ safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE]
 
 
417
 
418
+ if safe_tokens:
419
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
420
+ else:
421
+ response = ""
422
 
423
+ # Limpiar respuesta
424
+ response = clean_response(response, user_input)
425
 
426
+ # Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
427
+ if len(response) < 3:
428
+ if is_greeting:
429
+ response = "¡Hola! ¿En qué puedo ayudarte?"
430
+ else:
431
+ response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?"
432
 
433
  return {
434
  "reply": response,
 
438
 
439
  except Exception as e:
440
  print(f"❌ Error durante generación: {e}")
441
+ if is_greeting:
442
+ fallback = "¡Hola! ¿En qué puedo ayudarte?"
443
+ else:
444
+ fallback = "Lo siento, ocurrió un error al procesar tu solicitud."
445
  return {
446
+ "reply": fallback,
447
  "error": str(e)
448
  }
449
 
 
477
  }
478
 
479
  # ======================
480
+ # INTERFAZ WEB
481
  # ======================
482
  @app.get("/", response_class=HTMLResponse)
483
  def chat_ui():
 
486
  <html lang="es">
487
  <head>
488
  <meta charset="UTF-8">
489
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
490
  <title>MTP - Asistente IA</title>
 
 
 
491
  <style>
492
+ * { margin: 0; padding: 0; box-sizing: border-box; }
 
 
 
 
 
 
 
 
493
  body {
494
+ background: #131314;
495
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
496
+ height: 100vh;
 
 
497
  display: flex;
498
  flex-direction: column;
 
499
  }
500
+ .chat-header {
501
+ padding: 16px 20px;
502
+ background: #1E1F20;
503
+ border-bottom: 1px solid #2a2b2e;
 
 
 
 
 
 
 
 
504
  }
505
+ .chat-header h1 {
506
+ color: white;
507
+ font-size: 1.2rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  font-weight: 500;
 
 
 
 
509
  }
510
+ .chat-messages {
 
 
 
 
 
 
 
 
511
  flex: 1;
512
  overflow-y: auto;
513
+ padding: 20px;
514
  display: flex;
515
  flex-direction: column;
516
+ gap: 16px;
 
 
 
 
517
  }
518
+ .message {
519
  display: flex;
520
+ gap: 12px;
521
+ max-width: 80%;
 
 
 
522
  }
523
+ .message.user {
524
+ align-self: flex-end;
525
+ flex-direction: row-reverse;
 
 
 
 
526
  }
527
+ .message-content {
528
+ padding: 10px 16px;
 
529
  border-radius: 18px;
530
+ font-size: 0.95rem;
531
+ line-height: 1.4;
532
  }
533
+ .user .message-content {
534
+ background: #4a9eff;
535
+ color: white;
536
+ border-radius: 18px 4px 18px 18px;
 
537
  }
538
+ .bot .message-content {
539
+ background: #1E1F20;
540
+ color: #e3e3e3;
541
+ border-radius: 4px 18px 18px 18px;
542
  }
543
+ .chat-input-container {
544
+ padding: 16px 20px;
545
+ background: #1E1F20;
546
+ border-top: 1px solid #2a2b2e;
 
 
 
 
 
 
547
  }
548
+ .input-wrapper {
549
  display: flex;
550
+ gap: 12px;
551
+ max-width: 800px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  margin: 0 auto;
 
 
 
 
 
 
 
553
  }
554
+ #messageInput {
 
 
 
 
555
  flex: 1;
556
+ padding: 12px 16px;
557
+ background: #2a2b2e;
558
  border: none;
559
+ border-radius: 24px;
560
  color: white;
561
+ font-size: 0.95rem;
562
+ outline: none;
563
+ }
564
+ #messageInput::placeholder {
565
+ color: #888;
566
  }
567
+ #sendBtn {
568
+ padding: 12px 24px;
569
+ background: #4a9eff;
570
  border: none;
571
+ border-radius: 24px;
572
+ color: white;
573
+ font-weight: 500;
 
 
 
574
  cursor: pointer;
575
+ transition: opacity 0.2s;
 
576
  }
577
+ #sendBtn:hover { opacity: 0.9; }
578
+ #sendBtn:disabled {
579
+ opacity: 0.5;
580
+ cursor: not-allowed;
 
 
581
  }
582
+ .typing {
583
+ display: flex;
584
+ gap: 4px;
585
+ padding: 10px 16px;
586
  }
587
+ .typing span {
588
+ width: 8px;
589
+ height: 8px;
590
+ background: #888;
591
+ border-radius: 50%;
592
+ animation: bounce 1.4s infinite ease-in-out;
593
+ }
594
+ .typing span:nth-child(1) { animation-delay: -0.32s; }
595
+ .typing span:nth-child(2) { animation-delay: -0.16s; }
596
+ @keyframes bounce {
597
+ 0%, 80%, 100% { transform: scale(0); }
598
+ 40% { transform: scale(1); }
599
  }
 
 
 
 
600
  </style>
601
  </head>
602
  <body>
603
+ <div class="chat-header">
604
+ <h1>🤖 MTP - Asistente IA</h1>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  </div>
606
+ <div class="chat-messages" id="chatMessages">
607
+ <div class="message bot">
608
+ <div class="message-content">¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?</div>
 
609
  </div>
610
+ </div>
611
+ <div class="chat-input-container">
612
+ <div class="input-wrapper">
613
+ <input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
614
+ <button id="sendBtn">Enviar</button>
615
  </div>
616
  </div>
617
  <script>
618
+ const chatMessages = document.getElementById('chatMessages');
619
+ const messageInput = document.getElementById('messageInput');
620
+ const sendBtn = document.getElementById('sendBtn');
621
+ let isLoading = false;
 
 
 
 
 
 
 
622
 
623
+ function addMessage(text, isUser) {
624
+ const div = document.createElement('div');
625
+ div.className = `message ${isUser ? 'user' : 'bot'}`;
626
+ div.innerHTML = `<div class="message-content">${text}</div>`;
627
+ chatMessages.appendChild(div);
628
+ chatMessages.scrollTop = chatMessages.scrollHeight;
629
+ return div;
 
 
630
  }
631
 
632
+ function addTypingIndicator() {
633
+ const div = document.createElement('div');
634
+ div.className = 'message bot';
635
+ div.id = 'typingIndicator';
636
+ div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
637
+ chatMessages.appendChild(div);
638
+ chatMessages.scrollTop = chatMessages.scrollHeight;
639
  }
640
 
641
+ function removeTypingIndicator() {
642
+ const indicator = document.getElementById('typingIndicator');
643
+ if (indicator) indicator.remove();
 
 
 
 
 
 
644
  }
645
 
646
+ async function sendMessage() {
647
+ const text = messageInput.value.trim();
648
+ if (!text || isLoading) return;
649
+
650
+ messageInput.value = '';
651
+ addMessage(text, true);
652
+ isLoading = true;
653
+ sendBtn.disabled = true;
654
+ addTypingIndicator();
655
+
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  try {
657
  const response = await fetch('/generate', {
658
  method: 'POST',
659
  headers: { 'Content-Type': 'application/json' },
660
+ body: JSON.stringify({ text: text })
 
661
  });
662
  const data = await response.json();
663
+ removeTypingIndicator();
664
+ addMessage(data.reply, false);
 
 
 
 
 
 
665
  } catch (error) {
666
+ removeTypingIndicator();
667
+ addMessage('Error de conexión. Intenta de nuevo.', false);
668
+ } finally {
669
+ isLoading = false;
670
+ sendBtn.disabled = false;
671
+ messageInput.focus();
 
 
672
  }
673
  }
674
 
675
+ messageInput.addEventListener('keypress', (e) => {
676
+ if (e.key === 'Enter') sendMessage();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  });
678
+ sendBtn.addEventListener('click', sendMessage);
679
+ messageInput.focus();
680
  </script>
681
  </body>
682
  </html>