Spaces:
Running
Running
| import math | |
| import torch | |
| import torch.nn as nn | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=100): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| return x + self.pe[:, :x.size(1), :] | |
| class TrajectoryTransformerFusion(nn.Module): | |
| def __init__(self, fusion_dim=3): | |
| super().__init__() | |
| self.d_model = 64 | |
| # Base kinematic embedding from original model features. | |
| self.base_embed = nn.Linear(7, self.d_model) | |
| # Fusion branch: LiDAR/Radar strength features per timestep. | |
| self.fusion_embed = nn.Linear(fusion_dim, self.d_model) | |
| self.fusion_ln = nn.LayerNorm(self.d_model) | |
| self.pos_enc = PositionalEncoding(self.d_model) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=self.d_model, | |
| nhead=4, | |
| dim_feedforward=256, | |
| batch_first=True, | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) | |
| self.social_attn = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=4, batch_first=True) | |
| self.K = 3 | |
| self.hidden_dim = 128 | |
| self.future_len = 12 | |
| self.goal_head = nn.Sequential( | |
| nn.Linear(self.hidden_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, self.K * 2), | |
| ) | |
| self.traj_head = nn.Sequential( | |
| nn.Linear(self.hidden_dim + 2, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, self.future_len * 2), | |
| ) | |
| self.prob_head = nn.Linear(self.hidden_dim, self.K) | |
| def load_from_base_checkpoint(self, ckpt_path, map_location='cpu'): | |
| state = torch.load(ckpt_path, map_location=map_location) | |
| remap = {} | |
| for k, v in state.items(): | |
| if k.startswith('embed.'): | |
| remap['base_embed.' + k[len('embed.'):]] = v | |
| else: | |
| remap[k] = v | |
| missing, unexpected = self.load_state_dict(remap, strict=False) | |
| return missing, unexpected | |
| def social_pool(self, h_target, neighbor_h_list, device): | |
| if len(neighbor_h_list) == 0: | |
| return torch.zeros(self.d_model, device=device), None | |
| query = h_target.unsqueeze(0).unsqueeze(0) | |
| neighbor_h_tensor = torch.stack(neighbor_h_list).unsqueeze(0) | |
| attn_output, attn_weights = self.social_attn(query, neighbor_h_tensor, neighbor_h_tensor) | |
| return attn_output.squeeze(0).squeeze(0), attn_weights.squeeze(0) | |
| def forward(self, x, neighbors, fusion_feats=None): | |
| """ | |
| x: (B, 4, 7) | |
| neighbors: list length B, each element is list of neighbors with shape (4, 7) | |
| fusion_feats: (B, 4, F) where F=3 [lidar_pts_norm, radar_pts_norm, sensor_strength] | |
| """ | |
| B = x.size(0) | |
| device = x.device | |
| x_emb = self.base_embed(x) | |
| if fusion_feats is not None: | |
| x_emb = self.fusion_ln(x_emb + self.fusion_embed(fusion_feats)) | |
| x_emb = self.pos_enc(x_emb) | |
| enc_out = self.transformer_encoder(x_emb) | |
| h = enc_out[:, -1, :] | |
| final_h = [] | |
| batch_attn_weights = [] | |
| for i in range(B): | |
| h_target = h[i] | |
| neighbor_h_list = [] | |
| for n in neighbors[i]: | |
| n_tensor = torch.as_tensor(n, dtype=torch.float32, device=device).unsqueeze(0) | |
| n_emb = self.pos_enc(self.base_embed(n_tensor)) | |
| n_enc_out = self.transformer_encoder(n_emb) | |
| neighbor_h_list.append(n_enc_out[0, -1, :]) | |
| h_social, attn_weights = self.social_pool(h_target, neighbor_h_list, device) | |
| batch_attn_weights.append(attn_weights) | |
| h_combined = torch.cat([h_target, h_social], dim=0) | |
| final_h.append(h_combined) | |
| h_final = torch.stack(final_h) | |
| goals = self.goal_head(h_final).view(B, self.K, 2) | |
| trajs = [] | |
| for k in range(self.K): | |
| goal_k = goals[:, k, :] | |
| conditioned_context = torch.cat([h_final, goal_k], dim=1) | |
| traj_k = self.traj_head(conditioned_context).view(B, 1, self.future_len, 2) | |
| trajs.append(traj_k) | |
| traj = torch.cat(trajs, dim=1) | |
| probs = self.prob_head(h_final) | |
| probs = torch.softmax(probs, dim=1) | |
| return traj, goals, probs, batch_attn_weights | |