Sricharan Reddy Varra Claude Opus 4.6 (1M context) commited on
Commit
142ef3a
·
1 Parent(s): b3254d4

Replace xarray-ome with iohub native xarray support

Browse files

iohub@main has Position.to_xarray() which returns a dask-backed
DataArray with proper coordinates and scales. This eliminates the
xarray-ome dependency and the brittle zarr store path extraction.

Dims are now lowercase (t,c,z,y,x) per iohub convention.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. demo_utils.py +14 -35
  2. optimize_demo.py +6 -6
  3. requirements.txt +1 -2
demo_utils.py CHANGED
@@ -21,7 +21,7 @@ import numpy as np
21
  import torch
22
  import xarray as xr
23
  from numpy.typing import NDArray
24
- from xarray_ome import open_ome_dataset
25
 
26
  from waveorder import util
27
  from waveorder.models import isotropic_thin_3d
@@ -129,9 +129,7 @@ def load_fov_from_plate(
129
  plate, row: str, column: str, field: str, resolution: int = 0
130
  ) -> xr.DataArray:
131
  """
132
- Load a specific FOV from HCS plate using hybrid iohub + xarray-ome approach.
133
-
134
- Uses iohub for navigation, then xarray-ome for fast data loading.
135
 
136
  Parameters
137
  ----------
@@ -149,28 +147,11 @@ def load_fov_from_plate(
149
  Returns
150
  -------
151
  xr.DataArray
152
- Image data with labeled dimensions (T, C, Z, Y, X)
153
  """
154
- # Navigate to position using iohub (fast)
155
  position_key = f"{row}/{column}/{field}"
156
  position = plate[position_key]
157
-
158
- # Get full zarr path from position (handle both Zarr V2 and V3)
159
- store = position.zgroup.store
160
- if hasattr(store, 'path'):
161
- base_path = Path(store.path) # Zarr V2 (DirectoryStore)
162
- elif hasattr(store, 'root'):
163
- base_path = Path(store.root) # Zarr V3 (LocalStore)
164
- else:
165
- raise RuntimeError(f"Unknown store type: {type(store)}")
166
-
167
- position_path = base_path / position.zgroup.path
168
-
169
- # Load with xarray-ome (fast and reliable)
170
- fov_dataset = open_ome_dataset(position_path, resolution=resolution, validate=False)
171
- data_xr = fov_dataset["image"]
172
-
173
- return data_xr
174
 
175
 
176
  # === Data Loading ===
@@ -192,7 +173,7 @@ def load_ome_zarr_fov(
192
  Returns
193
  -------
194
  xr.DataArray
195
- Image data with labeled dimensions (T, C, Z, Y, X)
196
  """
197
  zarr_path = Path(zarr_path)
198
  fov_path = Path(fov_path)
@@ -200,13 +181,8 @@ def load_ome_zarr_fov(
200
  print(f"Loading zarr store from: {zarr_path}")
201
  print(f"Accessing FOV: {fov_path}")
202
 
203
- # Load as xarray Dataset
204
- fov_dataset: xr.Dataset = open_ome_dataset(
205
- zarr_path / fov_path, resolution=resolution, validate=False
206
- )
207
-
208
- # Extract the image DataArray
209
- data_xr = fov_dataset["image"]
210
 
211
  print(f"Loaded data shape: {dict(data_xr.sizes)}")
212
  print(f"Dimensions: {list(data_xr.dims)}")
@@ -316,11 +292,11 @@ def extract_2d_slice(
316
  # Build selection dictionary for indexed dimensions
317
  sel_dict = {}
318
  if t is not None:
319
- sel_dict["T"] = int(t)
320
  if c is not None:
321
- sel_dict["C"] = int(c)
322
  if z is not None:
323
- sel_dict["Z"] = int(z)
324
 
325
  # Extract slice using xarray's labeled indexing
326
  slice_xr = data_xr.isel(**sel_dict) if sel_dict else data_xr
@@ -455,7 +431,10 @@ def print_data_summary(data_xr: xr.DataArray) -> None:
455
  for dim in info["dims"]:
456
  coords = info["coords"][dim]
457
  if len(coords) > 0:
458
- print(f" {dim}: [{coords[0]:.2f} ... {coords[-1]:.2f}] (n={len(coords)})")
 
 
 
459
 
460
  # Print memory size estimate
461
  total_elements = np.prod(list(info["sizes"].values()))
 
21
  import torch
22
  import xarray as xr
23
  from numpy.typing import NDArray
24
+ from iohub import open_ome_zarr
25
 
26
  from waveorder import util
27
  from waveorder.models import isotropic_thin_3d
 
129
  plate, row: str, column: str, field: str, resolution: int = 0
130
  ) -> xr.DataArray:
131
  """
132
+ Load a specific FOV from HCS plate using iohub.
 
 
133
 
134
  Parameters
135
  ----------
 
147
  Returns
148
  -------
149
  xr.DataArray
150
+ Image data with labeled dimensions (t, c, z, y, x)
151
  """
 
152
  position_key = f"{row}/{column}/{field}"
153
  position = plate[position_key]
154
+ return position.to_xarray()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  # === Data Loading ===
 
173
  Returns
174
  -------
175
  xr.DataArray
176
+ Image data with labeled dimensions (t, c, z, y, x)
177
  """
178
  zarr_path = Path(zarr_path)
179
  fov_path = Path(fov_path)
 
181
  print(f"Loading zarr store from: {zarr_path}")
182
  print(f"Accessing FOV: {fov_path}")
183
 
184
+ position = open_ome_zarr(str(zarr_path / fov_path), mode="r")
185
+ data_xr = position.to_xarray()
 
 
 
 
 
186
 
187
  print(f"Loaded data shape: {dict(data_xr.sizes)}")
188
  print(f"Dimensions: {list(data_xr.dims)}")
 
292
  # Build selection dictionary for indexed dimensions
293
  sel_dict = {}
294
  if t is not None:
295
+ sel_dict["t"] = int(t)
296
  if c is not None:
297
+ sel_dict["c"] = int(c)
298
  if z is not None:
299
+ sel_dict["z"] = int(z)
300
 
301
  # Extract slice using xarray's labeled indexing
302
  slice_xr = data_xr.isel(**sel_dict) if sel_dict else data_xr
 
431
  for dim in info["dims"]:
432
  coords = info["coords"][dim]
433
  if len(coords) > 0:
434
+ try:
435
+ print(f" {dim}: [{coords[0]:.2f} ... {coords[-1]:.2f}] (n={len(coords)})")
436
+ except (ValueError, TypeError):
437
+ print(f" {dim}: [{coords[0]} ... {coords[-1]}] (n={len(coords)})")
438
 
439
  # Print memory size estimate
440
  total_elements = np.prod(list(info["sizes"].values()))
optimize_demo.py CHANGED
@@ -159,7 +159,7 @@ def load_selected_fov(field: str, current_z: int, plate_metadata):
159
  new_pixel_scales = (Config.PIXEL_SIZE_Z, Config.PIXEL_SIZE_YX, Config.PIXEL_SIZE_YX)
160
 
161
  # Update Z slider
162
- z_max = new_data_xr.sizes["Z"] - 1
163
  new_z = min(current_z, z_max)
164
 
165
  print(f"✅ Loaded: {dict(new_data_xr.sizes)}")
@@ -232,7 +232,7 @@ def run_reconstruction_ui(
232
  Uses slider parameters directly for a single fast reconstruction.
233
  """
234
  # Extract full Z-stack for timepoint 0 (for reconstruction)
235
- zyx_stack = data_xr_state.isel(T=0, C=Config.CHANNEL).values
236
 
237
  # Get current Z-slice for comparison (left side of ImageSlider)
238
  original_normalized = extract_2d_slice(
@@ -276,7 +276,7 @@ def run_optimization_ui(
276
  iteration history, iteration slider, and SLIDER UPDATES.
277
  """
278
  # Extract full Z-stack for timepoint 0 (for reconstruction)
279
- zyx_stack = data_xr_state.isel(T=0, C=Config.CHANNEL).values
280
 
281
  # Get current Z-slice for comparison (left side of ImageSlider)
282
  original_normalized = extract_2d_slice(
@@ -529,7 +529,7 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
529
  data_xr,
530
  t=0,
531
  c=Config.CHANNEL,
532
- z=data_xr.sizes["Z"] // 2,
533
  normalize=True,
534
  verbose=False,
535
  )
@@ -547,8 +547,8 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
547
  gr.Markdown("### 🎛️ Navigation")
548
  z_slider = gr.Slider(
549
  minimum=0,
550
- maximum=data_xr.sizes["Z"] - 1,
551
- value=data_xr.sizes["Z"] // 2,
552
  step=1,
553
  label="Z-slice",
554
  scale=1,
 
159
  new_pixel_scales = (Config.PIXEL_SIZE_Z, Config.PIXEL_SIZE_YX, Config.PIXEL_SIZE_YX)
160
 
161
  # Update Z slider
162
+ z_max = new_data_xr.sizes["z"] - 1
163
  new_z = min(current_z, z_max)
164
 
165
  print(f"✅ Loaded: {dict(new_data_xr.sizes)}")
 
232
  Uses slider parameters directly for a single fast reconstruction.
233
  """
234
  # Extract full Z-stack for timepoint 0 (for reconstruction)
235
+ zyx_stack = data_xr_state.isel(t=0, c=Config.CHANNEL).values
236
 
237
  # Get current Z-slice for comparison (left side of ImageSlider)
238
  original_normalized = extract_2d_slice(
 
276
  iteration history, iteration slider, and SLIDER UPDATES.
277
  """
278
  # Extract full Z-stack for timepoint 0 (for reconstruction)
279
+ zyx_stack = data_xr_state.isel(t=0, c=Config.CHANNEL).values
280
 
281
  # Get current Z-slice for comparison (left side of ImageSlider)
282
  original_normalized = extract_2d_slice(
 
529
  data_xr,
530
  t=0,
531
  c=Config.CHANNEL,
532
+ z=data_xr.sizes["z"] // 2,
533
  normalize=True,
534
  verbose=False,
535
  )
 
547
  gr.Markdown("### 🎛️ Navigation")
548
  z_slider = gr.Slider(
549
  minimum=0,
550
+ maximum=data_xr.sizes["z"] - 1,
551
+ value=data_xr.sizes["z"] // 2,
552
  step=1,
553
  label="Z-slice",
554
  scale=1,
requirements.txt CHANGED
@@ -3,8 +3,7 @@
3
 
4
  # Install waveorder from gradio-demo branch
5
  git+https://github.com/mehta-lab/waveorder.git@gradio-demo
6
- git+https://github.com/ianhi/xarray-ome.git@main
7
- git+https://github.com/czbiohub-sf/iohub.git@v0.3.0a2
8
 
9
  # Gradio for web interface
10
  gradio>=6.0.2
 
3
 
4
  # Install waveorder from gradio-demo branch
5
  git+https://github.com/mehta-lab/waveorder.git@gradio-demo
6
+ git+https://github.com/czbiohub-sf/iohub.git@main
 
7
 
8
  # Gradio for web interface
9
  gradio>=6.0.2