sajith-0701
Deploy FastAPI backend to HF Spaces (Docker SDK)
98075af
import torch
from torch.utils.data import DataLoader
from pathlib import Path
from backend.app.legacy.dataset import TrajectoryDataset
from backend.app.ml.model import TrajectoryTransformer
from backend.scripts.training.train import get_data, collate_fn, compute_ade, compute_fde
import numpy as np
import random
REPO_ROOT = Path(__file__).resolve().parents[3]
BASE_CKPT = REPO_ROOT / "models" / "best_social_model.pth"
def evaluate():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running Evaluation on {device}...")
samples = get_data()
# Use the same deterministic split as train.py to evaluate on validation set
random.seed(42)
random.shuffle(samples)
train_size = int(0.8 * len(samples))
val_samples = samples[train_size:]
dataset = TrajectoryDataset(val_samples, augment=False)
eval_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)
# Load Model
model = TrajectoryTransformer().to(device)
try:
model.load_state_dict(torch.load(BASE_CKPT, map_location=device, weights_only=True))
print("Successfully loaded 'best_social_model.pth' from models folder")
except Exception as e:
print(f"Could not load model weights: {e}")
return
model.eval()
total_ade = 0
total_fde = 0
miss_rate = 0
cv_total_ade = 0
cv_total_fde = 0
cv_miss_rate = 0
total_samples = 0
# Miss rate threshold: if best path's endpoint is off by more than 2.0 meters
MISS_THRESHOLD = 2.0
print("\n--- Starting Deep Evaluation ---")
with torch.no_grad():
for obs, neighbors, future in eval_loader:
obs, future = obs.to(device), future.to(device)
# --- MODEL PREDICTION ---
pred, goals, probs, _ = model(obs, neighbors)
# Find the best prediction out of K=3 for each item in the batch
gt = future.unsqueeze(1)
error = torch.norm(pred - gt, dim=3).mean(dim=2)
best_idx = torch.argmin(error, dim=1)
best_pred = pred[torch.arange(pred.size(0)), best_idx]
# Metrics Model
batch_ade = compute_ade(best_pred, future).item()
batch_fde = compute_fde(best_pred, future).item()
total_ade += batch_ade * obs.size(0)
total_fde += batch_fde * obs.size(0)
final_displacement = torch.norm(best_pred[:, -1] - future[:, -1], dim=1)
misses = (final_displacement > MISS_THRESHOLD).sum().item()
miss_rate += misses
# --- CONSTANT VELOCITY BASELINE ---
vx = obs[:, 3, 2].unsqueeze(1) # dx at last observed step
vy = obs[:, 3, 3].unsqueeze(1) # dy at last observed step
t = torch.arange(1, 13, device=device).unsqueeze(0).float() # Horizon is 12 steps
x_last = obs[:, 3, 0].unsqueeze(1) # x at last step
y_last = obs[:, 3, 1].unsqueeze(1) # y at last step
cv_pred_x = x_last + vx * t
cv_pred_y = y_last + vy * t
cv_pred = torch.stack([cv_pred_x, cv_pred_y], dim=-1)
# Metrics CV Baseline
cv_batch_ade = compute_ade(cv_pred, future).item()
cv_batch_fde = compute_fde(cv_pred, future).item()
cv_total_ade += cv_batch_ade * obs.size(0)
cv_total_fde += cv_batch_fde * obs.size(0)
cv_final_displacement = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1)
cv_misses = (cv_final_displacement > MISS_THRESHOLD).sum().item()
cv_miss_rate += cv_misses
total_samples += obs.size(0)
# Average metrics
avg_ade = total_ade / total_samples
avg_fde = total_fde / total_samples
avg_miss_rate = (miss_rate / total_samples) * 100
cv_avg_ade = cv_total_ade / total_samples
cv_avg_fde = cv_total_fde / total_samples
cv_avg_miss_rate = (cv_miss_rate / total_samples) * 100
print("\n========================================================")
print(" HACKATHON FINAL METRICS REPORT ")
print("========================================================")
print(f"Total Trajectories Evaluated (Val Set): {total_samples}")
print(f"Prediction Horizon: 6 Seconds (12 steps)")
print(f"Social Context Radius: 50 Meters")
print("--------------------------------------------------------")
print("METRIC | BASELINE (CV) | OUR TRANSFORMER ")
print("------------------------|---------------|-----------------")
print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:15.2f}")
print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:15.2f}")
print(f"Miss Rate (>2.0m) | {cv_avg_miss_rate:12.1f}% | {avg_miss_rate:14.1f}%")
print("========================================================\n")
if __name__ == '__main__':
evaluate()