Spaces:
Running
Running
| from pathlib import Path | |
| import torch | |
| from .model import TrajectoryTransformer | |
| from .model_fusion import TrajectoryTransformerFusion | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| MODEL_DIR = REPO_ROOT / "models" | |
| FUSION_CKPT = MODEL_DIR / "best_social_model_fusion.pth" | |
| BASE_CKPT = MODEL_DIR / "best_social_model.pth" | |
| # ---------------------------- | |
| # LOAD MODEL | |
| # ---------------------------- | |
| USING_FUSION_MODEL = False | |
| if FUSION_CKPT.exists(): | |
| model = TrajectoryTransformerFusion(fusion_dim=3).to(device) | |
| try: | |
| model.load_state_dict(torch.load(FUSION_CKPT, map_location=device)) | |
| USING_FUSION_MODEL = True | |
| print("Inference: Loaded Phase 2 fusion checkpoint (best_social_model_fusion.pth).") | |
| except Exception as e: | |
| print(f"Warning: could not load fusion checkpoint ({e}). Falling back to base model.") | |
| model = TrajectoryTransformer().to(device) | |
| try: | |
| model.load_state_dict(torch.load(BASE_CKPT, map_location=device)) | |
| print("Inference: Loaded base checkpoint (best_social_model.pth).") | |
| except Exception as e2: | |
| print(f"Warning: could not load base checkpoint ({e2}).") | |
| else: | |
| model = TrajectoryTransformer().to(device) | |
| try: | |
| model.load_state_dict(torch.load(BASE_CKPT, map_location=device)) | |
| print("Inference: Loaded base checkpoint (best_social_model.pth).") | |
| except Exception as e: | |
| print(f"Warning: could not load model weights ({e}), starting fresh.") | |
| model.eval() | |
| # ---------------------------- | |
| # PREPROCESS INPUT | |
| # ---------------------------- | |
| def prepare_input(points): | |
| import math | |
| x3, y3 = points[3] | |
| window = [[x - x3, y - y3] for x, y in points] | |
| vel = [] | |
| for j in range(len(window)): | |
| if j == 0: | |
| vel.append([0, 0, 0, 0, 0]) | |
| else: | |
| dx = window[j][0] - window[j-1][0] | |
| dy = window[j][1] - window[j-1][1] | |
| speed = math.hypot(dx, dy) | |
| if speed > 1e-5: | |
| sin_t = dy / speed | |
| cos_t = dx / speed | |
| else: | |
| sin_t = 0.0 | |
| cos_t = 0.0 | |
| vel.append([dx, dy, speed, sin_t, cos_t]) | |
| obs = [] | |
| for j in range(4): | |
| obs.append([ | |
| window[j][0], | |
| window[j][1], | |
| vel[j][0], | |
| vel[j][1], | |
| vel[j][2], | |
| vel[j][3], | |
| vel[j][4] | |
| ]) | |
| return obs, (x3, y3) | |
| # ---------------------------- | |
| # PREDICTION FUNCTION | |
| # ---------------------------- | |
| def predict(points, neighbor_points_list=None, fusion_feats=None): | |
| if neighbor_points_list is None: | |
| neighbor_points_list = [] | |
| obs, origin = prepare_input(points) | |
| obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) # (1,4,7) | |
| # Prepare neighbors exactly as the main trajectory | |
| import math | |
| x1, y1 = points[-1] | |
| neighbors = [] | |
| for np_points in neighbor_points_list: | |
| n_window = [[x - x1, y - y1] for x, y in np_points] | |
| vel_n = [] | |
| for j in range(len(n_window)): | |
| if j == 0: | |
| vel_n.append([0, 0, 0, 0, 0]) | |
| else: | |
| dx = n_window[j][0] - n_window[j-1][0] | |
| dy = n_window[j][1] - n_window[j-1][1] | |
| speed = math.hypot(dx, dy) | |
| if speed > 1e-5: | |
| sin_t = dy / speed | |
| cos_t = dx / speed | |
| else: | |
| sin_t = 0.0 | |
| cos_t = 0.0 | |
| vel_n.append([dx, dy, speed, sin_t, cos_t]) | |
| n_obs = [] | |
| for j in range(4): | |
| n_obs.append([ | |
| n_window[j][0], n_window[j][1], | |
| vel_n[j][0], vel_n[j][1], vel_n[j][2], vel_n[j][3], vel_n[j][4] | |
| ]) | |
| neighbors.append(n_obs) | |
| neighbors_batch = [neighbors] # batch size = 1 | |
| with torch.no_grad(): | |
| if USING_FUSION_MODEL: | |
| if fusion_feats is None: | |
| fusion_tensor = torch.zeros((1, 4, 3), dtype=torch.float32, device=device) | |
| else: | |
| fusion_tensor = torch.tensor(fusion_feats, dtype=torch.float32).unsqueeze(0).to(device) | |
| pred, goals, probs, attn_weights = model(obs, neighbors_batch, fusion_tensor) | |
| else: | |
| pred, goals, probs, attn_weights = model(obs, neighbors_batch) | |
| pred = pred.squeeze(0).cpu() | |
| probs = probs.squeeze(0).cpu() | |
| if attn_weights and attn_weights[0] is not None: | |
| attn_weights = [w.cpu() for w in attn_weights] | |
| # convert back to real coordinates | |
| x0, y0 = origin | |
| pred_real = pred.clone() | |
| pred_real[:, :, 0] += x0 | |
| pred_real[:, :, 1] += y0 | |
| return pred_real, probs, attn_weights | |
| # ---------------------------- | |
| # DEMO RUN | |
| # ---------------------------- | |
| if __name__ == "__main__": | |
| points = [ | |
| (0, 0), | |
| (10, 0), | |
| (20, 0), | |
| (30, 0) | |
| ] | |
| pred, probs, _ = predict(points) | |
| print("\nInput Points:") | |
| print(points) | |
| print("\nPredicted Trajectories (Real Coordinates):") | |
| for i in range(pred.shape[0]): | |
| print(f"\nTrajectory {i+1} (prob={probs[i].item():.2f}):") | |
| print(pred[i]) |