| import logging |
| import os |
| import torch |
| import torchaudio |
| import torch.nn.functional as F |
| import numpy as np |
| import onnxruntime as ort |
| import soundfile as sf |
| import subprocess |
| from fastapi import FastAPI, UploadFile, File, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| logger = logging.getLogger("LID_Engine") |
|
|
| app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx" |
|
|
| logger.info(f"π Loading pre-baked ONNX model from: {MODEL_PATH}") |
| try: |
| session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) |
| logger.info("β
Engine is LIVE and Ready!") |
| except Exception as e: |
| logger.error(f"β Failed to load model: {e}") |
| raise e |
|
|
| labels = ("balochi", "english", "pashto", "sindhi", "urdu") |
| id2label = {i: label for i, label in enumerate(labels)} |
|
|
| def predict_audio(input_path): |
| clean_wav_path = "cleaned_audio.wav" |
| try: |
| |
| |
| subprocess.run([ |
| 'ffmpeg', '-y', '-i', input_path, |
| '-ar', '16000', '-ac', '1', clean_wav_path |
| ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
|
| |
| data, sr = sf.read(clean_wav_path) |
| |
| waveform = torch.from_numpy(data).float() |
| if waveform.ndim == 1: |
| waveform = waveform.unsqueeze(0) |
| |
| |
| target_frames = 16000 * 15 |
| if waveform.shape[1] > target_frames: |
| waveform = waveform[:, :target_frames] |
| |
| waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean() |
| waveform = waveform / waveform.std().clamp(min=1e-6) |
| |
| length = waveform.shape[1] |
| mask = torch.zeros(target_frames, dtype=torch.long) |
| if length < target_frames: |
| mask[:length] = 1 |
| waveform = F.pad(waveform, (0, target_frames - length)) |
| else: |
| mask[:] = 1 |
|
|
| |
| ort_inputs = { |
| "input_values": waveform.numpy(), |
| "attention_mask": mask.unsqueeze(0).numpy() |
| } |
| |
| logits = session.run(None, ort_inputs)[0] |
| probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) |
| pred_id = np.argmax(probs, axis=1)[0] |
| |
| if os.path.exists(clean_wav_path): os.remove(clean_wav_path) |
| return id2label[pred_id], float(probs[0][pred_id]) |
| |
| except Exception as e: |
| if os.path.exists(clean_wav_path): os.remove(clean_wav_path) |
| raise e |
|
|
| @app.post("/predict") |
| async def predict(file: UploadFile = File(...)): |
| temp_path = f"temp_{file.filename}" |
| try: |
| with open(temp_path, "wb") as f: |
| f.write(await file.read()) |
| |
| lang, conf = predict_audio(temp_path) |
| |
| if os.path.exists(temp_path): os.remove(temp_path) |
| return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)} |
| |
| except Exception as e: |
| logger.error(f"Inference Error: {e}") |
| if os.path.exists(temp_path): os.remove(temp_path) |
| return {"success": False, "error": str(e)} |
|
|
| @app.get("/") |
| def health(): |
| return {"status": "online"} |