AnikS22 commited on
Commit
2a62959
·
verified ·
1 Parent(s): 359815a

Deploy MidasMap Gradio app; weights downloaded from model repo at runtime

Browse files
Files changed (5) hide show
  1. README.md +45 -17
  2. app.py +326 -115
  3. requirements.txt +2 -0
  4. src/ensemble.py +32 -3
  5. src/model.py +4 -2
README.md CHANGED
@@ -12,27 +12,55 @@ license: mit
12
 
13
  # MidasMap Space
14
 
15
- This folder is a **template** for creating a [Hugging Face Space](https://huggingface.co/docs/hub/spaces-overview).
16
 
17
- **Why not Vercel for the model?** Vercel serverless functions have strict size and time limits; they are not meant for PyTorch + a ~100MB checkpoint and multi-second GPU/CPU inference. **Host the Gradio app + weights on a Space** (CPU free tier or GPU upgrade).
18
 
19
- ## Create the Space
20
 
21
- 1. On Hugging Face: **New Space** → SDK **Gradio** → name e.g. `MidasMap`.
22
- 2. Clone the Space repo locally, or connect **GitHub** and set the Space root to this monorepo with **App file** pointing to the copied `app.py`.
23
- 3. Copy into the Space repository root:
24
- - `app.py` from the main MidasMap repo (project root), **or** symlink / duplicate.
25
- - `src/` (entire package)
26
- - `requirements-space.txt` from this folder as **`requirements.txt`**
27
- 4. In Space **Settings → Repository secrets** (if needed): none required for public weights.
28
- 5. Ensure `checkpoints/final/final_model.pth` is present:
29
- - Upload via **Files** tab, or
30
- - Add a startup script to download from `AnikS22/MidasMap` on the Hub (see HF docs for `hf_hub_download`).
31
 
32
- After the Space builds, point your **Vercel** site (`vercel-site`) at it:
33
 
34
- `https://yoursite.vercel.app/?embed=https://huggingface.co/spaces/YOUR_USER/YOUR_SPACE`
 
 
 
 
35
 
36
- ---
 
 
 
 
 
 
 
 
 
 
37
 
38
- Gradio app and model logic: [github.com/AnikS22/MidasMap](https://github.com/AnikS22/MidasMap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # MidasMap Space
14
 
15
+ Gradio demo for **[MidasMap](https://github.com/AnikS22/MidasMap)** (immunogold particle detection in TEM synapse images).
16
 
17
+ ## Deploy from your laptop
18
 
19
+ From the **MidasMap** repo root:
20
 
21
+ ```bash
22
+ export HF_TOKEN=hf_... # write token
23
+ # Recommended: do not upload the ~100MB checkpoint into the Space (avoids LFS / size issues).
24
+ export HF_SPACE_SKIP_CHECKPOINT=1
25
+ ./scripts/upload_hf_space.sh
26
+ ```
 
 
 
 
27
 
28
+ If **`upload_hf_space.sh` fails**, use **git + LFS** instead (often more reliable):
29
 
30
+ ```bash
31
+ brew install git-lfs && git lfs install # once
32
+ export HF_TOKEN=hf_...
33
+ ./scripts/push_hf_space_git.sh
34
+ ```
35
 
36
+ Full options: [docs/DEPLOY.md](../docs/DEPLOY.md) in the main repo.
37
+
38
+ Create the Space once if needed (Gradio SDK required for auto-create):
39
+
40
+ ```bash
41
+ huggingface-cli repo create MidasMap --type space --space_sdk gradio -y
42
+ ```
43
+
44
+ Weights are loaded from the **model** repo `AnikS22/MidasMap` at `checkpoints/final/final_model.pth` when the file is not in the Space. Override with Space secrets / env: `MIDASMAP_HF_WEIGHTS_REPO`, `MIDASMAP_HF_WEIGHTS_FILE`.
45
+
46
+ To bundle the checkpoint in the Space instead (larger upload):
47
 
48
+ ```bash
49
+ export HF_SPACE_SKIP_CHECKPOINT=0
50
+ ./scripts/upload_hf_space.sh
51
+ ```
52
+
53
+ ## Troubleshooting uploads
54
+
55
+ | Symptom | What to do |
56
+ |--------|----------------|
57
+ | **401 / not logged in** | `export HF_TOKEN=hf_...` with a token that has **write** access, or `huggingface-cli login`. |
58
+ | **LFS / authorization / upload stuck** | Use `HF_SPACE_SKIP_CHECKPOINT=1` so only code uploads; ensure the **model** repo (not the Space) contains `checkpoints/final/final_model.pth`. |
59
+ | **Space does not exist** | Create it in the HF web UI (**New Space** → **Gradio**) or run `huggingface-cli repo create ... --type space --space_sdk gradio`. |
60
+ | **“No space_sdk provided”** | The Space repo must be created as **Gradio** (or pass `--space_sdk gradio` when using `repo create`). |
61
+ | **Model not found on Space** | First boot downloads weights from the Hub; public repos need no token. Private model repo: add `HF_TOKEN` as a Space **secret** (read). |
62
+ | **Still failing** | Try `pip install hf_transfer` and `export HF_HUB_ENABLE_HF_TRANSFER=1` before upload. Or use **git** + **git lfs** clone of the Space, copy files, commit, push. |
63
+
64
+ ## Vercel embed
65
+
66
+ `https://yoursite.vercel.app/?embed=https://huggingface.co/spaces/YOUR_USER/YOUR_SPACE`
app.py CHANGED
@@ -91,7 +91,11 @@ def load_model(checkpoint_path: str):
91
  if torch.backends.mps.is_available()
92
  else "cpu"
93
  )
94
- MODEL = ImmunogoldCenterNet(bifpn_channels=128, bifpn_rounds=2)
 
 
 
 
95
  ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
96
  MODEL.load_state_dict(ckpt["model_state_dict"])
97
  MODEL.to(DEVICE)
@@ -176,6 +180,94 @@ def _df_to_preview_html(df: pd.DataFrame) -> str:
176
  )
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def detect_particles(
180
  image_file,
181
  conf_threshold: float = 0.25,
@@ -250,8 +342,27 @@ def detect_particles(
250
 
251
  from skimage.transform import resize
252
 
253
- hm6_up = resize(hm_np[0], (h, w), order=1)
254
- hm12_up = resize(hm_np[1], (h, w), order=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  # --- Overlay (publication-style legend + scale bar) ---
257
  fig_overlay, ax = plt.subplots(figsize=(11, 11))
@@ -291,21 +402,54 @@ def detect_particles(
291
  overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
292
  plt.close(fig_overlay)
293
 
294
- # --- Heatmaps ---
295
- fig_hm, axes = plt.subplots(1, 2, figsize=(14, 6.2))
296
- axes[0].imshow(img, cmap="gray", aspect="equal")
297
- axes[0].imshow(hm6_up, cmap="magma", alpha=0.55, vmin=0, vmax=max(0.3, float(hm6_up.max())))
298
- axes[0].set_title(f"AMPA (6 nm) channel · n = {n_6nm}", fontsize=11)
299
- axes[0].axis("off")
300
-
301
- axes[1].imshow(img, cmap="gray", aspect="equal")
302
- axes[1].imshow(hm12_up, cmap="inferno", alpha=0.55, vmin=0, vmax=max(0.3, float(hm12_up.max())))
303
- axes[1].set_title(f"NR1 (12 nm) channel · n = {n_12nm}", fontsize=11)
304
- axes[1].axis("off")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  plt.tight_layout()
306
- fig_hm.canvas.draw()
307
- heatmap_img = np.asarray(fig_hm.canvas.renderer.buffer_rgba())[:, :, :3]
 
 
 
308
  plt.close(fig_hm)
 
 
 
 
 
 
 
 
 
 
309
 
310
  # --- Stats (µm where helpful) ---
311
  fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
@@ -411,58 +555,73 @@ def detect_particles(
411
 
412
 
413
  MM_CSS = """
414
- .gradio-container { max-width: 1320px !important; margin: auto !important; }
 
415
  .mm-brand-bar {
416
  display: flex; align-items: center; justify-content: space-between;
417
- flex-wrap: wrap; gap: 0.75rem;
418
- padding: 0.6rem 0 1.25rem;
419
- border-bottom: 1px solid var(--border-color-primary);
420
- margin-bottom: 1.25rem;
421
  }
422
  .mm-brand-bar span {
423
- font-size: 0.72rem; letter-spacing: 0.14em; text-transform: uppercase;
424
- color: var(--body-text-color-subdued); font-weight: 600;
425
  }
426
  .mm-hero {
427
- padding: 1.5rem 1.35rem 1.35rem;
428
- margin-bottom: 0.25rem;
429
- border-radius: 10px;
430
- background: linear-gradient(145deg, #0c4a6e22 0%, #0f172a 48%, #1e1b4b33 100%);
431
- border: 1px solid #33415588;
 
432
  }
433
  .mm-hero h1 {
434
  font-family: "Libre Baskerville", Georgia, serif;
435
  font-weight: 700;
436
  letter-spacing: -0.02em;
437
- margin: 0 0 0.4rem 0;
438
- font-size: 1.65rem;
439
- color: #f1f5f9;
440
  }
441
  .mm-hero .mm-sub {
442
- margin: 0 0 0.85rem 0;
443
- color: #94a3b8;
444
- font-size: 0.92rem;
445
- line-height: 1.55;
446
- max-width: 58ch;
447
  }
448
- .mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
449
  .mm-badge {
450
- font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.07em;
451
- padding: 0.2rem 0.5rem; border-radius: 4px;
452
- background: #0e749033; color: #99f6e4; border: 1px solid #14b8a644;
 
453
  }
454
- .mm-layout { display: flex; gap: 1.25rem; align-items: flex-start; flex-wrap: wrap; }
455
  .mm-sidebar {
456
- flex: 1 1 280px; max-width: 340px;
457
- padding: 1rem 1.1rem; border-radius: 10px;
458
- border: 1px solid var(--border-color-primary);
459
  background: var(--block-background-fill);
 
 
 
 
 
 
460
  }
461
- .mm-main { flex: 3 1 520px; min-width: 0; }
462
  .mm-panel-title {
463
- font-size: 0.7rem; text-transform: uppercase; letter-spacing: 0.1em;
464
- color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.65rem 0;
465
  }
 
 
 
 
 
 
 
466
  .mm-callout {
467
  margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
468
  background: #1e293b66; border: 1px solid var(--border-color-primary);
@@ -513,21 +672,10 @@ table.mm-table td { padding: 0.35rem 0.5rem; border-bottom: 1px solid #33415544;
513
 
514
 
515
  def build_app():
 
516
  theme = gr.themes.Soft(
517
- primary_hue=gr.themes.Color(
518
- c50="#f0fdfa",
519
- c100="#ccfbf1",
520
- c200="#99f6e4",
521
- c300="#5eead4",
522
- c400="#2dd4bf",
523
- c500="#14b8a6",
524
- c600="#0d9488",
525
- c700="#0f766e",
526
- c800="#115e59",
527
- c900="#134e4a",
528
- c950="#042f2e",
529
- ),
530
- neutral_hue=gr.themes.colors.slate,
531
  font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
532
  font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
533
  ).set(
@@ -539,51 +687,47 @@ def build_app():
539
  block_label_text_size="*text_sm",
540
  )
541
 
542
- head = """
543
- <link href="https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap" rel="stylesheet">
544
- """
545
-
546
  with gr.Blocks(
547
  title="MidasMap — Immunogold analysis",
548
  theme=theme,
549
  css=MM_CSS,
550
- head=head,
551
  ) as app:
552
  gr.HTML(
553
  """
554
  <div class="mm-brand-bar">
555
- <span>Quantitative EM · synapse immunogold</span>
556
- <span>Research use · validate critical counts manually</span>
557
  </div>
558
  <div class="mm-hero">
559
  <h1>MidasMap</h1>
560
  <p class="mm-sub">
561
- Automated particle picking for <strong>freeze-fracture replica immunolabeling (FFRIL)</strong> TEM:
562
- <strong>6 nm</strong> gold (AMPA receptors) and <strong>12 nm</strong> gold (NR1 / NMDA receptors).
563
- Coordinates export in <strong>µm</strong> for comparison to physiology and super-resolution data—set calibration to match your microscope.
564
  </p>
565
  <div class="mm-badge-row">
566
- <span class="mm-badge">FFRIL / TEM</span>
567
  <span class="mm-badge">CenterNet</span>
568
- <span class="mm-badge">CEM500K backbone</span>
569
- <span class="mm-badge">LOOCV F1 ≈ 0.94</span>
570
  </div>
571
  </div>
572
  """
573
  )
574
 
 
 
575
  with gr.Row(elem_classes=["mm-layout"]):
576
  with gr.Column(elem_classes=["mm-sidebar"]):
577
- gr.HTML('<p class="mm-panel-title">Micrograph & calibration</p>')
578
  image_input = gr.File(
579
- label="Upload image",
580
  file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
581
  )
582
  px_per_um_in = gr.Number(
583
  value=DEFAULT_PX_PER_UM,
584
- label="Calibration (pixels per µm)",
585
- info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the published training set. "
586
- "Update if your acquisition scale differs.",
587
  minimum=1,
588
  maximum=1e6,
589
  )
@@ -592,28 +736,64 @@ def build_app():
592
  maximum=0.95,
593
  value=0.25,
594
  step=0.05,
595
- label="Confidence threshold",
596
- info="Higher fewer, sharper peaks. Lower recall with more false positives.",
597
  )
598
- with gr.Accordion("Advanced · non-max suppression", open=False):
599
  nms_6nm = gr.Slider(
600
  minimum=1,
601
  maximum=9,
602
  value=3,
603
  step=2,
604
- label="NMS · 6 nm channel",
605
- info="Minimum spacing between AMPA peaks on the heatmap grid.",
606
  )
607
  nms_12nm = gr.Slider(
608
  minimum=1,
609
  maximum=9,
610
  value=5,
611
  step=2,
612
- label="NMS · 12 nm channel",
613
  )
614
  detect_btn = gr.Button("Run detection", variant="primary", size="lg")
615
 
616
- with gr.Accordion("For neuroscientists — interpretation", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  gr.Markdown(
618
  """
619
  #### What the model outputs
@@ -634,8 +814,9 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
634
  )
635
 
636
  with gr.Column(elem_classes=["mm-main"]):
 
637
  summary_md = gr.HTML(
638
- value="<p class='mm-callout'>Upload a synapse micrograph to begin. Adjust calibration before export if your scale differs from the default.</p>"
639
  )
640
  with gr.Tabs():
641
  with gr.Tab("Overlay"):
@@ -643,18 +824,21 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
643
  label="Detections + scale bar",
644
  type="numpy",
645
  height=540,
 
646
  )
647
  with gr.Tab("Heatmaps"):
648
  heatmap_output = gr.Image(
649
  label="Class-specific maps",
650
  type="numpy",
651
  height=540,
 
652
  )
653
- with gr.Tab("Quant summary"):
654
  stats_output = gr.Image(
655
- label="Distributions & table",
656
  type="numpy",
657
  height=440,
 
658
  )
659
  with gr.Tab("Table & export"):
660
  table_output = gr.HTML(
@@ -674,8 +858,10 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
674
  """
675
  )
676
 
 
 
677
  detect_btn.click(
678
- fn=detect_particles,
679
  inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
680
  outputs=[
681
  overlay_output,
@@ -684,12 +870,52 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
684
  csv_output,
685
  table_output,
686
  summary_md,
 
687
  ],
688
- )
 
 
 
689
 
690
  return app
691
 
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  def main():
694
  parser = argparse.ArgumentParser(description="MidasMap web dashboard")
695
  parser.add_argument(
@@ -712,39 +938,24 @@ def main():
712
  if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
713
  args.share = True
714
 
715
- ckpt = Path(args.checkpoint)
716
- if not ckpt.is_file():
717
- raise SystemExit(
718
- f"Checkpoint not found: {ckpt}\n"
719
- "Train with train_final.py or download from Hugging Face:\n"
720
- " huggingface-cli download AnikS22/MidasMap checkpoints/final/final_model.pth "
721
- "--local-dir ."
722
- )
723
 
724
  load_model(str(ckpt))
725
  demo = build_app()
 
726
  launch_kw = dict(
727
  share=args.share,
728
- server_port=args.port,
729
  server_name=args.server_name,
730
  show_api=False,
731
  inbrowser=False,
732
  )
733
- try:
734
- demo.launch(**launch_kw)
735
- except ValueError as err:
736
- if (
737
- "localhost is not accessible" in str(err)
738
- and not launch_kw.get("share")
739
- and os.environ.get("GRADIO_SHARE", "").lower() not in ("1", "true", "yes")
740
- ):
741
- print(
742
- "Localhost check failed in this environment; starting with share=True "
743
- "(Gradio tunnel). Use --share next time, or set GRADIO_SHARE=1."
744
- )
745
- build_app().launch(**{**launch_kw, "share": True})
746
- else:
747
- raise
748
 
749
 
750
  if __name__ == "__main__":
 
91
  if torch.backends.mps.is_available()
92
  else "cpu"
93
  )
94
+ MODEL = ImmunogoldCenterNet(
95
+ bifpn_channels=128,
96
+ bifpn_rounds=2,
97
+ imagenet_encoder_fallback=False,
98
+ )
99
  ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
100
  MODEL.load_state_dict(ckpt["model_state_dict"])
101
  MODEL.to(DEVICE)
 
180
  )
181
 
182
 
183
+ def _numpy_image_to_uint8_rgb(img: np.ndarray) -> np.ndarray:
184
+ """Normalize various arrays to HxWx3 uint8 for cropping / display."""
185
+ if img is None:
186
+ return None
187
+ arr = np.asarray(img)
188
+ if arr.size == 0:
189
+ return None
190
+ if arr.ndim == 2:
191
+ arr = np.stack([arr, arr, arr], axis=-1)
192
+ elif arr.ndim == 3 and arr.shape[2] == 4:
193
+ arr = arr[:, :, :3]
194
+ if arr.dtype in (np.float32, np.float64):
195
+ mx = float(arr.max()) if arr.size else 1.0
196
+ if mx <= 1.0:
197
+ arr = (np.clip(arr, 0, 1) * 255.0).astype(np.uint8)
198
+ else:
199
+ arr = np.clip(arr, 0, 255).astype(np.uint8)
200
+ else:
201
+ arr = np.clip(arr, 0, 255).astype(np.uint8)
202
+ return arr
203
+
204
+
205
+ def magnifier_zoom(
206
+ store: dict,
207
+ view: str,
208
+ center_x_pct: float,
209
+ center_y_pct: float,
210
+ zoom: float,
211
+ output_px: int,
212
+ ) -> np.ndarray | None:
213
+ """
214
+ Crop a square region around (center_x_pct, center_y_pct) and upscale for a loupe view.
215
+ zoom: 1 = see ~full width in loupe; larger = stronger magnification (smaller crop).
216
+ """
217
+ if not store or not isinstance(store, dict):
218
+ return None
219
+ key = {"Overlay": "overlay", "Heatmaps": "heatmap", "Summary": "stats"}.get(view, "overlay")
220
+ img = _numpy_image_to_uint8_rgb(store.get(key))
221
+ if img is None:
222
+ return None
223
+ h, w = img.shape[:2]
224
+ cx = int(np.clip(center_x_pct / 100.0 * (w - 1), 0, w - 1))
225
+ cy = int(np.clip(center_y_pct / 100.0 * (h - 1), 0, h - 1))
226
+ z = max(1.0, float(zoom))
227
+ half_w = max(1, int(w / (2.0 * z)))
228
+ half_h = max(1, int(h / (2.0 * z)))
229
+ x0, x1 = max(0, cx - half_w), min(w, cx + half_w)
230
+ y0, y1 = max(0, cy - half_h), min(h, cy + half_h)
231
+ if x1 <= x0 or y1 <= y0:
232
+ crop = img
233
+ else:
234
+ crop = img[y0:y1, x0:x1]
235
+ side = int(np.clip(output_px, 256, 1024))
236
+ try:
237
+ from PIL import Image as PILImage
238
+
239
+ pil = PILImage.fromarray(crop)
240
+ pil = pil.resize((side, side), PILImage.Resampling.LANCZOS)
241
+ return np.asarray(pil)
242
+ except Exception:
243
+ from skimage.transform import resize
244
+
245
+ up = resize(crop, (side, side), order=1, preserve_range=True)
246
+ return np.clip(up, 0, 255).astype(np.uint8)
247
+
248
+
249
+ def run_detection(
250
+ image_file,
251
+ conf_threshold: float,
252
+ nms_6nm: int,
253
+ nms_12nm: int,
254
+ px_per_um: float,
255
+ progress=gr.Progress(track_tqdm=False),
256
+ ):
257
+ """Run model and return outputs plus viz state for the magnifier."""
258
+ out = detect_particles(
259
+ image_file,
260
+ conf_threshold,
261
+ nms_6nm,
262
+ nms_12nm,
263
+ px_per_um,
264
+ progress=progress,
265
+ )
266
+ overlay, hm, stats, csvp, table, summary = out
267
+ store = {"overlay": overlay, "heatmap": hm, "stats": stats}
268
+ return overlay, hm, stats, csvp, table, summary, store
269
+
270
+
271
  def detect_particles(
272
  image_file,
273
  conf_threshold: float = 0.25,
 
342
 
343
  from skimage.transform import resize
344
 
345
+ hm6_up = np.clip(
346
+ np.nan_to_num(resize(hm_np[0], (h, w), order=1), nan=0.0),
347
+ 0.0,
348
+ 1.0,
349
+ )
350
+ hm12_up = np.clip(
351
+ np.nan_to_num(resize(hm_np[1], (h, w), order=1), nan=0.0),
352
+ 0.0,
353
+ 1.0,
354
+ )
355
+
356
+ def _heatmap_vmax(hm: np.ndarray) -> float:
357
+ """Stable color scale: avoid invisible overlays when max is tiny or flat."""
358
+ flat = hm.ravel()
359
+ if flat.size == 0:
360
+ return 0.3
361
+ mx = float(np.max(flat))
362
+ if mx < 1e-6:
363
+ return 0.3
364
+ p99 = float(np.percentile(flat, 99.0))
365
+ return float(np.clip(max(0.12, p99 * 1.05, mx * 0.95), 0.05, 1.0))
366
 
367
  # --- Overlay (publication-style legend + scale bar) ---
368
  fig_overlay, ax = plt.subplots(figsize=(11, 11))
 
402
  overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
403
  plt.close(fig_overlay)
404
 
405
+ # --- Heatmaps: row1 = overlay on EM; row2 = model heat only (debug-friendly) ---
406
+ # Training uses Gaussian GT; inference heatmaps are learned sigmoid blobs, not analytic Gaussians.
407
+ v6, v12 = _heatmap_vmax(hm6_up), _heatmap_vmax(hm12_up)
408
+ fig_hm, axes = plt.subplots(2, 2, figsize=(14, 12))
409
+ ax00, ax01 = axes[0]
410
+ ax10, ax11 = axes[1]
411
+
412
+ for ax, hm, v, cmap, title in (
413
+ (ax00, hm6_up, v6, "magma", f"AMPA overlay · n={n_6nm} · vmax={v6:.2f}"),
414
+ (ax01, hm12_up, v12, "inferno", f"NR1 overlay · n={n_12nm} · vmax={v12:.2f}"),
415
+ ):
416
+ ax.imshow(img, cmap="gray", aspect="equal", interpolation="nearest")
417
+ ax.imshow(
418
+ hm,
419
+ cmap=cmap,
420
+ alpha=0.6,
421
+ vmin=0.0,
422
+ vmax=v,
423
+ interpolation="bilinear",
424
+ )
425
+ ax.set_title(title, fontsize=10)
426
+ ax.axis("off")
427
+
428
+ ax10.imshow(hm6_up, cmap="magma", vmin=0.0, vmax=v6, interpolation="nearest")
429
+ ax10.set_title(f"AMPA heatmap only · max={float(np.max(hm6_up)):.4f}", fontsize=10)
430
+ ax10.axis("off")
431
+
432
+ ax11.imshow(hm12_up, cmap="inferno", vmin=0.0, vmax=v12, interpolation="nearest")
433
+ ax11.set_title(f"NR1 heatmap only · max={float(np.max(hm12_up)):.4f}", fontsize=10)
434
+ ax11.axis("off")
435
+
436
  plt.tight_layout()
437
+ # PNG raster → uint8 RGB (reliable in Gradio vs raw canvas buffer on some setups)
438
+ from io import BytesIO
439
+
440
+ _buf = BytesIO()
441
+ fig_hm.savefig(_buf, format="png", dpi=120, bbox_inches="tight", facecolor="white")
442
  plt.close(fig_hm)
443
+ _buf.seek(0)
444
+ try:
445
+ from PIL import Image as _PILImage
446
+
447
+ heatmap_img = np.asarray(_PILImage.open(_buf).convert("RGB"))
448
+ except Exception:
449
+ import matplotlib.image as _mimg
450
+
451
+ _buf.seek(0)
452
+ heatmap_img = (_mimg.imread(_buf)[:, :, :3] * 255.0).clip(0, 255).astype(np.uint8)
453
 
454
  # --- Stats (µm where helpful) ---
455
  fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
 
555
 
556
 
557
  MM_CSS = """
558
+ @import url("https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap");
559
+ .gradio-container { max-width: 1280px !important; margin: auto !important; padding: 1rem 0.75rem 2rem !important; }
560
  .mm-brand-bar {
561
  display: flex; align-items: center; justify-content: space-between;
562
+ flex-wrap: wrap; gap: 0.5rem 1rem;
563
+ padding: 0 0 1rem;
564
+ margin-bottom: 1rem;
565
+ border-bottom: 1px solid rgba(148, 163, 184, 0.2);
566
  }
567
  .mm-brand-bar span {
568
+ font-size: 0.7rem; letter-spacing: 0.06em;
569
+ color: var(--body-text-color-subdued); font-weight: 500;
570
  }
571
  .mm-hero {
572
+ padding: 1.35rem 1.5rem;
573
+ margin-bottom: 1.25rem;
574
+ border-radius: 16px;
575
+ background: linear-gradient(155deg, rgba(13, 148, 136, 0.12) 0%, rgba(15, 23, 42, 0.95) 42%, rgba(30, 27, 75, 0.15) 100%);
576
+ border: 1px solid rgba(148, 163, 184, 0.15);
577
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
578
  }
579
  .mm-hero h1 {
580
  font-family: "Libre Baskerville", Georgia, serif;
581
  font-weight: 700;
582
  letter-spacing: -0.02em;
583
+ margin: 0 0 0.5rem 0;
584
+ font-size: 1.75rem;
585
+ color: #f8fafc;
586
  }
587
  .mm-hero .mm-sub {
588
+ margin: 0 0 1rem 0;
589
+ color: #cbd5e1;
590
+ font-size: 0.95rem;
591
+ line-height: 1.6;
592
+ max-width: 62ch;
593
  }
594
+ .mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.45rem; }
595
  .mm-badge {
596
+ font-size: 0.62rem; letter-spacing: 0.05em; font-weight: 600;
597
+ padding: 0.28rem 0.55rem; border-radius: 999px;
598
+ background: rgba(45, 212, 191, 0.12); color: #5eead4;
599
+ border: 1px solid rgba(45, 212, 191, 0.25);
600
  }
601
+ .mm-layout { display: flex; gap: 1.5rem; align-items: flex-start; flex-wrap: wrap; }
602
  .mm-sidebar {
603
+ flex: 1 1 300px; max-width: 360px;
604
+ padding: 1.25rem 1.35rem; border-radius: 16px;
605
+ border: 1px solid rgba(148, 163, 184, 0.12);
606
  background: var(--block-background-fill);
607
+ box-shadow: 0 4px 24px rgba(0, 0, 0, 0.12);
608
+ }
609
+ .mm-main {
610
+ flex: 1 1 480px; min-width: 0;
611
+ padding: 0.25rem 0.15rem;
612
+ border-radius: 16px;
613
  }
 
614
  .mm-panel-title {
615
+ font-size: 0.72rem; text-transform: uppercase; letter-spacing: 0.08em;
616
+ color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.75rem 0;
617
  }
618
+ .mm-loupe-help {
619
+ font-size: 0.82rem; line-height: 1.45; color: var(--body-text-color-subdued);
620
+ margin: 0 0 0.75rem 0; padding: 0.65rem 0.85rem;
621
+ border-radius: 10px; background: rgba(30, 41, 59, 0.45);
622
+ border: 1px solid rgba(148, 163, 184, 0.12);
623
+ }
624
+ .tabs > .tab-nav button { font-weight: 500 !important; letter-spacing: 0.01em; }
625
  .mm-callout {
626
  margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
627
  background: #1e293b66; border: 1px solid var(--border-color-primary);
 
672
 
673
 
674
  def build_app():
675
+ # Use named hues only (no custom Color dicts): avoids Gradio/Jinja template bugs on some stacks (e.g. HF Spaces + Py3.13).
676
  theme = gr.themes.Soft(
677
+ primary_hue="teal",
678
+ neutral_hue="slate",
 
 
 
 
 
 
 
 
 
 
 
 
679
  font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
680
  font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
681
  ).set(
 
687
  block_label_text_size="*text_sm",
688
  )
689
 
 
 
 
 
690
  with gr.Blocks(
691
  title="MidasMap — Immunogold analysis",
692
  theme=theme,
693
  css=MM_CSS,
 
694
  ) as app:
695
  gr.HTML(
696
  """
697
  <div class="mm-brand-bar">
698
+ <span>MidasMap · immunogold on TEM synapses</span>
699
+ <span>For research verify important counts by eye</span>
700
  </div>
701
  <div class="mm-hero">
702
  <h1>MidasMap</h1>
703
  <p class="mm-sub">
704
+ Find <strong>6 nm</strong> (AMPA) and <strong>12 nm</strong> (NR1) gold particles in
705
+ <strong>FFRIL</strong> micrographs. Set <strong>calibration</strong> so exports are in µm.
706
+ Use the <strong>magnifying glass</strong> below to inspect beads and heatmaps up close.
707
  </p>
708
  <div class="mm-badge-row">
709
+ <span class="mm-badge">FFRIL</span>
710
  <span class="mm-badge">CenterNet</span>
711
+ <span class="mm-badge">CEM500K</span>
712
+ <span class="mm-badge">F1 ≈ 0.94 LOOCV</span>
713
  </div>
714
  </div>
715
  """
716
  )
717
 
718
+ viz_state = gr.State({"overlay": None, "heatmap": None, "stats": None})
719
+
720
  with gr.Row(elem_classes=["mm-layout"]):
721
  with gr.Column(elem_classes=["mm-sidebar"]):
722
+ gr.HTML('<p class="mm-panel-title">1 · Upload & settings</p>')
723
  image_input = gr.File(
724
+ label="Micrograph",
725
  file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
726
  )
727
  px_per_um_in = gr.Number(
728
  value=DEFAULT_PX_PER_UM,
729
+ label="Pixels per µm",
730
+ info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the training corpus. Change if your scale differs.",
 
731
  minimum=1,
732
  maximum=1e6,
733
  )
 
736
  maximum=0.95,
737
  value=0.25,
738
  step=0.05,
739
+ label="Confidence",
740
+ info="Higher = stricter (fewer hits). Lower = more sensitive.",
741
  )
742
+ with gr.Accordion("Advanced · peak spacing (NMS)", open=False):
743
  nms_6nm = gr.Slider(
744
  minimum=1,
745
  maximum=9,
746
  value=3,
747
  step=2,
748
+ label="Spacing · 6 nm channel",
749
+ info="Minimum gap between AMPA peaks on the model grid.",
750
  )
751
  nms_12nm = gr.Slider(
752
  minimum=1,
753
  maximum=9,
754
  value=5,
755
  step=2,
756
+ label="Spacing · 12 nm channel",
757
  )
758
  detect_btn = gr.Button("Run detection", variant="primary", size="lg")
759
 
760
+ with gr.Accordion("Magnifying glass", open=True):
761
+ gr.HTML(
762
+ """<p class="mm-loupe-help" style="margin-top:0">
763
+ After you run detection, pick which result to inspect and adjust the sliders.
764
+ <strong>Magnification</strong> zooms in (smaller crop, upscaled). Use the fullscreen icon on any image for a larger view.
765
+ </p>"""
766
+ )
767
+ mag_view = gr.Radio(
768
+ choices=["Overlay", "Heatmaps", "Summary"],
769
+ value="Overlay",
770
+ label="Source image",
771
+ )
772
+ mag_cx = gr.Slider(
773
+ 0, 100, value=50, step=0.5,
774
+ label="Pan left ↔ right (%)",
775
+ )
776
+ mag_cy = gr.Slider(
777
+ 0, 100, value=50, step=0.5,
778
+ label="Pan up ↔ down (%)",
779
+ )
780
+ mag_zoom = gr.Slider(
781
+ 1, 10, value=2.5, step=0.25,
782
+ label="Magnification",
783
+ info="Higher = stronger zoom (smaller region).",
784
+ )
785
+ mag_out = gr.Slider(
786
+ 256, 768, value=480, step=64,
787
+ label="Loupe window (px)",
788
+ )
789
+ mag_out_img = gr.Image(
790
+ label="Loupe preview",
791
+ type="numpy",
792
+ height=380,
793
+ show_fullscreen_button=True,
794
+ )
795
+
796
+ with gr.Accordion("Notes for scientists", open=False):
797
  gr.Markdown(
798
  """
799
  #### What the model outputs
 
814
  )
815
 
816
  with gr.Column(elem_classes=["mm-main"]):
817
+ gr.HTML('<p class="mm-panel-title">2 · Results</p>')
818
  summary_md = gr.HTML(
819
+ value="<p class='mm-callout'>Upload a micrograph and tap <strong>Run detection</strong>. Set pixels/µm before exporting if your scale differs.</p>"
820
  )
821
  with gr.Tabs():
822
  with gr.Tab("Overlay"):
 
824
  label="Detections + scale bar",
825
  type="numpy",
826
  height=540,
827
+ show_fullscreen_button=True,
828
  )
829
  with gr.Tab("Heatmaps"):
830
  heatmap_output = gr.Image(
831
  label="Class-specific maps",
832
  type="numpy",
833
  height=540,
834
+ show_fullscreen_button=True,
835
  )
836
+ with gr.Tab("Summary"):
837
  stats_output = gr.Image(
838
+ label="Counts & distributions",
839
  type="numpy",
840
  height=440,
841
+ show_fullscreen_button=True,
842
  )
843
  with gr.Tab("Table & export"):
844
  table_output = gr.HTML(
 
858
  """
859
  )
860
 
861
+ mag_inputs = [viz_state, mag_view, mag_cx, mag_cy, mag_zoom, mag_out]
862
+
863
  detect_btn.click(
864
+ fn=run_detection,
865
  inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
866
  outputs=[
867
  overlay_output,
 
870
  csv_output,
871
  table_output,
872
  summary_md,
873
+ viz_state,
874
  ],
875
+ ).then(magnifier_zoom, mag_inputs, mag_out_img)
876
+
877
+ for _ctrl in (mag_view, mag_cx, mag_cy, mag_zoom, mag_out):
878
+ _ctrl.change(magnifier_zoom, mag_inputs, mag_out_img)
879
 
880
  return app
881
 
882
 
883
+ def _running_on_hf_space() -> bool:
884
+ """Hugging Face Spaces injects these env vars; Gradio must bind 0.0.0.0 and never use share=True."""
885
+ return bool(
886
+ os.environ.get("SPACE_REPO_NAME")
887
+ or os.environ.get("SPACE_AUTHOR_NAME")
888
+ or os.environ.get("SPACE_ID")
889
+ )
890
+
891
+
892
+ def _resolve_checkpoint(ckpt: Path) -> Path:
893
+ """Use local .pth if present; on HF Space fetch from the Hub model repo if missing (smaller Space uploads)."""
894
+ if ckpt.is_file():
895
+ return ckpt
896
+ if _running_on_hf_space():
897
+ try:
898
+ from huggingface_hub import hf_hub_download
899
+ except ImportError as e:
900
+ raise SystemExit(
901
+ "huggingface_hub is required on the Space to download weights. "
902
+ "Add it to requirements.txt or bundle checkpoints/final/final_model.pth in the Space."
903
+ ) from e
904
+ repo_id = os.environ.get("MIDASMAP_HF_WEIGHTS_REPO", "AnikS22/MidasMap").strip()
905
+ filename = os.environ.get(
906
+ "MIDASMAP_HF_WEIGHTS_FILE", "checkpoints/final/final_model.pth"
907
+ ).strip()
908
+ print(f"Checkpoint not found at {ckpt}; downloading {filename} from model repo {repo_id} ...")
909
+ cached = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
910
+ return Path(cached)
911
+ raise SystemExit(
912
+ f"Checkpoint not found: {ckpt}\n"
913
+ "Train with train_final.py or download from Hugging Face:\n"
914
+ " huggingface-cli download AnikS22/MidasMap checkpoints/final/final_model.pth "
915
+ "--local-dir . --repo-type model"
916
+ )
917
+
918
+
919
  def main():
920
  parser = argparse.ArgumentParser(description="MidasMap web dashboard")
921
  parser.add_argument(
 
938
  if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
939
  args.share = True
940
 
941
+ if _running_on_hf_space():
942
+ args.share = False
943
+ if not args.server_name:
944
+ args.server_name = "0.0.0.0"
945
+
946
+ ckpt = _resolve_checkpoint(Path(args.checkpoint))
 
 
947
 
948
  load_model(str(ckpt))
949
  demo = build_app()
950
+ port = int(os.environ.get("GRADIO_SERVER_PORT", os.environ.get("PORT", str(args.port))))
951
  launch_kw = dict(
952
  share=args.share,
953
+ server_port=port,
954
  server_name=args.server_name,
955
  show_api=False,
956
  inbrowser=False,
957
  )
958
+ demo.launch(**launch_kw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959
 
960
 
961
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -14,5 +14,7 @@ PyYAML>=6.0
14
  albumentations>=1.3.0
15
  opencv-python-headless>=4.7.0
16
  gradio==4.44.1
 
 
17
  huggingface_hub>=0.20.0,<0.25.0
18
  tqdm>=4.65.0
 
14
  albumentations>=1.3.0
15
  opencv-python-headless>=4.7.0
16
  gradio==4.44.1
17
+ # Avoid Jinja2 3.2+ cache key issues with some Gradio/Starlette stacks on HF Spaces.
18
+ jinja2>=3.1.0,<3.2.0
19
  huggingface_hub>=0.20.0,<0.25.0
20
  tqdm>=4.65.0
src/ensemble.py CHANGED
@@ -163,6 +163,22 @@ def ensemble_predict(
163
  return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def sliding_window_inference(
167
  model: ImmunogoldCenterNet,
168
  image: np.ndarray,
@@ -188,10 +204,18 @@ def sliding_window_inference(
188
  offsets: (2, H/2, W/2) numpy array
189
  """
190
  model.eval()
 
 
 
 
 
 
 
 
191
  h, w = image.shape[:2]
192
  stride_step = patch_size - overlap
193
 
194
- # Output dimensions at model stride
195
  out_h = h // 2
196
  out_w = w // 2
197
  out_patch = patch_size // 2
@@ -200,8 +224,8 @@ def sliding_window_inference(
200
  offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
201
  count = np.zeros((out_h, out_w), dtype=np.float32)
202
 
203
- for y0 in range(0, h - patch_size + 1, stride_step):
204
- for x0 in range(0, w - patch_size + 1, stride_step):
205
  patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
206
  tensor = (
207
  torch.from_numpy(patch)
@@ -233,4 +257,9 @@ def sliding_window_inference(
233
  count = np.maximum(count, 1)
234
  offsets /= count[np.newaxis, :, :]
235
 
 
 
 
 
 
236
  return heatmap, offsets
 
163
  return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
164
 
165
 
166
+ def _tile_origins(axis_len: int, patch: int, stride_step: int) -> list:
167
+ """
168
+ Starting indices for sliding windows along one axis so the last window
169
+ flush-aligns with the far edge. Plain range(0, n-patch+1, step) misses
170
+ the bottom/right of most image sizes (e.g. 2048 with patch 512, step 384),
171
+ leaving heatmap strips at zero.
172
+ """
173
+ if axis_len <= patch:
174
+ return [0]
175
+ last = axis_len - patch
176
+ starts = list(range(0, last + 1, stride_step))
177
+ if starts[-1] != last:
178
+ starts.append(last)
179
+ return starts
180
+
181
+
182
  def sliding_window_inference(
183
  model: ImmunogoldCenterNet,
184
  image: np.ndarray,
 
204
  offsets: (2, H/2, W/2) numpy array
205
  """
206
  model.eval()
207
+ orig_h, orig_w = image.shape[:2]
208
+ # Pad bottom/right so each dim >= patch_size; otherwise range() for tiles is empty
209
+ # and heatmaps stay all zeros (looks like a "broken" heatmap in the UI).
210
+ pad_h = max(0, patch_size - orig_h)
211
+ pad_w = max(0, patch_size - orig_w)
212
+ if pad_h > 0 or pad_w > 0:
213
+ image = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")
214
+
215
  h, w = image.shape[:2]
216
  stride_step = patch_size - overlap
217
 
218
+ # Output dimensions at model stride (padded image)
219
  out_h = h // 2
220
  out_w = w // 2
221
  out_patch = patch_size // 2
 
224
  offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
225
  count = np.zeros((out_h, out_w), dtype=np.float32)
226
 
227
+ for y0 in _tile_origins(h, patch_size, stride_step):
228
+ for x0 in _tile_origins(w, patch_size, stride_step):
229
  patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
230
  tensor = (
231
  torch.from_numpy(patch)
 
257
  count = np.maximum(count, 1)
258
  offsets /= count[np.newaxis, :, :]
259
 
260
+ # Crop back to original (pre-pad) spatial extent in heatmap space
261
+ crop_h, crop_w = orig_h // 2, orig_w // 2
262
+ heatmap = heatmap[:, :crop_h, :crop_w]
263
+ offsets = offsets[:, :crop_h, :crop_w]
264
+
265
  return heatmap, offsets
src/model.py CHANGED
@@ -215,6 +215,7 @@ class ImmunogoldCenterNet(nn.Module):
215
  bifpn_channels: int = 128,
216
  bifpn_rounds: int = 2,
217
  num_classes: int = 2,
 
218
  ):
219
  super().__init__()
220
  self.num_classes = num_classes
@@ -229,13 +230,14 @@ class ImmunogoldCenterNet(nn.Module):
229
  # Load pretrained weights
230
  if pretrained_path:
231
  self._load_pretrained(backbone, pretrained_path)
232
- else:
233
- # Use ImageNet weights as fallback, adapting conv1
234
  imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
235
  state = imagenet_backbone.state_dict()
236
  # Mean-pool RGB conv1 weights → grayscale
237
  state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
238
  backbone.load_state_dict(state, strict=False)
 
239
 
240
  # Extract encoder stages
241
  self.stem = nn.Sequential(
 
215
  bifpn_channels: int = 128,
216
  bifpn_rounds: int = 2,
217
  num_classes: int = 2,
218
+ imagenet_encoder_fallback: bool = True,
219
  ):
220
  super().__init__()
221
  self.num_classes = num_classes
 
230
  # Load pretrained weights
231
  if pretrained_path:
232
  self._load_pretrained(backbone, pretrained_path)
233
+ elif imagenet_encoder_fallback:
234
+ # Training: better init when CEM500K path is missing (downloads ~100MB).
235
  imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
236
  state = imagenet_backbone.state_dict()
237
  # Mean-pool RGB conv1 weights → grayscale
238
  state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
239
  backbone.load_state_dict(state, strict=False)
240
+ # else: random encoder init — use when loading a full checkpoint immediately (Gradio, predict).
241
 
242
  # Extract encoder stages
243
  self.stem = nn.Sequential(