| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import collections |
| import collections.abc |
| import functools |
| import json |
| import random |
| import time |
| from contextlib import ContextDecorator |
| from typing import Any, Callable, TypeVar |
|
|
| import numpy as np |
| import termcolor |
| import torch |
|
|
| from cosmos1.utils import distributed, log |
|
|
|
|
| def to( |
| data: Any, |
| device: str | torch.device | None = None, |
| dtype: torch.dtype | None = None, |
| memory_format: torch.memory_format = torch.preserve_format, |
| ) -> Any: |
| """Recursively cast data into the specified device, dtype, and/or memory_format. |
| |
| The input data can be a tensor, a list of tensors, a dict of tensors. |
| See the documentation for torch.Tensor.to() for details. |
| |
| Args: |
| data (Any): Input data. |
| device (str | torch.device): GPU device (default: None). |
| dtype (torch.dtype): data type (default: None). |
| memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). |
| |
| Returns: |
| data (Any): Data cast to the specified device, dtype, and/or memory_format. |
| """ |
| assert ( |
| device is not None or dtype is not None or memory_format is not None |
| ), "at least one of device, dtype, memory_format should be specified" |
| if isinstance(data, torch.Tensor): |
| is_cpu = (isinstance(device, str) and device == "cpu") or ( |
| isinstance(device, torch.device) and device.type == "cpu" |
| ) |
| data = data.to( |
| device=device, |
| dtype=dtype, |
| memory_format=memory_format, |
| non_blocking=(not is_cpu), |
| ) |
| return data |
| elif isinstance(data, collections.abc.Mapping): |
| return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) |
| elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): |
| return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) |
| else: |
| return data |
|
|
|
|
| def serialize(data: Any) -> Any: |
| """Serialize data by hierarchically traversing through iterables. |
| |
| Args: |
| data (Any): Input data. |
| |
| Returns: |
| data (Any): Serialized data. |
| """ |
| if isinstance(data, collections.abc.Mapping): |
| return type(data)({key: serialize(data[key]) for key in data}) |
| elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): |
| return type(data)([serialize(elem) for elem in data]) |
| else: |
| try: |
| json.dumps(data) |
| except TypeError: |
| data = str(data) |
| return data |
|
|
|
|
| def set_random_seed(seed: int, by_rank: bool = False) -> None: |
| """Set random seed. This includes random, numpy, Pytorch. |
| |
| Args: |
| seed (int): Random seed. |
| by_rank (bool): if true, each GPU will use a different random seed. |
| """ |
| if by_rank: |
| seed += distributed.get_rank() |
| log.info(f"Using random seed {seed}.") |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
|
|
| def arch_invariant_rand( |
| shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None |
| ): |
| """Produce a GPU-architecture-invariant randomized Torch tensor. |
| |
| Args: |
| shape (list or tuple of ints): Output tensor shape. |
| dtype (torch.dtype): Output tensor type. |
| device (torch.device): Device holding the output. |
| seed (int): Optional randomization seed. |
| |
| Returns: |
| tensor (torch.tensor): Randomly-generated tensor. |
| """ |
| |
| rng = np.random.RandomState(seed) |
|
|
| |
| random_array = rng.standard_normal(shape).astype(np.float32) |
|
|
| |
| return torch.from_numpy(random_array).to(dtype=dtype, device=device) |
|
|
|
|
| T = TypeVar("T", bound=Callable[..., Any]) |
|
|
|
|
| class timer(ContextDecorator): |
| """Simple timer for timing the execution of code. |
| |
| It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. |
| |
| Example: |
| def func_a(): |
| time.sleep(1) |
| with timer("func_a"): |
| func_a() |
| |
| @timer("func_b) |
| def func_b(): |
| time.sleep(1) |
| func_b() |
| """ |
|
|
| def __init__(self, context: str, debug: bool = False): |
| self.context = context |
| self.debug = debug |
|
|
| def __enter__(self) -> None: |
| self.tic = time.time() |
|
|
| def __exit__(self, exc_type, exc_value, traceback) -> None: |
| time_spent = time.time() - self.tic |
| if self.debug: |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") |
| else: |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") |
|
|
| def __call__(self, func: T) -> T: |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| tic = time.time() |
| result = func(*args, **kwargs) |
| time_spent = time.time() - tic |
| if self.debug: |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") |
| else: |
| log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") |
| return result |
|
|
| return wrapper |
|
|
|
|
| class Color: |
| """A convenience class to colorize strings in the console. |
| |
| Example: |
| import |
| print("This is {Color.red('important')}.") |
| """ |
|
|
| @staticmethod |
| def red(x: str) -> str: |
| return termcolor.colored(str(x), color="red") |
|
|
| @staticmethod |
| def green(x: str) -> str: |
| return termcolor.colored(str(x), color="green") |
|
|
| @staticmethod |
| def cyan(x: str) -> str: |
| return termcolor.colored(str(x), color="cyan") |
|
|
| @staticmethod |
| def yellow(x: str) -> str: |
| return termcolor.colored(str(x), color="yellow") |
|
|