sajith-0701
Deploy FastAPI backend to HF Spaces (Docker SDK)
98075af
import torch
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import os
import datetime
from pathlib import Path
from backend.app.legacy.dataset import TrajectoryDataset
from backend.app.ml.model import TrajectoryTransformer
from backend.app.legacy.data_loader import (
load_json, extract_pedestrian_instances,
build_trajectories, create_windows
)
REPO_ROOT = Path(__file__).resolve().parents[3]
MODEL_DIR = REPO_ROOT / "models"
# ----------------------------
# CUSTOM COLLATE (IMPORTANT)
# ----------------------------
def collate_fn(batch):
obs, neighbors, future = zip(*batch)
obs = torch.stack(obs)
future = torch.stack(future)
return obs, list(neighbors), future
# ----------------------------
# LOAD DATA
# ----------------------------
def get_data():
sample_annotations = load_json("sample_annotation")
instances = load_json("instance")
categories = load_json("category")
ped_instances = extract_pedestrian_instances(
sample_annotations, instances, categories
)
trajectories = build_trajectories(sample_annotations, ped_instances)
samples = create_windows(trajectories)
return samples
# ----------------------------
# METRICS
# ----------------------------
def compute_ade(pred, gt):
return torch.mean(torch.norm(pred - gt, dim=2))
def compute_fde(pred, gt):
return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1))
# ----------------------------
# LOSS
# ----------------------------
def best_of_k_loss(pred, goals, gt, probs):
gt_traj = gt.unsqueeze(1) # (B, 1, 6, 2)
gt_goal = gt[:, -1, :].unsqueeze(1) # (B, 1, 2)
# Error calculation over the entire path
error = torch.norm(pred - gt_traj, dim=3).mean(dim=2) # (B, K)
min_error, best_idx = torch.min(error, dim=1)
traj_loss = torch.mean(min_error)
# Goal Loss: force the network to explicitly predict accurate endpoints!
best_goals = goals[torch.arange(goals.size(0)), best_idx] # (B, 2)
goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean()
prob_loss = torch.nn.functional.cross_entropy(probs, best_idx)
# -----------------------------
# DIVERSITY REGULARIZATION
# -----------------------------
diversity_loss = 0
K = pred.size(1)
if K > 1:
for i in range(K):
for j in range(i + 1, K):
dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1)
diversity_loss += torch.exp(-dist).mean()
diversity_loss /= (K * (K - 1) / 2)
return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss
# ----------------------------
# TRAIN
# ----------------------------
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs("log", exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)
log_filename = os.path.join("log", f"train_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
best_model_path = MODEL_DIR / "best_social_model.pth"
def log_print(msg):
print(msg)
with open(log_filename, "a") as f:
f.write(msg + "\n")
import random
log_print(f"Starting training on {device}...")
samples = get_data()
# Deterministic split as promised
random.seed(42)
random.shuffle(samples)
train_size = int(0.8 * len(samples))
train_samples = samples[:train_size]
val_samples = samples[train_size:]
train_dataset = TrajectoryDataset(train_samples, augment=True)
val_dataset = TrajectoryDataset(val_samples, augment=False)
train_loader = DataLoader(
train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset, batch_size=64, collate_fn=collate_fn
)
model = TrajectoryTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
best_ade = float("inf")
patience_counter = 0
max_patience = 15
for epoch in range(100): # Increased to 100 max epochs with early stopping
model.train()
total_loss = 0
for obs, neighbors, future in train_loader:
obs, future = obs.to(device), future.to(device)
pred, goals, probs, _ = model(obs, neighbors)
loss = best_of_k_loss(pred, goals, future, probs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# ---------------- VALIDATION ----------------
model.eval()
ade, fde = 0, 0
with torch.no_grad():
for obs, neighbors, future in val_loader:
obs, future = obs.to(device), future.to(device)
pred, goals, probs, _ = model(obs, neighbors)
gt = future.unsqueeze(1)
error = torch.norm(pred - gt, dim=3).mean(dim=2)
best_idx = torch.argmin(error, dim=1)
best_pred = pred[torch.arange(pred.size(0)), best_idx]
ade += compute_ade(best_pred, future).item()
fde += compute_fde(best_pred, future).item()
log_print(f"Epoch {epoch+1}")
log_print(f"Train Loss: {total_loss:.4f}")
log_print(f"ADE: {ade:.4f}, FDE: {fde:.4f}")
log_print("-" * 40)
# Save best model
if ade < best_ade:
log_print(f"New best model found! ADE improved from {best_ade:.4f} to {ade:.4f}")
best_ade = ade
torch.save(model.state_dict(), best_model_path)
patience_counter = 0
else:
patience_counter += 1
# Update Learning Rate
scheduler.step(ade)
current_lr = optimizer.param_groups[0]['lr']
log_print(f"Current Learning Rate: {current_lr}")
if patience_counter >= max_patience:
log_print(f"Early stopping triggered! No improvement for {max_patience} epochs.")
break
log_print("Training complete!")
if __name__ == "__main__":
train()