IntentDrive / backend /app /ml /inference.py
sajith-0701
Deploy FastAPI backend to HF Spaces (Docker SDK)
98075af
from pathlib import Path
import torch
from .model import TrajectoryTransformer
from .model_fusion import TrajectoryTransformerFusion
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
REPO_ROOT = Path(__file__).resolve().parents[3]
MODEL_DIR = REPO_ROOT / "models"
FUSION_CKPT = MODEL_DIR / "best_social_model_fusion.pth"
BASE_CKPT = MODEL_DIR / "best_social_model.pth"
# ----------------------------
# LOAD MODEL
# ----------------------------
USING_FUSION_MODEL = False
if FUSION_CKPT.exists():
model = TrajectoryTransformerFusion(fusion_dim=3).to(device)
try:
model.load_state_dict(torch.load(FUSION_CKPT, map_location=device))
USING_FUSION_MODEL = True
print("Inference: Loaded Phase 2 fusion checkpoint (best_social_model_fusion.pth).")
except Exception as e:
print(f"Warning: could not load fusion checkpoint ({e}). Falling back to base model.")
model = TrajectoryTransformer().to(device)
try:
model.load_state_dict(torch.load(BASE_CKPT, map_location=device))
print("Inference: Loaded base checkpoint (best_social_model.pth).")
except Exception as e2:
print(f"Warning: could not load base checkpoint ({e2}).")
else:
model = TrajectoryTransformer().to(device)
try:
model.load_state_dict(torch.load(BASE_CKPT, map_location=device))
print("Inference: Loaded base checkpoint (best_social_model.pth).")
except Exception as e:
print(f"Warning: could not load model weights ({e}), starting fresh.")
model.eval()
# ----------------------------
# PREPROCESS INPUT
# ----------------------------
def prepare_input(points):
import math
x3, y3 = points[3]
window = [[x - x3, y - y3] for x, y in points]
vel = []
for j in range(len(window)):
if j == 0:
vel.append([0, 0, 0, 0, 0])
else:
dx = window[j][0] - window[j-1][0]
dy = window[j][1] - window[j-1][1]
speed = math.hypot(dx, dy)
if speed > 1e-5:
sin_t = dy / speed
cos_t = dx / speed
else:
sin_t = 0.0
cos_t = 0.0
vel.append([dx, dy, speed, sin_t, cos_t])
obs = []
for j in range(4):
obs.append([
window[j][0],
window[j][1],
vel[j][0],
vel[j][1],
vel[j][2],
vel[j][3],
vel[j][4]
])
return obs, (x3, y3)
# ----------------------------
# PREDICTION FUNCTION
# ----------------------------
def predict(points, neighbor_points_list=None, fusion_feats=None):
if neighbor_points_list is None:
neighbor_points_list = []
obs, origin = prepare_input(points)
obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) # (1,4,7)
# Prepare neighbors exactly as the main trajectory
import math
x1, y1 = points[-1]
neighbors = []
for np_points in neighbor_points_list:
n_window = [[x - x1, y - y1] for x, y in np_points]
vel_n = []
for j in range(len(n_window)):
if j == 0:
vel_n.append([0, 0, 0, 0, 0])
else:
dx = n_window[j][0] - n_window[j-1][0]
dy = n_window[j][1] - n_window[j-1][1]
speed = math.hypot(dx, dy)
if speed > 1e-5:
sin_t = dy / speed
cos_t = dx / speed
else:
sin_t = 0.0
cos_t = 0.0
vel_n.append([dx, dy, speed, sin_t, cos_t])
n_obs = []
for j in range(4):
n_obs.append([
n_window[j][0], n_window[j][1],
vel_n[j][0], vel_n[j][1], vel_n[j][2], vel_n[j][3], vel_n[j][4]
])
neighbors.append(n_obs)
neighbors_batch = [neighbors] # batch size = 1
with torch.no_grad():
if USING_FUSION_MODEL:
if fusion_feats is None:
fusion_tensor = torch.zeros((1, 4, 3), dtype=torch.float32, device=device)
else:
fusion_tensor = torch.tensor(fusion_feats, dtype=torch.float32).unsqueeze(0).to(device)
pred, goals, probs, attn_weights = model(obs, neighbors_batch, fusion_tensor)
else:
pred, goals, probs, attn_weights = model(obs, neighbors_batch)
pred = pred.squeeze(0).cpu()
probs = probs.squeeze(0).cpu()
if attn_weights and attn_weights[0] is not None:
attn_weights = [w.cpu() for w in attn_weights]
# convert back to real coordinates
x0, y0 = origin
pred_real = pred.clone()
pred_real[:, :, 0] += x0
pred_real[:, :, 1] += y0
return pred_real, probs, attn_weights
# ----------------------------
# DEMO RUN
# ----------------------------
if __name__ == "__main__":
points = [
(0, 0),
(10, 0),
(20, 0),
(30, 0)
]
pred, probs, _ = predict(points)
print("\nInput Points:")
print(points)
print("\nPredicted Trajectories (Real Coordinates):")
for i in range(pred.shape[0]):
print(f"\nTrajectory {i+1} (prob={probs[i].item():.2f}):")
print(pred[i])