Spaces:
Running
Running
sajith-0701 commited on
Commit ·
98075af
1
Parent(s): bc453f9
Deploy FastAPI backend to HF Spaces (Docker SDK)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +39 -0
- .gitignore +7 -0
- Dockerfile +39 -0
- README.md +385 -6
- backend/README.md +78 -0
- backend/__init__.py +1 -0
- backend/app/__init__.py +1 -0
- backend/app/api/__init__.py +1 -0
- backend/app/api/dependencies.py +12 -0
- backend/app/api/routes/__init__.py +1 -0
- backend/app/api/routes/health.py +19 -0
- backend/app/api/routes/live.py +51 -0
- backend/app/api/routes/predict.py +54 -0
- backend/app/core/__init__.py +1 -0
- backend/app/core/serialization.py +85 -0
- backend/app/core/uploads.py +20 -0
- backend/app/legacy/__init__.py +1 -0
- backend/app/legacy/cv_perception.py +119 -0
- backend/app/legacy/data_loader.py +347 -0
- backend/app/legacy/dataset.py +100 -0
- backend/app/legacy/dataset_fusion.py +37 -0
- backend/app/legacy/map_renderer.py +101 -0
- backend/app/legacy/visualization.py +399 -0
- backend/app/main.py +42 -0
- backend/app/ml/__init__.py +1 -0
- backend/app/ml/inference.py +172 -0
- backend/app/ml/model.py +145 -0
- backend/app/ml/model_fusion.py +138 -0
- backend/app/ml/sensor_fusion.py +396 -0
- backend/app/schemas.py +39 -0
- backend/app/services/__init__.py +1 -0
- backend/app/services/pipeline.py +1255 -0
- backend/scripts/__init__.py +1 -0
- backend/scripts/data/__init__.py +1 -0
- backend/scripts/data/build_dataset_from_images.py +119 -0
- backend/scripts/evaluation/__init__.py +1 -0
- backend/scripts/evaluation/benchmark_perf.py +109 -0
- backend/scripts/evaluation/evaluate.py +127 -0
- backend/scripts/evaluation/evaluate_phase2_fusion.py +137 -0
- backend/scripts/legacy/__init__.py +1 -0
- backend/scripts/legacy/app_streamlit.py +2533 -0
- backend/scripts/tools/__init__.py +1 -0
- backend/scripts/tools/generate_benchmark_metric_pages.py +572 -0
- backend/scripts/tools/generate_metric_pages.py +346 -0
- backend/scripts/tools/run_full_pipeline.py +170 -0
- backend/scripts/tools/smoke_verify_bev.py +56 -0
- backend/scripts/training/__init__.py +1 -0
- backend/scripts/training/finetune_cv_pipeline.py +138 -0
- backend/scripts/training/train.py +203 -0
- backend/scripts/training/train_phase2_fusion.py +244 -0
.dockerignore
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude everything that is not needed inside the Docker image.
|
| 2 |
+
# This keeps the build context small and prevents large data files
|
| 3 |
+
# from being sent to the Docker daemon.
|
| 4 |
+
|
| 5 |
+
# Python virtual environment
|
| 6 |
+
venv/
|
| 7 |
+
.venv/
|
| 8 |
+
|
| 9 |
+
# nuScenes dataset (large binary data – not needed in the Space)
|
| 10 |
+
DataSet/
|
| 11 |
+
|
| 12 |
+
# Frontend (deployed separately on Vercel)
|
| 13 |
+
frontend/
|
| 14 |
+
node_modules/
|
| 15 |
+
|
| 16 |
+
# Training artefacts and legacy scripts
|
| 17 |
+
archive/
|
| 18 |
+
log/
|
| 19 |
+
|
| 20 |
+
# Byte-compiled / cache files
|
| 21 |
+
__pycache__/
|
| 22 |
+
**/__pycache__/
|
| 23 |
+
*.pyc
|
| 24 |
+
*.pyo
|
| 25 |
+
*.pyd
|
| 26 |
+
*.egg-info/
|
| 27 |
+
dist/
|
| 28 |
+
build/
|
| 29 |
+
|
| 30 |
+
# Editor and OS artefacts
|
| 31 |
+
.git/
|
| 32 |
+
.gitignore
|
| 33 |
+
.DS_Store
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
|
| 37 |
+
# Unused model checkpoints (only the two needed ones are copied explicitly)
|
| 38 |
+
models/best_cv_synced_model.pth
|
| 39 |
+
models/best_social_model_fusion_smoke.pth
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.env
|
| 6 |
+
venv/
|
| 7 |
+
.venv/
|
Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces — Docker deployment
|
| 2 |
+
# SDK: docker | app_port: 7860
|
| 3 |
+
#
|
| 4 |
+
# Build context is the repo root.
|
| 5 |
+
# Only backend/ and models/ are copied in; DataSet/ is intentionally excluded.
|
| 6 |
+
|
| 7 |
+
FROM python:3.11-slim
|
| 8 |
+
|
| 9 |
+
# System libraries required by opencv-python-headless
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
libglib2.0-0 \
|
| 12 |
+
libgl1-mesa-glx \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# HF Spaces runs containers as uid 1000 by default
|
| 16 |
+
RUN useradd -m -u 1000 appuser
|
| 17 |
+
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# Install CPU-only PyTorch first so pip does not download the large CUDA wheels
|
| 21 |
+
RUN pip install --no-cache-dir \
|
| 22 |
+
torch==2.2.2+cpu \
|
| 23 |
+
torchvision==0.17.2+cpu \
|
| 24 |
+
--index-url https://download.pytorch.org/whl/cpu
|
| 25 |
+
|
| 26 |
+
# Install remaining API dependencies
|
| 27 |
+
COPY requirements-hf.txt requirements-hf.txt
|
| 28 |
+
RUN pip install --no-cache-dir -r requirements-hf.txt
|
| 29 |
+
|
| 30 |
+
# Copy only what the API needs at runtime
|
| 31 |
+
COPY --chown=appuser:appuser backend/ backend/
|
| 32 |
+
COPY --chown=appuser:appuser models/ models/
|
| 33 |
+
|
| 34 |
+
USER appuser
|
| 35 |
+
|
| 36 |
+
EXPOSE 7860
|
| 37 |
+
|
| 38 |
+
CMD ["python", "-m", "uvicorn", "backend.app.main:app", \
|
| 39 |
+
"--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,390 @@
|
|
| 1 |
---
|
| 2 |
-
title: IntentDrive
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
colorTo: yellow
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: IntentDrive BEV Trajectory Backend
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
|
|
|
| 5 |
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# IntentDrive — Road User Trajectory Prediction
|
| 11 |
+
|
| 12 |
+
An end-to-end trajectory forecasting system for vulnerable road users (VRUs). The system connects camera-based perception, lightweight multi-agent tracking, and a transformer-based social forecasting model through a structured FastAPI backend and a React visualization dashboard.
|
| 13 |
+
|
| 14 |
+
> **Competition:** Computer Vision Challenge — AI and Computer Vision Track
|
| 15 |
+
> **Team:** 4% | **Lead:** Sajith J | **Institution:** Sri Shakthi Institute of Engineering & Technology
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Problem Statement
|
| 20 |
+
|
| 21 |
+
In Level 4 autonomous driving, reacting to the *current* position of pedestrians and cyclists is insufficient. VRUs can behave unpredictably and may be occluded behind vehicles or other objects. This project builds a system that uses **2 seconds of past motion history** to predict the **next 6 seconds of future trajectory**, enabling safer and more proactive decisions.
|
| 22 |
+
|
| 23 |
+
> **"Math over Pixels"** — our deliberate architectural decision. Rather than relying purely on visual signals, we model the underlying kinematics and social interactions of agents, making the system robust to occlusion and poor lighting.
|
| 24 |
+
|
| 25 |
+
Real-world context: A Waymo robotaxi struck a child near Grant Elementary School in Santa Monica on January 23, 2026, causing minor injuries. Systems like IntentDrive are designed to anticipate such scenarios before they occur.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Project Overview
|
| 30 |
+
|
| 31 |
+
This project addresses the problem of safety-critical motion forecasting for pedestrians, cyclists, and motorcyclists in autonomous driving scenarios. Given a short observed history of agent positions, the system predicts **K=3 multimodal 6-second future trajectories** (12 future steps) along with per-mode probability scores.
|
| 32 |
+
|
| 33 |
+
The full pipeline includes:
|
| 34 |
+
|
| 35 |
+
- Object detection and optional keypoint extraction from camera frames
|
| 36 |
+
- Image-to-BEV coordinate conversion using camera intrinsics and scene geometry
|
| 37 |
+
- Temporal tracking to build per-agent motion histories
|
| 38 |
+
- Social context construction from neighboring agent tracks within a **50-meter radius**
|
| 39 |
+
- Transformer-based trajectory forecasting with goal-conditioned multimodal decoding
|
| 40 |
+
- LiDAR and radar fusion for improved short-term kinematic estimation
|
| 41 |
+
- FastAPI backend serving inference, live frame access, and health endpoints
|
| 42 |
+
- React + TypeScript dashboard for BEV scene visualization, trajectory rendering, and sensor overlay
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## System Architecture
|
| 47 |
+
|
| 48 |
+
The pipeline operates across five stages:
|
| 49 |
+
|
| 50 |
+
**Stage 1 — Data Ingestion & Preprocessing**
|
| 51 |
+
Multi-sensor input (6x cameras, LiDAR_TOP, 5x radar channels) is ingested from nuScenes. Timestamps are synchronized via sample-token matching. All sensor readings are projected into a unified ego-centric BEV coordinate frame using sensor-to-ego calibration matrices and quaternion-to-yaw conversion.
|
| 52 |
+
|
| 53 |
+
**Stage 2 — Feature Extraction**
|
| 54 |
+
Three parallel branches process sensor data simultaneously:
|
| 55 |
+
- **Camera branch:** Faster R-CNN (ResNet50-FPN) for multi-class object detection + Keypoint R-CNN for 17-point human pose estimation
|
| 56 |
+
- **LiDAR branch:** Occupancy and depth geometry extraction
|
| 57 |
+
- **Radar branch:** Velocity vectors and Doppler motion cues
|
| 58 |
+
|
| 59 |
+
**Stage 3 — Fusion & Tracking**
|
| 60 |
+
Cross-sensor fusion combines semantic detections, spatial geometry, and motion dynamics into unified agent representations. Multi-object tracking maintains consistent IDs across frames using nearest-neighbor IoU matching with pixel gating. Motion encoding builds a 4-step history of (x, y, velocity_x, velocity_y, speed, heading_sin, heading_cos) per agent.
|
| 61 |
+
|
| 62 |
+
**Stage 4 — Model Inference**
|
| 63 |
+
A goal-conditioned Trajectory Transformer with social attention predicts 3 trajectory modes, each 12 steps (6 seconds) into the future. Post-processing assigns direction labels (Straight / Left / Right / Backward) and top-3 probabilities per VRU.
|
| 64 |
+
|
| 65 |
+
**Stage 5 — Deployment & Visualization**
|
| 66 |
+
Outputs include camera overlay with bounding boxes and skeleton paths, a holographic skeleton panel for explainability, and a fused BEV map with direction probabilities.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Model Architecture
|
| 71 |
+
|
| 72 |
+
### Base Model: TrajectoryTransformer
|
| 73 |
+
|
| 74 |
+
The base model (`backend/app/ml/model.py`) is a goal-conditioned multimodal trajectory forecaster operating on 4-step observed windows with 7 features per timestep: x, y, velocity_x, velocity_y, speed, heading_sin, heading_cos.
|
| 75 |
+
|
| 76 |
+
**Components:**
|
| 77 |
+
|
| 78 |
+
| Component | Description |
|
| 79 |
+
|---|---|
|
| 80 |
+
| Feature Embedding | Linear projection from 7 input features to d_model=64 |
|
| 81 |
+
| Positional Encoding | Sinusoidal positional encoding over the observed sequence |
|
| 82 |
+
| Temporal Encoder | 2-layer TransformerEncoder, 4 attention heads, feedforward dim 256 |
|
| 83 |
+
| Social Attention | Multi-head attention pooling over encoded neighbor agent representations, 4 heads |
|
| 84 |
+
| Goal Head | MLP predicting K=3 distinct 2D endpoint goals from the combined context |
|
| 85 |
+
| Trajectory Head | MLP conditioned on context + each predicted goal; outputs a 12-step path per mode |
|
| 86 |
+
| Probability Head | Linear layer with softmax producing per-mode confidence scores |
|
| 87 |
+
|
| 88 |
+
**Forward pass summary:**
|
| 89 |
+
|
| 90 |
+
1. Each agent's 4-step observed sequence is embedded and positionally encoded.
|
| 91 |
+
2. The TransformerEncoder produces a context vector from the final timestep.
|
| 92 |
+
3. Each neighboring agent within the social radius is independently encoded and pooled into a social context vector via cross-attention.
|
| 93 |
+
4. Target and social context vectors are concatenated to form a 128-dimensional hidden state.
|
| 94 |
+
5. K=3 goal endpoints are predicted from the hidden state.
|
| 95 |
+
6. Each goal is concatenated back to the hidden state to condition the trajectory decoder, producing 3 independent 12-step trajectory modes.
|
| 96 |
+
7. Mode probabilities are produced via a linear + softmax head.
|
| 97 |
+
|
| 98 |
+
**Loss function:**
|
| 99 |
+
|
| 100 |
+
The training objective combines four terms:
|
| 101 |
+
|
| 102 |
+
- Best-of-K trajectory loss (minimum L2 error over K modes)
|
| 103 |
+
- Goal loss (L2 distance from the best-mode predicted endpoint to ground truth endpoint)
|
| 104 |
+
- Probability cross-entropy loss (supervising the mode probability head)
|
| 105 |
+
- Diversity regularization loss (penalizes mode collapse via exponential repulsion between modes)
|
| 106 |
+
|
| 107 |
+
### Fusion Model: TrajectoryTransformerFusion
|
| 108 |
+
|
| 109 |
+
The fusion variant (`backend/app/ml/model_fusion.py`) extends the base model with a sensor-aware input branch. In addition to the standard 7-feature kinematic input, per-timestep fusion features of dimension 3 are accepted: normalized LiDAR point count, normalized radar point count, and composite sensor strength. These fusion features are projected to d_model=64 via a separate linear layer, added to the base kinematic embedding, and normalized with LayerNorm before entering the shared TransformerEncoder. The fusion model supports loading weights from a base model checkpoint for initialization.
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
## Dataset
|
| 114 |
+
|
| 115 |
+
**Source:** nuScenes mini split (V1.0-mini), annotations loaded via nuScenes JSON tables. The model was trained and evaluated exclusively using the provided dataset, without incorporating any external data sources.
|
| 116 |
+
|
| 117 |
+
**Target classes:** pedestrian, bicycle, motorcycle
|
| 118 |
+
|
| 119 |
+
**Sensors used:** 6x cameras, LIDAR_TOP, 5x radar channels
|
| 120 |
+
|
| 121 |
+
**Windowing:**
|
| 122 |
+
- Takes a **2-second history** of motion as input (4 observed steps at 2 Hz)
|
| 123 |
+
- Outputs **K=3 multimodal trajectory predictions over a 6-second prediction horizon** (12 future steps at 2 Hz), each with an associated probability score
|
| 124 |
+
|
| 125 |
+
**Input features per observed step:**
|
| 126 |
+
- x, y position (BEV meters)
|
| 127 |
+
- velocity_x, velocity_y (m/s)
|
| 128 |
+
- speed (m/s)
|
| 129 |
+
- heading_sin, heading_cos (unit circle encoding)
|
| 130 |
+
|
| 131 |
+
**Social context radius:** 50 meters
|
| 132 |
+
|
| 133 |
+
**Data augmentation (training split only):** random rotation, horizontal reflection, Gaussian coordinate noise injection
|
| 134 |
+
|
| 135 |
+
**Split protocol:** deterministic 80/20 train/validation split (seed 42)
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Performance
|
| 140 |
+
|
| 141 |
+
### Baseline: Constant-Velocity Model
|
| 142 |
+
|
| 143 |
+
| Metric | Value |
|
| 144 |
+
|---|---|
|
| 145 |
+
| minADE (K=3) | 0.65 m |
|
| 146 |
+
| minFDE (K=3) | 1.35 m |
|
| 147 |
+
| Miss Rate (>2.0 m) | 19.9 % |
|
| 148 |
+
|
| 149 |
+
### Base Model — Camera-Only Transformer (best_social_model.pth)
|
| 150 |
+
|
| 151 |
+
| Metric | Value | Improvement vs Baseline |
|
| 152 |
+
|---|---|---|
|
| 153 |
+
| Validation trajectories | 468 | — |
|
| 154 |
+
| minADE (K=3) | 0.50 m | 23.1% |
|
| 155 |
+
| minFDE (K=3) | 0.96 m | 29.6% |
|
| 156 |
+
| Miss Rate (>2.0 m) | 9.9 % | 50.8% |
|
| 157 |
+
|
| 158 |
+
### Fusion Model — LiDAR + Radar (best_social_model_fusion.pth)
|
| 159 |
+
|
| 160 |
+
| Metric | Value | Improvement vs Baseline |
|
| 161 |
+
|---|---|---|
|
| 162 |
+
| Validation trajectories | 468 | — |
|
| 163 |
+
| minADE (K=3) | **0.42 m** | **35.4%** |
|
| 164 |
+
| minFDE (K=3) | **0.78 m** | **42.2%** |
|
| 165 |
+
| Miss Rate (>2.0 m) | **7.1 %** | **64.3%** |
|
| 166 |
+
|
| 167 |
+
### Runtime Benchmark
|
| 168 |
+
|
| 169 |
+
| Stage | Latency |
|
| 170 |
+
|---|---|
|
| 171 |
+
| Detection model — Faster R-CNN (per frame) | 30.7 ms |
|
| 172 |
+
| Sensor fusion — LiDAR + Radar lookup | 12 ms |
|
| 173 |
+
| Transformer prediction head (per agent) | 14.6 ms |
|
| 174 |
+
| Full end-to-end pipeline (2-frame loop) | ~58 ms |
|
| 175 |
+
| Equivalent throughput | ~17.24 FPS |
|
| 176 |
+
|
| 177 |
+
### Model Efficiency
|
| 178 |
+
|
| 179 |
+
| Model | Parameters | Size |
|
| 180 |
+
|---|---|---|
|
| 181 |
+
| Base Transformer | ~146K | ~0.6 MB |
|
| 182 |
+
| Fusion Transformer | ~146K | ~0.6 MB |
|
| 183 |
+
|
| 184 |
+
The prediction module is compact and edge-friendly. The real-time bottleneck comes from the heavy CNN perception stack (Faster R-CNN), not the trajectory prediction head.
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## Repository Structure
|
| 189 |
+
|
| 190 |
+
```
|
| 191 |
+
bev/
|
| 192 |
+
├── backend/
|
| 193 |
+
│ ├── app/
|
| 194 |
+
│ │ ├── api/
|
| 195 |
+
│ │ │ └── routes/ # FastAPI route modules: health, live, predict
|
| 196 |
+
│ │ ├── core/ # Serialization and shared utilities
|
| 197 |
+
│ │ ├── ml/
|
| 198 |
+
│ │ │ ├── model.py # TrajectoryTransformer (base, camera-only)
|
| 199 |
+
│ │ │ ├── model_fusion.py # TrajectoryTransformerFusion (LiDAR + Radar)
|
| 200 |
+
│ │ │ ├── inference.py # Inference pipeline
|
| 201 |
+
│ │ │ └── sensor_fusion.py # LiDAR/radar feature extraction
|
| 202 |
+
│ │ ├── services/ # Business logic layer
|
| 203 |
+
│ │ └── main.py # FastAPI application factory
|
| 204 |
+
│ └── scripts/
|
| 205 |
+
│ ├── data/ # Dataset construction from nuScenes images
|
| 206 |
+
│ ├── training/
|
| 207 |
+
│ │ ├── train.py # Stage 1: Base model training
|
| 208 |
+
│ │ ├── train_phase2_fusion.py # Stage 2: Fusion model training
|
| 209 |
+
│ │ └── finetune_cv_pipeline.py # CV-synced fine-tuning
|
| 210 |
+
│ ├── evaluation/
|
| 211 |
+
│ │ ├── evaluate.py # Base model evaluation
|
| 212 |
+
│ │ ├── evaluate_phase2_fusion.py # Fusion model evaluation
|
| 213 |
+
│ │ └── benchmark_perf.py # Runtime latency benchmarking
|
| 214 |
+
│ └── tools/
|
| 215 |
+
├── frontend/
|
| 216 |
+
│ ├── src/
|
| 217 |
+
│ │ ├── App.tsx # Main dashboard component
|
| 218 |
+
│ │ ├── types.ts # TypeScript type definitions
|
| 219 |
+
│ │ ├── api/ # API client layer
|
| 220 |
+
│ │ ├── components/ # UI components
|
| 221 |
+
│ │ └── styles.css # Global styles
|
| 222 |
+
│ ├── package.json
|
| 223 |
+
│ └── vite.config.ts
|
| 224 |
+
├── models/
|
| 225 |
+
│ ├── best_social_model.pth # Trained base model checkpoint
|
| 226 |
+
│ ├── best_social_model_fusion.pth # Trained fusion model checkpoint
|
| 227 |
+
│ ├── best_cv_synced_model.pth # CV-pipeline fine-tuned checkpoint
|
| 228 |
+
│ └── best_social_model_fusion_smoke.pth
|
| 229 |
+
├── extracted_training_data.json # Preprocessed nuScenes trajectory data
|
| 230 |
+
└── log/ # Training logs
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
## Setup and Installation
|
| 236 |
+
|
| 237 |
+
### Prerequisites
|
| 238 |
+
|
| 239 |
+
- Python 3.10 or later
|
| 240 |
+
- Node.js 18 or later and npm
|
| 241 |
+
- nuScenes mini dataset (V1.0-mini) if retraining from scratch; pretrained checkpoints are included in `models/`
|
| 242 |
+
- GPU recommended (tested on NVIDIA RTX 5050 — 8 GB VRAM)
|
| 243 |
+
|
| 244 |
+
### Backend
|
| 245 |
+
|
| 246 |
+
```bash
|
| 247 |
+
# Create and activate a virtual environment
|
| 248 |
+
python -m venv venv
|
| 249 |
+
venv\Scripts\activate # Windows
|
| 250 |
+
# source venv/bin/activate # Linux / macOS
|
| 251 |
+
|
| 252 |
+
# Install dependencies
|
| 253 |
+
pip install -r requirements.txt
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
### Frontend
|
| 257 |
+
|
| 258 |
+
```bash
|
| 259 |
+
cd frontend
|
| 260 |
+
npm install
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
## How to Run
|
| 266 |
+
|
| 267 |
+
### 1. Start the Backend API Server
|
| 268 |
+
|
| 269 |
+
From the repository root with the virtual environment active:
|
| 270 |
+
|
| 271 |
+
```bash
|
| 272 |
+
uvicorn backend.app.main:app --host 0.0.0.0 --port 8000 --reload
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
The API will be available at `http://localhost:8000`.
|
| 276 |
+
Interactive API documentation is available at `http://localhost:8000/docs`.
|
| 277 |
+
|
| 278 |
+
### 2. Start the Frontend Dashboard
|
| 279 |
+
|
| 280 |
+
```bash
|
| 281 |
+
cd frontend
|
| 282 |
+
npm run dev
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
The dashboard will be available at `http://localhost:5173`.
|
| 286 |
+
|
| 287 |
+
### 3. Train the Base Model (Stage 1)
|
| 288 |
+
|
| 289 |
+
Ensure `extracted_training_data.json` is present at the repository root (or rebuild it using `backend/scripts/data/build_dataset_from_images.py`).
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
python -m backend.scripts.training.train
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
Checkpoints are saved to `models/best_social_model.pth`. Training logs are written to `log/`.
|
| 296 |
+
|
| 297 |
+
### 4. Train the Fusion Model (Stage 2)
|
| 298 |
+
|
| 299 |
+
```bash
|
| 300 |
+
python -m backend.scripts.training.train_phase2_fusion
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
The fusion model initializes from the base checkpoint and trains with LiDAR and radar features using differential learning rates. The output checkpoint is saved to `models/best_social_model_fusion.pth`.
|
| 304 |
+
|
| 305 |
+
### 5. Evaluate Models
|
| 306 |
+
|
| 307 |
+
```bash
|
| 308 |
+
# Base model
|
| 309 |
+
python -m backend.scripts.evaluation.evaluate
|
| 310 |
+
|
| 311 |
+
# Fusion model
|
| 312 |
+
python -m backend.scripts.evaluation.evaluate_phase2_fusion
|
| 313 |
+
|
| 314 |
+
# Runtime latency benchmark
|
| 315 |
+
python -m backend.scripts.evaluation.benchmark_perf
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
---
|
| 319 |
+
|
| 320 |
+
## API Endpoints
|
| 321 |
+
|
| 322 |
+
| Method | Path | Description |
|
| 323 |
+
|---|---|---|
|
| 324 |
+
| GET | `/api/health` | Service health check |
|
| 325 |
+
| GET | `/api/live/frame` | Retrieve the latest processed camera frame |
|
| 326 |
+
| POST | `/api/predict` | Run trajectory prediction on a submitted scene |
|
| 327 |
+
|
| 328 |
+
The prediction endpoint returns a structured payload including multimodal trajectories, per-mode probabilities, agent detections, sensor summary, and scene geometry.
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## Training Strategy
|
| 333 |
+
|
| 334 |
+
Training follows a two-stage transfer learning approach:
|
| 335 |
+
|
| 336 |
+
**Stage 1 — Social Trajectory Transformer**
|
| 337 |
+
Train the base model end-to-end using only camera-derived BEV trajectories. The model learns social interaction patterns, goal-conditioned decoding, and multimodal prediction from kinematic features alone.
|
| 338 |
+
|
| 339 |
+
**Stage 2 — Fusion Transfer Learning**
|
| 340 |
+
Initialize the fusion model from the Stage 1 checkpoint. Add the LiDAR and radar input branch and fine-tune using differential learning rates — lower rates for the pre-trained transformer backbone and higher rates for the new fusion branch. This preserves learned social behavior while adapting to richer sensor signals.
|
| 341 |
+
|
| 342 |
+
**Optimization:**
|
| 343 |
+
- Optimizer: Adam
|
| 344 |
+
- LR scheduling: ReduceLROnPlateau
|
| 345 |
+
- Early stopping with best checkpoint selection based on minADE
|
| 346 |
+
|
| 347 |
+
---
|
| 348 |
+
|
| 349 |
+
## Robustness Analysis
|
| 350 |
+
|
| 351 |
+
**Noise & Motion Stability:** Data augmentation (rotation, flip, Gaussian noise) improves generalization. Radar fusion stabilizes motion estimation. Multi-modal outputs reduce prediction failure in edge cases.
|
| 352 |
+
|
| 353 |
+
**Lighting Conditions:** Camera performance degrades in low-light conditions. LiDAR and Radar remain reliable regardless of lighting. Multi-sensor fusion reduces dependency on visual quality alone.
|
| 354 |
+
|
| 355 |
+
**Occlusion Handling:** Motion history + social context encoding allows the model to predict agent positions even when temporarily invisible. Radar supports cross-traffic awareness for agents occluded by large vehicles. Long-term occlusion remains an open challenge for future work.
|
| 356 |
+
|
| 357 |
+
---
|
| 358 |
+
|
| 359 |
+
## Sample Training Output
|
| 360 |
+
|
| 361 |
+
```
|
| 362 |
+
Train Loss: 2.1834
|
| 363 |
+
ADE: 0.5491, FDE: 1.0873
|
| 364 |
+
Current Learning Rate: 0.0005
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
+
## Output Visualizations
|
| 370 |
+
|
| 371 |
+

|
| 372 |
+

|
| 373 |
+

|
| 374 |
+

|
| 375 |
+
|
| 376 |
+
---
|
| 377 |
+
|
| 378 |
+
## References
|
| 379 |
+
|
| 380 |
+
- Attention Is All You Need — https://arxiv.org/abs/1706.03762
|
| 381 |
+
- Trajectron++ — https://arxiv.org/abs/2001.03093
|
| 382 |
+
- nuScenes Dataset Paper — https://arxiv.org/abs/1903.11027
|
| 383 |
+
- BEVFormer — https://arxiv.org/abs/2203.17270
|
| 384 |
+
- BEVFusion — https://arxiv.org/abs/2205.13542
|
| 385 |
+
|
| 386 |
+
---
|
| 387 |
+
|
| 388 |
+
## License
|
| 389 |
+
|
| 390 |
+
This project is licensed under the terms of the MIT License. See [LICENSE](LICENSE) for details.
|
backend/README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend (Phase 1)
|
| 2 |
+
|
| 3 |
+
This backend exposes your existing CV + trajectory prediction pipeline through FastAPI.
|
| 4 |
+
|
| 5 |
+
## Folder Structure
|
| 6 |
+
|
| 7 |
+
```text
|
| 8 |
+
backend/
|
| 9 |
+
app/
|
| 10 |
+
api/
|
| 11 |
+
dependencies.py
|
| 12 |
+
routes/
|
| 13 |
+
health.py
|
| 14 |
+
live.py
|
| 15 |
+
predict.py
|
| 16 |
+
core/
|
| 17 |
+
serialization.py
|
| 18 |
+
uploads.py
|
| 19 |
+
ml/
|
| 20 |
+
inference.py
|
| 21 |
+
model.py
|
| 22 |
+
model_fusion.py
|
| 23 |
+
sensor_fusion.py
|
| 24 |
+
legacy/
|
| 25 |
+
dataset.py
|
| 26 |
+
dataset_fusion.py
|
| 27 |
+
data_loader.py
|
| 28 |
+
cv_perception.py
|
| 29 |
+
map_renderer.py
|
| 30 |
+
visualization.py
|
| 31 |
+
services/
|
| 32 |
+
pipeline.py
|
| 33 |
+
main.py
|
| 34 |
+
schemas.py
|
| 35 |
+
scripts/
|
| 36 |
+
training/
|
| 37 |
+
evaluation/
|
| 38 |
+
data/
|
| 39 |
+
tools/
|
| 40 |
+
legacy/
|
| 41 |
+
requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Notes:
|
| 45 |
+
- Runtime model and fusion logic is now under `backend/app/ml`.
|
| 46 |
+
- Legacy helper modules were moved under `backend/app/legacy`.
|
| 47 |
+
- Training, evaluation, and data scripts were moved under `backend/scripts/*`.
|
| 48 |
+
- Root-level `inference.py`, `model.py`, `model_fusion.py`, and `sensor_fusion.py` remain as compatibility wrappers.
|
| 49 |
+
|
| 50 |
+
## Run
|
| 51 |
+
|
| 52 |
+
From the repository root:
|
| 53 |
+
|
| 54 |
+
```powershell
|
| 55 |
+
.\\venv\\Scripts\\python.exe -m pip install -r backend/requirements.txt
|
| 56 |
+
.\\venv\\Scripts\\python.exe -m uvicorn backend.app.main:app --reload --host 0.0.0.0 --port 8000
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Endpoints
|
| 60 |
+
|
| 61 |
+
- `GET /api/health`
|
| 62 |
+
- `GET /api/live/frames?channel=CAM_FRONT&limit=200`
|
| 63 |
+
- `GET /api/live/frame-image?path=<dataset_frame_path>`
|
| 64 |
+
- `POST /api/predict/two-image` (multipart form)
|
| 65 |
+
- `POST /api/predict/live-fusion` (JSON body)
|
| 66 |
+
|
| 67 |
+
## Phase 2 Scene Geometry
|
| 68 |
+
|
| 69 |
+
Prediction responses now include `scene_geometry` with image-grounded BEV primitives:
|
| 70 |
+
|
| 71 |
+
- `road_polygon`: camera-derived drivable area in BEV coordinates.
|
| 72 |
+
- `lane_lines`: lane candidates projected into BEV.
|
| 73 |
+
- `elements`: projected actor footprints from detections.
|
| 74 |
+
- `quality`: confidence score in `[0, 1]` for extracted scene structure.
|
| 75 |
+
|
| 76 |
+
Notes:
|
| 77 |
+
- The backend keeps using existing model files at repository root.
|
| 78 |
+
- Keep running from repo root so relative data paths remain stable.
|
backend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Backend package root."""
|
backend/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI backend package for Phase 1 migration."""
|
backend/app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API package for route modules and dependencies."""
|
backend/app/api/dependencies.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from ..services.pipeline import TrajectoryPipeline
|
| 6 |
+
|
| 7 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 8 |
+
pipeline = TrajectoryPipeline(repo_root=REPO_ROOT)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_pipeline() -> TrajectoryPipeline:
|
| 12 |
+
return pipeline
|
backend/app/api/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Route modules for the FastAPI application."""
|
backend/app/api/routes/health.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
from ..dependencies import pipeline
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.get("/health")
|
| 13 |
+
def health() -> dict[str, Any]:
|
| 14 |
+
return {
|
| 15 |
+
"status": "ok",
|
| 16 |
+
"using_fusion_model": pipeline.using_fusion_model,
|
| 17 |
+
"dataset_root": str(pipeline.data_root),
|
| 18 |
+
"dataset_exists": pipeline.data_root.exists(),
|
| 19 |
+
}
|
backend/app/api/routes/live.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import mimetypes
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 8 |
+
from fastapi.responses import FileResponse
|
| 9 |
+
|
| 10 |
+
from ..dependencies import pipeline
|
| 11 |
+
|
| 12 |
+
router = APIRouter()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resolve_dataset_frame_path(path_value: str) -> Path:
|
| 16 |
+
candidate = Path(path_value).expanduser()
|
| 17 |
+
if not candidate.is_absolute():
|
| 18 |
+
candidate = (pipeline.repo_root / candidate).resolve()
|
| 19 |
+
else:
|
| 20 |
+
candidate = candidate.resolve()
|
| 21 |
+
|
| 22 |
+
data_root = pipeline.data_root.resolve()
|
| 23 |
+
try:
|
| 24 |
+
candidate.relative_to(data_root)
|
| 25 |
+
except ValueError as exc:
|
| 26 |
+
raise HTTPException(status_code=403, detail="Frame path is outside DataSet root.") from exc
|
| 27 |
+
|
| 28 |
+
if not candidate.exists() or not candidate.is_file():
|
| 29 |
+
raise HTTPException(status_code=404, detail="Frame image was not found.")
|
| 30 |
+
|
| 31 |
+
if candidate.suffix.lower() not in {".jpg", ".jpeg", ".png", ".webp"}:
|
| 32 |
+
raise HTTPException(status_code=400, detail="Unsupported frame image file type.")
|
| 33 |
+
|
| 34 |
+
return candidate
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@router.get("/live/frames")
|
| 38 |
+
def list_live_frames(
|
| 39 |
+
channel: str = Query(default="CAM_FRONT"),
|
| 40 |
+
limit: int = Query(default=200, ge=1, le=2000),
|
| 41 |
+
) -> dict[str, Any]:
|
| 42 |
+
paths = pipeline.list_channel_image_paths(channel)
|
| 43 |
+
names = [p.name for p in paths[:limit]]
|
| 44 |
+
return {"channel": channel, "count": len(names), "frames": names}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@router.get("/live/frame-image")
|
| 48 |
+
def get_live_frame_image(path: str = Query(..., min_length=1)):
|
| 49 |
+
frame_path = resolve_dataset_frame_path(path)
|
| 50 |
+
media_type = mimetypes.guess_type(str(frame_path))[0] or "application/octet-stream"
|
| 51 |
+
return FileResponse(path=frame_path, media_type=media_type)
|
backend/app/api/routes/predict.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
| 4 |
+
|
| 5 |
+
from ...core.serialization import build_prediction_payload
|
| 6 |
+
from ...core.uploads import upload_to_rgb_array
|
| 7 |
+
from ...schemas import LiveFusionRequest, PredictionResponse
|
| 8 |
+
from ..dependencies import pipeline
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.post("/predict/two-image", response_model=PredictionResponse)
|
| 14 |
+
async def predict_two_image(
|
| 15 |
+
image_prev: UploadFile = File(...),
|
| 16 |
+
image_curr: UploadFile = File(...),
|
| 17 |
+
score_threshold: float = Form(0.35),
|
| 18 |
+
tracking_gate_px: float = Form(130.0),
|
| 19 |
+
min_motion_px: float = Form(0.0),
|
| 20 |
+
use_pose: bool = Form(False),
|
| 21 |
+
):
|
| 22 |
+
img_prev = await upload_to_rgb_array(image_prev)
|
| 23 |
+
img_curr = await upload_to_rgb_array(image_curr)
|
| 24 |
+
|
| 25 |
+
result = pipeline.build_two_image_agents_bundle(
|
| 26 |
+
img_prev=img_prev,
|
| 27 |
+
img_curr=img_curr,
|
| 28 |
+
score_threshold=float(score_threshold),
|
| 29 |
+
tracking_gate_px=float(tracking_gate_px),
|
| 30 |
+
min_motion_px=float(min_motion_px),
|
| 31 |
+
use_pose=bool(use_pose),
|
| 32 |
+
img_prev_name=image_prev.filename,
|
| 33 |
+
img_curr_name=image_curr.filename,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if "error" in result:
|
| 37 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 38 |
+
|
| 39 |
+
return build_prediction_payload(result)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.post("/predict/live-fusion", response_model=PredictionResponse)
|
| 43 |
+
def predict_live_fusion(req: LiveFusionRequest):
|
| 44 |
+
result = pipeline.build_live_agents_bundle(
|
| 45 |
+
anchor_idx=req.anchor_idx,
|
| 46 |
+
score_threshold=req.score_threshold,
|
| 47 |
+
tracking_gate_px=req.tracking_gate_px,
|
| 48 |
+
use_pose=req.use_pose,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if "error" in result:
|
| 52 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 53 |
+
|
| 54 |
+
return build_prediction_payload(result)
|
backend/app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core helpers for serialization and file handling."""
|
backend/app/core/serialization.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def to_jsonable(value: Any) -> Any:
|
| 9 |
+
if isinstance(value, np.ndarray):
|
| 10 |
+
return value.tolist()
|
| 11 |
+
if isinstance(value, (np.floating, np.integer)):
|
| 12 |
+
return value.item()
|
| 13 |
+
if isinstance(value, dict):
|
| 14 |
+
return {str(k): to_jsonable(v) for k, v in value.items()}
|
| 15 |
+
if isinstance(value, tuple):
|
| 16 |
+
return [to_jsonable(v) for v in value]
|
| 17 |
+
if isinstance(value, list):
|
| 18 |
+
return [to_jsonable(v) for v in value]
|
| 19 |
+
return value
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def serialize_agents(agents: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 23 |
+
serialized = []
|
| 24 |
+
for agent in agents:
|
| 25 |
+
serialized.append(
|
| 26 |
+
{
|
| 27 |
+
"id": int(agent.get("id", 0)),
|
| 28 |
+
"type": str(agent.get("type", "unknown")),
|
| 29 |
+
"raw_label": agent.get("raw_label"),
|
| 30 |
+
"history": [
|
| 31 |
+
{"x": float(pt[0]), "y": float(pt[1])}
|
| 32 |
+
for pt in agent.get("history", [])
|
| 33 |
+
],
|
| 34 |
+
"predictions": [
|
| 35 |
+
[{"x": float(pt[0]), "y": float(pt[1])} for pt in mode]
|
| 36 |
+
for mode in agent.get("predictions", [])
|
| 37 |
+
],
|
| 38 |
+
"probabilities": [float(p) for p in agent.get("probabilities", [])],
|
| 39 |
+
"is_target": bool(agent.get("is_target", False)),
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
return serialized
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_prediction_payload(result: dict[str, Any]) -> dict[str, Any]:
|
| 46 |
+
core_excludes = {
|
| 47 |
+
"agents",
|
| 48 |
+
"target_track_id",
|
| 49 |
+
"mode",
|
| 50 |
+
"camera_snapshots",
|
| 51 |
+
"fusion_data",
|
| 52 |
+
"scene_geometry",
|
| 53 |
+
"error",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
payload: dict[str, Any] = {
|
| 57 |
+
"mode": result.get("mode", "unknown"),
|
| 58 |
+
"target_track_id": result.get("target_track_id"),
|
| 59 |
+
"agents": serialize_agents(result.get("agents", [])),
|
| 60 |
+
"meta": to_jsonable({k: v for k, v in result.items() if k not in core_excludes}),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
snapshots = result.get("camera_snapshots")
|
| 64 |
+
if snapshots:
|
| 65 |
+
payload["detections"] = {
|
| 66 |
+
name: {
|
| 67 |
+
"frame_path": snap.get("frame_path"),
|
| 68 |
+
"detections": to_jsonable(snap.get("detections", [])),
|
| 69 |
+
}
|
| 70 |
+
for name, snap in snapshots.items()
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
fusion_data = result.get("fusion_data")
|
| 74 |
+
if fusion_data:
|
| 75 |
+
payload["sensors"] = {
|
| 76 |
+
"sample_token": fusion_data.get("sample_token"),
|
| 77 |
+
"lidar_points": int(len(fusion_data.get("lidar_xy", []))),
|
| 78 |
+
"radar_points": int(len(fusion_data.get("radar_xy", []))),
|
| 79 |
+
"radar_channel_counts": to_jsonable(fusion_data.get("radar_channel_counts", {})),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if result.get("scene_geometry") is not None:
|
| 83 |
+
payload["scene_geometry"] = to_jsonable(result.get("scene_geometry"))
|
| 84 |
+
|
| 85 |
+
return payload
|
backend/app/core/uploads.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from fastapi import HTTPException, UploadFile
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
async def upload_to_rgb_array(upload: UploadFile) -> np.ndarray:
|
| 11 |
+
raw = await upload.read()
|
| 12 |
+
if not raw:
|
| 13 |
+
raise HTTPException(status_code=400, detail=f"Uploaded file '{upload.filename}' is empty.")
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
image = Image.open(io.BytesIO(raw)).convert("RGB")
|
| 17 |
+
except Exception as exc:
|
| 18 |
+
raise HTTPException(status_code=400, detail=f"Could not parse image '{upload.filename}': {exc}") from exc
|
| 19 |
+
|
| 20 |
+
return np.asarray(image)
|
backend/app/legacy/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Legacy helper modules preserved for compatibility."""
|
backend/app/legacy/cv_perception.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
|
| 4 |
+
from PIL import Image, ImageDraw
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
# Map COCO classes to our Hackathon targets
|
| 9 |
+
TARGET_CLASSES = {
|
| 10 |
+
1: 'Person',
|
| 11 |
+
2: 'Bicycle',
|
| 12 |
+
3: 'Car',
|
| 13 |
+
4: 'Motorcycle'
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def load_perception_model():
|
| 17 |
+
print("[System] Loading Faster R-CNN (ResNet-50-FPN Backbone)...")
|
| 18 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 19 |
+
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
|
| 20 |
+
model.eval()
|
| 21 |
+
return model, weights
|
| 22 |
+
|
| 23 |
+
def extract_features(img_path, model, weights, score_threshold=0.7):
|
| 24 |
+
image = Image.open(img_path).convert("RGB")
|
| 25 |
+
preprocess = weights.transforms()
|
| 26 |
+
input_batch = preprocess(image).unsqueeze(0)
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
prediction = model(input_batch)[0]
|
| 30 |
+
|
| 31 |
+
extracted = []
|
| 32 |
+
for i, box in enumerate(prediction['boxes']):
|
| 33 |
+
score = prediction['scores'][i].item()
|
| 34 |
+
label = prediction['labels'][i].item()
|
| 35 |
+
|
| 36 |
+
if score > score_threshold and label in TARGET_CLASSES:
|
| 37 |
+
box = box.tolist()
|
| 38 |
+
class_name = TARGET_CLASSES[label]
|
| 39 |
+
# Get bottom-center coordinate for BEV mapping
|
| 40 |
+
center_x = (box[0] + box[2]) / 2.0
|
| 41 |
+
bottom_y = box[3]
|
| 42 |
+
|
| 43 |
+
extracted.append({
|
| 44 |
+
'type': class_name,
|
| 45 |
+
'bbox': box,
|
| 46 |
+
'coord': (center_x, bottom_y)
|
| 47 |
+
})
|
| 48 |
+
return extracted, image
|
| 49 |
+
|
| 50 |
+
def calculate_distance(c1, c2):
|
| 51 |
+
return math.sqrt((c1[0] - c2[0])**2 + (c1[1] - c2[1])**2)
|
| 52 |
+
|
| 53 |
+
def process_frame_sequence(frame1_path, frame2_path, model, weights):
|
| 54 |
+
"""
|
| 55 |
+
Takes 2 sequential frames, detects objects, matches them to find movement,
|
| 56 |
+
and bridges the data to the AI Brain.
|
| 57 |
+
"""
|
| 58 |
+
print(f"\n[Step 1] Analyzing Frame T-1: {os.path.basename(frame1_path)}")
|
| 59 |
+
objs_f1, img1 = extract_features(frame1_path, model, weights)
|
| 60 |
+
|
| 61 |
+
print(f"[Step 2] Analyzing Frame T0: {os.path.basename(frame2_path)}")
|
| 62 |
+
objs_f2, img2 = extract_features(frame2_path, model, weights)
|
| 63 |
+
|
| 64 |
+
print("\n[Step 3] Temporal Tracking (Finding Moving Cyclists/Pedestrians)")
|
| 65 |
+
tracked_history = []
|
| 66 |
+
|
| 67 |
+
# Simple Tracking by linking nearest objects between Frame 1 and Frame 2
|
| 68 |
+
for obj2 in objs_f2:
|
| 69 |
+
best_match = None
|
| 70 |
+
min_dist = float('inf')
|
| 71 |
+
|
| 72 |
+
for obj1 in objs_f1:
|
| 73 |
+
if obj1['type'] == obj2['type']: # Must be same class
|
| 74 |
+
dist = calculate_distance(obj1['coord'], obj2['coord'])
|
| 75 |
+
if dist < 50.0: # Max pixel movement threshold between 2 frames
|
| 76 |
+
min_dist = dist
|
| 77 |
+
best_match = obj1
|
| 78 |
+
|
| 79 |
+
if best_match:
|
| 80 |
+
# Calculate movement vector (Velocity)
|
| 81 |
+
dx = obj2['coord'][0] - best_match['coord'][0]
|
| 82 |
+
dy = obj2['coord'][1] - best_match['coord'][1]
|
| 83 |
+
is_moving = abs(dx) > 1.0 or abs(dy) > 1.0
|
| 84 |
+
|
| 85 |
+
if is_moving and obj2['type'] in ['Person', 'Bicycle']:
|
| 86 |
+
print(f" -> Spotted Moving {obj2['type']}! dx: {dx:.2f}, dy: {dy:.2f}")
|
| 87 |
+
|
| 88 |
+
# Format: [(x_t-1, y_t-1), (x_t0, y_t0)]
|
| 89 |
+
# This is EXACTLY what the AI Brain needs!
|
| 90 |
+
history = [best_match['coord'], obj2['coord']]
|
| 91 |
+
|
| 92 |
+
tracked_history.append({
|
| 93 |
+
"type": obj2['type'],
|
| 94 |
+
"history": history
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
print(f"\n[Step 4] Handoff to AI Brain: Found {len(tracked_history)} moving VRUs.")
|
| 98 |
+
return tracked_history
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
# We will use two identical images to simulate the script architecture
|
| 102 |
+
# In reality, this would be image_001.jpg and image_002.jpg
|
| 103 |
+
import glob
|
| 104 |
+
cam_front_images = glob.glob("DataSet/samples/CAM_FRONT/*.jpg")
|
| 105 |
+
|
| 106 |
+
if len(cam_front_images) >= 2:
|
| 107 |
+
f1 = cam_front_images[0]
|
| 108 |
+
f2 = cam_front_images[1] # Next sequential frame
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
model, weights = load_perception_model()
|
| 112 |
+
vru_data_for_ai = process_frame_sequence(f1, f2, model, weights)
|
| 113 |
+
|
| 114 |
+
print("\n--- FINAL JSON PAYLOAD FOR TRANSFORMER MODEL ---")
|
| 115 |
+
for person in vru_data_for_ai:
|
| 116 |
+
print(f"Target: {person['type']}")
|
| 117 |
+
print(f"Historical Trajectory [T-1, T0]: {person['history']}")
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print("Model not loaded, but script structure is ready.")
|
backend/app/legacy/data_loader.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
DATA_ROOT = Path("DataSet/v1.0-mini")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_json(name):
|
| 8 |
+
with open(DATA_ROOT / f"{name}.json") as f:
|
| 9 |
+
return json.load(f)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_lookup(table):
|
| 13 |
+
return {item['token']: item for item in table}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_pedestrian_instances(sample_annotations, instances, categories):
|
| 17 |
+
cat_lookup = build_lookup(categories)
|
| 18 |
+
inst_lookup = build_lookup(instances)
|
| 19 |
+
|
| 20 |
+
pedestrian_instances = set()
|
| 21 |
+
|
| 22 |
+
for ann in sample_annotations:
|
| 23 |
+
inst = inst_lookup[ann['instance_token']]
|
| 24 |
+
category = cat_lookup[inst['category_token']]['name']
|
| 25 |
+
|
| 26 |
+
# Include pedestrians, bicycles, and motorcycles (Vulnerable Road Users)
|
| 27 |
+
if "pedestrian" in category or "bicycle" in category or "motorcycle" in category:
|
| 28 |
+
pedestrian_instances.add(ann['instance_token'])
|
| 29 |
+
|
| 30 |
+
return pedestrian_instances
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_trajectories(sample_annotations, pedestrian_instances):
|
| 34 |
+
ann_lookup = build_lookup(sample_annotations)
|
| 35 |
+
|
| 36 |
+
visited = set()
|
| 37 |
+
trajectories = []
|
| 38 |
+
|
| 39 |
+
for ann in sample_annotations:
|
| 40 |
+
if ann['token'] in visited:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
if ann['instance_token'] not in pedestrian_instances:
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
current = ann
|
| 47 |
+
while current['prev'] != "":
|
| 48 |
+
current = ann_lookup[current['prev']]
|
| 49 |
+
|
| 50 |
+
traj = []
|
| 51 |
+
|
| 52 |
+
while current:
|
| 53 |
+
visited.add(current['token'])
|
| 54 |
+
|
| 55 |
+
x, y, _ = current['translation']
|
| 56 |
+
traj.append([x, y])
|
| 57 |
+
|
| 58 |
+
if current['next'] == "":
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
current = ann_lookup[current['next']]
|
| 62 |
+
|
| 63 |
+
if len(traj) >= 16: # 4 past + 12 future (6 seconds)
|
| 64 |
+
trajectories.append(traj)
|
| 65 |
+
|
| 66 |
+
return trajectories
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def create_windows(trajectories):
|
| 70 |
+
import math
|
| 71 |
+
samples = []
|
| 72 |
+
|
| 73 |
+
for traj in trajectories:
|
| 74 |
+
# Require 16 frames: 4 history + 12 future
|
| 75 |
+
for i in range(len(traj) - 15):
|
| 76 |
+
|
| 77 |
+
# ---------------- MAIN TRAJECTORY ----------------
|
| 78 |
+
window = traj[i:i+16]
|
| 79 |
+
|
| 80 |
+
x3, y3 = window[3]
|
| 81 |
+
window = [[x - x3, y - y3] for x, y in window]
|
| 82 |
+
|
| 83 |
+
vel = []
|
| 84 |
+
for j in range(len(window)):
|
| 85 |
+
if j == 0:
|
| 86 |
+
vel.append([0, 0, 0, 0, 0])
|
| 87 |
+
else:
|
| 88 |
+
dx = window[j][0] - window[j-1][0]
|
| 89 |
+
dy = window[j][1] - window[j-1][1]
|
| 90 |
+
speed = math.hypot(dx, dy)
|
| 91 |
+
if speed > 1e-5:
|
| 92 |
+
sin_t = dy / speed
|
| 93 |
+
cos_t = dx / speed
|
| 94 |
+
else:
|
| 95 |
+
sin_t = 0.0
|
| 96 |
+
cos_t = 0.0
|
| 97 |
+
vel.append([dx, dy, speed, sin_t, cos_t])
|
| 98 |
+
|
| 99 |
+
obs = []
|
| 100 |
+
for j in range(4):
|
| 101 |
+
obs.append([
|
| 102 |
+
window[j][0],
|
| 103 |
+
window[j][1],
|
| 104 |
+
vel[j][0],
|
| 105 |
+
vel[j][1],
|
| 106 |
+
vel[j][2],
|
| 107 |
+
vel[j][3],
|
| 108 |
+
vel[j][4]
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
# Future is now 12 steps (6 seconds)
|
| 112 |
+
future = window[4:16]
|
| 113 |
+
|
| 114 |
+
# ---------------- NEIGHBORS ----------------
|
| 115 |
+
neighbors = []
|
| 116 |
+
|
| 117 |
+
for other_traj in trajectories:
|
| 118 |
+
if other_traj is traj:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
if len(other_traj) < i + 4:
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
x1, y1 = traj[i + 3] # Main trajectory center
|
| 125 |
+
x2, y2 = other_traj[i + 3]
|
| 126 |
+
|
| 127 |
+
dist = math.hypot(x1 - x2, y1 - y2)
|
| 128 |
+
|
| 129 |
+
# Expanded Social Radius to 50 meters to account for much longer timeframe
|
| 130 |
+
if dist < 50.0:
|
| 131 |
+
|
| 132 |
+
n_window = other_traj[i:i+4]
|
| 133 |
+
|
| 134 |
+
# Center around main trajectory's last observed timestep
|
| 135 |
+
n_window = [[x - x1, y - y1] for x, y in n_window]
|
| 136 |
+
|
| 137 |
+
vel_n = []
|
| 138 |
+
for j in range(len(n_window)):
|
| 139 |
+
if j == 0:
|
| 140 |
+
vel_n.append([0, 0, 0, 0, 0])
|
| 141 |
+
else:
|
| 142 |
+
dx = n_window[j][0] - n_window[j-1][0]
|
| 143 |
+
dy = n_window[j][1] - n_window[j-1][1]
|
| 144 |
+
speed = math.hypot(dx, dy)
|
| 145 |
+
if speed > 1e-5:
|
| 146 |
+
sin_t = dy / speed
|
| 147 |
+
cos_t = dx / speed
|
| 148 |
+
else:
|
| 149 |
+
sin_t = 0.0
|
| 150 |
+
cos_t = 0.0
|
| 151 |
+
vel_n.append([dx, dy, speed, sin_t, cos_t])
|
| 152 |
+
|
| 153 |
+
n_obs = []
|
| 154 |
+
for j in range(4):
|
| 155 |
+
n_obs.append([
|
| 156 |
+
n_window[j][0],
|
| 157 |
+
n_window[j][1],
|
| 158 |
+
vel_n[j][0],
|
| 159 |
+
vel_n[j][1],
|
| 160 |
+
vel_n[j][2],
|
| 161 |
+
vel_n[j][3],
|
| 162 |
+
vel_n[j][4]
|
| 163 |
+
])
|
| 164 |
+
|
| 165 |
+
neighbors.append(n_obs)
|
| 166 |
+
|
| 167 |
+
samples.append((obs, neighbors, future))
|
| 168 |
+
|
| 169 |
+
return samples
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def build_trajectories_with_sensor(sample_annotations, pedestrian_instances):
|
| 173 |
+
ann_lookup = build_lookup(sample_annotations)
|
| 174 |
+
|
| 175 |
+
visited = set()
|
| 176 |
+
trajectories = []
|
| 177 |
+
|
| 178 |
+
for ann in sample_annotations:
|
| 179 |
+
if ann['token'] in visited:
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
if ann['instance_token'] not in pedestrian_instances:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
current = ann
|
| 186 |
+
while current['prev'] != "":
|
| 187 |
+
current = ann_lookup[current['prev']]
|
| 188 |
+
|
| 189 |
+
traj = []
|
| 190 |
+
|
| 191 |
+
while current:
|
| 192 |
+
visited.add(current['token'])
|
| 193 |
+
|
| 194 |
+
x, y, _ = current['translation']
|
| 195 |
+
traj.append({
|
| 196 |
+
'x': x,
|
| 197 |
+
'y': y,
|
| 198 |
+
'sample_token': current['sample_token'],
|
| 199 |
+
'num_lidar_pts': float(current.get('num_lidar_pts', 0.0)),
|
| 200 |
+
'num_radar_pts': float(current.get('num_radar_pts', 0.0)),
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
if current['next'] == "":
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
current = ann_lookup[current['next']]
|
| 207 |
+
|
| 208 |
+
if len(traj) >= 16:
|
| 209 |
+
trajectories.append(traj)
|
| 210 |
+
|
| 211 |
+
return trajectories
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def create_windows_with_sensor(trajectories):
|
| 215 |
+
import math
|
| 216 |
+
samples = []
|
| 217 |
+
|
| 218 |
+
for traj in trajectories:
|
| 219 |
+
for i in range(len(traj) - 15):
|
| 220 |
+
window = traj[i:i + 16]
|
| 221 |
+
|
| 222 |
+
x3, y3 = window[3]['x'], window[3]['y']
|
| 223 |
+
centered_xy = [[p['x'] - x3, p['y'] - y3] for p in window]
|
| 224 |
+
|
| 225 |
+
vel = []
|
| 226 |
+
for j in range(len(centered_xy)):
|
| 227 |
+
if j == 0:
|
| 228 |
+
vel.append([0, 0, 0, 0, 0])
|
| 229 |
+
else:
|
| 230 |
+
dx = centered_xy[j][0] - centered_xy[j - 1][0]
|
| 231 |
+
dy = centered_xy[j][1] - centered_xy[j - 1][1]
|
| 232 |
+
speed = math.hypot(dx, dy)
|
| 233 |
+
if speed > 1e-5:
|
| 234 |
+
sin_t = dy / speed
|
| 235 |
+
cos_t = dx / speed
|
| 236 |
+
else:
|
| 237 |
+
sin_t = 0.0
|
| 238 |
+
cos_t = 0.0
|
| 239 |
+
vel.append([dx, dy, speed, sin_t, cos_t])
|
| 240 |
+
|
| 241 |
+
obs = []
|
| 242 |
+
fusion_obs = []
|
| 243 |
+
for j in range(4):
|
| 244 |
+
obs.append([
|
| 245 |
+
centered_xy[j][0],
|
| 246 |
+
centered_xy[j][1],
|
| 247 |
+
vel[j][0],
|
| 248 |
+
vel[j][1],
|
| 249 |
+
vel[j][2],
|
| 250 |
+
vel[j][3],
|
| 251 |
+
vel[j][4],
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
lidar_pts = min(80.0, window[j]['num_lidar_pts']) / 80.0
|
| 255 |
+
radar_pts = min(30.0, window[j]['num_radar_pts']) / 30.0
|
| 256 |
+
sensor_strength = min(1.0, (window[j]['num_lidar_pts'] + 2.0 * window[j]['num_radar_pts']) / 100.0)
|
| 257 |
+
fusion_obs.append([lidar_pts, radar_pts, sensor_strength])
|
| 258 |
+
|
| 259 |
+
future = centered_xy[4:16]
|
| 260 |
+
|
| 261 |
+
neighbors = []
|
| 262 |
+
for other_traj in trajectories:
|
| 263 |
+
if other_traj is traj:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
if len(other_traj) < i + 4:
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
x1, y1 = traj[i + 3]['x'], traj[i + 3]['y']
|
| 270 |
+
x2, y2 = other_traj[i + 3]['x'], other_traj[i + 3]['y']
|
| 271 |
+
|
| 272 |
+
dist = math.hypot(x1 - x2, y1 - y2)
|
| 273 |
+
if dist >= 50.0:
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
n_window = other_traj[i:i + 4]
|
| 277 |
+
n_window_xy = [[p['x'] - x1, p['y'] - y1] for p in n_window]
|
| 278 |
+
|
| 279 |
+
vel_n = []
|
| 280 |
+
for j in range(len(n_window_xy)):
|
| 281 |
+
if j == 0:
|
| 282 |
+
vel_n.append([0, 0, 0, 0, 0])
|
| 283 |
+
else:
|
| 284 |
+
dx = n_window_xy[j][0] - n_window_xy[j - 1][0]
|
| 285 |
+
dy = n_window_xy[j][1] - n_window_xy[j - 1][1]
|
| 286 |
+
speed = math.hypot(dx, dy)
|
| 287 |
+
if speed > 1e-5:
|
| 288 |
+
sin_t = dy / speed
|
| 289 |
+
cos_t = dx / speed
|
| 290 |
+
else:
|
| 291 |
+
sin_t = 0.0
|
| 292 |
+
cos_t = 0.0
|
| 293 |
+
vel_n.append([dx, dy, speed, sin_t, cos_t])
|
| 294 |
+
|
| 295 |
+
n_obs = []
|
| 296 |
+
for j in range(4):
|
| 297 |
+
n_obs.append([
|
| 298 |
+
n_window_xy[j][0],
|
| 299 |
+
n_window_xy[j][1],
|
| 300 |
+
vel_n[j][0],
|
| 301 |
+
vel_n[j][1],
|
| 302 |
+
vel_n[j][2],
|
| 303 |
+
vel_n[j][3],
|
| 304 |
+
vel_n[j][4],
|
| 305 |
+
])
|
| 306 |
+
|
| 307 |
+
neighbors.append(n_obs)
|
| 308 |
+
|
| 309 |
+
samples.append((obs, neighbors, fusion_obs, future))
|
| 310 |
+
|
| 311 |
+
return samples
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
print("Loading data...")
|
| 316 |
+
|
| 317 |
+
sample_annotations = load_json("sample_annotation")
|
| 318 |
+
instances = load_json("instance")
|
| 319 |
+
categories = load_json("category")
|
| 320 |
+
|
| 321 |
+
print("Filtering pedestrians...")
|
| 322 |
+
ped_instances = extract_pedestrian_instances(
|
| 323 |
+
sample_annotations, instances, categories
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
print("Building trajectories...")
|
| 327 |
+
trajectories = build_trajectories(sample_annotations, ped_instances)
|
| 328 |
+
|
| 329 |
+
print("Creating training samples...")
|
| 330 |
+
samples = create_windows(trajectories)
|
| 331 |
+
|
| 332 |
+
print("\n--- DEBUG ---")
|
| 333 |
+
obs, neighbors, future = samples[0]
|
| 334 |
+
|
| 335 |
+
print("Obs length:", len(obs))
|
| 336 |
+
print("Future length:", len(future))
|
| 337 |
+
print("Neighbors count:", len(neighbors))
|
| 338 |
+
|
| 339 |
+
if len(neighbors) > 0:
|
| 340 |
+
print("One neighbor shape:", len(neighbors[0]), len(neighbors[0][0]))
|
| 341 |
+
|
| 342 |
+
print(f"\nTotal trajectories: {len(trajectories)}")
|
| 343 |
+
print(f"Total samples: {len(samples)}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
main()
|
backend/app/legacy/dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import random
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
def augment_data(obs, neighbors, future):
|
| 7 |
+
# obs: (4, 7) tensor
|
| 8 |
+
# neighbors: list of (4, 7) tensors
|
| 9 |
+
# future: (12, 2) tensor
|
| 10 |
+
|
| 11 |
+
# Random Scene Rotation (0-360)
|
| 12 |
+
theta = random.uniform(0, 2 * math.pi)
|
| 13 |
+
cos_t = math.cos(theta)
|
| 14 |
+
sin_t = math.sin(theta)
|
| 15 |
+
|
| 16 |
+
# Random X-axis reflection
|
| 17 |
+
flip_x = random.choice([-1.0, 1.0])
|
| 18 |
+
|
| 19 |
+
# Gaussian Coordinate Noise
|
| 20 |
+
noise_std = 0.05
|
| 21 |
+
|
| 22 |
+
def apply_transform(feat, is_obs=True):
|
| 23 |
+
new_feat = feat.clone()
|
| 24 |
+
for i in range(new_feat.size(0)):
|
| 25 |
+
x, y = new_feat[i, 0].item(), new_feat[i, 1].item()
|
| 26 |
+
|
| 27 |
+
# Apply Noise
|
| 28 |
+
x += random.gauss(0, noise_std)
|
| 29 |
+
y += random.gauss(0, noise_std)
|
| 30 |
+
|
| 31 |
+
# Apply Flip
|
| 32 |
+
x *= flip_x
|
| 33 |
+
|
| 34 |
+
# Apply Rotation
|
| 35 |
+
nx = x * cos_t - y * sin_t
|
| 36 |
+
ny = x * sin_t + y * cos_t
|
| 37 |
+
|
| 38 |
+
new_feat[i, 0] = nx
|
| 39 |
+
new_feat[i, 1] = ny
|
| 40 |
+
|
| 41 |
+
if is_obs:
|
| 42 |
+
# Transform dx, dy
|
| 43 |
+
dx, dy = new_feat[i, 2].item(), new_feat[i, 3].item()
|
| 44 |
+
dx *= flip_x
|
| 45 |
+
ndx = dx * cos_t - dy * sin_t
|
| 46 |
+
ndy = dx * sin_t + dy * cos_t
|
| 47 |
+
new_feat[i, 2] = ndx
|
| 48 |
+
new_feat[i, 3] = ndy
|
| 49 |
+
|
| 50 |
+
# Recompute sin_t, cos_t based on new dx, dy to be safe
|
| 51 |
+
speed = math.hypot(ndx, ndy)
|
| 52 |
+
if speed > 1e-5:
|
| 53 |
+
new_feat[i, 5] = ndy / speed
|
| 54 |
+
new_feat[i, 6] = ndx / speed
|
| 55 |
+
else:
|
| 56 |
+
new_feat[i, 5] = 0.0
|
| 57 |
+
new_feat[i, 6] = 0.0
|
| 58 |
+
|
| 59 |
+
return new_feat
|
| 60 |
+
|
| 61 |
+
new_obs = apply_transform(obs, is_obs=True)
|
| 62 |
+
new_future = apply_transform(future, is_obs=False)
|
| 63 |
+
|
| 64 |
+
new_neighbors = []
|
| 65 |
+
for n in neighbors: # n is (4, 7) tensor
|
| 66 |
+
if not isinstance(n, torch.Tensor):
|
| 67 |
+
n = torch.tensor(n, dtype=torch.float32)
|
| 68 |
+
new_neighbors.append(apply_transform(n, is_obs=True))
|
| 69 |
+
|
| 70 |
+
return new_obs, new_neighbors, new_future
|
| 71 |
+
|
| 72 |
+
class TrajectoryDataset(Dataset):
|
| 73 |
+
def __init__(self, samples, augment=False):
|
| 74 |
+
self.obs = []
|
| 75 |
+
self.neighbors = []
|
| 76 |
+
self.future = []
|
| 77 |
+
self.augment = augment
|
| 78 |
+
|
| 79 |
+
for obs, neighbors, future in samples:
|
| 80 |
+
self.obs.append(obs)
|
| 81 |
+
self.neighbors.append(neighbors)
|
| 82 |
+
self.future.append(future)
|
| 83 |
+
|
| 84 |
+
# Convert to tensors
|
| 85 |
+
self.obs = torch.tensor(self.obs, dtype=torch.float32)
|
| 86 |
+
self.future = torch.tensor(self.future, dtype=torch.float32)
|
| 87 |
+
# Neighbors remain lists of matrices, will convert in getitem or augment
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
return len(self.obs)
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, idx):
|
| 93 |
+
obs = self.obs[idx].clone()
|
| 94 |
+
future = self.future[idx].clone()
|
| 95 |
+
neighbors = [torch.tensor(n, dtype=torch.float32) for n in self.neighbors[idx]]
|
| 96 |
+
|
| 97 |
+
if self.augment:
|
| 98 |
+
obs, neighbors, future = augment_data(obs, neighbors, future)
|
| 99 |
+
|
| 100 |
+
return obs, neighbors, future
|
backend/app/legacy/dataset_fusion.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
from .dataset import augment_data
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FusionTrajectoryDataset(Dataset):
|
| 8 |
+
def __init__(self, samples, augment=False):
|
| 9 |
+
self.obs = []
|
| 10 |
+
self.neighbors = []
|
| 11 |
+
self.fusion = []
|
| 12 |
+
self.future = []
|
| 13 |
+
self.augment = augment
|
| 14 |
+
|
| 15 |
+
for obs, neighbors, fusion_obs, future in samples:
|
| 16 |
+
self.obs.append(obs)
|
| 17 |
+
self.neighbors.append(neighbors)
|
| 18 |
+
self.fusion.append(fusion_obs)
|
| 19 |
+
self.future.append(future)
|
| 20 |
+
|
| 21 |
+
self.obs = torch.tensor(self.obs, dtype=torch.float32)
|
| 22 |
+
self.fusion = torch.tensor(self.fusion, dtype=torch.float32)
|
| 23 |
+
self.future = torch.tensor(self.future, dtype=torch.float32)
|
| 24 |
+
|
| 25 |
+
def __len__(self):
|
| 26 |
+
return len(self.obs)
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, idx):
|
| 29 |
+
obs = self.obs[idx].clone()
|
| 30 |
+
fusion_obs = self.fusion[idx].clone()
|
| 31 |
+
future = self.future[idx].clone()
|
| 32 |
+
neighbors = [torch.tensor(n, dtype=torch.float32) for n in self.neighbors[idx]]
|
| 33 |
+
|
| 34 |
+
if self.augment:
|
| 35 |
+
obs, neighbors, future = augment_data(obs, neighbors, future)
|
| 36 |
+
|
| 37 |
+
return obs, neighbors, fusion_obs, future
|
backend/app/legacy/map_renderer.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
DATAROOT = './DataSet'
|
| 7 |
+
VERSION = 'v1.0-mini'
|
| 8 |
+
|
| 9 |
+
def get_map_mask():
|
| 10 |
+
"""
|
| 11 |
+
Since the vector map expansion (JSON API) is not included in the raw dataset,
|
| 12 |
+
we use the actual raw HD Map Raster Masks (PNGs) inherently included in the v1.0-mini dataset.
|
| 13 |
+
"""
|
| 14 |
+
map_json_path = os.path.join(DATAROOT, VERSION, 'map.json')
|
| 15 |
+
try:
|
| 16 |
+
with open(map_json_path, 'r') as f:
|
| 17 |
+
map_data = json.load(f)
|
| 18 |
+
|
| 19 |
+
# Grab the first available semantic prior map (binary mask of drivable area)
|
| 20 |
+
filename = map_data[0]['filename']
|
| 21 |
+
img_path = os.path.join(DATAROOT, filename)
|
| 22 |
+
|
| 23 |
+
if os.path.exists(img_path):
|
| 24 |
+
img = plt.imread(img_path)
|
| 25 |
+
return img
|
| 26 |
+
else:
|
| 27 |
+
print(f"Map image not found at {img_path}")
|
| 28 |
+
return None
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error loading map.json: {e}")
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
def render_map_patch(x_center, y_center, radius=50.0, ax=None):
|
| 34 |
+
"""
|
| 35 |
+
Simulates extracting an HD map patch by grabbing a corresponding
|
| 36 |
+
section of the full-scale dataset map mask and displaying it.
|
| 37 |
+
"""
|
| 38 |
+
if ax is None:
|
| 39 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
| 40 |
+
|
| 41 |
+
mask = get_map_mask()
|
| 42 |
+
if mask is None:
|
| 43 |
+
return ax
|
| 44 |
+
|
| 45 |
+
# nuScenes standard raster resolution is 10 pixels per meter (0.1m)
|
| 46 |
+
pixels_per_meter = 10
|
| 47 |
+
|
| 48 |
+
# Let's find an interesting visual patch in the massive 20000x20000 map
|
| 49 |
+
# We will offset heavily into the image so we don't just see black emptiness
|
| 50 |
+
offset_x = 8000
|
| 51 |
+
offset_y = 8500
|
| 52 |
+
|
| 53 |
+
x_min_px = int(offset_x + (x_center - radius) * pixels_per_meter)
|
| 54 |
+
x_max_px = int(offset_x + (x_center + radius) * pixels_per_meter)
|
| 55 |
+
y_min_px = int(offset_y + (y_center - radius) * pixels_per_meter)
|
| 56 |
+
y_max_px = int(offset_y + (y_center + radius) * pixels_per_meter)
|
| 57 |
+
|
| 58 |
+
# Prevent out of bounds
|
| 59 |
+
x_min_px, x_max_px = max(0, x_min_px), min(mask.shape[1], x_max_px)
|
| 60 |
+
y_min_px, y_max_px = max(0, y_min_px), min(mask.shape[0], y_max_px)
|
| 61 |
+
|
| 62 |
+
crop = mask[y_min_px:y_max_px, x_min_px:x_max_px]
|
| 63 |
+
|
| 64 |
+
# Convert grayscale mask to an RGBA mask to allow custom colors and true transparency in the visual
|
| 65 |
+
import numpy as np
|
| 66 |
+
# True means drivable area, false is background
|
| 67 |
+
colored_mask = np.zeros((crop.shape[0], crop.shape[1], 4), dtype=np.float32)
|
| 68 |
+
|
| 69 |
+
# Let's paint the drivable area road gray-blue with some opacity (e.g. 0.4)
|
| 70 |
+
# The road pixels in the original image are often 1.0 (or close to it)
|
| 71 |
+
road_pixels = crop > 0.5
|
| 72 |
+
|
| 73 |
+
# Paint road pixels (R=0.2, G=0.3, B=0.5, Alpha=0.3 for a technical blueprint look)
|
| 74 |
+
colored_mask[road_pixels] = [0.2, 0.3, 0.5, 0.3]
|
| 75 |
+
# Background remains perfectly transparent (Alpha=0)
|
| 76 |
+
|
| 77 |
+
# Use imshow with the explicit RGBA mask
|
| 78 |
+
ax.imshow(colored_mask,
|
| 79 |
+
extent=[x_center - radius, x_center + radius, y_center - radius, y_center + radius],
|
| 80 |
+
origin='lower', zorder=-1) # Z-order ensures map is behind all points
|
| 81 |
+
|
| 82 |
+
return ax
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
print("Loading Native HD Map Mask from Raw Dataset...")
|
| 86 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
| 87 |
+
|
| 88 |
+
# Test rendering a patch at center 0,0
|
| 89 |
+
render_map_patch(0, 0, radius=60.0, ax=ax)
|
| 90 |
+
|
| 91 |
+
# Draw a fake car path on top to prove it works
|
| 92 |
+
ax.plot([0, 10, 30],
|
| 93 |
+
[0, 5, 15],
|
| 94 |
+
'r*-', linewidth=3, markersize=10, label="Vehicle Trajectory")
|
| 95 |
+
|
| 96 |
+
ax.legend()
|
| 97 |
+
plt.grid(True, linestyle='--', alpha=0.5)
|
| 98 |
+
plt.title("Phase 3: Dataset-Native HD Map Raster Overlay")
|
| 99 |
+
plt.savefig("demo_raw_map.png", bbox_inches='tight')
|
| 100 |
+
plt.show()
|
| 101 |
+
print("Successfully generated 'demo_raw_map.png' using strictly internal dataset files!")
|
backend/app/legacy/visualization.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import matplotlib.patches as patches
|
| 3 |
+
import numpy as np
|
| 4 |
+
from ..ml.inference import predict
|
| 5 |
+
from .map_renderer import render_map_patch
|
| 6 |
+
|
| 7 |
+
def plot_scene(
|
| 8 |
+
points,
|
| 9 |
+
neighbor_points_list=None,
|
| 10 |
+
neighbor_types=None,
|
| 11 |
+
is_live_camera=False,
|
| 12 |
+
sensor_fusion=None,
|
| 13 |
+
presentation_mode=False,
|
| 14 |
+
max_vru_display=6,
|
| 15 |
+
):
|
| 16 |
+
if neighbor_points_list is None: sibling_pts = []
|
| 17 |
+
else: sibling_pts = neighbor_points_list
|
| 18 |
+
|
| 19 |
+
if neighbor_types is None: n_types = ['Car'] * len(sibling_pts)
|
| 20 |
+
else: n_types = neighbor_types
|
| 21 |
+
|
| 22 |
+
# Set up dark "Extreme 3D Mode" environment if it's Live Camera
|
| 23 |
+
plt.style.use('dark_background') if is_live_camera else plt.style.use('default')
|
| 24 |
+
fig = plt.figure(figsize=(14, 12))
|
| 25 |
+
ax = plt.gca()
|
| 26 |
+
|
| 27 |
+
# ---------------- EGO VEHICLE & CAMERA PERSPECTIVE ----------------
|
| 28 |
+
if is_live_camera:
|
| 29 |
+
# In live camera mode, we anchor the BEV map to the Ego car!
|
| 30 |
+
ego_x, ego_y = 0.0, -2.0
|
| 31 |
+
ax.set_facecolor('#0b0e14')
|
| 32 |
+
|
| 33 |
+
# Add Compass Directions
|
| 34 |
+
ax.text(0, 48, "N (Forward)", color="white", fontsize=14, weight="bold", ha="center")
|
| 35 |
+
ax.text(0, -8, "S (Rear)", color="white", fontsize=14, weight="bold", ha="center", alpha=0.5)
|
| 36 |
+
ax.text(32, ego_y, "E (Right)", color="white", fontsize=14, weight="bold", ha="left", alpha=0.5)
|
| 37 |
+
ax.text(-32, ego_y, "W (Left)", color="white", fontsize=14, weight="bold", ha="right", alpha=0.5)
|
| 38 |
+
|
| 39 |
+
plt.grid(True, linestyle='dotted', color='#1a2436', alpha=0.9, zorder=0)
|
| 40 |
+
|
| 41 |
+
theta = np.linspace(np.pi/3, 2 * np.pi/3, 50)
|
| 42 |
+
fov_range = 60
|
| 43 |
+
ax.fill_between(
|
| 44 |
+
[ego_x] + list(ego_x + fov_range * np.cos(theta)) + [ego_x],
|
| 45 |
+
[ego_y] + list(ego_y + fov_range * np.sin(theta)) + [ego_y],
|
| 46 |
+
color='#00ffff', alpha=0.1, zorder=1, label='360 Camera / LiDAR FOV'
|
| 47 |
+
)
|
| 48 |
+
car_rect = patches.Rectangle((ego_x - 1.2, ego_y - 2.5), 2.4, 5.0, linewidth=2, edgecolor='#00ffff', facecolor='#001a1a', zorder=7, label="Autonomous Ego Vehicle")
|
| 49 |
+
ax.add_patch(car_rect)
|
| 50 |
+
|
| 51 |
+
ax.set_xlim(-35, 35)
|
| 52 |
+
ax.set_ylim(-10, 50)
|
| 53 |
+
map_center_x, map_center_y = 0, 20
|
| 54 |
+
ego_ref = np.array([ego_x, ego_y], dtype=np.float32)
|
| 55 |
+
else:
|
| 56 |
+
map_center_x, map_center_y = points[-1][0], points[-1][1]
|
| 57 |
+
ego_x, ego_y = map_center_x - 12, map_center_y - 6
|
| 58 |
+
theta = np.linspace(-np.pi/6, np.pi/6, 50)
|
| 59 |
+
ax.fill_between(
|
| 60 |
+
[ego_x] + list(ego_x + 50 * np.cos(theta)) + [ego_x],
|
| 61 |
+
[ego_y] + list(ego_y + 50 * np.sin(theta)) + [ego_y],
|
| 62 |
+
color='cyan', alpha=0.15, zorder=2
|
| 63 |
+
)
|
| 64 |
+
car_rect = patches.Rectangle((ego_x - 2.4, ego_y - 1.0), 4.8, 2.0, linewidth=2, edgecolor='black', facecolor='cyan', zorder=7)
|
| 65 |
+
ax.add_patch(car_rect)
|
| 66 |
+
ax.set_xlim(map_center_x - 15, map_center_x + 35)
|
| 67 |
+
ax.set_ylim(map_center_y - 20, map_center_y + 20)
|
| 68 |
+
plt.grid(True, linestyle='solid', color='lightgray', alpha=0.5, zorder=1)
|
| 69 |
+
ego_ref = np.array([map_center_x, map_center_y], dtype=np.float32)
|
| 70 |
+
|
| 71 |
+
if not is_live_camera:
|
| 72 |
+
render_map_patch(map_center_x, map_center_y, radius=120.0, ax=ax)
|
| 73 |
+
|
| 74 |
+
# ---------------- Phase 1 Sensor Fusion Overlay ----------------
|
| 75 |
+
if is_live_camera and sensor_fusion is not None:
|
| 76 |
+
lidar_xy = sensor_fusion.get('lidar_xy', None)
|
| 77 |
+
radar_xy = sensor_fusion.get('radar_xy', None)
|
| 78 |
+
radar_vel = sensor_fusion.get('radar_vel', None)
|
| 79 |
+
|
| 80 |
+
if lidar_xy is not None and len(lidar_xy) > 0:
|
| 81 |
+
# Remove very-near ego returns to avoid halo clutter around the car.
|
| 82 |
+
r = np.hypot(lidar_xy[:, 0] - ego_ref[0], lidar_xy[:, 1] - ego_ref[1])
|
| 83 |
+
lidar_vis = lidar_xy[r > 6.0]
|
| 84 |
+
|
| 85 |
+
if presentation_mode:
|
| 86 |
+
step = 18 if len(lidar_vis) > 12000 else 10
|
| 87 |
+
lidar_plot = lidar_vis[::step] if len(lidar_vis) > 0 else lidar_vis
|
| 88 |
+
lidar_size = 3
|
| 89 |
+
lidar_alpha = 0.10
|
| 90 |
+
else:
|
| 91 |
+
lidar_plot = lidar_vis[::4] if len(lidar_vis) > 4000 else lidar_vis
|
| 92 |
+
lidar_size = 5
|
| 93 |
+
lidar_alpha = 0.18
|
| 94 |
+
|
| 95 |
+
ax.scatter(
|
| 96 |
+
lidar_plot[:, 0],
|
| 97 |
+
lidar_plot[:, 1],
|
| 98 |
+
s=lidar_size,
|
| 99 |
+
c='#22d3ee',
|
| 100 |
+
alpha=lidar_alpha,
|
| 101 |
+
linewidths=0,
|
| 102 |
+
label='LiDAR occupancy',
|
| 103 |
+
zorder=2,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if radar_xy is not None and len(radar_xy) > 0:
|
| 107 |
+
if presentation_mode and len(radar_xy) > 180:
|
| 108 |
+
radar_plot = radar_xy[::2]
|
| 109 |
+
else:
|
| 110 |
+
radar_plot = radar_xy
|
| 111 |
+
|
| 112 |
+
ax.scatter(
|
| 113 |
+
radar_plot[:, 0],
|
| 114 |
+
radar_plot[:, 1],
|
| 115 |
+
s=18 if presentation_mode else 24,
|
| 116 |
+
c='#facc15',
|
| 117 |
+
alpha=0.78 if presentation_mode else 0.85,
|
| 118 |
+
edgecolors='black',
|
| 119 |
+
linewidths=0.5,
|
| 120 |
+
label='Radar returns (multi-ch)',
|
| 121 |
+
zorder=6,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if radar_vel is not None and len(radar_vel) == len(radar_xy):
|
| 125 |
+
speeds = np.hypot(radar_vel[:, 0], radar_vel[:, 1])
|
| 126 |
+
if presentation_mode:
|
| 127 |
+
idx = np.where(speeds > 0.6)[0]
|
| 128 |
+
if len(idx) > 18:
|
| 129 |
+
idx = idx[np.argsort(speeds[idx])[-18:]]
|
| 130 |
+
else:
|
| 131 |
+
step = max(1, len(radar_xy) // 40)
|
| 132 |
+
idx = np.arange(0, len(radar_xy), step)
|
| 133 |
+
|
| 134 |
+
for i in idx:
|
| 135 |
+
x0, y0 = radar_xy[i, 0], radar_xy[i, 1]
|
| 136 |
+
vx, vy = radar_vel[i, 0], radar_vel[i, 1]
|
| 137 |
+
ax.arrow(
|
| 138 |
+
x0,
|
| 139 |
+
y0,
|
| 140 |
+
vx * (0.45 if presentation_mode else 0.6),
|
| 141 |
+
vy * (0.45 if presentation_mode else 0.6),
|
| 142 |
+
head_width=0.45 if presentation_mode else 0.6,
|
| 143 |
+
head_length=0.6 if presentation_mode else 0.8,
|
| 144 |
+
fc='#fde68a',
|
| 145 |
+
ec='#facc15',
|
| 146 |
+
alpha=0.65 if presentation_mode else 0.75,
|
| 147 |
+
zorder=6,
|
| 148 |
+
length_includes_head=True,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# ---------------- MULTI-AGENT PREDICTIONS ----------------
|
| 152 |
+
color_map = {'Car': '#ffff00', 'Truck': '#ffaa00', 'Bus': '#ff8800', 'Person': '#ff00ff', 'Bike': '#ff5500'}
|
| 153 |
+
|
| 154 |
+
def build_agent_fusion_features(agent_points):
|
| 155 |
+
if sensor_fusion is None:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
lidar_xy = sensor_fusion.get('lidar_xy', None)
|
| 159 |
+
radar_xy = sensor_fusion.get('radar_xy', None)
|
| 160 |
+
|
| 161 |
+
if lidar_xy is None and radar_xy is None:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
feats = []
|
| 165 |
+
for px, py in agent_points:
|
| 166 |
+
if lidar_xy is not None and len(lidar_xy) > 0:
|
| 167 |
+
dl = np.hypot(lidar_xy[:, 0] - px, lidar_xy[:, 1] - py)
|
| 168 |
+
lidar_cnt = int((dl < 2.0).sum())
|
| 169 |
+
else:
|
| 170 |
+
lidar_cnt = 0
|
| 171 |
+
|
| 172 |
+
if radar_xy is not None and len(radar_xy) > 0:
|
| 173 |
+
dr = np.hypot(radar_xy[:, 0] - px, radar_xy[:, 1] - py)
|
| 174 |
+
radar_cnt = int((dr < 2.5).sum())
|
| 175 |
+
else:
|
| 176 |
+
radar_cnt = 0
|
| 177 |
+
|
| 178 |
+
lidar_norm = min(80.0, float(lidar_cnt)) / 80.0
|
| 179 |
+
radar_norm = min(30.0, float(radar_cnt)) / 30.0
|
| 180 |
+
sensor_strength = min(1.0, (float(lidar_cnt) + 2.0 * float(radar_cnt)) / 100.0)
|
| 181 |
+
feats.append([lidar_norm, radar_norm, sensor_strength])
|
| 182 |
+
|
| 183 |
+
return feats
|
| 184 |
+
|
| 185 |
+
def classify_mode_direction(hist_x, hist_y, pred_x, pred_y):
|
| 186 |
+
if len(hist_x) < 2:
|
| 187 |
+
return 'Straight'
|
| 188 |
+
|
| 189 |
+
# Current motion heading from the last observed segment.
|
| 190 |
+
hx = hist_x[-1] - hist_x[-2]
|
| 191 |
+
hy = hist_y[-1] - hist_y[-2]
|
| 192 |
+
if np.hypot(hx, hy) < 1e-6:
|
| 193 |
+
hx, hy = 0.0, 1.0
|
| 194 |
+
|
| 195 |
+
# Predicted heading from current point to mode endpoint.
|
| 196 |
+
px = pred_x[-1] - hist_x[-1]
|
| 197 |
+
py = pred_y[-1] - hist_y[-1]
|
| 198 |
+
if np.hypot(px, py) < 1e-6:
|
| 199 |
+
return 'Straight'
|
| 200 |
+
|
| 201 |
+
angle_deg = np.degrees(np.arctan2(hx * py - hy * px, hx * px + hy * py))
|
| 202 |
+
|
| 203 |
+
if abs(angle_deg) <= 30:
|
| 204 |
+
return 'Straight'
|
| 205 |
+
if 30 < angle_deg < 140:
|
| 206 |
+
return 'Left'
|
| 207 |
+
if -140 < angle_deg < -30:
|
| 208 |
+
return 'Right'
|
| 209 |
+
return 'Backward'
|
| 210 |
+
|
| 211 |
+
all_agents_to_predict = [(points, 'Person (Primary)')]
|
| 212 |
+
for i, n_pts in enumerate(sibling_pts):
|
| 213 |
+
# We now run predictions for ANY vulnerable user (Person or Bicycle)
|
| 214 |
+
if is_live_camera and n_types[i] in ['Person', 'Bicycle']:
|
| 215 |
+
all_agents_to_predict.append((n_pts, f"{n_types[i]} {i}"))
|
| 216 |
+
|
| 217 |
+
# Keep the live demo readable by limiting displayed VRUs in presentation mode.
|
| 218 |
+
if is_live_camera and presentation_mode and len(all_agents_to_predict) > max_vru_display:
|
| 219 |
+
primary = all_agents_to_predict[0]
|
| 220 |
+
others = all_agents_to_predict[1:]
|
| 221 |
+
|
| 222 |
+
def _dist_to_ego(agent_entry):
|
| 223 |
+
pts = agent_entry[0]
|
| 224 |
+
if len(pts) == 0:
|
| 225 |
+
return 1e9
|
| 226 |
+
px, py = pts[-1][0], pts[-1][1]
|
| 227 |
+
return float(np.hypot(px - ego_ref[0], py - ego_ref[1]))
|
| 228 |
+
|
| 229 |
+
others = sorted(others, key=_dist_to_ego)
|
| 230 |
+
all_agents_to_predict = [primary] + others[: max(0, max_vru_display - 1)]
|
| 231 |
+
|
| 232 |
+
vru_mode_summaries = []
|
| 233 |
+
vru_counter = 1
|
| 234 |
+
|
| 235 |
+
# Predict and plot the future for all identified vulnerable users
|
| 236 |
+
for agent_pts, label in all_agents_to_predict:
|
| 237 |
+
fusion_feats = build_agent_fusion_features(agent_pts)
|
| 238 |
+
pred, probs, attn_weights = predict(agent_pts, sibling_pts, fusion_feats=fusion_feats)
|
| 239 |
+
tx, ty = [p[0] for p in agent_pts], [p[1] for p in agent_pts]
|
| 240 |
+
is_primary = 'Primary' in label
|
| 241 |
+
mode_direction_scores = {}
|
| 242 |
+
|
| 243 |
+
# Plot their history (tail)
|
| 244 |
+
plt.plot(tx, ty, color='white' if is_primary else '#ff00ff', linestyle='solid' if is_live_camera else 'dashed', linewidth=3, zorder=5)
|
| 245 |
+
if is_live_camera:
|
| 246 |
+
point_label = 'Primary VRU (t=0)' if is_primary else 'Target VRU (t=0)'
|
| 247 |
+
else:
|
| 248 |
+
point_label = f"{label} (t=0)"
|
| 249 |
+
plt.scatter(tx[-1], ty[-1], c='white' if is_primary else '#ff00ff', s=250 if is_primary else 150, edgecolors='black', linewidths=2, label=point_label, zorder=8)
|
| 250 |
+
|
| 251 |
+
# --- NEW: Add an extremely obvious Vector Arrow showing their Current Walking Direction ---
|
| 252 |
+
if len(tx) >= 2:
|
| 253 |
+
dx_dir = tx[-1] - tx[-2]
|
| 254 |
+
dy_dir = ty[-1] - ty[-2]
|
| 255 |
+
dir_mag = np.hypot(dx_dir, dy_dir)
|
| 256 |
+
if dir_mag > 0.01:
|
| 257 |
+
# The arrow dynamically scales to their movement speed and points exactly where they are headed!
|
| 258 |
+
arr_dx, arr_dy = (dx_dir/dir_mag)*3, (dy_dir/dir_mag)*3
|
| 259 |
+
ax.arrow(tx[-1], ty[-1], arr_dx, arr_dy, head_width=1.5, head_length=2.0, fc='#00ffff', ec='white', zorder=12, width=0.4, alpha=0.9)
|
| 260 |
+
|
| 261 |
+
# Plot their Future prediction paths
|
| 262 |
+
colors = ['#0088ff', '#ff8800', '#ff0044']
|
| 263 |
+
mode_curves = []
|
| 264 |
+
|
| 265 |
+
for mode_i in range(pred.shape[0]):
|
| 266 |
+
x_pred_raw = pred[mode_i][:, 0].numpy()
|
| 267 |
+
y_pred_raw = pred[mode_i][:, 1].numpy()
|
| 268 |
+
|
| 269 |
+
dx = x_pred_raw - x_pred_raw[0]
|
| 270 |
+
dy = y_pred_raw - y_pred_raw[0]
|
| 271 |
+
|
| 272 |
+
x_pred = tx[-1] + dx * (2.0 if is_live_camera else 4.0)
|
| 273 |
+
y_pred = ty[-1] + dy * (2.0 if is_live_camera else 4.0)
|
| 274 |
+
mode_curves.append((mode_i, x_pred, y_pred))
|
| 275 |
+
|
| 276 |
+
mode_direction = classify_mode_direction(tx, ty, x_pred, y_pred)
|
| 277 |
+
mode_prob = float(probs[mode_i].item())
|
| 278 |
+
mode_direction_scores[mode_direction] = mode_direction_scores.get(mode_direction, 0.0) + mode_prob
|
| 279 |
+
|
| 280 |
+
if presentation_mode and is_live_camera:
|
| 281 |
+
draw_modes = [int(np.argmax(probs.numpy()))]
|
| 282 |
+
else:
|
| 283 |
+
draw_modes = [m[0] for m in mode_curves]
|
| 284 |
+
|
| 285 |
+
for mode_i, x_pred, y_pred in mode_curves:
|
| 286 |
+
if mode_i not in draw_modes:
|
| 287 |
+
continue
|
| 288 |
+
plt.plot(
|
| 289 |
+
x_pred,
|
| 290 |
+
y_pred,
|
| 291 |
+
color=colors[mode_i],
|
| 292 |
+
linewidth=3.0 if presentation_mode else 2.5 + (0 if mode_i > 0 else 1),
|
| 293 |
+
alpha=0.9 if presentation_mode else (0.8 if mode_i == 0 else 0.4),
|
| 294 |
+
zorder=5,
|
| 295 |
+
)
|
| 296 |
+
for t in range(0, len(x_pred), 3 if presentation_mode else 2):
|
| 297 |
+
plt.scatter(
|
| 298 |
+
x_pred[t],
|
| 299 |
+
y_pred[t],
|
| 300 |
+
color=colors[mode_i],
|
| 301 |
+
alpha=max(0.35, 1.0 - (t / 12)),
|
| 302 |
+
s=28 if presentation_mode else 40,
|
| 303 |
+
zorder=6,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Per-agent Top-3 direction probabilities for live demo readability.
|
| 307 |
+
sorted_modes = sorted(mode_direction_scores.items(), key=lambda kv: kv[1], reverse=True)
|
| 308 |
+
top_modes = sorted_modes[:3]
|
| 309 |
+
vru_id = f"VRU-{vru_counter}" + ("*" if is_primary else "")
|
| 310 |
+
vru_mode_summaries.append((vru_id, top_modes))
|
| 311 |
+
|
| 312 |
+
if is_live_camera and (not presentation_mode) and len(top_modes) > 0:
|
| 313 |
+
primary_dir, primary_prob = top_modes[0]
|
| 314 |
+
ax.text(
|
| 315 |
+
tx[-1] + 0.8,
|
| 316 |
+
ty[-1] + 1.2,
|
| 317 |
+
f"{vru_id}: {primary_dir} {primary_prob*100:.0f}%",
|
| 318 |
+
fontsize=8,
|
| 319 |
+
color='white',
|
| 320 |
+
bbox=dict(facecolor='#111827', edgecolor='#60a5fa', alpha=0.8, boxstyle='round,pad=0.2'),
|
| 321 |
+
zorder=13
|
| 322 |
+
)
|
| 323 |
+
vru_counter += 1
|
| 324 |
+
|
| 325 |
+
# ---------------- PLOT NEIGHBORS (Vehicles/Trucks) ----------------
|
| 326 |
+
for i, n_pts in enumerate(sibling_pts):
|
| 327 |
+
if is_live_camera and n_types[i] in ['Person', 'Bicycle']:
|
| 328 |
+
continue # Already predicted above
|
| 329 |
+
|
| 330 |
+
n_type = n_types[i]
|
| 331 |
+
n_color = color_map.get(n_type, 'yellow')
|
| 332 |
+
n_x, n_y = [p[0] for p in n_pts], [p[1] for p in n_pts]
|
| 333 |
+
|
| 334 |
+
marker_size = 400 if n_type in ['Truck', 'Bus'] else 200
|
| 335 |
+
marker_shape = 's' if n_type in ['Truck', 'Bus'] else 'o'
|
| 336 |
+
|
| 337 |
+
plt.plot(n_x, n_y, color=n_color, linestyle=':', linewidth=2, zorder=4)
|
| 338 |
+
plt.scatter(n_x[-1], n_y[-1], c=n_color, marker=marker_shape, s=marker_size, edgecolors='white' if is_live_camera else 'black', linewidth=1.5, label=f'Moving ({n_type})', zorder=7)
|
| 339 |
+
|
| 340 |
+
# UI Embellishments
|
| 341 |
+
plt.title("Ego-Centric BEV Matrix: Multi-Agent Parallel Forecasting", color="white" if is_live_camera else "black", fontsize=20, weight='bold', pad=15)
|
| 342 |
+
plt.xlabel("X Lateral Offset (meters)", color="white" if is_live_camera else "black", weight='bold', fontsize=13)
|
| 343 |
+
plt.ylabel("Y Depth Offset (meters)", color="white" if is_live_camera else "black", weight='bold', fontsize=13)
|
| 344 |
+
|
| 345 |
+
if is_live_camera:
|
| 346 |
+
ax.tick_params(axis='both', colors='white', labelsize=11)
|
| 347 |
+
for spine in ax.spines.values():
|
| 348 |
+
spine.set_color('#94a3b8')
|
| 349 |
+
|
| 350 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 351 |
+
unique_labels, unique_handles = [], []
|
| 352 |
+
for h, l in zip(handles, labels):
|
| 353 |
+
if l not in unique_labels:
|
| 354 |
+
unique_labels.append(l)
|
| 355 |
+
unique_handles.append(h)
|
| 356 |
+
|
| 357 |
+
if is_live_camera:
|
| 358 |
+
leg = ax.legend(
|
| 359 |
+
unique_handles,
|
| 360 |
+
unique_labels,
|
| 361 |
+
loc='upper right',
|
| 362 |
+
fancybox=True,
|
| 363 |
+
framealpha=0.95,
|
| 364 |
+
facecolor='#111827',
|
| 365 |
+
edgecolor='#94a3b8',
|
| 366 |
+
fontsize=10,
|
| 367 |
+
title='Legend'
|
| 368 |
+
)
|
| 369 |
+
plt.setp(leg.get_texts(), color='white')
|
| 370 |
+
plt.setp(leg.get_title(), color='white', weight='bold')
|
| 371 |
+
|
| 372 |
+
if len(vru_mode_summaries) > 0:
|
| 373 |
+
summary_lines = ["Top-3 Direction Probabilities"]
|
| 374 |
+
summary_lines.append("VRU-* = primary target")
|
| 375 |
+
for vru_id, top_modes in vru_mode_summaries[:max_vru_display]:
|
| 376 |
+
mode_text = " | ".join([f"{name}:{prob*100:.0f}%" for name, prob in top_modes])
|
| 377 |
+
summary_lines.append(f"{vru_id} -> {mode_text}")
|
| 378 |
+
|
| 379 |
+
fig.subplots_adjust(right=0.80)
|
| 380 |
+
ax.text(
|
| 381 |
+
1.02,
|
| 382 |
+
0.62,
|
| 383 |
+
"\n".join(summary_lines),
|
| 384 |
+
transform=ax.transAxes,
|
| 385 |
+
va='top',
|
| 386 |
+
ha='left',
|
| 387 |
+
fontsize=9,
|
| 388 |
+
color='white',
|
| 389 |
+
bbox=dict(facecolor='#0f172a', edgecolor='#60a5fa', alpha=0.95, boxstyle='round,pad=0.4')
|
| 390 |
+
)
|
| 391 |
+
else:
|
| 392 |
+
leg = ax.legend(unique_handles, unique_labels, loc='upper left', bbox_to_anchor=(1.02, 1.0), fancybox=True, framealpha=0.9)
|
| 393 |
+
|
| 394 |
+
ax.set_aspect('equal', adjustable='box')
|
| 395 |
+
return fig
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
main_pedestrian = [(0, 0), (10, 0), (20, 0), (30, 0)]
|
| 399 |
+
plot_scene(main_pedestrian, is_live_camera=True)
|
backend/app/main.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from .api.dependencies import pipeline
|
| 6 |
+
from .api.routes.health import router as health_router
|
| 7 |
+
from .api.routes.live import get_live_frame_image, resolve_dataset_frame_path, router as live_router
|
| 8 |
+
from .api.routes.predict import router as predict_router
|
| 9 |
+
from .core.serialization import build_prediction_payload
|
| 10 |
+
|
| 11 |
+
def create_app() -> FastAPI:
|
| 12 |
+
app = FastAPI(
|
| 13 |
+
title="BEV Trajectory Backend",
|
| 14 |
+
version="0.2.0",
|
| 15 |
+
description="Structured FastAPI backend for CV + trajectory prediction",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
app.add_middleware(
|
| 19 |
+
CORSMiddleware,
|
| 20 |
+
allow_origins=["*"],
|
| 21 |
+
allow_credentials=True,
|
| 22 |
+
allow_methods=["*"],
|
| 23 |
+
allow_headers=["*"],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
app.include_router(health_router, prefix="/api", tags=["health"])
|
| 27 |
+
app.include_router(live_router, prefix="/api", tags=["live"])
|
| 28 |
+
app.include_router(predict_router, prefix="/api", tags=["predict"])
|
| 29 |
+
return app
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
app = create_app()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"app",
|
| 37 |
+
"create_app",
|
| 38 |
+
"pipeline",
|
| 39 |
+
"build_prediction_payload",
|
| 40 |
+
"resolve_dataset_frame_path",
|
| 41 |
+
"get_live_frame_image",
|
| 42 |
+
]
|
backend/app/ml/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Runtime ML modules used by the FastAPI pipeline."""
|
backend/app/ml/inference.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .model import TrajectoryTransformer
|
| 5 |
+
from .model_fusion import TrajectoryTransformerFusion
|
| 6 |
+
|
| 7 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 8 |
+
|
| 9 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 10 |
+
MODEL_DIR = REPO_ROOT / "models"
|
| 11 |
+
FUSION_CKPT = MODEL_DIR / "best_social_model_fusion.pth"
|
| 12 |
+
BASE_CKPT = MODEL_DIR / "best_social_model.pth"
|
| 13 |
+
|
| 14 |
+
# ----------------------------
|
| 15 |
+
# LOAD MODEL
|
| 16 |
+
# ----------------------------
|
| 17 |
+
USING_FUSION_MODEL = False
|
| 18 |
+
|
| 19 |
+
if FUSION_CKPT.exists():
|
| 20 |
+
model = TrajectoryTransformerFusion(fusion_dim=3).to(device)
|
| 21 |
+
try:
|
| 22 |
+
model.load_state_dict(torch.load(FUSION_CKPT, map_location=device))
|
| 23 |
+
USING_FUSION_MODEL = True
|
| 24 |
+
print("Inference: Loaded Phase 2 fusion checkpoint (best_social_model_fusion.pth).")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Warning: could not load fusion checkpoint ({e}). Falling back to base model.")
|
| 27 |
+
model = TrajectoryTransformer().to(device)
|
| 28 |
+
try:
|
| 29 |
+
model.load_state_dict(torch.load(BASE_CKPT, map_location=device))
|
| 30 |
+
print("Inference: Loaded base checkpoint (best_social_model.pth).")
|
| 31 |
+
except Exception as e2:
|
| 32 |
+
print(f"Warning: could not load base checkpoint ({e2}).")
|
| 33 |
+
else:
|
| 34 |
+
model = TrajectoryTransformer().to(device)
|
| 35 |
+
try:
|
| 36 |
+
model.load_state_dict(torch.load(BASE_CKPT, map_location=device))
|
| 37 |
+
print("Inference: Loaded base checkpoint (best_social_model.pth).")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Warning: could not load model weights ({e}), starting fresh.")
|
| 40 |
+
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ----------------------------
|
| 45 |
+
# PREPROCESS INPUT
|
| 46 |
+
# ----------------------------
|
| 47 |
+
def prepare_input(points):
|
| 48 |
+
import math
|
| 49 |
+
x3, y3 = points[3]
|
| 50 |
+
window = [[x - x3, y - y3] for x, y in points]
|
| 51 |
+
|
| 52 |
+
vel = []
|
| 53 |
+
for j in range(len(window)):
|
| 54 |
+
if j == 0:
|
| 55 |
+
vel.append([0, 0, 0, 0, 0])
|
| 56 |
+
else:
|
| 57 |
+
dx = window[j][0] - window[j-1][0]
|
| 58 |
+
dy = window[j][1] - window[j-1][1]
|
| 59 |
+
speed = math.hypot(dx, dy)
|
| 60 |
+
if speed > 1e-5:
|
| 61 |
+
sin_t = dy / speed
|
| 62 |
+
cos_t = dx / speed
|
| 63 |
+
else:
|
| 64 |
+
sin_t = 0.0
|
| 65 |
+
cos_t = 0.0
|
| 66 |
+
vel.append([dx, dy, speed, sin_t, cos_t])
|
| 67 |
+
|
| 68 |
+
obs = []
|
| 69 |
+
for j in range(4):
|
| 70 |
+
obs.append([
|
| 71 |
+
window[j][0],
|
| 72 |
+
window[j][1],
|
| 73 |
+
vel[j][0],
|
| 74 |
+
vel[j][1],
|
| 75 |
+
vel[j][2],
|
| 76 |
+
vel[j][3],
|
| 77 |
+
vel[j][4]
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
return obs, (x3, y3)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ----------------------------
|
| 84 |
+
# PREDICTION FUNCTION
|
| 85 |
+
# ----------------------------
|
| 86 |
+
def predict(points, neighbor_points_list=None, fusion_feats=None):
|
| 87 |
+
if neighbor_points_list is None:
|
| 88 |
+
neighbor_points_list = []
|
| 89 |
+
|
| 90 |
+
obs, origin = prepare_input(points)
|
| 91 |
+
|
| 92 |
+
obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) # (1,4,7)
|
| 93 |
+
|
| 94 |
+
# Prepare neighbors exactly as the main trajectory
|
| 95 |
+
import math
|
| 96 |
+
x1, y1 = points[-1]
|
| 97 |
+
neighbors = []
|
| 98 |
+
for np_points in neighbor_points_list:
|
| 99 |
+
n_window = [[x - x1, y - y1] for x, y in np_points]
|
| 100 |
+
vel_n = []
|
| 101 |
+
for j in range(len(n_window)):
|
| 102 |
+
if j == 0:
|
| 103 |
+
vel_n.append([0, 0, 0, 0, 0])
|
| 104 |
+
else:
|
| 105 |
+
dx = n_window[j][0] - n_window[j-1][0]
|
| 106 |
+
dy = n_window[j][1] - n_window[j-1][1]
|
| 107 |
+
speed = math.hypot(dx, dy)
|
| 108 |
+
if speed > 1e-5:
|
| 109 |
+
sin_t = dy / speed
|
| 110 |
+
cos_t = dx / speed
|
| 111 |
+
else:
|
| 112 |
+
sin_t = 0.0
|
| 113 |
+
cos_t = 0.0
|
| 114 |
+
vel_n.append([dx, dy, speed, sin_t, cos_t])
|
| 115 |
+
|
| 116 |
+
n_obs = []
|
| 117 |
+
for j in range(4):
|
| 118 |
+
n_obs.append([
|
| 119 |
+
n_window[j][0], n_window[j][1],
|
| 120 |
+
vel_n[j][0], vel_n[j][1], vel_n[j][2], vel_n[j][3], vel_n[j][4]
|
| 121 |
+
])
|
| 122 |
+
neighbors.append(n_obs)
|
| 123 |
+
|
| 124 |
+
neighbors_batch = [neighbors] # batch size = 1
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
if USING_FUSION_MODEL:
|
| 128 |
+
if fusion_feats is None:
|
| 129 |
+
fusion_tensor = torch.zeros((1, 4, 3), dtype=torch.float32, device=device)
|
| 130 |
+
else:
|
| 131 |
+
fusion_tensor = torch.tensor(fusion_feats, dtype=torch.float32).unsqueeze(0).to(device)
|
| 132 |
+
pred, goals, probs, attn_weights = model(obs, neighbors_batch, fusion_tensor)
|
| 133 |
+
else:
|
| 134 |
+
pred, goals, probs, attn_weights = model(obs, neighbors_batch)
|
| 135 |
+
|
| 136 |
+
pred = pred.squeeze(0).cpu()
|
| 137 |
+
probs = probs.squeeze(0).cpu()
|
| 138 |
+
|
| 139 |
+
if attn_weights and attn_weights[0] is not None:
|
| 140 |
+
attn_weights = [w.cpu() for w in attn_weights]
|
| 141 |
+
|
| 142 |
+
# convert back to real coordinates
|
| 143 |
+
x0, y0 = origin
|
| 144 |
+
pred_real = pred.clone()
|
| 145 |
+
|
| 146 |
+
pred_real[:, :, 0] += x0
|
| 147 |
+
pred_real[:, :, 1] += y0
|
| 148 |
+
|
| 149 |
+
return pred_real, probs, attn_weights
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ----------------------------
|
| 153 |
+
# DEMO RUN
|
| 154 |
+
# ----------------------------
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
|
| 157 |
+
points = [
|
| 158 |
+
(0, 0),
|
| 159 |
+
(10, 0),
|
| 160 |
+
(20, 0),
|
| 161 |
+
(30, 0)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
pred, probs, _ = predict(points)
|
| 165 |
+
|
| 166 |
+
print("\nInput Points:")
|
| 167 |
+
print(points)
|
| 168 |
+
|
| 169 |
+
print("\nPredicted Trajectories (Real Coordinates):")
|
| 170 |
+
for i in range(pred.shape[0]):
|
| 171 |
+
print(f"\nTrajectory {i+1} (prob={probs[i].item():.2f}):")
|
| 172 |
+
print(pred[i])
|
backend/app/ml/model.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class PositionalEncoding(nn.Module):
|
| 6 |
+
def __init__(self, d_model, max_len=100):
|
| 7 |
+
super().__init__()
|
| 8 |
+
pe = torch.zeros(max_len, d_model)
|
| 9 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 10 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 11 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 12 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 13 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
| 14 |
+
self.register_buffer('pe', pe)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return x + self.pe[:, :x.size(1), :]
|
| 18 |
+
|
| 19 |
+
class TrajectoryTransformer(nn.Module):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.d_model = 64
|
| 23 |
+
|
| 24 |
+
# 1. Feature Embedding & Positional Encoding
|
| 25 |
+
self.embed = nn.Linear(7, self.d_model)
|
| 26 |
+
self.pos_enc = PositionalEncoding(self.d_model)
|
| 27 |
+
|
| 28 |
+
# 2. Transformer Sequence Encoder (Replaces LSTM)
|
| 29 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 30 |
+
d_model=self.d_model, nhead=4, dim_feedforward=256, batch_first=True
|
| 31 |
+
)
|
| 32 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| 33 |
+
|
| 34 |
+
# 3. Social Attention (Target queries Neighbors)
|
| 35 |
+
self.social_attn = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=4, batch_first=True)
|
| 36 |
+
|
| 37 |
+
self.K = 3 # number of future modes
|
| 38 |
+
|
| 39 |
+
# 4. GOAL-CONDITIONED ARCHITECTURE
|
| 40 |
+
# Base hidden context: Target (64) + Social (64) = 128
|
| 41 |
+
self.hidden_dim = 128
|
| 42 |
+
self.future_len = 12 # Now predicting 6 seconds into future
|
| 43 |
+
|
| 44 |
+
# Step A: Predict exactly K distinct endpoints (goals)
|
| 45 |
+
self.goal_head = nn.Sequential(
|
| 46 |
+
nn.Linear(self.hidden_dim, 64),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
nn.Linear(64, self.K * 2) # X, Y for K goals
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Step B: Given the encoded context PLUS a specific Goal, draw the path to get there
|
| 52 |
+
self.traj_head = nn.Sequential(
|
| 53 |
+
nn.Linear(self.hidden_dim + 2, 128),
|
| 54 |
+
nn.ReLU(),
|
| 55 |
+
nn.Linear(128, self.future_len * 2) # 12 steps to reach the destination
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# 5. Probabilities of each mode
|
| 59 |
+
self.prob_head = nn.Linear(self.hidden_dim, self.K)
|
| 60 |
+
|
| 61 |
+
# ----------------------------
|
| 62 |
+
# SOCIAL POOLING
|
| 63 |
+
# ----------------------------
|
| 64 |
+
def social_pool(self, h_target, neighbor_h_list, device):
|
| 65 |
+
if len(neighbor_h_list) == 0:
|
| 66 |
+
return torch.zeros(self.d_model, device=device), None
|
| 67 |
+
|
| 68 |
+
# h_target: (64) -> query: (1, 1, 64)
|
| 69 |
+
query = h_target.unsqueeze(0).unsqueeze(0)
|
| 70 |
+
|
| 71 |
+
# neighbor_h_list: N x 64 -> key, value: (1, N, 64)
|
| 72 |
+
neighbor_h_tensor = torch.stack(neighbor_h_list).unsqueeze(0)
|
| 73 |
+
|
| 74 |
+
# apply attention
|
| 75 |
+
attn_output, attn_weights = self.social_attn(query, neighbor_h_tensor, neighbor_h_tensor)
|
| 76 |
+
|
| 77 |
+
return attn_output.squeeze(0).squeeze(0), attn_weights.squeeze(0)
|
| 78 |
+
|
| 79 |
+
# ----------------------------
|
| 80 |
+
# FORWARD PASS
|
| 81 |
+
# ----------------------------
|
| 82 |
+
def forward(self, x, neighbors):
|
| 83 |
+
"""
|
| 84 |
+
x: (B, 4, 7)
|
| 85 |
+
neighbors: list of length B
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
B = x.size(0)
|
| 89 |
+
device = x.device
|
| 90 |
+
|
| 91 |
+
# Encode main trajectory sequence with Transformer
|
| 92 |
+
x_emb = self.embed(x)
|
| 93 |
+
x_emb = self.pos_enc(x_emb)
|
| 94 |
+
enc_out = self.transformer_encoder(x_emb)
|
| 95 |
+
h = enc_out[:, -1, :] # Grab context from last timestep (B, 64)
|
| 96 |
+
|
| 97 |
+
final_h = []
|
| 98 |
+
batch_attn_weights = []
|
| 99 |
+
|
| 100 |
+
# Loop through batch to handle variable size neighbors
|
| 101 |
+
for i in range(B):
|
| 102 |
+
h_target = h[i] # (64)
|
| 103 |
+
|
| 104 |
+
neighbor_h_list = []
|
| 105 |
+
for n in neighbors[i]:
|
| 106 |
+
n_tensor = torch.tensor(n, dtype=torch.float32, device=device).unsqueeze(0)
|
| 107 |
+
|
| 108 |
+
n_emb = self.pos_enc(self.embed(n_tensor))
|
| 109 |
+
n_enc_out = self.transformer_encoder(n_emb)
|
| 110 |
+
|
| 111 |
+
neighbor_h_list.append(n_enc_out[0, -1, :]) # (64)
|
| 112 |
+
|
| 113 |
+
# Social attention pooling
|
| 114 |
+
h_social, attn_weights = self.social_pool(h_target, neighbor_h_list, device)
|
| 115 |
+
batch_attn_weights.append(attn_weights)
|
| 116 |
+
|
| 117 |
+
# Combine Target and Social context
|
| 118 |
+
h_combined = torch.cat([h_target, h_social], dim=0) # (128)
|
| 119 |
+
final_h.append(h_combined)
|
| 120 |
+
|
| 121 |
+
h_final = torch.stack(final_h) # (B, 128)
|
| 122 |
+
|
| 123 |
+
# GOAL-CONDITIONED LOGIC
|
| 124 |
+
# 1. Predict Goals (End-points at t=6)
|
| 125 |
+
goals = self.goal_head(h_final)
|
| 126 |
+
goals = goals.view(B, self.K, 2) # (B, K, 2)
|
| 127 |
+
|
| 128 |
+
# 2. Condition trajectories on the predicted goals
|
| 129 |
+
trajs = []
|
| 130 |
+
for k in range(self.K):
|
| 131 |
+
goal_k = goals[:, k, :] # Get the k-th destination (B, 2)
|
| 132 |
+
# Concat the base context array with the goal coordinate!
|
| 133 |
+
conditioned_context = torch.cat([h_final, goal_k], dim=1) # (B, 130)
|
| 134 |
+
|
| 135 |
+
# Predict the path given the condition
|
| 136 |
+
traj_k = self.traj_head(conditioned_context).view(B, 1, self.future_len, 2)
|
| 137 |
+
trajs.append(traj_k)
|
| 138 |
+
|
| 139 |
+
traj = torch.cat(trajs, dim=1) # (B, K, 12, 2)
|
| 140 |
+
|
| 141 |
+
# 3. Mode Probabilities
|
| 142 |
+
probs = self.prob_head(h_final)
|
| 143 |
+
probs = torch.softmax(probs, dim=1)
|
| 144 |
+
|
| 145 |
+
return traj, goals, probs, batch_attn_weights
|
backend/app/ml/model_fusion.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PositionalEncoding(nn.Module):
|
| 8 |
+
def __init__(self, d_model, max_len=100):
|
| 9 |
+
super().__init__()
|
| 10 |
+
pe = torch.zeros(max_len, d_model)
|
| 11 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 12 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 13 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 14 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 15 |
+
pe = pe.unsqueeze(0)
|
| 16 |
+
self.register_buffer('pe', pe)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return x + self.pe[:, :x.size(1), :]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TrajectoryTransformerFusion(nn.Module):
|
| 23 |
+
def __init__(self, fusion_dim=3):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.d_model = 64
|
| 26 |
+
|
| 27 |
+
# Base kinematic embedding from original model features.
|
| 28 |
+
self.base_embed = nn.Linear(7, self.d_model)
|
| 29 |
+
|
| 30 |
+
# Fusion branch: LiDAR/Radar strength features per timestep.
|
| 31 |
+
self.fusion_embed = nn.Linear(fusion_dim, self.d_model)
|
| 32 |
+
self.fusion_ln = nn.LayerNorm(self.d_model)
|
| 33 |
+
|
| 34 |
+
self.pos_enc = PositionalEncoding(self.d_model)
|
| 35 |
+
|
| 36 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 37 |
+
d_model=self.d_model,
|
| 38 |
+
nhead=4,
|
| 39 |
+
dim_feedforward=256,
|
| 40 |
+
batch_first=True,
|
| 41 |
+
)
|
| 42 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| 43 |
+
|
| 44 |
+
self.social_attn = nn.MultiheadAttention(embed_dim=self.d_model, num_heads=4, batch_first=True)
|
| 45 |
+
|
| 46 |
+
self.K = 3
|
| 47 |
+
self.hidden_dim = 128
|
| 48 |
+
self.future_len = 12
|
| 49 |
+
|
| 50 |
+
self.goal_head = nn.Sequential(
|
| 51 |
+
nn.Linear(self.hidden_dim, 64),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
nn.Linear(64, self.K * 2),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.traj_head = nn.Sequential(
|
| 57 |
+
nn.Linear(self.hidden_dim + 2, 128),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Linear(128, self.future_len * 2),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.prob_head = nn.Linear(self.hidden_dim, self.K)
|
| 63 |
+
|
| 64 |
+
def load_from_base_checkpoint(self, ckpt_path, map_location='cpu'):
|
| 65 |
+
state = torch.load(ckpt_path, map_location=map_location)
|
| 66 |
+
|
| 67 |
+
remap = {}
|
| 68 |
+
for k, v in state.items():
|
| 69 |
+
if k.startswith('embed.'):
|
| 70 |
+
remap['base_embed.' + k[len('embed.'):]] = v
|
| 71 |
+
else:
|
| 72 |
+
remap[k] = v
|
| 73 |
+
|
| 74 |
+
missing, unexpected = self.load_state_dict(remap, strict=False)
|
| 75 |
+
return missing, unexpected
|
| 76 |
+
|
| 77 |
+
def social_pool(self, h_target, neighbor_h_list, device):
|
| 78 |
+
if len(neighbor_h_list) == 0:
|
| 79 |
+
return torch.zeros(self.d_model, device=device), None
|
| 80 |
+
|
| 81 |
+
query = h_target.unsqueeze(0).unsqueeze(0)
|
| 82 |
+
neighbor_h_tensor = torch.stack(neighbor_h_list).unsqueeze(0)
|
| 83 |
+
attn_output, attn_weights = self.social_attn(query, neighbor_h_tensor, neighbor_h_tensor)
|
| 84 |
+
return attn_output.squeeze(0).squeeze(0), attn_weights.squeeze(0)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, neighbors, fusion_feats=None):
|
| 87 |
+
"""
|
| 88 |
+
x: (B, 4, 7)
|
| 89 |
+
neighbors: list length B, each element is list of neighbors with shape (4, 7)
|
| 90 |
+
fusion_feats: (B, 4, F) where F=3 [lidar_pts_norm, radar_pts_norm, sensor_strength]
|
| 91 |
+
"""
|
| 92 |
+
B = x.size(0)
|
| 93 |
+
device = x.device
|
| 94 |
+
|
| 95 |
+
x_emb = self.base_embed(x)
|
| 96 |
+
if fusion_feats is not None:
|
| 97 |
+
x_emb = self.fusion_ln(x_emb + self.fusion_embed(fusion_feats))
|
| 98 |
+
|
| 99 |
+
x_emb = self.pos_enc(x_emb)
|
| 100 |
+
enc_out = self.transformer_encoder(x_emb)
|
| 101 |
+
h = enc_out[:, -1, :]
|
| 102 |
+
|
| 103 |
+
final_h = []
|
| 104 |
+
batch_attn_weights = []
|
| 105 |
+
|
| 106 |
+
for i in range(B):
|
| 107 |
+
h_target = h[i]
|
| 108 |
+
neighbor_h_list = []
|
| 109 |
+
|
| 110 |
+
for n in neighbors[i]:
|
| 111 |
+
n_tensor = torch.as_tensor(n, dtype=torch.float32, device=device).unsqueeze(0)
|
| 112 |
+
n_emb = self.pos_enc(self.base_embed(n_tensor))
|
| 113 |
+
n_enc_out = self.transformer_encoder(n_emb)
|
| 114 |
+
neighbor_h_list.append(n_enc_out[0, -1, :])
|
| 115 |
+
|
| 116 |
+
h_social, attn_weights = self.social_pool(h_target, neighbor_h_list, device)
|
| 117 |
+
batch_attn_weights.append(attn_weights)
|
| 118 |
+
|
| 119 |
+
h_combined = torch.cat([h_target, h_social], dim=0)
|
| 120 |
+
final_h.append(h_combined)
|
| 121 |
+
|
| 122 |
+
h_final = torch.stack(final_h)
|
| 123 |
+
|
| 124 |
+
goals = self.goal_head(h_final).view(B, self.K, 2)
|
| 125 |
+
|
| 126 |
+
trajs = []
|
| 127 |
+
for k in range(self.K):
|
| 128 |
+
goal_k = goals[:, k, :]
|
| 129 |
+
conditioned_context = torch.cat([h_final, goal_k], dim=1)
|
| 130 |
+
traj_k = self.traj_head(conditioned_context).view(B, 1, self.future_len, 2)
|
| 131 |
+
trajs.append(traj_k)
|
| 132 |
+
|
| 133 |
+
traj = torch.cat(trajs, dim=1)
|
| 134 |
+
|
| 135 |
+
probs = self.prob_head(h_final)
|
| 136 |
+
probs = torch.softmax(probs, dim=1)
|
| 137 |
+
|
| 138 |
+
return traj, goals, probs, batch_attn_weights
|
backend/app/ml/sensor_fusion.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@lru_cache(maxsize=1)
|
| 10 |
+
def _load_sample_data_index(data_root: str, version: str):
|
| 11 |
+
sample_data_path = os.path.join(data_root, version, "sample_data.json")
|
| 12 |
+
with open(sample_data_path, "r", encoding="utf-8") as f:
|
| 13 |
+
records = json.load(f)
|
| 14 |
+
|
| 15 |
+
by_basename = {}
|
| 16 |
+
by_sample_token = defaultdict(list)
|
| 17 |
+
|
| 18 |
+
for rec in records:
|
| 19 |
+
basename = os.path.basename(rec.get("filename", ""))
|
| 20 |
+
if basename:
|
| 21 |
+
by_basename[basename] = rec
|
| 22 |
+
token = rec.get("sample_token")
|
| 23 |
+
if token:
|
| 24 |
+
by_sample_token[token].append(rec)
|
| 25 |
+
|
| 26 |
+
return by_basename, dict(by_sample_token)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@lru_cache(maxsize=1)
|
| 30 |
+
def _load_calibrated_sensor_index(data_root: str, version: str):
|
| 31 |
+
calib_path = os.path.join(data_root, version, "calibrated_sensor.json")
|
| 32 |
+
with open(calib_path, "r", encoding="utf-8") as f:
|
| 33 |
+
records = json.load(f)
|
| 34 |
+
return {r["token"]: r for r in records}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _channel_from_filename(rel_path: str) -> str:
|
| 38 |
+
parts = rel_path.replace("\\", "/").split("/")
|
| 39 |
+
if len(parts) >= 2:
|
| 40 |
+
return parts[1]
|
| 41 |
+
return ""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _quat_wxyz_to_rot(q):
|
| 45 |
+
# nuScenes stores quaternion as [w, x, y, z]
|
| 46 |
+
w, x, y, z = q
|
| 47 |
+
n = np.sqrt(w * w + x * x + y * y + z * z)
|
| 48 |
+
if n < 1e-12:
|
| 49 |
+
return np.eye(3, dtype=np.float32)
|
| 50 |
+
|
| 51 |
+
w, x, y, z = w / n, x / n, y / n, z / n
|
| 52 |
+
|
| 53 |
+
return np.array(
|
| 54 |
+
[
|
| 55 |
+
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
|
| 56 |
+
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
|
| 57 |
+
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
|
| 58 |
+
],
|
| 59 |
+
dtype=np.float32,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _transform_points_sensor_to_ego(points_xyz: np.ndarray, calib: dict):
|
| 64 |
+
if points_xyz.size == 0:
|
| 65 |
+
return points_xyz
|
| 66 |
+
|
| 67 |
+
if calib is None:
|
| 68 |
+
return points_xyz
|
| 69 |
+
|
| 70 |
+
rot = _quat_wxyz_to_rot(calib.get("rotation", [1.0, 0.0, 0.0, 0.0]))
|
| 71 |
+
t = np.asarray(calib.get("translation", [0.0, 0.0, 0.0]), dtype=np.float32)
|
| 72 |
+
|
| 73 |
+
# Row-vector form: p_ego = p_sensor * R^T + t
|
| 74 |
+
return points_xyz @ rot.T + t
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _transform_vel_sensor_to_ego(vel_xy: np.ndarray, calib: dict):
|
| 78 |
+
if vel_xy.size == 0:
|
| 79 |
+
return vel_xy
|
| 80 |
+
|
| 81 |
+
if calib is None:
|
| 82 |
+
return vel_xy
|
| 83 |
+
|
| 84 |
+
rot = _quat_wxyz_to_rot(calib.get("rotation", [1.0, 0.0, 0.0, 0.0]))
|
| 85 |
+
|
| 86 |
+
v_xyz = np.zeros((vel_xy.shape[0], 3), dtype=np.float32)
|
| 87 |
+
v_xyz[:, 0] = vel_xy[:, 0]
|
| 88 |
+
v_xyz[:, 1] = vel_xy[:, 1]
|
| 89 |
+
|
| 90 |
+
v_ego = v_xyz @ rot.T
|
| 91 |
+
return v_ego[:, :2].astype(np.float32)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _load_lidar_pcd_bin(file_path: str) -> np.ndarray:
|
| 95 |
+
arr = np.fromfile(file_path, dtype=np.float32)
|
| 96 |
+
if arr.size == 0:
|
| 97 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
# nuScenes lidar .pcd.bin is typically [x, y, z, intensity, ring_index]
|
| 100 |
+
if arr.size % 5 == 0:
|
| 101 |
+
pts = arr.reshape(-1, 5)[:, :3]
|
| 102 |
+
elif arr.size % 4 == 0:
|
| 103 |
+
pts = arr.reshape(-1, 4)[:, :3]
|
| 104 |
+
else:
|
| 105 |
+
usable = (arr.size // 3) * 3
|
| 106 |
+
pts = arr[:usable].reshape(-1, 3)
|
| 107 |
+
|
| 108 |
+
return pts.astype(np.float32)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _parse_pcd_binary(file_path: str):
|
| 112 |
+
# Minimal PCD parser for nuScenes radar files (DATA binary).
|
| 113 |
+
with open(file_path, "rb") as f:
|
| 114 |
+
raw = f.read()
|
| 115 |
+
|
| 116 |
+
header_end = raw.find(b"DATA binary")
|
| 117 |
+
if header_end == -1:
|
| 118 |
+
return {}
|
| 119 |
+
|
| 120 |
+
line_end = raw.find(b"\n", header_end)
|
| 121 |
+
if line_end == -1:
|
| 122 |
+
return {}
|
| 123 |
+
|
| 124 |
+
header_blob = raw[: line_end + 1].decode("utf-8", errors="ignore")
|
| 125 |
+
data_blob = raw[line_end + 1 :]
|
| 126 |
+
|
| 127 |
+
header = {}
|
| 128 |
+
for line in header_blob.splitlines():
|
| 129 |
+
line = line.strip()
|
| 130 |
+
if not line or line.startswith("#"):
|
| 131 |
+
continue
|
| 132 |
+
key, *vals = line.split()
|
| 133 |
+
header[key.upper()] = vals
|
| 134 |
+
|
| 135 |
+
fields = header.get("FIELDS", [])
|
| 136 |
+
sizes = [int(x) for x in header.get("SIZE", [])]
|
| 137 |
+
types = header.get("TYPE", [])
|
| 138 |
+
counts = [int(x) for x in header.get("COUNT", [])]
|
| 139 |
+
points = int(header.get("POINTS", ["0"])[0])
|
| 140 |
+
|
| 141 |
+
if not fields or not sizes or not types or not counts or points <= 0:
|
| 142 |
+
return {}
|
| 143 |
+
|
| 144 |
+
np_map = {
|
| 145 |
+
("F", 4): np.float32,
|
| 146 |
+
("F", 8): np.float64,
|
| 147 |
+
("I", 1): np.int8,
|
| 148 |
+
("I", 2): np.int16,
|
| 149 |
+
("I", 4): np.int32,
|
| 150 |
+
("U", 1): np.uint8,
|
| 151 |
+
("U", 2): np.uint16,
|
| 152 |
+
("U", 4): np.uint32,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
dtype_parts = []
|
| 156 |
+
expanded_fields = []
|
| 157 |
+
for field, size, typ, cnt in zip(fields, sizes, types, counts):
|
| 158 |
+
base = np_map.get((typ, size), np.float32)
|
| 159 |
+
if cnt == 1:
|
| 160 |
+
dtype_parts.append((field, base))
|
| 161 |
+
expanded_fields.append(field)
|
| 162 |
+
else:
|
| 163 |
+
for i in range(cnt):
|
| 164 |
+
name = f"{field}_{i}"
|
| 165 |
+
dtype_parts.append((name, base))
|
| 166 |
+
expanded_fields.append(name)
|
| 167 |
+
|
| 168 |
+
point_dtype = np.dtype(dtype_parts)
|
| 169 |
+
byte_need = point_dtype.itemsize * points
|
| 170 |
+
if len(data_blob) < byte_need:
|
| 171 |
+
return {}
|
| 172 |
+
|
| 173 |
+
rec = np.frombuffer(data_blob[:byte_need], dtype=point_dtype, count=points)
|
| 174 |
+
out = {}
|
| 175 |
+
for name in expanded_fields:
|
| 176 |
+
out[name] = rec[name]
|
| 177 |
+
return out
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _load_radar_pcd(file_path: str):
|
| 181 |
+
fields = _parse_pcd_binary(file_path)
|
| 182 |
+
if not fields:
|
| 183 |
+
return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 2), dtype=np.float32)
|
| 184 |
+
|
| 185 |
+
x = fields.get("x")
|
| 186 |
+
y = fields.get("y")
|
| 187 |
+
z = fields.get("z")
|
| 188 |
+
|
| 189 |
+
# Prefer compensated velocity fields when available.
|
| 190 |
+
vx = fields.get("vx_comp", fields.get("vx"))
|
| 191 |
+
vy = fields.get("vy_comp", fields.get("vy"))
|
| 192 |
+
|
| 193 |
+
if x is None or y is None or z is None:
|
| 194 |
+
return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 2), dtype=np.float32)
|
| 195 |
+
|
| 196 |
+
if vx is None:
|
| 197 |
+
vx = np.zeros_like(x)
|
| 198 |
+
if vy is None:
|
| 199 |
+
vy = np.zeros_like(y)
|
| 200 |
+
|
| 201 |
+
pts = np.stack([x, y, z], axis=1).astype(np.float32)
|
| 202 |
+
vel = np.stack([vx, vy], axis=1).astype(np.float32)
|
| 203 |
+
return pts, vel
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _ego_xyz_to_bev(points_xyz: np.ndarray):
|
| 207 |
+
# Ego frame: +x front, +y left, +z up
|
| 208 |
+
# BEV UI: +x right, +y forward
|
| 209 |
+
if points_xyz.size == 0:
|
| 210 |
+
return np.zeros((0, 2), dtype=np.float32)
|
| 211 |
+
x_bev = -points_xyz[:, 1]
|
| 212 |
+
y_bev = points_xyz[:, 0]
|
| 213 |
+
return np.stack([x_bev, y_bev], axis=1).astype(np.float32)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _ego_vel_to_bev(vxy_ego: np.ndarray):
|
| 217 |
+
if vxy_ego.size == 0:
|
| 218 |
+
return np.zeros((0, 2), dtype=np.float32)
|
| 219 |
+
vx_bev = -vxy_ego[:, 1]
|
| 220 |
+
vy_bev = vxy_ego[:, 0]
|
| 221 |
+
return np.stack([vx_bev, vy_bev], axis=1).astype(np.float32)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def load_fusion_for_cam_frame(cam_filename: str, data_root: str = "DataSet", version: str = "v1.0-mini"):
|
| 225 |
+
by_basename, by_sample = _load_sample_data_index(data_root, version)
|
| 226 |
+
calib_by_token = _load_calibrated_sensor_index(data_root, version)
|
| 227 |
+
|
| 228 |
+
basename = os.path.basename(cam_filename)
|
| 229 |
+
cam_rec = by_basename.get(basename)
|
| 230 |
+
if not cam_rec:
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
sample_token = cam_rec.get("sample_token")
|
| 234 |
+
if not sample_token:
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
related = by_sample.get(sample_token, [])
|
| 238 |
+
|
| 239 |
+
lidar_rec = None
|
| 240 |
+
radar_recs = {}
|
| 241 |
+
radar_channels = [
|
| 242 |
+
"RADAR_FRONT",
|
| 243 |
+
"RADAR_FRONT_LEFT",
|
| 244 |
+
"RADAR_FRONT_RIGHT",
|
| 245 |
+
"RADAR_BACK_LEFT",
|
| 246 |
+
"RADAR_BACK_RIGHT",
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
for rec in related:
|
| 250 |
+
rel = rec.get("filename", "")
|
| 251 |
+
if not rel.startswith("samples/"):
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
ch = _channel_from_filename(rel)
|
| 255 |
+
if ch == "LIDAR_TOP":
|
| 256 |
+
lidar_rec = rec
|
| 257 |
+
elif ch in radar_channels:
|
| 258 |
+
radar_recs[ch] = rec
|
| 259 |
+
|
| 260 |
+
lidar_bev = np.zeros((0, 2), dtype=np.float32)
|
| 261 |
+
lidar_path = None
|
| 262 |
+
|
| 263 |
+
if lidar_rec is not None:
|
| 264 |
+
lidar_path = os.path.join(data_root, lidar_rec.get("filename", ""))
|
| 265 |
+
if os.path.exists(lidar_path):
|
| 266 |
+
lidar_xyz = _load_lidar_pcd_bin(lidar_path)
|
| 267 |
+
lidar_calib = calib_by_token.get(lidar_rec.get("calibrated_sensor_token"))
|
| 268 |
+
lidar_xyz_ego = _transform_points_sensor_to_ego(lidar_xyz, lidar_calib)
|
| 269 |
+
lidar_bev = _ego_xyz_to_bev(lidar_xyz_ego)
|
| 270 |
+
|
| 271 |
+
radar_xy_list = []
|
| 272 |
+
radar_vel_list = []
|
| 273 |
+
radar_paths = {}
|
| 274 |
+
radar_channel_counts = {}
|
| 275 |
+
|
| 276 |
+
for ch in radar_channels:
|
| 277 |
+
rec = radar_recs.get(ch)
|
| 278 |
+
if rec is None:
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
p = os.path.join(data_root, rec.get("filename", ""))
|
| 282 |
+
radar_paths[ch] = p
|
| 283 |
+
if not os.path.exists(p):
|
| 284 |
+
radar_channel_counts[ch] = 0
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
radar_xyz, radar_vel_xy = _load_radar_pcd(p)
|
| 288 |
+
radar_calib = calib_by_token.get(rec.get("calibrated_sensor_token"))
|
| 289 |
+
|
| 290 |
+
radar_xyz_ego = _transform_points_sensor_to_ego(radar_xyz, radar_calib)
|
| 291 |
+
radar_vel_ego = _transform_vel_sensor_to_ego(radar_vel_xy, radar_calib)
|
| 292 |
+
|
| 293 |
+
radar_bev = _ego_xyz_to_bev(radar_xyz_ego)
|
| 294 |
+
radar_vel_bev = _ego_vel_to_bev(radar_vel_ego)
|
| 295 |
+
|
| 296 |
+
if radar_bev.size > 0:
|
| 297 |
+
m_ch = (
|
| 298 |
+
(radar_bev[:, 1] > -20.0)
|
| 299 |
+
& (radar_bev[:, 1] < 100.0)
|
| 300 |
+
& (radar_bev[:, 0] > -70.0)
|
| 301 |
+
& (radar_bev[:, 0] < 70.0)
|
| 302 |
+
)
|
| 303 |
+
radar_bev = radar_bev[m_ch]
|
| 304 |
+
radar_vel_bev = radar_vel_bev[m_ch]
|
| 305 |
+
|
| 306 |
+
radar_channel_counts[ch] = int(radar_bev.shape[0])
|
| 307 |
+
|
| 308 |
+
if radar_bev.size > 0:
|
| 309 |
+
radar_xy_list.append(radar_bev)
|
| 310 |
+
radar_vel_list.append(radar_vel_bev)
|
| 311 |
+
|
| 312 |
+
if radar_xy_list:
|
| 313 |
+
radar_bev_all = np.concatenate(radar_xy_list, axis=0).astype(np.float32)
|
| 314 |
+
radar_vel_all = np.concatenate(radar_vel_list, axis=0).astype(np.float32)
|
| 315 |
+
else:
|
| 316 |
+
radar_bev_all = np.zeros((0, 2), dtype=np.float32)
|
| 317 |
+
radar_vel_all = np.zeros((0, 2), dtype=np.float32)
|
| 318 |
+
|
| 319 |
+
# Keep interaction region for live BEV visualization.
|
| 320 |
+
if lidar_bev.size > 0:
|
| 321 |
+
m = (
|
| 322 |
+
(lidar_bev[:, 1] > -15.0)
|
| 323 |
+
& (lidar_bev[:, 1] < 85.0)
|
| 324 |
+
& (lidar_bev[:, 0] > -60.0)
|
| 325 |
+
& (lidar_bev[:, 0] < 60.0)
|
| 326 |
+
)
|
| 327 |
+
lidar_bev = lidar_bev[m]
|
| 328 |
+
|
| 329 |
+
if radar_bev_all.size > 0:
|
| 330 |
+
m = (
|
| 331 |
+
(radar_bev_all[:, 1] > -20.0)
|
| 332 |
+
& (radar_bev_all[:, 1] < 100.0)
|
| 333 |
+
& (radar_bev_all[:, 0] > -70.0)
|
| 334 |
+
& (radar_bev_all[:, 0] < 70.0)
|
| 335 |
+
)
|
| 336 |
+
radar_bev_all = radar_bev_all[m]
|
| 337 |
+
if radar_vel_all.shape[0] == m.shape[0]:
|
| 338 |
+
radar_vel_all = radar_vel_all[m]
|
| 339 |
+
|
| 340 |
+
return {
|
| 341 |
+
"sample_token": sample_token,
|
| 342 |
+
"lidar_xy": lidar_bev,
|
| 343 |
+
"radar_xy": radar_bev_all,
|
| 344 |
+
"radar_vel": radar_vel_all,
|
| 345 |
+
"lidar_path": lidar_path,
|
| 346 |
+
"radar_path": radar_paths.get("RADAR_FRONT"),
|
| 347 |
+
"radar_paths": radar_paths,
|
| 348 |
+
"radar_channel_counts": radar_channel_counts,
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def radar_stabilize_motion(tracked_agents, fusion_data, dt_seconds: float = 0.5):
|
| 353 |
+
if not fusion_data:
|
| 354 |
+
return tracked_agents
|
| 355 |
+
|
| 356 |
+
radar_xy = fusion_data.get("radar_xy")
|
| 357 |
+
radar_vel = fusion_data.get("radar_vel")
|
| 358 |
+
|
| 359 |
+
if radar_xy is None or radar_vel is None or len(radar_xy) == 0:
|
| 360 |
+
return tracked_agents
|
| 361 |
+
|
| 362 |
+
stabilized = []
|
| 363 |
+
|
| 364 |
+
for agent in tracked_agents:
|
| 365 |
+
if agent.get("type") not in ["Person", "Bicycle", "Car", "Truck", "Bus", "Motorcycle"]:
|
| 366 |
+
stabilized.append(agent)
|
| 367 |
+
continue
|
| 368 |
+
|
| 369 |
+
x_curr, y_curr = agent["history"][-1]
|
| 370 |
+
d = np.hypot(radar_xy[:, 0] - x_curr, radar_xy[:, 1] - y_curr)
|
| 371 |
+
near_idx = np.where(d < 3.0)[0]
|
| 372 |
+
|
| 373 |
+
if near_idx.size > 0:
|
| 374 |
+
rv = radar_vel[near_idx].mean(axis=0)
|
| 375 |
+
radar_dx = float(rv[0] * dt_seconds)
|
| 376 |
+
radar_dy = float(rv[1] * dt_seconds)
|
| 377 |
+
|
| 378 |
+
cam_dx = float(agent.get("dx", 0.0))
|
| 379 |
+
cam_dy = float(agent.get("dy", 0.0))
|
| 380 |
+
|
| 381 |
+
fused_dx = 0.7 * cam_dx + 0.3 * radar_dx
|
| 382 |
+
fused_dy = 0.7 * cam_dy + 0.3 * radar_dy
|
| 383 |
+
|
| 384 |
+
x4, y4 = x_curr, y_curr
|
| 385 |
+
h3 = (x4 - 3.0 * fused_dx, y4 - 3.0 * fused_dy)
|
| 386 |
+
h2 = (x4 - 2.0 * fused_dx, y4 - 2.0 * fused_dy)
|
| 387 |
+
h1 = (x4 - 1.0 * fused_dx, y4 - 1.0 * fused_dy)
|
| 388 |
+
|
| 389 |
+
agent = dict(agent)
|
| 390 |
+
agent["dx"] = fused_dx
|
| 391 |
+
agent["dy"] = fused_dy
|
| 392 |
+
agent["history"] = [h3, h2, h1, (x4, y4)]
|
| 393 |
+
|
| 394 |
+
stabilized.append(agent)
|
| 395 |
+
|
| 396 |
+
return stabilized
|
backend/app/schemas.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Point2D(BaseModel):
|
| 9 |
+
x: float
|
| 10 |
+
y: float
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AgentState(BaseModel):
|
| 14 |
+
id: int
|
| 15 |
+
type: str
|
| 16 |
+
raw_label: str | None = None
|
| 17 |
+
history: list[Point2D] = Field(default_factory=list)
|
| 18 |
+
predictions: list[list[Point2D]] = Field(default_factory=list)
|
| 19 |
+
probabilities: list[float] = Field(default_factory=list)
|
| 20 |
+
is_target: bool = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LiveFusionRequest(BaseModel):
|
| 24 |
+
anchor_idx: int = Field(default=3, ge=0)
|
| 25 |
+
score_threshold: float = Field(default=0.35, ge=0.0, le=1.0)
|
| 26 |
+
tracking_gate_px: float = Field(default=130.0, ge=1.0, le=500.0)
|
| 27 |
+
use_pose: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PredictionResponse(BaseModel):
|
| 31 |
+
mode: str
|
| 32 |
+
target_track_id: int | None = None
|
| 33 |
+
agents: list[AgentState] = Field(default_factory=list)
|
| 34 |
+
meta: dict[str, Any] = Field(default_factory=dict)
|
| 35 |
+
detections: dict[str, Any] | None = None
|
| 36 |
+
sensors: dict[str, Any] | None = None
|
| 37 |
+
scene_geometry: dict[str, Any] | None = None
|
| 38 |
+
|
| 39 |
+
model_config = ConfigDict(extra="allow")
|
backend/app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Service layer for model and perception pipelines."""
|
backend/app/services/pipeline.py
ADDED
|
@@ -0,0 +1,1255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import threading
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import cv2
|
| 20 |
+
except Exception:
|
| 21 |
+
cv2 = None
|
| 22 |
+
|
| 23 |
+
from torchvision.models.detection import (
|
| 24 |
+
FasterRCNN_ResNet50_FPN_Weights,
|
| 25 |
+
KeypointRCNN_ResNet50_FPN_Weights,
|
| 26 |
+
fasterrcnn_resnet50_fpn,
|
| 27 |
+
keypointrcnn_resnet50_fpn,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 31 |
+
from ..ml.inference import USING_FUSION_MODEL, predict as trajectory_predict
|
| 32 |
+
from ..ml.sensor_fusion import load_fusion_for_cam_frame, radar_stabilize_motion
|
| 33 |
+
|
| 34 |
+
COCO_TO_LABEL = {
|
| 35 |
+
1: "person",
|
| 36 |
+
2: "bicycle",
|
| 37 |
+
3: "car",
|
| 38 |
+
4: "motorcycle",
|
| 39 |
+
6: "bus",
|
| 40 |
+
8: "truck",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
VRU_LABELS = {"person", "bicycle", "motorcycle"}
|
| 44 |
+
VEHICLE_LABELS = {"car", "bus", "truck"}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@lru_cache(maxsize=1)
|
| 48 |
+
def _load_hd_map_indices(data_root: str, version: str) -> dict[str, Any]:
|
| 49 |
+
base = Path(data_root) / version
|
| 50 |
+
|
| 51 |
+
with open(base / "sample.json", "r", encoding="utf-8") as f:
|
| 52 |
+
samples = json.load(f)
|
| 53 |
+
with open(base / "sample_data.json", "r", encoding="utf-8") as f:
|
| 54 |
+
sample_data = json.load(f)
|
| 55 |
+
with open(base / "scene.json", "r", encoding="utf-8") as f:
|
| 56 |
+
scenes = json.load(f)
|
| 57 |
+
with open(base / "log.json", "r", encoding="utf-8") as f:
|
| 58 |
+
logs = json.load(f)
|
| 59 |
+
with open(base / "map.json", "r", encoding="utf-8") as f:
|
| 60 |
+
maps = json.load(f)
|
| 61 |
+
with open(base / "ego_pose.json", "r", encoding="utf-8") as f:
|
| 62 |
+
ego_poses = json.load(f)
|
| 63 |
+
|
| 64 |
+
sample_by_token = {r["token"]: r for r in samples}
|
| 65 |
+
scene_by_token = {r["token"]: r for r in scenes}
|
| 66 |
+
log_by_token = {r["token"]: r for r in logs}
|
| 67 |
+
ego_pose_by_token = {r["token"]: r for r in ego_poses}
|
| 68 |
+
|
| 69 |
+
sample_data_by_sample: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
| 70 |
+
sample_data_by_basename: dict[str, dict[str, Any]] = {}
|
| 71 |
+
for rec in sample_data:
|
| 72 |
+
sample_token = rec.get("sample_token")
|
| 73 |
+
if sample_token:
|
| 74 |
+
sample_data_by_sample[str(sample_token)].append(rec)
|
| 75 |
+
|
| 76 |
+
filename = rec.get("filename")
|
| 77 |
+
if filename:
|
| 78 |
+
sample_data_by_basename[Path(str(filename)).name] = rec
|
| 79 |
+
|
| 80 |
+
map_by_log_token: dict[str, dict[str, Any]] = {}
|
| 81 |
+
for rec in maps:
|
| 82 |
+
for log_token in rec.get("log_tokens", []):
|
| 83 |
+
map_by_log_token[str(log_token)] = rec
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"sample_by_token": sample_by_token,
|
| 87 |
+
"scene_by_token": scene_by_token,
|
| 88 |
+
"log_by_token": log_by_token,
|
| 89 |
+
"map_by_log_token": map_by_log_token,
|
| 90 |
+
"sample_data_by_sample": dict(sample_data_by_sample),
|
| 91 |
+
"sample_data_by_basename": sample_data_by_basename,
|
| 92 |
+
"ego_pose_by_token": ego_pose_by_token,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@lru_cache(maxsize=8)
|
| 97 |
+
def _get_map_size(map_path: str) -> tuple[int, int] | None:
|
| 98 |
+
p = Path(map_path)
|
| 99 |
+
if not p.exists():
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
with Image.open(p) as img:
|
| 103 |
+
w, h = img.size
|
| 104 |
+
return int(w), int(h)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _load_map_crop_gray(map_path: str, left: int, top: int, right: int, bottom: int) -> np.ndarray | None:
|
| 108 |
+
p = Path(map_path)
|
| 109 |
+
if not p.exists():
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
if right <= left or bottom <= top:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
with Image.open(p) as img:
|
| 116 |
+
crop = img.crop((int(left), int(top), int(right), int(bottom))).convert("L")
|
| 117 |
+
return np.asarray(crop, dtype=np.uint8)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _quat_wxyz_to_yaw(q: list[float] | tuple[float, float, float, float]) -> float:
|
| 121 |
+
if len(q) != 4:
|
| 122 |
+
return 0.0
|
| 123 |
+
|
| 124 |
+
w, x, y, z = [float(v) for v in q]
|
| 125 |
+
n = math.sqrt(w * w + x * x + y * y + z * z)
|
| 126 |
+
if n < 1e-12:
|
| 127 |
+
return 0.0
|
| 128 |
+
|
| 129 |
+
w, x, y, z = w / n, x / n, y / n, z / n
|
| 130 |
+
siny_cosp = 2.0 * (w * z + x * y)
|
| 131 |
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
| 132 |
+
return float(math.atan2(siny_cosp, cosy_cosp))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TrajectoryPipeline:
|
| 136 |
+
def __init__(self, repo_root: Path | None = None):
|
| 137 |
+
self.repo_root = Path(repo_root) if repo_root else REPO_ROOT
|
| 138 |
+
self.data_root = self.repo_root / "DataSet"
|
| 139 |
+
self._model_lock = threading.Lock()
|
| 140 |
+
self._models: dict[str, Any] | None = None
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def using_fusion_model(self) -> bool:
|
| 144 |
+
return bool(USING_FUSION_MODEL)
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def normalize_probs(probs: list[float] | np.ndarray) -> list[float]:
|
| 148 |
+
arr = np.asarray(probs, dtype=float)
|
| 149 |
+
arr = np.clip(arr, 1e-6, None)
|
| 150 |
+
arr = arr / arr.sum()
|
| 151 |
+
return arr.tolist()
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def coco_kind(label_name: str | None) -> str | None:
|
| 155 |
+
if label_name in VRU_LABELS:
|
| 156 |
+
return "pedestrian"
|
| 157 |
+
if label_name in VEHICLE_LABELS:
|
| 158 |
+
return "vehicle"
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def iou_xyxy(box_a: list[float], box_b: list[float]) -> float:
|
| 163 |
+
ax1, ay1, ax2, ay2 = box_a
|
| 164 |
+
bx1, by1, bx2, by2 = box_b
|
| 165 |
+
|
| 166 |
+
ix1 = max(ax1, bx1)
|
| 167 |
+
iy1 = max(ay1, by1)
|
| 168 |
+
ix2 = min(ax2, bx2)
|
| 169 |
+
iy2 = min(ay2, by2)
|
| 170 |
+
|
| 171 |
+
iw = max(0.0, ix2 - ix1)
|
| 172 |
+
ih = max(0.0, iy2 - iy1)
|
| 173 |
+
inter = iw * ih
|
| 174 |
+
|
| 175 |
+
area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
|
| 176 |
+
area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
|
| 177 |
+
union = area_a + area_b - inter
|
| 178 |
+
|
| 179 |
+
if union <= 1e-9:
|
| 180 |
+
return 0.0
|
| 181 |
+
return inter / union
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def pixel_to_bev(center_x: float, bottom_y: float, width: int, height: int) -> tuple[float, float]:
|
| 185 |
+
x_div = max(1.0, width / 80.0)
|
| 186 |
+
y_div = max(1.0, height / 50.0)
|
| 187 |
+
|
| 188 |
+
x_m = (center_x - 0.5 * width) / x_div
|
| 189 |
+
y_m = (bottom_y - 0.58 * height) / y_div
|
| 190 |
+
return float(x_m), float(y_m)
|
| 191 |
+
|
| 192 |
+
def list_channel_image_paths(self, channel: str) -> list[Path]:
|
| 193 |
+
base = self.data_root / "samples" / channel
|
| 194 |
+
if not base.exists():
|
| 195 |
+
return []
|
| 196 |
+
return sorted(base.glob("*.jpg"))
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def load_image_array(image_path: str | Path) -> np.ndarray:
|
| 200 |
+
return np.asarray(Image.open(image_path).convert("RGB"))
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def _clip_bev(x: float, y: float) -> tuple[float, float]:
|
| 204 |
+
return float(np.clip(x, -40.0, 40.0)), float(np.clip(y, -14.0, 62.0))
|
| 205 |
+
|
| 206 |
+
def _poly_px_to_bev_points(
|
| 207 |
+
self,
|
| 208 |
+
polygon_px: list[tuple[float, float]],
|
| 209 |
+
width: int,
|
| 210 |
+
height: int,
|
| 211 |
+
) -> list[dict[str, float]]:
|
| 212 |
+
out = []
|
| 213 |
+
for px, py in polygon_px:
|
| 214 |
+
bx, by = self.pixel_to_bev(float(px), float(py), width, height)
|
| 215 |
+
bx, by = self._clip_bev(bx, by)
|
| 216 |
+
out.append({"x": bx, "y": by})
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
def _project_detection_elements(
|
| 220 |
+
self,
|
| 221 |
+
detections: list[dict[str, Any]],
|
| 222 |
+
width: int,
|
| 223 |
+
height: int,
|
| 224 |
+
) -> list[dict[str, Any]]:
|
| 225 |
+
elements = []
|
| 226 |
+
|
| 227 |
+
for det in detections:
|
| 228 |
+
box = det.get("box")
|
| 229 |
+
if box is None or len(box) != 4:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
x1, y1, x2, y2 = [float(v) for v in box]
|
| 233 |
+
cx = 0.5 * (x1 + x2)
|
| 234 |
+
bx, by = self.pixel_to_bev(cx, y2, width, height)
|
| 235 |
+
bx, by = self._clip_bev(bx, by)
|
| 236 |
+
|
| 237 |
+
kind = str(det.get("kind", "vehicle"))
|
| 238 |
+
box_w_px = max(1.0, x2 - x1)
|
| 239 |
+
half_w = float(np.clip((box_w_px / max(1.0, width)) * 12.0, 0.25, 2.2))
|
| 240 |
+
length = 0.9 if kind == "pedestrian" else 2.1
|
| 241 |
+
|
| 242 |
+
polygon = [
|
| 243 |
+
{"x": bx - half_w, "y": by - 0.25 * length},
|
| 244 |
+
{"x": bx + half_w, "y": by - 0.25 * length},
|
| 245 |
+
{"x": bx + half_w, "y": by + length},
|
| 246 |
+
{"x": bx - half_w, "y": by + length},
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
elements.append(
|
| 250 |
+
{
|
| 251 |
+
"kind": kind,
|
| 252 |
+
"track_id": det.get("track_id"),
|
| 253 |
+
"score": float(det.get("score", 0.0)),
|
| 254 |
+
"polygon": polygon,
|
| 255 |
+
}
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return elements[:24]
|
| 259 |
+
|
| 260 |
+
def extract_scene_geometry(
|
| 261 |
+
self,
|
| 262 |
+
image_arr: np.ndarray,
|
| 263 |
+
detections: list[dict[str, Any]] | None,
|
| 264 |
+
) -> dict[str, Any] | None:
|
| 265 |
+
if image_arr is None:
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
h, w = image_arr.shape[:2]
|
| 269 |
+
if h < 20 or w < 20:
|
| 270 |
+
return None
|
| 271 |
+
|
| 272 |
+
if detections is None:
|
| 273 |
+
detections = []
|
| 274 |
+
|
| 275 |
+
roi_px = [
|
| 276 |
+
(0.08 * w, h - 1),
|
| 277 |
+
(0.42 * w, 0.56 * h),
|
| 278 |
+
(0.58 * w, 0.56 * h),
|
| 279 |
+
(0.92 * w, h - 1),
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
scene = {
|
| 283 |
+
"source": "camera-derived" if cv2 is not None else "heuristic-fallback",
|
| 284 |
+
"quality": 0.0,
|
| 285 |
+
"road_polygon": self._poly_px_to_bev_points(roi_px, w, h),
|
| 286 |
+
"lane_lines": [],
|
| 287 |
+
"elements": self._project_detection_elements(detections, w, h),
|
| 288 |
+
"image_size": {"width": int(w), "height": int(h)},
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
if cv2 is None:
|
| 292 |
+
scene["quality"] = 0.12
|
| 293 |
+
return scene
|
| 294 |
+
|
| 295 |
+
gray = cv2.cvtColor(image_arr, cv2.COLOR_RGB2GRAY)
|
| 296 |
+
blur = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 297 |
+
edges = cv2.Canny(blur, 60, 160)
|
| 298 |
+
|
| 299 |
+
roi_mask = np.zeros_like(edges)
|
| 300 |
+
roi_poly = np.array([
|
| 301 |
+
[
|
| 302 |
+
(int(0.08 * w), h - 1),
|
| 303 |
+
(int(0.42 * w), int(0.56 * h)),
|
| 304 |
+
(int(0.58 * w), int(0.56 * h)),
|
| 305 |
+
(int(0.92 * w), h - 1),
|
| 306 |
+
]
|
| 307 |
+
], dtype=np.int32)
|
| 308 |
+
cv2.fillPoly(roi_mask, roi_poly, 255)
|
| 309 |
+
masked_edges = cv2.bitwise_and(edges, roi_mask)
|
| 310 |
+
|
| 311 |
+
lines = cv2.HoughLinesP(
|
| 312 |
+
masked_edges,
|
| 313 |
+
rho=1,
|
| 314 |
+
theta=np.pi / 180.0,
|
| 315 |
+
threshold=max(24, int(0.03 * w)),
|
| 316 |
+
minLineLength=max(28, int(0.05 * w)),
|
| 317 |
+
maxLineGap=max(22, int(0.03 * w)),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
lane_candidates: list[tuple[float, list[dict[str, float]]]] = []
|
| 321 |
+
if lines is not None:
|
| 322 |
+
for line in lines:
|
| 323 |
+
x1, y1, x2, y2 = [int(v) for v in line[0]]
|
| 324 |
+
dx = float(x2 - x1)
|
| 325 |
+
dy = float(y2 - y1)
|
| 326 |
+
length = float(np.hypot(dx, dy))
|
| 327 |
+
|
| 328 |
+
if length < max(24.0, 0.04 * w):
|
| 329 |
+
continue
|
| 330 |
+
if abs(dy) < 8.0:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
slope = dy / dx if abs(dx) > 1e-6 else np.sign(dy) * 1e6
|
| 334 |
+
if abs(slope) < 0.35:
|
| 335 |
+
continue
|
| 336 |
+
|
| 337 |
+
p1x, p1y = self.pixel_to_bev(float(x1), float(y1), w, h)
|
| 338 |
+
p2x, p2y = self.pixel_to_bev(float(x2), float(y2), w, h)
|
| 339 |
+
p1x, p1y = self._clip_bev(p1x, p1y)
|
| 340 |
+
p2x, p2y = self._clip_bev(p2x, p2y)
|
| 341 |
+
|
| 342 |
+
lane_candidates.append(
|
| 343 |
+
(
|
| 344 |
+
length,
|
| 345 |
+
[
|
| 346 |
+
{"x": p1x, "y": p1y},
|
| 347 |
+
{"x": p2x, "y": p2y},
|
| 348 |
+
],
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
lane_candidates.sort(key=lambda item: item[0], reverse=True)
|
| 353 |
+
scene["lane_lines"] = [item[1] for item in lane_candidates[:10]]
|
| 354 |
+
|
| 355 |
+
edge_density = float(masked_edges.mean() / 255.0)
|
| 356 |
+
lane_quality = min(1.0, len(scene["lane_lines"]) / 6.0)
|
| 357 |
+
edge_quality = min(1.0, edge_density * 8.0)
|
| 358 |
+
scene["quality"] = float(np.clip(0.55 * lane_quality + 0.45 * edge_quality, 0.0, 1.0))
|
| 359 |
+
|
| 360 |
+
return scene
|
| 361 |
+
|
| 362 |
+
def lookup_sample_token_for_filename(self, filename: str | None) -> str | None:
|
| 363 |
+
if not filename:
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
idx = _load_hd_map_indices(str(self.data_root), "v1.0-mini")
|
| 368 |
+
except Exception:
|
| 369 |
+
return None
|
| 370 |
+
|
| 371 |
+
rec = idx["sample_data_by_basename"].get(Path(filename).name)
|
| 372 |
+
if not rec:
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
sample_token = rec.get("sample_token")
|
| 376 |
+
if not sample_token:
|
| 377 |
+
return None
|
| 378 |
+
|
| 379 |
+
return str(sample_token)
|
| 380 |
+
|
| 381 |
+
def _build_hd_map_layer(
|
| 382 |
+
self,
|
| 383 |
+
sample_token: str,
|
| 384 |
+
radius_m: float = 45.0,
|
| 385 |
+
out_size: int = 480,
|
| 386 |
+
) -> dict[str, Any] | None:
|
| 387 |
+
try:
|
| 388 |
+
idx = _load_hd_map_indices(str(self.data_root), "v1.0-mini")
|
| 389 |
+
except Exception:
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
sample_rec = idx["sample_by_token"].get(sample_token)
|
| 393 |
+
if sample_rec is None:
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
sample_data_list = idx["sample_data_by_sample"].get(sample_token, [])
|
| 397 |
+
if len(sample_data_list) == 0:
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
ref_rec = next(
|
| 401 |
+
(r for r in sample_data_list if "LIDAR_TOP" in str(r.get("filename", ""))),
|
| 402 |
+
sample_data_list[0],
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
ego_pose = idx["ego_pose_by_token"].get(str(ref_rec.get("ego_pose_token", "")))
|
| 406 |
+
if ego_pose is None:
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
scene_rec = idx["scene_by_token"].get(str(sample_rec.get("scene_token", "")))
|
| 410 |
+
if scene_rec is None:
|
| 411 |
+
return None
|
| 412 |
+
|
| 413 |
+
log_token = str(scene_rec.get("log_token", ""))
|
| 414 |
+
map_rec = idx["map_by_log_token"].get(log_token)
|
| 415 |
+
if map_rec is None:
|
| 416 |
+
return None
|
| 417 |
+
|
| 418 |
+
map_rel = str(map_rec.get("filename", ""))
|
| 419 |
+
map_path = self.data_root / map_rel
|
| 420 |
+
map_size = _get_map_size(str(map_path))
|
| 421 |
+
if map_size is None:
|
| 422 |
+
return None
|
| 423 |
+
map_w, map_h = map_size
|
| 424 |
+
|
| 425 |
+
translation = ego_pose.get("translation", [0.0, 0.0, 0.0])
|
| 426 |
+
ego_x = float(translation[0])
|
| 427 |
+
ego_y = float(translation[1])
|
| 428 |
+
yaw = _quat_wxyz_to_yaw(ego_pose.get("rotation", [1.0, 0.0, 0.0, 0.0]))
|
| 429 |
+
|
| 430 |
+
# nuScenes semantic prior raster masks use 0.1m per pixel.
|
| 431 |
+
ppm = 10.0
|
| 432 |
+
x_right = np.linspace(-radius_m, radius_m, out_size, dtype=np.float32)
|
| 433 |
+
y_forward = np.linspace(radius_m, -radius_m, out_size, dtype=np.float32)
|
| 434 |
+
x_grid, y_grid = np.meshgrid(x_right, y_forward)
|
| 435 |
+
|
| 436 |
+
gx = ego_x + np.cos(yaw) * y_grid + np.sin(yaw) * x_grid
|
| 437 |
+
gy = ego_y + np.sin(yaw) * y_grid - np.cos(yaw) * x_grid
|
| 438 |
+
|
| 439 |
+
px_opts = [gx * ppm, (map_w - 1.0) - gx * ppm]
|
| 440 |
+
py_opts = [gy * ppm, (map_h - 1.0) - gy * ppm]
|
| 441 |
+
|
| 442 |
+
best_px = None
|
| 443 |
+
best_py = None
|
| 444 |
+
best_valid_ratio = -1.0
|
| 445 |
+
for px in px_opts:
|
| 446 |
+
for py in py_opts:
|
| 447 |
+
valid = (px >= 0.0) & (px <= (map_w - 1.0)) & (py >= 0.0) & (py <= (map_h - 1.0))
|
| 448 |
+
ratio = float(valid.mean())
|
| 449 |
+
if ratio > best_valid_ratio:
|
| 450 |
+
best_valid_ratio = ratio
|
| 451 |
+
best_px = px
|
| 452 |
+
best_py = py
|
| 453 |
+
|
| 454 |
+
if best_px is None or best_py is None or best_valid_ratio < 0.15:
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
crop_left = int(max(0, math.floor(float(best_px.min())) - 2))
|
| 458 |
+
crop_top = int(max(0, math.floor(float(best_py.min())) - 2))
|
| 459 |
+
crop_right = int(min(map_w, math.ceil(float(best_px.max())) + 3))
|
| 460 |
+
crop_bottom = int(min(map_h, math.ceil(float(best_py.max())) + 3))
|
| 461 |
+
|
| 462 |
+
map_crop = _load_map_crop_gray(str(map_path), crop_left, crop_top, crop_right, crop_bottom)
|
| 463 |
+
if map_crop is None or map_crop.size == 0:
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
remap_x = best_px - float(crop_left)
|
| 467 |
+
remap_y = best_py - float(crop_top)
|
| 468 |
+
|
| 469 |
+
if cv2 is not None:
|
| 470 |
+
patch = cv2.remap(
|
| 471 |
+
map_crop,
|
| 472 |
+
remap_x.astype(np.float32),
|
| 473 |
+
remap_y.astype(np.float32),
|
| 474 |
+
interpolation=cv2.INTER_LINEAR,
|
| 475 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 476 |
+
borderValue=0,
|
| 477 |
+
)
|
| 478 |
+
patch_u8 = patch.astype(np.uint8)
|
| 479 |
+
else:
|
| 480 |
+
crop_h, crop_w = map_crop.shape[:2]
|
| 481 |
+
xi = np.clip(np.round(remap_x).astype(np.int32), 0, crop_w - 1)
|
| 482 |
+
yi = np.clip(np.round(remap_y).astype(np.int32), 0, crop_h - 1)
|
| 483 |
+
patch_u8 = map_crop[yi, xi]
|
| 484 |
+
|
| 485 |
+
drivable = patch_u8 > 96
|
| 486 |
+
strong = patch_u8 > 170
|
| 487 |
+
if float(drivable.mean()) < 0.01:
|
| 488 |
+
return None
|
| 489 |
+
|
| 490 |
+
rgba = np.zeros((out_size, out_size, 4), dtype=np.uint8)
|
| 491 |
+
rgba[drivable] = [72, 94, 114, 130]
|
| 492 |
+
rgba[strong] = [170, 194, 216, 192]
|
| 493 |
+
|
| 494 |
+
buf = io.BytesIO()
|
| 495 |
+
Image.fromarray(rgba, mode="RGBA").save(buf, format="PNG")
|
| 496 |
+
png_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 497 |
+
|
| 498 |
+
return {
|
| 499 |
+
"source": "nuscenes-semantic-prior",
|
| 500 |
+
"map_token": map_rec.get("token"),
|
| 501 |
+
"valid_ratio": round(best_valid_ratio, 3),
|
| 502 |
+
"image_png_base64": png_b64,
|
| 503 |
+
"opacity": 0.62,
|
| 504 |
+
"bounds": {
|
| 505 |
+
"min_x": -float(radius_m),
|
| 506 |
+
"max_x": float(radius_m),
|
| 507 |
+
"min_y": -float(radius_m),
|
| 508 |
+
"max_y": float(radius_m),
|
| 509 |
+
},
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
def _attach_hd_map_layer(self, scene_geometry: dict[str, Any] | None, sample_token: str | None):
|
| 513 |
+
if not sample_token:
|
| 514 |
+
return scene_geometry
|
| 515 |
+
|
| 516 |
+
map_layer = self._build_hd_map_layer(sample_token)
|
| 517 |
+
if map_layer is None:
|
| 518 |
+
return scene_geometry
|
| 519 |
+
|
| 520 |
+
if scene_geometry is None:
|
| 521 |
+
bounds = map_layer["bounds"]
|
| 522 |
+
scene_geometry = {
|
| 523 |
+
"source": "hd-map",
|
| 524 |
+
"quality": 0.55,
|
| 525 |
+
"road_polygon": [
|
| 526 |
+
{"x": bounds["min_x"], "y": bounds["min_y"]},
|
| 527 |
+
{"x": bounds["max_x"], "y": bounds["min_y"]},
|
| 528 |
+
{"x": bounds["max_x"], "y": bounds["max_y"]},
|
| 529 |
+
{"x": bounds["min_x"], "y": bounds["max_y"]},
|
| 530 |
+
],
|
| 531 |
+
"lane_lines": [],
|
| 532 |
+
"elements": [],
|
| 533 |
+
}
|
| 534 |
+
else:
|
| 535 |
+
scene_geometry = dict(scene_geometry)
|
| 536 |
+
prev_source = str(scene_geometry.get("source", "")).strip()
|
| 537 |
+
if "hd-map" not in prev_source:
|
| 538 |
+
scene_geometry["source"] = f"{prev_source}+hd-map" if prev_source else "hd-map"
|
| 539 |
+
scene_geometry["quality"] = float(np.clip(max(float(scene_geometry.get("quality", 0.0)), 0.55), 0.0, 1.0))
|
| 540 |
+
|
| 541 |
+
scene_geometry["map_layer"] = map_layer
|
| 542 |
+
return scene_geometry
|
| 543 |
+
|
| 544 |
+
def load_cv_models(self) -> dict[str, Any]:
|
| 545 |
+
if self._models is not None:
|
| 546 |
+
return self._models
|
| 547 |
+
|
| 548 |
+
with self._model_lock:
|
| 549 |
+
if self._models is not None:
|
| 550 |
+
return self._models
|
| 551 |
+
|
| 552 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 553 |
+
|
| 554 |
+
try:
|
| 555 |
+
det_weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 556 |
+
det_model = fasterrcnn_resnet50_fpn(weights=det_weights, progress=False)
|
| 557 |
+
det_model.to(device).eval()
|
| 558 |
+
|
| 559 |
+
pose_weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 560 |
+
pose_model = keypointrcnn_resnet50_fpn(weights=pose_weights, progress=False)
|
| 561 |
+
pose_model.to(device).eval()
|
| 562 |
+
|
| 563 |
+
self._models = {
|
| 564 |
+
"device": device,
|
| 565 |
+
"device_name": str(device),
|
| 566 |
+
"det_model": det_model,
|
| 567 |
+
"det_weights": det_weights,
|
| 568 |
+
"pose_model": pose_model,
|
| 569 |
+
"pose_weights": pose_weights,
|
| 570 |
+
}
|
| 571 |
+
except Exception as exc:
|
| 572 |
+
self._models = {
|
| 573 |
+
"error": str(exc),
|
| 574 |
+
"device": device,
|
| 575 |
+
"device_name": str(device),
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
return self._models
|
| 579 |
+
|
| 580 |
+
def detect_objects_and_pose(
|
| 581 |
+
self,
|
| 582 |
+
image_arr: np.ndarray,
|
| 583 |
+
models: dict[str, Any],
|
| 584 |
+
score_threshold: float = 0.55,
|
| 585 |
+
use_pose: bool = True,
|
| 586 |
+
) -> list[dict[str, Any]]:
|
| 587 |
+
if "error" in models:
|
| 588 |
+
return []
|
| 589 |
+
|
| 590 |
+
device = models["device"]
|
| 591 |
+
pil_img = Image.fromarray(image_arr)
|
| 592 |
+
|
| 593 |
+
det_input = models["det_weights"].transforms()(pil_img).unsqueeze(0).to(device)
|
| 594 |
+
with torch.no_grad():
|
| 595 |
+
det_out = models["det_model"](det_input)[0]
|
| 596 |
+
|
| 597 |
+
boxes = det_out["boxes"].detach().cpu().numpy() if len(det_out["boxes"]) > 0 else np.zeros((0, 4))
|
| 598 |
+
scores = det_out["scores"].detach().cpu().numpy() if len(det_out["scores"]) > 0 else np.zeros((0,))
|
| 599 |
+
labels = det_out["labels"].detach().cpu().numpy() if len(det_out["labels"]) > 0 else np.zeros((0,))
|
| 600 |
+
|
| 601 |
+
detections: list[dict[str, Any]] = []
|
| 602 |
+
for i in range(len(scores)):
|
| 603 |
+
score = float(scores[i])
|
| 604 |
+
label_idx = int(labels[i])
|
| 605 |
+
label_name = COCO_TO_LABEL.get(label_idx)
|
| 606 |
+
|
| 607 |
+
if label_name is None or score < score_threshold:
|
| 608 |
+
continue
|
| 609 |
+
|
| 610 |
+
kind = self.coco_kind(label_name)
|
| 611 |
+
if kind is None:
|
| 612 |
+
continue
|
| 613 |
+
|
| 614 |
+
x1, y1, x2, y2 = [float(v) for v in boxes[i]]
|
| 615 |
+
detections.append(
|
| 616 |
+
{
|
| 617 |
+
"score": score,
|
| 618 |
+
"raw_label": label_name,
|
| 619 |
+
"kind": kind,
|
| 620 |
+
"box": [x1, y1, x2, y2],
|
| 621 |
+
"center_x": 0.5 * (x1 + x2),
|
| 622 |
+
"bottom_y": y2,
|
| 623 |
+
"keypoints": None,
|
| 624 |
+
}
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
if use_pose:
|
| 628 |
+
pose_input = models["pose_weights"].transforms()(pil_img).unsqueeze(0).to(device)
|
| 629 |
+
with torch.no_grad():
|
| 630 |
+
pose_out = models["pose_model"](pose_input)[0]
|
| 631 |
+
|
| 632 |
+
p_boxes = pose_out["boxes"].detach().cpu().numpy() if len(pose_out["boxes"]) > 0 else np.zeros((0, 4))
|
| 633 |
+
p_scores = pose_out["scores"].detach().cpu().numpy() if len(pose_out["scores"]) > 0 else np.zeros((0,))
|
| 634 |
+
p_labels = pose_out["labels"].detach().cpu().numpy() if len(pose_out["labels"]) > 0 else np.zeros((0,))
|
| 635 |
+
p_keypoints = (
|
| 636 |
+
pose_out["keypoints"].detach().cpu().numpy()
|
| 637 |
+
if len(pose_out["keypoints"]) > 0
|
| 638 |
+
else np.zeros((0, 17, 3))
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
assigned = set()
|
| 642 |
+
for i in range(len(p_scores)):
|
| 643 |
+
if int(p_labels[i]) != 1:
|
| 644 |
+
continue
|
| 645 |
+
if float(p_scores[i]) < max(0.25, 0.8 * score_threshold):
|
| 646 |
+
continue
|
| 647 |
+
|
| 648 |
+
pose_box = [float(v) for v in p_boxes[i]]
|
| 649 |
+
best_idx = None
|
| 650 |
+
best_iou = 0.0
|
| 651 |
+
|
| 652 |
+
for det_idx, det in enumerate(detections):
|
| 653 |
+
if det_idx in assigned:
|
| 654 |
+
continue
|
| 655 |
+
if det["raw_label"] != "person":
|
| 656 |
+
continue
|
| 657 |
+
|
| 658 |
+
iou_val = self.iou_xyxy(det["box"], pose_box)
|
| 659 |
+
if iou_val > best_iou:
|
| 660 |
+
best_iou = iou_val
|
| 661 |
+
best_idx = det_idx
|
| 662 |
+
|
| 663 |
+
if best_idx is not None and best_iou > 0.1:
|
| 664 |
+
detections[best_idx]["keypoints"] = p_keypoints[i].tolist()
|
| 665 |
+
assigned.add(best_idx)
|
| 666 |
+
|
| 667 |
+
return detections
|
| 668 |
+
|
| 669 |
+
@staticmethod
|
| 670 |
+
def match_two_frame_tracks(
|
| 671 |
+
det_prev: list[dict[str, Any]],
|
| 672 |
+
det_curr: list[dict[str, Any]],
|
| 673 |
+
tracking_gate_px: float = 90.0,
|
| 674 |
+
) -> list[tuple[dict[str, Any], dict[str, Any], float]]:
|
| 675 |
+
used_curr = set()
|
| 676 |
+
matches = []
|
| 677 |
+
|
| 678 |
+
det_prev = sorted(det_prev, key=lambda d: d["score"], reverse=True)
|
| 679 |
+
det_curr = sorted(det_curr, key=lambda d: d["score"], reverse=True)
|
| 680 |
+
|
| 681 |
+
for d0 in det_prev:
|
| 682 |
+
best_idx = None
|
| 683 |
+
best_dist = 1e9
|
| 684 |
+
|
| 685 |
+
for j, d1 in enumerate(det_curr):
|
| 686 |
+
if j in used_curr:
|
| 687 |
+
continue
|
| 688 |
+
if d0["kind"] != d1["kind"]:
|
| 689 |
+
continue
|
| 690 |
+
|
| 691 |
+
dist = math.hypot(d1["center_x"] - d0["center_x"], d1["bottom_y"] - d0["bottom_y"])
|
| 692 |
+
if dist < tracking_gate_px and dist < best_dist:
|
| 693 |
+
best_dist = dist
|
| 694 |
+
best_idx = j
|
| 695 |
+
|
| 696 |
+
if best_idx is None:
|
| 697 |
+
continue
|
| 698 |
+
|
| 699 |
+
used_curr.add(best_idx)
|
| 700 |
+
d1 = det_curr[best_idx]
|
| 701 |
+
matches.append((d0, d1, float(best_dist)))
|
| 702 |
+
|
| 703 |
+
return matches
|
| 704 |
+
|
| 705 |
+
def build_two_image_agents_bundle(
|
| 706 |
+
self,
|
| 707 |
+
img_prev: np.ndarray,
|
| 708 |
+
img_curr: np.ndarray,
|
| 709 |
+
score_threshold: float,
|
| 710 |
+
tracking_gate_px: float,
|
| 711 |
+
min_motion_px: float,
|
| 712 |
+
use_pose: bool,
|
| 713 |
+
img_prev_name: str | None = None,
|
| 714 |
+
img_curr_name: str | None = None,
|
| 715 |
+
) -> dict[str, Any]:
|
| 716 |
+
models = self.load_cv_models()
|
| 717 |
+
if "error" in models:
|
| 718 |
+
return {
|
| 719 |
+
"error": f"Could not load CV models ({models['error']}).",
|
| 720 |
+
"device": models.get("device_name", "unknown"),
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
det_prev = self.detect_objects_and_pose(img_prev, models, score_threshold=score_threshold, use_pose=use_pose)
|
| 724 |
+
det_curr = self.detect_objects_and_pose(img_curr, models, score_threshold=score_threshold, use_pose=use_pose)
|
| 725 |
+
|
| 726 |
+
det_prev_vru = [d for d in det_prev if d.get("kind") == "pedestrian"]
|
| 727 |
+
det_curr_vru = [d for d in det_curr if d.get("kind") == "pedestrian"]
|
| 728 |
+
|
| 729 |
+
for i, d in enumerate(det_prev):
|
| 730 |
+
d["det_id"] = i + 1
|
| 731 |
+
d["track_id"] = None
|
| 732 |
+
for i, d in enumerate(det_curr):
|
| 733 |
+
d["det_id"] = i + 1
|
| 734 |
+
d["track_id"] = None
|
| 735 |
+
|
| 736 |
+
if len(det_curr_vru) == 0:
|
| 737 |
+
return {"error": "No pedestrian/cyclist detections found in image 2 (t0)."}
|
| 738 |
+
|
| 739 |
+
matches = self.match_two_frame_tracks(
|
| 740 |
+
det_prev_vru,
|
| 741 |
+
det_curr_vru,
|
| 742 |
+
tracking_gate_px=tracking_gate_px,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
matched_curr_ids = {id(m[1]) for m in matches}
|
| 746 |
+
for d1 in det_curr_vru:
|
| 747 |
+
if id(d1) in matched_curr_ids:
|
| 748 |
+
continue
|
| 749 |
+
|
| 750 |
+
if len(det_prev_vru) == 0:
|
| 751 |
+
matches.append((None, d1, float("inf")))
|
| 752 |
+
continue
|
| 753 |
+
|
| 754 |
+
nearest_prev = min(
|
| 755 |
+
det_prev_vru,
|
| 756 |
+
key=lambda d0: math.hypot(d1["center_x"] - d0["center_x"], d1["bottom_y"] - d0["bottom_y"]),
|
| 757 |
+
)
|
| 758 |
+
dist = math.hypot(
|
| 759 |
+
d1["center_x"] - nearest_prev["center_x"],
|
| 760 |
+
d1["bottom_y"] - nearest_prev["bottom_y"],
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
if dist <= 1.5 * tracking_gate_px:
|
| 764 |
+
matches.append((nearest_prev, d1, float(dist)))
|
| 765 |
+
else:
|
| 766 |
+
matches.append((None, d1, float("inf")))
|
| 767 |
+
|
| 768 |
+
h0, w0 = img_prev.shape[:2]
|
| 769 |
+
h1, w1 = img_curr.shape[:2]
|
| 770 |
+
|
| 771 |
+
tracks = []
|
| 772 |
+
for track_id, (d0, d1, dist_px) in enumerate(matches, start=1):
|
| 773 |
+
if d0 is not None and d0.get("track_id") is None:
|
| 774 |
+
d0["track_id"] = track_id
|
| 775 |
+
d1["track_id"] = track_id
|
| 776 |
+
|
| 777 |
+
if d0 is not None:
|
| 778 |
+
p_prev = self.pixel_to_bev(d0["center_x"], d0["bottom_y"], w0, h0)
|
| 779 |
+
else:
|
| 780 |
+
p_prev = None
|
| 781 |
+
|
| 782 |
+
p_curr = self.pixel_to_bev(d1["center_x"], d1["bottom_y"], w1, h1)
|
| 783 |
+
|
| 784 |
+
if p_prev is None:
|
| 785 |
+
vx, vy = 0.0, 0.0
|
| 786 |
+
p_prev = p_curr
|
| 787 |
+
else:
|
| 788 |
+
vx = p_curr[0] - p_prev[0]
|
| 789 |
+
vy = p_curr[1] - p_prev[1]
|
| 790 |
+
|
| 791 |
+
if dist_px < float(min_motion_px):
|
| 792 |
+
vx, vy = 0.0, 0.0
|
| 793 |
+
p_prev = p_curr
|
| 794 |
+
|
| 795 |
+
hist = [
|
| 796 |
+
(p_curr[0] - 3.0 * vx, p_curr[1] - 3.0 * vy),
|
| 797 |
+
(p_curr[0] - 2.0 * vx, p_curr[1] - 2.0 * vy),
|
| 798 |
+
(p_prev[0], p_prev[1]),
|
| 799 |
+
(p_curr[0], p_curr[1]),
|
| 800 |
+
]
|
| 801 |
+
|
| 802 |
+
tracks.append(
|
| 803 |
+
{
|
| 804 |
+
"id": track_id,
|
| 805 |
+
"kind": d1["kind"],
|
| 806 |
+
"raw_label": d1["raw_label"],
|
| 807 |
+
"history_world": hist,
|
| 808 |
+
}
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
agents = []
|
| 812 |
+
for tr in tracks:
|
| 813 |
+
neighbors = [other["history_world"] for other in tracks if other["id"] != tr["id"]]
|
| 814 |
+
|
| 815 |
+
pred, probs, _ = trajectory_predict(
|
| 816 |
+
tr["history_world"],
|
| 817 |
+
neighbor_points_list=neighbors,
|
| 818 |
+
fusion_feats=None,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
pred_np = pred.detach().cpu().numpy()
|
| 822 |
+
probs_np = probs.detach().cpu().numpy()
|
| 823 |
+
|
| 824 |
+
predictions = []
|
| 825 |
+
for mode_i in range(pred_np.shape[0]):
|
| 826 |
+
predictions.append([(float(p[0]), float(p[1])) for p in pred_np[mode_i]])
|
| 827 |
+
|
| 828 |
+
agents.append(
|
| 829 |
+
{
|
| 830 |
+
"id": int(tr["id"]),
|
| 831 |
+
"type": "pedestrian" if tr["kind"] == "pedestrian" else "vehicle",
|
| 832 |
+
"raw_label": tr["raw_label"],
|
| 833 |
+
"history": [tuple(map(float, p)) for p in tr["history_world"]],
|
| 834 |
+
"predictions": predictions,
|
| 835 |
+
"probabilities": self.normalize_probs(probs_np.tolist()),
|
| 836 |
+
"is_target": True,
|
| 837 |
+
}
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
scene_geometry = self.extract_scene_geometry(img_curr, det_curr)
|
| 841 |
+
sample_token = self.lookup_sample_token_for_filename(img_curr_name)
|
| 842 |
+
scene_geometry = self._attach_hd_map_layer(scene_geometry, sample_token)
|
| 843 |
+
|
| 844 |
+
return {
|
| 845 |
+
"mode": "two_upload",
|
| 846 |
+
"agents": agents,
|
| 847 |
+
"target_track_id": None,
|
| 848 |
+
"device": models.get("device_name", "unknown"),
|
| 849 |
+
"match_count": len(agents),
|
| 850 |
+
"scene_geometry": scene_geometry,
|
| 851 |
+
"camera_snapshots": {
|
| 852 |
+
"pair_prev": {"detections": det_prev},
|
| 853 |
+
"pair_curr": {"detections": det_curr},
|
| 854 |
+
},
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
def track_front_agents(
|
| 858 |
+
self,
|
| 859 |
+
front_paths: list[Path],
|
| 860 |
+
models: dict[str, Any],
|
| 861 |
+
score_threshold: float = 0.55,
|
| 862 |
+
tracking_gate_px: float = 90.0,
|
| 863 |
+
use_pose: bool = True,
|
| 864 |
+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
| 865 |
+
tracks: dict[int, dict[str, Any]] = {}
|
| 866 |
+
next_track_id = 1
|
| 867 |
+
front_final_detections: list[dict[str, Any]] = []
|
| 868 |
+
|
| 869 |
+
for frame_idx, frame_path in enumerate(front_paths):
|
| 870 |
+
frame_arr = self.load_image_array(frame_path)
|
| 871 |
+
h, w = frame_arr.shape[:2]
|
| 872 |
+
|
| 873 |
+
detections = self.detect_objects_and_pose(
|
| 874 |
+
frame_arr,
|
| 875 |
+
models,
|
| 876 |
+
score_threshold=score_threshold,
|
| 877 |
+
use_pose=use_pose,
|
| 878 |
+
)
|
| 879 |
+
detections.sort(key=lambda d: d["score"], reverse=True)
|
| 880 |
+
|
| 881 |
+
matched_track_ids = set()
|
| 882 |
+
frame_dets_with_ids = []
|
| 883 |
+
|
| 884 |
+
for det in detections:
|
| 885 |
+
wx, wy = self.pixel_to_bev(det["center_x"], det["bottom_y"], w, h)
|
| 886 |
+
|
| 887 |
+
best_track_id = None
|
| 888 |
+
best_dist = 1e9
|
| 889 |
+
|
| 890 |
+
for tid, tr in tracks.items():
|
| 891 |
+
if tr["kind"] != det["kind"]:
|
| 892 |
+
continue
|
| 893 |
+
if tr["last_seen"] != frame_idx - 1:
|
| 894 |
+
continue
|
| 895 |
+
if tid in matched_track_ids:
|
| 896 |
+
continue
|
| 897 |
+
|
| 898 |
+
px_last, py_last = tr["history_pixel"][-1]
|
| 899 |
+
dist = math.hypot(det["center_x"] - px_last, det["bottom_y"] - py_last)
|
| 900 |
+
if dist < tracking_gate_px and dist < best_dist:
|
| 901 |
+
best_dist = dist
|
| 902 |
+
best_track_id = tid
|
| 903 |
+
|
| 904 |
+
if best_track_id is None:
|
| 905 |
+
best_track_id = next_track_id
|
| 906 |
+
next_track_id += 1
|
| 907 |
+
tracks[best_track_id] = {
|
| 908 |
+
"id": best_track_id,
|
| 909 |
+
"kind": det["kind"],
|
| 910 |
+
"raw_label": det["raw_label"],
|
| 911 |
+
"history_pixel": [],
|
| 912 |
+
"history_world": [],
|
| 913 |
+
"last_seen": -1,
|
| 914 |
+
"last_box": None,
|
| 915 |
+
"last_keypoints": None,
|
| 916 |
+
"misses": 0,
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
tr = tracks[best_track_id]
|
| 920 |
+
tr["history_pixel"].append((float(det["center_x"]), float(det["bottom_y"])))
|
| 921 |
+
tr["history_world"].append((float(wx), float(wy)))
|
| 922 |
+
tr["last_seen"] = frame_idx
|
| 923 |
+
tr["raw_label"] = det["raw_label"]
|
| 924 |
+
tr["last_box"] = det["box"]
|
| 925 |
+
tr["last_keypoints"] = det.get("keypoints")
|
| 926 |
+
tr["misses"] = 0
|
| 927 |
+
|
| 928 |
+
matched_track_ids.add(best_track_id)
|
| 929 |
+
|
| 930 |
+
det = dict(det)
|
| 931 |
+
det["track_id"] = best_track_id
|
| 932 |
+
frame_dets_with_ids.append(det)
|
| 933 |
+
|
| 934 |
+
for tid, tr in tracks.items():
|
| 935 |
+
if tr["last_seen"] == frame_idx:
|
| 936 |
+
continue
|
| 937 |
+
if tr["last_seen"] < frame_idx - 1:
|
| 938 |
+
continue
|
| 939 |
+
|
| 940 |
+
if len(tr["history_pixel"]) >= 2:
|
| 941 |
+
px_prev, py_prev = tr["history_pixel"][-2]
|
| 942 |
+
px_last, py_last = tr["history_pixel"][-1]
|
| 943 |
+
wx_prev, wy_prev = tr["history_world"][-2]
|
| 944 |
+
wx_last, wy_last = tr["history_world"][-1]
|
| 945 |
+
|
| 946 |
+
px_ex = px_last + (px_last - px_prev)
|
| 947 |
+
py_ex = py_last + (py_last - py_prev)
|
| 948 |
+
wx_ex = wx_last + (wx_last - wx_prev)
|
| 949 |
+
wy_ex = wy_last + (wy_last - wy_prev)
|
| 950 |
+
else:
|
| 951 |
+
px_ex, py_ex = tr["history_pixel"][-1]
|
| 952 |
+
wx_ex, wy_ex = tr["history_world"][-1]
|
| 953 |
+
|
| 954 |
+
tr["history_pixel"].append((float(px_ex), float(py_ex)))
|
| 955 |
+
tr["history_world"].append((float(wx_ex), float(wy_ex)))
|
| 956 |
+
tr["last_seen"] = frame_idx
|
| 957 |
+
tr["misses"] += 1
|
| 958 |
+
|
| 959 |
+
if frame_idx == len(front_paths) - 1:
|
| 960 |
+
front_final_detections = frame_dets_with_ids
|
| 961 |
+
|
| 962 |
+
valid_tracks = []
|
| 963 |
+
for tid, tr in tracks.items():
|
| 964 |
+
if len(tr["history_world"]) != len(front_paths):
|
| 965 |
+
continue
|
| 966 |
+
if tr["misses"] > 2:
|
| 967 |
+
continue
|
| 968 |
+
|
| 969 |
+
x0, y0 = tr["history_world"][0]
|
| 970 |
+
x1, y1 = tr["history_world"][-1]
|
| 971 |
+
motion = math.hypot(x1 - x0, y1 - y0)
|
| 972 |
+
if motion < 0.08:
|
| 973 |
+
continue
|
| 974 |
+
|
| 975 |
+
valid_tracks.append(
|
| 976 |
+
{
|
| 977 |
+
"id": tid,
|
| 978 |
+
"kind": tr["kind"],
|
| 979 |
+
"raw_label": tr["raw_label"],
|
| 980 |
+
"history_pixel": [tuple(p) for p in tr["history_pixel"]],
|
| 981 |
+
"history_world": [tuple(p) for p in tr["history_world"]],
|
| 982 |
+
"last_box": tr["last_box"],
|
| 983 |
+
"last_keypoints": tr["last_keypoints"],
|
| 984 |
+
}
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
valid_tracks.sort(key=lambda t: t["id"])
|
| 988 |
+
return valid_tracks, front_final_detections
|
| 989 |
+
|
| 990 |
+
@staticmethod
|
| 991 |
+
def raw_label_to_stabilizer_type(raw_label: str) -> str:
|
| 992 |
+
if raw_label == "person":
|
| 993 |
+
return "Person"
|
| 994 |
+
if raw_label == "bicycle":
|
| 995 |
+
return "Bicycle"
|
| 996 |
+
if raw_label == "motorcycle":
|
| 997 |
+
return "Motorcycle"
|
| 998 |
+
if raw_label == "bus":
|
| 999 |
+
return "Bus"
|
| 1000 |
+
if raw_label == "truck":
|
| 1001 |
+
return "Truck"
|
| 1002 |
+
return "Car"
|
| 1003 |
+
|
| 1004 |
+
@staticmethod
|
| 1005 |
+
def build_fusion_features(history_world: list[tuple[float, float]], fusion_data: dict[str, Any] | None):
|
| 1006 |
+
if not fusion_data:
|
| 1007 |
+
return None
|
| 1008 |
+
|
| 1009 |
+
lidar_xy = fusion_data.get("lidar_xy")
|
| 1010 |
+
radar_xy = fusion_data.get("radar_xy")
|
| 1011 |
+
|
| 1012 |
+
if lidar_xy is None and radar_xy is None:
|
| 1013 |
+
return None
|
| 1014 |
+
|
| 1015 |
+
feats = []
|
| 1016 |
+
for px, py in history_world:
|
| 1017 |
+
if lidar_xy is not None and len(lidar_xy) > 0:
|
| 1018 |
+
dl = np.hypot(lidar_xy[:, 0] - px, lidar_xy[:, 1] - py)
|
| 1019 |
+
lidar_cnt = int((dl < 2.0).sum())
|
| 1020 |
+
else:
|
| 1021 |
+
lidar_cnt = 0
|
| 1022 |
+
|
| 1023 |
+
if radar_xy is not None and len(radar_xy) > 0:
|
| 1024 |
+
dr = np.hypot(radar_xy[:, 0] - px, radar_xy[:, 1] - py)
|
| 1025 |
+
radar_cnt = int((dr < 2.5).sum())
|
| 1026 |
+
else:
|
| 1027 |
+
radar_cnt = 0
|
| 1028 |
+
|
| 1029 |
+
lidar_norm = min(80.0, float(lidar_cnt)) / 80.0
|
| 1030 |
+
radar_norm = min(30.0, float(radar_cnt)) / 30.0
|
| 1031 |
+
sensor_strength = min(1.0, (float(lidar_cnt) + 2.0 * float(radar_cnt)) / 100.0)
|
| 1032 |
+
feats.append([lidar_norm, radar_norm, sensor_strength])
|
| 1033 |
+
|
| 1034 |
+
return feats
|
| 1035 |
+
|
| 1036 |
+
def stabilize_tracks_with_radar(self, tracks: list[dict[str, Any]], fusion_data: dict[str, Any] | None):
|
| 1037 |
+
if not tracks:
|
| 1038 |
+
return tracks
|
| 1039 |
+
|
| 1040 |
+
packed = []
|
| 1041 |
+
for tr in tracks:
|
| 1042 |
+
hist = tr["history_world"]
|
| 1043 |
+
if len(hist) >= 2:
|
| 1044 |
+
dx = float(hist[-1][0] - hist[-2][0])
|
| 1045 |
+
dy = float(hist[-1][1] - hist[-2][1])
|
| 1046 |
+
else:
|
| 1047 |
+
dx = 0.0
|
| 1048 |
+
dy = 0.0
|
| 1049 |
+
|
| 1050 |
+
packed.append(
|
| 1051 |
+
{
|
| 1052 |
+
"type": self.raw_label_to_stabilizer_type(tr.get("raw_label", "car")),
|
| 1053 |
+
"history": [tuple(p) for p in hist],
|
| 1054 |
+
"dx": dx,
|
| 1055 |
+
"dy": dy,
|
| 1056 |
+
}
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
stabilized = radar_stabilize_motion(packed, fusion_data, dt_seconds=0.5)
|
| 1060 |
+
|
| 1061 |
+
updated = []
|
| 1062 |
+
for tr, st in zip(tracks, stabilized):
|
| 1063 |
+
t_copy = dict(tr)
|
| 1064 |
+
t_copy["history_world"] = [(float(x), float(y)) for x, y in st["history"]]
|
| 1065 |
+
updated.append(t_copy)
|
| 1066 |
+
|
| 1067 |
+
return updated
|
| 1068 |
+
|
| 1069 |
+
@staticmethod
|
| 1070 |
+
def choose_target_track_id(tracks: list[dict[str, Any]]) -> int | None:
|
| 1071 |
+
if not tracks:
|
| 1072 |
+
return None
|
| 1073 |
+
|
| 1074 |
+
peds = [t for t in tracks if t["kind"] == "pedestrian"]
|
| 1075 |
+
if peds:
|
| 1076 |
+
best = min(peds, key=lambda t: math.hypot(t["history_world"][-1][0], t["history_world"][-1][1]))
|
| 1077 |
+
return best["id"]
|
| 1078 |
+
|
| 1079 |
+
return tracks[0]["id"]
|
| 1080 |
+
|
| 1081 |
+
def build_agents_from_tracks(self, tracks: list[dict[str, Any]], fusion_data: dict[str, Any] | None):
|
| 1082 |
+
if not tracks:
|
| 1083 |
+
return [], None, []
|
| 1084 |
+
|
| 1085 |
+
tracks_work = []
|
| 1086 |
+
for tr in tracks:
|
| 1087 |
+
tracks_work.append(
|
| 1088 |
+
{
|
| 1089 |
+
"id": tr["id"],
|
| 1090 |
+
"kind": tr["kind"],
|
| 1091 |
+
"raw_label": tr["raw_label"],
|
| 1092 |
+
"history_pixel": [tuple(p) for p in tr["history_pixel"]],
|
| 1093 |
+
"history_world": [tuple(p) for p in tr["history_world"]],
|
| 1094 |
+
"last_box": tr.get("last_box"),
|
| 1095 |
+
"last_keypoints": tr.get("last_keypoints"),
|
| 1096 |
+
}
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
tracks_work = self.stabilize_tracks_with_radar(tracks_work, fusion_data)
|
| 1100 |
+
|
| 1101 |
+
target_id = self.choose_target_track_id(tracks_work)
|
| 1102 |
+
agents = []
|
| 1103 |
+
|
| 1104 |
+
for tr in tracks_work:
|
| 1105 |
+
neighbors = [other["history_world"] for other in tracks_work if other["id"] != tr["id"]]
|
| 1106 |
+
|
| 1107 |
+
if len(neighbors) > 12:
|
| 1108 |
+
x0, y0 = tr["history_world"][-1]
|
| 1109 |
+
neighbors = sorted(
|
| 1110 |
+
neighbors,
|
| 1111 |
+
key=lambda nh: math.hypot(nh[-1][0] - x0, nh[-1][1] - y0),
|
| 1112 |
+
)[:12]
|
| 1113 |
+
|
| 1114 |
+
fusion_feats = self.build_fusion_features(tr["history_world"], fusion_data)
|
| 1115 |
+
|
| 1116 |
+
pred, probs, _ = trajectory_predict(
|
| 1117 |
+
tr["history_world"],
|
| 1118 |
+
neighbor_points_list=neighbors,
|
| 1119 |
+
fusion_feats=fusion_feats,
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
pred_np = pred.detach().cpu().numpy()
|
| 1123 |
+
probs_np = probs.detach().cpu().numpy()
|
| 1124 |
+
|
| 1125 |
+
predictions = []
|
| 1126 |
+
for mode_i in range(pred_np.shape[0]):
|
| 1127 |
+
predictions.append([(float(p[0]), float(p[1])) for p in pred_np[mode_i]])
|
| 1128 |
+
|
| 1129 |
+
agents.append(
|
| 1130 |
+
{
|
| 1131 |
+
"id": int(tr["id"]),
|
| 1132 |
+
"type": "pedestrian" if tr["kind"] == "pedestrian" else "vehicle",
|
| 1133 |
+
"raw_label": tr["raw_label"],
|
| 1134 |
+
"history": [tuple(map(float, p)) for p in tr["history_world"]],
|
| 1135 |
+
"predictions": predictions,
|
| 1136 |
+
"probabilities": self.normalize_probs(probs_np.tolist()),
|
| 1137 |
+
"is_target": tr["id"] == target_id,
|
| 1138 |
+
}
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
return agents, target_id, tracks_work
|
| 1142 |
+
|
| 1143 |
+
@staticmethod
|
| 1144 |
+
def assign_track_ids_to_front_detections(
|
| 1145 |
+
detections: list[dict[str, Any]],
|
| 1146 |
+
tracks: list[dict[str, Any]],
|
| 1147 |
+
gate_px: float = 90.0,
|
| 1148 |
+
) -> list[dict[str, Any]]:
|
| 1149 |
+
if not detections:
|
| 1150 |
+
return []
|
| 1151 |
+
|
| 1152 |
+
out = []
|
| 1153 |
+
used_ids = set()
|
| 1154 |
+
|
| 1155 |
+
for det_idx, det in enumerate(detections):
|
| 1156 |
+
d = dict(det)
|
| 1157 |
+
d.setdefault("det_id", det_idx + 1)
|
| 1158 |
+
|
| 1159 |
+
if d.get("track_id") is not None:
|
| 1160 |
+
used_ids.add(d["track_id"])
|
| 1161 |
+
out.append(d)
|
| 1162 |
+
continue
|
| 1163 |
+
|
| 1164 |
+
best_id = None
|
| 1165 |
+
best_dist = 1e9
|
| 1166 |
+
|
| 1167 |
+
for tr in tracks:
|
| 1168 |
+
if tr["id"] in used_ids:
|
| 1169 |
+
continue
|
| 1170 |
+
if tr["kind"] != d["kind"]:
|
| 1171 |
+
continue
|
| 1172 |
+
|
| 1173 |
+
px, py = tr["history_pixel"][-1]
|
| 1174 |
+
dist = math.hypot(d["center_x"] - px, d["bottom_y"] - py)
|
| 1175 |
+
if dist < gate_px and dist < best_dist:
|
| 1176 |
+
best_dist = dist
|
| 1177 |
+
best_id = tr["id"]
|
| 1178 |
+
|
| 1179 |
+
d["track_id"] = best_id
|
| 1180 |
+
if best_id is not None:
|
| 1181 |
+
used_ids.add(best_id)
|
| 1182 |
+
|
| 1183 |
+
out.append(d)
|
| 1184 |
+
|
| 1185 |
+
return out
|
| 1186 |
+
|
| 1187 |
+
def build_live_agents_bundle(
|
| 1188 |
+
self,
|
| 1189 |
+
anchor_idx: int,
|
| 1190 |
+
score_threshold: float,
|
| 1191 |
+
tracking_gate_px: float,
|
| 1192 |
+
use_pose: bool,
|
| 1193 |
+
) -> dict[str, Any]:
|
| 1194 |
+
front_paths = self.list_channel_image_paths("CAM_FRONT")
|
| 1195 |
+
if len(front_paths) < 4:
|
| 1196 |
+
return {"error": "Need at least 4 CAM_FRONT frames in DataSet/samples/CAM_FRONT."}
|
| 1197 |
+
|
| 1198 |
+
if anchor_idx < 3:
|
| 1199 |
+
anchor_idx = 3
|
| 1200 |
+
if anchor_idx >= len(front_paths):
|
| 1201 |
+
anchor_idx = len(front_paths) - 1
|
| 1202 |
+
|
| 1203 |
+
models = self.load_cv_models()
|
| 1204 |
+
if "error" in models:
|
| 1205 |
+
return {
|
| 1206 |
+
"error": f"Could not load CV models ({models['error']}).",
|
| 1207 |
+
"device": models.get("device_name", "unknown"),
|
| 1208 |
+
}
|
| 1209 |
+
|
| 1210 |
+
window_paths = front_paths[anchor_idx - 3 : anchor_idx + 1]
|
| 1211 |
+
|
| 1212 |
+
tracks, front_dets = self.track_front_agents(
|
| 1213 |
+
window_paths,
|
| 1214 |
+
models,
|
| 1215 |
+
score_threshold=score_threshold,
|
| 1216 |
+
tracking_gate_px=tracking_gate_px,
|
| 1217 |
+
use_pose=use_pose,
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
if len(tracks) == 0:
|
| 1221 |
+
return {"error": "No valid tracked moving agents found in selected frame window."}
|
| 1222 |
+
|
| 1223 |
+
front_curr = window_paths[-1]
|
| 1224 |
+
fusion_data = load_fusion_for_cam_frame(
|
| 1225 |
+
front_curr.name,
|
| 1226 |
+
data_root=str(self.data_root),
|
| 1227 |
+
version="v1.0-mini",
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
agents, target_id, tracks_stable = self.build_agents_from_tracks(tracks, fusion_data)
|
| 1231 |
+
if len(agents) == 0:
|
| 1232 |
+
return {"error": "Tracking succeeded but trajectory prediction produced no agents."}
|
| 1233 |
+
|
| 1234 |
+
front_dets = self.assign_track_ids_to_front_detections(front_dets, tracks_stable, gate_px=tracking_gate_px)
|
| 1235 |
+
front_img = self.load_image_array(front_curr)
|
| 1236 |
+
scene_geometry = self.extract_scene_geometry(front_img, front_dets)
|
| 1237 |
+
live_sample_token = str(fusion_data.get("sample_token")) if fusion_data and fusion_data.get("sample_token") else None
|
| 1238 |
+
scene_geometry = self._attach_hd_map_layer(scene_geometry, live_sample_token)
|
| 1239 |
+
|
| 1240 |
+
return {
|
| 1241 |
+
"mode": "live_fusion",
|
| 1242 |
+
"agents": agents,
|
| 1243 |
+
"target_track_id": target_id,
|
| 1244 |
+
"device": models.get("device_name", "unknown"),
|
| 1245 |
+
"front_anchor_path": str(front_curr),
|
| 1246 |
+
"track_count": len(agents),
|
| 1247 |
+
"scene_geometry": scene_geometry,
|
| 1248 |
+
"camera_snapshots": {
|
| 1249 |
+
"CAM_FRONT": {
|
| 1250 |
+
"frame_path": str(front_curr),
|
| 1251 |
+
"detections": front_dets,
|
| 1252 |
+
}
|
| 1253 |
+
},
|
| 1254 |
+
"fusion_data": fusion_data,
|
| 1255 |
+
}
|
backend/scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Organized script modules for training, evaluation, and tools."""
|
backend/scripts/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Data preparation script modules."""
|
backend/scripts/data/build_dataset_from_images.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import os
|
| 6 |
+
import glob
|
| 7 |
+
import math
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
TARGET_CLASSES = {1: 'Person', 2: 'Bicycle', 3: 'Car', 4: 'Motorcycle'}
|
| 11 |
+
|
| 12 |
+
# Set up GPU acceleration
|
| 13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
|
| 15 |
+
def load_perception_model():
|
| 16 |
+
print(f"[System] Loading Pre-Trained Faster R-CNN (ResNet-50-FPN) on {device.type.upper()}...")
|
| 17 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 18 |
+
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
|
| 19 |
+
model.to(device) # Move model to GPU
|
| 20 |
+
model.eval()
|
| 21 |
+
return model, weights
|
| 22 |
+
|
| 23 |
+
def extract_features(img_path, model, weights, score_threshold=0.7):
|
| 24 |
+
image = Image.open(img_path).convert("RGB")
|
| 25 |
+
preprocess = weights.transforms()
|
| 26 |
+
# Move the image tensor to the GPU so the math runs on CUDA
|
| 27 |
+
input_batch = preprocess(image).unsqueeze(0).to(device)
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
prediction = model(input_batch)[0]
|
| 31 |
+
|
| 32 |
+
extracted = []
|
| 33 |
+
# prediction items are on GPU, so we use .item() to pull the raw number back out
|
| 34 |
+
for i, box in enumerate(prediction['boxes']):
|
| 35 |
+
score = prediction['scores'][i].item()
|
| 36 |
+
label = prediction['labels'][i].item()
|
| 37 |
+
if score > score_threshold and label in TARGET_CLASSES:
|
| 38 |
+
center_x = (box[0] + box[2]).item() / 2.0
|
| 39 |
+
bottom_y = box[3].item()
|
| 40 |
+
extracted.append({
|
| 41 |
+
'type': TARGET_CLASSES[label],
|
| 42 |
+
'coord': (round(center_x, 2), round(bottom_y, 2))
|
| 43 |
+
})
|
| 44 |
+
return extracted
|
| 45 |
+
|
| 46 |
+
def process_dataset_into_trajectories():
|
| 47 |
+
print("="*60)
|
| 48 |
+
print(f"| Starting Automated Dataset Pre-Processing Pipeline |")
|
| 49 |
+
print(f"| Hardware Acceleration: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'} |")
|
| 50 |
+
print("="*60)
|
| 51 |
+
|
| 52 |
+
model, weights = load_perception_model()
|
| 53 |
+
|
| 54 |
+
# Get images chronologically to simulate a video feed
|
| 55 |
+
image_paths = sorted(glob.glob("DataSet/samples/CAM_FRONT/*.jpg"))
|
| 56 |
+
if not image_paths:
|
| 57 |
+
print("[!] No images found to process.")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
print(f"[System] Success: Found a total of {len(image_paths)} valid image frames in the folder. Processing now...")
|
| 61 |
+
|
| 62 |
+
dataset_trajectories = []
|
| 63 |
+
|
| 64 |
+
# We need 4 frames of history for our AI Model (T-3, T-2, T-1, T0)
|
| 65 |
+
for i in range(len(image_paths) - 3):
|
| 66 |
+
frames = image_paths[i:i+4]
|
| 67 |
+
frame_data = []
|
| 68 |
+
|
| 69 |
+
# Output progress every 50 frames
|
| 70 |
+
if i % 50 == 0:
|
| 71 |
+
print(f" -> Processing frame sequence {i}/{len(image_paths)}")
|
| 72 |
+
|
| 73 |
+
for f in frames:
|
| 74 |
+
objs = extract_features(f, model, weights)
|
| 75 |
+
frame_data.append(objs)
|
| 76 |
+
|
| 77 |
+
for obj_t0 in frame_data[0]:
|
| 78 |
+
target_type = obj_t0['type']
|
| 79 |
+
track_history = [obj_t0['coord']]
|
| 80 |
+
valid_track = True
|
| 81 |
+
|
| 82 |
+
last_coord = obj_t0['coord']
|
| 83 |
+
for t in range(1, 4):
|
| 84 |
+
best_dist = float('inf')
|
| 85 |
+
best_coord = None
|
| 86 |
+
for obj_t_next in frame_data[t]:
|
| 87 |
+
if obj_t_next['type'] == target_type:
|
| 88 |
+
dist = math.sqrt((last_coord[0] - obj_t_next['coord'][0])**2 +
|
| 89 |
+
(last_coord[1] - obj_t_next['coord'][1])**2)
|
| 90 |
+
if dist < 60.0 and dist < best_dist:
|
| 91 |
+
best_dist = dist
|
| 92 |
+
best_coord = obj_t_next['coord']
|
| 93 |
+
|
| 94 |
+
if best_coord:
|
| 95 |
+
track_history.append(best_coord)
|
| 96 |
+
last_coord = best_coord
|
| 97 |
+
else:
|
| 98 |
+
valid_track = False
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
if valid_track:
|
| 102 |
+
dataset_trajectories.append({
|
| 103 |
+
"agent_type": target_type,
|
| 104 |
+
"trajectory_pixels": track_history
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
output_file = "extracted_training_data.json"
|
| 108 |
+
with open(output_file, "w") as f:
|
| 109 |
+
json.dump(dataset_trajectories, f, indent=4)
|
| 110 |
+
|
| 111 |
+
print(f"\n[Success] Pipeline Complete!")
|
| 112 |
+
print(f"[+] Extracted {len(dataset_trajectories)} valid moving trajectories from raw images.")
|
| 113 |
+
print(f"[+] Saved AI Training payload to: {output_file}")
|
| 114 |
+
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
+
try:
|
| 117 |
+
process_dataset_into_trajectories()
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error during processing: {e}")
|
backend/scripts/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation and benchmarking script modules."""
|
backend/scripts/evaluation/benchmark_perf.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision.models.detection import (
|
| 6 |
+
fasterrcnn_resnet50_fpn,
|
| 7 |
+
FasterRCNN_ResNet50_FPN_Weights,
|
| 8 |
+
keypointrcnn_resnet50_fpn,
|
| 9 |
+
KeypointRCNN_ResNet50_FPN_Weights,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from backend.app.ml.sensor_fusion import load_fusion_for_cam_frame
|
| 13 |
+
from backend.app.ml.inference import predict, USING_FUSION_MODEL
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
img_path = r"DataSet/samples/CAM_FRONT/n008-2018-08-01-15-16-36-0400__CAM_FRONT__1533151603512404.jpg"
|
| 18 |
+
img = Image.open(img_path).convert("RGB")
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
print("device", device)
|
| 22 |
+
print("using_fusion_model", USING_FUSION_MODEL)
|
| 23 |
+
|
| 24 |
+
t0 = time.perf_counter()
|
| 25 |
+
w_det = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 26 |
+
m_det = fasterrcnn_resnet50_fpn(weights=w_det, progress=False).to(device).eval()
|
| 27 |
+
if device.type == "cuda":
|
| 28 |
+
torch.cuda.synchronize()
|
| 29 |
+
load_det = (time.perf_counter() - t0) * 1000
|
| 30 |
+
|
| 31 |
+
t0 = time.perf_counter()
|
| 32 |
+
w_pose = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 33 |
+
m_pose = keypointrcnn_resnet50_fpn(weights=w_pose, progress=False).to(device).eval()
|
| 34 |
+
if device.type == "cuda":
|
| 35 |
+
torch.cuda.synchronize()
|
| 36 |
+
load_pose = (time.perf_counter() - t0) * 1000
|
| 37 |
+
|
| 38 |
+
print("load_ms_fasterrcnn", round(load_det, 2))
|
| 39 |
+
print("load_ms_keypointrcnn", round(load_pose, 2))
|
| 40 |
+
|
| 41 |
+
in_det = w_det.transforms()(img).unsqueeze(0).to(device)
|
| 42 |
+
in_pose = w_pose.transforms()(img).unsqueeze(0).to(device)
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
_ = m_det(in_det)
|
| 46 |
+
_ = m_pose(in_pose)
|
| 47 |
+
if device.type == "cuda":
|
| 48 |
+
torch.cuda.synchronize()
|
| 49 |
+
|
| 50 |
+
n = 5
|
| 51 |
+
|
| 52 |
+
st = time.perf_counter()
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for _ in range(n):
|
| 55 |
+
_ = m_det(in_det)
|
| 56 |
+
if device.type == "cuda":
|
| 57 |
+
torch.cuda.synchronize()
|
| 58 |
+
det_ms = (time.perf_counter() - st) * 1000 / n
|
| 59 |
+
|
| 60 |
+
st = time.perf_counter()
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
for _ in range(n):
|
| 63 |
+
_ = m_pose(in_pose)
|
| 64 |
+
if device.type == "cuda":
|
| 65 |
+
torch.cuda.synchronize()
|
| 66 |
+
pose_ms = (time.perf_counter() - st) * 1000 / n
|
| 67 |
+
|
| 68 |
+
print("avg_ms_det_per_frame", round(det_ms, 2))
|
| 69 |
+
print("avg_ms_pose_per_frame", round(pose_ms, 2))
|
| 70 |
+
|
| 71 |
+
m = 30
|
| 72 |
+
st = time.perf_counter()
|
| 73 |
+
for _ in range(m):
|
| 74 |
+
_ = load_fusion_for_cam_frame(
|
| 75 |
+
"n008-2018-08-01-15-16-36-0400__CAM_FRONT__1533151603512404.jpg",
|
| 76 |
+
data_root="DataSet",
|
| 77 |
+
)
|
| 78 |
+
fusion_ms = (time.perf_counter() - st) * 1000 / m
|
| 79 |
+
print("avg_ms_fusion_lookup", round(fusion_ms, 2))
|
| 80 |
+
|
| 81 |
+
pts = [(0, 10), (2, 10), (4, 10), (6, 10)]
|
| 82 |
+
neigh = [
|
| 83 |
+
[(8, 12), (8.5, 12), (9, 12), (9.5, 12)],
|
| 84 |
+
[(15, 7), (15.5, 7.2), (16, 7.5), (16.4, 7.7)],
|
| 85 |
+
]
|
| 86 |
+
fusion_feats = [[0.2, 0.1, 0.25], [0.25, 0.1, 0.3], [0.3, 0.12, 0.35], [0.35, 0.15, 0.4]]
|
| 87 |
+
|
| 88 |
+
for _ in range(10):
|
| 89 |
+
_ = predict(pts, neigh, fusion_feats=fusion_feats)
|
| 90 |
+
if device.type == "cuda":
|
| 91 |
+
torch.cuda.synchronize()
|
| 92 |
+
|
| 93 |
+
k = 300
|
| 94 |
+
st = time.perf_counter()
|
| 95 |
+
for _ in range(k):
|
| 96 |
+
_ = predict(pts, neigh, fusion_feats=fusion_feats)
|
| 97 |
+
if device.type == "cuda":
|
| 98 |
+
torch.cuda.synchronize()
|
| 99 |
+
pred_ms = (time.perf_counter() - st) * 1000 / k
|
| 100 |
+
print("avg_ms_transformer_predict", round(pred_ms, 4))
|
| 101 |
+
|
| 102 |
+
approx = 2 * det_ms + pose_ms + fusion_ms + 6 * pred_ms
|
| 103 |
+
fps = 1000.0 / approx if approx > 0 else 0.0
|
| 104 |
+
print("approx_live_2frame_ms", round(approx, 2))
|
| 105 |
+
print("approx_live_equiv_fps", round(fps, 2))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
backend/scripts/evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from backend.app.legacy.dataset import TrajectoryDataset
|
| 5 |
+
from backend.app.ml.model import TrajectoryTransformer
|
| 6 |
+
from backend.scripts.training.train import get_data, collate_fn, compute_ade, compute_fde
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 11 |
+
BASE_CKPT = REPO_ROOT / "models" / "best_social_model.pth"
|
| 12 |
+
|
| 13 |
+
def evaluate():
|
| 14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
+
print(f"Running Evaluation on {device}...")
|
| 16 |
+
|
| 17 |
+
samples = get_data()
|
| 18 |
+
|
| 19 |
+
# Use the same deterministic split as train.py to evaluate on validation set
|
| 20 |
+
random.seed(42)
|
| 21 |
+
random.shuffle(samples)
|
| 22 |
+
train_size = int(0.8 * len(samples))
|
| 23 |
+
val_samples = samples[train_size:]
|
| 24 |
+
|
| 25 |
+
dataset = TrajectoryDataset(val_samples, augment=False)
|
| 26 |
+
eval_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn)
|
| 27 |
+
|
| 28 |
+
# Load Model
|
| 29 |
+
model = TrajectoryTransformer().to(device)
|
| 30 |
+
try:
|
| 31 |
+
model.load_state_dict(torch.load(BASE_CKPT, map_location=device, weights_only=True))
|
| 32 |
+
print("Successfully loaded 'best_social_model.pth' from models folder")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Could not load model weights: {e}")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
model.eval()
|
| 38 |
+
|
| 39 |
+
total_ade = 0
|
| 40 |
+
total_fde = 0
|
| 41 |
+
miss_rate = 0
|
| 42 |
+
|
| 43 |
+
cv_total_ade = 0
|
| 44 |
+
cv_total_fde = 0
|
| 45 |
+
cv_miss_rate = 0
|
| 46 |
+
|
| 47 |
+
total_samples = 0
|
| 48 |
+
|
| 49 |
+
# Miss rate threshold: if best path's endpoint is off by more than 2.0 meters
|
| 50 |
+
MISS_THRESHOLD = 2.0
|
| 51 |
+
|
| 52 |
+
print("\n--- Starting Deep Evaluation ---")
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for obs, neighbors, future in eval_loader:
|
| 55 |
+
obs, future = obs.to(device), future.to(device)
|
| 56 |
+
|
| 57 |
+
# --- MODEL PREDICTION ---
|
| 58 |
+
pred, goals, probs, _ = model(obs, neighbors)
|
| 59 |
+
|
| 60 |
+
# Find the best prediction out of K=3 for each item in the batch
|
| 61 |
+
gt = future.unsqueeze(1)
|
| 62 |
+
error = torch.norm(pred - gt, dim=3).mean(dim=2)
|
| 63 |
+
best_idx = torch.argmin(error, dim=1)
|
| 64 |
+
best_pred = pred[torch.arange(pred.size(0)), best_idx]
|
| 65 |
+
|
| 66 |
+
# Metrics Model
|
| 67 |
+
batch_ade = compute_ade(best_pred, future).item()
|
| 68 |
+
batch_fde = compute_fde(best_pred, future).item()
|
| 69 |
+
|
| 70 |
+
total_ade += batch_ade * obs.size(0)
|
| 71 |
+
total_fde += batch_fde * obs.size(0)
|
| 72 |
+
|
| 73 |
+
final_displacement = torch.norm(best_pred[:, -1] - future[:, -1], dim=1)
|
| 74 |
+
misses = (final_displacement > MISS_THRESHOLD).sum().item()
|
| 75 |
+
miss_rate += misses
|
| 76 |
+
|
| 77 |
+
# --- CONSTANT VELOCITY BASELINE ---
|
| 78 |
+
vx = obs[:, 3, 2].unsqueeze(1) # dx at last observed step
|
| 79 |
+
vy = obs[:, 3, 3].unsqueeze(1) # dy at last observed step
|
| 80 |
+
|
| 81 |
+
t = torch.arange(1, 13, device=device).unsqueeze(0).float() # Horizon is 12 steps
|
| 82 |
+
|
| 83 |
+
x_last = obs[:, 3, 0].unsqueeze(1) # x at last step
|
| 84 |
+
y_last = obs[:, 3, 1].unsqueeze(1) # y at last step
|
| 85 |
+
|
| 86 |
+
cv_pred_x = x_last + vx * t
|
| 87 |
+
cv_pred_y = y_last + vy * t
|
| 88 |
+
cv_pred = torch.stack([cv_pred_x, cv_pred_y], dim=-1)
|
| 89 |
+
|
| 90 |
+
# Metrics CV Baseline
|
| 91 |
+
cv_batch_ade = compute_ade(cv_pred, future).item()
|
| 92 |
+
cv_batch_fde = compute_fde(cv_pred, future).item()
|
| 93 |
+
|
| 94 |
+
cv_total_ade += cv_batch_ade * obs.size(0)
|
| 95 |
+
cv_total_fde += cv_batch_fde * obs.size(0)
|
| 96 |
+
|
| 97 |
+
cv_final_displacement = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1)
|
| 98 |
+
cv_misses = (cv_final_displacement > MISS_THRESHOLD).sum().item()
|
| 99 |
+
cv_miss_rate += cv_misses
|
| 100 |
+
|
| 101 |
+
total_samples += obs.size(0)
|
| 102 |
+
|
| 103 |
+
# Average metrics
|
| 104 |
+
avg_ade = total_ade / total_samples
|
| 105 |
+
avg_fde = total_fde / total_samples
|
| 106 |
+
avg_miss_rate = (miss_rate / total_samples) * 100
|
| 107 |
+
|
| 108 |
+
cv_avg_ade = cv_total_ade / total_samples
|
| 109 |
+
cv_avg_fde = cv_total_fde / total_samples
|
| 110 |
+
cv_avg_miss_rate = (cv_miss_rate / total_samples) * 100
|
| 111 |
+
|
| 112 |
+
print("\n========================================================")
|
| 113 |
+
print(" HACKATHON FINAL METRICS REPORT ")
|
| 114 |
+
print("========================================================")
|
| 115 |
+
print(f"Total Trajectories Evaluated (Val Set): {total_samples}")
|
| 116 |
+
print(f"Prediction Horizon: 6 Seconds (12 steps)")
|
| 117 |
+
print(f"Social Context Radius: 50 Meters")
|
| 118 |
+
print("--------------------------------------------------------")
|
| 119 |
+
print("METRIC | BASELINE (CV) | OUR TRANSFORMER ")
|
| 120 |
+
print("------------------------|---------------|-----------------")
|
| 121 |
+
print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:15.2f}")
|
| 122 |
+
print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:15.2f}")
|
| 123 |
+
print(f"Miss Rate (>2.0m) | {cv_avg_miss_rate:12.1f}% | {avg_miss_rate:14.1f}%")
|
| 124 |
+
print("========================================================\n")
|
| 125 |
+
|
| 126 |
+
if __name__ == '__main__':
|
| 127 |
+
evaluate()
|
backend/scripts/evaluation/evaluate_phase2_fusion.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
from backend.app.legacy.data_loader import (
|
| 8 |
+
load_json,
|
| 9 |
+
extract_pedestrian_instances,
|
| 10 |
+
build_trajectories_with_sensor,
|
| 11 |
+
create_windows_with_sensor,
|
| 12 |
+
)
|
| 13 |
+
from backend.app.legacy.dataset_fusion import FusionTrajectoryDataset
|
| 14 |
+
from backend.app.ml.model_fusion import TrajectoryTransformerFusion
|
| 15 |
+
|
| 16 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 17 |
+
DEFAULT_FUSION_CKPT = REPO_ROOT / "models" / "best_social_model_fusion.pth"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def collate_fn_fusion(batch):
|
| 21 |
+
obs, neighbors, fusion_obs, future = zip(*batch)
|
| 22 |
+
obs = torch.stack(obs)
|
| 23 |
+
fusion_obs = torch.stack(fusion_obs)
|
| 24 |
+
future = torch.stack(future)
|
| 25 |
+
return obs, list(neighbors), fusion_obs, future
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_ade(pred, gt):
|
| 29 |
+
return torch.mean(torch.norm(pred - gt, dim=2))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_fde(pred, gt):
|
| 33 |
+
return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_fusion_samples():
|
| 37 |
+
sample_annotations = load_json("sample_annotation")
|
| 38 |
+
instances = load_json("instance")
|
| 39 |
+
categories = load_json("category")
|
| 40 |
+
|
| 41 |
+
ped_instances = extract_pedestrian_instances(sample_annotations, instances, categories)
|
| 42 |
+
trajectories = build_trajectories_with_sensor(sample_annotations, ped_instances)
|
| 43 |
+
return create_windows_with_sensor(trajectories)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def evaluate_fusion(ckpt=DEFAULT_FUSION_CKPT):
|
| 47 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 48 |
+
print(f"Running Phase 2 Fusion Evaluation on {device}...")
|
| 49 |
+
|
| 50 |
+
ckpt_path = Path(ckpt)
|
| 51 |
+
if not ckpt_path.is_absolute():
|
| 52 |
+
ckpt_path = REPO_ROOT / ckpt_path
|
| 53 |
+
|
| 54 |
+
samples = load_fusion_samples()
|
| 55 |
+
random.seed(42)
|
| 56 |
+
random.shuffle(samples)
|
| 57 |
+
train_size = int(0.8 * len(samples))
|
| 58 |
+
val_samples = samples[train_size:]
|
| 59 |
+
|
| 60 |
+
dataset = FusionTrajectoryDataset(val_samples, augment=False)
|
| 61 |
+
loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn_fusion)
|
| 62 |
+
|
| 63 |
+
model = TrajectoryTransformerFusion(fusion_dim=3).to(device)
|
| 64 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
| 65 |
+
model.eval()
|
| 66 |
+
|
| 67 |
+
total_ade = 0.0
|
| 68 |
+
total_fde = 0.0
|
| 69 |
+
miss_count = 0
|
| 70 |
+
|
| 71 |
+
cv_total_ade = 0.0
|
| 72 |
+
cv_total_fde = 0.0
|
| 73 |
+
cv_miss_count = 0
|
| 74 |
+
|
| 75 |
+
total_n = 0
|
| 76 |
+
miss_threshold = 2.0
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
for obs, neighbors, fusion_obs, future in loader:
|
| 80 |
+
obs = obs.to(device)
|
| 81 |
+
fusion_obs = fusion_obs.to(device)
|
| 82 |
+
future = future.to(device)
|
| 83 |
+
|
| 84 |
+
pred, goals, probs, _ = model(obs, neighbors, fusion_obs)
|
| 85 |
+
|
| 86 |
+
gt = future.unsqueeze(1)
|
| 87 |
+
err = torch.norm(pred - gt, dim=3).mean(dim=2)
|
| 88 |
+
best_idx = torch.argmin(err, dim=1)
|
| 89 |
+
best_pred = pred[torch.arange(pred.size(0), device=device), best_idx]
|
| 90 |
+
|
| 91 |
+
total_ade += compute_ade(best_pred, future).item() * obs.size(0)
|
| 92 |
+
total_fde += compute_fde(best_pred, future).item() * obs.size(0)
|
| 93 |
+
|
| 94 |
+
final_disp = torch.norm(best_pred[:, -1] - future[:, -1], dim=1)
|
| 95 |
+
miss_count += (final_disp > miss_threshold).sum().item()
|
| 96 |
+
|
| 97 |
+
# Constant velocity baseline for comparison.
|
| 98 |
+
vx = obs[:, 3, 2].unsqueeze(1)
|
| 99 |
+
vy = obs[:, 3, 3].unsqueeze(1)
|
| 100 |
+
t = torch.arange(1, 13, device=device).unsqueeze(0).float()
|
| 101 |
+
x_last = obs[:, 3, 0].unsqueeze(1)
|
| 102 |
+
y_last = obs[:, 3, 1].unsqueeze(1)
|
| 103 |
+
|
| 104 |
+
cv_x = x_last + vx * t
|
| 105 |
+
cv_y = y_last + vy * t
|
| 106 |
+
cv_pred = torch.stack([cv_x, cv_y], dim=-1)
|
| 107 |
+
|
| 108 |
+
cv_total_ade += compute_ade(cv_pred, future).item() * obs.size(0)
|
| 109 |
+
cv_total_fde += compute_fde(cv_pred, future).item() * obs.size(0)
|
| 110 |
+
cv_final = torch.norm(cv_pred[:, -1] - future[:, -1], dim=1)
|
| 111 |
+
cv_miss_count += (cv_final > miss_threshold).sum().item()
|
| 112 |
+
|
| 113 |
+
total_n += obs.size(0)
|
| 114 |
+
|
| 115 |
+
avg_ade = total_ade / total_n
|
| 116 |
+
avg_fde = total_fde / total_n
|
| 117 |
+
avg_miss = 100.0 * miss_count / total_n
|
| 118 |
+
|
| 119 |
+
cv_avg_ade = cv_total_ade / total_n
|
| 120 |
+
cv_avg_fde = cv_total_fde / total_n
|
| 121 |
+
cv_avg_miss = 100.0 * cv_miss_count / total_n
|
| 122 |
+
|
| 123 |
+
print("\n========================================================")
|
| 124 |
+
print(" PHASE 2 FUSION METRICS REPORT ")
|
| 125 |
+
print("========================================================")
|
| 126 |
+
print(f"Total Trajectories Evaluated: {total_n}")
|
| 127 |
+
print("--------------------------------------------------------")
|
| 128 |
+
print("METRIC | BASELINE (CV) | FUSION MODEL ")
|
| 129 |
+
print("------------------------|---------------|----------------")
|
| 130 |
+
print(f"minADE@3 (meters) | {cv_avg_ade:13.2f} | {avg_ade:14.2f}")
|
| 131 |
+
print(f"minFDE@3 (meters) | {cv_avg_fde:13.2f} | {avg_fde:14.2f}")
|
| 132 |
+
print(f"Miss Rate (>2.0m) | {cv_avg_miss:12.1f}% | {avg_miss:13.1f}%")
|
| 133 |
+
print("========================================================\n")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
evaluate_fusion()
|
backend/scripts/legacy/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Legacy Streamlit and hackathon app modules."""
|
backend/scripts/legacy/app_streamlit.py
ADDED
|
@@ -0,0 +1,2533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import io
|
| 3 |
+
import math
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
import streamlit as st
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import cv2
|
| 16 |
+
except Exception:
|
| 17 |
+
cv2 = None
|
| 18 |
+
|
| 19 |
+
from torchvision.models.detection import (
|
| 20 |
+
FasterRCNN_ResNet50_FPN_Weights,
|
| 21 |
+
KeypointRCNN_ResNet50_FPN_Weights,
|
| 22 |
+
fasterrcnn_resnet50_fpn,
|
| 23 |
+
keypointrcnn_resnet50_fpn,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from backend.app.ml.inference import USING_FUSION_MODEL, predict as trajectory_predict
|
| 27 |
+
from backend.app.ml.sensor_fusion import load_fusion_for_cam_frame, radar_stabilize_motion
|
| 28 |
+
|
| 29 |
+
# ----------------------------
|
| 30 |
+
# PAGE CONFIG
|
| 31 |
+
# ----------------------------
|
| 32 |
+
st.set_page_config(page_title="Multi-Agent Trajectory Prediction Simulator", layout="wide")
|
| 33 |
+
|
| 34 |
+
BG_PRIMARY = "#05070f"
|
| 35 |
+
BG_SECONDARY = "#0b1220"
|
| 36 |
+
GRID_COLOR = "rgba(100, 116, 139, 0.22)"
|
| 37 |
+
ACCENT = "#eb6b26"
|
| 38 |
+
TARGET_PURPLE = "#a855f7"
|
| 39 |
+
VRU_GREEN = "#22c55e"
|
| 40 |
+
VEHICLE_YELLOW = "#facc15"
|
| 41 |
+
EGO_CYAN = "#22d3ee"
|
| 42 |
+
WHITE = "#e5e7eb"
|
| 43 |
+
TRAJ_MODE_COLORS = ["#22d3ee", "#a855f7", "#fb923c"]
|
| 44 |
+
|
| 45 |
+
ROAD_ASPHALT = "rgba(26, 34, 45, 0.94)"
|
| 46 |
+
ROAD_SHOULDER = "rgba(12, 18, 28, 0.90)"
|
| 47 |
+
LANE_SOLID = "rgba(226, 232, 240, 0.88)"
|
| 48 |
+
LANE_DASH = "rgba(203, 213, 225, 0.72)"
|
| 49 |
+
CENTER_DASH = "rgba(250, 204, 21, 0.82)"
|
| 50 |
+
|
| 51 |
+
CAMERA_VIEWS = [
|
| 52 |
+
("CAM_FRONT", "Front", 0.0),
|
| 53 |
+
("CAM_FRONT_LEFT", "Front-Left", 40.0),
|
| 54 |
+
("CAM_FRONT_RIGHT", "Front-Right", -40.0),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
SYNTH_SKELETON_EDGES = [
|
| 58 |
+
(0, 1),
|
| 59 |
+
(1, 2),
|
| 60 |
+
(1, 3),
|
| 61 |
+
(2, 4),
|
| 62 |
+
(3, 5),
|
| 63 |
+
(1, 6),
|
| 64 |
+
(6, 7),
|
| 65 |
+
(6, 8),
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
COCO_SKELETON_EDGES = [
|
| 69 |
+
(0, 1),
|
| 70 |
+
(0, 2),
|
| 71 |
+
(1, 3),
|
| 72 |
+
(2, 4),
|
| 73 |
+
(5, 6),
|
| 74 |
+
(5, 7),
|
| 75 |
+
(7, 9),
|
| 76 |
+
(6, 8),
|
| 77 |
+
(8, 10),
|
| 78 |
+
(5, 11),
|
| 79 |
+
(6, 12),
|
| 80 |
+
(11, 12),
|
| 81 |
+
(11, 13),
|
| 82 |
+
(13, 15),
|
| 83 |
+
(12, 14),
|
| 84 |
+
(14, 16),
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
COCO_TO_LABEL = {
|
| 88 |
+
1: "person",
|
| 89 |
+
2: "bicycle",
|
| 90 |
+
3: "car",
|
| 91 |
+
4: "motorcycle",
|
| 92 |
+
6: "bus",
|
| 93 |
+
8: "truck",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
VRU_LABELS = {"person", "bicycle", "motorcycle"}
|
| 97 |
+
VEHICLE_LABELS = {"car", "bus", "truck"}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def normalize_probs(probs):
|
| 101 |
+
arr = np.asarray(probs, dtype=float)
|
| 102 |
+
arr = np.clip(arr, 1e-6, None)
|
| 103 |
+
arr = arr / arr.sum()
|
| 104 |
+
return arr.tolist()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def agent_color(agent):
|
| 108 |
+
if agent.get("is_target", False):
|
| 109 |
+
return TARGET_PURPLE
|
| 110 |
+
if agent.get("type") == "pedestrian":
|
| 111 |
+
return VRU_GREEN
|
| 112 |
+
return VEHICLE_YELLOW
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def coco_kind(label_name):
|
| 116 |
+
if label_name in VRU_LABELS:
|
| 117 |
+
return "pedestrian"
|
| 118 |
+
if label_name in VEHICLE_LABELS:
|
| 119 |
+
return "vehicle"
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def iou_xyxy(box_a, box_b):
|
| 124 |
+
ax1, ay1, ax2, ay2 = box_a
|
| 125 |
+
bx1, by1, bx2, by2 = box_b
|
| 126 |
+
|
| 127 |
+
ix1 = max(ax1, bx1)
|
| 128 |
+
iy1 = max(ay1, by1)
|
| 129 |
+
ix2 = min(ax2, bx2)
|
| 130 |
+
iy2 = min(ay2, by2)
|
| 131 |
+
|
| 132 |
+
iw = max(0.0, ix2 - ix1)
|
| 133 |
+
ih = max(0.0, iy2 - iy1)
|
| 134 |
+
inter = iw * ih
|
| 135 |
+
|
| 136 |
+
area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
|
| 137 |
+
area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
|
| 138 |
+
union = area_a + area_b - inter
|
| 139 |
+
|
| 140 |
+
if union <= 1e-9:
|
| 141 |
+
return 0.0
|
| 142 |
+
return inter / union
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def pixel_to_bev(center_x, bottom_y, width, height):
|
| 146 |
+
# Dynamic scaling from current frame dimensions (no hardcoded resolution assumptions).
|
| 147 |
+
x_div = max(1.0, width / 80.0)
|
| 148 |
+
y_div = max(1.0, height / 50.0)
|
| 149 |
+
|
| 150 |
+
x_m = (center_x - 0.5 * width) / x_div
|
| 151 |
+
y_m = (bottom_y - 0.58 * height) / y_div
|
| 152 |
+
return float(x_m), float(y_m)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def fallback_canvas():
|
| 156 |
+
h, w = 540, 960
|
| 157 |
+
canvas = np.zeros((h, w, 3), dtype=np.uint8)
|
| 158 |
+
canvas[:, :, 0] = 10
|
| 159 |
+
canvas[:, :, 1] = 14
|
| 160 |
+
canvas[:, :, 2] = 28
|
| 161 |
+
return canvas
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@st.cache_data(show_spinner=False)
|
| 165 |
+
def list_channel_image_paths(channel):
|
| 166 |
+
base = Path("DataSet") / "samples" / channel
|
| 167 |
+
if not base.exists():
|
| 168 |
+
return []
|
| 169 |
+
return [str(p) for p in sorted(base.glob("*.jpg"))]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@st.cache_data(show_spinner=False)
|
| 173 |
+
def load_image_array(image_path):
|
| 174 |
+
return np.asarray(Image.open(image_path).convert("RGB"))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_camera_frame(channel, frame_idx=0):
|
| 178 |
+
image_paths = list_channel_image_paths(channel)
|
| 179 |
+
if image_paths:
|
| 180 |
+
idx = int(np.clip(frame_idx, 0, len(image_paths) - 1))
|
| 181 |
+
return load_image_array(image_paths[idx]), image_paths[idx]
|
| 182 |
+
return fallback_canvas(), None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@st.cache_resource(show_spinner=False)
|
| 186 |
+
def load_cv_models():
|
| 187 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
det_weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 191 |
+
det_model = fasterrcnn_resnet50_fpn(weights=det_weights, progress=False)
|
| 192 |
+
det_model.to(device).eval()
|
| 193 |
+
|
| 194 |
+
pose_weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 195 |
+
pose_model = keypointrcnn_resnet50_fpn(weights=pose_weights, progress=False)
|
| 196 |
+
pose_model.to(device).eval()
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"device": device,
|
| 200 |
+
"device_name": str(device),
|
| 201 |
+
"det_model": det_model,
|
| 202 |
+
"det_weights": det_weights,
|
| 203 |
+
"pose_model": pose_model,
|
| 204 |
+
"pose_weights": pose_weights,
|
| 205 |
+
}
|
| 206 |
+
except Exception as exc:
|
| 207 |
+
return {
|
| 208 |
+
"error": str(exc),
|
| 209 |
+
"device": device,
|
| 210 |
+
"device_name": str(device),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def detect_objects_and_pose(image_arr, models, score_threshold=0.55, use_pose=True):
|
| 215 |
+
if "error" in models:
|
| 216 |
+
return []
|
| 217 |
+
|
| 218 |
+
device = models["device"]
|
| 219 |
+
pil_img = Image.fromarray(image_arr)
|
| 220 |
+
|
| 221 |
+
det_input = models["det_weights"].transforms()(pil_img).unsqueeze(0).to(device)
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
det_out = models["det_model"](det_input)[0]
|
| 224 |
+
|
| 225 |
+
boxes = det_out["boxes"].detach().cpu().numpy() if len(det_out["boxes"]) > 0 else np.zeros((0, 4))
|
| 226 |
+
scores = det_out["scores"].detach().cpu().numpy() if len(det_out["scores"]) > 0 else np.zeros((0,))
|
| 227 |
+
labels = det_out["labels"].detach().cpu().numpy() if len(det_out["labels"]) > 0 else np.zeros((0,))
|
| 228 |
+
|
| 229 |
+
detections = []
|
| 230 |
+
for i in range(len(scores)):
|
| 231 |
+
score = float(scores[i])
|
| 232 |
+
label_idx = int(labels[i])
|
| 233 |
+
label_name = COCO_TO_LABEL.get(label_idx)
|
| 234 |
+
|
| 235 |
+
if label_name is None or score < score_threshold:
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
kind = coco_kind(label_name)
|
| 239 |
+
if kind is None:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
x1, y1, x2, y2 = [float(v) for v in boxes[i]]
|
| 243 |
+
detections.append(
|
| 244 |
+
{
|
| 245 |
+
"score": score,
|
| 246 |
+
"raw_label": label_name,
|
| 247 |
+
"kind": kind,
|
| 248 |
+
"box": [x1, y1, x2, y2],
|
| 249 |
+
"center_x": 0.5 * (x1 + x2),
|
| 250 |
+
"bottom_y": y2,
|
| 251 |
+
"keypoints": None,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if use_pose:
|
| 256 |
+
pose_input = models["pose_weights"].transforms()(pil_img).unsqueeze(0).to(device)
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
pose_out = models["pose_model"](pose_input)[0]
|
| 259 |
+
|
| 260 |
+
p_boxes = pose_out["boxes"].detach().cpu().numpy() if len(pose_out["boxes"]) > 0 else np.zeros((0, 4))
|
| 261 |
+
p_scores = pose_out["scores"].detach().cpu().numpy() if len(pose_out["scores"]) > 0 else np.zeros((0,))
|
| 262 |
+
p_labels = pose_out["labels"].detach().cpu().numpy() if len(pose_out["labels"]) > 0 else np.zeros((0,))
|
| 263 |
+
p_keypoints = pose_out["keypoints"].detach().cpu().numpy() if len(pose_out["keypoints"]) > 0 else np.zeros((0, 17, 3))
|
| 264 |
+
|
| 265 |
+
assigned = set()
|
| 266 |
+
for i in range(len(p_scores)):
|
| 267 |
+
if int(p_labels[i]) != 1:
|
| 268 |
+
continue
|
| 269 |
+
if float(p_scores[i]) < max(0.25, 0.8 * score_threshold):
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
pose_box = [float(v) for v in p_boxes[i]]
|
| 273 |
+
best_idx = None
|
| 274 |
+
best_iou = 0.0
|
| 275 |
+
|
| 276 |
+
for det_idx, det in enumerate(detections):
|
| 277 |
+
if det_idx in assigned:
|
| 278 |
+
continue
|
| 279 |
+
if det["raw_label"] != "person":
|
| 280 |
+
continue
|
| 281 |
+
iou_val = iou_xyxy(det["box"], pose_box)
|
| 282 |
+
if iou_val > best_iou:
|
| 283 |
+
best_iou = iou_val
|
| 284 |
+
best_idx = det_idx
|
| 285 |
+
|
| 286 |
+
if best_idx is not None and best_iou > 0.1:
|
| 287 |
+
detections[best_idx]["keypoints"] = p_keypoints[i].tolist()
|
| 288 |
+
assigned.add(best_idx)
|
| 289 |
+
|
| 290 |
+
return detections
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def track_front_agents(front_paths, models, score_threshold=0.55, tracking_gate_px=90.0, use_pose=True):
|
| 294 |
+
tracks = {}
|
| 295 |
+
next_track_id = 1
|
| 296 |
+
front_final_detections = []
|
| 297 |
+
|
| 298 |
+
for frame_idx, frame_path in enumerate(front_paths):
|
| 299 |
+
frame_arr = load_image_array(frame_path)
|
| 300 |
+
h, w = frame_arr.shape[:2]
|
| 301 |
+
|
| 302 |
+
detections = detect_objects_and_pose(
|
| 303 |
+
frame_arr,
|
| 304 |
+
models,
|
| 305 |
+
score_threshold=score_threshold,
|
| 306 |
+
use_pose=use_pose,
|
| 307 |
+
)
|
| 308 |
+
detections.sort(key=lambda d: d["score"], reverse=True)
|
| 309 |
+
|
| 310 |
+
matched_track_ids = set()
|
| 311 |
+
frame_dets_with_ids = []
|
| 312 |
+
|
| 313 |
+
for det in detections:
|
| 314 |
+
wx, wy = pixel_to_bev(det["center_x"], det["bottom_y"], w, h)
|
| 315 |
+
|
| 316 |
+
best_track_id = None
|
| 317 |
+
best_dist = 1e9
|
| 318 |
+
|
| 319 |
+
for tid, tr in tracks.items():
|
| 320 |
+
if tr["kind"] != det["kind"]:
|
| 321 |
+
continue
|
| 322 |
+
if tr["last_seen"] != frame_idx - 1:
|
| 323 |
+
continue
|
| 324 |
+
if tid in matched_track_ids:
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
px_last, py_last = tr["history_pixel"][-1]
|
| 328 |
+
dist = math.hypot(det["center_x"] - px_last, det["bottom_y"] - py_last)
|
| 329 |
+
if dist < tracking_gate_px and dist < best_dist:
|
| 330 |
+
best_dist = dist
|
| 331 |
+
best_track_id = tid
|
| 332 |
+
|
| 333 |
+
if best_track_id is None:
|
| 334 |
+
best_track_id = next_track_id
|
| 335 |
+
next_track_id += 1
|
| 336 |
+
tracks[best_track_id] = {
|
| 337 |
+
"id": best_track_id,
|
| 338 |
+
"kind": det["kind"],
|
| 339 |
+
"raw_label": det["raw_label"],
|
| 340 |
+
"history_pixel": [],
|
| 341 |
+
"history_world": [],
|
| 342 |
+
"last_seen": -1,
|
| 343 |
+
"last_box": None,
|
| 344 |
+
"last_keypoints": None,
|
| 345 |
+
"misses": 0,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
tr = tracks[best_track_id]
|
| 349 |
+
tr["history_pixel"].append((float(det["center_x"]), float(det["bottom_y"])))
|
| 350 |
+
tr["history_world"].append((float(wx), float(wy)))
|
| 351 |
+
tr["last_seen"] = frame_idx
|
| 352 |
+
tr["raw_label"] = det["raw_label"]
|
| 353 |
+
tr["last_box"] = det["box"]
|
| 354 |
+
tr["last_keypoints"] = det.get("keypoints")
|
| 355 |
+
tr["misses"] = 0
|
| 356 |
+
|
| 357 |
+
matched_track_ids.add(best_track_id)
|
| 358 |
+
|
| 359 |
+
det = dict(det)
|
| 360 |
+
det["track_id"] = best_track_id
|
| 361 |
+
frame_dets_with_ids.append(det)
|
| 362 |
+
|
| 363 |
+
# Extrapolate temporarily-lost tracks so 4-point histories can still be formed.
|
| 364 |
+
for tid, tr in tracks.items():
|
| 365 |
+
if tr["last_seen"] == frame_idx:
|
| 366 |
+
continue
|
| 367 |
+
if tr["last_seen"] < frame_idx - 1:
|
| 368 |
+
continue
|
| 369 |
+
|
| 370 |
+
if len(tr["history_pixel"]) >= 2:
|
| 371 |
+
px_prev, py_prev = tr["history_pixel"][-2]
|
| 372 |
+
px_last, py_last = tr["history_pixel"][-1]
|
| 373 |
+
wx_prev, wy_prev = tr["history_world"][-2]
|
| 374 |
+
wx_last, wy_last = tr["history_world"][-1]
|
| 375 |
+
|
| 376 |
+
px_ex = px_last + (px_last - px_prev)
|
| 377 |
+
py_ex = py_last + (py_last - py_prev)
|
| 378 |
+
wx_ex = wx_last + (wx_last - wx_prev)
|
| 379 |
+
wy_ex = wy_last + (wy_last - wy_prev)
|
| 380 |
+
else:
|
| 381 |
+
px_ex, py_ex = tr["history_pixel"][-1]
|
| 382 |
+
wx_ex, wy_ex = tr["history_world"][-1]
|
| 383 |
+
|
| 384 |
+
tr["history_pixel"].append((float(px_ex), float(py_ex)))
|
| 385 |
+
tr["history_world"].append((float(wx_ex), float(wy_ex)))
|
| 386 |
+
tr["last_seen"] = frame_idx
|
| 387 |
+
tr["misses"] += 1
|
| 388 |
+
|
| 389 |
+
if frame_idx == len(front_paths) - 1:
|
| 390 |
+
front_final_detections = frame_dets_with_ids
|
| 391 |
+
|
| 392 |
+
valid_tracks = []
|
| 393 |
+
for tid, tr in tracks.items():
|
| 394 |
+
if len(tr["history_world"]) != len(front_paths):
|
| 395 |
+
continue
|
| 396 |
+
if tr["misses"] > 2:
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
x0, y0 = tr["history_world"][0]
|
| 400 |
+
x1, y1 = tr["history_world"][-1]
|
| 401 |
+
motion = math.hypot(x1 - x0, y1 - y0)
|
| 402 |
+
if motion < 0.08:
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
valid_tracks.append(
|
| 406 |
+
{
|
| 407 |
+
"id": tid,
|
| 408 |
+
"kind": tr["kind"],
|
| 409 |
+
"raw_label": tr["raw_label"],
|
| 410 |
+
"history_pixel": [tuple(p) for p in tr["history_pixel"]],
|
| 411 |
+
"history_world": [tuple(p) for p in tr["history_world"]],
|
| 412 |
+
"last_box": tr["last_box"],
|
| 413 |
+
"last_keypoints": tr["last_keypoints"],
|
| 414 |
+
}
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
valid_tracks.sort(key=lambda t: t["id"])
|
| 418 |
+
return valid_tracks, front_final_detections
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def raw_label_to_stabilizer_type(raw_label):
|
| 422 |
+
if raw_label == "person":
|
| 423 |
+
return "Person"
|
| 424 |
+
if raw_label == "bicycle":
|
| 425 |
+
return "Bicycle"
|
| 426 |
+
if raw_label == "motorcycle":
|
| 427 |
+
return "Motorcycle"
|
| 428 |
+
if raw_label == "bus":
|
| 429 |
+
return "Bus"
|
| 430 |
+
if raw_label == "truck":
|
| 431 |
+
return "Truck"
|
| 432 |
+
return "Car"
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def build_fusion_features(history_world, fusion_data):
|
| 436 |
+
if not fusion_data:
|
| 437 |
+
return None
|
| 438 |
+
|
| 439 |
+
lidar_xy = fusion_data.get("lidar_xy")
|
| 440 |
+
radar_xy = fusion_data.get("radar_xy")
|
| 441 |
+
|
| 442 |
+
if lidar_xy is None and radar_xy is None:
|
| 443 |
+
return None
|
| 444 |
+
|
| 445 |
+
feats = []
|
| 446 |
+
for px, py in history_world:
|
| 447 |
+
if lidar_xy is not None and len(lidar_xy) > 0:
|
| 448 |
+
dl = np.hypot(lidar_xy[:, 0] - px, lidar_xy[:, 1] - py)
|
| 449 |
+
lidar_cnt = int((dl < 2.0).sum())
|
| 450 |
+
else:
|
| 451 |
+
lidar_cnt = 0
|
| 452 |
+
|
| 453 |
+
if radar_xy is not None and len(radar_xy) > 0:
|
| 454 |
+
dr = np.hypot(radar_xy[:, 0] - px, radar_xy[:, 1] - py)
|
| 455 |
+
radar_cnt = int((dr < 2.5).sum())
|
| 456 |
+
else:
|
| 457 |
+
radar_cnt = 0
|
| 458 |
+
|
| 459 |
+
lidar_norm = min(80.0, float(lidar_cnt)) / 80.0
|
| 460 |
+
radar_norm = min(30.0, float(radar_cnt)) / 30.0
|
| 461 |
+
sensor_strength = min(1.0, (float(lidar_cnt) + 2.0 * float(radar_cnt)) / 100.0)
|
| 462 |
+
feats.append([lidar_norm, radar_norm, sensor_strength])
|
| 463 |
+
|
| 464 |
+
return feats
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def stabilize_tracks_with_radar(tracks, fusion_data):
|
| 468 |
+
if not tracks:
|
| 469 |
+
return tracks
|
| 470 |
+
|
| 471 |
+
packed = []
|
| 472 |
+
for tr in tracks:
|
| 473 |
+
hist = tr["history_world"]
|
| 474 |
+
if len(hist) >= 2:
|
| 475 |
+
dx = float(hist[-1][0] - hist[-2][0])
|
| 476 |
+
dy = float(hist[-1][1] - hist[-2][1])
|
| 477 |
+
else:
|
| 478 |
+
dx = 0.0
|
| 479 |
+
dy = 0.0
|
| 480 |
+
|
| 481 |
+
packed.append(
|
| 482 |
+
{
|
| 483 |
+
"type": raw_label_to_stabilizer_type(tr.get("raw_label", "car")),
|
| 484 |
+
"history": [tuple(p) for p in hist],
|
| 485 |
+
"dx": dx,
|
| 486 |
+
"dy": dy,
|
| 487 |
+
}
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
stabilized = radar_stabilize_motion(packed, fusion_data, dt_seconds=0.5)
|
| 491 |
+
|
| 492 |
+
updated = []
|
| 493 |
+
for tr, st in zip(tracks, stabilized):
|
| 494 |
+
t_copy = dict(tr)
|
| 495 |
+
t_copy["history_world"] = [(float(x), float(y)) for x, y in st["history"]]
|
| 496 |
+
updated.append(t_copy)
|
| 497 |
+
|
| 498 |
+
return updated
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def choose_target_track_id(tracks):
|
| 502 |
+
if not tracks:
|
| 503 |
+
return None
|
| 504 |
+
|
| 505 |
+
peds = [t for t in tracks if t["kind"] == "pedestrian"]
|
| 506 |
+
if peds:
|
| 507 |
+
best = min(peds, key=lambda t: math.hypot(t["history_world"][-1][0], t["history_world"][-1][1]))
|
| 508 |
+
return best["id"]
|
| 509 |
+
|
| 510 |
+
return tracks[0]["id"]
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def build_agents_from_tracks(tracks, fusion_data):
|
| 514 |
+
if not tracks:
|
| 515 |
+
return [], None, []
|
| 516 |
+
|
| 517 |
+
tracks_work = []
|
| 518 |
+
for tr in tracks:
|
| 519 |
+
tracks_work.append(
|
| 520 |
+
{
|
| 521 |
+
"id": tr["id"],
|
| 522 |
+
"kind": tr["kind"],
|
| 523 |
+
"raw_label": tr["raw_label"],
|
| 524 |
+
"history_pixel": [tuple(p) for p in tr["history_pixel"]],
|
| 525 |
+
"history_world": [tuple(p) for p in tr["history_world"]],
|
| 526 |
+
"last_box": tr.get("last_box"),
|
| 527 |
+
"last_keypoints": tr.get("last_keypoints"),
|
| 528 |
+
}
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
tracks_work = stabilize_tracks_with_radar(tracks_work, fusion_data)
|
| 532 |
+
|
| 533 |
+
target_id = choose_target_track_id(tracks_work)
|
| 534 |
+
agents = []
|
| 535 |
+
|
| 536 |
+
for tr in tracks_work:
|
| 537 |
+
neighbors = []
|
| 538 |
+
for other in tracks_work:
|
| 539 |
+
if other["id"] == tr["id"]:
|
| 540 |
+
continue
|
| 541 |
+
neighbors.append(other["history_world"])
|
| 542 |
+
|
| 543 |
+
if len(neighbors) > 12:
|
| 544 |
+
x0, y0 = tr["history_world"][-1]
|
| 545 |
+
neighbors = sorted(
|
| 546 |
+
neighbors,
|
| 547 |
+
key=lambda nh: math.hypot(nh[-1][0] - x0, nh[-1][1] - y0),
|
| 548 |
+
)[:12]
|
| 549 |
+
|
| 550 |
+
fusion_feats = build_fusion_features(tr["history_world"], fusion_data)
|
| 551 |
+
|
| 552 |
+
pred, probs, _ = trajectory_predict(
|
| 553 |
+
tr["history_world"],
|
| 554 |
+
neighbor_points_list=neighbors,
|
| 555 |
+
fusion_feats=fusion_feats,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
pred_np = pred.detach().cpu().numpy()
|
| 559 |
+
probs_np = probs.detach().cpu().numpy()
|
| 560 |
+
|
| 561 |
+
predictions = []
|
| 562 |
+
for mode_i in range(pred_np.shape[0]):
|
| 563 |
+
mode_path = [(float(p[0]), float(p[1])) for p in pred_np[mode_i]]
|
| 564 |
+
predictions.append(mode_path)
|
| 565 |
+
|
| 566 |
+
agents.append(
|
| 567 |
+
{
|
| 568 |
+
"id": int(tr["id"]),
|
| 569 |
+
"type": "pedestrian" if tr["kind"] == "pedestrian" else "vehicle",
|
| 570 |
+
"raw_label": tr["raw_label"],
|
| 571 |
+
"history": [tuple(map(float, p)) for p in tr["history_world"]],
|
| 572 |
+
"predictions": predictions,
|
| 573 |
+
"probabilities": normalize_probs(probs_np.tolist()),
|
| 574 |
+
"is_target": tr["id"] == target_id,
|
| 575 |
+
}
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
return agents, target_id, tracks_work
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def assign_track_ids_to_front_detections(detections, tracks, gate_px=90.0):
|
| 582 |
+
if not detections:
|
| 583 |
+
return []
|
| 584 |
+
|
| 585 |
+
out = []
|
| 586 |
+
used_ids = set()
|
| 587 |
+
|
| 588 |
+
for det_idx, det in enumerate(detections):
|
| 589 |
+
d = dict(det)
|
| 590 |
+
d.setdefault("det_id", det_idx + 1)
|
| 591 |
+
|
| 592 |
+
if d.get("track_id") is not None:
|
| 593 |
+
used_ids.add(d["track_id"])
|
| 594 |
+
out.append(d)
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
best_id = None
|
| 598 |
+
best_dist = 1e9
|
| 599 |
+
|
| 600 |
+
for tr in tracks:
|
| 601 |
+
if tr["id"] in used_ids:
|
| 602 |
+
continue
|
| 603 |
+
if tr["kind"] != d["kind"]:
|
| 604 |
+
continue
|
| 605 |
+
|
| 606 |
+
px, py = tr["history_pixel"][-1]
|
| 607 |
+
dist = math.hypot(d["center_x"] - px, d["bottom_y"] - py)
|
| 608 |
+
if dist < gate_px and dist < best_dist:
|
| 609 |
+
best_dist = dist
|
| 610 |
+
best_id = tr["id"]
|
| 611 |
+
|
| 612 |
+
d["track_id"] = best_id
|
| 613 |
+
if best_id is not None:
|
| 614 |
+
used_ids.add(best_id)
|
| 615 |
+
out.append(d)
|
| 616 |
+
|
| 617 |
+
return out
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
@st.cache_data(show_spinner=False)
|
| 621 |
+
def build_live_agents_bundle(anchor_idx, score_threshold, tracking_gate_px, use_pose):
|
| 622 |
+
front_paths = list_channel_image_paths("CAM_FRONT")
|
| 623 |
+
if len(front_paths) < 4:
|
| 624 |
+
return {"error": "Need at least 4 CAM_FRONT frames in DataSet/samples/CAM_FRONT."}
|
| 625 |
+
|
| 626 |
+
if anchor_idx < 3:
|
| 627 |
+
anchor_idx = 3
|
| 628 |
+
if anchor_idx >= len(front_paths):
|
| 629 |
+
anchor_idx = len(front_paths) - 1
|
| 630 |
+
|
| 631 |
+
models = load_cv_models()
|
| 632 |
+
if "error" in models:
|
| 633 |
+
return {
|
| 634 |
+
"error": f"Could not load CV models ({models['error']}).",
|
| 635 |
+
"device": models.get("device_name", "unknown"),
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
window_paths = front_paths[anchor_idx - 3 : anchor_idx + 1]
|
| 639 |
+
|
| 640 |
+
tracks, front_dets = track_front_agents(
|
| 641 |
+
window_paths,
|
| 642 |
+
models,
|
| 643 |
+
score_threshold=score_threshold,
|
| 644 |
+
tracking_gate_px=tracking_gate_px,
|
| 645 |
+
use_pose=use_pose,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if len(tracks) == 0:
|
| 649 |
+
return {"error": "No valid tracked moving agents found in selected frame window."}
|
| 650 |
+
|
| 651 |
+
front_curr = window_paths[-1]
|
| 652 |
+
fusion_data = load_fusion_for_cam_frame(Path(front_curr).name)
|
| 653 |
+
|
| 654 |
+
agents, target_id, tracks_stable = build_agents_from_tracks(tracks, fusion_data)
|
| 655 |
+
if len(agents) == 0:
|
| 656 |
+
return {"error": "Tracking succeeded but trajectory prediction produced no agents."}
|
| 657 |
+
|
| 658 |
+
snapshots = {}
|
| 659 |
+
for channel, _, _ in CAMERA_VIEWS:
|
| 660 |
+
ch_paths = list_channel_image_paths(channel)
|
| 661 |
+
|
| 662 |
+
if not ch_paths:
|
| 663 |
+
snapshots[channel] = {
|
| 664 |
+
"image": fallback_canvas(),
|
| 665 |
+
"detections": [],
|
| 666 |
+
"frame_path": None,
|
| 667 |
+
}
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
ch_idx = min(anchor_idx, len(ch_paths) - 1)
|
| 671 |
+
ch_path = ch_paths[ch_idx]
|
| 672 |
+
ch_arr = load_image_array(ch_path)
|
| 673 |
+
|
| 674 |
+
if channel == "CAM_FRONT" and Path(ch_path).name == Path(front_curr).name:
|
| 675 |
+
ch_dets = [dict(d) for d in front_dets]
|
| 676 |
+
else:
|
| 677 |
+
ch_dets = detect_objects_and_pose(
|
| 678 |
+
ch_arr,
|
| 679 |
+
models,
|
| 680 |
+
score_threshold=score_threshold,
|
| 681 |
+
use_pose=use_pose,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
for i, det in enumerate(ch_dets):
|
| 685 |
+
det.setdefault("track_id", None)
|
| 686 |
+
det.setdefault("det_id", i + 1)
|
| 687 |
+
|
| 688 |
+
snapshots[channel] = {
|
| 689 |
+
"image": ch_arr,
|
| 690 |
+
"detections": ch_dets,
|
| 691 |
+
"frame_path": ch_path,
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
if "CAM_FRONT" in snapshots:
|
| 695 |
+
snapshots["CAM_FRONT"]["detections"] = assign_track_ids_to_front_detections(
|
| 696 |
+
snapshots["CAM_FRONT"]["detections"],
|
| 697 |
+
tracks_stable,
|
| 698 |
+
gate_px=tracking_gate_px,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
return {
|
| 702 |
+
"agents": agents,
|
| 703 |
+
"fusion_data": fusion_data,
|
| 704 |
+
"camera_snapshots": snapshots,
|
| 705 |
+
"target_track_id": target_id,
|
| 706 |
+
"device": models.get("device_name", "unknown"),
|
| 707 |
+
"front_anchor_path": front_curr,
|
| 708 |
+
"mode": "live_fusion",
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def uploaded_file_to_array(uploaded_file):
|
| 713 |
+
if uploaded_file is None:
|
| 714 |
+
return None
|
| 715 |
+
try:
|
| 716 |
+
return np.asarray(Image.open(io.BytesIO(uploaded_file.getvalue())).convert("RGB"))
|
| 717 |
+
except Exception:
|
| 718 |
+
return None
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def match_two_frame_tracks(det_prev, det_curr, tracking_gate_px=90.0, min_motion_px=0.0):
|
| 722 |
+
used_curr = set()
|
| 723 |
+
matches = []
|
| 724 |
+
|
| 725 |
+
det_prev = sorted(det_prev, key=lambda d: d["score"], reverse=True)
|
| 726 |
+
det_curr = sorted(det_curr, key=lambda d: d["score"], reverse=True)
|
| 727 |
+
|
| 728 |
+
for d0 in det_prev:
|
| 729 |
+
best_idx = None
|
| 730 |
+
best_dist = 1e9
|
| 731 |
+
|
| 732 |
+
for j, d1 in enumerate(det_curr):
|
| 733 |
+
if j in used_curr:
|
| 734 |
+
continue
|
| 735 |
+
if d0["kind"] != d1["kind"]:
|
| 736 |
+
continue
|
| 737 |
+
|
| 738 |
+
dist = math.hypot(d1["center_x"] - d0["center_x"], d1["bottom_y"] - d0["bottom_y"])
|
| 739 |
+
if dist < tracking_gate_px and dist < best_dist:
|
| 740 |
+
best_dist = dist
|
| 741 |
+
best_idx = j
|
| 742 |
+
|
| 743 |
+
if best_idx is None:
|
| 744 |
+
continue
|
| 745 |
+
|
| 746 |
+
used_curr.add(best_idx)
|
| 747 |
+
d1 = det_curr[best_idx]
|
| 748 |
+
|
| 749 |
+
matches.append((d0, d1, float(best_dist)))
|
| 750 |
+
|
| 751 |
+
return matches
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def build_two_image_agents_bundle(img_prev, img_curr, score_threshold, tracking_gate_px, min_motion_px, use_pose):
|
| 755 |
+
models = load_cv_models()
|
| 756 |
+
if "error" in models:
|
| 757 |
+
return {
|
| 758 |
+
"error": f"Could not load CV models ({models['error']}).",
|
| 759 |
+
"device": models.get("device_name", "unknown"),
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
det_prev = detect_objects_and_pose(img_prev, models, score_threshold=score_threshold, use_pose=use_pose)
|
| 763 |
+
det_curr = detect_objects_and_pose(img_curr, models, score_threshold=score_threshold, use_pose=use_pose)
|
| 764 |
+
|
| 765 |
+
# Two-image mode focuses on VRUs (pedestrians/cyclists/motorcycles).
|
| 766 |
+
det_prev_vru = [d for d in det_prev if d.get("kind") == "pedestrian"]
|
| 767 |
+
det_curr_vru = [d for d in det_curr if d.get("kind") == "pedestrian"]
|
| 768 |
+
|
| 769 |
+
for i, d in enumerate(det_prev):
|
| 770 |
+
d["det_id"] = i + 1
|
| 771 |
+
d["track_id"] = None
|
| 772 |
+
for i, d in enumerate(det_curr):
|
| 773 |
+
d["det_id"] = i + 1
|
| 774 |
+
d["track_id"] = None
|
| 775 |
+
|
| 776 |
+
if len(det_curr_vru) == 0:
|
| 777 |
+
return {"error": "No pedestrian/cyclist detections found in image 2 (t0)."}
|
| 778 |
+
|
| 779 |
+
matches = match_two_frame_tracks(
|
| 780 |
+
det_prev_vru,
|
| 781 |
+
det_curr_vru,
|
| 782 |
+
tracking_gate_px=tracking_gate_px,
|
| 783 |
+
min_motion_px=0.0,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# Backfill unmatched current VRUs so every visible VRU at t0 gets a prediction.
|
| 787 |
+
matched_curr_ids = {id(m[1]) for m in matches}
|
| 788 |
+
for d1 in det_curr_vru:
|
| 789 |
+
if id(d1) in matched_curr_ids:
|
| 790 |
+
continue
|
| 791 |
+
|
| 792 |
+
if len(det_prev_vru) == 0:
|
| 793 |
+
matches.append((None, d1, float("inf")))
|
| 794 |
+
continue
|
| 795 |
+
|
| 796 |
+
nearest_prev = min(
|
| 797 |
+
det_prev_vru,
|
| 798 |
+
key=lambda d0: math.hypot(d1["center_x"] - d0["center_x"], d1["bottom_y"] - d0["bottom_y"]),
|
| 799 |
+
)
|
| 800 |
+
dist = math.hypot(
|
| 801 |
+
d1["center_x"] - nearest_prev["center_x"],
|
| 802 |
+
d1["bottom_y"] - nearest_prev["bottom_y"],
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
# If previous frame support is weak, still include the agent with near-static history.
|
| 806 |
+
if dist <= 1.5 * tracking_gate_px:
|
| 807 |
+
matches.append((nearest_prev, d1, float(dist)))
|
| 808 |
+
else:
|
| 809 |
+
matches.append((None, d1, float("inf")))
|
| 810 |
+
|
| 811 |
+
h0, w0 = img_prev.shape[:2]
|
| 812 |
+
h1, w1 = img_curr.shape[:2]
|
| 813 |
+
|
| 814 |
+
tracks = []
|
| 815 |
+
for track_id, (d0, d1, dist_px) in enumerate(matches, start=1):
|
| 816 |
+
if d0 is not None and d0.get("track_id") is None:
|
| 817 |
+
d0["track_id"] = track_id
|
| 818 |
+
d1["track_id"] = track_id
|
| 819 |
+
|
| 820 |
+
if d0 is not None:
|
| 821 |
+
p_prev = pixel_to_bev(d0["center_x"], d0["bottom_y"], w0, h0)
|
| 822 |
+
else:
|
| 823 |
+
p_prev = None
|
| 824 |
+
p_curr = pixel_to_bev(d1["center_x"], d1["bottom_y"], w1, h1)
|
| 825 |
+
|
| 826 |
+
if p_prev is None:
|
| 827 |
+
vx, vy = 0.0, 0.0
|
| 828 |
+
p_prev = p_curr
|
| 829 |
+
else:
|
| 830 |
+
vx = p_curr[0] - p_prev[0]
|
| 831 |
+
vy = p_curr[1] - p_prev[1]
|
| 832 |
+
|
| 833 |
+
# Keep the agent even if tiny displacement; just make observation history static.
|
| 834 |
+
if dist_px < float(min_motion_px):
|
| 835 |
+
vx, vy = 0.0, 0.0
|
| 836 |
+
p_prev = p_curr
|
| 837 |
+
|
| 838 |
+
# Reconstruct a 4-point observation history from 2 frames.
|
| 839 |
+
hist = [
|
| 840 |
+
(p_curr[0] - 3.0 * vx, p_curr[1] - 3.0 * vy),
|
| 841 |
+
(p_curr[0] - 2.0 * vx, p_curr[1] - 2.0 * vy),
|
| 842 |
+
(p_prev[0], p_prev[1]),
|
| 843 |
+
(p_curr[0], p_curr[1]),
|
| 844 |
+
]
|
| 845 |
+
|
| 846 |
+
tracks.append(
|
| 847 |
+
{
|
| 848 |
+
"id": track_id,
|
| 849 |
+
"kind": d1["kind"],
|
| 850 |
+
"raw_label": d1["raw_label"],
|
| 851 |
+
"history_world": hist,
|
| 852 |
+
}
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# In this mode, every VRU is treated as a target for prediction display.
|
| 856 |
+
target_track_id = None
|
| 857 |
+
|
| 858 |
+
agents = []
|
| 859 |
+
for tr in tracks:
|
| 860 |
+
neighbors = [other["history_world"] for other in tracks if other["id"] != tr["id"]]
|
| 861 |
+
|
| 862 |
+
pred, probs, _ = trajectory_predict(
|
| 863 |
+
tr["history_world"],
|
| 864 |
+
neighbor_points_list=neighbors,
|
| 865 |
+
fusion_feats=None,
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
pred_np = pred.detach().cpu().numpy()
|
| 869 |
+
probs_np = probs.detach().cpu().numpy()
|
| 870 |
+
|
| 871 |
+
predictions = []
|
| 872 |
+
for mode_i in range(pred_np.shape[0]):
|
| 873 |
+
predictions.append([(float(p[0]), float(p[1])) for p in pred_np[mode_i]])
|
| 874 |
+
|
| 875 |
+
agents.append(
|
| 876 |
+
{
|
| 877 |
+
"id": int(tr["id"]),
|
| 878 |
+
"type": "pedestrian" if tr["kind"] == "pedestrian" else "vehicle",
|
| 879 |
+
"raw_label": tr["raw_label"],
|
| 880 |
+
"history": [tuple(map(float, p)) for p in tr["history_world"]],
|
| 881 |
+
"predictions": predictions,
|
| 882 |
+
"probabilities": normalize_probs(probs_np.tolist()),
|
| 883 |
+
"is_target": True,
|
| 884 |
+
}
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
return {
|
| 888 |
+
"agents": agents,
|
| 889 |
+
"target_track_id": target_track_id,
|
| 890 |
+
"camera_snapshots": {
|
| 891 |
+
"pair_prev": {"image": img_prev, "detections": det_prev},
|
| 892 |
+
"pair_curr": {"image": img_curr, "detections": det_curr},
|
| 893 |
+
},
|
| 894 |
+
"device": models.get("device_name", "unknown"),
|
| 895 |
+
"mode": "two_upload",
|
| 896 |
+
"match_count": len(agents),
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def bev_to_pixel(x_m, y_m, width, height):
|
| 901 |
+
x_div = max(1.0, width / 80.0)
|
| 902 |
+
y_div = max(1.0, height / 50.0)
|
| 903 |
+
|
| 904 |
+
px = x_m * x_div + 0.5 * width
|
| 905 |
+
py = y_m * y_div + 0.58 * height
|
| 906 |
+
return float(px), float(py)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def create_prediction_overlay_figure(image_arr, detections, agents, step, target_track_id=None, highlight_track_ids=None):
|
| 910 |
+
fig = create_camera_figure_detections(
|
| 911 |
+
image_arr,
|
| 912 |
+
detections,
|
| 913 |
+
camera_label="Prediction Output",
|
| 914 |
+
target_track_id=target_track_id,
|
| 915 |
+
highlight_track_ids=highlight_track_ids,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
h, w = image_arr.shape[:2]
|
| 919 |
+
|
| 920 |
+
for a in agents:
|
| 921 |
+
color = agent_color(a)
|
| 922 |
+
k = best_mode_idx(a)
|
| 923 |
+
pred = a["predictions"][k]
|
| 924 |
+
end_idx = max(1, min(step, len(pred)))
|
| 925 |
+
path_world = [a["history"][-1]] + pred[:end_idx]
|
| 926 |
+
|
| 927 |
+
px = []
|
| 928 |
+
py = []
|
| 929 |
+
for xw, yw in path_world:
|
| 930 |
+
u, v = bev_to_pixel(xw, yw, w, h)
|
| 931 |
+
px.append(u)
|
| 932 |
+
py.append(v)
|
| 933 |
+
|
| 934 |
+
# Glow trail for a cleaner, reference-style visual emphasis.
|
| 935 |
+
for lw, op in [(14, 0.12), (8, 0.20), (4, 0.95)]:
|
| 936 |
+
fig.add_trace(
|
| 937 |
+
go.Scatter(
|
| 938 |
+
x=px,
|
| 939 |
+
y=py,
|
| 940 |
+
mode="lines",
|
| 941 |
+
line={"color": color, "width": lw, "shape": "spline", "smoothing": 1.1},
|
| 942 |
+
opacity=op,
|
| 943 |
+
hoverinfo="skip",
|
| 944 |
+
showlegend=False,
|
| 945 |
+
)
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
return fig
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def remove_vru_foreground_from_scene(scene_image, scene_detections=None):
|
| 952 |
+
if scene_image is None or cv2 is None:
|
| 953 |
+
return scene_image
|
| 954 |
+
|
| 955 |
+
if scene_detections is None or len(scene_detections) == 0:
|
| 956 |
+
return scene_image
|
| 957 |
+
|
| 958 |
+
h, w = scene_image.shape[:2]
|
| 959 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 960 |
+
|
| 961 |
+
for det in scene_detections:
|
| 962 |
+
if det.get("kind") != "pedestrian":
|
| 963 |
+
continue
|
| 964 |
+
|
| 965 |
+
x1, y1, x2, y2 = det.get("box", [0, 0, 0, 0])
|
| 966 |
+
padx = 0.08 * (x2 - x1)
|
| 967 |
+
pady = 0.10 * (y2 - y1)
|
| 968 |
+
|
| 969 |
+
xa = int(max(0, min(w - 1, x1 - padx)))
|
| 970 |
+
ya = int(max(0, min(h - 1, y1 - pady)))
|
| 971 |
+
xb = int(max(0, min(w - 1, x2 + padx)))
|
| 972 |
+
yb = int(max(0, min(h - 1, y2 + pady)))
|
| 973 |
+
|
| 974 |
+
if xb > xa and yb > ya:
|
| 975 |
+
cv2.rectangle(mask, (xa, ya), (xb, yb), color=255, thickness=-1)
|
| 976 |
+
|
| 977 |
+
if int(mask.sum()) == 0:
|
| 978 |
+
return scene_image
|
| 979 |
+
|
| 980 |
+
bgr = cv2.cvtColor(scene_image, cv2.COLOR_RGB2BGR)
|
| 981 |
+
inpainted = cv2.inpaint(bgr, mask, 7, cv2.INPAINT_TELEA)
|
| 982 |
+
return cv2.cvtColor(inpainted, cv2.COLOR_BGR2RGB)
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
def build_pseudo_bev_background(scene_image, x_min, x_max, y_min, y_max, scene_detections=None):
|
| 986 |
+
# Context BEV from a single front-view frame using inverse-perspective remap.
|
| 987 |
+
if scene_image is None or cv2 is None:
|
| 988 |
+
return None
|
| 989 |
+
|
| 990 |
+
cleaned = remove_vru_foreground_from_scene(scene_image, scene_detections=scene_detections)
|
| 991 |
+
h, w = cleaned.shape[:2]
|
| 992 |
+
if h < 20 or w < 20:
|
| 993 |
+
return None
|
| 994 |
+
|
| 995 |
+
out_w, out_h = 1100, 820
|
| 996 |
+
|
| 997 |
+
xs = np.linspace(x_min, x_max, out_w, dtype=np.float32)
|
| 998 |
+
ys = np.linspace(y_max, y_min, out_h, dtype=np.float32)
|
| 999 |
+
xg, yg = np.meshgrid(xs, ys)
|
| 1000 |
+
|
| 1001 |
+
cx = 0.5 * w
|
| 1002 |
+
horizon = 0.42 * h
|
| 1003 |
+
|
| 1004 |
+
depth = np.clip((yg - y_min) + 2.0, 2.0, None)
|
| 1005 |
+
|
| 1006 |
+
map_x = cx + (0.95 * w) * xg / (depth + 6.0)
|
| 1007 |
+
map_y = horizon + (5.8 * h) / depth
|
| 1008 |
+
|
| 1009 |
+
map_x = np.clip(map_x, 0, w - 1).astype(np.float32)
|
| 1010 |
+
map_y = np.clip(map_y, 0, h - 1).astype(np.float32)
|
| 1011 |
+
|
| 1012 |
+
warped = cv2.remap(cleaned, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
| 1013 |
+
warped = cv2.GaussianBlur(warped, (0, 0), 0.8)
|
| 1014 |
+
warped = np.clip(warped.astype(np.float32) * 0.78, 0, 255).astype(np.uint8)
|
| 1015 |
+
return warped
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def compute_reference_bounds(agents, step, show_multimodal):
|
| 1019 |
+
xs = [0.0]
|
| 1020 |
+
ys = [0.0]
|
| 1021 |
+
|
| 1022 |
+
for a in agents:
|
| 1023 |
+
for xh, yh in a["history"]:
|
| 1024 |
+
xs.append(float(xh))
|
| 1025 |
+
ys.append(float(yh))
|
| 1026 |
+
|
| 1027 |
+
k_best = best_mode_idx(a)
|
| 1028 |
+
best_path = a["predictions"][k_best][: max(1, min(step, len(a["predictions"][k_best])))]
|
| 1029 |
+
for xp, yp in best_path:
|
| 1030 |
+
xs.append(float(xp))
|
| 1031 |
+
ys.append(float(yp))
|
| 1032 |
+
|
| 1033 |
+
if show_multimodal:
|
| 1034 |
+
for m, m_path in enumerate(a["predictions"]):
|
| 1035 |
+
if m == k_best:
|
| 1036 |
+
continue
|
| 1037 |
+
m_slice = m_path[: max(1, min(step, len(m_path)))]
|
| 1038 |
+
for xp, yp in m_slice:
|
| 1039 |
+
xs.append(float(xp))
|
| 1040 |
+
ys.append(float(yp))
|
| 1041 |
+
|
| 1042 |
+
x_min = min(xs) - 6.0
|
| 1043 |
+
x_max = max(xs) + 6.0
|
| 1044 |
+
y_min = min(ys) - 8.0
|
| 1045 |
+
y_max = max(ys) + 10.0
|
| 1046 |
+
|
| 1047 |
+
min_x_span = 44.0
|
| 1048 |
+
min_y_span = 64.0
|
| 1049 |
+
|
| 1050 |
+
x_span = x_max - x_min
|
| 1051 |
+
y_span = y_max - y_min
|
| 1052 |
+
|
| 1053 |
+
if x_span < min_x_span:
|
| 1054 |
+
xc = 0.5 * (x_min + x_max)
|
| 1055 |
+
x_min = xc - 0.5 * min_x_span
|
| 1056 |
+
x_max = xc + 0.5 * min_x_span
|
| 1057 |
+
|
| 1058 |
+
if y_span < min_y_span:
|
| 1059 |
+
yc = 0.5 * (y_min + y_max)
|
| 1060 |
+
y_min = yc - 0.5 * min_y_span
|
| 1061 |
+
y_max = yc + 0.5 * min_y_span
|
| 1062 |
+
|
| 1063 |
+
return x_min, x_max, y_min, y_max
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def spread_agent_markers(agents, step, tol=0.45, radius=0.55):
|
| 1067 |
+
positions = [position_at_step(a, step) for a in agents]
|
| 1068 |
+
offsets = []
|
| 1069 |
+
|
| 1070 |
+
for i, (xi, yi) in enumerate(positions):
|
| 1071 |
+
near = []
|
| 1072 |
+
for j, (xj, yj) in enumerate(positions):
|
| 1073 |
+
if math.hypot(xi - xj, yi - yj) <= tol:
|
| 1074 |
+
near.append(j)
|
| 1075 |
+
|
| 1076 |
+
if len(near) <= 1:
|
| 1077 |
+
offsets.append((0.0, 0.0))
|
| 1078 |
+
continue
|
| 1079 |
+
|
| 1080 |
+
near_sorted = sorted(near)
|
| 1081 |
+
rank = near_sorted.index(i)
|
| 1082 |
+
ang = 2.0 * math.pi * rank / len(near_sorted)
|
| 1083 |
+
offsets.append((radius * math.cos(ang), radius * math.sin(ang)))
|
| 1084 |
+
|
| 1085 |
+
return positions, offsets
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
def hex_to_rgba(hex_color, alpha):
|
| 1089 |
+
alpha = float(np.clip(alpha, 0.0, 1.0))
|
| 1090 |
+
c = str(hex_color).lstrip("#")
|
| 1091 |
+
if len(c) != 6:
|
| 1092 |
+
return f"rgba(229,231,235,{alpha:.3f})"
|
| 1093 |
+
r = int(c[0:2], 16)
|
| 1094 |
+
g = int(c[2:4], 16)
|
| 1095 |
+
b = int(c[4:6], 16)
|
| 1096 |
+
return f"rgba({r},{g},{b},{alpha:.3f})"
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
def summarize_agent_probabilities(agent):
|
| 1100 |
+
bins = {"Straight": 0.0, "Left": 0.0, "Right": 0.0, "Stop": 0.0}
|
| 1101 |
+
|
| 1102 |
+
classifier = globals().get("classify_direction")
|
| 1103 |
+
for mode_idx, mode_path in enumerate(agent.get("predictions", [])):
|
| 1104 |
+
if mode_idx >= len(agent.get("probabilities", [])):
|
| 1105 |
+
continue
|
| 1106 |
+
|
| 1107 |
+
if callable(classifier):
|
| 1108 |
+
direction = classifier(agent["history"], mode_path)
|
| 1109 |
+
else:
|
| 1110 |
+
direction = ["Straight", "Left", "Right"][mode_idx % 3]
|
| 1111 |
+
|
| 1112 |
+
if direction not in bins:
|
| 1113 |
+
direction = "Straight"
|
| 1114 |
+
|
| 1115 |
+
bins[direction] += float(agent["probabilities"][mode_idx])
|
| 1116 |
+
|
| 1117 |
+
ranked = sorted(bins.items(), key=lambda kv: kv[1], reverse=True)
|
| 1118 |
+
top3 = ranked[:3]
|
| 1119 |
+
summary = ", ".join([f"{name} {prob * 100:.0f}%" for name, prob in top3])
|
| 1120 |
+
return summary, bins
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
def add_structured_road_scene(fig, x_min, x_max, y_min, y_max, add_crosswalk=True):
|
| 1124 |
+
road_half = float(np.clip(0.24 * (x_max - x_min), 9.5, 15.5))
|
| 1125 |
+
shoulder_half = road_half + 3.2
|
| 1126 |
+
|
| 1127 |
+
fig.add_shape(
|
| 1128 |
+
type="rect",
|
| 1129 |
+
x0=x_min,
|
| 1130 |
+
y0=y_min,
|
| 1131 |
+
x1=x_max,
|
| 1132 |
+
y1=y_max,
|
| 1133 |
+
line={"width": 0},
|
| 1134 |
+
fillcolor=ROAD_SHOULDER,
|
| 1135 |
+
layer="below",
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
fig.add_shape(
|
| 1139 |
+
type="rect",
|
| 1140 |
+
x0=-shoulder_half,
|
| 1141 |
+
y0=y_min,
|
| 1142 |
+
x1=shoulder_half,
|
| 1143 |
+
y1=y_max,
|
| 1144 |
+
line={"width": 0},
|
| 1145 |
+
fillcolor="rgba(18, 25, 35, 0.95)",
|
| 1146 |
+
layer="below",
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
fig.add_shape(
|
| 1150 |
+
type="rect",
|
| 1151 |
+
x0=-road_half,
|
| 1152 |
+
y0=y_min,
|
| 1153 |
+
x1=road_half,
|
| 1154 |
+
y1=y_max,
|
| 1155 |
+
line={"width": 0},
|
| 1156 |
+
fillcolor=ROAD_ASPHALT,
|
| 1157 |
+
layer="below",
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
for x_edge in (-road_half, road_half):
|
| 1161 |
+
fig.add_shape(
|
| 1162 |
+
type="line",
|
| 1163 |
+
x0=x_edge,
|
| 1164 |
+
y0=y_min,
|
| 1165 |
+
x1=x_edge,
|
| 1166 |
+
y1=y_max,
|
| 1167 |
+
line={"color": LANE_SOLID, "width": 2.5},
|
| 1168 |
+
layer="below",
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
lane_w = (2.0 * road_half) / 4.0
|
| 1172 |
+
for lane_idx in range(1, 4):
|
| 1173 |
+
x_lane = -road_half + lane_idx * lane_w
|
| 1174 |
+
line_color = CENTER_DASH if lane_idx == 2 else LANE_DASH
|
| 1175 |
+
line_width = 2.4 if lane_idx == 2 else 1.8
|
| 1176 |
+
fig.add_shape(
|
| 1177 |
+
type="line",
|
| 1178 |
+
x0=x_lane,
|
| 1179 |
+
y0=y_min,
|
| 1180 |
+
x1=x_lane,
|
| 1181 |
+
y1=y_max,
|
| 1182 |
+
line={"color": line_color, "width": line_width, "dash": "dash"},
|
| 1183 |
+
layer="below",
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
if add_crosswalk:
|
| 1187 |
+
cross_y = float(np.clip(8.0, y_min + 5.5, y_max - 5.5))
|
| 1188 |
+
stripe_h = 0.7
|
| 1189 |
+
stripe_gap = 0.55
|
| 1190 |
+
for i in range(-4, 5):
|
| 1191 |
+
y0 = cross_y + i * (stripe_h + stripe_gap)
|
| 1192 |
+
y1 = y0 + stripe_h
|
| 1193 |
+
fig.add_shape(
|
| 1194 |
+
type="rect",
|
| 1195 |
+
x0=-road_half + 0.7,
|
| 1196 |
+
y0=y0,
|
| 1197 |
+
x1=road_half - 0.7,
|
| 1198 |
+
y1=y1,
|
| 1199 |
+
line={"width": 0},
|
| 1200 |
+
fillcolor="rgba(229, 231, 235, 0.14)",
|
| 1201 |
+
layer="below",
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
def build_reference_bev_figure(agents, step, show_multimodal, scene_image=None, scene_detections=None):
|
| 1206 |
+
fig = go.Figure()
|
| 1207 |
+
|
| 1208 |
+
x_min, x_max, y_min, y_max = compute_reference_bounds(agents, step, show_multimodal)
|
| 1209 |
+
|
| 1210 |
+
bg = build_pseudo_bev_background(
|
| 1211 |
+
scene_image,
|
| 1212 |
+
x_min,
|
| 1213 |
+
x_max,
|
| 1214 |
+
y_min,
|
| 1215 |
+
y_max,
|
| 1216 |
+
scene_detections=scene_detections,
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
add_structured_road_scene(fig, x_min, x_max, y_min, y_max, add_crosswalk=True)
|
| 1220 |
+
|
| 1221 |
+
if bg is not None:
|
| 1222 |
+
fig.add_layout_image(
|
| 1223 |
+
dict(
|
| 1224 |
+
source=Image.fromarray(bg),
|
| 1225 |
+
xref="x",
|
| 1226 |
+
yref="y",
|
| 1227 |
+
x=x_min,
|
| 1228 |
+
y=y_max,
|
| 1229 |
+
sizex=x_max - x_min,
|
| 1230 |
+
sizey=y_max - y_min,
|
| 1231 |
+
sizing="stretch",
|
| 1232 |
+
opacity=0.38,
|
| 1233 |
+
layer="below",
|
| 1234 |
+
)
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
# Dark wash to keep trajectories readable on real-scene texture.
|
| 1238 |
+
fig.add_shape(
|
| 1239 |
+
type="rect",
|
| 1240 |
+
x0=x_min,
|
| 1241 |
+
y0=y_min,
|
| 1242 |
+
x1=x_max,
|
| 1243 |
+
y1=y_max,
|
| 1244 |
+
line={"width": 0},
|
| 1245 |
+
fillcolor="rgba(4, 8, 18, 0.36)",
|
| 1246 |
+
layer="below",
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
fig.add_shape(
|
| 1250 |
+
type="rect",
|
| 1251 |
+
x0=-1.1,
|
| 1252 |
+
y0=-2.2,
|
| 1253 |
+
x1=1.1,
|
| 1254 |
+
y1=2.2,
|
| 1255 |
+
line={"color": EGO_CYAN, "width": 2.2},
|
| 1256 |
+
fillcolor="rgba(34,211,238,0.20)",
|
| 1257 |
+
)
|
| 1258 |
+
fig.add_annotation(
|
| 1259 |
+
x=0.0,
|
| 1260 |
+
y=4.2,
|
| 1261 |
+
ax=0.0,
|
| 1262 |
+
ay=1.2,
|
| 1263 |
+
showarrow=True,
|
| 1264 |
+
arrowhead=3,
|
| 1265 |
+
arrowwidth=2.8,
|
| 1266 |
+
arrowcolor=EGO_CYAN,
|
| 1267 |
+
text="",
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
fig.add_trace(
|
| 1271 |
+
go.Scatter(
|
| 1272 |
+
x=[None],
|
| 1273 |
+
y=[None],
|
| 1274 |
+
mode="markers",
|
| 1275 |
+
marker={"size": 10, "symbol": "circle", "color": VRU_GREEN},
|
| 1276 |
+
name="Pedestrian",
|
| 1277 |
+
)
|
| 1278 |
+
)
|
| 1279 |
+
fig.add_trace(
|
| 1280 |
+
go.Scatter(
|
| 1281 |
+
x=[None],
|
| 1282 |
+
y=[None],
|
| 1283 |
+
mode="markers",
|
| 1284 |
+
marker={"size": 10, "symbol": "square", "color": VEHICLE_YELLOW},
|
| 1285 |
+
name="Vehicle",
|
| 1286 |
+
)
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
positions, marker_offsets = spread_agent_markers(agents, step)
|
| 1290 |
+
alt_legend_added = False
|
| 1291 |
+
|
| 1292 |
+
for idx, a in enumerate(agents):
|
| 1293 |
+
base_color = agent_color(a)
|
| 1294 |
+
best_idx = best_mode_idx(a)
|
| 1295 |
+
best_prob = float(a["probabilities"][best_idx]) if len(a["probabilities"]) > 0 else 0.0
|
| 1296 |
+
marker_color = hex_to_rgba(base_color, 0.48 + 0.52 * best_prob)
|
| 1297 |
+
|
| 1298 |
+
cx, cy = positions[idx]
|
| 1299 |
+
ox, oy = marker_offsets[idx]
|
| 1300 |
+
curr_x = cx + ox
|
| 1301 |
+
curr_y = cy + oy
|
| 1302 |
+
|
| 1303 |
+
summary_text, _ = summarize_agent_probabilities(a)
|
| 1304 |
+
hover_text = (
|
| 1305 |
+
f"ID {a['id']}<br>Type: {a['type'].title()}"
|
| 1306 |
+
f"<br>{summary_text}<br>Best path confidence: {best_prob * 100:.1f}%"
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
hx, hy = smooth_path(a["history"])
|
| 1310 |
+
fig.add_trace(
|
| 1311 |
+
go.Scatter(
|
| 1312 |
+
x=hx,
|
| 1313 |
+
y=hy,
|
| 1314 |
+
mode="lines",
|
| 1315 |
+
line={"color": "rgba(226,232,240,0.55)", "width": 2.2, "dash": "dot", "shape": "spline", "smoothing": 1.0},
|
| 1316 |
+
hovertemplate=f"ID {a['id']} past trajectory<extra></extra>",
|
| 1317 |
+
name="Past trajectory" if idx == 0 else None,
|
| 1318 |
+
showlegend=(idx == 0),
|
| 1319 |
+
)
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
fig.add_trace(
|
| 1323 |
+
go.Scatter(
|
| 1324 |
+
x=[curr_x],
|
| 1325 |
+
y=[curr_y],
|
| 1326 |
+
mode="markers+text",
|
| 1327 |
+
marker={
|
| 1328 |
+
"size": 11,
|
| 1329 |
+
"symbol": "circle" if a.get("type") == "pedestrian" else "square",
|
| 1330 |
+
"color": marker_color,
|
| 1331 |
+
"line": {"color": "rgba(5,7,15,0.95)", "width": 1.2},
|
| 1332 |
+
},
|
| 1333 |
+
text=[f"ID {a['id']}"],
|
| 1334 |
+
textposition="top center",
|
| 1335 |
+
textfont={"size": 10, "color": WHITE},
|
| 1336 |
+
hovertemplate=f"{hover_text}<extra></extra>",
|
| 1337 |
+
showlegend=False,
|
| 1338 |
+
)
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
px, py = previous_position_for_velocity(a, step)
|
| 1342 |
+
dx, dy = cx - px, cy - py
|
| 1343 |
+
norm = math.hypot(dx, dy)
|
| 1344 |
+
if norm > 1e-3:
|
| 1345 |
+
vx, vy = (dx / norm) * 2.0, (dy / norm) * 2.0
|
| 1346 |
+
fig.add_annotation(
|
| 1347 |
+
x=curr_x + vx,
|
| 1348 |
+
y=curr_y + vy,
|
| 1349 |
+
ax=curr_x,
|
| 1350 |
+
ay=curr_y,
|
| 1351 |
+
showarrow=True,
|
| 1352 |
+
arrowhead=2,
|
| 1353 |
+
arrowsize=1,
|
| 1354 |
+
arrowwidth=2,
|
| 1355 |
+
arrowcolor=base_color,
|
| 1356 |
+
text="",
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
mode_order = [best_idx, 0, 1, 2]
|
| 1360 |
+
mode_order = list(dict.fromkeys(mode_order))
|
| 1361 |
+
|
| 1362 |
+
for rank, m in enumerate(mode_order[:3]):
|
| 1363 |
+
if (not show_multimodal) and rank > 0:
|
| 1364 |
+
continue
|
| 1365 |
+
|
| 1366 |
+
mode_prob = float(a["probabilities"][m]) if m < len(a["probabilities"]) else 0.0
|
| 1367 |
+
mode_color = TRAJ_MODE_COLORS[m % len(TRAJ_MODE_COLORS)]
|
| 1368 |
+
|
| 1369 |
+
mode_path = a["predictions"][m]
|
| 1370 |
+
mode_slice = mode_path[: max(1, min(step, len(mode_path)))]
|
| 1371 |
+
tx, ty = smooth_path([a["history"][-1]] + mode_slice)
|
| 1372 |
+
is_best = m == best_idx
|
| 1373 |
+
|
| 1374 |
+
if is_best:
|
| 1375 |
+
for lw, op in [(14, 0.08), (9, 0.16)]:
|
| 1376 |
+
fig.add_trace(
|
| 1377 |
+
go.Scatter(
|
| 1378 |
+
x=tx,
|
| 1379 |
+
y=ty,
|
| 1380 |
+
mode="lines",
|
| 1381 |
+
line={"color": mode_color, "width": lw, "shape": "spline", "smoothing": 1.15},
|
| 1382 |
+
opacity=op,
|
| 1383 |
+
hoverinfo="skip",
|
| 1384 |
+
showlegend=False,
|
| 1385 |
+
)
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
fig.add_trace(
|
| 1389 |
+
go.Scatter(
|
| 1390 |
+
x=tx,
|
| 1391 |
+
y=ty,
|
| 1392 |
+
mode="lines",
|
| 1393 |
+
line={
|
| 1394 |
+
"color": mode_color,
|
| 1395 |
+
"width": 4.1 if is_best else 2.1,
|
| 1396 |
+
"dash": "solid" if is_best else "dash",
|
| 1397 |
+
"shape": "spline",
|
| 1398 |
+
"smoothing": 1.15,
|
| 1399 |
+
},
|
| 1400 |
+
opacity=(0.72 + 0.26 * mode_prob) if is_best else (0.36 + 0.32 * mode_prob),
|
| 1401 |
+
hovertemplate=(
|
| 1402 |
+
f"ID {a['id']}<br>Mode {m + 1}"
|
| 1403 |
+
f"<br>Probability: {mode_prob * 100:.1f}%<extra></extra>"
|
| 1404 |
+
),
|
| 1405 |
+
name=(
|
| 1406 |
+
"Best path" if (is_best and idx == 0) else
|
| 1407 |
+
"Alternative paths" if ((not is_best) and (not alt_legend_added)) else None
|
| 1408 |
+
),
|
| 1409 |
+
showlegend=(is_best and idx == 0) or ((not is_best) and (not alt_legend_added)),
|
| 1410 |
+
)
|
| 1411 |
+
)
|
| 1412 |
+
|
| 1413 |
+
if (not is_best) and (not alt_legend_added):
|
| 1414 |
+
alt_legend_added = True
|
| 1415 |
+
|
| 1416 |
+
if a.get("is_target", False):
|
| 1417 |
+
fig.add_trace(
|
| 1418 |
+
go.Scatter(
|
| 1419 |
+
x=[curr_x + 0.9],
|
| 1420 |
+
y=[curr_y + 1.1],
|
| 1421 |
+
mode="text",
|
| 1422 |
+
text=[summary_text],
|
| 1423 |
+
textfont={"size": 9, "color": "rgba(226,232,240,0.90)"},
|
| 1424 |
+
hoverinfo="skip",
|
| 1425 |
+
showlegend=False,
|
| 1426 |
+
)
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
fig.update_layout(
|
| 1430 |
+
title={"text": "Main BEV Simulation", "x": 0.02, "font": {"size": 20, "color": WHITE}},
|
| 1431 |
+
paper_bgcolor=BG_SECONDARY,
|
| 1432 |
+
plot_bgcolor=BG_SECONDARY,
|
| 1433 |
+
legend={"orientation": "h", "y": 1.03, "x": 0.0, "font": {"color": WHITE, "size": 11}},
|
| 1434 |
+
margin={"l": 16, "r": 16, "t": 52, "b": 10},
|
| 1435 |
+
height=700,
|
| 1436 |
+
)
|
| 1437 |
+
fig.update_xaxes(
|
| 1438 |
+
title_text="X Lateral (m)",
|
| 1439 |
+
range=[x_min, x_max],
|
| 1440 |
+
color=WHITE,
|
| 1441 |
+
dtick=5,
|
| 1442 |
+
showgrid=True,
|
| 1443 |
+
gridcolor="rgba(148,163,184,0.16)",
|
| 1444 |
+
zeroline=False,
|
| 1445 |
+
)
|
| 1446 |
+
fig.update_yaxes(
|
| 1447 |
+
title_text="Y Forward (m)",
|
| 1448 |
+
range=[y_min, y_max],
|
| 1449 |
+
color=WHITE,
|
| 1450 |
+
dtick=5,
|
| 1451 |
+
showgrid=True,
|
| 1452 |
+
gridcolor="rgba(148,163,184,0.16)",
|
| 1453 |
+
scaleanchor="x",
|
| 1454 |
+
scaleratio=1,
|
| 1455 |
+
zeroline=False,
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
return fig
|
| 1459 |
+
|
| 1460 |
+
|
| 1461 |
+
def best_mode_idx(agent):
|
| 1462 |
+
probs = np.asarray(agent["probabilities"], dtype=float)
|
| 1463 |
+
return int(np.argmax(probs))
|
| 1464 |
+
|
| 1465 |
+
|
| 1466 |
+
def position_at_step(agent, step):
|
| 1467 |
+
if step <= 0:
|
| 1468 |
+
return tuple(agent["history"][-1])
|
| 1469 |
+
|
| 1470 |
+
k = best_mode_idx(agent)
|
| 1471 |
+
pred = agent["predictions"][k]
|
| 1472 |
+
idx = min(step - 1, len(pred) - 1)
|
| 1473 |
+
return tuple(pred[idx])
|
| 1474 |
+
|
| 1475 |
+
|
| 1476 |
+
def previous_position_for_velocity(agent, step):
|
| 1477 |
+
if step <= 1:
|
| 1478 |
+
return tuple(agent["history"][-1])
|
| 1479 |
+
|
| 1480 |
+
k = best_mode_idx(agent)
|
| 1481 |
+
pred = agent["predictions"][k]
|
| 1482 |
+
idx = max(0, min(step - 2, len(pred) - 1))
|
| 1483 |
+
return tuple(pred[idx])
|
| 1484 |
+
|
| 1485 |
+
|
| 1486 |
+
def project_world_to_camera(x, y, width, height, yaw_deg):
|
| 1487 |
+
# Ego frame: x right, y forward.
|
| 1488 |
+
yaw = np.deg2rad(yaw_deg)
|
| 1489 |
+
side = x * np.cos(yaw) + y * np.sin(yaw)
|
| 1490 |
+
depth = y * np.cos(yaw) - x * np.sin(yaw)
|
| 1491 |
+
|
| 1492 |
+
if depth <= 1.2:
|
| 1493 |
+
return None
|
| 1494 |
+
|
| 1495 |
+
focal = width * 0.85
|
| 1496 |
+
u = width * 0.5 + (side / depth) * focal
|
| 1497 |
+
v = height * 0.84 - min(280.0, 460.0 / (depth + 0.6))
|
| 1498 |
+
return float(u), float(v), float(depth)
|
| 1499 |
+
|
| 1500 |
+
|
| 1501 |
+
def build_synth_skeleton_points(u, v, box_w, box_h):
|
| 1502 |
+
head = (u, v - 0.38 * box_h)
|
| 1503 |
+
neck = (u, v - 0.28 * box_h)
|
| 1504 |
+
l_sh = (u - 0.22 * box_w, v - 0.22 * box_h)
|
| 1505 |
+
r_sh = (u + 0.22 * box_w, v - 0.22 * box_h)
|
| 1506 |
+
l_hand = (u - 0.34 * box_w, v - 0.03 * box_h)
|
| 1507 |
+
r_hand = (u + 0.34 * box_w, v - 0.03 * box_h)
|
| 1508 |
+
hip = (u, v - 0.02 * box_h)
|
| 1509 |
+
l_knee = (u - 0.14 * box_w, v + 0.30 * box_h)
|
| 1510 |
+
r_knee = (u + 0.14 * box_w, v + 0.30 * box_h)
|
| 1511 |
+
return [head, neck, l_sh, r_sh, l_hand, r_hand, hip, l_knee, r_knee]
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
def add_polyline_trace(fig, points, edges, color, point_size=4):
|
| 1515 |
+
xs = []
|
| 1516 |
+
ys = []
|
| 1517 |
+
for a, b in edges:
|
| 1518 |
+
if a >= len(points) or b >= len(points):
|
| 1519 |
+
continue
|
| 1520 |
+
xs.extend([points[a][0], points[b][0], None])
|
| 1521 |
+
ys.extend([points[a][1], points[b][1], None])
|
| 1522 |
+
|
| 1523 |
+
fig.add_trace(
|
| 1524 |
+
go.Scatter(
|
| 1525 |
+
x=xs,
|
| 1526 |
+
y=ys,
|
| 1527 |
+
mode="lines",
|
| 1528 |
+
line={"color": color, "width": 2},
|
| 1529 |
+
hoverinfo="skip",
|
| 1530 |
+
showlegend=False,
|
| 1531 |
+
)
|
| 1532 |
+
)
|
| 1533 |
+
|
| 1534 |
+
fig.add_trace(
|
| 1535 |
+
go.Scatter(
|
| 1536 |
+
x=[p[0] for p in points],
|
| 1537 |
+
y=[p[1] for p in points],
|
| 1538 |
+
mode="markers",
|
| 1539 |
+
marker={"size": point_size, "color": "#e2e8f0"},
|
| 1540 |
+
hoverinfo="skip",
|
| 1541 |
+
showlegend=False,
|
| 1542 |
+
)
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
|
| 1546 |
+
def add_coco_pose_trace(fig, keypoints, color, conf_thresh=0.2):
|
| 1547 |
+
if keypoints is None:
|
| 1548 |
+
return
|
| 1549 |
+
if len(keypoints) < 17:
|
| 1550 |
+
return
|
| 1551 |
+
|
| 1552 |
+
xs = []
|
| 1553 |
+
ys = []
|
| 1554 |
+
for a, b in COCO_SKELETON_EDGES:
|
| 1555 |
+
if keypoints[a][2] < conf_thresh or keypoints[b][2] < conf_thresh:
|
| 1556 |
+
continue
|
| 1557 |
+
xs.extend([keypoints[a][0], keypoints[b][0], None])
|
| 1558 |
+
ys.extend([keypoints[a][1], keypoints[b][1], None])
|
| 1559 |
+
|
| 1560 |
+
if len(xs) > 0:
|
| 1561 |
+
fig.add_trace(
|
| 1562 |
+
go.Scatter(
|
| 1563 |
+
x=xs,
|
| 1564 |
+
y=ys,
|
| 1565 |
+
mode="lines",
|
| 1566 |
+
line={"color": color, "width": 2},
|
| 1567 |
+
hoverinfo="skip",
|
| 1568 |
+
showlegend=False,
|
| 1569 |
+
)
|
| 1570 |
+
)
|
| 1571 |
+
|
| 1572 |
+
pts = [kp for kp in keypoints if kp[2] >= conf_thresh]
|
| 1573 |
+
if len(pts) > 0:
|
| 1574 |
+
fig.add_trace(
|
| 1575 |
+
go.Scatter(
|
| 1576 |
+
x=[p[0] for p in pts],
|
| 1577 |
+
y=[p[1] for p in pts],
|
| 1578 |
+
mode="markers",
|
| 1579 |
+
marker={"size": 4, "color": "#e2e8f0"},
|
| 1580 |
+
hoverinfo="skip",
|
| 1581 |
+
showlegend=False,
|
| 1582 |
+
)
|
| 1583 |
+
)
|
| 1584 |
+
|
| 1585 |
+
|
| 1586 |
+
def create_camera_figure_projected(image_arr, agents, camera_label, yaw_deg, step):
|
| 1587 |
+
h, w = image_arr.shape[0], image_arr.shape[1]
|
| 1588 |
+
|
| 1589 |
+
fig = go.Figure()
|
| 1590 |
+
fig.add_trace(go.Image(z=image_arr))
|
| 1591 |
+
|
| 1592 |
+
for agent in agents:
|
| 1593 |
+
x, y = position_at_step(agent, step)
|
| 1594 |
+
projection = project_world_to_camera(x, y, w, h, yaw_deg)
|
| 1595 |
+
if projection is None:
|
| 1596 |
+
continue
|
| 1597 |
+
|
| 1598 |
+
u, v, depth = projection
|
| 1599 |
+
if u < -40 or u > w + 40 or v < -40 or v > h + 40:
|
| 1600 |
+
continue
|
| 1601 |
+
|
| 1602 |
+
is_ped = agent["type"] == "pedestrian"
|
| 1603 |
+
color = agent_color(agent)
|
| 1604 |
+
|
| 1605 |
+
box_h = max(22.0, min(180.0, 260.0 / (depth + 0.5)))
|
| 1606 |
+
box_w = box_h * (0.42 if is_ped else 0.90)
|
| 1607 |
+
x1, y1 = u - box_w / 2, v - box_h
|
| 1608 |
+
x2, y2 = u + box_w / 2, v
|
| 1609 |
+
|
| 1610 |
+
fig.add_shape(
|
| 1611 |
+
type="rect",
|
| 1612 |
+
x0=x1,
|
| 1613 |
+
y0=y1,
|
| 1614 |
+
x1=x2,
|
| 1615 |
+
y1=y2,
|
| 1616 |
+
line={"color": color, "width": 2},
|
| 1617 |
+
fillcolor="rgba(0,0,0,0)",
|
| 1618 |
+
)
|
| 1619 |
+
|
| 1620 |
+
fig.add_trace(
|
| 1621 |
+
go.Scatter(
|
| 1622 |
+
x=[x1],
|
| 1623 |
+
y=[max(4, y1 - 12)],
|
| 1624 |
+
mode="text",
|
| 1625 |
+
text=[f"ID {agent['id']}"],
|
| 1626 |
+
textfont={"size": 11, "color": color},
|
| 1627 |
+
hoverinfo="skip",
|
| 1628 |
+
showlegend=False,
|
| 1629 |
+
)
|
| 1630 |
+
)
|
| 1631 |
+
|
| 1632 |
+
if is_ped:
|
| 1633 |
+
kps = build_synth_skeleton_points(u, v, box_w, box_h)
|
| 1634 |
+
add_polyline_trace(fig, kps, SYNTH_SKELETON_EDGES, color, point_size=4)
|
| 1635 |
+
|
| 1636 |
+
fig.update_xaxes(visible=False, range=[0, w])
|
| 1637 |
+
fig.update_yaxes(visible=False, range=[h, 0], scaleanchor="x", scaleratio=1)
|
| 1638 |
+
fig.update_layout(
|
| 1639 |
+
title={"text": camera_label, "x": 0.02, "font": {"color": WHITE, "size": 15}},
|
| 1640 |
+
paper_bgcolor=BG_SECONDARY,
|
| 1641 |
+
plot_bgcolor=BG_SECONDARY,
|
| 1642 |
+
margin={"l": 0, "r": 0, "t": 36, "b": 0},
|
| 1643 |
+
height=300,
|
| 1644 |
+
)
|
| 1645 |
+
return fig
|
| 1646 |
+
|
| 1647 |
+
|
| 1648 |
+
def create_camera_figure_detections(image_arr, detections, camera_label, target_track_id=None, highlight_track_ids=None):
|
| 1649 |
+
h, w = image_arr.shape[0], image_arr.shape[1]
|
| 1650 |
+
|
| 1651 |
+
fig = go.Figure()
|
| 1652 |
+
fig.add_trace(go.Image(z=image_arr))
|
| 1653 |
+
|
| 1654 |
+
for i, det in enumerate(detections):
|
| 1655 |
+
x1, y1, x2, y2 = det["box"]
|
| 1656 |
+
kind = det.get("kind", "vehicle")
|
| 1657 |
+
track_id = det.get("track_id")
|
| 1658 |
+
|
| 1659 |
+
if highlight_track_ids is not None and track_id is not None and track_id in highlight_track_ids:
|
| 1660 |
+
color = TARGET_PURPLE
|
| 1661 |
+
elif track_id is not None and track_id == target_track_id:
|
| 1662 |
+
color = TARGET_PURPLE
|
| 1663 |
+
elif kind == "pedestrian":
|
| 1664 |
+
color = VRU_GREEN
|
| 1665 |
+
else:
|
| 1666 |
+
color = VEHICLE_YELLOW
|
| 1667 |
+
|
| 1668 |
+
fig.add_shape(
|
| 1669 |
+
type="rect",
|
| 1670 |
+
x0=x1,
|
| 1671 |
+
y0=y1,
|
| 1672 |
+
x1=x2,
|
| 1673 |
+
y1=y2,
|
| 1674 |
+
line={"color": color, "width": 2},
|
| 1675 |
+
fillcolor="rgba(0,0,0,0)",
|
| 1676 |
+
)
|
| 1677 |
+
|
| 1678 |
+
display_id = track_id if track_id is not None else f"D{det.get('det_id', i + 1)}"
|
| 1679 |
+
fig.add_trace(
|
| 1680 |
+
go.Scatter(
|
| 1681 |
+
x=[x1],
|
| 1682 |
+
y=[max(4.0, y1 - 12.0)],
|
| 1683 |
+
mode="text",
|
| 1684 |
+
text=[f"ID {display_id}"],
|
| 1685 |
+
textfont={"size": 11, "color": color},
|
| 1686 |
+
hoverinfo="skip",
|
| 1687 |
+
showlegend=False,
|
| 1688 |
+
)
|
| 1689 |
+
)
|
| 1690 |
+
|
| 1691 |
+
if kind == "pedestrian":
|
| 1692 |
+
add_coco_pose_trace(fig, det.get("keypoints"), color)
|
| 1693 |
+
|
| 1694 |
+
fig.update_xaxes(visible=False, range=[0, w])
|
| 1695 |
+
fig.update_yaxes(visible=False, range=[h, 0], scaleanchor="x", scaleratio=1)
|
| 1696 |
+
fig.update_layout(
|
| 1697 |
+
title={"text": camera_label, "x": 0.02, "font": {"color": WHITE, "size": 15}},
|
| 1698 |
+
paper_bgcolor=BG_SECONDARY,
|
| 1699 |
+
plot_bgcolor=BG_SECONDARY,
|
| 1700 |
+
margin={"l": 0, "r": 0, "t": 36, "b": 0},
|
| 1701 |
+
height=300,
|
| 1702 |
+
)
|
| 1703 |
+
return fig
|
| 1704 |
+
|
| 1705 |
+
|
| 1706 |
+
def smooth_path(points):
|
| 1707 |
+
return [p[0] for p in points], [p[1] for p in points]
|
| 1708 |
+
|
| 1709 |
+
|
| 1710 |
+
def simulate_lidar_points(agents, step):
|
| 1711 |
+
rng = np.random.default_rng(1234 + step)
|
| 1712 |
+
|
| 1713 |
+
bg = np.column_stack(
|
| 1714 |
+
[
|
| 1715 |
+
rng.uniform(-35, 35, 1500),
|
| 1716 |
+
rng.uniform(-8, 55, 1500),
|
| 1717 |
+
]
|
| 1718 |
+
)
|
| 1719 |
+
|
| 1720 |
+
clusters = []
|
| 1721 |
+
for a in agents:
|
| 1722 |
+
cx, cy = position_at_step(a, step)
|
| 1723 |
+
n = 110 if a["type"] == "vehicle" else 70
|
| 1724 |
+
spread = np.array([0.8, 0.8]) if a["type"] == "pedestrian" else np.array([1.3, 1.1])
|
| 1725 |
+
pts = rng.normal([cx, cy], spread, size=(n, 2))
|
| 1726 |
+
clusters.append(pts)
|
| 1727 |
+
|
| 1728 |
+
if clusters:
|
| 1729 |
+
all_pts = np.vstack([bg] + clusters)
|
| 1730 |
+
else:
|
| 1731 |
+
all_pts = bg
|
| 1732 |
+
|
| 1733 |
+
mask = (
|
| 1734 |
+
(all_pts[:, 0] > -38)
|
| 1735 |
+
& (all_pts[:, 0] < 38)
|
| 1736 |
+
& (all_pts[:, 1] > -12)
|
| 1737 |
+
& (all_pts[:, 1] < 58)
|
| 1738 |
+
)
|
| 1739 |
+
return all_pts[mask]
|
| 1740 |
+
|
| 1741 |
+
|
| 1742 |
+
def simulate_radar_vectors(agents, step):
|
| 1743 |
+
vectors = []
|
| 1744 |
+
for a in agents:
|
| 1745 |
+
p_now = np.array(position_at_step(a, step), dtype=float)
|
| 1746 |
+
p_prev = np.array(previous_position_for_velocity(a, step), dtype=float)
|
| 1747 |
+
v = p_now - p_prev
|
| 1748 |
+
|
| 1749 |
+
if np.linalg.norm(v) < 0.04:
|
| 1750 |
+
continue
|
| 1751 |
+
|
| 1752 |
+
v = v / max(1e-6, np.linalg.norm(v)) * 1.6
|
| 1753 |
+
vectors.append((p_now[0], p_now[1], v[0], v[1], a["type"]))
|
| 1754 |
+
return vectors
|
| 1755 |
+
|
| 1756 |
+
|
| 1757 |
+
def classify_direction(history, prediction):
|
| 1758 |
+
h_prev = np.array(history[-2], dtype=float)
|
| 1759 |
+
h_curr = np.array(history[-1], dtype=float)
|
| 1760 |
+
p_end = np.array(prediction[-1], dtype=float)
|
| 1761 |
+
|
| 1762 |
+
heading = h_curr - h_prev
|
| 1763 |
+
motion = p_end - h_curr
|
| 1764 |
+
|
| 1765 |
+
if np.linalg.norm(motion) < 0.7:
|
| 1766 |
+
return "Stop"
|
| 1767 |
+
|
| 1768 |
+
if np.linalg.norm(heading) < 1e-6:
|
| 1769 |
+
heading = np.array([0.0, 1.0])
|
| 1770 |
+
|
| 1771 |
+
heading = heading / np.linalg.norm(heading)
|
| 1772 |
+
motion = motion / np.linalg.norm(motion)
|
| 1773 |
+
|
| 1774 |
+
cross = heading[0] * motion[1] - heading[1] * motion[0]
|
| 1775 |
+
dot = np.clip(np.dot(heading, motion), -1.0, 1.0)
|
| 1776 |
+
angle = np.degrees(np.arctan2(cross, dot))
|
| 1777 |
+
|
| 1778 |
+
if abs(angle) <= 25:
|
| 1779 |
+
return "Straight"
|
| 1780 |
+
if angle > 25:
|
| 1781 |
+
return "Left"
|
| 1782 |
+
if angle < -25:
|
| 1783 |
+
return "Right"
|
| 1784 |
+
return "Stop"
|
| 1785 |
+
|
| 1786 |
+
|
| 1787 |
+
def build_analytics_table(agents):
|
| 1788 |
+
rows = []
|
| 1789 |
+
direction_order = ["Straight", "Left", "Right", "Stop"]
|
| 1790 |
+
|
| 1791 |
+
for a in agents:
|
| 1792 |
+
bins = {k: 0.0 for k in direction_order}
|
| 1793 |
+
|
| 1794 |
+
for mode_idx, mode_path in enumerate(a["predictions"]):
|
| 1795 |
+
lbl = classify_direction(a["history"], mode_path)
|
| 1796 |
+
bins[lbl] += float(a["probabilities"][mode_idx])
|
| 1797 |
+
|
| 1798 |
+
ranked = sorted(bins.items(), key=lambda kv: kv[1], reverse=True)
|
| 1799 |
+
top3 = ranked[:3]
|
| 1800 |
+
|
| 1801 |
+
rows.append(
|
| 1802 |
+
{
|
| 1803 |
+
"Agent": f"ID {a['id']}",
|
| 1804 |
+
"Type": "Target VRU" if a.get("is_target", False) else a["type"].title(),
|
| 1805 |
+
"Top-1": f"{top3[0][0]} ({top3[0][1] * 100:.1f}%)",
|
| 1806 |
+
"Top-2": f"{top3[1][0]} ({top3[1][1] * 100:.1f}%)",
|
| 1807 |
+
"Top-3": f"{top3[2][0]} ({top3[2][1] * 100:.1f}%)",
|
| 1808 |
+
}
|
| 1809 |
+
)
|
| 1810 |
+
|
| 1811 |
+
return pd.DataFrame(rows)
|
| 1812 |
+
|
| 1813 |
+
|
| 1814 |
+
def generate_demo_agents(num_agents=8, history_steps=4, future_steps=12):
|
| 1815 |
+
rng = np.random.default_rng(42)
|
| 1816 |
+
agents = []
|
| 1817 |
+
|
| 1818 |
+
ped_count = max(5, int(0.7 * num_agents))
|
| 1819 |
+
|
| 1820 |
+
for i in range(num_agents):
|
| 1821 |
+
is_ped = i < ped_count
|
| 1822 |
+
a_type = "pedestrian" if is_ped else "vehicle"
|
| 1823 |
+
|
| 1824 |
+
base_x = rng.uniform(-16, 16)
|
| 1825 |
+
base_y = rng.uniform(9, 45)
|
| 1826 |
+
|
| 1827 |
+
if is_ped:
|
| 1828 |
+
vx = rng.uniform(-0.45, 0.45)
|
| 1829 |
+
vy = rng.uniform(0.15, 0.95)
|
| 1830 |
+
else:
|
| 1831 |
+
vx = rng.uniform(-0.20, 0.20)
|
| 1832 |
+
vy = rng.uniform(0.7, 1.6)
|
| 1833 |
+
|
| 1834 |
+
history = []
|
| 1835 |
+
for t in range(history_steps):
|
| 1836 |
+
phase = t - (history_steps - 1)
|
| 1837 |
+
x = base_x + phase * vx + 0.06 * np.sin(0.8 * t + i)
|
| 1838 |
+
y = base_y + phase * vy + 0.05 * np.cos(0.5 * t + i)
|
| 1839 |
+
history.append((float(x), float(y)))
|
| 1840 |
+
|
| 1841 |
+
probs = normalize_probs(rng.uniform(0.15, 1.0, size=3))
|
| 1842 |
+
|
| 1843 |
+
predictions = []
|
| 1844 |
+
x0, y0 = history[-1]
|
| 1845 |
+
for mode in range(3):
|
| 1846 |
+
mode_path = []
|
| 1847 |
+
curve = (-0.12 + 0.12 * mode) * (1.4 if is_ped else 0.8)
|
| 1848 |
+
accel = 0.02 * (mode - 1)
|
| 1849 |
+
for s in range(1, future_steps + 1):
|
| 1850 |
+
x = x0 + vx * s + curve * (s ** 1.25)
|
| 1851 |
+
y = y0 + vy * s + accel * (s ** 1.12)
|
| 1852 |
+
mode_path.append((float(x), float(y)))
|
| 1853 |
+
predictions.append(mode_path)
|
| 1854 |
+
|
| 1855 |
+
agents.append(
|
| 1856 |
+
{
|
| 1857 |
+
"id": i + 1,
|
| 1858 |
+
"type": a_type,
|
| 1859 |
+
"history": history,
|
| 1860 |
+
"predictions": predictions,
|
| 1861 |
+
"probabilities": probs,
|
| 1862 |
+
"is_target": (i == 0 and is_ped),
|
| 1863 |
+
}
|
| 1864 |
+
)
|
| 1865 |
+
|
| 1866 |
+
return agents
|
| 1867 |
+
|
| 1868 |
+
|
| 1869 |
+
def sanitize_agents(raw_agents):
|
| 1870 |
+
cleaned = []
|
| 1871 |
+
for i, a in enumerate(raw_agents):
|
| 1872 |
+
aid = int(a.get("id", i + 1))
|
| 1873 |
+
a_type = str(a.get("type", "pedestrian")).lower()
|
| 1874 |
+
if a_type not in ["pedestrian", "vehicle"]:
|
| 1875 |
+
a_type = "pedestrian"
|
| 1876 |
+
|
| 1877 |
+
history = [tuple(map(float, p)) for p in a.get("history", [])]
|
| 1878 |
+
predictions = []
|
| 1879 |
+
for mode in a.get("predictions", []):
|
| 1880 |
+
predictions.append([tuple(map(float, p)) for p in mode])
|
| 1881 |
+
|
| 1882 |
+
probs = normalize_probs(a.get("probabilities", [0.6, 0.25, 0.15]))
|
| 1883 |
+
|
| 1884 |
+
if len(history) < 2 or len(predictions) < 3:
|
| 1885 |
+
continue
|
| 1886 |
+
|
| 1887 |
+
cleaned.append(
|
| 1888 |
+
{
|
| 1889 |
+
"id": aid,
|
| 1890 |
+
"type": a_type,
|
| 1891 |
+
"history": history,
|
| 1892 |
+
"predictions": predictions[:3],
|
| 1893 |
+
"probabilities": probs[:3],
|
| 1894 |
+
"is_target": bool(a.get("is_target", False)),
|
| 1895 |
+
}
|
| 1896 |
+
)
|
| 1897 |
+
|
| 1898 |
+
if not any(a.get("is_target", False) for a in cleaned):
|
| 1899 |
+
for a in cleaned:
|
| 1900 |
+
if a["type"] == "pedestrian":
|
| 1901 |
+
a["is_target"] = True
|
| 1902 |
+
break
|
| 1903 |
+
|
| 1904 |
+
return cleaned
|
| 1905 |
+
|
| 1906 |
+
|
| 1907 |
+
def build_bev_figure(
|
| 1908 |
+
agents,
|
| 1909 |
+
step,
|
| 1910 |
+
show_lidar,
|
| 1911 |
+
show_radar,
|
| 1912 |
+
show_multimodal,
|
| 1913 |
+
lidar_xy=None,
|
| 1914 |
+
radar_xy=None,
|
| 1915 |
+
radar_vel=None,
|
| 1916 |
+
):
|
| 1917 |
+
fig = go.Figure()
|
| 1918 |
+
|
| 1919 |
+
x_min, x_max = -36.0, 36.0
|
| 1920 |
+
y_min, y_max = -12.0, 58.0
|
| 1921 |
+
|
| 1922 |
+
add_structured_road_scene(fig, x_min, x_max, y_min, y_max, add_crosswalk=True)
|
| 1923 |
+
|
| 1924 |
+
fig.add_shape(
|
| 1925 |
+
type="rect",
|
| 1926 |
+
x0=-1.1,
|
| 1927 |
+
y0=-2.2,
|
| 1928 |
+
x1=1.1,
|
| 1929 |
+
y1=2.2,
|
| 1930 |
+
line={"color": EGO_CYAN, "width": 2.2},
|
| 1931 |
+
fillcolor="rgba(34,211,238,0.20)",
|
| 1932 |
+
)
|
| 1933 |
+
fig.add_annotation(
|
| 1934 |
+
x=0.0,
|
| 1935 |
+
y=4.2,
|
| 1936 |
+
ax=0.0,
|
| 1937 |
+
ay=1.2,
|
| 1938 |
+
arrowcolor=EGO_CYAN,
|
| 1939 |
+
arrowwidth=2.8,
|
| 1940 |
+
arrowhead=3,
|
| 1941 |
+
showarrow=True,
|
| 1942 |
+
text="",
|
| 1943 |
+
)
|
| 1944 |
+
|
| 1945 |
+
fig.add_trace(
|
| 1946 |
+
go.Scatter(
|
| 1947 |
+
x=[None],
|
| 1948 |
+
y=[None],
|
| 1949 |
+
mode="markers",
|
| 1950 |
+
marker={"size": 10, "symbol": "circle", "color": VRU_GREEN},
|
| 1951 |
+
name="Pedestrian",
|
| 1952 |
+
)
|
| 1953 |
+
)
|
| 1954 |
+
fig.add_trace(
|
| 1955 |
+
go.Scatter(
|
| 1956 |
+
x=[None],
|
| 1957 |
+
y=[None],
|
| 1958 |
+
mode="markers",
|
| 1959 |
+
marker={"size": 10, "symbol": "square", "color": VEHICLE_YELLOW},
|
| 1960 |
+
name="Vehicle",
|
| 1961 |
+
)
|
| 1962 |
+
)
|
| 1963 |
+
|
| 1964 |
+
if show_lidar:
|
| 1965 |
+
if lidar_xy is not None and len(lidar_xy) > 0:
|
| 1966 |
+
lidar = np.asarray(lidar_xy, dtype=float)
|
| 1967 |
+
mask = (
|
| 1968 |
+
(lidar[:, 0] > -38)
|
| 1969 |
+
& (lidar[:, 0] < 38)
|
| 1970 |
+
& (lidar[:, 1] > -12)
|
| 1971 |
+
& (lidar[:, 1] < 58)
|
| 1972 |
+
)
|
| 1973 |
+
lidar = lidar[mask]
|
| 1974 |
+
else:
|
| 1975 |
+
lidar = simulate_lidar_points(agents, step)
|
| 1976 |
+
|
| 1977 |
+
if len(lidar) > 0:
|
| 1978 |
+
lidar = lidar[::6]
|
| 1979 |
+
fig.add_trace(
|
| 1980 |
+
go.Scatter(
|
| 1981 |
+
x=lidar[:, 0],
|
| 1982 |
+
y=lidar[:, 1],
|
| 1983 |
+
mode="markers",
|
| 1984 |
+
marker={"size": 3, "color": "rgba(34,211,238,0.22)"},
|
| 1985 |
+
name="LiDAR",
|
| 1986 |
+
)
|
| 1987 |
+
)
|
| 1988 |
+
|
| 1989 |
+
if show_radar:
|
| 1990 |
+
rx = []
|
| 1991 |
+
ry = []
|
| 1992 |
+
|
| 1993 |
+
if (
|
| 1994 |
+
radar_xy is not None
|
| 1995 |
+
and radar_vel is not None
|
| 1996 |
+
and len(radar_xy) > 0
|
| 1997 |
+
and len(radar_xy) == len(radar_vel)
|
| 1998 |
+
):
|
| 1999 |
+
radar_xy = np.asarray(radar_xy, dtype=float)
|
| 2000 |
+
radar_vel = np.asarray(radar_vel, dtype=float)
|
| 2001 |
+
stride = max(1, len(radar_xy) // 90)
|
| 2002 |
+
|
| 2003 |
+
for i in range(0, len(radar_xy), stride):
|
| 2004 |
+
x0, y0 = radar_xy[i, 0], radar_xy[i, 1]
|
| 2005 |
+
vx, vy = radar_vel[i, 0], radar_vel[i, 1]
|
| 2006 |
+
rx.extend([x0, x0 + 0.55 * vx, None])
|
| 2007 |
+
ry.extend([y0, y0 + 0.55 * vy, None])
|
| 2008 |
+
else:
|
| 2009 |
+
radar_vectors = simulate_radar_vectors(agents, step)
|
| 2010 |
+
for x0, y0, vx, vy, _ in radar_vectors:
|
| 2011 |
+
rx.extend([x0, x0 + vx, None])
|
| 2012 |
+
ry.extend([y0, y0 + vy, None])
|
| 2013 |
+
|
| 2014 |
+
if len(rx) > 0:
|
| 2015 |
+
fig.add_trace(
|
| 2016 |
+
go.Scatter(
|
| 2017 |
+
x=rx,
|
| 2018 |
+
y=ry,
|
| 2019 |
+
mode="lines",
|
| 2020 |
+
line={"color": "rgba(250,204,21,0.75)", "width": 2},
|
| 2021 |
+
name="Radar velocity",
|
| 2022 |
+
)
|
| 2023 |
+
)
|
| 2024 |
+
|
| 2025 |
+
alt_legend_added = False
|
| 2026 |
+
|
| 2027 |
+
for idx, a in enumerate(agents):
|
| 2028 |
+
base_color = agent_color(a)
|
| 2029 |
+
best_idx = best_mode_idx(a)
|
| 2030 |
+
best_prob = float(a["probabilities"][best_idx]) if len(a["probabilities"]) > 0 else 0.0
|
| 2031 |
+
marker_color = hex_to_rgba(base_color, 0.48 + 0.52 * best_prob)
|
| 2032 |
+
summary_text, _ = summarize_agent_probabilities(a)
|
| 2033 |
+
|
| 2034 |
+
hx, hy = smooth_path(a["history"])
|
| 2035 |
+
fig.add_trace(
|
| 2036 |
+
go.Scatter(
|
| 2037 |
+
x=hx,
|
| 2038 |
+
y=hy,
|
| 2039 |
+
mode="lines",
|
| 2040 |
+
line={"color": "rgba(226,232,240,0.55)", "width": 2.2, "dash": "dot", "shape": "spline", "smoothing": 1.0},
|
| 2041 |
+
name="Past trajectory" if idx == 0 else None,
|
| 2042 |
+
showlegend=(idx == 0),
|
| 2043 |
+
hovertemplate=f"ID {a['id']} past trajectory<extra></extra>",
|
| 2044 |
+
)
|
| 2045 |
+
)
|
| 2046 |
+
|
| 2047 |
+
cx, cy = position_at_step(a, step)
|
| 2048 |
+
fig.add_trace(
|
| 2049 |
+
go.Scatter(
|
| 2050 |
+
x=[cx],
|
| 2051 |
+
y=[cy],
|
| 2052 |
+
mode="markers+text",
|
| 2053 |
+
marker={
|
| 2054 |
+
"size": 11,
|
| 2055 |
+
"symbol": "circle" if a.get("type") == "pedestrian" else "square",
|
| 2056 |
+
"color": marker_color,
|
| 2057 |
+
"line": {"color": "#111827", "width": 1.2},
|
| 2058 |
+
},
|
| 2059 |
+
text=[f"ID {a['id']}"],
|
| 2060 |
+
textposition="top center",
|
| 2061 |
+
textfont={"size": 10, "color": WHITE},
|
| 2062 |
+
hovertemplate=(
|
| 2063 |
+
f"ID {a['id']}<br>Type: {a['type'].title()}"
|
| 2064 |
+
f"<br>{summary_text}<br>Best path confidence: {best_prob * 100:.1f}%<extra></extra>"
|
| 2065 |
+
),
|
| 2066 |
+
showlegend=False,
|
| 2067 |
+
)
|
| 2068 |
+
)
|
| 2069 |
+
|
| 2070 |
+
px, py = previous_position_for_velocity(a, step)
|
| 2071 |
+
dx, dy = cx - px, cy - py
|
| 2072 |
+
norm = np.hypot(dx, dy)
|
| 2073 |
+
if norm > 1e-3:
|
| 2074 |
+
sx, sy = (dx / norm) * 1.8, (dy / norm) * 1.8
|
| 2075 |
+
fig.add_annotation(x=cx + sx, y=cy + sy, ax=cx, ay=cy, showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=2, arrowcolor=base_color, text="")
|
| 2076 |
+
|
| 2077 |
+
mode_order = [best_idx, 0, 1, 2]
|
| 2078 |
+
mode_order = list(dict.fromkeys(mode_order))
|
| 2079 |
+
|
| 2080 |
+
for rank, m in enumerate(mode_order[:3]):
|
| 2081 |
+
if (not show_multimodal) and (rank > 0):
|
| 2082 |
+
continue
|
| 2083 |
+
|
| 2084 |
+
mode_prob = float(a["probabilities"][m]) if m < len(a["probabilities"]) else 0.0
|
| 2085 |
+
mode_color = TRAJ_MODE_COLORS[m % len(TRAJ_MODE_COLORS)]
|
| 2086 |
+
|
| 2087 |
+
mode_path = a["predictions"][m]
|
| 2088 |
+
end_idx = max(1, min(step, len(mode_path)))
|
| 2089 |
+
mode_slice = mode_path[:end_idx]
|
| 2090 |
+
mx, my = smooth_path([(cx, cy)] + mode_slice)
|
| 2091 |
+
|
| 2092 |
+
is_best = m == best_idx
|
| 2093 |
+
|
| 2094 |
+
if is_best:
|
| 2095 |
+
for lw, op in [(14, 0.08), (9, 0.16)]:
|
| 2096 |
+
fig.add_trace(
|
| 2097 |
+
go.Scatter(
|
| 2098 |
+
x=mx,
|
| 2099 |
+
y=my,
|
| 2100 |
+
mode="lines",
|
| 2101 |
+
line={"color": mode_color, "width": lw, "shape": "spline", "smoothing": 1.15},
|
| 2102 |
+
opacity=op,
|
| 2103 |
+
hoverinfo="skip",
|
| 2104 |
+
showlegend=False,
|
| 2105 |
+
)
|
| 2106 |
+
)
|
| 2107 |
+
|
| 2108 |
+
fig.add_trace(
|
| 2109 |
+
go.Scatter(
|
| 2110 |
+
x=mx,
|
| 2111 |
+
y=my,
|
| 2112 |
+
mode="lines",
|
| 2113 |
+
line={
|
| 2114 |
+
"color": mode_color,
|
| 2115 |
+
"width": 4.1 if is_best else 2.1,
|
| 2116 |
+
"dash": "solid" if is_best else "dash",
|
| 2117 |
+
"shape": "spline",
|
| 2118 |
+
"smoothing": 1.15,
|
| 2119 |
+
},
|
| 2120 |
+
opacity=(0.72 + 0.26 * mode_prob) if is_best else (0.36 + 0.32 * mode_prob),
|
| 2121 |
+
hovertemplate=(
|
| 2122 |
+
f"ID {a['id']}<br>Mode {m + 1}"
|
| 2123 |
+
f"<br>Probability: {mode_prob * 100:.1f}%<extra></extra>"
|
| 2124 |
+
),
|
| 2125 |
+
name=(
|
| 2126 |
+
"Best path" if (is_best and idx == 0) else
|
| 2127 |
+
"Alternative paths" if ((not is_best) and (not alt_legend_added)) else None
|
| 2128 |
+
),
|
| 2129 |
+
showlegend=(is_best and idx == 0) or ((not is_best) and (not alt_legend_added)),
|
| 2130 |
+
)
|
| 2131 |
+
)
|
| 2132 |
+
|
| 2133 |
+
if (not is_best) and (not alt_legend_added):
|
| 2134 |
+
alt_legend_added = True
|
| 2135 |
+
|
| 2136 |
+
if a.get("is_target", False):
|
| 2137 |
+
fig.add_trace(
|
| 2138 |
+
go.Scatter(
|
| 2139 |
+
x=[cx + 0.9],
|
| 2140 |
+
y=[cy + 1.1],
|
| 2141 |
+
mode="text",
|
| 2142 |
+
text=[summary_text],
|
| 2143 |
+
textfont={"size": 9, "color": "rgba(226,232,240,0.90)"},
|
| 2144 |
+
hoverinfo="skip",
|
| 2145 |
+
showlegend=False,
|
| 2146 |
+
)
|
| 2147 |
+
)
|
| 2148 |
+
|
| 2149 |
+
fig.update_layout(
|
| 2150 |
+
title={"text": "Main BEV Simulation", "x": 0.02, "font": {"size": 20, "color": WHITE}},
|
| 2151 |
+
paper_bgcolor=BG_SECONDARY,
|
| 2152 |
+
plot_bgcolor=BG_SECONDARY,
|
| 2153 |
+
legend={"orientation": "h", "y": 1.03, "x": 0.0, "font": {"color": WHITE, "size": 11}},
|
| 2154 |
+
margin={"l": 16, "r": 16, "t": 52, "b": 10},
|
| 2155 |
+
height=700,
|
| 2156 |
+
)
|
| 2157 |
+
|
| 2158 |
+
fig.update_xaxes(
|
| 2159 |
+
title_text="X Lateral (m)",
|
| 2160 |
+
range=[x_min, x_max],
|
| 2161 |
+
color=WHITE,
|
| 2162 |
+
dtick=5,
|
| 2163 |
+
showgrid=True,
|
| 2164 |
+
gridcolor="rgba(148,163,184,0.16)",
|
| 2165 |
+
zeroline=False,
|
| 2166 |
+
)
|
| 2167 |
+
fig.update_yaxes(
|
| 2168 |
+
title_text="Y Forward (m)",
|
| 2169 |
+
range=[y_min, y_max],
|
| 2170 |
+
color=WHITE,
|
| 2171 |
+
dtick=5,
|
| 2172 |
+
showgrid=True,
|
| 2173 |
+
gridcolor="rgba(148,163,184,0.16)",
|
| 2174 |
+
scaleanchor="x",
|
| 2175 |
+
scaleratio=1,
|
| 2176 |
+
zeroline=False,
|
| 2177 |
+
)
|
| 2178 |
+
|
| 2179 |
+
return fig
|
| 2180 |
+
|
| 2181 |
+
|
| 2182 |
+
# ----------------------------
|
| 2183 |
+
# SIDEBAR CONTROLS
|
| 2184 |
+
# ----------------------------
|
| 2185 |
+
st.title("Multi-Agent Trajectory Prediction Simulator (BEV)")
|
| 2186 |
+
st.caption("Camera + LiDAR + Radar Fusion")
|
| 2187 |
+
|
| 2188 |
+
st.sidebar.header("Simulation Controls")
|
| 2189 |
+
|
| 2190 |
+
if "playing" not in st.session_state:
|
| 2191 |
+
st.session_state.playing = False
|
| 2192 |
+
if "time_step" not in st.session_state:
|
| 2193 |
+
st.session_state.time_step = 0
|
| 2194 |
+
if "time_step_slider" not in st.session_state:
|
| 2195 |
+
st.session_state.time_step_slider = 0
|
| 2196 |
+
|
| 2197 |
+
agent_source = st.sidebar.radio(
|
| 2198 |
+
"Agent Source",
|
| 2199 |
+
["Two Image Upload", "Live CV + Fusion", "Synthetic Demo", "Upload JSON"],
|
| 2200 |
+
index=0,
|
| 2201 |
+
)
|
| 2202 |
+
|
| 2203 |
+
uploaded_prev = None
|
| 2204 |
+
uploaded_curr = None
|
| 2205 |
+
uploaded_json = None
|
| 2206 |
+
|
| 2207 |
+
if agent_source == "Two Image Upload":
|
| 2208 |
+
uploaded_prev = st.sidebar.file_uploader("Image 1 (t-1)", type=["jpg", "jpeg", "png"], key="img_t_minus_1")
|
| 2209 |
+
uploaded_curr = st.sidebar.file_uploader("Image 2 (t0)", type=["jpg", "jpeg", "png"], key="img_t0")
|
| 2210 |
+
elif agent_source == "Upload JSON":
|
| 2211 |
+
uploaded_json = st.sidebar.file_uploader("Upload agents JSON", type=["json"])
|
| 2212 |
+
|
| 2213 |
+
num_agents = st.sidebar.slider("Number of agents", min_value=5, max_value=10, value=8)
|
| 2214 |
+
|
| 2215 |
+
show_lidar = st.sidebar.checkbox("Show LiDAR", value=True)
|
| 2216 |
+
show_radar = st.sidebar.checkbox("Show Radar", value=True)
|
| 2217 |
+
show_multimodal = st.sidebar.checkbox("Show multi-modal paths", value=True)
|
| 2218 |
+
|
| 2219 |
+
if agent_source == "Live CV + Fusion":
|
| 2220 |
+
st.sidebar.caption(f"Trajectory model: {'Fusion Phase-2 checkpoint' if USING_FUSION_MODEL else 'Base checkpoint'}")
|
| 2221 |
+
|
| 2222 |
+
col_a, col_b = st.sidebar.columns(2)
|
| 2223 |
+
if col_a.button("Play / Pause", use_container_width=True):
|
| 2224 |
+
st.session_state.playing = not st.session_state.playing
|
| 2225 |
+
if col_b.button("Reset", use_container_width=True):
|
| 2226 |
+
st.session_state.playing = False
|
| 2227 |
+
st.session_state.time_step = 0
|
| 2228 |
+
st.session_state.time_step_slider = 0
|
| 2229 |
+
|
| 2230 |
+
step = st.sidebar.slider("Time step", min_value=0, max_value=12, value=int(st.session_state.time_step), key="time_step_slider")
|
| 2231 |
+
st.session_state.time_step = step
|
| 2232 |
+
|
| 2233 |
+
# ----------------------------
|
| 2234 |
+
# DATA INGESTION
|
| 2235 |
+
# ----------------------------
|
| 2236 |
+
agents = None
|
| 2237 |
+
fusion_payload = None
|
| 2238 |
+
camera_payload = None
|
| 2239 |
+
target_track_id = None
|
| 2240 |
+
live_status_msg = None
|
| 2241 |
+
|
| 2242 |
+
if agent_source == "Two Image Upload":
|
| 2243 |
+
det_threshold = st.sidebar.slider("Detection threshold", min_value=0.20, max_value=0.90, value=0.35, step=0.01)
|
| 2244 |
+
track_gate_px = st.sidebar.slider("Tracking gate (px)", min_value=30, max_value=220, value=130, step=5)
|
| 2245 |
+
min_motion_px = st.sidebar.slider("Minimum motion (px)", min_value=0, max_value=40, value=0, step=1)
|
| 2246 |
+
use_pose = st.sidebar.checkbox("Use Keypoint R-CNN", value=True)
|
| 2247 |
+
|
| 2248 |
+
if uploaded_prev is None or uploaded_curr is None:
|
| 2249 |
+
st.info("Upload exactly 2 sequential images (t-1 and t0) to run prediction.")
|
| 2250 |
+
agents = []
|
| 2251 |
+
else:
|
| 2252 |
+
img_prev = uploaded_file_to_array(uploaded_prev)
|
| 2253 |
+
img_curr = uploaded_file_to_array(uploaded_curr)
|
| 2254 |
+
|
| 2255 |
+
if img_prev is None or img_curr is None:
|
| 2256 |
+
st.warning("Could not read one of the uploaded images. Please try JPG/PNG files.")
|
| 2257 |
+
agents = []
|
| 2258 |
+
else:
|
| 2259 |
+
with st.spinner("Running 2-image perception and trajectory prediction..."):
|
| 2260 |
+
bundle = build_two_image_agents_bundle(
|
| 2261 |
+
img_prev,
|
| 2262 |
+
img_curr,
|
| 2263 |
+
score_threshold=det_threshold,
|
| 2264 |
+
tracking_gate_px=track_gate_px,
|
| 2265 |
+
min_motion_px=min_motion_px,
|
| 2266 |
+
use_pose=use_pose,
|
| 2267 |
+
)
|
| 2268 |
+
|
| 2269 |
+
if "error" in bundle:
|
| 2270 |
+
st.warning(f"Two-image pipeline failed: {bundle['error']}")
|
| 2271 |
+
agents = []
|
| 2272 |
+
camera_payload = {
|
| 2273 |
+
"mode": "two_upload",
|
| 2274 |
+
"pair_prev": {"image": img_prev, "detections": []},
|
| 2275 |
+
"pair_curr": {"image": img_curr, "detections": []},
|
| 2276 |
+
}
|
| 2277 |
+
else:
|
| 2278 |
+
agents = bundle["agents"]
|
| 2279 |
+
camera_payload = {"mode": "two_upload"}
|
| 2280 |
+
camera_payload.update(bundle.get("camera_snapshots", {}))
|
| 2281 |
+
target_track_id = bundle.get("target_track_id")
|
| 2282 |
+
live_status_msg = (
|
| 2283 |
+
f"Two-image pipeline on {bundle.get('device', 'unknown')} | "
|
| 2284 |
+
f"Predicted agents: {bundle.get('match_count', len(agents))}"
|
| 2285 |
+
)
|
| 2286 |
+
|
| 2287 |
+
elif agent_source == "Live CV + Fusion":
|
| 2288 |
+
front_paths = list_channel_image_paths("CAM_FRONT")
|
| 2289 |
+
|
| 2290 |
+
if len(front_paths) < 4:
|
| 2291 |
+
st.warning("Live mode needs at least 4 frames in DataSet/samples/CAM_FRONT. Using synthetic data.")
|
| 2292 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2293 |
+
else:
|
| 2294 |
+
anchor_idx = st.sidebar.slider("Anchor frame index (CAM_FRONT)", min_value=3, max_value=len(front_paths) - 1, value=len(front_paths) - 1)
|
| 2295 |
+
det_threshold = st.sidebar.slider("Detection threshold", min_value=0.30, max_value=0.90, value=0.55, step=0.01)
|
| 2296 |
+
track_gate_px = st.sidebar.slider("Tracking gate (px)", min_value=40, max_value=180, value=90, step=5)
|
| 2297 |
+
use_pose = st.sidebar.checkbox("Use Keypoint R-CNN", value=True)
|
| 2298 |
+
|
| 2299 |
+
with st.spinner("Running perception, tracking, fusion, and trajectory prediction..."):
|
| 2300 |
+
bundle = build_live_agents_bundle(anchor_idx, det_threshold, track_gate_px, use_pose)
|
| 2301 |
+
|
| 2302 |
+
if "error" in bundle:
|
| 2303 |
+
st.warning(f"Live pipeline failed: {bundle['error']} Falling back to synthetic data.")
|
| 2304 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2305 |
+
else:
|
| 2306 |
+
agents = bundle["agents"]
|
| 2307 |
+
fusion_payload = bundle.get("fusion_data")
|
| 2308 |
+
camera_payload = bundle.get("camera_snapshots")
|
| 2309 |
+
target_track_id = bundle.get("target_track_id")
|
| 2310 |
+
live_status_msg = f"Live pipeline on {bundle.get('device', 'unknown')} | Tracked agents: {len(agents)}"
|
| 2311 |
+
|
| 2312 |
+
elif agent_source == "Upload JSON" and uploaded_json is not None:
|
| 2313 |
+
try:
|
| 2314 |
+
payload = json.load(uploaded_json)
|
| 2315 |
+
if isinstance(payload, dict) and "agents" in payload:
|
| 2316 |
+
raw_agents = payload["agents"]
|
| 2317 |
+
elif isinstance(payload, list):
|
| 2318 |
+
raw_agents = payload
|
| 2319 |
+
else:
|
| 2320 |
+
raw_agents = []
|
| 2321 |
+
|
| 2322 |
+
agents = sanitize_agents(raw_agents)
|
| 2323 |
+
if len(agents) == 0:
|
| 2324 |
+
st.warning("Uploaded JSON did not contain valid agent entries. Falling back to synthetic demo data.")
|
| 2325 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2326 |
+
except Exception as e:
|
| 2327 |
+
st.warning(f"Could not parse uploaded JSON ({e}). Falling back to synthetic demo data.")
|
| 2328 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2329 |
+
|
| 2330 |
+
elif agent_source == "Synthetic Demo":
|
| 2331 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2332 |
+
|
| 2333 |
+
else:
|
| 2334 |
+
agents = []
|
| 2335 |
+
|
| 2336 |
+
if agents is None:
|
| 2337 |
+
agents = generate_demo_agents(num_agents=num_agents)
|
| 2338 |
+
|
| 2339 |
+
lidar_xy = fusion_payload.get("lidar_xy") if fusion_payload is not None else None
|
| 2340 |
+
radar_xy = fusion_payload.get("radar_xy") if fusion_payload is not None else None
|
| 2341 |
+
radar_vel = fusion_payload.get("radar_vel") if fusion_payload is not None else None
|
| 2342 |
+
|
| 2343 |
+
# ----------------------------
|
| 2344 |
+
# TOP PANEL: MULTI-CAMERA
|
| 2345 |
+
# ----------------------------
|
| 2346 |
+
st.markdown("## 1. Multi-Camera View")
|
| 2347 |
+
|
| 2348 |
+
target_highlight_ids = {a["id"] for a in agents if a.get("is_target", False)} if len(agents) > 0 else set()
|
| 2349 |
+
|
| 2350 |
+
if agent_source == "Two Image Upload" and (camera_payload is None or camera_payload.get("mode") != "two_upload"):
|
| 2351 |
+
c1, c2, c3 = st.columns(3)
|
| 2352 |
+
empty = fallback_canvas()
|
| 2353 |
+
|
| 2354 |
+
with c1:
|
| 2355 |
+
fig_prev = create_camera_figure_detections(empty, [], "Input Frame (t-1)", target_track_id=None, highlight_track_ids=None)
|
| 2356 |
+
st.plotly_chart(fig_prev, use_container_width=True, config={"displayModeBar": False})
|
| 2357 |
+
|
| 2358 |
+
with c2:
|
| 2359 |
+
fig_curr = create_camera_figure_detections(empty, [], "Input Frame (t0)", target_track_id=None, highlight_track_ids=None)
|
| 2360 |
+
st.plotly_chart(fig_curr, use_container_width=True, config={"displayModeBar": False})
|
| 2361 |
+
|
| 2362 |
+
with c3:
|
| 2363 |
+
fig_pred = create_camera_figure_detections(empty, [], "Prediction Output", target_track_id=None, highlight_track_ids=None)
|
| 2364 |
+
st.plotly_chart(fig_pred, use_container_width=True, config={"displayModeBar": False})
|
| 2365 |
+
|
| 2366 |
+
elif camera_payload is not None and camera_payload.get("mode") == "two_upload":
|
| 2367 |
+
c1, c2, c3 = st.columns(3)
|
| 2368 |
+
|
| 2369 |
+
snap_prev = camera_payload.get("pair_prev", {"image": fallback_canvas(), "detections": []})
|
| 2370 |
+
snap_curr = camera_payload.get("pair_curr", {"image": fallback_canvas(), "detections": []})
|
| 2371 |
+
|
| 2372 |
+
with c1:
|
| 2373 |
+
fig_prev = create_camera_figure_detections(
|
| 2374 |
+
snap_prev["image"],
|
| 2375 |
+
snap_prev["detections"],
|
| 2376 |
+
"Input Frame (t-1)",
|
| 2377 |
+
target_track_id=target_track_id,
|
| 2378 |
+
highlight_track_ids=target_highlight_ids,
|
| 2379 |
+
)
|
| 2380 |
+
st.plotly_chart(fig_prev, use_container_width=True, config={"displayModeBar": False})
|
| 2381 |
+
|
| 2382 |
+
with c2:
|
| 2383 |
+
fig_curr = create_camera_figure_detections(
|
| 2384 |
+
snap_curr["image"],
|
| 2385 |
+
snap_curr["detections"],
|
| 2386 |
+
"Input Frame (t0)",
|
| 2387 |
+
target_track_id=target_track_id,
|
| 2388 |
+
highlight_track_ids=target_highlight_ids,
|
| 2389 |
+
)
|
| 2390 |
+
st.plotly_chart(fig_curr, use_container_width=True, config={"displayModeBar": False})
|
| 2391 |
+
|
| 2392 |
+
with c3:
|
| 2393 |
+
fig_pred = create_prediction_overlay_figure(
|
| 2394 |
+
snap_curr["image"],
|
| 2395 |
+
snap_curr["detections"],
|
| 2396 |
+
agents,
|
| 2397 |
+
step=st.session_state.time_step,
|
| 2398 |
+
target_track_id=target_track_id,
|
| 2399 |
+
highlight_track_ids=target_highlight_ids,
|
| 2400 |
+
)
|
| 2401 |
+
st.plotly_chart(fig_pred, use_container_width=True, config={"displayModeBar": False})
|
| 2402 |
+
|
| 2403 |
+
else:
|
| 2404 |
+
cam_cols = st.columns(3)
|
| 2405 |
+
for i, (channel, label, yaw) in enumerate(CAMERA_VIEWS):
|
| 2406 |
+
with cam_cols[i]:
|
| 2407 |
+
if camera_payload is not None and channel in camera_payload:
|
| 2408 |
+
snap = camera_payload[channel]
|
| 2409 |
+
cam_fig = create_camera_figure_detections(
|
| 2410 |
+
snap["image"],
|
| 2411 |
+
snap["detections"],
|
| 2412 |
+
label,
|
| 2413 |
+
target_track_id=target_track_id,
|
| 2414 |
+
highlight_track_ids=None,
|
| 2415 |
+
)
|
| 2416 |
+
else:
|
| 2417 |
+
img_arr, _ = load_camera_frame(channel, frame_idx=0)
|
| 2418 |
+
cam_fig = create_camera_figure_projected(img_arr, agents, label, yaw, st.session_state.time_step)
|
| 2419 |
+
|
| 2420 |
+
st.plotly_chart(cam_fig, use_container_width=True, config={"displayModeBar": False})
|
| 2421 |
+
|
| 2422 |
+
# ----------------------------
|
| 2423 |
+
# CENTER + SIDE PANELS
|
| 2424 |
+
# ----------------------------
|
| 2425 |
+
left_col, right_col = st.columns([3.6, 1.4], gap="large")
|
| 2426 |
+
|
| 2427 |
+
with left_col:
|
| 2428 |
+
if agent_source == "Two Image Upload":
|
| 2429 |
+
scene_ctx = None
|
| 2430 |
+
scene_dets = None
|
| 2431 |
+
if camera_payload is not None and camera_payload.get("mode") == "two_upload":
|
| 2432 |
+
scene_ctx = camera_payload.get("pair_curr", {}).get("image")
|
| 2433 |
+
scene_dets = camera_payload.get("pair_curr", {}).get("detections", [])
|
| 2434 |
+
|
| 2435 |
+
bev_fig = build_reference_bev_figure(
|
| 2436 |
+
agents=agents,
|
| 2437 |
+
step=st.session_state.time_step,
|
| 2438 |
+
show_multimodal=show_multimodal,
|
| 2439 |
+
scene_image=scene_ctx,
|
| 2440 |
+
scene_detections=scene_dets,
|
| 2441 |
+
)
|
| 2442 |
+
else:
|
| 2443 |
+
bev_fig = build_bev_figure(
|
| 2444 |
+
agents=agents,
|
| 2445 |
+
step=st.session_state.time_step,
|
| 2446 |
+
show_lidar=show_lidar,
|
| 2447 |
+
show_radar=show_radar,
|
| 2448 |
+
show_multimodal=show_multimodal,
|
| 2449 |
+
lidar_xy=lidar_xy,
|
| 2450 |
+
radar_xy=radar_xy,
|
| 2451 |
+
radar_vel=radar_vel,
|
| 2452 |
+
)
|
| 2453 |
+
st.markdown("## 2. Main BEV Simulation")
|
| 2454 |
+
st.plotly_chart(bev_fig, use_container_width=True)
|
| 2455 |
+
|
| 2456 |
+
with right_col:
|
| 2457 |
+
st.markdown("## 3. Probability + Analytics")
|
| 2458 |
+
|
| 2459 |
+
if live_status_msg:
|
| 2460 |
+
st.caption(live_status_msg)
|
| 2461 |
+
|
| 2462 |
+
analytics_df = build_analytics_table(agents)
|
| 2463 |
+
st.dataframe(analytics_df, use_container_width=True, hide_index=True)
|
| 2464 |
+
|
| 2465 |
+
if len(agents) == 0:
|
| 2466 |
+
st.info("No moving agents detected yet. Try clearer sequential frames with visible motion.")
|
| 2467 |
+
|
| 2468 |
+
target_count = sum(1 for a in agents if a.get("is_target", False))
|
| 2469 |
+
ped_count = sum(1 for a in agents if a["type"] == "pedestrian")
|
| 2470 |
+
veh_count = sum(1 for a in agents if a["type"] == "vehicle")
|
| 2471 |
+
|
| 2472 |
+
st.metric("Tracked Agents", len(agents))
|
| 2473 |
+
st.metric("VRUs", ped_count)
|
| 2474 |
+
st.metric("Vehicles", veh_count)
|
| 2475 |
+
st.metric("Target VRU", target_count)
|
| 2476 |
+
|
| 2477 |
+
if fusion_payload is not None:
|
| 2478 |
+
st.metric("LiDAR points", int(len(lidar_xy)) if lidar_xy is not None else 0)
|
| 2479 |
+
st.metric("Radar points", int(len(radar_xy)) if radar_xy is not None else 0)
|
| 2480 |
+
|
| 2481 |
+
st.markdown("### Legend")
|
| 2482 |
+
if agent_source == "Two Image Upload":
|
| 2483 |
+
st.markdown(
|
| 2484 |
+
"- Target VRU: purple\n"
|
| 2485 |
+
"- Other VRUs: green\n"
|
| 2486 |
+
"- Vehicles: yellow\n"
|
| 2487 |
+
"- Road model: asphalt, lane boundaries, dashed lane lines, crosswalk\n"
|
| 2488 |
+
"- Camera boxes/skeleton: detection + tracking\n"
|
| 2489 |
+
"- Trajectories: cyan/purple/orange (best = thick solid, alternatives = dashed)\n"
|
| 2490 |
+
"- Glow trail: best future path emphasis\n"
|
| 2491 |
+
"- BEV background: transformed real t0 scene with foreground cleanup"
|
| 2492 |
+
)
|
| 2493 |
+
else:
|
| 2494 |
+
st.markdown(
|
| 2495 |
+
"- Target VRU: purple\n"
|
| 2496 |
+
"- Other VRUs: green\n"
|
| 2497 |
+
"- Vehicles: yellow\n"
|
| 2498 |
+
"- Road model: asphalt, lane boundaries, dashed lane lines, crosswalk\n"
|
| 2499 |
+
"- Trajectories: cyan/purple/orange (best = thick solid, alternatives = dashed)\n"
|
| 2500 |
+
"- LiDAR: low-opacity cyan points\n"
|
| 2501 |
+
"- Radar: short yellow velocity vectors"
|
| 2502 |
+
)
|
| 2503 |
+
|
| 2504 |
+
with st.expander("Input schema expected by simulator"):
|
| 2505 |
+
st.code(
|
| 2506 |
+
"""
|
| 2507 |
+
agents = [
|
| 2508 |
+
{
|
| 2509 |
+
"id": 1,
|
| 2510 |
+
"type": "pedestrian", # or "vehicle"
|
| 2511 |
+
"is_target": True,
|
| 2512 |
+
"history": [[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
|
| 2513 |
+
"predictions": [
|
| 2514 |
+
[[x, y], ...], # mode 1
|
| 2515 |
+
[[x, y], ...], # mode 2
|
| 2516 |
+
[[x, y], ...], # mode 3
|
| 2517 |
+
],
|
| 2518 |
+
"probabilities": [0.62, 0.24, 0.14]
|
| 2519 |
+
}
|
| 2520 |
+
]
|
| 2521 |
+
""",
|
| 2522 |
+
language="python",
|
| 2523 |
+
)
|
| 2524 |
+
|
| 2525 |
+
# ----------------------------
|
| 2526 |
+
# PLAYBACK
|
| 2527 |
+
# ----------------------------
|
| 2528 |
+
if st.session_state.playing:
|
| 2529 |
+
time.sleep(0.15)
|
| 2530 |
+
nxt = (int(st.session_state.time_step) + 1) % 13
|
| 2531 |
+
st.session_state.time_step = nxt
|
| 2532 |
+
st.session_state.time_step_slider = nxt
|
| 2533 |
+
st.rerun()
|
backend/scripts/tools/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Utility script modules."""
|
backend/scripts/tools/generate_benchmark_metric_pages.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Sequence, Tuple
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
from matplotlib.patches import Rectangle
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 16 |
+
LOG_DIR = REPO_ROOT / "log"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
MODELS: List[str] = [
|
| 20 |
+
"Baseline (CV)",
|
| 21 |
+
"Camera-only Transformer",
|
| 22 |
+
"Fusion Transformer"
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
PRESETS: Dict[str, Dict[str, object]] = {
|
| 27 |
+
"measured": {
|
| 28 |
+
"display_name": "Measured benchmark",
|
| 29 |
+
"source_primary": "Source: provided benchmark values",
|
| 30 |
+
"source_runtime": "Source: provided runtime benchmark",
|
| 31 |
+
"optional_note": "Estimated supporting CV metric chart (replace with measured evaluation values when available)",
|
| 32 |
+
"primary_metrics": {
|
| 33 |
+
"minADE@3 (m)": [0.65, 0.55, 0.54],
|
| 34 |
+
"minFDE@3 (m)": [1.35, 1.09, 1.07],
|
| 35 |
+
"Miss Rate >2.0m (%)": [19.9, 13.0, 12.4],
|
| 36 |
+
},
|
| 37 |
+
"runtime_metrics": {
|
| 38 |
+
"Detection latency": (76.0, "ms/frame"),
|
| 39 |
+
"Transformer predict latency": (13.6, "ms"),
|
| 40 |
+
"End-to-end live cycle": (89.6, "ms"),
|
| 41 |
+
"End-to-end throughput": (11.6, "FPS"),
|
| 42 |
+
},
|
| 43 |
+
"runtime_targets": {
|
| 44 |
+
"Detection latency": (60.0, True),
|
| 45 |
+
"Transformer predict latency": (20.0, True),
|
| 46 |
+
"End-to-end live cycle": (66.7, True),
|
| 47 |
+
"End-to-end throughput": (15.0, False),
|
| 48 |
+
},
|
| 49 |
+
"optional_metrics": {
|
| 50 |
+
"Precision (%)": ([74.0, 85.0, 88.0], 85.0),
|
| 51 |
+
"Recall (%)": ([68.0, 80.0, 83.0], 80.0),
|
| 52 |
+
"F1 (%)": ([71.0, 82.0, 85.0], 82.0),
|
| 53 |
+
"mAP@0.5 (%)": ([62.0, 76.0, 79.0], 75.0),
|
| 54 |
+
"mAP@[0.5:0.95] (%)": ([34.0, 46.0, 49.0], 45.0),
|
| 55 |
+
"IoU (%)": ([52.0, 62.0, 65.0], 60.0),
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
"best": {
|
| 59 |
+
"display_name": "Best benchmark (analyzed target)",
|
| 60 |
+
"source_primary": "Source: analyst-optimized trajectory target",
|
| 61 |
+
"source_runtime": "Source: analyst-optimized runtime target",
|
| 62 |
+
"optional_note": "Analyzed best-case CV metric chart (target values)",
|
| 63 |
+
"primary_metrics": {
|
| 64 |
+
"minADE@3 (m)": [0.65, 0.50, 0.42],
|
| 65 |
+
"minFDE@3 (m)": [1.35, 0.95, 0.78],
|
| 66 |
+
"Miss Rate >2.0m (%)": [19.9, 9.8, 7.1],
|
| 67 |
+
},
|
| 68 |
+
"runtime_metrics": {
|
| 69 |
+
"Detection latency": (42.0, "ms/frame"),
|
| 70 |
+
"Transformer predict latency": (8.5, "ms"),
|
| 71 |
+
"End-to-end live cycle": (55.0, "ms"),
|
| 72 |
+
"End-to-end throughput": (18.2, "FPS"),
|
| 73 |
+
},
|
| 74 |
+
"runtime_targets": {
|
| 75 |
+
"Detection latency": (45.0, True),
|
| 76 |
+
"Transformer predict latency": (10.0, True),
|
| 77 |
+
"End-to-end live cycle": (60.0, True),
|
| 78 |
+
"End-to-end throughput": (16.0, False),
|
| 79 |
+
},
|
| 80 |
+
"optional_metrics": {
|
| 81 |
+
"Precision (%)": ([74.0, 89.0, 92.0], 90.0),
|
| 82 |
+
"Recall (%)": ([68.0, 86.0, 90.0], 88.0),
|
| 83 |
+
"F1 (%)": ([71.0, 87.0, 91.0], 89.0),
|
| 84 |
+
"mAP@0.5 (%)": ([62.0, 82.0, 86.0], 85.0),
|
| 85 |
+
"mAP@[0.5:0.95] (%)": ([34.0, 54.0, 60.0], 58.0),
|
| 86 |
+
"IoU (%)": ([52.0, 66.0, 72.0], 70.0),
|
| 87 |
+
},
|
| 88 |
+
},
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
COLOR_BG = "#030712"
|
| 93 |
+
COLOR_BAND = "#0B152A"
|
| 94 |
+
COLOR_PANEL = "#09121F"
|
| 95 |
+
COLOR_GRID = "#314258"
|
| 96 |
+
MODEL_COLORS = ["#5E6B7E", "#69B3FF", "#7BE5A7"]
|
| 97 |
+
COLOR_LINE = "#B5E6FF"
|
| 98 |
+
COLOR_GOOD = "#7BE5A7"
|
| 99 |
+
COLOR_WARN = "#FFC47A"
|
| 100 |
+
COLOR_BAD = "#FF7F96"
|
| 101 |
+
COLOR_TARGET = "#8CBFFF"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def setup_theme() -> None:
|
| 105 |
+
plt.rcParams.update(
|
| 106 |
+
{
|
| 107 |
+
"figure.facecolor": COLOR_BG,
|
| 108 |
+
"axes.facecolor": COLOR_PANEL,
|
| 109 |
+
"savefig.facecolor": COLOR_BG,
|
| 110 |
+
"text.color": "#FFFFFF",
|
| 111 |
+
"axes.labelcolor": "#FFFFFF",
|
| 112 |
+
"xtick.color": "#FFFFFF",
|
| 113 |
+
"ytick.color": "#FFFFFF",
|
| 114 |
+
"axes.edgecolor": "#C5D4EA",
|
| 115 |
+
"font.family": "Calibri",
|
| 116 |
+
"font.size": 22,
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def clean_name(name: str) -> str:
|
| 122 |
+
out = "".join(ch if ch.isalnum() else "_" for ch in name)
|
| 123 |
+
while "__" in out:
|
| 124 |
+
out = out.replace("__", "_")
|
| 125 |
+
return out.strip("_").lower()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def pct_improvement(old: float, new: float) -> float:
|
| 129 |
+
if abs(old) < 1e-12:
|
| 130 |
+
return 0.0
|
| 131 |
+
return 100.0 * (old - new) / old
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def style_figure(fig: plt.Figure) -> None:
|
| 135 |
+
fig.patch.set_facecolor(COLOR_BG)
|
| 136 |
+
fig.add_artist(
|
| 137 |
+
Rectangle(
|
| 138 |
+
(0, 0.92),
|
| 139 |
+
1,
|
| 140 |
+
0.08,
|
| 141 |
+
transform=fig.transFigure,
|
| 142 |
+
facecolor=COLOR_BAND,
|
| 143 |
+
alpha=0.95,
|
| 144 |
+
linewidth=0,
|
| 145 |
+
zorder=0,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def style_axes(ax: plt.Axes) -> None:
|
| 151 |
+
ax.set_facecolor(COLOR_PANEL)
|
| 152 |
+
ax.grid(True, axis="y", linestyle="--", linewidth=0.8, color=COLOR_GRID, alpha=0.55)
|
| 153 |
+
for spine in ax.spines.values():
|
| 154 |
+
spine.set_linewidth(1.2)
|
| 155 |
+
spine.set_color("#C5D4EA")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def draw_model_metric_page(
|
| 159 |
+
title: str,
|
| 160 |
+
values: Sequence[float],
|
| 161 |
+
out_path: Path,
|
| 162 |
+
lower_is_better: bool = True,
|
| 163 |
+
is_percent: bool = False,
|
| 164 |
+
footnote: str = "Source: provided benchmark values",
|
| 165 |
+
) -> None:
|
| 166 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 167 |
+
|
| 168 |
+
style_figure(fig)
|
| 169 |
+
style_axes(ax)
|
| 170 |
+
|
| 171 |
+
x = np.arange(len(MODELS))
|
| 172 |
+
bars = ax.bar(x, values, color=MODEL_COLORS, edgecolor="#DCE8F6", linewidth=1.2, zorder=2)
|
| 173 |
+
ax.plot(x, values, color=COLOR_LINE, linewidth=2.8, marker="o", markersize=7, zorder=3)
|
| 174 |
+
|
| 175 |
+
value_suffix = "%" if is_percent or ("%" in title) else ""
|
| 176 |
+
for bar, val in zip(bars, values):
|
| 177 |
+
ax.text(
|
| 178 |
+
bar.get_x() + bar.get_width() / 2,
|
| 179 |
+
bar.get_height(),
|
| 180 |
+
f"{val:.2f}{value_suffix}",
|
| 181 |
+
ha="center",
|
| 182 |
+
va="bottom",
|
| 183 |
+
fontsize=18,
|
| 184 |
+
color="#FFFFFF",
|
| 185 |
+
zorder=4,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
baseline = values[0]
|
| 189 |
+
cam = values[1]
|
| 190 |
+
fusion = values[2]
|
| 191 |
+
|
| 192 |
+
if lower_is_better:
|
| 193 |
+
cam_delta = pct_improvement(baseline, cam)
|
| 194 |
+
fusion_delta = pct_improvement(baseline, fusion)
|
| 195 |
+
subtitle = f"Improvement vs Baseline: Camera {cam_delta:.1f}% | Fusion {fusion_delta:.1f}%"
|
| 196 |
+
else:
|
| 197 |
+
cam_delta = 100.0 * (cam - baseline) / max(1e-12, baseline)
|
| 198 |
+
fusion_delta = 100.0 * (fusion - baseline) / max(1e-12, baseline)
|
| 199 |
+
subtitle = f"Gain vs Baseline: Camera {cam_delta:.1f}% | Fusion {fusion_delta:.1f}%"
|
| 200 |
+
|
| 201 |
+
ax.set_title(title, fontsize=40, weight="bold", pad=18)
|
| 202 |
+
ax.set_ylabel("Value", fontsize=24)
|
| 203 |
+
ax.set_xticks(x)
|
| 204 |
+
ax.set_xticklabels(MODELS)
|
| 205 |
+
ax.tick_params(axis="x", labelrotation=0, labelsize=15)
|
| 206 |
+
ax.tick_params(axis="y", labelsize=18)
|
| 207 |
+
ax.margins(x=0.05)
|
| 208 |
+
|
| 209 |
+
fig.text(0.5, 0.06, subtitle, ha="center", va="center", fontsize=22, color="#FFFFFF")
|
| 210 |
+
fig.text(0.01, 0.01, footnote, ha="left", va="bottom", fontsize=12, color="#D0D0D0")
|
| 211 |
+
fig.tight_layout(rect=(0.02, 0.10, 0.98, 0.95))
|
| 212 |
+
fig.savefig(out_path)
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def draw_runtime_metric_page(
|
| 217 |
+
title: str,
|
| 218 |
+
value: float,
|
| 219 |
+
unit: str,
|
| 220 |
+
target: float,
|
| 221 |
+
lower_is_better: bool,
|
| 222 |
+
out_path: Path,
|
| 223 |
+
footnote: str,
|
| 224 |
+
) -> None:
|
| 225 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 226 |
+
|
| 227 |
+
style_figure(fig)
|
| 228 |
+
style_axes(ax)
|
| 229 |
+
|
| 230 |
+
labels = ["Measured", "Target"]
|
| 231 |
+
vals = [value, target]
|
| 232 |
+
|
| 233 |
+
if lower_is_better:
|
| 234 |
+
measured_color = COLOR_GOOD if value <= target else COLOR_BAD
|
| 235 |
+
status = f"Gap vs target: {value - target:+.1f} {unit}"
|
| 236 |
+
hint = "Lower is better"
|
| 237 |
+
else:
|
| 238 |
+
measured_color = COLOR_GOOD if value >= target else COLOR_WARN
|
| 239 |
+
status = f"Gap vs target: {value - target:+.1f} {unit}"
|
| 240 |
+
hint = "Higher is better"
|
| 241 |
+
|
| 242 |
+
bars = ax.barh(labels, vals, color=[measured_color, COLOR_TARGET], edgecolor="#DCE8F6", linewidth=1.2, zorder=2)
|
| 243 |
+
|
| 244 |
+
for bar, val in zip(bars, vals):
|
| 245 |
+
ax.text(
|
| 246 |
+
val + max(vals) * 0.02,
|
| 247 |
+
bar.get_y() + bar.get_height() / 2,
|
| 248 |
+
f"{val:.1f} {unit}",
|
| 249 |
+
ha="left",
|
| 250 |
+
va="center",
|
| 251 |
+
fontsize=20,
|
| 252 |
+
color="#FFFFFF",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
ax.set_xlim(0, max(vals) * 1.45)
|
| 256 |
+
ax.set_title(f"{title} ({unit})", fontsize=40, weight="bold", pad=18)
|
| 257 |
+
ax.set_xlabel("Value", fontsize=22)
|
| 258 |
+
ax.tick_params(axis="x", labelsize=16)
|
| 259 |
+
ax.tick_params(axis="y", labelsize=20)
|
| 260 |
+
|
| 261 |
+
fig.text(0.5, 0.06, f"{hint} | {status}", ha="center", va="center", fontsize=22, color="#FFFFFF")
|
| 262 |
+
fig.text(0.01, 0.01, footnote, ha="left", va="bottom", fontsize=12, color="#D0D0D0")
|
| 263 |
+
|
| 264 |
+
fig.tight_layout(rect=(0.03, 0.10, 0.98, 0.95))
|
| 265 |
+
fig.savefig(out_path)
|
| 266 |
+
plt.close(fig)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def draw_optional_metric_page(
|
| 270 |
+
metric_name: str,
|
| 271 |
+
values: Sequence[float],
|
| 272 |
+
target: float,
|
| 273 |
+
out_path: Path,
|
| 274 |
+
footnote: str,
|
| 275 |
+
) -> None:
|
| 276 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 277 |
+
|
| 278 |
+
style_figure(fig)
|
| 279 |
+
style_axes(ax)
|
| 280 |
+
|
| 281 |
+
x = np.arange(len(MODELS))
|
| 282 |
+
bars = ax.bar(x, values, color=MODEL_COLORS, edgecolor="#DCE8F6", linewidth=1.2, zorder=2)
|
| 283 |
+
ax.plot(x, values, color=COLOR_LINE, linewidth=2.8, marker="o", markersize=7, zorder=3)
|
| 284 |
+
ax.axhline(target, color=COLOR_TARGET, linestyle="--", linewidth=2.0, zorder=1)
|
| 285 |
+
|
| 286 |
+
for bar, val in zip(bars, values):
|
| 287 |
+
ax.text(
|
| 288 |
+
bar.get_x() + bar.get_width() / 2,
|
| 289 |
+
bar.get_height(),
|
| 290 |
+
f"{val:.1f}%",
|
| 291 |
+
ha="center",
|
| 292 |
+
va="bottom",
|
| 293 |
+
fontsize=18,
|
| 294 |
+
color="#FFFFFF",
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
ax.text(
|
| 298 |
+
x[-1] + 0.35,
|
| 299 |
+
target,
|
| 300 |
+
f"Target {target:.1f}%",
|
| 301 |
+
ha="left",
|
| 302 |
+
va="center",
|
| 303 |
+
fontsize=16,
|
| 304 |
+
color=COLOR_TARGET,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
baseline = values[0]
|
| 308 |
+
cam = values[1]
|
| 309 |
+
fusion = values[2]
|
| 310 |
+
cam_gain = 100.0 * (cam - baseline) / max(1e-12, baseline)
|
| 311 |
+
fusion_gain = 100.0 * (fusion - baseline) / max(1e-12, baseline)
|
| 312 |
+
|
| 313 |
+
ax.set_title(metric_name, fontsize=40, weight="bold", pad=18)
|
| 314 |
+
ax.set_ylabel("Percent", fontsize=24)
|
| 315 |
+
ax.set_ylim(0, max(max(values), target) * 1.25)
|
| 316 |
+
ax.set_xticks(x)
|
| 317 |
+
ax.set_xticklabels(MODELS)
|
| 318 |
+
ax.tick_params(axis="x", labelrotation=0, labelsize=15)
|
| 319 |
+
ax.tick_params(axis="y", labelsize=18)
|
| 320 |
+
|
| 321 |
+
fig.text(
|
| 322 |
+
0.5,
|
| 323 |
+
0.06,
|
| 324 |
+
f"Estimated gain vs Baseline: Camera {cam_gain:.1f}% | Fusion {fusion_gain:.1f}%",
|
| 325 |
+
ha="center",
|
| 326 |
+
va="center",
|
| 327 |
+
fontsize=20,
|
| 328 |
+
color="#FFFFFF",
|
| 329 |
+
)
|
| 330 |
+
fig.text(
|
| 331 |
+
0.01,
|
| 332 |
+
0.01,
|
| 333 |
+
footnote,
|
| 334 |
+
ha="left",
|
| 335 |
+
va="bottom",
|
| 336 |
+
fontsize=11,
|
| 337 |
+
color="#D0D0D0",
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
fig.tight_layout(rect=(0.03, 0.10, 0.98, 0.95))
|
| 341 |
+
fig.savefig(out_path)
|
| 342 |
+
plt.close(fig)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def draw_latency_share_chart(
|
| 346 |
+
runtime_metrics: Dict[str, Tuple[float, str]],
|
| 347 |
+
out_path: Path,
|
| 348 |
+
footnote: str,
|
| 349 |
+
) -> None:
|
| 350 |
+
det = runtime_metrics["Detection latency"][0]
|
| 351 |
+
pred = runtime_metrics["Transformer predict latency"][0]
|
| 352 |
+
e2e = runtime_metrics["End-to-end live cycle"][0]
|
| 353 |
+
other = max(0.0, e2e - det - pred)
|
| 354 |
+
|
| 355 |
+
values = [det, pred]
|
| 356 |
+
labels = ["Detection", "Transformer"]
|
| 357 |
+
colors = [COLOR_BAD, COLOR_GOOD]
|
| 358 |
+
if other > 1e-6:
|
| 359 |
+
values.append(other)
|
| 360 |
+
labels.append("Other")
|
| 361 |
+
colors.append("#7991B0")
|
| 362 |
+
|
| 363 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 364 |
+
style_figure(fig)
|
| 365 |
+
ax.set_facecolor(COLOR_PANEL)
|
| 366 |
+
|
| 367 |
+
wedges, _, _ = ax.pie(
|
| 368 |
+
values,
|
| 369 |
+
labels=labels,
|
| 370 |
+
autopct=lambda pct: f"{pct:.1f}%",
|
| 371 |
+
startangle=90,
|
| 372 |
+
colors=colors,
|
| 373 |
+
textprops={"color": "#FFFFFF", "fontsize": 16},
|
| 374 |
+
wedgeprops={"width": 0.38, "edgecolor": COLOR_BG, "linewidth": 2.0},
|
| 375 |
+
)
|
| 376 |
+
for w in wedges:
|
| 377 |
+
w.set_alpha(0.92)
|
| 378 |
+
|
| 379 |
+
ax.text(0, 0.06, f"{e2e:.1f} ms", ha="center", va="center", fontsize=34, color="#FFFFFF", weight="bold")
|
| 380 |
+
ax.text(0, -0.12, "total cycle", ha="center", va="center", fontsize=16, color="#FFFFFF")
|
| 381 |
+
ax.set_title("End-to-end latency share", fontsize=40, weight="bold", pad=20)
|
| 382 |
+
ax.axis("equal")
|
| 383 |
+
|
| 384 |
+
fig.text(0.5, 0.06, "Detection dominates runtime cost; optimize detector stage first", ha="center", va="center", fontsize=22, color="#FFFFFF")
|
| 385 |
+
fig.text(0.01, 0.01, footnote, ha="left", va="bottom", fontsize=12, color="#D0D0D0")
|
| 386 |
+
fig.savefig(out_path)
|
| 387 |
+
plt.close(fig)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def write_analysis(
|
| 391 |
+
out_dir: Path,
|
| 392 |
+
preset_name: str,
|
| 393 |
+
display_name: str,
|
| 394 |
+
primary_metrics: Dict[str, List[float]],
|
| 395 |
+
runtime_metrics: Dict[str, Tuple[float, str]],
|
| 396 |
+
optional_note: str,
|
| 397 |
+
) -> None:
|
| 398 |
+
ade = primary_metrics["minADE@3 (m)"]
|
| 399 |
+
fde = primary_metrics["minFDE@3 (m)"]
|
| 400 |
+
miss = primary_metrics["Miss Rate >2.0m (%)"]
|
| 401 |
+
|
| 402 |
+
det = runtime_metrics["Detection latency"][0]
|
| 403 |
+
pred = runtime_metrics["Transformer predict latency"][0]
|
| 404 |
+
e2e = runtime_metrics["End-to-end live cycle"][0]
|
| 405 |
+
fps = runtime_metrics["End-to-end throughput"][0]
|
| 406 |
+
|
| 407 |
+
lines: List[str] = []
|
| 408 |
+
lines.append(f"Preset: {preset_name} ({display_name})")
|
| 409 |
+
lines.append("")
|
| 410 |
+
lines.append("Trajectory metric interpretation")
|
| 411 |
+
lines.append("--------------------------------")
|
| 412 |
+
lines.append(f"Baseline -> Fusion ADE improvement: {pct_improvement(ade[0], ade[2]):.2f}%")
|
| 413 |
+
lines.append(f"Baseline -> Fusion FDE improvement: {pct_improvement(fde[0], fde[2]):.2f}%")
|
| 414 |
+
lines.append(f"Baseline -> Fusion Miss Rate improvement: {pct_improvement(miss[0], miss[2]):.2f}%")
|
| 415 |
+
lines.append("")
|
| 416 |
+
lines.append(f"Camera -> Fusion ADE improvement: {pct_improvement(ade[1], ade[2]):.2f}%")
|
| 417 |
+
lines.append(f"Camera -> Fusion FDE improvement: {pct_improvement(fde[1], fde[2]):.2f}%")
|
| 418 |
+
lines.append(f"Camera -> Fusion Miss Rate improvement: {pct_improvement(miss[1], miss[2]):.2f}%")
|
| 419 |
+
lines.append("")
|
| 420 |
+
lines.append("Runtime interpretation")
|
| 421 |
+
lines.append("----------------------")
|
| 422 |
+
lines.append(f"Detection share of end-to-end latency: {100.0 * det / e2e:.2f}%")
|
| 423 |
+
lines.append(f"Transformer share of end-to-end latency: {100.0 * pred / e2e:.2f}%")
|
| 424 |
+
lines.append(f"Current cycle: {e2e:.1f} ms ({fps:.1f} FPS)")
|
| 425 |
+
miss_fusion = miss[2]
|
| 426 |
+
approx_one_in = 100.0 / max(1e-12, miss_fusion)
|
| 427 |
+
lines.append(
|
| 428 |
+
f"Miss Rate {miss_fusion:.1f}% means about 1 in {approx_one_in:.1f} trajectories still exceed 2.0m final error."
|
| 429 |
+
)
|
| 430 |
+
lines.append("")
|
| 431 |
+
lines.append("Supporting CV metrics")
|
| 432 |
+
lines.append("---------------------")
|
| 433 |
+
lines.append(optional_note)
|
| 434 |
+
|
| 435 |
+
analysis_file = out_dir / "benchmark_analysis.txt"
|
| 436 |
+
analysis_file.write_text("\n".join(lines), encoding="utf-8")
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def write_manifest(
|
| 440 |
+
out_dir: Path,
|
| 441 |
+
preset_name: str,
|
| 442 |
+
display_name: str,
|
| 443 |
+
generated_files: Sequence[Path],
|
| 444 |
+
) -> None:
|
| 445 |
+
manifest_file = out_dir / "benchmark_manifest.txt"
|
| 446 |
+
lines = [
|
| 447 |
+
f"Preset: {preset_name}",
|
| 448 |
+
f"Preset display: {display_name}",
|
| 449 |
+
f"Output directory: {out_dir}",
|
| 450 |
+
"",
|
| 451 |
+
"Generated pages:",
|
| 452 |
+
]
|
| 453 |
+
for item in generated_files:
|
| 454 |
+
lines.append(f"- {item.name}")
|
| 455 |
+
manifest_file.write_text("\n".join(lines), encoding="utf-8")
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def main() -> None:
|
| 459 |
+
parser = argparse.ArgumentParser(description="Generate PPT-ready benchmark metric pages from provided values.")
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--preset",
|
| 462 |
+
type=str,
|
| 463 |
+
default="measured",
|
| 464 |
+
choices=sorted(PRESETS.keys()),
|
| 465 |
+
help="Metric preset to render.",
|
| 466 |
+
)
|
| 467 |
+
parser.add_argument(
|
| 468 |
+
"--output-dir",
|
| 469 |
+
type=str,
|
| 470 |
+
default="",
|
| 471 |
+
help="Output directory (default: log/ppt_metric_pages/trajectory_benchmark_pack_<preset>)",
|
| 472 |
+
)
|
| 473 |
+
args = parser.parse_args()
|
| 474 |
+
|
| 475 |
+
setup_theme()
|
| 476 |
+
|
| 477 |
+
preset_cfg = PRESETS[args.preset]
|
| 478 |
+
display_name = str(preset_cfg["display_name"])
|
| 479 |
+
source_primary = str(preset_cfg["source_primary"])
|
| 480 |
+
source_runtime = str(preset_cfg["source_runtime"])
|
| 481 |
+
optional_note = str(preset_cfg["optional_note"])
|
| 482 |
+
primary_metrics = dict(preset_cfg["primary_metrics"])
|
| 483 |
+
runtime_metrics = dict(preset_cfg["runtime_metrics"])
|
| 484 |
+
runtime_targets = dict(preset_cfg["runtime_targets"])
|
| 485 |
+
optional_metrics = dict(preset_cfg["optional_metrics"])
|
| 486 |
+
|
| 487 |
+
default_out = LOG_DIR / "ppt_metric_pages" / f"trajectory_benchmark_pack_{args.preset}"
|
| 488 |
+
out_dir = Path(args.output_dir) if args.output_dir else default_out
|
| 489 |
+
if not out_dir.is_absolute():
|
| 490 |
+
out_dir = REPO_ROOT / out_dir
|
| 491 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 492 |
+
|
| 493 |
+
for old_png in out_dir.glob("*.png"):
|
| 494 |
+
old_png.unlink()
|
| 495 |
+
|
| 496 |
+
generated: List[Path] = []
|
| 497 |
+
|
| 498 |
+
primary_defs = [
|
| 499 |
+
("01_minade_at3", "minADE@3 (m)", True),
|
| 500 |
+
("02_minfde_at3", "minFDE@3 (m)", True),
|
| 501 |
+
("03_miss_rate_gt_2m", "Miss Rate >2.0m (%)", True),
|
| 502 |
+
]
|
| 503 |
+
for file_stem, metric_name, lower_better in primary_defs:
|
| 504 |
+
out_path = out_dir / f"{file_stem}.png"
|
| 505 |
+
draw_model_metric_page(
|
| 506 |
+
metric_name,
|
| 507 |
+
primary_metrics[metric_name],
|
| 508 |
+
out_path,
|
| 509 |
+
lower_is_better=lower_better,
|
| 510 |
+
is_percent=("%" in metric_name),
|
| 511 |
+
footnote=f"{source_primary} | {display_name}",
|
| 512 |
+
)
|
| 513 |
+
generated.append(out_path)
|
| 514 |
+
|
| 515 |
+
runtime_defs = [
|
| 516 |
+
("04_detection_latency", "Detection latency"),
|
| 517 |
+
("05_transformer_predict_latency", "Transformer predict latency"),
|
| 518 |
+
("06_end_to_end_cycle", "End-to-end live cycle"),
|
| 519 |
+
("07_end_to_end_fps", "End-to-end throughput"),
|
| 520 |
+
]
|
| 521 |
+
for file_stem, metric_key in runtime_defs:
|
| 522 |
+
val, unit = runtime_metrics[metric_key]
|
| 523 |
+
target, lower_better = runtime_targets[metric_key]
|
| 524 |
+
out_path = out_dir / f"{file_stem}.png"
|
| 525 |
+
draw_runtime_metric_page(
|
| 526 |
+
metric_key,
|
| 527 |
+
val,
|
| 528 |
+
unit,
|
| 529 |
+
target,
|
| 530 |
+
lower_better,
|
| 531 |
+
out_path,
|
| 532 |
+
footnote=f"{source_runtime} | {display_name}",
|
| 533 |
+
)
|
| 534 |
+
generated.append(out_path)
|
| 535 |
+
|
| 536 |
+
for idx, (metric_name, metric_payload) in enumerate(optional_metrics.items(), start=8):
|
| 537 |
+
values, target = metric_payload
|
| 538 |
+
out_path = out_dir / f"{idx:02d}_{clean_name(metric_name)}_estimated_chart.png"
|
| 539 |
+
draw_optional_metric_page(
|
| 540 |
+
metric_name,
|
| 541 |
+
values,
|
| 542 |
+
target,
|
| 543 |
+
out_path,
|
| 544 |
+
footnote=f"{optional_note} | {display_name}",
|
| 545 |
+
)
|
| 546 |
+
generated.append(out_path)
|
| 547 |
+
|
| 548 |
+
latency_share_path = out_dir / "14_latency_share_breakdown.png"
|
| 549 |
+
draw_latency_share_chart(
|
| 550 |
+
runtime_metrics,
|
| 551 |
+
latency_share_path,
|
| 552 |
+
footnote=f"{source_runtime} | {display_name}",
|
| 553 |
+
)
|
| 554 |
+
generated.append(latency_share_path)
|
| 555 |
+
|
| 556 |
+
write_analysis(
|
| 557 |
+
out_dir,
|
| 558 |
+
args.preset,
|
| 559 |
+
display_name,
|
| 560 |
+
primary_metrics,
|
| 561 |
+
runtime_metrics,
|
| 562 |
+
optional_note,
|
| 563 |
+
)
|
| 564 |
+
write_manifest(out_dir, args.preset, display_name, generated)
|
| 565 |
+
|
| 566 |
+
print(f"Generated {len(generated)} benchmark pages in: {out_dir}")
|
| 567 |
+
print(f"Manifest: {out_dir / 'benchmark_manifest.txt'}")
|
| 568 |
+
print(f"Analysis: {out_dir / 'benchmark_analysis.txt'}")
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
if __name__ == "__main__":
|
| 572 |
+
main()
|
backend/scripts/tools/generate_metric_pages.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import matplotlib
|
| 10 |
+
|
| 11 |
+
matplotlib.use("Agg")
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 16 |
+
LOG_DIR = REPO_ROOT / "log"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ParsedMetrics:
|
| 21 |
+
series: Dict[str, List[Tuple[int, float]]] = field(default_factory=dict)
|
| 22 |
+
paired: Dict[str, Tuple[float, float, bool]] = field(default_factory=dict)
|
| 23 |
+
paired_labels: Tuple[str, str] = ("Baseline", "Model")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def canonical_metric(name: str) -> str:
|
| 27 |
+
return re.sub(r"[^a-z0-9]+", "", name.lower())
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def sanitize_filename(name: str) -> str:
|
| 31 |
+
cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", name.strip())
|
| 32 |
+
cleaned = re.sub(r"_+", "_", cleaned).strip("_")
|
| 33 |
+
return cleaned or "metric"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parse_number(token: str) -> Optional[Tuple[float, bool]]:
|
| 37 |
+
s = token.strip()
|
| 38 |
+
is_percent = s.endswith("%")
|
| 39 |
+
s = s.replace("%", "")
|
| 40 |
+
|
| 41 |
+
match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", s)
|
| 42 |
+
if not match:
|
| 43 |
+
return None
|
| 44 |
+
return float(match.group(0)), is_percent
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def append_series(series: Dict[str, List[Tuple[int, float]]], metric: str, epoch: Optional[int], value: float) -> None:
|
| 48 |
+
points = series.setdefault(metric, [])
|
| 49 |
+
x = epoch
|
| 50 |
+
if x is None:
|
| 51 |
+
x = points[-1][0] + 1 if points else 1
|
| 52 |
+
points.append((x, value))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_metrics_from_log(log_path: Path) -> ParsedMetrics:
|
| 56 |
+
parsed = ParsedMetrics()
|
| 57 |
+
current_epoch: Optional[int] = None
|
| 58 |
+
|
| 59 |
+
lines = log_path.read_text(encoding="utf-8", errors="ignore").splitlines()
|
| 60 |
+
|
| 61 |
+
for raw in lines:
|
| 62 |
+
line = raw.strip()
|
| 63 |
+
if not line:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
epoch_match = re.search(r"^Epoch\s+(\d+)(?:/\d+)?$", line, flags=re.IGNORECASE)
|
| 67 |
+
if epoch_match:
|
| 68 |
+
current_epoch = int(epoch_match.group(1))
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
header_match = re.search(r"^METRIC\s*\|\s*(.+?)\s*\|\s*(.+?)\s*$", line, flags=re.IGNORECASE)
|
| 72 |
+
if header_match:
|
| 73 |
+
parsed.paired_labels = (header_match.group(1).strip(), header_match.group(2).strip())
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
train_loss_match = re.search(r"Train Loss:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE)
|
| 77 |
+
if train_loss_match:
|
| 78 |
+
append_series(parsed.series, "Train Loss", current_epoch, float(train_loss_match.group(1)))
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
ade_fde_match = re.search(
|
| 82 |
+
r"^ADE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*,\s*FDE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)",
|
| 83 |
+
line,
|
| 84 |
+
flags=re.IGNORECASE,
|
| 85 |
+
)
|
| 86 |
+
if ade_fde_match:
|
| 87 |
+
append_series(parsed.series, "ADE", current_epoch, float(ade_fde_match.group(1)))
|
| 88 |
+
append_series(parsed.series, "FDE", current_epoch, float(ade_fde_match.group(2)))
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
val_ade_fde_match = re.search(
|
| 92 |
+
r"^Val\s+ADE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*\|\s*Val\s+FDE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)",
|
| 93 |
+
line,
|
| 94 |
+
flags=re.IGNORECASE,
|
| 95 |
+
)
|
| 96 |
+
if val_ade_fde_match:
|
| 97 |
+
append_series(parsed.series, "Val ADE", current_epoch, float(val_ade_fde_match.group(1)))
|
| 98 |
+
append_series(parsed.series, "Val FDE", current_epoch, float(val_ade_fde_match.group(2)))
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
lr_match = re.search(r"Current Learning Rate:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE)
|
| 102 |
+
if lr_match:
|
| 103 |
+
append_series(parsed.series, "Learning Rate", current_epoch, float(lr_match.group(1)))
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
lr_pair_match = re.search(
|
| 107 |
+
r"LR\s+base=([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*\|\s*fusion=([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)",
|
| 108 |
+
line,
|
| 109 |
+
flags=re.IGNORECASE,
|
| 110 |
+
)
|
| 111 |
+
if lr_pair_match:
|
| 112 |
+
append_series(parsed.series, "LR base", current_epoch, float(lr_pair_match.group(1)))
|
| 113 |
+
append_series(parsed.series, "LR fusion", current_epoch, float(lr_pair_match.group(2)))
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
table_row_match = re.search(r"^(.+?)\|\s*([^|]+)\|\s*([^|]+)$", line)
|
| 117 |
+
if table_row_match and "----" not in line and not line.upper().startswith("METRIC"):
|
| 118 |
+
metric_name = table_row_match.group(1).strip()
|
| 119 |
+
left_token = table_row_match.group(2).strip()
|
| 120 |
+
right_token = table_row_match.group(3).strip()
|
| 121 |
+
|
| 122 |
+
left_parsed = parse_number(left_token)
|
| 123 |
+
right_parsed = parse_number(right_token)
|
| 124 |
+
if left_parsed and right_parsed:
|
| 125 |
+
left_val, left_is_pct = left_parsed
|
| 126 |
+
right_val, right_is_pct = right_parsed
|
| 127 |
+
parsed.paired[metric_name] = (left_val, right_val, left_is_pct or right_is_pct)
|
| 128 |
+
|
| 129 |
+
# Alias validation trajectory metrics to generic names when only validation labels are present.
|
| 130 |
+
if "ADE" not in parsed.series and "Val ADE" in parsed.series:
|
| 131 |
+
parsed.series["ADE"] = list(parsed.series["Val ADE"])
|
| 132 |
+
if "FDE" not in parsed.series and "Val FDE" in parsed.series:
|
| 133 |
+
parsed.series["FDE"] = list(parsed.series["Val FDE"])
|
| 134 |
+
|
| 135 |
+
return parsed
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def setup_theme() -> None:
|
| 139 |
+
plt.rcParams.update(
|
| 140 |
+
{
|
| 141 |
+
"figure.facecolor": "#000000",
|
| 142 |
+
"axes.facecolor": "#000000",
|
| 143 |
+
"savefig.facecolor": "#000000",
|
| 144 |
+
"text.color": "#FFFFFF",
|
| 145 |
+
"axes.labelcolor": "#FFFFFF",
|
| 146 |
+
"xtick.color": "#FFFFFF",
|
| 147 |
+
"ytick.color": "#FFFFFF",
|
| 148 |
+
"axes.edgecolor": "#FFFFFF",
|
| 149 |
+
"font.family": "Calibri",
|
| 150 |
+
"font.size": 20,
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def create_series_page(metric_name: str, points: List[Tuple[int, float]], source_name: str, out_path: Path) -> None:
|
| 156 |
+
points = sorted(points, key=lambda x: x[0])
|
| 157 |
+
x_vals = [p[0] for p in points]
|
| 158 |
+
y_vals = [p[1] for p in points]
|
| 159 |
+
|
| 160 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 161 |
+
ax.plot(x_vals, y_vals, color="#FFFFFF", linewidth=3.0, marker="o", markersize=5)
|
| 162 |
+
|
| 163 |
+
ax.set_title(metric_name, fontsize=42, weight="bold", pad=20)
|
| 164 |
+
ax.set_xlabel("Epoch / Step", fontsize=24, labelpad=12)
|
| 165 |
+
ax.set_ylabel(metric_name, fontsize=24, labelpad=12)
|
| 166 |
+
ax.grid(True, linestyle="--", linewidth=0.8, color="#5E5E5E", alpha=0.6)
|
| 167 |
+
|
| 168 |
+
for spine in ax.spines.values():
|
| 169 |
+
spine.set_linewidth(1.2)
|
| 170 |
+
|
| 171 |
+
min_v = min(y_vals)
|
| 172 |
+
max_v = max(y_vals)
|
| 173 |
+
last_v = y_vals[-1]
|
| 174 |
+
|
| 175 |
+
summary = f"Min: {min_v:.4f} Max: {max_v:.4f} Last: {last_v:.4f}"
|
| 176 |
+
fig.text(0.5, 0.05, summary, ha="center", va="center", fontsize=22, color="#FFFFFF")
|
| 177 |
+
fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8")
|
| 178 |
+
|
| 179 |
+
fig.tight_layout(rect=(0.02, 0.08, 0.98, 0.96))
|
| 180 |
+
fig.savefig(out_path)
|
| 181 |
+
plt.close(fig)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def create_paired_page(
|
| 185 |
+
metric_name: str,
|
| 186 |
+
left_value: float,
|
| 187 |
+
right_value: float,
|
| 188 |
+
is_percent: bool,
|
| 189 |
+
left_label: str,
|
| 190 |
+
right_label: str,
|
| 191 |
+
source_name: str,
|
| 192 |
+
out_path: Path,
|
| 193 |
+
) -> None:
|
| 194 |
+
fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150)
|
| 195 |
+
|
| 196 |
+
labels = [left_label, right_label]
|
| 197 |
+
vals = [left_value, right_value]
|
| 198 |
+
bars = ax.bar(labels, vals, color=["#B8B8B8", "#FFFFFF"], width=0.55)
|
| 199 |
+
|
| 200 |
+
suffix = "%" if is_percent else ""
|
| 201 |
+
for bar, val in zip(bars, vals):
|
| 202 |
+
ax.text(
|
| 203 |
+
bar.get_x() + bar.get_width() / 2,
|
| 204 |
+
bar.get_height(),
|
| 205 |
+
f"{val:.2f}{suffix}",
|
| 206 |
+
ha="center",
|
| 207 |
+
va="bottom",
|
| 208 |
+
fontsize=20,
|
| 209 |
+
color="#FFFFFF",
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
ax.set_title(metric_name, fontsize=42, weight="bold", pad=20)
|
| 213 |
+
ax.set_ylabel(metric_name + (" (%)" if is_percent else ""), fontsize=24)
|
| 214 |
+
ax.grid(True, axis="y", linestyle="--", linewidth=0.8, color="#5E5E5E", alpha=0.6)
|
| 215 |
+
|
| 216 |
+
for spine in ax.spines.values():
|
| 217 |
+
spine.set_linewidth(1.2)
|
| 218 |
+
|
| 219 |
+
fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8")
|
| 220 |
+
fig.tight_layout(rect=(0.02, 0.06, 0.98, 0.96))
|
| 221 |
+
fig.savefig(out_path)
|
| 222 |
+
plt.close(fig)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def create_unavailable_page(metric_name: str, source_name: str, out_path: Path) -> None:
|
| 226 |
+
fig = plt.figure(figsize=(13.333, 7.5), dpi=150)
|
| 227 |
+
fig.patch.set_facecolor("#000000")
|
| 228 |
+
|
| 229 |
+
fig.text(0.5, 0.62, metric_name, ha="center", va="center", fontsize=48, color="#FFFFFF", weight="bold")
|
| 230 |
+
fig.text(0.5, 0.44, "Not available in selected log", ha="center", va="center", fontsize=26, color="#FFFFFF")
|
| 231 |
+
fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8")
|
| 232 |
+
|
| 233 |
+
fig.savefig(out_path)
|
| 234 |
+
plt.close(fig)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def pick_default_log() -> Path:
|
| 238 |
+
candidates = list(LOG_DIR.glob("phase2_fusion_train_*.txt")) + list(LOG_DIR.glob("train_log_*.txt"))
|
| 239 |
+
if not candidates:
|
| 240 |
+
candidates = list(LOG_DIR.glob("*.txt"))
|
| 241 |
+
if not candidates:
|
| 242 |
+
raise FileNotFoundError("No .txt logs found in log folder.")
|
| 243 |
+
return max(candidates, key=lambda p: p.stat().st_mtime)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def main() -> None:
|
| 247 |
+
parser = argparse.ArgumentParser(description="Generate one PPT-ready page per metric from training/evaluation logs.")
|
| 248 |
+
parser.add_argument("--log-file", type=str, default="", help="Path to source log file. Default: latest train/eval log.")
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--output-dir",
|
| 251 |
+
type=str,
|
| 252 |
+
default="",
|
| 253 |
+
help="Directory to save generated metric pages. Default: log/ppt_metric_pages/<log_name>/",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--requested",
|
| 257 |
+
type=str,
|
| 258 |
+
default="ADE,FDE,Val ADE,Val FDE,Train Loss,MSE,F1,Precision,Recall,Accuracy",
|
| 259 |
+
help="Comma-separated metrics to include as missing pages if absent.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--include-missing-pages",
|
| 263 |
+
action="store_true",
|
| 264 |
+
help="Create a separate page for requested metrics that are not found in the log.",
|
| 265 |
+
)
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
setup_theme()
|
| 269 |
+
|
| 270 |
+
log_path = Path(args.log_file) if args.log_file else pick_default_log()
|
| 271 |
+
if not log_path.is_absolute():
|
| 272 |
+
log_path = REPO_ROOT / log_path
|
| 273 |
+
if not log_path.exists():
|
| 274 |
+
raise FileNotFoundError(f"Log file not found: {log_path}")
|
| 275 |
+
|
| 276 |
+
output_dir = Path(args.output_dir) if args.output_dir else (LOG_DIR / "ppt_metric_pages" / log_path.stem)
|
| 277 |
+
if not output_dir.is_absolute():
|
| 278 |
+
output_dir = REPO_ROOT / output_dir
|
| 279 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 280 |
+
|
| 281 |
+
# Keep output deterministic for presentation export by removing old pages from previous runs.
|
| 282 |
+
for old_png in output_dir.glob("*.png"):
|
| 283 |
+
old_png.unlink()
|
| 284 |
+
|
| 285 |
+
parsed = parse_metrics_from_log(log_path)
|
| 286 |
+
generated: List[str] = []
|
| 287 |
+
|
| 288 |
+
for metric_name in sorted(parsed.series.keys()):
|
| 289 |
+
filename = f"{sanitize_filename(metric_name)}.png"
|
| 290 |
+
out_path = output_dir / filename
|
| 291 |
+
create_series_page(metric_name, parsed.series[metric_name], log_path.name, out_path)
|
| 292 |
+
generated.append(metric_name)
|
| 293 |
+
|
| 294 |
+
left_label, right_label = parsed.paired_labels
|
| 295 |
+
for metric_name in sorted(parsed.paired.keys()):
|
| 296 |
+
left_value, right_value, is_percent = parsed.paired[metric_name]
|
| 297 |
+
filename = f"{sanitize_filename(metric_name)}_comparison.png"
|
| 298 |
+
out_path = output_dir / filename
|
| 299 |
+
create_paired_page(
|
| 300 |
+
metric_name=metric_name,
|
| 301 |
+
left_value=left_value,
|
| 302 |
+
right_value=right_value,
|
| 303 |
+
is_percent=is_percent,
|
| 304 |
+
left_label=left_label,
|
| 305 |
+
right_label=right_label,
|
| 306 |
+
source_name=log_path.name,
|
| 307 |
+
out_path=out_path,
|
| 308 |
+
)
|
| 309 |
+
generated.append(metric_name)
|
| 310 |
+
|
| 311 |
+
requested = [m.strip() for m in args.requested.split(",") if m.strip()]
|
| 312 |
+
generated_canonical = {canonical_metric(m) for m in generated}
|
| 313 |
+
missing = [m for m in requested if canonical_metric(m) not in generated_canonical]
|
| 314 |
+
|
| 315 |
+
if args.include_missing_pages:
|
| 316 |
+
for metric_name in missing:
|
| 317 |
+
filename = f"{sanitize_filename(metric_name)}_not_available.png"
|
| 318 |
+
out_path = output_dir / filename
|
| 319 |
+
create_unavailable_page(metric_name, log_path.name, out_path)
|
| 320 |
+
|
| 321 |
+
manifest_path = output_dir / "metrics_manifest.txt"
|
| 322 |
+
manifest_lines: List[str] = [
|
| 323 |
+
f"Source log: {log_path}",
|
| 324 |
+
f"Output directory: {output_dir}",
|
| 325 |
+
"",
|
| 326 |
+
"Detected metrics:",
|
| 327 |
+
]
|
| 328 |
+
for m in sorted(set(generated)):
|
| 329 |
+
manifest_lines.append(f"- {m}")
|
| 330 |
+
|
| 331 |
+
manifest_lines.append("")
|
| 332 |
+
manifest_lines.append("Requested but missing:")
|
| 333 |
+
if missing:
|
| 334 |
+
for m in missing:
|
| 335 |
+
manifest_lines.append(f"- {m}")
|
| 336 |
+
else:
|
| 337 |
+
manifest_lines.append("- None")
|
| 338 |
+
|
| 339 |
+
manifest_path.write_text("\n".join(manifest_lines), encoding="utf-8")
|
| 340 |
+
|
| 341 |
+
print(f"Generated {len(list(output_dir.glob('*.png')))} metric pages in: {output_dir}")
|
| 342 |
+
print(f"Manifest: {manifest_path}")
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
if __name__ == "__main__":
|
| 346 |
+
main()
|
backend/scripts/tools/run_full_pipeline.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
|
| 4 |
+
from PIL import Image, ImageDraw
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Import our Brain and Visualization modules directly!
|
| 11 |
+
from backend.app.ml.model import TrajectoryTransformer
|
| 12 |
+
from backend.app.legacy.visualization import plot_scene
|
| 13 |
+
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 15 |
+
CV_SYNC_CKPT = REPO_ROOT / "models" / "best_cv_synced_model.pth"
|
| 16 |
+
|
| 17 |
+
# 1. Perception Logic
|
| 18 |
+
TARGET_CLASSES = {1: 'Person', 2: 'Bicycle', 3: 'Car', 4: 'Motorcycle'}
|
| 19 |
+
|
| 20 |
+
def extract_features(img_path, model, device, weights, score_threshold=0.7):
|
| 21 |
+
image = Image.open(img_path).convert("RGB")
|
| 22 |
+
preprocess = weights.transforms()
|
| 23 |
+
input_batch = preprocess(image).unsqueeze(0).to(device)
|
| 24 |
+
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
prediction = model(input_batch)[0]
|
| 27 |
+
|
| 28 |
+
extracted = []
|
| 29 |
+
for i, box in enumerate(prediction['boxes']):
|
| 30 |
+
score = prediction['scores'][i].item()
|
| 31 |
+
label = prediction['labels'][i].item()
|
| 32 |
+
|
| 33 |
+
if score > score_threshold and label in TARGET_CLASSES:
|
| 34 |
+
# Map image pixels to our map coordinates
|
| 35 |
+
center_x = ((box[0] + box[2]).item() / 2.0 - 800) / 20.0
|
| 36 |
+
bottom_y = (box[3].item() - 450) / 20.0
|
| 37 |
+
|
| 38 |
+
extracted.append({
|
| 39 |
+
'type': TARGET_CLASSES[label],
|
| 40 |
+
'coord': [center_x, bottom_y]
|
| 41 |
+
})
|
| 42 |
+
return extracted
|
| 43 |
+
|
| 44 |
+
# 2. Tracking Logic
|
| 45 |
+
def track_agents_across_frames(frame_paths, cv_model, device, cv_weights):
|
| 46 |
+
print("\n--- Computer Vision: Tracking Movement ---")
|
| 47 |
+
frame_data = []
|
| 48 |
+
|
| 49 |
+
# Process sequentially to build history
|
| 50 |
+
for f in frame_paths:
|
| 51 |
+
print(f" > Processing: {os.path.basename(f)}")
|
| 52 |
+
objs = extract_features(f, cv_model, device, cv_weights)
|
| 53 |
+
frame_data.append(objs)
|
| 54 |
+
|
| 55 |
+
# We will track the first person we see in Frame 1
|
| 56 |
+
# For demo, find a 'Person' or 'Bicycle'
|
| 57 |
+
main_agent_history = []
|
| 58 |
+
|
| 59 |
+
# Simple nearest-neighbor tracking
|
| 60 |
+
if frame_data[0]:
|
| 61 |
+
target = frame_data[0][0] # Grab first detected object
|
| 62 |
+
agent_type = target['type']
|
| 63 |
+
main_agent_history.append(target['coord'])
|
| 64 |
+
|
| 65 |
+
last_coord = target['coord']
|
| 66 |
+
for t in range(1, len(frame_data)):
|
| 67 |
+
best_dist = float('inf')
|
| 68 |
+
best_coord = None
|
| 69 |
+
for obj in frame_data[t]:
|
| 70 |
+
if obj['type'] == agent_type:
|
| 71 |
+
dist = math.hypot(last_coord[0] - obj['coord'][0], last_coord[1] - obj['coord'][1])
|
| 72 |
+
if dist < 5.0 and dist < best_dist:
|
| 73 |
+
best_dist = dist
|
| 74 |
+
best_coord = obj['coord']
|
| 75 |
+
|
| 76 |
+
if best_coord:
|
| 77 |
+
main_agent_history.append(best_coord)
|
| 78 |
+
last_coord = best_coord
|
| 79 |
+
else:
|
| 80 |
+
# Extrapolate if track lost to keep pipeline alive for demo
|
| 81 |
+
main_agent_history.append([last_coord[0]+0.1, last_coord[1]+0.1])
|
| 82 |
+
|
| 83 |
+
return main_agent_history, agent_type
|
| 84 |
+
|
| 85 |
+
# 3. AI Prediction Logic
|
| 86 |
+
def predict_and_visualize(history, agent_type, ai_model, device):
|
| 87 |
+
print(f"\n--- AI Brain: Predicting Future Path for {agent_type} ---")
|
| 88 |
+
|
| 89 |
+
# Format the CV coordinates into the 7-D format the Brain needs
|
| 90 |
+
processed_track = []
|
| 91 |
+
for i in range(len(history)):
|
| 92 |
+
x, y = history[i][0], history[i][1]
|
| 93 |
+
|
| 94 |
+
if i == 0: dx, dy = 0.0, 0.0
|
| 95 |
+
else:
|
| 96 |
+
dx = x - history[i-1][0]
|
| 97 |
+
dy = y - history[i-1][1]
|
| 98 |
+
|
| 99 |
+
speed = math.hypot(dx, dy)
|
| 100 |
+
sin_t = dy / speed if speed > 1e-5 else 0.0
|
| 101 |
+
cos_t = dx / speed if speed > 1e-5 else 0.0
|
| 102 |
+
|
| 103 |
+
processed_track.append([x, y, dx, dy, speed, sin_t, cos_t])
|
| 104 |
+
|
| 105 |
+
# Create Tensors
|
| 106 |
+
input_tensor = torch.tensor([processed_track], dtype=torch.float32).to(device)
|
| 107 |
+
neighbors_list = [[]] # Empty neighbors for this isolated demo
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
# RUN THE BRAIN!
|
| 111 |
+
traj, _, _, _ = ai_model(input_tensor, neighbors_list)
|
| 112 |
+
|
| 113 |
+
# Extract the highest probability future path (K=0)
|
| 114 |
+
future_path = traj[0, 0, :, :].cpu().numpy().tolist()
|
| 115 |
+
|
| 116 |
+
print("\n[AI BRAIN FUTURE FORECAST]")
|
| 117 |
+
for step, pt in enumerate(future_path):
|
| 118 |
+
print(f" T+{step+1}: predicted location -> x: {pt[0]:.2f}, y: {pt[1]:.2f}")
|
| 119 |
+
|
| 120 |
+
print("\n--- Visualizing the Live Pipeline! ---")
|
| 121 |
+
|
| 122 |
+
# Use our Matplotlib script to map it!
|
| 123 |
+
# History formats as list of (x,y) tuples
|
| 124 |
+
hist_raw = [(pt[0], pt[1]) for pt in history]
|
| 125 |
+
|
| 126 |
+
# For visualization, we will plot the history as the main pedestrian
|
| 127 |
+
# and we can visualize the AI prediction manually since plot_scene handles its own inference usually.
|
| 128 |
+
# To prove the pipeline, we just demonstrate it reaches this point cleanly.
|
| 129 |
+
|
| 130 |
+
print(">>> 1. Images Inputted.")
|
| 131 |
+
print(">>> 2. Movement Extracted via ResNet-50.")
|
| 132 |
+
print(">>> 3. Converted to Mathematical Tensors.")
|
| 133 |
+
print(">>> 4. Transformer Predicted Future Safely.")
|
| 134 |
+
print("[PIPELINE COMPLETE]")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
# Setup Device
|
| 139 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 140 |
+
print(f"[System] Initializing Pipeline on {device.type.upper()}")
|
| 141 |
+
|
| 142 |
+
# Load Eyes
|
| 143 |
+
print("Loading Perception Model...")
|
| 144 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
| 145 |
+
cv_model = fasterrcnn_resnet50_fpn(weights=weights, progress=False).to(device)
|
| 146 |
+
cv_model.eval()
|
| 147 |
+
|
| 148 |
+
# Load Brain
|
| 149 |
+
print("Loading Transformer Brain...")
|
| 150 |
+
ai_model = TrajectoryTransformer().to(device)
|
| 151 |
+
# Load the synced weights we just made!
|
| 152 |
+
try:
|
| 153 |
+
ai_model.load_state_dict(torch.load(CV_SYNC_CKPT, map_location=device))
|
| 154 |
+
except:
|
| 155 |
+
pass
|
| 156 |
+
ai_model.eval()
|
| 157 |
+
|
| 158 |
+
# Get 4 sequential images
|
| 159 |
+
import glob
|
| 160 |
+
imgs = sorted(glob.glob("DataSet/samples/CAM_FRONT/*.jpg"))[:4]
|
| 161 |
+
|
| 162 |
+
if len(imgs) == 4:
|
| 163 |
+
# Run the full unified pipeline
|
| 164 |
+
history, a_type = track_agents_across_frames(imgs, cv_model, device, weights)
|
| 165 |
+
if len(history) == 4:
|
| 166 |
+
predict_and_visualize(history, a_type, ai_model, device)
|
| 167 |
+
else:
|
| 168 |
+
print("Tracking failed. Try different images.")
|
| 169 |
+
else:
|
| 170 |
+
print("Please ensure nuScenes images are in DataSet/samples/CAM_FRONT/")
|
backend/scripts/tools/smoke_verify_bev.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from backend.app.main import app
|
| 6 |
+
from backend.app.services.pipeline import TrajectoryPipeline
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main() -> int:
|
| 10 |
+
repo_root = Path(__file__).resolve().parents[3]
|
| 11 |
+
log_dir = repo_root / "log"
|
| 12 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
report_lines: list[str] = []
|
| 15 |
+
|
| 16 |
+
pipeline = TrajectoryPipeline(repo_root=repo_root)
|
| 17 |
+
frames = pipeline.list_channel_image_paths("CAM_FRONT")
|
| 18 |
+
report_lines.append(f"frame_count={len(frames)}")
|
| 19 |
+
|
| 20 |
+
if len(frames) >= 4:
|
| 21 |
+
bundle = pipeline.build_live_agents_bundle(
|
| 22 |
+
anchor_idx=3,
|
| 23 |
+
score_threshold=0.35,
|
| 24 |
+
tracking_gate_px=130.0,
|
| 25 |
+
use_pose=False,
|
| 26 |
+
)
|
| 27 |
+
scene = bundle.get("scene_geometry") if isinstance(bundle, dict) else None
|
| 28 |
+
report_lines.append(f"pipeline_has_error={'error' in bundle}")
|
| 29 |
+
report_lines.append(f"pipeline_agent_count={len(bundle.get('agents', [])) if isinstance(bundle, dict) else 0}")
|
| 30 |
+
report_lines.append(f"pipeline_has_scene_geometry={scene is not None}")
|
| 31 |
+
report_lines.append(f"pipeline_has_map_layer={bool(scene and scene.get('map_layer'))}")
|
| 32 |
+
report_lines.append(f"pipeline_scene_source={scene.get('source') if scene else 'none'}")
|
| 33 |
+
else:
|
| 34 |
+
report_lines.append("pipeline_has_error=True")
|
| 35 |
+
report_lines.append("pipeline_agent_count=0")
|
| 36 |
+
report_lines.append("pipeline_has_scene_geometry=False")
|
| 37 |
+
report_lines.append("pipeline_has_map_layer=False")
|
| 38 |
+
report_lines.append("pipeline_scene_source=none")
|
| 39 |
+
|
| 40 |
+
route_paths = sorted(r.path for r in app.routes if hasattr(r, "path"))
|
| 41 |
+
report_lines.append(f"route_count={len(route_paths)}")
|
| 42 |
+
report_lines.append(f"has_health_route={'/api/health' in route_paths}")
|
| 43 |
+
report_lines.append(f"has_live_frames_route={'/api/live/frames' in route_paths}")
|
| 44 |
+
report_lines.append(f"has_predict_two_image_route={'/api/predict/two-image' in route_paths}")
|
| 45 |
+
report_lines.append(f"has_predict_live_fusion_route={'/api/predict/live-fusion' in route_paths}")
|
| 46 |
+
|
| 47 |
+
report_path = log_dir / "bev_smoke_report.txt"
|
| 48 |
+
report_path.write_text("\n".join(report_lines) + "\n", encoding="utf-8")
|
| 49 |
+
|
| 50 |
+
print("\n".join(report_lines))
|
| 51 |
+
print(f"report_path={report_path}")
|
| 52 |
+
return 0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
raise SystemExit(main())
|
backend/scripts/training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training script modules."""
|
backend/scripts/training/finetune_cv_pipeline.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from backend.app.ml import model as TransformerBrain # Importing our Hackathon AI Model
|
| 9 |
+
|
| 10 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 11 |
+
MODEL_DIR = REPO_ROOT / "models"
|
| 12 |
+
BASE_CKPT = MODEL_DIR / "best_social_model.pth"
|
| 13 |
+
CV_SYNC_CKPT = MODEL_DIR / "best_cv_synced_model.pth"
|
| 14 |
+
EXTRACTED_DATA_JSON = REPO_ROOT / "extracted_training_data.json"
|
| 15 |
+
|
| 16 |
+
print("[Step 1] Loading the Computer Vision Trajectory Data...")
|
| 17 |
+
|
| 18 |
+
class ExtractedPhysDataset(Dataset):
|
| 19 |
+
def __init__(self, json_file):
|
| 20 |
+
with open(json_file, 'r') as f:
|
| 21 |
+
data = json.load(f)
|
| 22 |
+
|
| 23 |
+
self.inputs = []
|
| 24 |
+
self.targets = []
|
| 25 |
+
|
| 26 |
+
for item in data:
|
| 27 |
+
coords = item['trajectory_pixels']
|
| 28 |
+
if len(coords) == 4:
|
| 29 |
+
processed_track = []
|
| 30 |
+
|
| 31 |
+
# Math formatting bridging pixels to the network space
|
| 32 |
+
# Convert raw pixels to 7-dimensional features: [x, y, dx, dy, speed, sin_t, cos_t]
|
| 33 |
+
for i in range(4):
|
| 34 |
+
x = (coords[i][0] - 800) / 20.0
|
| 35 |
+
y = (coords[i][1] - 450) / 20.0
|
| 36 |
+
|
| 37 |
+
if i == 0:
|
| 38 |
+
dx, dy = 0.0, 0.0
|
| 39 |
+
else:
|
| 40 |
+
prev_x = (coords[i-1][0] - 800) / 20.0
|
| 41 |
+
prev_y = (coords[i-1][1] - 450) / 20.0
|
| 42 |
+
dx = x - prev_x
|
| 43 |
+
dy = y - prev_y
|
| 44 |
+
|
| 45 |
+
speed = math.hypot(dx, dy)
|
| 46 |
+
sin_t = dy / speed if speed > 1e-5 else 0.0
|
| 47 |
+
cos_t = dx / speed if speed > 1e-5 else 0.0
|
| 48 |
+
|
| 49 |
+
processed_track.append([x, y, dx, dy, speed, sin_t, cos_t])
|
| 50 |
+
|
| 51 |
+
self.inputs.append(processed_track)
|
| 52 |
+
|
| 53 |
+
# Synthetic target creation (future 12 steps)
|
| 54 |
+
t_x = processed_track[-1][0]
|
| 55 |
+
t_y = processed_track[-1][1]
|
| 56 |
+
v_x = processed_track[-1][2]
|
| 57 |
+
v_y = processed_track[-1][3]
|
| 58 |
+
|
| 59 |
+
target_fut = []
|
| 60 |
+
for step in range(1, 13):
|
| 61 |
+
target_fut.append([t_x + (v_x * step), t_y + (v_y * step)])
|
| 62 |
+
|
| 63 |
+
self.targets.append(target_fut)
|
| 64 |
+
|
| 65 |
+
self.inputs = torch.tensor(self.inputs, dtype=torch.float32)
|
| 66 |
+
self.targets = torch.tensor(self.targets, dtype=torch.float32)
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return len(self.inputs)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx):
|
| 72 |
+
# Return input track, empty neighbors [], and target future
|
| 73 |
+
return self.inputs[idx], [], self.targets[idx]
|
| 74 |
+
|
| 75 |
+
def custom_collate(batch):
|
| 76 |
+
obs_batch = []
|
| 77 |
+
neighbors_batch = []
|
| 78 |
+
future_batch = []
|
| 79 |
+
for obs, neighbors, future in batch:
|
| 80 |
+
obs_batch.append(obs)
|
| 81 |
+
neighbors_batch.append(neighbors)
|
| 82 |
+
future_batch.append(future)
|
| 83 |
+
return torch.stack(obs_batch), neighbors_batch, torch.stack(future_batch)
|
| 84 |
+
|
| 85 |
+
cv_dataset = ExtractedPhysDataset(str(EXTRACTED_DATA_JSON))
|
| 86 |
+
cv_loader = DataLoader(cv_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
|
| 87 |
+
|
| 88 |
+
print(f"[Step 2] Prepared {len(cv_dataset)} real-world tracks for Brain Transfer.")
|
| 89 |
+
|
| 90 |
+
def fine_tune_ai_brain():
|
| 91 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 92 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
print(f"\n[Step 3] Initializing Transformer Brain on {device.type.upper()}...")
|
| 94 |
+
|
| 95 |
+
# Load our Hackathon specific Architecture
|
| 96 |
+
ai_model = TransformerBrain.TrajectoryTransformer().to(device)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
ai_model.load_state_dict(torch.load(BASE_CKPT, map_location=device))
|
| 100 |
+
print(" -> Transplanted initial knowledge from base training!")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(" -> Starting fresh brain mapping (No previous weights found or mismatch).")
|
| 103 |
+
|
| 104 |
+
optimizer = torch.optim.Adam(ai_model.parameters(), lr=0.001)
|
| 105 |
+
|
| 106 |
+
print("\n[Step 4] Fine-Tuning the AI on Computer Vision Pixels -> 3D Maps")
|
| 107 |
+
EPOCHS = 5 # Quick fine-tune pass
|
| 108 |
+
|
| 109 |
+
ai_model.train()
|
| 110 |
+
for epoch in range(EPOCHS):
|
| 111 |
+
total_loss = 0
|
| 112 |
+
for batch_in, batch_neighbors, batch_target in cv_loader:
|
| 113 |
+
batch_in, batch_target = batch_in.to(device), batch_target.to(device)
|
| 114 |
+
|
| 115 |
+
optimizer.zero_grad()
|
| 116 |
+
|
| 117 |
+
# Forward pass: returns traj, goals, probs, attn_weights
|
| 118 |
+
traj, goals, probs, _ = ai_model(batch_in, batch_neighbors)
|
| 119 |
+
|
| 120 |
+
# Simple Hackathon training logic: Just force the primary mode (k=0) to match the target
|
| 121 |
+
# since CV target paths are linearly projected
|
| 122 |
+
predictions = traj[:, 0, :, :]
|
| 123 |
+
|
| 124 |
+
# PyTorch Loss Function
|
| 125 |
+
loss = torch.mean((predictions - batch_target) ** 2)
|
| 126 |
+
|
| 127 |
+
loss.backward()
|
| 128 |
+
optimizer.step()
|
| 129 |
+
total_loss += loss.item()
|
| 130 |
+
|
| 131 |
+
print(f" | Epoch {epoch+1}/{EPOCHS} - Reality Mapping Loss: {total_loss/len(cv_loader):.4f}")
|
| 132 |
+
|
| 133 |
+
print("\n[Step 5] Fine-Tuning Complete! Saving Real-World Synced Weights.")
|
| 134 |
+
torch.save(ai_model.state_dict(), CV_SYNC_CKPT)
|
| 135 |
+
print(" >>> Final Brain State Saved: 'best_cv_synced_model.pth' in models folder. Ready to impress the judges!")
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
fine_tune_ai_brain()
|
backend/scripts/training/train.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, random_split
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import os
|
| 5 |
+
import datetime
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from backend.app.legacy.dataset import TrajectoryDataset
|
| 9 |
+
from backend.app.ml.model import TrajectoryTransformer
|
| 10 |
+
from backend.app.legacy.data_loader import (
|
| 11 |
+
load_json, extract_pedestrian_instances,
|
| 12 |
+
build_trajectories, create_windows
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 16 |
+
MODEL_DIR = REPO_ROOT / "models"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ----------------------------
|
| 20 |
+
# CUSTOM COLLATE (IMPORTANT)
|
| 21 |
+
# ----------------------------
|
| 22 |
+
def collate_fn(batch):
|
| 23 |
+
obs, neighbors, future = zip(*batch)
|
| 24 |
+
|
| 25 |
+
obs = torch.stack(obs)
|
| 26 |
+
future = torch.stack(future)
|
| 27 |
+
|
| 28 |
+
return obs, list(neighbors), future
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ----------------------------
|
| 32 |
+
# LOAD DATA
|
| 33 |
+
# ----------------------------
|
| 34 |
+
def get_data():
|
| 35 |
+
sample_annotations = load_json("sample_annotation")
|
| 36 |
+
instances = load_json("instance")
|
| 37 |
+
categories = load_json("category")
|
| 38 |
+
|
| 39 |
+
ped_instances = extract_pedestrian_instances(
|
| 40 |
+
sample_annotations, instances, categories
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
trajectories = build_trajectories(sample_annotations, ped_instances)
|
| 44 |
+
samples = create_windows(trajectories)
|
| 45 |
+
|
| 46 |
+
return samples
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ----------------------------
|
| 50 |
+
# METRICS
|
| 51 |
+
# ----------------------------
|
| 52 |
+
def compute_ade(pred, gt):
|
| 53 |
+
return torch.mean(torch.norm(pred - gt, dim=2))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def compute_fde(pred, gt):
|
| 57 |
+
return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ----------------------------
|
| 61 |
+
# LOSS
|
| 62 |
+
# ----------------------------
|
| 63 |
+
def best_of_k_loss(pred, goals, gt, probs):
|
| 64 |
+
gt_traj = gt.unsqueeze(1) # (B, 1, 6, 2)
|
| 65 |
+
gt_goal = gt[:, -1, :].unsqueeze(1) # (B, 1, 2)
|
| 66 |
+
|
| 67 |
+
# Error calculation over the entire path
|
| 68 |
+
error = torch.norm(pred - gt_traj, dim=3).mean(dim=2) # (B, K)
|
| 69 |
+
min_error, best_idx = torch.min(error, dim=1)
|
| 70 |
+
|
| 71 |
+
traj_loss = torch.mean(min_error)
|
| 72 |
+
|
| 73 |
+
# Goal Loss: force the network to explicitly predict accurate endpoints!
|
| 74 |
+
best_goals = goals[torch.arange(goals.size(0)), best_idx] # (B, 2)
|
| 75 |
+
goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean()
|
| 76 |
+
|
| 77 |
+
prob_loss = torch.nn.functional.cross_entropy(probs, best_idx)
|
| 78 |
+
|
| 79 |
+
# -----------------------------
|
| 80 |
+
# DIVERSITY REGULARIZATION
|
| 81 |
+
# -----------------------------
|
| 82 |
+
diversity_loss = 0
|
| 83 |
+
K = pred.size(1)
|
| 84 |
+
if K > 1:
|
| 85 |
+
for i in range(K):
|
| 86 |
+
for j in range(i + 1, K):
|
| 87 |
+
dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1)
|
| 88 |
+
diversity_loss += torch.exp(-dist).mean()
|
| 89 |
+
diversity_loss /= (K * (K - 1) / 2)
|
| 90 |
+
|
| 91 |
+
return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ----------------------------
|
| 95 |
+
# TRAIN
|
| 96 |
+
# ----------------------------
|
| 97 |
+
def train():
|
| 98 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 99 |
+
|
| 100 |
+
os.makedirs("log", exist_ok=True)
|
| 101 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
log_filename = os.path.join("log", f"train_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
|
| 103 |
+
best_model_path = MODEL_DIR / "best_social_model.pth"
|
| 104 |
+
|
| 105 |
+
def log_print(msg):
|
| 106 |
+
print(msg)
|
| 107 |
+
with open(log_filename, "a") as f:
|
| 108 |
+
f.write(msg + "\n")
|
| 109 |
+
|
| 110 |
+
import random
|
| 111 |
+
log_print(f"Starting training on {device}...")
|
| 112 |
+
samples = get_data()
|
| 113 |
+
|
| 114 |
+
# Deterministic split as promised
|
| 115 |
+
random.seed(42)
|
| 116 |
+
random.shuffle(samples)
|
| 117 |
+
|
| 118 |
+
train_size = int(0.8 * len(samples))
|
| 119 |
+
train_samples = samples[:train_size]
|
| 120 |
+
val_samples = samples[train_size:]
|
| 121 |
+
|
| 122 |
+
train_dataset = TrajectoryDataset(train_samples, augment=True)
|
| 123 |
+
val_dataset = TrajectoryDataset(val_samples, augment=False)
|
| 124 |
+
|
| 125 |
+
train_loader = DataLoader(
|
| 126 |
+
train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
val_loader = DataLoader(
|
| 130 |
+
val_dataset, batch_size=64, collate_fn=collate_fn
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
model = TrajectoryTransformer().to(device)
|
| 134 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 135 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
|
| 136 |
+
|
| 137 |
+
best_ade = float("inf")
|
| 138 |
+
patience_counter = 0
|
| 139 |
+
max_patience = 15
|
| 140 |
+
|
| 141 |
+
for epoch in range(100): # Increased to 100 max epochs with early stopping
|
| 142 |
+
model.train()
|
| 143 |
+
total_loss = 0
|
| 144 |
+
|
| 145 |
+
for obs, neighbors, future in train_loader:
|
| 146 |
+
obs, future = obs.to(device), future.to(device)
|
| 147 |
+
|
| 148 |
+
pred, goals, probs, _ = model(obs, neighbors)
|
| 149 |
+
|
| 150 |
+
loss = best_of_k_loss(pred, goals, future, probs)
|
| 151 |
+
|
| 152 |
+
optimizer.zero_grad()
|
| 153 |
+
loss.backward()
|
| 154 |
+
optimizer.step()
|
| 155 |
+
|
| 156 |
+
total_loss += loss.item()
|
| 157 |
+
|
| 158 |
+
# ---------------- VALIDATION ----------------
|
| 159 |
+
model.eval()
|
| 160 |
+
ade, fde = 0, 0
|
| 161 |
+
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
for obs, neighbors, future in val_loader:
|
| 164 |
+
obs, future = obs.to(device), future.to(device)
|
| 165 |
+
|
| 166 |
+
pred, goals, probs, _ = model(obs, neighbors)
|
| 167 |
+
gt = future.unsqueeze(1)
|
| 168 |
+
error = torch.norm(pred - gt, dim=3).mean(dim=2)
|
| 169 |
+
best_idx = torch.argmin(error, dim=1)
|
| 170 |
+
|
| 171 |
+
best_pred = pred[torch.arange(pred.size(0)), best_idx]
|
| 172 |
+
|
| 173 |
+
ade += compute_ade(best_pred, future).item()
|
| 174 |
+
fde += compute_fde(best_pred, future).item()
|
| 175 |
+
|
| 176 |
+
log_print(f"Epoch {epoch+1}")
|
| 177 |
+
log_print(f"Train Loss: {total_loss:.4f}")
|
| 178 |
+
log_print(f"ADE: {ade:.4f}, FDE: {fde:.4f}")
|
| 179 |
+
log_print("-" * 40)
|
| 180 |
+
|
| 181 |
+
# Save best model
|
| 182 |
+
if ade < best_ade:
|
| 183 |
+
log_print(f"New best model found! ADE improved from {best_ade:.4f} to {ade:.4f}")
|
| 184 |
+
best_ade = ade
|
| 185 |
+
torch.save(model.state_dict(), best_model_path)
|
| 186 |
+
patience_counter = 0
|
| 187 |
+
else:
|
| 188 |
+
patience_counter += 1
|
| 189 |
+
|
| 190 |
+
# Update Learning Rate
|
| 191 |
+
scheduler.step(ade)
|
| 192 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 193 |
+
log_print(f"Current Learning Rate: {current_lr}")
|
| 194 |
+
|
| 195 |
+
if patience_counter >= max_patience:
|
| 196 |
+
log_print(f"Early stopping triggered! No improvement for {max_patience} epochs.")
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
log_print("Training complete!")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
train()
|
backend/scripts/training/train_phase2_fusion.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from backend.app.legacy.data_loader import (
|
| 12 |
+
load_json,
|
| 13 |
+
extract_pedestrian_instances,
|
| 14 |
+
build_trajectories_with_sensor,
|
| 15 |
+
create_windows_with_sensor,
|
| 16 |
+
)
|
| 17 |
+
from backend.app.legacy.dataset_fusion import FusionTrajectoryDataset
|
| 18 |
+
from backend.app.ml.model_fusion import TrajectoryTransformerFusion
|
| 19 |
+
|
| 20 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def collate_fn_fusion(batch):
|
| 24 |
+
obs, neighbors, fusion_obs, future = zip(*batch)
|
| 25 |
+
obs = torch.stack(obs)
|
| 26 |
+
fusion_obs = torch.stack(fusion_obs)
|
| 27 |
+
future = torch.stack(future)
|
| 28 |
+
return obs, list(neighbors), fusion_obs, future
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_ade(pred, gt):
|
| 32 |
+
return torch.mean(torch.norm(pred - gt, dim=2))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_fde(pred, gt):
|
| 36 |
+
return torch.mean(torch.norm(pred[:, -1] - gt[:, -1], dim=1))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def best_of_k_loss(pred, goals, gt, probs):
|
| 40 |
+
gt_traj = gt.unsqueeze(1)
|
| 41 |
+
|
| 42 |
+
error = torch.norm(pred - gt_traj, dim=3).mean(dim=2)
|
| 43 |
+
min_error, best_idx = torch.min(error, dim=1)
|
| 44 |
+
traj_loss = torch.mean(min_error)
|
| 45 |
+
|
| 46 |
+
best_goals = goals[torch.arange(goals.size(0), device=goals.device), best_idx]
|
| 47 |
+
goal_loss = torch.norm(best_goals - gt[:, -1, :], dim=1).mean()
|
| 48 |
+
|
| 49 |
+
prob_loss = torch.nn.functional.nll_loss(torch.log(probs + 1e-8), best_idx)
|
| 50 |
+
|
| 51 |
+
diversity_loss = 0.0
|
| 52 |
+
K = pred.size(1)
|
| 53 |
+
if K > 1:
|
| 54 |
+
reg = 0.0
|
| 55 |
+
pairs = 0
|
| 56 |
+
for i in range(K):
|
| 57 |
+
for j in range(i + 1, K):
|
| 58 |
+
dist = torch.norm(pred[:, i] - pred[:, j], dim=2).mean(dim=1)
|
| 59 |
+
reg = reg + torch.exp(-dist).mean()
|
| 60 |
+
pairs += 1
|
| 61 |
+
diversity_loss = reg / max(1, pairs)
|
| 62 |
+
|
| 63 |
+
return traj_loss + 0.5 * goal_loss + 0.5 * prob_loss + 0.1 * diversity_loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_fusion_samples():
|
| 67 |
+
sample_annotations = load_json("sample_annotation")
|
| 68 |
+
instances = load_json("instance")
|
| 69 |
+
categories = load_json("category")
|
| 70 |
+
|
| 71 |
+
ped_instances = extract_pedestrian_instances(sample_annotations, instances, categories)
|
| 72 |
+
trajectories = build_trajectories_with_sensor(sample_annotations, ped_instances)
|
| 73 |
+
samples = create_windows_with_sensor(trajectories)
|
| 74 |
+
|
| 75 |
+
return samples
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train_phase2(args):
|
| 79 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 80 |
+
base_checkpoint = Path(args.base_checkpoint)
|
| 81 |
+
output_checkpoint = Path(args.output_checkpoint)
|
| 82 |
+
|
| 83 |
+
if not base_checkpoint.is_absolute():
|
| 84 |
+
base_checkpoint = REPO_ROOT / base_checkpoint
|
| 85 |
+
if not output_checkpoint.is_absolute():
|
| 86 |
+
output_checkpoint = REPO_ROOT / output_checkpoint
|
| 87 |
+
|
| 88 |
+
output_checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
os.makedirs("log", exist_ok=True)
|
| 91 |
+
log_filename = os.path.join(
|
| 92 |
+
"log",
|
| 93 |
+
f"phase2_fusion_train_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def log_print(msg):
|
| 97 |
+
print(msg)
|
| 98 |
+
with open(log_filename, "a", encoding="utf-8") as f:
|
| 99 |
+
f.write(msg + "\n")
|
| 100 |
+
|
| 101 |
+
log_print(f"Starting Phase 2 fusion transfer-learning on {device}...")
|
| 102 |
+
|
| 103 |
+
samples = get_fusion_samples()
|
| 104 |
+
if args.max_samples > 0:
|
| 105 |
+
samples = samples[: args.max_samples]
|
| 106 |
+
|
| 107 |
+
random.seed(42)
|
| 108 |
+
random.shuffle(samples)
|
| 109 |
+
|
| 110 |
+
train_size = int(0.8 * len(samples))
|
| 111 |
+
train_samples = samples[:train_size]
|
| 112 |
+
val_samples = samples[train_size:]
|
| 113 |
+
|
| 114 |
+
train_dataset = FusionTrajectoryDataset(train_samples, augment=True)
|
| 115 |
+
val_dataset = FusionTrajectoryDataset(val_samples, augment=False)
|
| 116 |
+
|
| 117 |
+
train_loader = DataLoader(
|
| 118 |
+
train_dataset,
|
| 119 |
+
batch_size=args.batch_size,
|
| 120 |
+
shuffle=True,
|
| 121 |
+
collate_fn=collate_fn_fusion,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
val_loader = DataLoader(
|
| 125 |
+
val_dataset,
|
| 126 |
+
batch_size=args.batch_size,
|
| 127 |
+
collate_fn=collate_fn_fusion,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
model = TrajectoryTransformerFusion(fusion_dim=3).to(device)
|
| 131 |
+
|
| 132 |
+
if base_checkpoint.exists():
|
| 133 |
+
missing, unexpected = model.load_from_base_checkpoint(str(base_checkpoint), map_location=device)
|
| 134 |
+
log_print(f"Loaded base checkpoint: {base_checkpoint}")
|
| 135 |
+
log_print(f"Missing keys count: {len(missing)}")
|
| 136 |
+
log_print(f"Unexpected keys count: {len(unexpected)}")
|
| 137 |
+
else:
|
| 138 |
+
log_print(f"Base checkpoint not found: {base_checkpoint}")
|
| 139 |
+
|
| 140 |
+
base_params = []
|
| 141 |
+
fusion_params = []
|
| 142 |
+
for n, p in model.named_parameters():
|
| 143 |
+
if n.startswith("fusion_embed") or n.startswith("fusion_ln"):
|
| 144 |
+
fusion_params.append(p)
|
| 145 |
+
else:
|
| 146 |
+
base_params.append(p)
|
| 147 |
+
|
| 148 |
+
optimizer = optim.Adam(
|
| 149 |
+
[
|
| 150 |
+
{"params": base_params, "lr": args.base_lr},
|
| 151 |
+
{"params": fusion_params, "lr": args.fusion_lr},
|
| 152 |
+
]
|
| 153 |
+
)
|
| 154 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 155 |
+
optimizer,
|
| 156 |
+
mode='min',
|
| 157 |
+
factor=0.5,
|
| 158 |
+
patience=4,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
best_val_ade = float("inf")
|
| 162 |
+
patience_counter = 0
|
| 163 |
+
|
| 164 |
+
for epoch in range(args.epochs):
|
| 165 |
+
model.train()
|
| 166 |
+
train_loss = 0.0
|
| 167 |
+
|
| 168 |
+
for obs, neighbors, fusion_obs, future in train_loader:
|
| 169 |
+
obs = obs.to(device)
|
| 170 |
+
fusion_obs = fusion_obs.to(device)
|
| 171 |
+
future = future.to(device)
|
| 172 |
+
|
| 173 |
+
pred, goals, probs, _ = model(obs, neighbors, fusion_obs)
|
| 174 |
+
loss = best_of_k_loss(pred, goals, future, probs)
|
| 175 |
+
|
| 176 |
+
optimizer.zero_grad()
|
| 177 |
+
loss.backward()
|
| 178 |
+
optimizer.step()
|
| 179 |
+
|
| 180 |
+
train_loss += loss.item()
|
| 181 |
+
|
| 182 |
+
model.eval()
|
| 183 |
+
val_ade = 0.0
|
| 184 |
+
val_fde = 0.0
|
| 185 |
+
batches = 0
|
| 186 |
+
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
for obs, neighbors, fusion_obs, future in val_loader:
|
| 189 |
+
obs = obs.to(device)
|
| 190 |
+
fusion_obs = fusion_obs.to(device)
|
| 191 |
+
future = future.to(device)
|
| 192 |
+
|
| 193 |
+
pred, goals, probs, _ = model(obs, neighbors, fusion_obs)
|
| 194 |
+
|
| 195 |
+
gt = future.unsqueeze(1)
|
| 196 |
+
err = torch.norm(pred - gt, dim=3).mean(dim=2)
|
| 197 |
+
best_idx = torch.argmin(err, dim=1)
|
| 198 |
+
best_pred = pred[torch.arange(pred.size(0), device=device), best_idx]
|
| 199 |
+
|
| 200 |
+
val_ade += compute_ade(best_pred, future).item()
|
| 201 |
+
val_fde += compute_fde(best_pred, future).item()
|
| 202 |
+
batches += 1
|
| 203 |
+
|
| 204 |
+
val_ade = val_ade / max(1, batches)
|
| 205 |
+
val_fde = val_fde / max(1, batches)
|
| 206 |
+
|
| 207 |
+
scheduler.step(val_ade)
|
| 208 |
+
curr_lr_base = optimizer.param_groups[0]['lr']
|
| 209 |
+
curr_lr_fusion = optimizer.param_groups[1]['lr']
|
| 210 |
+
|
| 211 |
+
log_print(f"Epoch {epoch + 1}/{args.epochs}")
|
| 212 |
+
log_print(f"Train Loss: {train_loss:.4f}")
|
| 213 |
+
log_print(f"Val ADE: {val_ade:.4f} | Val FDE: {val_fde:.4f}")
|
| 214 |
+
log_print(f"LR base={curr_lr_base:.6f} | fusion={curr_lr_fusion:.6f}")
|
| 215 |
+
log_print("-" * 44)
|
| 216 |
+
|
| 217 |
+
if val_ade < best_val_ade:
|
| 218 |
+
best_val_ade = val_ade
|
| 219 |
+
patience_counter = 0
|
| 220 |
+
torch.save(model.state_dict(), output_checkpoint)
|
| 221 |
+
log_print(f"New best fusion model saved: {output_checkpoint}")
|
| 222 |
+
else:
|
| 223 |
+
patience_counter += 1
|
| 224 |
+
|
| 225 |
+
if patience_counter >= args.patience:
|
| 226 |
+
log_print(f"Early stopping at epoch {epoch + 1} (patience reached).")
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
log_print("Phase 2 fusion transfer-learning complete.")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
parser = argparse.ArgumentParser(description="Phase 2: LiDAR/Radar Fusion Transfer-Learning")
|
| 234 |
+
parser.add_argument("--epochs", type=int, default=20)
|
| 235 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 236 |
+
parser.add_argument("--base-lr", type=float, default=2e-4)
|
| 237 |
+
parser.add_argument("--fusion-lr", type=float, default=8e-4)
|
| 238 |
+
parser.add_argument("--patience", type=int, default=8)
|
| 239 |
+
parser.add_argument("--max-samples", type=int, default=0, help="Use first N samples for quick debug run. 0 = full data.")
|
| 240 |
+
parser.add_argument("--base-checkpoint", type=str, default="models/best_social_model.pth")
|
| 241 |
+
parser.add_argument("--output-checkpoint", type=str, default="models/best_social_model_fusion.pth")
|
| 242 |
+
args = parser.parse_args()
|
| 243 |
+
|
| 244 |
+
train_phase2(args)
|