Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ import os
|
|
| 12 |
import requests
|
| 13 |
import torch
|
| 14 |
from flask import Flask, request, render_template_string, jsonify
|
| 15 |
-
from transformers import
|
| 16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
|
| 18 |
# Configuration
|
|
@@ -35,17 +35,25 @@ tokenizer = None
|
|
| 35 |
device = None
|
| 36 |
|
| 37 |
def load_model():
|
| 38 |
-
"""Load MarianMT model once and cache it"""
|
| 39 |
global model, tokenizer, device
|
| 40 |
|
| 41 |
if model is None:
|
| 42 |
print("Loading MarianMT model...")
|
| 43 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
return model, tokenizer, device
|
| 51 |
|
|
@@ -57,6 +65,7 @@ def translate_marian(text):
|
|
| 57 |
try:
|
| 58 |
model, tokenizer, device = load_model()
|
| 59 |
|
|
|
|
| 60 |
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 61 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 62 |
|
|
@@ -74,7 +83,7 @@ def translate_marian(text):
|
|
| 74 |
translations = []
|
| 75 |
for output in outputs:
|
| 76 |
trans = tokenizer.decode(output, skip_special_tokens=True)
|
| 77 |
-
if trans not in translations: # Avoid duplicates
|
| 78 |
translations.append(trans)
|
| 79 |
|
| 80 |
return translations if translations else ["[Error: No translation generated]"]
|
|
@@ -83,7 +92,7 @@ def translate_marian(text):
|
|
| 83 |
print(f"MarianMT translation error: {e}")
|
| 84 |
import traceback
|
| 85 |
traceback.print_exc()
|
| 86 |
-
return [f"[Error: {str(e)}"]
|
| 87 |
|
| 88 |
def translate_libre_variant(text, variant_code):
|
| 89 |
"""Translate using a specific LibreTranslate variant"""
|
|
@@ -597,28 +606,20 @@ HTML_TEMPLATE = """
|
|
| 597 |
let selectedText = '';
|
| 598 |
|
| 599 |
function selectTranslation(element, text) {
|
| 600 |
-
// Remove selected class from all items in this list
|
| 601 |
element.parentElement.querySelectorAll('.translation-item').forEach(item => {
|
| 602 |
item.classList.remove('selected');
|
| 603 |
});
|
| 604 |
-
// Add selected class to clicked item
|
| 605 |
element.classList.add('selected');
|
| 606 |
selectedText = text;
|
| 607 |
-
|
| 608 |
-
// Copy automatically on selection
|
| 609 |
copyText(text, false);
|
| 610 |
}
|
| 611 |
|
| 612 |
function selectVariant(element, text) {
|
| 613 |
-
// Remove selected class from all variants
|
| 614 |
document.querySelectorAll('.variant-item').forEach(item => {
|
| 615 |
item.classList.remove('selected');
|
| 616 |
});
|
| 617 |
-
// Add selected class to clicked item
|
| 618 |
element.classList.add('selected');
|
| 619 |
selectedText = text;
|
| 620 |
-
|
| 621 |
-
// Copy automatically on selection
|
| 622 |
copyText(text, false);
|
| 623 |
}
|
| 624 |
|
|
@@ -629,7 +630,6 @@ HTML_TEMPLATE = """
|
|
| 629 |
showToastMessage('Copied to clipboard!');
|
| 630 |
}
|
| 631 |
} catch (err) {
|
| 632 |
-
// Fallback
|
| 633 |
const textArea = document.createElement('textarea');
|
| 634 |
textArea.value = text;
|
| 635 |
document.body.appendChild(textArea);
|
|
@@ -670,10 +670,7 @@ def index():
|
|
| 670 |
if request.method == "POST":
|
| 671 |
source_text = request.form.get("text", "").strip()
|
| 672 |
if source_text:
|
| 673 |
-
# Get Marian translations
|
| 674 |
marian = translate_marian(source_text)
|
| 675 |
-
|
| 676 |
-
# Get Libre translations
|
| 677 |
libre_results = translate_libre_all_variants(source_text)
|
| 678 |
for name, data in libre_results.items():
|
| 679 |
data['code'] = KABYLE_VARIANTS[name]
|
|
@@ -688,13 +685,11 @@ def index():
|
|
| 688 |
|
| 689 |
@app.route("/health")
|
| 690 |
def health():
|
| 691 |
-
"""Health check endpoint"""
|
| 692 |
return jsonify({
|
| 693 |
"status": "healthy",
|
| 694 |
"model_loaded": model is not None
|
| 695 |
})
|
| 696 |
|
| 697 |
if __name__ == "__main__":
|
| 698 |
-
# Hugging Face Spaces expects port 7860
|
| 699 |
port = int(os.environ.get("PORT", 7860))
|
| 700 |
app.run(host="0.0.0.0", port=port, debug=False)
|
|
|
|
| 12 |
import requests
|
| 13 |
import torch
|
| 14 |
from flask import Flask, request, render_template_string, jsonify
|
| 15 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
|
| 18 |
# Configuration
|
|
|
|
| 35 |
device = None
|
| 36 |
|
| 37 |
def load_model():
|
| 38 |
+
"""Load MarianMT model once and cache it using Auto classes"""
|
| 39 |
global model, tokenizer, device
|
| 40 |
|
| 41 |
if model is None:
|
| 42 |
print("Loading MarianMT model...")
|
| 43 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 44 |
|
| 45 |
+
try:
|
| 46 |
+
# Try Auto classes first (more flexible)
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
|
| 48 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
|
| 49 |
+
print(f"Model loaded successfully on {device} using Auto classes")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Auto classes failed: {e}")
|
| 52 |
+
print("Trying legacy Marian classes...")
|
| 53 |
+
from transformers import MarianMTModel, MarianTokenizer
|
| 54 |
+
tokenizer = MarianTokenizer.from_pretrained(MODEL_ID, use_fast=False)
|
| 55 |
+
model = MarianMTModel.from_pretrained(MODEL_ID).to(device).eval()
|
| 56 |
+
print(f"Model loaded on {device} using legacy classes")
|
| 57 |
|
| 58 |
return model, tokenizer, device
|
| 59 |
|
|
|
|
| 65 |
try:
|
| 66 |
model, tokenizer, device = load_model()
|
| 67 |
|
| 68 |
+
# Prepare inputs
|
| 69 |
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 70 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 71 |
|
|
|
|
| 83 |
translations = []
|
| 84 |
for output in outputs:
|
| 85 |
trans = tokenizer.decode(output, skip_special_tokens=True)
|
| 86 |
+
if trans and trans not in translations: # Avoid duplicates and empty
|
| 87 |
translations.append(trans)
|
| 88 |
|
| 89 |
return translations if translations else ["[Error: No translation generated]"]
|
|
|
|
| 92 |
print(f"MarianMT translation error: {e}")
|
| 93 |
import traceback
|
| 94 |
traceback.print_exc()
|
| 95 |
+
return [f"[Error: {str(e)}]"]
|
| 96 |
|
| 97 |
def translate_libre_variant(text, variant_code):
|
| 98 |
"""Translate using a specific LibreTranslate variant"""
|
|
|
|
| 606 |
let selectedText = '';
|
| 607 |
|
| 608 |
function selectTranslation(element, text) {
|
|
|
|
| 609 |
element.parentElement.querySelectorAll('.translation-item').forEach(item => {
|
| 610 |
item.classList.remove('selected');
|
| 611 |
});
|
|
|
|
| 612 |
element.classList.add('selected');
|
| 613 |
selectedText = text;
|
|
|
|
|
|
|
| 614 |
copyText(text, false);
|
| 615 |
}
|
| 616 |
|
| 617 |
function selectVariant(element, text) {
|
|
|
|
| 618 |
document.querySelectorAll('.variant-item').forEach(item => {
|
| 619 |
item.classList.remove('selected');
|
| 620 |
});
|
|
|
|
| 621 |
element.classList.add('selected');
|
| 622 |
selectedText = text;
|
|
|
|
|
|
|
| 623 |
copyText(text, false);
|
| 624 |
}
|
| 625 |
|
|
|
|
| 630 |
showToastMessage('Copied to clipboard!');
|
| 631 |
}
|
| 632 |
} catch (err) {
|
|
|
|
| 633 |
const textArea = document.createElement('textarea');
|
| 634 |
textArea.value = text;
|
| 635 |
document.body.appendChild(textArea);
|
|
|
|
| 670 |
if request.method == "POST":
|
| 671 |
source_text = request.form.get("text", "").strip()
|
| 672 |
if source_text:
|
|
|
|
| 673 |
marian = translate_marian(source_text)
|
|
|
|
|
|
|
| 674 |
libre_results = translate_libre_all_variants(source_text)
|
| 675 |
for name, data in libre_results.items():
|
| 676 |
data['code'] = KABYLE_VARIANTS[name]
|
|
|
|
| 685 |
|
| 686 |
@app.route("/health")
|
| 687 |
def health():
|
|
|
|
| 688 |
return jsonify({
|
| 689 |
"status": "healthy",
|
| 690 |
"model_loaded": model is not None
|
| 691 |
})
|
| 692 |
|
| 693 |
if __name__ == "__main__":
|
|
|
|
| 694 |
port = int(os.environ.get("PORT", 7860))
|
| 695 |
app.run(host="0.0.0.0", port=port, debug=False)
|