Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| from ultralytics.utils.plotting import Annotator, colors | |
| import glob | |
| import os | |
| import re | |
| import base64 | |
| from io import BytesIO | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import tempfile # Add this import at the top | |
| import spaces | |
| MODEL_ID = "google/gemma-4-31B-it" | |
| image_paths = glob.glob("*.jpg") | |
| # Load model | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype="auto", | |
| device_map="cuda" | |
| ) | |
| def generate_report(message): | |
| inputs = processor.apply_chat_template( | |
| message, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to('cuda') | |
| input_len = inputs["input_ids"].shape[-1] | |
| # Generate output | |
| outputs = model.generate(**inputs, max_new_tokens=512) | |
| response = processor.decode(outputs[0][input_len:], skip_special_tokens=False) | |
| # Parse output | |
| output=processor.parse_response(response) | |
| return output | |
| # ----------------------------- | |
| # Configuration & Model Loading | |
| # ----------------------------- | |
| # Update this path for examples | |
| image_paths = glob.glob(r".\*.jpg") | |
| # Load models | |
| models = { | |
| "dental": YOLO("Dental_model.pt"), | |
| "bone": YOLO("best.onnx") # Adjust img_size if your ONNX model requires it | |
| } | |
| # Regex Patterns | |
| DENTAL_KEYWORDS = r"\b(tooth|teeth|dental|molar|gum|cavity|xray|ortho|mouth)\b" | |
| BONE_KEYWORDS = r"\b(bone|fracture|skeleton|arm|leg|rib|spine|joint|break)\b" | |
| # ----------------------------- | |
| # Helper Functions | |
| # ----------------------------- | |
| def encode_image_to_base64(pil_image): | |
| """Converts a PIL image to a base64 string for API transmission.""" | |
| buffered = BytesIO() | |
| # Convert to RGB to ensure compatibility with JPEG | |
| pil_image.convert("RGB").save(buffered, format="JPEG", quality=85) | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def detect_objects(image, model_type="dental"): | |
| if image is None: | |
| return None, "No image provided." | |
| selected_model = models[model_type] | |
| img = image.copy() | |
| if model_type == "bone": | |
| results = selected_model.predict(img, imgsz=512)[0] # Adjust img_size for ONNX if needed | |
| names = selected_model.names | |
| else: | |
| results = selected_model.predict(img)[0] | |
| names = selected_model.names # Correct way to get class names | |
| classes = results.boxes.cls.cpu().tolist() | |
| detected_labels = [names[int(cls)] for cls in classes] | |
| annotated_image=Image.fromarray(results.plot()) # Convert to PIL Image for Gradio display | |
| label_summary = f"Detected in {model_type} scan: {', '.join(set(detected_labels))}" if detected_labels else "No structures detected." | |
| return annotated_image, label_summary | |
| # ----------------------------- | |
| # Chat Wrapper | |
| # ----------------------------- | |
| def analyze(image, user_text, history): | |
| if image is None: | |
| history.append({"role": "assistant", "content": "Please upload an image first."}) | |
| return history | |
| # Routing Logic | |
| text = user_text.lower() | |
| if re.search(DENTAL_KEYWORDS, text): | |
| target_model = "dental" | |
| elif re.search(BONE_KEYWORDS, text): | |
| target_model = "bone" | |
| else: | |
| target_model = "dental" # Default | |
| # 1. Local YOLO Detection | |
| annotated_image, explanation = detect_objects(image, target_model) | |
| # 2. Prepare image for API | |
| base64_image = encode_image_to_base64(annotated_image) | |
| # 3. Call Vision LLM | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": f"You are a medical imagery assistant. Answer to the user's query based and answer based on the given detections on the visual evidence provided. ### User Query: {user_text} ### Detection Summary: {explanation}" | |
| }, | |
| { | |
| "type": "image", | |
| "url": f"data:image/jpeg;base64,{base64_image}"} | |
| ] | |
| } | |
| ] | |
| ai_report=generate_report(messages) | |
| # 4. Update Gradio History | |
| history.append({"role": "user", "content": user_text}) | |
| # Passing the PIL image directly to the chatbot history | |
| history.append({"role": "assistant", "content":gr.Image(value=annotated_image, format="png")}) | |
| history.append({"role": "assistant", "content": ai_report['content']}) | |
| return history | |
| # ----------------------------- | |
| # UI Layout | |
| # ----------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🏥 Multi-Purpose Medical AI Assistant") | |
| gr.Markdown("Upload a dental or bone X-ray and ask a question. The system will route the image to the correct YOLO model and generate a report.") | |
| chatbot = gr.Chatbot(label="Clinical Analysis", height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Scan (X-Ray/CT)") | |
| gr.Examples(examples=image_paths, inputs=[image_input], label="Sample Scans") | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| placeholder="e.g., 'Analyze this dental xray for cavities' or 'Is the bone fractured?'", | |
| label="Your Query", | |
| lines=2 | |
| ) | |
| run_button = gr.Button("Run Medical Analysis", variant="primary") | |
| run_button.click( | |
| analyze, | |
| inputs=[image_input, text_input, chatbot], | |
| outputs=[chatbot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |