| import os |
| import torch |
| import time |
| import torch |
| import time |
| import gradio as gr |
| import spaces |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer, TextIteratorStreamer |
| import threading |
| import queue |
|
|
| |
| current_model = None |
| current_tokenizer = None |
|
|
| |
| model_choices = [ |
| "meta-llama/Llama-3.2-3B-Instruct", |
| "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", |
| "google/gemma-7b-it", |
| "mistralai/Mistral-Nemo-Instruct-FP8-2407" |
| ] |
|
|
| |
| patient_db = { |
| "001 - John Doe": { |
| "name": "John Doe", |
| "age": "45", |
| "id": "001", |
| "notes": "History of chest pain and hypertension. No prior surgeries." |
| }, |
| "002 - Maria Sanchez": { |
| "name": "Maria Sanchez", |
| "age": "62", |
| "id": "002", |
| "notes": "Suspected pulmonary embolism. Shortness of breath, tachycardia." |
| }, |
| "003 - Ahmed Al-Farsi": { |
| "name": "Ahmed Al-Farsi", |
| "age": "29", |
| "id": "003", |
| "notes": "Persistent migraines. MRI scheduled for brain imaging." |
| }, |
| "004 - Lin Wei": { |
| "name": "Lin Wei", |
| "age": "51", |
| "id": "004", |
| "notes": "Annual screening. Family history of breast cancer." |
| } |
| } |
|
|
| |
| patient_conversations = {} |
|
|
|
|
| class RichTextStreamer(TextIteratorStreamer): |
| def __init__(self, tokenizer, prompt_len=0, **kwargs): |
| super().__init__(tokenizer, **kwargs) |
| self.token_queue = queue.Queue() |
| self.prompt_len = prompt_len |
| self.count = 0 |
|
|
| def put(self, value): |
| if isinstance(value, torch.Tensor): |
| token_ids = value.view(-1).tolist() |
| elif isinstance(value, list): |
| token_ids = value |
| else: |
| token_ids = [value] |
|
|
| for token_id in token_ids: |
| self.count += 1 |
| if self.count <= self.prompt_len: |
| continue |
| token_str = self.tokenizer.decode([token_id], **self.decode_kwargs) |
| is_special = token_id in self.tokenizer.all_special_ids |
| self.token_queue.put({ |
| "token_id": token_id, |
| "token": token_str, |
| "is_special": is_special |
| }) |
|
|
| def __iter__(self): |
| while True: |
| try: |
| token_info = self.token_queue.get(timeout=self.timeout) |
| yield token_info |
| except queue.Empty: |
| if self.end_of_generation.is_set(): |
| break |
|
|
|
|
| @spaces.GPU |
| def chat_with_model(messages, pid): |
| global current_model, current_tokenizer |
| if current_model is None or current_tokenizer is None: |
| yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}] |
| return |
|
|
| current_id = pid |
| if not current_id: |
| yield messages |
| return |
|
|
| max_new_tokens = 1024 |
| output_text = "" |
| in_think = False |
| generated_tokens = 0 |
|
|
| pad_id = current_tokenizer.pad_token_id or current_tokenizer.unk_token_id or 0 |
| eos_id = current_tokenizer.eos_token_id |
|
|
|
|
| |
| prompt = format_prompt(messages) |
|
|
| device = torch.device("cuda") |
| current_model.to(device).half() |
|
|
| inputs = current_tokenizer(prompt, return_tensors="pt").to(device) |
| prompt_len = inputs["input_ids"].shape[-1] |
|
|
| print(prompt) |
|
|
| streamer = RichTextStreamer( |
| tokenizer=current_tokenizer, |
| prompt_len=prompt_len, |
| skip_special_tokens=False |
| ) |
|
|
| generation_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| streamer=streamer, |
| eos_token_id=eos_id, |
| pad_token_id=pad_id |
| ) |
|
|
| thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| |
| updated_messages = messages.copy() |
| updated_messages.append({"role": "assistant", "content": ""}) |
|
|
| print(updated_messages) |
|
|
| for token_info in streamer: |
| token_str = token_info["token"] |
| token_id = token_info["token_id"] |
|
|
| if token_id == eos_id: |
| break |
|
|
| if "<think>" in token_str: |
| in_think = True |
| token_str = token_str.replace("<think>", "") |
| output_text += "*" |
|
|
| if "</think>" in token_str: |
| in_think = False |
| token_str = token_str.replace("</think>", "") |
| output_text += token_str + "*" |
| else: |
| output_text += token_str |
|
|
| if "\nUser" in output_text: |
| output_text = output_text.split("\nUser")[0].rstrip() |
| updated_messages[-1]["content"] = output_text |
| break |
|
|
| if "\nSystem" in output_text: |
| output_text = output_text.split("\nSystem")[0].rstrip() |
| updated_messages[-1]["content"] = output_text |
| break |
|
|
| if "\nAssistant" in output_text: |
| output_text = output_text.split("\nAssistant")[0].rstrip() |
| updated_messages[-1]["content"] = output_text |
| break |
|
|
| generated_tokens += 1 |
| if generated_tokens >= max_new_tokens: |
| break |
|
|
| updated_messages[-1]["content"] = output_text |
|
|
| patient_conversations[current_id] = updated_messages |
| yield updated_messages |
| |
| if in_think: |
| output_text += "*" |
|
|
| updated_messages[-1]["content"] = output_text |
| patient_conversations[current_id] = updated_messages |
| torch.cuda.empty_cache() |
| return updated_messages |
|
|
|
|
|
|
| def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)): |
| global current_model, current_tokenizer |
| token = os.getenv("HF_TOKEN") |
|
|
| progress(0, desc="Loading config...") |
| config = AutoConfig.from_pretrained(model_name, use_auth_token=token) |
|
|
| progress(0.2, desc="Loading tokenizer...") |
|
|
| |
| current_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code= True, use_auth_token=token) |
|
|
| progress(0.5, desc="Loading model...") |
| current_model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="cpu", |
| use_auth_token=token |
| ) |
|
|
| progress(1, desc="Model ready.") |
| return f"{model_name} loaded and ready!" |
|
|
|
|
| |
| def format_prompt(messages): |
| prompt = "" |
| for msg in messages: |
| role = msg["role"] |
| if role == "user": |
| prompt += f"User: {msg['content'].strip()}\n" |
| elif role == "assistant": |
| prompt += f"Assistant: {msg['content'].strip()}\n" |
| elif role == "system": |
| prompt += f"System: {msg['content'].strip()}\n" |
| prompt += "Assistant:" |
| return prompt |
|
|
| def add_user_message(user_input, history, pid): |
| if not pid: |
| return "", [] |
| history.append({"role": "user", "content": user_input}) |
| patient_conversations[pid] = history |
| return "", history |
|
|
| def autofill_patient(patient_key): |
| if patient_key in patient_db: |
| info = patient_db[patient_key] |
|
|
| |
| if info["id"] not in patient_conversations: |
| patient_conversations[info["id"]] = [] |
|
|
| return info["name"], info["age"], info["id"], info["notes"] |
| return "", "", "", "" |
|
|
|
|
|
|
| |
|
|
| def resolve_model_choice(mode, dropdown_value, textbox_value): |
| return textbox_value.strip() if mode == "Enter custom model" else dropdown_value |
|
|
| def load_patient_conversation(patient_key): |
| if patient_key in patient_db: |
| patient_id_val = patient_db[patient_key]["id"] |
| history = patient_conversations.get(patient_id_val, []) |
| if not history: |
| system_message = [ |
| { |
| "role": "system", |
| "content": ( |
| "You are a radiologist's companion, here to answer questions about the patient and assist in the diagnosis if asked to do so. " |
| "You are able to call specialized tools. " |
| "At the moment, you have one tool available: an organ segmentation algorithm for abdominal CTs.\n\n" |
| "If the user requests an organ segmentation, output a JSON object in this structure:\n" |
| "{\n" |
| " \"function\": \"segment_organ\",\n" |
| " \"arguments\": {\n" |
| " \"scan_path\": \"<path_to_ct_scan>\",\n" |
| " \"organ\": \"<organ_name>\"\n" |
| " }\n" |
| "}\n\n" |
| "Once you call the function, the app will execute it and return the result." |
| ) |
| }, |
| { |
| "role": "system", |
| "content": f"Patient Information:\nName: {patient_name.value}\nAge: {patient_age.value}\nID: {patient_id.value}\nNotes: {patient_notes.value}" |
| } |
| ] |
| welcome_message = [ |
| { |
| "role": "assistant", |
| "content": ( |
| "Welcome to the Radiologist's Companion!\n\n" |
| "You can ask me about the patient's medical history or available imaging data.\n" |
| "- I can summarize key details from the EHR.\n" |
| "- I can tell you which medical images are available.\n" |
| "- If you'd like an organ segmentation (e.g. spleen, liver, kidney_left, colon, femur_right) on an abdominal CT scan, just ask!\n\n" |
| "Example Requests:\n" |
| "- \"What do we know about this patient?\"\n" |
| "- \"Which images are available for this patient?\"\n" |
| "- \"Can you segment the spleen from the CT scan?\"\n" |
| ) |
| } |
| ] |
| history = system_message + welcome_message |
| return history |
| return [] |
|
|
|
|
| def get_patient_conversation(): |
| current_id = patient_id.value |
| if not current_id: |
| return [] |
| return patient_conversations.get(current_id, []) |
|
|
|
|
| |
|
|
| css = """ |
| .equal-height > .gr-column { |
| height: 100% !important; |
| display: flex; |
| flex-direction: column; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| gr.Markdown("<h2 style='text-align: center;'>Radiologist's Companion</h2>") |
| default_model = gr.State(model_choices[0]) |
|
|
| with gr.Row(elem_classes="equal-height"): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Patient Information") |
| patient_selector = gr.Dropdown( |
| choices=list(patient_db.keys()), |
| value=list(patient_db.keys())[0], |
| label="Select Patient", |
| allow_custom_value=False |
| ) |
| patient_name = gr.Textbox(label="Name", placeholder="e.g., John Doe", interactive=False) |
| patient_age = gr.Textbox(label="Age", placeholder="e.g., 45", interactive=False) |
| patient_id = gr.Textbox(label="Patient ID", placeholder="e.g., 123456", interactive=False) |
| patient_notes = gr.Textbox(label="Clinical Notes", lines=10, interactive=False) |
|
|
| |
| with gr.Column(scale=3): |
| gr.Markdown("### Chat") |
| chatbot = gr.Chatbot(label="Chat", type="messages", height=450) |
| msg = gr.Textbox(label="Your message", placeholder="Enter your chat message...", show_label=False) |
| with gr.Row(): |
| submit_btn = gr.Button("Submit", variant="primary") |
| clear_btn = gr.Button("Clear", variant="secondary") |
|
|
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Model Settings") |
| mode = gr.Radio(["Choose from list", "Enter custom model"], value="Choose from list", label="Model Input Mode") |
| model_selector = gr.Dropdown(choices=model_choices, label="Select Predefined Model") |
| model_textbox = gr.Textbox(label="Or Enter HF Model Name") |
| model_status = gr.Textbox(label="Model Status", interactive=False) |
|
|
| |
| |
| |
| demo.load( |
| lambda: autofill_patient(list(patient_db.keys())[0]), |
| inputs=None, |
| outputs=[patient_name, patient_age, patient_id, patient_notes] |
| ).then( |
| lambda: load_patient_conversation(list(patient_db.keys())[0]), |
| inputs=None, |
| outputs=chatbot |
| ).then( |
| load_model_on_selection, |
| inputs=default_model, |
| outputs=model_status |
| ) |
|
|
| |
| patient_selector.change( |
| autofill_patient, |
| inputs=[patient_selector], |
| outputs=[patient_name, patient_age, patient_id, patient_notes] |
| ).then( |
| load_patient_conversation, |
| inputs=[patient_selector], |
| outputs=[chatbot] |
| ) |
|
|
| |
| mode.select(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( |
| load_model_on_selection, inputs=default_model, outputs=model_status |
| ) |
| model_selector.change(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( |
| load_model_on_selection, inputs=default_model, outputs=model_status |
| ) |
| model_textbox.submit(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( |
| load_model_on_selection, inputs=default_model, outputs=model_status |
| ) |
|
|
| msg.submit( |
| add_user_message, |
| [msg, chatbot, patient_id], |
| [msg, chatbot], |
| queue=False, |
| ).then( |
| chat_with_model, |
| [chatbot, patient_id], |
| chatbot, |
| ) |
|
|
| submit_btn.click( |
| add_user_message, |
| [msg, chatbot, patient_id], |
| [msg, chatbot], |
| queue=False, |
| ).then( |
| chat_with_model, |
| [chatbot, patient_id], |
| chatbot, |
| ) |
|
|
|
|
| |
| clear_btn.click(lambda: [], None, chatbot, queue=False) |
|
|
| demo.launch() |
|
|
|
|