Aryagm commited on
Commit
222d844
·
verified ·
1 Parent(s): fde1d5f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +71 -7
app.py CHANGED
@@ -99,22 +99,75 @@ def min_max_normalize(data):
99
  return (data - data_min) / (data_max - data_min)
100
 
101
  def preprocess_volume(data):
102
- """Preprocess MRI volume for model input"""
 
 
 
 
 
103
  # Normalize
104
  data = min_max_normalize(data)
105
 
106
  # Ensure float32
107
  data = data.astype(np.float32)
108
 
109
- # Transpose if needed (depends on model training)
110
- # Model expects [batch, D, H, W, channels]
111
- data = np.transpose(data, (2, 1, 0)) # Adjust axes as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Add batch and channel dimensions
114
  data = np.expand_dims(data, axis=0) # batch
115
  data = np.expand_dims(data, axis=-1) # channel
116
 
117
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def run_inference(data):
120
  """Run model inference on preprocessed data"""
@@ -177,13 +230,14 @@ async def segment(file: UploadFile = File(...)):
177
 
178
  # Preprocess
179
  preprocess_start = time.time()
180
- processed = preprocess_volume(data)
181
  preprocess_time = time.time() - preprocess_start
182
  print(f"Preprocessed shape: {processed.shape}, Time: {preprocess_time:.2f}s")
183
 
184
  # Run inference
185
  inference_start = time.time()
186
  segmentation = run_inference(processed)
 
187
  inference_time = time.time() - inference_start
188
  print(f"Inference time: {inference_time:.2f}s")
189
 
@@ -231,9 +285,19 @@ async def segment_compact(file: UploadFile = File(...)):
231
  raise HTTPException(400, "File must be a NIfTI file (.nii or .nii.gz)")
232
 
233
  file_bytes = await file.read()
 
 
234
  data, header = parse_nifti(file_bytes, file.filename)
235
- processed = preprocess_volume(data)
 
 
 
 
236
  segmentation = run_inference(processed)
 
 
 
 
237
 
238
  total_time = time.time() - start_time
239
 
 
99
  return (data - data_min) / (data_max - data_min)
100
 
101
  def preprocess_volume(data):
102
+ """
103
+ Preprocess MRI volume for model input.
104
+ Returns preprocessed data and info needed to unpad output.
105
+ """
106
+ original_shape = data.shape
107
+
108
  # Normalize
109
  data = min_max_normalize(data)
110
 
111
  # Ensure float32
112
  data = data.astype(np.float32)
113
 
114
+ # Transpose to match model expectations
115
+ data = np.transpose(data, (2, 1, 0))
116
+ transposed_shape = data.shape
117
+
118
+ # Pad to 256x256x256 if needed (model requires fixed input size)
119
+ target_shape = (256, 256, 256)
120
+ current_shape = data.shape
121
+ pad_info = None
122
+
123
+ if current_shape != target_shape:
124
+ print(f"Padding volume from {current_shape} to {target_shape}")
125
+ padded = np.zeros(target_shape, dtype=np.float32)
126
+
127
+ # Calculate padding offsets (center the volume)
128
+ offsets = [(t - c) // 2 for t, c in zip(target_shape, current_shape)]
129
+
130
+ # Handle cases where input is larger than target (crop instead)
131
+ slices_src = []
132
+ slices_dst = []
133
+ for i in range(3):
134
+ if current_shape[i] <= target_shape[i]:
135
+ # Pad: source is full, destination is offset
136
+ slices_src.append(slice(0, current_shape[i]))
137
+ slices_dst.append(slice(offsets[i], offsets[i] + current_shape[i]))
138
+ else:
139
+ # Crop: source is cropped, destination is full
140
+ start = (current_shape[i] - target_shape[i]) // 2
141
+ slices_src.append(slice(start, start + target_shape[i]))
142
+ slices_dst.append(slice(0, target_shape[i]))
143
+
144
+ padded[slices_dst[0], slices_dst[1], slices_dst[2]] = data[slices_src[0], slices_src[1], slices_src[2]]
145
+ data = padded
146
+ pad_info = {
147
+ 'original_transposed_shape': transposed_shape,
148
+ 'slices_dst': slices_dst
149
+ }
150
 
151
  # Add batch and channel dimensions
152
  data = np.expand_dims(data, axis=0) # batch
153
  data = np.expand_dims(data, axis=-1) # channel
154
 
155
+ return data, pad_info, original_shape
156
+
157
+
158
+ def postprocess_segmentation(segmentation, pad_info, original_shape):
159
+ """
160
+ Remove padding from segmentation output and transpose back to original orientation.
161
+ """
162
+ # If we padded, extract the original region
163
+ if pad_info is not None:
164
+ slices = pad_info['slices_dst']
165
+ segmentation = segmentation[slices[0], slices[1], slices[2]]
166
+
167
+ # Transpose back to original orientation
168
+ segmentation = np.transpose(segmentation, (2, 1, 0))
169
+
170
+ return segmentation
171
 
172
  def run_inference(data):
173
  """Run model inference on preprocessed data"""
 
230
 
231
  # Preprocess
232
  preprocess_start = time.time()
233
+ processed, pad_info, original_shape = preprocess_volume(data)
234
  preprocess_time = time.time() - preprocess_start
235
  print(f"Preprocessed shape: {processed.shape}, Time: {preprocess_time:.2f}s")
236
 
237
  # Run inference
238
  inference_start = time.time()
239
  segmentation = run_inference(processed)
240
+ segmentation = postprocess_segmentation(segmentation, pad_info, original_shape)
241
  inference_time = time.time() - inference_start
242
  print(f"Inference time: {inference_time:.2f}s")
243
 
 
285
  raise HTTPException(400, "File must be a NIfTI file (.nii or .nii.gz)")
286
 
287
  file_bytes = await file.read()
288
+ print(f"Processing: {file.filename}, size: {len(file_bytes)} bytes")
289
+
290
  data, header = parse_nifti(file_bytes, file.filename)
291
+ print(f"Parsed volume shape: {data.shape}")
292
+
293
+ processed, pad_info, original_shape = preprocess_volume(data)
294
+ print(f"Preprocessed shape: {processed.shape}")
295
+
296
  segmentation = run_inference(processed)
297
+ print(f"Raw segmentation shape: {segmentation.shape}")
298
+
299
+ segmentation = postprocess_segmentation(segmentation, pad_info, original_shape)
300
+ print(f"Final segmentation shape: {segmentation.shape}")
301
 
302
  total_time = time.time() - start_time
303