IFMedTechdemo commited on
Commit
57892d7
·
verified ·
1 Parent(s): e3b4744

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +56 -52
  3. machine_measurements.csv +3 -0
  4. requirements.txt +1 -0
.gitattributes CHANGED
@@ -38,3 +38,4 @@ examples/43522917.dat filter=lfs diff=lfs merge=lfs -text
38
  examples/45227415.dat filter=lfs diff=lfs merge=lfs -text
39
  examples/46642833.dat filter=lfs diff=lfs merge=lfs -text
40
  examples/49036311.dat filter=lfs diff=lfs merge=lfs -text
 
 
38
  examples/45227415.dat filter=lfs diff=lfs merge=lfs -text
39
  examples/46642833.dat filter=lfs diff=lfs merge=lfs -text
40
  examples/49036311.dat filter=lfs diff=lfs merge=lfs -text
41
+ machine_measurements.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -6,59 +6,76 @@ import numpy as np
6
  import matplotlib.pyplot as plt
7
  import os
8
  import glob
 
9
  from labels_refined import get_refined_labels, CLASSES
10
  from model import ResNet1d
11
  from dataset import MIMICECGDataset
12
 
13
  # --- Configuration ---
14
- # HF Space configuration: Data is local
15
  DATA_DIR = "./examples"
16
- MODEL_PATH = "resnet_advanced.pth"
17
- DEVICE = torch.device("cpu") # Spaces usually CPU unless GPU requested
18
 
19
  # --- Load Resources ---
20
- print("Loading Model...")
 
 
 
 
21
  model = ResNet1d(num_classes=5).to(DEVICE)
22
  try:
23
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
24
  except:
25
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
26
  model.load_state_dict(state_dict)
27
  model.eval()
28
 
29
- # --- Pre-defined Metadata for Examples ---
30
- # Hardcoded to avoid uploading the sensitive/huge patient CSV
31
- example_metadata = {
32
- "40689238": {
33
- "diagnosis": "Sinus Rhythm (Normal)",
34
- "text": "Sinus rhythm\nNormal ECG"
35
- },
36
- "46642833": {
37
- "diagnosis": "Atrial Fibrillation",
38
- "text": "Atrial fibrillation\nRapid ventricular response"
39
- },
40
- "49036311": {
41
- "diagnosis": "Sinus Tachycardia",
42
- "text": "Sinus tachycardia\nPossible Left Atrial Enlargement"
43
- },
44
- "43522917": {
45
- "diagnosis": "Sinus Bradycardia",
46
- "text": "Sinus bradycardia\nOtherwise normal"
47
- },
48
- "45227415": {
49
- "diagnosis": "Ventricular Tachycardia (Rare)",
50
- "text": "Ventricular tachycardia\nUrgent attention required"
51
- }
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def load_signal(path):
55
- # Reusing logic from dataset.py
56
  if not os.path.exists(path):
57
  return None
58
 
59
  gain = 200.0
60
  with open(path, 'rb') as f:
61
- # File is raw int16 binary
62
  raw_data = np.fromfile(f, dtype=np.int16)
63
 
64
  n_leads = 12
@@ -77,12 +94,9 @@ def load_signal(path):
77
  return signal
78
 
79
  def plot_ecg(signal, title="12-Lead ECG"):
80
- """Generates a matplotlib figure for the 12-lead ECG"""
81
  leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
82
-
83
  fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
84
  plt.subplots_adjust(hspace=0.2)
85
-
86
  for i in range(12):
87
  axes[i].plot(signal[i], color='k', linewidth=0.8)
88
  axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
@@ -91,44 +105,39 @@ def plot_ecg(signal, title="12-Lead ECG"):
91
  axes[i].spines['bottom'].set_visible(False if i < 11 else True)
92
  axes[i].spines['left'].set_visible(True)
93
  axes[i].grid(True, linestyle='--', alpha=0.5)
94
-
95
  axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
96
  fig.suptitle(title, fontsize=16, y=0.90)
97
-
98
  return fig
99
 
100
  def predict_ecg(study_id):
101
- # Path is local in examples/
102
  path = os.path.join(DATA_DIR, f"{study_id}.dat")
103
-
104
  if not os.path.exists(path):
105
  return None, f"File not found for study {study_id}", {}
106
 
107
- # Load Signal
108
  signal = load_signal(path)
109
  if signal is None:
110
  return None, "Error loading signal", {}
111
 
112
- # Generate Plot
113
  fig = plot_ecg(signal, title=f"Study {study_id}")
114
 
115
- # Inference
116
- tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE) # (1, 12, 5000)
117
  with torch.no_grad():
118
  logits = model(tensor_sig)
119
  probs = torch.sigmoid(logits).cpu().numpy()[0]
120
 
121
- # Format Results
122
  results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
123
 
124
- # Get True Text
125
  full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
126
 
127
  return fig, results, full_text
128
 
129
  # --- Gradio UI ---
130
  examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
131
- example_ids = [k for k in example_metadata.keys()]
 
 
 
 
132
 
133
  with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
134
  gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
@@ -136,17 +145,12 @@ with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
136
 
137
  with gr.Row():
138
  with gr.Column(scale=1):
139
- # Input
140
- study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0])
141
-
142
- # Info
143
  gr.Markdown("### Example Descriptions")
144
  gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
145
-
146
  analyze_btn = gr.Button("Analyze ECG", variant="primary")
147
 
148
  with gr.Column(scale=2):
149
- # Output
150
  plot_output = gr.Plot(label="12-Lead ECG Visualization")
151
  label_output = gr.Label(label="AI Predictions")
152
  text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
 
6
  import matplotlib.pyplot as plt
7
  import os
8
  import glob
9
+ from huggingface_hub import hf_hub_download
10
  from labels_refined import get_refined_labels, CLASSES
11
  from model import ResNet1d
12
  from dataset import MIMICECGDataset
13
 
14
  # --- Configuration ---
 
15
  DATA_DIR = "./examples"
16
+ CSV_PATH = "machine_measurements.csv" # Now local in Space
17
+ DEVICE = torch.device("cpu")
18
 
19
  # --- Load Resources ---
20
+ print("Downloading Model from Hub...")
21
+ # Downloads to local cache and returns path
22
+ model_path = hf_hub_download(repo_id="IFMedTech/ECG_Model", filename="resnet_advanced.pth")
23
+
24
+ print(f"Loading Model from {model_path}...")
25
  model = ResNet1d(num_classes=5).to(DEVICE)
26
  try:
27
+ state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
28
  except:
29
+ state_dict = torch.load(model_path, map_location=DEVICE)
30
  model.load_state_dict(state_dict)
31
  model.eval()
32
 
33
+ print("Loading Dataset Index...")
34
+ # Use CSV to dynamically find info for available examples
35
+ try:
36
+ df = pd.read_csv(CSV_PATH, low_memory=False)
37
+ print(f"Loaded CSV with {len(df)} records.")
38
+ except Exception as e:
39
+ print(f"Error loading CSV: {e}")
40
+ df = pd.DataFrame() # Fallback
41
+
42
+ # Scan examples folder for .dat files
43
+ example_files = glob.glob(os.path.join(DATA_DIR, "*.dat"))
44
+ available_study_ids = [os.path.splitext(os.path.basename(f))[0] for f in example_files]
45
+ print(f"Found examples: {available_study_ids}")
46
+
47
+ # Build Metadata for Gradio
48
+ example_metadata = {}
49
+ for sid in available_study_ids:
50
+ if df.empty:
51
+ example_metadata[sid] = {"diagnosis": "Unknown (CSV Missing)", "text": "N/A"}
52
+ continue
53
+
54
+ row = df[df['study_id'].astype(str) == str(sid)]
55
+ if not row.empty:
56
+ cols = [c for c in df.columns if 'report_' in c]
57
+ lines = [str(row.iloc[0][c]).strip() for c in cols if pd.notna(row.iloc[0][c]) and str(row.iloc[0][c]).strip() != '']
58
+ full_text = '\n'.join(lines)
59
+
60
+ # Simple diagnosis estimation from labels for display title
61
+ labels_vec = get_refined_labels(' '.join(lines))
62
+ active_classes = [CLASSES[i] for i, val in enumerate(labels_vec) if val == 1.0]
63
+ diagnosis = ", ".join(active_classes) if active_classes else "Normal/Other"
64
+
65
+ example_metadata[sid] = {
66
+ "diagnosis": diagnosis,
67
+ "text": full_text
68
+ }
69
+ else:
70
+ example_metadata[sid] = {"diagnosis": "Metadata Not Found", "text": "N/A"}
71
+
72
 
73
  def load_signal(path):
 
74
  if not os.path.exists(path):
75
  return None
76
 
77
  gain = 200.0
78
  with open(path, 'rb') as f:
 
79
  raw_data = np.fromfile(f, dtype=np.int16)
80
 
81
  n_leads = 12
 
94
  return signal
95
 
96
  def plot_ecg(signal, title="12-Lead ECG"):
 
97
  leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
 
98
  fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
99
  plt.subplots_adjust(hspace=0.2)
 
100
  for i in range(12):
101
  axes[i].plot(signal[i], color='k', linewidth=0.8)
102
  axes[i].set_ylabel(leads[i], rotation=0, labelpad=20, fontsize=10, fontweight='bold')
 
105
  axes[i].spines['bottom'].set_visible(False if i < 11 else True)
106
  axes[i].spines['left'].set_visible(True)
107
  axes[i].grid(True, linestyle='--', alpha=0.5)
 
108
  axes[11].set_xlabel("Samples (500Hz)", fontsize=12)
109
  fig.suptitle(title, fontsize=16, y=0.90)
 
110
  return fig
111
 
112
  def predict_ecg(study_id):
 
113
  path = os.path.join(DATA_DIR, f"{study_id}.dat")
 
114
  if not os.path.exists(path):
115
  return None, f"File not found for study {study_id}", {}
116
 
 
117
  signal = load_signal(path)
118
  if signal is None:
119
  return None, "Error loading signal", {}
120
 
 
121
  fig = plot_ecg(signal, title=f"Study {study_id}")
122
 
123
+ tensor_sig = torch.from_numpy(signal).float().unsqueeze(0).to(DEVICE)
 
124
  with torch.no_grad():
125
  logits = model(tensor_sig)
126
  probs = torch.sigmoid(logits).cpu().numpy()[0]
127
 
 
128
  results = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
129
 
 
130
  full_text = example_metadata.get(study_id, {}).get("text", "Unknown")
131
 
132
  return fig, results, full_text
133
 
134
  # --- Gradio UI ---
135
  examples = [[k, v["diagnosis"]] for k, v in example_metadata.items()]
136
+ examples.sort(key=lambda x: x[0])
137
+ example_ids = [k[0] for k in examples]
138
+
139
+ if not example_ids:
140
+ example_ids = ["No Examples Found"]
141
 
142
  with gr.Blocks(title="ECG Arrhythmia Classifier") as demo:
143
  gr.Markdown("# 🫀 AI ECG Arrhythmia Classifier")
 
145
 
146
  with gr.Row():
147
  with gr.Column(scale=1):
148
+ study_input = gr.Dropdown(choices=example_ids, label="Select Example Study ID", value=example_ids[0] if example_ids else None)
 
 
 
149
  gr.Markdown("### Example Descriptions")
150
  gr.DataFrame(headers=["Study ID", "Diagnosis"], value=examples, interactive=False)
 
151
  analyze_btn = gr.Button("Analyze ECG", variant="primary")
152
 
153
  with gr.Column(scale=2):
 
154
  plot_output = gr.Plot(label="12-Lead ECG Visualization")
155
  label_output = gr.Label(label="AI Predictions")
156
  text_output = gr.Textbox(label="Original Clinical Report (Ground Truth context)", lines=5)
machine_measurements.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56f6b1413221bce95bd6f48b28ca1acf27ae0b073d6f2c1d12f3af7500eabbb6
3
+ size 182674683
requirements.txt CHANGED
@@ -4,3 +4,4 @@ numpy
4
  matplotlib
5
  gradio
6
  scipy
 
 
4
  matplotlib
5
  gradio
6
  scipy
7
+ huggingface_hub