| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| from collections import defaultdict |
| from functools import partial |
| from typing import Callable, Union |
|
|
| import torch |
| from torch import nn |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp._runtime_utils import _lazy_init |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| from torch.optim import Optimizer |
| from transformers import PreTrainedModel |
| from transformers.trainer_pt_utils import get_module_class_from_name |
|
|
|
|
| def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]: |
| param_occurrence = defaultdict(int) |
| for _, param in model.named_parameters(remove_duplicate=False): |
| param_occurrence[param] += 1 |
|
|
| duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1} |
| materialized_params = {} |
|
|
| def init_fn(module: nn.Module): |
| for name, param in module.named_parameters(recurse=False): |
| if param in duplicated_params: |
| module._parameters[name] = materialized_params.setdefault( |
| param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad) |
| ) |
| else: |
| module._parameters[name] = nn.Parameter( |
| torch.empty_like(param.data, device=device), requires_grad=param.requires_grad |
| ) |
|
|
| return init_fn |
|
|
|
|
| def get_fsdp_wrap_policy(model: PreTrainedModel): |
| """Get FSDP wrap policy for the model. |
| |
| Args: |
| module: The module to get wrap policy for |
| """ |
| transformer_cls_to_wrap = set() |
| for module in model._no_split_modules: |
| transformer_cls = get_module_class_from_name(model, module) |
| if transformer_cls is None: |
| raise Exception(f"Cannot find {module} in pretrained model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
|
|
| return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap) |
|
|
|
|
| @torch.no_grad() |
| def offload_fsdp_model(model: FSDP, empty_cache: bool = True): |
| |
| _lazy_init(model, model) |
| assert model._is_root, "Only support root model offloading to CPU" |
| for handle in model._all_handles: |
| if handle._offload_params: |
| continue |
|
|
| flat_param = handle.flat_param |
| assert ( |
| flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() |
| and id(flat_param.data) != id(flat_param._local_shard) |
| and flat_param.data.size() == flat_param._local_shard.size() |
| ) |
| handle.flat_param_to("cpu", non_blocking=True) |
| |
| flat_param._local_shard = flat_param.data |
| assert id(flat_param._local_shard) != id(flat_param.data) |
|
|
| if empty_cache: |
| torch.cuda.empty_cache() |
|
|
|
|
| @torch.no_grad() |
| def load_fsdp_model(model: FSDP, empty_cache: bool = True): |
| |
| _lazy_init(model, model) |
| assert model._is_root, "Only support root model loading to GPU" |
| for handle in model._all_handles: |
| if handle._offload_params: |
| continue |
|
|
| flat_param = handle.flat_param |
| handle.flat_param_to("cuda", non_blocking=True) |
| |
| flat_param._local_shard = flat_param.data |
|
|
| if empty_cache: |
| gc.collect() |
|
|
|
|
| @torch.no_grad() |
| def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): |
| if not optimizer.state: |
| return |
|
|
| for param_group in optimizer.param_groups: |
| for param in param_group["params"]: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to("cpu", non_blocking=True) |
|
|
| if empty_cache: |
| torch.cuda.empty_cache() |
|
|
|
|
| @torch.no_grad() |
| def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): |
| if not optimizer.state: |
| return |
|
|
| for param_group in optimizer.param_groups: |
| for param in param_group["params"]: |
| state = optimizer.state[param] |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to("cuda", non_blocking=True) |
|
|
| if empty_cache: |
| gc.collect() |
|
|