junaid17 commited on
Commit
1782395
·
verified ·
1 Parent(s): f80a3ac

Upload 3 files

Browse files
Files changed (3) hide show
  1. scripts/gradcam.py +96 -0
  2. scripts/prediction_helper.py +172 -0
  3. scripts/yolo.py +47 -0
scripts/gradcam.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ def get_resnet_gradcam(image_path, predictor, output_path):
6
+ model = predictor.model
7
+ device = predictor.device
8
+ model.eval()
9
+
10
+ features, gradients = [], []
11
+
12
+ def forward_hook(module, input, output): features.append(output)
13
+ def backward_hook(module, grad_in, grad_out): gradients.append(grad_out[0])
14
+
15
+ target_layer = model.model.layer4[-1]
16
+ handle_fw = target_layer.register_forward_hook(forward_hook)
17
+ handle_bw = target_layer.register_full_backward_hook(backward_hook)
18
+
19
+ original_img = Image.open(image_path).convert("RGB")
20
+ input_tensor = predictor.test_transforms(original_img).unsqueeze(0).to(device)
21
+
22
+ model.zero_grad()
23
+ output = model(input_tensor)
24
+ pred_class_idx = output.argmax(dim=1).item()
25
+
26
+ score = output[0, pred_class_idx]
27
+ score.backward()
28
+
29
+ handle_fw.remove()
30
+ handle_bw.remove()
31
+
32
+ acts = features[0].cpu().data.numpy()[0]
33
+ grads = gradients[0].cpu().data.numpy()[0]
34
+ weights = np.mean(grads, axis=(1, 2))
35
+
36
+ cam = np.zeros(acts.shape[1:], dtype=np.float32)
37
+ for i, w in enumerate(weights):
38
+ cam += w * acts[i]
39
+
40
+ cam = np.maximum(cam, 0)
41
+ cam = cv2.resize(cam, (original_img.width, original_img.height))
42
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
43
+
44
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
45
+ original_np = np.array(original_img)
46
+
47
+ # Overlay logic (OpenCV style)
48
+ overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
49
+ cv2.imwrite(output_path, overlay)
50
+ return True
51
+
52
+ def get_deit_gradcam(image_path, predictor, output_path):
53
+ model = predictor.model
54
+ processor = predictor.processor
55
+ device = predictor.device
56
+ model.eval()
57
+
58
+ features, gradients = [], []
59
+
60
+ def forward_hook(module, input, output): features.append(output)
61
+ def backward_hook(module, grad_in, grad_out): gradients.append(grad_out[0])
62
+
63
+ target_layer = model.deit.encoder.layer[-1].layernorm_before
64
+ handle_fw = target_layer.register_forward_hook(forward_hook)
65
+ handle_bw = target_layer.register_full_backward_hook(backward_hook)
66
+
67
+ original_img = Image.open(image_path).convert("RGB")
68
+ inputs = processor(images=original_img, return_tensors="pt").to(device)
69
+
70
+ model.zero_grad()
71
+ outputs = model(**inputs)
72
+ pred_class_idx = outputs.logits.argmax(dim=1).item()
73
+
74
+ score = outputs.logits[0, pred_class_idx]
75
+ score.backward()
76
+
77
+ handle_fw.remove()
78
+ handle_bw.remove()
79
+
80
+ acts = features[0].cpu().data.numpy()[0]
81
+ grads = gradients[0].cpu().data.numpy()[0]
82
+ cam = np.sum(grads * acts, axis=-1)
83
+ cam = cam[2:] # Remove CLS and Distillation tokens
84
+
85
+ grid_size = int(np.sqrt(cam.shape[0]))
86
+ cam = cam.reshape(grid_size, grid_size)
87
+ cam = np.maximum(cam, 0)
88
+ cam = cv2.resize(cam, (original_img.width, original_img.height))
89
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
90
+
91
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
92
+ original_np = np.array(original_img)
93
+
94
+ overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
95
+ cv2.imwrite(output_path, overlay)
96
+ return True
scripts/prediction_helper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image, UnidentifiedImageError
6
+ from transformers import DeiTForImageClassification, DeiTImageProcessor
7
+
8
+ # ================================ ResNet-18 Classifier ================================
9
+ class Car_Classifier_Resnet(nn.Module):
10
+ def __init__(self, num_classes):
11
+ super().__init__()
12
+ self.model = models.resnet18(weights="DEFAULT")
13
+
14
+ for param in self.model.parameters():
15
+ param.requires_grad = False
16
+ for param in self.model.layer3.parameters():
17
+ param.requires_grad = True
18
+ for param in self.model.layer4.parameters():
19
+ param.requires_grad = True
20
+
21
+ # Replace FC head
22
+ self.model.fc = nn.Sequential(
23
+ nn.Dropout(0.5),
24
+ nn.Linear(self.model.fc.in_features, 256),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.3),
27
+ nn.Linear(256, num_classes)
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.model(x)
32
+
33
+
34
+ class ResnetCarDamagePredictor:
35
+ def __init__(self, checkpoint_path, class_map):
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ self.class_map = class_map
38
+
39
+ self.test_transforms = transforms.Compose([
40
+ transforms.Resize((128, 128)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.485, 0.456, 0.406],
43
+ [0.229, 0.224, 0.225])
44
+ ])
45
+
46
+ try:
47
+ self.model = Car_Classifier_Resnet(num_classes=len(class_map))
48
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
49
+ self.model.load_state_dict(checkpoint["model_state_dict"])
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+ except Exception as e:
53
+ raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
54
+
55
+ def resnet_predict(self, image_input):
56
+ try:
57
+ if isinstance(image_input, str):
58
+ image = Image.open(image_input).convert("RGB")
59
+ elif isinstance(image_input, Image.Image):
60
+ image = image_input.convert("RGB")
61
+ else:
62
+ raise TypeError("image_input must be a file path or PIL.Image")
63
+
64
+ image = self.test_transforms(image)
65
+ image = image.unsqueeze(0).to(self.device)
66
+
67
+ with torch.no_grad():
68
+ outputs = self.model(image)
69
+
70
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
71
+ class_probs = {
72
+ self.class_map[i]: float(probs[i].item())
73
+ for i in range(len(self.class_map))
74
+ }
75
+ return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
76
+
77
+ except UnidentifiedImageError:
78
+ raise ValueError("Invalid image file provided")
79
+ except Exception as e:
80
+ raise RuntimeError(f"ResNet prediction failed: {str(e)}")
81
+
82
+ # ================================ DeiT Classifier ================================
83
+ class DeitCarDamagePredictor:
84
+ def __init__(self, checkpoint_path, class_map):
85
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ self.class_map = class_map
87
+ self.checkpoint_path = checkpoint_path
88
+
89
+ self.transform = transforms.Compose([transforms.Resize((224, 224))])
90
+ model_name = "facebook/deit-base-distilled-patch16-224"
91
+
92
+ try:
93
+ self.processor = DeiTImageProcessor.from_pretrained(model_name)
94
+ self.model = DeiTForImageClassification.from_pretrained(
95
+ model_name,
96
+ num_labels=len(class_map),
97
+ ignore_mismatched_sizes=True
98
+ )
99
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
100
+ self.model.load_state_dict(checkpoint["model_state_dict"])
101
+ self.model.to(self.device)
102
+ self.model.eval()
103
+ except Exception as e:
104
+ raise RuntimeError(f"Failed to load DeiT model: {str(e)}")
105
+
106
+ def deit_predict(self, image_input):
107
+ try:
108
+ if isinstance(image_input, str):
109
+ image = Image.open(image_input).convert("RGB")
110
+ elif isinstance(image_input, Image.Image):
111
+ image = image_input.convert("RGB")
112
+ else:
113
+ raise TypeError("image_input must be a file path or PIL.Image")
114
+
115
+ image = self.transform(image)
116
+ inputs = self.processor(image, return_tensors="pt").to(self.device)
117
+
118
+ with torch.no_grad():
119
+ outputs = self.model(**inputs)
120
+
121
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
122
+ class_probs = {
123
+ self.class_map[i]: float(probs[i].item())
124
+ for i in range(len(self.class_map))
125
+ }
126
+ return dict(sorted(class_probs.items(), key=lambda x: x[1], reverse=True))
127
+
128
+ except UnidentifiedImageError:
129
+ raise ValueError("Invalid image file provided")
130
+ except Exception as e:
131
+ raise RuntimeError(f"DeiT prediction failed: {str(e)}")
132
+
133
+ # ================================ Fusion Predictor ================================
134
+ class FusionCarDamagePredictor:
135
+ def __init__(self, resnet_predictor, deit_predictor, resnet_weight=0.5, deit_weight=0.5):
136
+ if resnet_weight < 0 or deit_weight < 0:
137
+ raise ValueError("Weights must be non-negative")
138
+ total = resnet_weight + deit_weight
139
+ if total == 0:
140
+ raise ValueError("At least one weight must be greater than 0")
141
+
142
+ self.resnet_predictor = resnet_predictor
143
+ self.deit_predictor = deit_predictor
144
+ self.resnet_weight = resnet_weight / total
145
+ self.deit_weight = deit_weight / total
146
+
147
+ def fuse_predict(self, image_input):
148
+ try:
149
+ resnet_output = self.resnet_predictor.resnet_predict(image_input)
150
+ deit_output = self.deit_predictor.deit_predict(image_input)
151
+
152
+ all_classes = set(resnet_output.keys()).union(set(deit_output.keys()))
153
+ fused_output = {}
154
+ for cls in all_classes:
155
+ resnet_prob = resnet_output.get(cls, 0.0)
156
+ deit_prob = deit_output.get(cls, 0.0)
157
+ fused_prob = self.resnet_weight * resnet_prob + self.deit_weight * deit_prob
158
+ fused_output[cls] = float(fused_prob)
159
+
160
+ fused_output = dict(sorted(fused_output.items(), key=lambda x: x[1], reverse=True))
161
+ final_class = next(iter(fused_output))
162
+ final_confidence = fused_output[final_class]
163
+
164
+ return {
165
+ "resnet_output": resnet_output,
166
+ "deit_output": deit_output,
167
+ "fused_output": fused_output,
168
+ "final_prediction": final_class,
169
+ "final_confidence": round(final_confidence, 4)
170
+ }
171
+ except Exception as e:
172
+ raise RuntimeError(f"Fusion prediction failed: {str(e)}")
scripts/yolo.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+
6
+ yolo_model = YOLO("checkpoints/damage_detector.pt")
7
+
8
+
9
+ def get_yolo_damage_boxes(image_path, output_path):
10
+ try:
11
+ image = Image.open(image_path).convert("RGB")
12
+ results = yolo_model.predict(
13
+ source=image,
14
+ conf=0.05,
15
+ imgsz=640,
16
+ verbose=False
17
+ )
18
+
19
+ result = results[0]
20
+ boxes = result.boxes
21
+ detections = []
22
+
23
+ if boxes is not None and len(boxes) > 0:
24
+ for box in boxes:
25
+ conf = float(box.conf[0])
26
+ cls_id = int(box.cls[0])
27
+ label = yolo_model.names[cls_id]
28
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
29
+
30
+ detections.append({
31
+ "label": label,
32
+ "confidence": round(conf, 4),
33
+ "box": [x1, y1, x2, y2]
34
+ })
35
+
36
+ plotted_bgr = result.plot()
37
+ plotted_rgb = plotted_bgr[..., ::-1]
38
+ cv2.imwrite(output_path, plotted_rgb)
39
+
40
+ return {
41
+ "detections": detections,
42
+ "total_detections": len(detections),
43
+ "message": "No damage detected" if len(detections) == 0 else "Detections found"
44
+ }
45
+
46
+ except Exception as e:
47
+ raise RuntimeError(f"YOLO failed: {str(e)}")