Buckets:

rydlrKE's picture
download
raw
13.4 kB
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared utilities: validation decorator, batching, JSON I/O, seeding, tensor conversion."""
import inspect
import json
import math
import random
from collections.abc import Mapping, Sequence
from functools import wraps
from math import prod
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, ParamSpec, TypeVar, Union
import numpy as np
import torch
def validate(validator, save_args: bool = False, super_init: bool = False):
"""Create a decorator function for validating user inputs.
Args:
validator: the function to validate (pydantic dataclass)
save (bool): save all the attributes to the obj [args[0]]
super_init (bool): init parent with no arguments (useful for using save on a nn.Module)
Returns:
decorator: the decorator function
"""
def decorator(func):
@wraps(func)
def validated_func(*args, **kwargs):
conf = validator(**kwargs)
if save_args:
assert len(args) != 0
obj = args[0]
if super_init:
# init the parent module
super(type(obj), obj).__init__()
for key, val in conf.__dict__.items():
setattr(obj, key, val)
return func(*args, conf)
return validated_func
return decorator
# Type alias for clarity
Tensor = Any
P = ParamSpec("P")
R = TypeVar("R")
def ensure_batched(**spec: int) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to flatten complex batch dimensions.
Fixes included:
1. Handles 1D tensors (tail_ndim=0) correctly without slicing errors.
2. Skips .reshape() if the input is already purely flat (Optimization).
"""
if not spec:
raise ValueError("At least one argument spec must be provided.")
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
sig = inspect.signature(fn)
@wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
def _sequence_shape(name: str, value: Any) -> tuple[int, ...]:
if not isinstance(value, (list, tuple)):
return ()
if len(value) == 0:
return (0,)
first_shape = _sequence_shape(name, value[0])
for item in value[1:]:
item_shape = _sequence_shape(name, item)
if item_shape != first_shape:
raise ValueError(f"'{name}' must be a rectangular nested sequence, got ragged shape.")
return (len(value), *first_shape)
def _shape_and_ndim(name: str, value: Any) -> tuple[tuple[int, ...], int]:
if hasattr(value, "shape") and hasattr(value, "ndim"):
shape = tuple(value.shape)
return shape, int(value.ndim)
if isinstance(value, (list, tuple)):
shape = _sequence_shape(name, value)
return shape, len(shape)
raise TypeError(f"'{name}' must be tensor-like or a nested list/tuple, got {type(value)}.")
def _reshape_like(value: Any, shape: tuple[int, ...], name: str) -> Any:
if hasattr(value, "reshape"):
return value.reshape(*shape)
if not isinstance(value, (list, tuple)):
raise TypeError(f"Cannot reshape '{name}' of type {type(value)}.")
flat: list[Any] = []
def _flatten(x: Any) -> None:
if isinstance(x, (list, tuple)):
for item in x:
_flatten(item)
else:
flat.append(x)
_flatten(value)
expected_size = prod(shape) if shape else 1
if len(flat) != expected_size:
raise ValueError(f"Cannot reshape '{name}' with {len(flat)} elements into shape {shape}.")
def _build(index: int, dims: tuple[int, ...]) -> tuple[Any, int]:
if not dims:
return flat[index], index + 1
items = []
for _ in range(dims[0]):
item, index = _build(index, dims[1:])
items.append(item)
return items, index
rebuilt, used = _build(0, shape)
if used != len(flat):
raise ValueError(f"Internal reshape error for '{name}': used {used}/{len(flat)} elements.")
if isinstance(value, tuple) and isinstance(rebuilt, list):
return tuple(rebuilt)
return rebuilt
# --- 1. CANONICAL ARGUMENT ---
spec_items = list(spec.items())
canonical_name = None
canonical_ndim = None
x0 = None
for name, ndim in spec_items:
candidate = bound.arguments.get(name, None)
if candidate is not None:
canonical_name = name
canonical_ndim = ndim
x0 = candidate
break
if canonical_name is None:
raise ValueError(
"All canonical candidates are None: " + ", ".join(f"'{name}'" for name, _ in spec_items)
)
# Calculate split between Batch dims and Feature dims
expected_tail_dims = canonical_ndim - 1 # e.g. 3 - 1 = 2 (Sequence, Feat)
x0_shape, x0_ndim = _shape_and_ndim(canonical_name, x0)
# Validation
if x0_ndim < expected_tail_dims:
raise ValueError(f"'{canonical_name}' ndim={x0_ndim} < expected {expected_tail_dims} tail dims.")
# --- LOGIC FIX 1: Handle 0 tail dims correctly ---
if expected_tail_dims == 0:
orig_batch_shape = x0_shape
tail_shape = ()
else:
orig_batch_shape = x0_shape[:-expected_tail_dims]
tail_shape = x0_shape[-expected_tail_dims:]
# Calculate flattened batch size
# If orig_batch_shape is () (scalar input), size is 1.
B_flat = prod(orig_batch_shape) if orig_batch_shape else 1
# Determine if we added a fake batch dim (unbatched input)
is_unbatched_input = len(orig_batch_shape) == 0
# --- LOGIC FIX 2: Skip reshape if already flat (Optimization) ---
# If batch shape is already 1D (e.g. [2]), we don't need to reshape [2, 140, 5] -> [2, 140, 5]
is_already_flat = len(orig_batch_shape) == 1
if is_unbatched_input:
# (H, W) -> (1, H, W)
x0_batched = _reshape_like(x0, (1, *tail_shape), canonical_name)
elif is_already_flat:
# (B, H, W) -> Keep as is
x0_batched = x0
else:
# (B1, B2, H, W) -> (B1*B2, H, W)
x0_batched = _reshape_like(x0, (B_flat, *tail_shape), canonical_name)
bound.arguments[canonical_name] = x0_batched
# --- 2. OTHER ARGUMENTS ---
for name, target_ndim in spec_items:
if name == canonical_name:
continue
val = bound.arguments.get(name, None)
if val is None:
continue
arg_tail_dims = target_ndim - 1 # e.g. for lengths=1, tail=0
val_shape, val_ndim = _shape_and_ndim(name, val)
# Validate
if val_ndim < arg_tail_dims:
raise ValueError(f"'{name}' ndim={val_ndim} too small.")
# --- Get Batch Shape (With 0-tail fix) ---
if arg_tail_dims == 0:
val_batch_shape = val_shape
val_tail_shape = ()
else:
val_batch_shape = val_shape[:-arg_tail_dims]
val_tail_shape = val_shape[-arg_tail_dims:]
# --- Check Mismatch ---
# Unbatched inputs must match unbatched canonical
if len(val_batch_shape) == 0:
if not is_unbatched_input:
raise ValueError(f"'{name}' is unbatched but canonical is batched.")
val_batched = _reshape_like(val, (1, *val_tail_shape), name)
else:
# Batched inputs must match canonical batch shape EXACTLY
if val_batch_shape != orig_batch_shape:
raise ValueError(
f"Batch dimensions mismatch! '{canonical_name}' has {orig_batch_shape}, "
f"but '{name}' has {val_batch_shape}."
)
# Optimization: Don't reshape if already flat
if is_already_flat:
val_batched = val
else:
val_batched = _reshape_like(val, (B_flat, *val_tail_shape), name)
bound.arguments[name] = val_batched
# --- 3. EXECUTION ---
out = fn(**bound.arguments)
# --- 4. RESTORE ---
def restore(obj):
if isinstance(obj, Mapping):
return {k: restore(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(restore(x) for x in obj)
if hasattr(obj, "shape"):
if obj.ndim == 0:
return obj
# Verify batch dimension exists and wasn't reduced
if obj.shape[0] != B_flat:
return obj
# If input was simple (B, ...), return simple (B, ...)
if is_already_flat:
return obj
rest = obj.shape[1:]
if is_unbatched_input:
assert obj.shape[0] == 1, "The batch size should be 1 for unbatched."
return obj[0]
return obj.reshape(*orig_batch_shape, *rest)
return obj
return restore(out)
return wrapper
return decorator
def to_numpy(obj):
"""Recursively convert tensors in dicts/lists/tuples to numpy arrays; leave other types
unchanged."""
if isinstance(obj, Mapping):
return {k: to_numpy(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(to_numpy(x) for x in obj)
if isinstance(obj, torch.Tensor):
return obj.cpu().numpy()
return obj
def to_torch(obj, device=None, dtype=None):
"""Recursively convert numpy arrays in dicts/lists/tuples to torch tensors; optionally move to
device/dtype."""
if isinstance(obj, Mapping):
return {k: to_torch(v, device, dtype) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(to_torch(x, device, dtype) for x in obj)
if isinstance(obj, np.ndarray):
obj = torch.from_numpy(obj)
if isinstance(obj, torch.Tensor):
if dtype is not None:
obj = obj.to(dtype=dtype)
if device is None:
return obj
return obj.to(device)
return obj
def seed_everything(seed: int, deterministic: bool = False) -> None:
"""Seed all random number generators."""
random.seed(seed) # for Python random module.
np.random.seed(seed) # for NumPy.
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True # for deterministic behavior.
torch.backends.cudnn.benchmark = False # if you want to make the behavior deterministic.
def load_json(path: Union[str, Path]) -> Any:
"""Load a JSON file and return its contents.
Args:
path (str | Path): Path to the JSON file.
Returns:
Any: Parsed JSON content (dict, list, etc.).
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is not valid JSON.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"JSON file not found: {path}")
try:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in file {path}: {e}") from e
def save_json(path: Union[str, Path], data: Any) -> None:
"""Save data to a JSON file.
Args:
path (str | Path): Path to the JSON file.
data (Any): Data to save (must be JSON serializable).
Raises:
ValueError: If the data is not JSON serializable.
"""
path = Path(path)
# Create parent directories if they don't exist
path.parent.mkdir(parents=True, exist_ok=True)
try:
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
except (TypeError, ValueError) as e:
raise ValueError(f"Data is not JSON serializable: {e}") from e

Xet Storage Details

Size:
13.4 kB
·
Xet hash:
70a0a4e998f95046b73d3d685804cff6a0c4272ebcec8e664319a1ccc46bc2a4

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.