MEDPAI / app.py
alibidaran's picture
Update app.py
243f470 verified
import gradio as gr
from ultralytics import YOLO
from PIL import Image
from ultralytics.utils.plotting import Annotator, colors
import glob
import os
import re
import base64
from io import BytesIO
from transformers import AutoProcessor, AutoModelForCausalLM
import tempfile # Add this import at the top
import spaces
MODEL_ID = "google/gemma-4-31B-it"
image_paths = glob.glob("*.jpg")
# Load model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="cuda"
)
@spaces.GPU
def generate_report(message):
inputs = processor.apply_chat_template(
message,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to('cuda')
input_len = inputs["input_ids"].shape[-1]
# Generate output
outputs = model.generate(**inputs, max_new_tokens=512)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
# Parse output
output=processor.parse_response(response)
return output
# -----------------------------
# Configuration & Model Loading
# -----------------------------
# Update this path for examples
image_paths = glob.glob(r".\*.jpg")
# Load models
models = {
"dental": YOLO("Dental_model.pt"),
"bone": YOLO("best.onnx") # Adjust img_size if your ONNX model requires it
}
# Regex Patterns
DENTAL_KEYWORDS = r"\b(tooth|teeth|dental|molar|gum|cavity|xray|ortho|mouth)\b"
BONE_KEYWORDS = r"\b(bone|fracture|skeleton|arm|leg|rib|spine|joint|break)\b"
# -----------------------------
# Helper Functions
# -----------------------------
def encode_image_to_base64(pil_image):
"""Converts a PIL image to a base64 string for API transmission."""
buffered = BytesIO()
# Convert to RGB to ensure compatibility with JPEG
pil_image.convert("RGB").save(buffered, format="JPEG", quality=85)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
@spaces.GPU
def detect_objects(image, model_type="dental"):
if image is None:
return None, "No image provided."
selected_model = models[model_type]
img = image.copy()
if model_type == "bone":
results = selected_model.predict(img, imgsz=512)[0] # Adjust img_size for ONNX if needed
names = selected_model.names
else:
results = selected_model.predict(img)[0]
names = selected_model.names # Correct way to get class names
classes = results.boxes.cls.cpu().tolist()
detected_labels = [names[int(cls)] for cls in classes]
annotated_image=Image.fromarray(results.plot()) # Convert to PIL Image for Gradio display
label_summary = f"Detected in {model_type} scan: {', '.join(set(detected_labels))}" if detected_labels else "No structures detected."
return annotated_image, label_summary
# -----------------------------
# Chat Wrapper
# -----------------------------
@spaces.GPU
def analyze(image, user_text, history):
if image is None:
history.append({"role": "assistant", "content": "Please upload an image first."})
return history
# Routing Logic
text = user_text.lower()
if re.search(DENTAL_KEYWORDS, text):
target_model = "dental"
elif re.search(BONE_KEYWORDS, text):
target_model = "bone"
else:
target_model = "dental" # Default
# 1. Local YOLO Detection
annotated_image, explanation = detect_objects(image, target_model)
# 2. Prepare image for API
base64_image = encode_image_to_base64(annotated_image)
# 3. Call Vision LLM
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"You are a medical imagery assistant. Answer to the user's query based and answer based on the given detections on the visual evidence provided. ### User Query: {user_text} ### Detection Summary: {explanation}"
},
{
"type": "image",
"url": f"data:image/jpeg;base64,{base64_image}"}
]
}
]
ai_report=generate_report(messages)
# 4. Update Gradio History
history.append({"role": "user", "content": user_text})
# Passing the PIL image directly to the chatbot history
history.append({"role": "assistant", "content":gr.Image(value=annotated_image, format="png")})
history.append({"role": "assistant", "content": ai_report['content']})
return history
# -----------------------------
# UI Layout
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏥 Multi-Purpose Medical AI Assistant")
gr.Markdown("Upload a dental or bone X-ray and ask a question. The system will route the image to the correct YOLO model and generate a report.")
chatbot = gr.Chatbot(label="Clinical Analysis", height=500)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Scan (X-Ray/CT)")
gr.Examples(examples=image_paths, inputs=[image_input], label="Sample Scans")
with gr.Column(scale=2):
text_input = gr.Textbox(
placeholder="e.g., 'Analyze this dental xray for cavities' or 'Is the bone fractured?'",
label="Your Query",
lines=2
)
run_button = gr.Button("Run Medical Analysis", variant="primary")
run_button.click(
analyze,
inputs=[image_input, text_input, chatbot],
outputs=[chatbot]
)
if __name__ == "__main__":
demo.launch()