| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| import itertools |
| import os |
| from functools import partial |
| from typing import Any, Callable, List, Optional, Tuple, Union |
|
|
| import torch |
| from torch import Tensor, device |
|
|
| from diffusers import __version__ |
| from diffusers.utils import ( |
| CONFIG_NAME, |
| DIFFUSERS_CACHE, |
| FLAX_WEIGHTS_NAME, |
| HF_HUB_OFFLINE, |
| SAFETENSORS_WEIGHTS_NAME, |
| WEIGHTS_NAME, |
| _add_variant, |
| _get_model_file, |
| deprecate, |
| is_accelerate_available, |
| is_safetensors_available, |
| is_torch_version, |
| logging, |
| ) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| if is_torch_version(">=", "1.9.0"): |
| _LOW_CPU_MEM_USAGE_DEFAULT = True |
| else: |
| _LOW_CPU_MEM_USAGE_DEFAULT = False |
|
|
|
|
| if is_accelerate_available(): |
| import accelerate |
| from accelerate.utils import set_module_tensor_to_device |
| from accelerate.utils.versions import is_torch_version |
|
|
| if is_safetensors_available(): |
| import safetensors |
|
|
|
|
| def get_parameter_device(parameter: torch.nn.Module): |
| try: |
| parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) |
| return next(parameters_and_buffers).device |
| except StopIteration: |
| |
|
|
| def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: |
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
| return tuples |
|
|
| gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
| first_tuple = next(gen) |
| return first_tuple[1].device |
|
|
|
|
| def get_parameter_dtype(parameter: torch.nn.Module): |
| try: |
| params = tuple(parameter.parameters()) |
| if len(params) > 0: |
| return params[0].dtype |
|
|
| buffers = tuple(parameter.buffers()) |
| if len(buffers) > 0: |
| return buffers[0].dtype |
|
|
| except StopIteration: |
| |
|
|
| def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: |
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
| return tuples |
|
|
| gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
| first_tuple = next(gen) |
| return first_tuple[1].dtype |
|
|
|
|
| def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): |
| """ |
| Reads a checkpoint file, returning properly formatted errors if they arise. |
| """ |
| try: |
| if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): |
| return torch.load(checkpoint_file, map_location="cpu") |
| else: |
| return safetensors.torch.load_file(checkpoint_file, device="cpu") |
| except Exception as e: |
| try: |
| with open(checkpoint_file) as f: |
| if f.read().startswith("version"): |
| raise OSError( |
| "You seem to have cloned a repository without having git-lfs installed. Please install " |
| "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
| "you cloned." |
| ) |
| else: |
| raise ValueError( |
| f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " |
| "model. Make sure you have saved the model properly." |
| ) from e |
| except (UnicodeDecodeError, ValueError): |
| raise OSError( |
| f"Unable to load weights from checkpoint file for '{checkpoint_file}' " |
| f"at '{checkpoint_file}'. " |
| "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." |
| ) |
|
|
|
|
| def _load_state_dict_into_model(model_to_load, state_dict): |
| |
| |
| state_dict = state_dict.copy() |
| error_msgs = [] |
|
|
| |
| |
| def load(module: torch.nn.Module, prefix=""): |
| args = (state_dict, prefix, {}, True, [], [], error_msgs) |
| module._load_from_state_dict(*args) |
|
|
| for name, child in module._modules.items(): |
| if child is not None: |
| load(child, prefix + name + ".") |
|
|
| load(model_to_load) |
|
|
| return error_msgs |
|
|
|
|
| class ModelMixin(torch.nn.Module): |
| r""" |
| Base class for all models. |
| |
| [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading |
| and saving models. |
| |
| - **config_name** ([`str`]) -- A filename under which the model should be stored when calling |
| [`~models.ModelMixin.save_pretrained`]. |
| """ |
| config_name = CONFIG_NAME |
| _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] |
| _supports_gradient_checkpointing = False |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def __getattr__(self, name: str) -> Any: |
| """The only reason we overwrite `getattr` here is to gracefully deprecate accessing |
| config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite |
| __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': |
| https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module |
| """ |
|
|
| is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) |
| is_attribute = name in self.__dict__ |
|
|
| if is_in_config and not is_attribute: |
| deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." |
| deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) |
| return self._internal_dict[name] |
|
|
| |
| return super().__getattr__(name) |
|
|
| @property |
| def is_gradient_checkpointing(self) -> bool: |
| """ |
| Whether gradient checkpointing is activated for this model or not. |
| |
| Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint |
| activations". |
| """ |
| return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) |
|
|
| def enable_gradient_checkpointing(self): |
| """ |
| Activates gradient checkpointing for the current model. |
| |
| Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint |
| activations". |
| """ |
| if not self._supports_gradient_checkpointing: |
| raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") |
| self.apply(partial(self._set_gradient_checkpointing, value=True)) |
|
|
| def disable_gradient_checkpointing(self): |
| """ |
| Deactivates gradient checkpointing for the current model. |
| |
| Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint |
| activations". |
| """ |
| if self._supports_gradient_checkpointing: |
| self.apply(partial(self._set_gradient_checkpointing, value=False)) |
|
|
| def set_use_memory_efficient_attention_xformers( |
| self, valid: bool, attention_op: Optional[Callable] = None |
| ) -> None: |
| |
| |
| |
| def fn_recursive_set_mem_eff(module: torch.nn.Module): |
| if hasattr(module, "set_use_memory_efficient_attention_xformers"): |
| module.set_use_memory_efficient_attention_xformers(valid, attention_op) |
|
|
| for child in module.children(): |
| fn_recursive_set_mem_eff(child) |
|
|
| for module in self.children(): |
| if isinstance(module, torch.nn.Module): |
| fn_recursive_set_mem_eff(module) |
|
|
| def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): |
| r""" |
| Enable memory efficient attention as implemented in xformers. |
| |
| When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference |
| time. Speed up at training time is not guaranteed. |
| |
| Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention |
| is used. |
| |
| Parameters: |
| attention_op (`Callable`, *optional*): |
| Override the default `None` operator for use as `op` argument to the |
| [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) |
| function of xFormers. |
| |
| Examples: |
| |
| ```py |
| >>> import torch |
| >>> from diffusers import UNet2DConditionModel |
| >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
| |
| >>> model = UNet2DConditionModel.from_pretrained( |
| ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 |
| ... ) |
| >>> model = model.to("cuda") |
| >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) |
| ``` |
| """ |
| self.set_use_memory_efficient_attention_xformers(True, attention_op) |
|
|
| def disable_xformers_memory_efficient_attention(self): |
| r""" |
| Disable memory efficient attention as implemented in xformers. |
| """ |
| self.set_use_memory_efficient_attention_xformers(False) |
|
|
| def save_pretrained( |
| self, |
| save_directory: Union[str, os.PathLike], |
| is_main_process: bool = True, |
| save_function: Callable = None, |
| safe_serialization: bool = False, |
| variant: Optional[str] = None, |
| ): |
| """ |
| Save a model and its configuration file to a directory, so that it can be re-loaded using the |
| `[`~models.ModelMixin.from_pretrained`]` class method. |
| |
| Arguments: |
| save_directory (`str` or `os.PathLike`): |
| Directory to which to save. Will be created if it doesn't exist. |
| is_main_process (`bool`, *optional*, defaults to `True`): |
| Whether the process calling this is the main process or not. Useful when in distributed training like |
| TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on |
| the main process to avoid race conditions. |
| save_function (`Callable`): |
| The function to use to save the state dictionary. Useful on distributed training like TPUs when one |
| need to replace `torch.save` by another method. Can be configured with the environment variable |
| `DIFFUSERS_SAVE_MODE`. |
| safe_serialization (`bool`, *optional*, defaults to `False`): |
| Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). |
| variant (`str`, *optional*): |
| If specified, weights are saved in the format pytorch_model.<variant>.bin. |
| """ |
| if safe_serialization and not is_safetensors_available(): |
| raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") |
|
|
| if os.path.isfile(save_directory): |
| logger.error(f"Provided path ({save_directory}) should be a directory, not a file") |
| return |
|
|
| os.makedirs(save_directory, exist_ok=True) |
|
|
| model_to_save = self |
|
|
| |
| |
| if is_main_process: |
| model_to_save.save_config(save_directory) |
|
|
| |
| state_dict = model_to_save.state_dict() |
|
|
| weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME |
| weights_name = _add_variant(weights_name, variant) |
|
|
| |
| if safe_serialization: |
| safetensors.torch.save_file( |
| state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} |
| ) |
| else: |
| torch.save(state_dict, os.path.join(save_directory, weights_name)) |
|
|
| logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): |
| r""" |
| Instantiate a pretrained pytorch model from a pre-trained model configuration. |
| |
| The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train |
| the model, you should first set it back in training mode with `model.train()`. |
| |
| The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come |
| pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning |
| task. |
| |
| The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those |
| weights are discarded. |
| |
| Parameters: |
| pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): |
| Can be either: |
| |
| - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. |
| - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., |
| `./my_model_directory/`. |
| |
| cache_dir (`Union[str, os.PathLike]`, *optional*): |
| Path to a directory in which a downloaded pretrained model configuration should be cached if the |
| standard cache should not be used. |
| torch_dtype (`str` or `torch.dtype`, *optional*): |
| Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype |
| will be automatically derived from the model's weights. |
| force_download (`bool`, *optional*, defaults to `False`): |
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| cached versions if they exist. |
| resume_download (`bool`, *optional*, defaults to `False`): |
| Whether or not to delete incompletely received files. Will attempt to resume the download if such a |
| file exists. |
| proxies (`Dict[str, str]`, *optional*): |
| A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', |
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
| output_loading_info(`bool`, *optional*, defaults to `False`): |
| Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. |
| local_files_only(`bool`, *optional*, defaults to `False`): |
| Whether or not to only look at local files (i.e., do not try to download the model). |
| use_auth_token (`str` or *bool*, *optional*): |
| The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated |
| when running `diffusers-cli login` (stored in `~/.huggingface`). |
| revision (`str`, *optional*, defaults to `"main"`): |
| The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
| git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any |
| identifier allowed by git. |
| from_flax (`bool`, *optional*, defaults to `False`): |
| Load the model weights from a Flax checkpoint save file. |
| subfolder (`str`, *optional*, defaults to `""`): |
| In case the relevant files are located inside a subfolder of the model repo (either remote in |
| huggingface.co or downloaded locally), you can specify the folder name here. |
| |
| mirror (`str`, *optional*): |
| Mirror source to accelerate downloads in China. If you are from China and have an accessibility |
| problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. |
| Please refer to the mirror site for more information. |
| device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): |
| A map that specifies where each submodule should go. It doesn't need to be refined to each |
| parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the |
| same device. |
| |
| To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For |
| more information about each option see [designing a device |
| map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). |
| max_memory (`Dict`, *optional*): |
| A dictionary device identifier to maximum memory. Will default to the maximum memory available for each |
| GPU and the available CPU RAM if unset. |
| offload_folder (`str` or `os.PathLike`, *optional*): |
| If the `device_map` contains any value `"disk"`, the folder where we will offload weights. |
| offload_state_dict (`bool`, *optional*): |
| If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU |
| RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to |
| `True` when there is some disk offload. |
| low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): |
| Speed up model loading by not initializing the weights and only loading the pre-trained weights. This |
| also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the |
| model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, |
| setting this argument to `True` will raise an error. |
| variant (`str`, *optional*): |
| If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is |
| ignored when using `from_flax`. |
| use_safetensors (`bool`, *optional*, defaults to `None`): |
| If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the |
| `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from |
| `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. |
| |
| <Tip> |
| |
| It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated |
| models](https://huggingface.co/docs/hub/models-gated#gated-models). |
| |
| </Tip> |
| |
| <Tip> |
| |
| Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use |
| this method in a firewalled environment. |
| |
| </Tip> |
| |
| """ |
| cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) |
| force_download = kwargs.pop("force_download", False) |
| from_flax = kwargs.pop("from_flax", False) |
| resume_download = kwargs.pop("resume_download", False) |
| proxies = kwargs.pop("proxies", None) |
| output_loading_info = kwargs.pop("output_loading_info", False) |
| local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) |
| use_auth_token = kwargs.pop("use_auth_token", None) |
| revision = kwargs.pop("revision", None) |
| torch_dtype = kwargs.pop("torch_dtype", None) |
| subfolder = kwargs.pop("subfolder", None) |
| device_map = kwargs.pop("device_map", None) |
| max_memory = kwargs.pop("max_memory", None) |
| offload_folder = kwargs.pop("offload_folder", None) |
| offload_state_dict = kwargs.pop("offload_state_dict", False) |
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) |
| variant = kwargs.pop("variant", None) |
| use_safetensors = kwargs.pop("use_safetensors", None) |
|
|
| if use_safetensors and not is_safetensors_available(): |
| raise ValueError( |
| "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" |
| ) |
|
|
| allow_pickle = False |
| if use_safetensors is None: |
| use_safetensors = is_safetensors_available() |
| allow_pickle = True |
|
|
| if low_cpu_mem_usage and not is_accelerate_available(): |
| low_cpu_mem_usage = False |
| logger.warning( |
| "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" |
| " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" |
| " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" |
| " install accelerate\n```\n." |
| ) |
|
|
| if device_map is not None and not is_accelerate_available(): |
| raise NotImplementedError( |
| "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" |
| " `device_map=None`. You can install accelerate with `pip install accelerate`." |
| ) |
|
|
| |
| if device_map is not None and not is_torch_version(">=", "1.9.0"): |
| raise NotImplementedError( |
| "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" |
| " `device_map=None`." |
| ) |
|
|
| if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): |
| raise NotImplementedError( |
| "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" |
| " `low_cpu_mem_usage=False`." |
| ) |
|
|
| if low_cpu_mem_usage is False and device_map is not None: |
| raise ValueError( |
| f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" |
| " dispatching. Please make sure to set `low_cpu_mem_usage=True`." |
| ) |
|
|
| |
| config_path = pretrained_model_name_or_path |
|
|
| user_agent = { |
| "diffusers": __version__, |
| "file_type": "model", |
| "framework": "pytorch", |
| } |
|
|
| |
| config, unused_kwargs, commit_hash = cls.load_config( |
| config_path, |
| cache_dir=cache_dir, |
| return_unused_kwargs=True, |
| return_commit_hash=True, |
| force_download=force_download, |
| resume_download=resume_download, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| use_auth_token=use_auth_token, |
| revision=revision, |
| subfolder=subfolder, |
| device_map=device_map, |
| max_memory=max_memory, |
| offload_folder=offload_folder, |
| offload_state_dict=offload_state_dict, |
| user_agent=user_agent, |
| **kwargs, |
| ) |
|
|
| |
| model_file = None |
| if from_flax: |
| model_file = _get_model_file( |
| pretrained_model_name_or_path, |
| weights_name=FLAX_WEIGHTS_NAME, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| resume_download=resume_download, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| use_auth_token=use_auth_token, |
| revision=revision, |
| subfolder=subfolder, |
| user_agent=user_agent, |
| commit_hash=commit_hash, |
| ) |
| model = cls.from_config(config, **unused_kwargs) |
|
|
| |
| from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model |
|
|
| model = load_flax_checkpoint_in_pytorch_model(model, model_file) |
| else: |
| if use_safetensors: |
| try: |
| model_file = _get_model_file( |
| pretrained_model_name_or_path, |
| weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), |
| cache_dir=cache_dir, |
| force_download=force_download, |
| resume_download=resume_download, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| use_auth_token=use_auth_token, |
| revision=revision, |
| subfolder=subfolder, |
| user_agent=user_agent, |
| commit_hash=commit_hash, |
| ) |
| except IOError as e: |
| if not allow_pickle: |
| raise e |
| pass |
| if model_file is None: |
| model_file = _get_model_file( |
| pretrained_model_name_or_path, |
| weights_name=_add_variant(WEIGHTS_NAME, variant), |
| cache_dir=cache_dir, |
| force_download=force_download, |
| resume_download=resume_download, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| use_auth_token=use_auth_token, |
| revision=revision, |
| subfolder=subfolder, |
| user_agent=user_agent, |
| commit_hash=commit_hash, |
| ) |
|
|
| if low_cpu_mem_usage: |
| |
| with accelerate.init_empty_weights(): |
| model = cls.from_config(config, **unused_kwargs) |
|
|
| |
| if device_map is None: |
| param_device = "cpu" |
| state_dict = load_state_dict(model_file, variant=variant) |
| model._convert_deprecated_attention_blocks(state_dict) |
| |
| missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
| if len(missing_keys) > 0: |
| raise ValueError( |
| f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" |
| f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" |
| " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" |
| " those weights or else make sure your checkpoint file is correct." |
| ) |
|
|
| empty_state_dict = model.state_dict() |
| for param_name, param in state_dict.items(): |
| accepts_dtype = "dtype" in set( |
| inspect.signature(set_module_tensor_to_device).parameters.keys() |
| ) |
|
|
| if empty_state_dict[param_name].shape != param.shape: |
| raise ValueError( |
| f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." |
| ) |
|
|
| if accepts_dtype: |
| set_module_tensor_to_device( |
| model, param_name, param_device, value=param, dtype=torch_dtype |
| ) |
| else: |
| set_module_tensor_to_device(model, param_name, param_device, value=param) |
| else: |
| |
| |
| accelerate.load_checkpoint_and_dispatch( |
| model, |
| model_file, |
| device_map, |
| max_memory=max_memory, |
| offload_folder=offload_folder, |
| offload_state_dict=offload_state_dict, |
| dtype=torch_dtype, |
| ) |
|
|
| loading_info = { |
| "missing_keys": [], |
| "unexpected_keys": [], |
| "mismatched_keys": [], |
| "error_msgs": [], |
| } |
| else: |
| model = cls.from_config(config, **unused_kwargs) |
|
|
| state_dict = load_state_dict(model_file, variant=variant) |
| model._convert_deprecated_attention_blocks(state_dict) |
|
|
| model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( |
| model, |
| state_dict, |
| model_file, |
| pretrained_model_name_or_path, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| ) |
|
|
| loading_info = { |
| "missing_keys": missing_keys, |
| "unexpected_keys": unexpected_keys, |
| "mismatched_keys": mismatched_keys, |
| "error_msgs": error_msgs, |
| } |
|
|
| if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): |
| raise ValueError( |
| f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." |
| ) |
| elif torch_dtype is not None: |
| model = model.to(torch_dtype) |
|
|
| model.register_to_config(_name_or_path=pretrained_model_name_or_path) |
|
|
| |
| model.eval() |
| if output_loading_info: |
| return model, loading_info |
|
|
| return model |
|
|
| @classmethod |
| def _load_pretrained_model( |
| cls, |
| model, |
| state_dict, |
| resolved_archive_file, |
| pretrained_model_name_or_path, |
| ignore_mismatched_sizes=False, |
| ): |
| |
| model_state_dict = model.state_dict() |
| loaded_keys = list(state_dict.keys()) |
|
|
| expected_keys = list(model_state_dict.keys()) |
|
|
| original_loaded_keys = loaded_keys |
|
|
| missing_keys = list(set(expected_keys) - set(loaded_keys)) |
| unexpected_keys = list(set(loaded_keys) - set(expected_keys)) |
|
|
| |
| model_to_load = model |
|
|
| def _find_mismatched_keys( |
| state_dict, |
| model_state_dict, |
| loaded_keys, |
| ignore_mismatched_sizes, |
| ): |
| mismatched_keys = [] |
| if ignore_mismatched_sizes: |
| for checkpoint_key in loaded_keys: |
| model_key = checkpoint_key |
|
|
| if ( |
| model_key in model_state_dict |
| and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape |
| ): |
| mismatched_keys.append( |
| (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) |
| ) |
| del state_dict[checkpoint_key] |
| return mismatched_keys |
|
|
| if state_dict is not None: |
| |
| mismatched_keys = _find_mismatched_keys( |
| state_dict, |
| model_state_dict, |
| original_loaded_keys, |
| ignore_mismatched_sizes, |
| ) |
| error_msgs = _load_state_dict_into_model(model_to_load, state_dict) |
|
|
| if len(error_msgs) > 0: |
| error_msg = "\n\t".join(error_msgs) |
| if "size mismatch" in error_msg: |
| error_msg += ( |
| "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." |
| ) |
| raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") |
|
|
| if len(unexpected_keys) > 0: |
| logger.warning( |
| f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" |
| f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" |
| f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" |
| " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" |
| " BertForPreTraining model).\n- This IS NOT expected if you are initializing" |
| f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" |
| " identical (initializing a BertForSequenceClassification model from a" |
| " BertForSequenceClassification model)." |
| ) |
| else: |
| logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") |
| if len(missing_keys) > 0: |
| logger.warning( |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" |
| f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" |
| " TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
| ) |
| elif len(mismatched_keys) == 0: |
| logger.info( |
| f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" |
| f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" |
| f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" |
| " without further training." |
| ) |
| if len(mismatched_keys) > 0: |
| mismatched_warning = "\n".join( |
| [ |
| f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
| for key, shape1, shape2 in mismatched_keys |
| ] |
| ) |
| logger.warning( |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" |
| f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" |
| f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" |
| " able to use it for predictions and inference." |
| ) |
|
|
| return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs |
|
|
| @property |
| def device(self) -> device: |
| """ |
| `torch.device`: The device on which the module is (assuming that all the module parameters are on the same |
| device). |
| """ |
| return get_parameter_device(self) |
|
|
| @property |
| def dtype(self) -> torch.dtype: |
| """ |
| `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). |
| """ |
| return get_parameter_dtype(self) |
|
|
| def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: |
| """ |
| Get number of (optionally, trainable or non-embeddings) parameters in the module. |
| |
| Args: |
| only_trainable (`bool`, *optional*, defaults to `False`): |
| Whether or not to return only the number of trainable parameters |
| |
| exclude_embeddings (`bool`, *optional*, defaults to `False`): |
| Whether or not to return only the number of non-embeddings parameters |
| |
| Returns: |
| `int`: The number of parameters. |
| """ |
|
|
| if exclude_embeddings: |
| embedding_param_names = [ |
| f"{name}.weight" |
| for name, module_type in self.named_modules() |
| if isinstance(module_type, torch.nn.Embedding) |
| ] |
| non_embedding_parameters = [ |
| parameter for name, parameter in self.named_parameters() if name not in embedding_param_names |
| ] |
| return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) |
| else: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) |
|
|
| def _convert_deprecated_attention_blocks(self, state_dict): |
| deprecated_attention_block_paths = [] |
|
|
| def recursive_find_attn_block(name, module): |
| if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: |
| deprecated_attention_block_paths.append(name) |
|
|
| for sub_name, sub_module in module.named_children(): |
| sub_name = sub_name if name == "" else f"{name}.{sub_name}" |
| recursive_find_attn_block(sub_name, sub_module) |
|
|
| recursive_find_attn_block("", self) |
|
|
| |
| |
| |
|
|
| for path in deprecated_attention_block_paths: |
| |
|
|
| |
| if f"{path}.query.weight" in state_dict: |
| state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") |
| if f"{path}.query.bias" in state_dict: |
| state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") |
|
|
| |
| if f"{path}.key.weight" in state_dict: |
| state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") |
| if f"{path}.key.bias" in state_dict: |
| state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") |
|
|
| |
| if f"{path}.value.weight" in state_dict: |
| state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") |
| if f"{path}.value.bias" in state_dict: |
| state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") |
|
|
| |
| if f"{path}.proj_attn.weight" in state_dict: |
| state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") |
| if f"{path}.proj_attn.bias" in state_dict: |
| state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") |
|
|