Spaces:
Running
Running
| import argparse | |
| import datetime | |
| import os | |
| import random | |
| from pathlib import Path | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from backend.app.legacy.data_loader import ( | |
| load_json, | |
| extract_pedestrian_instances, | |
| build_trajectories_with_sensor, | |
| create_windows_with_sensor, | |
| ) | |
| from backend.app.legacy.dataset_fusion import FusionTrajectoryDataset | |
| from backend.app.ml.model_fusion import TrajectoryTransformerFusion | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| def collate_fn_fusion(batch): | |
| obs, neighbors, fusion_obs, future = zip(*batch) | |
| obs = torch.stack(obs) | |
| fusion_obs = torch.stack(fusion_obs) | |
| future = torch.stack(future) | |
| return obs, list(neighbors), fusion_obs, future | |
| def compute_ade(pred, gt): | |
| return torch.mean(torch.norm(pred - gt, dim=2)) | |
| def compute_fde(pred, gt): | |
| return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1)) | |
| def best_of_k_loss(pred, goals, gt, probs): | |
| gt_traj = gt.unsqueeze(1) | |
| error = torch.norm(pred - gt_traj, dim=3).mean(dim=2) | |
| min_error, best_idx = torch.min(error, dim=1) | |
| traj_loss = torch.mean(min_error) | |
| best_goals = goals[torch.arange(goals.size(0), device=goals.device), best_idx] | |
| goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean() | |
| prob_loss = torch.nn.functional.nll_loss(torch.log(probs + 1e-8), best_idx) | |
| diversity_loss = 0.0 | |
| K = pred.size(1) | |
| if K > 1: | |
| reg = 0.0 | |
| pairs = 0 | |
| for i in range(K): | |
| for j in range(i + 1, K): | |
| dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1) | |
| reg = reg + torch.exp(-dist).mean() | |
| pairs += 1 | |
| diversity_loss = reg / max(1, pairs) | |
| return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss | |
| def get_fusion_samples(): | |
| sample_annotations = load_json("sample_annotation") | |
| instances = load_json("instance") | |
| categories = load_json("category") | |
| ped_instances = extract_pedestrian_instances(sample_annotations, instances, categories) | |
| trajectories = build_trajectories_with_sensor(sample_annotations, ped_instances) | |
| samples = create_windows_with_sensor(trajectories) | |
| return samples | |
| def train_phase2(args): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| base_checkpoint = Path(args.base_checkpoint) | |
| output_checkpoint = Path(args.output_checkpoint) | |
| if not base_checkpoint.is_absolute(): | |
| base_checkpoint = REPO_ROOT / base_checkpoint | |
| if not output_checkpoint.is_absolute(): | |
| output_checkpoint = REPO_ROOT / output_checkpoint | |
| output_checkpoint.parent.mkdir(parents=True, exist_ok=True) | |
| os.makedirs("log", exist_ok=True) | |
| log_filename = os.path.join( | |
| "log", | |
| f"phase2_fusion_train_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
| ) | |
| def log_print(msg): | |
| print(msg) | |
| with open(log_filename, "a", encoding="utf-8") as f: | |
| f.write(msg + "\n") | |
| log_print(f"Starting Phase 2 fusion transfer-learning on {device}...") | |
| samples = get_fusion_samples() | |
| if args.max_samples > 0: | |
| samples = samples[: args.max_samples] | |
| random.seed(42) | |
| random.shuffle(samples) | |
| train_size = int(0.8 * len(samples)) | |
| train_samples = samples[:train_size] | |
| val_samples = samples[train_size:] | |
| train_dataset = FusionTrajectoryDataset(train_samples, augment=True) | |
| val_dataset = FusionTrajectoryDataset(val_samples, augment=False) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| collate_fn=collate_fn_fusion, | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| collate_fn=collate_fn_fusion, | |
| ) | |
| model = TrajectoryTransformerFusion(fusion_dim=3).to(device) | |
| if base_checkpoint.exists(): | |
| missing, unexpected = model.load_from_base_checkpoint(str(base_checkpoint), map_location=device) | |
| log_print(f"Loaded base checkpoint: {base_checkpoint}") | |
| log_print(f"Missing keys count: {len(missing)}") | |
| log_print(f"Unexpected keys count: {len(unexpected)}") | |
| else: | |
| log_print(f"Base checkpoint not found: {base_checkpoint}") | |
| base_params = [] | |
| fusion_params = [] | |
| for n, p in model.named_parameters(): | |
| if n.startswith("fusion_embed") or n.startswith("fusion_ln"): | |
| fusion_params.append(p) | |
| else: | |
| base_params.append(p) | |
| optimizer = optim.Adam( | |
| [ | |
| {"params": base_params, "lr": args.base_lr}, | |
| {"params": fusion_params, "lr": args.fusion_lr}, | |
| ] | |
| ) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| mode='min', | |
| factor=0.5, | |
| patience=4, | |
| ) | |
| best_val_ade = float("inf") | |
| patience_counter = 0 | |
| for epoch in range(args.epochs): | |
| model.train() | |
| train_loss = 0.0 | |
| for obs, neighbors, fusion_obs, future in train_loader: | |
| obs = obs.to(device) | |
| fusion_obs = fusion_obs.to(device) | |
| future = future.to(device) | |
| pred, goals, probs, _ = model(obs, neighbors, fusion_obs) | |
| loss = best_of_k_loss(pred, goals, future, probs) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| model.eval() | |
| val_ade = 0.0 | |
| val_fde = 0.0 | |
| batches = 0 | |
| with torch.no_grad(): | |
| for obs, neighbors, fusion_obs, future in val_loader: | |
| obs = obs.to(device) | |
| fusion_obs = fusion_obs.to(device) | |
| future = future.to(device) | |
| pred, goals, probs, _ = model(obs, neighbors, fusion_obs) | |
| gt = future.unsqueeze(1) | |
| err = torch.norm(pred - gt, dim=3).mean(dim=2) | |
| best_idx = torch.argmin(err, dim=1) | |
| best_pred = pred[torch.arange(pred.size(0), device=device), best_idx] | |
| val_ade += compute_ade(best_pred, future).item() | |
| val_fde += compute_fde(best_pred, future).item() | |
| batches += 1 | |
| val_ade = val_ade / max(1, batches) | |
| val_fde = val_fde / max(1, batches) | |
| scheduler.step(val_ade) | |
| curr_lr_base = optimizer.param_groups[0]['lr'] | |
| curr_lr_fusion = optimizer.param_groups[1]['lr'] | |
| log_print(f"Epoch {epoch + 1}/{args.epochs}") | |
| log_print(f"Train Loss: {train_loss:.4f}") | |
| log_print(f"Val ADE: {val_ade:.4f} | Val FDE: {val_fde:.4f}") | |
| log_print(f"LR base={curr_lr_base:.6f} | fusion={curr_lr_fusion:.6f}") | |
| log_print("-" * 44) | |
| if val_ade < best_val_ade: | |
| best_val_ade = val_ade | |
| patience_counter = 0 | |
| torch.save(model.state_dict(), output_checkpoint) | |
| log_print(f"New best fusion model saved: {output_checkpoint}") | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= args.patience: | |
| log_print(f"Early stopping at epoch {epoch + 1} (patience reached).") | |
| break | |
| log_print("Phase 2 fusion transfer-learning complete.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Phase 2: LiDAR/Radar Fusion Transfer-Learning") | |
| parser.add_argument("--epochs", type=int, default=20) | |
| parser.add_argument("--batch-size", type=int, default=64) | |
| parser.add_argument("--base-lr", type=float, default=2e-4) | |
| parser.add_argument("--fusion-lr", type=float, default=8e-4) | |
| parser.add_argument("--patience", type=int, default=8) | |
| parser.add_argument("--max-samples", type=int, default=0, help="Use first N samples for quick debug run. 0 = full data.") | |
| parser.add_argument("--base-checkpoint", type=str, default="models/best_social_model.pth") | |
| parser.add_argument("--output-checkpoint", type=str, default="models/best_social_model_fusion.pth") | |
| args = parser.parse_args() | |
| train_phase2(args) | |