| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The Trainer class, to easily train a π€ Transformers from scratch or finetune it on a new task. |
| """ |
|
|
| import contextlib |
| import functools |
| import glob |
| import inspect |
| import math |
| import os |
| import random |
| import re |
| import shutil |
| import sys |
| import time |
| import warnings |
| from collections.abc import Mapping |
| from distutils.util import strtobool |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| from tqdm.auto import tqdm |
|
|
|
|
| |
| |
| from transformers.integrations import ( |
| default_hp_search_backend, |
| get_reporting_integration_callbacks, |
| hp_params, |
| is_fairscale_available, |
| is_optuna_available, |
| is_ray_tune_available, |
| is_sigopt_available, |
| is_wandb_available, |
| run_hp_search_optuna, |
| run_hp_search_ray, |
| run_hp_search_sigopt, |
| run_hp_search_wandb, |
| ) |
|
|
| |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from huggingface_hub import Repository, create_repo |
| from packaging import version |
| from torch import nn |
| from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from transformers import __version__ |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator |
| from transformers.debug_utils import DebugOption, DebugUnderflowOverflow |
| from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled |
| from transformers.dependency_versions_check import dep_version_check |
| from transformers.modelcard import TrainingSummary |
| from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model |
| from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES |
| from transformers.optimization import Adafactor, get_scheduler |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
| from transformers.trainer_callback import ( |
| CallbackHandler, |
| DefaultFlowCallback, |
| PrinterCallback, |
| ProgressCallback, |
| TrainerCallback, |
| TrainerControl, |
| TrainerState, |
| ) |
| from transformers.trainer_pt_utils import ( |
| DistributedLengthGroupedSampler, |
| DistributedSamplerWithLoop, |
| DistributedTensorGatherer, |
| IterableDatasetShard, |
| LabelSmoother, |
| LengthGroupedSampler, |
| SequentialDistributedSampler, |
| ShardSampler, |
| distributed_broadcast_scalars, |
| distributed_concat, |
| find_batch_size, |
| get_module_class_from_name, |
| get_parameter_names, |
| nested_concat, |
| nested_detach, |
| nested_numpify, |
| nested_truncate, |
| nested_xla_mesh_reduce, |
| reissue_pt_warnings, |
| ) |
| from transformers.trainer_utils import ( |
| PREFIX_CHECKPOINT_DIR, |
| BestRun, |
| EvalLoopOutput, |
| EvalPrediction, |
| FSDPOption, |
| HPSearchBackend, |
| HubStrategy, |
| IntervalStrategy, |
| PredictionOutput, |
| RemoveColumnsCollator, |
| ShardedDDPOption, |
| TrainerMemoryTracker, |
| TrainOutput, |
| default_compute_objective, |
| default_hp_space, |
| denumpify_detensorize, |
| enable_full_determinism, |
| find_executable_batch_size, |
| get_last_checkpoint, |
| has_length, |
| number_of_arguments, |
| seed_worker, |
| set_seed, |
| speed_metrics, |
| ) |
| from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments |
| from transformers.utils import ( |
| CONFIG_NAME, |
| WEIGHTS_INDEX_NAME, |
| WEIGHTS_NAME, |
| can_return_loss, |
| find_labels, |
| get_full_repo_name, |
| is_accelerate_available, |
| is_apex_available, |
| is_datasets_available, |
| is_in_notebook, |
| is_ipex_available, |
| is_sagemaker_dp_enabled, |
| is_sagemaker_mp_enabled, |
| is_torch_compile_available, |
| is_torch_neuroncore_available, |
| is_torch_tpu_available, |
| logging, |
| ) |
| from transformers.utils.generic import ContextManagers |
|
|
|
|
| _is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 |
|
|
| DEFAULT_CALLBACKS = [DefaultFlowCallback] |
| DEFAULT_PROGRESS_CALLBACK = ProgressCallback |
|
|
| if is_in_notebook(): |
| from transformers.utils.notebook import NotebookProgressCallback |
|
|
| DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback |
|
|
| if is_apex_available(): |
| from apex import amp |
|
|
| if is_datasets_available(): |
| import datasets |
|
|
| if is_torch_tpu_available(check_device=False): |
| import torch_xla.core.xla_model as xm |
| import torch_xla.debug.metrics as met |
| import torch_xla.distributed.parallel_loader as pl |
|
|
| if is_fairscale_available(): |
| dep_version_check("fairscale") |
| import fairscale |
| from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP |
| from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP |
| from fairscale.nn.wrap import auto_wrap |
| from fairscale.optim import OSS |
| from fairscale.optim.grad_scaler import ShardedGradScaler |
|
|
|
|
| if is_sagemaker_mp_enabled(): |
| import smdistributed.modelparallel.torch as smp |
| from smdistributed.modelparallel import __version__ as SMP_VERSION |
|
|
| IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") |
|
|
| from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat |
| else: |
| IS_SAGEMAKER_MP_POST_1_10 = False |
|
|
|
|
| skip_first_batches = None |
| if is_accelerate_available(): |
| from accelerate import __version__ as accelerate_version |
|
|
| if version.parse(accelerate_version) >= version.parse("0.16"): |
| from accelerate import skip_first_batches |
|
|
|
|
| if TYPE_CHECKING: |
| import optuna |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| TRAINING_ARGS_NAME = "training_args.bin" |
| TRAINER_STATE_NAME = "trainer_state.json" |
| OPTIMIZER_NAME = "optimizer.pt" |
| SCHEDULER_NAME = "scheduler.pt" |
| SCALER_NAME = "scaler.pt" |
|
|
|
|
| class Trainer: |
| """ |
| Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for π€ Transformers. |
| |
| Args: |
| model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): |
| The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. |
| |
| <Tip> |
| |
| [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use |
| your own models defined as `torch.nn.Module` as long as they work the same way as the π€ Transformers |
| models. |
| |
| </Tip> |
| |
| args ([`TrainingArguments`], *optional*): |
| The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the |
| `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. |
| data_collator (`DataCollator`, *optional*): |
| The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will |
| default to [`default_data_collator`] if no `tokenizer` is provided, an instance of |
| [`DataCollatorWithPadding`] otherwise. |
| train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): |
| The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the |
| `model.forward()` method are automatically removed. |
| |
| Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a |
| distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a |
| `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will |
| manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally |
| sets the seed of the RNGs used. |
| eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): |
| The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the |
| `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each |
| dataset prepending the dictionary key to the metric name. |
| tokenizer ([`PreTrainedTokenizerBase`], *optional*): |
| The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the |
| maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an |
| interrupted training or reuse the fine-tuned model. |
| model_init (`Callable[[], PreTrainedModel]`, *optional*): |
| A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start |
| from a new instance of the model as given by this function. |
| |
| The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to |
| be able to choose different architectures according to hyper parameters (such as layer count, sizes of |
| inner layers, dropout probabilities etc). |
| compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): |
| The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return |
| a dictionary string to metric values. |
| callbacks (List of [`TrainerCallback`], *optional*): |
| A list of callbacks to customize the training loop. Will add those to the list of default callbacks |
| detailed in [here](callback). |
| |
| If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. |
| optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple |
| containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model |
| and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. |
| preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): |
| A function that preprocess the logits right before caching them at each evaluation step. Must take two |
| tensors, the logits and the labels, and return the logits once processed as desired. The modifications made |
| by this function will be reflected in the predictions received by `compute_metrics`. |
| |
| Note that the labels (second parameter) will be `None` if the dataset does not have them. |
| |
| Important attributes: |
| |
| - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] |
| subclass. |
| - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the |
| original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, |
| the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner |
| model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. |
| - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from |
| data parallelism, this means some of the model layers are split on different GPUs). |
| - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set |
| to `False` if model parallel or deepspeed is used, or if the default |
| `TrainingArguments.place_model_on_device` is overridden to return `False` . |
| - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while |
| in `train`) |
| |
| """ |
|
|
| from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state |
|
|
| def __init__( |
| self, |
| model: Union[PreTrainedModel, nn.Module] = None, |
| args: TrainingArguments = None, |
| data_collator: Optional[DataCollator] = None, |
| train_dataset: Optional[Dataset] = None, |
| eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, |
| tokenizer: Optional[PreTrainedTokenizerBase] = None, |
| model_init: Optional[Callable[[], PreTrainedModel]] = None, |
| compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
| callbacks: Optional[List[TrainerCallback]] = None, |
| optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), |
| preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
| save_prefixencoder: bool = False, |
| ): |
| self.save_prefixencoder = save_prefixencoder |
| if args is None: |
| output_dir = "tmp_trainer" |
| logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") |
| args = TrainingArguments(output_dir=output_dir) |
| self.args = args |
| |
| enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) |
| self.hp_name = None |
| self.deepspeed = None |
| self.is_in_train = False |
|
|
| |
| self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) |
| self._memory_tracker.start() |
|
|
| |
| log_level = args.get_process_log_level() |
| logging.set_verbosity(log_level) |
|
|
| |
| args._setup_devices |
|
|
| if model is None: |
| if model_init is not None: |
| self.model_init = model_init |
| model = self.call_model_init() |
| else: |
| raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") |
| else: |
| if model_init is not None: |
| warnings.warn( |
| "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" |
| " overwrite your model when calling the `train` method. This will become a fatal error in the next" |
| " release.", |
| FutureWarning, |
| ) |
| self.model_init = model_init |
|
|
| if model.__class__.__name__ in MODEL_MAPPING_NAMES: |
| raise ValueError( |
| f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " |
| "computes hidden states and does not accept any labels. You should choose a model with a head " |
| "suitable for your task like any of the `AutoModelForXxx` listed at " |
| "https://huggingface.co/docs/transformers/model_doc/auto." |
| ) |
|
|
| if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: |
| self.is_model_parallel = True |
| else: |
| self.is_model_parallel = False |
|
|
| |
| if getattr(model, "is_loaded_in_8bit", False): |
| if getattr(model, "_is_int8_training_enabled", False): |
| logger.info( |
| "The model is loaded in 8-bit precision. To train this model you need to add additional modules" |
| " inside the model such as adapters using `peft` library and freeze the model weights. Please" |
| " check " |
| " the examples in https://github.com/huggingface/peft for more details." |
| ) |
| else: |
| raise ValueError( |
| "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" |
| " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " |
| ) |
|
|
| |
| self.sharded_ddp = None |
| if len(args.sharded_ddp) > 0: |
| if args.deepspeed: |
| raise ValueError( |
| "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." |
| ) |
| if len(args.fsdp) > 0: |
| raise ValueError( |
| "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." |
| ) |
|
|
| if args.local_rank == -1: |
| raise ValueError("Using sharded DDP only works in distributed training.") |
| elif not is_fairscale_available(): |
| raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") |
| elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: |
| raise ImportError( |
| "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " |
| f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." |
| ) |
| elif ShardedDDPOption.SIMPLE in args.sharded_ddp: |
| self.sharded_ddp = ShardedDDPOption.SIMPLE |
| elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: |
| self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 |
| elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: |
| self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 |
|
|
| self.fsdp = None |
| if len(args.fsdp) > 0: |
| if args.deepspeed: |
| raise ValueError( |
| "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." |
| ) |
| if not args.fsdp_config["xla"] and args.local_rank == -1: |
| raise ValueError("Using fsdp only works in distributed training.") |
|
|
| |
| |
| |
| |
| if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): |
| raise ValueError("FSDP requires PyTorch >= 1.12.0") |
|
|
| from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy |
|
|
| if FSDPOption.FULL_SHARD in args.fsdp: |
| self.fsdp = ShardingStrategy.FULL_SHARD |
| elif FSDPOption.SHARD_GRAD_OP in args.fsdp: |
| self.fsdp = ShardingStrategy.SHARD_GRAD_OP |
| elif FSDPOption.NO_SHARD in args.fsdp: |
| self.fsdp = ShardingStrategy.NO_SHARD |
|
|
| self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE |
| if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch: |
| self.backward_prefetch = BackwardPrefetch.BACKWARD_POST |
|
|
| self.forword_prefetch = False |
| if self.args.fsdp_config.get("forword_prefect", False): |
| self.forword_prefetch = True |
|
|
| self.limit_all_gathers = False |
| if self.args.fsdp_config.get("limit_all_gathers", False): |
| self.limit_all_gathers = True |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| self.place_model_on_device = args.place_model_on_device |
| if ( |
| self.is_model_parallel |
| or args.deepspeed |
| or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) |
| or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) |
| or (self.fsdp is not None) |
| ): |
| self.place_model_on_device = False |
|
|
| default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) |
| self.data_collator = data_collator if data_collator is not None else default_collator |
| self.train_dataset = train_dataset |
| self.eval_dataset = eval_dataset |
| self.tokenizer = tokenizer |
|
|
| if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): |
| self._move_model_to_device(model, args.device) |
|
|
| |
| if self.is_model_parallel: |
| self.args._n_gpu = 1 |
|
|
| |
| self.model_wrapped = model |
| self.model = model |
|
|
| self.compute_metrics = compute_metrics |
| self.preprocess_logits_for_metrics = preprocess_logits_for_metrics |
| self.optimizer, self.lr_scheduler = optimizers |
| if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): |
| raise RuntimeError( |
| "Passing a `model_init` is incompatible with providing the `optimizers` argument. " |
| "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." |
| ) |
| if is_torch_tpu_available() and self.optimizer is not None: |
| for param in self.model.parameters(): |
| model_device = param.device |
| break |
| for param_group in self.optimizer.param_groups: |
| if len(param_group["params"]) > 0: |
| optimizer_device = param_group["params"][0].device |
| break |
| if model_device != optimizer_device: |
| raise ValueError( |
| "The model and the optimizer parameters are not on the same device, which probably means you" |
| " created an optimizer around your model **before** putting on the device and passing it to the" |
| " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" |
| " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." |
| ) |
| if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( |
| self.optimizer is not None or self.lr_scheduler is not None |
| ): |
| raise RuntimeError( |
| "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." |
| "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." |
| ) |
| default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) |
| callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks |
| self.callback_handler = CallbackHandler( |
| callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler |
| ) |
| self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) |
|
|
| |
| self._loggers_initialized = False |
|
|
| |
| if self.args.push_to_hub: |
| self.init_git_repo(at_init=True) |
| |
| if is_torch_tpu_available(): |
| xm.rendezvous("init git repo") |
| elif args.local_rank != -1: |
| dist.barrier() |
|
|
| if self.args.should_save: |
| os.makedirs(self.args.output_dir, exist_ok=True) |
|
|
| if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): |
| raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") |
|
|
| if args.max_steps > 0: |
| logger.info("max_steps is given, it will override any value given in num_train_epochs") |
|
|
| if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: |
| raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") |
|
|
| if ( |
| train_dataset is not None |
| and isinstance(train_dataset, torch.utils.data.IterableDataset) |
| and args.group_by_length |
| ): |
| raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") |
|
|
| self._signature_columns = None |
|
|
| |
| self.use_apex = False |
| self.use_cuda_amp = False |
| self.use_cpu_amp = False |
|
|
| |
| if is_sagemaker_mp_enabled(): |
| |
| if args.bf16: |
| raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") |
|
|
| if IS_SAGEMAKER_MP_POST_1_10: |
| |
| if args.fp16 != smp.state.cfg.fp16: |
| logger.warning( |
| f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," |
| f"but FP16 provided in trainer argument is {args.fp16}," |
| f"setting to {smp.state.cfg.fp16}" |
| ) |
| args.fp16 = smp.state.cfg.fp16 |
| else: |
| |
| if hasattr(smp.state.cfg, "fp16"): |
| logger.warning( |
| f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " |
| "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." |
| ) |
|
|
| if args.fp16 or args.bf16: |
| if args.half_precision_backend == "auto": |
| if args.device == torch.device("cpu"): |
| if args.fp16: |
| raise ValueError("Tried to use `fp16` but it is not supported on cpu") |
| elif _is_native_cpu_amp_available: |
| args.half_precision_backend = "cpu_amp" |
| else: |
| raise ValueError("Tried to use cpu amp but native cpu amp is not available") |
| else: |
| args.half_precision_backend = "cuda_amp" |
|
|
| logger.info(f"Using {args.half_precision_backend} half precision backend") |
|
|
| self.do_grad_scaling = False |
| if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): |
| |
| if args.half_precision_backend == "cuda_amp": |
| self.use_cuda_amp = True |
| self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 |
| |
| self.do_grad_scaling = self.amp_dtype == torch.float16 |
| if self.do_grad_scaling: |
| if self.sharded_ddp is not None: |
| self.scaler = ShardedGradScaler() |
| elif self.fsdp is not None: |
| from torch.distributed.fsdp.sharded_grad_scaler import ( |
| ShardedGradScaler as FSDPShardedGradScaler, |
| ) |
|
|
| self.scaler = FSDPShardedGradScaler() |
| elif is_torch_tpu_available(): |
| from torch_xla.amp import GradScaler |
|
|
| self.scaler = GradScaler() |
| else: |
| self.scaler = torch.cuda.amp.GradScaler() |
| elif args.half_precision_backend == "cpu_amp": |
| self.use_cpu_amp = True |
| self.amp_dtype = torch.bfloat16 |
| else: |
| if not is_apex_available(): |
| raise ImportError( |
| "Using FP16 with APEX but APEX is not installed, please refer to" |
| " https://www.github.com/nvidia/apex." |
| ) |
| self.use_apex = True |
|
|
| |
| if ( |
| is_sagemaker_mp_enabled() |
| and self.use_cuda_amp |
| and args.max_grad_norm is not None |
| and args.max_grad_norm > 0 |
| ): |
| raise ValueError( |
| "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " |
| "along 'max_grad_norm': 0 in your hyperparameters." |
| ) |
|
|
| |
| if self.args.label_smoothing_factor != 0: |
| self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) |
| else: |
| self.label_smoother = None |
|
|
| self.state = TrainerState( |
| is_local_process_zero=self.is_local_process_zero(), |
| is_world_process_zero=self.is_world_process_zero(), |
| ) |
|
|
| self.control = TrainerControl() |
| |
| |
| self.current_flos = 0 |
| self.hp_search_backend = None |
| self.use_tune_checkpoints = False |
| default_label_names = find_labels(self.model.__class__) |
| self.label_names = default_label_names if self.args.label_names is None else self.args.label_names |
| self.can_return_loss = can_return_loss(self.model.__class__) |
| self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) |
|
|
| |
| self._train_batch_size = args.train_batch_size |
|
|
| |
| self._memory_tracker.stop_and_update_metrics() |
|
|
| |
| if args.torch_compile and not is_torch_compile_available(): |
| raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") |
|
|
| def add_callback(self, callback): |
| """ |
| Add a callback to the current list of [`~transformer.TrainerCallback`]. |
| |
| Args: |
| callback (`type` or [`~transformer.TrainerCallback`]): |
| A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the |
| first case, will instantiate a member of that class. |
| """ |
| self.callback_handler.add_callback(callback) |
|
|
| def pop_callback(self, callback): |
| """ |
| Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. |
| |
| If the callback is not found, returns `None` (and no error is raised). |
| |
| Args: |
| callback (`type` or [`~transformer.TrainerCallback`]): |
| A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the |
| first case, will pop the first member of that class found in the list of callbacks. |
| |
| Returns: |
| [`~transformer.TrainerCallback`]: The callback removed, if found. |
| """ |
| return self.callback_handler.pop_callback(callback) |
|
|
| def remove_callback(self, callback): |
| """ |
| Remove a callback from the current list of [`~transformer.TrainerCallback`]. |
| |
| Args: |
| callback (`type` or [`~transformer.TrainerCallback`]): |
| A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the |
| first case, will remove the first member of that class found in the list of callbacks. |
| """ |
| self.callback_handler.remove_callback(callback) |
|
|
| def _move_model_to_device(self, model, device): |
| model = model.to(device) |
| |
| if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): |
| model.tie_weights() |
|
|
| def _set_signature_columns_if_needed(self): |
| if self._signature_columns is None: |
| |
| signature = inspect.signature(self.model.forward) |
| self._signature_columns = list(signature.parameters.keys()) |
| |
| self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) |
|
|
| def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): |
| if not self.args.remove_unused_columns: |
| return dataset |
| self._set_signature_columns_if_needed() |
| signature_columns = self._signature_columns |
|
|
| ignored_columns = list(set(dataset.column_names) - set(signature_columns)) |
| if len(ignored_columns) > 0: |
| dset_description = "" if description is None else f"in the {description} set" |
| logger.info( |
| f"The following columns {dset_description} don't have a corresponding argument in " |
| f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." |
| f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " |
| " you can safely ignore this message." |
| ) |
|
|
| columns = [k for k in signature_columns if k in dataset.column_names] |
|
|
| if version.parse(datasets.__version__) < version.parse("1.4.0"): |
| dataset.set_format( |
| type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] |
| ) |
| return dataset |
| else: |
| return dataset.remove_columns(ignored_columns) |
|
|
| def _get_collator_with_removed_columns( |
| self, data_collator: Callable, description: Optional[str] = None |
| ) -> Callable: |
| """Wrap the data collator in a callable removing unused columns.""" |
| if not self.args.remove_unused_columns: |
| return data_collator |
| self._set_signature_columns_if_needed() |
| signature_columns = self._signature_columns |
|
|
| remove_columns_collator = RemoveColumnsCollator( |
| data_collator=data_collator, |
| signature_columns=signature_columns, |
| logger=logger, |
| description=description, |
| model_name=self.model.__class__.__name__, |
| ) |
| return remove_columns_collator |
|
|
| def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
| if self.train_dataset is None or not has_length(self.train_dataset): |
| return None |
|
|
| generator = None |
| if self.args.world_size <= 1: |
| generator = torch.Generator() |
| |
| |
| |
| if self.args.data_seed is None: |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
| else: |
| seed = self.args.data_seed |
| generator.manual_seed(seed) |
|
|
| seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed |
|
|
| |
| if self.args.group_by_length: |
| if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): |
| lengths = ( |
| self.train_dataset[self.args.length_column_name] |
| if self.args.length_column_name in self.train_dataset.column_names |
| else None |
| ) |
| else: |
| lengths = None |
| model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None |
| if self.args.world_size <= 1: |
| return LengthGroupedSampler( |
| self.args.train_batch_size * self.args.gradient_accumulation_steps, |
| dataset=self.train_dataset, |
| lengths=lengths, |
| model_input_name=model_input_name, |
| generator=generator, |
| ) |
| else: |
| return DistributedLengthGroupedSampler( |
| self.args.train_batch_size * self.args.gradient_accumulation_steps, |
| dataset=self.train_dataset, |
| num_replicas=self.args.world_size, |
| rank=self.args.process_index, |
| lengths=lengths, |
| model_input_name=model_input_name, |
| seed=seed, |
| ) |
|
|
| else: |
| if self.args.world_size <= 1: |
| return RandomSampler(self.train_dataset, generator=generator) |
| elif ( |
| self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] |
| and not self.args.dataloader_drop_last |
| ): |
| |
| return DistributedSamplerWithLoop( |
| self.train_dataset, |
| batch_size=self.args.per_device_train_batch_size, |
| num_replicas=self.args.world_size, |
| rank=self.args.process_index, |
| seed=seed, |
| ) |
| else: |
| return DistributedSampler( |
| self.train_dataset, |
| num_replicas=self.args.world_size, |
| rank=self.args.process_index, |
| seed=seed, |
| ) |
|
|
| def get_train_dataloader(self) -> DataLoader: |
| """ |
| Returns the training [`~torch.utils.data.DataLoader`]. |
| |
| Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed |
| training if necessary) otherwise. |
| |
| Subclass and override this method if you want to inject some custom behavior. |
| """ |
| if self.train_dataset is None: |
| raise ValueError("Trainer: training requires a train_dataset.") |
|
|
| train_dataset = self.train_dataset |
| data_collator = self.data_collator |
| if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): |
| train_dataset = self._remove_unused_columns(train_dataset, description="training") |
| else: |
| data_collator = self._get_collator_with_removed_columns(data_collator, description="training") |
|
|
| if isinstance(train_dataset, torch.utils.data.IterableDataset): |
| if self.args.world_size > 1: |
| train_dataset = IterableDatasetShard( |
| train_dataset, |
| batch_size=self._train_batch_size, |
| drop_last=self.args.dataloader_drop_last, |
| num_processes=self.args.world_size, |
| process_index=self.args.process_index, |
| ) |
|
|
| return DataLoader( |
| train_dataset, |
| batch_size=self._train_batch_size, |
| collate_fn=data_collator, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| train_sampler = self._get_train_sampler() |
|
|
| return DataLoader( |
| train_dataset, |
| batch_size=self._train_batch_size, |
| sampler=train_sampler, |
| collate_fn=data_collator, |
| drop_last=self.args.dataloader_drop_last, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| worker_init_fn=seed_worker, |
| ) |
|
|
| def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: |
| |
| if self.args.use_legacy_prediction_loop: |
| if is_torch_tpu_available(): |
| return SequentialDistributedSampler( |
| eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() |
| ) |
| elif is_sagemaker_mp_enabled(): |
| return SequentialDistributedSampler( |
| eval_dataset, |
| num_replicas=smp.dp_size(), |
| rank=smp.dp_rank(), |
| batch_size=self.args.per_device_eval_batch_size, |
| ) |
| elif self.args.local_rank != -1: |
| return SequentialDistributedSampler(eval_dataset) |
| else: |
| return SequentialSampler(eval_dataset) |
|
|
| if self.args.world_size <= 1: |
| return SequentialSampler(eval_dataset) |
| else: |
| return ShardSampler( |
| eval_dataset, |
| batch_size=self.args.per_device_eval_batch_size, |
| num_processes=self.args.world_size, |
| process_index=self.args.process_index, |
| ) |
|
|
| def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| """ |
| Returns the evaluation [`~torch.utils.data.DataLoader`]. |
| |
| Subclass and override this method if you want to inject some custom behavior. |
| |
| Args: |
| eval_dataset (`torch.utils.data.Dataset`, *optional*): |
| If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted |
| by the `model.forward()` method are automatically removed. It must implement `__len__`. |
| """ |
| if eval_dataset is None and self.eval_dataset is None: |
| raise ValueError("Trainer: evaluation requires an eval_dataset.") |
| eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset |
| data_collator = self.data_collator |
|
|
| if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): |
| eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") |
| else: |
| data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") |
|
|
| if isinstance(eval_dataset, torch.utils.data.IterableDataset): |
| if self.args.world_size > 1: |
| eval_dataset = IterableDatasetShard( |
| eval_dataset, |
| batch_size=self.args.per_device_eval_batch_size, |
| drop_last=self.args.dataloader_drop_last, |
| num_processes=self.args.world_size, |
| process_index=self.args.process_index, |
| ) |
| return DataLoader( |
| eval_dataset, |
| batch_size=self.args.eval_batch_size, |
| collate_fn=data_collator, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| eval_sampler = self._get_eval_sampler(eval_dataset) |
|
|
| return DataLoader( |
| eval_dataset, |
| sampler=eval_sampler, |
| batch_size=self.args.eval_batch_size, |
| collate_fn=data_collator, |
| drop_last=self.args.dataloader_drop_last, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: |
| """ |
| Returns the test [`~torch.utils.data.DataLoader`]. |
| |
| Subclass and override this method if you want to inject some custom behavior. |
| |
| Args: |
| test_dataset (`torch.utils.data.Dataset`, *optional*): |
| The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the |
| `model.forward()` method are automatically removed. It must implement `__len__`. |
| """ |
| data_collator = self.data_collator |
|
|
| if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): |
| test_dataset = self._remove_unused_columns(test_dataset, description="test") |
| else: |
| data_collator = self._get_collator_with_removed_columns(data_collator, description="test") |
|
|
| if isinstance(test_dataset, torch.utils.data.IterableDataset): |
| if self.args.world_size > 1: |
| test_dataset = IterableDatasetShard( |
| test_dataset, |
| batch_size=self.args.eval_batch_size, |
| drop_last=self.args.dataloader_drop_last, |
| num_processes=self.args.world_size, |
| process_index=self.args.process_index, |
| ) |
| return DataLoader( |
| test_dataset, |
| batch_size=self.args.eval_batch_size, |
| collate_fn=data_collator, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| test_sampler = self._get_eval_sampler(test_dataset) |
|
|
| |
| return DataLoader( |
| test_dataset, |
| sampler=test_sampler, |
| batch_size=self.args.eval_batch_size, |
| collate_fn=data_collator, |
| drop_last=self.args.dataloader_drop_last, |
| num_workers=self.args.dataloader_num_workers, |
| pin_memory=self.args.dataloader_pin_memory, |
| ) |
|
|
| def create_optimizer_and_scheduler(self, num_training_steps: int): |
| """ |
| Setup the optimizer and the learning rate scheduler. |
| |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
| Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or |
| `create_scheduler`) in a subclass. |
| """ |
| self.create_optimizer() |
| if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: |
| |
| optimizer = self.optimizer.optimizer |
| else: |
| optimizer = self.optimizer |
| self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) |
|
|
| def create_optimizer(self): |
| """ |
| Setup the optimizer. |
| |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
| Trainer's init through `optimizers`, or subclass and override this method in a subclass. |
| """ |
| opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model |
|
|
| if self.optimizer is None: |
| decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) |
| decay_parameters = [name for name in decay_parameters if "bias" not in name] |
| optimizer_grouped_parameters = [ |
| { |
| "params": [ |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) |
| ], |
| "weight_decay": self.args.weight_decay, |
| }, |
| { |
| "params": [ |
| p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) |
| ], |
| "weight_decay": 0.0, |
| }, |
| ] |
|
|
| optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) |
|
|
| if self.sharded_ddp == ShardedDDPOption.SIMPLE: |
| self.optimizer = OSS( |
| params=optimizer_grouped_parameters, |
| optim=optimizer_cls, |
| **optimizer_kwargs, |
| ) |
| else: |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
| if optimizer_cls.__name__ == "Adam8bit": |
| import bitsandbytes |
|
|
| manager = bitsandbytes.optim.GlobalOptimManager.get_instance() |
|
|
| skipped = 0 |
| for module in opt_model.modules(): |
| if isinstance(module, nn.Embedding): |
| skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) |
| print(f"skipped {module}: {skipped/2**20}M params") |
| manager.register_module_override(module, "weight", {"optim_bits": 32}) |
| logger.debug(f"bitsandbytes: will optimize {module} in fp32") |
| print(f"skipped: {skipped/2**20}M params") |
|
|
| if is_sagemaker_mp_enabled(): |
| self.optimizer = smp.DistributedOptimizer(self.optimizer) |
|
|
| return self.optimizer |
|
|
| @staticmethod |
| def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: |
| """ |
| Returns the optimizer class and optimizer parameters based on the training arguments. |
| |
| Args: |
| args (`transformers.training_args.TrainingArguments`): |
| The training arguments for the training session. |
| |
| """ |
|
|
| |
| optim_args = {} |
| if args.optim_args: |
| for mapping in args.optim_args.replace(" ", "").split(","): |
| key, value = mapping.split("=") |
| optim_args[key] = value |
|
|
| optimizer_kwargs = {"lr": args.learning_rate} |
|
|
| adam_kwargs = { |
| "betas": (args.adam_beta1, args.adam_beta2), |
| "eps": args.adam_epsilon, |
| } |
| if args.optim == OptimizerNames.ADAFACTOR: |
| optimizer_cls = Adafactor |
| optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) |
| elif args.optim == OptimizerNames.ADAMW_HF: |
| from transformers.optimization import AdamW |
|
|
| optimizer_cls = AdamW |
| optimizer_kwargs.update(adam_kwargs) |
| elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: |
| from torch.optim import AdamW |
|
|
| optimizer_cls = AdamW |
| optimizer_kwargs.update(adam_kwargs) |
| if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: |
| optimizer_kwargs.update({"fused": True}) |
| elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: |
| try: |
| from torch_xla.amp.syncfree import AdamW |
|
|
| optimizer_cls = AdamW |
| optimizer_kwargs.update(adam_kwargs) |
| except ImportError: |
| raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") |
| elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: |
| try: |
| from apex.optimizers import FusedAdam |
|
|
| optimizer_cls = FusedAdam |
| optimizer_kwargs.update(adam_kwargs) |
| except ImportError: |
| raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") |
| elif args.optim == OptimizerNames.ADAMW_BNB: |
| try: |
| from bitsandbytes.optim import Adam8bit |
|
|
| optimizer_cls = Adam8bit |
| optimizer_kwargs.update(adam_kwargs) |
| except ImportError: |
| raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") |
| elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: |
| try: |
| from torchdistx.optimizers import AnyPrecisionAdamW |
|
|
| optimizer_cls = AnyPrecisionAdamW |
| optimizer_kwargs.update(adam_kwargs) |
|
|
| |
| optimizer_kwargs.update( |
| { |
| "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), |
| "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), |
| "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), |
| "compensation_buffer_dtype": getattr( |
| torch, optim_args.get("compensation_buffer_dtype", "bfloat16") |
| ), |
| } |
| ) |
| except ImportError: |
| raise ValueError("Please install https://github.com/pytorch/torchdistx") |
| elif args.optim == OptimizerNames.SGD: |
| optimizer_cls = torch.optim.SGD |
| elif args.optim == OptimizerNames.ADAGRAD: |
| optimizer_cls = torch.optim.Adagrad |
| else: |
| raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") |
| return optimizer_cls, optimizer_kwargs |
|
|
| def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): |
| """ |
| Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or |
| passed as an argument. |
| |
| Args: |
| num_training_steps (int): The number of training steps to do. |
| """ |
| if self.lr_scheduler is None: |
| self.lr_scheduler = get_scheduler( |
| self.args.lr_scheduler_type, |
| optimizer=self.optimizer if optimizer is None else optimizer, |
| num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
| num_training_steps=num_training_steps, |
| ) |
| return self.lr_scheduler |
|
|
| def num_examples(self, dataloader: DataLoader) -> int: |
| """ |
| Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When |
| dataloader.dataset does not exist or has no length, estimates as best it can |
| """ |
| try: |
| dataset = dataloader.dataset |
| |
| if isinstance(dataset, IterableDatasetShard): |
| return len(dataloader.dataset.dataset) |
| return len(dataloader.dataset) |
| except (NameError, AttributeError, TypeError): |
| return len(dataloader) * self.args.per_device_train_batch_size |
|
|
| def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): |
| """HP search setup code""" |
| self._trial = trial |
|
|
| if self.hp_search_backend is None or trial is None: |
| return |
| if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| params = self.hp_space(trial) |
| elif self.hp_search_backend == HPSearchBackend.RAY: |
| params = trial |
| params.pop("wandb", None) |
| elif self.hp_search_backend == HPSearchBackend.SIGOPT: |
| params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} |
| elif self.hp_search_backend == HPSearchBackend.WANDB: |
| params = trial |
|
|
| for key, value in params.items(): |
| if not hasattr(self.args, key): |
| logger.warning( |
| f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" |
| " `TrainingArguments`." |
| ) |
| continue |
| old_attr = getattr(self.args, key, None) |
| |
| if old_attr is not None: |
| value = type(old_attr)(value) |
| setattr(self.args, key, value) |
| if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| logger.info(f"Trial: {trial.params}") |
| if self.hp_search_backend == HPSearchBackend.SIGOPT: |
| logger.info(f"SigOpt Assignments: {trial.assignments}") |
| if self.hp_search_backend == HPSearchBackend.WANDB: |
| logger.info(f"W&B Sweep parameters: {trial}") |
| if self.args.deepspeed: |
| |
| from transformers.deepspeed import HfTrainerDeepSpeedConfig |
|
|
| self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) |
| self.args.hf_deepspeed_config.trainer_config_process(self.args) |
|
|
| def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): |
| if self.hp_search_backend is None or trial is None: |
| return |
| self.objective = self.compute_objective(metrics.copy()) |
| if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| import optuna |
|
|
| trial.report(self.objective, step) |
| if trial.should_prune(): |
| self.callback_handler.on_train_end(self.args, self.state, self.control) |
| raise optuna.TrialPruned() |
| elif self.hp_search_backend == HPSearchBackend.RAY: |
| from ray import tune |
|
|
| if self.control.should_save: |
| self._tune_save_checkpoint() |
| tune.report(objective=self.objective, **metrics) |
|
|
| def _tune_save_checkpoint(self): |
| from ray import tune |
|
|
| if not self.use_tune_checkpoints: |
| return |
| with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: |
| output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") |
| self.save_model(output_dir, _internal_call=True) |
| if self.args.should_save: |
| self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
| torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) |
| torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) |
|
|
| def call_model_init(self, trial=None): |
| model_init_argcount = number_of_arguments(self.model_init) |
| if model_init_argcount == 0: |
| model = self.model_init() |
| elif model_init_argcount == 1: |
| model = self.model_init(trial) |
| else: |
| raise RuntimeError("model_init should have 0 or 1 argument.") |
|
|
| if model is None: |
| raise RuntimeError("model_init should not return None.") |
|
|
| return model |
|
|
| def torch_jit_model_eval(self, model, dataloader, training=False): |
| if not training: |
| if dataloader is None: |
| logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") |
| return model |
| example_batch = next(iter(dataloader)) |
| example_batch = self._prepare_inputs(example_batch) |
| try: |
| jit_model = model.eval() |
| with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): |
| if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): |
| if isinstance(example_batch, dict): |
| jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) |
| else: |
| jit_model = torch.jit.trace( |
| jit_model, |
| example_kwarg_inputs={key: example_batch[key] for key in example_batch}, |
| strict=False, |
| ) |
| else: |
| jit_inputs = [] |
| for key in example_batch: |
| example_tensor = torch.ones_like(example_batch[key]) |
| jit_inputs.append(example_tensor) |
| jit_inputs = tuple(jit_inputs) |
| jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) |
| jit_model = torch.jit.freeze(jit_model) |
| with torch.no_grad(): |
| jit_model(**example_batch) |
| jit_model(**example_batch) |
| model = jit_model |
| self.use_cpu_amp = False |
| self.use_cuda_amp = False |
| except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: |
| logger.warning(f"failed to use PyTorch jit mode due to: {e}.") |
|
|
| return model |
|
|
| def ipex_optimize_model(self, model, training=False, dtype=torch.float32): |
| if not is_ipex_available(): |
| raise ImportError( |
| "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" |
| " to https://github.com/intel/intel-extension-for-pytorch." |
| ) |
|
|
| import intel_extension_for_pytorch as ipex |
|
|
| if not training: |
| model.eval() |
| dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype |
| |
| model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) |
| else: |
| if not model.training: |
| model.train() |
| model, self.optimizer = ipex.optimize( |
| model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" |
| ) |
|
|
| return model |
|
|
| def _wrap_model(self, model, training=True, dataloader=None): |
| if self.args.torch_compile: |
| model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) |
|
|
| if self.args.use_ipex: |
| dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 |
| model = self.ipex_optimize_model(model, training, dtype=dtype) |
|
|
| if is_sagemaker_mp_enabled(): |
| |
| if isinstance(self.model_wrapped, smp.model.DistributedModel): |
| return self.model_wrapped |
| return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) |
|
|
| |
| if self.deepspeed: |
| return self.deepspeed |
|
|
| |
| if unwrap_model(model) is not model: |
| return model |
|
|
| |
| if self.use_apex and training: |
| model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) |
|
|
| |
| if self.args.n_gpu > 1: |
| model = nn.DataParallel(model) |
|
|
| if self.args.jit_mode_eval: |
| start_time = time.time() |
| model = self.torch_jit_model_eval(model, dataloader, training) |
| self.jit_compilation_time = round(time.time() - start_time, 4) |
|
|
| |
| |
| if not training: |
| return model |
|
|
| |
| if self.sharded_ddp is not None: |
| |
| if self.sharded_ddp == ShardedDDPOption.SIMPLE: |
| model = ShardedDDP(model, self.optimizer) |
| else: |
| mixed_precision = self.args.fp16 or self.args.bf16 |
| cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp |
| zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 |
| |
| if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: |
| model = auto_wrap(model) |
| self.model = model = FullyShardedDDP( |
| model, |
| mixed_precision=mixed_precision, |
| reshard_after_forward=zero_3, |
| cpu_offload=cpu_offload, |
| ).to(self.args.device) |
| |
| elif self.fsdp is not None: |
| if not self.args.fsdp_config["xla"]: |
| |
| from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision |
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
| if FSDPOption.OFFLOAD in self.args.fsdp: |
| cpu_offload = CPUOffload(offload_params=True) |
| else: |
| cpu_offload = CPUOffload(offload_params=False) |
|
|
| auto_wrap_policy = None |
|
|
| if FSDPOption.AUTO_WRAP in self.args.fsdp: |
| if self.args.fsdp_config["fsdp_min_num_params"] > 0: |
| auto_wrap_policy = functools.partial( |
| size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] |
| ) |
| elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: |
| transformer_cls_to_wrap = set() |
| for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: |
| transformer_cls = get_module_class_from_name(model, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
| auto_wrap_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
| mixed_precision_policy = None |
| dtype = None |
| if self.args.fp16: |
| dtype = torch.float16 |
| elif self.args.bf16: |
| dtype = torch.bfloat16 |
| if dtype is not None: |
| mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) |
| if type(model) != FSDP: |
| |
| self.model = model = FSDP( |
| model, |
| sharding_strategy=self.fsdp, |
| cpu_offload=cpu_offload, |
| auto_wrap_policy=auto_wrap_policy, |
| mixed_precision=mixed_precision_policy, |
| device_id=self.args.device, |
| backward_prefetch=self.backward_prefetch, |
| forward_prefetch=self.forword_prefetch, |
| limit_all_gathers=self.limit_all_gathers, |
| ) |
| else: |
| try: |
| from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP |
| from torch_xla.distributed.fsdp import checkpoint_module |
| from torch_xla.distributed.fsdp.wrap import ( |
| size_based_auto_wrap_policy, |
| transformer_auto_wrap_policy, |
| ) |
| except ImportError: |
| raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") |
| auto_wrap_policy = None |
| auto_wrapper_callable = None |
| if self.args.fsdp_config["fsdp_min_num_params"] > 0: |
| auto_wrap_policy = functools.partial( |
| size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] |
| ) |
| elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: |
| transformer_cls_to_wrap = set() |
| for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: |
| transformer_cls = get_module_class_from_name(model, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
| auto_wrap_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
| fsdp_kwargs = self.args.xla_fsdp_config |
| if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: |
| |
| def auto_wrapper_callable(m, *args, **kwargs): |
| return FSDP(checkpoint_module(m), *args, **kwargs) |
|
|
| |
| self.model = model = FSDP( |
| model, |
| auto_wrap_policy=auto_wrap_policy, |
| auto_wrapper_callable=auto_wrapper_callable, |
| **fsdp_kwargs, |
| ) |
|
|
| |
| |
| def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): |
| loss = optimizer.step(**optimizer_args) |
| if barrier: |
| xm.mark_step() |
| return loss |
|
|
| xm.optimizer_step = patched_optimizer_step |
| elif is_sagemaker_dp_enabled(): |
| model = nn.parallel.DistributedDataParallel( |
| model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] |
| ) |
| elif self.args.local_rank != -1: |
| kwargs = {} |
| if self.args.ddp_find_unused_parameters is not None: |
| kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters |
| elif isinstance(model, PreTrainedModel): |
| |
| |
| kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing |
| else: |
| kwargs["find_unused_parameters"] = True |
|
|
| if self.args.ddp_bucket_cap_mb is not None: |
| kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb |
| if is_torch_neuroncore_available(): |
| return model |
| model = nn.parallel.DistributedDataParallel( |
| model, |
| device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, |
| output_device=self.args.local_rank if self.args._n_gpu != 0 else None, |
| **kwargs, |
| ) |
|
|
| return model |
|
|
| def train( |
| self, |
| resume_from_checkpoint: Optional[Union[str, bool]] = None, |
| trial: Union["optuna.Trial", Dict[str, Any]] = None, |
| ignore_keys_for_eval: Optional[List[str]] = None, |
| **kwargs, |
| ): |
| """ |
| Main training entry point. |
| |
| Args: |
| resume_from_checkpoint (`str` or `bool`, *optional*): |
| If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a |
| `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance |
| of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. |
| trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): |
| The trial run or the hyperparameter dictionary for hyperparameter search. |
| ignore_keys_for_eval (`List[str]`, *optional*) |
| A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| gathering predictions for evaluation during the training. |
| kwargs: |
| Additional keyword arguments used to hide deprecated arguments |
| """ |
| if resume_from_checkpoint is False: |
| resume_from_checkpoint = None |
|
|
| |
| self._memory_tracker.start() |
|
|
| args = self.args |
|
|
| self.is_in_train = True |
|
|
| |
| |
| if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: |
| self._move_model_to_device(self.model, args.device) |
|
|
| if "model_path" in kwargs: |
| resume_from_checkpoint = kwargs.pop("model_path") |
| warnings.warn( |
| "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " |
| "instead.", |
| FutureWarning, |
| ) |
| if len(kwargs) > 0: |
| raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") |
| |
| self._hp_search_setup(trial) |
| self._train_batch_size = self.args.train_batch_size |
|
|
| |
| model_reloaded = False |
| if self.model_init is not None: |
| |
| enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) |
| self.model = self.call_model_init(trial) |
| model_reloaded = True |
| |
| self.optimizer, self.lr_scheduler = None, None |
|
|
| |
| if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: |
| resume_from_checkpoint = get_last_checkpoint(args.output_dir) |
| if resume_from_checkpoint is None: |
| raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") |
|
|
| if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None: |
| self._load_from_checkpoint(resume_from_checkpoint) |
|
|
| |
| if model_reloaded: |
| if self.place_model_on_device: |
| self._move_model_to_device(self.model, args.device) |
| self.model_wrapped = self.model |
|
|
| inner_training_loop = find_executable_batch_size( |
| self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size |
| ) |
| return inner_training_loop( |
| args=args, |
| resume_from_checkpoint=resume_from_checkpoint, |
| trial=trial, |
| ignore_keys_for_eval=ignore_keys_for_eval, |
| ) |
|
|
| def _inner_training_loop( |
| self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None |
| ): |
| self._train_batch_size = batch_size |
| |
| train_dataloader = self.get_train_dataloader() |
|
|
| |
| |
| |
| |
| total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size |
|
|
| len_dataloader = None |
| if has_length(train_dataloader): |
| len_dataloader = len(train_dataloader) |
| num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps |
| num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
| num_examples = self.num_examples(train_dataloader) |
| if args.max_steps > 0: |
| max_steps = args.max_steps |
| num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( |
| args.max_steps % num_update_steps_per_epoch > 0 |
| ) |
| |
| |
| num_train_samples = args.max_steps * total_train_batch_size |
| else: |
| max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
| num_train_epochs = math.ceil(args.num_train_epochs) |
| num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs |
| elif args.max_steps > 0: |
| max_steps = args.max_steps |
| |
| num_train_epochs = sys.maxsize |
| num_update_steps_per_epoch = max_steps |
| num_examples = total_train_batch_size * args.max_steps |
| num_train_samples = args.max_steps * total_train_batch_size |
| else: |
| raise ValueError( |
| "args.max_steps must be set to a positive value if dataloader does not have a length, was" |
| f" {args.max_steps}" |
| ) |
|
|
| if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: |
| if self.args.n_gpu > 1: |
| |
| |
| raise ValueError( |
| "Currently --debug underflow_overflow is not supported under DP. Please use DDP" |
| " (torch.distributed.launch)." |
| ) |
| else: |
| debug_overflow = DebugUnderflowOverflow(self.model) |
|
|
| delay_optimizer_creation = ( |
| self.sharded_ddp is not None |
| and self.sharded_ddp != ShardedDDPOption.SIMPLE |
| or is_sagemaker_mp_enabled() |
| or self.fsdp is not None |
| ) |
| if args.deepspeed: |
| deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( |
| self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint |
| ) |
| self.model = deepspeed_engine.module |
| self.model_wrapped = deepspeed_engine |
| self.deepspeed = deepspeed_engine |
| self.optimizer = optimizer |
| self.lr_scheduler = lr_scheduler |
| elif not delay_optimizer_creation: |
| self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
| self.state = TrainerState() |
| self.state.is_hyper_param_search = trial is not None |
|
|
| |
| if args.gradient_checkpointing: |
| self.model.gradient_checkpointing_enable() |
|
|
| model = self._wrap_model(self.model_wrapped) |
|
|
| if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: |
| self._load_from_checkpoint(resume_from_checkpoint, model) |
|
|
| |
| if model is not self.model: |
| self.model_wrapped = model |
|
|
| if delay_optimizer_creation: |
| self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
|
|
| |
| self._load_optimizer_and_scheduler(resume_from_checkpoint) |
|
|
| |
| |
| |
|
|
| |
| logger.info("***** Running training *****") |
| logger.info(f" Num examples = {num_examples}") |
| logger.info(f" Num Epochs = {num_train_epochs}") |
| logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| logger.info(f" Total optimization steps = {max_steps}") |
| logger.info( |
| f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" |
| ) |
|
|
| self.state.epoch = 0 |
| start_time = time.time() |
| epochs_trained = 0 |
| steps_trained_in_current_epoch = 0 |
| steps_trained_progress_bar = None |
|
|
| |
| if resume_from_checkpoint is not None and os.path.isfile( |
| os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
| ): |
| self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
| epochs_trained = self.state.global_step // num_update_steps_per_epoch |
| if not args.ignore_data_skip: |
| steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
| steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
| else: |
| steps_trained_in_current_epoch = 0 |
|
|
| logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
| logger.info(f" Continuing training from epoch {epochs_trained}") |
| logger.info(f" Continuing training from global step {self.state.global_step}") |
| if not args.ignore_data_skip: |
| if skip_first_batches is None: |
| logger.info( |
| f" Will skip the first {epochs_trained} epochs then the first" |
| f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," |
| " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" |
| " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" |
| " training on data already seen by your model." |
| ) |
| else: |
| logger.info( |
| f" Will skip the first {epochs_trained} epochs then the first" |
| f" {steps_trained_in_current_epoch} batches in the first epoch." |
| ) |
| if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: |
| steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) |
| steps_trained_progress_bar.set_description("Skipping the first batches") |
|
|
| |
| self.callback_handler.model = self.model |
| self.callback_handler.optimizer = self.optimizer |
| self.callback_handler.lr_scheduler = self.lr_scheduler |
| self.callback_handler.train_dataloader = train_dataloader |
| if self.hp_name is not None and self._trial is not None: |
| |
| |
| self.state.trial_name = self.hp_name(self._trial) |
| if trial is not None: |
| assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial |
| self.state.trial_params = hp_params(assignments) |
| else: |
| self.state.trial_params = None |
| |
| |
| self.state.max_steps = max_steps |
| self.state.num_train_epochs = num_train_epochs |
| self.state.is_local_process_zero = self.is_local_process_zero() |
| self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
| |
| tr_loss = torch.tensor(0.0).to(args.device) |
| |
| self._total_loss_scalar = 0.0 |
| self._globalstep_last_logged = self.state.global_step |
| model.zero_grad() |
|
|
| self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
|
|
| |
| if not args.ignore_data_skip: |
| for epoch in range(epochs_trained): |
| is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( |
| train_dataloader.sampler, RandomSampler |
| ) |
| if is_torch_less_than_1_11 or not is_random_sampler: |
| |
| |
| for _ in train_dataloader: |
| break |
| else: |
| |
| |
| _ = list(train_dataloader.sampler) |
|
|
| total_batched_samples = 0 |
| for epoch in range(epochs_trained, num_train_epochs): |
| if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): |
| train_dataloader.sampler.set_epoch(epoch) |
| elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): |
| train_dataloader.dataset.set_epoch(epoch) |
|
|
| if is_torch_tpu_available(): |
| parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) |
| epoch_iterator = parallel_loader |
| else: |
| epoch_iterator = train_dataloader |
|
|
| |
| if args.past_index >= 0: |
| self._past = None |
|
|
| steps_in_epoch = ( |
| len(epoch_iterator) |
| if len_dataloader is not None |
| else args.max_steps * args.gradient_accumulation_steps |
| ) |
| self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
|
|
| if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: |
| self._load_rng_state(resume_from_checkpoint) |
|
|
| rng_to_sync = False |
| steps_skipped = 0 |
| if skip_first_batches is not None and steps_trained_in_current_epoch > 0: |
| epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) |
| steps_skipped = steps_trained_in_current_epoch |
| steps_trained_in_current_epoch = 0 |
| rng_to_sync = True |
|
|
| step = -1 |
| for step, inputs in enumerate(epoch_iterator): |
| total_batched_samples += 1 |
| if rng_to_sync: |
| self._load_rng_state(resume_from_checkpoint) |
| rng_to_sync = False |
|
|
| |
| if steps_trained_in_current_epoch > 0: |
| steps_trained_in_current_epoch -= 1 |
| if steps_trained_progress_bar is not None: |
| steps_trained_progress_bar.update(1) |
| if steps_trained_in_current_epoch == 0: |
| self._load_rng_state(resume_from_checkpoint) |
| continue |
| elif steps_trained_progress_bar is not None: |
| steps_trained_progress_bar.close() |
| steps_trained_progress_bar = None |
|
|
| if step % args.gradient_accumulation_steps == 0: |
| self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
|
|
| if ( |
| (total_batched_samples % args.gradient_accumulation_steps != 0) |
| and args.local_rank != -1 |
| and args._no_sync_in_gradient_accumulation |
| ): |
| |
| with model.no_sync(): |
| tr_loss_step = self.training_step(model, inputs) |
| else: |
| tr_loss_step = self.training_step(model, inputs) |
|
|
| if ( |
| args.logging_nan_inf_filter |
| and not is_torch_tpu_available() |
| and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
| ): |
| |
| tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
| else: |
| tr_loss += tr_loss_step |
|
|
| self.current_flos += float(self.floating_point_ops(inputs)) |
|
|
| |
| if self.deepspeed: |
| self.deepspeed.step() |
|
|
| if total_batched_samples % args.gradient_accumulation_steps == 0 or ( |
| |
| steps_in_epoch <= args.gradient_accumulation_steps |
| and (step + 1) == steps_in_epoch |
| ): |
| |
| if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: |
| |
|
|
| if self.do_grad_scaling: |
| |
| if is_torch_tpu_available(): |
| gradients = xm._fetch_gradients(self.optimizer) |
| xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) |
| |
| self.scaler.unscale_(self.optimizer) |
|
|
| if is_sagemaker_mp_enabled() and args.fp16: |
| self.optimizer.clip_master_grads(args.max_grad_norm) |
| elif hasattr(self.optimizer, "clip_grad_norm"): |
| |
| self.optimizer.clip_grad_norm(args.max_grad_norm) |
| elif hasattr(model, "clip_grad_norm_"): |
| |
| model.clip_grad_norm_(args.max_grad_norm) |
| else: |
| |
| nn.utils.clip_grad_norm_( |
| amp.master_params(self.optimizer) if self.use_apex else model.parameters(), |
| args.max_grad_norm, |
| ) |
|
|
| |
| optimizer_was_run = True |
| if self.deepspeed: |
| pass |
| elif is_torch_tpu_available(): |
| if self.do_grad_scaling: |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| xm.optimizer_step(self.optimizer) |
| elif self.do_grad_scaling: |
| scale_before = self.scaler.get_scale() |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| scale_after = self.scaler.get_scale() |
| optimizer_was_run = scale_before <= scale_after |
| else: |
| self.optimizer.step() |
|
|
| if optimizer_was_run and not self.deepspeed: |
| self.lr_scheduler.step() |
|
|
| model.zero_grad() |
| self.state.global_step += 1 |
| self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch |
| self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
|
|
| self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
| else: |
| self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
|
|
| if self.control.should_epoch_stop or self.control.should_training_stop: |
| break |
| if step < 0: |
| logger.warning( |
| "There seems to be not a single sample in your epoch_iterator, stopping training at step" |
| f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
| f" num_steps ({max_steps}) higher than the number of available samples." |
| ) |
| self.control.should_training_stop = True |
|
|
| self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
| self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) |
|
|
| if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
| if is_torch_tpu_available(): |
| |
| xm.master_print(met.metrics_report()) |
| else: |
| logger.warning( |
| "You enabled PyTorch/XLA debug metrics but you don't have a TPU " |
| "configured. Check your training configuration if this is unexpected." |
| ) |
| if self.control.should_training_stop: |
| break |
|
|
| if args.past_index and hasattr(self, "_past"): |
| |
| delattr(self, "_past") |
|
|
| logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
| if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
| |
| if is_torch_tpu_available(): |
| xm.rendezvous("load_best_model_at_end") |
| elif args.local_rank != -1: |
| dist.barrier() |
| elif is_sagemaker_mp_enabled(): |
| smp.barrier() |
|
|
| self._load_best_model() |
|
|
| |
| self._total_loss_scalar += tr_loss.item() |
| train_loss = self._total_loss_scalar / self.state.global_step |
|
|
| metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) |
| self.store_flos() |
| metrics["total_flos"] = self.state.total_flos |
| metrics["train_loss"] = train_loss |
|
|
| self.is_in_train = False |
|
|
| self._memory_tracker.stop_and_update_metrics(metrics) |
|
|
| self.log(metrics) |
|
|
| run_dir = self._get_output_dir(trial) |
| checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
|
|
| |
| if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
| for checkpoint in checkpoints_sorted: |
| if checkpoint != self.state.best_model_checkpoint: |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
| shutil.rmtree(checkpoint) |
|
|
| self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
|
|
| return TrainOutput(self.state.global_step, train_loss, metrics) |
|
|
| def _get_output_dir(self, trial): |
| if self.hp_search_backend is not None and trial is not None: |
| if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| run_id = trial.number |
| elif self.hp_search_backend == HPSearchBackend.RAY: |
| from ray import tune |
|
|
| run_id = tune.get_trial_id() |
| elif self.hp_search_backend == HPSearchBackend.SIGOPT: |
| run_id = trial.id |
| elif self.hp_search_backend == HPSearchBackend.WANDB: |
| import wandb |
|
|
| run_id = wandb.run.id |
| run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" |
| run_dir = os.path.join(self.args.output_dir, run_name) |
| else: |
| run_dir = self.args.output_dir |
| return run_dir |
|
|
| def _load_from_checkpoint(self, resume_from_checkpoint, model=None): |
| if model is None: |
| model = self.model |
|
|
| if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( |
| os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) |
| ): |
| raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") |
|
|
| logger.info(f"Loading model from {resume_from_checkpoint}.") |
|
|
| if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): |
| config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) |
| checkpoint_version = config.transformers_version |
| if checkpoint_version is not None and checkpoint_version != __version__: |
| logger.warning( |
| f"You are resuming training from a checkpoint trained with {checkpoint_version} of " |
| f"Transformers but your current version is {__version__}. This is not recommended and could " |
| "yield to errors or unwanted behaviors." |
| ) |
|
|
| if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): |
| |
| if is_sagemaker_mp_enabled(): |
| if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): |
| |
| |
| smp.resume_from_checkpoint( |
| path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False |
| ) |
| else: |
| |
| |
| if hasattr(self.args, "fp16") and self.args.fp16 is True: |
| logger.warning( |
| "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." |
| ) |
| state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") |
| |
| state_dict["_smp_is_partial"] = False |
| load_result = model.load_state_dict(state_dict, strict=True) |
| |
| del state_dict |
| else: |
| |
| state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") |
| |
| |
| load_result = model.load_state_dict(state_dict, False) |
| |
| del state_dict |
| self._issue_warnings_after_load(load_result) |
| else: |
| |
| load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) |
| if not is_sagemaker_mp_enabled(): |
| self._issue_warnings_after_load(load_result) |
|
|
| def _load_best_model(self): |
| logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") |
| best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) |
| model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model |
| if os.path.exists(best_model_path): |
| if self.deepspeed: |
| if self.model_wrapped is not None: |
| |
| self.model_wrapped.destroy() |
| self.model_wrapped = None |
|
|
| |
| deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( |
| self, |
| num_training_steps=self.args.max_steps, |
| resume_from_checkpoint=self.state.best_model_checkpoint, |
| ) |
| self.model = deepspeed_engine.module |
| self.model_wrapped = deepspeed_engine |
| self.deepspeed = deepspeed_engine |
| self.optimizer = optimizer |
| self.lr_scheduler = lr_scheduler |
| else: |
| if is_sagemaker_mp_enabled(): |
| if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): |
| |
| |
| smp.resume_from_checkpoint( |
| path=self.state.best_model_checkpoint, |
| tag=WEIGHTS_NAME, |
| partial=False, |
| load_optimizer=False, |
| ) |
| else: |
| |
| |
| state_dict = torch.load(best_model_path, map_location="cpu") |
| state_dict["_smp_is_partial"] = False |
| load_result = model.load_state_dict(state_dict, strict=True) |
| else: |
| |
| state_dict = torch.load(best_model_path, map_location="cpu") |
| |
| |
| |
| load_result = model.load_state_dict(state_dict, False) |
| if not is_sagemaker_mp_enabled(): |
| self._issue_warnings_after_load(load_result) |
| elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): |
| load_result = load_sharded_checkpoint( |
| model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() |
| ) |
| if not is_sagemaker_mp_enabled(): |
| self._issue_warnings_after_load(load_result) |
| else: |
| logger.warning( |
| f"Could not locate the best model at {best_model_path}, if you are running a distributed training " |
| "on multiple nodes, you should activate `--save_on_each_node`." |
| ) |
|
|
| def _issue_warnings_after_load(self, load_result): |
| if len(load_result.missing_keys) != 0: |
| if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( |
| self.model._keys_to_ignore_on_save |
| ): |
| self.model.tie_weights() |
| else: |
| logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") |
| if len(load_result.unexpected_keys) != 0: |
| logger.warning( |
| f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." |
| ) |
|
|
| def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): |
| if self.control.should_log: |
| if is_torch_tpu_available(): |
| xm.mark_step() |
|
|
| logs: Dict[str, float] = {} |
|
|
| |
| tr_loss_scalar = self._nested_gather(tr_loss).mean().item() |
|
|
| |
| tr_loss -= tr_loss |
|
|
| logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) |
| logs["learning_rate"] = self._get_learning_rate() |
|
|
| self._total_loss_scalar += tr_loss_scalar |
| self._globalstep_last_logged = self.state.global_step |
| self.store_flos() |
|
|
| self.log(logs) |
|
|
| metrics = None |
| if self.control.should_evaluate: |
| if isinstance(self.eval_dataset, dict): |
| for eval_dataset_name, eval_dataset in self.eval_dataset.items(): |
| metrics = self.evaluate( |
| eval_dataset=eval_dataset, |
| ignore_keys=ignore_keys_for_eval, |
| metric_key_prefix=f"eval_{eval_dataset_name}", |
| ) |
| else: |
| metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) |
| self._report_to_hp_search(trial, self.state.global_step, metrics) |
|
|
| if self.control.should_save: |
| self._save_checkpoint(model, trial, metrics=metrics) |
| self.control = self.callback_handler.on_save(self.args, self.state, self.control) |
|
|
| def _load_rng_state(self, checkpoint): |
| |
| if checkpoint is None: |
| return |
|
|
| if self.args.world_size > 1: |
| process_index = self.args.process_index |
| rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") |
| if not os.path.isfile(rng_file): |
| logger.info( |
| f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " |
| "wasn't launched in a distributed fashion, reproducibility is not guaranteed." |
| ) |
| return |
| else: |
| rng_file = os.path.join(checkpoint, "rng_state.pth") |
| if not os.path.isfile(rng_file): |
| logger.info( |
| "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
| "fashion, reproducibility is not guaranteed." |
| ) |
| return |
|
|
| checkpoint_rng_state = torch.load(rng_file) |
| random.setstate(checkpoint_rng_state["python"]) |
| np.random.set_state(checkpoint_rng_state["numpy"]) |
| torch.random.set_rng_state(checkpoint_rng_state["cpu"]) |
| if torch.cuda.is_available(): |
| if self.args.local_rank != -1: |
| torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) |
| else: |
| try: |
| torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) |
| except Exception as e: |
| logger.info( |
| f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" |
| "\nThis won't yield the same results as if the training had not been interrupted." |
| ) |
| if is_torch_tpu_available(): |
| xm.set_rng_state(checkpoint_rng_state["xla"]) |
|
|
| def _save_checkpoint(self, model, trial, metrics=None): |
| |
| |
| |
|
|
| |
| checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
|
|
| if self.hp_search_backend is None and trial is None: |
| self.store_flos() |
|
|
| run_dir = self._get_output_dir(trial=trial) |
| output_dir = os.path.join(run_dir, checkpoint_folder) |
| self.save_model(output_dir, _internal_call=True) |
| if self.deepspeed: |
| |
| |
| self.deepspeed.save_checkpoint(output_dir) |
|
|
| |
| if self.sharded_ddp == ShardedDDPOption.SIMPLE: |
| self.optimizer.consolidate_state_dict() |
|
|
| if is_torch_tpu_available(): |
| xm.rendezvous("saving_optimizer_states") |
| xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) |
| with warnings.catch_warnings(record=True) as caught_warnings: |
| xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) |
| reissue_pt_warnings(caught_warnings) |
| elif is_sagemaker_mp_enabled(): |
| opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) |
| smp.barrier() |
| if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: |
| smp.save( |
| opt_state_dict, |
| os.path.join(output_dir, OPTIMIZER_NAME), |
| partial=True, |
| v3=smp.state.cfg.shard_optimizer_state, |
| ) |
| if self.args.should_save: |
| with warnings.catch_warnings(record=True) as caught_warnings: |
| torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) |
| reissue_pt_warnings(caught_warnings) |
| if self.do_grad_scaling: |
| torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) |
| elif self.args.should_save and not self.deepspeed: |
| |
| torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) |
| with warnings.catch_warnings(record=True) as caught_warnings: |
| torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) |
| reissue_pt_warnings(caught_warnings) |
| if self.do_grad_scaling: |
| torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) |
|
|
| |
| if metrics is not None and self.args.metric_for_best_model is not None: |
| metric_to_check = self.args.metric_for_best_model |
| if not metric_to_check.startswith("eval_"): |
| metric_to_check = f"eval_{metric_to_check}" |
| metric_value = metrics[metric_to_check] |
|
|
| operator = np.greater if self.args.greater_is_better else np.less |
| if ( |
| self.state.best_metric is None |
| or self.state.best_model_checkpoint is None |
| or operator(metric_value, self.state.best_metric) |
| ): |
| self.state.best_metric = metric_value |
| self.state.best_model_checkpoint = output_dir |
|
|
| |
| if self.args.should_save: |
| self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
|
|
| |
| rng_states = { |
| "python": random.getstate(), |
| "numpy": np.random.get_state(), |
| "cpu": torch.random.get_rng_state(), |
| } |
| if torch.cuda.is_available(): |
| if self.args.local_rank == -1: |
| |
| rng_states["cuda"] = torch.cuda.random.get_rng_state_all() |
| else: |
| rng_states["cuda"] = torch.cuda.random.get_rng_state() |
|
|
| if is_torch_tpu_available(): |
| rng_states["xla"] = xm.get_rng_state() |
|
|
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| if self.args.world_size <= 1: |
| torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) |
| else: |
| torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) |
|
|
| if self.args.push_to_hub: |
| self._push_from_checkpoint(output_dir) |
|
|
| |
| if self.args.should_save: |
| self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) |
|
|
| def _load_optimizer_and_scheduler(self, checkpoint): |
| """If optimizer and scheduler states exist, load them.""" |
| if checkpoint is None: |
| return |
|
|
| if self.deepspeed: |
| |
| return |
|
|
| checkpoint_file_exists = ( |
| glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") |
| if is_sagemaker_mp_enabled() |
| else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) |
| ) |
| if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): |
| |
| if is_torch_tpu_available(): |
| |
| optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") |
| with warnings.catch_warnings(record=True) as caught_warnings: |
| lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") |
| reissue_pt_warnings(caught_warnings) |
|
|
| xm.send_cpu_data_to_device(optimizer_state, self.args.device) |
| xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) |
|
|
| self.optimizer.load_state_dict(optimizer_state) |
| self.lr_scheduler.load_state_dict(lr_scheduler_state) |
| else: |
| map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device |
| if is_sagemaker_mp_enabled(): |
| if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): |
| |
| def opt_load_hook(mod, opt): |
| opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) |
|
|
| else: |
| |
| def opt_load_hook(mod, opt): |
| if IS_SAGEMAKER_MP_POST_1_10: |
| opt.load_state_dict( |
| smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) |
| ) |
| else: |
| opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) |
|
|
| self.model_wrapped.register_post_step_hook(opt_load_hook) |
| else: |
| self.optimizer.load_state_dict( |
| torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) |
| ) |
| with warnings.catch_warnings(record=True) as caught_warnings: |
| self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) |
| reissue_pt_warnings(caught_warnings) |
| if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): |
| self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) |
|
|
| def hyperparameter_search( |
| self, |
| hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, |
| compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, |
| n_trials: int = 20, |
| direction: str = "minimize", |
| backend: Optional[Union["str", HPSearchBackend]] = None, |
| hp_name: Optional[Callable[["optuna.Trial"], str]] = None, |
| **kwargs, |
| ) -> BestRun: |
| """ |
| Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined |
| by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, |
| the sum of all metrics otherwise. |
| |
| <Tip warning={true}> |
| |
| To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to |
| reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to |
| subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom |
| optimizer/scheduler. |
| |
| </Tip> |
| |
| Args: |
| hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): |
| A function that defines the hyperparameter search space. Will default to |
| [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or |
| [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. |
| compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): |
| A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` |
| method. Will default to [`~trainer_utils.default_compute_objective`]. |
| n_trials (`int`, *optional*, defaults to 100): |
| The number of trial runs to test. |
| direction (`str`, *optional*, defaults to `"minimize"`): |
| Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick |
| `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. |
| backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): |
| The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending |
| on which one is installed. If all are installed, will default to optuna. |
| hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): |
| A function that defines the trial/run name. Will default to None. |
| kwargs (`Dict[str, Any]`, *optional*): |
| Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more |
| information see: |
| |
| - the documentation of |
| [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) |
| - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) |
| - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) |
| |
| Returns: |
| [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in |
| `run_summary` attribute for Ray backend. |
| """ |
| if backend is None: |
| backend = default_hp_search_backend() |
| if backend is None: |
| raise RuntimeError( |
| "At least one of optuna or ray should be installed. " |
| "To install optuna run `pip install optuna`. " |
| "To install ray run `pip install ray[tune]`. " |
| "To install sigopt run `pip install sigopt`." |
| ) |
| backend = HPSearchBackend(backend) |
| if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): |
| raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") |
| if backend == HPSearchBackend.RAY and not is_ray_tune_available(): |
| raise RuntimeError( |
| "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." |
| ) |
| if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): |
| raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") |
| if backend == HPSearchBackend.WANDB and not is_wandb_available(): |
| raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") |
| self.hp_search_backend = backend |
| if self.model_init is None: |
| raise RuntimeError( |
| "To use hyperparameter search, you need to pass your model through a model_init function." |
| ) |
|
|
| self.hp_space = default_hp_space[backend] if hp_space is None else hp_space |
| self.hp_name = hp_name |
| self.compute_objective = default_compute_objective if compute_objective is None else compute_objective |
|
|
| backend_dict = { |
| HPSearchBackend.OPTUNA: run_hp_search_optuna, |
| HPSearchBackend.RAY: run_hp_search_ray, |
| HPSearchBackend.SIGOPT: run_hp_search_sigopt, |
| HPSearchBackend.WANDB: run_hp_search_wandb, |
| } |
| best_run = backend_dict[backend](self, n_trials, direction, **kwargs) |
|
|
| self.hp_search_backend = None |
| return best_run |
|
|
| def log(self, logs: Dict[str, float]) -> None: |
| """ |
| Log `logs` on the various objects watching training. |
| |
| Subclass and override this method to inject custom behavior. |
| |
| Args: |
| logs (`Dict[str, float]`): |
| The values to log. |
| """ |
| if self.state.epoch is not None: |
| logs["epoch"] = round(self.state.epoch, 2) |
|
|
| output = {**logs, **{"step": self.state.global_step}} |
| self.state.log_history.append(output) |
| self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) |
|
|
| def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: |
| """ |
| Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. |
| """ |
| if isinstance(data, Mapping): |
| return type(data)({k: self._prepare_input(v) for k, v in data.items()}) |
| elif isinstance(data, (tuple, list)): |
| return type(data)(self._prepare_input(v) for v in data) |
| elif isinstance(data, torch.Tensor): |
| kwargs = {"device": self.args.device} |
| if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): |
| |
| |
| |
| kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) |
| return data.to(**kwargs) |
| return data |
|
|
| def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: |
| """ |
| Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and |
| handling potential state. |
| """ |
| inputs = self._prepare_input(inputs) |
| if len(inputs) == 0: |
| raise ValueError( |
| "The batch received was empty, your model won't be able to train on it. Double-check that your " |
| f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." |
| ) |
| if self.args.past_index >= 0 and self._past is not None: |
| inputs["mems"] = self._past |
|
|
| return inputs |
|
|
| def compute_loss_context_manager(self): |
| """ |
| A helper wrapper to group together context managers. |
| """ |
| return self.autocast_smart_context_manager() |
|
|
| def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): |
| """ |
| A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired |
| arguments, depending on the situation. |
| """ |
| if self.use_cuda_amp or self.use_cpu_amp: |
| if is_torch_greater_or_equal_than_1_10: |
| ctx_manager = ( |
| torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) |
| if self.use_cpu_amp |
| else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) |
| ) |
| else: |
| ctx_manager = torch.cuda.amp.autocast() |
| else: |
| ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() |
|
|
| return ctx_manager |
|
|
| def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: |
| """ |
| Perform a training step on a batch of inputs. |
| |
| Subclass and override to inject custom behavior. |
| |
| Args: |
| model (`nn.Module`): |
| The model to train. |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| argument `labels`. Check your model's documentation for all accepted arguments. |
| |
| Return: |
| `torch.Tensor`: The tensor with training loss on this batch. |
| """ |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| if is_sagemaker_mp_enabled(): |
| loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) |
| return loss_mb.reduce_mean().detach().to(self.args.device) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: |
| |
| loss = loss / self.args.gradient_accumulation_steps |
|
|
| if self.do_grad_scaling: |
| self.scaler.scale(loss).backward() |
| elif self.use_apex: |
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
| scaled_loss.backward() |
| elif self.deepspeed: |
| |
| loss = self.deepspeed.backward(loss) |
| else: |
| loss.backward() |
|
|
| return loss.detach() |
|
|
| def compute_loss(self, model, inputs, return_outputs=False): |
| """ |
| How the loss is computed by Trainer. By default, all models return the loss in the first element. |
| |
| Subclass and override for custom behavior. |
| """ |
| if self.label_smoother is not None and "labels" in inputs: |
| labels = inputs.pop("labels") |
| else: |
| labels = None |
| outputs = model(**inputs) |
| |
| |
| if self.args.past_index >= 0: |
| self._past = outputs[self.args.past_index] |
|
|
| if labels is not None: |
| if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): |
| loss = self.label_smoother(outputs, labels, shift_labels=True) |
| else: |
| loss = self.label_smoother(outputs, labels) |
| else: |
| if isinstance(outputs, dict) and "loss" not in outputs: |
| raise ValueError( |
| "The model did not return a loss from the inputs, only the following keys: " |
| f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." |
| ) |
| |
| loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
| def is_local_process_zero(self) -> bool: |
| """ |
| Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several |
| machines) main process. |
| """ |
| return self.args.local_process_index == 0 |
|
|
| def is_world_process_zero(self) -> bool: |
| """ |
| Whether or not this process is the global main process (when training in a distributed fashion on several |
| machines, this is only going to be `True` for one process). |
| """ |
| |
| |
| if is_sagemaker_mp_enabled(): |
| return smp.rank() == 0 |
| else: |
| return self.args.process_index == 0 |
|
|
| def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): |
| """ |
| Will save the model, so you can reload it using `from_pretrained()`. |
| |
| Will only save from the main process. |
| """ |
|
|
| if output_dir is None: |
| output_dir = self.args.output_dir |
|
|
| if is_torch_tpu_available(): |
| self._save_tpu(output_dir) |
| elif is_sagemaker_mp_enabled(): |
| |
| os.makedirs(output_dir, exist_ok=True) |
| state_dict = self.model_wrapped.state_dict() |
| if self.args.should_save: |
| self._save(output_dir, state_dict=state_dict) |
| if IS_SAGEMAKER_MP_POST_1_10: |
| |
| Path(os.path.join(output_dir, "user_content.pt")).touch() |
| elif ( |
| ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp |
| or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp |
| or self.fsdp is not None |
| ): |
| state_dict = self.model.state_dict() |
|
|
| if self.args.should_save: |
| self._save(output_dir, state_dict=state_dict) |
| elif self.deepspeed: |
| |
| if self.args.should_save: |
| self._save(output_dir) |
|
|
| if is_deepspeed_zero3_enabled(): |
| |
| |
| |
| |
| if self.args.should_save: |
| file = os.path.join(output_dir, WEIGHTS_NAME) |
| if os.path.isfile(file): |
| |
| os.remove(file) |
|
|
| |
| |
| |
| if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): |
| logger.warning( |
| "deepspeed.save_16bit_model didn't save the model, since" |
| " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" |
| " zero_to_fp32.py to recover weights" |
| ) |
| self.deepspeed.save_checkpoint(output_dir) |
|
|
| elif self.args.should_save: |
| self._save(output_dir) |
|
|
| |
| if self.args.push_to_hub and not _internal_call: |
| self.push_to_hub(commit_message="Model save") |
|
|
| def _save_tpu(self, output_dir: Optional[str] = None): |
| output_dir = output_dir if output_dir is not None else self.args.output_dir |
| logger.info(f"Saving model checkpoint to {output_dir}") |
|
|
| if xm.is_master_ordinal(): |
| os.makedirs(output_dir, exist_ok=True) |
| torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
| |
| |
| xm.rendezvous("saving_checkpoint") |
| if not isinstance(self.model, PreTrainedModel): |
| if isinstance(unwrap_model(self.model), PreTrainedModel): |
| unwrap_model(self.model).save_pretrained( |
| output_dir, |
| is_main_process=self.args.should_save, |
| state_dict=self.model.state_dict(), |
| save_function=xm.save, |
| ) |
| else: |
| logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
| state_dict = self.model.state_dict() |
| xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
| else: |
| self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) |
| if self.tokenizer is not None and self.args.should_save: |
| self.tokenizer.save_pretrained(output_dir) |
|
|
| def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| |
| output_dir = output_dir if output_dir is not None else self.args.output_dir |
| os.makedirs(output_dir, exist_ok=True) |
| logger.info(f"Saving model checkpoint to {output_dir}") |
| |
| |
| if not isinstance(self.model, PreTrainedModel): |
| if isinstance(unwrap_model(self.model), PreTrainedModel): |
| if state_dict is None: |
| state_dict = self.model.state_dict() |
| unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict) |
| else: |
| logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
| if state_dict is None: |
| state_dict = self.model.state_dict() |
| torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
| else: |
| if self.save_prefixencoder: |
| print("Saving PrefixEncoder") |
| state_dict = self.model.state_dict() |
| filtered_state_dict = {} |
| for k, v in self.model.named_parameters(): |
| if v.requires_grad: |
| filtered_state_dict[k] = state_dict[k] |
| self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) |
| else: |
| print("Saving the whole model") |
| self.model.save_pretrained(output_dir, state_dict=state_dict) |
| if self.tokenizer is not None: |
| self.tokenizer.save_pretrained(output_dir) |
|
|
| |
| torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
| def store_flos(self): |
| |
| if self.args.local_rank != -1: |
| self.state.total_flos += ( |
| distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() |
| ) |
| self.current_flos = 0 |
| else: |
| self.state.total_flos += self.current_flos |
| self.current_flos = 0 |
|
|
| def _sorted_checkpoints( |
| self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False |
| ) -> List[str]: |
| ordering_and_checkpoint_path = [] |
|
|
| glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] |
|
|
| for path in glob_checkpoints: |
| if use_mtime: |
| ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) |
| else: |
| regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) |
| if regex_match is not None and regex_match.groups() is not None: |
| ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
|
|
| checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
| checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
| |
| if self.state.best_model_checkpoint is not None: |
| best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) |
| for i in range(best_model_index, len(checkpoints_sorted) - 2): |
| checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] |
| return checkpoints_sorted |
|
|
| def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: |
| if self.args.save_total_limit is None or self.args.save_total_limit <= 0: |
| return |
|
|
| |
| checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) |
| if len(checkpoints_sorted) <= self.args.save_total_limit: |
| return |
|
|
| |
| |
| save_total_limit = self.args.save_total_limit |
| if ( |
| self.state.best_model_checkpoint is not None |
| and self.args.save_total_limit == 1 |
| and checkpoints_sorted[-1] != self.state.best_model_checkpoint |
| ): |
| save_total_limit = 2 |
|
|
| number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) |
| checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] |
| for checkpoint in checkpoints_to_be_deleted: |
| logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
| shutil.rmtree(checkpoint, ignore_errors=True) |
|
|
| def evaluate( |
| self, |
| eval_dataset: Optional[Dataset] = None, |
| ignore_keys: Optional[List[str]] = None, |
| metric_key_prefix: str = "eval", |
| ) -> Dict[str, float]: |
| """ |
| Run evaluation and returns metrics. |
| |
| The calling script will be responsible for providing a method to compute metrics, as they are task-dependent |
| (pass it to the init `compute_metrics` argument). |
| |
| You can also subclass and override this method to inject custom behavior. |
| |
| Args: |
| eval_dataset (`Dataset`, *optional*): |
| Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns |
| not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` |
| method. |
| ignore_keys (`Lst[str]`, *optional*): |
| A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| gathering predictions. |
| metric_key_prefix (`str`, *optional*, defaults to `"eval"`): |
| An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named |
| "eval_bleu" if the prefix is "eval" (default) |
| |
| Returns: |
| A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The |
| dictionary also contains the epoch number which comes from the training state. |
| """ |
| |
| self._memory_tracker.start() |
|
|
| eval_dataloader = self.get_eval_dataloader(eval_dataset) |
| start_time = time.time() |
|
|
| eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop |
| output = eval_loop( |
| eval_dataloader, |
| description="Evaluation", |
| |
| |
| prediction_loss_only=True if self.compute_metrics is None else None, |
| ignore_keys=ignore_keys, |
| metric_key_prefix=metric_key_prefix, |
| ) |
|
|
| total_batch_size = self.args.eval_batch_size * self.args.world_size |
| if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: |
| start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] |
| output.metrics.update( |
| speed_metrics( |
| metric_key_prefix, |
| start_time, |
| num_samples=output.num_samples, |
| num_steps=math.ceil(output.num_samples / total_batch_size), |
| ) |
| ) |
|
|
| self.log(output.metrics) |
|
|
| if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
| |
| xm.master_print(met.metrics_report()) |
|
|
| self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) |
|
|
| self._memory_tracker.stop_and_update_metrics(output.metrics) |
|
|
| return output.metrics |
|
|
| def predict( |
| self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" |
| ) -> PredictionOutput: |
| """ |
| Run prediction and returns predictions and potential metrics. |
| |
| Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method |
| will also return metrics, like in `evaluate()`. |
| |
| Args: |
| test_dataset (`Dataset`): |
| Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the |
| `model.forward()` method are automatically removed. Has to implement the method `__len__` |
| ignore_keys (`Lst[str]`, *optional*): |
| A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| gathering predictions. |
| metric_key_prefix (`str`, *optional*, defaults to `"test"`): |
| An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named |
| "test_bleu" if the prefix is "test" (default) |
| |
| <Tip> |
| |
| If your predictions or labels have different sequence length (for instance because you're doing dynamic padding |
| in a token classification task) the predictions will be padded (on the right) to allow for concatenation into |
| one array. The padding index is -100. |
| |
| </Tip> |
| |
| Returns: *NamedTuple* A namedtuple with the following keys: |
| |
| - predictions (`np.ndarray`): The predictions on `test_dataset`. |
| - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). |
| - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained |
| labels). |
| """ |
| |
| self._memory_tracker.start() |
|
|
| test_dataloader = self.get_test_dataloader(test_dataset) |
| start_time = time.time() |
|
|
| eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop |
| output = eval_loop( |
| test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix |
| ) |
| total_batch_size = self.args.eval_batch_size * self.args.world_size |
| if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: |
| start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] |
| output.metrics.update( |
| speed_metrics( |
| metric_key_prefix, |
| start_time, |
| num_samples=output.num_samples, |
| num_steps=math.ceil(output.num_samples / total_batch_size), |
| ) |
| ) |
|
|
| self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) |
| self._memory_tracker.stop_and_update_metrics(output.metrics) |
|
|
| return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) |
|
|
| def evaluation_loop( |
| self, |
| dataloader: DataLoader, |
| description: str, |
| prediction_loss_only: Optional[bool] = None, |
| ignore_keys: Optional[List[str]] = None, |
| metric_key_prefix: str = "eval", |
| ) -> EvalLoopOutput: |
| """ |
| Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. |
| |
| Works both with or without labels. |
| """ |
| args = self.args |
|
|
| prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only |
|
|
| |
| if args.deepspeed and not self.deepspeed: |
| |
| |
| deepspeed_engine, _, _ = deepspeed_init( |
| self, num_training_steps=0, resume_from_checkpoint=None, inference=True |
| ) |
| self.model = deepspeed_engine.module |
| self.model_wrapped = deepspeed_engine |
| self.deepspeed = deepspeed_engine |
|
|
| model = self._wrap_model(self.model, training=False, dataloader=dataloader) |
|
|
| |
| |
| if not self.is_in_train: |
| if args.fp16_full_eval: |
| model = model.to(dtype=torch.float16, device=args.device) |
| elif args.bf16_full_eval: |
| model = model.to(dtype=torch.bfloat16, device=args.device) |
|
|
| batch_size = self.args.eval_batch_size |
|
|
| logger.info(f"***** Running {description} *****") |
| if has_length(dataloader): |
| logger.info(f" Num examples = {self.num_examples(dataloader)}") |
| else: |
| logger.info(" Num examples: Unknown") |
| logger.info(f" Batch size = {batch_size}") |
|
|
| model.eval() |
|
|
| self.callback_handler.eval_dataloader = dataloader |
| |
| eval_dataset = getattr(dataloader, "dataset", None) |
|
|
| if is_torch_tpu_available(): |
| dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) |
|
|
| if args.past_index >= 0: |
| self._past = None |
|
|
| |
| |
| losses_host = None |
| preds_host = None |
| labels_host = None |
| inputs_host = None |
|
|
| |
| all_losses = None |
| all_preds = None |
| all_labels = None |
| all_inputs = None |
| |
|
|
| observed_num_examples = 0 |
| |
| for step, inputs in enumerate(dataloader): |
| |
| observed_batch_size = find_batch_size(inputs) |
| if observed_batch_size is not None: |
| observed_num_examples += observed_batch_size |
| |
| if batch_size is None: |
| batch_size = observed_batch_size |
|
|
| |
| loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) |
| inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None |
|
|
| if is_torch_tpu_available(): |
| xm.mark_step() |
|
|
| |
| if loss is not None: |
| losses = self._nested_gather(loss.repeat(batch_size)) |
| losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) |
| if labels is not None: |
| labels = self._pad_across_processes(labels) |
| labels = self._nested_gather(labels) |
| labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) |
| if inputs_decode is not None: |
| inputs_decode = self._pad_across_processes(inputs_decode) |
| inputs_decode = self._nested_gather(inputs_decode) |
| inputs_host = ( |
| inputs_decode |
| if inputs_host is None |
| else nested_concat(inputs_host, inputs_decode, padding_index=-100) |
| ) |
| if logits is not None: |
| logits = self._pad_across_processes(logits) |
| logits = self._nested_gather(logits) |
| if self.preprocess_logits_for_metrics is not None: |
| logits = self.preprocess_logits_for_metrics(logits, labels) |
| preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) |
| self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) |
|
|
| |
| if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: |
| if losses_host is not None: |
| losses = nested_numpify(losses_host) |
| all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) |
| if preds_host is not None: |
| logits = nested_numpify(preds_host) |
| all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) |
| if inputs_host is not None: |
| inputs_decode = nested_numpify(inputs_host) |
| all_inputs = ( |
| inputs_decode |
| if all_inputs is None |
| else nested_concat(all_inputs, inputs_decode, padding_index=-100) |
| ) |
| if labels_host is not None: |
| labels = nested_numpify(labels_host) |
| all_labels = ( |
| labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) |
| ) |
|
|
| |
| losses_host, preds_host, inputs_host, labels_host = None, None, None, None |
|
|
| if args.past_index and hasattr(self, "_past"): |
| |
| delattr(self, "_past") |
|
|
| |
| if losses_host is not None: |
| losses = nested_numpify(losses_host) |
| all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) |
| if preds_host is not None: |
| logits = nested_numpify(preds_host) |
| all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) |
| if inputs_host is not None: |
| inputs_decode = nested_numpify(inputs_host) |
| all_inputs = ( |
| inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) |
| ) |
| if labels_host is not None: |
| labels = nested_numpify(labels_host) |
| all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) |
|
|
| |
| if has_length(eval_dataset): |
| num_samples = len(eval_dataset) |
| |
| |
| elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: |
| num_samples = eval_dataset.num_examples |
| else: |
| if has_length(dataloader): |
| num_samples = self.num_examples(dataloader) |
| else: |
| num_samples = observed_num_examples |
| if num_samples == 0 and observed_num_examples > 0: |
| num_samples = observed_num_examples |
|
|
| |
| |
| if all_losses is not None: |
| all_losses = all_losses[:num_samples] |
| if all_preds is not None: |
| all_preds = nested_truncate(all_preds, num_samples) |
| if all_labels is not None: |
| all_labels = nested_truncate(all_labels, num_samples) |
| if all_inputs is not None: |
| all_inputs = nested_truncate(all_inputs, num_samples) |
|
|
| |
| if self.compute_metrics is not None and all_preds is not None and all_labels is not None: |
| if args.include_inputs_for_metrics: |
| metrics = self.compute_metrics( |
| EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) |
| ) |
| else: |
| metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) |
| else: |
| metrics = {} |
|
|
| |
| metrics = denumpify_detensorize(metrics) |
|
|
| if all_losses is not None: |
| metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() |
| if hasattr(self, "jit_compilation_time"): |
| metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time |
|
|
| |
| for key in list(metrics.keys()): |
| if not key.startswith(f"{metric_key_prefix}_"): |
| metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) |
|
|
| return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) |
|
|
| def _nested_gather(self, tensors, name=None): |
| """ |
| Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before |
| concatenating them to `gathered` |
| """ |
| if tensors is None: |
| return |
| if is_torch_tpu_available(): |
| if name is None: |
| name = "nested_gather" |
| tensors = nested_xla_mesh_reduce(tensors, name) |
| elif is_sagemaker_mp_enabled(): |
| tensors = smp_gather(tensors) |
| elif self.args.local_rank != -1: |
| tensors = distributed_concat(tensors) |
| return tensors |
|
|
| |
| def _pad_across_processes(self, tensor, pad_index=-100): |
| """ |
| Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so |
| they can safely be gathered. |
| """ |
| if isinstance(tensor, (list, tuple)): |
| return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) |
| elif isinstance(tensor, dict): |
| return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) |
| elif not isinstance(tensor, torch.Tensor): |
| raise TypeError( |
| f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." |
| ) |
|
|
| if len(tensor.shape) < 2: |
| return tensor |
| |
| size = torch.tensor(tensor.shape, device=tensor.device)[None] |
| sizes = self._nested_gather(size).cpu() |
|
|
| max_size = max(s[1] for s in sizes) |
| |
| |
| if tensor.shape[1] >= max_size: |
| return tensor |
|
|
| |
| old_size = tensor.shape |
| new_size = list(old_size) |
| new_size[1] = max_size |
| new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index |
| new_tensor[:, : old_size[1]] = tensor |
| return new_tensor |
|
|
| def prediction_step( |
| self, |
| model: nn.Module, |
| inputs: Dict[str, Union[torch.Tensor, Any]], |
| prediction_loss_only: bool, |
| ignore_keys: Optional[List[str]] = None, |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Perform an evaluation step on `model` using `inputs`. |
| |
| Subclass and override to inject custom behavior. |
| |
| Args: |
| model (`nn.Module`): |
| The model to evaluate. |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| argument `labels`. Check your model's documentation for all accepted arguments. |
| prediction_loss_only (`bool`): |
| Whether or not to return the loss only. |
| ignore_keys (`Lst[str]`, *optional*): |
| A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| gathering predictions. |
| |
| Return: |
| Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, |
| logits and labels (each being optional). |
| """ |
| has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) |
| |
| |
| |
| return_loss = inputs.get("return_loss", None) |
| if return_loss is None: |
| return_loss = self.can_return_loss |
| loss_without_labels = True if len(self.label_names) == 0 and return_loss else False |
|
|
| inputs = self._prepare_inputs(inputs) |
| if ignore_keys is None: |
| if hasattr(self.model, "config"): |
| ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) |
| else: |
| ignore_keys = [] |
|
|
| |
| if has_labels or loss_without_labels: |
| labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) |
| if len(labels) == 1: |
| labels = labels[0] |
| else: |
| labels = None |
|
|
| with torch.no_grad(): |
| if is_sagemaker_mp_enabled(): |
| raw_outputs = smp_forward_only(model, inputs) |
| if has_labels or loss_without_labels: |
| if isinstance(raw_outputs, dict): |
| loss_mb = raw_outputs["loss"] |
| logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) |
| else: |
| loss_mb = raw_outputs[0] |
| logits_mb = raw_outputs[1:] |
|
|
| loss = loss_mb.reduce_mean().detach().cpu() |
| logits = smp_nested_concat(logits_mb) |
| else: |
| loss = None |
| if isinstance(raw_outputs, dict): |
| logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) |
| else: |
| logits_mb = raw_outputs |
| logits = smp_nested_concat(logits_mb) |
| else: |
| if has_labels or loss_without_labels: |
| with self.compute_loss_context_manager(): |
| loss, outputs = self.compute_loss(model, inputs, return_outputs=True) |
| loss = loss.mean().detach() |
|
|
| if isinstance(outputs, dict): |
| logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) |
| else: |
| logits = outputs[1:] |
| else: |
| loss = None |
| with self.compute_loss_context_manager(): |
| outputs = model(**inputs) |
| if isinstance(outputs, dict): |
| logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) |
| else: |
| logits = outputs |
| |
| if self.args.past_index >= 0: |
| self._past = outputs[self.args.past_index - 1] |
|
|
| if prediction_loss_only: |
| return (loss, None, None) |
|
|
| logits = nested_detach(logits) |
| if len(logits) == 1: |
| logits = logits[0] |
|
|
| return (loss, logits, labels) |
|
|
| def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): |
| """ |
| For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point |
| operations for every backward + forward pass. If using another model, either implement such a method in the |
| model or subclass and override this method. |
| |
| Args: |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| Returns: |
| `int`: The number of floating-point operations. |
| """ |
| if hasattr(self.model, "floating_point_ops"): |
| return self.model.floating_point_ops(inputs) |
| else: |
| return 0 |
|
|
| def init_git_repo(self, at_init: bool = False): |
| """ |
| Initializes a git repo in `self.args.hub_model_id`. |
| |
| Args: |
| at_init (`bool`, *optional*, defaults to `False`): |
| Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is |
| `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped |
| out. |
| """ |
| if not self.is_world_process_zero(): |
| return |
| if self.args.hub_model_id is None: |
| repo_name = Path(self.args.output_dir).absolute().name |
| else: |
| repo_name = self.args.hub_model_id |
| if "/" not in repo_name: |
| repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) |
|
|
| |
| create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) |
| try: |
| self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) |
| except EnvironmentError: |
| if self.args.overwrite_output_dir and at_init: |
| |
| shutil.rmtree(self.args.output_dir) |
| self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) |
| else: |
| raise |
|
|
| self.repo.git_pull() |
|
|
| |
| if ( |
| not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) |
| and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS |
| ): |
| with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: |
| writer.writelines(["checkpoint-*/"]) |
|
|
| |
| if os.environ.get("SM_TRAINING_ENV"): |
| self._add_sm_patterns_to_gitignore() |
|
|
| self.push_in_progress = None |
|
|
| def create_model_card( |
| self, |
| language: Optional[str] = None, |
| license: Optional[str] = None, |
| tags: Union[str, List[str], None] = None, |
| model_name: Optional[str] = None, |
| finetuned_from: Optional[str] = None, |
| tasks: Union[str, List[str], None] = None, |
| dataset_tags: Union[str, List[str], None] = None, |
| dataset: Union[str, List[str], None] = None, |
| dataset_args: Union[str, List[str], None] = None, |
| ): |
| """ |
| Creates a draft of a model card using the information available to the `Trainer`. |
| |
| Args: |
| language (`str`, *optional*): |
| The language of the model (if applicable) |
| license (`str`, *optional*): |
| The license of the model. Will default to the license of the pretrained model used, if the original |
| model given to the `Trainer` comes from a repo on the Hub. |
| tags (`str` or `List[str]`, *optional*): |
| Some tags to be included in the metadata of the model card. |
| model_name (`str`, *optional*): |
| The name of the model. |
| finetuned_from (`str`, *optional*): |
| The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo |
| of the original model given to the `Trainer` (if it comes from the Hub). |
| tasks (`str` or `List[str]`, *optional*): |
| One or several task identifiers, to be included in the metadata of the model card. |
| dataset_tags (`str` or `List[str]`, *optional*): |
| One or several dataset tags, to be included in the metadata of the model card. |
| dataset (`str` or `List[str]`, *optional*): |
| One or several dataset identifiers, to be included in the metadata of the model card. |
| dataset_args (`str` or `List[str]`, *optional*): |
| One or several dataset arguments, to be included in the metadata of the model card. |
| """ |
| if not self.is_world_process_zero(): |
| return |
|
|
| training_summary = TrainingSummary.from_trainer( |
| self, |
| language=language, |
| license=license, |
| tags=tags, |
| model_name=model_name, |
| finetuned_from=finetuned_from, |
| tasks=tasks, |
| dataset_tags=dataset_tags, |
| dataset=dataset, |
| dataset_args=dataset_args, |
| ) |
| model_card = training_summary.to_model_card() |
| with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: |
| f.write(model_card) |
|
|
| def _push_from_checkpoint(self, checkpoint_folder): |
| |
| if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: |
| return |
| |
| if self.push_in_progress is not None and not self.push_in_progress.is_done: |
| return |
|
|
| output_dir = self.args.output_dir |
| |
| modeling_files = [CONFIG_NAME, WEIGHTS_NAME] |
| for modeling_file in modeling_files: |
| if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): |
| shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) |
| |
| if self.tokenizer is not None: |
| self.tokenizer.save_pretrained(output_dir) |
| |
| torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
| try: |
| if self.args.hub_strategy == HubStrategy.CHECKPOINT: |
| |
| tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") |
| |
| |
| if os.path.isdir(tmp_checkpoint): |
| shutil.rmtree(tmp_checkpoint) |
| shutil.move(checkpoint_folder, tmp_checkpoint) |
|
|
| if self.args.save_strategy == IntervalStrategy.STEPS: |
| commit_message = f"Training in progress, step {self.state.global_step}" |
| else: |
| commit_message = f"Training in progress, epoch {int(self.state.epoch)}" |
| _, self.push_in_progress = self.repo.push_to_hub( |
| commit_message=commit_message, blocking=False, auto_lfs_prune=True |
| ) |
| finally: |
| if self.args.hub_strategy == HubStrategy.CHECKPOINT: |
| |
| shutil.move(tmp_checkpoint, checkpoint_folder) |
|
|
| def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: |
| """ |
| Upload *self.model* and *self.tokenizer* to the π€ model hub on the repo *self.args.hub_model_id*. |
| |
| Parameters: |
| commit_message (`str`, *optional*, defaults to `"End of training"`): |
| Message to commit while pushing. |
| blocking (`bool`, *optional*, defaults to `True`): |
| Whether the function should return only when the `git push` has finished. |
| kwargs: |
| Additional keyword arguments passed along to [`~Trainer.create_model_card`]. |
| |
| Returns: |
| The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of |
| the commit and an object to track the progress of the commit if `blocking=True` |
| """ |
| |
| |
| if not hasattr(self, "repo"): |
| self.init_git_repo() |
|
|
| model_name = kwargs.pop("model_name", None) |
| if model_name is None and self.args.should_save: |
| if self.args.hub_model_id is None: |
| model_name = Path(self.args.output_dir).name |
| else: |
| model_name = self.args.hub_model_id.split("/")[-1] |
|
|
| |
| |
| self.save_model(_internal_call=True) |
|
|
| |
| if not self.is_world_process_zero(): |
| return |
|
|
| |
| if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: |
| self.push_in_progress._process.kill() |
| self.push_in_progress = None |
|
|
| git_head_commit_url = self.repo.push_to_hub( |
| commit_message=commit_message, blocking=blocking, auto_lfs_prune=True |
| ) |
| |
| if self.args.should_save: |
| self.create_model_card(model_name=model_name, **kwargs) |
| try: |
| self.repo.push_to_hub( |
| commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True |
| ) |
| except EnvironmentError as exc: |
| logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") |
|
|
| return git_head_commit_url |
|
|
| |
| |
| |
|
|
| def prediction_loop( |
| self, |
| dataloader: DataLoader, |
| description: str, |
| prediction_loss_only: Optional[bool] = None, |
| ignore_keys: Optional[List[str]] = None, |
| metric_key_prefix: str = "eval", |
| ) -> EvalLoopOutput: |
| """ |
| Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. |
| |
| Works both with or without labels. |
| """ |
| args = self.args |
|
|
| if not has_length(dataloader): |
| raise ValueError("dataloader must implement a working __len__") |
|
|
| prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only |
|
|
| |
| if args.deepspeed and not self.deepspeed: |
| |
| |
| deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) |
| self.model = deepspeed_engine.module |
| self.model_wrapped = deepspeed_engine |
| self.deepspeed = deepspeed_engine |
| |
| |
| |
| deepspeed_engine.optimizer.optimizer = None |
| deepspeed_engine.lr_scheduler = None |
|
|
| model = self._wrap_model(self.model, training=False, dataloader=dataloader) |
|
|
| |
| |
| if not self.is_in_train: |
| if args.fp16_full_eval: |
| model = model.to(dtype=torch.float16, device=args.device) |
| elif args.bf16_full_eval: |
| model = model.to(dtype=torch.bfloat16, device=args.device) |
|
|
| batch_size = dataloader.batch_size |
| num_examples = self.num_examples(dataloader) |
| logger.info(f"***** Running {description} *****") |
| logger.info(f" Num examples = {num_examples}") |
| logger.info(f" Batch size = {batch_size}") |
| losses_host: torch.Tensor = None |
| preds_host: Union[torch.Tensor, List[torch.Tensor]] = None |
| labels_host: Union[torch.Tensor, List[torch.Tensor]] = None |
| inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None |
|
|
| world_size = max(1, args.world_size) |
|
|
| eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) |
| if not prediction_loss_only: |
| |
| |
| make_multiple_of = None |
| if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): |
| make_multiple_of = dataloader.sampler.batch_size |
| preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) |
| labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) |
| inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) |
|
|
| model.eval() |
|
|
| if is_torch_tpu_available(): |
| dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) |
|
|
| if args.past_index >= 0: |
| self._past = None |
|
|
| self.callback_handler.eval_dataloader = dataloader |
|
|
| for step, inputs in enumerate(dataloader): |
| loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) |
| inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None |
|
|
| if loss is not None: |
| losses = loss.repeat(batch_size) |
| losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) |
| if logits is not None: |
| preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) |
| if labels is not None: |
| labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) |
| if inputs_decode is not None: |
| inputs_host = ( |
| inputs_decode |
| if inputs_host is None |
| else nested_concat(inputs_host, inputs_decode, padding_index=-100) |
| ) |
| self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) |
|
|
| |
| if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: |
| eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) |
| if not prediction_loss_only: |
| preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) |
| labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) |
| inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) |
|
|
| |
| losses_host, preds_host, labels_host, inputs_host = None, None, None, None |
|
|
| if args.past_index and hasattr(self, "_past"): |
| |
| delattr(self, "_past") |
|
|
| |
| eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) |
| if not prediction_loss_only: |
| preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) |
| labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) |
| inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) |
|
|
| eval_loss = eval_losses_gatherer.finalize() |
| preds = preds_gatherer.finalize() if not prediction_loss_only else None |
| label_ids = labels_gatherer.finalize() if not prediction_loss_only else None |
| inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None |
|
|
| if self.compute_metrics is not None and preds is not None and label_ids is not None: |
| if args.include_inputs_for_metrics: |
| metrics = self.compute_metrics( |
| EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) |
| ) |
| else: |
| metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) |
| else: |
| metrics = {} |
|
|
| |
| metrics = denumpify_detensorize(metrics) |
|
|
| if eval_loss is not None: |
| metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() |
|
|
| |
| for key in list(metrics.keys()): |
| if not key.startswith(f"{metric_key_prefix}_"): |
| metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) |
|
|
| return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) |
|
|
| def _gather_and_numpify(self, tensors, name): |
| """ |
| Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before |
| concatenating them to `gathered` |
| """ |
| if tensors is None: |
| return |
| if is_torch_tpu_available(): |
| tensors = nested_xla_mesh_reduce(tensors, name) |
| elif is_sagemaker_mp_enabled(): |
| tensors = smp_gather(tensors) |
| elif self.args.local_rank != -1: |
| tensors = distributed_concat(tensors) |
|
|
| return nested_numpify(tensors) |
|
|
| def _add_sm_patterns_to_gitignore(self) -> None: |
| """Add SageMaker Checkpointing patterns to .gitignore file.""" |
| |
| if not self.is_world_process_zero(): |
| return |
|
|
| patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] |
|
|
| |
| if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): |
| with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: |
| current_content = f.read() |
| else: |
| current_content = "" |
|
|
| |
| content = current_content |
| for pattern in patterns: |
| if pattern not in content: |
| if content.endswith("\n"): |
| content += pattern |
| else: |
| content += f"\n{pattern}" |
|
|
| |
| if content != current_content: |
| with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: |
| logger.debug(f"Writing .gitignore file. Content: {content}") |
| f.write(content) |
|
|
| self.repo.git_add(".gitignore") |
|
|
| |
| time.sleep(0.5) |
|
|
| if not self.repo.is_repo_clean(): |
| self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") |
| self.repo.git_push() |
|
|