Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import asyncio | |
| import logging | |
| import time | |
| from model_loader import engine | |
| from deepgram import DeepgramClient, PrerecordedOptions | |
| from huggingface_hub import snapshot_download | |
| # --- Setup & Configuration --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("smart_turn_gradio") | |
| MODEL_REPO_ID = "Rishi2455/smart-turn-model" | |
| local_model_path = snapshot_download(repo_id=MODEL_REPO_ID) | |
| DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY") | |
| dg_client = DeepgramClient(DEEPGRAM_API_KEY) if DEEPGRAM_API_KEY else None | |
| async def load_model(): | |
| if not engine.is_loaded: | |
| await engine.load_model(local_model_path) | |
| # --- Helper: Stylish Result Component with Latency --- | |
| def format_result_html(is_complete, confidence, latency, extra_info=""): | |
| label = "COMPLETE β " if is_complete else "INCOMPLETE β³" | |
| color = "#10b981" if is_complete else "#f59e0b" | |
| bg_color = "rgba(16, 185, 129, 0.1)" if is_complete else "rgba(245, 158, 11, 0.1)" | |
| return f""" | |
| <div style="padding: 24px; border-radius: 12px; background-color: {bg_color}; border: 2px solid {color}; transition: all 0.3s ease; font-family: sans-serif;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center;"> | |
| <h1 style="margin: 0; color: white; font-size: 2.2em; font-weight: 800; letter-spacing: -0.5px;">{label}</h1> | |
| <div style="text-align: right;"> | |
| <p style="margin: 0; color: #94a3b8; font-size: 0.7em; font-weight: 600; letter-spacing: 1px;">CONFIDENCE</p> | |
| <b style="color: {color}; font-size: 1.6em;">{confidence:.2%}</b> | |
| <p style="margin: 4px 0 0 0; color: #64748b; font-size: 0.75em;">Latency: <b>{latency:.1f}ms</b></p> | |
| </div> | |
| </div> | |
| {f'<div style="margin-top: 12px; padding-top: 12px; border-top: 1px solid rgba(255,255,255,0.1); color: #cbd5e1; font-style: italic; font-size: 0.9em;">{extra_info}</div>' if extra_info else ''} | |
| </div> | |
| """ | |
| # --- Prediction Logic --- | |
| async def predict_text(text): | |
| if not text: | |
| return "<div style='text-align: center; color: #64748b;'>Please enter some text.</div>" | |
| await load_model() | |
| start_time = time.perf_counter() | |
| result = await engine.predict(text) | |
| latency = (time.perf_counter() - start_time) * 1000 | |
| return format_result_html(result["is_complete"], result["confidence"], latency) | |
| async def predict_audio(audio_path): | |
| if not audio_path: | |
| return "<div style='text-align: center; color: #64748b;'>Please record or upload audio.</div>" | |
| await load_model() | |
| # Transcription Step | |
| with open(audio_path, 'rb') as audio: | |
| source = {'buffer': audio.read()} | |
| options = PrerecordedOptions(model="nova-2", smart_format=True) | |
| response = dg_client.listen.rest.v("1").transcribe_file(source, options) | |
| transcript = response.results.channels[0].alternatives[0].transcript | |
| if not transcript: | |
| return "<div style='text-align: center; color: #64748b;'>No speech detected.</div>" | |
| # Model Inference Step | |
| start_time = time.perf_counter() | |
| result = await engine.predict(transcript) | |
| latency = (time.perf_counter() - start_time) * 1000 | |
| return format_result_html(result["is_complete"], result["confidence"], latency, extra_info=f"Transcript: \"{transcript}\"") | |
| # --- UI Layout --- | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue="indigo"), css="#container {max-width: 800px; margin: auto; padding-top: 20px;}") as demo: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown("# π€ Smart Turn - EOU Detection") | |
| # Added target="_blank" to ensure the link opens in a new tab | |
| gr.Markdown(f"Classification of End-of-Utterance (EOU) using: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank' style='color: #6366f1; text-decoration: none; font-weight: 600;'>{MODEL_REPO_ID}</a>") | |
| with gr.Tabs(): | |
| with gr.Tab("π Text Prediction"): | |
| text_input = gr.Textbox(placeholder="Type a sentence...", label="Input Text", lines=3) | |
| text_btn = gr.Button("Analyze Text", variant="primary") | |
| with gr.Tab("ποΈ Audio Prediction"): | |
| audio_input = gr.Audio(type="filepath", label="Record/Upload Audio") | |
| audio_btn = gr.Button("Transcribe & Analyze", variant="primary") | |
| gr.Markdown("### π Model Result") | |
| output_html = gr.HTML("<div style='height: 120px; display: flex; align-items: center; justify-content: center; border: 2px dashed #334155; border-radius: 12px; color: #64748b; text-align: center;'>Provide input and click Analyze</div>") | |
| gr.Examples( | |
| examples=[["i want to"], ["i want to book a flight"], ["can you help me with"]], | |
| inputs=text_input | |
| ) | |
| # Event Handlers | |
| text_btn.click(predict_text, inputs=text_input, outputs=output_html) | |
| audio_btn.click(predict_audio, inputs=audio_input, outputs=output_html) | |
| if __name__ == "__main__": | |
| demo.launch() |