SimpleViva / app.py
gladguy's picture
Gradio migration
2464a55
import gradio as gr
import torch
import numpy as np
import io
from scipy.io.wavfile import write
from transformers import pipeline
import time
from typing import Dict, List, Tuple
# --- TTS Engine ---
class FreeVoiceTTS:
def __init__(self):
self.model = None
self.device = "cpu"
self.sample_rate = 24000
def load_silero_tts(self):
"""Load Silero TTS - lightweight and reliable"""
try:
torch.set_num_threads(4)
model, example_text = torch.hub.load(
repo_or_dir='snakers4/silero-models',
model='silero_tts',
language='en',
speaker='v3_en'
)
self.silero_model = model
return True
except Exception as e:
print(f"Silero TTS loading failed: {e}")
return False
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
"""Convert text to speech, returning (sample_rate, audio_numpy)"""
try:
if not hasattr(self, 'silero_model'):
if not self.load_silero_tts():
return None
# Generate audio using Silero
audio = self.silero_model.apply_tts(
text=text,
speaker='en_0', # English female voice
sample_rate=self.sample_rate
)
# Convert to numpy array for Gradio
# Silero returns a torch tensor, we convert to numpy
return (self.sample_rate, audio.numpy())
except Exception as e:
print(f"Silero TTS failed: {e}")
return None
# --- STT Engine ---
class SpeechToText:
def __init__(self):
self.transcriber = None
def load_model(self):
try:
self.transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
return True
except Exception as e:
print(f"STT loading failed: {e}")
return False
def transcribe(self, audio_path: str) -> str:
if not self.transcriber:
self.load_model()
if not audio_path:
return ""
try:
result = self.transcriber(audio_path)
return result["text"]
except Exception as e:
print(f"Transcription failed: {e}")
return ""
# --- Application Logic ---
# Initialize Engines
tts_engine = FreeVoiceTTS()
stt_engine = SpeechToText()
# Pre-load models
print("Loading AI Models...")
tts_engine.load_silero_tts()
stt_engine.load_model()
print("Models Loaded.")
QUESTION_BANK = {
"upper_limb": [
{
"question": "Describe the course and distribution of the median nerve from its origin to the hand.",
"key_points": ["brachial plexus roots C5-T1", "medial and lateral cords", "carpal tunnel", "LOAF muscles"],
"follow_up": "What clinical condition results from median nerve compression at the wrist?",
"difficulty": "medium"
},
{
"question": "Explain the brachial plexus in detail, including its major branches.",
"key_points": ["roots, trunks, divisions, cords, branches", "mnemonic: Real Texans Drink Cold Beer", "musculocutaneous, axillary, radial, median, ulnar nerves"],
"follow_up": "Which cord of the brachial plexus is most vulnerable in shoulder dislocations?",
"difficulty": "hard"
},
{
"question": "What are the muscles of the rotator cuff and their functions?",
"key_points": ["supraspinatus", "infraspinatus", "teres minor", "subscapularis", "SITS mnemonic"],
"follow_up": "Which rotator cuff muscle is most commonly injured?",
"difficulty": "medium"
}
],
"lower_limb": [
{
"question": "Trace the course of the sciatic nerve from its origin to its terminal branches.",
"key_points": ["L4-S3 roots", "passes through greater sciatic foramen", "divides into tibial and common fibular nerves", "innervates hamstrings"],
"follow_up": "What are the clinical manifestations of sciatic nerve injury?",
"difficulty": "medium"
},
{
"question": "Describe the boundaries and contents of the femoral triangle.",
"key_points": ["inguinal ligament", "sartorius", "adductor longus", "femoral nerve, artery, vein", "NAVY arrangement"],
"follow_up": "Why is the femoral triangle important clinically?",
"difficulty": "medium"
}
],
"cardiology": [
{
"question": "Describe the blood supply to the heart and the coronary circulation.",
"key_points": ["left and right coronary arteries", "circumflex artery", "left anterior descending", "coronary sinus"],
"follow_up": "Which coronary artery is most commonly involved in myocardial infarction?",
"difficulty": "medium"
},
{
"question": "Explain the conduction system of the heart.",
"key_points": ["SA node", "AV node", "bundle of His", "bundle branches", "Purkinje fibers"],
"follow_up": "What is the clinical significance of the AV node?",
"difficulty": "hard"
}
],
"neuroanatomy": [
{
"question": "Describe the blood supply of the brain.",
"key_points": ["internal carotid arteries", "vertebral arteries", "circle of Willis", "anterior, middle, posterior cerebral arteries"],
"follow_up": "What is the clinical consequence of middle cerebral artery occlusion?",
"difficulty": "hard"
},
{
"question": "Name the twelve cranial nerves and their basic functions.",
"key_points": ["olfactory, optic, oculomotor, trochlear, trigeminal, abducens, facial, vestibulocochlear, glossopharyngeal, vagus, accessory, hypoglossal"],
"follow_up": "Which cranial nerve has the longest intracranial course?",
"difficulty": "medium"
}
]
}
def start_session(topic):
if not topic:
return (
None,
[],
"Please select a topic first.",
gr.update(visible=False),
gr.update(visible=True)
)
session_state = {
"topic": topic,
"question_index": 0,
"score": 0,
"history": [],
"current_question_data": QUESTION_BANK[topic][0]
}
first_question = session_state["current_question_data"]["question"]
# Generate audio for first question
audio = tts_engine.text_to_speech(first_question)
return (
session_state,
[(None, first_question)], # Chat history
f"Topic: {topic.replace('_', ' ').title()}",
gr.update(visible=True), # Show session
gr.update(visible=False), # Hide topic selection
audio # Auto-play question
)
def process_response(audio_input, text_input, session_state, history):
if not session_state:
return session_state, history, "Error: No active session", None, None
# Determine user answer (Audio takes precedence)
user_answer = ""
if audio_input:
user_answer = stt_engine.transcribe(audio_input)
elif text_input:
user_answer = text_input
if not user_answer:
return session_state, history, "", None, None # No input
# Evaluate Answer
question_data = session_state["current_question_data"]
score, feedback = evaluate_answer(user_answer, question_data)
# Update State
session_state["score"] += score
session_state["history"].append({
"question": question_data["question"],
"answer": user_answer,
"feedback": feedback,
"score": score
})
# Update Chat History
history.append((user_answer, feedback))
# Prepare Next Question
session_state["question_index"] += 1
topic_questions = QUESTION_BANK[session_state["topic"]]
next_audio = None
if session_state["question_index"] < len(topic_questions):
next_question_data = topic_questions[session_state["question_index"]]
session_state["current_question_data"] = next_question_data
next_q_text = next_question_data["question"]
history.append((None, next_q_text))
# Generate audio for next question
next_audio = tts_engine.text_to_speech(next_q_text)
else:
# End of session
final_score = session_state["score"]
count = len(topic_questions)
avg = final_score / count if count > 0 else 0
end_msg = f"Session Complete! Final Score: {final_score:.1f}/{count*10} (Avg: {avg:.1f})"
history.append((None, end_msg))
next_audio = tts_engine.text_to_speech(end_msg)
session_state = None # Reset state
return (
session_state,
history,
"", # Clear text input
None, # Clear audio input
next_audio
)
def evaluate_answer(answer: str, question_data: Dict) -> Tuple[float, str]:
"""Simple keyword matching evaluation"""
answer_lower = answer.lower()
key_points = question_data["key_points"]
covered_points = sum(1 for point in key_points if any(word in answer_lower for word in point.lower().split()))
score = min(10, (covered_points / len(key_points)) * 10)
if score >= 8:
feedback = f"Excellent! {question_data.get('follow_up', '')}"
elif score >= 5:
feedback = f"Good. You missed some details. {question_data.get('follow_up', '')}"
else:
missed = [p for p in key_points if not any(w in answer_lower for w in p.lower().split())]
feedback = f"Key points missed: {', '.join(missed[:2])}. {question_data.get('follow_up', '')}"
return score, feedback
# --- Gradio UI ---
with gr.Blocks(title="Anatomy Viva Voce", theme=gr.themes.Soft()) as demo:
state = gr.State(None) # Session state
gr.Markdown("# 🧠 Anatomy Viva Voce Simulator")
gr.Markdown("Practice medical anatomy with an AI Professor. Speak or type your answers!")
# Topic Selection View
with gr.Group(visible=True) as topic_view:
gr.Markdown("### Select a Topic to Begin")
with gr.Row():
btn_upper = gr.Button("Upper Limb", variant="primary")
btn_lower = gr.Button("Lower Limb", variant="primary")
btn_cardio = gr.Button("Cardiology", variant="primary")
btn_neuro = gr.Button("Neuroanatomy", variant="primary")
# Session View
with gr.Group(visible=False) as session_view:
session_info = gr.Markdown("Topic: ...")
chatbot = gr.Chatbot(label="Viva Session", height=400)
# Professor Audio Output (Hidden player, auto-played via return)
professor_audio = gr.Audio(label="Professor's Voice", autoplay=True, visible=False)
with gr.Row():
with gr.Column(scale=4):
txt_input = gr.Textbox(
show_label=False,
placeholder="Type your answer here...",
lines=2
)
with gr.Column(scale=1):
audio_input = gr.Audio(
source="microphone",
type="filepath",
label="Voice Answer",
show_label=False
)
with gr.Row():
submit_btn = gr.Button("Submit Answer", variant="primary")
end_btn = gr.Button("End Session", variant="stop")
# Event Handlers
topic_buttons = [btn_upper, btn_lower, btn_cardio, btn_neuro]
topics = ["upper_limb", "lower_limb", "cardiology", "neuroanatomy"]
for btn, topic in zip(topic_buttons, topics):
btn.click(
fn=start_session,
inputs=[gr.State(topic)],
outputs=[state, chatbot, session_info, session_view, topic_view, professor_audio]
)
# Submit via Text or Audio
submit_inputs = [audio_input, txt_input, state, chatbot]
submit_outputs = [state, chatbot, txt_input, audio_input, professor_audio]
submit_btn.click(fn=process_response, inputs=submit_inputs, outputs=submit_outputs)
txt_input.submit(fn=process_response, inputs=submit_inputs, outputs=submit_outputs)
audio_input.change(fn=process_response, inputs=submit_inputs, outputs=submit_outputs) # Auto-submit on stop recording? Maybe better to require button for audio to avoid accidental submits.
# Actually, let's NOT auto-submit audio on change, user might want to re-record.
# But `change` triggers when recording stops. Let's stick to button for now to be safe, or add a specific listener.
# For now, let's keep it simple: User records, then clicks submit.
# Wait, `audio_input.change` is triggered when file is updated.
def reset_ui():
return None, [], gr.update(visible=False), gr.update(visible=True)
end_btn.click(
fn=reset_ui,
inputs=None,
outputs=[state, chatbot, session_view, topic_view]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)