hiren05 commited on
Commit
7db50ad
·
verified ·
1 Parent(s): c331f29

Upload 2 files

Browse files
Files changed (2) hide show
  1. amodal_completion_model.pth +3 -0
  2. final1_2.py +981 -0
amodal_completion_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4933917dbfac8f3970d150b7fce00c95c58da724e42f21d74964ea53c13c625a
3
+ size 124272930
final1_2.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """final1.2.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT
8
+ """
9
+
10
+
11
+
12
+ """## FINAL 1.2"""
13
+
14
+
15
+
16
+ pip install torchmetrics lpips
17
+
18
+ # PyTorch, Torchvision
19
+ import torch
20
+ from torch import nn
21
+ from torchvision.transforms import ToPILImage, ToTensor
22
+ from torchvision.utils import make_grid
23
+ from torchvision.io import write_video
24
+
25
+ # Common
26
+ from pathlib import Path
27
+ from PIL import Image
28
+ import numpy as np
29
+ import matplotlib.pyplot as plt
30
+ import random
31
+ import json
32
+ from IPython.display import Video
33
+
34
+ # Utils from Torchvision
35
+ tensor_to_image = ToPILImage()
36
+ image_to_tensor = ToTensor()
37
+
38
+ def get_img_dict(img_dir):
39
+ img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')]
40
+ img_files.sort()
41
+
42
+ img_dict = {}
43
+ for img_file in img_files:
44
+ img_type = img_file.name.split('_')[0]
45
+ if img_type not in img_dict:
46
+ img_dict[img_type] = []
47
+ img_dict[img_type].append(img_file)
48
+ return img_dict
49
+
50
+
51
+ def get_sample_dict(sample_dir):
52
+
53
+ camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name]
54
+ camera_dirs.sort()
55
+
56
+ sample_dict = {}
57
+
58
+ for cam_dir in camera_dirs:
59
+ cam_dict = {}
60
+ cam_dict['scene'] = get_img_dict(cam_dir)
61
+
62
+ obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name]
63
+ obj_dirs.sort()
64
+
65
+ for obj_dir in obj_dirs:
66
+ cam_dict[obj_dir.name] = get_img_dict(obj_dir)
67
+
68
+ sample_dict[cam_dir.name] = cam_dict
69
+
70
+ return sample_dict
71
+
72
+ !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json
73
+ #Download Descriptors, Readme, etc.
74
+ !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json
75
+ !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4
76
+ !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md
77
+ !wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf"
78
+ !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes
79
+ #Test to see if you are on the right huggingface repo
80
+ from huggingface_hub import HfApi, hf_hub_download
81
+ import random, os
82
+ api = HfApi()
83
+ repo_id = "Amar-S/MOVi-MC-AC"
84
+ # # List all files in the repo
85
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
86
+ # # Separate train and test files
87
+ train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")]
88
+ test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")]
89
+ print(f"Found {len(train_files)} train files and {len(test_files)} test files.")
90
+ #Download 4% of Train/Test files
91
+ import os
92
+ import random
93
+ import shutil
94
+ from huggingface_hub import hf_hub_download
95
+ os.makedirs("/content/data/train", exist_ok=True)
96
+ os.makedirs("/content/data/test", exist_ok=True)
97
+ # # Sample 4% of each split (as you were doing)
98
+ subset_train = random.sample(train_files, int(len(train_files) * 0.005))
99
+ subset_test = random.sample(test_files, int(len(test_files) * 0.005))
100
+ # # Download the training files (uncomment and fix)
101
+ for file in subset_train:
102
+ out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
103
+ dest_path = f"/content/data/train/{os.path.basename(file)}"
104
+ shutil.copyfile(out_path, dest_path) # COPY the actual file content instead of renaming symlink
105
+ # # Download the test files
106
+ for file in subset_test:
107
+ out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
108
+ dest_path = f"/content/data/test/{os.path.basename(file)}"
109
+ shutil.copyfile(out_path, dest_path) # COPY the actual file content here as well
110
+
111
+ import os
112
+
113
+ # Untar all files in data/train
114
+ train_dir = "data/train"
115
+ for file in os.listdir(train_dir):
116
+ if file.endswith(".tar.gz"):
117
+ filepath = os.path.join(train_dir, file)
118
+ !tar -xzf {filepath} -C {train_dir}
119
+
120
+ # Untar all files in data/test
121
+ test_dir = "data/test"
122
+ for file in os.listdir(test_dir):
123
+ if file.endswith(".tar.gz"):
124
+ filepath = os.path.join(test_dir, file)
125
+ !tar -xzf {filepath} -C {test_dir}
126
+
127
+
128
+
129
+ import os
130
+ from pathlib import Path
131
+ root = Path('/content/data') # or wherever your files live
132
+ deleted = 0
133
+ for archive in root.rglob('*.tar.gz'):
134
+ try:
135
+ archive.unlink()
136
+ print(f"Deleted {archive}")
137
+ deleted += 1
138
+ except Exception as e:
139
+ print(f"Error deleting {archive}: {e}")
140
+ print(f"Total deleted: {deleted}")
141
+
142
+ pip install torchmetrics lpips
143
+
144
+ import matplotlib.pyplot as plt
145
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
146
+ import lpips
147
+ import matplotlib.pyplot as plt
148
+ import torch
149
+
150
+ def visualize_results(model, dataloader, device, num_samples=8):
151
+ """Visualize results with properly masked output (no background)"""
152
+ model.eval()
153
+ samples_shown = 0
154
+
155
+ with torch.no_grad():
156
+ for batch in dataloader:
157
+ if samples_shown >= num_samples:
158
+ break
159
+
160
+ rgb = batch['rgb'].to(device)
161
+ modal_mask = batch['modal_mask'].to(device)
162
+ amodal_mask = batch['amodal_mask'].to(device)
163
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
164
+
165
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
166
+ pred = model(input_tensor)
167
+
168
+ pred_masked = pred * amodal_mask # Remove background from prediction
169
+ gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently
170
+
171
+ for i in range(rgb.shape[0]):
172
+ if samples_shown >= num_samples:
173
+ break
174
+
175
+ fig, axes = plt.subplots(1, 6, figsize=(24, 4))
176
+
177
+ # Scene RGB
178
+ axes[0].imshow(rgb[i].cpu().permute(1, 2, 0))
179
+ axes[0].set_title('Scene RGB')
180
+ axes[0].axis('off')
181
+
182
+ # Amodal Mask
183
+ axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray')
184
+ axes[1].set_title('Amodal Mask')
185
+ axes[1].axis('off')
186
+
187
+ # Modal Mask
188
+ axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray')
189
+ axes[2].set_title('Modal Mask')
190
+ axes[2].axis('off')
191
+
192
+ # Ground Truth Amodal RGB (masked)
193
+ axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0))
194
+ axes[3].set_title('GT Amodal RGB')
195
+ axes[3].axis('off')
196
+
197
+ # Predicted Amodal RGB (masked)
198
+ axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0))
199
+ axes[4].set_title('Predicted Amodal RGB')
200
+ axes[4].axis('off')
201
+
202
+ # Difference Heatmap
203
+ diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0)
204
+ im = axes[5].imshow(diff.cpu(), cmap='hot')
205
+ axes[5].set_title('Prediction Error')
206
+ axes[5].axis('off')
207
+ plt.colorbar(im, ax=axes[5])
208
+
209
+ plt.tight_layout()
210
+ plt.show()
211
+
212
+ samples_shown += 1
213
+
214
+
215
+
216
+ # STEP 4: Add this function for better evaluation:
217
+ def evaluate_metrics(model, dataloader, device):
218
+ """Compute evaluation metrics only within object regions"""
219
+ model.eval()
220
+ total_mse = 0
221
+ occluded_mse = 0
222
+ visible_mse = 0
223
+ total_pixels = 0
224
+ occluded_pixels = 0
225
+ visible_pixels = 0
226
+
227
+ with torch.no_grad():
228
+ for batch in dataloader:
229
+ rgb = batch['rgb'].to(device)
230
+ modal_mask = batch['modal_mask'].to(device)
231
+ amodal_mask = batch['amodal_mask'].to(device)
232
+ occluded_mask = batch['occluded_mask'].to(device)
233
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
234
+
235
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
236
+ pred = model(input_tensor)
237
+
238
+ # Mask both prediction and ground truth to object regions only
239
+ pred_masked = pred * amodal_mask
240
+ gt_masked = gt_amodal_rgb * amodal_mask
241
+
242
+ # Overall MSE within object region
243
+ object_pixels = amodal_mask.sum()
244
+ if object_pixels > 0:
245
+ mse = F.mse_loss(pred_masked, gt_masked, reduction='sum')
246
+ total_mse += mse.item()
247
+ total_pixels += object_pixels.item()
248
+
249
+ # Occluded region MSE
250
+ occluded_region = occluded_mask * amodal_mask
251
+ occ_pixels = occluded_region.sum()
252
+ if occ_pixels > 0:
253
+ occ_mse = F.mse_loss(pred_masked * occluded_region,
254
+ gt_masked * occluded_region, reduction='sum')
255
+ occluded_mse += occ_mse.item()
256
+ occluded_pixels += occ_pixels.item()
257
+
258
+ # Visible region MSE
259
+ visible_region = modal_mask * amodal_mask
260
+ vis_pixels = visible_region.sum()
261
+ if vis_pixels > 0:
262
+ vis_mse = F.mse_loss(pred_masked * visible_region,
263
+ gt_masked * visible_region, reduction='sum')
264
+ visible_mse += vis_mse.item()
265
+ visible_pixels += vis_pixels.item()
266
+
267
+ return {
268
+ 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0,
269
+ 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0,
270
+ 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0,
271
+ }
272
+
273
+
274
+
275
+ def calculate_metrics(model, dataloader, device):
276
+ """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs."""
277
+
278
+ model.eval()
279
+ psnr = PeakSignalNoiseRatio().to(device)
280
+ ssim = StructuralSimilarityIndexMeasure().to(device)
281
+ lpips_loss = lpips.LPIPS(net='alex').to(device)
282
+
283
+ total_psnr, total_ssim, total_lpips = 0, 0, 0
284
+ total_iou = 0
285
+ count = 0
286
+
287
+ with torch.no_grad():
288
+ for batch in dataloader:
289
+ rgb = batch['rgb'].to(device)
290
+ modal_mask = batch['modal_mask'].to(device)
291
+ amodal_mask = batch['amodal_mask'].to(device)
292
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
293
+
294
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
295
+ pred = model(input_tensor)
296
+
297
+ pred_masked = pred * amodal_mask
298
+ gt_masked = gt_amodal_rgb * amodal_mask
299
+
300
+ for i in range(pred.shape[0]):
301
+ pred_i = pred_masked[i].unsqueeze(0)
302
+ gt_i = gt_masked[i].unsqueeze(0)
303
+
304
+ # Resize for LPIPS if necessary (it requires >= 64x64)
305
+ if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64:
306
+ continue
307
+
308
+ total_psnr += psnr(pred_i, gt_i).item()
309
+ total_ssim += ssim(pred_i, gt_i).item()
310
+ total_lpips += lpips_loss(pred_i, gt_i).item()
311
+
312
+ # mIoU between masks
313
+ intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum()
314
+ union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum()
315
+ if union > 0:
316
+ iou = intersection.float() / union.float()
317
+ total_iou += iou.item()
318
+
319
+ count += 1
320
+
321
+ if count == 0:
322
+ return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0}
323
+
324
+ return {
325
+ "psnr": total_psnr / count,
326
+ "ssim": total_ssim / count,
327
+ "lpips": total_lpips / count,
328
+ "miou": total_iou / count
329
+ }
330
+
331
+ pip install torchmetrics lpips
332
+
333
+ import matplotlib.pyplot as plt
334
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
335
+ import lpips
336
+ import matplotlib.pyplot as plt
337
+ import torch
338
+
339
+ def visualize_results(model, dataloader, device, num_samples=8):
340
+ """Visualize results with properly masked output (no background)"""
341
+ model.eval()
342
+ samples_shown = 0
343
+
344
+ with torch.no_grad():
345
+ for batch in dataloader:
346
+ if samples_shown >= num_samples:
347
+ break
348
+
349
+ rgb = batch['rgb'].to(device)
350
+ modal_mask = batch['modal_mask'].to(device)
351
+ amodal_mask = batch['amodal_mask'].to(device)
352
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
353
+
354
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
355
+ pred = model(input_tensor)
356
+
357
+ pred_masked = pred * amodal_mask # Remove background from prediction
358
+ gt_masked = gt_amodal_rgb * amodal_mask # Ensure GT is also masked consistently
359
+
360
+ for i in range(rgb.shape[0]):
361
+ if samples_shown >= num_samples:
362
+ break
363
+
364
+ fig, axes = plt.subplots(1, 6, figsize=(24, 4))
365
+
366
+ # Scene RGB
367
+ axes[0].imshow(rgb[i].cpu().permute(1, 2, 0))
368
+ axes[0].set_title('Scene RGB')
369
+ axes[0].axis('off')
370
+
371
+ # Amodal Mask
372
+ axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray')
373
+ axes[1].set_title('Amodal Mask')
374
+ axes[1].axis('off')
375
+
376
+ # Modal Mask
377
+ axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray')
378
+ axes[2].set_title('Modal Mask')
379
+ axes[2].axis('off')
380
+
381
+ # Ground Truth Amodal RGB (masked)
382
+ axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0))
383
+ axes[3].set_title('GT Amodal RGB')
384
+ axes[3].axis('off')
385
+
386
+ # Predicted Amodal RGB (masked)
387
+ axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0))
388
+ axes[4].set_title('Predicted Amodal RGB')
389
+ axes[4].axis('off')
390
+
391
+ # Difference Heatmap
392
+ diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0)
393
+ im = axes[5].imshow(diff.cpu(), cmap='hot')
394
+ axes[5].set_title('Prediction Error')
395
+ axes[5].axis('off')
396
+ plt.colorbar(im, ax=axes[5])
397
+
398
+ plt.tight_layout()
399
+ plt.show()
400
+
401
+ samples_shown += 1
402
+
403
+
404
+ def evaluate_metrics(model, dataloader, device):
405
+ """Compute evaluation metrics only within object regions"""
406
+ model.eval()
407
+ total_mse = 0
408
+ occluded_mse = 0
409
+ visible_mse = 0
410
+ total_pixels = 0
411
+ occluded_pixels = 0
412
+ visible_pixels = 0
413
+
414
+ with torch.no_grad():
415
+ for batch in dataloader:
416
+ rgb = batch['rgb'].to(device)
417
+ modal_mask = batch['modal_mask'].to(device)
418
+ amodal_mask = batch['amodal_mask'].to(device)
419
+ occluded_mask = batch['occluded_mask'].to(device)
420
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
421
+
422
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
423
+ pred = model(input_tensor)
424
+
425
+ # Mask both prediction and ground truth to object regions only
426
+ pred_masked = pred * amodal_mask
427
+ gt_masked = gt_amodal_rgb * amodal_mask
428
+
429
+ # Overall MSE within object region
430
+ object_pixels = amodal_mask.sum()
431
+ if object_pixels > 0:
432
+ mse = F.mse_loss(pred_masked, gt_masked, reduction='sum')
433
+ total_mse += mse.item()
434
+ total_pixels += object_pixels.item()
435
+
436
+ # Occluded region MSE
437
+ occluded_region = occluded_mask * amodal_mask
438
+ occ_pixels = occluded_region.sum()
439
+ if occ_pixels > 0:
440
+ occ_mse = F.mse_loss(pred_masked * occluded_region,
441
+ gt_masked * occluded_region, reduction='sum')
442
+ occluded_mse += occ_mse.item()
443
+ occluded_pixels += occ_pixels.item()
444
+
445
+ # Visible region MSE
446
+ visible_region = modal_mask * amodal_mask
447
+ vis_pixels = visible_region.sum()
448
+ if vis_pixels > 0:
449
+ vis_mse = F.mse_loss(pred_masked * visible_region,
450
+ gt_masked * visible_region, reduction='sum')
451
+ visible_mse += vis_mse.item()
452
+ visible_pixels += vis_pixels.item()
453
+
454
+ return {
455
+ 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0,
456
+ 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0,
457
+ 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0,
458
+ }
459
+
460
+
461
+
462
+ def calculate_metrics(model, dataloader, device):
463
+ """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs."""
464
+
465
+ model.eval()
466
+ psnr = PeakSignalNoiseRatio().to(device)
467
+ ssim = StructuralSimilarityIndexMeasure().to(device)
468
+ lpips_loss = lpips.LPIPS(net='alex').to(device)
469
+
470
+ total_psnr, total_ssim, total_lpips = 0, 0, 0
471
+ total_iou = 0
472
+ count = 0
473
+
474
+ with torch.no_grad():
475
+ for batch in dataloader:
476
+ rgb = batch['rgb'].to(device)
477
+ modal_mask = batch['modal_mask'].to(device)
478
+ amodal_mask = batch['amodal_mask'].to(device)
479
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
480
+
481
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
482
+ pred = model(input_tensor)
483
+
484
+ pred_masked = pred * amodal_mask
485
+ gt_masked = gt_amodal_rgb * amodal_mask
486
+
487
+ for i in range(pred.shape[0]):
488
+ pred_i = pred_masked[i].unsqueeze(0)
489
+ gt_i = gt_masked[i].unsqueeze(0)
490
+
491
+ # Resize for LPIPS if necessary (it requires >= 64x64)
492
+ if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64:
493
+ continue
494
+
495
+ total_psnr += psnr(pred_i, gt_i).item()
496
+ total_ssim += ssim(pred_i, gt_i).item()
497
+ total_lpips += lpips_loss(pred_i, gt_i).item()
498
+
499
+ # mIoU between masks
500
+ intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum()
501
+ union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum()
502
+ if union > 0:
503
+ iou = intersection.float() / union.float()
504
+ total_iou += iou.item()
505
+
506
+ count += 1
507
+
508
+ if count == 0:
509
+ return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0}
510
+
511
+ return {
512
+ "psnr": total_psnr / count,
513
+ "ssim": total_ssim / count,
514
+ "lpips": total_lpips / count,
515
+ "miou": total_iou / count
516
+ }
517
+
518
+
519
+
520
+
521
+ import torch
522
+ import torch.nn as nn
523
+ import torch.nn.functional as F
524
+ from torch.utils.data import Dataset, DataLoader
525
+ from torchvision import transforms
526
+ from pathlib import Path
527
+ from PIL import Image, ImageChops
528
+ import numpy as np
529
+
530
+ class ModalAmodalDataset(Dataset):
531
+ def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False):
532
+ self.root_dir = Path(root_dir)
533
+ self.img_size = img_size
534
+ self.max_samples = max_samples
535
+ self.val_split = val_split
536
+ self.use_val_from_train = use_val_from_train
537
+ self.split = split
538
+
539
+ if split == 'val' and use_val_from_train:
540
+ # Load from train folder but use validation subset
541
+ self.root_dir = self.root_dir / 'train'
542
+ else:
543
+ self.root_dir = self.root_dir / split
544
+
545
+ self.samples = self._build_sample_index()
546
+
547
+ self.rgb_transform = transforms.Compose([
548
+ transforms.Resize(img_size),
549
+ transforms.ToTensor(),
550
+ ])
551
+ self.mask_transform = transforms.Compose([
552
+ transforms.Resize(img_size),
553
+ transforms.ToTensor(),
554
+ ])
555
+
556
+ def _build_sample_index(self):
557
+ samples = []
558
+ for scene_dir in self.root_dir.iterdir():
559
+ if not scene_dir.is_dir():
560
+ continue
561
+ for camera_dir in scene_dir.iterdir():
562
+ if not camera_dir.name.startswith('camera_'):
563
+ continue
564
+
565
+ rgba_paths = sorted(camera_dir.glob('rgba_*.png'))
566
+ seg_paths = sorted(camera_dir.glob('segmentation_*.png'))
567
+
568
+ for obj_dir in camera_dir.iterdir():
569
+ if not obj_dir.name.startswith('obj_'):
570
+ continue
571
+
572
+ amodal_paths = sorted(obj_dir.glob('segmentation_*.png'))
573
+ amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png'))
574
+
575
+ if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)):
576
+ continue
577
+
578
+ for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip(
579
+ rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths
580
+ ):
581
+ samples.append({
582
+ 'rgb_path': rgba_path,
583
+ 'seg_path': seg_path,
584
+ 'amodal_path': amodal_path,
585
+ 'amodal_rgb_path': amodal_rgb_path,
586
+ 'object_id': int(obj_dir.name.split('_')[1]),
587
+ 'scene': scene_dir.name,
588
+ 'camera': camera_dir.name
589
+ })
590
+
591
+ # Limit dataset size if specified
592
+ if self.max_samples is not None and len(samples) > self.max_samples:
593
+ # Randomly sample to get diverse examples
594
+ import random
595
+ random.seed(42) # For reproducibility
596
+ samples = random.sample(samples, self.max_samples)
597
+ print(f"Dataset limited to {len(samples)} samples")
598
+
599
+ # Create train/val split if using validation from train
600
+ if self.use_val_from_train:
601
+ import random
602
+ random.seed(42) # Ensure reproducible splits
603
+ random.shuffle(samples)
604
+
605
+ val_size = int(len(samples) * self.val_split)
606
+ if self.split == 'train':
607
+ samples = samples[val_size:] # Use remaining samples for training
608
+ print(f"Train split: {len(samples)} samples")
609
+ elif self.split == 'val':
610
+ samples = samples[:val_size] # Use first samples for validation
611
+ print(f"Validation split: {len(samples)} samples")
612
+
613
+ return samples
614
+
615
+ def __len__(self):
616
+ return len(self.samples)
617
+
618
+ def __getitem__(self, idx):
619
+ sample = self.samples[idx]
620
+
621
+ # Load images
622
+ rgb = Image.open(sample['rgb_path']).convert('RGB')
623
+ seg_map = np.array(Image.open(sample['seg_path']))
624
+ amodal_mask_img = Image.open(sample['amodal_path']).convert('L')
625
+ amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB')
626
+
627
+ # Compute modal mask (visible part)
628
+ modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255
629
+ modal_mask_img = Image.fromarray(modal_mask_np, mode='L')
630
+
631
+ # Transform images and masks
632
+ rgb = self.rgb_transform(rgb)
633
+ modal_mask = self.mask_transform(modal_mask_img)
634
+ amodal_mask = self.mask_transform(amodal_mask_img)
635
+ amodal_rgb = self.rgb_transform(amodal_rgb)
636
+
637
+ # Create occluded mask (parts that are hidden)
638
+ occluded_mask = amodal_mask - modal_mask
639
+ occluded_mask = torch.clamp(occluded_mask, 0, 1)
640
+
641
+ return {
642
+ 'rgb': rgb,
643
+ 'modal_mask': modal_mask,
644
+ 'amodal_mask': amodal_mask,
645
+ 'occluded_mask': occluded_mask,
646
+ 'amodal_rgb': amodal_rgb,
647
+ }
648
+
649
+
650
+ class ImprovedUNet(nn.Module):
651
+
652
+ def __init__(self, in_channels=5, out_channels=3): # RGB + modal_mask + amodal_mask
653
+ super().__init__()
654
+
655
+ def conv_block(in_ch, out_ch, dropout=0.1):
656
+ return nn.Sequential(
657
+ nn.Conv2d(in_ch, out_ch, 3, padding=1),
658
+ nn.BatchNorm2d(out_ch),
659
+ nn.ReLU(inplace=True),
660
+ nn.Dropout2d(dropout),
661
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
662
+ nn.BatchNorm2d(out_ch),
663
+ nn.ReLU(inplace=True)
664
+ )
665
+
666
+ # Encoder
667
+ self.down1 = conv_block(in_channels, 64)
668
+ self.pool1 = nn.MaxPool2d(2)
669
+ self.down2 = conv_block(64, 128)
670
+ self.pool2 = nn.MaxPool2d(2)
671
+ self.down3 = conv_block(128, 256)
672
+ self.pool3 = nn.MaxPool2d(2)
673
+ self.down4 = conv_block(256, 512)
674
+ self.pool4 = nn.MaxPool2d(2)
675
+
676
+ # Bottleneck
677
+ self.middle = conv_block(512, 1024, dropout=0.2)
678
+
679
+ # Decoder
680
+ self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
681
+ self.up_block1 = conv_block(1024, 512)
682
+ self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
683
+ self.up_block2 = conv_block(512, 256)
684
+ self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
685
+ self.up_block3 = conv_block(256, 128)
686
+ self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
687
+ self.up_block4 = conv_block(128, 64)
688
+
689
+ self.final = nn.Conv2d(64, out_channels, 1)
690
+
691
+ def forward(self, x):
692
+ # Encoder
693
+ d1 = self.down1(x)
694
+ d2 = self.down2(self.pool1(d1))
695
+ d3 = self.down3(self.pool2(d2))
696
+ d4 = self.down4(self.pool3(d3))
697
+
698
+ # Bottleneck
699
+ m = self.middle(self.pool4(d4))
700
+
701
+ # Decoder with skip connections
702
+ u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1))
703
+ u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1))
704
+ u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1))
705
+ u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1))
706
+
707
+ return torch.sigmoid(self.final(u4)) # Ensure output is in [0,1]
708
+
709
+ class AmodalCompletionLoss(nn.Module):
710
+ """Loss that only considers object regions (ignores background)"""
711
+
712
+ def __init__(self, occluded_weight=5.0, visible_weight=1.0):
713
+ super().__init__()
714
+ self.occluded_weight = occluded_weight
715
+ self.visible_weight = visible_weight
716
+ self.lpips_model = lpips.LPIPS(net='alex')
717
+
718
+ def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask):
719
+ # Only compute loss within the amodal mask (object region)
720
+ device = pred.device
721
+ self.lpips_model = self.lpips_model.to(device)
722
+
723
+ pred_masked = pred * amodal_mask
724
+ target_masked = target * amodal_mask
725
+
726
+
727
+
728
+ # Loss on visible parts (within object)
729
+ visible_region = modal_mask * amodal_mask
730
+ if visible_region.sum() > 0:
731
+ visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region)
732
+ else:
733
+ visible_loss = torch.tensor(0.0).to(pred.device)
734
+
735
+ # Loss on occluded parts (within object)
736
+ occluded_region = occluded_mask * amodal_mask
737
+ if occluded_region.sum() > 0:
738
+ occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region)
739
+ else:
740
+ occluded_loss = torch.tensor(0.0).to(pred.device)
741
+
742
+
743
+ perceptual_loss = self.lpips_model(pred_masked, target_masked).mean()
744
+
745
+ # Boundary consistency within object
746
+ boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1)
747
+ boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float()
748
+ boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask)
749
+
750
+ total_loss = (self.visible_weight * visible_loss +
751
+ self.occluded_weight * occluded_loss +
752
+ 2.0 * boundary_loss)
753
+
754
+ return total_loss, visible_loss, occluded_loss, boundary_loss
755
+
756
+
757
+ def train_improved(model, dataloader, optimizer, device, num_epochs):
758
+ model.train()
759
+ criterion = AmodalCompletionLoss()
760
+
761
+ for epoch in range(num_epochs):
762
+ total_loss = 0
763
+ for i, batch in enumerate(dataloader):
764
+ rgb = batch['rgb'].to(device)
765
+ modal_mask = batch['modal_mask'].to(device)
766
+ amodal_mask = batch['amodal_mask'].to(device)
767
+ occluded_mask = batch['occluded_mask'].to(device)
768
+ gt_amodal_rgb = batch['amodal_rgb'].to(device)
769
+
770
+ input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1)
771
+
772
+ optimizer.zero_grad()
773
+ pred = model(input_tensor)
774
+
775
+ loss, vis_loss, occ_loss, boundary_loss = criterion(
776
+ pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask
777
+ )
778
+
779
+ loss.backward()
780
+ optimizer.step()
781
+ total_loss += loss.item()
782
+
783
+ if i % 16 == 0:
784
+ print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] "
785
+ f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, "
786
+ f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}")
787
+
788
+ print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}")
789
+
790
+ # Usage
791
+ if __name__ == "__main__":
792
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
793
+
794
+ # Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING
795
+ data_root = "data"
796
+
797
+ # Create train dataset (80% of train folder)
798
+ train_dataset = ModalAmodalDataset(
799
+ root_dir=data_root,
800
+ split='train',
801
+ img_size=(128, 128),
802
+ max_samples=1000, # Only use 1000 samples total before split
803
+ val_split=0.2, # 20% for validation
804
+ use_val_from_train=True # Create val split from train folder
805
+ )
806
+ train_loader = DataLoader(
807
+ train_dataset,
808
+ batch_size=16,
809
+ shuffle=True,
810
+ num_workers=2,
811
+ pin_memory=True,
812
+ drop_last=True
813
+ )
814
+
815
+ # Create validation dataset (20% of train folder)
816
+ val_dataset = ModalAmodalDataset(
817
+ root_dir=data_root,
818
+ split='val',
819
+ img_size=(128, 128),
820
+ max_samples=1000, # Same max_samples to ensure proper split
821
+ val_split=0.2,
822
+ use_val_from_train=True # Create val split from train folder
823
+ )
824
+ val_loader = DataLoader(
825
+ val_dataset,
826
+ batch_size=4,
827
+ shuffle=True,
828
+ num_workers=2,
829
+ pin_memory=True
830
+ )
831
+
832
+ print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch")
833
+ print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches")
834
+
835
+
836
+
837
+
838
+ model = ImprovedUNet().to(device)
839
+ model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device))
840
+
841
+
842
+
843
+
844
+
845
+
846
+ # Model and optimizer
847
+ model = model.to(device)
848
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
849
+
850
+ # Training
851
+ #train_improved(model, train_loader, optimizer, device, num_epochs=10)
852
+
853
+ # Evaluation and Visualization
854
+ print("\n" + "="*50)
855
+ print("EVALUATION RESULTS")
856
+ print("="*50)
857
+
858
+ # Compute metrics
859
+ metrics = evaluate_metrics(model, val_loader, device)
860
+ print(f"Overall MSE: {metrics['total_mse']:.6f}")
861
+ print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
862
+ print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
863
+ print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
864
+
865
+ # Visualize results
866
+ print("\nGenerating visualizations...")
867
+ visualize_results(model, val_loader, device, num_samples=8)
868
+
869
+ # Compute metrics
870
+ image_metrics = calculate_metrics(model, val_loader, device)
871
+ print(f"PSNR: {image_metrics['psnr']:.4f}")
872
+ print(f"SSIM: {image_metrics['ssim']:.4f}")
873
+ print(f"LPIPS: {image_metrics['lpips']:.4f}")
874
+ print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}")
875
+
876
+ # Dataset and DataLoader - REDUCED SIZE FOR FASTER TRAINING
877
+ data_root = "data"
878
+
879
+ # Create train dataset (80% of train folder)
880
+ train_dataset = ModalAmodalDataset(
881
+ root_dir=data_root,
882
+ split='train',
883
+ img_size=(128, 128),
884
+ max_samples=1000, # Only use 1000 samples total before split
885
+ val_split=0.2, # 20% for validation
886
+ use_val_from_train=True # Create val split from train folder
887
+ )
888
+ train_loader = DataLoader(
889
+ train_dataset,
890
+ batch_size=16,
891
+ shuffle=True,
892
+ num_workers=2,
893
+ pin_memory=True,
894
+ drop_last=True
895
+ )
896
+
897
+ # Create validation dataset (20% of train folder)
898
+ val_dataset = ModalAmodalDataset(
899
+ root_dir=data_root,
900
+ split='val',
901
+ img_size=(128, 128),
902
+ max_samples=1000, # Same max_samples to ensure proper split
903
+ val_split=0.2,
904
+ use_val_from_train=True # Create val split from train folder
905
+ )
906
+ val_loader = DataLoader(
907
+ val_dataset,
908
+ batch_size=4,
909
+ shuffle=True,
910
+ num_workers=2,
911
+ pin_memory=True
912
+ )
913
+
914
+ # Optional: Save model
915
+ torch.save(model.state_dict(), 'amodal_completion_model.pth')
916
+
917
+ # Evaluation and Visualization
918
+
919
+ test_dataset = ModalAmodalDataset(
920
+ root_dir=data_root,
921
+ split='test',
922
+ img_size=(128, 128),
923
+ max_samples=2000 # Only use 1000 samples total before split
924
+ )
925
+ test_loader = DataLoader(
926
+ test_dataset,
927
+ batch_size=8,
928
+ shuffle=True,
929
+ num_workers=2,
930
+ pin_memory=True,
931
+ drop_last=True
932
+ )
933
+
934
+ print("EVALUATION RESULTS")
935
+ print("="*50)
936
+
937
+ # Compute metrics
938
+ metrics = evaluate_metrics(model, test_loader, device)
939
+ print(f"Overall MSE: {metrics['total_mse']:.6f}")
940
+ print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
941
+ print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
942
+ print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
943
+
944
+ # Visualize results
945
+ print("\nGenerating visualizations...")
946
+ visualize_results(model, test_loader, device, num_samples=16)
947
+
948
+ from google.colab import runtime
949
+ runtime.unassign()
950
+
951
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
952
+ model = ImprovedUNet() # replace with actual class name
953
+ torch.load('amodal_completion_model.pth', map_location=torch.device('cpu'))
954
+ model.to(device)
955
+ model.eval()
956
+
957
+ # Evaluation and Visualization
958
+ print("\n" + "="*50)
959
+ print("EVALUATION RESULTS")
960
+ print("="*50)
961
+
962
+ # Compute metrics
963
+ metrics = evaluate_metrics(model, val_loader, device)
964
+ print(f"Overall MSE: {metrics['total_mse']:.6f}")
965
+ print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}")
966
+ print(f"Visible Region MSE: {metrics['visible_mse']:.6f}")
967
+ print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}")
968
+
969
+ # Visualize results
970
+ print("\nGenerating visualizations...")
971
+ visualize_results(model, val_loader, device, num_samples=8)
972
+
973
+ # Compute metrics
974
+ image_metrics = calculate_metrics(model, val_loader, device)
975
+ print(f"PSNR: {image_metrics['psnr']:.4f}")
976
+ print(f"SSIM: {image_metrics['ssim']:.4f}")
977
+ print(f"LPIPS: {image_metrics['lpips']:.4f}")
978
+ print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}")
979
+
980
+ model = ImprovedUNet()
981
+ model.eval()