Aryagm commited on
Commit
a2af649
·
verified ·
1 Parent(s): c090080

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +11 -1
  2. app.py +106 -6
README.md CHANGED
@@ -32,15 +32,25 @@ The `model18cls/` folder should now contain `model.h5` and `saved_model/`.
32
 
33
  ## API Endpoints
34
 
 
 
 
 
 
 
 
 
 
35
  ### POST /segment
36
  Upload a NIfTI file (.nii or .nii.gz) for segmentation.
 
37
 
38
  ```bash
39
  curl -X POST -F "file=@brain.nii.gz" https://YOUR-SPACE.hf.space/segment
40
  ```
41
 
42
  ### POST /segment/compact
43
- Same as above but returns base64-gzipped results (more efficient for large volumes).
44
 
45
  ### GET /health
46
  Check API status and GPU availability.
 
32
 
33
  ## API Endpoints
34
 
35
+ ### POST /segment/tensor (Recommended)
36
+ Upload pre-processed tensor data (256³ uint8, gzipped) for GPU inference.
37
+ This is the recommended endpoint as it ensures identical preprocessing to local inference.
38
+
39
+ ```bash
40
+ # Frontend sends the conformed tensor from NiiVue directly
41
+ # See atrophy/main.js for implementation
42
+ ```
43
+
44
  ### POST /segment
45
  Upload a NIfTI file (.nii or .nii.gz) for segmentation.
46
+ Note: Server-side preprocessing may differ slightly from local NiiVue preprocessing.
47
 
48
  ```bash
49
  curl -X POST -F "file=@brain.nii.gz" https://YOUR-SPACE.hf.space/segment
50
  ```
51
 
52
  ### POST /segment/compact
53
+ Same as /segment but returns base64-gzipped results.
54
 
55
  ### GET /health
56
  Check API status and GPU availability.
app.py CHANGED
@@ -67,7 +67,7 @@ def load_model():
67
  return model
68
 
69
  def parse_nifti(file_bytes: bytes, filename: str = "temp.nii"):
70
- """Parse NIfTI file from bytes"""
71
  import tempfile
72
 
73
  # Determine file extension for nibabel
@@ -81,8 +81,16 @@ def parse_nifti(file_bytes: bytes, filename: str = "temp.nii"):
81
 
82
  try:
83
  img = nib.load(tmp_path)
84
- data = img.get_fdata()
85
- header = img.header
 
 
 
 
 
 
 
 
86
  finally:
87
  # Clean up temp file
88
  import os
@@ -166,8 +174,12 @@ def preprocess_volume(data, header):
166
  # Ensure float32
167
  data = data.astype(np.float32)
168
 
169
- # Transpose to match model expectations (the model was trained with this orientation)
 
 
 
170
  data = np.transpose(data, (2, 1, 0))
 
171
 
172
  # Add batch and channel dimensions
173
  data = np.expand_dims(data, axis=0) # batch
@@ -178,10 +190,10 @@ def preprocess_volume(data, header):
178
 
179
  def postprocess_segmentation(segmentation):
180
  """
181
- Transpose segmentation back to standard orientation.
182
  Output is 256^3 (conformed space).
183
  """
184
- # Transpose back
185
  segmentation = np.transpose(segmentation, (2, 1, 0))
186
  return segmentation
187
 
@@ -337,6 +349,94 @@ async def segment_compact(file: UploadFile = File(...)):
337
  traceback.print_exc()
338
  raise HTTPException(500, f"Segmentation failed: {str(e)}")
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  if __name__ == "__main__":
341
  import uvicorn
342
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
67
  return model
68
 
69
  def parse_nifti(file_bytes: bytes, filename: str = "temp.nii"):
70
+ """Parse NIfTI file from bytes and reorient to canonical (RAS+) orientation"""
71
  import tempfile
72
 
73
  # Determine file extension for nibabel
 
81
 
82
  try:
83
  img = nib.load(tmp_path)
84
+
85
+ # Reorient to canonical RAS+ orientation (like NiiVue does)
86
+ # This ensures consistent orientation regardless of how the file was saved
87
+ img_canonical = nib.as_closest_canonical(img)
88
+ data = img_canonical.get_fdata()
89
+ header = img_canonical.header
90
+
91
+ print(f"Original orientation: {nib.aff2axcodes(img.affine)}")
92
+ print(f"Canonical orientation: {nib.aff2axcodes(img_canonical.affine)}")
93
+
94
  finally:
95
  # Clean up temp file
96
  import os
 
174
  # Ensure float32
175
  data = data.astype(np.float32)
176
 
177
+ # The model expects input in a specific orientation
178
+ # After canonical reorientation, data is in RAS+ (Right-Anterior-Superior)
179
+ # The tfjs model was trained with transposed input, so we transpose here
180
+ # This matches the local frontend's behavior
181
  data = np.transpose(data, (2, 1, 0))
182
+ print(f"After transpose shape: {data.shape}")
183
 
184
  # Add batch and channel dimensions
185
  data = np.expand_dims(data, axis=0) # batch
 
190
 
191
  def postprocess_segmentation(segmentation):
192
  """
193
+ Transpose segmentation back to standard RAS+ orientation.
194
  Output is 256^3 (conformed space).
195
  """
196
+ # Transpose back to RAS+ orientation
197
  segmentation = np.transpose(segmentation, (2, 1, 0))
198
  return segmentation
199
 
 
349
  traceback.print_exc()
350
  raise HTTPException(500, f"Segmentation failed: {str(e)}")
351
 
352
+
353
+ @app.post("/segment/tensor")
354
+ async def segment_tensor(file: UploadFile = File(...)):
355
+ """
356
+ Segment using pre-processed tensor from frontend.
357
+
358
+ Accepts gzipped raw tensor data (256x256x256 uint8) that has already been
359
+ conformed by NiiVue. This ensures identical preprocessing to local inference.
360
+
361
+ The frontend sends the conformed volume, server just runs inference.
362
+ """
363
+ import base64
364
+
365
+ try:
366
+ start_time = time.time()
367
+
368
+ # Read gzipped tensor data
369
+ compressed_bytes = await file.read()
370
+ print(f"Received {len(compressed_bytes)} bytes of compressed tensor data")
371
+
372
+ # Decompress
373
+ try:
374
+ raw_bytes = gzip.decompress(compressed_bytes)
375
+ except:
376
+ # Maybe not compressed
377
+ raw_bytes = compressed_bytes
378
+
379
+ expected_size = 256 * 256 * 256
380
+ if len(raw_bytes) != expected_size:
381
+ raise HTTPException(400, f"Expected {expected_size} bytes (256³), got {len(raw_bytes)}")
382
+
383
+ # Convert to numpy array
384
+ data = np.frombuffer(raw_bytes, dtype=np.uint8).reshape((256, 256, 256))
385
+ print(f"Tensor shape: {data.shape}, dtype: {data.dtype}")
386
+
387
+ # Normalize to [0, 1] - same as brainchop's minMaxNormalizeVolumeData
388
+ data = data.astype(np.float32)
389
+ data_min = data.min()
390
+ data_max = data.max()
391
+ if data_max - data_min > 0:
392
+ data = (data - data_min) / (data_max - data_min)
393
+ print(f"Normalized range: [{data.min():.3f}, {data.max():.3f}]")
394
+
395
+ # Transpose - same as brainchop with enableTranspose=true
396
+ data = np.transpose(data, (2, 1, 0))
397
+ print(f"After transpose: {data.shape}")
398
+
399
+ # Add batch and channel dimensions
400
+ data = np.expand_dims(data, axis=0) # batch
401
+ data = np.expand_dims(data, axis=-1) # channel
402
+ print(f"Model input shape: {data.shape}")
403
+
404
+ # Run inference
405
+ inference_start = time.time()
406
+ loaded_model = load_model()
407
+ prediction = loaded_model.predict(data, verbose=0)
408
+ segmentation = np.argmax(prediction, axis=-1)[0]
409
+
410
+ # Transpose back to match frontend expectations
411
+ segmentation = np.transpose(segmentation, (2, 1, 0))
412
+ inference_time = time.time() - inference_start
413
+ print(f"Inference time: {inference_time:.2f}s, output shape: {segmentation.shape}")
414
+
415
+ total_time = time.time() - start_time
416
+
417
+ # Compress and encode result
418
+ seg_bytes = segmentation.astype(np.uint8).tobytes()
419
+ compressed = gzip.compress(seg_bytes)
420
+ encoded = base64.b64encode(compressed).decode('utf-8')
421
+
422
+ return JSONResponse({
423
+ "success": True,
424
+ "shape": list(segmentation.shape),
425
+ "dtype": "uint8",
426
+ "encoding": "base64_gzip",
427
+ "inference_time": round(inference_time, 3),
428
+ "total_time": round(total_time, 3),
429
+ "data": encoded
430
+ })
431
+
432
+ except HTTPException:
433
+ raise
434
+ except Exception as e:
435
+ import traceback
436
+ print(f"ERROR in /segment/tensor: {str(e)}")
437
+ traceback.print_exc()
438
+ raise HTTPException(500, f"Tensor inference failed: {str(e)}")
439
+
440
  if __name__ == "__main__":
441
  import uvicorn
442
  uvicorn.run(app, host="0.0.0.0", port=7860)