import os import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image, UnidentifiedImageError from transformers import DeiTForImageClassification, DeiTImageProcessor # ================================ ResNet-18 Classifier ================================ class Car_Classifier_Resnet(nn.Module): def __init__(self, num_classes): super().__init__() self.model = models.resnet18(weights="DEFAULT") for param in self.model.parameters(): param.requires_grad = False for param in self.model.layer3.parameters(): param.requires_grad = True for param in self.model.layer4.parameters(): param.requires_grad = True # Replace FC head self.model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(self.model.fc.in_features, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) def forward(self, x): return self.model(x) class ResnetCarDamagePredictor: def __init__(self, checkpoint_path, class_map): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.class_map = class_map self.test_transforms = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) try: self.model = Car_Classifier_Resnet(num_classes=len(class_map)) checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load ResNet model: {str(e)}") def resnet_predict(self, image_input): try: if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise TypeError("image_input must be a file path or PIL.Image") image = self.test_transforms(image) image = image.unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(image) probs = torch.nn.functional.softmax(outputs, dim=1)[0] class_probs = { self.class_map[i]: float(probs[i].item()) for i in range(len(self.class_map)) } return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) except UnidentifiedImageError: raise ValueError("Invalid image file provided") except Exception as e: raise RuntimeError(f"ResNet prediction failed: {str(e)}") # ================================ DeiT Classifier ================================ class DeitCarDamagePredictor: def __init__(self, checkpoint_path, class_map): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.class_map = class_map self.checkpoint_path = checkpoint_path self.transform = transforms.Compose([transforms.Resize((224, 224))]) model_name = "facebook/deit-base-distilled-patch16-224" try: self.processor = DeiTImageProcessor.from_pretrained(model_name) self.model = DeiTForImageClassification.from_pretrained( model_name, num_labels=len(class_map), ignore_mismatched_sizes=True ) checkpoint = torch.load(self.checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load DeiT model: {str(e)}") def deit_predict(self, image_input): try: if isinstance(image_input, str): image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise TypeError("image_input must be a file path or PIL.Image") image = self.transform(image) inputs = self.processor(image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] class_probs = { self.class_map[i]: float(probs[i].item()) for i in range(len(self.class_map)) } return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True)) except UnidentifiedImageError: raise ValueError("Invalid image file provided") except Exception as e: raise RuntimeError(f"DeiT prediction failed: {str(e)}") # ================================ Fusion Predictor ================================ class FusionCarDamagePredictor: def __init__(self, resnet_predictor, deit_predictor, resnet_weight=0.5, deit_weight=0.5): if resnet_weight < 0 or deit_weight < 0: raise ValueError("Weights must be non-negative") total = resnet_weight + deit_weight if total == 0: raise ValueError("At least one weight must be greater than 0") self.resnet_predictor = resnet_predictor self.deit_predictor = deit_predictor self.resnet_weight = resnet_weight / total self.deit_weight = deit_weight / total def fuse_predict(self, image_input): try: resnet_output = self.resnet_predictor.resnet_predict(image_input) deit_output = self.deit_predictor.deit_predict(image_input) all_classes = set(resnet_output.keys()).union(set(deit_output.keys())) fused_output = {} for cls in all_classes: resnet_prob = resnet_output.get(cls, 0.0) deit_prob = deit_output.get(cls, 0.0) fused_prob = self.resnet_weight * resnet_prob + self.deit_weight * deit_prob fused_output[cls] = float(fused_prob) fused_output = dict(sorted(fused_output.items(), key=lambda x: x[1], reverse=True)) final_class = next(iter(fused_output)) final_confidence = fused_output[final_class] return { "resnet_output": resnet_output, "deit_output": deit_output, "fused_output": fused_output, "final_prediction": final_class, "final_confidence": round(final_confidence, 4) } except Exception as e: raise RuntimeError(f"Fusion prediction failed: {str(e)}")