Deploy MidasMap Gradio app; weights downloaded from model repo at runtime
Browse files- README.md +45 -17
- app.py +326 -115
- requirements.txt +2 -0
- src/ensemble.py +32 -3
- src/model.py +4 -2
README.md
CHANGED
|
@@ -12,27 +12,55 @@ license: mit
|
|
| 12 |
|
| 13 |
# MidasMap Space
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
4. In Space **Settings → Repository secrets** (if needed): none required for public weights.
|
| 28 |
-
5. Ensure `checkpoints/final/final_model.pth` is present:
|
| 29 |
-
- Upload via **Files** tab, or
|
| 30 |
-
- Add a startup script to download from `AnikS22/MidasMap` on the Hub (see HF docs for `hf_hub_download`).
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# MidasMap Space
|
| 14 |
|
| 15 |
+
Gradio demo for **[MidasMap](https://github.com/AnikS22/MidasMap)** (immunogold particle detection in TEM synapse images).
|
| 16 |
|
| 17 |
+
## Deploy from your laptop
|
| 18 |
|
| 19 |
+
From the **MidasMap** repo root:
|
| 20 |
|
| 21 |
+
```bash
|
| 22 |
+
export HF_TOKEN=hf_... # write token
|
| 23 |
+
# Recommended: do not upload the ~100MB checkpoint into the Space (avoids LFS / size issues).
|
| 24 |
+
export HF_SPACE_SKIP_CHECKPOINT=1
|
| 25 |
+
./scripts/upload_hf_space.sh
|
| 26 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
If **`upload_hf_space.sh` fails**, use **git + LFS** instead (often more reliable):
|
| 29 |
|
| 30 |
+
```bash
|
| 31 |
+
brew install git-lfs && git lfs install # once
|
| 32 |
+
export HF_TOKEN=hf_...
|
| 33 |
+
./scripts/push_hf_space_git.sh
|
| 34 |
+
```
|
| 35 |
|
| 36 |
+
Full options: [docs/DEPLOY.md](../docs/DEPLOY.md) in the main repo.
|
| 37 |
+
|
| 38 |
+
Create the Space once if needed (Gradio SDK required for auto-create):
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
huggingface-cli repo create MidasMap --type space --space_sdk gradio -y
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Weights are loaded from the **model** repo `AnikS22/MidasMap` at `checkpoints/final/final_model.pth` when the file is not in the Space. Override with Space secrets / env: `MIDASMAP_HF_WEIGHTS_REPO`, `MIDASMAP_HF_WEIGHTS_FILE`.
|
| 45 |
+
|
| 46 |
+
To bundle the checkpoint in the Space instead (larger upload):
|
| 47 |
|
| 48 |
+
```bash
|
| 49 |
+
export HF_SPACE_SKIP_CHECKPOINT=0
|
| 50 |
+
./scripts/upload_hf_space.sh
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Troubleshooting uploads
|
| 54 |
+
|
| 55 |
+
| Symptom | What to do |
|
| 56 |
+
|--------|----------------|
|
| 57 |
+
| **401 / not logged in** | `export HF_TOKEN=hf_...` with a token that has **write** access, or `huggingface-cli login`. |
|
| 58 |
+
| **LFS / authorization / upload stuck** | Use `HF_SPACE_SKIP_CHECKPOINT=1` so only code uploads; ensure the **model** repo (not the Space) contains `checkpoints/final/final_model.pth`. |
|
| 59 |
+
| **Space does not exist** | Create it in the HF web UI (**New Space** → **Gradio**) or run `huggingface-cli repo create ... --type space --space_sdk gradio`. |
|
| 60 |
+
| **“No space_sdk provided”** | The Space repo must be created as **Gradio** (or pass `--space_sdk gradio` when using `repo create`). |
|
| 61 |
+
| **Model not found on Space** | First boot downloads weights from the Hub; public repos need no token. Private model repo: add `HF_TOKEN` as a Space **secret** (read). |
|
| 62 |
+
| **Still failing** | Try `pip install hf_transfer` and `export HF_HUB_ENABLE_HF_TRANSFER=1` before upload. Or use **git** + **git lfs** clone of the Space, copy files, commit, push. |
|
| 63 |
+
|
| 64 |
+
## Vercel embed
|
| 65 |
+
|
| 66 |
+
`https://yoursite.vercel.app/?embed=https://huggingface.co/spaces/YOUR_USER/YOUR_SPACE`
|
app.py
CHANGED
|
@@ -91,7 +91,11 @@ def load_model(checkpoint_path: str):
|
|
| 91 |
if torch.backends.mps.is_available()
|
| 92 |
else "cpu"
|
| 93 |
)
|
| 94 |
-
MODEL = ImmunogoldCenterNet(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 96 |
MODEL.load_state_dict(ckpt["model_state_dict"])
|
| 97 |
MODEL.to(DEVICE)
|
|
@@ -176,6 +180,94 @@ def _df_to_preview_html(df: pd.DataFrame) -> str:
|
|
| 176 |
)
|
| 177 |
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def detect_particles(
|
| 180 |
image_file,
|
| 181 |
conf_threshold: float = 0.25,
|
|
@@ -250,8 +342,27 @@ def detect_particles(
|
|
| 250 |
|
| 251 |
from skimage.transform import resize
|
| 252 |
|
| 253 |
-
hm6_up =
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
# --- Overlay (publication-style legend + scale bar) ---
|
| 257 |
fig_overlay, ax = plt.subplots(figsize=(11, 11))
|
|
@@ -291,21 +402,54 @@ def detect_particles(
|
|
| 291 |
overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
|
| 292 |
plt.close(fig_overlay)
|
| 293 |
|
| 294 |
-
# --- Heatmaps ---
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
axes[
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
plt.tight_layout()
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
| 308 |
plt.close(fig_hm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
# --- Stats (µm where helpful) ---
|
| 311 |
fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
|
|
@@ -411,58 +555,73 @@ def detect_particles(
|
|
| 411 |
|
| 412 |
|
| 413 |
MM_CSS = """
|
| 414 |
-
|
|
|
|
| 415 |
.mm-brand-bar {
|
| 416 |
display: flex; align-items: center; justify-content: space-between;
|
| 417 |
-
flex-wrap: wrap; gap: 0.
|
| 418 |
-
padding: 0
|
| 419 |
-
|
| 420 |
-
|
| 421 |
}
|
| 422 |
.mm-brand-bar span {
|
| 423 |
-
font-size: 0.
|
| 424 |
-
color: var(--body-text-color-subdued); font-weight:
|
| 425 |
}
|
| 426 |
.mm-hero {
|
| 427 |
-
padding: 1.
|
| 428 |
-
margin-bottom:
|
| 429 |
-
border-radius:
|
| 430 |
-
background: linear-gradient(
|
| 431 |
-
border: 1px solid
|
|
|
|
| 432 |
}
|
| 433 |
.mm-hero h1 {
|
| 434 |
font-family: "Libre Baskerville", Georgia, serif;
|
| 435 |
font-weight: 700;
|
| 436 |
letter-spacing: -0.02em;
|
| 437 |
-
margin: 0 0 0.
|
| 438 |
-
font-size: 1.
|
| 439 |
-
color: #
|
| 440 |
}
|
| 441 |
.mm-hero .mm-sub {
|
| 442 |
-
margin: 0 0
|
| 443 |
-
color: #
|
| 444 |
-
font-size: 0.
|
| 445 |
-
line-height: 1.
|
| 446 |
-
max-width:
|
| 447 |
}
|
| 448 |
-
.mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.
|
| 449 |
.mm-badge {
|
| 450 |
-
font-size: 0.
|
| 451 |
-
padding: 0.
|
| 452 |
-
background:
|
|
|
|
| 453 |
}
|
| 454 |
-
.mm-layout { display: flex; gap: 1.
|
| 455 |
.mm-sidebar {
|
| 456 |
-
flex: 1 1
|
| 457 |
-
padding:
|
| 458 |
-
border: 1px solid
|
| 459 |
background: var(--block-background-fill);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
}
|
| 461 |
-
.mm-main { flex: 3 1 520px; min-width: 0; }
|
| 462 |
.mm-panel-title {
|
| 463 |
-
font-size: 0.
|
| 464 |
-
color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.
|
| 465 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
.mm-callout {
|
| 467 |
margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
|
| 468 |
background: #1e293b66; border: 1px solid var(--border-color-primary);
|
|
@@ -513,21 +672,10 @@ table.mm-table td { padding: 0.35rem 0.5rem; border-bottom: 1px solid #33415544;
|
|
| 513 |
|
| 514 |
|
| 515 |
def build_app():
|
|
|
|
| 516 |
theme = gr.themes.Soft(
|
| 517 |
-
primary_hue=
|
| 518 |
-
|
| 519 |
-
c100="#ccfbf1",
|
| 520 |
-
c200="#99f6e4",
|
| 521 |
-
c300="#5eead4",
|
| 522 |
-
c400="#2dd4bf",
|
| 523 |
-
c500="#14b8a6",
|
| 524 |
-
c600="#0d9488",
|
| 525 |
-
c700="#0f766e",
|
| 526 |
-
c800="#115e59",
|
| 527 |
-
c900="#134e4a",
|
| 528 |
-
c950="#042f2e",
|
| 529 |
-
),
|
| 530 |
-
neutral_hue=gr.themes.colors.slate,
|
| 531 |
font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
|
| 532 |
font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
|
| 533 |
).set(
|
|
@@ -539,51 +687,47 @@ def build_app():
|
|
| 539 |
block_label_text_size="*text_sm",
|
| 540 |
)
|
| 541 |
|
| 542 |
-
head = """
|
| 543 |
-
<link href="https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap" rel="stylesheet">
|
| 544 |
-
"""
|
| 545 |
-
|
| 546 |
with gr.Blocks(
|
| 547 |
title="MidasMap — Immunogold analysis",
|
| 548 |
theme=theme,
|
| 549 |
css=MM_CSS,
|
| 550 |
-
head=head,
|
| 551 |
) as app:
|
| 552 |
gr.HTML(
|
| 553 |
"""
|
| 554 |
<div class="mm-brand-bar">
|
| 555 |
-
<span>
|
| 556 |
-
<span>
|
| 557 |
</div>
|
| 558 |
<div class="mm-hero">
|
| 559 |
<h1>MidasMap</h1>
|
| 560 |
<p class="mm-sub">
|
| 561 |
-
|
| 562 |
-
<strong>
|
| 563 |
-
|
| 564 |
</p>
|
| 565 |
<div class="mm-badge-row">
|
| 566 |
-
<span class="mm-badge">FFRIL
|
| 567 |
<span class="mm-badge">CenterNet</span>
|
| 568 |
-
<span class="mm-badge">CEM500K
|
| 569 |
-
<span class="mm-badge">
|
| 570 |
</div>
|
| 571 |
</div>
|
| 572 |
"""
|
| 573 |
)
|
| 574 |
|
|
|
|
|
|
|
| 575 |
with gr.Row(elem_classes=["mm-layout"]):
|
| 576 |
with gr.Column(elem_classes=["mm-sidebar"]):
|
| 577 |
-
gr.HTML('<p class="mm-panel-title">
|
| 578 |
image_input = gr.File(
|
| 579 |
-
label="
|
| 580 |
file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
|
| 581 |
)
|
| 582 |
px_per_um_in = gr.Number(
|
| 583 |
value=DEFAULT_PX_PER_UM,
|
| 584 |
-
label="
|
| 585 |
-
info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the
|
| 586 |
-
"Update if your acquisition scale differs.",
|
| 587 |
minimum=1,
|
| 588 |
maximum=1e6,
|
| 589 |
)
|
|
@@ -592,28 +736,64 @@ def build_app():
|
|
| 592 |
maximum=0.95,
|
| 593 |
value=0.25,
|
| 594 |
step=0.05,
|
| 595 |
-
label="Confidence
|
| 596 |
-
info="Higher
|
| 597 |
)
|
| 598 |
-
with gr.Accordion("Advanced ·
|
| 599 |
nms_6nm = gr.Slider(
|
| 600 |
minimum=1,
|
| 601 |
maximum=9,
|
| 602 |
value=3,
|
| 603 |
step=2,
|
| 604 |
-
label="
|
| 605 |
-
info="Minimum
|
| 606 |
)
|
| 607 |
nms_12nm = gr.Slider(
|
| 608 |
minimum=1,
|
| 609 |
maximum=9,
|
| 610 |
value=5,
|
| 611 |
step=2,
|
| 612 |
-
label="
|
| 613 |
)
|
| 614 |
detect_btn = gr.Button("Run detection", variant="primary", size="lg")
|
| 615 |
|
| 616 |
-
with gr.Accordion("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
gr.Markdown(
|
| 618 |
"""
|
| 619 |
#### What the model outputs
|
|
@@ -634,8 +814,9 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
|
|
| 634 |
)
|
| 635 |
|
| 636 |
with gr.Column(elem_classes=["mm-main"]):
|
|
|
|
| 637 |
summary_md = gr.HTML(
|
| 638 |
-
value="<p class='mm-callout'>Upload a
|
| 639 |
)
|
| 640 |
with gr.Tabs():
|
| 641 |
with gr.Tab("Overlay"):
|
|
@@ -643,18 +824,21 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
|
|
| 643 |
label="Detections + scale bar",
|
| 644 |
type="numpy",
|
| 645 |
height=540,
|
|
|
|
| 646 |
)
|
| 647 |
with gr.Tab("Heatmaps"):
|
| 648 |
heatmap_output = gr.Image(
|
| 649 |
label="Class-specific maps",
|
| 650 |
type="numpy",
|
| 651 |
height=540,
|
|
|
|
| 652 |
)
|
| 653 |
-
with gr.Tab("
|
| 654 |
stats_output = gr.Image(
|
| 655 |
-
label="
|
| 656 |
type="numpy",
|
| 657 |
height=440,
|
|
|
|
| 658 |
)
|
| 659 |
with gr.Tab("Table & export"):
|
| 660 |
table_output = gr.HTML(
|
|
@@ -674,8 +858,10 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
|
|
| 674 |
"""
|
| 675 |
)
|
| 676 |
|
|
|
|
|
|
|
| 677 |
detect_btn.click(
|
| 678 |
-
fn=
|
| 679 |
inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
|
| 680 |
outputs=[
|
| 681 |
overlay_output,
|
|
@@ -684,12 +870,52 @@ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
|
|
| 684 |
csv_output,
|
| 685 |
table_output,
|
| 686 |
summary_md,
|
|
|
|
| 687 |
],
|
| 688 |
-
)
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
return app
|
| 691 |
|
| 692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
def main():
|
| 694 |
parser = argparse.ArgumentParser(description="MidasMap web dashboard")
|
| 695 |
parser.add_argument(
|
|
@@ -712,39 +938,24 @@ def main():
|
|
| 712 |
if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
|
| 713 |
args.share = True
|
| 714 |
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
"--local-dir ."
|
| 722 |
-
)
|
| 723 |
|
| 724 |
load_model(str(ckpt))
|
| 725 |
demo = build_app()
|
|
|
|
| 726 |
launch_kw = dict(
|
| 727 |
share=args.share,
|
| 728 |
-
server_port=
|
| 729 |
server_name=args.server_name,
|
| 730 |
show_api=False,
|
| 731 |
inbrowser=False,
|
| 732 |
)
|
| 733 |
-
|
| 734 |
-
demo.launch(**launch_kw)
|
| 735 |
-
except ValueError as err:
|
| 736 |
-
if (
|
| 737 |
-
"localhost is not accessible" in str(err)
|
| 738 |
-
and not launch_kw.get("share")
|
| 739 |
-
and os.environ.get("GRADIO_SHARE", "").lower() not in ("1", "true", "yes")
|
| 740 |
-
):
|
| 741 |
-
print(
|
| 742 |
-
"Localhost check failed in this environment; starting with share=True "
|
| 743 |
-
"(Gradio tunnel). Use --share next time, or set GRADIO_SHARE=1."
|
| 744 |
-
)
|
| 745 |
-
build_app().launch(**{**launch_kw, "share": True})
|
| 746 |
-
else:
|
| 747 |
-
raise
|
| 748 |
|
| 749 |
|
| 750 |
if __name__ == "__main__":
|
|
|
|
| 91 |
if torch.backends.mps.is_available()
|
| 92 |
else "cpu"
|
| 93 |
)
|
| 94 |
+
MODEL = ImmunogoldCenterNet(
|
| 95 |
+
bifpn_channels=128,
|
| 96 |
+
bifpn_rounds=2,
|
| 97 |
+
imagenet_encoder_fallback=False,
|
| 98 |
+
)
|
| 99 |
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 100 |
MODEL.load_state_dict(ckpt["model_state_dict"])
|
| 101 |
MODEL.to(DEVICE)
|
|
|
|
| 180 |
)
|
| 181 |
|
| 182 |
|
| 183 |
+
def _numpy_image_to_uint8_rgb(img: np.ndarray) -> np.ndarray:
|
| 184 |
+
"""Normalize various arrays to HxWx3 uint8 for cropping / display."""
|
| 185 |
+
if img is None:
|
| 186 |
+
return None
|
| 187 |
+
arr = np.asarray(img)
|
| 188 |
+
if arr.size == 0:
|
| 189 |
+
return None
|
| 190 |
+
if arr.ndim == 2:
|
| 191 |
+
arr = np.stack([arr, arr, arr], axis=-1)
|
| 192 |
+
elif arr.ndim == 3 and arr.shape[2] == 4:
|
| 193 |
+
arr = arr[:, :, :3]
|
| 194 |
+
if arr.dtype in (np.float32, np.float64):
|
| 195 |
+
mx = float(arr.max()) if arr.size else 1.0
|
| 196 |
+
if mx <= 1.0:
|
| 197 |
+
arr = (np.clip(arr, 0, 1) * 255.0).astype(np.uint8)
|
| 198 |
+
else:
|
| 199 |
+
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
| 200 |
+
else:
|
| 201 |
+
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
| 202 |
+
return arr
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def magnifier_zoom(
|
| 206 |
+
store: dict,
|
| 207 |
+
view: str,
|
| 208 |
+
center_x_pct: float,
|
| 209 |
+
center_y_pct: float,
|
| 210 |
+
zoom: float,
|
| 211 |
+
output_px: int,
|
| 212 |
+
) -> np.ndarray | None:
|
| 213 |
+
"""
|
| 214 |
+
Crop a square region around (center_x_pct, center_y_pct) and upscale for a loupe view.
|
| 215 |
+
zoom: 1 = see ~full width in loupe; larger = stronger magnification (smaller crop).
|
| 216 |
+
"""
|
| 217 |
+
if not store or not isinstance(store, dict):
|
| 218 |
+
return None
|
| 219 |
+
key = {"Overlay": "overlay", "Heatmaps": "heatmap", "Summary": "stats"}.get(view, "overlay")
|
| 220 |
+
img = _numpy_image_to_uint8_rgb(store.get(key))
|
| 221 |
+
if img is None:
|
| 222 |
+
return None
|
| 223 |
+
h, w = img.shape[:2]
|
| 224 |
+
cx = int(np.clip(center_x_pct / 100.0 * (w - 1), 0, w - 1))
|
| 225 |
+
cy = int(np.clip(center_y_pct / 100.0 * (h - 1), 0, h - 1))
|
| 226 |
+
z = max(1.0, float(zoom))
|
| 227 |
+
half_w = max(1, int(w / (2.0 * z)))
|
| 228 |
+
half_h = max(1, int(h / (2.0 * z)))
|
| 229 |
+
x0, x1 = max(0, cx - half_w), min(w, cx + half_w)
|
| 230 |
+
y0, y1 = max(0, cy - half_h), min(h, cy + half_h)
|
| 231 |
+
if x1 <= x0 or y1 <= y0:
|
| 232 |
+
crop = img
|
| 233 |
+
else:
|
| 234 |
+
crop = img[y0:y1, x0:x1]
|
| 235 |
+
side = int(np.clip(output_px, 256, 1024))
|
| 236 |
+
try:
|
| 237 |
+
from PIL import Image as PILImage
|
| 238 |
+
|
| 239 |
+
pil = PILImage.fromarray(crop)
|
| 240 |
+
pil = pil.resize((side, side), PILImage.Resampling.LANCZOS)
|
| 241 |
+
return np.asarray(pil)
|
| 242 |
+
except Exception:
|
| 243 |
+
from skimage.transform import resize
|
| 244 |
+
|
| 245 |
+
up = resize(crop, (side, side), order=1, preserve_range=True)
|
| 246 |
+
return np.clip(up, 0, 255).astype(np.uint8)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def run_detection(
|
| 250 |
+
image_file,
|
| 251 |
+
conf_threshold: float,
|
| 252 |
+
nms_6nm: int,
|
| 253 |
+
nms_12nm: int,
|
| 254 |
+
px_per_um: float,
|
| 255 |
+
progress=gr.Progress(track_tqdm=False),
|
| 256 |
+
):
|
| 257 |
+
"""Run model and return outputs plus viz state for the magnifier."""
|
| 258 |
+
out = detect_particles(
|
| 259 |
+
image_file,
|
| 260 |
+
conf_threshold,
|
| 261 |
+
nms_6nm,
|
| 262 |
+
nms_12nm,
|
| 263 |
+
px_per_um,
|
| 264 |
+
progress=progress,
|
| 265 |
+
)
|
| 266 |
+
overlay, hm, stats, csvp, table, summary = out
|
| 267 |
+
store = {"overlay": overlay, "heatmap": hm, "stats": stats}
|
| 268 |
+
return overlay, hm, stats, csvp, table, summary, store
|
| 269 |
+
|
| 270 |
+
|
| 271 |
def detect_particles(
|
| 272 |
image_file,
|
| 273 |
conf_threshold: float = 0.25,
|
|
|
|
| 342 |
|
| 343 |
from skimage.transform import resize
|
| 344 |
|
| 345 |
+
hm6_up = np.clip(
|
| 346 |
+
np.nan_to_num(resize(hm_np[0], (h, w), order=1), nan=0.0),
|
| 347 |
+
0.0,
|
| 348 |
+
1.0,
|
| 349 |
+
)
|
| 350 |
+
hm12_up = np.clip(
|
| 351 |
+
np.nan_to_num(resize(hm_np[1], (h, w), order=1), nan=0.0),
|
| 352 |
+
0.0,
|
| 353 |
+
1.0,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _heatmap_vmax(hm: np.ndarray) -> float:
|
| 357 |
+
"""Stable color scale: avoid invisible overlays when max is tiny or flat."""
|
| 358 |
+
flat = hm.ravel()
|
| 359 |
+
if flat.size == 0:
|
| 360 |
+
return 0.3
|
| 361 |
+
mx = float(np.max(flat))
|
| 362 |
+
if mx < 1e-6:
|
| 363 |
+
return 0.3
|
| 364 |
+
p99 = float(np.percentile(flat, 99.0))
|
| 365 |
+
return float(np.clip(max(0.12, p99 * 1.05, mx * 0.95), 0.05, 1.0))
|
| 366 |
|
| 367 |
# --- Overlay (publication-style legend + scale bar) ---
|
| 368 |
fig_overlay, ax = plt.subplots(figsize=(11, 11))
|
|
|
|
| 402 |
overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
|
| 403 |
plt.close(fig_overlay)
|
| 404 |
|
| 405 |
+
# --- Heatmaps: row1 = overlay on EM; row2 = model heat only (debug-friendly) ---
|
| 406 |
+
# Training uses Gaussian GT; inference heatmaps are learned sigmoid blobs, not analytic Gaussians.
|
| 407 |
+
v6, v12 = _heatmap_vmax(hm6_up), _heatmap_vmax(hm12_up)
|
| 408 |
+
fig_hm, axes = plt.subplots(2, 2, figsize=(14, 12))
|
| 409 |
+
ax00, ax01 = axes[0]
|
| 410 |
+
ax10, ax11 = axes[1]
|
| 411 |
+
|
| 412 |
+
for ax, hm, v, cmap, title in (
|
| 413 |
+
(ax00, hm6_up, v6, "magma", f"AMPA overlay · n={n_6nm} · vmax={v6:.2f}"),
|
| 414 |
+
(ax01, hm12_up, v12, "inferno", f"NR1 overlay · n={n_12nm} · vmax={v12:.2f}"),
|
| 415 |
+
):
|
| 416 |
+
ax.imshow(img, cmap="gray", aspect="equal", interpolation="nearest")
|
| 417 |
+
ax.imshow(
|
| 418 |
+
hm,
|
| 419 |
+
cmap=cmap,
|
| 420 |
+
alpha=0.6,
|
| 421 |
+
vmin=0.0,
|
| 422 |
+
vmax=v,
|
| 423 |
+
interpolation="bilinear",
|
| 424 |
+
)
|
| 425 |
+
ax.set_title(title, fontsize=10)
|
| 426 |
+
ax.axis("off")
|
| 427 |
+
|
| 428 |
+
ax10.imshow(hm6_up, cmap="magma", vmin=0.0, vmax=v6, interpolation="nearest")
|
| 429 |
+
ax10.set_title(f"AMPA heatmap only · max={float(np.max(hm6_up)):.4f}", fontsize=10)
|
| 430 |
+
ax10.axis("off")
|
| 431 |
+
|
| 432 |
+
ax11.imshow(hm12_up, cmap="inferno", vmin=0.0, vmax=v12, interpolation="nearest")
|
| 433 |
+
ax11.set_title(f"NR1 heatmap only · max={float(np.max(hm12_up)):.4f}", fontsize=10)
|
| 434 |
+
ax11.axis("off")
|
| 435 |
+
|
| 436 |
plt.tight_layout()
|
| 437 |
+
# PNG raster → uint8 RGB (reliable in Gradio vs raw canvas buffer on some setups)
|
| 438 |
+
from io import BytesIO
|
| 439 |
+
|
| 440 |
+
_buf = BytesIO()
|
| 441 |
+
fig_hm.savefig(_buf, format="png", dpi=120, bbox_inches="tight", facecolor="white")
|
| 442 |
plt.close(fig_hm)
|
| 443 |
+
_buf.seek(0)
|
| 444 |
+
try:
|
| 445 |
+
from PIL import Image as _PILImage
|
| 446 |
+
|
| 447 |
+
heatmap_img = np.asarray(_PILImage.open(_buf).convert("RGB"))
|
| 448 |
+
except Exception:
|
| 449 |
+
import matplotlib.image as _mimg
|
| 450 |
+
|
| 451 |
+
_buf.seek(0)
|
| 452 |
+
heatmap_img = (_mimg.imread(_buf)[:, :, :3] * 255.0).clip(0, 255).astype(np.uint8)
|
| 453 |
|
| 454 |
# --- Stats (µm where helpful) ---
|
| 455 |
fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
|
|
|
|
| 555 |
|
| 556 |
|
| 557 |
MM_CSS = """
|
| 558 |
+
@import url("https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap");
|
| 559 |
+
.gradio-container { max-width: 1280px !important; margin: auto !important; padding: 1rem 0.75rem 2rem !important; }
|
| 560 |
.mm-brand-bar {
|
| 561 |
display: flex; align-items: center; justify-content: space-between;
|
| 562 |
+
flex-wrap: wrap; gap: 0.5rem 1rem;
|
| 563 |
+
padding: 0 0 1rem;
|
| 564 |
+
margin-bottom: 1rem;
|
| 565 |
+
border-bottom: 1px solid rgba(148, 163, 184, 0.2);
|
| 566 |
}
|
| 567 |
.mm-brand-bar span {
|
| 568 |
+
font-size: 0.7rem; letter-spacing: 0.06em;
|
| 569 |
+
color: var(--body-text-color-subdued); font-weight: 500;
|
| 570 |
}
|
| 571 |
.mm-hero {
|
| 572 |
+
padding: 1.35rem 1.5rem;
|
| 573 |
+
margin-bottom: 1.25rem;
|
| 574 |
+
border-radius: 16px;
|
| 575 |
+
background: linear-gradient(155deg, rgba(13, 148, 136, 0.12) 0%, rgba(15, 23, 42, 0.95) 42%, rgba(30, 27, 75, 0.15) 100%);
|
| 576 |
+
border: 1px solid rgba(148, 163, 184, 0.15);
|
| 577 |
+
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
|
| 578 |
}
|
| 579 |
.mm-hero h1 {
|
| 580 |
font-family: "Libre Baskerville", Georgia, serif;
|
| 581 |
font-weight: 700;
|
| 582 |
letter-spacing: -0.02em;
|
| 583 |
+
margin: 0 0 0.5rem 0;
|
| 584 |
+
font-size: 1.75rem;
|
| 585 |
+
color: #f8fafc;
|
| 586 |
}
|
| 587 |
.mm-hero .mm-sub {
|
| 588 |
+
margin: 0 0 1rem 0;
|
| 589 |
+
color: #cbd5e1;
|
| 590 |
+
font-size: 0.95rem;
|
| 591 |
+
line-height: 1.6;
|
| 592 |
+
max-width: 62ch;
|
| 593 |
}
|
| 594 |
+
.mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.45rem; }
|
| 595 |
.mm-badge {
|
| 596 |
+
font-size: 0.62rem; letter-spacing: 0.05em; font-weight: 600;
|
| 597 |
+
padding: 0.28rem 0.55rem; border-radius: 999px;
|
| 598 |
+
background: rgba(45, 212, 191, 0.12); color: #5eead4;
|
| 599 |
+
border: 1px solid rgba(45, 212, 191, 0.25);
|
| 600 |
}
|
| 601 |
+
.mm-layout { display: flex; gap: 1.5rem; align-items: flex-start; flex-wrap: wrap; }
|
| 602 |
.mm-sidebar {
|
| 603 |
+
flex: 1 1 300px; max-width: 360px;
|
| 604 |
+
padding: 1.25rem 1.35rem; border-radius: 16px;
|
| 605 |
+
border: 1px solid rgba(148, 163, 184, 0.12);
|
| 606 |
background: var(--block-background-fill);
|
| 607 |
+
box-shadow: 0 4px 24px rgba(0, 0, 0, 0.12);
|
| 608 |
+
}
|
| 609 |
+
.mm-main {
|
| 610 |
+
flex: 1 1 480px; min-width: 0;
|
| 611 |
+
padding: 0.25rem 0.15rem;
|
| 612 |
+
border-radius: 16px;
|
| 613 |
}
|
|
|
|
| 614 |
.mm-panel-title {
|
| 615 |
+
font-size: 0.72rem; text-transform: uppercase; letter-spacing: 0.08em;
|
| 616 |
+
color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.75rem 0;
|
| 617 |
}
|
| 618 |
+
.mm-loupe-help {
|
| 619 |
+
font-size: 0.82rem; line-height: 1.45; color: var(--body-text-color-subdued);
|
| 620 |
+
margin: 0 0 0.75rem 0; padding: 0.65rem 0.85rem;
|
| 621 |
+
border-radius: 10px; background: rgba(30, 41, 59, 0.45);
|
| 622 |
+
border: 1px solid rgba(148, 163, 184, 0.12);
|
| 623 |
+
}
|
| 624 |
+
.tabs > .tab-nav button { font-weight: 500 !important; letter-spacing: 0.01em; }
|
| 625 |
.mm-callout {
|
| 626 |
margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
|
| 627 |
background: #1e293b66; border: 1px solid var(--border-color-primary);
|
|
|
|
| 672 |
|
| 673 |
|
| 674 |
def build_app():
|
| 675 |
+
# Use named hues only (no custom Color dicts): avoids Gradio/Jinja template bugs on some stacks (e.g. HF Spaces + Py3.13).
|
| 676 |
theme = gr.themes.Soft(
|
| 677 |
+
primary_hue="teal",
|
| 678 |
+
neutral_hue="slate",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
|
| 680 |
font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
|
| 681 |
).set(
|
|
|
|
| 687 |
block_label_text_size="*text_sm",
|
| 688 |
)
|
| 689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
with gr.Blocks(
|
| 691 |
title="MidasMap — Immunogold analysis",
|
| 692 |
theme=theme,
|
| 693 |
css=MM_CSS,
|
|
|
|
| 694 |
) as app:
|
| 695 |
gr.HTML(
|
| 696 |
"""
|
| 697 |
<div class="mm-brand-bar">
|
| 698 |
+
<span>MidasMap · immunogold on TEM synapses</span>
|
| 699 |
+
<span>For research — verify important counts by eye</span>
|
| 700 |
</div>
|
| 701 |
<div class="mm-hero">
|
| 702 |
<h1>MidasMap</h1>
|
| 703 |
<p class="mm-sub">
|
| 704 |
+
Find <strong>6 nm</strong> (AMPA) and <strong>12 nm</strong> (NR1) gold particles in
|
| 705 |
+
<strong>FFRIL</strong> micrographs. Set <strong>calibration</strong> so exports are in µm.
|
| 706 |
+
Use the <strong>magnifying glass</strong> below to inspect beads and heatmaps up close.
|
| 707 |
</p>
|
| 708 |
<div class="mm-badge-row">
|
| 709 |
+
<span class="mm-badge">FFRIL</span>
|
| 710 |
<span class="mm-badge">CenterNet</span>
|
| 711 |
+
<span class="mm-badge">CEM500K</span>
|
| 712 |
+
<span class="mm-badge">F1 ≈ 0.94 LOOCV</span>
|
| 713 |
</div>
|
| 714 |
</div>
|
| 715 |
"""
|
| 716 |
)
|
| 717 |
|
| 718 |
+
viz_state = gr.State({"overlay": None, "heatmap": None, "stats": None})
|
| 719 |
+
|
| 720 |
with gr.Row(elem_classes=["mm-layout"]):
|
| 721 |
with gr.Column(elem_classes=["mm-sidebar"]):
|
| 722 |
+
gr.HTML('<p class="mm-panel-title">1 · Upload & settings</p>')
|
| 723 |
image_input = gr.File(
|
| 724 |
+
label="Micrograph",
|
| 725 |
file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
|
| 726 |
)
|
| 727 |
px_per_um_in = gr.Number(
|
| 728 |
value=DEFAULT_PX_PER_UM,
|
| 729 |
+
label="Pixels per µm",
|
| 730 |
+
info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the training corpus. Change if your scale differs.",
|
|
|
|
| 731 |
minimum=1,
|
| 732 |
maximum=1e6,
|
| 733 |
)
|
|
|
|
| 736 |
maximum=0.95,
|
| 737 |
value=0.25,
|
| 738 |
step=0.05,
|
| 739 |
+
label="Confidence",
|
| 740 |
+
info="Higher = stricter (fewer hits). Lower = more sensitive.",
|
| 741 |
)
|
| 742 |
+
with gr.Accordion("Advanced · peak spacing (NMS)", open=False):
|
| 743 |
nms_6nm = gr.Slider(
|
| 744 |
minimum=1,
|
| 745 |
maximum=9,
|
| 746 |
value=3,
|
| 747 |
step=2,
|
| 748 |
+
label="Spacing · 6 nm channel",
|
| 749 |
+
info="Minimum gap between AMPA peaks on the model grid.",
|
| 750 |
)
|
| 751 |
nms_12nm = gr.Slider(
|
| 752 |
minimum=1,
|
| 753 |
maximum=9,
|
| 754 |
value=5,
|
| 755 |
step=2,
|
| 756 |
+
label="Spacing · 12 nm channel",
|
| 757 |
)
|
| 758 |
detect_btn = gr.Button("Run detection", variant="primary", size="lg")
|
| 759 |
|
| 760 |
+
with gr.Accordion("Magnifying glass", open=True):
|
| 761 |
+
gr.HTML(
|
| 762 |
+
"""<p class="mm-loupe-help" style="margin-top:0">
|
| 763 |
+
After you run detection, pick which result to inspect and adjust the sliders.
|
| 764 |
+
<strong>Magnification</strong> zooms in (smaller crop, upscaled). Use the fullscreen icon on any image for a larger view.
|
| 765 |
+
</p>"""
|
| 766 |
+
)
|
| 767 |
+
mag_view = gr.Radio(
|
| 768 |
+
choices=["Overlay", "Heatmaps", "Summary"],
|
| 769 |
+
value="Overlay",
|
| 770 |
+
label="Source image",
|
| 771 |
+
)
|
| 772 |
+
mag_cx = gr.Slider(
|
| 773 |
+
0, 100, value=50, step=0.5,
|
| 774 |
+
label="Pan left ↔ right (%)",
|
| 775 |
+
)
|
| 776 |
+
mag_cy = gr.Slider(
|
| 777 |
+
0, 100, value=50, step=0.5,
|
| 778 |
+
label="Pan up ↔ down (%)",
|
| 779 |
+
)
|
| 780 |
+
mag_zoom = gr.Slider(
|
| 781 |
+
1, 10, value=2.5, step=0.25,
|
| 782 |
+
label="Magnification",
|
| 783 |
+
info="Higher = stronger zoom (smaller region).",
|
| 784 |
+
)
|
| 785 |
+
mag_out = gr.Slider(
|
| 786 |
+
256, 768, value=480, step=64,
|
| 787 |
+
label="Loupe window (px)",
|
| 788 |
+
)
|
| 789 |
+
mag_out_img = gr.Image(
|
| 790 |
+
label="Loupe preview",
|
| 791 |
+
type="numpy",
|
| 792 |
+
height=380,
|
| 793 |
+
show_fullscreen_button=True,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
with gr.Accordion("Notes for scientists", open=False):
|
| 797 |
gr.Markdown(
|
| 798 |
"""
|
| 799 |
#### What the model outputs
|
|
|
|
| 814 |
)
|
| 815 |
|
| 816 |
with gr.Column(elem_classes=["mm-main"]):
|
| 817 |
+
gr.HTML('<p class="mm-panel-title">2 · Results</p>')
|
| 818 |
summary_md = gr.HTML(
|
| 819 |
+
value="<p class='mm-callout'>Upload a micrograph and tap <strong>Run detection</strong>. Set pixels/µm before exporting if your scale differs.</p>"
|
| 820 |
)
|
| 821 |
with gr.Tabs():
|
| 822 |
with gr.Tab("Overlay"):
|
|
|
|
| 824 |
label="Detections + scale bar",
|
| 825 |
type="numpy",
|
| 826 |
height=540,
|
| 827 |
+
show_fullscreen_button=True,
|
| 828 |
)
|
| 829 |
with gr.Tab("Heatmaps"):
|
| 830 |
heatmap_output = gr.Image(
|
| 831 |
label="Class-specific maps",
|
| 832 |
type="numpy",
|
| 833 |
height=540,
|
| 834 |
+
show_fullscreen_button=True,
|
| 835 |
)
|
| 836 |
+
with gr.Tab("Summary"):
|
| 837 |
stats_output = gr.Image(
|
| 838 |
+
label="Counts & distributions",
|
| 839 |
type="numpy",
|
| 840 |
height=440,
|
| 841 |
+
show_fullscreen_button=True,
|
| 842 |
)
|
| 843 |
with gr.Tab("Table & export"):
|
| 844 |
table_output = gr.HTML(
|
|
|
|
| 858 |
"""
|
| 859 |
)
|
| 860 |
|
| 861 |
+
mag_inputs = [viz_state, mag_view, mag_cx, mag_cy, mag_zoom, mag_out]
|
| 862 |
+
|
| 863 |
detect_btn.click(
|
| 864 |
+
fn=run_detection,
|
| 865 |
inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
|
| 866 |
outputs=[
|
| 867 |
overlay_output,
|
|
|
|
| 870 |
csv_output,
|
| 871 |
table_output,
|
| 872 |
summary_md,
|
| 873 |
+
viz_state,
|
| 874 |
],
|
| 875 |
+
).then(magnifier_zoom, mag_inputs, mag_out_img)
|
| 876 |
+
|
| 877 |
+
for _ctrl in (mag_view, mag_cx, mag_cy, mag_zoom, mag_out):
|
| 878 |
+
_ctrl.change(magnifier_zoom, mag_inputs, mag_out_img)
|
| 879 |
|
| 880 |
return app
|
| 881 |
|
| 882 |
|
| 883 |
+
def _running_on_hf_space() -> bool:
|
| 884 |
+
"""Hugging Face Spaces injects these env vars; Gradio must bind 0.0.0.0 and never use share=True."""
|
| 885 |
+
return bool(
|
| 886 |
+
os.environ.get("SPACE_REPO_NAME")
|
| 887 |
+
or os.environ.get("SPACE_AUTHOR_NAME")
|
| 888 |
+
or os.environ.get("SPACE_ID")
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def _resolve_checkpoint(ckpt: Path) -> Path:
|
| 893 |
+
"""Use local .pth if present; on HF Space fetch from the Hub model repo if missing (smaller Space uploads)."""
|
| 894 |
+
if ckpt.is_file():
|
| 895 |
+
return ckpt
|
| 896 |
+
if _running_on_hf_space():
|
| 897 |
+
try:
|
| 898 |
+
from huggingface_hub import hf_hub_download
|
| 899 |
+
except ImportError as e:
|
| 900 |
+
raise SystemExit(
|
| 901 |
+
"huggingface_hub is required on the Space to download weights. "
|
| 902 |
+
"Add it to requirements.txt or bundle checkpoints/final/final_model.pth in the Space."
|
| 903 |
+
) from e
|
| 904 |
+
repo_id = os.environ.get("MIDASMAP_HF_WEIGHTS_REPO", "AnikS22/MidasMap").strip()
|
| 905 |
+
filename = os.environ.get(
|
| 906 |
+
"MIDASMAP_HF_WEIGHTS_FILE", "checkpoints/final/final_model.pth"
|
| 907 |
+
).strip()
|
| 908 |
+
print(f"Checkpoint not found at {ckpt}; downloading {filename} from model repo {repo_id} ...")
|
| 909 |
+
cached = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
|
| 910 |
+
return Path(cached)
|
| 911 |
+
raise SystemExit(
|
| 912 |
+
f"Checkpoint not found: {ckpt}\n"
|
| 913 |
+
"Train with train_final.py or download from Hugging Face:\n"
|
| 914 |
+
" huggingface-cli download AnikS22/MidasMap checkpoints/final/final_model.pth "
|
| 915 |
+
"--local-dir . --repo-type model"
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
|
| 919 |
def main():
|
| 920 |
parser = argparse.ArgumentParser(description="MidasMap web dashboard")
|
| 921 |
parser.add_argument(
|
|
|
|
| 938 |
if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
|
| 939 |
args.share = True
|
| 940 |
|
| 941 |
+
if _running_on_hf_space():
|
| 942 |
+
args.share = False
|
| 943 |
+
if not args.server_name:
|
| 944 |
+
args.server_name = "0.0.0.0"
|
| 945 |
+
|
| 946 |
+
ckpt = _resolve_checkpoint(Path(args.checkpoint))
|
|
|
|
|
|
|
| 947 |
|
| 948 |
load_model(str(ckpt))
|
| 949 |
demo = build_app()
|
| 950 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", os.environ.get("PORT", str(args.port))))
|
| 951 |
launch_kw = dict(
|
| 952 |
share=args.share,
|
| 953 |
+
server_port=port,
|
| 954 |
server_name=args.server_name,
|
| 955 |
show_api=False,
|
| 956 |
inbrowser=False,
|
| 957 |
)
|
| 958 |
+
demo.launch(**launch_kw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
|
| 960 |
|
| 961 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -14,5 +14,7 @@ PyYAML>=6.0
|
|
| 14 |
albumentations>=1.3.0
|
| 15 |
opencv-python-headless>=4.7.0
|
| 16 |
gradio==4.44.1
|
|
|
|
|
|
|
| 17 |
huggingface_hub>=0.20.0,<0.25.0
|
| 18 |
tqdm>=4.65.0
|
|
|
|
| 14 |
albumentations>=1.3.0
|
| 15 |
opencv-python-headless>=4.7.0
|
| 16 |
gradio==4.44.1
|
| 17 |
+
# Avoid Jinja2 3.2+ cache key issues with some Gradio/Starlette stacks on HF Spaces.
|
| 18 |
+
jinja2>=3.1.0,<3.2.0
|
| 19 |
huggingface_hub>=0.20.0,<0.25.0
|
| 20 |
tqdm>=4.65.0
|
src/ensemble.py
CHANGED
|
@@ -163,6 +163,22 @@ def ensemble_predict(
|
|
| 163 |
return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
|
| 164 |
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
def sliding_window_inference(
|
| 167 |
model: ImmunogoldCenterNet,
|
| 168 |
image: np.ndarray,
|
|
@@ -188,10 +204,18 @@ def sliding_window_inference(
|
|
| 188 |
offsets: (2, H/2, W/2) numpy array
|
| 189 |
"""
|
| 190 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
h, w = image.shape[:2]
|
| 192 |
stride_step = patch_size - overlap
|
| 193 |
|
| 194 |
-
# Output dimensions at model stride
|
| 195 |
out_h = h // 2
|
| 196 |
out_w = w // 2
|
| 197 |
out_patch = patch_size // 2
|
|
@@ -200,8 +224,8 @@ def sliding_window_inference(
|
|
| 200 |
offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
|
| 201 |
count = np.zeros((out_h, out_w), dtype=np.float32)
|
| 202 |
|
| 203 |
-
for y0 in
|
| 204 |
-
for x0 in
|
| 205 |
patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
|
| 206 |
tensor = (
|
| 207 |
torch.from_numpy(patch)
|
|
@@ -233,4 +257,9 @@ def sliding_window_inference(
|
|
| 233 |
count = np.maximum(count, 1)
|
| 234 |
offsets /= count[np.newaxis, :, :]
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
return heatmap, offsets
|
|
|
|
| 163 |
return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
|
| 164 |
|
| 165 |
|
| 166 |
+
def _tile_origins(axis_len: int, patch: int, stride_step: int) -> list:
|
| 167 |
+
"""
|
| 168 |
+
Starting indices for sliding windows along one axis so the last window
|
| 169 |
+
flush-aligns with the far edge. Plain range(0, n-patch+1, step) misses
|
| 170 |
+
the bottom/right of most image sizes (e.g. 2048 with patch 512, step 384),
|
| 171 |
+
leaving heatmap strips at zero.
|
| 172 |
+
"""
|
| 173 |
+
if axis_len <= patch:
|
| 174 |
+
return [0]
|
| 175 |
+
last = axis_len - patch
|
| 176 |
+
starts = list(range(0, last + 1, stride_step))
|
| 177 |
+
if starts[-1] != last:
|
| 178 |
+
starts.append(last)
|
| 179 |
+
return starts
|
| 180 |
+
|
| 181 |
+
|
| 182 |
def sliding_window_inference(
|
| 183 |
model: ImmunogoldCenterNet,
|
| 184 |
image: np.ndarray,
|
|
|
|
| 204 |
offsets: (2, H/2, W/2) numpy array
|
| 205 |
"""
|
| 206 |
model.eval()
|
| 207 |
+
orig_h, orig_w = image.shape[:2]
|
| 208 |
+
# Pad bottom/right so each dim >= patch_size; otherwise range() for tiles is empty
|
| 209 |
+
# and heatmaps stay all zeros (looks like a "broken" heatmap in the UI).
|
| 210 |
+
pad_h = max(0, patch_size - orig_h)
|
| 211 |
+
pad_w = max(0, patch_size - orig_w)
|
| 212 |
+
if pad_h > 0 or pad_w > 0:
|
| 213 |
+
image = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")
|
| 214 |
+
|
| 215 |
h, w = image.shape[:2]
|
| 216 |
stride_step = patch_size - overlap
|
| 217 |
|
| 218 |
+
# Output dimensions at model stride (padded image)
|
| 219 |
out_h = h // 2
|
| 220 |
out_w = w // 2
|
| 221 |
out_patch = patch_size // 2
|
|
|
|
| 224 |
offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
|
| 225 |
count = np.zeros((out_h, out_w), dtype=np.float32)
|
| 226 |
|
| 227 |
+
for y0 in _tile_origins(h, patch_size, stride_step):
|
| 228 |
+
for x0 in _tile_origins(w, patch_size, stride_step):
|
| 229 |
patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
|
| 230 |
tensor = (
|
| 231 |
torch.from_numpy(patch)
|
|
|
|
| 257 |
count = np.maximum(count, 1)
|
| 258 |
offsets /= count[np.newaxis, :, :]
|
| 259 |
|
| 260 |
+
# Crop back to original (pre-pad) spatial extent in heatmap space
|
| 261 |
+
crop_h, crop_w = orig_h // 2, orig_w // 2
|
| 262 |
+
heatmap = heatmap[:, :crop_h, :crop_w]
|
| 263 |
+
offsets = offsets[:, :crop_h, :crop_w]
|
| 264 |
+
|
| 265 |
return heatmap, offsets
|
src/model.py
CHANGED
|
@@ -215,6 +215,7 @@ class ImmunogoldCenterNet(nn.Module):
|
|
| 215 |
bifpn_channels: int = 128,
|
| 216 |
bifpn_rounds: int = 2,
|
| 217 |
num_classes: int = 2,
|
|
|
|
| 218 |
):
|
| 219 |
super().__init__()
|
| 220 |
self.num_classes = num_classes
|
|
@@ -229,13 +230,14 @@ class ImmunogoldCenterNet(nn.Module):
|
|
| 229 |
# Load pretrained weights
|
| 230 |
if pretrained_path:
|
| 231 |
self._load_pretrained(backbone, pretrained_path)
|
| 232 |
-
|
| 233 |
-
#
|
| 234 |
imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 235 |
state = imagenet_backbone.state_dict()
|
| 236 |
# Mean-pool RGB conv1 weights → grayscale
|
| 237 |
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 238 |
backbone.load_state_dict(state, strict=False)
|
|
|
|
| 239 |
|
| 240 |
# Extract encoder stages
|
| 241 |
self.stem = nn.Sequential(
|
|
|
|
| 215 |
bifpn_channels: int = 128,
|
| 216 |
bifpn_rounds: int = 2,
|
| 217 |
num_classes: int = 2,
|
| 218 |
+
imagenet_encoder_fallback: bool = True,
|
| 219 |
):
|
| 220 |
super().__init__()
|
| 221 |
self.num_classes = num_classes
|
|
|
|
| 230 |
# Load pretrained weights
|
| 231 |
if pretrained_path:
|
| 232 |
self._load_pretrained(backbone, pretrained_path)
|
| 233 |
+
elif imagenet_encoder_fallback:
|
| 234 |
+
# Training: better init when CEM500K path is missing (downloads ~100MB).
|
| 235 |
imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 236 |
state = imagenet_backbone.state_dict()
|
| 237 |
# Mean-pool RGB conv1 weights → grayscale
|
| 238 |
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 239 |
backbone.load_state_dict(state, strict=False)
|
| 240 |
+
# else: random encoder init — use when loading a full checkpoint immediately (Gradio, predict).
|
| 241 |
|
| 242 |
# Extract encoder stages
|
| 243 |
self.stem = nn.Sequential(
|