Spaces:
Paused
Paused
| import os | |
| import base64 | |
| import uuid | |
| import requests | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from functools import lru_cache | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| # ===================================================== | |
| # CONFIG | |
| # ===================================================== | |
| MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | |
| DOC_FILE_EN = "general.md" | |
| DOC_FILE_HI = "general-hi.md" | |
| TTS_API_URL = os.getenv( | |
| "TTS_API_URL" | |
| ) | |
| MAX_NEW_TOKENS = 128 | |
| TOP_K = 3 | |
| SESSION = requests.Session() | |
| # ===================================================== | |
| # LOAD DOCUMENTS | |
| # ===================================================== | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DOC_PATH_EN = os.path.join(BASE_DIR, DOC_FILE_EN) | |
| DOC_PATH_HI = os.path.join(BASE_DIR, DOC_FILE_HI) | |
| for path, name in [(DOC_PATH_EN, DOC_FILE_EN), (DOC_PATH_HI, DOC_FILE_HI)]: | |
| if not os.path.exists(path): | |
| raise RuntimeError(f"{name} not found") | |
| with open(DOC_PATH_EN, "r", encoding="utf-8", errors="ignore") as f: | |
| DOC_TEXT_EN = f.read() | |
| with open(DOC_PATH_HI, "r", encoding="utf-8", errors="ignore") as f: | |
| DOC_TEXT_HI = f.read() | |
| # ===================================================== | |
| # CHUNK + EMBED | |
| # ===================================================== | |
| def chunk_text(text, chunk_size=300, overlap=50): | |
| words = text.split() | |
| chunks, i = [], 0 | |
| while i < len(words): | |
| chunks.append(" ".join(words[i:i + chunk_size])) | |
| i += chunk_size - overlap | |
| return chunks | |
| DOC_CHUNKS_EN = chunk_text(DOC_TEXT_EN) | |
| DOC_CHUNKS_HI = chunk_text(DOC_TEXT_HI) | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu") | |
| DOC_EMBEDS_EN = embedder.encode(DOC_CHUNKS_EN, normalize_embeddings=True, batch_size=32) | |
| DOC_EMBEDS_HI = embedder.encode(DOC_CHUNKS_HI, normalize_embeddings=True, batch_size=32) | |
| # ===================================================== | |
| # LOAD QWEN MODEL (CPU only) | |
| # ===================================================== | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| # ===================================================== | |
| # RETRIEVAL WITH CACHE | |
| # ===================================================== | |
| def retrieve_context(question: str, lang: str): | |
| q_emb = embedder.encode([question], normalize_embeddings=True) | |
| if lang == "hi": | |
| scores = np.dot(DOC_EMBEDS_HI, q_emb[0]) | |
| top_ids = scores.argsort()[-TOP_K:][::-1] | |
| return "\n\n".join(DOC_CHUNKS_HI[i] for i in top_ids) | |
| else: | |
| scores = np.dot(DOC_EMBEDS_EN, q_emb[0]) | |
| top_ids = scores.argsort()[-TOP_K:][::-1] | |
| return "\n\n".join(DOC_CHUNKS_EN[i] for i in top_ids) | |
| # ===================================================== | |
| # QWEN ANSWER (CPU optimized) | |
| # ===================================================== | |
| def answer_question(question: str, lang: str = "en") -> str: | |
| context = retrieve_context(question, lang) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a strict document-based Q&A assistant.\n" | |
| "Answer ONLY the question.\n" | |
| "Respond in 1 short sentence.\n" | |
| "If not found, say:\n" | |
| "'I could not find this information in the document.'" | |
| ) | |
| }, | |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}"} | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| use_cache=True | |
| ) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return decoded.split("\n")[-1].strip() | |
| # ===================================================== | |
| # TTS (CPU safe, flexible language) | |
| # ===================================================== | |
| def generate_audio(text: str, language_id: str = "en") -> str: | |
| payload = {"text": text, "language_id": language_id, "mode": "Speak 🗣️"} | |
| r = SESSION.post(TTS_API_URL, json=payload, timeout=None) | |
| r.raise_for_status() | |
| wav_path = f"/tmp/tts_{uuid.uuid4().hex}.wav" | |
| # raw audio bytes | |
| if r.headers.get("content-type", "").startswith("audio"): | |
| with open(wav_path, "wb") as f: | |
| f.write(r.content) | |
| return wav_path | |
| # JSON base64 | |
| data = r.json() | |
| audio_b64 = data.get("audio") or data.get("audio_base64") or data.get("wav") | |
| if not audio_b64: | |
| raise RuntimeError(f"TTS API returned no audio: {data}") | |
| audio_bytes = base64.b64decode(audio_b64) | |
| with open(wav_path, "wb") as f: | |
| f.write(audio_bytes) | |
| if os.path.getsize(wav_path) < 1000: | |
| raise RuntimeError("Generated audio file too small") | |
| return wav_path | |
| # ===================================================== | |
| # MAIN PIPELINE | |
| # ===================================================== | |
| def run_pipeline(question: str, language_id: str): | |
| if not question.strip(): | |
| return "", None | |
| # 1️⃣ Answer text | |
| answer = answer_question(question, language_id) | |
| # 2️⃣ TTS | |
| try: | |
| audio_path = generate_audio(answer, language_id) | |
| except Exception as e: | |
| print("TTS failed:", e) | |
| audio_path = None | |
| return f"**Bot:** {answer}", audio_path | |
| # ===================================================== | |
| # GRADIO UI | |
| # ===================================================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| user_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Who is CEO of OhamLab?", | |
| lines=3 | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| label="TTS Language", | |
| choices=["en", "hi"], | |
| value="en" | |
| ) | |
| ask_btn = gr.Button("Ask") | |
| # ===== Example questions ===== | |
| gr.Markdown("### Example Questions") | |
| with gr.Tabs(): | |
| with gr.Tab("English"): | |
| gr.Button("Who is the CEO of OhamLab?").click( | |
| lambda: ("Who is the CEO of OhamLab?", "en"), | |
| outputs=[user_input, language_dropdown] | |
| ) | |
| gr.Button("What does OhamLab AI do?").click( | |
| lambda: ("What does OhamLab AI do?", "en"), | |
| outputs=[user_input, language_dropdown] | |
| ) | |
| with gr.Tab("Hindi"): | |
| gr.Button("ओहमलैब के सीईओ कौन हैं?").click( | |
| lambda: ("ओहमलैब के सीईओ कौन हैं?", "hi"), | |
| outputs=[user_input, language_dropdown] | |
| ) | |
| gr.Button("ओहमलैब एआई क्या करता है?").click( | |
| lambda: ("ओहमलैब एआई क्या करता है?", "hi"), | |
| outputs=[user_input, language_dropdown] | |
| ) | |
| with gr.Column(scale=1): | |
| answer_text = gr.Markdown() | |
| answer_audio = gr.Audio(type="filepath") | |
| # ===== Ask button click ===== | |
| ask_btn.click( | |
| fn=run_pipeline, | |
| inputs=[user_input, language_dropdown], | |
| outputs=[answer_text, answer_audio] | |
| ) | |
| demo.queue() # enable long-running jobs for TTS | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |