lid / app.py
Hammad712's picture
Update app.py
ab82fde verified
import logging
import os
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
import onnxruntime as ort
import soundfile as sf
import subprocess
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("LID_Engine")
app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")
# CORS Fix
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Use Absolute Path (Success from previous log!)
MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx"
logger.info(f"πŸš€ Loading pre-baked ONNX model from: {MODEL_PATH}")
try:
session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
logger.info("βœ… Engine is LIVE and Ready!")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise e
labels = ("balochi", "english", "pashto", "sindhi", "urdu")
id2label = {i: label for i, label in enumerate(labels)}
def predict_audio(input_path):
clean_wav_path = "cleaned_audio.wav"
try:
# πŸ› οΈ THE FIX: Use FFmpeg to convert ANY format (WebM, OGG, etc.) to Standard WAV
# This handles the "Format not recognised" error
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-ar', '16000', '-ac', '1', clean_wav_path
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Now read the standard WAV
data, sr = sf.read(clean_wav_path)
waveform = torch.from_numpy(data).float()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Audio Preprocessing
target_frames = 16000 * 15
if waveform.shape[1] > target_frames:
waveform = waveform[:, :target_frames]
waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
waveform = waveform / waveform.std().clamp(min=1e-6)
length = waveform.shape[1]
mask = torch.zeros(target_frames, dtype=torch.long)
if length < target_frames:
mask[:length] = 1
waveform = F.pad(waveform, (0, target_frames - length))
else:
mask[:] = 1
# ONNX Inference
ort_inputs = {
"input_values": waveform.numpy(),
"attention_mask": mask.unsqueeze(0).numpy()
}
logits = session.run(None, ort_inputs)[0]
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
pred_id = np.argmax(probs, axis=1)[0]
if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
return id2label[pred_id], float(probs[0][pred_id])
except Exception as e:
if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
raise e
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
temp_path = f"temp_{file.filename}"
try:
with open(temp_path, "wb") as f:
f.write(await file.read())
lang, conf = predict_audio(temp_path)
if os.path.exists(temp_path): os.remove(temp_path)
return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
except Exception as e:
logger.error(f"Inference Error: {e}")
if os.path.exists(temp_path): os.remove(temp_path)
return {"success": False, "error": str(e)}
@app.get("/")
def health():
return {"status": "online"}