import random from pathlib import Path import torch from torch.utils.data import DataLoader from backend.app.legacy.data_loader import ( load_json, extract_pedestrian_instances, build_trajectories_with_sensor, create_windows_with_sensor, ) from backend.app.legacy.dataset_fusion import FusionTrajectoryDataset from backend.app.ml.model_fusion import TrajectoryTransformerFusion REPO_ROOT = Path(__file__).resolve().parents[3] DEFAULT_FUSION_CKPT = REPO_ROOT / "models" / "best_social_model_fusion.pth" def collate_fn_fusion(batch): obs, neighbors, fusion_obs, future = zip(*batch) obs = torch.stack(obs) fusion_obs = torch.stack(fusion_obs) future = torch.stack(future) return obs, list(neighbors), fusion_obs, future def compute_ade(pred, gt): return torch.mean(torch.norm(pred - gt, dim=2)) def compute_fde(pred, gt): return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1)) def load_fusion_samples(): sample_annotations = load_json("sample_annotation") instances = load_json("instance") categories = load_json("category") ped_instances = extract_pedestrian_instances(sample_annotations, instances, categories) trajectories = build_trajectories_with_sensor(sample_annotations, ped_instances) return create_windows_with_sensor(trajectories) def evaluate_fusion(ckpt=DEFAULT_FUSION_CKPT): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Running Phase 2 Fusion Evaluation on {device}...") ckpt_path = Path(ckpt) if not ckpt_path.is_absolute(): ckpt_path = REPO_ROOT / ckpt_path samples = load_fusion_samples() random.seed(42) random.shuffle(samples) train_size = int(0.8 * len(samples)) val_samples = samples[train_size:] dataset = FusionTrajectoryDataset(val_samples, augment=False) loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn_fusion) model = TrajectoryTransformerFusion(fusion_dim=3).to(device) model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.eval() total_ade = 0.0 total_fde = 0.0 miss_count = 0 cv_total_ade = 0.0 cv_total_fde = 0.0 cv_miss_count = 0 total_n = 0 miss_threshold = 2.0 with torch.no_grad(): for obs, neighbors, fusion_obs, future in loader: obs = obs.to(device) fusion_obs = fusion_obs.to(device) future = future.to(device) pred, goals, probs, _ = model(obs, neighbors, fusion_obs) gt = future.unsqueeze(1) err = torch.norm(pred - gt, dim=3).mean(dim=2) best_idx = torch.argmin(err, dim=1) best_pred = pred[torch.arange(pred.size(0), device=device), best_idx] total_ade += compute_ade(best_pred, future).item() * obs.size(0) total_fde += compute_fde(best_pred, future).item() * obs.size(0) final_disp = torch.norm(best_pred[:, -1] - future[:, -1], dim=1) miss_count += (final_disp > miss_threshold).sum().item() # Constant velocity baseline for comparison. vx = obs[:, 3, 2].unsqueeze(1) vy = obs[:, 3, 3].unsqueeze(1) t = torch.arange(1, 13, device=device).unsqueeze(0).float() x_last = obs[:, 3, 0].unsqueeze(1) y_last = obs[:, 3, 1].unsqueeze(1) cv_x = x_last + vx * t cv_y = y_last + vy * t cv_pred = torch.stack([cv_x, cv_y], dim=-1) cv_total_ade += compute_ade(cv_pred, future).item() * obs.size(0) cv_total_fde += compute_fde(cv_pred, future).item() * obs.size(0) cv_final = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1) cv_miss_count += (cv_final > miss_threshold).sum().item() total_n += obs.size(0) avg_ade = total_ade / total_n avg_fde = total_fde / total_n avg_miss = 100.0 * miss_count / total_n cv_avg_ade = cv_total_ade / total_n cv_avg_fde = cv_total_fde / total_n cv_avg_miss = 100.0 * cv_miss_count / total_n print("\n========================================================") print(" PHASE 2 FUSION METRICS REPORT ") print("========================================================") print(f"Total Trajectories Evaluated: {total_n}") print("--------------------------------------------------------") print("METRIC | BASELINE (CV) | FUSION MODEL ") print("------------------------|---------------|----------------") print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:14.2f}") print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:14.2f}") print(f"Miss Rate (>2.0m) | {cv_avg_miss:12.1f}% | {avg_miss:13.1f}%") print("========================================================\n") if __name__ == '__main__': evaluate_fusion()