Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| import torch.optim as optim | |
| import os | |
| import datetime | |
| from pathlib import Path | |
| from backend.app.legacy.dataset import TrajectoryDataset | |
| from backend.app.ml.model import TrajectoryTransformer | |
| from backend.app.legacy.data_loader import ( | |
| load_json, extract_pedestrian_instances, | |
| build_trajectories, create_windows | |
| ) | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| MODEL_DIR = REPO_ROOT / "models" | |
| # ---------------------------- | |
| # CUSTOM COLLATE (IMPORTANT) | |
| # ---------------------------- | |
| def collate_fn(batch): | |
| obs, neighbors, future = zip(*batch) | |
| obs = torch.stack(obs) | |
| future = torch.stack(future) | |
| return obs, list(neighbors), future | |
| # ---------------------------- | |
| # LOAD DATA | |
| # ---------------------------- | |
| def get_data(): | |
| sample_annotations = load_json("sample_annotation") | |
| instances = load_json("instance") | |
| categories = load_json("category") | |
| ped_instances = extract_pedestrian_instances( | |
| sample_annotations, instances, categories | |
| ) | |
| trajectories = build_trajectories(sample_annotations, ped_instances) | |
| samples = create_windows(trajectories) | |
| return samples | |
| # ---------------------------- | |
| # METRICS | |
| # ---------------------------- | |
| def compute_ade(pred, gt): | |
| return torch.mean(torch.norm(pred - gt, dim=2)) | |
| def compute_fde(pred, gt): | |
| return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1)) | |
| # ---------------------------- | |
| # LOSS | |
| # ---------------------------- | |
| def best_of_k_loss(pred, goals, gt, probs): | |
| gt_traj = gt.unsqueeze(1) # (B, 1, 6, 2) | |
| gt_goal = gt[:, -1, :].unsqueeze(1) # (B, 1, 2) | |
| # Error calculation over the entire path | |
| error = torch.norm(pred - gt_traj, dim=3).mean(dim=2) # (B, K) | |
| min_error, best_idx = torch.min(error, dim=1) | |
| traj_loss = torch.mean(min_error) | |
| # Goal Loss: force the network to explicitly predict accurate endpoints! | |
| best_goals = goals[torch.arange(goals.size(0)), best_idx] # (B, 2) | |
| goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean() | |
| prob_loss = torch.nn.functional.cross_entropy(probs, best_idx) | |
| # ----------------------------- | |
| # DIVERSITY REGULARIZATION | |
| # ----------------------------- | |
| diversity_loss = 0 | |
| K = pred.size(1) | |
| if K > 1: | |
| for i in range(K): | |
| for j in range(i + 1, K): | |
| dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1) | |
| diversity_loss += torch.exp(-dist).mean() | |
| diversity_loss /= (K * (K - 1) / 2) | |
| return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss | |
| # ---------------------------- | |
| # TRAIN | |
| # ---------------------------- | |
| def train(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| os.makedirs("log", exist_ok=True) | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| log_filename = os.path.join("log", f"train_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt") | |
| best_model_path = MODEL_DIR / "best_social_model.pth" | |
| def log_print(msg): | |
| print(msg) | |
| with open(log_filename, "a") as f: | |
| f.write(msg + "\n") | |
| import random | |
| log_print(f"Starting training on {device}...") | |
| samples = get_data() | |
| # Deterministic split as promised | |
| random.seed(42) | |
| random.shuffle(samples) | |
| train_size = int(0.8 * len(samples)) | |
| train_samples = samples[:train_size] | |
| val_samples = samples[train_size:] | |
| train_dataset = TrajectoryDataset(train_samples, augment=True) | |
| val_dataset = TrajectoryDataset(val_samples, augment=False) | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, batch_size=64, collate_fn=collate_fn | |
| ) | |
| model = TrajectoryTransformer().to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) | |
| best_ade = float("inf") | |
| patience_counter = 0 | |
| max_patience = 15 | |
| for epoch in range(100): # Increased to 100 max epochs with early stopping | |
| model.train() | |
| total_loss = 0 | |
| for obs, neighbors, future in train_loader: | |
| obs, future = obs.to(device), future.to(device) | |
| pred, goals, probs, _ = model(obs, neighbors) | |
| loss = best_of_k_loss(pred, goals, future, probs) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| # ---------------- VALIDATION ---------------- | |
| model.eval() | |
| ade, fde = 0, 0 | |
| with torch.no_grad(): | |
| for obs, neighbors, future in val_loader: | |
| obs, future = obs.to(device), future.to(device) | |
| pred, goals, probs, _ = model(obs, neighbors) | |
| gt = future.unsqueeze(1) | |
| error = torch.norm(pred - gt, dim=3).mean(dim=2) | |
| best_idx = torch.argmin(error, dim=1) | |
| best_pred = pred[torch.arange(pred.size(0)), best_idx] | |
| ade += compute_ade(best_pred, future).item() | |
| fde += compute_fde(best_pred, future).item() | |
| log_print(f"Epoch {epoch+1}") | |
| log_print(f"Train Loss: {total_loss:.4f}") | |
| log_print(f"ADE: {ade:.4f}, FDE: {fde:.4f}") | |
| log_print("-" * 40) | |
| # Save best model | |
| if ade < best_ade: | |
| log_print(f"New best model found! ADE improved from {best_ade:.4f} to {ade:.4f}") | |
| best_ade = ade | |
| torch.save(model.state_dict(), best_model_path) | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| # Update Learning Rate | |
| scheduler.step(ade) | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| log_print(f"Current Learning Rate: {current_lr}") | |
| if patience_counter >= max_patience: | |
| log_print(f"Early stopping triggered! No improvement for {max_patience} epochs.") | |
| break | |
| log_print("Training complete!") | |
| if __name__ == "__main__": | |
| train() |