| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Decorators. |
| """ |
|
|
| import functools |
| import threading |
| import time |
| from typing import Callable |
| import torch |
|
|
| from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank |
| from common.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def log_on_entry(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will log the function name at entry. |
| When using multiple decorators, this must be applied innermost to properly capture the name. |
| """ |
|
|
| def log_on_entry_wrapper(*args, **kwargs): |
| logger.info(f"Entering {func.__name__}") |
| return func(*args, **kwargs) |
|
|
| return log_on_entry_wrapper |
|
|
|
|
| def barrier_on_entry(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will start executing when all ranks are ready to enter. |
| """ |
|
|
| def barrier_on_entry_wrapper(*args, **kwargs): |
| barrier_if_distributed() |
| return func(*args, **kwargs) |
|
|
| return barrier_on_entry_wrapper |
|
|
|
|
| def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: |
| """ |
| Helper function for local_rank_zero_only and global_rank_zero_only. |
| """ |
|
|
| def conditional_execute_wrapper(*args, **kwargs): |
| |
| result = func(*args, **kwargs) if execute else None |
| |
| barrier_if_distributed() |
| |
| return result |
|
|
| return conditional_execute_wrapper |
|
|
|
|
| def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: |
| """ |
| Helper function for some functions with special constraints, |
| especially functions called by other global_rank_zero_only / local_rank_zero_only ones, |
| in case they are wrongly invoked in other scenarios. |
| """ |
|
|
| def asserted_execute_wrapper(*args, **kwargs): |
| assert condition, err_msg |
| result = func(*args, **kwargs) |
| return result |
|
|
| return asserted_execute_wrapper |
|
|
|
|
| def local_rank_zero_only(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will only execute on local rank zero. |
| """ |
| return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) |
|
|
|
|
| def global_rank_zero_only(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will only execute on global rank zero. |
| """ |
| return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) |
|
|
|
|
| def assert_only_global_rank_zero(func: Callable) -> Callable: |
| """ |
| Functions with this decorator are only accessible to processes with global rank zero. |
| """ |
| return _asserted_wrapper_factory( |
| get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" |
| ) |
|
|
|
|
| def assert_only_local_rank_zero(func: Callable) -> Callable: |
| """ |
| Functions with this decorator are only accessible to processes with local rank zero. |
| """ |
| return _asserted_wrapper_factory( |
| get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" |
| ) |
|
|
|
|
| def new_thread(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will run in a new thread. |
| The function will return the thread, which can be joined to wait for completion. |
| """ |
|
|
| def new_thread_wrapper(*args, **kwargs): |
| thread = threading.Thread(target=func, args=args, kwargs=kwargs) |
| thread.start() |
| return thread |
|
|
| return new_thread_wrapper |
|
|
|
|
| def log_runtime(func: Callable) -> Callable: |
| """ |
| Functions with this decorator will logging the runtime. |
| """ |
|
|
| @functools.wraps(func) |
| def wrapped(*args, **kwargs): |
| torch.distributed.barrier() |
| start = time.perf_counter() |
| result = func(*args, **kwargs) |
| torch.distributed.barrier() |
| logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") |
| return result |
|
|
| return wrapped |
|
|