| |
| """final1.2.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT |
| """ |
|
|
|
|
|
|
| """## FINAL 1.2""" |
|
|
|
|
|
|
| pip install torchmetrics lpips |
|
|
| |
| import torch |
| from torch import nn |
| from torchvision.transforms import ToPILImage, ToTensor |
| from torchvision.utils import make_grid |
| from torchvision.io import write_video |
|
|
| |
| from pathlib import Path |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import random |
| import json |
| from IPython.display import Video |
|
|
| |
| tensor_to_image = ToPILImage() |
| image_to_tensor = ToTensor() |
|
|
| def get_img_dict(img_dir): |
| img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')] |
| img_files.sort() |
|
|
| img_dict = {} |
| for img_file in img_files: |
| img_type = img_file.name.split('_')[0] |
| if img_type not in img_dict: |
| img_dict[img_type] = [] |
| img_dict[img_type].append(img_file) |
| return img_dict |
|
|
|
|
| def get_sample_dict(sample_dir): |
|
|
| camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name] |
| camera_dirs.sort() |
|
|
| sample_dict = {} |
|
|
| for cam_dir in camera_dirs: |
| cam_dict = {} |
| cam_dict['scene'] = get_img_dict(cam_dir) |
|
|
| obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name] |
| obj_dirs.sort() |
|
|
| for obj_dir in obj_dirs: |
| cam_dict[obj_dir.name] = get_img_dict(obj_dir) |
|
|
| sample_dict[cam_dir.name] = cam_dict |
|
|
| return sample_dict |
|
|
| !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json |
| |
| !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json |
| !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4 |
| !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md |
| !wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf" |
| !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes |
| |
| from huggingface_hub import HfApi, hf_hub_download |
| import random, os |
| api = HfApi() |
| repo_id = "Amar-S/MOVi-MC-AC" |
| |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
| |
| train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")] |
| test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")] |
| print(f"Found {len(train_files)} train files and {len(test_files)} test files.") |
| |
| import os |
| import random |
| import shutil |
| from huggingface_hub import hf_hub_download |
| os.makedirs("/content/data/train", exist_ok=True) |
| os.makedirs("/content/data/test", exist_ok=True) |
| |
| subset_train = random.sample(train_files, int(len(train_files) * 0.005)) |
| subset_test = random.sample(test_files, int(len(test_files) * 0.005)) |
| |
| for file in subset_train: |
| out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
| dest_path = f"/content/data/train/{os.path.basename(file)}" |
| shutil.copyfile(out_path, dest_path) |
| |
| for file in subset_test: |
| out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
| dest_path = f"/content/data/test/{os.path.basename(file)}" |
| shutil.copyfile(out_path, dest_path) |
|
|
| import os |
|
|
| |
| train_dir = "data/train" |
| for file in os.listdir(train_dir): |
| if file.endswith(".tar.gz"): |
| filepath = os.path.join(train_dir, file) |
| !tar -xzf {filepath} -C {train_dir} |
|
|
| |
| test_dir = "data/test" |
| for file in os.listdir(test_dir): |
| if file.endswith(".tar.gz"): |
| filepath = os.path.join(test_dir, file) |
| !tar -xzf {filepath} -C {test_dir} |
|
|
|
|
|
|
| import os |
| from pathlib import Path |
| root = Path('/content/data') |
| deleted = 0 |
| for archive in root.rglob('*.tar.gz'): |
| try: |
| archive.unlink() |
| print(f"Deleted {archive}") |
| deleted += 1 |
| except Exception as e: |
| print(f"Error deleting {archive}: {e}") |
| print(f"Total deleted: {deleted}") |
|
|
| pip install torchmetrics lpips |
|
|
| import matplotlib.pyplot as plt |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| import lpips |
| import matplotlib.pyplot as plt |
| import torch |
|
|
| def visualize_results(model, dataloader, device, num_samples=8): |
| """Visualize results with properly masked output (no background)""" |
| model.eval() |
| samples_shown = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| if samples_shown >= num_samples: |
| break |
|
|
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| for i in range(rgb.shape[0]): |
| if samples_shown >= num_samples: |
| break |
|
|
| fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
|
|
| |
| axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
| axes[0].set_title('Scene RGB') |
| axes[0].axis('off') |
|
|
| |
| axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
| axes[1].set_title('Amodal Mask') |
| axes[1].axis('off') |
|
|
| |
| axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
| axes[2].set_title('Modal Mask') |
| axes[2].axis('off') |
|
|
| |
| axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
| axes[3].set_title('GT Amodal RGB') |
| axes[3].axis('off') |
|
|
| |
| axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
| axes[4].set_title('Predicted Amodal RGB') |
| axes[4].axis('off') |
|
|
| |
| diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
| im = axes[5].imshow(diff.cpu(), cmap='hot') |
| axes[5].set_title('Prediction Error') |
| axes[5].axis('off') |
| plt.colorbar(im, ax=axes[5]) |
|
|
| plt.tight_layout() |
| plt.show() |
|
|
| samples_shown += 1 |
|
|
|
|
|
|
| |
| def evaluate_metrics(model, dataloader, device): |
| """Compute evaluation metrics only within object regions""" |
| model.eval() |
| total_mse = 0 |
| occluded_mse = 0 |
| visible_mse = 0 |
| total_pixels = 0 |
| occluded_pixels = 0 |
| visible_pixels = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| occluded_mask = batch['occluded_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| |
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| |
| object_pixels = amodal_mask.sum() |
| if object_pixels > 0: |
| mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
| total_mse += mse.item() |
| total_pixels += object_pixels.item() |
|
|
| |
| occluded_region = occluded_mask * amodal_mask |
| occ_pixels = occluded_region.sum() |
| if occ_pixels > 0: |
| occ_mse = F.mse_loss(pred_masked * occluded_region, |
| gt_masked * occluded_region, reduction='sum') |
| occluded_mse += occ_mse.item() |
| occluded_pixels += occ_pixels.item() |
|
|
| |
| visible_region = modal_mask * amodal_mask |
| vis_pixels = visible_region.sum() |
| if vis_pixels > 0: |
| vis_mse = F.mse_loss(pred_masked * visible_region, |
| gt_masked * visible_region, reduction='sum') |
| visible_mse += vis_mse.item() |
| visible_pixels += vis_pixels.item() |
|
|
| return { |
| 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
| 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
| 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
| } |
|
|
|
|
|
|
| def calculate_metrics(model, dataloader, device): |
| """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
|
|
| model.eval() |
| psnr = PeakSignalNoiseRatio().to(device) |
| ssim = StructuralSimilarityIndexMeasure().to(device) |
| lpips_loss = lpips.LPIPS(net='alex').to(device) |
|
|
| total_psnr, total_ssim, total_lpips = 0, 0, 0 |
| total_iou = 0 |
| count = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| for i in range(pred.shape[0]): |
| pred_i = pred_masked[i].unsqueeze(0) |
| gt_i = gt_masked[i].unsqueeze(0) |
|
|
| |
| if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
| continue |
|
|
| total_psnr += psnr(pred_i, gt_i).item() |
| total_ssim += ssim(pred_i, gt_i).item() |
| total_lpips += lpips_loss(pred_i, gt_i).item() |
|
|
| |
| intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
| union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
| if union > 0: |
| iou = intersection.float() / union.float() |
| total_iou += iou.item() |
|
|
| count += 1 |
|
|
| if count == 0: |
| return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
|
|
| return { |
| "psnr": total_psnr / count, |
| "ssim": total_ssim / count, |
| "lpips": total_lpips / count, |
| "miou": total_iou / count |
| } |
|
|
| pip install torchmetrics lpips |
|
|
| import matplotlib.pyplot as plt |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| import lpips |
| import matplotlib.pyplot as plt |
| import torch |
|
|
| def visualize_results(model, dataloader, device, num_samples=8): |
| """Visualize results with properly masked output (no background)""" |
| model.eval() |
| samples_shown = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| if samples_shown >= num_samples: |
| break |
|
|
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| for i in range(rgb.shape[0]): |
| if samples_shown >= num_samples: |
| break |
|
|
| fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
|
|
| |
| axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
| axes[0].set_title('Scene RGB') |
| axes[0].axis('off') |
|
|
| |
| axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
| axes[1].set_title('Amodal Mask') |
| axes[1].axis('off') |
|
|
| |
| axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
| axes[2].set_title('Modal Mask') |
| axes[2].axis('off') |
|
|
| |
| axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
| axes[3].set_title('GT Amodal RGB') |
| axes[3].axis('off') |
|
|
| |
| axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
| axes[4].set_title('Predicted Amodal RGB') |
| axes[4].axis('off') |
|
|
| |
| diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
| im = axes[5].imshow(diff.cpu(), cmap='hot') |
| axes[5].set_title('Prediction Error') |
| axes[5].axis('off') |
| plt.colorbar(im, ax=axes[5]) |
|
|
| plt.tight_layout() |
| plt.show() |
|
|
| samples_shown += 1 |
|
|
|
|
| def evaluate_metrics(model, dataloader, device): |
| """Compute evaluation metrics only within object regions""" |
| model.eval() |
| total_mse = 0 |
| occluded_mse = 0 |
| visible_mse = 0 |
| total_pixels = 0 |
| occluded_pixels = 0 |
| visible_pixels = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| occluded_mask = batch['occluded_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| |
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| |
| object_pixels = amodal_mask.sum() |
| if object_pixels > 0: |
| mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
| total_mse += mse.item() |
| total_pixels += object_pixels.item() |
|
|
| |
| occluded_region = occluded_mask * amodal_mask |
| occ_pixels = occluded_region.sum() |
| if occ_pixels > 0: |
| occ_mse = F.mse_loss(pred_masked * occluded_region, |
| gt_masked * occluded_region, reduction='sum') |
| occluded_mse += occ_mse.item() |
| occluded_pixels += occ_pixels.item() |
|
|
| |
| visible_region = modal_mask * amodal_mask |
| vis_pixels = visible_region.sum() |
| if vis_pixels > 0: |
| vis_mse = F.mse_loss(pred_masked * visible_region, |
| gt_masked * visible_region, reduction='sum') |
| visible_mse += vis_mse.item() |
| visible_pixels += vis_pixels.item() |
|
|
| return { |
| 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
| 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
| 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
| } |
|
|
|
|
|
|
| def calculate_metrics(model, dataloader, device): |
| """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
|
|
| model.eval() |
| psnr = PeakSignalNoiseRatio().to(device) |
| ssim = StructuralSimilarityIndexMeasure().to(device) |
| lpips_loss = lpips.LPIPS(net='alex').to(device) |
|
|
| total_psnr, total_ssim, total_lpips = 0, 0, 0 |
| total_iou = 0 |
| count = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| pred = model(input_tensor) |
|
|
| pred_masked = pred * amodal_mask |
| gt_masked = gt_amodal_rgb * amodal_mask |
|
|
| for i in range(pred.shape[0]): |
| pred_i = pred_masked[i].unsqueeze(0) |
| gt_i = gt_masked[i].unsqueeze(0) |
|
|
| |
| if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
| continue |
|
|
| total_psnr += psnr(pred_i, gt_i).item() |
| total_ssim += ssim(pred_i, gt_i).item() |
| total_lpips += lpips_loss(pred_i, gt_i).item() |
|
|
| |
| intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
| union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
| if union > 0: |
| iou = intersection.float() / union.float() |
| total_iou += iou.item() |
|
|
| count += 1 |
|
|
| if count == 0: |
| return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
|
|
| return { |
| "psnr": total_psnr / count, |
| "ssim": total_ssim / count, |
| "lpips": total_lpips / count, |
| "miou": total_iou / count |
| } |
|
|
|
|
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from pathlib import Path |
| from PIL import Image, ImageChops |
| import numpy as np |
|
|
| class ModalAmodalDataset(Dataset): |
| def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False): |
| self.root_dir = Path(root_dir) |
| self.img_size = img_size |
| self.max_samples = max_samples |
| self.val_split = val_split |
| self.use_val_from_train = use_val_from_train |
| self.split = split |
|
|
| if split == 'val' and use_val_from_train: |
| |
| self.root_dir = self.root_dir / 'train' |
| else: |
| self.root_dir = self.root_dir / split |
|
|
| self.samples = self._build_sample_index() |
|
|
| self.rgb_transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.ToTensor(), |
| ]) |
| self.mask_transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.ToTensor(), |
| ]) |
|
|
| def _build_sample_index(self): |
| samples = [] |
| for scene_dir in self.root_dir.iterdir(): |
| if not scene_dir.is_dir(): |
| continue |
| for camera_dir in scene_dir.iterdir(): |
| if not camera_dir.name.startswith('camera_'): |
| continue |
|
|
| rgba_paths = sorted(camera_dir.glob('rgba_*.png')) |
| seg_paths = sorted(camera_dir.glob('segmentation_*.png')) |
|
|
| for obj_dir in camera_dir.iterdir(): |
| if not obj_dir.name.startswith('obj_'): |
| continue |
|
|
| amodal_paths = sorted(obj_dir.glob('segmentation_*.png')) |
| amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png')) |
|
|
| if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)): |
| continue |
|
|
| for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip( |
| rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths |
| ): |
| samples.append({ |
| 'rgb_path': rgba_path, |
| 'seg_path': seg_path, |
| 'amodal_path': amodal_path, |
| 'amodal_rgb_path': amodal_rgb_path, |
| 'object_id': int(obj_dir.name.split('_')[1]), |
| 'scene': scene_dir.name, |
| 'camera': camera_dir.name |
| }) |
|
|
| |
| if self.max_samples is not None and len(samples) > self.max_samples: |
| |
| import random |
| random.seed(42) |
| samples = random.sample(samples, self.max_samples) |
| print(f"Dataset limited to {len(samples)} samples") |
|
|
| |
| if self.use_val_from_train: |
| import random |
| random.seed(42) |
| random.shuffle(samples) |
|
|
| val_size = int(len(samples) * self.val_split) |
| if self.split == 'train': |
| samples = samples[val_size:] |
| print(f"Train split: {len(samples)} samples") |
| elif self.split == 'val': |
| samples = samples[:val_size] |
| print(f"Validation split: {len(samples)} samples") |
|
|
| return samples |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
|
|
| |
| rgb = Image.open(sample['rgb_path']).convert('RGB') |
| seg_map = np.array(Image.open(sample['seg_path'])) |
| amodal_mask_img = Image.open(sample['amodal_path']).convert('L') |
| amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB') |
|
|
| |
| modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255 |
| modal_mask_img = Image.fromarray(modal_mask_np, mode='L') |
|
|
| |
| rgb = self.rgb_transform(rgb) |
| modal_mask = self.mask_transform(modal_mask_img) |
| amodal_mask = self.mask_transform(amodal_mask_img) |
| amodal_rgb = self.rgb_transform(amodal_rgb) |
|
|
| |
| occluded_mask = amodal_mask - modal_mask |
| occluded_mask = torch.clamp(occluded_mask, 0, 1) |
|
|
| return { |
| 'rgb': rgb, |
| 'modal_mask': modal_mask, |
| 'amodal_mask': amodal_mask, |
| 'occluded_mask': occluded_mask, |
| 'amodal_rgb': amodal_rgb, |
| } |
|
|
|
|
| class ImprovedUNet(nn.Module): |
|
|
| def __init__(self, in_channels=5, out_channels=3): |
| super().__init__() |
|
|
| def conv_block(in_ch, out_ch, dropout=0.1): |
| return nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| nn.Dropout2d(dropout), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| |
| self.down1 = conv_block(in_channels, 64) |
| self.pool1 = nn.MaxPool2d(2) |
| self.down2 = conv_block(64, 128) |
| self.pool2 = nn.MaxPool2d(2) |
| self.down3 = conv_block(128, 256) |
| self.pool3 = nn.MaxPool2d(2) |
| self.down4 = conv_block(256, 512) |
| self.pool4 = nn.MaxPool2d(2) |
|
|
| |
| self.middle = conv_block(512, 1024, dropout=0.2) |
|
|
| |
| self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2) |
| self.up_block1 = conv_block(1024, 512) |
| self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2) |
| self.up_block2 = conv_block(512, 256) |
| self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) |
| self.up_block3 = conv_block(256, 128) |
| self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2) |
| self.up_block4 = conv_block(128, 64) |
|
|
| self.final = nn.Conv2d(64, out_channels, 1) |
|
|
| def forward(self, x): |
| |
| d1 = self.down1(x) |
| d2 = self.down2(self.pool1(d1)) |
| d3 = self.down3(self.pool2(d2)) |
| d4 = self.down4(self.pool3(d3)) |
|
|
| |
| m = self.middle(self.pool4(d4)) |
|
|
| |
| u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1)) |
| u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1)) |
| u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1)) |
| u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1)) |
|
|
| return torch.sigmoid(self.final(u4)) |
|
|
| class AmodalCompletionLoss(nn.Module): |
| """Loss that only considers object regions (ignores background)""" |
|
|
| def __init__(self, occluded_weight=5.0, visible_weight=1.0): |
| super().__init__() |
| self.occluded_weight = occluded_weight |
| self.visible_weight = visible_weight |
| self.lpips_model = lpips.LPIPS(net='alex') |
|
|
| def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask): |
| |
| device = pred.device |
| self.lpips_model = self.lpips_model.to(device) |
|
|
| pred_masked = pred * amodal_mask |
| target_masked = target * amodal_mask |
|
|
|
|
|
|
| |
| visible_region = modal_mask * amodal_mask |
| if visible_region.sum() > 0: |
| visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region) |
| else: |
| visible_loss = torch.tensor(0.0).to(pred.device) |
|
|
| |
| occluded_region = occluded_mask * amodal_mask |
| if occluded_region.sum() > 0: |
| occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region) |
| else: |
| occluded_loss = torch.tensor(0.0).to(pred.device) |
|
|
|
|
| perceptual_loss = self.lpips_model(pred_masked, target_masked).mean() |
|
|
| |
| boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1) |
| boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float() |
| boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask) |
|
|
| total_loss = (self.visible_weight * visible_loss + |
| self.occluded_weight * occluded_loss + |
| 2.0 * boundary_loss) |
|
|
| return total_loss, visible_loss, occluded_loss, boundary_loss |
|
|
|
|
| def train_improved(model, dataloader, optimizer, device, num_epochs): |
| model.train() |
| criterion = AmodalCompletionLoss() |
|
|
| for epoch in range(num_epochs): |
| total_loss = 0 |
| for i, batch in enumerate(dataloader): |
| rgb = batch['rgb'].to(device) |
| modal_mask = batch['modal_mask'].to(device) |
| amodal_mask = batch['amodal_mask'].to(device) |
| occluded_mask = batch['occluded_mask'].to(device) |
| gt_amodal_rgb = batch['amodal_rgb'].to(device) |
|
|
| input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
|
|
| optimizer.zero_grad() |
| pred = model(input_tensor) |
|
|
| loss, vis_loss, occ_loss, boundary_loss = criterion( |
| pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask |
| ) |
|
|
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
|
|
| if i % 16 == 0: |
| print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] " |
| f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, " |
| f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}") |
|
|
| print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}") |
|
|
| |
| if __name__ == "__main__": |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| data_root = "data" |
|
|
| |
| train_dataset = ModalAmodalDataset( |
| root_dir=data_root, |
| split='train', |
| img_size=(128, 128), |
| max_samples=1000, |
| val_split=0.2, |
| use_val_from_train=True |
| ) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=16, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
| |
| val_dataset = ModalAmodalDataset( |
| root_dir=data_root, |
| split='val', |
| img_size=(128, 128), |
| max_samples=1000, |
| val_split=0.2, |
| use_val_from_train=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=4, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True |
| ) |
|
|
| print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch") |
| print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches") |
|
|
|
|
|
|
|
|
| model = ImprovedUNet().to(device) |
| model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device)) |
|
|
|
|
|
|
|
|
|
|
|
|
| |
| model = model.to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) |
|
|
| |
| |
|
|
| |
| print("\n" + "="*50) |
| print("EVALUATION RESULTS") |
| print("="*50) |
|
|
| |
| metrics = evaluate_metrics(model, val_loader, device) |
| print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
| |
| print("\nGenerating visualizations...") |
| visualize_results(model, val_loader, device, num_samples=8) |
|
|
| |
| image_metrics = calculate_metrics(model, val_loader, device) |
| print(f"PSNR: {image_metrics['psnr']:.4f}") |
| print(f"SSIM: {image_metrics['ssim']:.4f}") |
| print(f"LPIPS: {image_metrics['lpips']:.4f}") |
| print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
|
|
| |
| data_root = "data" |
|
|
| |
| train_dataset = ModalAmodalDataset( |
| root_dir=data_root, |
| split='train', |
| img_size=(128, 128), |
| max_samples=1000, |
| val_split=0.2, |
| use_val_from_train=True |
| ) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=16, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
| |
| val_dataset = ModalAmodalDataset( |
| root_dir=data_root, |
| split='val', |
| img_size=(128, 128), |
| max_samples=1000, |
| val_split=0.2, |
| use_val_from_train=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=4, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True |
| ) |
|
|
| |
| torch.save(model.state_dict(), 'amodal_completion_model.pth') |
|
|
| |
|
|
| test_dataset = ModalAmodalDataset( |
| root_dir=data_root, |
| split='test', |
| img_size=(128, 128), |
| max_samples=2000 |
| ) |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=8, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
| print("EVALUATION RESULTS") |
| print("="*50) |
|
|
| |
| metrics = evaluate_metrics(model, test_loader, device) |
| print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
| |
| print("\nGenerating visualizations...") |
| visualize_results(model, test_loader, device, num_samples=16) |
|
|
| from google.colab import runtime |
| runtime.unassign() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = ImprovedUNet() |
| torch.load('amodal_completion_model.pth', map_location=torch.device('cpu')) |
| model.to(device) |
| model.eval() |
|
|
| |
| print("\n" + "="*50) |
| print("EVALUATION RESULTS") |
| print("="*50) |
|
|
| |
| metrics = evaluate_metrics(model, val_loader, device) |
| print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
|
|
| |
| print("\nGenerating visualizations...") |
| visualize_results(model, val_loader, device, num_samples=8) |
|
|
| |
| image_metrics = calculate_metrics(model, val_loader, device) |
| print(f"PSNR: {image_metrics['psnr']:.4f}") |
| print(f"SSIM: {image_metrics['ssim']:.4f}") |
| print(f"LPIPS: {image_metrics['lpips']:.4f}") |
| print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
|
|
| model = ImprovedUNet() |
| model.eval() |