Benny-Tang commited on
Commit
69dcdac
·
verified ·
1 Parent(s): 60fcd9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import json
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image
@@ -16,44 +16,54 @@ MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
16
  MODEL.eval()
17
  PATHOLOGIES = MODEL.pathologies
18
 
19
- # TorchXRayVision transforms (keep separate, not in torchvision.Compose)
20
- center_crop = xrv.datasets.XRayCenterCrop()
21
- resize = xrv.datasets.XRayResizer(224)
22
-
23
  def imaging_agent(image_path: str):
24
  if not image_path:
25
  return "No image provided.", None
26
  try:
 
27
  img = Image.open(image_path).convert("L")
28
  arr = np.array(img).astype(np.float32)
 
 
29
  if arr.max() > 1:
30
  arr /= 255.0
 
 
31
  arr = xrv.datasets.normalize(arr, 4096)
32
 
33
- # TorchXRayVision expects dict
34
  sample = {"img": arr}
35
- sample = center_crop(sample)
36
- sample = resize(sample)
37
- arr = sample["img"]
38
 
 
 
 
 
 
 
 
39
  arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
40
 
 
41
  with torch.no_grad():
42
  preds = MODEL(arr)[0]
43
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
44
 
45
- # Focus on lung pathologies relevant to cancer
46
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
47
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
48
 
 
49
  lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
50
  table = {name: round(p, 4) for name, p in focus}
 
51
  return "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), json.dumps(table, indent=2)
 
52
  except Exception as e:
53
  return f"Imaging agent error: {e}", None
54
 
 
55
  # -----------------------------
56
- # Lab Agent (tumor markers, stub thresholds)
57
  # -----------------------------
58
  CANCER_MARKERS = {
59
  "psa": {"unit": "ng/mL", "high": 4},
@@ -80,6 +90,7 @@ def lab_agent(text: str):
80
  return "Could not parse tumor markers."
81
  return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ("\nFlags: " + ", ".join(flags) if flags else "\nFlags: none")
82
 
 
83
  # -----------------------------
84
  # Coordinator
85
  # -----------------------------
@@ -90,6 +101,7 @@ def coordinator(imaging_txt, lab_txt):
90
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
91
  return summary
92
 
 
93
  # -----------------------------
94
  # Demo samples
95
  # -----------------------------
@@ -99,6 +111,7 @@ SAMPLES = {
99
  }
100
  SAMPLE_LABS = "PSA: 8 ng/mL\nCA125: 20 U/mL\nAFP: 15 ng/mL"
101
 
 
102
  # -----------------------------
103
  # Runner
104
  # -----------------------------
@@ -108,6 +121,7 @@ def run_all(image, labs):
108
  coord = coordinator(txt, lab)
109
  return txt, raw, lab, coord
110
 
 
111
  # -----------------------------
112
  # Gradio UI
113
  # -----------------------------
@@ -147,3 +161,4 @@ demo.launch()
147
 
148
 
149
 
 
 
1
+ import json
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image
 
16
  MODEL.eval()
17
  PATHOLOGIES = MODEL.pathologies
18
 
 
 
 
 
19
  def imaging_agent(image_path: str):
20
  if not image_path:
21
  return "No image provided.", None
22
  try:
23
+ # Load grayscale X-ray
24
  img = Image.open(image_path).convert("L")
25
  arr = np.array(img).astype(np.float32)
26
+
27
+ # Normalize to [0,1]
28
  if arr.max() > 1:
29
  arr /= 255.0
30
+
31
+ # TorchXRayVision normalization
32
  arr = xrv.datasets.normalize(arr, 4096)
33
 
34
+ # Wrap in dict for TorchXRayVision
35
  sample = {"img": arr}
 
 
 
36
 
37
+ # Apply transforms step by step
38
+ sample = xrv.datasets.XRayCenterCrop()(sample)
39
+ sample = xrv.datasets.XRayResizer(224)(sample)
40
+
41
+ arr = sample["img"] # Unpack back to array
42
+
43
+ # Convert to torch tensor
44
  arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
45
 
46
+ # Inference
47
  with torch.no_grad():
48
  preds = MODEL(arr)[0]
49
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
50
 
51
+ # Focus only on cancer-relevant lung pathologies
52
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
53
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
54
 
55
+ # Format outputs
56
  lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
57
  table = {name: round(p, 4) for name, p in focus}
58
+
59
  return "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), json.dumps(table, indent=2)
60
+
61
  except Exception as e:
62
  return f"Imaging agent error: {e}", None
63
 
64
+
65
  # -----------------------------
66
+ # Lab Agent (tumor markers, thresholds stub)
67
  # -----------------------------
68
  CANCER_MARKERS = {
69
  "psa": {"unit": "ng/mL", "high": 4},
 
90
  return "Could not parse tumor markers."
91
  return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ("\nFlags: " + ", ".join(flags) if flags else "\nFlags: none")
92
 
93
+
94
  # -----------------------------
95
  # Coordinator
96
  # -----------------------------
 
101
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
102
  return summary
103
 
104
+
105
  # -----------------------------
106
  # Demo samples
107
  # -----------------------------
 
111
  }
112
  SAMPLE_LABS = "PSA: 8 ng/mL\nCA125: 20 U/mL\nAFP: 15 ng/mL"
113
 
114
+
115
  # -----------------------------
116
  # Runner
117
  # -----------------------------
 
121
  coord = coordinator(txt, lab)
122
  return txt, raw, lab, coord
123
 
124
+
125
  # -----------------------------
126
  # Gradio UI
127
  # -----------------------------
 
161
 
162
 
163
 
164
+