| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| from functools import partial |
| from types import MethodType |
| from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple |
|
|
| import torch |
|
|
| from ...extras.constants import LAYERNORM_NAMES |
| from ...extras.logging import get_logger |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import PreTrainedModel |
|
|
| from ...hparams import ModelArguments |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def _gradient_checkpointing_enable( |
| self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None |
| ) -> None: |
| r""" |
| Activates gradient checkpointing for the current model. |
| |
| Modification of the original method to enable gradient checkpointing for block-wise optimizer. |
| """ |
| from torch.utils.checkpoint import checkpoint |
|
|
| if not self.supports_gradient_checkpointing: |
| raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) |
|
|
| if gradient_checkpointing_kwargs is None: |
| gradient_checkpointing_kwargs = {"use_reentrant": True} |
|
|
| gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) |
|
|
| def custom_gradient_checkpointing_func(func, *args, **kwargs): |
| module: "torch.nn.Module" = func.__self__ |
|
|
| if any(param.requires_grad for param in module.parameters()): |
| for arg in args: |
| if torch.is_tensor(arg) and torch.is_floating_point(arg): |
| arg.requires_grad_(True) |
|
|
| return gradient_checkpointing_func(func, *args, **kwargs) |
|
|
| if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: |
| self.apply(partial(self._set_gradient_checkpointing, value=True)) |
| self.enable_input_require_grads() |
| logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") |
| else: |
| self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) |
|
|
|
|
| def _fp32_forward_post_hook( |
| module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" |
| ) -> "torch.Tensor": |
| return output.to(torch.float32) |
|
|
|
|
| def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: |
| r""" |
| Includes: |
| (1) cast the layernorm in fp32 |
| (2) make output embedding layer require grads |
| (3) add the upcasting of the lm_head in fp32 |
| """ |
| if model_args.upcast_layernorm: |
| logger.info("Upcasting layernorm weights in float32.") |
| for name, param in model.named_parameters(): |
| if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): |
| param.data = param.data.to(torch.float32) |
|
|
| if not model_args.disable_gradient_checkpointing: |
| if not getattr(model, "supports_gradient_checkpointing", False): |
| logger.warning("Current model does not support gradient checkpointing.") |
| else: |
| |
| |
| model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) |
| setattr(model.config, "use_cache", False) |
| logger.info("Gradient checkpointing enabled.") |
|
|
| if model_args.upcast_lmhead_output: |
| output_layer = model.get_output_embeddings() |
| if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: |
| logger.info("Upcasting lm_head outputs in float32.") |
| output_layer.register_forward_hook(_fp32_forward_post_hook) |
|
|