Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Shared utilities: validation decorator, batching, JSON I/O, seeding, tensor conversion.""" | |
| import inspect | |
| import json | |
| import math | |
| import random | |
| from collections.abc import Mapping, Sequence | |
| from functools import wraps | |
| from math import prod | |
| from pathlib import Path | |
| from typing import Any, Callable, Mapping, Optional, ParamSpec, TypeVar, Union | |
| import numpy as np | |
| import torch | |
| def validate(validator, save_args: bool = False, super_init: bool = False): | |
| """Create a decorator function for validating user inputs. | |
| Args: | |
| validator: the function to validate (pydantic dataclass) | |
| save (bool): save all the attributes to the obj [args[0]] | |
| super_init (bool): init parent with no arguments (useful for using save on a nn.Module) | |
| Returns: | |
| decorator: the decorator function | |
| """ | |
| def decorator(func): | |
| def validated_func(*args, **kwargs): | |
| conf = validator(**kwargs) | |
| if save_args: | |
| assert len(args) != 0 | |
| obj = args[0] | |
| if super_init: | |
| # init the parent module | |
| super(type(obj), obj).__init__() | |
| for key, val in conf.__dict__.items(): | |
| setattr(obj, key, val) | |
| return func(*args, conf) | |
| return validated_func | |
| return decorator | |
| # Type alias for clarity | |
| Tensor = Any | |
| P = ParamSpec("P") | |
| R = TypeVar("R") | |
| def ensure_batched(**spec: int) -> Callable[[Callable[P, R]], Callable[P, R]]: | |
| """Decorator to flatten complex batch dimensions. | |
| Fixes included: | |
| 1. Handles 1D tensors (tail_ndim=0) correctly without slicing errors. | |
| 2. Skips .reshape() if the input is already purely flat (Optimization). | |
| """ | |
| if not spec: | |
| raise ValueError("At least one argument spec must be provided.") | |
| def decorator(fn: Callable[P, R]) -> Callable[P, R]: | |
| sig = inspect.signature(fn) | |
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
| bound = sig.bind(*args, **kwargs) | |
| bound.apply_defaults() | |
| def _sequence_shape(name: str, value: Any) -> tuple[int, ...]: | |
| if not isinstance(value, (list, tuple)): | |
| return () | |
| if len(value) == 0: | |
| return (0,) | |
| first_shape = _sequence_shape(name, value[0]) | |
| for item in value[1:]: | |
| item_shape = _sequence_shape(name, item) | |
| if item_shape != first_shape: | |
| raise ValueError(f"'{name}' must be a rectangular nested sequence, got ragged shape.") | |
| return (len(value), *first_shape) | |
| def _shape_and_ndim(name: str, value: Any) -> tuple[tuple[int, ...], int]: | |
| if hasattr(value, "shape") and hasattr(value, "ndim"): | |
| shape = tuple(value.shape) | |
| return shape, int(value.ndim) | |
| if isinstance(value, (list, tuple)): | |
| shape = _sequence_shape(name, value) | |
| return shape, len(shape) | |
| raise TypeError(f"'{name}' must be tensor-like or a nested list/tuple, got {type(value)}.") | |
| def _reshape_like(value: Any, shape: tuple[int, ...], name: str) -> Any: | |
| if hasattr(value, "reshape"): | |
| return value.reshape(*shape) | |
| if not isinstance(value, (list, tuple)): | |
| raise TypeError(f"Cannot reshape '{name}' of type {type(value)}.") | |
| flat: list[Any] = [] | |
| def _flatten(x: Any) -> None: | |
| if isinstance(x, (list, tuple)): | |
| for item in x: | |
| _flatten(item) | |
| else: | |
| flat.append(x) | |
| _flatten(value) | |
| expected_size = prod(shape) if shape else 1 | |
| if len(flat) != expected_size: | |
| raise ValueError(f"Cannot reshape '{name}' with {len(flat)} elements into shape {shape}.") | |
| def _build(index: int, dims: tuple[int, ...]) -> tuple[Any, int]: | |
| if not dims: | |
| return flat[index], index + 1 | |
| items = [] | |
| for _ in range(dims[0]): | |
| item, index = _build(index, dims[1:]) | |
| items.append(item) | |
| return items, index | |
| rebuilt, used = _build(0, shape) | |
| if used != len(flat): | |
| raise ValueError(f"Internal reshape error for '{name}': used {used}/{len(flat)} elements.") | |
| if isinstance(value, tuple) and isinstance(rebuilt, list): | |
| return tuple(rebuilt) | |
| return rebuilt | |
| # --- 1. CANONICAL ARGUMENT --- | |
| spec_items = list(spec.items()) | |
| canonical_name = None | |
| canonical_ndim = None | |
| x0 = None | |
| for name, ndim in spec_items: | |
| candidate = bound.arguments.get(name, None) | |
| if candidate is not None: | |
| canonical_name = name | |
| canonical_ndim = ndim | |
| x0 = candidate | |
| break | |
| if canonical_name is None: | |
| raise ValueError( | |
| "All canonical candidates are None: " + ", ".join(f"'{name}'" for name, _ in spec_items) | |
| ) | |
| # Calculate split between Batch dims and Feature dims | |
| expected_tail_dims = canonical_ndim - 1 # e.g. 3 - 1 = 2 (Sequence, Feat) | |
| x0_shape, x0_ndim = _shape_and_ndim(canonical_name, x0) | |
| # Validation | |
| if x0_ndim < expected_tail_dims: | |
| raise ValueError(f"'{canonical_name}' ndim={x0_ndim} < expected {expected_tail_dims} tail dims.") | |
| # --- LOGIC FIX 1: Handle 0 tail dims correctly --- | |
| if expected_tail_dims == 0: | |
| orig_batch_shape = x0_shape | |
| tail_shape = () | |
| else: | |
| orig_batch_shape = x0_shape[:-expected_tail_dims] | |
| tail_shape = x0_shape[-expected_tail_dims:] | |
| # Calculate flattened batch size | |
| # If orig_batch_shape is () (scalar input), size is 1. | |
| B_flat = prod(orig_batch_shape) if orig_batch_shape else 1 | |
| # Determine if we added a fake batch dim (unbatched input) | |
| is_unbatched_input = len(orig_batch_shape) == 0 | |
| # --- LOGIC FIX 2: Skip reshape if already flat (Optimization) --- | |
| # If batch shape is already 1D (e.g. [2]), we don't need to reshape [2, 140, 5] -> [2, 140, 5] | |
| is_already_flat = len(orig_batch_shape) == 1 | |
| if is_unbatched_input: | |
| # (H, W) -> (1, H, W) | |
| x0_batched = _reshape_like(x0, (1, *tail_shape), canonical_name) | |
| elif is_already_flat: | |
| # (B, H, W) -> Keep as is | |
| x0_batched = x0 | |
| else: | |
| # (B1, B2, H, W) -> (B1*B2, H, W) | |
| x0_batched = _reshape_like(x0, (B_flat, *tail_shape), canonical_name) | |
| bound.arguments[canonical_name] = x0_batched | |
| # --- 2. OTHER ARGUMENTS --- | |
| for name, target_ndim in spec_items: | |
| if name == canonical_name: | |
| continue | |
| val = bound.arguments.get(name, None) | |
| if val is None: | |
| continue | |
| arg_tail_dims = target_ndim - 1 # e.g. for lengths=1, tail=0 | |
| val_shape, val_ndim = _shape_and_ndim(name, val) | |
| # Validate | |
| if val_ndim < arg_tail_dims: | |
| raise ValueError(f"'{name}' ndim={val_ndim} too small.") | |
| # --- Get Batch Shape (With 0-tail fix) --- | |
| if arg_tail_dims == 0: | |
| val_batch_shape = val_shape | |
| val_tail_shape = () | |
| else: | |
| val_batch_shape = val_shape[:-arg_tail_dims] | |
| val_tail_shape = val_shape[-arg_tail_dims:] | |
| # --- Check Mismatch --- | |
| # Unbatched inputs must match unbatched canonical | |
| if len(val_batch_shape) == 0: | |
| if not is_unbatched_input: | |
| raise ValueError(f"'{name}' is unbatched but canonical is batched.") | |
| val_batched = _reshape_like(val, (1, *val_tail_shape), name) | |
| else: | |
| # Batched inputs must match canonical batch shape EXACTLY | |
| if val_batch_shape != orig_batch_shape: | |
| raise ValueError( | |
| f"Batch dimensions mismatch! '{canonical_name}' has {orig_batch_shape}, " | |
| f"but '{name}' has {val_batch_shape}." | |
| ) | |
| # Optimization: Don't reshape if already flat | |
| if is_already_flat: | |
| val_batched = val | |
| else: | |
| val_batched = _reshape_like(val, (B_flat, *val_tail_shape), name) | |
| bound.arguments[name] = val_batched | |
| # --- 3. EXECUTION --- | |
| out = fn(**bound.arguments) | |
| # --- 4. RESTORE --- | |
| def restore(obj): | |
| if isinstance(obj, Mapping): | |
| return {k: restore(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return type(obj)(restore(x) for x in obj) | |
| if hasattr(obj, "shape"): | |
| if obj.ndim == 0: | |
| return obj | |
| # Verify batch dimension exists and wasn't reduced | |
| if obj.shape[0] != B_flat: | |
| return obj | |
| # If input was simple (B, ...), return simple (B, ...) | |
| if is_already_flat: | |
| return obj | |
| rest = obj.shape[1:] | |
| if is_unbatched_input: | |
| assert obj.shape[0] == 1, "The batch size should be 1 for unbatched." | |
| return obj[0] | |
| return obj.reshape(*orig_batch_shape, *rest) | |
| return obj | |
| return restore(out) | |
| return wrapper | |
| return decorator | |
| def to_numpy(obj): | |
| """Recursively convert tensors in dicts/lists/tuples to numpy arrays; leave other types | |
| unchanged.""" | |
| if isinstance(obj, Mapping): | |
| return {k: to_numpy(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return type(obj)(to_numpy(x) for x in obj) | |
| if isinstance(obj, torch.Tensor): | |
| return obj.cpu().numpy() | |
| return obj | |
| def to_torch(obj, device=None, dtype=None): | |
| """Recursively convert numpy arrays in dicts/lists/tuples to torch tensors; optionally move to | |
| device/dtype.""" | |
| if isinstance(obj, Mapping): | |
| return {k: to_torch(v, device, dtype) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return type(obj)(to_torch(x, device, dtype) for x in obj) | |
| if isinstance(obj, np.ndarray): | |
| obj = torch.from_numpy(obj) | |
| if isinstance(obj, torch.Tensor): | |
| if dtype is not None: | |
| obj = obj.to(dtype=dtype) | |
| if device is None: | |
| return obj | |
| return obj.to(device) | |
| return obj | |
| def seed_everything(seed: int, deterministic: bool = False) -> None: | |
| """Seed all random number generators.""" | |
| random.seed(seed) # for Python random module. | |
| np.random.seed(seed) # for NumPy. | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| if deterministic: | |
| torch.backends.cudnn.deterministic = True # for deterministic behavior. | |
| torch.backends.cudnn.benchmark = False # if you want to make the behavior deterministic. | |
| def load_json(path: Union[str, Path]) -> Any: | |
| """Load a JSON file and return its contents. | |
| Args: | |
| path (str | Path): Path to the JSON file. | |
| Returns: | |
| Any: Parsed JSON content (dict, list, etc.). | |
| Raises: | |
| FileNotFoundError: If the file does not exist. | |
| ValueError: If the file is not valid JSON. | |
| """ | |
| path = Path(path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"JSON file not found: {path}") | |
| try: | |
| with path.open("r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Invalid JSON in file {path}: {e}") from e | |
| def save_json(path: Union[str, Path], data: Any) -> None: | |
| """Save data to a JSON file. | |
| Args: | |
| path (str | Path): Path to the JSON file. | |
| data (Any): Data to save (must be JSON serializable). | |
| Raises: | |
| ValueError: If the data is not JSON serializable. | |
| """ | |
| path = Path(path) | |
| # Create parent directories if they don't exist | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| try: | |
| with path.open("w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| except (TypeError, ValueError) as e: | |
| raise ValueError(f"Data is not JSON serializable: {e}") from e | |