| import os |
| import io |
| import base64 |
| import traceback |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from flask import Flask, request, render_template, flash, redirect, url_for, jsonify |
| from dotenv import load_dotenv |
|
|
| |
| from transformers import ( |
| AutoModel, |
| AutoImageProcessor, |
| T5ForConditionalGeneration, |
| T5Tokenizer, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| ) |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
| load_dotenv() |
|
|
| |
| MODEL_PATH = '/cluster/home/ammaa/Downloads/Projects/CheXpert-Report-Generation/swin-t5-model.pth' |
| SWIN_MODEL_NAME = "microsoft/swin-base-patch4-window7-224" |
| T5_MODEL_NAME = "t5-base" |
| LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
| HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
| if not HF_TOKEN: |
| print("Warning: HUGGING_FACE_HUB_TOKEN environment variable not set. Llama model download might fail.") |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| UPLOAD_FOLDER = 'uploads' |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} |
|
|
| |
| class ImageCaptioningModel(nn.Module): |
| def __init__(self, |
| swin_model_name=SWIN_MODEL_NAME, |
| t5_model_name=T5_MODEL_NAME): |
| super().__init__() |
| |
| self.swin = AutoModel.from_pretrained(swin_model_name) |
| self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name) |
| |
| self.img_proj = nn.Linear(self.swin.config.hidden_size, self.t5.config.d_model) |
|
|
| def forward(self, images, labels=None): |
| |
| swin_outputs = self.swin(images, return_dict=True) |
| img_feats = swin_outputs.last_hidden_state |
| img_feats_proj = self.img_proj(img_feats) |
| encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) |
| if labels is not None: |
| outputs = self.t5(encoder_outputs=encoder_outputs, labels=labels) |
| else: |
| outputs = self.t5(encoder_outputs=encoder_outputs) |
| return outputs |
|
|
| |
| swin_t5_model = None |
| swin_t5_tokenizer = None |
| transform = None |
| llama_model = None |
| llama_tokenizer = None |
|
|
| def load_swin_t5_model_components(): |
| """Loads the Swin-T5 model, tokenizer, and transformation pipeline.""" |
| global swin_t5_model, swin_t5_tokenizer, transform |
| try: |
| print(f"Loading Swin-T5 model components on device: {DEVICE}") |
| |
| swin_t5_model = ImageCaptioningModel(swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME) |
|
|
| |
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError(f"Swin-T5 Model file not found at {MODEL_PATH}.") |
|
|
| |
| state = torch.load(MODEL_PATH, map_location=DEVICE) |
| |
| if isinstance(state, dict) and "model_state_dict" in state and len(state) > 1: |
| |
| swin_t5_model.load_state_dict(state["model_state_dict"]) |
| else: |
| swin_t5_model.load_state_dict(state) |
|
|
| swin_t5_model.to(DEVICE) |
| swin_t5_model.eval() |
| print("Swin-T5 Model loaded successfully.") |
|
|
| |
| swin_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME) |
| print("Swin-T5 Tokenizer loaded successfully.") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| print("Transforms defined.") |
|
|
| except Exception as e: |
| print(f"Error loading Swin-T5 model components: {e}") |
| print(traceback.format_exc()) |
| |
| raise |
|
|
| def load_llama_model_components(): |
| """Loads the Llama model and tokenizer.""" |
| global llama_model, llama_tokenizer |
| if not HF_TOKEN: |
| print("Skipping Llama model load: Hugging Face token not found.") |
| return |
|
|
| try: |
| print(f"Loading Llama model ({LLAMA_MODEL_NAME}) components...") |
|
|
| |
| if torch.cuda.is_available(): |
| |
| try: |
| torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| except Exception: |
| torch_dtype = torch.float16 |
| else: |
| torch_dtype = torch.float32 |
|
|
| |
| llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_auth_token=HF_TOKEN) |
| llama_model = AutoModelForCausalLM.from_pretrained( |
| LLAMA_MODEL_NAME, |
| torch_dtype=torch_dtype, |
| device_map="auto", |
| use_auth_token=HF_TOKEN |
| ) |
| llama_model.eval() |
| print("Llama Model and Tokenizer loaded successfully.") |
| except Exception as e: |
| print(f"Error loading Llama model components: {e}") |
| print(traceback.format_exc()) |
| llama_model = None |
| llama_tokenizer = None |
| print("WARNING: Chatbot functionality will be disabled due to loading error.") |
|
|
| |
| def generate_report(image_bytes, selected_vlm, max_length=100): |
| """Generates a report/caption for the given image bytes using Swin-T5.""" |
| global swin_t5_model, swin_t5_tokenizer, transform |
| |
| if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: |
| load_swin_t5_model_components() |
| if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: |
| raise RuntimeError("Swin-T5 model components failed to load.") |
|
|
| if selected_vlm != "swin_t5_chexpert": |
| return "Error: Selected VLM is not supported." |
|
|
| try: |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| input_image = transform(image).unsqueeze(0).to(DEVICE) |
|
|
| |
| with torch.no_grad(): |
| swin_outputs = swin_t5_model.swin(input_image, return_dict=True) |
| img_feats = swin_outputs.last_hidden_state |
| img_feats_proj = swin_t5_model.img_proj(img_feats) |
| encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) |
|
|
| generated_ids = swin_t5_model.t5.generate( |
| encoder_outputs=encoder_outputs, |
| max_length=max_length, |
| num_beams=4, |
| early_stopping=True |
| ) |
| report = swin_t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| return report |
|
|
| except Exception as e: |
| print(f"Error during Swin-T5 report generation: {e}") |
| print(traceback.format_exc()) |
| return f"Error generating report: {e}" |
|
|
| |
| def generate_chat_response(question, report_context, max_new_tokens=250): |
| """Generates a chat response using Llama based on the report context.""" |
| global llama_model, llama_tokenizer |
| if llama_model is None or llama_tokenizer is None: |
| return "Chatbot is currently unavailable." |
|
|
| system_prompt = "You are a helpful medical assistant. I'm a medical student, your task is to help me understand the following report." |
| prompt = (f"{system_prompt}\n\nBased on the following report:\n\n---\n{report_context}\n---\n\n" |
| f"Please answer this question: {question}\n") |
|
|
| try: |
| |
| inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True) |
| input_ids = inputs["input_ids"].to(next(llama_model.parameters()).device) |
| attention_mask = inputs.get("attention_mask", None) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(input_ids.device) |
|
|
| with torch.no_grad(): |
| outputs = llama_model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new_tokens, |
| eos_token_id=llama_tokenizer.eos_token_id, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.9, |
| pad_token_id=llama_tokenizer.eos_token_id |
| ) |
|
|
| |
| generated = outputs[0] |
| |
| response_ids = generated[input_ids.shape[-1]:] |
| response_text = llama_tokenizer.decode(response_ids, skip_special_tokens=True).strip() |
| return response_text |
|
|
| except Exception as e: |
| print(f"Error during Llama chat generation: {e}") |
| print(traceback.format_exc()) |
| return f"Error generating chat response: {e}" |
|
|
| |
| app = Flask(__name__) |
| app.secret_key = os.urandom(24) |
|
|
| |
| print("Loading models on application startup...") |
| try: |
| load_swin_t5_model_components() |
| load_llama_model_components() |
| print("Model loading complete.") |
| except Exception as e: |
| print(f"FATAL ERROR during model loading: {e}") |
| print(traceback.format_exc()) |
| |
| |
|
|
| def allowed_file(filename): |
| return '.' in filename and \ |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
| def parse_patient_info(filename): |
| """ |
| Parses a filename like '00069-34-Frontal-AP-63.0-Male-White.png' |
| Returns a dictionary with 'view', 'age', 'gender', 'ethnicity'. |
| Returns None if parsing fails. |
| """ |
| try: |
| base_name = os.path.splitext(filename)[0] |
| parts = base_name.split('-') |
| if len(parts) < 5: |
| print(f"Warning: Filename '{filename}' has fewer parts than expected.") |
| return None |
|
|
| ethnicity = parts[-1] |
| gender = parts[-2] |
| age_str = parts[-3] |
| try: |
| age = int(float(age_str)) |
| except ValueError: |
| print(f"Warning: Could not parse age '{age_str}' from filename '{filename}'.") |
| return None |
|
|
| view_parts = parts[2:-3] |
| view = '-'.join(view_parts) if view_parts else "Unknown" |
|
|
| if gender.lower() not in ['male', 'female', 'other', 'unknown']: |
| print(f"Warning: Unusual gender '{gender}' found in filename '{filename}'.") |
|
|
| return { |
| 'view': view, |
| 'age': age, |
| 'gender': gender.capitalize(), |
| 'ethnicity': ethnicity.capitalize() |
| } |
| except IndexError: |
| print(f"Error parsing filename '{filename}': Index out of bounds.") |
| return None |
| except Exception as e: |
| print(f"Error parsing filename '{filename}': {e}") |
| print(traceback.format_exc()) |
| return None |
|
|
| |
| @app.route('/', methods=['GET']) |
| def index(): |
| chatbot_available = bool(llama_model and llama_tokenizer) |
| return render_template('index.html', chatbot_available=chatbot_available) |
|
|
| @app.route('/predict', methods=['POST']) |
| def predict(): |
| chatbot_available = bool(llama_model and llama_tokenizer) |
| patient_info = None |
|
|
| if 'image' not in request.files: |
| flash('No image file part in the request.', 'danger') |
| return redirect(url_for('index')) |
|
|
| file = request.files['image'] |
| vlm_choice = request.form.get('vlm_choice', 'swin_t5_chexpert') |
| try: |
| max_length = int(request.form.get('max_length', 100)) |
| if not (10 <= max_length <= 512): |
| raise ValueError("Max length must be between 10 and 512.") |
| except ValueError as e: |
| flash(f'Invalid Max Length value: {e}', 'danger') |
| return redirect(url_for('index')) |
|
|
| if file.filename == '': |
| flash('No image selected for uploading.', 'warning') |
| return redirect(url_for('index')) |
|
|
| if file and allowed_file(file.filename): |
| try: |
| image_bytes = file.read() |
|
|
| original_filename = file.filename |
| patient_info = parse_patient_info(original_filename) |
| if patient_info: |
| print(f"Parsed Patient Info: {patient_info}") |
| else: |
| print(f"Could not parse patient info from filename: {original_filename}") |
|
|
| report = generate_report(image_bytes, vlm_choice, max_length) |
|
|
| if isinstance(report, str) and report.startswith("Error"): |
| flash(f'Report Generation Failed: {report}', 'danger') |
| image_data = base64.b64encode(image_bytes).decode('utf-8') |
| return render_template('index.html', |
| report=None, |
| image_data=image_data, |
| patient_info=patient_info, |
| chatbot_available=chatbot_available) |
|
|
| image_data = base64.b64encode(image_bytes).decode('utf-8') |
|
|
| return render_template('index.html', |
| report=report, |
| image_data=image_data, |
| patient_info=patient_info, |
| chatbot_available=chatbot_available) |
|
|
| except FileNotFoundError as fnf_error: |
| flash(f'Model file not found: {fnf_error}. Please check server configuration.', 'danger') |
| print(f"Model file error: {fnf_error}\n{traceback.format_exc()}") |
| return redirect(url_for('index')) |
| except RuntimeError as rt_error: |
| flash(f'Model loading error: {rt_error}. Please check server logs.', 'danger') |
| print(f"Runtime error during prediction: {rt_error}\n{traceback.format_exc()}") |
| return redirect(url_for('index')) |
| except Exception as e: |
| flash(f'An unexpected error occurred during prediction: {e}', 'danger') |
| print(f"Error during prediction: {e}\n{traceback.format_exc()}") |
| return redirect(url_for('index')) |
| else: |
| flash('Invalid image file type. Allowed types: png, jpg, jpeg.', 'danger') |
| return redirect(url_for('index')) |
|
|
| @app.route('/chat', methods=['POST']) |
| def chat(): |
| if not llama_model or not llama_tokenizer: |
| return jsonify({"answer": "Chatbot is not available."}), 503 |
|
|
| data = request.get_json() |
| if not data or 'question' not in data or 'report_context' not in data: |
| return jsonify({"error": "Missing question or report context"}), 400 |
|
|
| question = data['question'] |
| report_context = data['report_context'] |
|
|
| try: |
| answer = generate_chat_response(question, report_context) |
| return jsonify({"answer": answer}) |
| except Exception as e: |
| print(f"Error in /chat endpoint: {e}") |
| print(traceback.format_exc()) |
| return jsonify({"error": "Failed to generate chat response"}), 500 |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=5000, debug=False) |
|
|