OhamLab-AI / app_qwen_tts_fast.py
rahul7star's picture
Update app_qwen_tts_fast.py
5bd895e verified
import os
import base64
import uuid
import requests
import torch
import gradio as gr
import numpy as np
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
# =====================================================
# CONFIG
# =====================================================
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DOC_FILE_EN = "general.md"
DOC_FILE_HI = "general-hi.md"
TTS_API_URL = os.getenv(
"TTS_API_URL"
)
MAX_NEW_TOKENS = 128
TOP_K = 3
SESSION = requests.Session()
# =====================================================
# LOAD DOCUMENTS
# =====================================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DOC_PATH_EN = os.path.join(BASE_DIR, DOC_FILE_EN)
DOC_PATH_HI = os.path.join(BASE_DIR, DOC_FILE_HI)
for path, name in [(DOC_PATH_EN, DOC_FILE_EN), (DOC_PATH_HI, DOC_FILE_HI)]:
if not os.path.exists(path):
raise RuntimeError(f"{name} not found")
with open(DOC_PATH_EN, "r", encoding="utf-8", errors="ignore") as f:
DOC_TEXT_EN = f.read()
with open(DOC_PATH_HI, "r", encoding="utf-8", errors="ignore") as f:
DOC_TEXT_HI = f.read()
# =====================================================
# CHUNK + EMBED
# =====================================================
def chunk_text(text, chunk_size=300, overlap=50):
words = text.split()
chunks, i = [], 0
while i < len(words):
chunks.append(" ".join(words[i:i + chunk_size]))
i += chunk_size - overlap
return chunks
DOC_CHUNKS_EN = chunk_text(DOC_TEXT_EN)
DOC_CHUNKS_HI = chunk_text(DOC_TEXT_HI)
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
DOC_EMBEDS_EN = embedder.encode(DOC_CHUNKS_EN, normalize_embeddings=True, batch_size=32)
DOC_EMBEDS_HI = embedder.encode(DOC_CHUNKS_HI, normalize_embeddings=True, batch_size=32)
# =====================================================
# LOAD QWEN MODEL (CPU only)
# =====================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
model.eval()
# =====================================================
# RETRIEVAL WITH CACHE
# =====================================================
@lru_cache(maxsize=256)
def retrieve_context(question: str, lang: str):
q_emb = embedder.encode([question], normalize_embeddings=True)
if lang == "hi":
scores = np.dot(DOC_EMBEDS_HI, q_emb[0])
top_ids = scores.argsort()[-TOP_K:][::-1]
return "\n\n".join(DOC_CHUNKS_HI[i] for i in top_ids)
else:
scores = np.dot(DOC_EMBEDS_EN, q_emb[0])
top_ids = scores.argsort()[-TOP_K:][::-1]
return "\n\n".join(DOC_CHUNKS_EN[i] for i in top_ids)
# =====================================================
# QWEN ANSWER (CPU optimized)
# =====================================================
def answer_question(question: str, lang: str = "en") -> str:
context = retrieve_context(question, lang)
messages = [
{
"role": "system",
"content": (
"You are a strict document-based Q&A assistant.\n"
"Answer ONLY the question.\n"
"Respond in 1 short sentence.\n"
"If not found, say:\n"
"'I could not find this information in the document.'"
)
},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
use_cache=True
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("\n")[-1].strip()
# =====================================================
# TTS (CPU safe, flexible language)
# =====================================================
@lru_cache(maxsize=128)
def generate_audio(text: str, language_id: str = "en") -> str:
payload = {"text": text, "language_id": language_id, "mode": "Speak 🗣️"}
r = SESSION.post(TTS_API_URL, json=payload, timeout=None)
r.raise_for_status()
wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav"
# raw audio bytes
if r.headers.get("content-type", "").startswith("audio"):
with open(wav_path, "wb") as f:
f.write(r.content)
return wav_path
# JSON base64
data = r.json()
audio_b64 = data.get("audio") or data.get("audio_base64") or data.get("wav")
if not audio_b64:
raise RuntimeError(f"TTS API returned no audio: {data}")
audio_bytes = base64.b64decode(audio_b64)
with open(wav_path, "wb") as f:
f.write(audio_bytes)
if os.path.getsize(wav_path) < 1000:
raise RuntimeError("Generated audio file too small")
return wav_path
# =====================================================
# MAIN PIPELINE
# =====================================================
def run_pipeline(question: str, language_id: str):
if not question.strip():
return "", None
# 1️⃣ Answer text
answer = answer_question(question, language_id)
# 2️⃣ TTS
try:
audio_path = generate_audio(answer, language_id)
except Exception as e:
print("TTS failed:", e)
audio_path = None
return f"**Bot:** {answer}", audio_path
# =====================================================
# GRADIO UI
# =====================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(
label="Your Question",
placeholder="Who is CEO of OhamLab?",
lines=3
)
language_dropdown = gr.Dropdown(
label="TTS Language",
choices=["en", "hi"],
value="en"
)
ask_btn = gr.Button("Ask")
# ===== Example questions =====
gr.Markdown("### Example Questions")
with gr.Tabs():
with gr.Tab("English"):
gr.Button("Who is the CEO of OhamLab?").click(
lambda: ("Who is the CEO of OhamLab?", "en"),
outputs=[user_input, language_dropdown]
)
gr.Button("What does OhamLab AI do?").click(
lambda: ("What does OhamLab AI do?", "en"),
outputs=[user_input, language_dropdown]
)
with gr.Tab("Hindi"):
gr.Button("ओहमलैब के सीईओ कौन हैं?").click(
lambda: ("ओहमलैब के सीईओ कौन हैं?", "hi"),
outputs=[user_input, language_dropdown]
)
gr.Button("ओहमलैब एआई क्या करता है?").click(
lambda: ("ओहमलैब एआई क्या करता है?", "hi"),
outputs=[user_input, language_dropdown]
)
with gr.Column(scale=1):
answer_text = gr.Markdown()
answer_audio = gr.Audio(type="filepath")
# ===== Ask button click =====
ask_btn.click(
fn=run_pipeline,
inputs=[user_input, language_dropdown],
outputs=[answer_text, answer_audio]
)
demo.queue() # enable long-running jobs for TTS
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)