Spaces:
Running
Running
Commit ·
2e50ccd
0
Parent(s):
Add flight booking website (Google Flights clone)
Browse filesFull-stack flight search app for computer-use agent testing:
- Backend: FastAPI with deterministic pricing engine, route finder
(direct + 1-stop + 2-stop via hub detection), timezone-aware
flight generation, airport autocomplete, calendar pricing
- Frontend: React + TypeScript + Tailwind CSS with search form,
autocomplete, results page, filters, sorting, URL-driven state
- Docker: Multi-stage build (Node frontend → Python backend)
- Data: 3,770 airports, 55,627 routes, 604 airlines from
airline_routes.json loaded in-memory (~21 MB)
- All elements have data-testid attributes for agent testing
- Same search params always produce same results (SHA-256 seeded)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +10 -0
- .gitattributes +1 -0
- .gitignore +8 -0
- CLAUDE.md +518 -0
- Dockerfile +26 -0
- JOURNAL.md +0 -0
- airline_routes.json +3 -0
- backend/__init__.py +0 -0
- backend/api/__init__.py +0 -0
- backend/api/airports.py +47 -0
- backend/api/calendar.py +70 -0
- backend/api/search.py +144 -0
- backend/config.py +93 -0
- backend/data_loader.py +164 -0
- backend/flight_generator.py +270 -0
- backend/hub_detector.py +52 -0
- backend/main.py +59 -0
- backend/models.py +145 -0
- backend/price_engine.py +113 -0
- backend/requirements.txt +3 -0
- backend/route_finder.py +141 -0
- backend/seed_utils.py +20 -0
- description.md +1122 -0
- docker-compose.yml +6 -0
- frontend/eslint.config.js +23 -0
- frontend/index.html +13 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +33 -0
- frontend/public/vite.svg +1 -0
- frontend/src/App.css +42 -0
- frontend/src/App.tsx +16 -0
- frontend/src/api/client.ts +39 -0
- frontend/src/api/types.ts +90 -0
- frontend/src/assets/react.svg +1 -0
- frontend/src/components/results/FilterPanel.tsx +133 -0
- frontend/src/components/results/FlightCard.tsx +112 -0
- frontend/src/components/results/FlightSegment.tsx +42 -0
- frontend/src/components/results/NoResults.tsx +29 -0
- frontend/src/components/results/SortBar.tsx +40 -0
- frontend/src/components/search/AirportInput.tsx +107 -0
- frontend/src/components/search/ClassSelector.tsx +30 -0
- frontend/src/components/search/DatePicker.tsx +22 -0
- frontend/src/components/search/PassengerSelector.tsx +86 -0
- frontend/src/components/search/SearchForm.tsx +114 -0
- frontend/src/components/search/SwapButton.tsx +18 -0
- frontend/src/components/search/TripTypeSelector.tsx +33 -0
- frontend/src/components/shared/Header.tsx +22 -0
- frontend/src/components/shared/Loading.tsx +8 -0
- frontend/src/hooks/useDebounce.ts +12 -0
- frontend/src/hooks/useFlightSearch.ts +70 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
node_modules
|
| 2 |
+
frontend/node_modules
|
| 3 |
+
frontend/dist
|
| 4 |
+
.venv
|
| 5 |
+
.venv_*
|
| 6 |
+
.git
|
| 7 |
+
__pycache__
|
| 8 |
+
*.pyc
|
| 9 |
+
.uv_cache
|
| 10 |
+
.uv_pythons
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
airline_routes.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
node_modules/
|
| 2 |
+
frontend/dist/
|
| 3 |
+
.venv/
|
| 4 |
+
.venv_*/
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
.uv_cache/
|
| 8 |
+
.uv_pythons/
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md — Rules for AI Assistants (ECMoE Project)
|
| 2 |
+
|
| 3 |
+
## MANDATORY FIRST STEPS
|
| 4 |
+
|
| 5 |
+
**Before taking ANY action on a task, you MUST:**
|
| 6 |
+
|
| 7 |
+
1. Tell the user you have read CLAUDE.md and how you'll follow the THREE RULES
|
| 8 |
+
2. **Actually read these files** (not optional):
|
| 9 |
+
- **README.md** — Directory structure, setup, how to run experiments
|
| 10 |
+
- **JOURNAL.md** — Recent bugs, what's broken/fixed, latest results
|
| 11 |
+
- **description.md** — Detailed method descriptions, design choices, hyperparameters
|
| 12 |
+
|
| 13 |
+
**Do NOT skip this to "get to work faster."** Skipping causes you to use wrong directories, miss known issues, and waste time on already-solved problems.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## THE THREE RULES
|
| 18 |
+
|
| 19 |
+
### 1. EDIT, NEVER REWRITE
|
| 20 |
+
- **ALWAYS edit existing code, NEVER rewrite from scratch**
|
| 21 |
+
- Find the exact file/function, make surgical changes with Edit tool
|
| 22 |
+
- If you're about to write 50+ lines of new code doing something similar to existing code, STOP
|
| 23 |
+
- Reuse existing classes: `Compressor`, `Decompressor`, `StaleDecompressor`, `train_compressor`, etc.
|
| 24 |
+
|
| 25 |
+
### 2. VALIDATE DATA BEFORE PLOTTING
|
| 26 |
+
- Always load results from JSON files, never hardcode values
|
| 27 |
+
- If a number looks different than expected, investigate before proceeding
|
| 28 |
+
- Check `results/summary/all_results_summary.json` for the canonical results
|
| 29 |
+
|
| 30 |
+
### 3. COMMIT AND DOCUMENT IMMEDIATELY
|
| 31 |
+
- `git commit` after every fix (no remote configured — push when available)
|
| 32 |
+
- Update `JOURNAL.md` right after committing
|
| 33 |
+
- Don't batch changes — commit as you go
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## MINDSET: NO SHORTCUTS
|
| 38 |
+
|
| 39 |
+
- Academic rigor means doing things RIGHT, not just doing things FAST
|
| 40 |
+
- Be skeptical of your own first approach — question whether it could be better
|
| 41 |
+
- Don't simplify the requirement — solve the actual problem
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Communication
|
| 46 |
+
|
| 47 |
+
**When showing results or finishing tasks:**
|
| 48 |
+
- ALWAYS provide the **full absolute path** to any files created or modified
|
| 49 |
+
- Example: "View the result at: `/project/6004852/lfy/ECMoE/results/summary/ppl_vs_ratio_all.png`"
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## Project-Specific Rules
|
| 54 |
+
|
| 55 |
+
### Environment Setup (Compute Canada)
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
# Modules MUST be loaded BEFORE activating venv
|
| 59 |
+
module load cuda/12.6 arrow/22.0.0
|
| 60 |
+
source .venv/bin/activate
|
| 61 |
+
|
| 62 |
+
# HuggingFace cache goes to persistent project dir (home quota is small)
|
| 63 |
+
export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Directory Structure
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
src/ # Python source code
|
| 70 |
+
scripts/ # Bash wrappers for each experiment
|
| 71 |
+
results/ # ALL experiment outputs (gitignored)
|
| 72 |
+
01_distribution/ # Task 1: distribution analysis
|
| 73 |
+
02_quantization/ # Task 2: quantization baseline
|
| 74 |
+
03_neural_compressor/ # Task 3: shared neural compressor
|
| 75 |
+
03b_perlayer_compressor/ # Task 3b: per-layer neural compressor
|
| 76 |
+
04a_stale_compressed/ # Task 4a: stale-conditioned (compressed stale)
|
| 77 |
+
04b_stale_uncompressed/ # Task 4b: stale-conditioned (uncompressed stale)
|
| 78 |
+
05a_e2e_perlayer/ # Task 5a: e2e per-layer compressor (no stale)
|
| 79 |
+
05b_e2e_stale/ # Task 5b: e2e stale-conditioned compressor
|
| 80 |
+
05c_e2e_baseline/ # Task 5c: baseline (no compression, same pipeline)
|
| 81 |
+
05c_megatron_e2e_baseline/ # Task 5c: baseline (Megatron variant)
|
| 82 |
+
06a_megatron_e2e_pretrained_perlayer/ # Task 6a: e2e with 3b init (Megatron)
|
| 83 |
+
06b_megatron_e2e_pretrained_stale/ # Task 6b: e2e with 4b init (Megatron)
|
| 84 |
+
07a_megatron_e2e_split_perlayer/ # Task 7a: split-mode e2e (router=original)
|
| 85 |
+
07b_megatron_e2e_split_stale/ # Task 7b: split-mode e2e + stale
|
| 86 |
+
08_ep_compression/ # Task 8: EP compression eval (uses 7a/7b weights)
|
| 87 |
+
summary/ # Cross-method comparison plots and tables
|
| 88 |
+
data/hidden_states/ # Cached MoE hidden states (gitignored, ~37 GB in bfloat16)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Key Code Architecture
|
| 92 |
+
|
| 93 |
+
- **`src/model_utils.py`** — Central library: model loading, MoE detection, hidden state
|
| 94 |
+
collection, ALL perplexity evaluation functions (baseline, shared, per-layer, stale)
|
| 95 |
+
- **`src/metrics.py`** — Reconstruction metrics: MSE, cosine similarity, relative error, SNR
|
| 96 |
+
- **`src/run_neural_compressor.py`** — Defines `Compressor`, `Decompressor`, `train_compressor()`.
|
| 97 |
+
Other scripts import from here — never duplicate these classes
|
| 98 |
+
- **`src/run_stale_compressor.py`** — Defines `StaleDecompressor`, `train_stale_compressor()`
|
| 99 |
+
- **`src/run_e2e_compressor.py`** — End-to-end training of per-layer compressors via LM loss.
|
| 100 |
+
Defines `E2ECompressorManager`, `SFTDataset`. Uses Dolci-Instruct-SFT with SFT mode
|
| 101 |
+
(response-only training). `_tokenize_sft_sample()` in `model_utils.py` handles the
|
| 102 |
+
response-only label masking.
|
| 103 |
+
- **`src/vllm_ep_compression.py`** — EP-aware compress/decompress registration for vLLM.
|
| 104 |
+
Sets `_ecmoe_compress_fn` / `_ecmoe_decompress_fn` on FusedMoE instances via
|
| 105 |
+
`apply_model()`. Supports per-layer and stale-conditioned methods. Requires patched
|
| 106 |
+
vLLM (`.venv_vllm_exp`).
|
| 107 |
+
- **`src/run_ep_compression_eval.py`** — Task 8 entry point: evaluates EP compression
|
| 108 |
+
with actual dispatch/combine in vLLM. Two modes: `simulation` (single-GPU) and `ep`
|
| 109 |
+
(multi-GPU with `enable_expert_parallel=True`). Uses Task 7a/7b weights.
|
| 110 |
+
- **`src/visualize_all_results.py`** — Generates all cross-method comparison plots and tables
|
| 111 |
+
- **`src/downstream_eval.py`** — Shared utility for downstream task evaluation via lm-eval-harness.
|
| 112 |
+
Provides hook registration functions (`register_quantization_hooks`, `register_perlayer_hooks`,
|
| 113 |
+
`register_stale_hooks`, `register_e2e_hooks`), `run_lm_eval()` wrapper, and result saving.
|
| 114 |
+
Imported by each task script when `--downstream-tasks` is specified.
|
| 115 |
+
Also provides vLLM backend support via apply_model pattern: `create_vllm_backend()`,
|
| 116 |
+
`register_perlayer_hooks_vllm()`, `register_stale_hooks_vllm()`,
|
| 117 |
+
`register_quantization_hooks_vllm()`, `remove_hooks_vllm()`.
|
| 118 |
+
Split (router-uncompressed) mode: `register_perlayer_hooks_split()`,
|
| 119 |
+
`register_stale_hooks_split()` for HF, and `register_perlayer_hooks_split_vllm()`,
|
| 120 |
+
`register_stale_hooks_split_vllm()` for vLLM. In split mode, the router sees original
|
| 121 |
+
hidden states while experts see decompressed — more realistic EP simulation.
|
| 122 |
+
- **`src/run_all_downstream.py`** — Standalone downstream evaluator. Loads model once,
|
| 123 |
+
evaluates all methods sequentially. Supports `--backend hf/vllm` and
|
| 124 |
+
`--router-mode compressed/uncompressed`.
|
| 125 |
+
|
| 126 |
+
### Known Issues / Gotchas
|
| 127 |
+
|
| 128 |
+
**Layer sorting:** Always use `sorted(keys, key=layer_index)` from `model_utils`. Lexicographic
|
| 129 |
+
sorting puts layer 10 before layer 2 (`model.layers.10` < `model.layers.2`).
|
| 130 |
+
|
| 131 |
+
**Dtype mismatch:** Dequantized tensors and neural compressor outputs must match the model's
|
| 132 |
+
activation dtype (bfloat16). Always cast: `.to(x.dtype).to(x.device)`.
|
| 133 |
+
|
| 134 |
+
**What went wrong (2026-02-11):** `absmax_dequantize` returned float32 but model expected
|
| 135 |
+
bfloat16, causing `RuntimeError` during perplexity eval. Fix: explicit `.to(scale.dtype)` cast.
|
| 136 |
+
|
| 137 |
+
**What went wrong (2026-02-11):** When asked to remove quantization for Tasks 1–4, the agent
|
| 138 |
+
implemented the change (default `load_in_4bit=False`, `device="auto"`) without the user having
|
| 139 |
+
specified this as a hyperparameter. The model loading precision (BF16 vs 4-bit NF4) is a key
|
| 140 |
+
experimental parameter — changing it retroactively means old results are no longer reproducible
|
| 141 |
+
with default settings. **Lesson:** Treat model loading precision as a hyperparameter. Do NOT
|
| 142 |
+
change defaults that affect reproducibility without explicit user instruction. When the user says
|
| 143 |
+
"remove quantization", ASK whether they want it as a new default or as a CLI override.
|
| 144 |
+
|
| 145 |
+
**Response-only hidden state collection:** `collect_hidden_states()` defaults to
|
| 146 |
+
`response_only=True` — only assistant-response tokens are captured (labels != -100).
|
| 147 |
+
This ensures offline compressor training (Tasks 2–4) trains on the same distribution
|
| 148 |
+
that PPL evaluation measures. Use `--no-response-only` in `run_distribution.py` for
|
| 149 |
+
legacy all-token collection. Metadata records `"response_only": true/false`.
|
| 150 |
+
|
| 151 |
+
**Legacy Megatron script deleted:** `src/run_megatron_e2e_compressor.py` was removed because
|
| 152 |
+
it used `PackedTokenDataset` + `labels=input_ids` (standard LM, not SFT response-only),
|
| 153 |
+
did not use `get_split_indices()`, and misreported effective batch size with DP > 1.
|
| 154 |
+
Always use `src/megatron_e2e/train.py` for Megatron-based training.
|
| 155 |
+
|
| 156 |
+
**Large data files:** Hidden states for 100K tokens are ~18.5 GB per file in bfloat16
|
| 157 |
+
(dispatch + gather = ~37 GB). These are gitignored. Never try to `git add` them.
|
| 158 |
+
|
| 159 |
+
**Model VRAM:** Model is loaded in full BF16 (~60 GB). Tasks 1–4 use single GPU
|
| 160 |
+
(`device="cuda:0"`) — the model fits on one H100 80 GB with headroom for inference.
|
| 161 |
+
Task 5 uses multi-GPU (`device_map="auto"`) because backprop needs extra VRAM.
|
| 162 |
+
4-bit NF4 loading (~15 GB) is available via `--load-in-4bit` but is NOT the default.
|
| 163 |
+
|
| 164 |
+
**device="auto" vs tensor ops:** When `device="auto"` is used for model loading (Task 5),
|
| 165 |
+
`"auto"` is NOT a valid torch device for tensor operations. Scripts that do `.to(device)` or
|
| 166 |
+
`train_compressor(device=...)` must use `compute_device` (resolved to `"cuda:0"` when
|
| 167 |
+
`device="auto"`). Only `load_model_and_tokenizer()` accepts `"auto"` directly.
|
| 168 |
+
Tasks 1–4 default to `device="cuda:0"` so this is only relevant for Task 5.
|
| 169 |
+
|
| 170 |
+
**Hook device safety (2026-02-17):** With `device_map="auto"`, model layers may reside on
|
| 171 |
+
different GPUs. PPL evaluation hooks in `model_utils.py` now explicitly call `.to(x.device)`
|
| 172 |
+
on compressor/decompressor outputs before returning them to the model. This is a no-op when
|
| 173 |
+
compressor and layer are on the same device but prevents cross-device errors when they differ.
|
| 174 |
+
|
| 175 |
+
### vLLM Environment (downstream evaluation)
|
| 176 |
+
|
| 177 |
+
**vLLM backend:** `src/downstream_eval.py` + `src/run_all_downstream.py` — vLLM 0.8.4+
|
| 178 |
+
for downstream task evaluation with compression hooks.
|
| 179 |
+
|
| 180 |
+
```bash
|
| 181 |
+
# Separate venv from HF-based experiments — CUDA 12.6
|
| 182 |
+
module load cuda/12.6 arrow/22.0.0
|
| 183 |
+
source .venv_vllm/bin/activate
|
| 184 |
+
export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
|
| 185 |
+
|
| 186 |
+
# Setup (first time only):
|
| 187 |
+
bash scripts/vllm_setup_env.sh
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**Known issues / gotchas (vLLM):**
|
| 191 |
+
- **vLLM V1 engine (>= 0.15):** The model runs in a **separate subprocess** (EngineCore).
|
| 192 |
+
You CANNOT access the model directly from the main process. The old path
|
| 193 |
+
`llm_engine.model_executor.driver_worker.model_runner.model` does NOT work.
|
| 194 |
+
Instead, use `vllm.LLM.apply_model(func)` to send functions to the worker process.
|
| 195 |
+
Functions are serialized via cloudpickle — they must be self-contained (include their
|
| 196 |
+
own imports and class definitions). Requires `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.
|
| 197 |
+
`create_vllm_backend()` sets this automatically.
|
| 198 |
+
- **enforce_eager=True required:** vLLM's CUDA graph capture prevents PyTorch hooks
|
| 199 |
+
from being called. Always use `enforce_eager=True` when registering compression hooks.
|
| 200 |
+
`create_vllm_backend()` sets this automatically.
|
| 201 |
+
- **Hook registration pattern:** All vLLM hook functions use the apply_model pattern:
|
| 202 |
+
`_vllm_register_perlayer()` returns a closure → `vllm_llm.apply_model(closure)`.
|
| 203 |
+
The closure runs inside the worker, loads weights, creates compressor modules,
|
| 204 |
+
and registers PyTorch pre-hooks. Cleanup via `_vllm_remove_hooks()` → `remove_hooks_vllm()`.
|
| 205 |
+
- **Layer name mapping:** vLLM may use different module paths than HF. `_map_layer_name()`
|
| 206 |
+
maps by numeric layer index, which is robust to naming differences.
|
| 207 |
+
- **Two router modes (--router-mode):**
|
| 208 |
+
- `compressed` (default): Pre-hook compress→decompress. Router AND experts see
|
| 209 |
+
decompressed. Conservative lower bound — same as the original PPL evaluation hooks.
|
| 210 |
+
- `uncompressed`: Split forward — router sees ORIGINAL input, experts see decompressed.
|
| 211 |
+
More realistic EP simulation where router runs on source GPU with original data.
|
| 212 |
+
Both modes work for HF and vLLM backends.
|
| 213 |
+
- **No multi-device placement:** The plan called for `compressor_device` (attention GPU)
|
| 214 |
+
vs `decompressor_devices` (expert GPUs) to simulate the actual communication topology.
|
| 215 |
+
Current implementation puts both compressor and decompressor on the same device. This
|
| 216 |
+
doesn't affect quality measurement (the math is device-independent) but doesn't
|
| 217 |
+
demonstrate the real communication pattern or measure cross-device overhead.
|
| 218 |
+
- **No shared expert handling:** Split mode omits `shared_expert` /
|
| 219 |
+
`shared_expert_gate` logic. Qwen3-30B-A3B doesn't use shared experts so this is
|
| 220 |
+
correct for the current model, but reduces generality.
|
| 221 |
+
- **No separate E2E hooks for vLLM:** E2E and offline weights have identical format.
|
| 222 |
+
`register_perlayer_hooks_vllm()` works for 3b + 5a + 6a weights.
|
| 223 |
+
`register_stale_hooks_vllm()` works for 4a/4b + 5b + 6b weights.
|
| 224 |
+
- **TP > 1 with vLLM:** When using tensor parallelism, each rank has a partial model.
|
| 225 |
+
Hook registration should still work (hooks are on the full module), but compressor
|
| 226 |
+
modules stay on one device. Tested with TP=1 by default.
|
| 227 |
+
|
| 228 |
+
**vLLM-specific directories:**
|
| 229 |
+
```
|
| 230 |
+
.venv_vllm/ # Separate virtual environment (gitignored)
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
### vLLM EP Compression Environment (Task 8)
|
| 234 |
+
|
| 235 |
+
**EP compression:** `src/vllm_ep_compression.py` — Sets compress/decompress functions
|
| 236 |
+
on FusedMoE instances. Patched `forward_impl()` calls compress BEFORE dispatch and
|
| 237 |
+
decompress AFTER, achieving real communication reduction.
|
| 238 |
+
|
| 239 |
+
```bash
|
| 240 |
+
# Separate venv with patched vLLM 0.15.1 — CUDA 12.6
|
| 241 |
+
module load cuda/12.6 arrow/22.0.0
|
| 242 |
+
source .venv_vllm_exp/bin/activate
|
| 243 |
+
export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
|
| 244 |
+
|
| 245 |
+
# Setup (first time only):
|
| 246 |
+
bash scripts/vllm_exp_setup_env.sh
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
**Key differences from .venv_vllm:**
|
| 250 |
+
- vLLM 0.15.1 pinned (for patch compatibility)
|
| 251 |
+
- `FusedMoE.forward_impl()` patched with 3 insertion points (~12 lines)
|
| 252 |
+
- Uses `_ecmoe_compress_fn` / `_ecmoe_decompress_fn` attributes (not PyTorch hooks)
|
| 253 |
+
- Supports `enable_expert_parallel=True` for actual EP dispatch
|
| 254 |
+
|
| 255 |
+
**Known issues / gotchas (EP compression):**
|
| 256 |
+
- **allgather_reducescatter backend:** vLLM's default `all2all_backend`. After dispatch,
|
| 257 |
+
every rank has ALL tokens. Stale cache approach works because token ordering is
|
| 258 |
+
consistent across layers.
|
| 259 |
+
- **Router unaffected:** `router_logits` are computed at `Qwen3MoeSparseMoeBlock.forward()`
|
| 260 |
+
BEFORE `FusedMoE.forward_impl()`, so compression never affects routing decisions.
|
| 261 |
+
- **Stale piggybacking:** Reference layers concatenate `cat(compressed, stale)` before
|
| 262 |
+
dispatch. After dispatch, decompress_fn splits and caches stale globally. Non-reference
|
| 263 |
+
layers dispatch only compressed (max compression), retrieve cached stale for decompression.
|
| 264 |
+
|
| 265 |
+
**vLLM EP compression directories:**
|
| 266 |
+
```
|
| 267 |
+
.venv_vllm_exp/ # Patched vLLM environment (gitignored)
|
| 268 |
+
results/08_ep_compression/ # EP eval results
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### Megatron-LM Environment (Task 5 Megatron variant)
|
| 272 |
+
|
| 273 |
+
**Megatron implementation:** `src/megatron_e2e/` package — EP-first, CUDA 12.9, Megatron Bridge.
|
| 274 |
+
(Legacy `src/run_megatron_e2e_compressor.py` was deleted due to SFT/split/batch bugs.)
|
| 275 |
+
|
| 276 |
+
```bash
|
| 277 |
+
# Separate venv from HF-based experiments — CUDA 12.9 required
|
| 278 |
+
module load cuda/12.9 nccl arrow/22.0.0
|
| 279 |
+
source .venv_megatron/bin/activate
|
| 280 |
+
export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
|
| 281 |
+
|
| 282 |
+
# Setup (first time only):
|
| 283 |
+
bash scripts/megatron_setup_env.sh
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
**Key differences from HF environment:**
|
| 287 |
+
- Uses `megatron-core` >=0.15.0 for model parallelism (EP, TP, DP, PP)
|
| 288 |
+
- Requires Transformer Engine (for Megatron Bridge and fused kernels)
|
| 289 |
+
- Uses `megatron-bridge` >=0.2.0 for HF→Megatron weight conversion
|
| 290 |
+
- Default parallelism: EP=4, TP=1, PP=1 (expert parallelism, not tensor)
|
| 291 |
+
- Launch via `torchrun`, not `python`
|
| 292 |
+
|
| 293 |
+
**Megatron-specific directories:**
|
| 294 |
+
```
|
| 295 |
+
src/megatron_e2e/ # Package-based implementation (recommended)
|
| 296 |
+
.venv_megatron/ # Separate virtual environment (gitignored)
|
| 297 |
+
.uv_cache/ # uv cache on project disk (gitignored)
|
| 298 |
+
.uv_pythons/ # uv Python installs (gitignored)
|
| 299 |
+
third_party/ # Apex, etc. (gitignored, legacy only)
|
| 300 |
+
data/megatron_dolci/ # Preprocessed binary dataset (gitignored)
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
**Known issues / gotchas (Megatron):**
|
| 304 |
+
- **CUDA version:** Megatron Bridge requires CUDA >= 12.8. Use `cuda/12.9` module
|
| 305 |
+
on Compute Canada, NOT `cuda/12.6`.
|
| 306 |
+
- **EP vs TP:** Default is EP=4 (expert parallelism). With EP, each GPU holds 32/128
|
| 307 |
+
experts per layer. TP=4 is the legacy approach and splits attention heads across GPUs.
|
| 308 |
+
- **Megatron layer names** differ from HF: `decoder.layers.N.mlp` vs `model.layers.N.mlp`.
|
| 309 |
+
`_megatron_to_hf_layer_name()` in `compressor_manager.py` handles conversion.
|
| 310 |
+
- Compressor weights are replicated across all ranks (not sharded), since they
|
| 311 |
+
are tiny (~200M total). Saved from rank 0 only.
|
| 312 |
+
- With EP>1, compressor is on source GPU (attention side), decompressor on
|
| 313 |
+
destination GPU (expert side) — different devices.
|
| 314 |
+
- `MegatronModelWrapper` bridges Megatron's forward interface to HF-style
|
| 315 |
+
`SimpleNamespace(loss=..., logits=...)`. Uses `vocab_parallel_cross_entropy`
|
| 316 |
+
for correct loss with TP > 1. SFT labels (-100) are clamped to 0 before
|
| 317 |
+
calling `vocab_parallel_cross_entropy`, and loss is masked via
|
| 318 |
+
`(per_token_loss * loss_mask).sum() / num_valid`.
|
| 319 |
+
- DistributedSampler must use DP rank/size (via `get_dp_info()`), NOT global
|
| 320 |
+
world size. All ranks in a TP group must see the SAME data.
|
| 321 |
+
- Saved weights use HF layer names (`model.layers.N.mlp`) for compatibility
|
| 322 |
+
with HF `E2ECompressorManager.load_weights()`.
|
| 323 |
+
- **Model loading:** `train.py` tries AutoBridge → MegatronBridge → manual fallback
|
| 324 |
+
for HF→Megatron conversion. If Bridge is not installed, falls back to manual
|
| 325 |
+
weight conversion using `load_megatron_qwen3()` from legacy code.
|
| 326 |
+
- **Train loss DP reduction (2026-02-17):** `train.py` now all-reduces step-level and
|
| 327 |
+
epoch-level train loss across DP ranks before logging. Previously, only rank 0's local
|
| 328 |
+
shard loss was logged, which was inaccurate with DP > 1. Wandb `train/loss` and
|
| 329 |
+
`train/epoch_loss` now reflect the true DP-averaged loss.
|
| 330 |
+
|
| 331 |
+
### Running Experiments
|
| 332 |
+
|
| 333 |
+
Task 1 must run first (caches hidden states for Tasks 2–4). Task 5 is independent.
|
| 334 |
+
Tasks 1–4 use 1 GPU each; Task 5a/5b use 4 GPUs each.
|
| 335 |
+
|
| 336 |
+
**Data selection:** All tasks use seed=42 for reproducible 80/10/10 train/val/test
|
| 337 |
+
split of dataset rows. Tasks 1–4 draw from TRAIN split, PPL evaluation from TEST
|
| 338 |
+
split. No data leakage between splits.
|
| 339 |
+
|
| 340 |
+
**Task 5 config (HF):** batch_size=2, grad_accum=8 (effective=16), max_sequences=500K,
|
| 341 |
+
max_length=2048, val_interval=2500 steps, val_batch_size=8, SFT mode
|
| 342 |
+
(response-only training), wandb enabled by default.
|
| 343 |
+
|
| 344 |
+
**Task 5/6 config (Megatron):** Same as HF except max_sequences=100K,
|
| 345 |
+
val_interval=1000 steps. Task 6 uses same Megatron config with `--init-weights-dir`.
|
| 346 |
+
|
| 347 |
+
Tail micro-batches (when `len(dataloader) % grad_accum != 0`) are handled by rescaling
|
| 348 |
+
accumulated gradients and performing the optimizer step.
|
| 349 |
+
|
| 350 |
+
**Two evaluation stages:** Training-time val loss uses the VAL split (50K seqs,
|
| 351 |
+
batch_size=8, every 2500 steps) for checkpoint selection and wandb monitoring.
|
| 352 |
+
Final PPL evaluation uses the TEST split (50K seqs, batch_size=1, in
|
| 353 |
+
`model_utils.py`) for reported results. Different code paths — `--val-batch-size`
|
| 354 |
+
only affects training-time eval.
|
| 355 |
+
|
| 356 |
+
**SFT data loading:** All E2E training (Task 5) and perplexity evaluation now use
|
| 357 |
+
SFT mode: each sample is one conversation, tokenized independently. Labels are
|
| 358 |
+
-100 for non-assistant tokens (system, user, template markup) and actual token
|
| 359 |
+
IDs for assistant responses. Loss and perplexity are computed on response tokens
|
| 360 |
+
only. Data is loaded by sampling N sequences from the dataset (not packing tokens).
|
| 361 |
+
`_tokenize_sft_sample()` in `model_utils.py` handles the tokenization.
|
| 362 |
+
|
| 363 |
+
```bash
|
| 364 |
+
# Phase 1: Megatron 5a + 5b in parallel (8 GPUs)
|
| 365 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh none &
|
| 366 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_megatron_e2e.sh uncompressed &
|
| 367 |
+
wait
|
| 368 |
+
|
| 369 |
+
# Phase 2: Task 1 (re-cache with seed=42)
|
| 370 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/01_analyze_distribution.sh
|
| 371 |
+
|
| 372 |
+
# Phase 3: Tasks 2-4 + HF 5a (parallel)
|
| 373 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/02_run_quantization.sh &
|
| 374 |
+
CUDA_VISIBLE_DEVICES=1 bash scripts/03_run_neural_compressor.sh &
|
| 375 |
+
CUDA_VISIBLE_DEVICES=2 bash scripts/03b_run_perlayer_compressor.sh &
|
| 376 |
+
CUDA_VISIBLE_DEVICES=3 bash scripts/04_run_stale_compressor.sh compressed &
|
| 377 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_run_e2e_compressor.sh none &
|
| 378 |
+
wait
|
| 379 |
+
|
| 380 |
+
# Phase 4: Task 4b + HF 5b (parallel)
|
| 381 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/04_run_stale_compressor.sh uncompressed &
|
| 382 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_run_e2e_compressor.sh uncompressed &
|
| 383 |
+
wait
|
| 384 |
+
|
| 385 |
+
# Megatron-based E2E training (alternative to HF Task 5):
|
| 386 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh none # 5a
|
| 387 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh uncompressed # 5b
|
| 388 |
+
|
| 389 |
+
# Task 5c: Baseline evaluation (no compression, same pipeline):
|
| 390 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_run_e2e_compressor.sh baseline # HF
|
| 391 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh baseline # Megatron
|
| 392 |
+
|
| 393 |
+
# Task 6a/6b: E2E with pretrained init (requires Task 3b/4b weights):
|
| 394 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/06_megatron_e2e_pretrained.sh none & # 6a (init from 3b)
|
| 395 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/06_megatron_e2e_pretrained.sh uncompressed & # 6b (init from 4b)
|
| 396 |
+
wait
|
| 397 |
+
|
| 398 |
+
# Task 7a/7b: Split-mode E2E (router sees original, experts see decompressed):
|
| 399 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/07_megatron_e2e_split.sh none & # 7a (init from 3b)
|
| 400 |
+
CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/07_megatron_e2e_split.sh uncompressed & # 7b (init from 4b)
|
| 401 |
+
wait
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
### Downstream Task Evaluation (lm-eval-harness)
|
| 405 |
+
|
| 406 |
+
Downstream eval is triggered by setting `DOWNSTREAM_TASKS` before running any script.
|
| 407 |
+
It runs **after** the existing PPL evaluation step, using `lm-eval-harness` with the
|
| 408 |
+
same compression hooks active. Results saved to `downstream_results.json` in each
|
| 409 |
+
task's output directory.
|
| 410 |
+
|
| 411 |
+
```bash
|
| 412 |
+
# Run Task 2 + PPL eval + downstream eval:
|
| 413 |
+
DOWNSTREAM_TASKS="gsm8k_cot" bash scripts/02_run_quantization.sh
|
| 414 |
+
|
| 415 |
+
# Run Task 5a + PPL eval + downstream eval:
|
| 416 |
+
DOWNSTREAM_TASKS="gsm8k_cot" bash scripts/05_run_e2e_compressor.sh none
|
| 417 |
+
|
| 418 |
+
# Eval-only mode + downstream:
|
| 419 |
+
DOWNSTREAM_TASKS="gsm8k_cot" python src/run_e2e_compressor.py \
|
| 420 |
+
--skip-training --output-dir results/05a_e2e_perlayer --stale-mode none
|
| 421 |
+
|
| 422 |
+
# Smoke test with 10 examples:
|
| 423 |
+
DOWNSTREAM_TASKS="gsm8k_cot" DOWNSTREAM_LIMIT=10 bash scripts/05_run_e2e_compressor.sh none
|
| 424 |
+
```
|
| 425 |
+
|
| 426 |
+
**Key code:** `src/downstream_eval.py` provides `register_*_hooks()` for each method,
|
| 427 |
+
`run_lm_eval()` wrapper, and `save_downstream_results()`. Each task script imports from
|
| 428 |
+
it when `--downstream-tasks` is specified. GSM8K variant: `gsm8k_cot` (8-shot CoT).
|
| 429 |
+
|
| 430 |
+
**vLLM backend:** Use `--backend vllm` (or `DOWNSTREAM_BACKEND=vllm`) for vLLM-based
|
| 431 |
+
downstream evaluation. Two router modes (`--router-mode compressed/uncompressed`):
|
| 432 |
+
|
| 433 |
+
```bash
|
| 434 |
+
# Standalone vLLM eval (all methods, default router=compressed):
|
| 435 |
+
source .venv_vllm/bin/activate
|
| 436 |
+
python src/run_all_downstream.py --backend vllm --tasks gsm8k_cot
|
| 437 |
+
|
| 438 |
+
# Router-uncompressed mode (split: router sees original, experts see decompressed):
|
| 439 |
+
python src/run_all_downstream.py --backend vllm --router-mode uncompressed --method e2e_perlayer --tasks gsm8k_cot
|
| 440 |
+
|
| 441 |
+
# With tensor parallelism:
|
| 442 |
+
python src/run_all_downstream.py --backend vllm --tensor-parallel-size 4 --tasks gsm8k_cot
|
| 443 |
+
|
| 444 |
+
# Via task scripts (HF model, vLLM downstream):
|
| 445 |
+
DOWNSTREAM_TASKS="gsm8k_cot" DOWNSTREAM_BACKEND=vllm bash scripts/05_run_e2e_compressor.sh none
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
### Visualization
|
| 449 |
+
|
| 450 |
+
Regenerate all summary plots and tables:
|
| 451 |
+
```bash
|
| 452 |
+
source .venv/bin/activate
|
| 453 |
+
python src/visualize_all_results.py
|
| 454 |
+
```
|
| 455 |
+
|
| 456 |
+
Outputs to `results/summary/`:
|
| 457 |
+
- `ppl_vs_ratio_all.png` — PPL vs compression ratio (log-log)
|
| 458 |
+
- `reconstruction_vs_ratio_all.png` — MSE and CosSim vs ratio
|
| 459 |
+
- `ppl_bar_practical.png` — Bar chart at 2x and 4x
|
| 460 |
+
- `all_results_summary.json` — Machine-readable summary
|
| 461 |
+
- `param_count_table.{csv,md,json}` — Parameter counts for all methods
|
| 462 |
+
|
| 463 |
+
---
|
| 464 |
+
|
| 465 |
+
## Code Changes
|
| 466 |
+
|
| 467 |
+
**Before changing any code:**
|
| 468 |
+
1. FIND the exact file that produces the current output
|
| 469 |
+
2. READ and understand it
|
| 470 |
+
3. EDIT only the specific lines needed (use Edit tool)
|
| 471 |
+
4. TEST that output matches except for your intended change
|
| 472 |
+
|
| 473 |
+
**Adding new compression methods:**
|
| 474 |
+
- Reuse `Compressor`, `Decompressor` from `run_neural_compressor.py`
|
| 475 |
+
- Reuse `train_compressor()` for standard autoencoder training
|
| 476 |
+
- Add new perplexity evaluation functions to `model_utils.py`
|
| 477 |
+
- Follow the same JSON output format as existing experiments
|
| 478 |
+
- Update `visualize_all_results.py` to include the new method
|
| 479 |
+
|
| 480 |
+
---
|
| 481 |
+
|
| 482 |
+
## NEVER GUESS SILENTLY
|
| 483 |
+
|
| 484 |
+
**When you encounter ambiguity:**
|
| 485 |
+
1. **STOP** — Do not make an arbitrary choice
|
| 486 |
+
2. **ASK** — Present the options to the user
|
| 487 |
+
3. **FLAG** — Note the documentation gap
|
| 488 |
+
4. **FIX** — Update README.md or CLAUDE.md
|
| 489 |
+
|
| 490 |
+
---
|
| 491 |
+
|
| 492 |
+
## Version Control
|
| 493 |
+
|
| 494 |
+
- Commit after EVERY fix (don't wait)
|
| 495 |
+
- Check `git status` and file sizes before committing (no files >100MB)
|
| 496 |
+
- Update JOURNAL.md immediately after committing
|
| 497 |
+
- No git remote is currently configured — commits are local only
|
| 498 |
+
|
| 499 |
+
---
|
| 500 |
+
|
| 501 |
+
## Investigation
|
| 502 |
+
|
| 503 |
+
**When something seems wrong:**
|
| 504 |
+
1. STOP — don't patch the visible symptom
|
| 505 |
+
2. ASK WHY — trace back to data generation
|
| 506 |
+
3. VERIFY — test hypotheses with minimal examples
|
| 507 |
+
4. FIX ROOT — fix the source, not downstream
|
| 508 |
+
|
| 509 |
+
---
|
| 510 |
+
|
| 511 |
+
## Meta-Rule: Continuous Improvement
|
| 512 |
+
|
| 513 |
+
**When a preventable issue occurs:**
|
| 514 |
+
1. Identify the root cause
|
| 515 |
+
2. Add a "What went wrong" example to this file
|
| 516 |
+
3. Commit the improvement
|
| 517 |
+
|
| 518 |
+
This file should evolve based on lessons learned.
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Stage 1: Build frontend
|
| 2 |
+
FROM node:22-alpine AS frontend-build
|
| 3 |
+
WORKDIR /app/frontend
|
| 4 |
+
COPY frontend/package.json frontend/package-lock.json ./
|
| 5 |
+
RUN npm ci
|
| 6 |
+
COPY frontend/ ./
|
| 7 |
+
RUN npm run build
|
| 8 |
+
|
| 9 |
+
# Stage 2: Run backend + serve frontend
|
| 10 |
+
FROM python:3.12-slim
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Install Python deps
|
| 14 |
+
COPY backend/requirements.txt ./backend/
|
| 15 |
+
RUN pip install --no-cache-dir -r backend/requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy backend
|
| 18 |
+
COPY backend/ ./backend/
|
| 19 |
+
COPY airline_routes.json ./
|
| 20 |
+
|
| 21 |
+
# Copy built frontend
|
| 22 |
+
COPY --from=frontend-build /app/frontend/dist ./frontend/dist
|
| 23 |
+
|
| 24 |
+
EXPOSE 8080
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8080"]
|
JOURNAL.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
airline_routes.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e8d97548f626230927e3e38b8ba9712710612ef90292e9a0696698d20b3bac3
|
| 3 |
+
size 21798276
|
backend/__init__.py
ADDED
|
File without changes
|
backend/api/__init__.py
ADDED
|
File without changes
|
backend/api/airports.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Airport autocomplete and info endpoints."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 4 |
+
|
| 5 |
+
from ..data_loader import get_route_graph
|
| 6 |
+
from ..models import AirportInfo, AutocompleteResult
|
| 7 |
+
|
| 8 |
+
router = APIRouter(prefix="/api/airports", tags=["airports"])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.get("/autocomplete", response_model=list[AutocompleteResult])
|
| 12 |
+
async def autocomplete(q: str = Query(..., min_length=1, max_length=50)):
|
| 13 |
+
graph = get_route_graph()
|
| 14 |
+
airports = graph.search_airports(q, limit=10)
|
| 15 |
+
return [
|
| 16 |
+
AutocompleteResult(
|
| 17 |
+
iata=a.iata,
|
| 18 |
+
name=a.name,
|
| 19 |
+
city_name=a.city_name,
|
| 20 |
+
country=a.country,
|
| 21 |
+
display_name=a.display_name,
|
| 22 |
+
hub_score=a.hub_score,
|
| 23 |
+
)
|
| 24 |
+
for a in airports
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@router.get("/{iata}", response_model=AirportInfo)
|
| 29 |
+
async def get_airport(iata: str):
|
| 30 |
+
graph = get_route_graph()
|
| 31 |
+
iata = iata.upper()
|
| 32 |
+
airport = graph.airports.get(iata)
|
| 33 |
+
if not airport:
|
| 34 |
+
raise HTTPException(status_code=404, detail=f"Airport {iata} not found")
|
| 35 |
+
return AirportInfo(
|
| 36 |
+
iata=airport.iata,
|
| 37 |
+
name=airport.name,
|
| 38 |
+
city_name=airport.city_name,
|
| 39 |
+
country=airport.country,
|
| 40 |
+
country_code=airport.country_code,
|
| 41 |
+
continent=airport.continent,
|
| 42 |
+
latitude=airport.latitude,
|
| 43 |
+
longitude=airport.longitude,
|
| 44 |
+
timezone=airport.timezone,
|
| 45 |
+
hub_score=airport.hub_score,
|
| 46 |
+
route_count=len(airport.routes),
|
| 47 |
+
)
|
backend/api/calendar.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Calendar pricing endpoint — cheapest price per day for a month."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import calendar
|
| 6 |
+
from datetime import date
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter, HTTPException, Query
|
| 9 |
+
|
| 10 |
+
from ..data_loader import get_route_graph
|
| 11 |
+
from ..models import CalendarDay, CalendarResponse, CabinClass
|
| 12 |
+
from ..price_engine import compute_calendar_price
|
| 13 |
+
from ..seed_utils import seeded_random
|
| 14 |
+
|
| 15 |
+
router = APIRouter(prefix="/api", tags=["calendar"])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.get("/calendar", response_model=CalendarResponse)
|
| 19 |
+
async def get_calendar(
|
| 20 |
+
origin: str = Query(..., min_length=3, max_length=3),
|
| 21 |
+
destination: str = Query(..., min_length=3, max_length=3),
|
| 22 |
+
year: int = Query(..., ge=2025, le=2028),
|
| 23 |
+
month: int = Query(..., ge=1, le=12),
|
| 24 |
+
cabin_class: CabinClass = Query(CabinClass.economy),
|
| 25 |
+
):
|
| 26 |
+
graph = get_route_graph()
|
| 27 |
+
origin = origin.upper()
|
| 28 |
+
destination = destination.upper()
|
| 29 |
+
|
| 30 |
+
if origin not in graph.airports:
|
| 31 |
+
raise HTTPException(status_code=404, detail=f"Airport {origin} not found")
|
| 32 |
+
if destination not in graph.airports:
|
| 33 |
+
raise HTTPException(status_code=404, detail=f"Airport {destination} not found")
|
| 34 |
+
|
| 35 |
+
route = graph.get_direct_route(origin, destination)
|
| 36 |
+
if not route:
|
| 37 |
+
# Try to estimate distance for pricing
|
| 38 |
+
from ..route_finder import _estimate_distance
|
| 39 |
+
distance = _estimate_distance(graph, origin, destination)
|
| 40 |
+
if distance is None:
|
| 41 |
+
raise HTTPException(status_code=404, detail="No route found")
|
| 42 |
+
num_carriers = 2 # default estimate
|
| 43 |
+
else:
|
| 44 |
+
distance = route.distance_km
|
| 45 |
+
num_carriers = len(route.carriers)
|
| 46 |
+
|
| 47 |
+
dest_airport = graph.airports[destination]
|
| 48 |
+
num_days = calendar.monthrange(year, month)[1]
|
| 49 |
+
|
| 50 |
+
days = []
|
| 51 |
+
for day in range(1, num_days + 1):
|
| 52 |
+
d = date(year, month, day)
|
| 53 |
+
rng = seeded_random(origin, destination, d.isoformat(), "calendar")
|
| 54 |
+
price = compute_calendar_price(
|
| 55 |
+
distance_km=distance,
|
| 56 |
+
cabin_class=cabin_class.value,
|
| 57 |
+
target_date=d,
|
| 58 |
+
num_carriers=num_carriers,
|
| 59 |
+
dest_continent=dest_airport.continent,
|
| 60 |
+
rng=rng,
|
| 61 |
+
)
|
| 62 |
+
days.append(CalendarDay(date=d, cheapest_price=price))
|
| 63 |
+
|
| 64 |
+
return CalendarResponse(
|
| 65 |
+
origin=origin,
|
| 66 |
+
destination=destination,
|
| 67 |
+
year=year,
|
| 68 |
+
month=month,
|
| 69 |
+
days=days,
|
| 70 |
+
)
|
backend/api/search.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flight search endpoint."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException
|
| 6 |
+
|
| 7 |
+
from ..config import MAX_RESULTS
|
| 8 |
+
from ..data_loader import get_route_graph
|
| 9 |
+
from ..flight_generator import generate_flights_for_route
|
| 10 |
+
from ..hub_detector import compute_hub_scores
|
| 11 |
+
from ..models import FlightOffer, SearchRequest, SearchResponse, SortBy
|
| 12 |
+
from ..route_finder import find_routes
|
| 13 |
+
from ..seed_utils import make_seed
|
| 14 |
+
|
| 15 |
+
router = APIRouter(prefix="/api", tags=["search"])
|
| 16 |
+
|
| 17 |
+
# Module-level hub cache
|
| 18 |
+
_hub_iatas: list[str] | None = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_hubs() -> list[str]:
|
| 22 |
+
global _hub_iatas
|
| 23 |
+
if _hub_iatas is None:
|
| 24 |
+
graph = get_route_graph()
|
| 25 |
+
_hub_iatas = compute_hub_scores(graph)
|
| 26 |
+
return _hub_iatas
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _apply_filters(flights: list[FlightOffer], req: SearchRequest) -> list[FlightOffer]:
|
| 30 |
+
f = req.filters
|
| 31 |
+
result = flights
|
| 32 |
+
|
| 33 |
+
if f.max_stops is not None:
|
| 34 |
+
result = [fl for fl in result if fl.stops <= f.max_stops]
|
| 35 |
+
|
| 36 |
+
if f.max_price is not None:
|
| 37 |
+
result = [fl for fl in result if fl.price_usd <= f.max_price]
|
| 38 |
+
|
| 39 |
+
if f.max_duration_minutes is not None:
|
| 40 |
+
result = [fl for fl in result if fl.total_duration_minutes <= f.max_duration_minutes]
|
| 41 |
+
|
| 42 |
+
if f.airlines:
|
| 43 |
+
airline_set = set(f.airlines)
|
| 44 |
+
result = [
|
| 45 |
+
fl for fl in result
|
| 46 |
+
if any(seg.airline_code in airline_set for seg in fl.segments)
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
if f.departure_time_min:
|
| 50 |
+
h, m = map(int, f.departure_time_min.split(":"))
|
| 51 |
+
min_minutes = h * 60 + m
|
| 52 |
+
result = [
|
| 53 |
+
fl for fl in result
|
| 54 |
+
if fl.departure.hour * 60 + fl.departure.minute >= min_minutes
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
if f.departure_time_max:
|
| 58 |
+
h, m = map(int, f.departure_time_max.split(":"))
|
| 59 |
+
max_minutes = h * 60 + m
|
| 60 |
+
result = [
|
| 61 |
+
fl for fl in result
|
| 62 |
+
if fl.departure.hour * 60 + fl.departure.minute <= max_minutes
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
return result
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _sort_flights(flights: list[FlightOffer], sort_by: SortBy) -> list[FlightOffer]:
|
| 69 |
+
if sort_by == SortBy.cheapest:
|
| 70 |
+
return sorted(flights, key=lambda f: f.price_usd)
|
| 71 |
+
elif sort_by == SortBy.fastest:
|
| 72 |
+
return sorted(flights, key=lambda f: f.total_duration_minutes)
|
| 73 |
+
else: # best: balance of price and duration
|
| 74 |
+
if not flights:
|
| 75 |
+
return flights
|
| 76 |
+
max_price = max(f.price_usd for f in flights) or 1
|
| 77 |
+
max_dur = max(f.total_duration_minutes for f in flights) or 1
|
| 78 |
+
return sorted(
|
| 79 |
+
flights,
|
| 80 |
+
key=lambda f: (f.price_usd / max_price) * 0.6 + (f.total_duration_minutes / max_dur) * 0.4,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@router.post("/search", response_model=SearchResponse)
|
| 85 |
+
async def search_flights(req: SearchRequest):
|
| 86 |
+
graph = get_route_graph()
|
| 87 |
+
hub_iatas = _get_hubs()
|
| 88 |
+
|
| 89 |
+
if not req.legs:
|
| 90 |
+
raise HTTPException(status_code=400, detail="At least one leg required")
|
| 91 |
+
|
| 92 |
+
# Validate airports
|
| 93 |
+
for leg in req.legs:
|
| 94 |
+
if leg.origin.upper() not in graph.airports:
|
| 95 |
+
raise HTTPException(status_code=404, detail=f"Airport {leg.origin} not found")
|
| 96 |
+
if leg.destination.upper() not in graph.airports:
|
| 97 |
+
raise HTTPException(status_code=404, detail=f"Airport {leg.destination} not found")
|
| 98 |
+
|
| 99 |
+
# Generate outbound flights
|
| 100 |
+
outbound_leg = req.legs[0]
|
| 101 |
+
origin = outbound_leg.origin.upper()
|
| 102 |
+
destination = outbound_leg.destination.upper()
|
| 103 |
+
|
| 104 |
+
max_stops = req.filters.max_stops
|
| 105 |
+
route_plans = find_routes(graph, origin, destination, hub_iatas, max_stops=max_stops)
|
| 106 |
+
|
| 107 |
+
outbound_flights: list[FlightOffer] = []
|
| 108 |
+
for plan in route_plans:
|
| 109 |
+
flights = generate_flights_for_route(
|
| 110 |
+
graph, plan, outbound_leg.date, req.cabin_class, hub_iatas
|
| 111 |
+
)
|
| 112 |
+
outbound_flights.extend(flights)
|
| 113 |
+
|
| 114 |
+
outbound_flights = _apply_filters(outbound_flights, req)
|
| 115 |
+
outbound_flights = _sort_flights(outbound_flights, req.sort_by)
|
| 116 |
+
outbound_flights = outbound_flights[:MAX_RESULTS]
|
| 117 |
+
|
| 118 |
+
# Generate return flights if round trip
|
| 119 |
+
return_flights: list[FlightOffer] = []
|
| 120 |
+
if req.trip_type.value == "round_trip" and len(req.legs) >= 2:
|
| 121 |
+
return_leg = req.legs[1]
|
| 122 |
+
ret_origin = return_leg.origin.upper()
|
| 123 |
+
ret_dest = return_leg.destination.upper()
|
| 124 |
+
ret_plans = find_routes(graph, ret_origin, ret_dest, hub_iatas, max_stops=max_stops)
|
| 125 |
+
|
| 126 |
+
for plan in ret_plans:
|
| 127 |
+
flights = generate_flights_for_route(
|
| 128 |
+
graph, plan, return_leg.date, req.cabin_class, hub_iatas
|
| 129 |
+
)
|
| 130 |
+
return_flights.extend(flights)
|
| 131 |
+
|
| 132 |
+
return_flights = _apply_filters(return_flights, req)
|
| 133 |
+
return_flights = _sort_flights(return_flights, req.sort_by)
|
| 134 |
+
return_flights = return_flights[:MAX_RESULTS]
|
| 135 |
+
|
| 136 |
+
search_id = str(make_seed(origin, destination, outbound_leg.date.isoformat()))
|
| 137 |
+
|
| 138 |
+
return SearchResponse(
|
| 139 |
+
outbound_flights=outbound_flights,
|
| 140 |
+
return_flights=return_flights,
|
| 141 |
+
search_id=search_id,
|
| 142 |
+
origin=origin,
|
| 143 |
+
destination=destination,
|
| 144 |
+
)
|
backend/config.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pricing constants and configuration."""
|
| 2 |
+
|
| 3 |
+
# Base price formula
|
| 4 |
+
BASE_FIXED_USD = 40
|
| 5 |
+
BASE_PER_KM_USD = 0.08
|
| 6 |
+
|
| 7 |
+
# Cabin class multipliers
|
| 8 |
+
CLASS_MULTIPLIERS = {
|
| 9 |
+
"economy": 1.0,
|
| 10 |
+
"premium_economy": 1.6,
|
| 11 |
+
"business": 3.2,
|
| 12 |
+
"first": 5.5,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
# Day-of-week multipliers (0=Monday, 6=Sunday)
|
| 16 |
+
DAY_MULTIPLIERS = {
|
| 17 |
+
0: 0.90, # Monday
|
| 18 |
+
1: 0.90, # Tuesday
|
| 19 |
+
2: 0.90, # Wednesday
|
| 20 |
+
3: 1.00, # Thursday
|
| 21 |
+
4: 1.15, # Friday
|
| 22 |
+
5: 1.05, # Saturday
|
| 23 |
+
6: 1.10, # Sunday
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# Season multipliers by month
|
| 27 |
+
SEASON_MULTIPLIERS = {
|
| 28 |
+
1: 0.85, # January - off season
|
| 29 |
+
2: 0.85, # February - off season
|
| 30 |
+
3: 0.95, # March
|
| 31 |
+
4: 1.00, # April
|
| 32 |
+
5: 1.05, # May
|
| 33 |
+
6: 1.15, # June - summer peak
|
| 34 |
+
7: 1.20, # July - summer peak
|
| 35 |
+
8: 1.15, # August - summer peak
|
| 36 |
+
9: 0.90, # September - off season
|
| 37 |
+
10: 0.95, # October
|
| 38 |
+
11: 1.00, # November
|
| 39 |
+
12: 1.40, # December - Christmas
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Season bonus for EU destinations in summer
|
| 43 |
+
EU_SUMMER_BONUS = 0.15 # +15% on top of summer multiplier
|
| 44 |
+
EU_CONTINENTS = {"EU"}
|
| 45 |
+
EU_SUMMER_MONTHS = {6, 7, 8}
|
| 46 |
+
|
| 47 |
+
# Demand multipliers
|
| 48 |
+
MONOPOLY_ROUTE_BONUS = 0.20 # +20% if only 1 carrier
|
| 49 |
+
HIGH_COMPETITION_DISCOUNT = 0.05 # -5% if 4+ carriers
|
| 50 |
+
|
| 51 |
+
# Advance booking multipliers (days before departure)
|
| 52 |
+
ADVANCE_MULTIPLIERS = [
|
| 53 |
+
(3, 1.50), # 0-3 days: +50%
|
| 54 |
+
(7, 1.35), # 4-7 days: +35%
|
| 55 |
+
(14, 1.20), # 8-14 days: +20%
|
| 56 |
+
(21, 1.10), # 15-21 days: +10%
|
| 57 |
+
(60, 1.00), # 22-60 days: base
|
| 58 |
+
(90, 0.90), # 61-90 days: -10%
|
| 59 |
+
(float("inf"), 0.95), # 91+ days: -5%
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
# Jitter range (±8%)
|
| 63 |
+
JITTER_RANGE = 0.08
|
| 64 |
+
|
| 65 |
+
# Hub detection thresholds
|
| 66 |
+
HUB_MIN_ROUTES = 100
|
| 67 |
+
HUB_TOP_N = 125
|
| 68 |
+
|
| 69 |
+
# Connecting flight constraints
|
| 70 |
+
MAX_1STOP_DISTANCE_RATIO = 1.8 # Max total distance vs great-circle
|
| 71 |
+
MAX_2STOP_DISTANCE_RATIO = 2.5
|
| 72 |
+
MIN_LAYOVER_MINUTES = 60 # 1 hour
|
| 73 |
+
MAX_LAYOVER_MINUTES = 360 # 6 hours
|
| 74 |
+
|
| 75 |
+
# Flight generation
|
| 76 |
+
MIN_FLIGHTS_PER_DAY = 1
|
| 77 |
+
MAX_FLIGHTS_SINGLE_CARRIER = 3
|
| 78 |
+
MAX_FLIGHTS_MULTI_CARRIER = 15
|
| 79 |
+
DEPARTURE_HOUR_MIN = 5 # 05:00
|
| 80 |
+
DEPARTURE_HOUR_MAX = 23 # 23:00
|
| 81 |
+
|
| 82 |
+
# Aircraft types by distance
|
| 83 |
+
AIRCRAFT_BY_DISTANCE = [
|
| 84 |
+
(500, ["E190", "E175", "CRJ-900"]),
|
| 85 |
+
(2000, ["A320", "A321", "737-800", "737 MAX 8"]),
|
| 86 |
+
(5000, ["A321neo LR", "757-200", "767-300ER"]),
|
| 87 |
+
(10000, ["787-8", "787-9", "A330-300", "A350-900"]),
|
| 88 |
+
(float("inf"), ["777-300ER", "A350-1000", "787-10", "A380"]),
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
# Search limits
|
| 92 |
+
MAX_RESULTS = 200
|
| 93 |
+
MAX_AUTOCOMPLETE_RESULTS = 10
|
backend/data_loader.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load airline_routes.json and build in-memory route graph + search index."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Route:
|
| 12 |
+
destination: str
|
| 13 |
+
distance_km: int
|
| 14 |
+
duration_min: int
|
| 15 |
+
carriers: list[dict] # [{"iata": "AA", "name": "American Airlines"}, ...]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class Airport:
|
| 20 |
+
iata: str
|
| 21 |
+
name: str
|
| 22 |
+
city_name: str
|
| 23 |
+
country: str
|
| 24 |
+
country_code: str
|
| 25 |
+
continent: str
|
| 26 |
+
latitude: float
|
| 27 |
+
longitude: float
|
| 28 |
+
timezone: str
|
| 29 |
+
elevation: int
|
| 30 |
+
icao: str
|
| 31 |
+
display_name: str
|
| 32 |
+
routes: list[Route] = field(default_factory=list)
|
| 33 |
+
hub_score: float = 0.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RouteGraph:
|
| 37 |
+
"""In-memory route graph and search index."""
|
| 38 |
+
|
| 39 |
+
def __init__(self) -> None:
|
| 40 |
+
self.airports: dict[str, Airport] = {}
|
| 41 |
+
# route_map[origin_iata][dest_iata] = Route
|
| 42 |
+
self.route_map: dict[str, dict[str, Route]] = {}
|
| 43 |
+
# Search index: lowercase tokens → set of IATA codes
|
| 44 |
+
self._search_index: dict[str, set[str]] = {}
|
| 45 |
+
|
| 46 |
+
def load(self, filepath: str) -> None:
|
| 47 |
+
with open(filepath) as f:
|
| 48 |
+
data: dict = json.load(f)
|
| 49 |
+
|
| 50 |
+
for iata, info in data.items():
|
| 51 |
+
routes = []
|
| 52 |
+
for r in info.get("routes", []):
|
| 53 |
+
routes.append(Route(
|
| 54 |
+
destination=r["iata"],
|
| 55 |
+
distance_km=r["km"],
|
| 56 |
+
duration_min=r["min"],
|
| 57 |
+
carriers=r["carriers"],
|
| 58 |
+
))
|
| 59 |
+
|
| 60 |
+
airport = Airport(
|
| 61 |
+
iata=iata,
|
| 62 |
+
name=info["name"],
|
| 63 |
+
city_name=info["city_name"],
|
| 64 |
+
country=info["country"],
|
| 65 |
+
country_code=info["country_code"],
|
| 66 |
+
continent=info["continent"],
|
| 67 |
+
latitude=float(info["latitude"]) if info.get("latitude") is not None else 0.0,
|
| 68 |
+
longitude=float(info["longitude"]) if info.get("longitude") is not None else 0.0,
|
| 69 |
+
timezone=info.get("timezone", "UTC"),
|
| 70 |
+
elevation=info.get("elevation", 0),
|
| 71 |
+
icao=info.get("icao", ""),
|
| 72 |
+
display_name=info.get("display_name", f"{info['city_name']} ({iata})"),
|
| 73 |
+
routes=routes,
|
| 74 |
+
)
|
| 75 |
+
self.airports[iata] = airport
|
| 76 |
+
|
| 77 |
+
# Build route map
|
| 78 |
+
self.route_map.setdefault(iata, {})
|
| 79 |
+
for route in routes:
|
| 80 |
+
self.route_map[iata][route.destination] = route
|
| 81 |
+
|
| 82 |
+
# Build search index
|
| 83 |
+
self._index_airport(airport)
|
| 84 |
+
|
| 85 |
+
def _index_airport(self, airport: Airport) -> None:
|
| 86 |
+
tokens = set()
|
| 87 |
+
# IATA code
|
| 88 |
+
tokens.add(airport.iata.lower())
|
| 89 |
+
# City name tokens
|
| 90 |
+
for word in airport.city_name.lower().split():
|
| 91 |
+
tokens.add(word)
|
| 92 |
+
# Airport name tokens
|
| 93 |
+
for word in airport.name.lower().split():
|
| 94 |
+
tokens.add(word)
|
| 95 |
+
# Country
|
| 96 |
+
for word in airport.country.lower().split():
|
| 97 |
+
tokens.add(word)
|
| 98 |
+
# Country code
|
| 99 |
+
tokens.add(airport.country_code.lower())
|
| 100 |
+
|
| 101 |
+
for token in tokens:
|
| 102 |
+
# Index exact token and all prefixes ≥ 2 chars
|
| 103 |
+
for i in range(2, len(token) + 1):
|
| 104 |
+
prefix = token[:i]
|
| 105 |
+
self._search_index.setdefault(prefix, set()).add(airport.iata)
|
| 106 |
+
|
| 107 |
+
def search_airports(self, query: str, limit: int = 10) -> list[Airport]:
|
| 108 |
+
"""Search airports by IATA code, city, name, or country."""
|
| 109 |
+
q = query.strip().lower()
|
| 110 |
+
if not q:
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
# Exact IATA match first
|
| 114 |
+
if len(q) == 3 and q.upper() in self.airports:
|
| 115 |
+
exact = self.airports[q.upper()]
|
| 116 |
+
results = [exact]
|
| 117 |
+
# Add more results from prefix search
|
| 118 |
+
candidates = self._search_index.get(q, set())
|
| 119 |
+
for iata in candidates:
|
| 120 |
+
if iata != exact.iata:
|
| 121 |
+
results.append(self.airports[iata])
|
| 122 |
+
if len(results) >= limit:
|
| 123 |
+
break
|
| 124 |
+
return results[:limit]
|
| 125 |
+
|
| 126 |
+
# Split query into tokens, intersect matches
|
| 127 |
+
query_tokens = q.split()
|
| 128 |
+
if not query_tokens:
|
| 129 |
+
return []
|
| 130 |
+
|
| 131 |
+
# Get candidates matching first token
|
| 132 |
+
candidates = self._search_index.get(query_tokens[0], set()).copy()
|
| 133 |
+
|
| 134 |
+
# Intersect with additional tokens
|
| 135 |
+
for token in query_tokens[1:]:
|
| 136 |
+
token_matches = self._search_index.get(token, set())
|
| 137 |
+
candidates &= token_matches
|
| 138 |
+
|
| 139 |
+
if not candidates:
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
# Sort by hub score (descending), then alphabetically
|
| 143 |
+
airports = [self.airports[iata] for iata in candidates if iata in self.airports]
|
| 144 |
+
airports.sort(key=lambda a: (-a.hub_score, a.city_name))
|
| 145 |
+
return airports[:limit]
|
| 146 |
+
|
| 147 |
+
def get_direct_route(self, origin: str, destination: str) -> Route | None:
|
| 148 |
+
return self.route_map.get(origin, {}).get(destination)
|
| 149 |
+
|
| 150 |
+
def get_outbound_routes(self, origin: str) -> dict[str, Route]:
|
| 151 |
+
return self.route_map.get(origin, {})
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# Singleton
|
| 155 |
+
_graph: RouteGraph | None = None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_route_graph() -> RouteGraph:
|
| 159 |
+
global _graph
|
| 160 |
+
if _graph is None:
|
| 161 |
+
_graph = RouteGraph()
|
| 162 |
+
data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "airline_routes.json")
|
| 163 |
+
_graph.load(data_path)
|
| 164 |
+
return _graph
|
backend/flight_generator.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate concrete flights for a route + date."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from datetime import date, datetime, timedelta, timezone
|
| 7 |
+
|
| 8 |
+
from zoneinfo import ZoneInfo
|
| 9 |
+
|
| 10 |
+
from .config import (
|
| 11 |
+
AIRCRAFT_BY_DISTANCE,
|
| 12 |
+
DEPARTURE_HOUR_MAX,
|
| 13 |
+
DEPARTURE_HOUR_MIN,
|
| 14 |
+
MAX_FLIGHTS_MULTI_CARRIER,
|
| 15 |
+
MAX_FLIGHTS_SINGLE_CARRIER,
|
| 16 |
+
MAX_LAYOVER_MINUTES,
|
| 17 |
+
MIN_FLIGHTS_PER_DAY,
|
| 18 |
+
MIN_LAYOVER_MINUTES,
|
| 19 |
+
)
|
| 20 |
+
from .data_loader import Route, RouteGraph
|
| 21 |
+
from .models import CabinClass, FlightOffer, FlightSegment
|
| 22 |
+
from .price_engine import compute_price
|
| 23 |
+
from .route_finder import RoutePlan
|
| 24 |
+
from .seed_utils import seeded_random
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _pick_aircraft(distance_km: int, rng: random.Random) -> str:
|
| 28 |
+
for max_dist, aircraft_list in AIRCRAFT_BY_DISTANCE:
|
| 29 |
+
if distance_km <= max_dist:
|
| 30 |
+
return rng.choice(aircraft_list)
|
| 31 |
+
return "777-300ER"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _make_flight_number(carrier_iata: str, rng: random.Random) -> str:
|
| 35 |
+
return f"{carrier_iata}{rng.randint(100, 9999)}"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_timezone(graph: RouteGraph, iata: str) -> ZoneInfo:
|
| 39 |
+
airport = graph.airports.get(iata)
|
| 40 |
+
if airport and airport.timezone:
|
| 41 |
+
try:
|
| 42 |
+
return ZoneInfo(airport.timezone)
|
| 43 |
+
except KeyError:
|
| 44 |
+
pass
|
| 45 |
+
return ZoneInfo("UTC")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def generate_flights_for_route(
|
| 49 |
+
graph: RouteGraph,
|
| 50 |
+
route_plan: RoutePlan,
|
| 51 |
+
departure_date: date,
|
| 52 |
+
cabin_class: CabinClass,
|
| 53 |
+
hub_iatas: list[str],
|
| 54 |
+
) -> list[FlightOffer]:
|
| 55 |
+
"""Generate concrete flight offers for a route plan on a given date."""
|
| 56 |
+
origin = route_plan.waypoints[0]
|
| 57 |
+
destination = route_plan.waypoints[-1]
|
| 58 |
+
|
| 59 |
+
# Seed based on route + date for determinism
|
| 60 |
+
seed_key = f"{origin}-{destination}-{departure_date.isoformat()}-{cabin_class.value}"
|
| 61 |
+
rng = seeded_random(seed_key, *route_plan.waypoints)
|
| 62 |
+
|
| 63 |
+
if route_plan.stops == 0:
|
| 64 |
+
return _generate_direct_flights(graph, route_plan, departure_date, cabin_class, rng)
|
| 65 |
+
else:
|
| 66 |
+
return _generate_connecting_flights(graph, route_plan, departure_date, cabin_class, rng)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _generate_direct_flights(
|
| 70 |
+
graph: RouteGraph,
|
| 71 |
+
route_plan: RoutePlan,
|
| 72 |
+
departure_date: date,
|
| 73 |
+
cabin_class: CabinClass,
|
| 74 |
+
rng: random.Random,
|
| 75 |
+
) -> list[FlightOffer]:
|
| 76 |
+
"""Generate multiple direct flight options for a single-leg route."""
|
| 77 |
+
leg = route_plan.legs[0]
|
| 78 |
+
origin = route_plan.waypoints[0]
|
| 79 |
+
destination = route_plan.waypoints[1]
|
| 80 |
+
|
| 81 |
+
# Number of flights based on carrier count
|
| 82 |
+
num_carriers = len(leg.carriers)
|
| 83 |
+
if num_carriers == 1:
|
| 84 |
+
num_flights = rng.randint(MIN_FLIGHTS_PER_DAY, MAX_FLIGHTS_SINGLE_CARRIER)
|
| 85 |
+
elif num_carriers <= 3:
|
| 86 |
+
num_flights = rng.randint(3, 8)
|
| 87 |
+
else:
|
| 88 |
+
num_flights = rng.randint(8, MAX_FLIGHTS_MULTI_CARRIER)
|
| 89 |
+
|
| 90 |
+
# Generate departure times spread across the day
|
| 91 |
+
departure_hours = sorted([
|
| 92 |
+
rng.randint(DEPARTURE_HOUR_MIN * 60, DEPARTURE_HOUR_MAX * 60)
|
| 93 |
+
for _ in range(num_flights)
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
origin_tz = _get_timezone(graph, origin)
|
| 97 |
+
dest_tz = _get_timezone(graph, destination)
|
| 98 |
+
origin_airport = graph.airports[origin]
|
| 99 |
+
dest_airport = graph.airports[destination]
|
| 100 |
+
|
| 101 |
+
flights = []
|
| 102 |
+
for dep_minutes in departure_hours:
|
| 103 |
+
carrier = rng.choice(leg.carriers)
|
| 104 |
+
dep_hour = dep_minutes // 60
|
| 105 |
+
dep_min = dep_minutes % 60
|
| 106 |
+
|
| 107 |
+
departure_dt = datetime(
|
| 108 |
+
departure_date.year, departure_date.month, departure_date.day,
|
| 109 |
+
dep_hour, dep_min,
|
| 110 |
+
tzinfo=origin_tz,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Calculate arrival
|
| 114 |
+
arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
|
| 115 |
+
arrival_dt = arrival_dt.astimezone(dest_tz)
|
| 116 |
+
|
| 117 |
+
price = compute_price(
|
| 118 |
+
distance_km=leg.distance_km,
|
| 119 |
+
cabin_class=cabin_class.value,
|
| 120 |
+
departure_date=departure_date,
|
| 121 |
+
departure_hour=dep_hour,
|
| 122 |
+
num_carriers=num_carriers,
|
| 123 |
+
dest_continent=dest_airport.continent,
|
| 124 |
+
rng=rng,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
flight_id = f"{origin}{destination}{departure_date.isoformat()}{dep_minutes}{carrier['iata']}"
|
| 128 |
+
|
| 129 |
+
segment = FlightSegment(
|
| 130 |
+
airline_code=carrier["iata"],
|
| 131 |
+
airline_name=carrier["name"],
|
| 132 |
+
flight_number=_make_flight_number(carrier["iata"], rng),
|
| 133 |
+
aircraft=_pick_aircraft(leg.distance_km, rng),
|
| 134 |
+
origin=origin,
|
| 135 |
+
origin_city=origin_airport.city_name,
|
| 136 |
+
destination=destination,
|
| 137 |
+
destination_city=dest_airport.city_name,
|
| 138 |
+
departure=departure_dt,
|
| 139 |
+
arrival=arrival_dt,
|
| 140 |
+
duration_minutes=leg.duration_min,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
flights.append(FlightOffer(
|
| 144 |
+
id=flight_id,
|
| 145 |
+
segments=[segment],
|
| 146 |
+
total_duration_minutes=leg.duration_min,
|
| 147 |
+
stops=0,
|
| 148 |
+
price_usd=price,
|
| 149 |
+
cabin_class=cabin_class,
|
| 150 |
+
origin=origin,
|
| 151 |
+
destination=destination,
|
| 152 |
+
departure=departure_dt,
|
| 153 |
+
arrival=arrival_dt,
|
| 154 |
+
))
|
| 155 |
+
|
| 156 |
+
return flights
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _generate_connecting_flights(
|
| 160 |
+
graph: RouteGraph,
|
| 161 |
+
route_plan: RoutePlan,
|
| 162 |
+
departure_date: date,
|
| 163 |
+
cabin_class: CabinClass,
|
| 164 |
+
rng: random.Random,
|
| 165 |
+
) -> list[FlightOffer]:
|
| 166 |
+
"""Generate connecting flight options (1-stop or 2-stop)."""
|
| 167 |
+
origin = route_plan.waypoints[0]
|
| 168 |
+
destination = route_plan.waypoints[-1]
|
| 169 |
+
dest_airport = graph.airports[destination]
|
| 170 |
+
|
| 171 |
+
# Generate 2-5 options per connecting route
|
| 172 |
+
num_options = rng.randint(2, 5)
|
| 173 |
+
|
| 174 |
+
flights = []
|
| 175 |
+
for option_idx in range(num_options):
|
| 176 |
+
departure_minutes = rng.randint(DEPARTURE_HOUR_MIN * 60, DEPARTURE_HOUR_MAX * 60)
|
| 177 |
+
|
| 178 |
+
segments = []
|
| 179 |
+
current_time = datetime(
|
| 180 |
+
departure_date.year, departure_date.month, departure_date.day,
|
| 181 |
+
departure_minutes // 60, departure_minutes % 60,
|
| 182 |
+
tzinfo=_get_timezone(graph, origin),
|
| 183 |
+
)
|
| 184 |
+
total_price = 0.0
|
| 185 |
+
total_duration = 0
|
| 186 |
+
|
| 187 |
+
valid = True
|
| 188 |
+
for i, leg in enumerate(route_plan.legs):
|
| 189 |
+
leg_origin = route_plan.waypoints[i]
|
| 190 |
+
leg_dest = route_plan.waypoints[i + 1]
|
| 191 |
+
origin_tz = _get_timezone(graph, leg_origin)
|
| 192 |
+
dest_tz = _get_timezone(graph, leg_dest)
|
| 193 |
+
origin_ap = graph.airports[leg_origin]
|
| 194 |
+
dest_ap = graph.airports[leg_dest]
|
| 195 |
+
|
| 196 |
+
carrier = rng.choice(leg.carriers)
|
| 197 |
+
departure_dt = current_time.astimezone(origin_tz)
|
| 198 |
+
arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
|
| 199 |
+
arrival_dt = arrival_dt.astimezone(dest_tz)
|
| 200 |
+
|
| 201 |
+
# Per-leg price
|
| 202 |
+
leg_price = compute_price(
|
| 203 |
+
distance_km=leg.distance_km,
|
| 204 |
+
cabin_class=cabin_class.value,
|
| 205 |
+
departure_date=departure_date,
|
| 206 |
+
departure_hour=departure_dt.hour,
|
| 207 |
+
num_carriers=len(leg.carriers),
|
| 208 |
+
dest_continent=dest_ap.continent,
|
| 209 |
+
rng=rng,
|
| 210 |
+
)
|
| 211 |
+
# Connecting flights get a discount
|
| 212 |
+
leg_price *= 0.75
|
| 213 |
+
total_price += leg_price
|
| 214 |
+
|
| 215 |
+
segments.append(FlightSegment(
|
| 216 |
+
airline_code=carrier["iata"],
|
| 217 |
+
airline_name=carrier["name"],
|
| 218 |
+
flight_number=_make_flight_number(carrier["iata"], rng),
|
| 219 |
+
aircraft=_pick_aircraft(leg.distance_km, rng),
|
| 220 |
+
origin=leg_origin,
|
| 221 |
+
origin_city=origin_ap.city_name,
|
| 222 |
+
destination=leg_dest,
|
| 223 |
+
destination_city=dest_ap.city_name,
|
| 224 |
+
departure=departure_dt,
|
| 225 |
+
arrival=arrival_dt,
|
| 226 |
+
duration_minutes=leg.duration_min,
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
# Add layover time for next leg
|
| 230 |
+
if i < len(route_plan.legs) - 1:
|
| 231 |
+
layover = rng.randint(MIN_LAYOVER_MINUTES, MAX_LAYOVER_MINUTES)
|
| 232 |
+
current_time = arrival_dt + timedelta(minutes=layover)
|
| 233 |
+
total_duration += leg.duration_min + layover
|
| 234 |
+
|
| 235 |
+
# Check if layover pushes to next day too far
|
| 236 |
+
if (current_time - datetime(
|
| 237 |
+
departure_date.year, departure_date.month, departure_date.day,
|
| 238 |
+
tzinfo=origin_tz
|
| 239 |
+
)).days > 1:
|
| 240 |
+
valid = False
|
| 241 |
+
break
|
| 242 |
+
else:
|
| 243 |
+
total_duration += leg.duration_min
|
| 244 |
+
|
| 245 |
+
if not valid:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
total_price = round(total_price, 0)
|
| 249 |
+
first_departure = segments[0].departure
|
| 250 |
+
last_arrival = segments[-1].arrival
|
| 251 |
+
|
| 252 |
+
flight_id = (
|
| 253 |
+
f"{origin}{destination}{departure_date.isoformat()}"
|
| 254 |
+
f"{departure_minutes}{'-'.join(route_plan.waypoints)}{option_idx}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
flights.append(FlightOffer(
|
| 258 |
+
id=flight_id,
|
| 259 |
+
segments=segments,
|
| 260 |
+
total_duration_minutes=total_duration,
|
| 261 |
+
stops=route_plan.stops,
|
| 262 |
+
price_usd=total_price,
|
| 263 |
+
cabin_class=cabin_class,
|
| 264 |
+
origin=origin,
|
| 265 |
+
destination=destination,
|
| 266 |
+
departure=first_departure,
|
| 267 |
+
arrival=last_arrival,
|
| 268 |
+
))
|
| 269 |
+
|
| 270 |
+
return flights
|
backend/hub_detector.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compute hub scores for airports.
|
| 2 |
+
|
| 3 |
+
Hub score = route_count * carrier_diversity * continent_reach
|
| 4 |
+
Used for: connecting flight search (top hubs as waypoints) and autocomplete ranking.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from .config import HUB_MIN_ROUTES, HUB_TOP_N
|
| 10 |
+
from .data_loader import RouteGraph
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_hub_scores(graph: RouteGraph) -> list[str]:
|
| 14 |
+
"""Compute hub scores for all airports. Returns top hub IATA codes.
|
| 15 |
+
|
| 16 |
+
Modifies airports in-place to set hub_score.
|
| 17 |
+
"""
|
| 18 |
+
for airport in graph.airports.values():
|
| 19 |
+
route_count = len(airport.routes)
|
| 20 |
+
if route_count < 5:
|
| 21 |
+
airport.hub_score = 0.0
|
| 22 |
+
continue
|
| 23 |
+
|
| 24 |
+
# Carrier diversity: unique carriers across all routes
|
| 25 |
+
carriers = set()
|
| 26 |
+
for route in airport.routes:
|
| 27 |
+
for c in route.carriers:
|
| 28 |
+
carriers.add(c["iata"])
|
| 29 |
+
carrier_diversity = len(carriers)
|
| 30 |
+
|
| 31 |
+
# Continent reach: unique continents reachable
|
| 32 |
+
continents = set()
|
| 33 |
+
for route in airport.routes:
|
| 34 |
+
dest = graph.airports.get(route.destination)
|
| 35 |
+
if dest:
|
| 36 |
+
continents.add(dest.continent)
|
| 37 |
+
continent_reach = len(continents)
|
| 38 |
+
|
| 39 |
+
airport.hub_score = route_count * (carrier_diversity ** 0.5) * (continent_reach ** 0.3)
|
| 40 |
+
|
| 41 |
+
# Normalize scores to 0-100
|
| 42 |
+
max_score = max((a.hub_score for a in graph.airports.values()), default=1.0)
|
| 43 |
+
if max_score > 0:
|
| 44 |
+
for airport in graph.airports.values():
|
| 45 |
+
airport.hub_score = round(airport.hub_score / max_score * 100, 2)
|
| 46 |
+
|
| 47 |
+
# Return top hubs
|
| 48 |
+
hubs = sorted(
|
| 49 |
+
[a for a in graph.airports.values() if len(a.routes) >= HUB_MIN_ROUTES],
|
| 50 |
+
key=lambda a: -a.hub_score,
|
| 51 |
+
)
|
| 52 |
+
return [h.iata for h in hubs[:HUB_TOP_N]]
|
backend/main.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application — flight search backend."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi.staticfiles import StaticFiles
|
| 9 |
+
from fastapi.responses import FileResponse
|
| 10 |
+
|
| 11 |
+
from .api import airports, calendar, search
|
| 12 |
+
from .data_loader import get_route_graph
|
| 13 |
+
from .hub_detector import compute_hub_scores
|
| 14 |
+
|
| 15 |
+
app = FastAPI(title="Flight Search API", version="1.0.0")
|
| 16 |
+
|
| 17 |
+
# CORS for development
|
| 18 |
+
app.add_middleware(
|
| 19 |
+
CORSMiddleware,
|
| 20 |
+
allow_origins=["*"],
|
| 21 |
+
allow_credentials=True,
|
| 22 |
+
allow_methods=["*"],
|
| 23 |
+
allow_headers=["*"],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Register API routers
|
| 27 |
+
app.include_router(airports.router)
|
| 28 |
+
app.include_router(search.router)
|
| 29 |
+
app.include_router(calendar.router)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@app.on_event("startup")
|
| 33 |
+
async def startup():
|
| 34 |
+
"""Load data and compute hub scores on startup."""
|
| 35 |
+
t0 = time.time()
|
| 36 |
+
graph = get_route_graph()
|
| 37 |
+
hubs = compute_hub_scores(graph)
|
| 38 |
+
elapsed = time.time() - t0
|
| 39 |
+
print(f"Loaded {len(graph.airports)} airports, {len(hubs)} hubs in {elapsed:.1f}s")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.get("/api/health")
|
| 43 |
+
async def health():
|
| 44 |
+
graph = get_route_graph()
|
| 45 |
+
return {"status": "ok", "airports": len(graph.airports)}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Serve frontend static files (production)
|
| 49 |
+
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend", "dist")
|
| 50 |
+
if os.path.isdir(STATIC_DIR):
|
| 51 |
+
app.mount("/assets", StaticFiles(directory=os.path.join(STATIC_DIR, "assets")), name="assets")
|
| 52 |
+
|
| 53 |
+
@app.get("/{full_path:path}")
|
| 54 |
+
async def serve_frontend(full_path: str):
|
| 55 |
+
"""Serve the React SPA for all non-API routes."""
|
| 56 |
+
file_path = os.path.join(STATIC_DIR, full_path)
|
| 57 |
+
if os.path.isfile(file_path):
|
| 58 |
+
return FileResponse(file_path)
|
| 59 |
+
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|
backend/models.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for API request/response contracts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from datetime import date, datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CabinClass(str, Enum):
|
| 13 |
+
economy = "economy"
|
| 14 |
+
premium_economy = "premium_economy"
|
| 15 |
+
business = "business"
|
| 16 |
+
first = "first"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TripType(str, Enum):
|
| 20 |
+
one_way = "one_way"
|
| 21 |
+
round_trip = "round_trip"
|
| 22 |
+
multi_city = "multi_city"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SortBy(str, Enum):
|
| 26 |
+
best = "best"
|
| 27 |
+
cheapest = "cheapest"
|
| 28 |
+
fastest = "fastest"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- Airport ---
|
| 32 |
+
|
| 33 |
+
class AirportInfo(BaseModel):
|
| 34 |
+
iata: str
|
| 35 |
+
name: str
|
| 36 |
+
city_name: str
|
| 37 |
+
country: str
|
| 38 |
+
country_code: str
|
| 39 |
+
continent: str
|
| 40 |
+
latitude: float
|
| 41 |
+
longitude: float
|
| 42 |
+
timezone: str
|
| 43 |
+
hub_score: float = 0.0
|
| 44 |
+
route_count: int = 0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# --- Flight segment ---
|
| 48 |
+
|
| 49 |
+
class FlightSegment(BaseModel):
|
| 50 |
+
airline_code: str
|
| 51 |
+
airline_name: str
|
| 52 |
+
flight_number: str
|
| 53 |
+
aircraft: str
|
| 54 |
+
origin: str
|
| 55 |
+
origin_city: str
|
| 56 |
+
destination: str
|
| 57 |
+
destination_city: str
|
| 58 |
+
departure: datetime
|
| 59 |
+
arrival: datetime
|
| 60 |
+
duration_minutes: int
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --- Flight offer (may have multiple segments) ---
|
| 64 |
+
|
| 65 |
+
class FlightOffer(BaseModel):
|
| 66 |
+
id: str
|
| 67 |
+
segments: list[FlightSegment]
|
| 68 |
+
total_duration_minutes: int
|
| 69 |
+
stops: int
|
| 70 |
+
price_usd: float
|
| 71 |
+
cabin_class: CabinClass
|
| 72 |
+
origin: str
|
| 73 |
+
destination: str
|
| 74 |
+
departure: datetime
|
| 75 |
+
arrival: datetime
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- Search request ---
|
| 79 |
+
|
| 80 |
+
class SearchLeg(BaseModel):
|
| 81 |
+
origin: str = Field(..., min_length=3, max_length=3, description="IATA code")
|
| 82 |
+
destination: str = Field(..., min_length=3, max_length=3, description="IATA code")
|
| 83 |
+
date: date
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Passengers(BaseModel):
|
| 87 |
+
adults: int = Field(1, ge=1, le=9)
|
| 88 |
+
children: int = Field(0, ge=0, le=9)
|
| 89 |
+
infants: int = Field(0, ge=0, le=4)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def total(self) -> int:
|
| 93 |
+
return self.adults + self.children + self.infants
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Filters(BaseModel):
|
| 97 |
+
max_stops: Optional[int] = None
|
| 98 |
+
max_price: Optional[float] = None
|
| 99 |
+
max_duration_minutes: Optional[int] = None
|
| 100 |
+
airlines: Optional[list[str]] = None # IATA codes to include
|
| 101 |
+
departure_time_min: Optional[str] = None # "06:00"
|
| 102 |
+
departure_time_max: Optional[str] = None # "18:00"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class SearchRequest(BaseModel):
|
| 106 |
+
trip_type: TripType = TripType.round_trip
|
| 107 |
+
legs: list[SearchLeg] = Field(..., min_length=1, max_length=6)
|
| 108 |
+
passengers: Passengers = Passengers()
|
| 109 |
+
cabin_class: CabinClass = CabinClass.economy
|
| 110 |
+
filters: Filters = Filters()
|
| 111 |
+
sort_by: SortBy = SortBy.best
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SearchResponse(BaseModel):
|
| 115 |
+
outbound_flights: list[FlightOffer]
|
| 116 |
+
return_flights: list[FlightOffer] = []
|
| 117 |
+
search_id: str
|
| 118 |
+
origin: str
|
| 119 |
+
destination: str
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# --- Calendar ---
|
| 123 |
+
|
| 124 |
+
class CalendarDay(BaseModel):
|
| 125 |
+
date: date
|
| 126 |
+
cheapest_price: Optional[float] = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class CalendarResponse(BaseModel):
|
| 130 |
+
origin: str
|
| 131 |
+
destination: str
|
| 132 |
+
year: int
|
| 133 |
+
month: int
|
| 134 |
+
days: list[CalendarDay]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# --- Autocomplete ---
|
| 138 |
+
|
| 139 |
+
class AutocompleteResult(BaseModel):
|
| 140 |
+
iata: str
|
| 141 |
+
name: str
|
| 142 |
+
city_name: str
|
| 143 |
+
country: str
|
| 144 |
+
display_name: str
|
| 145 |
+
hub_score: float = 0.0
|
backend/price_engine.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Price formula: base + 8 multipliers.
|
| 2 |
+
|
| 3 |
+
base_usd = 40 + (distance_km * 0.08)
|
| 4 |
+
final = base * class * day_of_week * time_of_day * season * demand * advance * jitter
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from datetime import date, datetime, timedelta
|
| 11 |
+
|
| 12 |
+
from .config import (
|
| 13 |
+
ADVANCE_MULTIPLIERS,
|
| 14 |
+
BASE_FIXED_USD,
|
| 15 |
+
BASE_PER_KM_USD,
|
| 16 |
+
CLASS_MULTIPLIERS,
|
| 17 |
+
DAY_MULTIPLIERS,
|
| 18 |
+
EU_CONTINENTS,
|
| 19 |
+
EU_SUMMER_BONUS,
|
| 20 |
+
EU_SUMMER_MONTHS,
|
| 21 |
+
HIGH_COMPETITION_DISCOUNT,
|
| 22 |
+
JITTER_RANGE,
|
| 23 |
+
MONOPOLY_ROUTE_BONUS,
|
| 24 |
+
SEASON_MULTIPLIERS,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_price(
|
| 29 |
+
distance_km: int,
|
| 30 |
+
cabin_class: str,
|
| 31 |
+
departure_date: date,
|
| 32 |
+
departure_hour: int,
|
| 33 |
+
num_carriers: int,
|
| 34 |
+
dest_continent: str,
|
| 35 |
+
rng: random.Random,
|
| 36 |
+
booking_date: date | None = None,
|
| 37 |
+
) -> float:
|
| 38 |
+
"""Compute flight price using the full pricing formula."""
|
| 39 |
+
base = BASE_FIXED_USD + (distance_km * BASE_PER_KM_USD)
|
| 40 |
+
|
| 41 |
+
# 1. Cabin class
|
| 42 |
+
class_mult = CLASS_MULTIPLIERS.get(cabin_class, 1.0)
|
| 43 |
+
|
| 44 |
+
# 2. Day of week
|
| 45 |
+
day_mult = DAY_MULTIPLIERS.get(departure_date.weekday(), 1.0)
|
| 46 |
+
|
| 47 |
+
# 3. Time of day
|
| 48 |
+
if 6 <= departure_hour <= 8:
|
| 49 |
+
time_mult = 1.10 # Morning peak
|
| 50 |
+
elif 16 <= departure_hour <= 19:
|
| 51 |
+
time_mult = 1.15 # Evening peak
|
| 52 |
+
elif departure_hour >= 22 or departure_hour <= 5:
|
| 53 |
+
time_mult = 0.85 # Red-eye discount
|
| 54 |
+
else:
|
| 55 |
+
time_mult = 1.00
|
| 56 |
+
|
| 57 |
+
# 4. Season
|
| 58 |
+
season_mult = SEASON_MULTIPLIERS.get(departure_date.month, 1.0)
|
| 59 |
+
# EU summer bonus
|
| 60 |
+
if dest_continent in EU_CONTINENTS and departure_date.month in EU_SUMMER_MONTHS:
|
| 61 |
+
season_mult += EU_SUMMER_BONUS
|
| 62 |
+
|
| 63 |
+
# 5. Demand (based on competition)
|
| 64 |
+
if num_carriers == 1:
|
| 65 |
+
demand_mult = 1.0 + MONOPOLY_ROUTE_BONUS
|
| 66 |
+
elif num_carriers >= 4:
|
| 67 |
+
demand_mult = 1.0 - HIGH_COMPETITION_DISCOUNT
|
| 68 |
+
else:
|
| 69 |
+
demand_mult = 1.0
|
| 70 |
+
|
| 71 |
+
# 6. Advance booking
|
| 72 |
+
if booking_date is None:
|
| 73 |
+
booking_date = date.today()
|
| 74 |
+
days_advance = (departure_date - booking_date).days
|
| 75 |
+
if days_advance < 0:
|
| 76 |
+
days_advance = 0
|
| 77 |
+
advance_mult = 1.0
|
| 78 |
+
for threshold, mult in ADVANCE_MULTIPLIERS:
|
| 79 |
+
if days_advance <= threshold:
|
| 80 |
+
advance_mult = mult
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
# 7. Jitter (seeded)
|
| 84 |
+
jitter = 1.0 + rng.uniform(-JITTER_RANGE, JITTER_RANGE)
|
| 85 |
+
|
| 86 |
+
price = base * class_mult * day_mult * time_mult * season_mult * demand_mult * advance_mult * jitter
|
| 87 |
+
|
| 88 |
+
# Round to nearest dollar, minimum $25
|
| 89 |
+
return max(25.0, round(price, 0))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_calendar_price(
|
| 93 |
+
distance_km: int,
|
| 94 |
+
cabin_class: str,
|
| 95 |
+
target_date: date,
|
| 96 |
+
num_carriers: int,
|
| 97 |
+
dest_continent: str,
|
| 98 |
+
rng: random.Random,
|
| 99 |
+
) -> float:
|
| 100 |
+
"""Compute cheapest flight price for a given date (for calendar view).
|
| 101 |
+
|
| 102 |
+
Uses noon departure and 14-day advance booking as baseline.
|
| 103 |
+
"""
|
| 104 |
+
return compute_price(
|
| 105 |
+
distance_km=distance_km,
|
| 106 |
+
cabin_class=cabin_class,
|
| 107 |
+
departure_date=target_date,
|
| 108 |
+
departure_hour=12,
|
| 109 |
+
num_carriers=num_carriers,
|
| 110 |
+
dest_continent=dest_continent,
|
| 111 |
+
rng=rng,
|
| 112 |
+
booking_date=target_date - timedelta(days=14),
|
| 113 |
+
)
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.27.0
|
| 3 |
+
pydantic>=2.6.0
|
backend/route_finder.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Direct + 1-stop + 2-stop route discovery."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from .config import MAX_1STOP_DISTANCE_RATIO, MAX_2STOP_DISTANCE_RATIO
|
| 8 |
+
from .data_loader import Route, RouteGraph
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RoutePlan:
|
| 13 |
+
"""A planned route from origin to destination (may have multiple legs)."""
|
| 14 |
+
legs: list[Route] # Each leg has origin implicitly from position
|
| 15 |
+
waypoints: list[str] # [origin, hub1, ..., destination]
|
| 16 |
+
total_distance_km: int
|
| 17 |
+
total_duration_min: int
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def stops(self) -> int:
|
| 21 |
+
return len(self.legs) - 1
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def find_routes(
|
| 25 |
+
graph: RouteGraph,
|
| 26 |
+
origin: str,
|
| 27 |
+
destination: str,
|
| 28 |
+
hub_iatas: list[str],
|
| 29 |
+
max_stops: int | None = None,
|
| 30 |
+
) -> list[RoutePlan]:
|
| 31 |
+
"""Find all route plans from origin to destination.
|
| 32 |
+
|
| 33 |
+
Returns direct, 1-stop, and 2-stop routes.
|
| 34 |
+
"""
|
| 35 |
+
results: list[RoutePlan] = []
|
| 36 |
+
|
| 37 |
+
if max_stops is not None and max_stops < 0:
|
| 38 |
+
return results
|
| 39 |
+
|
| 40 |
+
# Direct route
|
| 41 |
+
direct = graph.get_direct_route(origin, destination)
|
| 42 |
+
if direct:
|
| 43 |
+
results.append(RoutePlan(
|
| 44 |
+
legs=[direct],
|
| 45 |
+
waypoints=[origin, destination],
|
| 46 |
+
total_distance_km=direct.distance_km,
|
| 47 |
+
total_duration_min=direct.duration_min,
|
| 48 |
+
))
|
| 49 |
+
|
| 50 |
+
if max_stops is not None and max_stops == 0:
|
| 51 |
+
return results
|
| 52 |
+
|
| 53 |
+
# 1-stop routes through hubs
|
| 54 |
+
direct_distance = direct.distance_km if direct else _estimate_distance(graph, origin, destination)
|
| 55 |
+
if direct_distance is None:
|
| 56 |
+
return results
|
| 57 |
+
|
| 58 |
+
origin_routes = graph.get_outbound_routes(origin)
|
| 59 |
+
|
| 60 |
+
for hub in hub_iatas:
|
| 61 |
+
if hub == origin or hub == destination:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
leg1 = origin_routes.get(hub)
|
| 65 |
+
if not leg1:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
leg2 = graph.get_direct_route(hub, destination)
|
| 69 |
+
if not leg2:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
total_dist = leg1.distance_km + leg2.distance_km
|
| 73 |
+
if total_dist > direct_distance * MAX_1STOP_DISTANCE_RATIO:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
total_dur = leg1.duration_min + leg2.duration_min + 90 # +90 min layover estimate
|
| 77 |
+
results.append(RoutePlan(
|
| 78 |
+
legs=[leg1, leg2],
|
| 79 |
+
waypoints=[origin, hub, destination],
|
| 80 |
+
total_distance_km=total_dist,
|
| 81 |
+
total_duration_min=total_dur,
|
| 82 |
+
))
|
| 83 |
+
|
| 84 |
+
if max_stops is not None and max_stops <= 1:
|
| 85 |
+
return results
|
| 86 |
+
|
| 87 |
+
# 2-stop routes through pairs of hubs (limit to top hubs for performance)
|
| 88 |
+
top_hubs = hub_iatas[:60]
|
| 89 |
+
for hub1 in top_hubs:
|
| 90 |
+
if hub1 == origin or hub1 == destination:
|
| 91 |
+
continue
|
| 92 |
+
leg1 = origin_routes.get(hub1)
|
| 93 |
+
if not leg1:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
hub1_routes = graph.get_outbound_routes(hub1)
|
| 97 |
+
for hub2 in top_hubs:
|
| 98 |
+
if hub2 == origin or hub2 == destination or hub2 == hub1:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
leg2 = hub1_routes.get(hub2)
|
| 102 |
+
if not leg2:
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
leg3 = graph.get_direct_route(hub2, destination)
|
| 106 |
+
if not leg3:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
total_dist = leg1.distance_km + leg2.distance_km + leg3.distance_km
|
| 110 |
+
if total_dist > direct_distance * MAX_2STOP_DISTANCE_RATIO:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
total_dur = (leg1.duration_min + leg2.duration_min + leg3.duration_min
|
| 114 |
+
+ 90 + 90) # Two layovers
|
| 115 |
+
results.append(RoutePlan(
|
| 116 |
+
legs=[leg1, leg2, leg3],
|
| 117 |
+
waypoints=[origin, hub1, hub2, destination],
|
| 118 |
+
total_distance_km=total_dist,
|
| 119 |
+
total_duration_min=total_dur,
|
| 120 |
+
))
|
| 121 |
+
|
| 122 |
+
return results
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _estimate_distance(graph: RouteGraph, origin: str, destination: str) -> int | None:
|
| 126 |
+
"""Estimate great-circle distance between two airports using coordinates."""
|
| 127 |
+
import math
|
| 128 |
+
|
| 129 |
+
o = graph.airports.get(origin)
|
| 130 |
+
d = graph.airports.get(destination)
|
| 131 |
+
if not o or not d:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
lat1, lon1 = math.radians(o.latitude), math.radians(o.longitude)
|
| 135 |
+
lat2, lon2 = math.radians(d.latitude), math.radians(d.longitude)
|
| 136 |
+
|
| 137 |
+
dlat = lat2 - lat1
|
| 138 |
+
dlon = lon2 - lon1
|
| 139 |
+
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
|
| 140 |
+
c = 2 * math.asin(math.sqrt(a))
|
| 141 |
+
return int(c * 6371) # Earth radius in km
|
backend/seed_utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic seeding utilities.
|
| 2 |
+
|
| 3 |
+
Same search parameters always produce the same flights and prices.
|
| 4 |
+
Uses SHA-256 hash of search params → integer seed for random.Random.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import hashlib
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_seed(*parts: str | int | float) -> int:
|
| 12 |
+
"""Create a deterministic seed from arbitrary parts."""
|
| 13 |
+
key = "|".join(str(p) for p in parts)
|
| 14 |
+
h = hashlib.sha256(key.encode()).hexdigest()
|
| 15 |
+
return int(h[:16], 16)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def seeded_random(*parts: str | int | float) -> random.Random:
|
| 19 |
+
"""Return a seeded Random instance for the given search params."""
|
| 20 |
+
return random.Random(make_seed(*parts))
|
description.md
ADDED
|
@@ -0,0 +1,1122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ECMoE — Method and Experiment Description
|
| 2 |
+
|
| 3 |
+
## 1. Problem Statement
|
| 4 |
+
|
| 5 |
+
In Mixture-of-Experts (MoE) models with expert parallelism, each token's hidden state must be communicated between GPUs twice per MoE layer:
|
| 6 |
+
|
| 7 |
+
1. **Dispatch (all-to-all):** The hidden state is sent from the token's source GPU to the GPU hosting its assigned expert(s).
|
| 8 |
+
2. **Gather (all-to-all):** The expert output is sent back to the source GPU.
|
| 9 |
+
|
| 10 |
+
For a model like Qwen3-30B-A3B with `hidden_dim=2048` and 48 MoE layers, each token requires transmitting `2 × 48 × 2048 × 2 bytes = 384 KB` of data per forward pass (in BF16). At scale, this communication dominates inference latency.
|
| 11 |
+
|
| 12 |
+
This project investigates methods to **compress these hidden-state vectors** before transmission, reducing communication volume while preserving model quality.
|
| 13 |
+
|
| 14 |
+
### Training paradigms
|
| 15 |
+
|
| 16 |
+
This project uses two training paradigms:
|
| 17 |
+
|
| 18 |
+
**Offline (Tasks 2–4):** Compressors are trained on **cached hidden states**, not end-to-end through the LLM:
|
| 19 |
+
|
| 20 |
+
1. **Capture:** Run the unmodified LLM on calibration data and cache MoE layer inputs/outputs to disk.
|
| 21 |
+
2. **Train:** Train each compressor/decompressor pair independently on the cached data, minimizing a local reconstruction loss. No gradients flow through the LLM.
|
| 22 |
+
3. **Evaluate:** Insert trained compressors into the live model via forward hooks and measure perplexity.
|
| 23 |
+
|
| 24 |
+
Each pair is trained in isolation — no joint optimization across layers, no end-to-end backpropagation. This is cheap (minutes per layer) but means compressors cannot adapt to how errors compound across layers.
|
| 25 |
+
|
| 26 |
+
**End-to-end (Task 5):** Compressors are trained **through the live LLM** using the language modeling objective:
|
| 27 |
+
|
| 28 |
+
1. **Insert:** Register per-layer compressor/decompressor pairs as forward pre-hooks on each MoE layer.
|
| 29 |
+
2. **Train:** Run standard next-token prediction. The LLM weights are frozen; only compressor parameters receive gradients. Gradients flow through the entire frozen LLM.
|
| 30 |
+
3. **Evaluate:** Same hook-based perplexity evaluation as offline methods.
|
| 31 |
+
|
| 32 |
+
All 48 compressors are optimized jointly through a single global loss. This allows the system to learn how compression errors at early layers affect all downstream layers.
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## 2. Model Specification
|
| 37 |
+
|
| 38 |
+
| Property | Value |
|
| 39 |
+
|---|---|
|
| 40 |
+
| Architecture | Qwen3-30B-A3B-Instruct-2507 |
|
| 41 |
+
| Total parameters | 30.53B |
|
| 42 |
+
| Activated parameters | 3.35B |
|
| 43 |
+
| Hidden dimension | 2048 |
|
| 44 |
+
| Number of layers | 48 (all MoE) |
|
| 45 |
+
| Number of experts | 128 per layer |
|
| 46 |
+
| Top-k routing | 8 experts per token |
|
| 47 |
+
| Attention heads | 32 (Q), 4 (KV) |
|
| 48 |
+
| Head dimension | 128 |
|
| 49 |
+
| MoE expert FFN intermediate size | 768 |
|
| 50 |
+
| Vocabulary size | 151,936 |
|
| 51 |
+
|
| 52 |
+
All tasks use the same model variant and precision:
|
| 53 |
+
|
| 54 |
+
| Variant | Used in | Loading | VRAM |
|
| 55 |
+
|---|---|---|---|
|
| 56 |
+
| Qwen3-30B-A3B-Instruct-2507 | All tasks (1–5) | Full BF16 | ~60 GB |
|
| 57 |
+
|
| 58 |
+
**Tasks 1–4:** Single GPU (`device="cuda:0"`). The ~60 GB model fits on one H100 80 GB with headroom for inference activations. Using single-GPU avoids the overhead of cross-GPU communication from `device_map="auto"`.
|
| 59 |
+
|
| 60 |
+
**Task 5:** Multi-GPU via `device_map="auto"` across 4 GPUs. Backpropagation through the frozen model during end-to-end training requires additional VRAM for activations and gradient checkpoints that exceed single-GPU capacity.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## 3. Data Collection
|
| 65 |
+
|
| 66 |
+
### 3.1 Calibration Data
|
| 67 |
+
|
| 68 |
+
- **Dataset:** allenai/Dolci-Instruct-SFT (train split)
|
| 69 |
+
- **Format:** Chat-formatted instruction data, tokenized via `tokenizer.apply_chat_template()`
|
| 70 |
+
- **Sequences:** Up to 256 samples, each tokenized independently (one conversation = one sequence)
|
| 71 |
+
- **Max length:** 2048 tokens per sequence (configurable via `--max-length`)
|
| 72 |
+
- **SFT mode:** Labels mask non-assistant tokens with -100; perplexity computed on responses only
|
| 73 |
+
- **Response-only collection:** By default, only assistant-response tokens are captured
|
| 74 |
+
(positions where `labels != -100`). This ensures offline compressor training (Tasks 2–4)
|
| 75 |
+
trains on the same token distribution that PPL evaluation measures. Use `--no-response-only`
|
| 76 |
+
for legacy all-token collection.
|
| 77 |
+
- **Total tokens collected:** 100,000 per MoE layer (response tokens only by default)
|
| 78 |
+
|
| 79 |
+
### 3.2 Hidden State Capture
|
| 80 |
+
|
| 81 |
+
PyTorch forward hooks are registered on each MoE module:
|
| 82 |
+
- **Pre-forward hook** captures dispatch states (MoE layer inputs)
|
| 83 |
+
- **Post-forward hook** captures gather states (MoE layer outputs)
|
| 84 |
+
|
| 85 |
+
**Token filtering:** `MoEHiddenStateCollector` supports a per-sequence boolean mask
|
| 86 |
+
(`set_token_mask(mask)`). When `response_only=True` (default), the mask is derived from
|
| 87 |
+
`labels != -100` before each forward pass. The same mask is applied to all 48 MoE layers
|
| 88 |
+
within a sequence, preserving token alignment across layers. Positions where the mask is
|
| 89 |
+
`False` (system, user, template markup, padding) are not collected.
|
| 90 |
+
|
| 91 |
+
Each captured tensor has shape `[N, 2048]` where N = number of response tokens (or all
|
| 92 |
+
tokens if `response_only=False`). States are stored in the model's native dtype (`bfloat16`)
|
| 93 |
+
on CPU.
|
| 94 |
+
|
| 95 |
+
**Implementation:** `MoEHiddenStateCollector` class in `src/model_utils.py`.
|
| 96 |
+
|
| 97 |
+
### 3.3 Storage
|
| 98 |
+
|
| 99 |
+
```
|
| 100 |
+
data/hidden_states/
|
| 101 |
+
├── dispatch_states.pt # dict {layer_name: tensor [100000, 2048]}
|
| 102 |
+
├── gather_states.pt # dict {layer_name: tensor [100000, 2048]}
|
| 103 |
+
└── metadata.json # model name, dims, token count, layer names
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Total size: ~37 GB (18.5 GB dispatch + 18.5 GB gather, bfloat16 = 2 bytes/value).
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 4. Evaluation Methodology
|
| 111 |
+
|
| 112 |
+
### 4.1 Reconstruction Metrics (Offline)
|
| 113 |
+
|
| 114 |
+
Computed on cached hidden states without running the full model:
|
| 115 |
+
|
| 116 |
+
| Metric | Formula | Notes |
|
| 117 |
+
|---|---|---|
|
| 118 |
+
| MSE | `mean((x - x')²)` | Mean squared error |
|
| 119 |
+
| Cosine Similarity | `mean(cos(x, x'))` | Per-token, averaged |
|
| 120 |
+
| Relative Error | `mean(‖x - x'‖₂ / ‖x‖₂)` | Per-token L2 relative error |
|
| 121 |
+
| SNR (dB) | `10 · log₁₀(signal_power / noise_power)` | Signal-to-noise ratio |
|
| 122 |
+
|
| 123 |
+
**Implementation:** `src/metrics.py`
|
| 124 |
+
|
| 125 |
+
### 4.2 End-to-End Perplexity (Online)
|
| 126 |
+
|
| 127 |
+
The true impact of compression is measured by evaluating cross-entropy perplexity on allenai/Dolci-Instruct-SFT (the same dataset used for calibration/training) with compression hooks active:
|
| 128 |
+
|
| 129 |
+
- **Dispatch compression:** A pre-forward hook on each MoE block applies `compress → decompress` to the input hidden states before they enter the block.
|
| 130 |
+
- **Evaluation:** 50,000 sequences, max length 2048 tokens.
|
| 131 |
+
- **SFT mode:** Perplexity is computed on assistant response tokens only. Non-response tokens
|
| 132 |
+
(system, user, template markup) are labeled with -100 and excluded from the loss.
|
| 133 |
+
This measures the model's ability to generate correct responses, not to predict prompt tokens.
|
| 134 |
+
|
| 135 |
+
**Caveat:** This simulation also affects the router's input. In real expert parallelism, the router runs on the original hidden state at the source node. Our simulation gives a **conservative lower bound** — the true impact would be smaller.
|
| 136 |
+
|
| 137 |
+
**Implementation:** `evaluate_perplexity_with_compression()`, `evaluate_perplexity_with_perlayer_compression()`, `evaluate_perplexity_with_stale_compression()` in `src/model_utils.py`.
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## 5. Method Descriptions
|
| 142 |
+
|
| 143 |
+
### 5.1 Quantization Baseline (Task 2)
|
| 144 |
+
|
| 145 |
+
**Idea:** Reduce the bit width of hidden-state elements from BF16 (16 bits) to INT8/INT4/INT2.
|
| 146 |
+
|
| 147 |
+
**Symmetric (absmax) quantization:**
|
| 148 |
+
```
|
| 149 |
+
scale = max(|x|) / (2^(bits-1) - 1) # per-token
|
| 150 |
+
x_q = round(x / scale) # quantize
|
| 151 |
+
x' = x_q * scale # dequantize
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
**Asymmetric (zero-point) quantization:**
|
| 155 |
+
```
|
| 156 |
+
scale = (max(x) - min(x)) / (2^bits - 1)
|
| 157 |
+
zero_point = round(-min(x) / scale)
|
| 158 |
+
x_q = round(x / scale + zero_point)
|
| 159 |
+
x' = (x_q - zero_point) * scale
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
**Compression ratios:**
|
| 163 |
+
|
| 164 |
+
| Bits | Effective Ratio | Bytes/token (hidden_dim=2048) |
|
| 165 |
+
|---|---|---|
|
| 166 |
+
| INT8 (absmax) | ~2.0x | 2050 (2048 + 2 for scale) |
|
| 167 |
+
| INT4 (absmax) | ~4.0x | 1026 (1024 + 2 for scale) |
|
| 168 |
+
| INT2 (absmax) | ~8.0x | 514 (512 + 2 for scale) |
|
| 169 |
+
|
| 170 |
+
**Additional parameters:** 0 (quantization is parameter-free).
|
| 171 |
+
|
| 172 |
+
**Implementation:** `src/run_quantization.py`
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
### 5.2 Shared Neural Compressor (Task 3)
|
| 177 |
+
|
| 178 |
+
**Idea:** Train a single-layer linear autoencoder shared across all 48 MoE layers.
|
| 179 |
+
|
| 180 |
+
**Architecture:**
|
| 181 |
+
```
|
| 182 |
+
Compressor: Linear(2048, bottleneck_dim) + bias
|
| 183 |
+
Decompressor: Linear(bottleneck_dim, 2048) + bias
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
One compressor-decompressor pair is shared across all layers. Training data pools dispatch states from all 48 layers. Training is offline: the compressor minimizes reconstruction loss on cached hidden states, with no gradients flowing through the LLM.
|
| 187 |
+
|
| 188 |
+
**Compression ratios:** `hidden_dim / bottleneck_dim` = {2x, 4x, 8x, 16x} corresponding to `bottleneck_dim` = {1024, 512, 256, 128}.
|
| 189 |
+
|
| 190 |
+
**Training hyperparameters:**
|
| 191 |
+
|
| 192 |
+
| Parameter | Value |
|
| 193 |
+
|---|---|
|
| 194 |
+
| Optimizer | Adam |
|
| 195 |
+
| Learning rate | 1e-3 |
|
| 196 |
+
| LR schedule | Cosine annealing (T_max = epochs) |
|
| 197 |
+
| Max epochs | 50 |
|
| 198 |
+
| Batch size | 2048 |
|
| 199 |
+
| Early stopping patience | 8 epochs |
|
| 200 |
+
| Validation fraction | 10% |
|
| 201 |
+
| Loss function | MSE + 0.1 × (1 - cosine_similarity) |
|
| 202 |
+
|
| 203 |
+
**Loss function:**
|
| 204 |
+
```
|
| 205 |
+
L = MSE(x', x) + λ · (1 - mean(cos_sim(x', x)))
|
| 206 |
+
```
|
| 207 |
+
where `λ = 0.1` (cosine_weight). The cosine term encourages preserving direction, not just magnitude.
|
| 208 |
+
|
| 209 |
+
**Parameter count:**
|
| 210 |
+
```
|
| 211 |
+
params = (2048 × b + b) + (b × 2048 + 2048)
|
| 212 |
+
```
|
| 213 |
+
where `b` = bottleneck_dim.
|
| 214 |
+
|
| 215 |
+
| Ratio | Bottleneck | Parameters | % of Activated |
|
| 216 |
+
|---|---|---|---|
|
| 217 |
+
| 2x | 1024 | 4.20M | 0.125% |
|
| 218 |
+
| 4x | 512 | 2.10M | 0.063% |
|
| 219 |
+
| 8x | 256 | 1.05M | 0.031% |
|
| 220 |
+
| 16x | 128 | 0.53M | 0.016% |
|
| 221 |
+
|
| 222 |
+
**Implementation:** `src/run_neural_compressor.py`
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
### 5.3 Per-Layer Neural Compressor (Task 3b)
|
| 227 |
+
|
| 228 |
+
**Motivation:** Hidden state distributions vary dramatically across layers:
|
| 229 |
+
- Standard deviation: 0.16 (layer 0) → 1.21 (layer 47)
|
| 230 |
+
- Kurtosis: 3 (near-Gaussian, early layers) → 81,340 (extremely heavy-tailed, late layers)
|
| 231 |
+
|
| 232 |
+
A single shared compressor cannot adapt to this variation.
|
| 233 |
+
|
| 234 |
+
**Architecture:** Same `Compressor` + `Decompressor` structure, but **48 independent pairs** — one per MoE layer. Each layer's compressor is trained independently and only on that layer's cached dispatch data. There is no joint optimization across layers.
|
| 235 |
+
|
| 236 |
+
**Compression ratios:** Same as shared: {2x, 4x, 8x, 16x}.
|
| 237 |
+
|
| 238 |
+
**Training:** Same hyperparameters as shared (see Section 5.2). Each layer is trained independently on its own 100K token dispatch data (90% train / 10% val).
|
| 239 |
+
|
| 240 |
+
**Parameter count:**
|
| 241 |
+
```
|
| 242 |
+
params = 48 × (2048 × b + b + b × 2048 + 2048)
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
| Ratio | Bottleneck | Parameters | % of Activated |
|
| 246 |
+
|---|---|---|---|
|
| 247 |
+
| 2x | 1024 | 201.47M | 6.008% |
|
| 248 |
+
| 4x | 512 | 100.79M | 3.006% |
|
| 249 |
+
| 8x | 256 | 50.44M | 1.504% |
|
| 250 |
+
| 16x | 128 | 25.27M | 0.754% |
|
| 251 |
+
|
| 252 |
+
**Implementation:** `src/run_perlayer_compressor.py`
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
### 5.4 Stale-Conditioned Compressor (Tasks 4a/4b)
|
| 257 |
+
|
| 258 |
+
**Motivation:** Adjacent MoE layers process the same token, so their hidden states are correlated. A decompressor can exploit this by receiving a "stale" signal — the hidden state from a nearby layer that was already transmitted — as side information.
|
| 259 |
+
|
| 260 |
+
**Reference layer grouping (stride=12):**
|
| 261 |
+
- Reference layers: {0, 12, 24, 36} (4 layers)
|
| 262 |
+
- Layer 1–11 → stale from layer 0
|
| 263 |
+
- Layer 13–23 → stale from layer 12
|
| 264 |
+
- Layer 25–35 → stale from layer 24
|
| 265 |
+
- Layer 37–47 → stale from layer 36
|
| 266 |
+
|
| 267 |
+
**Architecture:**
|
| 268 |
+
- **Reference layers** use standard per-layer `Compressor` + `Decompressor` (no stale signal).
|
| 269 |
+
- **Non-reference layers** use `Compressor` + `StaleDecompressor`:
|
| 270 |
+
|
| 271 |
+
```
|
| 272 |
+
Compressor: Linear(2048, bottleneck_dim) + bias
|
| 273 |
+
StaleDecompressor: Linear(bottleneck_dim + stale_dim, 2048) + bias
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
The decompressor receives `cat(compressed_current, stale_signal)` as input.
|
| 277 |
+
|
| 278 |
+
**Two stale modes:**
|
| 279 |
+
|
| 280 |
+
| Mode | Task | Stale signal | StaleDecompressor input dim |
|
| 281 |
+
|---|---|---|---|
|
| 282 |
+
| Compressed (4a) | `--stale-mode compressed` | Compressed ref layer input (via ref's compressor) | `bottleneck_dim + bottleneck_dim` |
|
| 283 |
+
| Uncompressed (4b) | `--stale-mode uncompressed` | Raw ref layer input (full hidden dim) | `bottleneck_dim + 2048` |
|
| 284 |
+
|
| 285 |
+
**Training:**
|
| 286 |
+
1. **Phase 1:** Train reference layer compressors independently (standard per-layer autoencoder, same hyperparameters as Section 5.2).
|
| 287 |
+
2. **Phase 2:** Train non-reference layer compressors independently. For each non-ref layer:
|
| 288 |
+
- Current data: that layer's cached dispatch states
|
| 289 |
+
- Stale data: the reference layer's cached dispatch states (compressed or raw, depending on mode)
|
| 290 |
+
- The stale signal is **pre-computed and frozen** — the reference layer's compressor is not jointly optimized with non-reference layers
|
| 291 |
+
- Token alignment is guaranteed: `dispatch[layer_0][i]` and `dispatch[layer_5][i]` correspond to the same token
|
| 292 |
+
|
| 293 |
+
As with all neural methods in this project, training is offline on cached hidden states. No gradients flow through the LLM, and each layer's compressor is trained in isolation.
|
| 294 |
+
|
| 295 |
+
**Stale-conditioned training loss:** Same as Section 5.2 (`MSE + 0.1 × (1 - cos_sim)`), but the decompressor receives the concatenated input.
|
| 296 |
+
|
| 297 |
+
**Parameter count:**
|
| 298 |
+
|
| 299 |
+
For compressed stale (`stale_dim = bottleneck_dim`):
|
| 300 |
+
```
|
| 301 |
+
ref_pair = (2048 × b + b) + (b × 2048 + 2048)
|
| 302 |
+
nonref_pair = (2048 × b + b) + ((b + b) × 2048 + 2048)
|
| 303 |
+
total = 4 × ref_pair + 44 × nonref_pair
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
For uncompressed stale (`stale_dim = 2048`):
|
| 307 |
+
```
|
| 308 |
+
ref_pair = (2048 × b + b) + (b × 2048 + 2048)
|
| 309 |
+
nonref_pair = (2048 × b + b) + ((b + 2048) × 2048 + 2048)
|
| 310 |
+
total = 4 × ref_pair + 44 × nonref_pair
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
| Mode | Ratio | Bottleneck | Stale dim | Parameters | % of Activated |
|
| 314 |
+
|---|---|---|---|---|---|
|
| 315 |
+
| Compressed | 2x | 1024 | 1024 | 293.75M | 8.760% |
|
| 316 |
+
| Compressed | 4x | 512 | 512 | 146.92M | 4.382% |
|
| 317 |
+
| Compressed | 8x | 256 | 256 | 73.51M | 2.192% |
|
| 318 |
+
| Compressed | 16x | 128 | 128 | 36.80M | 1.098% |
|
| 319 |
+
| Uncompressed | 2x | 1024 | 2048 | 386.02M | 11.512% |
|
| 320 |
+
| Uncompressed | 4x | 512 | 2048 | 285.34M | 8.509% |
|
| 321 |
+
| Uncompressed | 8x | 256 | 2048 | 234.99M | 7.008% |
|
| 322 |
+
| Uncompressed | 16x | 128 | 2048 | 209.82M | 6.257% |
|
| 323 |
+
|
| 324 |
+
Note: The uncompressed stale method's parameter count does not scale down as aggressively because the `StaleDecompressor` input always includes the full 2048-dim stale signal, making the `(2048 × 2048)` weight block dominant.
|
| 325 |
+
|
| 326 |
+
**Perplexity evaluation with stale hooks:** During forward pass, a shared `stale_cache` dictionary stores reference layer inputs. PyTorch processes layers 0→47 sequentially, so layer 0's pre-hook fires before layer 1's, guaranteeing the stale cache is populated in time.
|
| 327 |
+
|
| 328 |
+
**Implementation:** `src/run_stale_compressor.py`, `evaluate_perplexity_with_stale_compression()` in `src/model_utils.py`.
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
### 5.5 End-to-End Per-Layer Compressor (Tasks 5a/5b)
|
| 333 |
+
|
| 334 |
+
**Motivation:** All offline methods (Tasks 3–4) share a fundamental limitation: each compressor is trained to minimize *local* reconstruction error in isolation. It cannot account for how its errors compound through downstream layers during a full forward pass. Additionally, the stale signal used during offline training is the *unperturbed* reference layer input, but during inference the reference layer itself is compressed, creating a train-inference mismatch.
|
| 335 |
+
|
| 336 |
+
End-to-end training addresses both issues by optimizing compressors through the full LLM forward pass using the language modeling objective.
|
| 337 |
+
|
| 338 |
+
**Architecture:** Same `Compressor` + `Decompressor` (5a) or `Compressor` + `StaleDecompressor` (5b) structure as Tasks 3b/4b. The compressor modules are identical — only the training objective differs.
|
| 339 |
+
|
| 340 |
+
**Training paradigm:**
|
| 341 |
+
1. Load the LLM in full BF16 across 4 GPUs. Freeze all LLM weights.
|
| 342 |
+
2. Insert per-layer compressor/decompressor pairs as forward pre-hooks on each MoE layer. Each pair is placed on the same GPU as its MoE layer.
|
| 343 |
+
3. Run standard next-token prediction on training data. Only compressor/decompressor parameters receive gradients.
|
| 344 |
+
4. Gradients flow backward through the entire frozen LLM, from the cross-entropy loss at the output back through all 48 layers to every compressor.
|
| 345 |
+
|
| 346 |
+
**Key difference from offline: joint optimization.** All 48 compressors share a single loss function (cross-entropy). Layer 0's compressor receives gradient signal about how its reconstruction error affects layers 1–47. The system implicitly learns to allocate more fidelity to layers where errors are most harmful to the final prediction.
|
| 347 |
+
|
| 348 |
+
**Stale signal gradient flow (5b):** Unlike offline Task 4b where the stale signal is pre-computed and frozen, end-to-end training does **not** detach the stale signal. Gradients flow through the stale path:
|
| 349 |
+
- A non-reference layer's decompressor receives `cat(compressed_current, stale)` where `stale` is the raw input to the reference layer
|
| 350 |
+
- During backward, gradients flow from the non-ref layer through `stale` to the reference layer's input, and further back to earlier layers
|
| 351 |
+
- This means reference layers' compressors are optimized not just for their own reconstruction, but also for how their inputs serve as stale side information for all downstream non-reference layers
|
| 352 |
+
- This eliminates the train-inference mismatch: during training, the stale signal already reflects upstream compression artifacts
|
| 353 |
+
|
| 354 |
+
**Near-identity initialization:**
|
| 355 |
+
- Compressor `W_c`: first `bottleneck_dim` rows of the identity matrix
|
| 356 |
+
- Decompressor `W_d`: first `bottleneck_dim` columns of the identity matrix
|
| 357 |
+
- Composition `W_d @ W_c ≈ I` (projects to first `b` dimensions and reconstructs)
|
| 358 |
+
- This ensures the initial forward pass is close to uncompressed, avoiding catastrophic initial loss from random projections. The optimizer then refines from this starting point.
|
| 359 |
+
|
| 360 |
+
**Model and data:**
|
| 361 |
+
- **Model:** Qwen/Qwen3-30B-A3B-Instruct-2507 (full BF16, same as all tasks)
|
| 362 |
+
- **Training data:** allenai/Dolci-Instruct-SFT, 500K sequences (HF) / 100K sequences (Megatron) sampled from train split,
|
| 363 |
+
max_length=2048 tokens per sequence
|
| 364 |
+
- **SFT mode:** Each conversation is tokenized independently (one sample = one sequence).
|
| 365 |
+
Labels mask non-assistant tokens with -100; loss is computed on assistant responses only.
|
| 366 |
+
Data is loaded by sampling N sequences from the dataset (not by packing tokens).
|
| 367 |
+
- **Evaluation:** allenai/Dolci-Instruct-SFT (same dataset, response-only perplexity)
|
| 368 |
+
|
| 369 |
+
**Two modes:**
|
| 370 |
+
|
| 371 |
+
| Mode | Task | Stale signal | Decompressor |
|
| 372 |
+
|---|---|---|---|
|
| 373 |
+
| No stale (5a) | `--stale-mode none` | None | `Decompressor(bottleneck_dim, 2048)` |
|
| 374 |
+
| Uncompressed stale (5b) | `--stale-mode uncompressed` | Raw ref layer input | `StaleDecompressor(bottleneck_dim, 2048, 2048)` |
|
| 375 |
+
|
| 376 |
+
**Training hyperparameters:**
|
| 377 |
+
|
| 378 |
+
| Parameter | Value |
|
| 379 |
+
|---|---|
|
| 380 |
+
| Optimizer | AdamW |
|
| 381 |
+
| Learning rate | 1e-4 |
|
| 382 |
+
| Weight decay | 0.01 |
|
| 383 |
+
| LR schedule | Cosine with 10% linear warmup |
|
| 384 |
+
| Max epochs | 1 |
|
| 385 |
+
| Batch size | 2 (gradient accumulation: 8, effective: 16) |
|
| 386 |
+
| Gradient clipping | max_norm = 1.0 |
|
| 387 |
+
| Early stopping patience | 5 epochs |
|
| 388 |
+
| Validation interval | Every 2500 optimizer steps (HF) / 1000 (Megatron) (configurable via `--val-interval`) |
|
| 389 |
+
| Validation batch size | 8 (configurable via `--val-batch-size`; larger than train because no backward) |
|
| 390 |
+
| Validation fraction | 10% |
|
| 391 |
+
| Max sequence length | 2048 (configurable via `--max-length`) |
|
| 392 |
+
| Loss function | Cross-entropy (response tokens only, SFT mode) |
|
| 393 |
+
|
| 394 |
+
Note the lower learning rate (1e-4 vs 1e-3 for offline) — the LM loss landscape propagates gradients through 48 frozen transformer layers, requiring more conservative updates.
|
| 395 |
+
|
| 396 |
+
**Tail micro-batch handling:** When `len(dataloader) % grad_accum != 0`, the remaining micro-batches
|
| 397 |
+
have their accumulated gradients rescaled by `grad_accum / remainder` (correcting the divisor from
|
| 398 |
+
`1/grad_accum` to `1/remainder`) before performing a final optimizer step. This ensures no training
|
| 399 |
+
data is discarded. Applied to both HF (`run_e2e_compressor.py`) and Megatron (`train.py`).
|
| 400 |
+
|
| 401 |
+
**Two evaluation stages (different data, different code paths):**
|
| 402 |
+
|
| 403 |
+
| Stage | Split | Batch size | Function | Purpose |
|
| 404 |
+
|---|---|---|---|---|
|
| 405 |
+
| Training-time val | VAL (50K seqs) | `--val-batch-size` (default 8) | `evaluate_val_loss()` in training script | Checkpoint selection, wandb monitoring |
|
| 406 |
+
| Final PPL | TEST (50K seqs) | 1 (per-sample) | `evaluate_perplexity()` in `model_utils.py` | Reported results |
|
| 407 |
+
|
| 408 |
+
The training-time validation runs every `--val-interval` optimizer steps and at epoch end, using the VAL split. It drives best-checkpoint selection. The final perplexity evaluation runs after training on the held-out TEST split (never seen during training or checkpoint selection) and produces the numbers reported in the results tables. These are separate code paths — `--val-batch-size` only affects the training-time evaluation.
|
| 409 |
+
|
| 410 |
+
**Parameter count:** Same as Tasks 3b (5a) and 4b-uncompressed (5b):
|
| 411 |
+
|
| 412 |
+
| Mode | Ratio | Bottleneck | Parameters | % of Activated |
|
| 413 |
+
|---|---|---|---|---|
|
| 414 |
+
| No stale (5a) | 2x | 1024 | 201.47M | 6.008% |
|
| 415 |
+
| No stale (5a) | 4x | 512 | 100.79M | 3.006% |
|
| 416 |
+
| No stale (5a) | 8x | 256 | 50.44M | 1.504% |
|
| 417 |
+
| No stale (5a) | 16x | 128 | 25.27M | 0.754% |
|
| 418 |
+
| Uncompressed stale (5b) | 2x | 1024 | 386.02M | 11.512% |
|
| 419 |
+
| Uncompressed stale (5b) | 4x | 512 | 285.34M | 8.509% |
|
| 420 |
+
| Uncompressed stale (5b) | 8x | 256 | 234.99M | 7.008% |
|
| 421 |
+
| Uncompressed stale (5b) | 16x | 128 | 209.82M | 6.257% |
|
| 422 |
+
|
| 423 |
+
**Multi-GPU setup:**
|
| 424 |
+
- Model distributed across 4 GPUs via `device_map="auto"` (~15 GB/GPU)
|
| 425 |
+
- Gradient checkpointing enabled (`use_reentrant=False`) to reduce activation memory
|
| 426 |
+
- 8 GPUs available → 05a and 05b run in parallel on separate GPU sets (GPUs 0-3 and 4-7)
|
| 427 |
+
- Each compressor is automatically placed on the same GPU as its MoE layer
|
| 428 |
+
|
| 429 |
+
**Implementation:** `src/run_e2e_compressor.py`, `scripts/05_run_e2e_compressor.sh`.
|
| 430 |
+
|
| 431 |
+
---
|
| 432 |
+
|
| 433 |
+
### 5.6 Megatron-LM E2E Training (Task 5 — Megatron variant)
|
| 434 |
+
|
| 435 |
+
**Motivation:** The HuggingFace-based Task 5 uses `device_map="auto"` for naive layer-sharded model parallelism. Only one GPU is active at a time during forward pass (sequential layer execution), with no tensor or data parallelism. This limits training throughput and cannot scale to multi-node.
|
| 436 |
+
|
| 437 |
+
**Approach:** Replace HuggingFace with Megatron-LM to get proper tensor parallelism (TP), expert parallelism (EP), and data parallelism (DP):
|
| 438 |
+
- All 4 GPUs active simultaneously via TP (each GPU holds shards of every layer)
|
| 439 |
+
- Multi-node scaling via DP across nodes + TP within nodes
|
| 440 |
+
- Megatron's optimized kernels (fused LayerNorm, FlashAttention, etc.)
|
| 441 |
+
|
| 442 |
+
**Compressor/decompressor placement:**
|
| 443 |
+
|
| 444 |
+
In real expert parallelism, the compressor and decompressor are on DIFFERENT GPUs:
|
| 445 |
+
- **Compressor:** Same GPU as attention output (source GPU where token originates)
|
| 446 |
+
- **Decompressor:** Same GPU as MoE expert (destination GPU after dispatch)
|
| 447 |
+
|
| 448 |
+
This is more realistic than the HF hook-based simulation where the router sees compressed-then-decompressed input. With Megatron, the router sees the ORIGINAL hidden state; only the dispatch is compressed.
|
| 449 |
+
|
| 450 |
+
**Phase A (TP only, EP=1):** Compressor and decompressor on same GPU (same as current HF approach). TP=4 shards each layer across 4 GPUs.
|
| 451 |
+
|
| 452 |
+
**Phase B (with EP):** Compressor on attention GPU, decompressor on expert GPU. MoE dispatch sends compressed tokens (reduced all-to-all volume). The `CompressedMoETokenDispatcher` wraps Megatron's dispatcher to:
|
| 453 |
+
1. Compress on source GPU (attention side)
|
| 454 |
+
2. Dispatch compressed tokens (smaller all-to-all)
|
| 455 |
+
3. Decompress on destination GPU (expert side)
|
| 456 |
+
|
| 457 |
+
**Training pipeline:**
|
| 458 |
+
1. Convert Qwen3-30B-A3B from HF format to Megatron format via Megatron Bridge
|
| 459 |
+
2. Load with TP=4 (each GPU holds ~15-20 GB of sharded weights)
|
| 460 |
+
3. Freeze all LLM parameters
|
| 461 |
+
4. Insert per-layer compressor/decompressor pairs at MoE boundaries
|
| 462 |
+
5. Train compressors via language modeling objective (same as HF Task 5)
|
| 463 |
+
6. Save compressor weights (from rank 0, since all TP ranks have identical copies)
|
| 464 |
+
|
| 465 |
+
**TP-aware loss computation:** `MegatronModelWrapper._compute_loss()` uses
|
| 466 |
+
`vocab_parallel_cross_entropy` when TP > 1. SFT labels (-100) are clamped to 0 before
|
| 467 |
+
the call (avoiding garbage per-token loss for masked positions), and loss is computed as
|
| 468 |
+
`(per_token_loss * loss_mask).sum() / num_valid`. The non-TP path uses PyTorch's
|
| 469 |
+
`cross_entropy(ignore_index=-100)` which handles masking internally.
|
| 470 |
+
|
| 471 |
+
**Evaluation:** Uses existing HF-based evaluation code — load trained compressor weights into `E2ECompressorManager` and evaluate perplexity with hook-based simulation.
|
| 472 |
+
|
| 473 |
+
**Parallelism strategies:**
|
| 474 |
+
|
| 475 |
+
| Hardware | Configuration | Notes |
|
| 476 |
+
|---|---|---|
|
| 477 |
+
| 4 GPUs | TP=4, EP=1, PP=1, DP=1 | All GPUs active via tensor parallelism |
|
| 478 |
+
| 8 GPUs | TP=4, EP=1, PP=1, DP=2 | TP within 4 GPUs, DP across 2 replicas |
|
| 479 |
+
| N nodes × 4 GPUs | TP=4, DP=N | TP within node (NVLink), DP across nodes |
|
| 480 |
+
| EP variant | TP=2, EP=2, PP=1, DP=1 | Compressor on TP ranks, decompressor on EP ranks |
|
| 481 |
+
|
| 482 |
+
**Compressor weights with TP:**
|
| 483 |
+
- Compressors are replicated on all TP ranks (not sharded)
|
| 484 |
+
- Input is full hidden state (post-attention all-reduce)
|
| 485 |
+
- Gradients identical across ranks — no extra all-reduce needed
|
| 486 |
+
- Save from rank 0 only
|
| 487 |
+
|
| 488 |
+
**Implementation:** `src/megatron_e2e/` package with EP-first parallelism (EP=4, TP=1), CUDA 12.9, Megatron Bridge 0.2+, Transformer Engine. Entry point: `src/megatron_e2e/train.py`, bash wrapper: `scripts/05_megatron_e2e.sh`, setup: `scripts/megatron_setup_env.sh`. Multi-node: `scripts/05_megatron_e2e_multinode.sh`.
|
| 489 |
+
|
| 490 |
+
---
|
| 491 |
+
|
| 492 |
+
### 5.7 Baseline E2E Evaluation (Task 5c)
|
| 493 |
+
|
| 494 |
+
**Motivation:** Tasks 5a/5b report perplexity relative to an "untrained baseline" (the original model evaluated on the same test data). However, 5a/5b's training pipeline also loads and processes data through `load_e2e_data()`, computes SFT-masked loss on train/val splits, and may differ subtly from a raw model evaluation. Task 5c runs the exact same pipeline (same data loading, same loss computation, same evaluation) but WITHOUT inserting any compressors. This provides:
|
| 495 |
+
|
| 496 |
+
1. **Train/val loss context:** If 5c's train loss is ~1.0, and 5a-2x's is 1.11, the compression overhead is only +0.11 — not the raw 1.11 value.
|
| 497 |
+
2. **Pipeline consistency:** Confirms that the data pipeline itself does not introduce artifacts.
|
| 498 |
+
3. **Fair comparison:** All three (5a, 5b, 5c) use identical code paths except for the compression hooks.
|
| 499 |
+
|
| 500 |
+
**What it does:**
|
| 501 |
+
- Loads data via `load_e2e_data()` (same function as 5a/5b)
|
| 502 |
+
- Evaluates train and val loss using `evaluate_loss_no_hooks()` — same as `evaluate_val_loss()` but without a compressor manager
|
| 503 |
+
- Evaluates baseline PPL on the TEST split (same as 5a/5b)
|
| 504 |
+
- No training, no compression ratios, no weight files
|
| 505 |
+
|
| 506 |
+
**Implementation:** Added as `--stale-mode baseline` to both `src/run_e2e_compressor.py` (HF) and `src/megatron_e2e/train.py` (Megatron). Output dirs: `results/05c_e2e_baseline/` (HF), `results/05c_megatron_e2e_baseline/` (Megatron).
|
| 507 |
+
|
| 508 |
+
---
|
| 509 |
+
|
| 510 |
+
### 5.8 E2E with Pretrained Init (Tasks 6a/6b)
|
| 511 |
+
|
| 512 |
+
**Motivation:** Tasks 5a/5b initialize compressor/decompressor weights with a near-identity matrix — the first `bottleneck_dim` dimensions are preserved, and the rest are zeroed out. This is a reasonable starting point but the optimizer must learn the full compression mapping from scratch using only the LM loss signal.
|
| 513 |
+
|
| 514 |
+
Tasks 3b and 4b already train compressors to minimize reconstruction loss on cached hidden states. While this offline objective doesn't directly optimize for LM quality, the resulting weights encode the structure of hidden-state distributions and provide a potentially better starting point for E2E fine-tuning.
|
| 515 |
+
|
| 516 |
+
Tasks 6a/6b test this hypothesis: does initializing E2E training from reconstruction-optimized weights (instead of near-identity) lead to faster convergence or better final quality?
|
| 517 |
+
|
| 518 |
+
**Architecture:** Identical to Tasks 5a/5b — same `Compressor`, `Decompressor`, `StaleDecompressor` classes, same training objective (cross-entropy), same hyperparameters. The only difference is the initial weight values.
|
| 519 |
+
|
| 520 |
+
**Two modes:**
|
| 521 |
+
|
| 522 |
+
| Mode | Task | Init from | Stale signal |
|
| 523 |
+
|---|---|---|---|
|
| 524 |
+
| No stale (6a) | `--stale-mode none --init-weights-dir results/03b_perlayer_compressor` | Task 3b (per-layer offline) | None |
|
| 525 |
+
| Uncompressed stale (6b) | `--stale-mode uncompressed --init-weights-dir results/04b_stale_uncompressed` | Task 4b (stale offline) | Raw ref layer input |
|
| 526 |
+
|
| 527 |
+
**Weight compatibility:** Tasks 3b/4b save weights keyed by HF layer names (`model.layers.N.mlp`) with `compressor` and `decompressor` sub-keys. The `MegatronCompressorManager.load_weights()` expects the same format (it converts Megatron names to HF names via `_megatron_to_hf_layer_name()`). The offline and E2E architectures use identical module classes, so `load_state_dict()` works directly.
|
| 528 |
+
|
| 529 |
+
**Parameter count:** Same as Tasks 5a/5b (identical architecture).
|
| 530 |
+
|
| 531 |
+
**Training hyperparameters:** Same as Tasks 5a/5b (same LR, warmup, epochs, etc.).
|
| 532 |
+
|
| 533 |
+
**Implementation:** Added `--init-weights-dir` argument to `src/megatron_e2e/train.py`. Auto-detects weight file naming pattern. Bash wrapper: `scripts/06_megatron_e2e_pretrained.sh`. Output dirs: `results/06a_megatron_e2e_pretrained_perlayer/` (6a), `results/06b_megatron_e2e_pretrained_stale/` (6b).
|
| 534 |
+
|
| 535 |
+
### 5.9 Split-Mode E2E Training (Tasks 7a/7b)
|
| 536 |
+
|
| 537 |
+
**Motivation:** Tasks 5/6 use forward pre-hooks that compress→decompress the MoE input — both the router AND experts see the decompressed hidden state. This is a conservative lower bound on quality. In real expert parallelism, the router runs on the source GPU with the **original** hidden state (before compression), and only experts on the destination GPU see the decompressed version. Task 7 trains the compressor under this more realistic "split mode" to see whether the training signal improves when the router is not degraded by compression artifacts.
|
| 538 |
+
|
| 539 |
+
**Approach — Two-Level Pre-Hooks:**
|
| 540 |
+
|
| 541 |
+
Instead of monkey-patching MoE forward methods, two pre-hooks are registered per MoE layer:
|
| 542 |
+
|
| 543 |
+
1. **MoE pre-hook:** Saves the original input, then returns the compress→decompress result. The MoE module's `forward()` receives the decompressed tensor as its input.
|
| 544 |
+
2. **Router pre-hook:** Registered on the router/gate submodule. When the MoE's `forward()` calls `self.gate(hidden_states)`, this hook intercepts and swaps the input back to the saved original.
|
| 545 |
+
|
| 546 |
+
This works because:
|
| 547 |
+
- The MoE pre-hook changes what `forward()` receives (decompressed), so experts get decompressed data.
|
| 548 |
+
- The router pre-hook only affects the `gate` submodule's input, restoring the original.
|
| 549 |
+
- PyTorch hook execution order: MoE pre-hook runs first (on the outer module), then when `forward()` calls `self.gate(...)` internally, the gate pre-hook runs and swaps the argument.
|
| 550 |
+
|
| 551 |
+
**Two modes:**
|
| 552 |
+
|
| 553 |
+
| Mode | Task | Init from | Stale signal | Router input |
|
| 554 |
+
|---|---|---|---|---|
|
| 555 |
+
| No stale (7a) | `--stale-mode none --router-mode uncompressed --init-weights-dir results/03b_perlayer_compressor` | Task 3b | None | Original |
|
| 556 |
+
| Uncompressed stale (7b) | `--stale-mode uncompressed --router-mode uncompressed --init-weights-dir results/04b_stale_uncompressed` | Task 4b | Raw ref input | Original |
|
| 557 |
+
|
| 558 |
+
**Architecture:** Identical to Tasks 6a/6b — same classes, same init weights, same hyperparameters. The only difference is that `router_mode="uncompressed"` activates the two-level hook pattern during training and evaluation.
|
| 559 |
+
|
| 560 |
+
**Implementation:** Added `--router-mode` argument to `src/megatron_e2e/train.py` and `src/run_e2e_compressor.py`. Split-mode hooks added to `MegatronCompressorManager` (Megatron training) and `evaluate_perplexity_with_perlayer_compression`/`evaluate_perplexity_with_stale_compression` (HF PPL evaluation). Bash wrapper: `scripts/07_megatron_e2e_split.sh`. Output dirs: `results/07a_megatron_e2e_split_perlayer/` (7a), `results/07b_megatron_e2e_split_stale/` (7b).
|
| 561 |
+
|
| 562 |
+
---
|
| 563 |
+
|
| 564 |
+
## 6. Results
|
| 565 |
+
|
| 566 |
+
### 6.1 Summary Table — All Methods
|
| 567 |
+
|
| 568 |
+
**Model:** Qwen3-30B-A3B-Instruct-2507 (full BF16)
|
| 569 |
+
**Dataset:** allenai/Dolci-Instruct-SFT
|
| 570 |
+
|
| 571 |
+
| Method | Ratio | MSE | CosSim | PPL | PPL Delta | HF Strict | HF Flex |
|
| 572 |
+
|---|---|---|---|---|---|---|---|
|
| 573 |
+
| Baseline (Tasks 2–4) | — | — | — | 3.89 | — | 44.12% | 82.79% |
|
| 574 |
+
| Baseline (5c / Megatron) | — | — | — | 3.94 | — | 44.12% | 82.79% |
|
| 575 |
+
| Quant INT8 | 2.0x | — | — | 3.90 | +0.01 | 48.90% | 82.26% |
|
| 576 |
+
| Quant INT4 | 4.0x | — | — | 4.51 | +0.62 | 56.41% | 68.54% |
|
| 577 |
+
| Quant INT2 | 8.0x | — | — | 1532.59 | +1528.70 | 0.00% | 0.00% |
|
| 578 |
+
| Neural (per-layer) | 2x | 0.0535 | 0.922 | 21.07 | +17.18 | 0.00% | 1.52% |
|
| 579 |
+
| Neural (per-layer) | 4x | 0.1073 | 0.835 | 425.75 | +421.87 | 0.00% | 0.00% |
|
| 580 |
+
| Neural (per-layer) | 8x | 0.1523 | 0.755 | 7949.78 | +7945.89 | 0.00% | 0.00% |
|
| 581 |
+
| Neural (per-layer) | 16x | 0.1893 | 0.683 | 52440.05 | +52436.16 | 0.00% | 0.00% |
|
| 582 |
+
| Stale-cond. (compressed) | 2x | 0.0379 | 0.947 | 6.13 | +2.24 | 3.41% | 62.55% |
|
| 583 |
+
| Stale-cond. (compressed) | 4x | 0.0876 | 0.869 | 31.64 | +27.75 | 0.61% | 1.52% |
|
| 584 |
+
| Stale-cond. (compressed) | 8x | 0.1330 | 0.791 | 2982.23 | +2978.34 | 0.00% | 0.00% |
|
| 585 |
+
| Stale-cond. (compressed) | 16x | 0.1720 | 0.717 | 17486.21 | +17482.32 | 0.00% | 0.00% |
|
| 586 |
+
| Stale-cond. (uncompressed) | 2x | 0.0346 | 0.952 | 6.24 | +2.36 | 2.81% | 67.10% |
|
| 587 |
+
| Stale-cond. (uncompressed) | 4x | 0.0690 | 0.900 | 16.11 | +12.22 | 0.99% | 6.14% |
|
| 588 |
+
| Stale-cond. (uncompressed) | 8x | 0.0966 | 0.855 | 423.68 | +419.79 | 0.00% | 0.00% |
|
| 589 |
+
| Stale-cond. (uncompressed) | 16x | 0.1173 | 0.819 | 3740.41 | +3736.53 | 0.00% | 0.00% |
|
| 590 |
+
| Megatron E2E per-layer (5a) | 2x | — | — | 2.77 | -1.17 | 61.33% | 61.64% |
|
| 591 |
+
| Megatron E2E per-layer (5a) | 4x | — | — | 4.28 | +0.35 | 20.70% | 21.30% |
|
| 592 |
+
| Megatron E2E per-layer (5a) | 8x | — | — | 7.49 | +3.55 | 1.82% | 2.12% |
|
| 593 |
+
| Megatron E2E per-layer (5a) | 16x | — | — | 11.26 | +7.33 | 0.91% | 2.73% |
|
| 594 |
+
| Megatron E2E stale (5b) | 2x | — | — | 2.71 | -1.23 | 60.27% | 60.65% |
|
| 595 |
+
| Megatron E2E stale (5b) | 4x | — | — | 3.61 | -0.33 | 31.54% | 32.37% |
|
| 596 |
+
| Megatron E2E stale (5b) | 8x | — | — | 4.98 | +1.04 | 4.93% | 5.00% |
|
| 597 |
+
| Megatron E2E stale (5b) | 16x | — | — | 6.34 | +2.41 | 2.12% | 2.27% |
|
| 598 |
+
| Megatron E2E pretrained per-layer (6a) | 2x | — | — | 2.41 | -1.53 | 79.98% | 80.06% |
|
| 599 |
+
| Megatron E2E pretrained per-layer (6a) | 4x | — | — | 3.18 | -0.76 | 55.04% | 55.19% |
|
| 600 |
+
| Megatron E2E pretrained per-layer (6a) | 8x | — | — | 4.52 | +0.58 | 16.98% | 16.98% |
|
| 601 |
+
| Megatron E2E pretrained per-layer (6a) | 16x | — | — | 7.34 | +3.40 | 2.27% | 2.27% |
|
| 602 |
+
| Megatron E2E pretrained stale (6b) | 2x | — | — | 2.25 | -1.69 | 82.49% | 82.64% |
|
| 603 |
+
| Megatron E2E pretrained stale (6b) | 4x | — | — | 2.57 | -1.37 | 64.37% | 64.52% |
|
| 604 |
+
| Megatron E2E pretrained stale (6b) | 8x | — | — | 3.04 | -0.90 | 45.79% | 45.94% |
|
| 605 |
+
| Megatron E2E pretrained stale (6b) | 16x | — | — | 3.47 | -0.47 | 25.85% | 25.85% |
|
| 606 |
+
| Split-mode E2E per-layer (7a) | 2x | — | — | 2.58 | -1.31 | 79.91% | 79.98% |
|
| 607 |
+
| Split-mode E2E per-layer (7a) | 4x | — | — | 3.72 | -0.17 | 42.08% | 42.15% |
|
| 608 |
+
| Split-mode E2E per-layer (7a) | 8x | — | — | 6.43 | +2.54 | 4.93% | 5.46% |
|
| 609 |
+
| Split-mode E2E per-layer (7a) | 16x | — | — | 908.20 | +904.31 | 0.00% | 0.53% |
|
| 610 |
+
| Split-mode E2E stale (7b) | 2x | — | — | 2.34 | -1.55 | 80.67% | 80.67% |
|
| 611 |
+
| Split-mode E2E stale (7b) | 4x | — | — | 2.80 | -1.09 | 65.81% | 65.96% |
|
| 612 |
+
| Split-mode E2E stale (7b) | 8x | — | — | 3.37 | -0.51 | 35.63% | 35.63% |
|
| 613 |
+
| Split-mode E2E stale (7b) | 16x | — | — | 4.28 | +0.39 | 16.53% | 16.68% |
|
| 614 |
+
|
| 615 |
+
Note: Tasks 2–4 and 5c baselines differ in PPL (3.89 vs 3.94) due to different evaluation
|
| 616 |
+
code paths (single-GPU HF vs Megatron pipeline). PPL deltas for offline methods use 3.89;
|
| 617 |
+
E2E methods use 3.94. HF Strict/Flex: GSM8K evaluated via HF backend (lm-eval-harness,
|
| 618 |
+
router-compressed mode). For Tasks 7a/7b, HF Strict/Flex is compressed-router only.
|
| 619 |
+
Uncompressed-router results for Tasks 7a/7b are in a dedicated table below Section 6.4.
|
| 620 |
+
GSM8K scores are identical for both baselines because GSM8K evaluation uses the same raw HF
|
| 621 |
+
model. GSM8K uses Megatron-trained weights for E2E methods. "Strict" requires exact
|
| 622 |
+
`#### <number>` format; "flexible" extracts the number from anywhere in the output.
|
| 623 |
+
HF-trained E2E weights (Tasks 5a/5b) were not available.
|
| 624 |
+
|
| 625 |
+
### 6.2 Key Findings
|
| 626 |
+
|
| 627 |
+
1. **E2E training is transformative** — E2E methods achieve PPL *below* baseline (3.94) at 2x. E2E stale stays below baseline at 4x (PPL=3.61).
|
| 628 |
+
2. **E2E stale at 16x is moderate** — PPL=6.34 (+2.41), 61% above baseline, with GSM8K strict-match at 2.12%.
|
| 629 |
+
3. **E2E dramatically outperforms offline** — Same architecture, same params: offline per-layer 4x PPL=425.75 vs E2E 4x PPL=4.28 (99x better). At 16x: 52440 vs 11.26 (4658x better).
|
| 630 |
+
4. **Stale conditioning matters more at high compression** — At 2x the gap is small (E2E stale 2.71 vs E2E per-layer 2.77), but at 16x it's 1.8x (6.34 vs 11.26).
|
| 631 |
+
5. **INT8 quantization is nearly lossless** — PPL 3.90 vs baseline 3.89 at 2x (+0.01), with GSM8K preserved (48.90% strict, 82.26% flexible).
|
| 632 |
+
6. **INT4 quantization is acceptable** — PPL 4.51 at ~4x (+0.62 delta). GSM8K strict-match actually improves to 56.41%.
|
| 633 |
+
7. **INT2 is catastrophic** — PPL 1533 at ~8x, completely unusable.
|
| 634 |
+
8. **Offline methods degrade rapidly** — Per-layer neural: PPL=21 at 2x, PPL=425 at 4x, PPL=7950 at 8x. Stale-conditioning (uncompressed) helps at 2x (PPL=6.24) but collapses at 8x (PPL=424).
|
| 635 |
+
9. **Below-baseline PPL** suggests E2E compressors act as regularizers, filtering noise from hidden states while preserving task-relevant information. Confirmed by GSM8K: E2E 2x scores 61.33% vs baseline 44.12%.
|
| 636 |
+
10. **Downstream tasks are more sensitive than PPL** — Offline stale_uncomp_2x has PPL=6.24 (+2.36) but GSM8K drops from 44% to 3% strict-match. E2E methods maintain both PPL and GSM8K. See Section 6.4.
|
| 637 |
+
11. **Offline compression destroys output format but partially preserves reasoning** — stale_uncomp_2x: 2.81% strict but 67.10% flexible-extract. E2E methods show no such gap (~0.3 pp).
|
| 638 |
+
12. **Pretrained init (Task 6) dramatically improves E2E training** — Initializing from offline-trained weights (Tasks 3b/4b) instead of near-identity gives 13–45% PPL improvement and massive GSM8K gains. 6b at 2x achieves PPL=2.25 and 82.5% GSM8K strict-match (vs 5b: PPL=2.71, 60.3%). Even at 16x, 6b (PPL=3.47, GSM8K 25.9%) stays below baseline PPL (3.89) and retains meaningful downstream accuracy.
|
| 639 |
+
13. **Pretrained init benefits grow with compression ratio** — For stale-conditioned (6b vs 5b): PPL improvement goes from 17% at 2x to 45% at 16x; GSM8K goes from +22 pp at 2x to +24 pp at 16x. The offline-trained weights provide a much better starting point for E2E optimization, especially at high compression where near-identity init struggles.
|
| 640 |
+
14. **Split-mode training (Task 7) matches deployment reality** — Training with split-mode (router sees original, experts see decompressed) then evaluating in the same mode yields the best uncompressed-router results. 7b uncompressed at 2x achieves 83.3% GSM8K strict-match — the best result across all methods and modes.
|
| 641 |
+
15. **7b uncompressed stays below baseline PPL at ALL ratios** — Even at 16x compression, 7b uncompressed PPL=3.27 remains below the no-compression baseline (3.89). This is the only method to maintain below-baseline PPL at every compression ratio, demonstrating that stale-conditioned split-mode E2E compressors can be simultaneously lossy (16x compression) and beneficial (regularization effect).
|
| 642 |
+
16. **Split-mode training trades compressed-eval quality for uncompressed-eval quality** — 7a/7b compressed-eval PPL is worse than 6a/6b (e.g., 7a 16x compressed: 908 vs 6a: 8.49) because the model was not trained to have the router see decompressed data. But 7a/7b uncompressed-eval is better (7a 16x uncompressed: 6.64 vs 6a compressed: 8.49). This confirms the training mode should match the deployment mode.
|
| 643 |
+
17. **Catastrophic collapse at extreme compression without stale** — 7a 16x compressed PPL=908 (vs 7a 16x uncompressed=6.64), showing that when per-layer compression is too lossy, correct routing (from original hidden states) becomes critical. Stale conditioning (7b) avoids this entirely: 7b 16x compressed=4.28, uncompressed=3.27.
|
| 644 |
+
|
| 645 |
+
### 6.3 HF vs Megatron Comparison
|
| 646 |
+
|
| 647 |
+
**Note:** HF E2E results in this section are from an earlier training run. The HF E2E
|
| 648 |
+
weight files are no longer available in the current `results/05a_e2e_perlayer/` and
|
| 649 |
+
`results/05b_e2e_stale/` directories (only logs remain). The Megatron results are
|
| 650 |
+
from the current run and match the JSON files. The comparison below is preserved for
|
| 651 |
+
historical reference but the HF numbers cannot be independently verified from current data.
|
| 652 |
+
|
| 653 |
+
Both implementations use the same compressor architecture (Compressor + Decompressor / StaleDecompressor), the same model (Qwen3-30B-A3B-Instruct-2507), and the same training data (Dolci-Instruct-SFT). The key differences are in the distributed training strategy and model parallelism framework.
|
| 654 |
+
|
| 655 |
+
**Implementation differences:**
|
| 656 |
+
|
| 657 |
+
| Aspect | HuggingFace | Megatron |
|
| 658 |
+
|---|---|---|
|
| 659 |
+
| Framework | HF Transformers + `device_map="auto"` | Megatron-Core + AutoBridge |
|
| 660 |
+
| Parallelism | Naive layer sharding (sequential) | EP=4, TP=1, PP=1, DP=4 |
|
| 661 |
+
| GPU utilization | 1 GPU active at a time | All 4 GPUs active (DP) |
|
| 662 |
+
| Data parallelism | None (single data stream) | DP=4 (each rank sees 1/4 of data per step) |
|
| 663 |
+
| Optimizer | AdamW (single replica) | AdamW (replicated, gradients all-reduced) |
|
| 664 |
+
| CUDA | 12.6 | 12.9 |
|
| 665 |
+
|
| 666 |
+
**Task 5a — E2E per-layer (no stale):**
|
| 667 |
+
|
| 668 |
+
| Ratio | HF PPL | Megatron PPL | Gap (Meg−HF) |
|
| 669 |
+
|---|---|---|---|
|
| 670 |
+
| 2x | **2.645** (−1.58) | 2.682 (−1.54) | +0.04 |
|
| 671 |
+
| 4x | **3.687** (−0.54) | 4.410 (+0.19) | +0.72 |
|
| 672 |
+
| 8x | **6.371** (+2.15) | 8.182 (+3.96) | +1.81 |
|
| 673 |
+
| 16x | **9.157** (+4.93) | 11.670 (+7.44) | +2.51 |
|
| 674 |
+
|
| 675 |
+
**Task 5b — E2E stale-conditioned (uncompressed stale):**
|
| 676 |
+
|
| 677 |
+
| Ratio | HF PPL | Megatron PPL | Gap (Meg−HF) |
|
| 678 |
+
|---|---|---|---|
|
| 679 |
+
| 2x | 2.570 (−1.65) | **2.568** (−1.66) | −0.00 |
|
| 680 |
+
| 4x | **3.102** (−1.12) | 3.420 (−0.80) | +0.32 |
|
| 681 |
+
| 8x | **4.015** (−0.21) | 4.743 (+0.52) | +0.73 |
|
| 682 |
+
| 16x | **4.550** (+0.32) | 5.232 (+1.01) | +0.68 |
|
| 683 |
+
|
| 684 |
+
**Training losses (train / val):**
|
| 685 |
+
|
| 686 |
+
| Config | HF 5a | Megatron 5a | HF 5b | Megatron 5b |
|
| 687 |
+
|---|---|---|---|---|
|
| 688 |
+
| 2x | 1.215 / 1.093 | 1.258 / 1.109 | 1.193 / 1.070 | 1.210 / 1.068 |
|
| 689 |
+
| 4x | 1.786 / 1.447 | 2.103 / 1.627 | 1.579 / 1.286 | 1.784 / 1.375 |
|
| 690 |
+
| 8x | 2.412 / 2.004 | 2.776 / 2.242 | 1.921 / 1.555 | 2.206 / 1.724 |
|
| 691 |
+
| 16x | 2.768 / 2.326 | 3.180 / 2.567 | 2.069 / 1.686 | 2.344 / 1.823 |
|
| 692 |
+
|
| 693 |
+
**Analysis:**
|
| 694 |
+
|
| 695 |
+
1. **At 2x, both implementations converge to the same quality.** The gap is negligible (0.04 for 5a, −0.002 for 5b). Near-identity initialization gives a strong starting point, and 2x compression is easy enough that both optimizers find similar solutions.
|
| 696 |
+
|
| 697 |
+
2. **Megatron's gap grows at higher compression ratios for 5a** (no stale). At 4x the gap is +0.72, at 16x it's +2.51. The likely cause is that Megatron with DP=4 provides each rank with 1/4 of the data per step — effectively a noisier gradient estimate. HF's single-replica training sees the full data stream, leading to a slightly better optimizer trajectory for harder problems (higher compression).
|
| 698 |
+
|
| 699 |
+
3. **Stale conditioning dramatically narrows the Megatron-HF gap.** Adding stale conditioning reduces the gap by 50–73% at all ratios:
|
| 700 |
+
- 4x: +0.72 → +0.32 (56% reduction)
|
| 701 |
+
- 8x: +1.81 → +0.73 (60% reduction)
|
| 702 |
+
- 16x: +2.51 → +0.68 (73% reduction)
|
| 703 |
+
The stale signal acts as an anchor that partially corrects for the noisier optimization — it provides a strong prior about the expected hidden state, reducing the difficulty of the decompression task.
|
| 704 |
+
|
| 705 |
+
4. **Both Megatron variants produce usable compressors.** Megatron 5b at 4x (PPL=3.42) is still 19% below baseline, and even at 16x (PPL=5.23) the degradation is only +24%. For production deployment where Megatron's scalability is needed, these results are practical.
|
| 706 |
+
|
| 707 |
+
5. **Recommendation:** Use Megatron with stale conditioning (5b mode) for production. At 2–4x compression, results match HF quality. At 8–16x, there is a modest quality gap, but Megatron's multi-node scalability and proper expert parallelism make it the right choice for large-scale deployment.
|
| 708 |
+
|
| 709 |
+
### 6.4 Downstream Task Evaluation (GSM8K)
|
| 710 |
+
|
| 711 |
+
**Benchmark:** GSM8K chain-of-thought (gsm8k_cot), 8-shot, 1319 test examples.
|
| 712 |
+
Two metrics: **strict-match** (exact `#### <number>` format) and **flexible-extract**
|
| 713 |
+
(number extracted from anywhere in the output via regex).
|
| 714 |
+
Two router modes: **compressed** (router AND experts see decompressed hidden states)
|
| 715 |
+
and **uncompressed** (router sees original, experts see decompressed — more realistic EP
|
| 716 |
+
simulation). PPL, MSE, CosSim from HF-based evaluation (`model_utils.py`).
|
| 717 |
+
HF Strict/Flex from HF backend (lm-eval-harness, router-compressed mode).
|
| 718 |
+
vLLM columns from vLLM backend (`run_all_downstream.py`, both router modes).
|
| 719 |
+
For Tasks 7a/7b, vLLM Uncomp. columns show HF backend uncompressed-router results
|
| 720 |
+
(confirmed identical via both `run_all_downstream.py` and `run_e2e_compressor.py
|
| 721 |
+
--router-mode uncompressed`).
|
| 722 |
+
|
| 723 |
+
| Method | Ratio | MSE | CosSim | PPL | PPL Δ | HF Strict | HF Flex | vLLM Comp. Strict | vLLM Comp. Flex | vLLM Uncomp. Strict | vLLM Uncomp. Flex |
|
| 724 |
+
|---|---|---|---|---|---|---|---|---|---|---|---|
|
| 725 |
+
| Baseline | — | — | — | 3.89 | — | 44.1% | 82.8% | 43.3% | 82.9% | — | — |
|
| 726 |
+
| Quant INT8 | 2x | — | — | 3.90 | +0.01 | 48.9% | 82.3% | 43.7% | 82.2% | — | — |
|
| 727 |
+
| Quant INT4 | 4x | — | — | 4.51 | +0.62 | 56.4% | 68.5% | 46.8% | 65.4% | — | — |
|
| 728 |
+
| Quant INT2 | 8x | — | — | 1532.59 | +1528.70 | 0.0% | 0.0% | 0.0% | 0.0% | — | — |
|
| 729 |
+
| Neural (per-layer) | 2x | 0.0535 | 0.922 | 21.07 | +17.18 | 0.0% | 1.5% | 0.0% | 1.2% | 22.7% | 42.6% |
|
| 730 |
+
| Neural (per-layer) | 4x | 0.1073 | 0.835 | 425.75 | +421.87 | 0.0% | 0.0% | 0.0% | 0.4% | 1.0% | 2.4% |
|
| 731 |
+
| Neural (per-layer) | 8x | 0.1523 | 0.755 | 7949.78 | +7945.89 | 0.0% | 0.0% | 0.0% | 0.0% | 2.0% | 1.9% |
|
| 732 |
+
| Neural (per-layer) | 16x | 0.1893 | 0.683 | 52440.05 | +52436.16 | 0.0% | 0.0% | 0.0% | 0.0% | 1.5% | 1.5% |
|
| 733 |
+
| Stale-cond. (compressed) | 2x | 0.0379 | 0.947 | 6.13 | +2.24 | 3.4% | 62.6% | 0.2% | 0.8% | 34.1% | 69.7% |
|
| 734 |
+
| Stale-cond. (compressed) | 4x | 0.0876 | 0.869 | 31.64 | +27.75 | 0.6% | 1.5% | 0.0% | 0.6% | 2.7% | 4.9% |
|
| 735 |
+
| Stale-cond. (compressed) | 8x | 0.1330 | 0.791 | 2982.23 | +2978.34 | 0.0% | 0.0% | 0.0% | 0.0% | 1.3% | 1.8% |
|
| 736 |
+
| Stale-cond. (compressed) | 16x | 0.1720 | 0.717 | 17486.21 | +17482.32 | 0.0% | 0.0% | 0.0% | 0.0% | 1.8% | 2.0% |
|
| 737 |
+
| Stale-cond. (uncompressed) | 2x | 0.0346 | 0.952 | 6.24 | +2.36 | 2.8% | 67.1% | 0.2% | 1.1% | 30.7% | 72.6% |
|
| 738 |
+
| Stale-cond. (uncompressed) | 4x | 0.0690 | 0.900 | 16.11 | +12.22 | 1.0% | 6.1% | 0.0% | 0.6% | 6.1% | 9.3% |
|
| 739 |
+
| Stale-cond. (uncompressed) | 8x | 0.0966 | 0.855 | 423.68 | +419.79 | 0.0% | 0.0% | 0.0% | 0.0% | 1.2% | 2.5% |
|
| 740 |
+
| Stale-cond. (uncompressed) | 16x | 0.1173 | 0.819 | 3740.41 | +3736.53 | 0.0% | 0.0% | 0.0% | 0.0% | 1.4% | 2.0% |
|
| 741 |
+
| E2E per-layer (5a) | 2x | — | — | 2.77 | −1.17 | 61.3% | 61.6% | 61.5% | 61.6% | 52.4% | 59.6% |
|
| 742 |
+
| E2E per-layer (5a) | 4x | — | — | 4.28 | +0.35 | 20.7% | 21.3% | 21.2% | 22.4% | 11.0% | 12.9% |
|
| 743 |
+
| E2E per-layer (5a) | 8x | — | — | 7.49 | +3.55 | 1.8% | 2.1% | 0.0% | 0.0% | 0.0% | 0.0% |
|
| 744 |
+
| E2E per-layer (5a) | 16x | — | — | 11.26 | +7.33 | 0.9% | 2.7% | 0.0% | 0.0% | 0.0% | 0.1% |
|
| 745 |
+
| E2E stale (5b) | 2x | — | — | 2.71 | −1.23 | 60.3% | 60.7% | 61.3% | 61.6% | 53.2% | 61.2% |
|
| 746 |
+
| E2E stale (5b) | 4x | — | — | 3.61 | −0.33 | 31.5% | 32.4% | 33.0% | 33.2% | 18.6% | 22.1% |
|
| 747 |
+
| E2E stale (5b) | 8x | — | — | 4.98 | +1.04 | 4.9% | 5.0% | 3.4% | 4.3% | 0.2% | 2.4% |
|
| 748 |
+
| E2E stale (5b) | 16x | — | — | 6.34 | +2.41 | 2.1% | 2.3% | 0.0% | 0.2% | 0.0% | 0.1% |
|
| 749 |
+
| E2E pretrained per-layer (6a) | 2x | — | — | 2.41 | −1.53 | 80.0% | 80.1% | 80.1% | 80.0% | 80.6% | 80.8% |
|
| 750 |
+
| E2E pretrained per-layer (6a) | 4x | — | — | 3.18 | −0.76 | 55.0% | 55.2% | 52.8% | 52.9% | 43.3% | 43.9% |
|
| 751 |
+
| E2E pretrained per-layer (6a) | 8x | — | — | 4.52 | +0.58 | 17.0% | 17.0% | 13.5% | 14.0% | 6.7% | 7.6% |
|
| 752 |
+
| E2E pretrained per-layer (6a) | 16x | — | — | 7.34 | +3.40 | 2.3% | 2.3% | 0.3% | 1.1% | 1.1% | 2.1% |
|
| 753 |
+
| E2E pretrained stale (6b) | 2x | — | — | 2.25 | −1.69 | 82.5% | 82.6% | 82.0% | 82.3% | 83.9% | 84.0% |
|
| 754 |
+
| E2E pretrained stale (6b) | 4x | — | — | 2.57 | −1.37 | 64.4% | 64.5% | 71.0% | 71.1% | 68.8% | 68.9% |
|
| 755 |
+
| E2E pretrained stale (6b) | 8x | — | — | 3.04 | −0.90 | 45.8% | 45.9% | 37.6% | 37.6% | 24.3% | 24.3% |
|
| 756 |
+
| E2E pretrained stale (6b) | 16x | — | — | 3.47 | −0.47 | 25.9% | 25.9% | 18.7% | 18.7% | 9.0% | 9.6% |
|
| 757 |
+
| Split E2E per-layer (7a) | 2x | — | — | 2.58 | −1.31 | 79.9% | 80.0% | — | — | 79.5% | 79.7% |
|
| 758 |
+
| Split E2E per-layer (7a) | 4x | — | — | 3.72 | −0.17 | 42.1% | 42.2% | — | — | 51.6% | 51.8% |
|
| 759 |
+
| Split E2E per-layer (7a) | 8x | — | — | 6.43 | +2.54 | 4.9% | 5.5% | — | — | 18.5% | 18.7% |
|
| 760 |
+
| Split E2E per-layer (7a) | 16x | — | — | 908.20 | +904.31 | 0.0% | 0.5% | — | — | 2.0% | 2.5% |
|
| 761 |
+
| Split E2E stale (7b) | 2x | — | — | 2.34 | −1.55 | 80.7% | 80.7% | — | — | 83.3% | 83.4% |
|
| 762 |
+
| Split E2E stale (7b) | 4x | — | — | 2.80 | −1.09 | 65.8% | 66.0% | — | — | 70.7% | 70.7% |
|
| 763 |
+
| Split E2E stale (7b) | 8x | — | — | 3.37 | −0.51 | 35.6% | 35.6% | — | — | 47.2% | 47.2% |
|
| 764 |
+
| Split E2E stale (7b) | 16x | — | — | 4.28 | +0.39 | 16.5% | 16.7% | — | — | 27.1% | 27.1% |
|
| 765 |
+
|
| 766 |
+
Notes: HF = HF backend (router-compressed mode). vLLM Comp. = vLLM backend, router-compressed
|
| 767 |
+
(router+experts see decompressed). vLLM Uncomp. = vLLM backend, router-uncompressed (router sees
|
| 768 |
+
original, experts see decompressed — split forward). For Tasks 7a/7b, HF Strict/Flex = HF backend
|
| 769 |
+
with compressed router; vLLM Uncomp. = HF backend with uncompressed router (confirmed identical
|
| 770 |
+
results from both `run_all_downstream.py` and `run_e2e_compressor.py --router-mode uncompressed`).
|
| 771 |
+
Baseline and quantization have no split mode. PPL baseline: 3.89 (offline) / 3.94 (E2E). GSM8K
|
| 772 |
+
uses Megatron-trained weights for E2E methods. Task 7 PPL column shows compressed-router PPL.
|
| 773 |
+
Uncompressed-router results (confirmed identical via both original eval code path and
|
| 774 |
+
`run_e2e_compressor.py --router-mode uncompressed`):
|
| 775 |
+
|
| 776 |
+
| Ratio | 7a PPL | 7b PPL | Baseline PPL | 7a Strict | 7a Flex | 7b Strict | 7b Flex |
|
| 777 |
+
|-------|--------|--------|--------------|-----------|---------|-----------|---------|
|
| 778 |
+
| 2x | 2.38 | 2.23 | 3.89 | 79.5% | 79.7% | 83.3% | 83.4% |
|
| 779 |
+
| 4x | 3.08 | 2.53 | 3.89 | 51.6% | 51.8% | 70.7% | 70.7% |
|
| 780 |
+
| 8x | 4.18 | 2.89 | 3.89 | 18.5% | 18.7% | 47.2% | 47.2% |
|
| 781 |
+
| 16x | 6.64 | 3.27 | 3.89 | 2.0% | 2.5% | 27.1% | 27.1% |
|
| 782 |
+
|
| 783 |
+
**Key findings:**
|
| 784 |
+
|
| 785 |
+
1. **E2E compression improves GSM8K over baseline.** Baseline strict-match is 44.12%.
|
| 786 |
+
E2E per-layer 2x achieves 61.33% (+17.2 pp) and E2E stale 2x achieves 60.27%
|
| 787 |
+
(+16.2 pp). This mirrors the below-baseline PPL effect — E2E compressors act as
|
| 788 |
+
regularizers that improve both perplexity and downstream task performance.
|
| 789 |
+
|
| 790 |
+
2. **INT8 and INT4 quantization also improve strict-match.** INT8: 48.90% (+4.8 pp),
|
| 791 |
+
INT4: 56.41% (+12.3 pp). The flexible-extract gap is smaller (INT8: 82.26% vs
|
| 792 |
+
baseline 82.79%), suggesting quantization noise may regularize the strict output
|
| 793 |
+
format without hurting reasoning.
|
| 794 |
+
|
| 795 |
+
3. **Offline methods catastrophically fail on generation tasks.** Per-layer neural
|
| 796 |
+
compressors score 0% strict-match at all ratios (even 2x, which has PPL=21.07).
|
| 797 |
+
Stale-conditioned 2x scores only 2.81% strict / 67.10% flexible. The flexible-extract
|
| 798 |
+
score reveals that the model still produces correct numerical answers but the output
|
| 799 |
+
format is destroyed — compression disrupts the learned generation patterns.
|
| 800 |
+
|
| 801 |
+
4. **The strict-vs-flexible gap reveals a format disruption effect.** Offline methods
|
| 802 |
+
show huge gaps: stale_uncomp_2x has 2.81% strict but 67.10% flexible (64.3 pp gap).
|
| 803 |
+
E2E methods show almost no gap: e2e_2x has 61.33% strict vs 61.64% flexible (0.3 pp).
|
| 804 |
+
End-to-end training preserves both the model's reasoning ability AND its output
|
| 805 |
+
formatting, while offline compression preserves some reasoning but destroys formatting.
|
| 806 |
+
|
| 807 |
+
5. **GSM8K is more sensitive than PPL to compression quality.** Stale_uncomp_2x has
|
| 808 |
+
PPL=6.24 (only +2.36 above baseline) yet scores 2.81% on GSM8K strict-match (vs
|
| 809 |
+
44.12% baseline). E2E per-layer 4x has PPL=4.28 (only +0.35 above baseline) yet
|
| 810 |
+
drops to 20.70% GSM8K. Generation tasks amplify small distributional shifts that
|
| 811 |
+
PPL barely registers.
|
| 812 |
+
|
| 813 |
+
6. **Stale conditioning matters for downstream tasks.** At 4x: E2E stale gets 31.54%
|
| 814 |
+
vs E2E per-layer 20.70% (+10.8 pp). At 8x: stale gets 4.93% vs per-layer 1.82%.
|
| 815 |
+
The stale signal helps preserve generation quality, consistent with PPL findings.
|
| 816 |
+
|
| 817 |
+
7. **Pretrained init (Task 6) yields dramatic GSM8K improvements.** 6b stale at 2x
|
| 818 |
+
achieves 82.49% strict-match — nearly double baseline (44.12%) and +22 pp over 5b
|
| 819 |
+
(60.27%). 6a per-layer at 2x reaches 79.98% (+19 pp over 5a). Even at 8x, 6b retains
|
| 820 |
+
45.79% (exceeding baseline) while 5b collapses to 4.93%.
|
| 821 |
+
|
| 822 |
+
8. **Pretrained init enables useful compression at 16x.** 6b at 16x achieves 25.85%
|
| 823 |
+
GSM8K strict-match — down from baseline (44.12%) but still practically useful. Compare
|
| 824 |
+
with 5b at 16x (2.12%) or 5a at 16x (0.91%). Offline weights provide the optimizer
|
| 825 |
+
with a much better starting region of parameter space.
|
| 826 |
+
|
| 827 |
+
9. **Best overall result: 6b at 2–4x compression.** 6b at 2x (PPL=2.25, GSM8K=82.5%)
|
| 828 |
+
and 4x (PPL=2.57, GSM8K=64.4%) both outperform baseline on PPL and at 4x still retain
|
| 829 |
+
strong downstream performance. This suggests stale-conditioned E2E compression with
|
| 830 |
+
pretrained init is a viable approach for reducing MoE communication by 2–4x with
|
| 831 |
+
minimal or even improved model quality.
|
| 832 |
+
|
| 833 |
+
---
|
| 834 |
+
|
| 835 |
+
## 7. Design Choices and Trade-offs
|
| 836 |
+
|
| 837 |
+
### 7.1 Offline Independent Training vs End-to-End
|
| 838 |
+
|
| 839 |
+
**Offline training (Tasks 2–4)** trains compressors on cached hidden states, independently per layer:
|
| 840 |
+
|
| 841 |
+
| Aspect | Offline | End-to-End (Task 5) |
|
| 842 |
+
|---|---|---|
|
| 843 |
+
| Loss | MSE + cosine (reconstruction) | Cross-entropy (next-token prediction) |
|
| 844 |
+
| Optimization scope | Per-layer, independent | Joint, all 48 layers |
|
| 845 |
+
| Gradient flow | None through LLM | Through entire frozen LLM |
|
| 846 |
+
| Stale signal | Pre-computed, frozen | Live, gradients flow through |
|
| 847 |
+
| Model precision | Full BF16 (~60 GB, 1 GPU) | Full BF16 (~60 GB, 4 GPUs) |
|
| 848 |
+
| Training cost | Minutes per layer | Hours for all layers + ratios |
|
| 849 |
+
| Error compounding | Not accounted for | Naturally optimized via global loss |
|
| 850 |
+
|
| 851 |
+
**Offline advantages:**
|
| 852 |
+
- Fast and cheap (minutes per layer on a single GPU)
|
| 853 |
+
- No need to backpropagate through the full LLM
|
| 854 |
+
- Each layer's compressor can be trained in parallel
|
| 855 |
+
|
| 856 |
+
**Offline limitations (addressed by e2e):**
|
| 857 |
+
- Compressors cannot adapt to how their reconstruction errors compound across layers. A small error at layer 0 may shift the hidden state distribution at layer 1, but layer 1's compressor was trained on the *original* layer-1 distribution.
|
| 858 |
+
- No joint optimization means the system cannot learn to allocate more capacity to layers where errors are most harmful.
|
| 859 |
+
- The stale signal used during offline training is the *unperturbed* reference input, but during inference the reference layer itself is compressed, creating a train-inference mismatch.
|
| 860 |
+
|
| 861 |
+
**E2E advantages:**
|
| 862 |
+
- Compressors are optimized for the actual downstream impact of compression on model quality.
|
| 863 |
+
- Joint optimization: the system implicitly learns which layers need higher fidelity.
|
| 864 |
+
- Stale gradients flow: reference layer compressors are optimized for their dual role (own reconstruction + stale side information for downstream layers). The stale signal during training already reflects upstream compression artifacts, eliminating the train-inference mismatch.
|
| 865 |
+
|
| 866 |
+
**E2E limitations:**
|
| 867 |
+
- Requires full-precision model in memory for proper gradient flow (~60 GB across 4 GPUs).
|
| 868 |
+
- Training is slower (full forward + backward through 48 frozen transformer layers per step).
|
| 869 |
+
- More hyperparameter-sensitive (LR, warmup, gradient clipping matter more).
|
| 870 |
+
|
| 871 |
+
### 7.2 Linear vs Non-linear Compressors
|
| 872 |
+
|
| 873 |
+
All compressors are single-layer linear networks (no activation functions). This was a deliberate choice:
|
| 874 |
+
- Linear compressors are equivalent to learning an optimal projection/reconstruction pair (related to PCA)
|
| 875 |
+
- They are fast to train and apply (single matrix multiply)
|
| 876 |
+
- They establish a clean baseline before trying non-linear architectures
|
| 877 |
+
|
| 878 |
+
### 7.3 Loss Function
|
| 879 |
+
|
| 880 |
+
The combined `MSE + 0.1 × (1 - cos_sim)` loss was chosen because:
|
| 881 |
+
- MSE alone can be dominated by outlier values (which are common in later layers with kurtosis up to 81K)
|
| 882 |
+
- Cosine similarity preserves the direction of the hidden state vector, which matters more than exact magnitude for downstream attention and expert computations
|
| 883 |
+
- The 0.1 weighting keeps MSE as the primary objective while regularizing directions
|
| 884 |
+
|
| 885 |
+
### 7.4 Reference Layer Stride
|
| 886 |
+
|
| 887 |
+
The stride of 12 (giving reference layers {0, 12, 24, 36}) was chosen as a balance:
|
| 888 |
+
- More reference layers (smaller stride) → better stale signals but more communication (ref layers use standard compression without stale)
|
| 889 |
+
- Fewer reference layers (larger stride) → stale signals become less correlated with non-ref layers
|
| 890 |
+
- stride=12 gives 4 reference layers covering 48 layers, with each non-ref layer at most 11 layers away from its reference
|
| 891 |
+
|
| 892 |
+
### 7.5 Training Data Size
|
| 893 |
+
|
| 894 |
+
100,000 tokens per layer (increased from initial 10,000). Each token produces a 2048-dim vector, so training data per layer is 100K × 2048 = 204.8M values. This is sufficient for learning a linear map with ~4M parameters (2x compression, per-layer).
|
| 895 |
+
|
| 896 |
+
### 7.6 Model Precision
|
| 897 |
+
|
| 898 |
+
All tasks use the same model in full BF16 precision (no weight quantization). This ensures:
|
| 899 |
+
- Hidden states used for offline training exactly match inference conditions
|
| 900 |
+
- End-to-end training has proper gradient flow through frozen layers
|
| 901 |
+
- All methods share the same baseline perplexity, enabling direct comparison
|
| 902 |
+
- 4-bit NF4 quantization is available via `--load-in-4bit` but is not the default
|
| 903 |
+
|
| 904 |
+
---
|
| 905 |
+
|
| 906 |
+
## 8. Implementation Details
|
| 907 |
+
|
| 908 |
+
### 8.1 Hook-Based Evaluation and Training
|
| 909 |
+
|
| 910 |
+
Four hook modes are used across experiments:
|
| 911 |
+
|
| 912 |
+
| Mode | Hook type | Used in |
|
| 913 |
+
|---|---|---|
|
| 914 |
+
| `evaluate_perplexity_with_compression` | Same compress/decompress for all layers | Shared compressor (Task 3) |
|
| 915 |
+
| `evaluate_perplexity_with_perlayer_compression` | Per-layer compress/decompress dicts | Per-layer compressor (Task 3b) |
|
| 916 |
+
| `evaluate_perplexity_with_stale_compression` | Per-layer + stale cache + ref/non-ref split | Stale-conditioned (Tasks 4a/4b) |
|
| 917 |
+
| `E2ECompressorManager.register_hooks()` | Per-layer, trainable, with/without stale cache | E2E training + eval (Task 5) |
|
| 918 |
+
|
| 919 |
+
The stale evaluation maintains a `stale_cache` dictionary that is populated by reference layer pre-hooks and read by subsequent non-reference layer hooks. This works because PyTorch processes layers sequentially (layer 0 before layer 1, etc.).
|
| 920 |
+
|
| 921 |
+
**Device safety in evaluation hooks:** With `device_map="auto"`, model layers may reside on
|
| 922 |
+
different GPUs. All evaluation hooks in `model_utils.py` (`evaluate_perplexity_with_perlayer_compression`
|
| 923 |
+
and `evaluate_perplexity_with_stale_compression`) explicitly call `.to(x.device)` on
|
| 924 |
+
compressor/decompressor outputs before returning them to the model. This ensures correctness
|
| 925 |
+
when compressor weights and MoE layers are on different devices.
|
| 926 |
+
|
| 927 |
+
**E2E training hooks (Task 5)** differ from evaluation hooks in two ways:
|
| 928 |
+
1. Compressor/decompressor parameters have `requires_grad=True`, so the autograd graph is maintained through the hooks.
|
| 929 |
+
2. For stale mode (5b), the cached stale signal is **not detached** — gradients flow through the stale path to earlier layers, enabling true end-to-end optimization.
|
| 930 |
+
|
| 931 |
+
### 8.2 MoE Layer Detection
|
| 932 |
+
|
| 933 |
+
`find_moe_layers()` in `model_utils.py` detects MoE modules by:
|
| 934 |
+
1. Checking if the class name contains "Moe", "MoE", or "SparseMoe"
|
| 935 |
+
2. Checking for `experts` attribute
|
| 936 |
+
3. Checking for both `gate` and `experts` attributes
|
| 937 |
+
|
| 938 |
+
This is model-agnostic and works for Qwen3, Mixtral, and other MoE architectures.
|
| 939 |
+
|
| 940 |
+
### 8.3 File Organization
|
| 941 |
+
|
| 942 |
+
**Offline experiments (Tasks 1–4)** follow the same pattern:
|
| 943 |
+
1. Load cached hidden states from `data/hidden_states/`
|
| 944 |
+
2. Train compressors on dispatch states
|
| 945 |
+
3. Evaluate reconstruction metrics (offline, on cached data)
|
| 946 |
+
4. Load the full model and evaluate perplexity (online, with hooks)
|
| 947 |
+
5. Save results to `results/{experiment}/`
|
| 948 |
+
|
| 949 |
+
**End-to-end experiments (Task 5)** follow a different pattern:
|
| 950 |
+
1. Load the full model in BF16 across 4 GPUs
|
| 951 |
+
2. Load and tokenize training data (Dolci-Instruct-SFT)
|
| 952 |
+
3. For each compression ratio: create compressor manager, train e2e, save weights
|
| 953 |
+
4. Evaluate perplexity on Dolci-Instruct-SFT (with hooks, same as offline)
|
| 954 |
+
5. Save results to `results/05{a,b}_e2e_{perlayer,stale}/`
|
| 955 |
+
|
| 956 |
+
Bash wrappers in `scripts/` handle environment setup, module loading, and argument passing.
|
| 957 |
+
|
| 958 |
+
### 8.4 Progress Tracking and Logging
|
| 959 |
+
|
| 960 |
+
All long-running loops use `tqdm` progress bars (written to stderr) for real-time progress monitoring with elapsed time and ETA. Key loops instrumented:
|
| 961 |
+
|
| 962 |
+
- **Training loops:** Epoch progress with loss/cosine postfix (all training functions)
|
| 963 |
+
- **Layer loops:** Per-layer training iteration (Tasks 3b, 4a/4b)
|
| 964 |
+
- **Data loading:** Calibration data and tokenization progress
|
| 965 |
+
- **Evaluation:** Perplexity evaluation sequence progress, quantization config iteration
|
| 966 |
+
- **Ratio loops:** Outer compression ratio iteration (all tasks)
|
| 967 |
+
|
| 968 |
+
Each bash script redirects output to two log files in the task's output directory:
|
| 969 |
+
|
| 970 |
+
| File | Contents | Source |
|
| 971 |
+
|---|---|---|
|
| 972 |
+
| `run.log` | Full output (print statements, results, summaries) | stdout |
|
| 973 |
+
| `progress.log` | tqdm progress bars (elapsed time, ETA, loss metrics) | stderr |
|
| 974 |
+
|
| 975 |
+
Monitor progress of a running experiment: `tail -f results/<task>/progress.log`
|
| 976 |
+
|
| 977 |
+
---
|
| 978 |
+
|
| 979 |
+
## 9. Reproducibility
|
| 980 |
+
|
| 981 |
+
### 9.1 Software Environment
|
| 982 |
+
|
| 983 |
+
- Python 3.11
|
| 984 |
+
- PyTorch (via `pip install torch` with CUDA 12.6)
|
| 985 |
+
- Transformers (HuggingFace)
|
| 986 |
+
- bitsandbytes (optional, for 4-bit model loading)
|
| 987 |
+
- datasets (for allenai/Dolci-Instruct-SFT)
|
| 988 |
+
- matplotlib, numpy
|
| 989 |
+
|
| 990 |
+
### 9.2 Hardware
|
| 991 |
+
|
| 992 |
+
- NVIDIA H100 80 GB GPUs (8 available)
|
| 993 |
+
- Tasks 1–4: single GPU sufficient (model in full BF16, ~60 GB on one H100 80 GB)
|
| 994 |
+
- Task 5: 4 GPUs per job (model in full BF16, ~60 GB + backprop memory); 05a and 05b run in parallel on GPUs 0-3 and 4-7
|
| 995 |
+
- 500+ GB system RAM (required for loading ~37 GB of hidden states for offline tasks)
|
| 996 |
+
- Compute Canada cluster
|
| 997 |
+
|
| 998 |
+
### 9.3 Random Seeds and Data Splitting
|
| 999 |
+
|
| 1000 |
+
All experiments use **seed=42** for reproducibility. A deterministic 80/10/10
|
| 1001 |
+
train/val/test split of the Dolci-Instruct-SFT dataset rows is computed via
|
| 1002 |
+
`get_split_indices()` in `model_utils.py`:
|
| 1003 |
+
|
| 1004 |
+
```python
|
| 1005 |
+
rng = random.Random(42)
|
| 1006 |
+
indices = list(range(dataset_size))
|
| 1007 |
+
rng.shuffle(indices)
|
| 1008 |
+
# 80% train, 10% val, 10% test
|
| 1009 |
+
```
|
| 1010 |
+
|
| 1011 |
+
**Split consistency across tasks:**
|
| 1012 |
+
- Task 1 hidden state collection: TRAIN split (max_samples=10000)
|
| 1013 |
+
- Tasks 2–4 offline training: uses cached hidden states from Task 1 (TRAIN split)
|
| 1014 |
+
- Tasks 2–4 PPL evaluation: TEST split (max_samples_ppl=50000, response-only)
|
| 1015 |
+
- Task 5 E2E training: TRAIN split (500K sequences HF / 100K Megatron, SFT mode)
|
| 1016 |
+
- Task 5 E2E validation: VAL split sequences (SFT mode)
|
| 1017 |
+
- Task 5 PPL evaluation: TEST split (same as tasks 2–4, response-only)
|
| 1018 |
+
|
| 1019 |
+
**SFT data loading (Task 5 and PPL evaluation):**
|
| 1020 |
+
- Each conversation is tokenized independently (one sample = one sequence)
|
| 1021 |
+
- Labels are -100 for non-assistant tokens, actual token IDs for assistant responses
|
| 1022 |
+
- `_tokenize_sft_sample()` in `model_utils.py` finds assistant token boundaries
|
| 1023 |
+
via incremental prefix tokenization of the chat template
|
| 1024 |
+
- Max sequence length: 2048 (configurable via `--max-length`)
|
| 1025 |
+
- Loss and perplexity are computed on response tokens only
|
| 1026 |
+
|
| 1027 |
+
Additional seed setting in Task 5:
|
| 1028 |
+
- `random.seed(42)`, `np.random.seed(42)`, `torch.manual_seed(42)`,
|
| 1029 |
+
`torch.cuda.manual_seed_all(42)` at start of main()
|
| 1030 |
+
- DataLoader shuffling uses PyTorch's seeded RNG
|
| 1031 |
+
|
| 1032 |
+
### 9.4 Experiment Tracking (Wandb)
|
| 1033 |
+
|
| 1034 |
+
Both HF and Megatron E2E scripts support Weights & Biases logging:
|
| 1035 |
+
|
| 1036 |
+
- **CLI:** `--wandb` / `--no-wandb`, `--wandb-project <name>`
|
| 1037 |
+
- **Logged metrics:** `train/loss` and `train/lr` per optimizer step,
|
| 1038 |
+
`val/loss` every `--val-interval` steps (default 2500) and at end of epoch,
|
| 1039 |
+
`train/epoch_loss` per epoch
|
| 1040 |
+
- **Projects:** `ecmoe-e2e` (HF), `ecmoe-megatron-e2e` (Megatron)
|
| 1041 |
+
- **Default:** Enabled in bash scripts via `WANDB_FLAG`; disable with
|
| 1042 |
+
`WANDB_FLAG="--no-wandb" bash scripts/05_run_e2e_compressor.sh none`
|
| 1043 |
+
- Megatron: only rank 0 logs to wandb
|
| 1044 |
+
- Megatron `train/loss` and `train/epoch_loss` are DP-averaged (all-reduced across
|
| 1045 |
+
data-parallel ranks) before logging, so wandb values reflect the true global loss
|
| 1046 |
+
- Graceful fallback if wandb is not installed (HAS_WANDB flag)
|
| 1047 |
+
|
| 1048 |
+
---
|
| 1049 |
+
|
| 1050 |
+
## 10. Task 8: EP Communication Compression in vLLM
|
| 1051 |
+
|
| 1052 |
+
### 10.1 Motivation
|
| 1053 |
+
|
| 1054 |
+
Tasks 5–7 evaluate compression quality using PyTorch hooks that compress and decompress
|
| 1055 |
+
on the **same GPU** — simulating the quality impact but not achieving actual communication
|
| 1056 |
+
reduction. In real expert parallelism, the pipeline is:
|
| 1057 |
+
|
| 1058 |
+
1. Router computes logits from **original** hidden states (attention GPU)
|
| 1059 |
+
2. **Compressor** runs on attention GPU: `hidden_dim` → `bottleneck_dim`
|
| 1060 |
+
3. All-to-all dispatch sends only the **compressed** tensor (reduced volume!)
|
| 1061 |
+
4. **Decompressor** runs on expert GPU: `bottleneck_dim` → `hidden_dim`
|
| 1062 |
+
5. Experts compute on decompressed states
|
| 1063 |
+
|
| 1064 |
+
Task 8 modifies vLLM's `FusedMoE.forward_impl()` to implement this pipeline,
|
| 1065 |
+
compressing BEFORE dispatch and decompressing AFTER.
|
| 1066 |
+
|
| 1067 |
+
### 10.2 Implementation
|
| 1068 |
+
|
| 1069 |
+
**Patched vLLM (`scripts/patch_vllm_fused_moe.py`):** Adds ~12 lines to
|
| 1070 |
+
`FusedMoE.forward_impl()` at three locations:
|
| 1071 |
+
|
| 1072 |
+
1. **Compress before dispatch (EP mode):** `_ecmoe_compress_fn(hidden_states)` →
|
| 1073 |
+
dispatches compressed tensor instead of full hidden_dim.
|
| 1074 |
+
2. **Decompress after dispatch (EP mode):** After `get_ep_group().dispatch()`,
|
| 1075 |
+
`_ecmoe_decompress_fn(hidden_states_combined)` restores full hidden_dim.
|
| 1076 |
+
3. **Single-GPU fallback:** When `do_naive_dispatch_combine=False` (TP=1),
|
| 1077 |
+
applies compress→decompress in-place for simulation mode.
|
| 1078 |
+
|
| 1079 |
+
When `_ecmoe_compress_fn` is None (default), behavior is identical to stock vLLM.
|
| 1080 |
+
|
| 1081 |
+
**EP-aware registration (`src/vllm_ep_compression.py`):** Uses `apply_model()`
|
| 1082 |
+
to set compress/decompress functions on each FusedMoE instance:
|
| 1083 |
+
|
| 1084 |
+
- **Per-layer:** `register_ep_perlayer()` — Independent linear compress/decompress per layer.
|
| 1085 |
+
- **Stale-conditioned:** `register_ep_stale()` — Reference layers piggyback stale signal
|
| 1086 |
+
on compressed tensor before dispatch. Non-reference layers dispatch only compressed data.
|
| 1087 |
+
|
| 1088 |
+
### 10.3 Stale Broadcast via Dispatch Piggybacking
|
| 1089 |
+
|
| 1090 |
+
**Reference layers (0, 12, 24, 36):**
|
| 1091 |
+
- compress_fn: `cat(compressed[B, bottleneck], stale[B, stale_dim])` → dispatch `[B, bottleneck + stale_dim]`
|
| 1092 |
+
- decompress_fn: split → cache stale_part globally → decompress compressed_part
|
| 1093 |
+
|
| 1094 |
+
**Non-reference layers (all others):**
|
| 1095 |
+
- compress_fn: `compressed[B, bottleneck]` only → dispatch `[B, bottleneck]` (maximum compression!)
|
| 1096 |
+
- decompress_fn: retrieve cached stale → `cat(compressed, cached_stale)` → StaleDecomp
|
| 1097 |
+
|
| 1098 |
+
**Correctness:** vLLM's default `all2all_backend=allgather_reducescatter` means after
|
| 1099 |
+
dispatch, every rank has ALL tokens in consistent ordering. Stale cached from reference
|
| 1100 |
+
layers matches token ordering at non-reference layers.
|
| 1101 |
+
|
| 1102 |
+
### 10.4 Communication Savings
|
| 1103 |
+
|
| 1104 |
+
| Mode | Ref layers (4/48) | Non-ref layers (44/48) | Weighted avg | vs baseline 2048 |
|
| 1105 |
+
|------|-------------------|----------------------|--------------|-------------------|
|
| 1106 |
+
| perlayer 2x | 1024 | 1024 | 1024 | **50% saving** |
|
| 1107 |
+
| perlayer 4x | 512 | 512 | 512 | **75% saving** |
|
| 1108 |
+
| stale(comp) 4x | 1024 | 512 | 555 | **73% saving** |
|
| 1109 |
+
| stale(uncomp) 4x | 2560 | 512 | 683 | **67% saving** |
|
| 1110 |
+
| stale(uncomp) 2x | 3072 | 1024 | 1195 | **42% saving** |
|
| 1111 |
+
|
| 1112 |
+
Stale broadcast cost is amortized over ~11 non-reference layers per reference layer.
|
| 1113 |
+
|
| 1114 |
+
### 10.5 Evaluation Modes
|
| 1115 |
+
|
| 1116 |
+
- **simulation** (`--mode simulation`): Single-GPU (TP=1), no dispatch/combine.
|
| 1117 |
+
Validates numerical correctness against existing split-mode results.
|
| 1118 |
+
- **ep** (`--mode ep`): Multi-GPU (TP=4, `enable_expert_parallel=True`).
|
| 1119 |
+
Uses actual EP dispatch/combine with compressed tensors.
|
| 1120 |
+
|
| 1121 |
+
Both use Task 7a/7b weights (split-mode E2E trained) from
|
| 1122 |
+
`results/07a_megatron_e2e_split_perlayer/` and `results/07b_megatron_e2e_split_stale/`.
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
flight-search:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "8080:8080"
|
| 6 |
+
restart: unless-stopped
|
frontend/eslint.config.js
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import tseslint from 'typescript-eslint'
|
| 6 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 7 |
+
|
| 8 |
+
export default defineConfig([
|
| 9 |
+
globalIgnores(['dist']),
|
| 10 |
+
{
|
| 11 |
+
files: ['**/*.{ts,tsx}'],
|
| 12 |
+
extends: [
|
| 13 |
+
js.configs.recommended,
|
| 14 |
+
tseslint.configs.recommended,
|
| 15 |
+
reactHooks.configs.flat.recommended,
|
| 16 |
+
reactRefresh.configs.vite,
|
| 17 |
+
],
|
| 18 |
+
languageOptions: {
|
| 19 |
+
ecmaVersion: 2020,
|
| 20 |
+
globals: globals.browser,
|
| 21 |
+
},
|
| 22 |
+
},
|
| 23 |
+
])
|
frontend/index.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>Flight Search</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "frontend",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "0.0.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "tsc -b && vite build",
|
| 9 |
+
"lint": "eslint .",
|
| 10 |
+
"preview": "vite preview"
|
| 11 |
+
},
|
| 12 |
+
"dependencies": {
|
| 13 |
+
"react": "^19.2.0",
|
| 14 |
+
"react-dom": "^19.2.0",
|
| 15 |
+
"react-router-dom": "^7.13.1"
|
| 16 |
+
},
|
| 17 |
+
"devDependencies": {
|
| 18 |
+
"@eslint/js": "^9.39.1",
|
| 19 |
+
"@tailwindcss/vite": "^4.2.1",
|
| 20 |
+
"@types/node": "^24.10.1",
|
| 21 |
+
"@types/react": "^19.2.7",
|
| 22 |
+
"@types/react-dom": "^19.2.3",
|
| 23 |
+
"@vitejs/plugin-react": "^5.1.1",
|
| 24 |
+
"eslint": "^9.39.1",
|
| 25 |
+
"eslint-plugin-react-hooks": "^7.0.1",
|
| 26 |
+
"eslint-plugin-react-refresh": "^0.4.24",
|
| 27 |
+
"globals": "^16.5.0",
|
| 28 |
+
"tailwindcss": "^4.2.1",
|
| 29 |
+
"typescript": "~5.9.3",
|
| 30 |
+
"typescript-eslint": "^8.48.0",
|
| 31 |
+
"vite": "^7.3.1"
|
| 32 |
+
}
|
| 33 |
+
}
|
frontend/public/vite.svg
ADDED
|
|
frontend/src/App.css
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#root {
|
| 2 |
+
max-width: 1280px;
|
| 3 |
+
margin: 0 auto;
|
| 4 |
+
padding: 2rem;
|
| 5 |
+
text-align: center;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
.logo {
|
| 9 |
+
height: 6em;
|
| 10 |
+
padding: 1.5em;
|
| 11 |
+
will-change: filter;
|
| 12 |
+
transition: filter 300ms;
|
| 13 |
+
}
|
| 14 |
+
.logo:hover {
|
| 15 |
+
filter: drop-shadow(0 0 2em #646cffaa);
|
| 16 |
+
}
|
| 17 |
+
.logo.react:hover {
|
| 18 |
+
filter: drop-shadow(0 0 2em #61dafbaa);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@keyframes logo-spin {
|
| 22 |
+
from {
|
| 23 |
+
transform: rotate(0deg);
|
| 24 |
+
}
|
| 25 |
+
to {
|
| 26 |
+
transform: rotate(360deg);
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@media (prefers-reduced-motion: no-preference) {
|
| 31 |
+
a:nth-of-type(2) .logo {
|
| 32 |
+
animation: logo-spin infinite 20s linear;
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.card {
|
| 37 |
+
padding: 2em;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.read-the-docs {
|
| 41 |
+
color: #888;
|
| 42 |
+
}
|
frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { BrowserRouter, Route, Routes } from 'react-router-dom';
|
| 2 |
+
import Header from './components/shared/Header';
|
| 3 |
+
import SearchPage from './pages/SearchPage';
|
| 4 |
+
import ResultsPage from './pages/ResultsPage';
|
| 5 |
+
|
| 6 |
+
export default function App() {
|
| 7 |
+
return (
|
| 8 |
+
<BrowserRouter>
|
| 9 |
+
<Header />
|
| 10 |
+
<Routes>
|
| 11 |
+
<Route path="/" element={<SearchPage />} />
|
| 12 |
+
<Route path="/results" element={<ResultsPage />} />
|
| 13 |
+
</Routes>
|
| 14 |
+
</BrowserRouter>
|
| 15 |
+
);
|
| 16 |
+
}
|
frontend/src/api/client.ts
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { AutocompleteResult, CalendarResponse, SearchRequest, SearchResponse } from './types';
|
| 2 |
+
|
| 3 |
+
const BASE_URL = '/api';
|
| 4 |
+
|
| 5 |
+
async function fetchJson<T>(url: string, init?: RequestInit): Promise<T> {
|
| 6 |
+
const res = await fetch(url, init);
|
| 7 |
+
if (!res.ok) {
|
| 8 |
+
const text = await res.text();
|
| 9 |
+
throw new Error(`API error ${res.status}: ${text}`);
|
| 10 |
+
}
|
| 11 |
+
return res.json();
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export async function searchAirports(query: string): Promise<AutocompleteResult[]> {
|
| 15 |
+
if (!query || query.length < 1) return [];
|
| 16 |
+
return fetchJson<AutocompleteResult[]>(
|
| 17 |
+
`${BASE_URL}/airports/autocomplete?q=${encodeURIComponent(query)}`
|
| 18 |
+
);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
export async function searchFlights(req: SearchRequest): Promise<SearchResponse> {
|
| 22 |
+
return fetchJson<SearchResponse>(`${BASE_URL}/search`, {
|
| 23 |
+
method: 'POST',
|
| 24 |
+
headers: { 'Content-Type': 'application/json' },
|
| 25 |
+
body: JSON.stringify(req),
|
| 26 |
+
});
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
export async function getCalendar(
|
| 30 |
+
origin: string,
|
| 31 |
+
destination: string,
|
| 32 |
+
year: number,
|
| 33 |
+
month: number,
|
| 34 |
+
cabinClass: string = 'economy'
|
| 35 |
+
): Promise<CalendarResponse> {
|
| 36 |
+
return fetchJson<CalendarResponse>(
|
| 37 |
+
`${BASE_URL}/calendar?origin=${origin}&destination=${destination}&year=${year}&month=${month}&cabin_class=${cabinClass}`
|
| 38 |
+
);
|
| 39 |
+
}
|
frontend/src/api/types.ts
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export type CabinClass = 'economy' | 'premium_economy' | 'business' | 'first';
|
| 2 |
+
export type TripType = 'one_way' | 'round_trip' | 'multi_city';
|
| 3 |
+
export type SortBy = 'best' | 'cheapest' | 'fastest';
|
| 4 |
+
|
| 5 |
+
export interface AutocompleteResult {
|
| 6 |
+
iata: string;
|
| 7 |
+
name: string;
|
| 8 |
+
city_name: string;
|
| 9 |
+
country: string;
|
| 10 |
+
display_name: string;
|
| 11 |
+
hub_score: number;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export interface FlightSegment {
|
| 15 |
+
airline_code: string;
|
| 16 |
+
airline_name: string;
|
| 17 |
+
flight_number: string;
|
| 18 |
+
aircraft: string;
|
| 19 |
+
origin: string;
|
| 20 |
+
origin_city: string;
|
| 21 |
+
destination: string;
|
| 22 |
+
destination_city: string;
|
| 23 |
+
departure: string;
|
| 24 |
+
arrival: string;
|
| 25 |
+
duration_minutes: number;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
export interface FlightOffer {
|
| 29 |
+
id: string;
|
| 30 |
+
segments: FlightSegment[];
|
| 31 |
+
total_duration_minutes: number;
|
| 32 |
+
stops: number;
|
| 33 |
+
price_usd: number;
|
| 34 |
+
cabin_class: CabinClass;
|
| 35 |
+
origin: string;
|
| 36 |
+
destination: string;
|
| 37 |
+
departure: string;
|
| 38 |
+
arrival: string;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
export interface SearchLeg {
|
| 42 |
+
origin: string;
|
| 43 |
+
destination: string;
|
| 44 |
+
date: string; // YYYY-MM-DD
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
export interface Passengers {
|
| 48 |
+
adults: number;
|
| 49 |
+
children: number;
|
| 50 |
+
infants: number;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
export interface Filters {
|
| 54 |
+
max_stops?: number | null;
|
| 55 |
+
max_price?: number | null;
|
| 56 |
+
max_duration_minutes?: number | null;
|
| 57 |
+
airlines?: string[] | null;
|
| 58 |
+
departure_time_min?: string | null;
|
| 59 |
+
departure_time_max?: string | null;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
export interface SearchRequest {
|
| 63 |
+
trip_type: TripType;
|
| 64 |
+
legs: SearchLeg[];
|
| 65 |
+
passengers: Passengers;
|
| 66 |
+
cabin_class: CabinClass;
|
| 67 |
+
filters: Filters;
|
| 68 |
+
sort_by: SortBy;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
export interface SearchResponse {
|
| 72 |
+
outbound_flights: FlightOffer[];
|
| 73 |
+
return_flights: FlightOffer[];
|
| 74 |
+
search_id: string;
|
| 75 |
+
origin: string;
|
| 76 |
+
destination: string;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
export interface CalendarDay {
|
| 80 |
+
date: string;
|
| 81 |
+
cheapest_price: number | null;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
export interface CalendarResponse {
|
| 85 |
+
origin: string;
|
| 86 |
+
destination: string;
|
| 87 |
+
year: number;
|
| 88 |
+
month: number;
|
| 89 |
+
days: CalendarDay[];
|
| 90 |
+
}
|
frontend/src/assets/react.svg
ADDED
|
|
frontend/src/components/results/FilterPanel.tsx
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useMemo } from 'react';
|
| 2 |
+
import type { Filters, FlightOffer } from '../../api/types';
|
| 3 |
+
|
| 4 |
+
interface Props {
|
| 5 |
+
flights: FlightOffer[];
|
| 6 |
+
filters: Filters;
|
| 7 |
+
onChange: (f: Filters) => void;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
export default function FilterPanel({ flights, filters, onChange }: Props) {
|
| 11 |
+
// Compute available filter options from flights
|
| 12 |
+
const airlines = useMemo(() => {
|
| 13 |
+
const map = new Map<string, string>();
|
| 14 |
+
flights.forEach(f => f.segments.forEach(s => map.set(s.airline_code, s.airline_name)));
|
| 15 |
+
return Array.from(map.entries()).sort((a, b) => a[1].localeCompare(b[1]));
|
| 16 |
+
}, [flights]);
|
| 17 |
+
|
| 18 |
+
const maxStopsAvailable = useMemo(() => Math.max(...flights.map(f => f.stops), 0), [flights]);
|
| 19 |
+
|
| 20 |
+
return (
|
| 21 |
+
<div className="space-y-6" data-testid="filter-panel">
|
| 22 |
+
{/* Stops filter */}
|
| 23 |
+
<div>
|
| 24 |
+
<h3 className="mb-2 text-sm font-medium text-gray-900">Stops</h3>
|
| 25 |
+
<div className="space-y-1">
|
| 26 |
+
{[null, 0, 1, 2].filter(v => v === null || v <= maxStopsAvailable).map(v => (
|
| 27 |
+
<label key={String(v)} className="flex items-center gap-2 cursor-pointer">
|
| 28 |
+
<input
|
| 29 |
+
type="radio"
|
| 30 |
+
name="stops"
|
| 31 |
+
checked={filters.max_stops === v}
|
| 32 |
+
onChange={() => onChange({ ...filters, max_stops: v })}
|
| 33 |
+
className="accent-[#1a73e8]"
|
| 34 |
+
data-testid={`filter-stops-${v === null ? 'any' : v}`}
|
| 35 |
+
/>
|
| 36 |
+
<span className="text-sm text-gray-700">
|
| 37 |
+
{v === null ? 'Any' : v === 0 ? 'Nonstop only' : `Up to ${v} stop${v > 1 ? 's' : ''}`}
|
| 38 |
+
</span>
|
| 39 |
+
</label>
|
| 40 |
+
))}
|
| 41 |
+
</div>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
{/* Price filter */}
|
| 45 |
+
<div>
|
| 46 |
+
<h3 className="mb-2 text-sm font-medium text-gray-900">Max price</h3>
|
| 47 |
+
<div className="flex items-center gap-2">
|
| 48 |
+
<span className="text-sm text-gray-500">$</span>
|
| 49 |
+
<input
|
| 50 |
+
type="number"
|
| 51 |
+
value={filters.max_price ?? ''}
|
| 52 |
+
onChange={e => onChange({ ...filters, max_price: e.target.value ? Number(e.target.value) : null })}
|
| 53 |
+
placeholder="Any"
|
| 54 |
+
className="w-24 rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
|
| 55 |
+
min={0}
|
| 56 |
+
data-testid="filter-max-price"
|
| 57 |
+
/>
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
|
| 61 |
+
{/* Departure time filter */}
|
| 62 |
+
<div>
|
| 63 |
+
<h3 className="mb-2 text-sm font-medium text-gray-900">Departure time</h3>
|
| 64 |
+
<div className="flex items-center gap-2">
|
| 65 |
+
<input
|
| 66 |
+
type="time"
|
| 67 |
+
value={filters.departure_time_min ?? ''}
|
| 68 |
+
onChange={e => onChange({ ...filters, departure_time_min: e.target.value || null })}
|
| 69 |
+
className="rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
|
| 70 |
+
data-testid="filter-dep-time-min"
|
| 71 |
+
/>
|
| 72 |
+
<span className="text-xs text-gray-500">to</span>
|
| 73 |
+
<input
|
| 74 |
+
type="time"
|
| 75 |
+
value={filters.departure_time_max ?? ''}
|
| 76 |
+
onChange={e => onChange({ ...filters, departure_time_max: e.target.value || null })}
|
| 77 |
+
className="rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
|
| 78 |
+
data-testid="filter-dep-time-max"
|
| 79 |
+
/>
|
| 80 |
+
</div>
|
| 81 |
+
</div>
|
| 82 |
+
|
| 83 |
+
{/* Airlines filter */}
|
| 84 |
+
{airlines.length > 1 && (
|
| 85 |
+
<div>
|
| 86 |
+
<h3 className="mb-2 text-sm font-medium text-gray-900">Airlines</h3>
|
| 87 |
+
<div className="max-h-48 space-y-1 overflow-y-auto">
|
| 88 |
+
{airlines.map(([code, name]) => {
|
| 89 |
+
const selected = !filters.airlines || filters.airlines.includes(code);
|
| 90 |
+
return (
|
| 91 |
+
<label key={code} className="flex items-center gap-2 cursor-pointer">
|
| 92 |
+
<input
|
| 93 |
+
type="checkbox"
|
| 94 |
+
checked={selected}
|
| 95 |
+
onChange={() => {
|
| 96 |
+
let next: string[] | null;
|
| 97 |
+
if (!filters.airlines) {
|
| 98 |
+
// First deselection: select all except this one
|
| 99 |
+
next = airlines.filter(([c]) => c !== code).map(([c]) => c);
|
| 100 |
+
} else if (selected) {
|
| 101 |
+
next = filters.airlines.filter(c => c !== code);
|
| 102 |
+
if (next.length === 0) next = null; // deselect all = any
|
| 103 |
+
} else {
|
| 104 |
+
next = [...filters.airlines, code];
|
| 105 |
+
if (next.length === airlines.length) next = null; // all selected = any
|
| 106 |
+
}
|
| 107 |
+
onChange({ ...filters, airlines: next });
|
| 108 |
+
}}
|
| 109 |
+
className="accent-[#1a73e8]"
|
| 110 |
+
data-testid={`filter-airline-${code}`}
|
| 111 |
+
/>
|
| 112 |
+
<span className="text-sm text-gray-700">{name} ({code})</span>
|
| 113 |
+
</label>
|
| 114 |
+
);
|
| 115 |
+
})}
|
| 116 |
+
</div>
|
| 117 |
+
</div>
|
| 118 |
+
)}
|
| 119 |
+
|
| 120 |
+
{/* Clear all */}
|
| 121 |
+
<button
|
| 122 |
+
onClick={() => onChange({
|
| 123 |
+
max_stops: null, max_price: null, max_duration_minutes: null,
|
| 124 |
+
airlines: null, departure_time_min: null, departure_time_max: null,
|
| 125 |
+
})}
|
| 126 |
+
className="text-sm text-[#1a73e8] hover:underline cursor-pointer"
|
| 127 |
+
data-testid="filter-clear"
|
| 128 |
+
>
|
| 129 |
+
Clear all filters
|
| 130 |
+
</button>
|
| 131 |
+
</div>
|
| 132 |
+
);
|
| 133 |
+
}
|
frontend/src/components/results/FlightCard.tsx
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState } from 'react';
|
| 2 |
+
import type { FlightOffer } from '../../api/types';
|
| 3 |
+
import { formatDuration, formatPrice, formatStops, formatTime } from '../../utils/format';
|
| 4 |
+
import FlightSegmentView from './FlightSegment';
|
| 5 |
+
|
| 6 |
+
interface Props {
|
| 7 |
+
flight: FlightOffer;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
export default function FlightCard({ flight }: Props) {
|
| 11 |
+
const [expanded, setExpanded] = useState(false);
|
| 12 |
+
const firstSeg = flight.segments[0];
|
| 13 |
+
|
| 14 |
+
// Check if arrival is on a different day
|
| 15 |
+
const depDate = new Date(flight.departure).toDateString();
|
| 16 |
+
const arrDate = new Date(flight.arrival).toDateString();
|
| 17 |
+
const dayDiff = depDate !== arrDate;
|
| 18 |
+
|
| 19 |
+
return (
|
| 20 |
+
<div
|
| 21 |
+
className="rounded-lg border border-gray-200 bg-white hover:shadow-md transition-shadow cursor-pointer"
|
| 22 |
+
onClick={() => flight.stops > 0 && setExpanded(!expanded)}
|
| 23 |
+
data-testid={`flight-card-${flight.id}`}
|
| 24 |
+
role="article"
|
| 25 |
+
aria-label={`Flight from ${flight.origin} to ${flight.destination}, ${formatPrice(flight.price_usd)}`}
|
| 26 |
+
>
|
| 27 |
+
<div className="flex items-center gap-4 p-4">
|
| 28 |
+
{/* Airline badge */}
|
| 29 |
+
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-gray-100 text-xs font-bold text-gray-600 flex-shrink-0" data-testid="airline-badge">
|
| 30 |
+
{firstSeg.airline_code}
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
{/* Main info */}
|
| 34 |
+
<div className="flex flex-1 items-center gap-6">
|
| 35 |
+
{/* Times */}
|
| 36 |
+
<div className="flex items-center gap-3 flex-1">
|
| 37 |
+
<div className="text-right">
|
| 38 |
+
<div className="text-lg font-medium" data-testid="departure-time">{formatTime(flight.departure)}</div>
|
| 39 |
+
<div className="text-xs text-gray-500">{flight.origin}</div>
|
| 40 |
+
</div>
|
| 41 |
+
|
| 42 |
+
<div className="flex flex-1 flex-col items-center px-2">
|
| 43 |
+
<div className="text-xs text-gray-500">{formatDuration(flight.total_duration_minutes)}</div>
|
| 44 |
+
<div className="relative my-1 h-px w-full bg-gray-300">
|
| 45 |
+
<div className="absolute left-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
|
| 46 |
+
<div className="absolute right-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
|
| 47 |
+
{/* Stop indicators */}
|
| 48 |
+
{flight.stops > 0 && flight.segments.slice(0, -1).map((_, i) => (
|
| 49 |
+
<div
|
| 50 |
+
key={i}
|
| 51 |
+
className="absolute top-1/2 h-2 w-2 -translate-y-1/2 rounded-full bg-gray-400"
|
| 52 |
+
style={{ left: `${((i + 1) / flight.segments.length) * 100}%` }}
|
| 53 |
+
/>
|
| 54 |
+
))}
|
| 55 |
+
</div>
|
| 56 |
+
<div className="text-xs text-gray-500" data-testid="stops">{formatStops(flight.stops)}</div>
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
<div>
|
| 60 |
+
<div className="text-lg font-medium" data-testid="arrival-time">
|
| 61 |
+
{formatTime(flight.arrival)}
|
| 62 |
+
{dayDiff && <sup className="ml-0.5 text-xs text-red-500">+1</sup>}
|
| 63 |
+
</div>
|
| 64 |
+
<div className="text-xs text-gray-500">{flight.destination}</div>
|
| 65 |
+
</div>
|
| 66 |
+
</div>
|
| 67 |
+
|
| 68 |
+
{/* Airline name */}
|
| 69 |
+
<div className="hidden md:block text-xs text-gray-500 min-w-[100px]" data-testid="airline-name">
|
| 70 |
+
{firstSeg.airline_name}
|
| 71 |
+
</div>
|
| 72 |
+
</div>
|
| 73 |
+
|
| 74 |
+
{/* Price */}
|
| 75 |
+
<div className="text-right pl-4 min-w-[80px]">
|
| 76 |
+
<div className="text-lg font-medium text-gray-900" data-testid="price">
|
| 77 |
+
{formatPrice(flight.price_usd)}
|
| 78 |
+
</div>
|
| 79 |
+
<div className="text-xs text-gray-500">{flight.cabin_class.replace('_', ' ')}</div>
|
| 80 |
+
</div>
|
| 81 |
+
|
| 82 |
+
{/* Expand icon for connecting flights */}
|
| 83 |
+
{flight.stops > 0 && (
|
| 84 |
+
<svg
|
| 85 |
+
className={`h-5 w-5 text-gray-400 transition-transform ${expanded ? 'rotate-180' : ''}`}
|
| 86 |
+
viewBox="0 0 20 20" fill="currentColor"
|
| 87 |
+
>
|
| 88 |
+
<path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
|
| 89 |
+
</svg>
|
| 90 |
+
)}
|
| 91 |
+
</div>
|
| 92 |
+
|
| 93 |
+
{/* Expanded segment details */}
|
| 94 |
+
{expanded && flight.stops > 0 && (
|
| 95 |
+
<div className="border-t border-gray-100 px-4 pb-4" data-testid="segments-detail">
|
| 96 |
+
{flight.segments.map((seg, i) => (
|
| 97 |
+
<div key={i}>
|
| 98 |
+
<FlightSegmentView segment={seg} showDetails />
|
| 99 |
+
{i < flight.segments.length - 1 && (
|
| 100 |
+
<div className="ml-14 border-l-2 border-dashed border-gray-200 py-2 pl-4">
|
| 101 |
+
<span className="text-xs text-gray-400">
|
| 102 |
+
Layover at {seg.destination} ({seg.destination_city})
|
| 103 |
+
</span>
|
| 104 |
+
</div>
|
| 105 |
+
)}
|
| 106 |
+
</div>
|
| 107 |
+
))}
|
| 108 |
+
</div>
|
| 109 |
+
)}
|
| 110 |
+
</div>
|
| 111 |
+
);
|
| 112 |
+
}
|
frontend/src/components/results/FlightSegment.tsx
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { FlightSegment as SegmentType } from '../../api/types';
|
| 2 |
+
import { formatDuration, formatTime } from '../../utils/format';
|
| 3 |
+
|
| 4 |
+
interface Props {
|
| 5 |
+
segment: SegmentType;
|
| 6 |
+
showDetails?: boolean;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
export default function FlightSegmentView({ segment, showDetails }: Props) {
|
| 10 |
+
return (
|
| 11 |
+
<div className="flex items-center gap-4 py-2" data-testid="flight-segment">
|
| 12 |
+
{/* Airline badge */}
|
| 13 |
+
<div className="flex h-9 w-9 items-center justify-center rounded bg-gray-100 text-xs font-bold text-gray-600 flex-shrink-0" data-testid="airline-badge">
|
| 14 |
+
{segment.airline_code}
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
{/* Timeline */}
|
| 18 |
+
<div className="flex flex-1 items-center gap-3">
|
| 19 |
+
<div className="text-right min-w-[70px]">
|
| 20 |
+
<div className="text-base font-medium" data-testid="departure-time">{formatTime(segment.departure)}</div>
|
| 21 |
+
<div className="text-xs text-gray-500" data-testid="origin-code">{segment.origin}</div>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
<div className="flex flex-1 flex-col items-center px-2">
|
| 25 |
+
<div className="text-xs text-gray-500">{formatDuration(segment.duration_minutes)}</div>
|
| 26 |
+
<div className="relative my-1 h-px w-full bg-gray-300">
|
| 27 |
+
<div className="absolute left-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
|
| 28 |
+
<div className="absolute right-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
|
| 29 |
+
</div>
|
| 30 |
+
{showDetails && (
|
| 31 |
+
<div className="text-xs text-gray-400">{segment.airline_name} · {segment.flight_number} · {segment.aircraft}</div>
|
| 32 |
+
)}
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
<div className="min-w-[70px]">
|
| 36 |
+
<div className="text-base font-medium" data-testid="arrival-time">{formatTime(segment.arrival)}</div>
|
| 37 |
+
<div className="text-xs text-gray-500" data-testid="destination-code">{segment.destination}</div>
|
| 38 |
+
</div>
|
| 39 |
+
</div>
|
| 40 |
+
</div>
|
| 41 |
+
);
|
| 42 |
+
}
|
frontend/src/components/results/NoResults.tsx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface Props {
|
| 2 |
+
hasFilters: boolean;
|
| 3 |
+
onClearFilters?: () => void;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
export default function NoResults({ hasFilters, onClearFilters }: Props) {
|
| 7 |
+
return (
|
| 8 |
+
<div className="flex flex-col items-center justify-center py-16 text-center" data-testid="no-results">
|
| 9 |
+
<svg className="mb-4 h-16 w-16 text-gray-300" viewBox="0 0 24 24" fill="currentColor">
|
| 10 |
+
<path d="M21 16v-2l-8-5V3.5c0-.83-.67-1.5-1.5-1.5S10 2.67 10 3.5V9l-8 5v2l8-2.5V19l-2 1.5V22l3.5-1 3.5 1v-1.5L13 19v-5.5l8 2.5z"/>
|
| 11 |
+
</svg>
|
| 12 |
+
<h3 className="text-lg font-medium text-gray-700">No flights found</h3>
|
| 13 |
+
<p className="mt-1 text-sm text-gray-500">
|
| 14 |
+
{hasFilters
|
| 15 |
+
? 'Try adjusting your filters or search criteria.'
|
| 16 |
+
: 'No direct or connecting flights available for this route.'}
|
| 17 |
+
</p>
|
| 18 |
+
{hasFilters && onClearFilters && (
|
| 19 |
+
<button
|
| 20 |
+
onClick={onClearFilters}
|
| 21 |
+
className="mt-4 text-sm font-medium text-[#1a73e8] hover:underline cursor-pointer"
|
| 22 |
+
data-testid="clear-filters-button"
|
| 23 |
+
>
|
| 24 |
+
Clear all filters
|
| 25 |
+
</button>
|
| 26 |
+
)}
|
| 27 |
+
</div>
|
| 28 |
+
);
|
| 29 |
+
}
|
frontend/src/components/results/SortBar.tsx
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { SortBy } from '../../api/types';
|
| 2 |
+
|
| 3 |
+
const OPTIONS: { value: SortBy; label: string }[] = [
|
| 4 |
+
{ value: 'best', label: 'Best' },
|
| 5 |
+
{ value: 'cheapest', label: 'Cheapest' },
|
| 6 |
+
{ value: 'fastest', label: 'Fastest' },
|
| 7 |
+
];
|
| 8 |
+
|
| 9 |
+
interface Props {
|
| 10 |
+
value: SortBy;
|
| 11 |
+
onChange: (v: SortBy) => void;
|
| 12 |
+
resultCount: number;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export default function SortBar({ value, onChange, resultCount }: Props) {
|
| 16 |
+
return (
|
| 17 |
+
<div className="flex items-center justify-between" data-testid="sort-bar">
|
| 18 |
+
<span className="text-sm text-gray-500" data-testid="result-count">
|
| 19 |
+
{resultCount} result{resultCount !== 1 ? 's' : ''}
|
| 20 |
+
</span>
|
| 21 |
+
<div className="flex gap-1 rounded-lg bg-gray-100 p-1">
|
| 22 |
+
{OPTIONS.map(o => (
|
| 23 |
+
<button
|
| 24 |
+
key={o.value}
|
| 25 |
+
onClick={() => onChange(o.value)}
|
| 26 |
+
className={`rounded-md px-4 py-1.5 text-sm font-medium transition-colors cursor-pointer ${
|
| 27 |
+
value === o.value
|
| 28 |
+
? 'bg-white text-[#1a73e8] shadow-sm'
|
| 29 |
+
: 'text-gray-600 hover:text-gray-900'
|
| 30 |
+
}`}
|
| 31 |
+
data-testid={`sort-${o.value}`}
|
| 32 |
+
aria-pressed={value === o.value}
|
| 33 |
+
>
|
| 34 |
+
{o.label}
|
| 35 |
+
</button>
|
| 36 |
+
))}
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
);
|
| 40 |
+
}
|
frontend/src/components/search/AirportInput.tsx
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useRef, useState } from 'react';
|
| 2 |
+
import { searchAirports } from '../../api/client';
|
| 3 |
+
import type { AutocompleteResult } from '../../api/types';
|
| 4 |
+
import { useDebounce } from '../../hooks/useDebounce';
|
| 5 |
+
|
| 6 |
+
interface Props {
|
| 7 |
+
label: string;
|
| 8 |
+
value: string; // IATA code
|
| 9 |
+
displayValue: string; // "New York (JFK)"
|
| 10 |
+
onChange: (iata: string, display: string) => void;
|
| 11 |
+
placeholder?: string;
|
| 12 |
+
testId?: string;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export default function AirportInput({ label, value, displayValue, onChange, placeholder, testId }: Props) {
|
| 16 |
+
const [query, setQuery] = useState(displayValue);
|
| 17 |
+
const [results, setResults] = useState<AutocompleteResult[]>([]);
|
| 18 |
+
const [open, setOpen] = useState(false);
|
| 19 |
+
const [focused, setFocused] = useState(false);
|
| 20 |
+
const debouncedQuery = useDebounce(query, 200);
|
| 21 |
+
const wrapperRef = useRef<HTMLDivElement>(null);
|
| 22 |
+
|
| 23 |
+
// Sync display value when parent changes it
|
| 24 |
+
useEffect(() => {
|
| 25 |
+
if (!focused) setQuery(displayValue);
|
| 26 |
+
}, [displayValue, focused]);
|
| 27 |
+
|
| 28 |
+
// Fetch autocomplete results
|
| 29 |
+
useEffect(() => {
|
| 30 |
+
if (!focused) return;
|
| 31 |
+
if (debouncedQuery.length < 1) {
|
| 32 |
+
setResults([]);
|
| 33 |
+
return;
|
| 34 |
+
}
|
| 35 |
+
let cancelled = false;
|
| 36 |
+
searchAirports(debouncedQuery).then(r => {
|
| 37 |
+
if (!cancelled) {
|
| 38 |
+
setResults(r);
|
| 39 |
+
setOpen(r.length > 0);
|
| 40 |
+
}
|
| 41 |
+
});
|
| 42 |
+
return () => { cancelled = true; };
|
| 43 |
+
}, [debouncedQuery, focused]);
|
| 44 |
+
|
| 45 |
+
// Close on click outside
|
| 46 |
+
useEffect(() => {
|
| 47 |
+
function handler(e: MouseEvent) {
|
| 48 |
+
if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
|
| 49 |
+
setOpen(false);
|
| 50 |
+
setFocused(false);
|
| 51 |
+
if (!value) setQuery('');
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
document.addEventListener('mousedown', handler);
|
| 55 |
+
return () => document.removeEventListener('mousedown', handler);
|
| 56 |
+
}, [value]);
|
| 57 |
+
|
| 58 |
+
function select(r: AutocompleteResult) {
|
| 59 |
+
onChange(r.iata, `${r.city_name} (${r.iata})`);
|
| 60 |
+
setQuery(`${r.city_name} (${r.iata})`);
|
| 61 |
+
setOpen(false);
|
| 62 |
+
setFocused(false);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
return (
|
| 66 |
+
<div ref={wrapperRef} className="relative flex-1 min-w-[180px]" data-testid={testId}>
|
| 67 |
+
<label className="absolute -top-2 left-3 bg-white px-1 text-xs text-gray-500 z-10">{label}</label>
|
| 68 |
+
<input
|
| 69 |
+
type="text"
|
| 70 |
+
value={query}
|
| 71 |
+
onChange={e => { setQuery(e.target.value); setOpen(true); }}
|
| 72 |
+
onFocus={() => { setFocused(true); setQuery(''); setOpen(true); }}
|
| 73 |
+
placeholder={placeholder || 'City or airport'}
|
| 74 |
+
className="w-full rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-900 placeholder-gray-400 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none"
|
| 75 |
+
aria-label={label}
|
| 76 |
+
data-testid={testId ? `${testId}-input` : undefined}
|
| 77 |
+
autoComplete="off"
|
| 78 |
+
/>
|
| 79 |
+
{open && results.length > 0 && (
|
| 80 |
+
<ul
|
| 81 |
+
className="absolute top-full left-0 right-0 z-50 mt-1 max-h-64 overflow-y-auto rounded-lg border border-gray-200 bg-white shadow-lg"
|
| 82 |
+
data-testid={testId ? `${testId}-dropdown` : undefined}
|
| 83 |
+
role="listbox"
|
| 84 |
+
>
|
| 85 |
+
{results.map(r => (
|
| 86 |
+
<li
|
| 87 |
+
key={r.iata}
|
| 88 |
+
onClick={() => select(r)}
|
| 89 |
+
className="flex cursor-pointer items-center gap-3 px-4 py-3 hover:bg-gray-50"
|
| 90 |
+
role="option"
|
| 91 |
+
data-testid={`airport-option-${r.iata}`}
|
| 92 |
+
aria-selected={r.iata === value}
|
| 93 |
+
>
|
| 94 |
+
<svg className="h-5 w-5 flex-shrink-0 text-gray-400" viewBox="0 0 24 24" fill="none">
|
| 95 |
+
<path d="M12 2C8.13 2 5 5.13 5 9c0 5.25 7 13 7 13s7-7.75 7-13c0-3.87-3.13-7-7-7zm0 9.5c-1.38 0-2.5-1.12-2.5-2.5s1.12-2.5 2.5-2.5 2.5 1.12 2.5 2.5-1.12 2.5-2.5 2.5z" fill="currentColor"/>
|
| 96 |
+
</svg>
|
| 97 |
+
<div className="flex-1 min-w-0">
|
| 98 |
+
<div className="text-sm font-medium text-gray-900 truncate">{r.city_name} ({r.iata})</div>
|
| 99 |
+
<div className="text-xs text-gray-500 truncate">{r.name}, {r.country}</div>
|
| 100 |
+
</div>
|
| 101 |
+
</li>
|
| 102 |
+
))}
|
| 103 |
+
</ul>
|
| 104 |
+
)}
|
| 105 |
+
</div>
|
| 106 |
+
);
|
| 107 |
+
}
|
frontend/src/components/search/ClassSelector.tsx
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { CabinClass } from '../../api/types';
|
| 2 |
+
import { cabinClassLabel } from '../../utils/format';
|
| 3 |
+
|
| 4 |
+
const OPTIONS: CabinClass[] = ['economy', 'premium_economy', 'business', 'first'];
|
| 5 |
+
|
| 6 |
+
interface Props {
|
| 7 |
+
value: CabinClass;
|
| 8 |
+
onChange: (v: CabinClass) => void;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
export default function ClassSelector({ value, onChange }: Props) {
|
| 12 |
+
return (
|
| 13 |
+
<div className="relative" data-testid="class-selector">
|
| 14 |
+
<select
|
| 15 |
+
value={value}
|
| 16 |
+
onChange={e => onChange(e.target.value as CabinClass)}
|
| 17 |
+
className="appearance-none rounded-md border border-gray-300 bg-white px-3 py-2 pr-8 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
|
| 18 |
+
aria-label="Cabin class"
|
| 19 |
+
data-testid="class-select"
|
| 20 |
+
>
|
| 21 |
+
{OPTIONS.map(c => (
|
| 22 |
+
<option key={c} value={c}>{cabinClassLabel(c)}</option>
|
| 23 |
+
))}
|
| 24 |
+
</select>
|
| 25 |
+
<svg className="pointer-events-none absolute right-2 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" viewBox="0 0 20 20" fill="currentColor">
|
| 26 |
+
<path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
|
| 27 |
+
</svg>
|
| 28 |
+
</div>
|
| 29 |
+
);
|
| 30 |
+
}
|
frontend/src/components/search/DatePicker.tsx
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface Props {
|
| 2 |
+
label: string;
|
| 3 |
+
value: string; // YYYY-MM-DD
|
| 4 |
+
onChange: (v: string) => void;
|
| 5 |
+
testId?: string;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export default function DatePicker({ label, value, onChange, testId }: Props) {
|
| 9 |
+
return (
|
| 10 |
+
<div className="relative min-w-[150px]" data-testid={testId}>
|
| 11 |
+
<label className="absolute -top-2 left-3 bg-white px-1 text-xs text-gray-500 z-10">{label}</label>
|
| 12 |
+
<input
|
| 13 |
+
type="date"
|
| 14 |
+
value={value}
|
| 15 |
+
onChange={e => onChange(e.target.value)}
|
| 16 |
+
className="w-full rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-900 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
|
| 17 |
+
aria-label={label}
|
| 18 |
+
data-testid={testId ? `${testId}-input` : undefined}
|
| 19 |
+
/>
|
| 20 |
+
</div>
|
| 21 |
+
);
|
| 22 |
+
}
|
frontend/src/components/search/PassengerSelector.tsx
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useRef, useState } from 'react';
|
| 2 |
+
import type { Passengers } from '../../api/types';
|
| 3 |
+
|
| 4 |
+
interface Props {
|
| 5 |
+
value: Passengers;
|
| 6 |
+
onChange: (v: Passengers) => void;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
export default function PassengerSelector({ value, onChange }: Props) {
|
| 10 |
+
const [open, setOpen] = useState(false);
|
| 11 |
+
const ref = useRef<HTMLDivElement>(null);
|
| 12 |
+
|
| 13 |
+
useEffect(() => {
|
| 14 |
+
function handler(e: MouseEvent) {
|
| 15 |
+
if (ref.current && !ref.current.contains(e.target as Node)) setOpen(false);
|
| 16 |
+
}
|
| 17 |
+
document.addEventListener('mousedown', handler);
|
| 18 |
+
return () => document.removeEventListener('mousedown', handler);
|
| 19 |
+
}, []);
|
| 20 |
+
|
| 21 |
+
const total = value.adults + value.children + value.infants;
|
| 22 |
+
|
| 23 |
+
function update(field: keyof Passengers, delta: number) {
|
| 24 |
+
const v = { ...value };
|
| 25 |
+
v[field] = Math.max(field === 'adults' ? 1 : 0, Math.min(9, v[field] + delta));
|
| 26 |
+
onChange(v);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
return (
|
| 30 |
+
<div ref={ref} className="relative" data-testid="passenger-selector">
|
| 31 |
+
<button
|
| 32 |
+
onClick={() => setOpen(!open)}
|
| 33 |
+
className="rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none flex items-center gap-1 cursor-pointer"
|
| 34 |
+
aria-label="Passengers"
|
| 35 |
+
data-testid="passenger-button"
|
| 36 |
+
>
|
| 37 |
+
<svg className="h-4 w-4 text-gray-500" viewBox="0 0 24 24" fill="currentColor">
|
| 38 |
+
<path d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"/>
|
| 39 |
+
</svg>
|
| 40 |
+
<span>{total}</span>
|
| 41 |
+
<svg className="h-4 w-4 text-gray-400" viewBox="0 0 20 20" fill="currentColor">
|
| 42 |
+
<path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
|
| 43 |
+
</svg>
|
| 44 |
+
</button>
|
| 45 |
+
|
| 46 |
+
{open && (
|
| 47 |
+
<div className="absolute top-full right-0 z-50 mt-1 w-64 rounded-lg border border-gray-200 bg-white p-4 shadow-lg" data-testid="passenger-dropdown">
|
| 48 |
+
{([
|
| 49 |
+
{ key: 'adults' as const, label: 'Adults', sub: '' },
|
| 50 |
+
{ key: 'children' as const, label: 'Children', sub: 'Aged 2-11' },
|
| 51 |
+
{ key: 'infants' as const, label: 'Infants', sub: 'Under 2' },
|
| 52 |
+
]).map(row => (
|
| 53 |
+
<div key={row.key} className="flex items-center justify-between py-2">
|
| 54 |
+
<div>
|
| 55 |
+
<div className="text-sm font-medium text-gray-900">{row.label}</div>
|
| 56 |
+
{row.sub && <div className="text-xs text-gray-500">{row.sub}</div>}
|
| 57 |
+
</div>
|
| 58 |
+
<div className="flex items-center gap-3">
|
| 59 |
+
<button
|
| 60 |
+
onClick={() => update(row.key, -1)}
|
| 61 |
+
disabled={value[row.key] <= (row.key === 'adults' ? 1 : 0)}
|
| 62 |
+
className="h-8 w-8 rounded-full border border-gray-300 text-gray-600 hover:bg-gray-50 disabled:opacity-30 disabled:cursor-not-allowed cursor-pointer flex items-center justify-center"
|
| 63 |
+
aria-label={`Decrease ${row.label}`}
|
| 64 |
+
data-testid={`${row.key}-decrease`}
|
| 65 |
+
>−</button>
|
| 66 |
+
<span className="w-4 text-center text-sm" data-testid={`${row.key}-count`}>{value[row.key]}</span>
|
| 67 |
+
<button
|
| 68 |
+
onClick={() => update(row.key, 1)}
|
| 69 |
+
disabled={value[row.key] >= 9}
|
| 70 |
+
className="h-8 w-8 rounded-full border border-gray-300 text-gray-600 hover:bg-gray-50 disabled:opacity-30 disabled:cursor-not-allowed cursor-pointer flex items-center justify-center"
|
| 71 |
+
aria-label={`Increase ${row.label}`}
|
| 72 |
+
data-testid={`${row.key}-increase`}
|
| 73 |
+
>+</button>
|
| 74 |
+
</div>
|
| 75 |
+
</div>
|
| 76 |
+
))}
|
| 77 |
+
<button
|
| 78 |
+
onClick={() => setOpen(false)}
|
| 79 |
+
className="mt-2 w-full rounded-md bg-[#1a73e8] py-2 text-sm text-white hover:bg-[#1765cc] cursor-pointer"
|
| 80 |
+
data-testid="passenger-done"
|
| 81 |
+
>Done</button>
|
| 82 |
+
</div>
|
| 83 |
+
)}
|
| 84 |
+
</div>
|
| 85 |
+
);
|
| 86 |
+
}
|
frontend/src/components/search/SearchForm.tsx
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState } from 'react';
|
| 2 |
+
import type { CabinClass, Passengers, TripType } from '../../api/types';
|
| 3 |
+
import { getDefaultDepartureDate, getDefaultReturnDate } from '../../utils/date';
|
| 4 |
+
import AirportInput from './AirportInput';
|
| 5 |
+
import ClassSelector from './ClassSelector';
|
| 6 |
+
import DatePicker from './DatePicker';
|
| 7 |
+
import PassengerSelector from './PassengerSelector';
|
| 8 |
+
import SwapButton from './SwapButton';
|
| 9 |
+
import TripTypeSelector from './TripTypeSelector';
|
| 10 |
+
|
| 11 |
+
export interface SearchFormData {
|
| 12 |
+
tripType: TripType;
|
| 13 |
+
origin: string;
|
| 14 |
+
originDisplay: string;
|
| 15 |
+
destination: string;
|
| 16 |
+
destinationDisplay: string;
|
| 17 |
+
departureDate: string;
|
| 18 |
+
returnDate: string;
|
| 19 |
+
passengers: Passengers;
|
| 20 |
+
cabinClass: CabinClass;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
interface Props {
|
| 24 |
+
initial?: Partial<SearchFormData>;
|
| 25 |
+
onSearch: (data: SearchFormData) => void;
|
| 26 |
+
compact?: boolean;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
export default function SearchForm({ initial, onSearch, compact }: Props) {
|
| 30 |
+
const [tripType, setTripType] = useState<TripType>(initial?.tripType || 'round_trip');
|
| 31 |
+
const [origin, setOrigin] = useState(initial?.origin || '');
|
| 32 |
+
const [originDisplay, setOriginDisplay] = useState(initial?.originDisplay || '');
|
| 33 |
+
const [destination, setDestination] = useState(initial?.destination || '');
|
| 34 |
+
const [destinationDisplay, setDestinationDisplay] = useState(initial?.destinationDisplay || '');
|
| 35 |
+
const [departureDate, setDepartureDate] = useState(initial?.departureDate || getDefaultDepartureDate());
|
| 36 |
+
const [returnDate, setReturnDate] = useState(initial?.returnDate || getDefaultReturnDate());
|
| 37 |
+
const [passengers, setPassengers] = useState<Passengers>(initial?.passengers || { adults: 1, children: 0, infants: 0 });
|
| 38 |
+
const [cabinClass, setCabinClass] = useState<CabinClass>(initial?.cabinClass || 'economy');
|
| 39 |
+
|
| 40 |
+
function handleSwap() {
|
| 41 |
+
const tmpCode = origin;
|
| 42 |
+
const tmpDisplay = originDisplay;
|
| 43 |
+
setOrigin(destination);
|
| 44 |
+
setOriginDisplay(destinationDisplay);
|
| 45 |
+
setDestination(tmpCode);
|
| 46 |
+
setDestinationDisplay(tmpDisplay);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
function handleSubmit(e: React.FormEvent) {
|
| 50 |
+
e.preventDefault();
|
| 51 |
+
if (!origin || !destination) return;
|
| 52 |
+
onSearch({
|
| 53 |
+
tripType, origin, originDisplay, destination, destinationDisplay,
|
| 54 |
+
departureDate, returnDate, passengers, cabinClass,
|
| 55 |
+
});
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return (
|
| 59 |
+
<form onSubmit={handleSubmit} data-testid="search-form">
|
| 60 |
+
{/* Row 1: Trip type, passengers, class */}
|
| 61 |
+
<div className="mb-3 flex flex-wrap items-center gap-2">
|
| 62 |
+
<TripTypeSelector value={tripType} onChange={setTripType} />
|
| 63 |
+
<PassengerSelector value={passengers} onChange={setPassengers} />
|
| 64 |
+
<ClassSelector value={cabinClass} onChange={setCabinClass} />
|
| 65 |
+
</div>
|
| 66 |
+
|
| 67 |
+
{/* Row 2: Airport inputs + dates + search button */}
|
| 68 |
+
<div className={`flex flex-wrap items-end gap-2 ${compact ? '' : 'rounded-xl border border-gray-300 p-3'}`}>
|
| 69 |
+
<AirportInput
|
| 70 |
+
label="Where from?"
|
| 71 |
+
value={origin}
|
| 72 |
+
displayValue={originDisplay}
|
| 73 |
+
onChange={(iata, display) => { setOrigin(iata); setOriginDisplay(display); }}
|
| 74 |
+
testId="origin"
|
| 75 |
+
/>
|
| 76 |
+
|
| 77 |
+
<SwapButton onClick={handleSwap} />
|
| 78 |
+
|
| 79 |
+
<AirportInput
|
| 80 |
+
label="Where to?"
|
| 81 |
+
value={destination}
|
| 82 |
+
displayValue={destinationDisplay}
|
| 83 |
+
onChange={(iata, display) => { setDestination(iata); setDestinationDisplay(display); }}
|
| 84 |
+
testId="destination"
|
| 85 |
+
/>
|
| 86 |
+
|
| 87 |
+
<DatePicker
|
| 88 |
+
label="Departure"
|
| 89 |
+
value={departureDate}
|
| 90 |
+
onChange={setDepartureDate}
|
| 91 |
+
testId="departure-date"
|
| 92 |
+
/>
|
| 93 |
+
|
| 94 |
+
{tripType === 'round_trip' && (
|
| 95 |
+
<DatePicker
|
| 96 |
+
label="Return"
|
| 97 |
+
value={returnDate}
|
| 98 |
+
onChange={setReturnDate}
|
| 99 |
+
testId="return-date"
|
| 100 |
+
/>
|
| 101 |
+
)}
|
| 102 |
+
|
| 103 |
+
<button
|
| 104 |
+
type="submit"
|
| 105 |
+
disabled={!origin || !destination}
|
| 106 |
+
className="rounded-full bg-[#1a73e8] px-6 py-3 text-sm font-medium text-white hover:bg-[#1765cc] hover:shadow-md disabled:opacity-40 disabled:cursor-not-allowed focus:outline-none cursor-pointer"
|
| 107 |
+
data-testid="search-button"
|
| 108 |
+
>
|
| 109 |
+
Search
|
| 110 |
+
</button>
|
| 111 |
+
</div>
|
| 112 |
+
</form>
|
| 113 |
+
);
|
| 114 |
+
}
|
frontend/src/components/search/SwapButton.tsx
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface Props {
|
| 2 |
+
onClick: () => void;
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
export default function SwapButton({ onClick }: Props) {
|
| 6 |
+
return (
|
| 7 |
+
<button
|
| 8 |
+
onClick={onClick}
|
| 9 |
+
className="flex h-10 w-10 items-center justify-center rounded-full border border-gray-300 bg-white text-gray-500 hover:bg-gray-50 hover:text-gray-700 focus:outline-none self-center cursor-pointer"
|
| 10 |
+
aria-label="Swap origin and destination"
|
| 11 |
+
data-testid="swap-button"
|
| 12 |
+
>
|
| 13 |
+
<svg className="h-5 w-5" viewBox="0 0 24 24" fill="currentColor">
|
| 14 |
+
<path d="M6.99 11L3 15l3.99 4v-3H14v-2H6.99v-3zM21 9l-3.99-4v3H10v2h7.01v3L21 9z"/>
|
| 15 |
+
</svg>
|
| 16 |
+
</button>
|
| 17 |
+
);
|
| 18 |
+
}
|
frontend/src/components/search/TripTypeSelector.tsx
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { TripType } from '../../api/types';
|
| 2 |
+
|
| 3 |
+
const OPTIONS: { value: TripType; label: string }[] = [
|
| 4 |
+
{ value: 'round_trip', label: 'Round trip' },
|
| 5 |
+
{ value: 'one_way', label: 'One way' },
|
| 6 |
+
{ value: 'multi_city', label: 'Multi-city' },
|
| 7 |
+
];
|
| 8 |
+
|
| 9 |
+
interface Props {
|
| 10 |
+
value: TripType;
|
| 11 |
+
onChange: (v: TripType) => void;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export default function TripTypeSelector({ value, onChange }: Props) {
|
| 15 |
+
return (
|
| 16 |
+
<div className="relative" data-testid="trip-type-selector">
|
| 17 |
+
<select
|
| 18 |
+
value={value}
|
| 19 |
+
onChange={e => onChange(e.target.value as TripType)}
|
| 20 |
+
className="appearance-none rounded-md border border-gray-300 bg-white px-3 py-2 pr-8 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
|
| 21 |
+
aria-label="Trip type"
|
| 22 |
+
data-testid="trip-type-select"
|
| 23 |
+
>
|
| 24 |
+
{OPTIONS.map(o => (
|
| 25 |
+
<option key={o.value} value={o.value}>{o.label}</option>
|
| 26 |
+
))}
|
| 27 |
+
</select>
|
| 28 |
+
<svg className="pointer-events-none absolute right-2 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" viewBox="0 0 20 20" fill="currentColor">
|
| 29 |
+
<path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
|
| 30 |
+
</svg>
|
| 31 |
+
</div>
|
| 32 |
+
);
|
| 33 |
+
}
|
frontend/src/components/shared/Header.tsx
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useNavigate } from 'react-router-dom';
|
| 2 |
+
|
| 3 |
+
export default function Header() {
|
| 4 |
+
const navigate = useNavigate();
|
| 5 |
+
|
| 6 |
+
return (
|
| 7 |
+
<header className="border-b border-gray-200 bg-white" data-testid="header">
|
| 8 |
+
<div className="mx-auto max-w-7xl px-4 py-3 flex items-center gap-3">
|
| 9 |
+
<button
|
| 10 |
+
onClick={() => navigate('/')}
|
| 11 |
+
className="flex items-center gap-2 text-xl font-medium text-gray-900 hover:opacity-80 cursor-pointer"
|
| 12 |
+
data-testid="logo"
|
| 13 |
+
>
|
| 14 |
+
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" className="text-[#1a73e8]">
|
| 15 |
+
<path d="M21 16v-2l-8-5V3.5c0-.83-.67-1.5-1.5-1.5S10 2.67 10 3.5V9l-8 5v2l8-2.5V19l-2 1.5V22l3.5-1 3.5 1v-1.5L13 19v-5.5l8 2.5z" fill="currentColor"/>
|
| 16 |
+
</svg>
|
| 17 |
+
Flights
|
| 18 |
+
</button>
|
| 19 |
+
</div>
|
| 20 |
+
</header>
|
| 21 |
+
);
|
| 22 |
+
}
|
frontend/src/components/shared/Loading.tsx
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default function Loading() {
|
| 2 |
+
return (
|
| 3 |
+
<div className="flex flex-col items-center justify-center py-20" data-testid="loading">
|
| 4 |
+
<div className="h-10 w-10 animate-spin rounded-full border-4 border-gray-200 border-t-[#1a73e8]" />
|
| 5 |
+
<p className="mt-4 text-sm text-gray-500">Searching flights...</p>
|
| 6 |
+
</div>
|
| 7 |
+
);
|
| 8 |
+
}
|
frontend/src/hooks/useDebounce.ts
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState } from 'react';
|
| 2 |
+
|
| 3 |
+
export function useDebounce<T>(value: T, delay: number): T {
|
| 4 |
+
const [debounced, setDebounced] = useState(value);
|
| 5 |
+
|
| 6 |
+
useEffect(() => {
|
| 7 |
+
const timer = setTimeout(() => setDebounced(value), delay);
|
| 8 |
+
return () => clearTimeout(timer);
|
| 9 |
+
}, [value, delay]);
|
| 10 |
+
|
| 11 |
+
return debounced;
|
| 12 |
+
}
|
frontend/src/hooks/useFlightSearch.ts
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useCallback, useState } from 'react';
|
| 2 |
+
import { searchFlights } from '../api/client';
|
| 3 |
+
import type { CabinClass, Filters, FlightOffer, Passengers, SearchRequest, SortBy, TripType } from '../api/types';
|
| 4 |
+
|
| 5 |
+
interface SearchState {
|
| 6 |
+
outboundFlights: FlightOffer[];
|
| 7 |
+
returnFlights: FlightOffer[];
|
| 8 |
+
loading: boolean;
|
| 9 |
+
error: string | null;
|
| 10 |
+
searched: boolean;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
export function useFlightSearch() {
|
| 14 |
+
const [state, setState] = useState<SearchState>({
|
| 15 |
+
outboundFlights: [],
|
| 16 |
+
returnFlights: [],
|
| 17 |
+
loading: false,
|
| 18 |
+
error: null,
|
| 19 |
+
searched: false,
|
| 20 |
+
});
|
| 21 |
+
|
| 22 |
+
const search = useCallback(async (params: {
|
| 23 |
+
tripType: TripType;
|
| 24 |
+
origin: string;
|
| 25 |
+
destination: string;
|
| 26 |
+
departureDate: string;
|
| 27 |
+
returnDate?: string;
|
| 28 |
+
passengers: Passengers;
|
| 29 |
+
cabinClass: CabinClass;
|
| 30 |
+
filters: Filters;
|
| 31 |
+
sortBy: SortBy;
|
| 32 |
+
}) => {
|
| 33 |
+
setState(s => ({ ...s, loading: true, error: null }));
|
| 34 |
+
|
| 35 |
+
const legs = [{ origin: params.origin, destination: params.destination, date: params.departureDate }];
|
| 36 |
+
if (params.tripType === 'round_trip' && params.returnDate) {
|
| 37 |
+
legs.push({ origin: params.destination, destination: params.origin, date: params.returnDate });
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
const req: SearchRequest = {
|
| 41 |
+
trip_type: params.tripType,
|
| 42 |
+
legs,
|
| 43 |
+
passengers: params.passengers,
|
| 44 |
+
cabin_class: params.cabinClass,
|
| 45 |
+
filters: params.filters,
|
| 46 |
+
sort_by: params.sortBy,
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
try {
|
| 50 |
+
const res = await searchFlights(req);
|
| 51 |
+
setState({
|
| 52 |
+
outboundFlights: res.outbound_flights,
|
| 53 |
+
returnFlights: res.return_flights,
|
| 54 |
+
loading: false,
|
| 55 |
+
error: null,
|
| 56 |
+
searched: true,
|
| 57 |
+
});
|
| 58 |
+
} catch (err) {
|
| 59 |
+
setState({
|
| 60 |
+
outboundFlights: [],
|
| 61 |
+
returnFlights: [],
|
| 62 |
+
loading: false,
|
| 63 |
+
error: err instanceof Error ? err.message : 'Search failed',
|
| 64 |
+
searched: true,
|
| 65 |
+
});
|
| 66 |
+
}
|
| 67 |
+
}, []);
|
| 68 |
+
|
| 69 |
+
return { ...state, search };
|
| 70 |
+
}
|