Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import io | |
| from scipy.io.wavfile import write | |
| from transformers import pipeline | |
| import time | |
| from typing import Dict, List, Tuple | |
| # --- TTS Engine --- | |
| class FreeVoiceTTS: | |
| def __init__(self): | |
| self.model = None | |
| self.device = "cpu" | |
| self.sample_rate = 24000 | |
| def load_silero_tts(self): | |
| """Load Silero TTS - lightweight and reliable""" | |
| try: | |
| torch.set_num_threads(4) | |
| model, example_text = torch.hub.load( | |
| repo_or_dir='snakers4/silero-models', | |
| model='silero_tts', | |
| language='en', | |
| speaker='v3_en' | |
| ) | |
| self.silero_model = model | |
| return True | |
| except Exception as e: | |
| print(f"Silero TTS loading failed: {e}") | |
| return False | |
| def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]: | |
| """Convert text to speech, returning (sample_rate, audio_numpy)""" | |
| try: | |
| if not hasattr(self, 'silero_model'): | |
| if not self.load_silero_tts(): | |
| return None | |
| # Generate audio using Silero | |
| audio = self.silero_model.apply_tts( | |
| text=text, | |
| speaker='en_0', # English female voice | |
| sample_rate=self.sample_rate | |
| ) | |
| # Convert to numpy array for Gradio | |
| # Silero returns a torch tensor, we convert to numpy | |
| return (self.sample_rate, audio.numpy()) | |
| except Exception as e: | |
| print(f"Silero TTS failed: {e}") | |
| return None | |
| # --- STT Engine --- | |
| class SpeechToText: | |
| def __init__(self): | |
| self.transcriber = None | |
| def load_model(self): | |
| try: | |
| self.transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
| return True | |
| except Exception as e: | |
| print(f"STT loading failed: {e}") | |
| return False | |
| def transcribe(self, audio_path: str) -> str: | |
| if not self.transcriber: | |
| self.load_model() | |
| if not audio_path: | |
| return "" | |
| try: | |
| result = self.transcriber(audio_path) | |
| return result["text"] | |
| except Exception as e: | |
| print(f"Transcription failed: {e}") | |
| return "" | |
| # --- Application Logic --- | |
| # Initialize Engines | |
| tts_engine = FreeVoiceTTS() | |
| stt_engine = SpeechToText() | |
| # Pre-load models | |
| print("Loading AI Models...") | |
| tts_engine.load_silero_tts() | |
| stt_engine.load_model() | |
| print("Models Loaded.") | |
| QUESTION_BANK = { | |
| "upper_limb": [ | |
| { | |
| "question": "Describe the course and distribution of the median nerve from its origin to the hand.", | |
| "key_points": ["brachial plexus roots C5-T1", "medial and lateral cords", "carpal tunnel", "LOAF muscles"], | |
| "follow_up": "What clinical condition results from median nerve compression at the wrist?", | |
| "difficulty": "medium" | |
| }, | |
| { | |
| "question": "Explain the brachial plexus in detail, including its major branches.", | |
| "key_points": ["roots, trunks, divisions, cords, branches", "mnemonic: Real Texans Drink Cold Beer", "musculocutaneous, axillary, radial, median, ulnar nerves"], | |
| "follow_up": "Which cord of the brachial plexus is most vulnerable in shoulder dislocations?", | |
| "difficulty": "hard" | |
| }, | |
| { | |
| "question": "What are the muscles of the rotator cuff and their functions?", | |
| "key_points": ["supraspinatus", "infraspinatus", "teres minor", "subscapularis", "SITS mnemonic"], | |
| "follow_up": "Which rotator cuff muscle is most commonly injured?", | |
| "difficulty": "medium" | |
| } | |
| ], | |
| "lower_limb": [ | |
| { | |
| "question": "Trace the course of the sciatic nerve from its origin to its terminal branches.", | |
| "key_points": ["L4-S3 roots", "passes through greater sciatic foramen", "divides into tibial and common fibular nerves", "innervates hamstrings"], | |
| "follow_up": "What are the clinical manifestations of sciatic nerve injury?", | |
| "difficulty": "medium" | |
| }, | |
| { | |
| "question": "Describe the boundaries and contents of the femoral triangle.", | |
| "key_points": ["inguinal ligament", "sartorius", "adductor longus", "femoral nerve, artery, vein", "NAVY arrangement"], | |
| "follow_up": "Why is the femoral triangle important clinically?", | |
| "difficulty": "medium" | |
| } | |
| ], | |
| "cardiology": [ | |
| { | |
| "question": "Describe the blood supply to the heart and the coronary circulation.", | |
| "key_points": ["left and right coronary arteries", "circumflex artery", "left anterior descending", "coronary sinus"], | |
| "follow_up": "Which coronary artery is most commonly involved in myocardial infarction?", | |
| "difficulty": "medium" | |
| }, | |
| { | |
| "question": "Explain the conduction system of the heart.", | |
| "key_points": ["SA node", "AV node", "bundle of His", "bundle branches", "Purkinje fibers"], | |
| "follow_up": "What is the clinical significance of the AV node?", | |
| "difficulty": "hard" | |
| } | |
| ], | |
| "neuroanatomy": [ | |
| { | |
| "question": "Describe the blood supply of the brain.", | |
| "key_points": ["internal carotid arteries", "vertebral arteries", "circle of Willis", "anterior, middle, posterior cerebral arteries"], | |
| "follow_up": "What is the clinical consequence of middle cerebral artery occlusion?", | |
| "difficulty": "hard" | |
| }, | |
| { | |
| "question": "Name the twelve cranial nerves and their basic functions.", | |
| "key_points": ["olfactory, optic, oculomotor, trochlear, trigeminal, abducens, facial, vestibulocochlear, glossopharyngeal, vagus, accessory, hypoglossal"], | |
| "follow_up": "Which cranial nerve has the longest intracranial course?", | |
| "difficulty": "medium" | |
| } | |
| ] | |
| } | |
| def start_session(topic): | |
| if not topic: | |
| return ( | |
| None, | |
| [], | |
| "Please select a topic first.", | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ) | |
| session_state = { | |
| "topic": topic, | |
| "question_index": 0, | |
| "score": 0, | |
| "history": [], | |
| "current_question_data": QUESTION_BANK[topic][0] | |
| } | |
| first_question = session_state["current_question_data"]["question"] | |
| # Generate audio for first question | |
| audio = tts_engine.text_to_speech(first_question) | |
| return ( | |
| session_state, | |
| [(None, first_question)], # Chat history | |
| f"Topic: {topic.replace('_', ' ').title()}", | |
| gr.update(visible=True), # Show session | |
| gr.update(visible=False), # Hide topic selection | |
| audio # Auto-play question | |
| ) | |
| def process_response(audio_input, text_input, session_state, history): | |
| if not session_state: | |
| return session_state, history, "Error: No active session", None, None | |
| # Determine user answer (Audio takes precedence) | |
| user_answer = "" | |
| if audio_input: | |
| user_answer = stt_engine.transcribe(audio_input) | |
| elif text_input: | |
| user_answer = text_input | |
| if not user_answer: | |
| return session_state, history, "", None, None # No input | |
| # Evaluate Answer | |
| question_data = session_state["current_question_data"] | |
| score, feedback = evaluate_answer(user_answer, question_data) | |
| # Update State | |
| session_state["score"] += score | |
| session_state["history"].append({ | |
| "question": question_data["question"], | |
| "answer": user_answer, | |
| "feedback": feedback, | |
| "score": score | |
| }) | |
| # Update Chat History | |
| history.append((user_answer, feedback)) | |
| # Prepare Next Question | |
| session_state["question_index"] += 1 | |
| topic_questions = QUESTION_BANK[session_state["topic"]] | |
| next_audio = None | |
| if session_state["question_index"] < len(topic_questions): | |
| next_question_data = topic_questions[session_state["question_index"]] | |
| session_state["current_question_data"] = next_question_data | |
| next_q_text = next_question_data["question"] | |
| history.append((None, next_q_text)) | |
| # Generate audio for next question | |
| next_audio = tts_engine.text_to_speech(next_q_text) | |
| else: | |
| # End of session | |
| final_score = session_state["score"] | |
| count = len(topic_questions) | |
| avg = final_score / count if count > 0 else 0 | |
| end_msg = f"Session Complete! Final Score: {final_score:.1f}/{count*10} (Avg: {avg:.1f})" | |
| history.append((None, end_msg)) | |
| next_audio = tts_engine.text_to_speech(end_msg) | |
| session_state = None # Reset state | |
| return ( | |
| session_state, | |
| history, | |
| "", # Clear text input | |
| None, # Clear audio input | |
| next_audio | |
| ) | |
| def evaluate_answer(answer: str, question_data: Dict) -> Tuple[float, str]: | |
| """Simple keyword matching evaluation""" | |
| answer_lower = answer.lower() | |
| key_points = question_data["key_points"] | |
| covered_points = sum(1 for point in key_points if any(word in answer_lower for word in point.lower().split())) | |
| score = min(10, (covered_points / len(key_points)) * 10) | |
| if score >= 8: | |
| feedback = f"Excellent! {question_data.get('follow_up', '')}" | |
| elif score >= 5: | |
| feedback = f"Good. You missed some details. {question_data.get('follow_up', '')}" | |
| else: | |
| missed = [p for p in key_points if not any(w in answer_lower for w in p.lower().split())] | |
| feedback = f"Key points missed: {', '.join(missed[:2])}. {question_data.get('follow_up', '')}" | |
| return score, feedback | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="Anatomy Viva Voce", theme=gr.themes.Soft()) as demo: | |
| state = gr.State(None) # Session state | |
| gr.Markdown("# π§ Anatomy Viva Voce Simulator") | |
| gr.Markdown("Practice medical anatomy with an AI Professor. Speak or type your answers!") | |
| # Topic Selection View | |
| with gr.Group(visible=True) as topic_view: | |
| gr.Markdown("### Select a Topic to Begin") | |
| with gr.Row(): | |
| btn_upper = gr.Button("Upper Limb", variant="primary") | |
| btn_lower = gr.Button("Lower Limb", variant="primary") | |
| btn_cardio = gr.Button("Cardiology", variant="primary") | |
| btn_neuro = gr.Button("Neuroanatomy", variant="primary") | |
| # Session View | |
| with gr.Group(visible=False) as session_view: | |
| session_info = gr.Markdown("Topic: ...") | |
| chatbot = gr.Chatbot(label="Viva Session", height=400) | |
| # Professor Audio Output (Hidden player, auto-played via return) | |
| professor_audio = gr.Audio(label="Professor's Voice", autoplay=True, visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| txt_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your answer here...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| source="microphone", | |
| type="filepath", | |
| label="Voice Answer", | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit Answer", variant="primary") | |
| end_btn = gr.Button("End Session", variant="stop") | |
| # Event Handlers | |
| topic_buttons = [btn_upper, btn_lower, btn_cardio, btn_neuro] | |
| topics = ["upper_limb", "lower_limb", "cardiology", "neuroanatomy"] | |
| for btn, topic in zip(topic_buttons, topics): | |
| btn.click( | |
| fn=start_session, | |
| inputs=[gr.State(topic)], | |
| outputs=[state, chatbot, session_info, session_view, topic_view, professor_audio] | |
| ) | |
| # Submit via Text or Audio | |
| submit_inputs = [audio_input, txt_input, state, chatbot] | |
| submit_outputs = [state, chatbot, txt_input, audio_input, professor_audio] | |
| submit_btn.click(fn=process_response, inputs=submit_inputs, outputs=submit_outputs) | |
| txt_input.submit(fn=process_response, inputs=submit_inputs, outputs=submit_outputs) | |
| audio_input.change(fn=process_response, inputs=submit_inputs, outputs=submit_outputs) # Auto-submit on stop recording? Maybe better to require button for audio to avoid accidental submits. | |
| # Actually, let's NOT auto-submit audio on change, user might want to re-record. | |
| # But `change` triggers when recording stops. Let's stick to button for now to be safe, or add a specific listener. | |
| # For now, let's keep it simple: User records, then clicks submit. | |
| # Wait, `audio_input.change` is triggered when file is updated. | |
| def reset_ui(): | |
| return None, [], gr.update(visible=False), gr.update(visible=True) | |
| end_btn.click( | |
| fn=reset_ui, | |
| inputs=None, | |
| outputs=[state, chatbot, session_view, topic_view] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |