| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| from typing import Union |
| from warnings import warn |
|
|
| import psutil |
| import torch |
| from torch import nn |
| from torch.autograd.graph import saved_tensors_hooks |
|
|
| from torchtitan.tools.logging import logger |
|
|
| try: |
| import torchao |
| from torchao.dtypes.nf4tensor import NF4Tensor |
| except ImportError: |
| torchao = None |
| NF4Tensor = None |
| logger.warning("torchao not found. ") |
|
|
| |
|
|
|
|
| class OffloadActivations(saved_tensors_hooks): |
| """Context manager under which activation tensors created in the forward pass will be offloaded. |
| |
| Enable the memory efficiency technique of activation offloading, where activations bigger than |
| min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. |
| This is in contrast to maintaining the activation on GPU VRAM throughout the program. |
| |
| This manager contains the option of using one additional CUDA stream to handle the communication |
| between CUDA and CPU, which is intended to overlap with the default computation stream to improve |
| runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between |
| runtime vs memory usage. |
| |
| Args: |
| use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned |
| memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly |
| but is a limited resource. Default: True. |
| |
| use_streams (bool): Whether or not to use streams for performance optimization where |
| the communications get overlapped with the computation. Requires a torch build |
| after torch-2.5.0.]. Default: True. |
| |
| max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of |
| consecutive activations to keep alive during the forward pass. This number must be at |
| least 1. Keeping alive more activations will potentially allow more overlap between the |
| communication and compute streams at the cost of increasing memory usage. Keeping alive |
| fewer activations will conserve memory, but may cause poor overlap between the streams, |
| increasing runtime. Default: 5. |
| |
| min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify |
| for offloading. If the tensor is too small, we do not want to waste bandwidth and resources |
| moving it to CPU and back. Default: 1024 bytes. |
| |
| Raises: |
| ValueError: if max_fwd_stash_size is not at least 1. |
| |
| Example: |
| >>> with OffloadActivations(): |
| >>> logits = model(inputs) |
| >>> loss = ... |
| >>> loss.backward() |
| """ |
|
|
| def __init__( |
| self, |
| use_pin_memory: bool = True, |
| use_streams: bool = True, |
| max_fwd_stash_size: int = 5, |
| min_offload_size: int = 1024, |
| ) -> None: |
|
|
| self.use_streams: bool = use_streams |
|
|
| self.min_tensor_size_bytes = ( |
| min_offload_size |
| ) |
| self.tracker = ( |
| {} |
| ) |
| self.tensor_id: int = 0 |
| self.is_first_forward_call = True |
| self.is_first_backward_call = True |
| self.is_first_forward_pass = True |
|
|
| |
| self.use_pin_memory: bool = use_pin_memory |
| self.virtual_memory_safe_pct = ( |
| 60 |
| ) |
|
|
| self.s0 = torch.cuda.default_stream() |
|
|
| |
| if self.use_streams: |
| self.s1 = torch.cuda.Stream() |
| self.fwd_stash = {} |
| if max_fwd_stash_size < 1: |
| raise ValueError( |
| f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" |
| ) |
| self.max_fwd_stash_size = max_fwd_stash_size |
| self.bwd_tensor_stash = {} |
| self.bwd_ev_stash = {} |
| self.curr_graph_id = None |
| self.curr_autograd_node = None |
|
|
| |
| def verify_sufficient_virtual_memory(): |
| curr_pct = get_cpu_ram_pct() |
| if curr_pct > self.virtual_memory_safe_pct: |
| warn( |
| f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" |
| ) |
|
|
| def get_cpu_ram_pct() -> float: |
| |
| return psutil.virtual_memory().percent |
|
|
| def get_tensor_id() -> int: |
| |
| self.tensor_id += 1 |
| return self.tensor_id |
|
|
| def get_num_bytes_tensor(x: torch.Tensor) -> int: |
| |
| return ( |
| x.element_size() * x.nelement() |
| ) |
|
|
| |
| def pack_tensor(activation: torch.Tensor) -> int: |
| |
| if self.is_first_forward_call: |
| assert ( |
| len(self.tracker) == 0 |
| ), "backward pass should have cleared tracker of all tensors" |
|
|
| |
| self.is_first_forward_call = False |
| self.is_first_backward_call = True |
|
|
| |
| num_bytes = get_num_bytes_tensor(activation) |
| tensor_id = get_tensor_id() |
|
|
| |
| |
| if ( |
| activation.is_cuda |
| and num_bytes >= self.min_tensor_size_bytes |
| and ( |
| not isinstance(activation, torch.nn.Parameter) |
| and not isinstance(activation, torch.nn.Buffer) |
| ) |
| ): |
| if self.use_streams: |
| |
| |
| for id in [k for k in self.fwd_stash.keys()]: |
| if id <= tensor_id - self.max_fwd_stash_size: |
| _, ev = self.fwd_stash[id] |
| self.s0.wait_event(ev) |
| del self.fwd_stash[id] |
| else: |
| break |
|
|
| |
| self.s1.wait_stream(self.s0) |
|
|
| stream = self.s1 if self.use_streams else self.s0 |
| with torch.cuda.stream(stream): |
| try: |
| cpu_tensor = torch.empty_like( |
| activation, pin_memory=self.use_pin_memory, device="cpu" |
| ) |
| except NotImplementedError as e: |
| if ( |
| isinstance(activation, NF4Tensor) |
| and torchao.__version__ < "0.6.0.dev20240917" |
| ): |
| raise RuntimeError( |
| "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later" |
| ) from e |
| raise e |
| cpu_tensor.copy_(activation, non_blocking=True) |
| self.tracker[tensor_id] = ( |
| cpu_tensor, |
| True, |
| ) |
|
|
| if self.use_streams: |
| event = self.s1.record_event() |
|
|
| |
| self.fwd_stash[tensor_id] = (activation, event) |
| else: |
| self.tracker[tensor_id] = ( |
| activation, |
| False, |
| ) |
|
|
| return tensor_id |
|
|
| def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: |
| |
| |
| if self.is_first_backward_call: |
| if self.is_first_forward_pass: |
| self.is_first_forward_pass = False |
| if self.use_pin_memory: |
| verify_sufficient_virtual_memory() |
|
|
| self.is_first_backward_call = False |
| self.is_first_forward_call = True |
|
|
| assert ( |
| unpack_tensor_id in self.tracker |
| ), f"untracked tensor with id {unpack_tensor_id}" |
|
|
| maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
| if modified: |
| gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
| maybe_gpu_tensor = gpu_tensor |
|
|
| |
| del self.tracker[unpack_tensor_id] |
| return maybe_gpu_tensor |
|
|
| def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: |
| |
| |
| if self.is_first_backward_call: |
| self.curr_graph_id = torch._C._current_graph_task_id() |
|
|
| def wait_and_del_remaining_references() -> None: |
| for id in [k for k in self.bwd_tensor_stash.keys()]: |
| event = self.bwd_ev_stash[id] |
| self.s1.wait_event(event) |
| del self.bwd_tensor_stash[id] |
|
|
| |
| torch.autograd.variable.Variable._execution_engine.queue_callback( |
| wait_and_del_remaining_references |
| ) |
|
|
| if self.is_first_forward_pass: |
| self.is_first_forward_pass = False |
| if self.use_pin_memory: |
| verify_sufficient_virtual_memory() |
|
|
| self.is_first_backward_call = False |
| self.is_first_forward_call = True |
|
|
| assert ( |
| unpack_tensor_id in self.tracker |
| ), f"untracked tensor with id {unpack_tensor_id}" |
|
|
| maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
| if modified: |
| |
| graph_id = torch._C._current_graph_task_id() |
| node = torch._C._current_autograd_node() |
| prev_node_ids = [] |
|
|
| |
| if graph_id == self.curr_graph_id and self.curr_autograd_node != node: |
| self.curr_autograd_node = node |
| prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] |
|
|
| brought_back_from_cpu = True |
| if unpack_tensor_id in self.fwd_stash: |
| maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] |
| brought_back_from_cpu = False |
| else: |
| |
| with torch.cuda.stream(self.s1): |
| gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
| maybe_gpu_tensor = gpu_tensor |
|
|
| |
| self.s0.wait_stream(self.s1) |
|
|
| |
| self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| storage_refcount = torch._C._storage_Use_Count( |
| maybe_gpu_tensor.untyped_storage()._cdata |
| ) |
|
|
| def hook(outputs, inputs): |
| |
| if brought_back_from_cpu: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] |
| if ( |
| torch._C._storage_Use_Count( |
| unpacked_tensor.untyped_storage()._cdata |
| ) |
| > storage_refcount |
| ): |
| unpacked_tensor.record_stream(self.s0) |
| del self.bwd_tensor_stash[unpack_tensor_id] |
| else: |
| event = self.s0.record_event() |
| self.bwd_ev_stash[unpack_tensor_id] = event |
|
|
| |
| for id in [k for k in self.fwd_stash.keys()]: |
| _, ev = self.fwd_stash[id] |
| self.s0.wait_event(ev) |
| del self.fwd_stash[id] |
|
|
| |
| for id in prev_node_ids: |
| event = self.bwd_ev_stash[id] |
| self.s1.wait_event(event) |
| del self.bwd_tensor_stash[id] |
|
|
| return outputs |
|
|
| node.register_hook(hook) |
|
|
| |
| del self.tracker[unpack_tensor_id] |
| return maybe_gpu_tensor |
|
|
| unpack_tensor = ( |
| unpack_tensor_with_streams |
| if self.use_streams |
| else unpack_tensor_single_stream |
| ) |
| super().__init__(pack_tensor, unpack_tensor) |
|
|
|
|
| class NoOpManager(saved_tensors_hooks): |
| """ |
| A saved_tensors_hook manager used to disable any other saved_tensors_hook manager |
| applied before. This relies on the behavior that only the most recently registered |
| saved_tensors_hook will run. |
| |
| One example usage is to opt a local region of code out of activations offloading, |
| which is usually applied globally to best track state. |
| """ |
|
|
| def __init__(self) -> None: |
| def noop(tensor): |
| return tensor |
|
|
| super().__init__(noop, noop) |
|
|
|
|
| def get_act_offloading_ctx_manager( |
| model: nn.Module, enable_activation_offloading: bool |
| ) -> Union[OffloadActivations, contextlib.nullcontext]: |
| """Returns the activation offloading context manager for the model, which will be |
| a null context if enable_activation_offloading is False. |
| |
| If activation offloading is enabled, we return the OffloadActivations context manager. |
| If activation offloading is disabled, we return a NoOpManager context manager. |
| |
| Args: |
| model (nn.Module): the model to wrap with the activation offloading context manager. |
| enable_activation_offloading (bool): whether or not to enable activation offloading |
| for the model. |
| |
| Returns: |
| contextlib.ContextDecorator: the activation offloading context manager for the model. |
| |
| Raises: |
| NotImplementedError: If the model is a multimodal model and activation offloading is enabled. |
| """ |
| if enable_activation_offloading: |
| activations_handling_ctx = OffloadActivations() |
|
|
| |
| |
| |
| |
| output_head_detected = False |
| noop_ctx = NoOpManager() |
|
|
| if hasattr(model, "output"): |
| if isinstance(model.output, nn.Module): |
| model.output.register_forward_pre_hook( |
| lambda *args: noop_ctx.__enter__() |
| ) |
| model.output.register_forward_hook( |
| lambda *args: noop_ctx.__exit__(), always_call=True |
| ) |
| print("registering hooks for model.output ============ ") |
| output_head_detected = True |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if not output_head_detected: |
| logger.warning( |
| "During activation offloading, no output head was detected. " |
| "If your model has an output head, it will be offloaded. " |
| "This usually greatly slows training, given the large vocabulary size. " |
| "To change this behavior, set your output head as model.output and make it " |
| "an nn.Module." |
| ) |
|
|
| else: |
| activations_handling_ctx = contextlib.nullcontext() |
|
|
| return activations_handling_ctx |
|
|