fyliu Claude Opus 4.6 commited on
Commit
2e50ccd
·
0 Parent(s):

Add flight booking website (Google Flights clone)

Browse files

Full-stack flight search app for computer-use agent testing:
- Backend: FastAPI with deterministic pricing engine, route finder
(direct + 1-stop + 2-stop via hub detection), timezone-aware
flight generation, airport autocomplete, calendar pricing
- Frontend: React + TypeScript + Tailwind CSS with search form,
autocomplete, results page, filters, sorting, URL-driven state
- Docker: Multi-stage build (Node frontend → Python backend)
- Data: 3,770 airports, 55,627 routes, 604 airlines from
airline_routes.json loaded in-memory (~21 MB)
- All elements have data-testid attributes for agent testing
- Same search params always produce same results (SHA-256 seeded)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +10 -0
  2. .gitattributes +1 -0
  3. .gitignore +8 -0
  4. CLAUDE.md +518 -0
  5. Dockerfile +26 -0
  6. JOURNAL.md +0 -0
  7. airline_routes.json +3 -0
  8. backend/__init__.py +0 -0
  9. backend/api/__init__.py +0 -0
  10. backend/api/airports.py +47 -0
  11. backend/api/calendar.py +70 -0
  12. backend/api/search.py +144 -0
  13. backend/config.py +93 -0
  14. backend/data_loader.py +164 -0
  15. backend/flight_generator.py +270 -0
  16. backend/hub_detector.py +52 -0
  17. backend/main.py +59 -0
  18. backend/models.py +145 -0
  19. backend/price_engine.py +113 -0
  20. backend/requirements.txt +3 -0
  21. backend/route_finder.py +141 -0
  22. backend/seed_utils.py +20 -0
  23. description.md +1122 -0
  24. docker-compose.yml +6 -0
  25. frontend/eslint.config.js +23 -0
  26. frontend/index.html +13 -0
  27. frontend/package-lock.json +0 -0
  28. frontend/package.json +33 -0
  29. frontend/public/vite.svg +1 -0
  30. frontend/src/App.css +42 -0
  31. frontend/src/App.tsx +16 -0
  32. frontend/src/api/client.ts +39 -0
  33. frontend/src/api/types.ts +90 -0
  34. frontend/src/assets/react.svg +1 -0
  35. frontend/src/components/results/FilterPanel.tsx +133 -0
  36. frontend/src/components/results/FlightCard.tsx +112 -0
  37. frontend/src/components/results/FlightSegment.tsx +42 -0
  38. frontend/src/components/results/NoResults.tsx +29 -0
  39. frontend/src/components/results/SortBar.tsx +40 -0
  40. frontend/src/components/search/AirportInput.tsx +107 -0
  41. frontend/src/components/search/ClassSelector.tsx +30 -0
  42. frontend/src/components/search/DatePicker.tsx +22 -0
  43. frontend/src/components/search/PassengerSelector.tsx +86 -0
  44. frontend/src/components/search/SearchForm.tsx +114 -0
  45. frontend/src/components/search/SwapButton.tsx +18 -0
  46. frontend/src/components/search/TripTypeSelector.tsx +33 -0
  47. frontend/src/components/shared/Header.tsx +22 -0
  48. frontend/src/components/shared/Loading.tsx +8 -0
  49. frontend/src/hooks/useDebounce.ts +12 -0
  50. frontend/src/hooks/useFlightSearch.ts +70 -0
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ node_modules
2
+ frontend/node_modules
3
+ frontend/dist
4
+ .venv
5
+ .venv_*
6
+ .git
7
+ __pycache__
8
+ *.pyc
9
+ .uv_cache
10
+ .uv_pythons
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ airline_routes.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ node_modules/
2
+ frontend/dist/
3
+ .venv/
4
+ .venv_*/
5
+ __pycache__/
6
+ *.pyc
7
+ .uv_cache/
8
+ .uv_pythons/
CLAUDE.md ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md — Rules for AI Assistants (ECMoE Project)
2
+
3
+ ## MANDATORY FIRST STEPS
4
+
5
+ **Before taking ANY action on a task, you MUST:**
6
+
7
+ 1. Tell the user you have read CLAUDE.md and how you'll follow the THREE RULES
8
+ 2. **Actually read these files** (not optional):
9
+ - **README.md** — Directory structure, setup, how to run experiments
10
+ - **JOURNAL.md** — Recent bugs, what's broken/fixed, latest results
11
+ - **description.md** — Detailed method descriptions, design choices, hyperparameters
12
+
13
+ **Do NOT skip this to "get to work faster."** Skipping causes you to use wrong directories, miss known issues, and waste time on already-solved problems.
14
+
15
+ ---
16
+
17
+ ## THE THREE RULES
18
+
19
+ ### 1. EDIT, NEVER REWRITE
20
+ - **ALWAYS edit existing code, NEVER rewrite from scratch**
21
+ - Find the exact file/function, make surgical changes with Edit tool
22
+ - If you're about to write 50+ lines of new code doing something similar to existing code, STOP
23
+ - Reuse existing classes: `Compressor`, `Decompressor`, `StaleDecompressor`, `train_compressor`, etc.
24
+
25
+ ### 2. VALIDATE DATA BEFORE PLOTTING
26
+ - Always load results from JSON files, never hardcode values
27
+ - If a number looks different than expected, investigate before proceeding
28
+ - Check `results/summary/all_results_summary.json` for the canonical results
29
+
30
+ ### 3. COMMIT AND DOCUMENT IMMEDIATELY
31
+ - `git commit` after every fix (no remote configured — push when available)
32
+ - Update `JOURNAL.md` right after committing
33
+ - Don't batch changes — commit as you go
34
+
35
+ ---
36
+
37
+ ## MINDSET: NO SHORTCUTS
38
+
39
+ - Academic rigor means doing things RIGHT, not just doing things FAST
40
+ - Be skeptical of your own first approach — question whether it could be better
41
+ - Don't simplify the requirement — solve the actual problem
42
+
43
+ ---
44
+
45
+ ## Communication
46
+
47
+ **When showing results or finishing tasks:**
48
+ - ALWAYS provide the **full absolute path** to any files created or modified
49
+ - Example: "View the result at: `/project/6004852/lfy/ECMoE/results/summary/ppl_vs_ratio_all.png`"
50
+
51
+ ---
52
+
53
+ ## Project-Specific Rules
54
+
55
+ ### Environment Setup (Compute Canada)
56
+
57
+ ```bash
58
+ # Modules MUST be loaded BEFORE activating venv
59
+ module load cuda/12.6 arrow/22.0.0
60
+ source .venv/bin/activate
61
+
62
+ # HuggingFace cache goes to persistent project dir (home quota is small)
63
+ export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
64
+ ```
65
+
66
+ ### Directory Structure
67
+
68
+ ```
69
+ src/ # Python source code
70
+ scripts/ # Bash wrappers for each experiment
71
+ results/ # ALL experiment outputs (gitignored)
72
+ 01_distribution/ # Task 1: distribution analysis
73
+ 02_quantization/ # Task 2: quantization baseline
74
+ 03_neural_compressor/ # Task 3: shared neural compressor
75
+ 03b_perlayer_compressor/ # Task 3b: per-layer neural compressor
76
+ 04a_stale_compressed/ # Task 4a: stale-conditioned (compressed stale)
77
+ 04b_stale_uncompressed/ # Task 4b: stale-conditioned (uncompressed stale)
78
+ 05a_e2e_perlayer/ # Task 5a: e2e per-layer compressor (no stale)
79
+ 05b_e2e_stale/ # Task 5b: e2e stale-conditioned compressor
80
+ 05c_e2e_baseline/ # Task 5c: baseline (no compression, same pipeline)
81
+ 05c_megatron_e2e_baseline/ # Task 5c: baseline (Megatron variant)
82
+ 06a_megatron_e2e_pretrained_perlayer/ # Task 6a: e2e with 3b init (Megatron)
83
+ 06b_megatron_e2e_pretrained_stale/ # Task 6b: e2e with 4b init (Megatron)
84
+ 07a_megatron_e2e_split_perlayer/ # Task 7a: split-mode e2e (router=original)
85
+ 07b_megatron_e2e_split_stale/ # Task 7b: split-mode e2e + stale
86
+ 08_ep_compression/ # Task 8: EP compression eval (uses 7a/7b weights)
87
+ summary/ # Cross-method comparison plots and tables
88
+ data/hidden_states/ # Cached MoE hidden states (gitignored, ~37 GB in bfloat16)
89
+ ```
90
+
91
+ ### Key Code Architecture
92
+
93
+ - **`src/model_utils.py`** — Central library: model loading, MoE detection, hidden state
94
+ collection, ALL perplexity evaluation functions (baseline, shared, per-layer, stale)
95
+ - **`src/metrics.py`** — Reconstruction metrics: MSE, cosine similarity, relative error, SNR
96
+ - **`src/run_neural_compressor.py`** — Defines `Compressor`, `Decompressor`, `train_compressor()`.
97
+ Other scripts import from here — never duplicate these classes
98
+ - **`src/run_stale_compressor.py`** — Defines `StaleDecompressor`, `train_stale_compressor()`
99
+ - **`src/run_e2e_compressor.py`** — End-to-end training of per-layer compressors via LM loss.
100
+ Defines `E2ECompressorManager`, `SFTDataset`. Uses Dolci-Instruct-SFT with SFT mode
101
+ (response-only training). `_tokenize_sft_sample()` in `model_utils.py` handles the
102
+ response-only label masking.
103
+ - **`src/vllm_ep_compression.py`** — EP-aware compress/decompress registration for vLLM.
104
+ Sets `_ecmoe_compress_fn` / `_ecmoe_decompress_fn` on FusedMoE instances via
105
+ `apply_model()`. Supports per-layer and stale-conditioned methods. Requires patched
106
+ vLLM (`.venv_vllm_exp`).
107
+ - **`src/run_ep_compression_eval.py`** — Task 8 entry point: evaluates EP compression
108
+ with actual dispatch/combine in vLLM. Two modes: `simulation` (single-GPU) and `ep`
109
+ (multi-GPU with `enable_expert_parallel=True`). Uses Task 7a/7b weights.
110
+ - **`src/visualize_all_results.py`** — Generates all cross-method comparison plots and tables
111
+ - **`src/downstream_eval.py`** — Shared utility for downstream task evaluation via lm-eval-harness.
112
+ Provides hook registration functions (`register_quantization_hooks`, `register_perlayer_hooks`,
113
+ `register_stale_hooks`, `register_e2e_hooks`), `run_lm_eval()` wrapper, and result saving.
114
+ Imported by each task script when `--downstream-tasks` is specified.
115
+ Also provides vLLM backend support via apply_model pattern: `create_vllm_backend()`,
116
+ `register_perlayer_hooks_vllm()`, `register_stale_hooks_vllm()`,
117
+ `register_quantization_hooks_vllm()`, `remove_hooks_vllm()`.
118
+ Split (router-uncompressed) mode: `register_perlayer_hooks_split()`,
119
+ `register_stale_hooks_split()` for HF, and `register_perlayer_hooks_split_vllm()`,
120
+ `register_stale_hooks_split_vllm()` for vLLM. In split mode, the router sees original
121
+ hidden states while experts see decompressed — more realistic EP simulation.
122
+ - **`src/run_all_downstream.py`** — Standalone downstream evaluator. Loads model once,
123
+ evaluates all methods sequentially. Supports `--backend hf/vllm` and
124
+ `--router-mode compressed/uncompressed`.
125
+
126
+ ### Known Issues / Gotchas
127
+
128
+ **Layer sorting:** Always use `sorted(keys, key=layer_index)` from `model_utils`. Lexicographic
129
+ sorting puts layer 10 before layer 2 (`model.layers.10` < `model.layers.2`).
130
+
131
+ **Dtype mismatch:** Dequantized tensors and neural compressor outputs must match the model's
132
+ activation dtype (bfloat16). Always cast: `.to(x.dtype).to(x.device)`.
133
+
134
+ **What went wrong (2026-02-11):** `absmax_dequantize` returned float32 but model expected
135
+ bfloat16, causing `RuntimeError` during perplexity eval. Fix: explicit `.to(scale.dtype)` cast.
136
+
137
+ **What went wrong (2026-02-11):** When asked to remove quantization for Tasks 1–4, the agent
138
+ implemented the change (default `load_in_4bit=False`, `device="auto"`) without the user having
139
+ specified this as a hyperparameter. The model loading precision (BF16 vs 4-bit NF4) is a key
140
+ experimental parameter — changing it retroactively means old results are no longer reproducible
141
+ with default settings. **Lesson:** Treat model loading precision as a hyperparameter. Do NOT
142
+ change defaults that affect reproducibility without explicit user instruction. When the user says
143
+ "remove quantization", ASK whether they want it as a new default or as a CLI override.
144
+
145
+ **Response-only hidden state collection:** `collect_hidden_states()` defaults to
146
+ `response_only=True` — only assistant-response tokens are captured (labels != -100).
147
+ This ensures offline compressor training (Tasks 2–4) trains on the same distribution
148
+ that PPL evaluation measures. Use `--no-response-only` in `run_distribution.py` for
149
+ legacy all-token collection. Metadata records `"response_only": true/false`.
150
+
151
+ **Legacy Megatron script deleted:** `src/run_megatron_e2e_compressor.py` was removed because
152
+ it used `PackedTokenDataset` + `labels=input_ids` (standard LM, not SFT response-only),
153
+ did not use `get_split_indices()`, and misreported effective batch size with DP > 1.
154
+ Always use `src/megatron_e2e/train.py` for Megatron-based training.
155
+
156
+ **Large data files:** Hidden states for 100K tokens are ~18.5 GB per file in bfloat16
157
+ (dispatch + gather = ~37 GB). These are gitignored. Never try to `git add` them.
158
+
159
+ **Model VRAM:** Model is loaded in full BF16 (~60 GB). Tasks 1–4 use single GPU
160
+ (`device="cuda:0"`) — the model fits on one H100 80 GB with headroom for inference.
161
+ Task 5 uses multi-GPU (`device_map="auto"`) because backprop needs extra VRAM.
162
+ 4-bit NF4 loading (~15 GB) is available via `--load-in-4bit` but is NOT the default.
163
+
164
+ **device="auto" vs tensor ops:** When `device="auto"` is used for model loading (Task 5),
165
+ `"auto"` is NOT a valid torch device for tensor operations. Scripts that do `.to(device)` or
166
+ `train_compressor(device=...)` must use `compute_device` (resolved to `"cuda:0"` when
167
+ `device="auto"`). Only `load_model_and_tokenizer()` accepts `"auto"` directly.
168
+ Tasks 1–4 default to `device="cuda:0"` so this is only relevant for Task 5.
169
+
170
+ **Hook device safety (2026-02-17):** With `device_map="auto"`, model layers may reside on
171
+ different GPUs. PPL evaluation hooks in `model_utils.py` now explicitly call `.to(x.device)`
172
+ on compressor/decompressor outputs before returning them to the model. This is a no-op when
173
+ compressor and layer are on the same device but prevents cross-device errors when they differ.
174
+
175
+ ### vLLM Environment (downstream evaluation)
176
+
177
+ **vLLM backend:** `src/downstream_eval.py` + `src/run_all_downstream.py` — vLLM 0.8.4+
178
+ for downstream task evaluation with compression hooks.
179
+
180
+ ```bash
181
+ # Separate venv from HF-based experiments — CUDA 12.6
182
+ module load cuda/12.6 arrow/22.0.0
183
+ source .venv_vllm/bin/activate
184
+ export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
185
+
186
+ # Setup (first time only):
187
+ bash scripts/vllm_setup_env.sh
188
+ ```
189
+
190
+ **Known issues / gotchas (vLLM):**
191
+ - **vLLM V1 engine (>= 0.15):** The model runs in a **separate subprocess** (EngineCore).
192
+ You CANNOT access the model directly from the main process. The old path
193
+ `llm_engine.model_executor.driver_worker.model_runner.model` does NOT work.
194
+ Instead, use `vllm.LLM.apply_model(func)` to send functions to the worker process.
195
+ Functions are serialized via cloudpickle — they must be self-contained (include their
196
+ own imports and class definitions). Requires `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.
197
+ `create_vllm_backend()` sets this automatically.
198
+ - **enforce_eager=True required:** vLLM's CUDA graph capture prevents PyTorch hooks
199
+ from being called. Always use `enforce_eager=True` when registering compression hooks.
200
+ `create_vllm_backend()` sets this automatically.
201
+ - **Hook registration pattern:** All vLLM hook functions use the apply_model pattern:
202
+ `_vllm_register_perlayer()` returns a closure → `vllm_llm.apply_model(closure)`.
203
+ The closure runs inside the worker, loads weights, creates compressor modules,
204
+ and registers PyTorch pre-hooks. Cleanup via `_vllm_remove_hooks()` → `remove_hooks_vllm()`.
205
+ - **Layer name mapping:** vLLM may use different module paths than HF. `_map_layer_name()`
206
+ maps by numeric layer index, which is robust to naming differences.
207
+ - **Two router modes (--router-mode):**
208
+ - `compressed` (default): Pre-hook compress→decompress. Router AND experts see
209
+ decompressed. Conservative lower bound — same as the original PPL evaluation hooks.
210
+ - `uncompressed`: Split forward — router sees ORIGINAL input, experts see decompressed.
211
+ More realistic EP simulation where router runs on source GPU with original data.
212
+ Both modes work for HF and vLLM backends.
213
+ - **No multi-device placement:** The plan called for `compressor_device` (attention GPU)
214
+ vs `decompressor_devices` (expert GPUs) to simulate the actual communication topology.
215
+ Current implementation puts both compressor and decompressor on the same device. This
216
+ doesn't affect quality measurement (the math is device-independent) but doesn't
217
+ demonstrate the real communication pattern or measure cross-device overhead.
218
+ - **No shared expert handling:** Split mode omits `shared_expert` /
219
+ `shared_expert_gate` logic. Qwen3-30B-A3B doesn't use shared experts so this is
220
+ correct for the current model, but reduces generality.
221
+ - **No separate E2E hooks for vLLM:** E2E and offline weights have identical format.
222
+ `register_perlayer_hooks_vllm()` works for 3b + 5a + 6a weights.
223
+ `register_stale_hooks_vllm()` works for 4a/4b + 5b + 6b weights.
224
+ - **TP > 1 with vLLM:** When using tensor parallelism, each rank has a partial model.
225
+ Hook registration should still work (hooks are on the full module), but compressor
226
+ modules stay on one device. Tested with TP=1 by default.
227
+
228
+ **vLLM-specific directories:**
229
+ ```
230
+ .venv_vllm/ # Separate virtual environment (gitignored)
231
+ ```
232
+
233
+ ### vLLM EP Compression Environment (Task 8)
234
+
235
+ **EP compression:** `src/vllm_ep_compression.py` — Sets compress/decompress functions
236
+ on FusedMoE instances. Patched `forward_impl()` calls compress BEFORE dispatch and
237
+ decompress AFTER, achieving real communication reduction.
238
+
239
+ ```bash
240
+ # Separate venv with patched vLLM 0.15.1 — CUDA 12.6
241
+ module load cuda/12.6 arrow/22.0.0
242
+ source .venv_vllm_exp/bin/activate
243
+ export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
244
+
245
+ # Setup (first time only):
246
+ bash scripts/vllm_exp_setup_env.sh
247
+ ```
248
+
249
+ **Key differences from .venv_vllm:**
250
+ - vLLM 0.15.1 pinned (for patch compatibility)
251
+ - `FusedMoE.forward_impl()` patched with 3 insertion points (~12 lines)
252
+ - Uses `_ecmoe_compress_fn` / `_ecmoe_decompress_fn` attributes (not PyTorch hooks)
253
+ - Supports `enable_expert_parallel=True` for actual EP dispatch
254
+
255
+ **Known issues / gotchas (EP compression):**
256
+ - **allgather_reducescatter backend:** vLLM's default `all2all_backend`. After dispatch,
257
+ every rank has ALL tokens. Stale cache approach works because token ordering is
258
+ consistent across layers.
259
+ - **Router unaffected:** `router_logits` are computed at `Qwen3MoeSparseMoeBlock.forward()`
260
+ BEFORE `FusedMoE.forward_impl()`, so compression never affects routing decisions.
261
+ - **Stale piggybacking:** Reference layers concatenate `cat(compressed, stale)` before
262
+ dispatch. After dispatch, decompress_fn splits and caches stale globally. Non-reference
263
+ layers dispatch only compressed (max compression), retrieve cached stale for decompression.
264
+
265
+ **vLLM EP compression directories:**
266
+ ```
267
+ .venv_vllm_exp/ # Patched vLLM environment (gitignored)
268
+ results/08_ep_compression/ # EP eval results
269
+ ```
270
+
271
+ ### Megatron-LM Environment (Task 5 Megatron variant)
272
+
273
+ **Megatron implementation:** `src/megatron_e2e/` package — EP-first, CUDA 12.9, Megatron Bridge.
274
+ (Legacy `src/run_megatron_e2e_compressor.py` was deleted due to SFT/split/batch bugs.)
275
+
276
+ ```bash
277
+ # Separate venv from HF-based experiments — CUDA 12.9 required
278
+ module load cuda/12.9 nccl arrow/22.0.0
279
+ source .venv_megatron/bin/activate
280
+ export HF_HOME=/home/lfy/projects/rrg-bengioy-ad/lfy/ECMoE/.cache/huggingface
281
+
282
+ # Setup (first time only):
283
+ bash scripts/megatron_setup_env.sh
284
+ ```
285
+
286
+ **Key differences from HF environment:**
287
+ - Uses `megatron-core` >=0.15.0 for model parallelism (EP, TP, DP, PP)
288
+ - Requires Transformer Engine (for Megatron Bridge and fused kernels)
289
+ - Uses `megatron-bridge` >=0.2.0 for HF→Megatron weight conversion
290
+ - Default parallelism: EP=4, TP=1, PP=1 (expert parallelism, not tensor)
291
+ - Launch via `torchrun`, not `python`
292
+
293
+ **Megatron-specific directories:**
294
+ ```
295
+ src/megatron_e2e/ # Package-based implementation (recommended)
296
+ .venv_megatron/ # Separate virtual environment (gitignored)
297
+ .uv_cache/ # uv cache on project disk (gitignored)
298
+ .uv_pythons/ # uv Python installs (gitignored)
299
+ third_party/ # Apex, etc. (gitignored, legacy only)
300
+ data/megatron_dolci/ # Preprocessed binary dataset (gitignored)
301
+ ```
302
+
303
+ **Known issues / gotchas (Megatron):**
304
+ - **CUDA version:** Megatron Bridge requires CUDA >= 12.8. Use `cuda/12.9` module
305
+ on Compute Canada, NOT `cuda/12.6`.
306
+ - **EP vs TP:** Default is EP=4 (expert parallelism). With EP, each GPU holds 32/128
307
+ experts per layer. TP=4 is the legacy approach and splits attention heads across GPUs.
308
+ - **Megatron layer names** differ from HF: `decoder.layers.N.mlp` vs `model.layers.N.mlp`.
309
+ `_megatron_to_hf_layer_name()` in `compressor_manager.py` handles conversion.
310
+ - Compressor weights are replicated across all ranks (not sharded), since they
311
+ are tiny (~200M total). Saved from rank 0 only.
312
+ - With EP>1, compressor is on source GPU (attention side), decompressor on
313
+ destination GPU (expert side) — different devices.
314
+ - `MegatronModelWrapper` bridges Megatron's forward interface to HF-style
315
+ `SimpleNamespace(loss=..., logits=...)`. Uses `vocab_parallel_cross_entropy`
316
+ for correct loss with TP > 1. SFT labels (-100) are clamped to 0 before
317
+ calling `vocab_parallel_cross_entropy`, and loss is masked via
318
+ `(per_token_loss * loss_mask).sum() / num_valid`.
319
+ - DistributedSampler must use DP rank/size (via `get_dp_info()`), NOT global
320
+ world size. All ranks in a TP group must see the SAME data.
321
+ - Saved weights use HF layer names (`model.layers.N.mlp`) for compatibility
322
+ with HF `E2ECompressorManager.load_weights()`.
323
+ - **Model loading:** `train.py` tries AutoBridge → MegatronBridge → manual fallback
324
+ for HF→Megatron conversion. If Bridge is not installed, falls back to manual
325
+ weight conversion using `load_megatron_qwen3()` from legacy code.
326
+ - **Train loss DP reduction (2026-02-17):** `train.py` now all-reduces step-level and
327
+ epoch-level train loss across DP ranks before logging. Previously, only rank 0's local
328
+ shard loss was logged, which was inaccurate with DP > 1. Wandb `train/loss` and
329
+ `train/epoch_loss` now reflect the true DP-averaged loss.
330
+
331
+ ### Running Experiments
332
+
333
+ Task 1 must run first (caches hidden states for Tasks 2–4). Task 5 is independent.
334
+ Tasks 1–4 use 1 GPU each; Task 5a/5b use 4 GPUs each.
335
+
336
+ **Data selection:** All tasks use seed=42 for reproducible 80/10/10 train/val/test
337
+ split of dataset rows. Tasks 1–4 draw from TRAIN split, PPL evaluation from TEST
338
+ split. No data leakage between splits.
339
+
340
+ **Task 5 config (HF):** batch_size=2, grad_accum=8 (effective=16), max_sequences=500K,
341
+ max_length=2048, val_interval=2500 steps, val_batch_size=8, SFT mode
342
+ (response-only training), wandb enabled by default.
343
+
344
+ **Task 5/6 config (Megatron):** Same as HF except max_sequences=100K,
345
+ val_interval=1000 steps. Task 6 uses same Megatron config with `--init-weights-dir`.
346
+
347
+ Tail micro-batches (when `len(dataloader) % grad_accum != 0`) are handled by rescaling
348
+ accumulated gradients and performing the optimizer step.
349
+
350
+ **Two evaluation stages:** Training-time val loss uses the VAL split (50K seqs,
351
+ batch_size=8, every 2500 steps) for checkpoint selection and wandb monitoring.
352
+ Final PPL evaluation uses the TEST split (50K seqs, batch_size=1, in
353
+ `model_utils.py`) for reported results. Different code paths — `--val-batch-size`
354
+ only affects training-time eval.
355
+
356
+ **SFT data loading:** All E2E training (Task 5) and perplexity evaluation now use
357
+ SFT mode: each sample is one conversation, tokenized independently. Labels are
358
+ -100 for non-assistant tokens (system, user, template markup) and actual token
359
+ IDs for assistant responses. Loss and perplexity are computed on response tokens
360
+ only. Data is loaded by sampling N sequences from the dataset (not packing tokens).
361
+ `_tokenize_sft_sample()` in `model_utils.py` handles the tokenization.
362
+
363
+ ```bash
364
+ # Phase 1: Megatron 5a + 5b in parallel (8 GPUs)
365
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh none &
366
+ CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_megatron_e2e.sh uncompressed &
367
+ wait
368
+
369
+ # Phase 2: Task 1 (re-cache with seed=42)
370
+ CUDA_VISIBLE_DEVICES=0 bash scripts/01_analyze_distribution.sh
371
+
372
+ # Phase 3: Tasks 2-4 + HF 5a (parallel)
373
+ CUDA_VISIBLE_DEVICES=0 bash scripts/02_run_quantization.sh &
374
+ CUDA_VISIBLE_DEVICES=1 bash scripts/03_run_neural_compressor.sh &
375
+ CUDA_VISIBLE_DEVICES=2 bash scripts/03b_run_perlayer_compressor.sh &
376
+ CUDA_VISIBLE_DEVICES=3 bash scripts/04_run_stale_compressor.sh compressed &
377
+ CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_run_e2e_compressor.sh none &
378
+ wait
379
+
380
+ # Phase 4: Task 4b + HF 5b (parallel)
381
+ CUDA_VISIBLE_DEVICES=0 bash scripts/04_run_stale_compressor.sh uncompressed &
382
+ CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/05_run_e2e_compressor.sh uncompressed &
383
+ wait
384
+
385
+ # Megatron-based E2E training (alternative to HF Task 5):
386
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh none # 5a
387
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh uncompressed # 5b
388
+
389
+ # Task 5c: Baseline evaluation (no compression, same pipeline):
390
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_run_e2e_compressor.sh baseline # HF
391
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/05_megatron_e2e.sh baseline # Megatron
392
+
393
+ # Task 6a/6b: E2E with pretrained init (requires Task 3b/4b weights):
394
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/06_megatron_e2e_pretrained.sh none & # 6a (init from 3b)
395
+ CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/06_megatron_e2e_pretrained.sh uncompressed & # 6b (init from 4b)
396
+ wait
397
+
398
+ # Task 7a/7b: Split-mode E2E (router sees original, experts see decompressed):
399
+ CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/07_megatron_e2e_split.sh none & # 7a (init from 3b)
400
+ CUDA_VISIBLE_DEVICES=4,5,6,7 bash scripts/07_megatron_e2e_split.sh uncompressed & # 7b (init from 4b)
401
+ wait
402
+ ```
403
+
404
+ ### Downstream Task Evaluation (lm-eval-harness)
405
+
406
+ Downstream eval is triggered by setting `DOWNSTREAM_TASKS` before running any script.
407
+ It runs **after** the existing PPL evaluation step, using `lm-eval-harness` with the
408
+ same compression hooks active. Results saved to `downstream_results.json` in each
409
+ task's output directory.
410
+
411
+ ```bash
412
+ # Run Task 2 + PPL eval + downstream eval:
413
+ DOWNSTREAM_TASKS="gsm8k_cot" bash scripts/02_run_quantization.sh
414
+
415
+ # Run Task 5a + PPL eval + downstream eval:
416
+ DOWNSTREAM_TASKS="gsm8k_cot" bash scripts/05_run_e2e_compressor.sh none
417
+
418
+ # Eval-only mode + downstream:
419
+ DOWNSTREAM_TASKS="gsm8k_cot" python src/run_e2e_compressor.py \
420
+ --skip-training --output-dir results/05a_e2e_perlayer --stale-mode none
421
+
422
+ # Smoke test with 10 examples:
423
+ DOWNSTREAM_TASKS="gsm8k_cot" DOWNSTREAM_LIMIT=10 bash scripts/05_run_e2e_compressor.sh none
424
+ ```
425
+
426
+ **Key code:** `src/downstream_eval.py` provides `register_*_hooks()` for each method,
427
+ `run_lm_eval()` wrapper, and `save_downstream_results()`. Each task script imports from
428
+ it when `--downstream-tasks` is specified. GSM8K variant: `gsm8k_cot` (8-shot CoT).
429
+
430
+ **vLLM backend:** Use `--backend vllm` (or `DOWNSTREAM_BACKEND=vllm`) for vLLM-based
431
+ downstream evaluation. Two router modes (`--router-mode compressed/uncompressed`):
432
+
433
+ ```bash
434
+ # Standalone vLLM eval (all methods, default router=compressed):
435
+ source .venv_vllm/bin/activate
436
+ python src/run_all_downstream.py --backend vllm --tasks gsm8k_cot
437
+
438
+ # Router-uncompressed mode (split: router sees original, experts see decompressed):
439
+ python src/run_all_downstream.py --backend vllm --router-mode uncompressed --method e2e_perlayer --tasks gsm8k_cot
440
+
441
+ # With tensor parallelism:
442
+ python src/run_all_downstream.py --backend vllm --tensor-parallel-size 4 --tasks gsm8k_cot
443
+
444
+ # Via task scripts (HF model, vLLM downstream):
445
+ DOWNSTREAM_TASKS="gsm8k_cot" DOWNSTREAM_BACKEND=vllm bash scripts/05_run_e2e_compressor.sh none
446
+ ```
447
+
448
+ ### Visualization
449
+
450
+ Regenerate all summary plots and tables:
451
+ ```bash
452
+ source .venv/bin/activate
453
+ python src/visualize_all_results.py
454
+ ```
455
+
456
+ Outputs to `results/summary/`:
457
+ - `ppl_vs_ratio_all.png` — PPL vs compression ratio (log-log)
458
+ - `reconstruction_vs_ratio_all.png` — MSE and CosSim vs ratio
459
+ - `ppl_bar_practical.png` — Bar chart at 2x and 4x
460
+ - `all_results_summary.json` — Machine-readable summary
461
+ - `param_count_table.{csv,md,json}` — Parameter counts for all methods
462
+
463
+ ---
464
+
465
+ ## Code Changes
466
+
467
+ **Before changing any code:**
468
+ 1. FIND the exact file that produces the current output
469
+ 2. READ and understand it
470
+ 3. EDIT only the specific lines needed (use Edit tool)
471
+ 4. TEST that output matches except for your intended change
472
+
473
+ **Adding new compression methods:**
474
+ - Reuse `Compressor`, `Decompressor` from `run_neural_compressor.py`
475
+ - Reuse `train_compressor()` for standard autoencoder training
476
+ - Add new perplexity evaluation functions to `model_utils.py`
477
+ - Follow the same JSON output format as existing experiments
478
+ - Update `visualize_all_results.py` to include the new method
479
+
480
+ ---
481
+
482
+ ## NEVER GUESS SILENTLY
483
+
484
+ **When you encounter ambiguity:**
485
+ 1. **STOP** — Do not make an arbitrary choice
486
+ 2. **ASK** — Present the options to the user
487
+ 3. **FLAG** — Note the documentation gap
488
+ 4. **FIX** — Update README.md or CLAUDE.md
489
+
490
+ ---
491
+
492
+ ## Version Control
493
+
494
+ - Commit after EVERY fix (don't wait)
495
+ - Check `git status` and file sizes before committing (no files >100MB)
496
+ - Update JOURNAL.md immediately after committing
497
+ - No git remote is currently configured — commits are local only
498
+
499
+ ---
500
+
501
+ ## Investigation
502
+
503
+ **When something seems wrong:**
504
+ 1. STOP — don't patch the visible symptom
505
+ 2. ASK WHY — trace back to data generation
506
+ 3. VERIFY — test hypotheses with minimal examples
507
+ 4. FIX ROOT — fix the source, not downstream
508
+
509
+ ---
510
+
511
+ ## Meta-Rule: Continuous Improvement
512
+
513
+ **When a preventable issue occurs:**
514
+ 1. Identify the root cause
515
+ 2. Add a "What went wrong" example to this file
516
+ 3. Commit the improvement
517
+
518
+ This file should evolve based on lessons learned.
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1: Build frontend
2
+ FROM node:22-alpine AS frontend-build
3
+ WORKDIR /app/frontend
4
+ COPY frontend/package.json frontend/package-lock.json ./
5
+ RUN npm ci
6
+ COPY frontend/ ./
7
+ RUN npm run build
8
+
9
+ # Stage 2: Run backend + serve frontend
10
+ FROM python:3.12-slim
11
+ WORKDIR /app
12
+
13
+ # Install Python deps
14
+ COPY backend/requirements.txt ./backend/
15
+ RUN pip install --no-cache-dir -r backend/requirements.txt
16
+
17
+ # Copy backend
18
+ COPY backend/ ./backend/
19
+ COPY airline_routes.json ./
20
+
21
+ # Copy built frontend
22
+ COPY --from=frontend-build /app/frontend/dist ./frontend/dist
23
+
24
+ EXPOSE 8080
25
+
26
+ CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8080"]
JOURNAL.md ADDED
The diff for this file is too large to render. See raw diff
 
airline_routes.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e8d97548f626230927e3e38b8ba9712710612ef90292e9a0696698d20b3bac3
3
+ size 21798276
backend/__init__.py ADDED
File without changes
backend/api/__init__.py ADDED
File without changes
backend/api/airports.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Airport autocomplete and info endpoints."""
2
+
3
+ from fastapi import APIRouter, HTTPException, Query
4
+
5
+ from ..data_loader import get_route_graph
6
+ from ..models import AirportInfo, AutocompleteResult
7
+
8
+ router = APIRouter(prefix="/api/airports", tags=["airports"])
9
+
10
+
11
+ @router.get("/autocomplete", response_model=list[AutocompleteResult])
12
+ async def autocomplete(q: str = Query(..., min_length=1, max_length=50)):
13
+ graph = get_route_graph()
14
+ airports = graph.search_airports(q, limit=10)
15
+ return [
16
+ AutocompleteResult(
17
+ iata=a.iata,
18
+ name=a.name,
19
+ city_name=a.city_name,
20
+ country=a.country,
21
+ display_name=a.display_name,
22
+ hub_score=a.hub_score,
23
+ )
24
+ for a in airports
25
+ ]
26
+
27
+
28
+ @router.get("/{iata}", response_model=AirportInfo)
29
+ async def get_airport(iata: str):
30
+ graph = get_route_graph()
31
+ iata = iata.upper()
32
+ airport = graph.airports.get(iata)
33
+ if not airport:
34
+ raise HTTPException(status_code=404, detail=f"Airport {iata} not found")
35
+ return AirportInfo(
36
+ iata=airport.iata,
37
+ name=airport.name,
38
+ city_name=airport.city_name,
39
+ country=airport.country,
40
+ country_code=airport.country_code,
41
+ continent=airport.continent,
42
+ latitude=airport.latitude,
43
+ longitude=airport.longitude,
44
+ timezone=airport.timezone,
45
+ hub_score=airport.hub_score,
46
+ route_count=len(airport.routes),
47
+ )
backend/api/calendar.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calendar pricing endpoint — cheapest price per day for a month."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import calendar
6
+ from datetime import date
7
+
8
+ from fastapi import APIRouter, HTTPException, Query
9
+
10
+ from ..data_loader import get_route_graph
11
+ from ..models import CalendarDay, CalendarResponse, CabinClass
12
+ from ..price_engine import compute_calendar_price
13
+ from ..seed_utils import seeded_random
14
+
15
+ router = APIRouter(prefix="/api", tags=["calendar"])
16
+
17
+
18
+ @router.get("/calendar", response_model=CalendarResponse)
19
+ async def get_calendar(
20
+ origin: str = Query(..., min_length=3, max_length=3),
21
+ destination: str = Query(..., min_length=3, max_length=3),
22
+ year: int = Query(..., ge=2025, le=2028),
23
+ month: int = Query(..., ge=1, le=12),
24
+ cabin_class: CabinClass = Query(CabinClass.economy),
25
+ ):
26
+ graph = get_route_graph()
27
+ origin = origin.upper()
28
+ destination = destination.upper()
29
+
30
+ if origin not in graph.airports:
31
+ raise HTTPException(status_code=404, detail=f"Airport {origin} not found")
32
+ if destination not in graph.airports:
33
+ raise HTTPException(status_code=404, detail=f"Airport {destination} not found")
34
+
35
+ route = graph.get_direct_route(origin, destination)
36
+ if not route:
37
+ # Try to estimate distance for pricing
38
+ from ..route_finder import _estimate_distance
39
+ distance = _estimate_distance(graph, origin, destination)
40
+ if distance is None:
41
+ raise HTTPException(status_code=404, detail="No route found")
42
+ num_carriers = 2 # default estimate
43
+ else:
44
+ distance = route.distance_km
45
+ num_carriers = len(route.carriers)
46
+
47
+ dest_airport = graph.airports[destination]
48
+ num_days = calendar.monthrange(year, month)[1]
49
+
50
+ days = []
51
+ for day in range(1, num_days + 1):
52
+ d = date(year, month, day)
53
+ rng = seeded_random(origin, destination, d.isoformat(), "calendar")
54
+ price = compute_calendar_price(
55
+ distance_km=distance,
56
+ cabin_class=cabin_class.value,
57
+ target_date=d,
58
+ num_carriers=num_carriers,
59
+ dest_continent=dest_airport.continent,
60
+ rng=rng,
61
+ )
62
+ days.append(CalendarDay(date=d, cheapest_price=price))
63
+
64
+ return CalendarResponse(
65
+ origin=origin,
66
+ destination=destination,
67
+ year=year,
68
+ month=month,
69
+ days=days,
70
+ )
backend/api/search.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flight search endpoint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi import APIRouter, HTTPException
6
+
7
+ from ..config import MAX_RESULTS
8
+ from ..data_loader import get_route_graph
9
+ from ..flight_generator import generate_flights_for_route
10
+ from ..hub_detector import compute_hub_scores
11
+ from ..models import FlightOffer, SearchRequest, SearchResponse, SortBy
12
+ from ..route_finder import find_routes
13
+ from ..seed_utils import make_seed
14
+
15
+ router = APIRouter(prefix="/api", tags=["search"])
16
+
17
+ # Module-level hub cache
18
+ _hub_iatas: list[str] | None = None
19
+
20
+
21
+ def _get_hubs() -> list[str]:
22
+ global _hub_iatas
23
+ if _hub_iatas is None:
24
+ graph = get_route_graph()
25
+ _hub_iatas = compute_hub_scores(graph)
26
+ return _hub_iatas
27
+
28
+
29
+ def _apply_filters(flights: list[FlightOffer], req: SearchRequest) -> list[FlightOffer]:
30
+ f = req.filters
31
+ result = flights
32
+
33
+ if f.max_stops is not None:
34
+ result = [fl for fl in result if fl.stops <= f.max_stops]
35
+
36
+ if f.max_price is not None:
37
+ result = [fl for fl in result if fl.price_usd <= f.max_price]
38
+
39
+ if f.max_duration_minutes is not None:
40
+ result = [fl for fl in result if fl.total_duration_minutes <= f.max_duration_minutes]
41
+
42
+ if f.airlines:
43
+ airline_set = set(f.airlines)
44
+ result = [
45
+ fl for fl in result
46
+ if any(seg.airline_code in airline_set for seg in fl.segments)
47
+ ]
48
+
49
+ if f.departure_time_min:
50
+ h, m = map(int, f.departure_time_min.split(":"))
51
+ min_minutes = h * 60 + m
52
+ result = [
53
+ fl for fl in result
54
+ if fl.departure.hour * 60 + fl.departure.minute >= min_minutes
55
+ ]
56
+
57
+ if f.departure_time_max:
58
+ h, m = map(int, f.departure_time_max.split(":"))
59
+ max_minutes = h * 60 + m
60
+ result = [
61
+ fl for fl in result
62
+ if fl.departure.hour * 60 + fl.departure.minute <= max_minutes
63
+ ]
64
+
65
+ return result
66
+
67
+
68
+ def _sort_flights(flights: list[FlightOffer], sort_by: SortBy) -> list[FlightOffer]:
69
+ if sort_by == SortBy.cheapest:
70
+ return sorted(flights, key=lambda f: f.price_usd)
71
+ elif sort_by == SortBy.fastest:
72
+ return sorted(flights, key=lambda f: f.total_duration_minutes)
73
+ else: # best: balance of price and duration
74
+ if not flights:
75
+ return flights
76
+ max_price = max(f.price_usd for f in flights) or 1
77
+ max_dur = max(f.total_duration_minutes for f in flights) or 1
78
+ return sorted(
79
+ flights,
80
+ key=lambda f: (f.price_usd / max_price) * 0.6 + (f.total_duration_minutes / max_dur) * 0.4,
81
+ )
82
+
83
+
84
+ @router.post("/search", response_model=SearchResponse)
85
+ async def search_flights(req: SearchRequest):
86
+ graph = get_route_graph()
87
+ hub_iatas = _get_hubs()
88
+
89
+ if not req.legs:
90
+ raise HTTPException(status_code=400, detail="At least one leg required")
91
+
92
+ # Validate airports
93
+ for leg in req.legs:
94
+ if leg.origin.upper() not in graph.airports:
95
+ raise HTTPException(status_code=404, detail=f"Airport {leg.origin} not found")
96
+ if leg.destination.upper() not in graph.airports:
97
+ raise HTTPException(status_code=404, detail=f"Airport {leg.destination} not found")
98
+
99
+ # Generate outbound flights
100
+ outbound_leg = req.legs[0]
101
+ origin = outbound_leg.origin.upper()
102
+ destination = outbound_leg.destination.upper()
103
+
104
+ max_stops = req.filters.max_stops
105
+ route_plans = find_routes(graph, origin, destination, hub_iatas, max_stops=max_stops)
106
+
107
+ outbound_flights: list[FlightOffer] = []
108
+ for plan in route_plans:
109
+ flights = generate_flights_for_route(
110
+ graph, plan, outbound_leg.date, req.cabin_class, hub_iatas
111
+ )
112
+ outbound_flights.extend(flights)
113
+
114
+ outbound_flights = _apply_filters(outbound_flights, req)
115
+ outbound_flights = _sort_flights(outbound_flights, req.sort_by)
116
+ outbound_flights = outbound_flights[:MAX_RESULTS]
117
+
118
+ # Generate return flights if round trip
119
+ return_flights: list[FlightOffer] = []
120
+ if req.trip_type.value == "round_trip" and len(req.legs) >= 2:
121
+ return_leg = req.legs[1]
122
+ ret_origin = return_leg.origin.upper()
123
+ ret_dest = return_leg.destination.upper()
124
+ ret_plans = find_routes(graph, ret_origin, ret_dest, hub_iatas, max_stops=max_stops)
125
+
126
+ for plan in ret_plans:
127
+ flights = generate_flights_for_route(
128
+ graph, plan, return_leg.date, req.cabin_class, hub_iatas
129
+ )
130
+ return_flights.extend(flights)
131
+
132
+ return_flights = _apply_filters(return_flights, req)
133
+ return_flights = _sort_flights(return_flights, req.sort_by)
134
+ return_flights = return_flights[:MAX_RESULTS]
135
+
136
+ search_id = str(make_seed(origin, destination, outbound_leg.date.isoformat()))
137
+
138
+ return SearchResponse(
139
+ outbound_flights=outbound_flights,
140
+ return_flights=return_flights,
141
+ search_id=search_id,
142
+ origin=origin,
143
+ destination=destination,
144
+ )
backend/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pricing constants and configuration."""
2
+
3
+ # Base price formula
4
+ BASE_FIXED_USD = 40
5
+ BASE_PER_KM_USD = 0.08
6
+
7
+ # Cabin class multipliers
8
+ CLASS_MULTIPLIERS = {
9
+ "economy": 1.0,
10
+ "premium_economy": 1.6,
11
+ "business": 3.2,
12
+ "first": 5.5,
13
+ }
14
+
15
+ # Day-of-week multipliers (0=Monday, 6=Sunday)
16
+ DAY_MULTIPLIERS = {
17
+ 0: 0.90, # Monday
18
+ 1: 0.90, # Tuesday
19
+ 2: 0.90, # Wednesday
20
+ 3: 1.00, # Thursday
21
+ 4: 1.15, # Friday
22
+ 5: 1.05, # Saturday
23
+ 6: 1.10, # Sunday
24
+ }
25
+
26
+ # Season multipliers by month
27
+ SEASON_MULTIPLIERS = {
28
+ 1: 0.85, # January - off season
29
+ 2: 0.85, # February - off season
30
+ 3: 0.95, # March
31
+ 4: 1.00, # April
32
+ 5: 1.05, # May
33
+ 6: 1.15, # June - summer peak
34
+ 7: 1.20, # July - summer peak
35
+ 8: 1.15, # August - summer peak
36
+ 9: 0.90, # September - off season
37
+ 10: 0.95, # October
38
+ 11: 1.00, # November
39
+ 12: 1.40, # December - Christmas
40
+ }
41
+
42
+ # Season bonus for EU destinations in summer
43
+ EU_SUMMER_BONUS = 0.15 # +15% on top of summer multiplier
44
+ EU_CONTINENTS = {"EU"}
45
+ EU_SUMMER_MONTHS = {6, 7, 8}
46
+
47
+ # Demand multipliers
48
+ MONOPOLY_ROUTE_BONUS = 0.20 # +20% if only 1 carrier
49
+ HIGH_COMPETITION_DISCOUNT = 0.05 # -5% if 4+ carriers
50
+
51
+ # Advance booking multipliers (days before departure)
52
+ ADVANCE_MULTIPLIERS = [
53
+ (3, 1.50), # 0-3 days: +50%
54
+ (7, 1.35), # 4-7 days: +35%
55
+ (14, 1.20), # 8-14 days: +20%
56
+ (21, 1.10), # 15-21 days: +10%
57
+ (60, 1.00), # 22-60 days: base
58
+ (90, 0.90), # 61-90 days: -10%
59
+ (float("inf"), 0.95), # 91+ days: -5%
60
+ ]
61
+
62
+ # Jitter range (±8%)
63
+ JITTER_RANGE = 0.08
64
+
65
+ # Hub detection thresholds
66
+ HUB_MIN_ROUTES = 100
67
+ HUB_TOP_N = 125
68
+
69
+ # Connecting flight constraints
70
+ MAX_1STOP_DISTANCE_RATIO = 1.8 # Max total distance vs great-circle
71
+ MAX_2STOP_DISTANCE_RATIO = 2.5
72
+ MIN_LAYOVER_MINUTES = 60 # 1 hour
73
+ MAX_LAYOVER_MINUTES = 360 # 6 hours
74
+
75
+ # Flight generation
76
+ MIN_FLIGHTS_PER_DAY = 1
77
+ MAX_FLIGHTS_SINGLE_CARRIER = 3
78
+ MAX_FLIGHTS_MULTI_CARRIER = 15
79
+ DEPARTURE_HOUR_MIN = 5 # 05:00
80
+ DEPARTURE_HOUR_MAX = 23 # 23:00
81
+
82
+ # Aircraft types by distance
83
+ AIRCRAFT_BY_DISTANCE = [
84
+ (500, ["E190", "E175", "CRJ-900"]),
85
+ (2000, ["A320", "A321", "737-800", "737 MAX 8"]),
86
+ (5000, ["A321neo LR", "757-200", "767-300ER"]),
87
+ (10000, ["787-8", "787-9", "A330-300", "A350-900"]),
88
+ (float("inf"), ["777-300ER", "A350-1000", "787-10", "A380"]),
89
+ ]
90
+
91
+ # Search limits
92
+ MAX_RESULTS = 200
93
+ MAX_AUTOCOMPLETE_RESULTS = 10
backend/data_loader.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load airline_routes.json and build in-memory route graph + search index."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class Route:
12
+ destination: str
13
+ distance_km: int
14
+ duration_min: int
15
+ carriers: list[dict] # [{"iata": "AA", "name": "American Airlines"}, ...]
16
+
17
+
18
+ @dataclass
19
+ class Airport:
20
+ iata: str
21
+ name: str
22
+ city_name: str
23
+ country: str
24
+ country_code: str
25
+ continent: str
26
+ latitude: float
27
+ longitude: float
28
+ timezone: str
29
+ elevation: int
30
+ icao: str
31
+ display_name: str
32
+ routes: list[Route] = field(default_factory=list)
33
+ hub_score: float = 0.0
34
+
35
+
36
+ class RouteGraph:
37
+ """In-memory route graph and search index."""
38
+
39
+ def __init__(self) -> None:
40
+ self.airports: dict[str, Airport] = {}
41
+ # route_map[origin_iata][dest_iata] = Route
42
+ self.route_map: dict[str, dict[str, Route]] = {}
43
+ # Search index: lowercase tokens → set of IATA codes
44
+ self._search_index: dict[str, set[str]] = {}
45
+
46
+ def load(self, filepath: str) -> None:
47
+ with open(filepath) as f:
48
+ data: dict = json.load(f)
49
+
50
+ for iata, info in data.items():
51
+ routes = []
52
+ for r in info.get("routes", []):
53
+ routes.append(Route(
54
+ destination=r["iata"],
55
+ distance_km=r["km"],
56
+ duration_min=r["min"],
57
+ carriers=r["carriers"],
58
+ ))
59
+
60
+ airport = Airport(
61
+ iata=iata,
62
+ name=info["name"],
63
+ city_name=info["city_name"],
64
+ country=info["country"],
65
+ country_code=info["country_code"],
66
+ continent=info["continent"],
67
+ latitude=float(info["latitude"]) if info.get("latitude") is not None else 0.0,
68
+ longitude=float(info["longitude"]) if info.get("longitude") is not None else 0.0,
69
+ timezone=info.get("timezone", "UTC"),
70
+ elevation=info.get("elevation", 0),
71
+ icao=info.get("icao", ""),
72
+ display_name=info.get("display_name", f"{info['city_name']} ({iata})"),
73
+ routes=routes,
74
+ )
75
+ self.airports[iata] = airport
76
+
77
+ # Build route map
78
+ self.route_map.setdefault(iata, {})
79
+ for route in routes:
80
+ self.route_map[iata][route.destination] = route
81
+
82
+ # Build search index
83
+ self._index_airport(airport)
84
+
85
+ def _index_airport(self, airport: Airport) -> None:
86
+ tokens = set()
87
+ # IATA code
88
+ tokens.add(airport.iata.lower())
89
+ # City name tokens
90
+ for word in airport.city_name.lower().split():
91
+ tokens.add(word)
92
+ # Airport name tokens
93
+ for word in airport.name.lower().split():
94
+ tokens.add(word)
95
+ # Country
96
+ for word in airport.country.lower().split():
97
+ tokens.add(word)
98
+ # Country code
99
+ tokens.add(airport.country_code.lower())
100
+
101
+ for token in tokens:
102
+ # Index exact token and all prefixes ≥ 2 chars
103
+ for i in range(2, len(token) + 1):
104
+ prefix = token[:i]
105
+ self._search_index.setdefault(prefix, set()).add(airport.iata)
106
+
107
+ def search_airports(self, query: str, limit: int = 10) -> list[Airport]:
108
+ """Search airports by IATA code, city, name, or country."""
109
+ q = query.strip().lower()
110
+ if not q:
111
+ return []
112
+
113
+ # Exact IATA match first
114
+ if len(q) == 3 and q.upper() in self.airports:
115
+ exact = self.airports[q.upper()]
116
+ results = [exact]
117
+ # Add more results from prefix search
118
+ candidates = self._search_index.get(q, set())
119
+ for iata in candidates:
120
+ if iata != exact.iata:
121
+ results.append(self.airports[iata])
122
+ if len(results) >= limit:
123
+ break
124
+ return results[:limit]
125
+
126
+ # Split query into tokens, intersect matches
127
+ query_tokens = q.split()
128
+ if not query_tokens:
129
+ return []
130
+
131
+ # Get candidates matching first token
132
+ candidates = self._search_index.get(query_tokens[0], set()).copy()
133
+
134
+ # Intersect with additional tokens
135
+ for token in query_tokens[1:]:
136
+ token_matches = self._search_index.get(token, set())
137
+ candidates &= token_matches
138
+
139
+ if not candidates:
140
+ return []
141
+
142
+ # Sort by hub score (descending), then alphabetically
143
+ airports = [self.airports[iata] for iata in candidates if iata in self.airports]
144
+ airports.sort(key=lambda a: (-a.hub_score, a.city_name))
145
+ return airports[:limit]
146
+
147
+ def get_direct_route(self, origin: str, destination: str) -> Route | None:
148
+ return self.route_map.get(origin, {}).get(destination)
149
+
150
+ def get_outbound_routes(self, origin: str) -> dict[str, Route]:
151
+ return self.route_map.get(origin, {})
152
+
153
+
154
+ # Singleton
155
+ _graph: RouteGraph | None = None
156
+
157
+
158
+ def get_route_graph() -> RouteGraph:
159
+ global _graph
160
+ if _graph is None:
161
+ _graph = RouteGraph()
162
+ data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "airline_routes.json")
163
+ _graph.load(data_path)
164
+ return _graph
backend/flight_generator.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate concrete flights for a route + date."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from datetime import date, datetime, timedelta, timezone
7
+
8
+ from zoneinfo import ZoneInfo
9
+
10
+ from .config import (
11
+ AIRCRAFT_BY_DISTANCE,
12
+ DEPARTURE_HOUR_MAX,
13
+ DEPARTURE_HOUR_MIN,
14
+ MAX_FLIGHTS_MULTI_CARRIER,
15
+ MAX_FLIGHTS_SINGLE_CARRIER,
16
+ MAX_LAYOVER_MINUTES,
17
+ MIN_FLIGHTS_PER_DAY,
18
+ MIN_LAYOVER_MINUTES,
19
+ )
20
+ from .data_loader import Route, RouteGraph
21
+ from .models import CabinClass, FlightOffer, FlightSegment
22
+ from .price_engine import compute_price
23
+ from .route_finder import RoutePlan
24
+ from .seed_utils import seeded_random
25
+
26
+
27
+ def _pick_aircraft(distance_km: int, rng: random.Random) -> str:
28
+ for max_dist, aircraft_list in AIRCRAFT_BY_DISTANCE:
29
+ if distance_km <= max_dist:
30
+ return rng.choice(aircraft_list)
31
+ return "777-300ER"
32
+
33
+
34
+ def _make_flight_number(carrier_iata: str, rng: random.Random) -> str:
35
+ return f"{carrier_iata}{rng.randint(100, 9999)}"
36
+
37
+
38
+ def _get_timezone(graph: RouteGraph, iata: str) -> ZoneInfo:
39
+ airport = graph.airports.get(iata)
40
+ if airport and airport.timezone:
41
+ try:
42
+ return ZoneInfo(airport.timezone)
43
+ except KeyError:
44
+ pass
45
+ return ZoneInfo("UTC")
46
+
47
+
48
+ def generate_flights_for_route(
49
+ graph: RouteGraph,
50
+ route_plan: RoutePlan,
51
+ departure_date: date,
52
+ cabin_class: CabinClass,
53
+ hub_iatas: list[str],
54
+ ) -> list[FlightOffer]:
55
+ """Generate concrete flight offers for a route plan on a given date."""
56
+ origin = route_plan.waypoints[0]
57
+ destination = route_plan.waypoints[-1]
58
+
59
+ # Seed based on route + date for determinism
60
+ seed_key = f"{origin}-{destination}-{departure_date.isoformat()}-{cabin_class.value}"
61
+ rng = seeded_random(seed_key, *route_plan.waypoints)
62
+
63
+ if route_plan.stops == 0:
64
+ return _generate_direct_flights(graph, route_plan, departure_date, cabin_class, rng)
65
+ else:
66
+ return _generate_connecting_flights(graph, route_plan, departure_date, cabin_class, rng)
67
+
68
+
69
+ def _generate_direct_flights(
70
+ graph: RouteGraph,
71
+ route_plan: RoutePlan,
72
+ departure_date: date,
73
+ cabin_class: CabinClass,
74
+ rng: random.Random,
75
+ ) -> list[FlightOffer]:
76
+ """Generate multiple direct flight options for a single-leg route."""
77
+ leg = route_plan.legs[0]
78
+ origin = route_plan.waypoints[0]
79
+ destination = route_plan.waypoints[1]
80
+
81
+ # Number of flights based on carrier count
82
+ num_carriers = len(leg.carriers)
83
+ if num_carriers == 1:
84
+ num_flights = rng.randint(MIN_FLIGHTS_PER_DAY, MAX_FLIGHTS_SINGLE_CARRIER)
85
+ elif num_carriers <= 3:
86
+ num_flights = rng.randint(3, 8)
87
+ else:
88
+ num_flights = rng.randint(8, MAX_FLIGHTS_MULTI_CARRIER)
89
+
90
+ # Generate departure times spread across the day
91
+ departure_hours = sorted([
92
+ rng.randint(DEPARTURE_HOUR_MIN * 60, DEPARTURE_HOUR_MAX * 60)
93
+ for _ in range(num_flights)
94
+ ])
95
+
96
+ origin_tz = _get_timezone(graph, origin)
97
+ dest_tz = _get_timezone(graph, destination)
98
+ origin_airport = graph.airports[origin]
99
+ dest_airport = graph.airports[destination]
100
+
101
+ flights = []
102
+ for dep_minutes in departure_hours:
103
+ carrier = rng.choice(leg.carriers)
104
+ dep_hour = dep_minutes // 60
105
+ dep_min = dep_minutes % 60
106
+
107
+ departure_dt = datetime(
108
+ departure_date.year, departure_date.month, departure_date.day,
109
+ dep_hour, dep_min,
110
+ tzinfo=origin_tz,
111
+ )
112
+
113
+ # Calculate arrival
114
+ arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
115
+ arrival_dt = arrival_dt.astimezone(dest_tz)
116
+
117
+ price = compute_price(
118
+ distance_km=leg.distance_km,
119
+ cabin_class=cabin_class.value,
120
+ departure_date=departure_date,
121
+ departure_hour=dep_hour,
122
+ num_carriers=num_carriers,
123
+ dest_continent=dest_airport.continent,
124
+ rng=rng,
125
+ )
126
+
127
+ flight_id = f"{origin}{destination}{departure_date.isoformat()}{dep_minutes}{carrier['iata']}"
128
+
129
+ segment = FlightSegment(
130
+ airline_code=carrier["iata"],
131
+ airline_name=carrier["name"],
132
+ flight_number=_make_flight_number(carrier["iata"], rng),
133
+ aircraft=_pick_aircraft(leg.distance_km, rng),
134
+ origin=origin,
135
+ origin_city=origin_airport.city_name,
136
+ destination=destination,
137
+ destination_city=dest_airport.city_name,
138
+ departure=departure_dt,
139
+ arrival=arrival_dt,
140
+ duration_minutes=leg.duration_min,
141
+ )
142
+
143
+ flights.append(FlightOffer(
144
+ id=flight_id,
145
+ segments=[segment],
146
+ total_duration_minutes=leg.duration_min,
147
+ stops=0,
148
+ price_usd=price,
149
+ cabin_class=cabin_class,
150
+ origin=origin,
151
+ destination=destination,
152
+ departure=departure_dt,
153
+ arrival=arrival_dt,
154
+ ))
155
+
156
+ return flights
157
+
158
+
159
+ def _generate_connecting_flights(
160
+ graph: RouteGraph,
161
+ route_plan: RoutePlan,
162
+ departure_date: date,
163
+ cabin_class: CabinClass,
164
+ rng: random.Random,
165
+ ) -> list[FlightOffer]:
166
+ """Generate connecting flight options (1-stop or 2-stop)."""
167
+ origin = route_plan.waypoints[0]
168
+ destination = route_plan.waypoints[-1]
169
+ dest_airport = graph.airports[destination]
170
+
171
+ # Generate 2-5 options per connecting route
172
+ num_options = rng.randint(2, 5)
173
+
174
+ flights = []
175
+ for option_idx in range(num_options):
176
+ departure_minutes = rng.randint(DEPARTURE_HOUR_MIN * 60, DEPARTURE_HOUR_MAX * 60)
177
+
178
+ segments = []
179
+ current_time = datetime(
180
+ departure_date.year, departure_date.month, departure_date.day,
181
+ departure_minutes // 60, departure_minutes % 60,
182
+ tzinfo=_get_timezone(graph, origin),
183
+ )
184
+ total_price = 0.0
185
+ total_duration = 0
186
+
187
+ valid = True
188
+ for i, leg in enumerate(route_plan.legs):
189
+ leg_origin = route_plan.waypoints[i]
190
+ leg_dest = route_plan.waypoints[i + 1]
191
+ origin_tz = _get_timezone(graph, leg_origin)
192
+ dest_tz = _get_timezone(graph, leg_dest)
193
+ origin_ap = graph.airports[leg_origin]
194
+ dest_ap = graph.airports[leg_dest]
195
+
196
+ carrier = rng.choice(leg.carriers)
197
+ departure_dt = current_time.astimezone(origin_tz)
198
+ arrival_dt = departure_dt + timedelta(minutes=leg.duration_min)
199
+ arrival_dt = arrival_dt.astimezone(dest_tz)
200
+
201
+ # Per-leg price
202
+ leg_price = compute_price(
203
+ distance_km=leg.distance_km,
204
+ cabin_class=cabin_class.value,
205
+ departure_date=departure_date,
206
+ departure_hour=departure_dt.hour,
207
+ num_carriers=len(leg.carriers),
208
+ dest_continent=dest_ap.continent,
209
+ rng=rng,
210
+ )
211
+ # Connecting flights get a discount
212
+ leg_price *= 0.75
213
+ total_price += leg_price
214
+
215
+ segments.append(FlightSegment(
216
+ airline_code=carrier["iata"],
217
+ airline_name=carrier["name"],
218
+ flight_number=_make_flight_number(carrier["iata"], rng),
219
+ aircraft=_pick_aircraft(leg.distance_km, rng),
220
+ origin=leg_origin,
221
+ origin_city=origin_ap.city_name,
222
+ destination=leg_dest,
223
+ destination_city=dest_ap.city_name,
224
+ departure=departure_dt,
225
+ arrival=arrival_dt,
226
+ duration_minutes=leg.duration_min,
227
+ ))
228
+
229
+ # Add layover time for next leg
230
+ if i < len(route_plan.legs) - 1:
231
+ layover = rng.randint(MIN_LAYOVER_MINUTES, MAX_LAYOVER_MINUTES)
232
+ current_time = arrival_dt + timedelta(minutes=layover)
233
+ total_duration += leg.duration_min + layover
234
+
235
+ # Check if layover pushes to next day too far
236
+ if (current_time - datetime(
237
+ departure_date.year, departure_date.month, departure_date.day,
238
+ tzinfo=origin_tz
239
+ )).days > 1:
240
+ valid = False
241
+ break
242
+ else:
243
+ total_duration += leg.duration_min
244
+
245
+ if not valid:
246
+ continue
247
+
248
+ total_price = round(total_price, 0)
249
+ first_departure = segments[0].departure
250
+ last_arrival = segments[-1].arrival
251
+
252
+ flight_id = (
253
+ f"{origin}{destination}{departure_date.isoformat()}"
254
+ f"{departure_minutes}{'-'.join(route_plan.waypoints)}{option_idx}"
255
+ )
256
+
257
+ flights.append(FlightOffer(
258
+ id=flight_id,
259
+ segments=segments,
260
+ total_duration_minutes=total_duration,
261
+ stops=route_plan.stops,
262
+ price_usd=total_price,
263
+ cabin_class=cabin_class,
264
+ origin=origin,
265
+ destination=destination,
266
+ departure=first_departure,
267
+ arrival=last_arrival,
268
+ ))
269
+
270
+ return flights
backend/hub_detector.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute hub scores for airports.
2
+
3
+ Hub score = route_count * carrier_diversity * continent_reach
4
+ Used for: connecting flight search (top hubs as waypoints) and autocomplete ranking.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .config import HUB_MIN_ROUTES, HUB_TOP_N
10
+ from .data_loader import RouteGraph
11
+
12
+
13
+ def compute_hub_scores(graph: RouteGraph) -> list[str]:
14
+ """Compute hub scores for all airports. Returns top hub IATA codes.
15
+
16
+ Modifies airports in-place to set hub_score.
17
+ """
18
+ for airport in graph.airports.values():
19
+ route_count = len(airport.routes)
20
+ if route_count < 5:
21
+ airport.hub_score = 0.0
22
+ continue
23
+
24
+ # Carrier diversity: unique carriers across all routes
25
+ carriers = set()
26
+ for route in airport.routes:
27
+ for c in route.carriers:
28
+ carriers.add(c["iata"])
29
+ carrier_diversity = len(carriers)
30
+
31
+ # Continent reach: unique continents reachable
32
+ continents = set()
33
+ for route in airport.routes:
34
+ dest = graph.airports.get(route.destination)
35
+ if dest:
36
+ continents.add(dest.continent)
37
+ continent_reach = len(continents)
38
+
39
+ airport.hub_score = route_count * (carrier_diversity ** 0.5) * (continent_reach ** 0.3)
40
+
41
+ # Normalize scores to 0-100
42
+ max_score = max((a.hub_score for a in graph.airports.values()), default=1.0)
43
+ if max_score > 0:
44
+ for airport in graph.airports.values():
45
+ airport.hub_score = round(airport.hub_score / max_score * 100, 2)
46
+
47
+ # Return top hubs
48
+ hubs = sorted(
49
+ [a for a in graph.airports.values() if len(a.routes) >= HUB_MIN_ROUTES],
50
+ key=lambda a: -a.hub_score,
51
+ )
52
+ return [h.iata for h in hubs[:HUB_TOP_N]]
backend/main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application — flight search backend."""
2
+
3
+ import os
4
+ import time
5
+
6
+ from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.responses import FileResponse
10
+
11
+ from .api import airports, calendar, search
12
+ from .data_loader import get_route_graph
13
+ from .hub_detector import compute_hub_scores
14
+
15
+ app = FastAPI(title="Flight Search API", version="1.0.0")
16
+
17
+ # CORS for development
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # Register API routers
27
+ app.include_router(airports.router)
28
+ app.include_router(search.router)
29
+ app.include_router(calendar.router)
30
+
31
+
32
+ @app.on_event("startup")
33
+ async def startup():
34
+ """Load data and compute hub scores on startup."""
35
+ t0 = time.time()
36
+ graph = get_route_graph()
37
+ hubs = compute_hub_scores(graph)
38
+ elapsed = time.time() - t0
39
+ print(f"Loaded {len(graph.airports)} airports, {len(hubs)} hubs in {elapsed:.1f}s")
40
+
41
+
42
+ @app.get("/api/health")
43
+ async def health():
44
+ graph = get_route_graph()
45
+ return {"status": "ok", "airports": len(graph.airports)}
46
+
47
+
48
+ # Serve frontend static files (production)
49
+ STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend", "dist")
50
+ if os.path.isdir(STATIC_DIR):
51
+ app.mount("/assets", StaticFiles(directory=os.path.join(STATIC_DIR, "assets")), name="assets")
52
+
53
+ @app.get("/{full_path:path}")
54
+ async def serve_frontend(full_path: str):
55
+ """Serve the React SPA for all non-API routes."""
56
+ file_path = os.path.join(STATIC_DIR, full_path)
57
+ if os.path.isfile(file_path):
58
+ return FileResponse(file_path)
59
+ return FileResponse(os.path.join(STATIC_DIR, "index.html"))
backend/models.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for API request/response contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import date, datetime
6
+ from enum import Enum
7
+ from typing import Optional
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class CabinClass(str, Enum):
13
+ economy = "economy"
14
+ premium_economy = "premium_economy"
15
+ business = "business"
16
+ first = "first"
17
+
18
+
19
+ class TripType(str, Enum):
20
+ one_way = "one_way"
21
+ round_trip = "round_trip"
22
+ multi_city = "multi_city"
23
+
24
+
25
+ class SortBy(str, Enum):
26
+ best = "best"
27
+ cheapest = "cheapest"
28
+ fastest = "fastest"
29
+
30
+
31
+ # --- Airport ---
32
+
33
+ class AirportInfo(BaseModel):
34
+ iata: str
35
+ name: str
36
+ city_name: str
37
+ country: str
38
+ country_code: str
39
+ continent: str
40
+ latitude: float
41
+ longitude: float
42
+ timezone: str
43
+ hub_score: float = 0.0
44
+ route_count: int = 0
45
+
46
+
47
+ # --- Flight segment ---
48
+
49
+ class FlightSegment(BaseModel):
50
+ airline_code: str
51
+ airline_name: str
52
+ flight_number: str
53
+ aircraft: str
54
+ origin: str
55
+ origin_city: str
56
+ destination: str
57
+ destination_city: str
58
+ departure: datetime
59
+ arrival: datetime
60
+ duration_minutes: int
61
+
62
+
63
+ # --- Flight offer (may have multiple segments) ---
64
+
65
+ class FlightOffer(BaseModel):
66
+ id: str
67
+ segments: list[FlightSegment]
68
+ total_duration_minutes: int
69
+ stops: int
70
+ price_usd: float
71
+ cabin_class: CabinClass
72
+ origin: str
73
+ destination: str
74
+ departure: datetime
75
+ arrival: datetime
76
+
77
+
78
+ # --- Search request ---
79
+
80
+ class SearchLeg(BaseModel):
81
+ origin: str = Field(..., min_length=3, max_length=3, description="IATA code")
82
+ destination: str = Field(..., min_length=3, max_length=3, description="IATA code")
83
+ date: date
84
+
85
+
86
+ class Passengers(BaseModel):
87
+ adults: int = Field(1, ge=1, le=9)
88
+ children: int = Field(0, ge=0, le=9)
89
+ infants: int = Field(0, ge=0, le=4)
90
+
91
+ @property
92
+ def total(self) -> int:
93
+ return self.adults + self.children + self.infants
94
+
95
+
96
+ class Filters(BaseModel):
97
+ max_stops: Optional[int] = None
98
+ max_price: Optional[float] = None
99
+ max_duration_minutes: Optional[int] = None
100
+ airlines: Optional[list[str]] = None # IATA codes to include
101
+ departure_time_min: Optional[str] = None # "06:00"
102
+ departure_time_max: Optional[str] = None # "18:00"
103
+
104
+
105
+ class SearchRequest(BaseModel):
106
+ trip_type: TripType = TripType.round_trip
107
+ legs: list[SearchLeg] = Field(..., min_length=1, max_length=6)
108
+ passengers: Passengers = Passengers()
109
+ cabin_class: CabinClass = CabinClass.economy
110
+ filters: Filters = Filters()
111
+ sort_by: SortBy = SortBy.best
112
+
113
+
114
+ class SearchResponse(BaseModel):
115
+ outbound_flights: list[FlightOffer]
116
+ return_flights: list[FlightOffer] = []
117
+ search_id: str
118
+ origin: str
119
+ destination: str
120
+
121
+
122
+ # --- Calendar ---
123
+
124
+ class CalendarDay(BaseModel):
125
+ date: date
126
+ cheapest_price: Optional[float] = None
127
+
128
+
129
+ class CalendarResponse(BaseModel):
130
+ origin: str
131
+ destination: str
132
+ year: int
133
+ month: int
134
+ days: list[CalendarDay]
135
+
136
+
137
+ # --- Autocomplete ---
138
+
139
+ class AutocompleteResult(BaseModel):
140
+ iata: str
141
+ name: str
142
+ city_name: str
143
+ country: str
144
+ display_name: str
145
+ hub_score: float = 0.0
backend/price_engine.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Price formula: base + 8 multipliers.
2
+
3
+ base_usd = 40 + (distance_km * 0.08)
4
+ final = base * class * day_of_week * time_of_day * season * demand * advance * jitter
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import random
10
+ from datetime import date, datetime, timedelta
11
+
12
+ from .config import (
13
+ ADVANCE_MULTIPLIERS,
14
+ BASE_FIXED_USD,
15
+ BASE_PER_KM_USD,
16
+ CLASS_MULTIPLIERS,
17
+ DAY_MULTIPLIERS,
18
+ EU_CONTINENTS,
19
+ EU_SUMMER_BONUS,
20
+ EU_SUMMER_MONTHS,
21
+ HIGH_COMPETITION_DISCOUNT,
22
+ JITTER_RANGE,
23
+ MONOPOLY_ROUTE_BONUS,
24
+ SEASON_MULTIPLIERS,
25
+ )
26
+
27
+
28
+ def compute_price(
29
+ distance_km: int,
30
+ cabin_class: str,
31
+ departure_date: date,
32
+ departure_hour: int,
33
+ num_carriers: int,
34
+ dest_continent: str,
35
+ rng: random.Random,
36
+ booking_date: date | None = None,
37
+ ) -> float:
38
+ """Compute flight price using the full pricing formula."""
39
+ base = BASE_FIXED_USD + (distance_km * BASE_PER_KM_USD)
40
+
41
+ # 1. Cabin class
42
+ class_mult = CLASS_MULTIPLIERS.get(cabin_class, 1.0)
43
+
44
+ # 2. Day of week
45
+ day_mult = DAY_MULTIPLIERS.get(departure_date.weekday(), 1.0)
46
+
47
+ # 3. Time of day
48
+ if 6 <= departure_hour <= 8:
49
+ time_mult = 1.10 # Morning peak
50
+ elif 16 <= departure_hour <= 19:
51
+ time_mult = 1.15 # Evening peak
52
+ elif departure_hour >= 22 or departure_hour <= 5:
53
+ time_mult = 0.85 # Red-eye discount
54
+ else:
55
+ time_mult = 1.00
56
+
57
+ # 4. Season
58
+ season_mult = SEASON_MULTIPLIERS.get(departure_date.month, 1.0)
59
+ # EU summer bonus
60
+ if dest_continent in EU_CONTINENTS and departure_date.month in EU_SUMMER_MONTHS:
61
+ season_mult += EU_SUMMER_BONUS
62
+
63
+ # 5. Demand (based on competition)
64
+ if num_carriers == 1:
65
+ demand_mult = 1.0 + MONOPOLY_ROUTE_BONUS
66
+ elif num_carriers >= 4:
67
+ demand_mult = 1.0 - HIGH_COMPETITION_DISCOUNT
68
+ else:
69
+ demand_mult = 1.0
70
+
71
+ # 6. Advance booking
72
+ if booking_date is None:
73
+ booking_date = date.today()
74
+ days_advance = (departure_date - booking_date).days
75
+ if days_advance < 0:
76
+ days_advance = 0
77
+ advance_mult = 1.0
78
+ for threshold, mult in ADVANCE_MULTIPLIERS:
79
+ if days_advance <= threshold:
80
+ advance_mult = mult
81
+ break
82
+
83
+ # 7. Jitter (seeded)
84
+ jitter = 1.0 + rng.uniform(-JITTER_RANGE, JITTER_RANGE)
85
+
86
+ price = base * class_mult * day_mult * time_mult * season_mult * demand_mult * advance_mult * jitter
87
+
88
+ # Round to nearest dollar, minimum $25
89
+ return max(25.0, round(price, 0))
90
+
91
+
92
+ def compute_calendar_price(
93
+ distance_km: int,
94
+ cabin_class: str,
95
+ target_date: date,
96
+ num_carriers: int,
97
+ dest_continent: str,
98
+ rng: random.Random,
99
+ ) -> float:
100
+ """Compute cheapest flight price for a given date (for calendar view).
101
+
102
+ Uses noon departure and 14-day advance booking as baseline.
103
+ """
104
+ return compute_price(
105
+ distance_km=distance_km,
106
+ cabin_class=cabin_class,
107
+ departure_date=target_date,
108
+ departure_hour=12,
109
+ num_carriers=num_carriers,
110
+ dest_continent=dest_continent,
111
+ rng=rng,
112
+ booking_date=target_date - timedelta(days=14),
113
+ )
backend/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.27.0
3
+ pydantic>=2.6.0
backend/route_finder.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Direct + 1-stop + 2-stop route discovery."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from .config import MAX_1STOP_DISTANCE_RATIO, MAX_2STOP_DISTANCE_RATIO
8
+ from .data_loader import Route, RouteGraph
9
+
10
+
11
+ @dataclass
12
+ class RoutePlan:
13
+ """A planned route from origin to destination (may have multiple legs)."""
14
+ legs: list[Route] # Each leg has origin implicitly from position
15
+ waypoints: list[str] # [origin, hub1, ..., destination]
16
+ total_distance_km: int
17
+ total_duration_min: int
18
+
19
+ @property
20
+ def stops(self) -> int:
21
+ return len(self.legs) - 1
22
+
23
+
24
+ def find_routes(
25
+ graph: RouteGraph,
26
+ origin: str,
27
+ destination: str,
28
+ hub_iatas: list[str],
29
+ max_stops: int | None = None,
30
+ ) -> list[RoutePlan]:
31
+ """Find all route plans from origin to destination.
32
+
33
+ Returns direct, 1-stop, and 2-stop routes.
34
+ """
35
+ results: list[RoutePlan] = []
36
+
37
+ if max_stops is not None and max_stops < 0:
38
+ return results
39
+
40
+ # Direct route
41
+ direct = graph.get_direct_route(origin, destination)
42
+ if direct:
43
+ results.append(RoutePlan(
44
+ legs=[direct],
45
+ waypoints=[origin, destination],
46
+ total_distance_km=direct.distance_km,
47
+ total_duration_min=direct.duration_min,
48
+ ))
49
+
50
+ if max_stops is not None and max_stops == 0:
51
+ return results
52
+
53
+ # 1-stop routes through hubs
54
+ direct_distance = direct.distance_km if direct else _estimate_distance(graph, origin, destination)
55
+ if direct_distance is None:
56
+ return results
57
+
58
+ origin_routes = graph.get_outbound_routes(origin)
59
+
60
+ for hub in hub_iatas:
61
+ if hub == origin or hub == destination:
62
+ continue
63
+
64
+ leg1 = origin_routes.get(hub)
65
+ if not leg1:
66
+ continue
67
+
68
+ leg2 = graph.get_direct_route(hub, destination)
69
+ if not leg2:
70
+ continue
71
+
72
+ total_dist = leg1.distance_km + leg2.distance_km
73
+ if total_dist > direct_distance * MAX_1STOP_DISTANCE_RATIO:
74
+ continue
75
+
76
+ total_dur = leg1.duration_min + leg2.duration_min + 90 # +90 min layover estimate
77
+ results.append(RoutePlan(
78
+ legs=[leg1, leg2],
79
+ waypoints=[origin, hub, destination],
80
+ total_distance_km=total_dist,
81
+ total_duration_min=total_dur,
82
+ ))
83
+
84
+ if max_stops is not None and max_stops <= 1:
85
+ return results
86
+
87
+ # 2-stop routes through pairs of hubs (limit to top hubs for performance)
88
+ top_hubs = hub_iatas[:60]
89
+ for hub1 in top_hubs:
90
+ if hub1 == origin or hub1 == destination:
91
+ continue
92
+ leg1 = origin_routes.get(hub1)
93
+ if not leg1:
94
+ continue
95
+
96
+ hub1_routes = graph.get_outbound_routes(hub1)
97
+ for hub2 in top_hubs:
98
+ if hub2 == origin or hub2 == destination or hub2 == hub1:
99
+ continue
100
+
101
+ leg2 = hub1_routes.get(hub2)
102
+ if not leg2:
103
+ continue
104
+
105
+ leg3 = graph.get_direct_route(hub2, destination)
106
+ if not leg3:
107
+ continue
108
+
109
+ total_dist = leg1.distance_km + leg2.distance_km + leg3.distance_km
110
+ if total_dist > direct_distance * MAX_2STOP_DISTANCE_RATIO:
111
+ continue
112
+
113
+ total_dur = (leg1.duration_min + leg2.duration_min + leg3.duration_min
114
+ + 90 + 90) # Two layovers
115
+ results.append(RoutePlan(
116
+ legs=[leg1, leg2, leg3],
117
+ waypoints=[origin, hub1, hub2, destination],
118
+ total_distance_km=total_dist,
119
+ total_duration_min=total_dur,
120
+ ))
121
+
122
+ return results
123
+
124
+
125
+ def _estimate_distance(graph: RouteGraph, origin: str, destination: str) -> int | None:
126
+ """Estimate great-circle distance between two airports using coordinates."""
127
+ import math
128
+
129
+ o = graph.airports.get(origin)
130
+ d = graph.airports.get(destination)
131
+ if not o or not d:
132
+ return None
133
+
134
+ lat1, lon1 = math.radians(o.latitude), math.radians(o.longitude)
135
+ lat2, lon2 = math.radians(d.latitude), math.radians(d.longitude)
136
+
137
+ dlat = lat2 - lat1
138
+ dlon = lon2 - lon1
139
+ a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
140
+ c = 2 * math.asin(math.sqrt(a))
141
+ return int(c * 6371) # Earth radius in km
backend/seed_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic seeding utilities.
2
+
3
+ Same search parameters always produce the same flights and prices.
4
+ Uses SHA-256 hash of search params → integer seed for random.Random.
5
+ """
6
+
7
+ import hashlib
8
+ import random
9
+
10
+
11
+ def make_seed(*parts: str | int | float) -> int:
12
+ """Create a deterministic seed from arbitrary parts."""
13
+ key = "|".join(str(p) for p in parts)
14
+ h = hashlib.sha256(key.encode()).hexdigest()
15
+ return int(h[:16], 16)
16
+
17
+
18
+ def seeded_random(*parts: str | int | float) -> random.Random:
19
+ """Return a seeded Random instance for the given search params."""
20
+ return random.Random(make_seed(*parts))
description.md ADDED
@@ -0,0 +1,1122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ECMoE — Method and Experiment Description
2
+
3
+ ## 1. Problem Statement
4
+
5
+ In Mixture-of-Experts (MoE) models with expert parallelism, each token's hidden state must be communicated between GPUs twice per MoE layer:
6
+
7
+ 1. **Dispatch (all-to-all):** The hidden state is sent from the token's source GPU to the GPU hosting its assigned expert(s).
8
+ 2. **Gather (all-to-all):** The expert output is sent back to the source GPU.
9
+
10
+ For a model like Qwen3-30B-A3B with `hidden_dim=2048` and 48 MoE layers, each token requires transmitting `2 × 48 × 2048 × 2 bytes = 384 KB` of data per forward pass (in BF16). At scale, this communication dominates inference latency.
11
+
12
+ This project investigates methods to **compress these hidden-state vectors** before transmission, reducing communication volume while preserving model quality.
13
+
14
+ ### Training paradigms
15
+
16
+ This project uses two training paradigms:
17
+
18
+ **Offline (Tasks 2–4):** Compressors are trained on **cached hidden states**, not end-to-end through the LLM:
19
+
20
+ 1. **Capture:** Run the unmodified LLM on calibration data and cache MoE layer inputs/outputs to disk.
21
+ 2. **Train:** Train each compressor/decompressor pair independently on the cached data, minimizing a local reconstruction loss. No gradients flow through the LLM.
22
+ 3. **Evaluate:** Insert trained compressors into the live model via forward hooks and measure perplexity.
23
+
24
+ Each pair is trained in isolation — no joint optimization across layers, no end-to-end backpropagation. This is cheap (minutes per layer) but means compressors cannot adapt to how errors compound across layers.
25
+
26
+ **End-to-end (Task 5):** Compressors are trained **through the live LLM** using the language modeling objective:
27
+
28
+ 1. **Insert:** Register per-layer compressor/decompressor pairs as forward pre-hooks on each MoE layer.
29
+ 2. **Train:** Run standard next-token prediction. The LLM weights are frozen; only compressor parameters receive gradients. Gradients flow through the entire frozen LLM.
30
+ 3. **Evaluate:** Same hook-based perplexity evaluation as offline methods.
31
+
32
+ All 48 compressors are optimized jointly through a single global loss. This allows the system to learn how compression errors at early layers affect all downstream layers.
33
+
34
+ ---
35
+
36
+ ## 2. Model Specification
37
+
38
+ | Property | Value |
39
+ |---|---|
40
+ | Architecture | Qwen3-30B-A3B-Instruct-2507 |
41
+ | Total parameters | 30.53B |
42
+ | Activated parameters | 3.35B |
43
+ | Hidden dimension | 2048 |
44
+ | Number of layers | 48 (all MoE) |
45
+ | Number of experts | 128 per layer |
46
+ | Top-k routing | 8 experts per token |
47
+ | Attention heads | 32 (Q), 4 (KV) |
48
+ | Head dimension | 128 |
49
+ | MoE expert FFN intermediate size | 768 |
50
+ | Vocabulary size | 151,936 |
51
+
52
+ All tasks use the same model variant and precision:
53
+
54
+ | Variant | Used in | Loading | VRAM |
55
+ |---|---|---|---|
56
+ | Qwen3-30B-A3B-Instruct-2507 | All tasks (1–5) | Full BF16 | ~60 GB |
57
+
58
+ **Tasks 1–4:** Single GPU (`device="cuda:0"`). The ~60 GB model fits on one H100 80 GB with headroom for inference activations. Using single-GPU avoids the overhead of cross-GPU communication from `device_map="auto"`.
59
+
60
+ **Task 5:** Multi-GPU via `device_map="auto"` across 4 GPUs. Backpropagation through the frozen model during end-to-end training requires additional VRAM for activations and gradient checkpoints that exceed single-GPU capacity.
61
+
62
+ ---
63
+
64
+ ## 3. Data Collection
65
+
66
+ ### 3.1 Calibration Data
67
+
68
+ - **Dataset:** allenai/Dolci-Instruct-SFT (train split)
69
+ - **Format:** Chat-formatted instruction data, tokenized via `tokenizer.apply_chat_template()`
70
+ - **Sequences:** Up to 256 samples, each tokenized independently (one conversation = one sequence)
71
+ - **Max length:** 2048 tokens per sequence (configurable via `--max-length`)
72
+ - **SFT mode:** Labels mask non-assistant tokens with -100; perplexity computed on responses only
73
+ - **Response-only collection:** By default, only assistant-response tokens are captured
74
+ (positions where `labels != -100`). This ensures offline compressor training (Tasks 2–4)
75
+ trains on the same token distribution that PPL evaluation measures. Use `--no-response-only`
76
+ for legacy all-token collection.
77
+ - **Total tokens collected:** 100,000 per MoE layer (response tokens only by default)
78
+
79
+ ### 3.2 Hidden State Capture
80
+
81
+ PyTorch forward hooks are registered on each MoE module:
82
+ - **Pre-forward hook** captures dispatch states (MoE layer inputs)
83
+ - **Post-forward hook** captures gather states (MoE layer outputs)
84
+
85
+ **Token filtering:** `MoEHiddenStateCollector` supports a per-sequence boolean mask
86
+ (`set_token_mask(mask)`). When `response_only=True` (default), the mask is derived from
87
+ `labels != -100` before each forward pass. The same mask is applied to all 48 MoE layers
88
+ within a sequence, preserving token alignment across layers. Positions where the mask is
89
+ `False` (system, user, template markup, padding) are not collected.
90
+
91
+ Each captured tensor has shape `[N, 2048]` where N = number of response tokens (or all
92
+ tokens if `response_only=False`). States are stored in the model's native dtype (`bfloat16`)
93
+ on CPU.
94
+
95
+ **Implementation:** `MoEHiddenStateCollector` class in `src/model_utils.py`.
96
+
97
+ ### 3.3 Storage
98
+
99
+ ```
100
+ data/hidden_states/
101
+ ├── dispatch_states.pt # dict {layer_name: tensor [100000, 2048]}
102
+ ├── gather_states.pt # dict {layer_name: tensor [100000, 2048]}
103
+ └── metadata.json # model name, dims, token count, layer names
104
+ ```
105
+
106
+ Total size: ~37 GB (18.5 GB dispatch + 18.5 GB gather, bfloat16 = 2 bytes/value).
107
+
108
+ ---
109
+
110
+ ## 4. Evaluation Methodology
111
+
112
+ ### 4.1 Reconstruction Metrics (Offline)
113
+
114
+ Computed on cached hidden states without running the full model:
115
+
116
+ | Metric | Formula | Notes |
117
+ |---|---|---|
118
+ | MSE | `mean((x - x')²)` | Mean squared error |
119
+ | Cosine Similarity | `mean(cos(x, x'))` | Per-token, averaged |
120
+ | Relative Error | `mean(‖x - x'‖₂ / ‖x‖₂)` | Per-token L2 relative error |
121
+ | SNR (dB) | `10 · log₁₀(signal_power / noise_power)` | Signal-to-noise ratio |
122
+
123
+ **Implementation:** `src/metrics.py`
124
+
125
+ ### 4.2 End-to-End Perplexity (Online)
126
+
127
+ The true impact of compression is measured by evaluating cross-entropy perplexity on allenai/Dolci-Instruct-SFT (the same dataset used for calibration/training) with compression hooks active:
128
+
129
+ - **Dispatch compression:** A pre-forward hook on each MoE block applies `compress → decompress` to the input hidden states before they enter the block.
130
+ - **Evaluation:** 50,000 sequences, max length 2048 tokens.
131
+ - **SFT mode:** Perplexity is computed on assistant response tokens only. Non-response tokens
132
+ (system, user, template markup) are labeled with -100 and excluded from the loss.
133
+ This measures the model's ability to generate correct responses, not to predict prompt tokens.
134
+
135
+ **Caveat:** This simulation also affects the router's input. In real expert parallelism, the router runs on the original hidden state at the source node. Our simulation gives a **conservative lower bound** — the true impact would be smaller.
136
+
137
+ **Implementation:** `evaluate_perplexity_with_compression()`, `evaluate_perplexity_with_perlayer_compression()`, `evaluate_perplexity_with_stale_compression()` in `src/model_utils.py`.
138
+
139
+ ---
140
+
141
+ ## 5. Method Descriptions
142
+
143
+ ### 5.1 Quantization Baseline (Task 2)
144
+
145
+ **Idea:** Reduce the bit width of hidden-state elements from BF16 (16 bits) to INT8/INT4/INT2.
146
+
147
+ **Symmetric (absmax) quantization:**
148
+ ```
149
+ scale = max(|x|) / (2^(bits-1) - 1) # per-token
150
+ x_q = round(x / scale) # quantize
151
+ x' = x_q * scale # dequantize
152
+ ```
153
+
154
+ **Asymmetric (zero-point) quantization:**
155
+ ```
156
+ scale = (max(x) - min(x)) / (2^bits - 1)
157
+ zero_point = round(-min(x) / scale)
158
+ x_q = round(x / scale + zero_point)
159
+ x' = (x_q - zero_point) * scale
160
+ ```
161
+
162
+ **Compression ratios:**
163
+
164
+ | Bits | Effective Ratio | Bytes/token (hidden_dim=2048) |
165
+ |---|---|---|
166
+ | INT8 (absmax) | ~2.0x | 2050 (2048 + 2 for scale) |
167
+ | INT4 (absmax) | ~4.0x | 1026 (1024 + 2 for scale) |
168
+ | INT2 (absmax) | ~8.0x | 514 (512 + 2 for scale) |
169
+
170
+ **Additional parameters:** 0 (quantization is parameter-free).
171
+
172
+ **Implementation:** `src/run_quantization.py`
173
+
174
+ ---
175
+
176
+ ### 5.2 Shared Neural Compressor (Task 3)
177
+
178
+ **Idea:** Train a single-layer linear autoencoder shared across all 48 MoE layers.
179
+
180
+ **Architecture:**
181
+ ```
182
+ Compressor: Linear(2048, bottleneck_dim) + bias
183
+ Decompressor: Linear(bottleneck_dim, 2048) + bias
184
+ ```
185
+
186
+ One compressor-decompressor pair is shared across all layers. Training data pools dispatch states from all 48 layers. Training is offline: the compressor minimizes reconstruction loss on cached hidden states, with no gradients flowing through the LLM.
187
+
188
+ **Compression ratios:** `hidden_dim / bottleneck_dim` = {2x, 4x, 8x, 16x} corresponding to `bottleneck_dim` = {1024, 512, 256, 128}.
189
+
190
+ **Training hyperparameters:**
191
+
192
+ | Parameter | Value |
193
+ |---|---|
194
+ | Optimizer | Adam |
195
+ | Learning rate | 1e-3 |
196
+ | LR schedule | Cosine annealing (T_max = epochs) |
197
+ | Max epochs | 50 |
198
+ | Batch size | 2048 |
199
+ | Early stopping patience | 8 epochs |
200
+ | Validation fraction | 10% |
201
+ | Loss function | MSE + 0.1 × (1 - cosine_similarity) |
202
+
203
+ **Loss function:**
204
+ ```
205
+ L = MSE(x', x) + λ · (1 - mean(cos_sim(x', x)))
206
+ ```
207
+ where `λ = 0.1` (cosine_weight). The cosine term encourages preserving direction, not just magnitude.
208
+
209
+ **Parameter count:**
210
+ ```
211
+ params = (2048 × b + b) + (b × 2048 + 2048)
212
+ ```
213
+ where `b` = bottleneck_dim.
214
+
215
+ | Ratio | Bottleneck | Parameters | % of Activated |
216
+ |---|---|---|---|
217
+ | 2x | 1024 | 4.20M | 0.125% |
218
+ | 4x | 512 | 2.10M | 0.063% |
219
+ | 8x | 256 | 1.05M | 0.031% |
220
+ | 16x | 128 | 0.53M | 0.016% |
221
+
222
+ **Implementation:** `src/run_neural_compressor.py`
223
+
224
+ ---
225
+
226
+ ### 5.3 Per-Layer Neural Compressor (Task 3b)
227
+
228
+ **Motivation:** Hidden state distributions vary dramatically across layers:
229
+ - Standard deviation: 0.16 (layer 0) → 1.21 (layer 47)
230
+ - Kurtosis: 3 (near-Gaussian, early layers) → 81,340 (extremely heavy-tailed, late layers)
231
+
232
+ A single shared compressor cannot adapt to this variation.
233
+
234
+ **Architecture:** Same `Compressor` + `Decompressor` structure, but **48 independent pairs** — one per MoE layer. Each layer's compressor is trained independently and only on that layer's cached dispatch data. There is no joint optimization across layers.
235
+
236
+ **Compression ratios:** Same as shared: {2x, 4x, 8x, 16x}.
237
+
238
+ **Training:** Same hyperparameters as shared (see Section 5.2). Each layer is trained independently on its own 100K token dispatch data (90% train / 10% val).
239
+
240
+ **Parameter count:**
241
+ ```
242
+ params = 48 × (2048 × b + b + b × 2048 + 2048)
243
+ ```
244
+
245
+ | Ratio | Bottleneck | Parameters | % of Activated |
246
+ |---|---|---|---|
247
+ | 2x | 1024 | 201.47M | 6.008% |
248
+ | 4x | 512 | 100.79M | 3.006% |
249
+ | 8x | 256 | 50.44M | 1.504% |
250
+ | 16x | 128 | 25.27M | 0.754% |
251
+
252
+ **Implementation:** `src/run_perlayer_compressor.py`
253
+
254
+ ---
255
+
256
+ ### 5.4 Stale-Conditioned Compressor (Tasks 4a/4b)
257
+
258
+ **Motivation:** Adjacent MoE layers process the same token, so their hidden states are correlated. A decompressor can exploit this by receiving a "stale" signal — the hidden state from a nearby layer that was already transmitted — as side information.
259
+
260
+ **Reference layer grouping (stride=12):**
261
+ - Reference layers: {0, 12, 24, 36} (4 layers)
262
+ - Layer 1–11 → stale from layer 0
263
+ - Layer 13–23 → stale from layer 12
264
+ - Layer 25–35 → stale from layer 24
265
+ - Layer 37–47 → stale from layer 36
266
+
267
+ **Architecture:**
268
+ - **Reference layers** use standard per-layer `Compressor` + `Decompressor` (no stale signal).
269
+ - **Non-reference layers** use `Compressor` + `StaleDecompressor`:
270
+
271
+ ```
272
+ Compressor: Linear(2048, bottleneck_dim) + bias
273
+ StaleDecompressor: Linear(bottleneck_dim + stale_dim, 2048) + bias
274
+ ```
275
+
276
+ The decompressor receives `cat(compressed_current, stale_signal)` as input.
277
+
278
+ **Two stale modes:**
279
+
280
+ | Mode | Task | Stale signal | StaleDecompressor input dim |
281
+ |---|---|---|---|
282
+ | Compressed (4a) | `--stale-mode compressed` | Compressed ref layer input (via ref's compressor) | `bottleneck_dim + bottleneck_dim` |
283
+ | Uncompressed (4b) | `--stale-mode uncompressed` | Raw ref layer input (full hidden dim) | `bottleneck_dim + 2048` |
284
+
285
+ **Training:**
286
+ 1. **Phase 1:** Train reference layer compressors independently (standard per-layer autoencoder, same hyperparameters as Section 5.2).
287
+ 2. **Phase 2:** Train non-reference layer compressors independently. For each non-ref layer:
288
+ - Current data: that layer's cached dispatch states
289
+ - Stale data: the reference layer's cached dispatch states (compressed or raw, depending on mode)
290
+ - The stale signal is **pre-computed and frozen** — the reference layer's compressor is not jointly optimized with non-reference layers
291
+ - Token alignment is guaranteed: `dispatch[layer_0][i]` and `dispatch[layer_5][i]` correspond to the same token
292
+
293
+ As with all neural methods in this project, training is offline on cached hidden states. No gradients flow through the LLM, and each layer's compressor is trained in isolation.
294
+
295
+ **Stale-conditioned training loss:** Same as Section 5.2 (`MSE + 0.1 × (1 - cos_sim)`), but the decompressor receives the concatenated input.
296
+
297
+ **Parameter count:**
298
+
299
+ For compressed stale (`stale_dim = bottleneck_dim`):
300
+ ```
301
+ ref_pair = (2048 × b + b) + (b × 2048 + 2048)
302
+ nonref_pair = (2048 × b + b) + ((b + b) × 2048 + 2048)
303
+ total = 4 × ref_pair + 44 × nonref_pair
304
+ ```
305
+
306
+ For uncompressed stale (`stale_dim = 2048`):
307
+ ```
308
+ ref_pair = (2048 × b + b) + (b × 2048 + 2048)
309
+ nonref_pair = (2048 × b + b) + ((b + 2048) × 2048 + 2048)
310
+ total = 4 × ref_pair + 44 × nonref_pair
311
+ ```
312
+
313
+ | Mode | Ratio | Bottleneck | Stale dim | Parameters | % of Activated |
314
+ |---|---|---|---|---|---|
315
+ | Compressed | 2x | 1024 | 1024 | 293.75M | 8.760% |
316
+ | Compressed | 4x | 512 | 512 | 146.92M | 4.382% |
317
+ | Compressed | 8x | 256 | 256 | 73.51M | 2.192% |
318
+ | Compressed | 16x | 128 | 128 | 36.80M | 1.098% |
319
+ | Uncompressed | 2x | 1024 | 2048 | 386.02M | 11.512% |
320
+ | Uncompressed | 4x | 512 | 2048 | 285.34M | 8.509% |
321
+ | Uncompressed | 8x | 256 | 2048 | 234.99M | 7.008% |
322
+ | Uncompressed | 16x | 128 | 2048 | 209.82M | 6.257% |
323
+
324
+ Note: The uncompressed stale method's parameter count does not scale down as aggressively because the `StaleDecompressor` input always includes the full 2048-dim stale signal, making the `(2048 × 2048)` weight block dominant.
325
+
326
+ **Perplexity evaluation with stale hooks:** During forward pass, a shared `stale_cache` dictionary stores reference layer inputs. PyTorch processes layers 0→47 sequentially, so layer 0's pre-hook fires before layer 1's, guaranteeing the stale cache is populated in time.
327
+
328
+ **Implementation:** `src/run_stale_compressor.py`, `evaluate_perplexity_with_stale_compression()` in `src/model_utils.py`.
329
+
330
+ ---
331
+
332
+ ### 5.5 End-to-End Per-Layer Compressor (Tasks 5a/5b)
333
+
334
+ **Motivation:** All offline methods (Tasks 3–4) share a fundamental limitation: each compressor is trained to minimize *local* reconstruction error in isolation. It cannot account for how its errors compound through downstream layers during a full forward pass. Additionally, the stale signal used during offline training is the *unperturbed* reference layer input, but during inference the reference layer itself is compressed, creating a train-inference mismatch.
335
+
336
+ End-to-end training addresses both issues by optimizing compressors through the full LLM forward pass using the language modeling objective.
337
+
338
+ **Architecture:** Same `Compressor` + `Decompressor` (5a) or `Compressor` + `StaleDecompressor` (5b) structure as Tasks 3b/4b. The compressor modules are identical — only the training objective differs.
339
+
340
+ **Training paradigm:**
341
+ 1. Load the LLM in full BF16 across 4 GPUs. Freeze all LLM weights.
342
+ 2. Insert per-layer compressor/decompressor pairs as forward pre-hooks on each MoE layer. Each pair is placed on the same GPU as its MoE layer.
343
+ 3. Run standard next-token prediction on training data. Only compressor/decompressor parameters receive gradients.
344
+ 4. Gradients flow backward through the entire frozen LLM, from the cross-entropy loss at the output back through all 48 layers to every compressor.
345
+
346
+ **Key difference from offline: joint optimization.** All 48 compressors share a single loss function (cross-entropy). Layer 0's compressor receives gradient signal about how its reconstruction error affects layers 1–47. The system implicitly learns to allocate more fidelity to layers where errors are most harmful to the final prediction.
347
+
348
+ **Stale signal gradient flow (5b):** Unlike offline Task 4b where the stale signal is pre-computed and frozen, end-to-end training does **not** detach the stale signal. Gradients flow through the stale path:
349
+ - A non-reference layer's decompressor receives `cat(compressed_current, stale)` where `stale` is the raw input to the reference layer
350
+ - During backward, gradients flow from the non-ref layer through `stale` to the reference layer's input, and further back to earlier layers
351
+ - This means reference layers' compressors are optimized not just for their own reconstruction, but also for how their inputs serve as stale side information for all downstream non-reference layers
352
+ - This eliminates the train-inference mismatch: during training, the stale signal already reflects upstream compression artifacts
353
+
354
+ **Near-identity initialization:**
355
+ - Compressor `W_c`: first `bottleneck_dim` rows of the identity matrix
356
+ - Decompressor `W_d`: first `bottleneck_dim` columns of the identity matrix
357
+ - Composition `W_d @ W_c ≈ I` (projects to first `b` dimensions and reconstructs)
358
+ - This ensures the initial forward pass is close to uncompressed, avoiding catastrophic initial loss from random projections. The optimizer then refines from this starting point.
359
+
360
+ **Model and data:**
361
+ - **Model:** Qwen/Qwen3-30B-A3B-Instruct-2507 (full BF16, same as all tasks)
362
+ - **Training data:** allenai/Dolci-Instruct-SFT, 500K sequences (HF) / 100K sequences (Megatron) sampled from train split,
363
+ max_length=2048 tokens per sequence
364
+ - **SFT mode:** Each conversation is tokenized independently (one sample = one sequence).
365
+ Labels mask non-assistant tokens with -100; loss is computed on assistant responses only.
366
+ Data is loaded by sampling N sequences from the dataset (not by packing tokens).
367
+ - **Evaluation:** allenai/Dolci-Instruct-SFT (same dataset, response-only perplexity)
368
+
369
+ **Two modes:**
370
+
371
+ | Mode | Task | Stale signal | Decompressor |
372
+ |---|---|---|---|
373
+ | No stale (5a) | `--stale-mode none` | None | `Decompressor(bottleneck_dim, 2048)` |
374
+ | Uncompressed stale (5b) | `--stale-mode uncompressed` | Raw ref layer input | `StaleDecompressor(bottleneck_dim, 2048, 2048)` |
375
+
376
+ **Training hyperparameters:**
377
+
378
+ | Parameter | Value |
379
+ |---|---|
380
+ | Optimizer | AdamW |
381
+ | Learning rate | 1e-4 |
382
+ | Weight decay | 0.01 |
383
+ | LR schedule | Cosine with 10% linear warmup |
384
+ | Max epochs | 1 |
385
+ | Batch size | 2 (gradient accumulation: 8, effective: 16) |
386
+ | Gradient clipping | max_norm = 1.0 |
387
+ | Early stopping patience | 5 epochs |
388
+ | Validation interval | Every 2500 optimizer steps (HF) / 1000 (Megatron) (configurable via `--val-interval`) |
389
+ | Validation batch size | 8 (configurable via `--val-batch-size`; larger than train because no backward) |
390
+ | Validation fraction | 10% |
391
+ | Max sequence length | 2048 (configurable via `--max-length`) |
392
+ | Loss function | Cross-entropy (response tokens only, SFT mode) |
393
+
394
+ Note the lower learning rate (1e-4 vs 1e-3 for offline) — the LM loss landscape propagates gradients through 48 frozen transformer layers, requiring more conservative updates.
395
+
396
+ **Tail micro-batch handling:** When `len(dataloader) % grad_accum != 0`, the remaining micro-batches
397
+ have their accumulated gradients rescaled by `grad_accum / remainder` (correcting the divisor from
398
+ `1/grad_accum` to `1/remainder`) before performing a final optimizer step. This ensures no training
399
+ data is discarded. Applied to both HF (`run_e2e_compressor.py`) and Megatron (`train.py`).
400
+
401
+ **Two evaluation stages (different data, different code paths):**
402
+
403
+ | Stage | Split | Batch size | Function | Purpose |
404
+ |---|---|---|---|---|
405
+ | Training-time val | VAL (50K seqs) | `--val-batch-size` (default 8) | `evaluate_val_loss()` in training script | Checkpoint selection, wandb monitoring |
406
+ | Final PPL | TEST (50K seqs) | 1 (per-sample) | `evaluate_perplexity()` in `model_utils.py` | Reported results |
407
+
408
+ The training-time validation runs every `--val-interval` optimizer steps and at epoch end, using the VAL split. It drives best-checkpoint selection. The final perplexity evaluation runs after training on the held-out TEST split (never seen during training or checkpoint selection) and produces the numbers reported in the results tables. These are separate code paths — `--val-batch-size` only affects the training-time evaluation.
409
+
410
+ **Parameter count:** Same as Tasks 3b (5a) and 4b-uncompressed (5b):
411
+
412
+ | Mode | Ratio | Bottleneck | Parameters | % of Activated |
413
+ |---|---|---|---|---|
414
+ | No stale (5a) | 2x | 1024 | 201.47M | 6.008% |
415
+ | No stale (5a) | 4x | 512 | 100.79M | 3.006% |
416
+ | No stale (5a) | 8x | 256 | 50.44M | 1.504% |
417
+ | No stale (5a) | 16x | 128 | 25.27M | 0.754% |
418
+ | Uncompressed stale (5b) | 2x | 1024 | 386.02M | 11.512% |
419
+ | Uncompressed stale (5b) | 4x | 512 | 285.34M | 8.509% |
420
+ | Uncompressed stale (5b) | 8x | 256 | 234.99M | 7.008% |
421
+ | Uncompressed stale (5b) | 16x | 128 | 209.82M | 6.257% |
422
+
423
+ **Multi-GPU setup:**
424
+ - Model distributed across 4 GPUs via `device_map="auto"` (~15 GB/GPU)
425
+ - Gradient checkpointing enabled (`use_reentrant=False`) to reduce activation memory
426
+ - 8 GPUs available → 05a and 05b run in parallel on separate GPU sets (GPUs 0-3 and 4-7)
427
+ - Each compressor is automatically placed on the same GPU as its MoE layer
428
+
429
+ **Implementation:** `src/run_e2e_compressor.py`, `scripts/05_run_e2e_compressor.sh`.
430
+
431
+ ---
432
+
433
+ ### 5.6 Megatron-LM E2E Training (Task 5 — Megatron variant)
434
+
435
+ **Motivation:** The HuggingFace-based Task 5 uses `device_map="auto"` for naive layer-sharded model parallelism. Only one GPU is active at a time during forward pass (sequential layer execution), with no tensor or data parallelism. This limits training throughput and cannot scale to multi-node.
436
+
437
+ **Approach:** Replace HuggingFace with Megatron-LM to get proper tensor parallelism (TP), expert parallelism (EP), and data parallelism (DP):
438
+ - All 4 GPUs active simultaneously via TP (each GPU holds shards of every layer)
439
+ - Multi-node scaling via DP across nodes + TP within nodes
440
+ - Megatron's optimized kernels (fused LayerNorm, FlashAttention, etc.)
441
+
442
+ **Compressor/decompressor placement:**
443
+
444
+ In real expert parallelism, the compressor and decompressor are on DIFFERENT GPUs:
445
+ - **Compressor:** Same GPU as attention output (source GPU where token originates)
446
+ - **Decompressor:** Same GPU as MoE expert (destination GPU after dispatch)
447
+
448
+ This is more realistic than the HF hook-based simulation where the router sees compressed-then-decompressed input. With Megatron, the router sees the ORIGINAL hidden state; only the dispatch is compressed.
449
+
450
+ **Phase A (TP only, EP=1):** Compressor and decompressor on same GPU (same as current HF approach). TP=4 shards each layer across 4 GPUs.
451
+
452
+ **Phase B (with EP):** Compressor on attention GPU, decompressor on expert GPU. MoE dispatch sends compressed tokens (reduced all-to-all volume). The `CompressedMoETokenDispatcher` wraps Megatron's dispatcher to:
453
+ 1. Compress on source GPU (attention side)
454
+ 2. Dispatch compressed tokens (smaller all-to-all)
455
+ 3. Decompress on destination GPU (expert side)
456
+
457
+ **Training pipeline:**
458
+ 1. Convert Qwen3-30B-A3B from HF format to Megatron format via Megatron Bridge
459
+ 2. Load with TP=4 (each GPU holds ~15-20 GB of sharded weights)
460
+ 3. Freeze all LLM parameters
461
+ 4. Insert per-layer compressor/decompressor pairs at MoE boundaries
462
+ 5. Train compressors via language modeling objective (same as HF Task 5)
463
+ 6. Save compressor weights (from rank 0, since all TP ranks have identical copies)
464
+
465
+ **TP-aware loss computation:** `MegatronModelWrapper._compute_loss()` uses
466
+ `vocab_parallel_cross_entropy` when TP > 1. SFT labels (-100) are clamped to 0 before
467
+ the call (avoiding garbage per-token loss for masked positions), and loss is computed as
468
+ `(per_token_loss * loss_mask).sum() / num_valid`. The non-TP path uses PyTorch's
469
+ `cross_entropy(ignore_index=-100)` which handles masking internally.
470
+
471
+ **Evaluation:** Uses existing HF-based evaluation code — load trained compressor weights into `E2ECompressorManager` and evaluate perplexity with hook-based simulation.
472
+
473
+ **Parallelism strategies:**
474
+
475
+ | Hardware | Configuration | Notes |
476
+ |---|---|---|
477
+ | 4 GPUs | TP=4, EP=1, PP=1, DP=1 | All GPUs active via tensor parallelism |
478
+ | 8 GPUs | TP=4, EP=1, PP=1, DP=2 | TP within 4 GPUs, DP across 2 replicas |
479
+ | N nodes × 4 GPUs | TP=4, DP=N | TP within node (NVLink), DP across nodes |
480
+ | EP variant | TP=2, EP=2, PP=1, DP=1 | Compressor on TP ranks, decompressor on EP ranks |
481
+
482
+ **Compressor weights with TP:**
483
+ - Compressors are replicated on all TP ranks (not sharded)
484
+ - Input is full hidden state (post-attention all-reduce)
485
+ - Gradients identical across ranks — no extra all-reduce needed
486
+ - Save from rank 0 only
487
+
488
+ **Implementation:** `src/megatron_e2e/` package with EP-first parallelism (EP=4, TP=1), CUDA 12.9, Megatron Bridge 0.2+, Transformer Engine. Entry point: `src/megatron_e2e/train.py`, bash wrapper: `scripts/05_megatron_e2e.sh`, setup: `scripts/megatron_setup_env.sh`. Multi-node: `scripts/05_megatron_e2e_multinode.sh`.
489
+
490
+ ---
491
+
492
+ ### 5.7 Baseline E2E Evaluation (Task 5c)
493
+
494
+ **Motivation:** Tasks 5a/5b report perplexity relative to an "untrained baseline" (the original model evaluated on the same test data). However, 5a/5b's training pipeline also loads and processes data through `load_e2e_data()`, computes SFT-masked loss on train/val splits, and may differ subtly from a raw model evaluation. Task 5c runs the exact same pipeline (same data loading, same loss computation, same evaluation) but WITHOUT inserting any compressors. This provides:
495
+
496
+ 1. **Train/val loss context:** If 5c's train loss is ~1.0, and 5a-2x's is 1.11, the compression overhead is only +0.11 — not the raw 1.11 value.
497
+ 2. **Pipeline consistency:** Confirms that the data pipeline itself does not introduce artifacts.
498
+ 3. **Fair comparison:** All three (5a, 5b, 5c) use identical code paths except for the compression hooks.
499
+
500
+ **What it does:**
501
+ - Loads data via `load_e2e_data()` (same function as 5a/5b)
502
+ - Evaluates train and val loss using `evaluate_loss_no_hooks()` — same as `evaluate_val_loss()` but without a compressor manager
503
+ - Evaluates baseline PPL on the TEST split (same as 5a/5b)
504
+ - No training, no compression ratios, no weight files
505
+
506
+ **Implementation:** Added as `--stale-mode baseline` to both `src/run_e2e_compressor.py` (HF) and `src/megatron_e2e/train.py` (Megatron). Output dirs: `results/05c_e2e_baseline/` (HF), `results/05c_megatron_e2e_baseline/` (Megatron).
507
+
508
+ ---
509
+
510
+ ### 5.8 E2E with Pretrained Init (Tasks 6a/6b)
511
+
512
+ **Motivation:** Tasks 5a/5b initialize compressor/decompressor weights with a near-identity matrix — the first `bottleneck_dim` dimensions are preserved, and the rest are zeroed out. This is a reasonable starting point but the optimizer must learn the full compression mapping from scratch using only the LM loss signal.
513
+
514
+ Tasks 3b and 4b already train compressors to minimize reconstruction loss on cached hidden states. While this offline objective doesn't directly optimize for LM quality, the resulting weights encode the structure of hidden-state distributions and provide a potentially better starting point for E2E fine-tuning.
515
+
516
+ Tasks 6a/6b test this hypothesis: does initializing E2E training from reconstruction-optimized weights (instead of near-identity) lead to faster convergence or better final quality?
517
+
518
+ **Architecture:** Identical to Tasks 5a/5b — same `Compressor`, `Decompressor`, `StaleDecompressor` classes, same training objective (cross-entropy), same hyperparameters. The only difference is the initial weight values.
519
+
520
+ **Two modes:**
521
+
522
+ | Mode | Task | Init from | Stale signal |
523
+ |---|---|---|---|
524
+ | No stale (6a) | `--stale-mode none --init-weights-dir results/03b_perlayer_compressor` | Task 3b (per-layer offline) | None |
525
+ | Uncompressed stale (6b) | `--stale-mode uncompressed --init-weights-dir results/04b_stale_uncompressed` | Task 4b (stale offline) | Raw ref layer input |
526
+
527
+ **Weight compatibility:** Tasks 3b/4b save weights keyed by HF layer names (`model.layers.N.mlp`) with `compressor` and `decompressor` sub-keys. The `MegatronCompressorManager.load_weights()` expects the same format (it converts Megatron names to HF names via `_megatron_to_hf_layer_name()`). The offline and E2E architectures use identical module classes, so `load_state_dict()` works directly.
528
+
529
+ **Parameter count:** Same as Tasks 5a/5b (identical architecture).
530
+
531
+ **Training hyperparameters:** Same as Tasks 5a/5b (same LR, warmup, epochs, etc.).
532
+
533
+ **Implementation:** Added `--init-weights-dir` argument to `src/megatron_e2e/train.py`. Auto-detects weight file naming pattern. Bash wrapper: `scripts/06_megatron_e2e_pretrained.sh`. Output dirs: `results/06a_megatron_e2e_pretrained_perlayer/` (6a), `results/06b_megatron_e2e_pretrained_stale/` (6b).
534
+
535
+ ### 5.9 Split-Mode E2E Training (Tasks 7a/7b)
536
+
537
+ **Motivation:** Tasks 5/6 use forward pre-hooks that compress→decompress the MoE input — both the router AND experts see the decompressed hidden state. This is a conservative lower bound on quality. In real expert parallelism, the router runs on the source GPU with the **original** hidden state (before compression), and only experts on the destination GPU see the decompressed version. Task 7 trains the compressor under this more realistic "split mode" to see whether the training signal improves when the router is not degraded by compression artifacts.
538
+
539
+ **Approach — Two-Level Pre-Hooks:**
540
+
541
+ Instead of monkey-patching MoE forward methods, two pre-hooks are registered per MoE layer:
542
+
543
+ 1. **MoE pre-hook:** Saves the original input, then returns the compress→decompress result. The MoE module's `forward()` receives the decompressed tensor as its input.
544
+ 2. **Router pre-hook:** Registered on the router/gate submodule. When the MoE's `forward()` calls `self.gate(hidden_states)`, this hook intercepts and swaps the input back to the saved original.
545
+
546
+ This works because:
547
+ - The MoE pre-hook changes what `forward()` receives (decompressed), so experts get decompressed data.
548
+ - The router pre-hook only affects the `gate` submodule's input, restoring the original.
549
+ - PyTorch hook execution order: MoE pre-hook runs first (on the outer module), then when `forward()` calls `self.gate(...)` internally, the gate pre-hook runs and swaps the argument.
550
+
551
+ **Two modes:**
552
+
553
+ | Mode | Task | Init from | Stale signal | Router input |
554
+ |---|---|---|---|---|
555
+ | No stale (7a) | `--stale-mode none --router-mode uncompressed --init-weights-dir results/03b_perlayer_compressor` | Task 3b | None | Original |
556
+ | Uncompressed stale (7b) | `--stale-mode uncompressed --router-mode uncompressed --init-weights-dir results/04b_stale_uncompressed` | Task 4b | Raw ref input | Original |
557
+
558
+ **Architecture:** Identical to Tasks 6a/6b — same classes, same init weights, same hyperparameters. The only difference is that `router_mode="uncompressed"` activates the two-level hook pattern during training and evaluation.
559
+
560
+ **Implementation:** Added `--router-mode` argument to `src/megatron_e2e/train.py` and `src/run_e2e_compressor.py`. Split-mode hooks added to `MegatronCompressorManager` (Megatron training) and `evaluate_perplexity_with_perlayer_compression`/`evaluate_perplexity_with_stale_compression` (HF PPL evaluation). Bash wrapper: `scripts/07_megatron_e2e_split.sh`. Output dirs: `results/07a_megatron_e2e_split_perlayer/` (7a), `results/07b_megatron_e2e_split_stale/` (7b).
561
+
562
+ ---
563
+
564
+ ## 6. Results
565
+
566
+ ### 6.1 Summary Table — All Methods
567
+
568
+ **Model:** Qwen3-30B-A3B-Instruct-2507 (full BF16)
569
+ **Dataset:** allenai/Dolci-Instruct-SFT
570
+
571
+ | Method | Ratio | MSE | CosSim | PPL | PPL Delta | HF Strict | HF Flex |
572
+ |---|---|---|---|---|---|---|---|
573
+ | Baseline (Tasks 2–4) | — | — | — | 3.89 | — | 44.12% | 82.79% |
574
+ | Baseline (5c / Megatron) | — | — | — | 3.94 | — | 44.12% | 82.79% |
575
+ | Quant INT8 | 2.0x | — | — | 3.90 | +0.01 | 48.90% | 82.26% |
576
+ | Quant INT4 | 4.0x | — | — | 4.51 | +0.62 | 56.41% | 68.54% |
577
+ | Quant INT2 | 8.0x | — | — | 1532.59 | +1528.70 | 0.00% | 0.00% |
578
+ | Neural (per-layer) | 2x | 0.0535 | 0.922 | 21.07 | +17.18 | 0.00% | 1.52% |
579
+ | Neural (per-layer) | 4x | 0.1073 | 0.835 | 425.75 | +421.87 | 0.00% | 0.00% |
580
+ | Neural (per-layer) | 8x | 0.1523 | 0.755 | 7949.78 | +7945.89 | 0.00% | 0.00% |
581
+ | Neural (per-layer) | 16x | 0.1893 | 0.683 | 52440.05 | +52436.16 | 0.00% | 0.00% |
582
+ | Stale-cond. (compressed) | 2x | 0.0379 | 0.947 | 6.13 | +2.24 | 3.41% | 62.55% |
583
+ | Stale-cond. (compressed) | 4x | 0.0876 | 0.869 | 31.64 | +27.75 | 0.61% | 1.52% |
584
+ | Stale-cond. (compressed) | 8x | 0.1330 | 0.791 | 2982.23 | +2978.34 | 0.00% | 0.00% |
585
+ | Stale-cond. (compressed) | 16x | 0.1720 | 0.717 | 17486.21 | +17482.32 | 0.00% | 0.00% |
586
+ | Stale-cond. (uncompressed) | 2x | 0.0346 | 0.952 | 6.24 | +2.36 | 2.81% | 67.10% |
587
+ | Stale-cond. (uncompressed) | 4x | 0.0690 | 0.900 | 16.11 | +12.22 | 0.99% | 6.14% |
588
+ | Stale-cond. (uncompressed) | 8x | 0.0966 | 0.855 | 423.68 | +419.79 | 0.00% | 0.00% |
589
+ | Stale-cond. (uncompressed) | 16x | 0.1173 | 0.819 | 3740.41 | +3736.53 | 0.00% | 0.00% |
590
+ | Megatron E2E per-layer (5a) | 2x | — | — | 2.77 | -1.17 | 61.33% | 61.64% |
591
+ | Megatron E2E per-layer (5a) | 4x | — | — | 4.28 | +0.35 | 20.70% | 21.30% |
592
+ | Megatron E2E per-layer (5a) | 8x | — | — | 7.49 | +3.55 | 1.82% | 2.12% |
593
+ | Megatron E2E per-layer (5a) | 16x | — | — | 11.26 | +7.33 | 0.91% | 2.73% |
594
+ | Megatron E2E stale (5b) | 2x | — | — | 2.71 | -1.23 | 60.27% | 60.65% |
595
+ | Megatron E2E stale (5b) | 4x | — | — | 3.61 | -0.33 | 31.54% | 32.37% |
596
+ | Megatron E2E stale (5b) | 8x | — | — | 4.98 | +1.04 | 4.93% | 5.00% |
597
+ | Megatron E2E stale (5b) | 16x | — | — | 6.34 | +2.41 | 2.12% | 2.27% |
598
+ | Megatron E2E pretrained per-layer (6a) | 2x | — | — | 2.41 | -1.53 | 79.98% | 80.06% |
599
+ | Megatron E2E pretrained per-layer (6a) | 4x | — | — | 3.18 | -0.76 | 55.04% | 55.19% |
600
+ | Megatron E2E pretrained per-layer (6a) | 8x | — | — | 4.52 | +0.58 | 16.98% | 16.98% |
601
+ | Megatron E2E pretrained per-layer (6a) | 16x | — | — | 7.34 | +3.40 | 2.27% | 2.27% |
602
+ | Megatron E2E pretrained stale (6b) | 2x | — | — | 2.25 | -1.69 | 82.49% | 82.64% |
603
+ | Megatron E2E pretrained stale (6b) | 4x | — | — | 2.57 | -1.37 | 64.37% | 64.52% |
604
+ | Megatron E2E pretrained stale (6b) | 8x | — | — | 3.04 | -0.90 | 45.79% | 45.94% |
605
+ | Megatron E2E pretrained stale (6b) | 16x | — | — | 3.47 | -0.47 | 25.85% | 25.85% |
606
+ | Split-mode E2E per-layer (7a) | 2x | — | — | 2.58 | -1.31 | 79.91% | 79.98% |
607
+ | Split-mode E2E per-layer (7a) | 4x | — | — | 3.72 | -0.17 | 42.08% | 42.15% |
608
+ | Split-mode E2E per-layer (7a) | 8x | — | — | 6.43 | +2.54 | 4.93% | 5.46% |
609
+ | Split-mode E2E per-layer (7a) | 16x | — | — | 908.20 | +904.31 | 0.00% | 0.53% |
610
+ | Split-mode E2E stale (7b) | 2x | — | — | 2.34 | -1.55 | 80.67% | 80.67% |
611
+ | Split-mode E2E stale (7b) | 4x | — | — | 2.80 | -1.09 | 65.81% | 65.96% |
612
+ | Split-mode E2E stale (7b) | 8x | — | — | 3.37 | -0.51 | 35.63% | 35.63% |
613
+ | Split-mode E2E stale (7b) | 16x | — | — | 4.28 | +0.39 | 16.53% | 16.68% |
614
+
615
+ Note: Tasks 2–4 and 5c baselines differ in PPL (3.89 vs 3.94) due to different evaluation
616
+ code paths (single-GPU HF vs Megatron pipeline). PPL deltas for offline methods use 3.89;
617
+ E2E methods use 3.94. HF Strict/Flex: GSM8K evaluated via HF backend (lm-eval-harness,
618
+ router-compressed mode). For Tasks 7a/7b, HF Strict/Flex is compressed-router only.
619
+ Uncompressed-router results for Tasks 7a/7b are in a dedicated table below Section 6.4.
620
+ GSM8K scores are identical for both baselines because GSM8K evaluation uses the same raw HF
621
+ model. GSM8K uses Megatron-trained weights for E2E methods. "Strict" requires exact
622
+ `#### <number>` format; "flexible" extracts the number from anywhere in the output.
623
+ HF-trained E2E weights (Tasks 5a/5b) were not available.
624
+
625
+ ### 6.2 Key Findings
626
+
627
+ 1. **E2E training is transformative** — E2E methods achieve PPL *below* baseline (3.94) at 2x. E2E stale stays below baseline at 4x (PPL=3.61).
628
+ 2. **E2E stale at 16x is moderate** — PPL=6.34 (+2.41), 61% above baseline, with GSM8K strict-match at 2.12%.
629
+ 3. **E2E dramatically outperforms offline** — Same architecture, same params: offline per-layer 4x PPL=425.75 vs E2E 4x PPL=4.28 (99x better). At 16x: 52440 vs 11.26 (4658x better).
630
+ 4. **Stale conditioning matters more at high compression** — At 2x the gap is small (E2E stale 2.71 vs E2E per-layer 2.77), but at 16x it's 1.8x (6.34 vs 11.26).
631
+ 5. **INT8 quantization is nearly lossless** — PPL 3.90 vs baseline 3.89 at 2x (+0.01), with GSM8K preserved (48.90% strict, 82.26% flexible).
632
+ 6. **INT4 quantization is acceptable** — PPL 4.51 at ~4x (+0.62 delta). GSM8K strict-match actually improves to 56.41%.
633
+ 7. **INT2 is catastrophic** — PPL 1533 at ~8x, completely unusable.
634
+ 8. **Offline methods degrade rapidly** — Per-layer neural: PPL=21 at 2x, PPL=425 at 4x, PPL=7950 at 8x. Stale-conditioning (uncompressed) helps at 2x (PPL=6.24) but collapses at 8x (PPL=424).
635
+ 9. **Below-baseline PPL** suggests E2E compressors act as regularizers, filtering noise from hidden states while preserving task-relevant information. Confirmed by GSM8K: E2E 2x scores 61.33% vs baseline 44.12%.
636
+ 10. **Downstream tasks are more sensitive than PPL** — Offline stale_uncomp_2x has PPL=6.24 (+2.36) but GSM8K drops from 44% to 3% strict-match. E2E methods maintain both PPL and GSM8K. See Section 6.4.
637
+ 11. **Offline compression destroys output format but partially preserves reasoning** — stale_uncomp_2x: 2.81% strict but 67.10% flexible-extract. E2E methods show no such gap (~0.3 pp).
638
+ 12. **Pretrained init (Task 6) dramatically improves E2E training** — Initializing from offline-trained weights (Tasks 3b/4b) instead of near-identity gives 13–45% PPL improvement and massive GSM8K gains. 6b at 2x achieves PPL=2.25 and 82.5% GSM8K strict-match (vs 5b: PPL=2.71, 60.3%). Even at 16x, 6b (PPL=3.47, GSM8K 25.9%) stays below baseline PPL (3.89) and retains meaningful downstream accuracy.
639
+ 13. **Pretrained init benefits grow with compression ratio** — For stale-conditioned (6b vs 5b): PPL improvement goes from 17% at 2x to 45% at 16x; GSM8K goes from +22 pp at 2x to +24 pp at 16x. The offline-trained weights provide a much better starting point for E2E optimization, especially at high compression where near-identity init struggles.
640
+ 14. **Split-mode training (Task 7) matches deployment reality** — Training with split-mode (router sees original, experts see decompressed) then evaluating in the same mode yields the best uncompressed-router results. 7b uncompressed at 2x achieves 83.3% GSM8K strict-match — the best result across all methods and modes.
641
+ 15. **7b uncompressed stays below baseline PPL at ALL ratios** — Even at 16x compression, 7b uncompressed PPL=3.27 remains below the no-compression baseline (3.89). This is the only method to maintain below-baseline PPL at every compression ratio, demonstrating that stale-conditioned split-mode E2E compressors can be simultaneously lossy (16x compression) and beneficial (regularization effect).
642
+ 16. **Split-mode training trades compressed-eval quality for uncompressed-eval quality** — 7a/7b compressed-eval PPL is worse than 6a/6b (e.g., 7a 16x compressed: 908 vs 6a: 8.49) because the model was not trained to have the router see decompressed data. But 7a/7b uncompressed-eval is better (7a 16x uncompressed: 6.64 vs 6a compressed: 8.49). This confirms the training mode should match the deployment mode.
643
+ 17. **Catastrophic collapse at extreme compression without stale** — 7a 16x compressed PPL=908 (vs 7a 16x uncompressed=6.64), showing that when per-layer compression is too lossy, correct routing (from original hidden states) becomes critical. Stale conditioning (7b) avoids this entirely: 7b 16x compressed=4.28, uncompressed=3.27.
644
+
645
+ ### 6.3 HF vs Megatron Comparison
646
+
647
+ **Note:** HF E2E results in this section are from an earlier training run. The HF E2E
648
+ weight files are no longer available in the current `results/05a_e2e_perlayer/` and
649
+ `results/05b_e2e_stale/` directories (only logs remain). The Megatron results are
650
+ from the current run and match the JSON files. The comparison below is preserved for
651
+ historical reference but the HF numbers cannot be independently verified from current data.
652
+
653
+ Both implementations use the same compressor architecture (Compressor + Decompressor / StaleDecompressor), the same model (Qwen3-30B-A3B-Instruct-2507), and the same training data (Dolci-Instruct-SFT). The key differences are in the distributed training strategy and model parallelism framework.
654
+
655
+ **Implementation differences:**
656
+
657
+ | Aspect | HuggingFace | Megatron |
658
+ |---|---|---|
659
+ | Framework | HF Transformers + `device_map="auto"` | Megatron-Core + AutoBridge |
660
+ | Parallelism | Naive layer sharding (sequential) | EP=4, TP=1, PP=1, DP=4 |
661
+ | GPU utilization | 1 GPU active at a time | All 4 GPUs active (DP) |
662
+ | Data parallelism | None (single data stream) | DP=4 (each rank sees 1/4 of data per step) |
663
+ | Optimizer | AdamW (single replica) | AdamW (replicated, gradients all-reduced) |
664
+ | CUDA | 12.6 | 12.9 |
665
+
666
+ **Task 5a — E2E per-layer (no stale):**
667
+
668
+ | Ratio | HF PPL | Megatron PPL | Gap (Meg−HF) |
669
+ |---|---|---|---|
670
+ | 2x | **2.645** (−1.58) | 2.682 (−1.54) | +0.04 |
671
+ | 4x | **3.687** (−0.54) | 4.410 (+0.19) | +0.72 |
672
+ | 8x | **6.371** (+2.15) | 8.182 (+3.96) | +1.81 |
673
+ | 16x | **9.157** (+4.93) | 11.670 (+7.44) | +2.51 |
674
+
675
+ **Task 5b — E2E stale-conditioned (uncompressed stale):**
676
+
677
+ | Ratio | HF PPL | Megatron PPL | Gap (Meg−HF) |
678
+ |---|---|---|---|
679
+ | 2x | 2.570 (−1.65) | **2.568** (−1.66) | −0.00 |
680
+ | 4x | **3.102** (−1.12) | 3.420 (−0.80) | +0.32 |
681
+ | 8x | **4.015** (−0.21) | 4.743 (+0.52) | +0.73 |
682
+ | 16x | **4.550** (+0.32) | 5.232 (+1.01) | +0.68 |
683
+
684
+ **Training losses (train / val):**
685
+
686
+ | Config | HF 5a | Megatron 5a | HF 5b | Megatron 5b |
687
+ |---|---|---|---|---|
688
+ | 2x | 1.215 / 1.093 | 1.258 / 1.109 | 1.193 / 1.070 | 1.210 / 1.068 |
689
+ | 4x | 1.786 / 1.447 | 2.103 / 1.627 | 1.579 / 1.286 | 1.784 / 1.375 |
690
+ | 8x | 2.412 / 2.004 | 2.776 / 2.242 | 1.921 / 1.555 | 2.206 / 1.724 |
691
+ | 16x | 2.768 / 2.326 | 3.180 / 2.567 | 2.069 / 1.686 | 2.344 / 1.823 |
692
+
693
+ **Analysis:**
694
+
695
+ 1. **At 2x, both implementations converge to the same quality.** The gap is negligible (0.04 for 5a, −0.002 for 5b). Near-identity initialization gives a strong starting point, and 2x compression is easy enough that both optimizers find similar solutions.
696
+
697
+ 2. **Megatron's gap grows at higher compression ratios for 5a** (no stale). At 4x the gap is +0.72, at 16x it's +2.51. The likely cause is that Megatron with DP=4 provides each rank with 1/4 of the data per step — effectively a noisier gradient estimate. HF's single-replica training sees the full data stream, leading to a slightly better optimizer trajectory for harder problems (higher compression).
698
+
699
+ 3. **Stale conditioning dramatically narrows the Megatron-HF gap.** Adding stale conditioning reduces the gap by 50–73% at all ratios:
700
+ - 4x: +0.72 → +0.32 (56% reduction)
701
+ - 8x: +1.81 → +0.73 (60% reduction)
702
+ - 16x: +2.51 → +0.68 (73% reduction)
703
+ The stale signal acts as an anchor that partially corrects for the noisier optimization — it provides a strong prior about the expected hidden state, reducing the difficulty of the decompression task.
704
+
705
+ 4. **Both Megatron variants produce usable compressors.** Megatron 5b at 4x (PPL=3.42) is still 19% below baseline, and even at 16x (PPL=5.23) the degradation is only +24%. For production deployment where Megatron's scalability is needed, these results are practical.
706
+
707
+ 5. **Recommendation:** Use Megatron with stale conditioning (5b mode) for production. At 2–4x compression, results match HF quality. At 8–16x, there is a modest quality gap, but Megatron's multi-node scalability and proper expert parallelism make it the right choice for large-scale deployment.
708
+
709
+ ### 6.4 Downstream Task Evaluation (GSM8K)
710
+
711
+ **Benchmark:** GSM8K chain-of-thought (gsm8k_cot), 8-shot, 1319 test examples.
712
+ Two metrics: **strict-match** (exact `#### <number>` format) and **flexible-extract**
713
+ (number extracted from anywhere in the output via regex).
714
+ Two router modes: **compressed** (router AND experts see decompressed hidden states)
715
+ and **uncompressed** (router sees original, experts see decompressed — more realistic EP
716
+ simulation). PPL, MSE, CosSim from HF-based evaluation (`model_utils.py`).
717
+ HF Strict/Flex from HF backend (lm-eval-harness, router-compressed mode).
718
+ vLLM columns from vLLM backend (`run_all_downstream.py`, both router modes).
719
+ For Tasks 7a/7b, vLLM Uncomp. columns show HF backend uncompressed-router results
720
+ (confirmed identical via both `run_all_downstream.py` and `run_e2e_compressor.py
721
+ --router-mode uncompressed`).
722
+
723
+ | Method | Ratio | MSE | CosSim | PPL | PPL Δ | HF Strict | HF Flex | vLLM Comp. Strict | vLLM Comp. Flex | vLLM Uncomp. Strict | vLLM Uncomp. Flex |
724
+ |---|---|---|---|---|---|---|---|---|---|---|---|
725
+ | Baseline | — | — | — | 3.89 | — | 44.1% | 82.8% | 43.3% | 82.9% | — | — |
726
+ | Quant INT8 | 2x | — | — | 3.90 | +0.01 | 48.9% | 82.3% | 43.7% | 82.2% | — | — |
727
+ | Quant INT4 | 4x | — | — | 4.51 | +0.62 | 56.4% | 68.5% | 46.8% | 65.4% | — | — |
728
+ | Quant INT2 | 8x | — | — | 1532.59 | +1528.70 | 0.0% | 0.0% | 0.0% | 0.0% | — | — |
729
+ | Neural (per-layer) | 2x | 0.0535 | 0.922 | 21.07 | +17.18 | 0.0% | 1.5% | 0.0% | 1.2% | 22.7% | 42.6% |
730
+ | Neural (per-layer) | 4x | 0.1073 | 0.835 | 425.75 | +421.87 | 0.0% | 0.0% | 0.0% | 0.4% | 1.0% | 2.4% |
731
+ | Neural (per-layer) | 8x | 0.1523 | 0.755 | 7949.78 | +7945.89 | 0.0% | 0.0% | 0.0% | 0.0% | 2.0% | 1.9% |
732
+ | Neural (per-layer) | 16x | 0.1893 | 0.683 | 52440.05 | +52436.16 | 0.0% | 0.0% | 0.0% | 0.0% | 1.5% | 1.5% |
733
+ | Stale-cond. (compressed) | 2x | 0.0379 | 0.947 | 6.13 | +2.24 | 3.4% | 62.6% | 0.2% | 0.8% | 34.1% | 69.7% |
734
+ | Stale-cond. (compressed) | 4x | 0.0876 | 0.869 | 31.64 | +27.75 | 0.6% | 1.5% | 0.0% | 0.6% | 2.7% | 4.9% |
735
+ | Stale-cond. (compressed) | 8x | 0.1330 | 0.791 | 2982.23 | +2978.34 | 0.0% | 0.0% | 0.0% | 0.0% | 1.3% | 1.8% |
736
+ | Stale-cond. (compressed) | 16x | 0.1720 | 0.717 | 17486.21 | +17482.32 | 0.0% | 0.0% | 0.0% | 0.0% | 1.8% | 2.0% |
737
+ | Stale-cond. (uncompressed) | 2x | 0.0346 | 0.952 | 6.24 | +2.36 | 2.8% | 67.1% | 0.2% | 1.1% | 30.7% | 72.6% |
738
+ | Stale-cond. (uncompressed) | 4x | 0.0690 | 0.900 | 16.11 | +12.22 | 1.0% | 6.1% | 0.0% | 0.6% | 6.1% | 9.3% |
739
+ | Stale-cond. (uncompressed) | 8x | 0.0966 | 0.855 | 423.68 | +419.79 | 0.0% | 0.0% | 0.0% | 0.0% | 1.2% | 2.5% |
740
+ | Stale-cond. (uncompressed) | 16x | 0.1173 | 0.819 | 3740.41 | +3736.53 | 0.0% | 0.0% | 0.0% | 0.0% | 1.4% | 2.0% |
741
+ | E2E per-layer (5a) | 2x | — | — | 2.77 | −1.17 | 61.3% | 61.6% | 61.5% | 61.6% | 52.4% | 59.6% |
742
+ | E2E per-layer (5a) | 4x | — | — | 4.28 | +0.35 | 20.7% | 21.3% | 21.2% | 22.4% | 11.0% | 12.9% |
743
+ | E2E per-layer (5a) | 8x | — | — | 7.49 | +3.55 | 1.8% | 2.1% | 0.0% | 0.0% | 0.0% | 0.0% |
744
+ | E2E per-layer (5a) | 16x | — | — | 11.26 | +7.33 | 0.9% | 2.7% | 0.0% | 0.0% | 0.0% | 0.1% |
745
+ | E2E stale (5b) | 2x | — | — | 2.71 | −1.23 | 60.3% | 60.7% | 61.3% | 61.6% | 53.2% | 61.2% |
746
+ | E2E stale (5b) | 4x | — | — | 3.61 | −0.33 | 31.5% | 32.4% | 33.0% | 33.2% | 18.6% | 22.1% |
747
+ | E2E stale (5b) | 8x | — | — | 4.98 | +1.04 | 4.9% | 5.0% | 3.4% | 4.3% | 0.2% | 2.4% |
748
+ | E2E stale (5b) | 16x | — | — | 6.34 | +2.41 | 2.1% | 2.3% | 0.0% | 0.2% | 0.0% | 0.1% |
749
+ | E2E pretrained per-layer (6a) | 2x | — | — | 2.41 | −1.53 | 80.0% | 80.1% | 80.1% | 80.0% | 80.6% | 80.8% |
750
+ | E2E pretrained per-layer (6a) | 4x | — | — | 3.18 | −0.76 | 55.0% | 55.2% | 52.8% | 52.9% | 43.3% | 43.9% |
751
+ | E2E pretrained per-layer (6a) | 8x | — | — | 4.52 | +0.58 | 17.0% | 17.0% | 13.5% | 14.0% | 6.7% | 7.6% |
752
+ | E2E pretrained per-layer (6a) | 16x | — | — | 7.34 | +3.40 | 2.3% | 2.3% | 0.3% | 1.1% | 1.1% | 2.1% |
753
+ | E2E pretrained stale (6b) | 2x | — | — | 2.25 | −1.69 | 82.5% | 82.6% | 82.0% | 82.3% | 83.9% | 84.0% |
754
+ | E2E pretrained stale (6b) | 4x | — | — | 2.57 | −1.37 | 64.4% | 64.5% | 71.0% | 71.1% | 68.8% | 68.9% |
755
+ | E2E pretrained stale (6b) | 8x | — | — | 3.04 | −0.90 | 45.8% | 45.9% | 37.6% | 37.6% | 24.3% | 24.3% |
756
+ | E2E pretrained stale (6b) | 16x | — | — | 3.47 | −0.47 | 25.9% | 25.9% | 18.7% | 18.7% | 9.0% | 9.6% |
757
+ | Split E2E per-layer (7a) | 2x | — | — | 2.58 | −1.31 | 79.9% | 80.0% | — | — | 79.5% | 79.7% |
758
+ | Split E2E per-layer (7a) | 4x | — | — | 3.72 | −0.17 | 42.1% | 42.2% | — | — | 51.6% | 51.8% |
759
+ | Split E2E per-layer (7a) | 8x | — | — | 6.43 | +2.54 | 4.9% | 5.5% | — | — | 18.5% | 18.7% |
760
+ | Split E2E per-layer (7a) | 16x | — | — | 908.20 | +904.31 | 0.0% | 0.5% | — | — | 2.0% | 2.5% |
761
+ | Split E2E stale (7b) | 2x | — | — | 2.34 | −1.55 | 80.7% | 80.7% | — | — | 83.3% | 83.4% |
762
+ | Split E2E stale (7b) | 4x | — | — | 2.80 | −1.09 | 65.8% | 66.0% | — | — | 70.7% | 70.7% |
763
+ | Split E2E stale (7b) | 8x | — | — | 3.37 | −0.51 | 35.6% | 35.6% | — | — | 47.2% | 47.2% |
764
+ | Split E2E stale (7b) | 16x | — | — | 4.28 | +0.39 | 16.5% | 16.7% | — | — | 27.1% | 27.1% |
765
+
766
+ Notes: HF = HF backend (router-compressed mode). vLLM Comp. = vLLM backend, router-compressed
767
+ (router+experts see decompressed). vLLM Uncomp. = vLLM backend, router-uncompressed (router sees
768
+ original, experts see decompressed — split forward). For Tasks 7a/7b, HF Strict/Flex = HF backend
769
+ with compressed router; vLLM Uncomp. = HF backend with uncompressed router (confirmed identical
770
+ results from both `run_all_downstream.py` and `run_e2e_compressor.py --router-mode uncompressed`).
771
+ Baseline and quantization have no split mode. PPL baseline: 3.89 (offline) / 3.94 (E2E). GSM8K
772
+ uses Megatron-trained weights for E2E methods. Task 7 PPL column shows compressed-router PPL.
773
+ Uncompressed-router results (confirmed identical via both original eval code path and
774
+ `run_e2e_compressor.py --router-mode uncompressed`):
775
+
776
+ | Ratio | 7a PPL | 7b PPL | Baseline PPL | 7a Strict | 7a Flex | 7b Strict | 7b Flex |
777
+ |-------|--------|--------|--------------|-----------|---------|-----------|---------|
778
+ | 2x | 2.38 | 2.23 | 3.89 | 79.5% | 79.7% | 83.3% | 83.4% |
779
+ | 4x | 3.08 | 2.53 | 3.89 | 51.6% | 51.8% | 70.7% | 70.7% |
780
+ | 8x | 4.18 | 2.89 | 3.89 | 18.5% | 18.7% | 47.2% | 47.2% |
781
+ | 16x | 6.64 | 3.27 | 3.89 | 2.0% | 2.5% | 27.1% | 27.1% |
782
+
783
+ **Key findings:**
784
+
785
+ 1. **E2E compression improves GSM8K over baseline.** Baseline strict-match is 44.12%.
786
+ E2E per-layer 2x achieves 61.33% (+17.2 pp) and E2E stale 2x achieves 60.27%
787
+ (+16.2 pp). This mirrors the below-baseline PPL effect — E2E compressors act as
788
+ regularizers that improve both perplexity and downstream task performance.
789
+
790
+ 2. **INT8 and INT4 quantization also improve strict-match.** INT8: 48.90% (+4.8 pp),
791
+ INT4: 56.41% (+12.3 pp). The flexible-extract gap is smaller (INT8: 82.26% vs
792
+ baseline 82.79%), suggesting quantization noise may regularize the strict output
793
+ format without hurting reasoning.
794
+
795
+ 3. **Offline methods catastrophically fail on generation tasks.** Per-layer neural
796
+ compressors score 0% strict-match at all ratios (even 2x, which has PPL=21.07).
797
+ Stale-conditioned 2x scores only 2.81% strict / 67.10% flexible. The flexible-extract
798
+ score reveals that the model still produces correct numerical answers but the output
799
+ format is destroyed — compression disrupts the learned generation patterns.
800
+
801
+ 4. **The strict-vs-flexible gap reveals a format disruption effect.** Offline methods
802
+ show huge gaps: stale_uncomp_2x has 2.81% strict but 67.10% flexible (64.3 pp gap).
803
+ E2E methods show almost no gap: e2e_2x has 61.33% strict vs 61.64% flexible (0.3 pp).
804
+ End-to-end training preserves both the model's reasoning ability AND its output
805
+ formatting, while offline compression preserves some reasoning but destroys formatting.
806
+
807
+ 5. **GSM8K is more sensitive than PPL to compression quality.** Stale_uncomp_2x has
808
+ PPL=6.24 (only +2.36 above baseline) yet scores 2.81% on GSM8K strict-match (vs
809
+ 44.12% baseline). E2E per-layer 4x has PPL=4.28 (only +0.35 above baseline) yet
810
+ drops to 20.70% GSM8K. Generation tasks amplify small distributional shifts that
811
+ PPL barely registers.
812
+
813
+ 6. **Stale conditioning matters for downstream tasks.** At 4x: E2E stale gets 31.54%
814
+ vs E2E per-layer 20.70% (+10.8 pp). At 8x: stale gets 4.93% vs per-layer 1.82%.
815
+ The stale signal helps preserve generation quality, consistent with PPL findings.
816
+
817
+ 7. **Pretrained init (Task 6) yields dramatic GSM8K improvements.** 6b stale at 2x
818
+ achieves 82.49% strict-match — nearly double baseline (44.12%) and +22 pp over 5b
819
+ (60.27%). 6a per-layer at 2x reaches 79.98% (+19 pp over 5a). Even at 8x, 6b retains
820
+ 45.79% (exceeding baseline) while 5b collapses to 4.93%.
821
+
822
+ 8. **Pretrained init enables useful compression at 16x.** 6b at 16x achieves 25.85%
823
+ GSM8K strict-match — down from baseline (44.12%) but still practically useful. Compare
824
+ with 5b at 16x (2.12%) or 5a at 16x (0.91%). Offline weights provide the optimizer
825
+ with a much better starting region of parameter space.
826
+
827
+ 9. **Best overall result: 6b at 2–4x compression.** 6b at 2x (PPL=2.25, GSM8K=82.5%)
828
+ and 4x (PPL=2.57, GSM8K=64.4%) both outperform baseline on PPL and at 4x still retain
829
+ strong downstream performance. This suggests stale-conditioned E2E compression with
830
+ pretrained init is a viable approach for reducing MoE communication by 2–4x with
831
+ minimal or even improved model quality.
832
+
833
+ ---
834
+
835
+ ## 7. Design Choices and Trade-offs
836
+
837
+ ### 7.1 Offline Independent Training vs End-to-End
838
+
839
+ **Offline training (Tasks 2–4)** trains compressors on cached hidden states, independently per layer:
840
+
841
+ | Aspect | Offline | End-to-End (Task 5) |
842
+ |---|---|---|
843
+ | Loss | MSE + cosine (reconstruction) | Cross-entropy (next-token prediction) |
844
+ | Optimization scope | Per-layer, independent | Joint, all 48 layers |
845
+ | Gradient flow | None through LLM | Through entire frozen LLM |
846
+ | Stale signal | Pre-computed, frozen | Live, gradients flow through |
847
+ | Model precision | Full BF16 (~60 GB, 1 GPU) | Full BF16 (~60 GB, 4 GPUs) |
848
+ | Training cost | Minutes per layer | Hours for all layers + ratios |
849
+ | Error compounding | Not accounted for | Naturally optimized via global loss |
850
+
851
+ **Offline advantages:**
852
+ - Fast and cheap (minutes per layer on a single GPU)
853
+ - No need to backpropagate through the full LLM
854
+ - Each layer's compressor can be trained in parallel
855
+
856
+ **Offline limitations (addressed by e2e):**
857
+ - Compressors cannot adapt to how their reconstruction errors compound across layers. A small error at layer 0 may shift the hidden state distribution at layer 1, but layer 1's compressor was trained on the *original* layer-1 distribution.
858
+ - No joint optimization means the system cannot learn to allocate more capacity to layers where errors are most harmful.
859
+ - The stale signal used during offline training is the *unperturbed* reference input, but during inference the reference layer itself is compressed, creating a train-inference mismatch.
860
+
861
+ **E2E advantages:**
862
+ - Compressors are optimized for the actual downstream impact of compression on model quality.
863
+ - Joint optimization: the system implicitly learns which layers need higher fidelity.
864
+ - Stale gradients flow: reference layer compressors are optimized for their dual role (own reconstruction + stale side information for downstream layers). The stale signal during training already reflects upstream compression artifacts, eliminating the train-inference mismatch.
865
+
866
+ **E2E limitations:**
867
+ - Requires full-precision model in memory for proper gradient flow (~60 GB across 4 GPUs).
868
+ - Training is slower (full forward + backward through 48 frozen transformer layers per step).
869
+ - More hyperparameter-sensitive (LR, warmup, gradient clipping matter more).
870
+
871
+ ### 7.2 Linear vs Non-linear Compressors
872
+
873
+ All compressors are single-layer linear networks (no activation functions). This was a deliberate choice:
874
+ - Linear compressors are equivalent to learning an optimal projection/reconstruction pair (related to PCA)
875
+ - They are fast to train and apply (single matrix multiply)
876
+ - They establish a clean baseline before trying non-linear architectures
877
+
878
+ ### 7.3 Loss Function
879
+
880
+ The combined `MSE + 0.1 × (1 - cos_sim)` loss was chosen because:
881
+ - MSE alone can be dominated by outlier values (which are common in later layers with kurtosis up to 81K)
882
+ - Cosine similarity preserves the direction of the hidden state vector, which matters more than exact magnitude for downstream attention and expert computations
883
+ - The 0.1 weighting keeps MSE as the primary objective while regularizing directions
884
+
885
+ ### 7.4 Reference Layer Stride
886
+
887
+ The stride of 12 (giving reference layers {0, 12, 24, 36}) was chosen as a balance:
888
+ - More reference layers (smaller stride) → better stale signals but more communication (ref layers use standard compression without stale)
889
+ - Fewer reference layers (larger stride) → stale signals become less correlated with non-ref layers
890
+ - stride=12 gives 4 reference layers covering 48 layers, with each non-ref layer at most 11 layers away from its reference
891
+
892
+ ### 7.5 Training Data Size
893
+
894
+ 100,000 tokens per layer (increased from initial 10,000). Each token produces a 2048-dim vector, so training data per layer is 100K × 2048 = 204.8M values. This is sufficient for learning a linear map with ~4M parameters (2x compression, per-layer).
895
+
896
+ ### 7.6 Model Precision
897
+
898
+ All tasks use the same model in full BF16 precision (no weight quantization). This ensures:
899
+ - Hidden states used for offline training exactly match inference conditions
900
+ - End-to-end training has proper gradient flow through frozen layers
901
+ - All methods share the same baseline perplexity, enabling direct comparison
902
+ - 4-bit NF4 quantization is available via `--load-in-4bit` but is not the default
903
+
904
+ ---
905
+
906
+ ## 8. Implementation Details
907
+
908
+ ### 8.1 Hook-Based Evaluation and Training
909
+
910
+ Four hook modes are used across experiments:
911
+
912
+ | Mode | Hook type | Used in |
913
+ |---|---|---|
914
+ | `evaluate_perplexity_with_compression` | Same compress/decompress for all layers | Shared compressor (Task 3) |
915
+ | `evaluate_perplexity_with_perlayer_compression` | Per-layer compress/decompress dicts | Per-layer compressor (Task 3b) |
916
+ | `evaluate_perplexity_with_stale_compression` | Per-layer + stale cache + ref/non-ref split | Stale-conditioned (Tasks 4a/4b) |
917
+ | `E2ECompressorManager.register_hooks()` | Per-layer, trainable, with/without stale cache | E2E training + eval (Task 5) |
918
+
919
+ The stale evaluation maintains a `stale_cache` dictionary that is populated by reference layer pre-hooks and read by subsequent non-reference layer hooks. This works because PyTorch processes layers sequentially (layer 0 before layer 1, etc.).
920
+
921
+ **Device safety in evaluation hooks:** With `device_map="auto"`, model layers may reside on
922
+ different GPUs. All evaluation hooks in `model_utils.py` (`evaluate_perplexity_with_perlayer_compression`
923
+ and `evaluate_perplexity_with_stale_compression`) explicitly call `.to(x.device)` on
924
+ compressor/decompressor outputs before returning them to the model. This ensures correctness
925
+ when compressor weights and MoE layers are on different devices.
926
+
927
+ **E2E training hooks (Task 5)** differ from evaluation hooks in two ways:
928
+ 1. Compressor/decompressor parameters have `requires_grad=True`, so the autograd graph is maintained through the hooks.
929
+ 2. For stale mode (5b), the cached stale signal is **not detached** — gradients flow through the stale path to earlier layers, enabling true end-to-end optimization.
930
+
931
+ ### 8.2 MoE Layer Detection
932
+
933
+ `find_moe_layers()` in `model_utils.py` detects MoE modules by:
934
+ 1. Checking if the class name contains "Moe", "MoE", or "SparseMoe"
935
+ 2. Checking for `experts` attribute
936
+ 3. Checking for both `gate` and `experts` attributes
937
+
938
+ This is model-agnostic and works for Qwen3, Mixtral, and other MoE architectures.
939
+
940
+ ### 8.3 File Organization
941
+
942
+ **Offline experiments (Tasks 1–4)** follow the same pattern:
943
+ 1. Load cached hidden states from `data/hidden_states/`
944
+ 2. Train compressors on dispatch states
945
+ 3. Evaluate reconstruction metrics (offline, on cached data)
946
+ 4. Load the full model and evaluate perplexity (online, with hooks)
947
+ 5. Save results to `results/{experiment}/`
948
+
949
+ **End-to-end experiments (Task 5)** follow a different pattern:
950
+ 1. Load the full model in BF16 across 4 GPUs
951
+ 2. Load and tokenize training data (Dolci-Instruct-SFT)
952
+ 3. For each compression ratio: create compressor manager, train e2e, save weights
953
+ 4. Evaluate perplexity on Dolci-Instruct-SFT (with hooks, same as offline)
954
+ 5. Save results to `results/05{a,b}_e2e_{perlayer,stale}/`
955
+
956
+ Bash wrappers in `scripts/` handle environment setup, module loading, and argument passing.
957
+
958
+ ### 8.4 Progress Tracking and Logging
959
+
960
+ All long-running loops use `tqdm` progress bars (written to stderr) for real-time progress monitoring with elapsed time and ETA. Key loops instrumented:
961
+
962
+ - **Training loops:** Epoch progress with loss/cosine postfix (all training functions)
963
+ - **Layer loops:** Per-layer training iteration (Tasks 3b, 4a/4b)
964
+ - **Data loading:** Calibration data and tokenization progress
965
+ - **Evaluation:** Perplexity evaluation sequence progress, quantization config iteration
966
+ - **Ratio loops:** Outer compression ratio iteration (all tasks)
967
+
968
+ Each bash script redirects output to two log files in the task's output directory:
969
+
970
+ | File | Contents | Source |
971
+ |---|---|---|
972
+ | `run.log` | Full output (print statements, results, summaries) | stdout |
973
+ | `progress.log` | tqdm progress bars (elapsed time, ETA, loss metrics) | stderr |
974
+
975
+ Monitor progress of a running experiment: `tail -f results/<task>/progress.log`
976
+
977
+ ---
978
+
979
+ ## 9. Reproducibility
980
+
981
+ ### 9.1 Software Environment
982
+
983
+ - Python 3.11
984
+ - PyTorch (via `pip install torch` with CUDA 12.6)
985
+ - Transformers (HuggingFace)
986
+ - bitsandbytes (optional, for 4-bit model loading)
987
+ - datasets (for allenai/Dolci-Instruct-SFT)
988
+ - matplotlib, numpy
989
+
990
+ ### 9.2 Hardware
991
+
992
+ - NVIDIA H100 80 GB GPUs (8 available)
993
+ - Tasks 1–4: single GPU sufficient (model in full BF16, ~60 GB on one H100 80 GB)
994
+ - Task 5: 4 GPUs per job (model in full BF16, ~60 GB + backprop memory); 05a and 05b run in parallel on GPUs 0-3 and 4-7
995
+ - 500+ GB system RAM (required for loading ~37 GB of hidden states for offline tasks)
996
+ - Compute Canada cluster
997
+
998
+ ### 9.3 Random Seeds and Data Splitting
999
+
1000
+ All experiments use **seed=42** for reproducibility. A deterministic 80/10/10
1001
+ train/val/test split of the Dolci-Instruct-SFT dataset rows is computed via
1002
+ `get_split_indices()` in `model_utils.py`:
1003
+
1004
+ ```python
1005
+ rng = random.Random(42)
1006
+ indices = list(range(dataset_size))
1007
+ rng.shuffle(indices)
1008
+ # 80% train, 10% val, 10% test
1009
+ ```
1010
+
1011
+ **Split consistency across tasks:**
1012
+ - Task 1 hidden state collection: TRAIN split (max_samples=10000)
1013
+ - Tasks 2–4 offline training: uses cached hidden states from Task 1 (TRAIN split)
1014
+ - Tasks 2–4 PPL evaluation: TEST split (max_samples_ppl=50000, response-only)
1015
+ - Task 5 E2E training: TRAIN split (500K sequences HF / 100K Megatron, SFT mode)
1016
+ - Task 5 E2E validation: VAL split sequences (SFT mode)
1017
+ - Task 5 PPL evaluation: TEST split (same as tasks 2–4, response-only)
1018
+
1019
+ **SFT data loading (Task 5 and PPL evaluation):**
1020
+ - Each conversation is tokenized independently (one sample = one sequence)
1021
+ - Labels are -100 for non-assistant tokens, actual token IDs for assistant responses
1022
+ - `_tokenize_sft_sample()` in `model_utils.py` finds assistant token boundaries
1023
+ via incremental prefix tokenization of the chat template
1024
+ - Max sequence length: 2048 (configurable via `--max-length`)
1025
+ - Loss and perplexity are computed on response tokens only
1026
+
1027
+ Additional seed setting in Task 5:
1028
+ - `random.seed(42)`, `np.random.seed(42)`, `torch.manual_seed(42)`,
1029
+ `torch.cuda.manual_seed_all(42)` at start of main()
1030
+ - DataLoader shuffling uses PyTorch's seeded RNG
1031
+
1032
+ ### 9.4 Experiment Tracking (Wandb)
1033
+
1034
+ Both HF and Megatron E2E scripts support Weights & Biases logging:
1035
+
1036
+ - **CLI:** `--wandb` / `--no-wandb`, `--wandb-project <name>`
1037
+ - **Logged metrics:** `train/loss` and `train/lr` per optimizer step,
1038
+ `val/loss` every `--val-interval` steps (default 2500) and at end of epoch,
1039
+ `train/epoch_loss` per epoch
1040
+ - **Projects:** `ecmoe-e2e` (HF), `ecmoe-megatron-e2e` (Megatron)
1041
+ - **Default:** Enabled in bash scripts via `WANDB_FLAG`; disable with
1042
+ `WANDB_FLAG="--no-wandb" bash scripts/05_run_e2e_compressor.sh none`
1043
+ - Megatron: only rank 0 logs to wandb
1044
+ - Megatron `train/loss` and `train/epoch_loss` are DP-averaged (all-reduced across
1045
+ data-parallel ranks) before logging, so wandb values reflect the true global loss
1046
+ - Graceful fallback if wandb is not installed (HAS_WANDB flag)
1047
+
1048
+ ---
1049
+
1050
+ ## 10. Task 8: EP Communication Compression in vLLM
1051
+
1052
+ ### 10.1 Motivation
1053
+
1054
+ Tasks 5–7 evaluate compression quality using PyTorch hooks that compress and decompress
1055
+ on the **same GPU** — simulating the quality impact but not achieving actual communication
1056
+ reduction. In real expert parallelism, the pipeline is:
1057
+
1058
+ 1. Router computes logits from **original** hidden states (attention GPU)
1059
+ 2. **Compressor** runs on attention GPU: `hidden_dim` → `bottleneck_dim`
1060
+ 3. All-to-all dispatch sends only the **compressed** tensor (reduced volume!)
1061
+ 4. **Decompressor** runs on expert GPU: `bottleneck_dim` → `hidden_dim`
1062
+ 5. Experts compute on decompressed states
1063
+
1064
+ Task 8 modifies vLLM's `FusedMoE.forward_impl()` to implement this pipeline,
1065
+ compressing BEFORE dispatch and decompressing AFTER.
1066
+
1067
+ ### 10.2 Implementation
1068
+
1069
+ **Patched vLLM (`scripts/patch_vllm_fused_moe.py`):** Adds ~12 lines to
1070
+ `FusedMoE.forward_impl()` at three locations:
1071
+
1072
+ 1. **Compress before dispatch (EP mode):** `_ecmoe_compress_fn(hidden_states)` →
1073
+ dispatches compressed tensor instead of full hidden_dim.
1074
+ 2. **Decompress after dispatch (EP mode):** After `get_ep_group().dispatch()`,
1075
+ `_ecmoe_decompress_fn(hidden_states_combined)` restores full hidden_dim.
1076
+ 3. **Single-GPU fallback:** When `do_naive_dispatch_combine=False` (TP=1),
1077
+ applies compress→decompress in-place for simulation mode.
1078
+
1079
+ When `_ecmoe_compress_fn` is None (default), behavior is identical to stock vLLM.
1080
+
1081
+ **EP-aware registration (`src/vllm_ep_compression.py`):** Uses `apply_model()`
1082
+ to set compress/decompress functions on each FusedMoE instance:
1083
+
1084
+ - **Per-layer:** `register_ep_perlayer()` — Independent linear compress/decompress per layer.
1085
+ - **Stale-conditioned:** `register_ep_stale()` — Reference layers piggyback stale signal
1086
+ on compressed tensor before dispatch. Non-reference layers dispatch only compressed data.
1087
+
1088
+ ### 10.3 Stale Broadcast via Dispatch Piggybacking
1089
+
1090
+ **Reference layers (0, 12, 24, 36):**
1091
+ - compress_fn: `cat(compressed[B, bottleneck], stale[B, stale_dim])` → dispatch `[B, bottleneck + stale_dim]`
1092
+ - decompress_fn: split → cache stale_part globally → decompress compressed_part
1093
+
1094
+ **Non-reference layers (all others):**
1095
+ - compress_fn: `compressed[B, bottleneck]` only → dispatch `[B, bottleneck]` (maximum compression!)
1096
+ - decompress_fn: retrieve cached stale → `cat(compressed, cached_stale)` → StaleDecomp
1097
+
1098
+ **Correctness:** vLLM's default `all2all_backend=allgather_reducescatter` means after
1099
+ dispatch, every rank has ALL tokens in consistent ordering. Stale cached from reference
1100
+ layers matches token ordering at non-reference layers.
1101
+
1102
+ ### 10.4 Communication Savings
1103
+
1104
+ | Mode | Ref layers (4/48) | Non-ref layers (44/48) | Weighted avg | vs baseline 2048 |
1105
+ |------|-------------------|----------------------|--------------|-------------------|
1106
+ | perlayer 2x | 1024 | 1024 | 1024 | **50% saving** |
1107
+ | perlayer 4x | 512 | 512 | 512 | **75% saving** |
1108
+ | stale(comp) 4x | 1024 | 512 | 555 | **73% saving** |
1109
+ | stale(uncomp) 4x | 2560 | 512 | 683 | **67% saving** |
1110
+ | stale(uncomp) 2x | 3072 | 1024 | 1195 | **42% saving** |
1111
+
1112
+ Stale broadcast cost is amortized over ~11 non-reference layers per reference layer.
1113
+
1114
+ ### 10.5 Evaluation Modes
1115
+
1116
+ - **simulation** (`--mode simulation`): Single-GPU (TP=1), no dispatch/combine.
1117
+ Validates numerical correctness against existing split-mode results.
1118
+ - **ep** (`--mode ep`): Multi-GPU (TP=4, `enable_expert_parallel=True`).
1119
+ Uses actual EP dispatch/combine with compressed tensors.
1120
+
1121
+ Both use Task 7a/7b weights (split-mode E2E trained) from
1122
+ `results/07a_megatron_e2e_split_perlayer/` and `results/07b_megatron_e2e_split_stale/`.
docker-compose.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ services:
2
+ flight-search:
3
+ build: .
4
+ ports:
5
+ - "8080:8080"
6
+ restart: unless-stopped
frontend/eslint.config.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import tseslint from 'typescript-eslint'
6
+ import { defineConfig, globalIgnores } from 'eslint/config'
7
+
8
+ export default defineConfig([
9
+ globalIgnores(['dist']),
10
+ {
11
+ files: ['**/*.{ts,tsx}'],
12
+ extends: [
13
+ js.configs.recommended,
14
+ tseslint.configs.recommended,
15
+ reactHooks.configs.flat.recommended,
16
+ reactRefresh.configs.vite,
17
+ ],
18
+ languageOptions: {
19
+ ecmaVersion: 2020,
20
+ globals: globals.browser,
21
+ },
22
+ },
23
+ ])
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Flight Search</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.tsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc -b && vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "react": "^19.2.0",
14
+ "react-dom": "^19.2.0",
15
+ "react-router-dom": "^7.13.1"
16
+ },
17
+ "devDependencies": {
18
+ "@eslint/js": "^9.39.1",
19
+ "@tailwindcss/vite": "^4.2.1",
20
+ "@types/node": "^24.10.1",
21
+ "@types/react": "^19.2.7",
22
+ "@types/react-dom": "^19.2.3",
23
+ "@vitejs/plugin-react": "^5.1.1",
24
+ "eslint": "^9.39.1",
25
+ "eslint-plugin-react-hooks": "^7.0.1",
26
+ "eslint-plugin-react-refresh": "^0.4.24",
27
+ "globals": "^16.5.0",
28
+ "tailwindcss": "^4.2.1",
29
+ "typescript": "~5.9.3",
30
+ "typescript-eslint": "^8.48.0",
31
+ "vite": "^7.3.1"
32
+ }
33
+ }
frontend/public/vite.svg ADDED
frontend/src/App.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #root {
2
+ max-width: 1280px;
3
+ margin: 0 auto;
4
+ padding: 2rem;
5
+ text-align: center;
6
+ }
7
+
8
+ .logo {
9
+ height: 6em;
10
+ padding: 1.5em;
11
+ will-change: filter;
12
+ transition: filter 300ms;
13
+ }
14
+ .logo:hover {
15
+ filter: drop-shadow(0 0 2em #646cffaa);
16
+ }
17
+ .logo.react:hover {
18
+ filter: drop-shadow(0 0 2em #61dafbaa);
19
+ }
20
+
21
+ @keyframes logo-spin {
22
+ from {
23
+ transform: rotate(0deg);
24
+ }
25
+ to {
26
+ transform: rotate(360deg);
27
+ }
28
+ }
29
+
30
+ @media (prefers-reduced-motion: no-preference) {
31
+ a:nth-of-type(2) .logo {
32
+ animation: logo-spin infinite 20s linear;
33
+ }
34
+ }
35
+
36
+ .card {
37
+ padding: 2em;
38
+ }
39
+
40
+ .read-the-docs {
41
+ color: #888;
42
+ }
frontend/src/App.tsx ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { BrowserRouter, Route, Routes } from 'react-router-dom';
2
+ import Header from './components/shared/Header';
3
+ import SearchPage from './pages/SearchPage';
4
+ import ResultsPage from './pages/ResultsPage';
5
+
6
+ export default function App() {
7
+ return (
8
+ <BrowserRouter>
9
+ <Header />
10
+ <Routes>
11
+ <Route path="/" element={<SearchPage />} />
12
+ <Route path="/results" element={<ResultsPage />} />
13
+ </Routes>
14
+ </BrowserRouter>
15
+ );
16
+ }
frontend/src/api/client.ts ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { AutocompleteResult, CalendarResponse, SearchRequest, SearchResponse } from './types';
2
+
3
+ const BASE_URL = '/api';
4
+
5
+ async function fetchJson<T>(url: string, init?: RequestInit): Promise<T> {
6
+ const res = await fetch(url, init);
7
+ if (!res.ok) {
8
+ const text = await res.text();
9
+ throw new Error(`API error ${res.status}: ${text}`);
10
+ }
11
+ return res.json();
12
+ }
13
+
14
+ export async function searchAirports(query: string): Promise<AutocompleteResult[]> {
15
+ if (!query || query.length < 1) return [];
16
+ return fetchJson<AutocompleteResult[]>(
17
+ `${BASE_URL}/airports/autocomplete?q=${encodeURIComponent(query)}`
18
+ );
19
+ }
20
+
21
+ export async function searchFlights(req: SearchRequest): Promise<SearchResponse> {
22
+ return fetchJson<SearchResponse>(`${BASE_URL}/search`, {
23
+ method: 'POST',
24
+ headers: { 'Content-Type': 'application/json' },
25
+ body: JSON.stringify(req),
26
+ });
27
+ }
28
+
29
+ export async function getCalendar(
30
+ origin: string,
31
+ destination: string,
32
+ year: number,
33
+ month: number,
34
+ cabinClass: string = 'economy'
35
+ ): Promise<CalendarResponse> {
36
+ return fetchJson<CalendarResponse>(
37
+ `${BASE_URL}/calendar?origin=${origin}&destination=${destination}&year=${year}&month=${month}&cabin_class=${cabinClass}`
38
+ );
39
+ }
frontend/src/api/types.ts ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export type CabinClass = 'economy' | 'premium_economy' | 'business' | 'first';
2
+ export type TripType = 'one_way' | 'round_trip' | 'multi_city';
3
+ export type SortBy = 'best' | 'cheapest' | 'fastest';
4
+
5
+ export interface AutocompleteResult {
6
+ iata: string;
7
+ name: string;
8
+ city_name: string;
9
+ country: string;
10
+ display_name: string;
11
+ hub_score: number;
12
+ }
13
+
14
+ export interface FlightSegment {
15
+ airline_code: string;
16
+ airline_name: string;
17
+ flight_number: string;
18
+ aircraft: string;
19
+ origin: string;
20
+ origin_city: string;
21
+ destination: string;
22
+ destination_city: string;
23
+ departure: string;
24
+ arrival: string;
25
+ duration_minutes: number;
26
+ }
27
+
28
+ export interface FlightOffer {
29
+ id: string;
30
+ segments: FlightSegment[];
31
+ total_duration_minutes: number;
32
+ stops: number;
33
+ price_usd: number;
34
+ cabin_class: CabinClass;
35
+ origin: string;
36
+ destination: string;
37
+ departure: string;
38
+ arrival: string;
39
+ }
40
+
41
+ export interface SearchLeg {
42
+ origin: string;
43
+ destination: string;
44
+ date: string; // YYYY-MM-DD
45
+ }
46
+
47
+ export interface Passengers {
48
+ adults: number;
49
+ children: number;
50
+ infants: number;
51
+ }
52
+
53
+ export interface Filters {
54
+ max_stops?: number | null;
55
+ max_price?: number | null;
56
+ max_duration_minutes?: number | null;
57
+ airlines?: string[] | null;
58
+ departure_time_min?: string | null;
59
+ departure_time_max?: string | null;
60
+ }
61
+
62
+ export interface SearchRequest {
63
+ trip_type: TripType;
64
+ legs: SearchLeg[];
65
+ passengers: Passengers;
66
+ cabin_class: CabinClass;
67
+ filters: Filters;
68
+ sort_by: SortBy;
69
+ }
70
+
71
+ export interface SearchResponse {
72
+ outbound_flights: FlightOffer[];
73
+ return_flights: FlightOffer[];
74
+ search_id: string;
75
+ origin: string;
76
+ destination: string;
77
+ }
78
+
79
+ export interface CalendarDay {
80
+ date: string;
81
+ cheapest_price: number | null;
82
+ }
83
+
84
+ export interface CalendarResponse {
85
+ origin: string;
86
+ destination: string;
87
+ year: number;
88
+ month: number;
89
+ days: CalendarDay[];
90
+ }
frontend/src/assets/react.svg ADDED
frontend/src/components/results/FilterPanel.tsx ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo } from 'react';
2
+ import type { Filters, FlightOffer } from '../../api/types';
3
+
4
+ interface Props {
5
+ flights: FlightOffer[];
6
+ filters: Filters;
7
+ onChange: (f: Filters) => void;
8
+ }
9
+
10
+ export default function FilterPanel({ flights, filters, onChange }: Props) {
11
+ // Compute available filter options from flights
12
+ const airlines = useMemo(() => {
13
+ const map = new Map<string, string>();
14
+ flights.forEach(f => f.segments.forEach(s => map.set(s.airline_code, s.airline_name)));
15
+ return Array.from(map.entries()).sort((a, b) => a[1].localeCompare(b[1]));
16
+ }, [flights]);
17
+
18
+ const maxStopsAvailable = useMemo(() => Math.max(...flights.map(f => f.stops), 0), [flights]);
19
+
20
+ return (
21
+ <div className="space-y-6" data-testid="filter-panel">
22
+ {/* Stops filter */}
23
+ <div>
24
+ <h3 className="mb-2 text-sm font-medium text-gray-900">Stops</h3>
25
+ <div className="space-y-1">
26
+ {[null, 0, 1, 2].filter(v => v === null || v <= maxStopsAvailable).map(v => (
27
+ <label key={String(v)} className="flex items-center gap-2 cursor-pointer">
28
+ <input
29
+ type="radio"
30
+ name="stops"
31
+ checked={filters.max_stops === v}
32
+ onChange={() => onChange({ ...filters, max_stops: v })}
33
+ className="accent-[#1a73e8]"
34
+ data-testid={`filter-stops-${v === null ? 'any' : v}`}
35
+ />
36
+ <span className="text-sm text-gray-700">
37
+ {v === null ? 'Any' : v === 0 ? 'Nonstop only' : `Up to ${v} stop${v > 1 ? 's' : ''}`}
38
+ </span>
39
+ </label>
40
+ ))}
41
+ </div>
42
+ </div>
43
+
44
+ {/* Price filter */}
45
+ <div>
46
+ <h3 className="mb-2 text-sm font-medium text-gray-900">Max price</h3>
47
+ <div className="flex items-center gap-2">
48
+ <span className="text-sm text-gray-500">$</span>
49
+ <input
50
+ type="number"
51
+ value={filters.max_price ?? ''}
52
+ onChange={e => onChange({ ...filters, max_price: e.target.value ? Number(e.target.value) : null })}
53
+ placeholder="Any"
54
+ className="w-24 rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
55
+ min={0}
56
+ data-testid="filter-max-price"
57
+ />
58
+ </div>
59
+ </div>
60
+
61
+ {/* Departure time filter */}
62
+ <div>
63
+ <h3 className="mb-2 text-sm font-medium text-gray-900">Departure time</h3>
64
+ <div className="flex items-center gap-2">
65
+ <input
66
+ type="time"
67
+ value={filters.departure_time_min ?? ''}
68
+ onChange={e => onChange({ ...filters, departure_time_min: e.target.value || null })}
69
+ className="rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
70
+ data-testid="filter-dep-time-min"
71
+ />
72
+ <span className="text-xs text-gray-500">to</span>
73
+ <input
74
+ type="time"
75
+ value={filters.departure_time_max ?? ''}
76
+ onChange={e => onChange({ ...filters, departure_time_max: e.target.value || null })}
77
+ className="rounded border border-gray-300 px-2 py-1 text-sm focus:border-[#1a73e8] focus:outline-none"
78
+ data-testid="filter-dep-time-max"
79
+ />
80
+ </div>
81
+ </div>
82
+
83
+ {/* Airlines filter */}
84
+ {airlines.length > 1 && (
85
+ <div>
86
+ <h3 className="mb-2 text-sm font-medium text-gray-900">Airlines</h3>
87
+ <div className="max-h-48 space-y-1 overflow-y-auto">
88
+ {airlines.map(([code, name]) => {
89
+ const selected = !filters.airlines || filters.airlines.includes(code);
90
+ return (
91
+ <label key={code} className="flex items-center gap-2 cursor-pointer">
92
+ <input
93
+ type="checkbox"
94
+ checked={selected}
95
+ onChange={() => {
96
+ let next: string[] | null;
97
+ if (!filters.airlines) {
98
+ // First deselection: select all except this one
99
+ next = airlines.filter(([c]) => c !== code).map(([c]) => c);
100
+ } else if (selected) {
101
+ next = filters.airlines.filter(c => c !== code);
102
+ if (next.length === 0) next = null; // deselect all = any
103
+ } else {
104
+ next = [...filters.airlines, code];
105
+ if (next.length === airlines.length) next = null; // all selected = any
106
+ }
107
+ onChange({ ...filters, airlines: next });
108
+ }}
109
+ className="accent-[#1a73e8]"
110
+ data-testid={`filter-airline-${code}`}
111
+ />
112
+ <span className="text-sm text-gray-700">{name} ({code})</span>
113
+ </label>
114
+ );
115
+ })}
116
+ </div>
117
+ </div>
118
+ )}
119
+
120
+ {/* Clear all */}
121
+ <button
122
+ onClick={() => onChange({
123
+ max_stops: null, max_price: null, max_duration_minutes: null,
124
+ airlines: null, departure_time_min: null, departure_time_max: null,
125
+ })}
126
+ className="text-sm text-[#1a73e8] hover:underline cursor-pointer"
127
+ data-testid="filter-clear"
128
+ >
129
+ Clear all filters
130
+ </button>
131
+ </div>
132
+ );
133
+ }
frontend/src/components/results/FlightCard.tsx ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState } from 'react';
2
+ import type { FlightOffer } from '../../api/types';
3
+ import { formatDuration, formatPrice, formatStops, formatTime } from '../../utils/format';
4
+ import FlightSegmentView from './FlightSegment';
5
+
6
+ interface Props {
7
+ flight: FlightOffer;
8
+ }
9
+
10
+ export default function FlightCard({ flight }: Props) {
11
+ const [expanded, setExpanded] = useState(false);
12
+ const firstSeg = flight.segments[0];
13
+
14
+ // Check if arrival is on a different day
15
+ const depDate = new Date(flight.departure).toDateString();
16
+ const arrDate = new Date(flight.arrival).toDateString();
17
+ const dayDiff = depDate !== arrDate;
18
+
19
+ return (
20
+ <div
21
+ className="rounded-lg border border-gray-200 bg-white hover:shadow-md transition-shadow cursor-pointer"
22
+ onClick={() => flight.stops > 0 && setExpanded(!expanded)}
23
+ data-testid={`flight-card-${flight.id}`}
24
+ role="article"
25
+ aria-label={`Flight from ${flight.origin} to ${flight.destination}, ${formatPrice(flight.price_usd)}`}
26
+ >
27
+ <div className="flex items-center gap-4 p-4">
28
+ {/* Airline badge */}
29
+ <div className="flex h-10 w-10 items-center justify-center rounded-lg bg-gray-100 text-xs font-bold text-gray-600 flex-shrink-0" data-testid="airline-badge">
30
+ {firstSeg.airline_code}
31
+ </div>
32
+
33
+ {/* Main info */}
34
+ <div className="flex flex-1 items-center gap-6">
35
+ {/* Times */}
36
+ <div className="flex items-center gap-3 flex-1">
37
+ <div className="text-right">
38
+ <div className="text-lg font-medium" data-testid="departure-time">{formatTime(flight.departure)}</div>
39
+ <div className="text-xs text-gray-500">{flight.origin}</div>
40
+ </div>
41
+
42
+ <div className="flex flex-1 flex-col items-center px-2">
43
+ <div className="text-xs text-gray-500">{formatDuration(flight.total_duration_minutes)}</div>
44
+ <div className="relative my-1 h-px w-full bg-gray-300">
45
+ <div className="absolute left-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
46
+ <div className="absolute right-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
47
+ {/* Stop indicators */}
48
+ {flight.stops > 0 && flight.segments.slice(0, -1).map((_, i) => (
49
+ <div
50
+ key={i}
51
+ className="absolute top-1/2 h-2 w-2 -translate-y-1/2 rounded-full bg-gray-400"
52
+ style={{ left: `${((i + 1) / flight.segments.length) * 100}%` }}
53
+ />
54
+ ))}
55
+ </div>
56
+ <div className="text-xs text-gray-500" data-testid="stops">{formatStops(flight.stops)}</div>
57
+ </div>
58
+
59
+ <div>
60
+ <div className="text-lg font-medium" data-testid="arrival-time">
61
+ {formatTime(flight.arrival)}
62
+ {dayDiff && <sup className="ml-0.5 text-xs text-red-500">+1</sup>}
63
+ </div>
64
+ <div className="text-xs text-gray-500">{flight.destination}</div>
65
+ </div>
66
+ </div>
67
+
68
+ {/* Airline name */}
69
+ <div className="hidden md:block text-xs text-gray-500 min-w-[100px]" data-testid="airline-name">
70
+ {firstSeg.airline_name}
71
+ </div>
72
+ </div>
73
+
74
+ {/* Price */}
75
+ <div className="text-right pl-4 min-w-[80px]">
76
+ <div className="text-lg font-medium text-gray-900" data-testid="price">
77
+ {formatPrice(flight.price_usd)}
78
+ </div>
79
+ <div className="text-xs text-gray-500">{flight.cabin_class.replace('_', ' ')}</div>
80
+ </div>
81
+
82
+ {/* Expand icon for connecting flights */}
83
+ {flight.stops > 0 && (
84
+ <svg
85
+ className={`h-5 w-5 text-gray-400 transition-transform ${expanded ? 'rotate-180' : ''}`}
86
+ viewBox="0 0 20 20" fill="currentColor"
87
+ >
88
+ <path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
89
+ </svg>
90
+ )}
91
+ </div>
92
+
93
+ {/* Expanded segment details */}
94
+ {expanded && flight.stops > 0 && (
95
+ <div className="border-t border-gray-100 px-4 pb-4" data-testid="segments-detail">
96
+ {flight.segments.map((seg, i) => (
97
+ <div key={i}>
98
+ <FlightSegmentView segment={seg} showDetails />
99
+ {i < flight.segments.length - 1 && (
100
+ <div className="ml-14 border-l-2 border-dashed border-gray-200 py-2 pl-4">
101
+ <span className="text-xs text-gray-400">
102
+ Layover at {seg.destination} ({seg.destination_city})
103
+ </span>
104
+ </div>
105
+ )}
106
+ </div>
107
+ ))}
108
+ </div>
109
+ )}
110
+ </div>
111
+ );
112
+ }
frontend/src/components/results/FlightSegment.tsx ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { FlightSegment as SegmentType } from '../../api/types';
2
+ import { formatDuration, formatTime } from '../../utils/format';
3
+
4
+ interface Props {
5
+ segment: SegmentType;
6
+ showDetails?: boolean;
7
+ }
8
+
9
+ export default function FlightSegmentView({ segment, showDetails }: Props) {
10
+ return (
11
+ <div className="flex items-center gap-4 py-2" data-testid="flight-segment">
12
+ {/* Airline badge */}
13
+ <div className="flex h-9 w-9 items-center justify-center rounded bg-gray-100 text-xs font-bold text-gray-600 flex-shrink-0" data-testid="airline-badge">
14
+ {segment.airline_code}
15
+ </div>
16
+
17
+ {/* Timeline */}
18
+ <div className="flex flex-1 items-center gap-3">
19
+ <div className="text-right min-w-[70px]">
20
+ <div className="text-base font-medium" data-testid="departure-time">{formatTime(segment.departure)}</div>
21
+ <div className="text-xs text-gray-500" data-testid="origin-code">{segment.origin}</div>
22
+ </div>
23
+
24
+ <div className="flex flex-1 flex-col items-center px-2">
25
+ <div className="text-xs text-gray-500">{formatDuration(segment.duration_minutes)}</div>
26
+ <div className="relative my-1 h-px w-full bg-gray-300">
27
+ <div className="absolute left-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
28
+ <div className="absolute right-0 top-1/2 h-2 w-2 -translate-y-1/2 rounded-full border-2 border-gray-400 bg-white" />
29
+ </div>
30
+ {showDetails && (
31
+ <div className="text-xs text-gray-400">{segment.airline_name} · {segment.flight_number} · {segment.aircraft}</div>
32
+ )}
33
+ </div>
34
+
35
+ <div className="min-w-[70px]">
36
+ <div className="text-base font-medium" data-testid="arrival-time">{formatTime(segment.arrival)}</div>
37
+ <div className="text-xs text-gray-500" data-testid="destination-code">{segment.destination}</div>
38
+ </div>
39
+ </div>
40
+ </div>
41
+ );
42
+ }
frontend/src/components/results/NoResults.tsx ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface Props {
2
+ hasFilters: boolean;
3
+ onClearFilters?: () => void;
4
+ }
5
+
6
+ export default function NoResults({ hasFilters, onClearFilters }: Props) {
7
+ return (
8
+ <div className="flex flex-col items-center justify-center py-16 text-center" data-testid="no-results">
9
+ <svg className="mb-4 h-16 w-16 text-gray-300" viewBox="0 0 24 24" fill="currentColor">
10
+ <path d="M21 16v-2l-8-5V3.5c0-.83-.67-1.5-1.5-1.5S10 2.67 10 3.5V9l-8 5v2l8-2.5V19l-2 1.5V22l3.5-1 3.5 1v-1.5L13 19v-5.5l8 2.5z"/>
11
+ </svg>
12
+ <h3 className="text-lg font-medium text-gray-700">No flights found</h3>
13
+ <p className="mt-1 text-sm text-gray-500">
14
+ {hasFilters
15
+ ? 'Try adjusting your filters or search criteria.'
16
+ : 'No direct or connecting flights available for this route.'}
17
+ </p>
18
+ {hasFilters && onClearFilters && (
19
+ <button
20
+ onClick={onClearFilters}
21
+ className="mt-4 text-sm font-medium text-[#1a73e8] hover:underline cursor-pointer"
22
+ data-testid="clear-filters-button"
23
+ >
24
+ Clear all filters
25
+ </button>
26
+ )}
27
+ </div>
28
+ );
29
+ }
frontend/src/components/results/SortBar.tsx ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { SortBy } from '../../api/types';
2
+
3
+ const OPTIONS: { value: SortBy; label: string }[] = [
4
+ { value: 'best', label: 'Best' },
5
+ { value: 'cheapest', label: 'Cheapest' },
6
+ { value: 'fastest', label: 'Fastest' },
7
+ ];
8
+
9
+ interface Props {
10
+ value: SortBy;
11
+ onChange: (v: SortBy) => void;
12
+ resultCount: number;
13
+ }
14
+
15
+ export default function SortBar({ value, onChange, resultCount }: Props) {
16
+ return (
17
+ <div className="flex items-center justify-between" data-testid="sort-bar">
18
+ <span className="text-sm text-gray-500" data-testid="result-count">
19
+ {resultCount} result{resultCount !== 1 ? 's' : ''}
20
+ </span>
21
+ <div className="flex gap-1 rounded-lg bg-gray-100 p-1">
22
+ {OPTIONS.map(o => (
23
+ <button
24
+ key={o.value}
25
+ onClick={() => onChange(o.value)}
26
+ className={`rounded-md px-4 py-1.5 text-sm font-medium transition-colors cursor-pointer ${
27
+ value === o.value
28
+ ? 'bg-white text-[#1a73e8] shadow-sm'
29
+ : 'text-gray-600 hover:text-gray-900'
30
+ }`}
31
+ data-testid={`sort-${o.value}`}
32
+ aria-pressed={value === o.value}
33
+ >
34
+ {o.label}
35
+ </button>
36
+ ))}
37
+ </div>
38
+ </div>
39
+ );
40
+ }
frontend/src/components/search/AirportInput.tsx ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useRef, useState } from 'react';
2
+ import { searchAirports } from '../../api/client';
3
+ import type { AutocompleteResult } from '../../api/types';
4
+ import { useDebounce } from '../../hooks/useDebounce';
5
+
6
+ interface Props {
7
+ label: string;
8
+ value: string; // IATA code
9
+ displayValue: string; // "New York (JFK)"
10
+ onChange: (iata: string, display: string) => void;
11
+ placeholder?: string;
12
+ testId?: string;
13
+ }
14
+
15
+ export default function AirportInput({ label, value, displayValue, onChange, placeholder, testId }: Props) {
16
+ const [query, setQuery] = useState(displayValue);
17
+ const [results, setResults] = useState<AutocompleteResult[]>([]);
18
+ const [open, setOpen] = useState(false);
19
+ const [focused, setFocused] = useState(false);
20
+ const debouncedQuery = useDebounce(query, 200);
21
+ const wrapperRef = useRef<HTMLDivElement>(null);
22
+
23
+ // Sync display value when parent changes it
24
+ useEffect(() => {
25
+ if (!focused) setQuery(displayValue);
26
+ }, [displayValue, focused]);
27
+
28
+ // Fetch autocomplete results
29
+ useEffect(() => {
30
+ if (!focused) return;
31
+ if (debouncedQuery.length < 1) {
32
+ setResults([]);
33
+ return;
34
+ }
35
+ let cancelled = false;
36
+ searchAirports(debouncedQuery).then(r => {
37
+ if (!cancelled) {
38
+ setResults(r);
39
+ setOpen(r.length > 0);
40
+ }
41
+ });
42
+ return () => { cancelled = true; };
43
+ }, [debouncedQuery, focused]);
44
+
45
+ // Close on click outside
46
+ useEffect(() => {
47
+ function handler(e: MouseEvent) {
48
+ if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
49
+ setOpen(false);
50
+ setFocused(false);
51
+ if (!value) setQuery('');
52
+ }
53
+ }
54
+ document.addEventListener('mousedown', handler);
55
+ return () => document.removeEventListener('mousedown', handler);
56
+ }, [value]);
57
+
58
+ function select(r: AutocompleteResult) {
59
+ onChange(r.iata, `${r.city_name} (${r.iata})`);
60
+ setQuery(`${r.city_name} (${r.iata})`);
61
+ setOpen(false);
62
+ setFocused(false);
63
+ }
64
+
65
+ return (
66
+ <div ref={wrapperRef} className="relative flex-1 min-w-[180px]" data-testid={testId}>
67
+ <label className="absolute -top-2 left-3 bg-white px-1 text-xs text-gray-500 z-10">{label}</label>
68
+ <input
69
+ type="text"
70
+ value={query}
71
+ onChange={e => { setQuery(e.target.value); setOpen(true); }}
72
+ onFocus={() => { setFocused(true); setQuery(''); setOpen(true); }}
73
+ placeholder={placeholder || 'City or airport'}
74
+ className="w-full rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-900 placeholder-gray-400 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none"
75
+ aria-label={label}
76
+ data-testid={testId ? `${testId}-input` : undefined}
77
+ autoComplete="off"
78
+ />
79
+ {open && results.length > 0 && (
80
+ <ul
81
+ className="absolute top-full left-0 right-0 z-50 mt-1 max-h-64 overflow-y-auto rounded-lg border border-gray-200 bg-white shadow-lg"
82
+ data-testid={testId ? `${testId}-dropdown` : undefined}
83
+ role="listbox"
84
+ >
85
+ {results.map(r => (
86
+ <li
87
+ key={r.iata}
88
+ onClick={() => select(r)}
89
+ className="flex cursor-pointer items-center gap-3 px-4 py-3 hover:bg-gray-50"
90
+ role="option"
91
+ data-testid={`airport-option-${r.iata}`}
92
+ aria-selected={r.iata === value}
93
+ >
94
+ <svg className="h-5 w-5 flex-shrink-0 text-gray-400" viewBox="0 0 24 24" fill="none">
95
+ <path d="M12 2C8.13 2 5 5.13 5 9c0 5.25 7 13 7 13s7-7.75 7-13c0-3.87-3.13-7-7-7zm0 9.5c-1.38 0-2.5-1.12-2.5-2.5s1.12-2.5 2.5-2.5 2.5 1.12 2.5 2.5-1.12 2.5-2.5 2.5z" fill="currentColor"/>
96
+ </svg>
97
+ <div className="flex-1 min-w-0">
98
+ <div className="text-sm font-medium text-gray-900 truncate">{r.city_name} ({r.iata})</div>
99
+ <div className="text-xs text-gray-500 truncate">{r.name}, {r.country}</div>
100
+ </div>
101
+ </li>
102
+ ))}
103
+ </ul>
104
+ )}
105
+ </div>
106
+ );
107
+ }
frontend/src/components/search/ClassSelector.tsx ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { CabinClass } from '../../api/types';
2
+ import { cabinClassLabel } from '../../utils/format';
3
+
4
+ const OPTIONS: CabinClass[] = ['economy', 'premium_economy', 'business', 'first'];
5
+
6
+ interface Props {
7
+ value: CabinClass;
8
+ onChange: (v: CabinClass) => void;
9
+ }
10
+
11
+ export default function ClassSelector({ value, onChange }: Props) {
12
+ return (
13
+ <div className="relative" data-testid="class-selector">
14
+ <select
15
+ value={value}
16
+ onChange={e => onChange(e.target.value as CabinClass)}
17
+ className="appearance-none rounded-md border border-gray-300 bg-white px-3 py-2 pr-8 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
18
+ aria-label="Cabin class"
19
+ data-testid="class-select"
20
+ >
21
+ {OPTIONS.map(c => (
22
+ <option key={c} value={c}>{cabinClassLabel(c)}</option>
23
+ ))}
24
+ </select>
25
+ <svg className="pointer-events-none absolute right-2 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" viewBox="0 0 20 20" fill="currentColor">
26
+ <path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
27
+ </svg>
28
+ </div>
29
+ );
30
+ }
frontend/src/components/search/DatePicker.tsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface Props {
2
+ label: string;
3
+ value: string; // YYYY-MM-DD
4
+ onChange: (v: string) => void;
5
+ testId?: string;
6
+ }
7
+
8
+ export default function DatePicker({ label, value, onChange, testId }: Props) {
9
+ return (
10
+ <div className="relative min-w-[150px]" data-testid={testId}>
11
+ <label className="absolute -top-2 left-3 bg-white px-1 text-xs text-gray-500 z-10">{label}</label>
12
+ <input
13
+ type="date"
14
+ value={value}
15
+ onChange={e => onChange(e.target.value)}
16
+ className="w-full rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-900 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
17
+ aria-label={label}
18
+ data-testid={testId ? `${testId}-input` : undefined}
19
+ />
20
+ </div>
21
+ );
22
+ }
frontend/src/components/search/PassengerSelector.tsx ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useRef, useState } from 'react';
2
+ import type { Passengers } from '../../api/types';
3
+
4
+ interface Props {
5
+ value: Passengers;
6
+ onChange: (v: Passengers) => void;
7
+ }
8
+
9
+ export default function PassengerSelector({ value, onChange }: Props) {
10
+ const [open, setOpen] = useState(false);
11
+ const ref = useRef<HTMLDivElement>(null);
12
+
13
+ useEffect(() => {
14
+ function handler(e: MouseEvent) {
15
+ if (ref.current && !ref.current.contains(e.target as Node)) setOpen(false);
16
+ }
17
+ document.addEventListener('mousedown', handler);
18
+ return () => document.removeEventListener('mousedown', handler);
19
+ }, []);
20
+
21
+ const total = value.adults + value.children + value.infants;
22
+
23
+ function update(field: keyof Passengers, delta: number) {
24
+ const v = { ...value };
25
+ v[field] = Math.max(field === 'adults' ? 1 : 0, Math.min(9, v[field] + delta));
26
+ onChange(v);
27
+ }
28
+
29
+ return (
30
+ <div ref={ref} className="relative" data-testid="passenger-selector">
31
+ <button
32
+ onClick={() => setOpen(!open)}
33
+ className="rounded-md border border-gray-300 px-3 py-3 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none flex items-center gap-1 cursor-pointer"
34
+ aria-label="Passengers"
35
+ data-testid="passenger-button"
36
+ >
37
+ <svg className="h-4 w-4 text-gray-500" viewBox="0 0 24 24" fill="currentColor">
38
+ <path d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"/>
39
+ </svg>
40
+ <span>{total}</span>
41
+ <svg className="h-4 w-4 text-gray-400" viewBox="0 0 20 20" fill="currentColor">
42
+ <path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
43
+ </svg>
44
+ </button>
45
+
46
+ {open && (
47
+ <div className="absolute top-full right-0 z-50 mt-1 w-64 rounded-lg border border-gray-200 bg-white p-4 shadow-lg" data-testid="passenger-dropdown">
48
+ {([
49
+ { key: 'adults' as const, label: 'Adults', sub: '' },
50
+ { key: 'children' as const, label: 'Children', sub: 'Aged 2-11' },
51
+ { key: 'infants' as const, label: 'Infants', sub: 'Under 2' },
52
+ ]).map(row => (
53
+ <div key={row.key} className="flex items-center justify-between py-2">
54
+ <div>
55
+ <div className="text-sm font-medium text-gray-900">{row.label}</div>
56
+ {row.sub && <div className="text-xs text-gray-500">{row.sub}</div>}
57
+ </div>
58
+ <div className="flex items-center gap-3">
59
+ <button
60
+ onClick={() => update(row.key, -1)}
61
+ disabled={value[row.key] <= (row.key === 'adults' ? 1 : 0)}
62
+ className="h-8 w-8 rounded-full border border-gray-300 text-gray-600 hover:bg-gray-50 disabled:opacity-30 disabled:cursor-not-allowed cursor-pointer flex items-center justify-center"
63
+ aria-label={`Decrease ${row.label}`}
64
+ data-testid={`${row.key}-decrease`}
65
+ >−</button>
66
+ <span className="w-4 text-center text-sm" data-testid={`${row.key}-count`}>{value[row.key]}</span>
67
+ <button
68
+ onClick={() => update(row.key, 1)}
69
+ disabled={value[row.key] >= 9}
70
+ className="h-8 w-8 rounded-full border border-gray-300 text-gray-600 hover:bg-gray-50 disabled:opacity-30 disabled:cursor-not-allowed cursor-pointer flex items-center justify-center"
71
+ aria-label={`Increase ${row.label}`}
72
+ data-testid={`${row.key}-increase`}
73
+ >+</button>
74
+ </div>
75
+ </div>
76
+ ))}
77
+ <button
78
+ onClick={() => setOpen(false)}
79
+ className="mt-2 w-full rounded-md bg-[#1a73e8] py-2 text-sm text-white hover:bg-[#1765cc] cursor-pointer"
80
+ data-testid="passenger-done"
81
+ >Done</button>
82
+ </div>
83
+ )}
84
+ </div>
85
+ );
86
+ }
frontend/src/components/search/SearchForm.tsx ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState } from 'react';
2
+ import type { CabinClass, Passengers, TripType } from '../../api/types';
3
+ import { getDefaultDepartureDate, getDefaultReturnDate } from '../../utils/date';
4
+ import AirportInput from './AirportInput';
5
+ import ClassSelector from './ClassSelector';
6
+ import DatePicker from './DatePicker';
7
+ import PassengerSelector from './PassengerSelector';
8
+ import SwapButton from './SwapButton';
9
+ import TripTypeSelector from './TripTypeSelector';
10
+
11
+ export interface SearchFormData {
12
+ tripType: TripType;
13
+ origin: string;
14
+ originDisplay: string;
15
+ destination: string;
16
+ destinationDisplay: string;
17
+ departureDate: string;
18
+ returnDate: string;
19
+ passengers: Passengers;
20
+ cabinClass: CabinClass;
21
+ }
22
+
23
+ interface Props {
24
+ initial?: Partial<SearchFormData>;
25
+ onSearch: (data: SearchFormData) => void;
26
+ compact?: boolean;
27
+ }
28
+
29
+ export default function SearchForm({ initial, onSearch, compact }: Props) {
30
+ const [tripType, setTripType] = useState<TripType>(initial?.tripType || 'round_trip');
31
+ const [origin, setOrigin] = useState(initial?.origin || '');
32
+ const [originDisplay, setOriginDisplay] = useState(initial?.originDisplay || '');
33
+ const [destination, setDestination] = useState(initial?.destination || '');
34
+ const [destinationDisplay, setDestinationDisplay] = useState(initial?.destinationDisplay || '');
35
+ const [departureDate, setDepartureDate] = useState(initial?.departureDate || getDefaultDepartureDate());
36
+ const [returnDate, setReturnDate] = useState(initial?.returnDate || getDefaultReturnDate());
37
+ const [passengers, setPassengers] = useState<Passengers>(initial?.passengers || { adults: 1, children: 0, infants: 0 });
38
+ const [cabinClass, setCabinClass] = useState<CabinClass>(initial?.cabinClass || 'economy');
39
+
40
+ function handleSwap() {
41
+ const tmpCode = origin;
42
+ const tmpDisplay = originDisplay;
43
+ setOrigin(destination);
44
+ setOriginDisplay(destinationDisplay);
45
+ setDestination(tmpCode);
46
+ setDestinationDisplay(tmpDisplay);
47
+ }
48
+
49
+ function handleSubmit(e: React.FormEvent) {
50
+ e.preventDefault();
51
+ if (!origin || !destination) return;
52
+ onSearch({
53
+ tripType, origin, originDisplay, destination, destinationDisplay,
54
+ departureDate, returnDate, passengers, cabinClass,
55
+ });
56
+ }
57
+
58
+ return (
59
+ <form onSubmit={handleSubmit} data-testid="search-form">
60
+ {/* Row 1: Trip type, passengers, class */}
61
+ <div className="mb-3 flex flex-wrap items-center gap-2">
62
+ <TripTypeSelector value={tripType} onChange={setTripType} />
63
+ <PassengerSelector value={passengers} onChange={setPassengers} />
64
+ <ClassSelector value={cabinClass} onChange={setCabinClass} />
65
+ </div>
66
+
67
+ {/* Row 2: Airport inputs + dates + search button */}
68
+ <div className={`flex flex-wrap items-end gap-2 ${compact ? '' : 'rounded-xl border border-gray-300 p-3'}`}>
69
+ <AirportInput
70
+ label="Where from?"
71
+ value={origin}
72
+ displayValue={originDisplay}
73
+ onChange={(iata, display) => { setOrigin(iata); setOriginDisplay(display); }}
74
+ testId="origin"
75
+ />
76
+
77
+ <SwapButton onClick={handleSwap} />
78
+
79
+ <AirportInput
80
+ label="Where to?"
81
+ value={destination}
82
+ displayValue={destinationDisplay}
83
+ onChange={(iata, display) => { setDestination(iata); setDestinationDisplay(display); }}
84
+ testId="destination"
85
+ />
86
+
87
+ <DatePicker
88
+ label="Departure"
89
+ value={departureDate}
90
+ onChange={setDepartureDate}
91
+ testId="departure-date"
92
+ />
93
+
94
+ {tripType === 'round_trip' && (
95
+ <DatePicker
96
+ label="Return"
97
+ value={returnDate}
98
+ onChange={setReturnDate}
99
+ testId="return-date"
100
+ />
101
+ )}
102
+
103
+ <button
104
+ type="submit"
105
+ disabled={!origin || !destination}
106
+ className="rounded-full bg-[#1a73e8] px-6 py-3 text-sm font-medium text-white hover:bg-[#1765cc] hover:shadow-md disabled:opacity-40 disabled:cursor-not-allowed focus:outline-none cursor-pointer"
107
+ data-testid="search-button"
108
+ >
109
+ Search
110
+ </button>
111
+ </div>
112
+ </form>
113
+ );
114
+ }
frontend/src/components/search/SwapButton.tsx ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ interface Props {
2
+ onClick: () => void;
3
+ }
4
+
5
+ export default function SwapButton({ onClick }: Props) {
6
+ return (
7
+ <button
8
+ onClick={onClick}
9
+ className="flex h-10 w-10 items-center justify-center rounded-full border border-gray-300 bg-white text-gray-500 hover:bg-gray-50 hover:text-gray-700 focus:outline-none self-center cursor-pointer"
10
+ aria-label="Swap origin and destination"
11
+ data-testid="swap-button"
12
+ >
13
+ <svg className="h-5 w-5" viewBox="0 0 24 24" fill="currentColor">
14
+ <path d="M6.99 11L3 15l3.99 4v-3H14v-2H6.99v-3zM21 9l-3.99-4v3H10v2h7.01v3L21 9z"/>
15
+ </svg>
16
+ </button>
17
+ );
18
+ }
frontend/src/components/search/TripTypeSelector.tsx ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { TripType } from '../../api/types';
2
+
3
+ const OPTIONS: { value: TripType; label: string }[] = [
4
+ { value: 'round_trip', label: 'Round trip' },
5
+ { value: 'one_way', label: 'One way' },
6
+ { value: 'multi_city', label: 'Multi-city' },
7
+ ];
8
+
9
+ interface Props {
10
+ value: TripType;
11
+ onChange: (v: TripType) => void;
12
+ }
13
+
14
+ export default function TripTypeSelector({ value, onChange }: Props) {
15
+ return (
16
+ <div className="relative" data-testid="trip-type-selector">
17
+ <select
18
+ value={value}
19
+ onChange={e => onChange(e.target.value as TripType)}
20
+ className="appearance-none rounded-md border border-gray-300 bg-white px-3 py-2 pr-8 text-sm text-gray-700 hover:border-gray-400 focus:border-[#1a73e8] focus:outline-none cursor-pointer"
21
+ aria-label="Trip type"
22
+ data-testid="trip-type-select"
23
+ >
24
+ {OPTIONS.map(o => (
25
+ <option key={o.value} value={o.value}>{o.label}</option>
26
+ ))}
27
+ </select>
28
+ <svg className="pointer-events-none absolute right-2 top-1/2 -translate-y-1/2 h-4 w-4 text-gray-500" viewBox="0 0 20 20" fill="currentColor">
29
+ <path fillRule="evenodd" d="M5.23 7.21a.75.75 0 011.06.02L10 11.168l3.71-3.938a.75.75 0 111.08 1.04l-4.25 4.5a.75.75 0 01-1.08 0l-4.25-4.5a.75.75 0 01.02-1.06z" clipRule="evenodd"/>
30
+ </svg>
31
+ </div>
32
+ );
33
+ }
frontend/src/components/shared/Header.tsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useNavigate } from 'react-router-dom';
2
+
3
+ export default function Header() {
4
+ const navigate = useNavigate();
5
+
6
+ return (
7
+ <header className="border-b border-gray-200 bg-white" data-testid="header">
8
+ <div className="mx-auto max-w-7xl px-4 py-3 flex items-center gap-3">
9
+ <button
10
+ onClick={() => navigate('/')}
11
+ className="flex items-center gap-2 text-xl font-medium text-gray-900 hover:opacity-80 cursor-pointer"
12
+ data-testid="logo"
13
+ >
14
+ <svg width="24" height="24" viewBox="0 0 24 24" fill="none" className="text-[#1a73e8]">
15
+ <path d="M21 16v-2l-8-5V3.5c0-.83-.67-1.5-1.5-1.5S10 2.67 10 3.5V9l-8 5v2l8-2.5V19l-2 1.5V22l3.5-1 3.5 1v-1.5L13 19v-5.5l8 2.5z" fill="currentColor"/>
16
+ </svg>
17
+ Flights
18
+ </button>
19
+ </div>
20
+ </header>
21
+ );
22
+ }
frontend/src/components/shared/Loading.tsx ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ export default function Loading() {
2
+ return (
3
+ <div className="flex flex-col items-center justify-center py-20" data-testid="loading">
4
+ <div className="h-10 w-10 animate-spin rounded-full border-4 border-gray-200 border-t-[#1a73e8]" />
5
+ <p className="mt-4 text-sm text-gray-500">Searching flights...</p>
6
+ </div>
7
+ );
8
+ }
frontend/src/hooks/useDebounce.ts ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from 'react';
2
+
3
+ export function useDebounce<T>(value: T, delay: number): T {
4
+ const [debounced, setDebounced] = useState(value);
5
+
6
+ useEffect(() => {
7
+ const timer = setTimeout(() => setDebounced(value), delay);
8
+ return () => clearTimeout(timer);
9
+ }, [value, delay]);
10
+
11
+ return debounced;
12
+ }
frontend/src/hooks/useFlightSearch.ts ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useCallback, useState } from 'react';
2
+ import { searchFlights } from '../api/client';
3
+ import type { CabinClass, Filters, FlightOffer, Passengers, SearchRequest, SortBy, TripType } from '../api/types';
4
+
5
+ interface SearchState {
6
+ outboundFlights: FlightOffer[];
7
+ returnFlights: FlightOffer[];
8
+ loading: boolean;
9
+ error: string | null;
10
+ searched: boolean;
11
+ }
12
+
13
+ export function useFlightSearch() {
14
+ const [state, setState] = useState<SearchState>({
15
+ outboundFlights: [],
16
+ returnFlights: [],
17
+ loading: false,
18
+ error: null,
19
+ searched: false,
20
+ });
21
+
22
+ const search = useCallback(async (params: {
23
+ tripType: TripType;
24
+ origin: string;
25
+ destination: string;
26
+ departureDate: string;
27
+ returnDate?: string;
28
+ passengers: Passengers;
29
+ cabinClass: CabinClass;
30
+ filters: Filters;
31
+ sortBy: SortBy;
32
+ }) => {
33
+ setState(s => ({ ...s, loading: true, error: null }));
34
+
35
+ const legs = [{ origin: params.origin, destination: params.destination, date: params.departureDate }];
36
+ if (params.tripType === 'round_trip' && params.returnDate) {
37
+ legs.push({ origin: params.destination, destination: params.origin, date: params.returnDate });
38
+ }
39
+
40
+ const req: SearchRequest = {
41
+ trip_type: params.tripType,
42
+ legs,
43
+ passengers: params.passengers,
44
+ cabin_class: params.cabinClass,
45
+ filters: params.filters,
46
+ sort_by: params.sortBy,
47
+ };
48
+
49
+ try {
50
+ const res = await searchFlights(req);
51
+ setState({
52
+ outboundFlights: res.outbound_flights,
53
+ returnFlights: res.return_flights,
54
+ loading: false,
55
+ error: null,
56
+ searched: true,
57
+ });
58
+ } catch (err) {
59
+ setState({
60
+ outboundFlights: [],
61
+ returnFlights: [],
62
+ loading: false,
63
+ error: err instanceof Error ? err.message : 'Search failed',
64
+ searched: true,
65
+ });
66
+ }
67
+ }, []);
68
+
69
+ return { ...state, search };
70
+ }