import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=100): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:, :x.size(1), :] class TrajectoryTransformer(nn.Module): def __init__(self): super().__init__() self.d_model = 64 # 1. Feature Embedding & Positional Encoding self.embed = nn.Linear(7, self.d_model) self.pos_enc = PositionalEncoding(self.d_model) # 2. Transformer Sequence Encoder (Replaces LSTM) encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=4, dim_feedforward=256, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) # 3. Social Attention (Target queries Neighbors) self.social_attn = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=4, batch_first=True) self.K = 3 # number of future modes # 4. GOAL-CONDITIONED ARCHITECTURE # Base hidden context: Target (64) + Social (64) = 128 self.hidden_dim = 128 self.future_len = 12 # Now predicting 6 seconds into future # Step A: Predict exactly K distinct endpoints (goals) self.goal_head = nn.Sequential( nn.Linear(self.hidden_dim, 64), nn.ReLU(), nn.Linear(64, self.K * 2) # X, Y for K goals ) # Step B: Given the encoded context PLUS a specific Goal, draw the path to get there self.traj_head = nn.Sequential( nn.Linear(self.hidden_dim + 2, 128), nn.ReLU(), nn.Linear(128, self.future_len * 2) # 12 steps to reach the destination ) # 5. Probabilities of each mode self.prob_head = nn.Linear(self.hidden_dim, self.K) # ---------------------------- # SOCIAL POOLING # ---------------------------- def social_pool(self, h_target, neighbor_h_list, device): if len(neighbor_h_list) == 0: return torch.zeros(self.d_model, device=device), None # h_target: (64) -> query: (1, 1, 64) query = h_target.unsqueeze(0).unsqueeze(0) # neighbor_h_list: N x 64 -> key, value: (1, N, 64) neighbor_h_tensor = torch.stack(neighbor_h_list).unsqueeze(0) # apply attention attn_output, attn_weights = self.social_attn(query, neighbor_h_tensor, neighbor_h_tensor) return attn_output.squeeze(0).squeeze(0), attn_weights.squeeze(0) # ---------------------------- # FORWARD PASS # ---------------------------- def forward(self, x, neighbors): """ x: (B, 4, 7) neighbors: list of length B """ B = x.size(0) device = x.device # Encode main trajectory sequence with Transformer x_emb = self.embed(x) x_emb = self.pos_enc(x_emb) enc_out = self.transformer_encoder(x_emb) h = enc_out[:, -1, :] # Grab context from last timestep (B, 64) final_h = [] batch_attn_weights = [] # Loop through batch to handle variable size neighbors for i in range(B): h_target = h[i] # (64) neighbor_h_list = [] for n in neighbors[i]: n_tensor = torch.tensor(n, dtype=torch.float32, device=device).unsqueeze(0) n_emb = self.pos_enc(self.embed(n_tensor)) n_enc_out = self.transformer_encoder(n_emb) neighbor_h_list.append(n_enc_out[0, -1, :]) # (64) # Social attention pooling h_social, attn_weights = self.social_pool(h_target, neighbor_h_list, device) batch_attn_weights.append(attn_weights) # Combine Target and Social context h_combined = torch.cat([h_target, h_social], dim=0) # (128) final_h.append(h_combined) h_final = torch.stack(final_h) # (B, 128) # GOAL-CONDITIONED LOGIC # 1. Predict Goals (End-points at t=6) goals = self.goal_head(h_final) goals = goals.view(B, self.K, 2) # (B, K, 2) # 2. Condition trajectories on the predicted goals trajs = [] for k in range(self.K): goal_k = goals[:, k, :] # Get the k-th destination (B, 2) # Concat the base context array with the goal coordinate! conditioned_context = torch.cat([h_final, goal_k], dim=1) # (B, 130) # Predict the path given the condition traj_k = self.traj_head(conditioned_context).view(B, 1, self.future_len, 2) trajs.append(traj_k) traj = torch.cat(trajs, dim=1) # (B, K, 12, 2) # 3. Mode Probabilities probs = self.prob_head(h_final) probs = torch.softmax(probs, dim=1) return traj, goals, probs, batch_attn_weights