Yatsuiii commited on
Commit
c71d06b
·
verified ·
1 Parent(s): 1bc9be2

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -17,18 +17,62 @@ _STEP = 3
17
  _MAX_WINDOWS = 30
18
  _FC_THRESHOLD = 0.2
19
 
20
- # CC200 atlas (Craddock 2012) → approximate Yeo 7-network parcellation
21
- _NET_NAMES = ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"]
22
- _NET_BOUNDS = [0, 38, 69, 99, 137, 165, 180, 200]
23
- _NET_COLORS = ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"]
24
-
25
- _CKPTS = {
26
- "NYU": Path("checkpoints/nyu.ckpt"),
27
- "USM": Path("checkpoints/usm.ckpt"),
28
- "UCLA": Path("checkpoints/ucla.ckpt"),
29
- "UM": Path("checkpoints/um.ckpt"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
 
 
 
 
 
 
 
 
 
 
32
  # ── preprocessing ──────────────────────────────────────────────────────────
33
 
34
  def _zscore(bold):
@@ -60,21 +104,23 @@ def preprocess(bold):
60
 
61
  # ── model loading ──────────────────────────────────────────────────────────
62
 
63
- _models = None
64
 
65
- def get_models():
66
- global _models
67
- if _models is not None:
68
- return _models
69
  from brain_gcn.tasks import ClassificationTask
70
- _models = []
71
- for site, ckpt in _CKPTS.items():
 
72
  if not ckpt.exists():
73
  continue
74
  task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
75
  task.eval()
76
- _models.append((site, task))
77
- return _models
 
78
 
79
  # ── gradient saliency ──────────────────────────────────────────────────────
80
 
@@ -99,7 +145,7 @@ _NET_MNI = np.array([
99
  [ 14, 4, 4], # Subcortical (thalamus)
100
  ], dtype=np.float32)
101
 
102
- def _saliency_figure(sal, p_mean):
103
  import matplotlib
104
  matplotlib.use("Agg")
105
  import matplotlib.pyplot as plt
@@ -107,18 +153,21 @@ def _saliency_figure(sal, p_mean):
107
  from mpl_toolkits.mplot3d.art3d import Line3DCollection
108
  from PIL import Image
109
 
110
- n_nets = len(_NET_NAMES)
 
 
 
111
 
112
- # Aggregate 200×200 saliency → 7×7 network-level matrix
113
  net_sal = np.zeros((n_nets, n_nets))
114
- for i, (s1, e1) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
115
- for j, (s2, e2) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
116
  net_sal[i, j] = sal[s1:e1, s2:e2].mean()
117
 
118
  # Network importance: mean outgoing + incoming saliency per network
119
  net_imp = np.array([
120
  sal[s:e, :].mean() + sal[:, s:e].mean()
121
- for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
122
  ])
123
 
124
  fig = plt.figure(figsize=(18, 5.5))
@@ -137,8 +186,8 @@ def _saliency_figure(sal, p_mean):
137
 
138
  ax.set_xticks(range(n_nets))
139
  ax.set_yticks(range(n_nets))
140
- ax.set_xticklabels(_NET_NAMES, rotation=40, ha="right", fontsize=9, color="#ccc")
141
- ax.set_yticklabels(_NET_NAMES, fontsize=9, color="#ccc")
142
  ax.tick_params(colors="#555", length=0)
143
  for sp in ax.spines.values():
144
  sp.set_color("#222")
@@ -173,7 +222,7 @@ def _saliency_figure(sal, p_mean):
173
 
174
  # Callout labels for top-3 cross-network edges
175
  for rank, (score, i, j) in enumerate(top3_edges):
176
- label = f"#{rank+1} {_NET_NAMES[i]}↔{_NET_NAMES[j]}"
177
  ax.annotate(label,
178
  xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
179
  fontsize=6, color="#fb923c", fontweight="600",
@@ -193,9 +242,9 @@ def _saliency_figure(sal, p_mean):
193
 
194
  order = net_imp.argsort()[::-1]
195
  bars = ax2.barh(range(n_nets), net_imp[order],
196
- color=[_NET_COLORS[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
197
  ax2.set_yticks(range(n_nets))
198
- ax2.set_yticklabels([_NET_NAMES[i] for i in order], fontsize=9.5, color="#ddd")
199
  ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
200
  ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
201
  ax2.invert_yaxis()
@@ -250,7 +299,7 @@ def _saliency_figure(sal, p_mean):
250
  ax3.set_box_aspect([1.2, 1.4, 1.0])
251
 
252
  fig.suptitle(
253
- f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
254
  color="#444", fontsize=8.5, y=1.02,
255
  )
256
  plt.tight_layout()
@@ -267,6 +316,7 @@ def run_gcn(file_path):
267
  return "", "", "", None
268
 
269
  path = Path(file_path)
 
270
  try:
271
  if path.suffix == ".npz":
272
  d = np.load(path, allow_pickle=True)
@@ -282,13 +332,40 @@ def run_gcn(file_path):
282
  adj_t = torch.FloatTensor(adj).unsqueeze(0)
283
  else:
284
  bold = np.loadtxt(path, dtype=np.float32)
285
- if bold.ndim != 2 or bold.shape[1] != 200:
286
- return f"Error: expected (200), got {bold.shape}", "", "", None
 
 
 
 
 
 
 
 
 
 
 
 
287
  bw_t, adj_t = preprocess(bold)
288
  except Exception as e:
289
  return f"Error loading file: {e}", "", "", None
290
 
291
- models = get_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  per_model = []
293
  with torch.no_grad():
294
  for site, task in models:
@@ -300,7 +377,12 @@ def run_gcn(file_path):
300
  conf = max(p_mean, 1 - p_mean) * 100
301
 
302
  try:
303
- sal_img = _saliency_figure(_compute_saliency(bw_t, adj_t, models), p_mean)
 
 
 
 
 
304
  except Exception:
305
  sal_img = None
306
 
@@ -398,7 +480,7 @@ LOSO AUC = 0.7872 · 529 held-out subjects · 4 institutions
398
  <div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">ICD-10 Classification</div>
399
  <div style="color:#cbd5e1;font-size:0.84rem;line-height:1.4">{icd}</div></div>
400
  <div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">Ensemble Confidence</div>
401
- <div style="color:#cbd5e1;font-size:0.84rem">{conf:.1f}% · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO</div></div>
402
  </div>
403
 
404
  <div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Impression</div>
@@ -576,7 +658,7 @@ ARCHITECTURE = """
576
  <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:14px 16px;min-width:130px;text-align:center;flex-shrink:0">
577
  <div style="color:#8b95a7;font-size:0.65rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:6px">Input</div>
578
  <div style="color:#f4f4f5;font-weight:600;font-size:0.88rem">fMRI BOLD</div>
579
- <div style="color:#5e6675;font-size:0.74rem;margin-top:3px">T × 200 ROIs</div>
580
  </div>
581
 
582
  <div style="color:#252a35;font-size:1.4rem;padding:0 6px;flex-shrink:0">→</div>
@@ -637,7 +719,7 @@ ARCHITECTURE = """
637
  <div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
638
  <table style="width:100%;border-collapse:collapse;font-size:0.85rem">
639
  <tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Dataset</td><td style="padding:10px 16px;color:#cbd5e1">ABIDE I · 1,102 subjects · 17 acquisition sites</td></tr>
640
- <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Parcellation</td><td style="padding:10px 16px;color:#cbd5e1">CC200 (Craddock 2012) · 200 functional ROIs</td></tr>
641
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Model</td><td style="padding:10px 16px;color:#cbd5e1">AdversarialBrainModeNetwork · K=16 modes · hidden_dim=64</td></tr>
642
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Validation</td><td style="padding:10px 16px;color:#cbd5e1">LOSO AUC = <span style="color:#ef4444;font-weight:600">0.7872</span> · 529 held-out subjects · 0 confident misclassifications</td></tr>
643
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Interpretability</td><td style="padding:10px 16px;color:#cbd5e1">Real-time gradient saliency · 7-network aggregation · 3D brain surface</td></tr>
@@ -715,7 +797,7 @@ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as dem
715
 
716
  with gr.Tabs():
717
  with gr.Tab("Analysis"):
718
- file_input = gr.File(label="Upload CC200 fMRI (.1D or .npz)", type="filepath")
719
  gr.HTML("<div style='color:#8b95a7;font-size:0.7rem;text-transform:uppercase;letter-spacing:1.2px;margin:14px 0 8px;font-weight:500'>Or try a real ABIDE subject from a held-out site</div>")
720
  with gr.Row():
721
  btn_asd = gr.Button("ASD · Stanford 0051160", size="sm")
 
17
  _MAX_WINDOWS = 30
18
  _FC_THRESHOLD = 0.2
19
 
20
+ # ── Atlas configurations ────────────────────────────────────────────────────
21
+ # CC200 Yeo 7-network parcellation (approximate ROI ordering)
22
+ _ATLAS_CFG = {
23
+ "cc200": {
24
+ "n_rois": 200,
25
+ "label": "CC200",
26
+ "net_names": ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"],
27
+ "net_bounds": [0, 38, 69, 99, 137, 165, 180, 200],
28
+ "net_colors": ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"],
29
+ "ckpts": {
30
+ "NYU": Path("checkpoints/nyu.ckpt"),
31
+ "USM": Path("checkpoints/usm.ckpt"),
32
+ "UCLA": Path("checkpoints/ucla.ckpt"),
33
+ "UM": Path("checkpoints/um.ckpt"),
34
+ },
35
+ },
36
+ "aal": {
37
+ "n_rois": 116,
38
+ "label": "AAL-116",
39
+ # Approximate Yeo-7 parcellation for AAL-116 anatomical ordering:
40
+ # Frontal/FPN (1-28), Sensorimotor (29-40), DMN parietal (41-60),
41
+ # Temporal/DMN (61-76), Subcortical (77-90), Occipital/Visual (91-116)
42
+ "net_names": ["Frontoparietal", "Sensorimotor", "Dorsal Attn", "DMN", "Salience", "Subcortical", "Visual"],
43
+ "net_bounds": [0, 20, 34, 50, 68, 80, 92, 116],
44
+ "net_colors": ["#457b9d", "#2dc653", "#8b5cf6", "#e63946", "#f4a261", "#6b7280", "#a8dadc"],
45
+ "ckpts": {
46
+ "NYU": Path("checkpoints/aal_nyu.ckpt"),
47
+ "USM": Path("checkpoints/aal_usm.ckpt"),
48
+ "UCLA": Path("checkpoints/aal_ucla.ckpt"),
49
+ "UM": Path("checkpoints/aal_um.ckpt"),
50
+ },
51
+ },
52
+ "ho": {
53
+ "n_rois": 111,
54
+ "label": "Harvard-Oxford",
55
+ "net_names": ["Frontoparietal", "Sensorimotor", "DMN", "Salience", "Subcortical", "Visual", "Temporal"],
56
+ "net_bounds": [0, 18, 30, 48, 68, 80, 96, 111],
57
+ "net_colors": ["#457b9d", "#2dc653", "#e63946", "#f4a261", "#6b7280", "#a8dadc", "#8b5cf6"],
58
+ "ckpts": {
59
+ "NYU": Path("checkpoints/ho_nyu.ckpt"),
60
+ "USM": Path("checkpoints/ho_usm.ckpt"),
61
+ "UCLA": Path("checkpoints/ho_ucla.ckpt"),
62
+ "UM": Path("checkpoints/ho_um.ckpt"),
63
+ },
64
+ },
65
  }
66
 
67
+ # Resolve active atlas config by ROI count
68
+ _ROI_TO_ATLAS = {cfg["n_rois"]: key for key, cfg in _ATLAS_CFG.items()}
69
+
70
+ # Legacy aliases kept for backward compat
71
+ _NET_NAMES = _ATLAS_CFG["cc200"]["net_names"]
72
+ _NET_BOUNDS = _ATLAS_CFG["cc200"]["net_bounds"]
73
+ _NET_COLORS = _ATLAS_CFG["cc200"]["net_colors"]
74
+ _CKPTS = _ATLAS_CFG["cc200"]["ckpts"]
75
+
76
  # ── preprocessing ──────────────────────────────────────────────────────────
77
 
78
  def _zscore(bold):
 
104
 
105
  # ── model loading ──────────────────────────────────────────────────────────
106
 
107
+ _model_cache: dict[str, list] = {}
108
 
109
+ def get_models(atlas: str = "cc200"):
110
+ global _model_cache
111
+ if atlas in _model_cache:
112
+ return _model_cache[atlas]
113
  from brain_gcn.tasks import ClassificationTask
114
+ cfg = _ATLAS_CFG.get(atlas, _ATLAS_CFG["cc200"])
115
+ models = []
116
+ for site, ckpt in cfg["ckpts"].items():
117
  if not ckpt.exists():
118
  continue
119
  task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
120
  task.eval()
121
+ models.append((site, task))
122
+ _model_cache[atlas] = models
123
+ return models
124
 
125
  # ── gradient saliency ──────────────────────────────────────────────────────
126
 
 
145
  [ 14, 4, 4], # Subcortical (thalamus)
146
  ], dtype=np.float32)
147
 
148
+ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=None):
149
  import matplotlib
150
  matplotlib.use("Agg")
151
  import matplotlib.pyplot as plt
 
153
  from mpl_toolkits.mplot3d.art3d import Line3DCollection
154
  from PIL import Image
155
 
156
+ _nn = net_names if net_names is not None else _NET_NAMES
157
+ _nb = net_bounds if net_bounds is not None else _NET_BOUNDS
158
+ _nc = net_colors if net_colors is not None else _NET_COLORS
159
+ n_nets = len(_nn)
160
 
161
+ # Aggregate N×N saliency → 7×7 network-level matrix
162
  net_sal = np.zeros((n_nets, n_nets))
163
+ for i, (s1, e1) in enumerate(zip(_nb[:-1], _nb[1:])):
164
+ for j, (s2, e2) in enumerate(zip(_nb[:-1], _nb[1:])):
165
  net_sal[i, j] = sal[s1:e1, s2:e2].mean()
166
 
167
  # Network importance: mean outgoing + incoming saliency per network
168
  net_imp = np.array([
169
  sal[s:e, :].mean() + sal[:, s:e].mean()
170
+ for s, e in zip(_nb[:-1], _nb[1:])
171
  ])
172
 
173
  fig = plt.figure(figsize=(18, 5.5))
 
186
 
187
  ax.set_xticks(range(n_nets))
188
  ax.set_yticks(range(n_nets))
189
+ ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=9, color="#ccc")
190
+ ax.set_yticklabels(_nn, fontsize=9, color="#ccc")
191
  ax.tick_params(colors="#555", length=0)
192
  for sp in ax.spines.values():
193
  sp.set_color("#222")
 
222
 
223
  # Callout labels for top-3 cross-network edges
224
  for rank, (score, i, j) in enumerate(top3_edges):
225
+ label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
226
  ax.annotate(label,
227
  xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
228
  fontsize=6, color="#fb923c", fontweight="600",
 
242
 
243
  order = net_imp.argsort()[::-1]
244
  bars = ax2.barh(range(n_nets), net_imp[order],
245
+ color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
246
  ax2.set_yticks(range(n_nets))
247
+ ax2.set_yticklabels([_nn[i] for i in order], fontsize=9.5, color="#ddd")
248
  ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
249
  ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
250
  ax2.invert_yaxis()
 
299
  ax3.set_box_aspect([1.2, 1.4, 1.0])
300
 
301
  fig.suptitle(
302
+ f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
303
  color="#444", fontsize=8.5, y=1.02,
304
  )
305
  plt.tight_layout()
 
316
  return "", "", "", None
317
 
318
  path = Path(file_path)
319
+ atlas_key = "cc200" # default; overridden below for .1D files
320
  try:
321
  if path.suffix == ".npz":
322
  d = np.load(path, allow_pickle=True)
 
332
  adj_t = torch.FloatTensor(adj).unsqueeze(0)
333
  else:
334
  bold = np.loadtxt(path, dtype=np.float32)
335
+ if bold.ndim != 2:
336
+ return "<div style='color:#ef4444;padding:12px'>Error: file must be a 2D ROIs matrix.</div>", "", "", None
337
+ n_rois = bold.shape[1]
338
+ atlas_key = _ROI_TO_ATLAS.get(n_rois)
339
+ if atlas_key is None:
340
+ supported = ", ".join(f"{cfg['label']} ({cfg['n_rois']} ROIs)" for cfg in _ATLAS_CFG.values())
341
+ return (
342
+ f"<div style='background:#1a1015;border-left:3px solid #ef4444;padding:16px 20px;border-radius:8px;margin-top:14px'>"
343
+ f"<div style='color:#ef4444;font-weight:600;margin-bottom:6px'>Unsupported atlas ({n_rois} ROIs)</div>"
344
+ f"<div style='color:#cbd5e1;font-size:0.88rem;line-height:1.6'>"
345
+ f"Supported: {supported}.<br>"
346
+ f"Download from FCP-INDI S3: <code style='color:#fb923c'>rois_cc200/</code>, <code style='color:#fb923c'>rois_aal/</code>, or <code style='color:#fb923c'>rois_ho/</code>"
347
+ f"</div></div>"
348
+ ), "", "", None
349
  bw_t, adj_t = preprocess(bold)
350
  except Exception as e:
351
  return f"Error loading file: {e}", "", "", None
352
 
353
+
354
+
355
+ atlas_cfg = _ATLAS_CFG[atlas_key]
356
+ models = get_models(atlas_key)
357
+
358
+ if not models:
359
+ atlas_label = atlas_cfg["label"]
360
+ return (
361
+ f"<div style='background:#1a1015;border-left:3px solid #f59e0b;padding:16px 20px;border-radius:8px;margin-top:14px'>"
362
+ f"<div style='color:#f59e0b;font-weight:600;margin-bottom:6px'>{atlas_label} models not yet available</div>"
363
+ f"<div style='color:#cbd5e1;font-size:0.88rem;line-height:1.6'>"
364
+ f"Training is in progress. CC200 models are available now — convert your data with:<br>"
365
+ f"<code style='color:#fb923c;font-size:0.82rem'>aws s3 cp s3://fcp-indi/.../rois_cc200/ . --no-sign-request --recursive</code>"
366
+ f"</div></div>"
367
+ ), "", "", None
368
+
369
  per_model = []
370
  with torch.no_grad():
371
  for site, task in models:
 
377
  conf = max(p_mean, 1 - p_mean) * 100
378
 
379
  try:
380
+ sal_img = _saliency_figure(
381
+ _compute_saliency(bw_t, adj_t, models), p_mean,
382
+ net_names=atlas_cfg["net_names"],
383
+ net_bounds=atlas_cfg["net_bounds"],
384
+ net_colors=atlas_cfg["net_colors"],
385
+ )
386
  except Exception:
387
  sal_img = None
388
 
 
480
  <div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">ICD-10 Classification</div>
481
  <div style="color:#cbd5e1;font-size:0.84rem;line-height:1.4">{icd}</div></div>
482
  <div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">Ensemble Confidence</div>
483
+ <div style="color:#cbd5e1;font-size:0.84rem">{conf:.1f}% · p(ASD) = {p_mean:.3f} · {len(models)}-model LOSO</div></div>
484
  </div>
485
 
486
  <div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Impression</div>
 
658
  <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:14px 16px;min-width:130px;text-align:center;flex-shrink:0">
659
  <div style="color:#8b95a7;font-size:0.65rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:6px">Input</div>
660
  <div style="color:#f4f4f5;font-weight:600;font-size:0.88rem">fMRI BOLD</div>
661
+ <div style="color:#5e6675;font-size:0.74rem;margin-top:3px">T × ROIs (CC200/AAL/HO)</div>
662
  </div>
663
 
664
  <div style="color:#252a35;font-size:1.4rem;padding:0 6px;flex-shrink:0">→</div>
 
719
  <div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
720
  <table style="width:100%;border-collapse:collapse;font-size:0.85rem">
721
  <tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Dataset</td><td style="padding:10px 16px;color:#cbd5e1">ABIDE I · 1,102 subjects · 17 acquisition sites</td></tr>
722
+ <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Parcellation</td><td style="padding:10px 16px;color:#cbd5e1">CC200 (200 ROIs) · AAL-116 (116 ROIs) · Harvard-Oxford (111 ROIs)</td></tr>
723
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Model</td><td style="padding:10px 16px;color:#cbd5e1">AdversarialBrainModeNetwork · K=16 modes · hidden_dim=64</td></tr>
724
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Validation</td><td style="padding:10px 16px;color:#cbd5e1">LOSO AUC = <span style="color:#ef4444;font-weight:600">0.7872</span> · 529 held-out subjects · 0 confident misclassifications</td></tr>
725
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Interpretability</td><td style="padding:10px 16px;color:#cbd5e1">Real-time gradient saliency · 7-network aggregation · 3D brain surface</td></tr>
 
797
 
798
  with gr.Tabs():
799
  with gr.Tab("Analysis"):
800
+ file_input = gr.File(label="Upload fMRI time series — CC200 (200), AAL (116), or Harvard-Oxford (111) ROIs · .1D or .npz", type="filepath")
801
  gr.HTML("<div style='color:#8b95a7;font-size:0.7rem;text-transform:uppercase;letter-spacing:1.2px;margin:14px 0 8px;font-weight:500'>Or try a real ABIDE subject from a held-out site</div>")
802
  with gr.Row():
803
  btn_asd = gr.Button("ASD · Stanford 0051160", size="sm")
app_with_llm.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
3
+ Full pipeline: Adversarial GCN + Qwen2.5-7B fine-tuned on AMD MI300X.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import io
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ import gradio as gr
13
+
14
+ _WINDOW_LEN = 50
15
+ _STEP = 3
16
+ _MAX_WINDOWS = 30
17
+ _FC_THRESHOLD = 0.2
18
+
19
+ _CKPTS = {
20
+ "NYU": Path("checkpoints/nyu.ckpt"),
21
+ "USM": Path("checkpoints/usm.ckpt"),
22
+ "UCLA": Path("checkpoints/ucla.ckpt"),
23
+ "UM": Path("checkpoints/um.ckpt"),
24
+ }
25
+
26
+ _LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
27
+
28
+ SYSTEM_PROMPT = (
29
+ "You are a clinical AI assistant specializing in functional MRI brain "
30
+ "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
31
+ "You interpret outputs from a validated graph neural network (GCN) trained on "
32
+ "the ABIDE I dataset and provide structured clinical summaries for neurologists "
33
+ "and psychiatrists. Your reports are informative and evidence-based but always "
34
+ "clarify that findings are AI-assisted and should be integrated with full "
35
+ "clinical assessment. You do not make a diagnosis."
36
+ )
37
+
38
+ # ── preprocessing ──────────────────────────────────────────────────────────
39
+
40
+ def _zscore(bold):
41
+ mean = bold.mean(0, keepdims=True)
42
+ std = bold.std(0, keepdims=True)
43
+ std[std < 1e-8] = 1.0
44
+ return ((bold - mean) / std).astype(np.float32)
45
+
46
+ def _fc(bold):
47
+ fc = np.corrcoef(bold.T).astype(np.float32)
48
+ np.nan_to_num(fc, copy=False)
49
+ return fc
50
+
51
+ def _windows(bold):
52
+ T, N = bold.shape
53
+ starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
54
+ w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
55
+ if len(w) >= _MAX_WINDOWS:
56
+ return w[:_MAX_WINDOWS]
57
+ return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
58
+
59
+ def preprocess(bold):
60
+ bold = _zscore(bold)
61
+ fc = _fc(bold)
62
+ fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
63
+ adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
64
+ bw = _windows(bold)
65
+ return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
66
+
67
+ # ── GCN model loading ──────────────────────────────────────────────────────
68
+
69
+ _models: list | None = None
70
+
71
+ def get_models():
72
+ global _models
73
+ if _models is not None:
74
+ return _models
75
+ from brain_gcn.tasks import ClassificationTask
76
+ _models = []
77
+ for site, ckpt in _CKPTS.items():
78
+ if not ckpt.exists():
79
+ continue
80
+ task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
81
+ task.eval()
82
+ _models.append((site, task))
83
+ return _models
84
+
85
+ # ── LLM loading ────────────────────────────────────────────────────────────
86
+
87
+ _llm = None
88
+
89
+ def get_llm():
90
+ global _llm
91
+ if _llm is not None:
92
+ return _llm
93
+ from transformers import AutoModelForCausalLM, AutoTokenizer
94
+ print(f"Loading LLM: {_LLM_MODEL}")
95
+ tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
96
+ tok.pad_token = tok.eos_token
97
+ mdl = AutoModelForCausalLM.from_pretrained(
98
+ _LLM_MODEL,
99
+ torch_dtype=torch.bfloat16,
100
+ device_map="auto",
101
+ )
102
+ mdl.eval()
103
+ _llm = (mdl, tok)
104
+ return _llm
105
+
106
+
107
+ def _llm_report(p_mean: float, per_model: list) -> str:
108
+ consensus = sum(1 for _, p in per_model if p > 0.5)
109
+ per_model_str = "\n".join(
110
+ f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
111
+ )
112
+ conf_label = (
113
+ "HIGH" if p_mean >= 0.75 else
114
+ "MODERATE" if p_mean >= 0.6 else
115
+ "LOW / UNCERTAIN" if p_mean >= 0.4 else
116
+ "MODERATE (TC)" if p_mean >= 0.25 else
117
+ "HIGH (TC)"
118
+ )
119
+ user_msg = (
120
+ f"Brain Connectivity GCN Analysis Report\n"
121
+ f"{'='*40}\n"
122
+ f"p(ASD) : {p_mean:.3f}\n"
123
+ f"Confidence Level : {conf_label}\n"
124
+ f"Model Consensus : {consensus}/4 site-blind models predict ASD\n\n"
125
+ f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
126
+ f"Please provide a structured clinical interpretation of these findings."
127
+ )
128
+ try:
129
+ mdl, tok = get_llm()
130
+ messages = [
131
+ {"role": "system", "content": SYSTEM_PROMPT},
132
+ {"role": "user", "content": user_msg},
133
+ ]
134
+ text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
+ inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
136
+ with torch.no_grad():
137
+ out = mdl.generate(
138
+ **inputs,
139
+ max_new_tokens=512,
140
+ temperature=0.3,
141
+ do_sample=True,
142
+ pad_token_id=tok.eos_token_id,
143
+ )
144
+ generated = out[0][inputs["input_ids"].shape[1]:]
145
+ return tok.decode(generated, skip_special_tokens=True).strip()
146
+ except Exception as e:
147
+ return f"LLM unavailable: {e}"
148
+
149
+ # ── gradient saliency ──────────────────────────────────────────────────────
150
+
151
+ def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
152
+ maps = []
153
+ for _, task in models:
154
+ adj = adj_t.clone().requires_grad_(True)
155
+ logits = task.model(bw_t, adj)
156
+ p = torch.softmax(logits, -1)[0, 1]
157
+ p.backward()
158
+ maps.append(adj.grad[0].abs().detach().numpy())
159
+ sal = np.mean(maps, axis=0)
160
+ sal = (sal + sal.T) / 2
161
+ return sal
162
+
163
+
164
+ def _saliency_figure(sal: np.ndarray, p_mean: float):
165
+ import matplotlib
166
+ matplotlib.use("Agg")
167
+ import matplotlib.pyplot as plt
168
+ from PIL import Image
169
+
170
+ thresh = np.percentile(sal, 95)
171
+ sal_top = np.where(sal >= thresh, sal, 0.0)
172
+ roi_imp = sal.sum(1)
173
+ top20 = roi_imp.argsort()[-20:][::-1]
174
+ verdict_color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261"
175
+
176
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
177
+ fig.patch.set_facecolor("#0d0d0d")
178
+
179
+ ax = axes[0]
180
+ ax.set_facecolor("#111")
181
+ im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
182
+ ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
183
+ ax.set_xlabel("ROI index", color="#777", fontsize=9)
184
+ ax.set_ylabel("ROI index", color="#777", fontsize=9)
185
+ ax.tick_params(colors="#555", labelsize=8)
186
+ cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
187
+ cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
188
+ plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
189
+ for spine in ax.spines.values():
190
+ spine.set_color("#333")
191
+
192
+ ax2 = axes[1]
193
+ ax2.set_facecolor("#111")
194
+ ax2.barh(range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none")
195
+ ax2.set_yticks(range(20))
196
+ ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
197
+ ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
198
+ ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
199
+ ax2.tick_params(colors="#555", labelsize=8)
200
+ ax2.invert_yaxis()
201
+ for spine in ["top", "right"]:
202
+ ax2.spines[spine].set_visible(False)
203
+ for spine in ["bottom", "left"]:
204
+ ax2.spines[spine].set_color("#333")
205
+
206
+ fig.suptitle(
207
+ f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
208
+ color="#888", fontsize=10, y=1.02,
209
+ )
210
+ plt.tight_layout()
211
+ buf = io.BytesIO()
212
+ plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
213
+ plt.close(fig)
214
+ buf.seek(0)
215
+ return Image.open(buf).copy()
216
+
217
+ # ── inference ──────────────────────────────────────────────────────────────
218
+
219
+ def run_gcn(file_path: str | None):
220
+ if file_path is None:
221
+ return "", "", "", None, ""
222
+
223
+ path = Path(file_path)
224
+ try:
225
+ if path.suffix == ".npz":
226
+ d = np.load(path, allow_pickle=True)
227
+ fc = d["mean_fc"].astype(np.float32)
228
+ fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
229
+ adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
230
+ bw = d["bold_windows"].astype(np.float32)
231
+ if len(bw) >= _MAX_WINDOWS:
232
+ bw = bw[:_MAX_WINDOWS]
233
+ else:
234
+ bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
235
+ bw_t = torch.FloatTensor(bw).unsqueeze(0)
236
+ adj_t = torch.FloatTensor(adj).unsqueeze(0)
237
+ else:
238
+ bold = np.loadtxt(path, dtype=np.float32)
239
+ if bold.ndim != 2 or bold.shape[1] != 200:
240
+ return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None, ""
241
+ bw_t, adj_t = preprocess(bold)
242
+ except Exception as e:
243
+ return f"⚠️ Error loading file: {e}", "", "", None, ""
244
+
245
+ models = get_models()
246
+
247
+ per_model = []
248
+ with torch.no_grad():
249
+ for site, task in models:
250
+ logits = task(bw_t, adj_t)
251
+ p = torch.softmax(logits, -1)[0, 1].item()
252
+ per_model.append((site, p))
253
+
254
+ p_mean = float(np.mean([p for _, p in per_model]))
255
+ consensus = sum(1 for _, p in per_model if p > 0.5)
256
+ conf = max(p_mean, 1 - p_mean) * 100
257
+
258
+ try:
259
+ sal = _compute_saliency(bw_t, adj_t, models)
260
+ sal_img = _saliency_figure(sal, p_mean)
261
+ except Exception:
262
+ sal_img = None
263
+
264
+ # Verdict
265
+ if p_mean > 0.6:
266
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
267
+ <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
268
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{consensus}/4</b> site-blind models agree</div>
269
+ </div>"""
270
+ elif p_mean < 0.4:
271
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
272
+ <div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
273
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
274
+ </div>"""
275
+ else:
276
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
277
+ <div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
278
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; Model disagreement — clinical review required</div>
279
+ </div>"""
280
+
281
+ # Ensemble breakdown
282
+ rows = ""
283
+ for site, p in per_model:
284
+ lbl = "ASD" if p > 0.5 else "TC"
285
+ color = "#e63946" if p > 0.5 else "#2dc653"
286
+ bar_w = int(p * 100)
287
+ rows += f"""<tr>
288
+ <td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
289
+ <td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
290
+ <div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
291
+ <td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
292
+ <td style="padding:8px 12px;color:#888">p={p:.3f}</td>
293
+ </tr>"""
294
+
295
+ ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
296
+ <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
297
+ <table style="width:100%;border-collapse:collapse">{rows}</table>
298
+ <div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree &nbsp;·&nbsp; LOSO AUC = 0.7872 across 529 held-out subjects</div>
299
+ </div>"""
300
+
301
+ # LLM clinical report
302
+ llm_text = _llm_report(p_mean, per_model)
303
+ report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
304
+ <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Report — Qwen2.5-7B fine-tuned on AMD Instinct MI300X</div>
305
+ <div style="color:#ddd;font-size:0.95rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
306
+ <div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#555;font-size:0.78rem;margin-top:16px">
307
+ ⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments (ADOS-2, ADI-R).
308
+ </div></div>"""
309
+
310
+ return verdict, ensemble, report, sal_img
311
+
312
+
313
+ # ── UI ─────────────────────────────────────────────────────────────────────
314
+
315
+ css = """
316
+ body { background: #0d0d0d; }
317
+ .gradio-container { max-width: 960px; margin: auto; }
318
+ """
319
+
320
+ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
321
+ gr.HTML("""
322
+ <div style="text-align:center;padding:32px 0 16px">
323
+ <div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
324
+ <div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
325
+ <div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
326
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
327
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
328
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
329
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
330
+ </div>
331
+ </div>
332
+ """)
333
+
334
+ file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
335
+ verdict_html = gr.HTML()
336
+ ensemble_html = gr.HTML()
337
+
338
+ gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
339
+ saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
340
+
341
+ report_html = gr.HTML()
342
+
343
+ file_input.change(
344
+ fn=run_gcn,
345
+ inputs=file_input,
346
+ outputs=[verdict_html, ensemble_html, report_html, saliency_img],
347
+ )
348
+
349
+ gr.HTML("""
350
+ <div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
351
+ Adversarial Brain-Mode GCN (k=16) · Qwen2.5-7B LoRA (AMD MI300X) · ABIDE I ·
352
+ <a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
353
+ </div>
354
+ """)
355
+
356
+ print("Preloading GCN models...")
357
+ get_models()
358
+ print("Preloading LLM...")
359
+ get_llm()
360
+ print("All models ready.")
361
+
362
+ if __name__ == "__main__":
363
+ demo.launch()
checkpoints/aal_nyu.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d588597092f4b9483fdcfb0f8d5aae67030ea9d17df8cf7a0c027ef527b5657
3
+ size 253386
checkpoints/aal_ucla.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af986f0e5909e8551c2473ee3ff876ff9cc814ea38e220de33fad96201eaa37
3
+ size 253386
checkpoints/aal_um.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b380df053b6a9a83dc6062070418cd6b8718b4f4c5c5f8b388330dcf0a9abf5
3
+ size 253386
checkpoints/aal_usm.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7410692b93237e59fe5a03726338b7f02f0b61acbb6185dc1ecc79dfbee6bdac
3
+ size 252813