| |
| """2.2.2.2.2.ipynb |
| |
| Automatically generated by Colab. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1igY4MKIJJTPHgEkdLFI_T5H6sLUoTaLr |
| """ |
|
|
| |
|
|
| """## CODE""" |
|
|
| pip install torchmetrics lpips |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from pathlib import Path |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| from torchmetrics.image.fid import FrechetInceptionDistance |
| import lpips |
| import os |
| import random |
| import shutil |
| from huggingface_hub import HfApi, hf_hub_download |
| import tarfile |
| import json |
| import cv2 |
| from tqdm import tqdm |
|
|
| def download_sequential_data(repo_id="Amar-S/MOVi-MC-AC", sample_ratio=0.01, base_dir="/content/data"): |
| """ |
| Download data while preserving video sequences |
| """ |
| api = HfApi() |
|
|
| |
| os.makedirs(f"{base_dir}/train", exist_ok=True) |
| os.makedirs(f"{base_dir}/test", exist_ok=True) |
|
|
| |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
|
|
| |
| |
| test_files = [f for f in files if f.startswith("test/") and f.endswith(".tar.gz")] |
|
|
| |
|
|
| |
| |
| subset_test = random.sample(test_files, max(1, int(len(test_files) * sample_ratio))) |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| for file in subset_test: |
| print(f"Downloading {file}...") |
| out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
| dest_path = f"{base_dir}/test/{os.path.basename(file)}" |
| shutil.copyfile(out_path, dest_path) |
|
|
| |
| extract_archives(f"{base_dir}/train") |
| extract_archives(f"{base_dir}/test") |
|
|
| print("Download and extraction complete!") |
|
|
| def extract_archives(directory): |
| """Extract all tar.gz files in a directory""" |
| for file in os.listdir(directory): |
| if file.endswith(".tar.gz"): |
| filepath = os.path.join(directory, file) |
| print(f"Extracting {filepath}...") |
| with tarfile.open(filepath, 'r:gz') as tar: |
| tar.extractall(path=directory) |
| |
| os.remove(filepath) |
|
|
| download_sequential_data() |
| |
| extract_archives('/content/data/test') |
|
|
| def extract_archives(directory): |
| """Extract all tar.gz files in a directory""" |
| for file in os.listdir(directory): |
| if file.endswith(".tar.gz"): |
| filepath = os.path.join(directory, file) |
| print(f"Extracting {filepath}...") |
| with tarfile.open(filepath, 'r:gz') as tar: |
| print(filepath) |
| tar.extractall(path=directory) |
| |
| os.remove(filepath) |
|
|
| |
| extract_archives('/content/data/test') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class VideoAmodalDataset(Dataset): |
| def __init__(self, root_dir, split='train', seq_len=8, img_size=(256,256), |
| max_scenes=4, samples_per_scene=3, max_samples=None): |
| self.root_dir = Path(root_dir) |
| self.split = split |
| self.seq_len = seq_len |
| self.img_size = img_size |
| self.max_scenes = max_scenes |
| self.samples_per_scene = samples_per_scene |
|
|
| self.samples = self._build_sample_index(max_samples) |
|
|
| self.transform = transforms.Compose([ |
| transforms.Resize(img_size), |
| transforms.ToTensor(), |
| ]) |
|
|
| def _build_sample_index(self, max_samples): |
| samples = [] |
| scene_paths = sorted((self.root_dir / self.split).glob('scene_*'))[:self.max_scenes] |
|
|
| for scene_path in scene_paths: |
| camera_paths = sorted(scene_path.glob('camera_*')) |
|
|
| for camera_path in camera_paths: |
| obj_paths = sorted(camera_path.glob('obj_*')) |
| selected_objs = random.sample(obj_paths, min(self.samples_per_scene, len(obj_paths))) |
|
|
| for obj_path in selected_objs: |
| rgba_files = sorted(camera_path.glob('rgba_*.png')) |
| frame_ids = [int(p.stem.split('_')[1]) for p in rgba_files] |
|
|
| |
| for i in range(0, len(frame_ids) - self.seq_len + 1, self.seq_len): |
| samples.append({ |
| 'scene': scene_path.name, |
| 'camera': camera_path.name, |
| 'obj_folder': obj_path.name, |
| 'frame_ids': frame_ids[i:i+self.seq_len], |
| 'obj_id': int(obj_path.name.split('_')[1]) |
| }) |
|
|
| if max_samples and len(samples) >= max_samples: |
| return samples |
|
|
| return samples |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| base_path = self.root_dir / self.split / sample['scene'] / sample['camera'] |
| obj_path = base_path / sample['obj_folder'] |
|
|
| rgb_frames = [] |
| modal_mask_frames = [] |
| amodal_mask_frames = [] |
| amodal_rgb_frames = [] |
|
|
| for fid in sample['frame_ids']: |
| fid_str = f"{fid:05d}" |
|
|
| try: |
| |
| rgb = Image.open(base_path / f'rgba_{fid_str}.png').convert('RGB') |
| rgb = self.transform(rgb) |
|
|
| |
| seg_map = np.array(Image.open(base_path / f'segmentation_{fid_str}.png')) |
| modal_mask_np = (seg_map == sample['obj_id']).astype(np.uint8) * 255 |
| modal_mask = Image.fromarray(modal_mask_np, mode='L') |
| modal_mask = self.transform(modal_mask) |
|
|
| |
| amodal_mask = Image.open(obj_path / f'segmentation_{fid_str}.png').convert('L') |
| amodal_mask = self.transform(amodal_mask) |
|
|
| |
| amodal_rgb = Image.open(obj_path / f'rgba_{fid_str}.png').convert('RGB') |
| amodal_rgb = self.transform(amodal_rgb) |
|
|
| rgb_frames.append(rgb) |
| modal_mask_frames.append(modal_mask) |
| amodal_mask_frames.append(amodal_mask) |
| amodal_rgb_frames.append(amodal_rgb) |
|
|
| except Exception as e: |
| print(f"Error loading {base_path}/rgba_{fid_str}.png: {e}") |
| |
| empty_rgb = torch.zeros(3, self.img_size[0], self.img_size[1]) |
| empty_mask = torch.zeros(1, self.img_size[0], self.img_size[1]) |
|
|
| return { |
| 'rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
| 'modal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
| 'amodal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
| 'amodal_rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
| 'scene': sample['scene'], |
| 'camera': sample['camera'], |
| 'object_id': sample['obj_id'] |
| } |
|
|
| return { |
| 'rgb_sequence': torch.stack(rgb_frames), |
| 'modal_masks': torch.stack(modal_mask_frames), |
| 'amodal_masks': torch.stack(amodal_mask_frames), |
| 'amodal_rgb_sequence': torch.stack(amodal_rgb_frames), |
| 'scene': sample['scene'], |
| 'camera': sample['camera'], |
| 'object_id': sample['obj_id'] |
| } |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| import wandb |
|
|
| wandb.login() |
|
|
| |
| import numpy as np |
| from skimage.metrics import structural_similarity as ssim |
| from skimage.metrics import peak_signal_noise_ratio as psnr |
| import torch.nn.functional as F |
| from scipy import linalg |
| import matplotlib.pyplot as plt |
| import matplotlib.cm as cm |
| from torchvision.models import inception_v3 |
| from torchvision.transforms import Resize, Normalize |
| import lpips |
|
|
| |
| class VideoAmodalMetrics: |
| """Compute various metrics for video amodal completion""" |
|
|
| def __init__(self, device='cuda'): |
| self.device = device |
| |
| self.lpips_model = lpips.LPIPS(net='alex').to(device) |
|
|
| |
| self.inception_model = inception_v3(pretrained=True, transform_input=False).to(device) |
| self.inception_model.eval() |
|
|
| |
| self.inception_transform = torch.nn.Sequential( |
| Resize((299, 299)), |
| Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ) |
|
|
| def calculate_psnr(self, pred, target, mask=None): |
| """Calculate PSNR between prediction and target""" |
| if mask is not None: |
| |
| pred_masked = pred * mask |
| target_masked = target * mask |
|
|
| |
| psnr_values = [] |
| for i in range(pred.shape[0]): |
| if pred.dim() == 5: |
| for j in range(pred.shape[2]): |
| p = pred_masked[i, :, j].permute(1, 2, 0).cpu().numpy() |
| t = target_masked[i, :, j].permute(1, 2, 0).cpu().numpy() |
| m = mask[i, 0, j].cpu().numpy() |
|
|
| if m.sum() > 0: |
| psnr_val = psnr(t, p, data_range=1.0) |
| psnr_values.append(psnr_val) |
| else: |
| p = pred_masked[i].permute(1, 2, 0).cpu().numpy() |
| t = target_masked[i].permute(1, 2, 0).cpu().numpy() |
| m = mask[i, 0].cpu().numpy() |
|
|
| if m.sum() > 0: |
| psnr_val = psnr(t, p, data_range=1.0) |
| psnr_values.append(psnr_val) |
| else: |
| |
| mse = F.mse_loss(pred, target) |
| psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse)) |
| return psnr_val.item() |
|
|
| return np.mean(psnr_values) if psnr_values else 0.0 |
|
|
| def calculate_ssim(self, pred, target, mask=None): |
| """Calculate SSIM between prediction and target""" |
| ssim_values = [] |
|
|
| for i in range(pred.shape[0]): |
| if pred.dim() == 5: |
| for j in range(pred.shape[2]): |
| p = pred[i, :, j].permute(1, 2, 0).cpu().numpy() |
| t = target[i, :, j].permute(1, 2, 0).cpu().numpy() |
|
|
| if mask is not None: |
| m = mask[i, 0, j].cpu().numpy() |
| if m.sum() == 0: |
| continue |
|
|
| ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) |
| ssim_values.append(ssim_val) |
| else: |
| p = pred[i].permute(1, 2, 0).cpu().numpy() |
| t = target[i].permute(1, 2, 0).cpu().numpy() |
|
|
| if mask is not None: |
| m = mask[i, 0].cpu().numpy() |
| if m.sum() == 0: |
| continue |
|
|
| ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) |
| ssim_values.append(ssim_val) |
|
|
| return np.mean(ssim_values) if ssim_values else 0.0 |
|
|
| def calculate_lpips(self, pred, target, mask=None): |
| """Calculate LPIPS perceptual distance""" |
| |
| pred_norm = pred * 2.0 - 1.0 |
| target_norm = target * 2.0 - 1.0 |
|
|
| lpips_values = [] |
|
|
| if pred.dim() == 5: |
| for i in range(pred.shape[0]): |
| for j in range(pred.shape[2]): |
| p = pred_norm[i, :, j].unsqueeze(0) |
| t = target_norm[i, :, j].unsqueeze(0) |
|
|
| with torch.no_grad(): |
| lpips_val = self.lpips_model(p, t) |
| lpips_values.append(lpips_val.item()) |
| else: |
| with torch.no_grad(): |
| lpips_val = self.lpips_model(pred_norm, target_norm) |
| lpips_values.extend(lpips_val.cpu().numpy().tolist()) |
|
|
| return np.mean(lpips_values) if lpips_values else 0.0 |
|
|
| def calculate_iou(self, pred_mask, target_mask, threshold=0.5): |
| """Calculate IoU for binary masks""" |
| pred_binary = (pred_mask > threshold).float() |
| target_binary = (target_mask > threshold).float() |
|
|
| intersection = (pred_binary * target_binary).sum() |
| union = pred_binary.sum() + target_binary.sum() - intersection |
|
|
| iou = intersection / (union + 1e-8) |
| return iou.item() |
|
|
| def get_inception_features(self, images): |
| """Extract features from Inception model for FID calculation""" |
| with torch.no_grad(): |
| |
| images_preprocessed = self.inception_transform(images) |
|
|
| |
| features = self.inception_model(images_preprocessed) |
| return features.cpu().numpy() |
|
|
| def calculate_fid(self, pred, target): |
| """Calculate Fréchet Inception Distance""" |
| |
| if pred.dim() == 5: |
| pred = pred.permute(0, 2, 1, 3, 4).reshape(-1, pred.shape[1], pred.shape[3], pred.shape[4]) |
| target = target.permute(0, 2, 1, 3, 4).reshape(-1, target.shape[1], target.shape[3], target.shape[4]) |
|
|
| |
| pred_features = self.get_inception_features(pred) |
| target_features = self.get_inception_features(target) |
|
|
| |
| mu1, sigma1 = pred_features.mean(axis=0), np.cov(pred_features, rowvar=False) |
| mu2, sigma2 = target_features.mean(axis=0), np.cov(target_features, rowvar=False) |
|
|
| |
| diff = mu1 - mu2 |
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) |
| if np.iscomplexobj(covmean): |
| covmean = covmean.real |
|
|
| fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean) |
| return fid |
|
|
| def calculate_all_metrics(self, pred, target, amodal_mask=None): |
| """Calculate all metrics at once""" |
| metrics = {} |
|
|
| metrics['psnr'] = self.calculate_psnr(pred, target, amodal_mask) |
| metrics['ssim'] = self.calculate_ssim(pred, target, amodal_mask) |
| metrics['lpips'] = self.calculate_lpips(pred, target, amodal_mask) |
|
|
| try: |
| metrics['fid'] = self.calculate_fid(pred, target) |
| except: |
| metrics['fid'] = 0.0 |
|
|
| |
| if amodal_mask is not None: |
| |
| pred_intensity = pred.mean(dim=1, keepdim=True) |
| metrics['iou'] = self.calculate_iou(pred_intensity, amodal_mask) |
|
|
| return metrics |
|
|
| |
| def create_error_heatmap(pred, target, mask=None): |
| """Create error heatmap between prediction and target""" |
| |
| error = torch.abs(pred - target).mean(dim=0) |
|
|
| if mask is not None: |
| error = error * mask.squeeze() |
|
|
| return error.cpu().numpy() |
|
|
| |
| def train_video_amodal_with_metrics(): |
| |
| wandb.init( |
| project="video-amodal-completion", |
| config={ |
| 'batch_size': 2, |
| 'seq_len': 6, |
| 'img_size': (256, 256), |
| 'num_epochs': 30, |
| 'learning_rate': 5e-5, |
| 'max_scenes': 2, |
| 'samples_per_scene': 2, |
| 'num_workers': 2, |
| 'grad_accum_steps': 4 |
| } |
| ) |
|
|
|
|
| |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| torch.cuda.empty_cache() |
|
|
| config = wandb.config |
|
|
| |
| metrics_calculator = VideoAmodalMetrics(device) |
|
|
| |
| train_dataset = VideoAmodalDataset( |
| root_dir='data', |
| split='train', |
| seq_len=config.seq_len, |
| img_size=config.img_size, |
| max_scenes=config.max_scenes, |
| samples_per_scene=config.samples_per_scene, |
| max_samples=100 |
| ) |
|
|
| val_dataset = VideoAmodalDataset( |
| root_dir='data', |
| split='test', |
| seq_len=config.seq_len, |
| img_size=config.img_size, |
| max_scenes=1, |
| samples_per_scene=1, |
| max_samples=10 |
| ) |
|
|
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| num_workers=config.num_workers, |
| pin_memory=True |
| ) |
|
|
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=1 |
| ) |
|
|
| |
| model = Video3DUNet( |
| in_channels=5, |
| out_channels=3, |
| sequence_length=config.seq_len |
| ).to(device) |
|
|
|
|
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4) |
| criterion = VideoAmodalCompletionLoss() |
|
|
| |
| for epoch in range(config.num_epochs): |
| model.train() |
| epoch_losses = [] |
| epoch_metrics = { |
| 'train_psnr': [], |
| 'train_ssim': [], |
| 'train_lpips': [], |
| 'train_fid': [], |
| 'train_iou': [] |
| } |
|
|
| for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")): |
| |
| inputs = prepare_model_input(batch).to(device, non_blocking=True) |
| targets = prepare_model_target(batch).to(device, non_blocking=True) |
| modal_masks = batch['modal_masks'].to(device, non_blocking=True) |
| amodal_masks = batch['amodal_masks'].to(device, non_blocking=True) |
|
|
| |
| with torch.cuda.amp.autocast(): |
| outputs = model(inputs) |
| loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) |
| loss = loss / config.grad_accum_steps |
|
|
| |
| loss.backward() |
|
|
| |
| if i % 10 == 0: |
| with torch.no_grad(): |
| amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) |
| batch_metrics = metrics_calculator.calculate_all_metrics( |
| outputs, targets, amodal_masks_3d |
| ) |
|
|
| for key, value in batch_metrics.items(): |
| if f'train_{key}' in epoch_metrics: |
| epoch_metrics[f'train_{key}'].append(value) |
|
|
| |
| if (i + 1) % config.grad_accum_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
| torch.cuda.empty_cache() |
|
|
| epoch_losses.append(loss_dict['total_loss']) |
|
|
| |
| if i % 20 == 0: |
| log_dict = { |
| 'batch': epoch * len(train_loader) + i, |
| 'train_loss': loss_dict['total_loss'], |
| 'train_visible_loss': loss_dict['visible_loss'], |
| 'train_occluded_loss': loss_dict['occluded_loss'], |
| 'train_background_loss': loss_dict['background_loss'], |
| 'train_boundary_loss': loss_dict['boundary_loss'] |
| } |
|
|
| |
| for key, values in epoch_metrics.items(): |
| if values: |
| log_dict[key] = values[-1] |
|
|
| wandb.log(log_dict) |
|
|
| print(f"Batch {i}, Loss: {loss_dict['total_loss']:.4f}") |
| print(f" Visible: {loss_dict['visible_loss']:.4f}, " |
| f"Occluded: {loss_dict['occluded_loss']:.4f}, " |
| f"Background: {loss_dict['background_loss']:.4f}") |
|
|
| |
| model.eval() |
| val_losses = [] |
| val_metrics = { |
| 'val_psnr': [], |
| 'val_ssim': [], |
| 'val_lpips': [], |
| 'val_fid': [], |
| 'val_iou': [] |
| } |
|
|
| with torch.no_grad(): |
| for batch in val_loader: |
| inputs = prepare_model_input(batch).to(device) |
| targets = prepare_model_target(batch).to(device) |
| modal_masks = batch['modal_masks'].to(device) |
| amodal_masks = batch['amodal_masks'].to(device) |
|
|
| outputs = model(inputs) |
| loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) |
| val_losses.append(loss_dict['total_loss']) |
|
|
| |
| amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) |
| batch_metrics = metrics_calculator.calculate_all_metrics( |
| outputs, targets, amodal_masks_3d |
| ) |
|
|
| for key, value in batch_metrics.items(): |
| if f'val_{key}' in val_metrics: |
| val_metrics[f'val_{key}'].append(value) |
|
|
| |
| avg_train_loss = np.mean(epoch_losses) |
| avg_val_loss = np.mean(val_losses) |
|
|
| epoch_log = { |
| 'epoch': epoch, |
| 'avg_train_loss': avg_train_loss, |
| 'avg_val_loss': avg_val_loss |
| } |
|
|
| |
| for key, values in {**epoch_metrics, **val_metrics}.items(): |
| if values: |
| epoch_log[f'avg_{key}'] = np.mean(values) |
|
|
| wandb.log(epoch_log) |
|
|
| print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") |
|
|
| |
| for key, values in val_metrics.items(): |
| if values: |
| print(f" {key}: {np.mean(values):.4f}") |
|
|
| |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'train_loss': avg_train_loss, |
| 'val_loss': avg_val_loss, |
| 'metrics': {key: np.mean(values) for key, values in val_metrics.items() if values} |
| }, f"epoch_{epoch}.pth") |
|
|
| wandb.finish() |
|
|
| |
| def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, |
| output_path="amodal_completion_with_error.gif", duration=200): |
| """Create animated GIF with error heatmap""" |
| from PIL import Image |
| import numpy as np |
|
|
| frames = [] |
| all_errors = [] |
|
|
| |
| for i in range(len(predictions)): |
| pred_tensor = predictions[i] |
| gt_tensor = gt_amodal_frames[i] |
| mask_tensor = amodal_masks[i] if amodal_masks else None |
|
|
| error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), |
| mask_tensor.unsqueeze(0) if mask_tensor is not None else None) |
|
|
| all_errors.append(error) |
|
|
| |
| max_error = max(error.max() for error in all_errors) |
| min_error = min(error.min() for error in all_errors) |
|
|
| for i in range(len(predictions)): |
| |
| scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
| |
| pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) |
|
|
| |
| gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
| |
| |
| error = all_errors[i] |
|
|
| |
| if max_error > min_error: |
| error_normalized = (error - min_error) / (max_error - min_error) |
| else: |
| error_normalized = error |
|
|
| |
| error_normalized = np.squeeze(error_normalized) |
| if error_normalized.ndim == 3: |
| error_normalized = error_normalized[0] |
|
|
| |
| error_colored = cm.jet(error_normalized) |
| error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
| |
| combined = np.concatenate([scene_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) |
|
|
|
|
| |
| from PIL import ImageDraw, ImageFont |
| img_pil = Image.fromarray(combined) |
| draw = ImageDraw.Draw(img_pil) |
|
|
| |
| try: |
| font = ImageFont.load_default() |
| except: |
| font = None |
|
|
| text = f"Error: {min_error:.3f} - {max_error:.3f}" |
| draw.text((combined.shape[1] - 150, 10), text, fill=(255, 255, 255), font=font) |
|
|
| frames.append(img_pil) |
|
|
| |
| frames[0].save( |
| output_path, |
| save_all=True, |
| append_images=frames[1:], |
| duration=duration, |
| loop=0 |
| ) |
|
|
| print(f"GIF with error heatmap saved to {output_path}") |
| print(f"Error range: {min_error:.4f} to {max_error:.4f}") |
|
|
| |
| def load_model_and_generate_video_with_metrics(checkpoint_path, dataset, device, |
| output_path="amodal_completion.mp4", fps=8): |
| """Load trained model and generate video with metrics calculation""" |
| import cv2 |
| from pathlib import Path |
|
|
| |
| metrics_calculator = VideoAmodalMetrics(device) |
|
|
| |
| model = Video3DUNet(in_channels=5, out_channels=3, sequence_length=8).to(device) |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
|
|
| print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}") |
|
|
| |
| sample = dataset[0] |
| seq_len = 8 |
| total_frames = len(sample['rgb_sequence']) |
|
|
| print(f"Processing {total_frames} frames in windows of {seq_len}") |
|
|
| all_predictions = [] |
| all_rgb = [] |
| all_modal_masks = [] |
| all_amodal_masks = [] |
| all_metrics = [] |
|
|
| with torch.no_grad(): |
| |
| for start_idx in range(0, total_frames - seq_len + 1, seq_len//2): |
| end_idx = min(start_idx + seq_len, total_frames) |
|
|
| |
| window_batch = {} |
| for key, value in sample.items(): |
| if isinstance(value, torch.Tensor): |
| if value.dim() == 4: |
| window_batch[key] = value[start_idx:end_idx].unsqueeze(0) |
| else: |
| window_batch[key] = value.unsqueeze(0) |
| else: |
| window_batch[key] = [value] |
|
|
| |
| inputs = prepare_model_input(window_batch).to(device) |
| pred = model(inputs) |
|
|
| |
| amodal_mask = window_batch['amodal_masks'].permute(0, 2, 1, 3, 4).expand_as(pred).to(device) |
| pred_masked = pred * amodal_mask |
|
|
| |
| target = prepare_model_target(window_batch).to(device) |
| window_metrics = metrics_calculator.calculate_all_metrics(pred, target, amodal_mask) |
| all_metrics.append(window_metrics) |
|
|
| |
| pred_frames = pred_masked.squeeze(0).permute(1, 0, 2, 3).cpu() |
|
|
| if start_idx == 0: |
| all_predictions.extend([pred_frames[i] for i in range(len(pred_frames))]) |
| else: |
| overlap_frames = seq_len // 2 |
| for i in range(overlap_frames): |
| if len(all_predictions) > start_idx + i: |
| all_predictions[start_idx + i] = (all_predictions[start_idx + i] + pred_frames[i]) / 2.0 |
|
|
| for i in range(overlap_frames, len(pred_frames)): |
| if start_idx + i < total_frames: |
| all_predictions.append(pred_frames[i]) |
|
|
| if start_idx == 0: |
| all_rgb = [sample['rgb_sequence'][i] for i in range(total_frames)] |
| all_modal_masks = [sample['modal_masks'][i] for i in range(total_frames)] |
| all_amodal_masks = [sample['amodal_masks'][i] for i in range(total_frames)] |
| all_gt_amodal = [sample['amodal_rgb_sequence'][i] for i in range(total_frames)] |
|
|
| |
| print("\nOverall Metrics:") |
| avg_metrics = {} |
| for key in all_metrics[0].keys(): |
| avg_metrics[key] = np.mean([m[key] for m in all_metrics]) |
| print(f" {key.upper()}: {avg_metrics[key]:.4f}") |
|
|
| |
| all_predictions = all_predictions[:total_frames] |
| print(f"Generated {len(all_predictions)} prediction frames") |
|
|
| |
| height, width = all_predictions[0].shape[-2:] |
| video_width = width * 4 |
| video_height = height |
|
|
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(output_path, fourcc, fps, (video_width, video_height)) |
|
|
| for i in range(len(all_predictions)): |
| scene_rgb = all_rgb[i].permute(1, 2, 0).numpy() |
| modal_mask = all_modal_masks[i][0].numpy() |
| modal_mask_rgb = np.stack([modal_mask, modal_mask, modal_mask], axis=2) |
|
|
| pred_rgb = all_predictions[i].permute(1, 2, 0).numpy() |
| pred_rgb = np.clip(pred_rgb, 0, 1) |
|
|
| try: |
| gt_amodal = sample['amodal_rgb_sequence'][i].permute(1, 2, 0).numpy() |
| amodal_mask_np = all_amodal_masks[i][0].numpy() |
| gt_amodal_masked = gt_amodal * amodal_mask_np[:, :, None] |
| except: |
| gt_amodal_masked = np.zeros_like(pred_rgb) |
|
|
| combined_frame = np.concatenate([ |
| scene_rgb, |
| modal_mask_rgb, |
| pred_rgb, |
| gt_amodal_masked |
| ], axis=1) |
|
|
| combined_frame_bgr = cv2.cvtColor((combined_frame * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) |
| out.write(combined_frame_bgr) |
|
|
| if i % 5 == 0: |
| print(f"Processed frame {i+1}/{len(all_predictions)}") |
|
|
| out.release() |
| print(f"Video saved to {output_path}") |
|
|
| return all_predictions, all_rgb, all_gt_amodal, all_amodal_masks, avg_metrics |
|
|
| |
| def run_enhanced_video_generation(): |
| """Run video generation with metrics and error visualization""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| dataset = VideoAmodalDataset( |
| root_dir='data', |
| split='test', |
| seq_len=24, |
| img_size=(256, 256), |
| max_scenes=1, |
| samples_per_scene=1, |
| max_samples=1 |
| ) |
|
|
| |
| checkpoint_path = "video_amodal_model_epoch_4.pth" |
| predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( |
| checkpoint_path, |
| dataset, |
| device, |
| output_path="amodal_completion_video_with_metrics.mp4", |
| fps=8 |
| ) |
|
|
| |
| create_gif_with_error_heatmap( |
| predictions, |
| rgb_frames, |
| gt_amodal_frames, |
| amodal_masks, |
| output_path="amodal_completion_with_error.gif", |
| duration=150 |
| ) |
|
|
| print("Enhanced video generation complete!") |
| return metrics |
|
|
| train_video_amodal_with_metrics() |
|
|
| |
|
|
| import torch |
|
|
| def run_gif_generation(): |
| """Simple function to generate GIFs from your trained model""" |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| dataset = VideoAmodalDataset( |
| root_dir='data', |
| split='test', |
| seq_len=24, |
| img_size=(256, 256), |
| max_scenes=50, |
| samples_per_scene=5, |
| max_samples=50 |
| ) |
|
|
| |
| checkpoint_path = "epoch_29.pth" |
|
|
| predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( |
| checkpoint_path, |
| dataset, |
| device, |
| output_path="amodal_completion_video.mp4", |
| fps=6 |
| ) |
|
|
|
|
|
|
| |
| create_gif_with_error_heatmap( |
| predictions, |
| rgb_frames, |
| gt_amodal_frames, |
| amodal_masks, |
| output_path="amodal_completion_with_error.gif", |
| duration=150 |
| ) |
|
|
|
|
| print("GIF creation complete!") |
| print(f"Metrics: {metrics}") |
|
|
| |
| if __name__ == "__main__": |
| run_gif_generation() |
|
|
| import cv2 |
|
|
| def draw_amodal_boundary(rgb_image, amodal_mask, color=(255, 0, 255)): |
| contours, _ = cv2.findContours(amodal_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| outlined = rgb_image.copy() |
| cv2.drawContours(outlined, contours, -1, color, thickness=2) |
| return outlined |
|
|
| |
| def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, |
| output_path="amodal_completion_with_error.gif", duration=240): |
| """Create animated GIF with proper error heatmap and colorbar""" |
| from PIL import Image, ImageDraw, ImageFont |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib.cm as cm |
| from matplotlib.colors import Normalize |
| import io |
|
|
| frames = [] |
| all_errors = [] |
|
|
| |
| for i in range(len(predictions)): |
| pred_tensor = predictions[i] |
| gt_tensor = gt_amodal_frames[i] |
| mask_tensor = amodal_masks[i] if amodal_masks else None |
|
|
| error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), |
| mask_tensor.unsqueeze(0) if mask_tensor is not None else None) |
| all_errors.append(error) |
|
|
| |
| |
| masked_errors = [] |
| for i, error in enumerate(all_errors): |
| if amodal_masks is not None: |
| mask = amodal_masks[i][0].numpy() |
| masked_error = error * mask |
| masked_errors.extend(masked_error[masked_error > 0]) |
| else: |
| masked_errors.extend(error.flatten()) |
|
|
| if masked_errors: |
| |
| min_error = np.percentile(masked_errors, 5) |
| max_error = np.percentile(masked_errors, 95) |
| else: |
| min_error = min(error.min() for error in all_errors) |
| max_error = max(error.max() for error in all_errors) |
|
|
| |
| if max_error - min_error < 1e-6: |
| max_error = min_error + 1e-6 |
|
|
| print(f"Error range for visualization: {min_error:.4f} to {max_error:.4f}") |
|
|
| |
| def create_colorbar(height=256, width=30): |
| |
| gradient = np.linspace(1, 0, height).reshape(-1, 1) |
| gradient = np.repeat(gradient, width, axis=1) |
|
|
| |
| cmap = cm.get_cmap('hot') |
| colorbar_colored = cmap(gradient) |
| colorbar_rgb = (colorbar_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
| |
| colorbar_img = Image.fromarray(colorbar_rgb) |
|
|
| |
| fig, ax = plt.subplots(figsize=(1, 4)) |
| fig.patch.set_facecolor('black') |
| ax.set_facecolor('black') |
|
|
| |
| norm = Normalize(vmin=min_error, vmax=max_error) |
| sm = cm.ScalarMappable(norm=norm, cmap='hot') |
| sm.set_array([]) |
|
|
| cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=1.0) |
| cbar.set_label('Prediction Error', color='white', fontsize=10) |
| cbar.ax.tick_params(colors='white', labelsize=8) |
|
|
| |
| ax.remove() |
|
|
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', |
| facecolor='black', edgecolor='none', dpi=100) |
| buf.seek(0) |
| colorbar_with_labels = Image.open(buf) |
| plt.close() |
|
|
| return colorbar_with_labels |
|
|
| |
| colorbar_img = create_colorbar() |
| colorbar_width = colorbar_img.width |
|
|
| for i in range(len(predictions)): |
| |
| scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
| |
| pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) |
|
|
| |
| gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
| |
| error = all_errors[i] |
|
|
| |
| if amodal_masks is not None: |
| mask = amodal_masks[i][0].numpy() |
| error = error * mask |
|
|
| |
| error = np.squeeze(error) |
| if error.ndim == 3: |
| error = error[0] |
|
|
| |
| error_normalized = np.clip((error - min_error) / (max_error - min_error), 0, 1) |
|
|
| |
| cmap = cm.get_cmap('hot') |
| error_colored = cmap(error_normalized) |
| error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
| |
| if amodal_masks is not None: |
| mask_3d = np.stack([mask, mask, mask], axis=2) |
| error_rgb = error_rgb * mask_3d.astype(np.uint8) |
|
|
| |
| highlighted_rgb = draw_amodal_boundary(scene_rgb, amodal_masks[i][0].cpu().numpy()) |
|
|
|
|
| combined = np.concatenate([highlighted_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) |
|
|
| |
| img_pil = Image.fromarray(combined) |
|
|
| |
| colorbar_resized = colorbar_img.resize((colorbar_width, img_pil.height)) |
|
|
| |
| final_width = img_pil.width + colorbar_width + 10 |
| final_img = Image.new('RGB', (final_width, img_pil.height), color='black') |
|
|
| |
| final_img.paste(img_pil, (0, 0)) |
| final_img.paste(colorbar_resized, (img_pil.width + 10, 0)) |
|
|
| |
| draw = ImageDraw.Draw(final_img) |
| try: |
| font = ImageFont.load_default() |
| except: |
| font = None |
|
|
| frame_text = f"Frame {i+1}/{len(predictions)}" |
| draw.text((10, 10), frame_text, fill=(0, 0, 0), font=font) |
|
|
| frames.append(final_img) |
|
|
| |
| frames[0].save( |
| output_path, |
| save_all=True, |
| append_images=frames[1:], |
| duration=duration, |
| loop=0 |
| ) |
|
|
| print(f"GIF with proper error heatmap saved to {output_path}") |
| print(f"Error range: {min_error:.4f} to {max_error:.4f}") |
| print(f"Colorbar shows errors from low (black/red) to high (yellow/white)") |
|
|
| |
| def create_error_heatmap(pred, target, mask=None): |
| """Create error heatmap between prediction and target with enhanced sensitivity""" |
| |
| error = torch.sqrt(torch.sum((pred - target) ** 2, dim=1)) |
|
|
| |
| |
|
|
| if mask is not None: |
| error = error * mask.squeeze() |
|
|
| return error.cpu().numpy() |
|
|
|
|