Hammad712 commited on
Commit
ab82fde
Β·
verified Β·
1 Parent(s): 5599463

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -45
app.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn.functional as F
6
  import numpy as np
7
  import onnxruntime as ort
8
  import soundfile as sf
 
9
  from fastapi import FastAPI, UploadFile, File, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
@@ -24,62 +25,69 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # 🚨 Use Absolute Path
28
  MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx"
29
 
30
- logger.info(f"πŸš€ Attempting to load model from: {MODEL_PATH}")
31
-
32
  try:
33
- if not os.path.exists(MODEL_PATH):
34
- # List files for debugging in logs if it fails
35
- logger.error(f"Files in /app/local_model: {os.listdir('/app/local_model') if os.path.exists('/app/local_model') else 'Dir not found'}")
36
- raise FileNotFoundError(f"Model file missing at {MODEL_PATH}")
37
-
38
- # Load with mmap to save RAM
39
- session_options = ort.SessionOptions()
40
- session = ort.InferenceSession(MODEL_PATH, sess_options=session_options, providers=['CPUExecutionProvider'])
41
  logger.info("βœ… Engine is LIVE and Ready!")
42
  except Exception as e:
43
- logger.error(f"❌ Initialization Error: {e}")
44
  raise e
45
 
46
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
47
  id2label = {i: label for i, label in enumerate(labels)}
48
 
49
- def predict_audio(audio_path):
50
- data, sr = sf.read(audio_path)
51
- waveform = torch.from_numpy(data).float()
52
- if waveform.ndim == 2:
53
- waveform = waveform.T.mean(dim=0, keepdim=True)
54
- else:
55
- waveform = waveform.unsqueeze(0)
56
-
57
- if sr != 16000:
58
- waveform = torchaudio.functional.resample(waveform, sr, 16000)
59
-
60
- target_frames = 16000 * 15
61
- waveform = waveform[:, :target_frames]
62
- waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
63
- waveform = waveform / waveform.std().clamp(min=1e-6)
64
-
65
- length = waveform.shape[1]
66
- mask = torch.zeros(target_frames, dtype=torch.long)
67
- if length < target_frames:
68
- mask[:length] = 1
69
- waveform = F.pad(waveform, (0, target_frames - length))
70
- else:
71
- mask[:] = 1
72
 
73
- ort_inputs = {
74
- "input_values": waveform.numpy(),
75
- "attention_mask": mask.unsqueeze(0).numpy()
76
- }
77
-
78
- logits = session.run(None, ort_inputs)[0]
79
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
80
- pred_id = np.argmax(probs, axis=1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- return id2label[pred_id], float(probs[0][pred_id])
 
 
83
 
84
  @app.post("/predict")
85
  async def predict(file: UploadFile = File(...)):
@@ -87,9 +95,12 @@ async def predict(file: UploadFile = File(...)):
87
  try:
88
  with open(temp_path, "wb") as f:
89
  f.write(await file.read())
 
90
  lang, conf = predict_audio(temp_path)
91
- os.remove(temp_path)
 
92
  return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
 
93
  except Exception as e:
94
  logger.error(f"Inference Error: {e}")
95
  if os.path.exists(temp_path): os.remove(temp_path)
 
6
  import numpy as np
7
  import onnxruntime as ort
8
  import soundfile as sf
9
+ import subprocess
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
 
25
  allow_headers=["*"],
26
  )
27
 
28
+ # Use Absolute Path (Success from previous log!)
29
  MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx"
30
 
31
+ logger.info(f"πŸš€ Loading pre-baked ONNX model from: {MODEL_PATH}")
 
32
  try:
33
+ session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
 
 
 
 
 
 
 
34
  logger.info("βœ… Engine is LIVE and Ready!")
35
  except Exception as e:
36
+ logger.error(f"❌ Failed to load model: {e}")
37
  raise e
38
 
39
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
40
  id2label = {i: label for i, label in enumerate(labels)}
41
 
42
+ def predict_audio(input_path):
43
+ clean_wav_path = "cleaned_audio.wav"
44
+ try:
45
+ # πŸ› οΈ THE FIX: Use FFmpeg to convert ANY format (WebM, OGG, etc.) to Standard WAV
46
+ # This handles the "Format not recognised" error
47
+ subprocess.run([
48
+ 'ffmpeg', '-y', '-i', input_path,
49
+ '-ar', '16000', '-ac', '1', clean_wav_path
50
+ ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Now read the standard WAV
53
+ data, sr = sf.read(clean_wav_path)
54
+
55
+ waveform = torch.from_numpy(data).float()
56
+ if waveform.ndim == 1:
57
+ waveform = waveform.unsqueeze(0)
58
+
59
+ # Audio Preprocessing
60
+ target_frames = 16000 * 15
61
+ if waveform.shape[1] > target_frames:
62
+ waveform = waveform[:, :target_frames]
63
+
64
+ waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
65
+ waveform = waveform / waveform.std().clamp(min=1e-6)
66
+
67
+ length = waveform.shape[1]
68
+ mask = torch.zeros(target_frames, dtype=torch.long)
69
+ if length < target_frames:
70
+ mask[:length] = 1
71
+ waveform = F.pad(waveform, (0, target_frames - length))
72
+ else:
73
+ mask[:] = 1
74
+
75
+ # ONNX Inference
76
+ ort_inputs = {
77
+ "input_values": waveform.numpy(),
78
+ "attention_mask": mask.unsqueeze(0).numpy()
79
+ }
80
+
81
+ logits = session.run(None, ort_inputs)[0]
82
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
83
+ pred_id = np.argmax(probs, axis=1)[0]
84
+
85
+ if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
86
+ return id2label[pred_id], float(probs[0][pred_id])
87
 
88
+ except Exception as e:
89
+ if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
90
+ raise e
91
 
92
  @app.post("/predict")
93
  async def predict(file: UploadFile = File(...)):
 
95
  try:
96
  with open(temp_path, "wb") as f:
97
  f.write(await file.read())
98
+
99
  lang, conf = predict_audio(temp_path)
100
+
101
+ if os.path.exists(temp_path): os.remove(temp_path)
102
  return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
103
+
104
  except Exception as e:
105
  logger.error(f"Inference Error: {e}")
106
  if os.path.exists(temp_path): os.remove(temp_path)