MesserMMP commited on
Commit
8f49763
·
verified ·
1 Parent(s): 7ae43e9

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +55 -11
  2. app.py +240 -0
  3. huggingface.yaml +3 -0
  4. inference.py +120 -0
  5. requirements.txt +10 -0
README.md CHANGED
@@ -1,14 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Syntax Video Infer
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: unknown
11
- short_description: Интерфейс для автоматической оценки SYNTAX-score
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🫀 SYNTAX-Video — Multi-study Inference
2
+
3
+ **SYNTAX-Video** — это интерфейс для автоматической оценки **SYNTAX-score** по видеозаписям коронарной ангиографии.
4
+ Модель анализирует DICOM-файлы левой и правой коронарной артерии, вычисляя вклад каждой и итоговый суммарный показатель.
5
+
6
+ ---
7
+
8
+ ## 🚀 Возможности
9
+
10
+ * Загрузка исследований в формате **DICOM**
11
+ * Поддержка **многомодельного ансамбля** (усреднение только по моделям)
12
+ * Обработка **нескольких клипов** одной артерии как единой последовательности
13
+ * Интерфейс для **множественных исследований** в одном сеансе
14
+ * Автоматическое определение весов модели (`weights/left`, `weights/right`)
15
+
16
+ ---
17
+
18
+ ## 🧭 Как использовать
19
+
20
+ 1. Укажите **ID исследования** и при необходимости краткое **описание**.
21
+ 2. Загрузите **DICOM-файлы** для **левой** и/или **правой** артерии.
22
+ 3. Нажмите **“Add study”**, затем **“Run inference”**.
23
+ 4. Результаты отобразятся в виде JSON-структуры с:
24
+
25
+ * предсказаниями по каждой модели,
26
+ * средним значением для каждой артерии,
27
+ * суммарным SYNTAX-score,
28
+ * пометкой «High-risk» при превышении порога.
29
+
30
  ---
31
+
32
+ ## ⚙️ Технические детали
33
+
34
+ * **Backbone:** `r3d_18` (torchvision)
35
+ * **Head:** `lstm_mean` (возможно использование других вариантов: GRU, mean, BERT-head и др.)
36
+ * **Формат входа:** `(1, S, C, T, H, W)` — пакет клипов одной артерии
37
+ * **Нормализация:** стандарт ImageNet
38
+ * **Усреднение:** только по моделям ансамбля (без усреднения по клипам внутри исследования)
39
+ * **Пороговые значения:**
40
+
41
+ * левая артерия ≥ 15
42
+ * правая артерия ≥ 5
43
+ * общее исследование ≥ 22
44
+
45
  ---
46
 
47
+ ## 🧩 Структура проекта
48
+
49
+ ```
50
+ configs/
51
+ └── default.yaml # Основная конфигурация
52
+ src/syntax_pred/
53
+ ├── model.py # Архитектура и загрузка весов
54
+ ├── preprocess.py # Обработка DICOM и трансформации
55
+ ├── data.py # Минималистичный датасет для инференса
56
+ ├── utils.py # Поддержка: выбор устройства, поиск весов
57
+ app.py # Gradio-интерфейс
58
+ ```
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — SYNTAX-Video (multi-study UI) — pack-per-study, average only over models
2
+ import os
3
+ import base64
4
+ from dataclasses import dataclass, asdict
5
+ from typing import List, Dict, Any, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import gradio as gr
10
+ from omegaconf import OmegaConf
11
+
12
+ from src.syntax_pred.utils import pick_device, discover_weights
13
+ from src.syntax_pred.model import SyntaxLightningModule
14
+ from src.syntax_pred.preprocess import (
15
+ read_dicom_uint8,
16
+ ensure_length_center_crop,
17
+ test_like_transform,
18
+ IMAGENET_MEAN,
19
+ IMAGENET_STD,
20
+ )
21
+
22
+ # -------- Globals --------
23
+ CFG = OmegaConf.load("configs/default.yaml")
24
+ DEVICE = pick_device(CFG.device)
25
+
26
+ # ===== Logo handling (base64) =====
27
+ DEFAULT_LOGO = "assets/logo.png"
28
+ LOGO_PATH = os.environ.get("LOGO_PATH", DEFAULT_LOGO)
29
+
30
+ def _logo_html() -> str:
31
+ path = LOGO_PATH
32
+ if not path or not os.path.exists(path):
33
+ return ""
34
+ try:
35
+ with open(path, "rb") as f:
36
+ b64 = base64.b64encode(f.read()).decode("ascii")
37
+ ext = os.path.splitext(path)[1].lower()
38
+ mime = "image/png" if ext in {".png", ""} else "image/jpeg" if ext in {".jpg", ".jpeg"} else "image/png"
39
+ data_uri = f"data:{mime};base64,{b64}"
40
+ return (
41
+ f'<img src="{data_uri}" alt="logo" '
42
+ f'style="height:40px;vertical-align:middle;display:inline-block;'
43
+ f'image-rendering:auto;object-fit:contain;margin-right:12px;" />'
44
+ )
45
+ except Exception:
46
+ return ""
47
+
48
+ # ===== Models =====
49
+ def build_model(weight_path: str):
50
+ return SyntaxLightningModule(
51
+ num_classes=CFG.num_classes,
52
+ lr=1e-5,
53
+ variant=CFG.variant,
54
+ pl_weight_path=weight_path,
55
+ rnn_hidden_div=CFG.a_rnn.hidden_div,
56
+ rnn_dropout=CFG.a_rnn.dropout,
57
+ bert_nhead=CFG.bert.nhead,
58
+ bert_layers=CFG.bert.num_layers,
59
+ bert_ff_div=CFG.bert.ff_div,
60
+ bert_dropout=CFG.bert.dropout,
61
+ precision=CFG.precision,
62
+ ).to(DEVICE).eval()
63
+
64
+ def list_weights() -> Tuple[List[str], List[str]]:
65
+ left = CFG.weights.left or discover_weights("weights/left")
66
+ right = CFG.weights.right or discover_weights("weights/right")
67
+ return left, right
68
+
69
+ # ===== Data structures =====
70
+ @dataclass
71
+ class Study:
72
+ name: str
73
+ description: str
74
+ left_paths: List[str]
75
+ right_paths: List[str]
76
+
77
+ def _files_to_paths(files) -> List[str]:
78
+ return [f.name for f in (files or []) if hasattr(f, "name") and os.path.exists(f.name)]
79
+
80
+ # ===== Packing: внутри исследования все клипы одной артерии → одна последовательность =====
81
+ def _pack_study_side_to_tensor(
82
+ file_paths: List[str],
83
+ frames_per_clip: int,
84
+ video_size: Tuple[int, int],
85
+ ) -> torch.Tensor:
86
+ if not file_paths:
87
+ return None
88
+ tx = test_like_transform(video_size)
89
+ clips = []
90
+ for p in file_paths:
91
+ arr = read_dicom_uint8(p) # (T,H,W) uint8
92
+ arr = ensure_length_center_crop(arr, frames_per_clip)
93
+ thwc = np.stack([arr, arr, arr], axis=-1) # (T,H,W,3)
94
+ thwc = torch.tensor(thwc, dtype=torch.uint8)
95
+ cthw = tx(thwc) # (C,T,H,W)
96
+ clips.append(cthw)
97
+ if not clips:
98
+ return None
99
+ return torch.stack(clips, dim=0).unsqueeze(0) # (1,S,C,T,H,W)
100
+
101
+ # ===== Inference logic (среднее только по моделям) =====
102
+ @torch.no_grad()
103
+ def _score_side_by_models(
104
+ side_paths: List[str],
105
+ model_paths: List[str],
106
+ frames_per_clip: int,
107
+ video_size: Tuple[int, int],
108
+ ) -> Dict[str, Any]:
109
+ if not side_paths:
110
+ return {"mean": 0.0, "per_model": [], "n_files": 0}
111
+
112
+ x = _pack_study_side_to_tensor(side_paths, frames_per_clip, video_size)
113
+ if x is None:
114
+ return {"mean": 0.0, "per_model": [], "n_files": 0}
115
+ x = x.to(DEVICE)
116
+
117
+ per_model_scores: List[float] = []
118
+ used_models: List[str] = []
119
+
120
+ for wp in model_paths:
121
+ try:
122
+ m = build_model(wp)
123
+ y = m(x) # (1,2)
124
+ reg_log = float(y[0, 1].detach().cpu().numpy())
125
+ score = float(max(0.0, np.exp(reg_log) - 1.0)) # inverse log(1+score)
126
+ per_model_scores.append(score)
127
+ used_models.append(os.path.basename(wp))
128
+ except Exception as e:
129
+ print(f"[WARN] model {wp} failed: {e}")
130
+
131
+ mean_score = float(np.mean(per_model_scores)) if per_model_scores else 0.0
132
+ return {
133
+ "mean": mean_score,
134
+ "per_model": [{"model": n, "score": round(s, 3)} for n, s in zip(used_models, per_model_scores)],
135
+ "n_files": len(side_paths),
136
+ }
137
+
138
+ def run_inference(studies: List[Study]) -> Dict[str, Any]:
139
+ left_w, right_w = list_weights()
140
+ if not left_w and not right_w:
141
+ return {"error": "No weights found. Upload to weights/left and weights/right."}
142
+
143
+ results = {"studies": []}
144
+ thr = CFG.thresholds.both
145
+ video_size = tuple(CFG.video_size)
146
+ frames = CFG.frames_per_clip
147
+
148
+ for st in studies:
149
+ left_res = _score_side_by_models(st.left_paths, left_w, frames, video_size)
150
+ right_res = _score_side_by_models(st.right_paths, right_w, frames, video_size)
151
+ total = left_res["mean"] + right_res["mean"]
152
+
153
+ results["studies"].append({
154
+ "study": st.name,
155
+ "description": st.description or "",
156
+ "left": {"mean": round(left_res["mean"], 3), "per_model": left_res["per_model"], "n_files": left_res["n_files"]},
157
+ "right": {"mean": round(right_res["mean"], 3), "per_model": right_res["per_model"], "n_files": right_res["n_files"]},
158
+ "total": {"mean": round(total, 3), f"High-risk (≥{thr:.1f})": bool(total >= thr)},
159
+ })
160
+ return results
161
+
162
+ # ===== UI =====
163
+ def ui():
164
+ with gr.Blocks() as demo:
165
+ gr.HTML(
166
+ f"""
167
+ <div style="display:flex;align-items:center;gap:10px;margin-bottom:8px;">
168
+ {_logo_html()}
169
+ <h1 style="margin:0;font-weight:800;text-align:center;flex:1;">SYNTAX-Video — Multi-study Inference</h1>
170
+ </div>
171
+ <ol style="margin:0 0 12px 20px; color:#475569; line-height:1.5;">
172
+ <li>Укажите ID исследования и (необязательно) описание.</li>
173
+ <li>Загрузите DICOM-файлы для ЛЕВОЙ и/или ПРАВОЙ артерии.</li>
174
+ <li>Нажмите “Add study”, чтобы добавить исследование, и затем “Run inference” для запуска анализа.</li>
175
+ </ol>
176
+ """
177
+ )
178
+
179
+
180
+ studies_state = gr.State([]) # list[dict]
181
+
182
+ # --- ВЕРХНИЙ РЯД: Study ID + Description ---
183
+ with gr.Row():
184
+ study_name = gr.Textbox(label="Study ID", placeholder="e.g., S1234")
185
+ study_desc = gr.Textbox(label="Description (optional)", placeholder="Free text...")
186
+
187
+ # --- НИЖНИЙ РЯД: два загрузчика бок-о-бок ---
188
+ with gr.Row():
189
+ add_left = gr.File(label="LEFT artery DICOM(s)", file_count="multiple")
190
+ add_right = gr.File(label="RIGHT artery DICOM(s)", file_count="multiple")
191
+
192
+ with gr.Row():
193
+ btn_add = gr.Button("➕ Add study")
194
+ btn_clear = gr.Button("🗑️ Clear all")
195
+
196
+ queue_table = gr.Dataframe(
197
+ headers=["Study", "Description", "Left paths", "Right paths"],
198
+ datatype=["str", "str", "str", "str"],
199
+ interactive=False,
200
+ label="Queued studies (full paths)",
201
+ row_count=(0, "dynamic"),
202
+ )
203
+
204
+ def _add_study_fn(studies: List[Dict[str, Any]], name, desc, left_files, right_files):
205
+ name = (name or "").strip() or f"Study_{len(studies)+1}"
206
+ desc = (desc or "").strip()
207
+ left_paths = _files_to_paths(left_files)
208
+ right_paths = _files_to_paths(right_files)
209
+ new = Study(name=name, description=desc, left_paths=left_paths, right_paths=right_paths)
210
+ studies = studies + [asdict(new)]
211
+ table = [[s["name"], s.get("description",""), "\n".join(s["left_paths"]), "\n".join(s["right_paths"])] for s in studies]
212
+ return studies, table, "", "", None, None
213
+
214
+ btn_add.click(
215
+ _add_study_fn,
216
+ inputs=[studies_state, study_name, study_desc, add_left, add_right],
217
+ outputs=[studies_state, queue_table, study_name, study_desc, add_left, add_right],
218
+ )
219
+
220
+ def _clear_all():
221
+ return [], []
222
+
223
+ btn_clear.click(_clear_all, inputs=None, outputs=[studies_state, queue_table])
224
+
225
+ run_btn = gr.Button("🚀 Run inference", variant="primary")
226
+ out_json = gr.JSON(label="Results")
227
+
228
+ def _run_infer(studies):
229
+ study_objs = [Study(**s) for s in (studies or [])]
230
+ return run_inference(study_objs)
231
+
232
+ run_btn.click(_run_infer, inputs=[studies_state], outputs=[out_json])
233
+
234
+ gr.Markdown("⚠️ Research-only. Not a medical device. Predictions depend on input quality and domain shift.")
235
+ return demo
236
+
237
+
238
+ if __name__ == "__main__":
239
+ favicon = LOGO_PATH if (LOGO_PATH and os.path.exists(LOGO_PATH)) else None
240
+ ui().launch(favicon_path=favicon)
huggingface.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: gradio
2
+ python_version: "3.11"
3
+ app_file: app.py
inference.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py — pack-per-side, average only over models (parity with app.py)
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ from omegaconf import OmegaConf
9
+
10
+ from src.syntax_pred.model import SyntaxLightningModule
11
+ from src.syntax_pred.utils import pick_device, discover_weights
12
+ from src.syntax_pred.preprocess import (
13
+ read_dicom_uint8,
14
+ ensure_length_center_crop,
15
+ test_like_transform,
16
+ )
17
+
18
+ def build_model(cfg, weight_path: str) -> SyntaxLightningModule:
19
+ """Создаёт модель и загружает веса целого модуля."""
20
+ m = SyntaxLightningModule(
21
+ num_classes=cfg.num_classes,
22
+ lr=1e-5,
23
+ variant=cfg.variant,
24
+ weight_decay=0.0,
25
+ max_epochs=1,
26
+ pl_weight_path=weight_path,
27
+ rnn_hidden_div=cfg.a_rnn.hidden_div,
28
+ rnn_dropout=cfg.a_rnn.dropout,
29
+ bert_nhead=cfg.bert.nhead,
30
+ bert_layers=cfg.bert.num_layers,
31
+ bert_ff_div=cfg.bert.ff_div,
32
+ bert_dropout=cfg.bert.dropout,
33
+ precision=cfg.precision,
34
+ )
35
+ m.eval()
36
+ return m
37
+
38
+ def _pack_side_to_tensor(file_paths, frames_per_clip, video_size) -> torch.Tensor | None:
39
+ """
40
+ Собирает список DICOM в один батч тензоров формы (1, S, C, T, H, W),
41
+ где S — число клипов/файлов для одной артерии.
42
+ """
43
+ paths = [p for p in (file_paths or []) if os.path.exists(p)]
44
+ if not paths:
45
+ return None
46
+ tx = test_like_transform(tuple(video_size))
47
+ clips = []
48
+ for p in paths:
49
+ arr = read_dicom_uint8(p) # (T,H,W) uint8
50
+ arr = ensure_length_center_crop(arr, int(frames_per_clip))
51
+ thwc = np.stack([arr, arr, arr], axis=-1) # (T,H,W,3)
52
+ thwc = torch.tensor(thwc, dtype=torch.uint8)
53
+ cthw = tx(thwc) # (C,T,H,W)
54
+ clips.append(cthw)
55
+ if not clips:
56
+ return None
57
+ return torch.stack(clips, dim=0).unsqueeze(0) # (1,S,C,T,H,W)
58
+
59
+ @torch.no_grad()
60
+ def score_side_by_models(file_paths, model_paths, frames_per_clip, video_size, device) -> dict:
61
+ """
62
+ Внутри одной артерии: пакуем все DICOM в одну последовательность и
63
+ прогоняем через каждую модель. Итог — среднее по моделям.
64
+ """
65
+ x = _pack_side_to_tensor(file_paths, frames_per_clip, video_size)
66
+ if x is None:
67
+ return {"mean": 0.0, "per_model": [], "n_files": 0}
68
+ x = x.to(device)
69
+
70
+ per_model = []
71
+ scores = []
72
+ for wp in model_paths:
73
+ try:
74
+ m = build_model(CFG, wp).to(device)
75
+ y = m(x) # (1,2): [logit, log(1+score)]
76
+ reg_log = float(y[0, 1].item())
77
+ s = float(max(0.0, np.exp(reg_log) - 1.0))
78
+ per_model.append({"model": os.path.basename(wp), "score": round(s, 6)})
79
+ scores.append(s)
80
+ except Exception as e:
81
+ print(f"[WARN] model {wp} failed: {e}")
82
+ mean_score = float(np.mean(scores)) if scores else 0.0
83
+ return {"mean": mean_score, "per_model": per_model, "n_files": len(file_paths or [])}
84
+
85
+ def list_weights(cfg) -> tuple[list[str], list[str]]:
86
+ """Берёт пути из конфига или ищет по директориям weights/left и weights/right."""
87
+ left = cfg.weights.left or discover_weights("weights/left")
88
+ right = cfg.weights.right or discover_weights("weights/right")
89
+ return left, right
90
+
91
+ def main():
92
+ ap = argparse.ArgumentParser()
93
+ ap.add_argument("--config", default="configs/default.yaml")
94
+ ap.add_argument("--left", nargs="*", default=[], help="List of LEFT DICOM files")
95
+ ap.add_argument("--right", nargs="*", default=[], help="List of RIGHT DICOM files")
96
+ args = ap.parse_args()
97
+
98
+ global CFG
99
+ CFG = OmegaConf.load(args.config)
100
+
101
+ device = pick_device(CFG.device)
102
+ left_w, right_w = list_weights(CFG)
103
+
104
+ if not left_w and not right_w:
105
+ print({"error": "No weights found. Upload to weights/left and weights/right."})
106
+ return
107
+
108
+ left_res = score_side_by_models(args.left, left_w, CFG.frames_per_clip, CFG.video_size, device)
109
+ right_res = score_side_by_models(args.right, right_w, CFG.frames_per_clip, CFG.video_size, device)
110
+
111
+ total = left_res["mean"] + right_res["mean"]
112
+ out = {
113
+ "left": {"mean": round(left_res["mean"], 6), "per_model": left_res["per_model"], "n_files": left_res["n_files"]},
114
+ "right": {"mean": round(right_res["mean"], 6), "per_model": right_res["per_model"], "n_files": right_res["n_files"]},
115
+ "total": {"mean": round(total, 6)},
116
+ }
117
+ print(out)
118
+
119
+ if __name__ == "__main__":
120
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ lightning
4
+ pydicom
5
+ pytorchvideo
6
+ omegaconf
7
+ gradio
8
+ numpy
9
+ scikit-learn
10
+ plotly