import os import io import base64 import requests import torch import gradio as gr import numpy as np from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from scipy.io.wavfile import write as write_wav # ===================================================== # CONFIG # ===================================================== MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" DOC_FILE = "general.md" TTS_API_URL = "https://rahul7star-Chatterbox-Multilingual-TTS-API.hf.space/tts" MAX_NEW_TOKENS = 200 TOP_K = 3 # ===================================================== # LOAD DOCUMENT # ===================================================== BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DOC_PATH = os.path.join(BASE_DIR, DOC_FILE) with open(DOC_PATH, "r", encoding="utf-8", errors="ignore") as f: DOC_TEXT = f.read() # ===================================================== # CHUNK + EMBED # ===================================================== def chunk_text(text, chunk_size=300, overlap=50): words = text.split() chunks, i = [], 0 while i < len(words): chunks.append(" ".join(words[i:i + chunk_size])) i += chunk_size - overlap return chunks DOC_CHUNKS = chunk_text(DOC_TEXT) embedder = SentenceTransformer("all-MiniLM-L6-v2") DOC_EMBEDS = embedder.encode(DOC_CHUNKS, normalize_embeddings=True) # ===================================================== # LOAD QWEN # ===================================================== tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) model.eval() # ===================================================== # RETRIEVAL # ===================================================== def retrieve_context(question): q_emb = embedder.encode([question], normalize_embeddings=True) scores = np.dot(DOC_EMBEDS, q_emb[0]) top_ids = scores.argsort()[-TOP_K:][::-1] return "\n\n".join(DOC_CHUNKS[i] for i in top_ids) # ===================================================== # QWEN ANSWER # ===================================================== def answer_question(question): context = retrieve_context(question) messages = [ { "role": "system", "content": ( "You are a strict document-based Q&A assistant.\n" "Answer ONLY the question.\n" "Do NOT repeat context.\n" "Respond in 1 sentence.\n" "If not found, say:\n" "'I could not find this information in the document.'" ) }, {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"} ] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return decoded.split("\n")[-1].strip() # ===================================================== # TTS (BASE64 → WAV) # ===================================================== def generate_audio(text: str): payload = { "text": text, "language_id": "en", "mode": "Speak 🗣️" } r = requests.post(TTS_API_URL, json=payload, timeout=None) # ---- Case 1: API returns raw WAV bytes ---- if r.headers.get("content-type", "").startswith("audio"): wav_path = "/tmp/output.wav" with open(wav_path, "wb") as f: f.write(r.content) return wav_path # ---- Case 2: API returns JSON ---- data = r.json() # Try known keys safely audio_b64 = ( data.get("audio") or data.get("audio_base64") or data.get("wav") ) if audio_b64: wav_bytes = base64.b64decode(audio_b64) wav_path = "/tmp/output.wav" with open(wav_path, "wb") as f: f.write(wav_bytes) return wav_path # ---- Case 3: API returns a file path ---- if "path" in data and os.path.exists(data["path"]): return data["path"] # ---- Otherwise ---- raise RuntimeError(f"TTS API response invalid: {data}") # ===================================================== # MAIN HANDLER # ===================================================== def run_pipeline(question): if not question.strip(): return "", None # 1️⃣ TEXT FIRST answer = answer_question(question) # 2️⃣ AUDIO (SLOW, NO TIMEOUT) audio_path = generate_audio(answer) return answer, audio_path # ===================================================== # UI # ===================================================== with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=1): user_input = gr.Textbox( label="Your Question", placeholder="Who is CEO of OhamLab?", lines=4 ) ask_btn = gr.Button("Ask") with gr.Column(scale=1): answer_text = gr.Markdown( label="Assistant Answer", value="**Bot:** _Waiting for question..._" ) answer_audio = gr.Audio( label="Assistant Voice", type="filepath" ) ask_btn.click( fn=run_pipeline, inputs=user_input, outputs=[answer_text, answer_audio] ) demo.launch(server_name="0.0.0.0", server_port=7860, share=False)