Spaces:
Running
Running
Upload 5 files
Browse files- README.md +55 -11
- app.py +240 -0
- huggingface.yaml +3 -0
- inference.py +120 -0
- requirements.txt +10 -0
README.md
CHANGED
|
@@ -1,14 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🫀 SYNTAX-Video — Multi-study Inference
|
| 2 |
+
|
| 3 |
+
**SYNTAX-Video** — это интерфейс для автоматической оценки **SYNTAX-score** по видеозаписям коронарной ангиографии.
|
| 4 |
+
Модель анализирует DICOM-файлы левой и правой коронарной артерии, вычисляя вклад каждой и итоговый суммарный показатель.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## 🚀 Возможности
|
| 9 |
+
|
| 10 |
+
* Загрузка исследований в формате **DICOM**
|
| 11 |
+
* Поддержка **многомодельного ансамбля** (усреднение только по моделям)
|
| 12 |
+
* Обработка **нескольких клипов** одной артерии как единой последовательности
|
| 13 |
+
* Интерфейс для **множественных исследований** в одном сеансе
|
| 14 |
+
* Автоматическое определение весов модели (`weights/left`, `weights/right`)
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## 🧭 Как использовать
|
| 19 |
+
|
| 20 |
+
1. Укажите **ID исследования** и при необходимости краткое **описание**.
|
| 21 |
+
2. Загрузите **DICOM-файлы** для **левой** и/или **правой** артерии.
|
| 22 |
+
3. Нажмите **“Add study”**, затем **“Run inference”**.
|
| 23 |
+
4. Результаты отобразятся в виде JSON-структуры с:
|
| 24 |
+
|
| 25 |
+
* предсказаниями по каждой модели,
|
| 26 |
+
* средним значением для каждой артерии,
|
| 27 |
+
* суммарным SYNTAX-score,
|
| 28 |
+
* пометкой «High-risk» при превышении порога.
|
| 29 |
+
|
| 30 |
---
|
| 31 |
+
|
| 32 |
+
## ⚙️ Технические детали
|
| 33 |
+
|
| 34 |
+
* **Backbone:** `r3d_18` (torchvision)
|
| 35 |
+
* **Head:** `lstm_mean` (возможно использование других вариантов: GRU, mean, BERT-head и др.)
|
| 36 |
+
* **Формат входа:** `(1, S, C, T, H, W)` — пакет клипов одной артерии
|
| 37 |
+
* **Нормализация:** стандарт ImageNet
|
| 38 |
+
* **Усреднение:** только по моделям ансамбля (без усреднения по клипам внутри исследования)
|
| 39 |
+
* **Пороговые значения:**
|
| 40 |
+
|
| 41 |
+
* левая артерия ≥ 15
|
| 42 |
+
* правая артерия ≥ 5
|
| 43 |
+
* общее исследование ≥ 22
|
| 44 |
+
|
| 45 |
---
|
| 46 |
|
| 47 |
+
## 🧩 Структура проекта
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
configs/
|
| 51 |
+
└── default.yaml # Основная конфигурация
|
| 52 |
+
src/syntax_pred/
|
| 53 |
+
├── model.py # Архитектура и загрузка весов
|
| 54 |
+
├── preprocess.py # Обработка DICOM и трансформации
|
| 55 |
+
├── data.py # Минималистичный датасет для инференса
|
| 56 |
+
├── utils.py # Поддержка: выбор устройства, поиск весов
|
| 57 |
+
app.py # Gradio-интерфейс
|
| 58 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py — SYNTAX-Video (multi-study UI) — pack-per-study, average only over models
|
| 2 |
+
import os
|
| 3 |
+
import base64
|
| 4 |
+
from dataclasses import dataclass, asdict
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
from src.syntax_pred.utils import pick_device, discover_weights
|
| 13 |
+
from src.syntax_pred.model import SyntaxLightningModule
|
| 14 |
+
from src.syntax_pred.preprocess import (
|
| 15 |
+
read_dicom_uint8,
|
| 16 |
+
ensure_length_center_crop,
|
| 17 |
+
test_like_transform,
|
| 18 |
+
IMAGENET_MEAN,
|
| 19 |
+
IMAGENET_STD,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# -------- Globals --------
|
| 23 |
+
CFG = OmegaConf.load("configs/default.yaml")
|
| 24 |
+
DEVICE = pick_device(CFG.device)
|
| 25 |
+
|
| 26 |
+
# ===== Logo handling (base64) =====
|
| 27 |
+
DEFAULT_LOGO = "assets/logo.png"
|
| 28 |
+
LOGO_PATH = os.environ.get("LOGO_PATH", DEFAULT_LOGO)
|
| 29 |
+
|
| 30 |
+
def _logo_html() -> str:
|
| 31 |
+
path = LOGO_PATH
|
| 32 |
+
if not path or not os.path.exists(path):
|
| 33 |
+
return ""
|
| 34 |
+
try:
|
| 35 |
+
with open(path, "rb") as f:
|
| 36 |
+
b64 = base64.b64encode(f.read()).decode("ascii")
|
| 37 |
+
ext = os.path.splitext(path)[1].lower()
|
| 38 |
+
mime = "image/png" if ext in {".png", ""} else "image/jpeg" if ext in {".jpg", ".jpeg"} else "image/png"
|
| 39 |
+
data_uri = f"data:{mime};base64,{b64}"
|
| 40 |
+
return (
|
| 41 |
+
f'<img src="{data_uri}" alt="logo" '
|
| 42 |
+
f'style="height:40px;vertical-align:middle;display:inline-block;'
|
| 43 |
+
f'image-rendering:auto;object-fit:contain;margin-right:12px;" />'
|
| 44 |
+
)
|
| 45 |
+
except Exception:
|
| 46 |
+
return ""
|
| 47 |
+
|
| 48 |
+
# ===== Models =====
|
| 49 |
+
def build_model(weight_path: str):
|
| 50 |
+
return SyntaxLightningModule(
|
| 51 |
+
num_classes=CFG.num_classes,
|
| 52 |
+
lr=1e-5,
|
| 53 |
+
variant=CFG.variant,
|
| 54 |
+
pl_weight_path=weight_path,
|
| 55 |
+
rnn_hidden_div=CFG.a_rnn.hidden_div,
|
| 56 |
+
rnn_dropout=CFG.a_rnn.dropout,
|
| 57 |
+
bert_nhead=CFG.bert.nhead,
|
| 58 |
+
bert_layers=CFG.bert.num_layers,
|
| 59 |
+
bert_ff_div=CFG.bert.ff_div,
|
| 60 |
+
bert_dropout=CFG.bert.dropout,
|
| 61 |
+
precision=CFG.precision,
|
| 62 |
+
).to(DEVICE).eval()
|
| 63 |
+
|
| 64 |
+
def list_weights() -> Tuple[List[str], List[str]]:
|
| 65 |
+
left = CFG.weights.left or discover_weights("weights/left")
|
| 66 |
+
right = CFG.weights.right or discover_weights("weights/right")
|
| 67 |
+
return left, right
|
| 68 |
+
|
| 69 |
+
# ===== Data structures =====
|
| 70 |
+
@dataclass
|
| 71 |
+
class Study:
|
| 72 |
+
name: str
|
| 73 |
+
description: str
|
| 74 |
+
left_paths: List[str]
|
| 75 |
+
right_paths: List[str]
|
| 76 |
+
|
| 77 |
+
def _files_to_paths(files) -> List[str]:
|
| 78 |
+
return [f.name for f in (files or []) if hasattr(f, "name") and os.path.exists(f.name)]
|
| 79 |
+
|
| 80 |
+
# ===== Packing: внутри исследования все клипы одной артерии → одна последовательность =====
|
| 81 |
+
def _pack_study_side_to_tensor(
|
| 82 |
+
file_paths: List[str],
|
| 83 |
+
frames_per_clip: int,
|
| 84 |
+
video_size: Tuple[int, int],
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
if not file_paths:
|
| 87 |
+
return None
|
| 88 |
+
tx = test_like_transform(video_size)
|
| 89 |
+
clips = []
|
| 90 |
+
for p in file_paths:
|
| 91 |
+
arr = read_dicom_uint8(p) # (T,H,W) uint8
|
| 92 |
+
arr = ensure_length_center_crop(arr, frames_per_clip)
|
| 93 |
+
thwc = np.stack([arr, arr, arr], axis=-1) # (T,H,W,3)
|
| 94 |
+
thwc = torch.tensor(thwc, dtype=torch.uint8)
|
| 95 |
+
cthw = tx(thwc) # (C,T,H,W)
|
| 96 |
+
clips.append(cthw)
|
| 97 |
+
if not clips:
|
| 98 |
+
return None
|
| 99 |
+
return torch.stack(clips, dim=0).unsqueeze(0) # (1,S,C,T,H,W)
|
| 100 |
+
|
| 101 |
+
# ===== Inference logic (среднее только по моделям) =====
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def _score_side_by_models(
|
| 104 |
+
side_paths: List[str],
|
| 105 |
+
model_paths: List[str],
|
| 106 |
+
frames_per_clip: int,
|
| 107 |
+
video_size: Tuple[int, int],
|
| 108 |
+
) -> Dict[str, Any]:
|
| 109 |
+
if not side_paths:
|
| 110 |
+
return {"mean": 0.0, "per_model": [], "n_files": 0}
|
| 111 |
+
|
| 112 |
+
x = _pack_study_side_to_tensor(side_paths, frames_per_clip, video_size)
|
| 113 |
+
if x is None:
|
| 114 |
+
return {"mean": 0.0, "per_model": [], "n_files": 0}
|
| 115 |
+
x = x.to(DEVICE)
|
| 116 |
+
|
| 117 |
+
per_model_scores: List[float] = []
|
| 118 |
+
used_models: List[str] = []
|
| 119 |
+
|
| 120 |
+
for wp in model_paths:
|
| 121 |
+
try:
|
| 122 |
+
m = build_model(wp)
|
| 123 |
+
y = m(x) # (1,2)
|
| 124 |
+
reg_log = float(y[0, 1].detach().cpu().numpy())
|
| 125 |
+
score = float(max(0.0, np.exp(reg_log) - 1.0)) # inverse log(1+score)
|
| 126 |
+
per_model_scores.append(score)
|
| 127 |
+
used_models.append(os.path.basename(wp))
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"[WARN] model {wp} failed: {e}")
|
| 130 |
+
|
| 131 |
+
mean_score = float(np.mean(per_model_scores)) if per_model_scores else 0.0
|
| 132 |
+
return {
|
| 133 |
+
"mean": mean_score,
|
| 134 |
+
"per_model": [{"model": n, "score": round(s, 3)} for n, s in zip(used_models, per_model_scores)],
|
| 135 |
+
"n_files": len(side_paths),
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def run_inference(studies: List[Study]) -> Dict[str, Any]:
|
| 139 |
+
left_w, right_w = list_weights()
|
| 140 |
+
if not left_w and not right_w:
|
| 141 |
+
return {"error": "No weights found. Upload to weights/left and weights/right."}
|
| 142 |
+
|
| 143 |
+
results = {"studies": []}
|
| 144 |
+
thr = CFG.thresholds.both
|
| 145 |
+
video_size = tuple(CFG.video_size)
|
| 146 |
+
frames = CFG.frames_per_clip
|
| 147 |
+
|
| 148 |
+
for st in studies:
|
| 149 |
+
left_res = _score_side_by_models(st.left_paths, left_w, frames, video_size)
|
| 150 |
+
right_res = _score_side_by_models(st.right_paths, right_w, frames, video_size)
|
| 151 |
+
total = left_res["mean"] + right_res["mean"]
|
| 152 |
+
|
| 153 |
+
results["studies"].append({
|
| 154 |
+
"study": st.name,
|
| 155 |
+
"description": st.description or "",
|
| 156 |
+
"left": {"mean": round(left_res["mean"], 3), "per_model": left_res["per_model"], "n_files": left_res["n_files"]},
|
| 157 |
+
"right": {"mean": round(right_res["mean"], 3), "per_model": right_res["per_model"], "n_files": right_res["n_files"]},
|
| 158 |
+
"total": {"mean": round(total, 3), f"High-risk (≥{thr:.1f})": bool(total >= thr)},
|
| 159 |
+
})
|
| 160 |
+
return results
|
| 161 |
+
|
| 162 |
+
# ===== UI =====
|
| 163 |
+
def ui():
|
| 164 |
+
with gr.Blocks() as demo:
|
| 165 |
+
gr.HTML(
|
| 166 |
+
f"""
|
| 167 |
+
<div style="display:flex;align-items:center;gap:10px;margin-bottom:8px;">
|
| 168 |
+
{_logo_html()}
|
| 169 |
+
<h1 style="margin:0;font-weight:800;text-align:center;flex:1;">SYNTAX-Video — Multi-study Inference</h1>
|
| 170 |
+
</div>
|
| 171 |
+
<ol style="margin:0 0 12px 20px; color:#475569; line-height:1.5;">
|
| 172 |
+
<li>Укажите ID исследования и (необязательно) описание.</li>
|
| 173 |
+
<li>Загрузите DICOM-файлы для ЛЕВОЙ и/или ПРАВОЙ артерии.</li>
|
| 174 |
+
<li>Нажмите “Add study”, чтобы добавить исследование, и затем “Run inference” для запуска анализа.</li>
|
| 175 |
+
</ol>
|
| 176 |
+
"""
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
studies_state = gr.State([]) # list[dict]
|
| 181 |
+
|
| 182 |
+
# --- ВЕРХНИЙ РЯД: Study ID + Description ---
|
| 183 |
+
with gr.Row():
|
| 184 |
+
study_name = gr.Textbox(label="Study ID", placeholder="e.g., S1234")
|
| 185 |
+
study_desc = gr.Textbox(label="Description (optional)", placeholder="Free text...")
|
| 186 |
+
|
| 187 |
+
# --- НИЖНИЙ РЯД: два загрузчика бок-о-бок ---
|
| 188 |
+
with gr.Row():
|
| 189 |
+
add_left = gr.File(label="LEFT artery DICOM(s)", file_count="multiple")
|
| 190 |
+
add_right = gr.File(label="RIGHT artery DICOM(s)", file_count="multiple")
|
| 191 |
+
|
| 192 |
+
with gr.Row():
|
| 193 |
+
btn_add = gr.Button("➕ Add study")
|
| 194 |
+
btn_clear = gr.Button("🗑️ Clear all")
|
| 195 |
+
|
| 196 |
+
queue_table = gr.Dataframe(
|
| 197 |
+
headers=["Study", "Description", "Left paths", "Right paths"],
|
| 198 |
+
datatype=["str", "str", "str", "str"],
|
| 199 |
+
interactive=False,
|
| 200 |
+
label="Queued studies (full paths)",
|
| 201 |
+
row_count=(0, "dynamic"),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def _add_study_fn(studies: List[Dict[str, Any]], name, desc, left_files, right_files):
|
| 205 |
+
name = (name or "").strip() or f"Study_{len(studies)+1}"
|
| 206 |
+
desc = (desc or "").strip()
|
| 207 |
+
left_paths = _files_to_paths(left_files)
|
| 208 |
+
right_paths = _files_to_paths(right_files)
|
| 209 |
+
new = Study(name=name, description=desc, left_paths=left_paths, right_paths=right_paths)
|
| 210 |
+
studies = studies + [asdict(new)]
|
| 211 |
+
table = [[s["name"], s.get("description",""), "\n".join(s["left_paths"]), "\n".join(s["right_paths"])] for s in studies]
|
| 212 |
+
return studies, table, "", "", None, None
|
| 213 |
+
|
| 214 |
+
btn_add.click(
|
| 215 |
+
_add_study_fn,
|
| 216 |
+
inputs=[studies_state, study_name, study_desc, add_left, add_right],
|
| 217 |
+
outputs=[studies_state, queue_table, study_name, study_desc, add_left, add_right],
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def _clear_all():
|
| 221 |
+
return [], []
|
| 222 |
+
|
| 223 |
+
btn_clear.click(_clear_all, inputs=None, outputs=[studies_state, queue_table])
|
| 224 |
+
|
| 225 |
+
run_btn = gr.Button("🚀 Run inference", variant="primary")
|
| 226 |
+
out_json = gr.JSON(label="Results")
|
| 227 |
+
|
| 228 |
+
def _run_infer(studies):
|
| 229 |
+
study_objs = [Study(**s) for s in (studies or [])]
|
| 230 |
+
return run_inference(study_objs)
|
| 231 |
+
|
| 232 |
+
run_btn.click(_run_infer, inputs=[studies_state], outputs=[out_json])
|
| 233 |
+
|
| 234 |
+
gr.Markdown("⚠️ Research-only. Not a medical device. Predictions depend on input quality and domain shift.")
|
| 235 |
+
return demo
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
favicon = LOGO_PATH if (LOGO_PATH and os.path.exists(LOGO_PATH)) else None
|
| 240 |
+
ui().launch(favicon_path=favicon)
|
huggingface.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: gradio
|
| 2 |
+
python_version: "3.11"
|
| 3 |
+
app_file: app.py
|
inference.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py — pack-per-side, average only over models (parity with app.py)
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
from src.syntax_pred.model import SyntaxLightningModule
|
| 11 |
+
from src.syntax_pred.utils import pick_device, discover_weights
|
| 12 |
+
from src.syntax_pred.preprocess import (
|
| 13 |
+
read_dicom_uint8,
|
| 14 |
+
ensure_length_center_crop,
|
| 15 |
+
test_like_transform,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def build_model(cfg, weight_path: str) -> SyntaxLightningModule:
|
| 19 |
+
"""Создаёт модель и загружает веса целого модуля."""
|
| 20 |
+
m = SyntaxLightningModule(
|
| 21 |
+
num_classes=cfg.num_classes,
|
| 22 |
+
lr=1e-5,
|
| 23 |
+
variant=cfg.variant,
|
| 24 |
+
weight_decay=0.0,
|
| 25 |
+
max_epochs=1,
|
| 26 |
+
pl_weight_path=weight_path,
|
| 27 |
+
rnn_hidden_div=cfg.a_rnn.hidden_div,
|
| 28 |
+
rnn_dropout=cfg.a_rnn.dropout,
|
| 29 |
+
bert_nhead=cfg.bert.nhead,
|
| 30 |
+
bert_layers=cfg.bert.num_layers,
|
| 31 |
+
bert_ff_div=cfg.bert.ff_div,
|
| 32 |
+
bert_dropout=cfg.bert.dropout,
|
| 33 |
+
precision=cfg.precision,
|
| 34 |
+
)
|
| 35 |
+
m.eval()
|
| 36 |
+
return m
|
| 37 |
+
|
| 38 |
+
def _pack_side_to_tensor(file_paths, frames_per_clip, video_size) -> torch.Tensor | None:
|
| 39 |
+
"""
|
| 40 |
+
Собирает список DICOM в один батч тензоров формы (1, S, C, T, H, W),
|
| 41 |
+
где S — число клипов/файлов для одной артерии.
|
| 42 |
+
"""
|
| 43 |
+
paths = [p for p in (file_paths or []) if os.path.exists(p)]
|
| 44 |
+
if not paths:
|
| 45 |
+
return None
|
| 46 |
+
tx = test_like_transform(tuple(video_size))
|
| 47 |
+
clips = []
|
| 48 |
+
for p in paths:
|
| 49 |
+
arr = read_dicom_uint8(p) # (T,H,W) uint8
|
| 50 |
+
arr = ensure_length_center_crop(arr, int(frames_per_clip))
|
| 51 |
+
thwc = np.stack([arr, arr, arr], axis=-1) # (T,H,W,3)
|
| 52 |
+
thwc = torch.tensor(thwc, dtype=torch.uint8)
|
| 53 |
+
cthw = tx(thwc) # (C,T,H,W)
|
| 54 |
+
clips.append(cthw)
|
| 55 |
+
if not clips:
|
| 56 |
+
return None
|
| 57 |
+
return torch.stack(clips, dim=0).unsqueeze(0) # (1,S,C,T,H,W)
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def score_side_by_models(file_paths, model_paths, frames_per_clip, video_size, device) -> dict:
|
| 61 |
+
"""
|
| 62 |
+
Внутри одной артерии: пакуем все DICOM в одну последовательность и
|
| 63 |
+
прогоняем через каждую модель. Итог — среднее по моделям.
|
| 64 |
+
"""
|
| 65 |
+
x = _pack_side_to_tensor(file_paths, frames_per_clip, video_size)
|
| 66 |
+
if x is None:
|
| 67 |
+
return {"mean": 0.0, "per_model": [], "n_files": 0}
|
| 68 |
+
x = x.to(device)
|
| 69 |
+
|
| 70 |
+
per_model = []
|
| 71 |
+
scores = []
|
| 72 |
+
for wp in model_paths:
|
| 73 |
+
try:
|
| 74 |
+
m = build_model(CFG, wp).to(device)
|
| 75 |
+
y = m(x) # (1,2): [logit, log(1+score)]
|
| 76 |
+
reg_log = float(y[0, 1].item())
|
| 77 |
+
s = float(max(0.0, np.exp(reg_log) - 1.0))
|
| 78 |
+
per_model.append({"model": os.path.basename(wp), "score": round(s, 6)})
|
| 79 |
+
scores.append(s)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"[WARN] model {wp} failed: {e}")
|
| 82 |
+
mean_score = float(np.mean(scores)) if scores else 0.0
|
| 83 |
+
return {"mean": mean_score, "per_model": per_model, "n_files": len(file_paths or [])}
|
| 84 |
+
|
| 85 |
+
def list_weights(cfg) -> tuple[list[str], list[str]]:
|
| 86 |
+
"""Берёт пути из конфига или ищет по директориям weights/left и weights/right."""
|
| 87 |
+
left = cfg.weights.left or discover_weights("weights/left")
|
| 88 |
+
right = cfg.weights.right or discover_weights("weights/right")
|
| 89 |
+
return left, right
|
| 90 |
+
|
| 91 |
+
def main():
|
| 92 |
+
ap = argparse.ArgumentParser()
|
| 93 |
+
ap.add_argument("--config", default="configs/default.yaml")
|
| 94 |
+
ap.add_argument("--left", nargs="*", default=[], help="List of LEFT DICOM files")
|
| 95 |
+
ap.add_argument("--right", nargs="*", default=[], help="List of RIGHT DICOM files")
|
| 96 |
+
args = ap.parse_args()
|
| 97 |
+
|
| 98 |
+
global CFG
|
| 99 |
+
CFG = OmegaConf.load(args.config)
|
| 100 |
+
|
| 101 |
+
device = pick_device(CFG.device)
|
| 102 |
+
left_w, right_w = list_weights(CFG)
|
| 103 |
+
|
| 104 |
+
if not left_w and not right_w:
|
| 105 |
+
print({"error": "No weights found. Upload to weights/left and weights/right."})
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
left_res = score_side_by_models(args.left, left_w, CFG.frames_per_clip, CFG.video_size, device)
|
| 109 |
+
right_res = score_side_by_models(args.right, right_w, CFG.frames_per_clip, CFG.video_size, device)
|
| 110 |
+
|
| 111 |
+
total = left_res["mean"] + right_res["mean"]
|
| 112 |
+
out = {
|
| 113 |
+
"left": {"mean": round(left_res["mean"], 6), "per_model": left_res["per_model"], "n_files": left_res["n_files"]},
|
| 114 |
+
"right": {"mean": round(right_res["mean"], 6), "per_model": right_res["per_model"], "n_files": right_res["n_files"]},
|
| 115 |
+
"total": {"mean": round(total, 6)},
|
| 116 |
+
}
|
| 117 |
+
print(out)
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
lightning
|
| 4 |
+
pydicom
|
| 5 |
+
pytorchvideo
|
| 6 |
+
omegaconf
|
| 7 |
+
gradio
|
| 8 |
+
numpy
|
| 9 |
+
scikit-learn
|
| 10 |
+
plotly
|