| """ |
| RadFig VQA Image Filtering Model - Inference Script |
| Classifies medical images as suitable/unsuitable for VQA tasks. |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import timm |
| import cv2 |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| from albumentations import Compose, Resize, Normalize |
| from albumentations.pytorch import ToTensorV2 |
| from tqdm import tqdm |
|
|
|
|
| class Config: |
| """Configuration for inference""" |
| model_name = "tf_efficientnetv2_s.in21k_ft_in1k" |
| size = 512 |
| batch_size = 32 |
| num_workers = 4 |
| target_size = 1 |
| n_fold = 5 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
| class TestDataset(Dataset): |
| """Dataset for inference""" |
| |
| def __init__(self, image_paths, transform=None): |
| self.image_paths = image_paths |
| self.transform = transform |
| |
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| image_path = self.image_paths[idx] |
| |
| |
| image = cv2.imread(image_path) |
| if image is None: |
| raise ValueError(f"Could not load image: {image_path}") |
| |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| if self.transform: |
| augmented = self.transform(image=image) |
| image = augmented['image'] |
| |
| return image |
|
|
|
|
| def get_transforms(): |
| """Get inference transforms""" |
| return Compose([ |
| Resize(Config.size, Config.size), |
| Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225], |
| ), |
| ToTensorV2(), |
| ]) |
|
|
|
|
| class RadFigClassifier: |
| """RadFig VQA Image Filtering Classifier""" |
| |
| def __init__(self, model_dir="models"): |
| self.config = Config() |
| self.model_dir = model_dir |
| self.device = self.config.device |
| self.model = None |
| self.states = [] |
| |
| |
| self._load_model_states() |
| |
| def _load_model_states(self): |
| """Load all fold model states""" |
| self.states = [] |
| for fold in range(self.config.n_fold): |
| model_path = os.path.join( |
| self.model_dir, |
| f"{self.config.model_name}_fold{fold}_best_loss.pth" |
| ) |
| |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
| |
| state = torch.load(model_path, map_location=self.device) |
| self.states.append(state) |
| |
| print(f"Loaded {len(self.states)} model states from {self.model_dir}") |
| |
| def _create_model(self): |
| """Create model architecture""" |
| model = timm.create_model( |
| model_name=self.config.model_name, |
| num_classes=self.config.target_size, |
| pretrained=False |
| ) |
| return model.to(self.device) |
| |
| def predict_batch(self, image_paths, return_probabilities=True): |
| """ |
| Predict on a batch of images |
| |
| Args: |
| image_paths (list): List of image file paths |
| return_probabilities (bool): If True, return probabilities. If False, return binary predictions. |
| |
| Returns: |
| numpy.ndarray: Predictions (probabilities or binary) |
| """ |
| |
| dataset = TestDataset(image_paths, transform=get_transforms()) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=self.config.batch_size, |
| shuffle=False, |
| num_workers=self.config.num_workers, |
| pin_memory=True |
| ) |
| |
| |
| model = self._create_model() |
| |
| all_predictions = [] |
| |
| |
| with torch.no_grad(): |
| for images in tqdm(dataloader, desc="Predicting"): |
| images = images.to(self.device) |
| |
| |
| fold_predictions = [] |
| |
| for state in self.states: |
| model.load_state_dict(state['model']) |
| model.eval() |
| |
| outputs = model(images) |
| probabilities = torch.sigmoid(outputs).cpu().numpy() |
| fold_predictions.append(probabilities) |
| |
| |
| avg_predictions = np.mean(fold_predictions, axis=0) |
| all_predictions.append(avg_predictions) |
| |
| |
| predictions = np.concatenate(all_predictions, axis=0).flatten() |
| |
| if return_probabilities: |
| return predictions |
| else: |
| return (predictions > 0.5).astype(int) |
| |
| def predict_single(self, image_path, return_probability=True): |
| """ |
| Predict on a single image |
| |
| Args: |
| image_path (str): Path to image file |
| return_probability (bool): If True, return probability. If False, return binary prediction. |
| |
| Returns: |
| float or int: Prediction |
| """ |
| predictions = self.predict_batch([image_path], return_probabilities=return_probability) |
| return predictions[0] |
| |
| def predict_directory(self, directory_path, output_csv=None, return_probabilities=True): |
| """ |
| Predict on all images in a directory |
| |
| Args: |
| directory_path (str): Path to directory containing images |
| output_csv (str, optional): Path to save results as CSV |
| return_probabilities (bool): If True, return probabilities. If False, return binary predictions. |
| |
| Returns: |
| pandas.DataFrame: Results with image paths and predictions |
| """ |
| |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} |
| image_paths = [] |
| |
| for filename in os.listdir(directory_path): |
| if any(filename.lower().endswith(ext) for ext in image_extensions): |
| image_paths.append(os.path.join(directory_path, filename)) |
| |
| if not image_paths: |
| raise ValueError(f"No image files found in {directory_path}") |
| |
| print(f"Found {len(image_paths)} images in {directory_path}") |
| |
| |
| predictions = self.predict_batch(image_paths, return_probabilities=return_probabilities) |
| |
| |
| results = pd.DataFrame({ |
| 'image_path': image_paths, |
| 'filename': [os.path.basename(path) for path in image_paths], |
| 'prediction': predictions, |
| 'suitable_for_vqa': predictions > 0.9 if return_probabilities else predictions.astype(bool) |
| }) |
| |
| |
| results = results.sort_values('filename').reset_index(drop=True) |
| |
| |
| if output_csv: |
| results.to_csv(output_csv, index=False) |
| print(f"Results saved to {output_csv}") |
| |
| return results |
|
|
|
|
| def main(): |
| """Example usage""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="RadFig VQA Image Filtering Inference") |
| parser.add_argument("--input", required=True, help="Input image file or directory") |
| parser.add_argument("--models", default="models", help="Directory containing model files") |
| parser.add_argument("--output", help="Output CSV file (for directory input)") |
| parser.add_argument("--binary", action="store_true", help="Return binary predictions instead of probabilities") |
| |
| args = parser.parse_args() |
| |
| |
| classifier = RadFigClassifier(model_dir=args.models) |
| |
| if os.path.isfile(args.input): |
| |
| prediction = classifier.predict_single( |
| args.input, |
| return_probability=not args.binary |
| ) |
| |
| if args.binary: |
| result = "suitable" if prediction else "not suitable" |
| print(f"Image: {args.input}") |
| print(f"Prediction: {result} for VQA") |
| else: |
| print(f"Image: {args.input}") |
| print(f"Probability suitable for VQA: {prediction:.4f}") |
| print(f"Classification: {'suitable' if prediction > 0.9 else 'not suitable'}") |
| |
| elif os.path.isdir(args.input): |
| |
| results = classifier.predict_directory( |
| args.input, |
| output_csv=args.output, |
| return_probabilities=not args.binary |
| ) |
| |
| |
| if args.binary: |
| suitable_count = results['suitable_for_vqa'].sum() |
| else: |
| suitable_count = (results['prediction'] > 0.9).sum() |
| |
| total_count = len(results) |
| |
| print(f"\nSummary:") |
| print(f"Total images: {total_count}") |
| print(f"Suitable for VQA: {suitable_count}") |
| print(f"Not suitable for VQA: {total_count - suitable_count}") |
| print(f"Percentage suitable: {suitable_count/total_count*100:.1f}%") |
| |
| else: |
| print(f"Error: {args.input} is not a valid file or directory") |
|
|
|
|
| if __name__ == "__main__": |
| main() |