IntentDrive / backend /scripts /training /train_phase2_fusion.py
sajith-0701
Deploy FastAPI backend to HF Spaces (Docker SDK)
98075af
import argparse
import datetime
import os
import random
from pathlib import Path
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from backend.app.legacy.data_loader import (
load_json,
extract_pedestrian_instances,
build_trajectories_with_sensor,
create_windows_with_sensor,
)
from backend.app.legacy.dataset_fusion import FusionTrajectoryDataset
from backend.app.ml.model_fusion import TrajectoryTransformerFusion
REPO_ROOT = Path(__file__).resolve().parents[3]
def collate_fn_fusion(batch):
obs, neighbors, fusion_obs, future = zip(*batch)
obs = torch.stack(obs)
fusion_obs = torch.stack(fusion_obs)
future = torch.stack(future)
return obs, list(neighbors), fusion_obs, future
def compute_ade(pred, gt):
return torch.mean(torch.norm(pred - gt, dim=2))
def compute_fde(pred, gt):
return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1))
def best_of_k_loss(pred, goals, gt, probs):
gt_traj = gt.unsqueeze(1)
error = torch.norm(pred - gt_traj, dim=3).mean(dim=2)
min_error, best_idx = torch.min(error, dim=1)
traj_loss = torch.mean(min_error)
best_goals = goals[torch.arange(goals.size(0), device=goals.device), best_idx]
goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean()
prob_loss = torch.nn.functional.nll_loss(torch.log(probs + 1e-8), best_idx)
diversity_loss = 0.0
K = pred.size(1)
if K > 1:
reg = 0.0
pairs = 0
for i in range(K):
for j in range(i + 1, K):
dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1)
reg = reg + torch.exp(-dist).mean()
pairs += 1
diversity_loss = reg / max(1, pairs)
return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss
def get_fusion_samples():
sample_annotations = load_json("sample_annotation")
instances = load_json("instance")
categories = load_json("category")
ped_instances = extract_pedestrian_instances(sample_annotations, instances, categories)
trajectories = build_trajectories_with_sensor(sample_annotations, ped_instances)
samples = create_windows_with_sensor(trajectories)
return samples
def train_phase2(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_checkpoint = Path(args.base_checkpoint)
output_checkpoint = Path(args.output_checkpoint)
if not base_checkpoint.is_absolute():
base_checkpoint = REPO_ROOT / base_checkpoint
if not output_checkpoint.is_absolute():
output_checkpoint = REPO_ROOT / output_checkpoint
output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
os.makedirs("log", exist_ok=True)
log_filename = os.path.join(
"log",
f"phase2_fusion_train_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
)
def log_print(msg):
print(msg)
with open(log_filename, "a", encoding="utf-8") as f:
f.write(msg + "\n")
log_print(f"Starting Phase 2 fusion transfer-learning on {device}...")
samples = get_fusion_samples()
if args.max_samples > 0:
samples = samples[: args.max_samples]
random.seed(42)
random.shuffle(samples)
train_size = int(0.8 * len(samples))
train_samples = samples[:train_size]
val_samples = samples[train_size:]
train_dataset = FusionTrajectoryDataset(train_samples, augment=True)
val_dataset = FusionTrajectoryDataset(val_samples, augment=False)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn_fusion,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn_fusion,
)
model = TrajectoryTransformerFusion(fusion_dim=3).to(device)
if base_checkpoint.exists():
missing, unexpected = model.load_from_base_checkpoint(str(base_checkpoint), map_location=device)
log_print(f"Loaded base checkpoint: {base_checkpoint}")
log_print(f"Missing keys count: {len(missing)}")
log_print(f"Unexpected keys count: {len(unexpected)}")
else:
log_print(f"Base checkpoint not found: {base_checkpoint}")
base_params = []
fusion_params = []
for n, p in model.named_parameters():
if n.startswith("fusion_embed") or n.startswith("fusion_ln"):
fusion_params.append(p)
else:
base_params.append(p)
optimizer = optim.Adam(
[
{"params": base_params, "lr": args.base_lr},
{"params": fusion_params, "lr": args.fusion_lr},
]
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=4,
)
best_val_ade = float("inf")
patience_counter = 0
for epoch in range(args.epochs):
model.train()
train_loss = 0.0
for obs, neighbors, fusion_obs, future in train_loader:
obs = obs.to(device)
fusion_obs = fusion_obs.to(device)
future = future.to(device)
pred, goals, probs, _ = model(obs, neighbors, fusion_obs)
loss = best_of_k_loss(pred, goals, future, probs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
model.eval()
val_ade = 0.0
val_fde = 0.0
batches = 0
with torch.no_grad():
for obs, neighbors, fusion_obs, future in val_loader:
obs = obs.to(device)
fusion_obs = fusion_obs.to(device)
future = future.to(device)
pred, goals, probs, _ = model(obs, neighbors, fusion_obs)
gt = future.unsqueeze(1)
err = torch.norm(pred - gt, dim=3).mean(dim=2)
best_idx = torch.argmin(err, dim=1)
best_pred = pred[torch.arange(pred.size(0), device=device), best_idx]
val_ade += compute_ade(best_pred, future).item()
val_fde += compute_fde(best_pred, future).item()
batches += 1
val_ade = val_ade / max(1, batches)
val_fde = val_fde / max(1, batches)
scheduler.step(val_ade)
curr_lr_base = optimizer.param_groups[0]['lr']
curr_lr_fusion = optimizer.param_groups[1]['lr']
log_print(f"Epoch {epoch + 1}/{args.epochs}")
log_print(f"Train Loss: {train_loss:.4f}")
log_print(f"Val ADE: {val_ade:.4f} | Val FDE: {val_fde:.4f}")
log_print(f"LR base={curr_lr_base:.6f} | fusion={curr_lr_fusion:.6f}")
log_print("-" * 44)
if val_ade < best_val_ade:
best_val_ade = val_ade
patience_counter = 0
torch.save(model.state_dict(), output_checkpoint)
log_print(f"New best fusion model saved: {output_checkpoint}")
else:
patience_counter += 1
if patience_counter >= args.patience:
log_print(f"Early stopping at epoch {epoch + 1} (patience reached).")
break
log_print("Phase 2 fusion transfer-learning complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phase 2: LiDAR/Radar Fusion Transfer-Learning")
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--base-lr", type=float, default=2e-4)
parser.add_argument("--fusion-lr", type=float, default=8e-4)
parser.add_argument("--patience", type=int, default=8)
parser.add_argument("--max-samples", type=int, default=0, help="Use first N samples for quick debug run. 0 = full data.")
parser.add_argument("--base-checkpoint", type=str, default="models/best_social_model.pth")
parser.add_argument("--output-checkpoint", type=str, default="models/best_social_model_fusion.pth")
args = parser.parse_args()
train_phase2(args)