Spaces:
Running
Running
File size: 4,732 Bytes
98075af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import math
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
class TrajectoryTransformerFusion(nn.Module):
def __init__(self, fusion_dim=3):
super().__init__()
self.d_model = 64
# Base kinematic embedding from original model features.
self.base_embed = nn.Linear(7, self.d_model)
# Fusion branch: LiDAR/Radar strength features per timestep.
self.fusion_embed = nn.Linear(fusion_dim, self.d_model)
self.fusion_ln = nn.LayerNorm(self.d_model)
self.pos_enc = PositionalEncoding(self.d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=4,
dim_feedforward=256,
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
self.social_attn = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=4, batch_first=True)
self.K = 3
self.hidden_dim = 128
self.future_len = 12
self.goal_head = nn.Sequential(
nn.Linear(self.hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, self.K * 2),
)
self.traj_head = nn.Sequential(
nn.Linear(self.hidden_dim + 2, 128),
nn.ReLU(),
nn.Linear(128, self.future_len * 2),
)
self.prob_head = nn.Linear(self.hidden_dim, self.K)
def load_from_base_checkpoint(self, ckpt_path, map_location='cpu'):
state = torch.load(ckpt_path, map_location=map_location)
remap = {}
for k, v in state.items():
if k.startswith('embed.'):
remap['base_embed.' + k[len('embed.'):]] = v
else:
remap[k] = v
missing, unexpected = self.load_state_dict(remap, strict=False)
return missing, unexpected
def social_pool(self, h_target, neighbor_h_list, device):
if len(neighbor_h_list) == 0:
return torch.zeros(self.d_model, device=device), None
query = h_target.unsqueeze(0).unsqueeze(0)
neighbor_h_tensor = torch.stack(neighbor_h_list).unsqueeze(0)
attn_output, attn_weights = self.social_attn(query, neighbor_h_tensor, neighbor_h_tensor)
return attn_output.squeeze(0).squeeze(0), attn_weights.squeeze(0)
def forward(self, x, neighbors, fusion_feats=None):
"""
x: (B, 4, 7)
neighbors: list length B, each element is list of neighbors with shape (4, 7)
fusion_feats: (B, 4, F) where F=3 [lidar_pts_norm, radar_pts_norm, sensor_strength]
"""
B = x.size(0)
device = x.device
x_emb = self.base_embed(x)
if fusion_feats is not None:
x_emb = self.fusion_ln(x_emb + self.fusion_embed(fusion_feats))
x_emb = self.pos_enc(x_emb)
enc_out = self.transformer_encoder(x_emb)
h = enc_out[:, -1, :]
final_h = []
batch_attn_weights = []
for i in range(B):
h_target = h[i]
neighbor_h_list = []
for n in neighbors[i]:
n_tensor = torch.as_tensor(n, dtype=torch.float32, device=device).unsqueeze(0)
n_emb = self.pos_enc(self.base_embed(n_tensor))
n_enc_out = self.transformer_encoder(n_emb)
neighbor_h_list.append(n_enc_out[0, -1, :])
h_social, attn_weights = self.social_pool(h_target, neighbor_h_list, device)
batch_attn_weights.append(attn_weights)
h_combined = torch.cat([h_target, h_social], dim=0)
final_h.append(h_combined)
h_final = torch.stack(final_h)
goals = self.goal_head(h_final).view(B, self.K, 2)
trajs = []
for k in range(self.K):
goal_k = goals[:, k, :]
conditioned_context = torch.cat([h_final, goal_k], dim=1)
traj_k = self.traj_head(conditioned_context).view(B, 1, self.future_len, 2)
trajs.append(traj_k)
traj = torch.cat(trajs, dim=1)
probs = self.prob_head(h_final)
probs = torch.softmax(probs, dim=1)
return traj, goals, probs, batch_attn_weights
|