boffire commited on
Commit
f88d1f9
·
verified ·
1 Parent(s): 61a5f48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -12,7 +12,7 @@ import os
12
  import requests
13
  import torch
14
  from flask import Flask, request, render_template_string, jsonify
15
- from transformers import MarianMTModel, MarianTokenizer
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
 
18
  # Configuration
@@ -35,17 +35,25 @@ tokenizer = None
35
  device = None
36
 
37
  def load_model():
38
- """Load MarianMT model once and cache it"""
39
  global model, tokenizer, device
40
 
41
  if model is None:
42
  print("Loading MarianMT model...")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
 
45
- tokenizer = MarianTokenizer.from_pretrained(MODEL_ID)
46
- model = MarianMTModel.from_pretrained(MODEL_ID).to(device).eval()
47
-
48
- print(f"Model loaded on {device}")
 
 
 
 
 
 
 
 
49
 
50
  return model, tokenizer, device
51
 
@@ -57,6 +65,7 @@ def translate_marian(text):
57
  try:
58
  model, tokenizer, device = load_model()
59
 
 
60
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
61
  inputs = {k: v.to(device) for k, v in inputs.items()}
62
 
@@ -74,7 +83,7 @@ def translate_marian(text):
74
  translations = []
75
  for output in outputs:
76
  trans = tokenizer.decode(output, skip_special_tokens=True)
77
- if trans not in translations: # Avoid duplicates
78
  translations.append(trans)
79
 
80
  return translations if translations else ["[Error: No translation generated]"]
@@ -83,7 +92,7 @@ def translate_marian(text):
83
  print(f"MarianMT translation error: {e}")
84
  import traceback
85
  traceback.print_exc()
86
- return [f"[Error: {str(e)}"]
87
 
88
  def translate_libre_variant(text, variant_code):
89
  """Translate using a specific LibreTranslate variant"""
@@ -597,28 +606,20 @@ HTML_TEMPLATE = """
597
  let selectedText = '';
598
 
599
  function selectTranslation(element, text) {
600
- // Remove selected class from all items in this list
601
  element.parentElement.querySelectorAll('.translation-item').forEach(item => {
602
  item.classList.remove('selected');
603
  });
604
- // Add selected class to clicked item
605
  element.classList.add('selected');
606
  selectedText = text;
607
-
608
- // Copy automatically on selection
609
  copyText(text, false);
610
  }
611
 
612
  function selectVariant(element, text) {
613
- // Remove selected class from all variants
614
  document.querySelectorAll('.variant-item').forEach(item => {
615
  item.classList.remove('selected');
616
  });
617
- // Add selected class to clicked item
618
  element.classList.add('selected');
619
  selectedText = text;
620
-
621
- // Copy automatically on selection
622
  copyText(text, false);
623
  }
624
 
@@ -629,7 +630,6 @@ HTML_TEMPLATE = """
629
  showToastMessage('Copied to clipboard!');
630
  }
631
  } catch (err) {
632
- // Fallback
633
  const textArea = document.createElement('textarea');
634
  textArea.value = text;
635
  document.body.appendChild(textArea);
@@ -670,10 +670,7 @@ def index():
670
  if request.method == "POST":
671
  source_text = request.form.get("text", "").strip()
672
  if source_text:
673
- # Get Marian translations
674
  marian = translate_marian(source_text)
675
-
676
- # Get Libre translations
677
  libre_results = translate_libre_all_variants(source_text)
678
  for name, data in libre_results.items():
679
  data['code'] = KABYLE_VARIANTS[name]
@@ -688,13 +685,11 @@ def index():
688
 
689
  @app.route("/health")
690
  def health():
691
- """Health check endpoint"""
692
  return jsonify({
693
  "status": "healthy",
694
  "model_loaded": model is not None
695
  })
696
 
697
  if __name__ == "__main__":
698
- # Hugging Face Spaces expects port 7860
699
  port = int(os.environ.get("PORT", 7860))
700
  app.run(host="0.0.0.0", port=port, debug=False)
 
12
  import requests
13
  import torch
14
  from flask import Flask, request, render_template_string, jsonify
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
 
18
  # Configuration
 
35
  device = None
36
 
37
  def load_model():
38
+ """Load MarianMT model once and cache it using Auto classes"""
39
  global model, tokenizer, device
40
 
41
  if model is None:
42
  print("Loading MarianMT model...")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
 
45
+ try:
46
+ # Try Auto classes first (more flexible)
47
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
48
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
49
+ print(f"Model loaded successfully on {device} using Auto classes")
50
+ except Exception as e:
51
+ print(f"Auto classes failed: {e}")
52
+ print("Trying legacy Marian classes...")
53
+ from transformers import MarianMTModel, MarianTokenizer
54
+ tokenizer = MarianTokenizer.from_pretrained(MODEL_ID, use_fast=False)
55
+ model = MarianMTModel.from_pretrained(MODEL_ID).to(device).eval()
56
+ print(f"Model loaded on {device} using legacy classes")
57
 
58
  return model, tokenizer, device
59
 
 
65
  try:
66
  model, tokenizer, device = load_model()
67
 
68
+ # Prepare inputs
69
  inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
70
  inputs = {k: v.to(device) for k, v in inputs.items()}
71
 
 
83
  translations = []
84
  for output in outputs:
85
  trans = tokenizer.decode(output, skip_special_tokens=True)
86
+ if trans and trans not in translations: # Avoid duplicates and empty
87
  translations.append(trans)
88
 
89
  return translations if translations else ["[Error: No translation generated]"]
 
92
  print(f"MarianMT translation error: {e}")
93
  import traceback
94
  traceback.print_exc()
95
+ return [f"[Error: {str(e)}]"]
96
 
97
  def translate_libre_variant(text, variant_code):
98
  """Translate using a specific LibreTranslate variant"""
 
606
  let selectedText = '';
607
 
608
  function selectTranslation(element, text) {
 
609
  element.parentElement.querySelectorAll('.translation-item').forEach(item => {
610
  item.classList.remove('selected');
611
  });
 
612
  element.classList.add('selected');
613
  selectedText = text;
 
 
614
  copyText(text, false);
615
  }
616
 
617
  function selectVariant(element, text) {
 
618
  document.querySelectorAll('.variant-item').forEach(item => {
619
  item.classList.remove('selected');
620
  });
 
621
  element.classList.add('selected');
622
  selectedText = text;
 
 
623
  copyText(text, false);
624
  }
625
 
 
630
  showToastMessage('Copied to clipboard!');
631
  }
632
  } catch (err) {
 
633
  const textArea = document.createElement('textarea');
634
  textArea.value = text;
635
  document.body.appendChild(textArea);
 
670
  if request.method == "POST":
671
  source_text = request.form.get("text", "").strip()
672
  if source_text:
 
673
  marian = translate_marian(source_text)
 
 
674
  libre_results = translate_libre_all_variants(source_text)
675
  for name, data in libre_results.items():
676
  data['code'] = KABYLE_VARIANTS[name]
 
685
 
686
  @app.route("/health")
687
  def health():
 
688
  return jsonify({
689
  "status": "healthy",
690
  "model_loaded": model is not None
691
  })
692
 
693
  if __name__ == "__main__":
 
694
  port = int(os.environ.get("PORT", 7860))
695
  app.run(host="0.0.0.0", port=port, debug=False)