boffire commited on
Commit
3252379
·
verified ·
1 Parent(s): f88d1f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -35,25 +35,17 @@ tokenizer = None
35
  device = None
36
 
37
  def load_model():
38
- """Load MarianMT model once and cache it using Auto classes"""
39
  global model, tokenizer, device
40
 
41
  if model is None:
42
  print("Loading MarianMT model...")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
 
45
- try:
46
- # Try Auto classes first (more flexible)
47
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
48
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
49
- print(f"Model loaded successfully on {device} using Auto classes")
50
- except Exception as e:
51
- print(f"Auto classes failed: {e}")
52
- print("Trying legacy Marian classes...")
53
- from transformers import MarianMTModel, MarianTokenizer
54
- tokenizer = MarianTokenizer.from_pretrained(MODEL_ID, use_fast=False)
55
- model = MarianMTModel.from_pretrained(MODEL_ID).to(device).eval()
56
- print(f"Model loaded on {device} using legacy classes")
57
 
58
  return model, tokenizer, device
59
 
@@ -70,20 +62,20 @@ def translate_marian(text):
70
  inputs = {k: v.to(device) for k, v in inputs.items()}
71
 
72
  with torch.no_grad():
 
73
  outputs = model.generate(
74
  **inputs,
75
- num_beams=6,
76
- num_beam_groups=3,
77
- diversity_penalty=1.2,
78
  num_return_sequences=3,
79
  max_length=128,
80
  early_stopping=True,
 
81
  )
82
 
83
  translations = []
84
  for output in outputs:
85
  trans = tokenizer.decode(output, skip_special_tokens=True)
86
- if trans and trans not in translations: # Avoid duplicates and empty
87
  translations.append(trans)
88
 
89
  return translations if translations else ["[Error: No translation generated]"]
 
35
  device = None
36
 
37
  def load_model():
38
+ """Load MarianMT model once and cache it"""
39
  global model, tokenizer, device
40
 
41
  if model is None:
42
  print("Loading MarianMT model...")
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
 
45
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
46
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device).eval()
47
+
48
+ print(f"Model loaded successfully on {device}")
 
 
 
 
 
 
 
 
49
 
50
  return model, tokenizer, device
51
 
 
62
  inputs = {k: v.to(device) for k, v in inputs.items()}
63
 
64
  with torch.no_grad():
65
+ # Simple beam search without group beam search
66
  outputs = model.generate(
67
  **inputs,
68
+ num_beams=4,
 
 
69
  num_return_sequences=3,
70
  max_length=128,
71
  early_stopping=True,
72
+ do_sample=False,
73
  )
74
 
75
  translations = []
76
  for output in outputs:
77
  trans = tokenizer.decode(output, skip_special_tokens=True)
78
+ if trans and trans not in translations:
79
  translations.append(trans)
80
 
81
  return translations if translations else ["[Error: No translation generated]"]