HemanthSai7 commited on
Commit
63afef9
·
verified ·
1 Parent(s): 18694a9

Upload configuration_utils.py

Browse files
Files changed (1) hide show
  1. configuration_utils.py +1287 -0
configuration_utils.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Configuration base class and utilities."""
16
+
17
+ import copy
18
+ import json
19
+ import math
20
+ import os
21
+ from collections.abc import Sequence
22
+ from dataclasses import dataclass
23
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union
24
+
25
+ from huggingface_hub import create_repo
26
+ from huggingface_hub.dataclasses import strict
27
+ from packaging import version
28
+
29
+ from . import __version__
30
+ from .dynamic_module_utils import custom_object_save
31
+ from .generation.configuration_utils import GenerationConfig
32
+ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
33
+ from .modeling_rope_utils import RotaryEmbeddingConfigMixin
34
+ from .tokenization_utils_base import PreTrainedTokenizerBase
35
+ from .utils import (
36
+ CONFIG_NAME,
37
+ PushToHubMixin,
38
+ cached_file,
39
+ copy_func,
40
+ extract_commit_hash,
41
+ is_torch_available,
42
+ logging,
43
+ )
44
+ from .utils.generic import is_timm_config_dict
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ import torch
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ # type hinting: specifying the type of config class that inherits from PreTrainedConfig
55
+ SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
56
+
57
+ _FLOAT_TAG_KEY = "__float__"
58
+ _FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
59
+
60
+
61
+ ALLOWED_LAYER_TYPES = (
62
+ "full_attention",
63
+ "sliding_attention",
64
+ "chunked_attention",
65
+ "linear_attention", # used in minimax
66
+ "conv", # used in LFMv2
67
+ "mamba",
68
+ "attention",
69
+ "sparse",
70
+ "dense",
71
+ )
72
+
73
+
74
+ @strict(accept_kwargs=True)
75
+ @dataclass(repr=False)
76
+ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
77
+ # no-format
78
+ r"""
79
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
80
+ methods for loading/downloading/saving configurations.
81
+
82
+ <Tip>
83
+
84
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
85
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
86
+
87
+ </Tip>
88
+
89
+ Class attributes (overridden by derived classes):
90
+
91
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
92
+ the correct object in [`~transformers.AutoConfig`].
93
+ - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
94
+ Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
95
+ (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
96
+ two or more configs of type [`~transformers.PreTrainedConfig`].
97
+ - **keys_to_ignore_at_inference** (`list[str]`) -- A list of keys to ignore by default when looking at dictionary
98
+ outputs of the model during inference.
99
+ - **attribute_map** (`dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
100
+ naming of attributes.
101
+ - **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
102
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
103
+ - **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
104
+ pipeline parallel plan that enables users to place the child-module on the appropriate device.
105
+
106
+ Common attributes (present in all subclasses):
107
+
108
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
109
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
110
+ - **hidden_size** (`int`) -- The hidden size of the model.
111
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
112
+ model.
113
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
114
+
115
+ <Tip warning={true}>
116
+
117
+ Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
118
+ some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
119
+ them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
120
+ information about the individual parameters.
121
+
122
+ </Tip>
123
+
124
+ Arg:
125
+ name_or_path (`str`, *optional*, defaults to `""`):
126
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path`
127
+ if the configuration was created with such a method.
128
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
129
+ Whether or not the model should return all hidden-states.
130
+ output_attentions (`bool`, *optional*, defaults to `False`):
131
+ Whether or not the model should returns all attentions.
132
+ return_dict (`bool`, *optional*, defaults to `True`):
133
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
134
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
135
+ Whether the model is used as an encoder/decoder or not.
136
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
137
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
138
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
139
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
140
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
141
+
142
+ > Parameters for fine-tuning tasks
143
+
144
+ architectures (`list[str]`, *optional*):
145
+ Model architectures that can be used with the model pretrained weights.
146
+ id2label (`dict[int, str]`, *optional*):
147
+ A map from index (for instance prediction index, or target index) to label.
148
+ label2id (`dict[str, int]`, *optional*):
149
+ A map from label to index for the model.
150
+ num_labels (`int`, *optional*):
151
+ Number of labels to use in the last layer added to the model, typically for a classification task.
152
+ problem_type (`str`, *optional*):
153
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
154
+ `"single_label_classification"` or `"multi_label_classification"`.
155
+
156
+ > PyTorch specific parameters
157
+
158
+ dtype (`str`, *optional*):
159
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
160
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
161
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
162
+ `float16` weights.
163
+ """
164
+
165
+ # Class attributes that we don't want to save or have in `self.__dict__`
166
+ # They are not supposed to be set/changed by users. Each field is set when
167
+ # creating a model class
168
+ base_config_key: ClassVar[str] = ""
169
+ sub_configs: ClassVar[dict[str, type["PreTrainedConfig"]]] = {}
170
+ has_no_defaults_at_init: ClassVar[bool] = False
171
+ keys_to_ignore_at_inference: ClassVar[list[str]] = []
172
+ attribute_map: ClassVar[dict[str, str]] = {}
173
+ base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
174
+ base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
175
+ base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
176
+ _auto_class: ClassVar[str | None] = None
177
+
178
+ # Attributes set internally when saving and used to infer model
179
+ # class for `Auto` mapping
180
+ model_type: ClassVar[str] = ""
181
+ transformers_version: str | None = None
182
+ architectures: list[str] | None = None
183
+
184
+ # Common attributes for all models
185
+ output_hidden_states: bool | None = False
186
+ return_dict: bool | None = True
187
+ dtype: Union[str, "torch.dtype"] | None = None
188
+ chunk_size_feed_forward: int = 0
189
+ is_encoder_decoder: bool = False
190
+
191
+ # Fine-tuning task arguments
192
+ id2label: dict[int, str] | dict[str, str] | None = None
193
+ label2id: dict[str, int] | dict[str, str] | None = None
194
+ problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None
195
+
196
+ # Tokenizer kwargs
197
+ tokenizer_class: str | PreTrainedTokenizerBase | None = None
198
+
199
+ def __post_init__(self, **kwargs):
200
+ # BC for the `torch_dtype` argument instead of the simpler `dtype`
201
+ # Do not warn, as it would otherwise always be triggered since most configs on the hub have `torch_dtype`
202
+ if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
203
+ # If both are provided, keep `dtype`
204
+ self.dtype = self.dtype if self.dtype is not None else torch_dtype
205
+ if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available():
206
+ # we will start using self.dtype in v5, but to be consistent with
207
+ # from_pretrained's dtype arg convert it to an actual torch.dtype object
208
+ import torch
209
+
210
+ self.dtype = getattr(torch, self.dtype)
211
+
212
+ # Keep the default value of `num_labels=2` in case users have saved a classfier with 2 labels
213
+ # Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other
214
+ # cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise
215
+ if self.id2label is None:
216
+ self.num_labels = kwargs.get("num_labels", 2)
217
+ else:
218
+ if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"):
219
+ logger.warning(
220
+ f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to "
221
+ f"the `id2label` map of length `{len(self.id2label)}`."
222
+ )
223
+ # Keys are always strings in JSON so convert ids to int
224
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
225
+
226
+ # BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
227
+ if hasattr(self, "rope_parameters"):
228
+ kwargs = self.convert_rope_params_to_dict(**kwargs)
229
+
230
+ # Parameters for sequence generation saved in the config are popped instead of loading them.
231
+ for parameter_name in GenerationConfig._get_default_generation_params().keys():
232
+ kwargs.pop(parameter_name, None)
233
+
234
+ # Name or path to the pretrained checkpoint
235
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
236
+ self._commit_hash = kwargs.pop("_commit_hash", None)
237
+
238
+ # Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs)
239
+ self._output_attentions: bool | None = kwargs.pop("output_attentions", False)
240
+ self._attn_implementation: str | None = kwargs.pop("attn_implementation", None)
241
+ self._experts_implementation: str | None = kwargs.pop("experts_implementation", None)
242
+
243
+ # Additional attributes without default values
244
+ for key, value in kwargs.items():
245
+ # Check this to avoid deserializing problematic fields from hub configs - they should use the public field
246
+ if key not in ("_attn_implementation_internal", "_experts_implementation_internal"):
247
+ try:
248
+ setattr(self, key, value)
249
+ except AttributeError as err:
250
+ logger.error(f"Can't set {key} with value {value} for {self}")
251
+ raise err
252
+
253
+ def __init_subclass__(cls, *args, **kwargs):
254
+ super().__init_subclass__(*args, **kwargs)
255
+ cls = dataclass(cls, repr=False)
256
+
257
+ @property
258
+ def name_or_path(self) -> str | None:
259
+ return getattr(self, "_name_or_path", None)
260
+
261
+ @name_or_path.setter
262
+ def name_or_path(self, value):
263
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
264
+
265
+ @property
266
+ def num_labels(self) -> int:
267
+ """
268
+ `int`: The number of labels for classification models.
269
+ """
270
+ return len(self.id2label) if self.id2label is not None else None
271
+
272
+ @num_labels.setter
273
+ def num_labels(self, num_labels: int):
274
+ # we do not store `num_labels` attribute in config, but instead
275
+ # compute it based on the length of the `id2label` map
276
+ if self.id2label is None or self.num_labels != num_labels:
277
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
278
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
279
+
280
+ @property
281
+ def output_attentions(self):
282
+ """
283
+ `bool`: Whether or not the model should returns all attentions.
284
+ """
285
+ return self._output_attentions
286
+
287
+ @output_attentions.setter
288
+ def output_attentions(self, value: bool):
289
+ # If we set `output_attentions` explicitly before the attn implementation, dispatch eager
290
+ if value and self._attn_implementation is None:
291
+ self._attn_implementation = "eager"
292
+ if value and self._attn_implementation != "eager":
293
+ raise ValueError(
294
+ "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
295
+ f"{self._attn_implementation}. Please set it to 'eager' instead."
296
+ )
297
+ self._output_attentions = value
298
+
299
+ @property
300
+ def _attn_implementation(self):
301
+ return self._attn_implementation_internal
302
+
303
+ @_attn_implementation.setter
304
+ def _attn_implementation(self, value: str | dict | None):
305
+ """We set it recursively on the sub-configs as well"""
306
+ # Set if for current config
307
+ current_attn = getattr(self, "_attn_implementation", None)
308
+ attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
309
+ self._attn_implementation_internal = attn_implementation
310
+
311
+ # Set it recursively on the subconfigs
312
+ for subconfig_key in self.sub_configs:
313
+ subconfig = getattr(self, subconfig_key, None)
314
+ if subconfig is not None:
315
+ current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
316
+ sub_implementation = (
317
+ value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
318
+ )
319
+ subconfig._attn_implementation = sub_implementation
320
+
321
+ @property
322
+ def _experts_implementation(self):
323
+ return self._experts_implementation_internal
324
+
325
+ @_experts_implementation.setter
326
+ def _experts_implementation(self, value: str | dict | None):
327
+ """We set it recursively on the sub-configs as well"""
328
+ # Set if for current config
329
+ current_moe = getattr(self, "_experts_implementation", None)
330
+ experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe)
331
+ self._experts_implementation_internal = experts_implementation
332
+
333
+ # Set it recursively on the subconfigs
334
+ for subconfig_key in self.sub_configs:
335
+ subconfig = getattr(self, subconfig_key, None)
336
+ if subconfig is not None:
337
+ current_subconfig_moe = getattr(subconfig, "_experts_implementation", None)
338
+ sub_implementation = (
339
+ value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe)
340
+ )
341
+ subconfig._experts_implementation = sub_implementation
342
+
343
+ @property
344
+ def torch_dtype(self):
345
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
346
+ return self.dtype
347
+
348
+ @property
349
+ def use_return_dict(self):
350
+ logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!")
351
+ return self.return_dict
352
+
353
+ @torch_dtype.setter
354
+ def torch_dtype(self, value):
355
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
356
+ self.dtype = value
357
+
358
+ def __setattr__(self, key, value):
359
+ if key in super().__getattribute__("attribute_map"):
360
+ key = super().__getattribute__("attribute_map")[key]
361
+ super().__setattr__(key, value)
362
+
363
+ def __getattribute__(self, key):
364
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
365
+ key = super().__getattribute__("attribute_map")[key]
366
+ return super().__getattribute__(key)
367
+
368
+ def validate_output_attentions(self):
369
+ if self.output_attentions and self._attn_implementation not in ["eager", None]:
370
+ raise ValueError(
371
+ "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
372
+ f"{self._attn_implementation}. Please set it to 'eager' instead."
373
+ )
374
+
375
+ def validate_architecture(self):
376
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
377
+ if (
378
+ hasattr(self, "head_dim")
379
+ and hasattr(self, "num_heads")
380
+ and hasattr(self, "embed_dim")
381
+ and self.head_dim * self.num_heads != self.embed_dim
382
+ ):
383
+ raise ValueError(
384
+ f"The embed_dim ({self.embed_dim}) is not a multiple of the number of attention "
385
+ f"heads ({self.num_heads})."
386
+ )
387
+
388
+ def validate_token_ids(self):
389
+ """Part of `@strict`-powered validation. Validates the contents of the special tokens."""
390
+ text_config = self.get_text_config(decoder=True)
391
+ vocab_size = getattr(text_config, "vocab_size", None)
392
+ if vocab_size is not None:
393
+ # Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id
394
+ for value in text_config:
395
+ if value.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size:
396
+ # Can't be an exception until we can load configs that fail validation: several configs on the Hub
397
+ # store invalid special tokens, e.g. `pad_token_id=-1`
398
+ logger.warning_once(
399
+ f"Model config: {value} must be `None` or an integer within the vocabulary (between 0 "
400
+ f"and {vocab_size - 1}), got {value}. This may result in unexpected behavior."
401
+ )
402
+
403
+ def validate_layer_type(self):
404
+ """Check that `layer_types` is correctly defined."""
405
+ if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")):
406
+ return
407
+ elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types):
408
+ raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}")
409
+ elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types):
410
+ raise ValueError(
411
+ f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
412
+ f"({len(self.layer_types)})"
413
+ )
414
+
415
+ @property
416
+ def rope_scaling(self):
417
+ return self.rope_parameters
418
+
419
+ @rope_scaling.setter
420
+ def rope_scaling(self, value):
421
+ self.rope_parameters = value
422
+
423
+ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
424
+ """
425
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
426
+ [`~PreTrainedConfig.from_pretrained`] class method.
427
+
428
+ Args:
429
+ save_directory (`str` or `os.PathLike`):
430
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
431
+ push_to_hub (`bool`, *optional*, defaults to `False`):
432
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
433
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
434
+ namespace).
435
+ kwargs (`dict[str, Any]`, *optional*):
436
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
437
+ """
438
+ if os.path.isfile(save_directory):
439
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
440
+
441
+ generation_parameters = self._get_generation_parameters()
442
+ if len(generation_parameters) > 0:
443
+ raise ValueError(
444
+ "Some generation parameters are set in the model config. These should go into `model.generation_config`"
445
+ f"as opposed to `model.config`. \nGeneration parameters found: {str(generation_parameters)}",
446
+ )
447
+
448
+ os.makedirs(save_directory, exist_ok=True)
449
+
450
+ if push_to_hub:
451
+ commit_message = kwargs.pop("commit_message", None)
452
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
453
+ repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
454
+ files_timestamps = self._get_files_timestamps(save_directory)
455
+
456
+ # This attribute is important to know on load, but should not be serialized on save.
457
+ if "transformers_weights" in self:
458
+ delattr(self, "transformers_weights")
459
+
460
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
461
+ # loaded from the Hub.
462
+ if self._auto_class is not None:
463
+ custom_object_save(self, save_directory, config=self)
464
+
465
+ # If we save using the predefined names, we can load using `from_pretrained`
466
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
467
+
468
+ # Strict validation at save-time: prevent bad patterns from propagating
469
+ # Using `strict` decorator guarantees that `self.validate` exists , but not all
470
+ # model config might have the decorator added
471
+ if hasattr(self, "validate"):
472
+ self.validate()
473
+ self.to_json_file(output_config_file, use_diff=True)
474
+ logger.info(f"Configuration saved in {output_config_file}")
475
+
476
+ if push_to_hub:
477
+ self._upload_modified_files(
478
+ save_directory,
479
+ repo_id,
480
+ files_timestamps,
481
+ commit_message=commit_message,
482
+ token=kwargs.get("token"),
483
+ )
484
+
485
+ @classmethod
486
+ def from_pretrained(
487
+ cls: type[SpecificPreTrainedConfigType],
488
+ pretrained_model_name_or_path: str | os.PathLike,
489
+ cache_dir: str | os.PathLike | None = None,
490
+ force_download: bool = False,
491
+ local_files_only: bool = False,
492
+ token: str | bool | None = None,
493
+ revision: str = "main",
494
+ **kwargs,
495
+ ) -> SpecificPreTrainedConfigType:
496
+ r"""
497
+ Instantiate a [`PreTrainedConfig`] (or a derived class) from a pretrained model configuration.
498
+
499
+ Args:
500
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
501
+ This can be either:
502
+
503
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
504
+ huggingface.co.
505
+ - a path to a *directory* containing a configuration file saved using the
506
+ [`~PreTrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
507
+ - a path to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
508
+ cache_dir (`str` or `os.PathLike`, *optional*):
509
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
510
+ standard cache should not be used.
511
+ force_download (`bool`, *optional*, defaults to `False`):
512
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
513
+ they exist.
514
+ proxies (`dict[str, str]`, *optional*):
515
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
516
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
517
+ token (`str` or `bool`, *optional*):
518
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
519
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
520
+ revision (`str`, *optional*, defaults to `"main"`):
521
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
522
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
523
+ identifier allowed by git.
524
+
525
+ <Tip>
526
+
527
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
528
+
529
+ </Tip>
530
+
531
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
532
+ If `False`, then this function returns just the final configuration object.
533
+
534
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
535
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
536
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
537
+ subfolder (`str`, *optional*, defaults to `""`):
538
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
539
+ specify the folder name here.
540
+ kwargs (`dict[str, Any]`, *optional*):
541
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
542
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
543
+ by the `return_unused_kwargs` keyword parameter.
544
+
545
+ Returns:
546
+ [`PreTrainedConfig`]: The configuration object instantiated from this pretrained model.
547
+
548
+ Examples:
549
+
550
+ ```python
551
+ # We can't instantiate directly the base class *PreTrainedConfig* so let's show the examples on a
552
+ # derived class: BertConfig
553
+ config = BertConfig.from_pretrained(
554
+ "google-bert/bert-base-uncased"
555
+ ) # Download configuration from huggingface.co and cache.
556
+ config = BertConfig.from_pretrained(
557
+ "./test/saved_model/"
558
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
559
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
560
+ config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
561
+ assert config.output_attentions == True
562
+ config, unused_kwargs = BertConfig.from_pretrained(
563
+ "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
564
+ )
565
+ assert config.output_attentions == True
566
+ assert unused_kwargs == {"foo": False}
567
+ ```"""
568
+ kwargs["cache_dir"] = cache_dir
569
+ kwargs["force_download"] = force_download
570
+ kwargs["local_files_only"] = local_files_only
571
+ kwargs["revision"] = revision
572
+
573
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
574
+ if cls.base_config_key and cls.base_config_key in config_dict:
575
+ config_dict = config_dict[cls.base_config_key]
576
+
577
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
578
+ # sometimes the config has no `base_config_key` if the config is used in several composite models
579
+ # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
580
+ for v in config_dict.values():
581
+ if isinstance(v, dict) and v.get("model_type") == cls.model_type:
582
+ config_dict = v
583
+
584
+ # raise warning only if we still can't see a match in `model_type`
585
+ if config_dict["model_type"] != cls.model_type:
586
+ logger.warning(
587
+ f"You are using a model of type `{config_dict['model_type']}` to instantiate a model of type "
588
+ f"`{cls.model_type}`. This may be expected if you are loading a checkpoint that shares a subset "
589
+ f"of the architecture (e.g., loading a `sam2_video` checkpoint into `Sam2Model`), but is otherwise "
590
+ f"not supported and can yield errors. Please verify that the checkpoint is compatible with the "
591
+ f"model you are instantiating."
592
+ )
593
+
594
+ return cls.from_dict(config_dict, **kwargs)
595
+
596
+ @classmethod
597
+ def get_config_dict(
598
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
599
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
600
+ """
601
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
602
+ [`PreTrainedConfig`] using `from_dict`.
603
+
604
+ Parameters:
605
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
606
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
607
+
608
+ Returns:
609
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
610
+
611
+ """
612
+ original_kwargs = copy.deepcopy(kwargs)
613
+ # Get config dict associated with the base config file
614
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
615
+ if config_dict is None:
616
+ return {}, kwargs
617
+ if "_commit_hash" in config_dict:
618
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
619
+
620
+ # That config file may point us toward another config file to use.
621
+ if "configuration_files" in config_dict:
622
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
623
+ config_dict, kwargs = cls._get_config_dict(
624
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
625
+ )
626
+
627
+ return config_dict, kwargs
628
+
629
+ @classmethod
630
+ def _get_config_dict(
631
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
632
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
633
+ cache_dir = kwargs.pop("cache_dir", None)
634
+ force_download = kwargs.pop("force_download", False)
635
+ proxies = kwargs.pop("proxies", None)
636
+ token = kwargs.pop("token", None)
637
+ local_files_only = kwargs.pop("local_files_only", False)
638
+ revision = kwargs.pop("revision", None)
639
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
640
+ subfolder = kwargs.pop("subfolder", "")
641
+ from_pipeline = kwargs.pop("_from_pipeline", None)
642
+ from_auto_class = kwargs.pop("_from_auto", False)
643
+ commit_hash = kwargs.pop("_commit_hash", None)
644
+
645
+ gguf_file = kwargs.get("gguf_file")
646
+
647
+ if trust_remote_code is True:
648
+ logger.warning(
649
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
650
+ " ignored."
651
+ )
652
+
653
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
654
+ if from_pipeline is not None:
655
+ user_agent["using_pipeline"] = from_pipeline
656
+
657
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
658
+
659
+ is_local = os.path.isdir(pretrained_model_name_or_path)
660
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
661
+ # Special case when pretrained_model_name_or_path is a local file
662
+ resolved_config_file = pretrained_model_name_or_path
663
+ is_local = True
664
+ else:
665
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
666
+
667
+ try:
668
+ # Load from local folder or from cache or download from model Hub and cache
669
+ resolved_config_file = cached_file(
670
+ pretrained_model_name_or_path,
671
+ configuration_file,
672
+ cache_dir=cache_dir,
673
+ force_download=force_download,
674
+ proxies=proxies,
675
+ local_files_only=local_files_only,
676
+ token=token,
677
+ user_agent=user_agent,
678
+ revision=revision,
679
+ subfolder=subfolder,
680
+ _commit_hash=commit_hash,
681
+ )
682
+ if resolved_config_file is None:
683
+ return None, kwargs
684
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
685
+ except OSError:
686
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
687
+ # the original exception.
688
+ raise
689
+ except Exception:
690
+ # For any other exception, we throw a generic error.
691
+ raise OSError(
692
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
693
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
694
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
695
+ f" containing a {configuration_file} file"
696
+ )
697
+
698
+ try:
699
+ if gguf_file:
700
+ config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
701
+ else:
702
+ # Load config dict
703
+ config_dict = cls._dict_from_json_file(resolved_config_file)
704
+
705
+ config_dict["_commit_hash"] = commit_hash
706
+ except (json.JSONDecodeError, UnicodeDecodeError):
707
+ raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
708
+
709
+ if is_local:
710
+ logger.info(f"loading configuration file {resolved_config_file}")
711
+ else:
712
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
713
+
714
+ # timm models are not saved with the model_type in the config file
715
+ if "model_type" not in config_dict and is_timm_config_dict(config_dict):
716
+ config_dict["model_type"] = "timm_wrapper"
717
+
718
+ return config_dict, kwargs
719
+
720
+ @classmethod
721
+ def from_dict(
722
+ cls: type[SpecificPreTrainedConfigType], config_dict: dict[str, Any], **kwargs
723
+ ) -> SpecificPreTrainedConfigType:
724
+ """
725
+ Instantiates a [`PreTrainedConfig`] from a Python dictionary of parameters.
726
+
727
+ Args:
728
+ config_dict (`dict[str, Any]`):
729
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
730
+ retrieved from a pretrained checkpoint by leveraging the [`~PreTrainedConfig.get_config_dict`] method.
731
+ kwargs (`dict[str, Any]`):
732
+ Additional parameters from which to initialize the configuration object.
733
+
734
+ Returns:
735
+ [`PreTrainedConfig`]: The configuration object instantiated from those parameters.
736
+ """
737
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
738
+
739
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
740
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
741
+ kwargs.setdefault("_commit_hash", config_dict["_commit_hash"])
742
+
743
+ # To remove arg here are those passed along for our internal telemetry but we still need to remove them
744
+ to_remove = ["_from_auto", "_from_pipeline"]
745
+ valid_fields = [
746
+ "num_labels",
747
+ "attn_implementation",
748
+ "experts_implementation",
749
+ "output_attentions",
750
+ "torch_dtype",
751
+ "dtype",
752
+ "name_or_path",
753
+ ]
754
+ for key, value in kwargs.items():
755
+ if key in valid_fields:
756
+ if key not in ["torch_dtype", "dtype"]:
757
+ config_dict[key] = value
758
+ to_remove.append(key)
759
+ elif value != "auto":
760
+ config_dict[key] = value
761
+
762
+ config = cls(**config_dict)
763
+
764
+ for key, value in kwargs.items():
765
+ if hasattr(config, key):
766
+ current_attr = getattr(config, key)
767
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
768
+ # We need to update only custom kwarg values instead and keep other attr in subconfig.
769
+ if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict):
770
+ current_attr_updated = current_attr.to_dict()
771
+ current_attr_updated.update(value)
772
+ value = current_attr.__class__(**current_attr_updated)
773
+ setattr(config, key, value)
774
+ to_remove.append(key)
775
+
776
+ for key in to_remove:
777
+ kwargs.pop(key, None)
778
+
779
+ logger.info(f"Model config {config}")
780
+ if return_unused_kwargs:
781
+ return config, kwargs
782
+ else:
783
+ return config
784
+
785
+ @classmethod
786
+ def from_json_file(
787
+ cls: type[SpecificPreTrainedConfigType], json_file: str | os.PathLike
788
+ ) -> SpecificPreTrainedConfigType:
789
+ """
790
+ Instantiates a [`PreTrainedConfig`] from the path to a JSON file of parameters.
791
+
792
+ Args:
793
+ json_file (`str` or `os.PathLike`):
794
+ Path to the JSON file containing the parameters.
795
+
796
+ Returns:
797
+ [`PreTrainedConfig`]: The configuration object instantiated from that JSON file.
798
+
799
+ """
800
+ config_dict = cls._dict_from_json_file(json_file)
801
+ return cls(**config_dict)
802
+
803
+ @classmethod
804
+ def _dict_from_json_file(cls, json_file: str | os.PathLike):
805
+ with open(json_file, encoding="utf-8") as reader:
806
+ text = reader.read()
807
+ config_dict = json.loads(text)
808
+
809
+ return cls._decode_special_floats(config_dict)
810
+
811
+ @classmethod
812
+ def _encode_special_floats(cls, obj: Any) -> Any:
813
+ """
814
+ Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
815
+ engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
816
+
817
+ It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
818
+ """
819
+ if isinstance(obj, float):
820
+ if math.isnan(obj):
821
+ return {_FLOAT_TAG_KEY: "NaN"}
822
+ if obj == float("inf"):
823
+ return {_FLOAT_TAG_KEY: "Infinity"}
824
+ if obj == float("-inf"):
825
+ return {_FLOAT_TAG_KEY: "-Infinity"}
826
+ return obj
827
+
828
+ if isinstance(obj, dict):
829
+ return {k: cls._encode_special_floats(v) for k, v in obj.items()}
830
+
831
+ if isinstance(obj, (list, tuple)):
832
+ return [cls._encode_special_floats(v) for v in obj]
833
+
834
+ return obj
835
+
836
+ @classmethod
837
+ def _decode_special_floats(cls, obj: Any) -> Any:
838
+ """
839
+ Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
840
+ engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
841
+
842
+ This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
843
+ """
844
+ if isinstance(obj, dict):
845
+ if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
846
+ tag = obj[_FLOAT_TAG_KEY]
847
+ if tag in _FLOAT_TAG_VALUES:
848
+ return _FLOAT_TAG_VALUES[tag]
849
+ return obj
850
+
851
+ return {k: cls._decode_special_floats(v) for k, v in obj.items()}
852
+
853
+ if isinstance(obj, list):
854
+ return [cls._decode_special_floats(v) for v in obj]
855
+
856
+ return obj
857
+
858
+ def __eq__(self, other):
859
+ return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
860
+
861
+ def __repr__(self):
862
+ return f"{self.__class__.__name__} {self.to_json_string()}"
863
+
864
+ def __iter__(self):
865
+ yield from self.__dict__
866
+
867
+ def to_diff_dict(self) -> dict[str, Any]:
868
+ """
869
+ Removes all attributes from the configuration that correspond to the default config attributes for
870
+ better readability, while always retaining the `config` attribute from the class. Serializes to a
871
+ Python dictionary.
872
+
873
+ Returns:
874
+ dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
875
+ """
876
+ config_dict = self.to_dict()
877
+
878
+ # Get the default config dict (from a fresh PreTrainedConfig instance)
879
+ default_config_dict = PreTrainedConfig().to_dict()
880
+
881
+ # get class specific config dict
882
+ class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
883
+
884
+ serializable_config_dict = {}
885
+
886
+ # Only serialize values that differ from the default config,
887
+ # except always keep the 'config' attribute.
888
+ for key, value in config_dict.items():
889
+ if (
890
+ isinstance(getattr(self, key, None), PreTrainedConfig)
891
+ and key in class_config_dict
892
+ and isinstance(class_config_dict[key], dict)
893
+ ):
894
+ # For nested configs we need to clean the diff recursively
895
+ diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
896
+ if "model_type" in value:
897
+ # Needs to be set even if it's not in the diff
898
+ diff["model_type"] = value["model_type"]
899
+
900
+ serializable_config_dict[key] = diff
901
+ elif (
902
+ key not in default_config_dict
903
+ or key == "transformers_version"
904
+ or key == "vocab_file"
905
+ or value != default_config_dict[key]
906
+ or (key in default_config_dict and value != class_config_dict.get(key, value))
907
+ ):
908
+ serializable_config_dict[key] = value
909
+
910
+ self._remove_keys_not_serialized(serializable_config_dict)
911
+
912
+ # Key removed only in diff dict
913
+ if "_name_or_path" in serializable_config_dict:
914
+ del serializable_config_dict["_name_or_path"]
915
+
916
+ if hasattr(self, "quantization_config"):
917
+ serializable_config_dict["quantization_config"] = (
918
+ self.quantization_config.to_dict()
919
+ if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
920
+ else self.quantization_config
921
+ )
922
+ self.dict_dtype_to_str(serializable_config_dict)
923
+
924
+ return serializable_config_dict
925
+
926
+ def to_dict(self) -> dict[str, Any]:
927
+ """
928
+ Serializes this instance to a Python dictionary.
929
+
930
+ Returns:
931
+ `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
932
+ """
933
+ output = copy.deepcopy(self.__dict__)
934
+ if hasattr(self.__class__, "model_type"):
935
+ output["model_type"] = self.__class__.model_type
936
+
937
+ # Transformers version when serializing the model
938
+ output["transformers_version"] = __version__
939
+
940
+ # Pop "kwargs" since they are unpacked and set in the post init
941
+ output.pop("kwargs", None)
942
+
943
+ def to_list(value):
944
+ if isinstance(value, tuple):
945
+ value = [to_list(item) for item in value]
946
+ return value
947
+
948
+ for key, value in output.items():
949
+ # Deal with nested configs like CLIP
950
+ if isinstance(value, PreTrainedConfig):
951
+ value = value.to_dict()
952
+ del value["transformers_version"]
953
+
954
+ # Some models have defaults as tuples because dataclass
955
+ # doesn't allow mutables. Let's convert back to `list``
956
+ elif isinstance(value, tuple):
957
+ value = to_list(value)
958
+
959
+ output[key] = value
960
+
961
+ self._remove_keys_not_serialized(output)
962
+
963
+ if hasattr(self, "quantization_config"):
964
+ output["quantization_config"] = (
965
+ self.quantization_config.to_dict()
966
+ if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
967
+ else self.quantization_config
968
+ )
969
+ self.dict_dtype_to_str(output)
970
+
971
+ return output
972
+
973
+ def to_json_string(self, use_diff: bool = True) -> str:
974
+ """
975
+ Serializes this instance to a JSON string.
976
+
977
+ Args:
978
+ use_diff (`bool`, *optional*, defaults to `True`):
979
+ If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
980
+ is serialized to JSON string.
981
+
982
+ Returns:
983
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
984
+ """
985
+ if use_diff is True:
986
+ config_dict = self.to_diff_dict()
987
+ else:
988
+ config_dict = self.to_dict()
989
+
990
+ # Handle +/-Infinity and NaNs
991
+ config_dict = self._encode_special_floats(config_dict)
992
+
993
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
994
+
995
+ def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
996
+ """
997
+ Save this instance to a JSON file.
998
+
999
+ Args:
1000
+ json_file_path (`str` or `os.PathLike`):
1001
+ Path to the JSON file in which this configuration instance's parameters will be saved.
1002
+ use_diff (`bool`, *optional*, defaults to `True`):
1003
+ If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
1004
+ is serialized to JSON file.
1005
+ """
1006
+ with open(json_file_path, "w", encoding="utf-8") as writer:
1007
+ writer.write(self.to_json_string(use_diff=use_diff))
1008
+
1009
+ def update(self, config_dict: dict[str, Any]):
1010
+ """
1011
+ Updates attributes of this class with attributes from `config_dict`.
1012
+
1013
+ Args:
1014
+ config_dict (`dict[str, Any]`): Dictionary of attributes that should be updated for this class.
1015
+ """
1016
+ for key, value in config_dict.items():
1017
+ setattr(self, key, value)
1018
+
1019
+ def update_from_string(self, update_str: str):
1020
+ """
1021
+ Updates attributes of this class with attributes from `update_str`.
1022
+
1023
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
1024
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
1025
+
1026
+ The keys to change have to already exist in the config object.
1027
+
1028
+ Args:
1029
+ update_str (`str`): String with attributes that should be updated for this class.
1030
+
1031
+ """
1032
+
1033
+ d = dict(x.split("=") for x in update_str.split(","))
1034
+ for k, v in d.items():
1035
+ if not hasattr(self, k):
1036
+ raise ValueError(f"key {k} isn't in the original config dict")
1037
+
1038
+ old_v = getattr(self, k)
1039
+ if isinstance(old_v, bool):
1040
+ if v.lower() in ["true", "1", "y", "yes"]:
1041
+ v = True
1042
+ elif v.lower() in ["false", "0", "n", "no"]:
1043
+ v = False
1044
+ else:
1045
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
1046
+ elif isinstance(old_v, int):
1047
+ v = int(v)
1048
+ elif isinstance(old_v, float):
1049
+ v = float(v)
1050
+ elif not isinstance(old_v, str):
1051
+ raise TypeError(
1052
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
1053
+ )
1054
+
1055
+ setattr(self, k, v)
1056
+
1057
+ def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
1058
+ """
1059
+ Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
1060
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
1061
+ string, which can then be stored in the json format.
1062
+ """
1063
+ if d.get("dtype") is not None:
1064
+ if isinstance(d["dtype"], dict):
1065
+ d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
1066
+ # models like Emu3 can have "dtype" as token in config's vocabulary map,
1067
+ # so we also exclude int type here to avoid error in this special case.
1068
+ elif not isinstance(d["dtype"], (str, int)):
1069
+ d["dtype"] = str(d["dtype"]).split(".")[1]
1070
+ for value in d.values():
1071
+ if isinstance(value, dict):
1072
+ self.dict_dtype_to_str(value)
1073
+
1074
+ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
1075
+ """
1076
+ Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
1077
+ Runs recursive check on the dict, to remove from all sub configs.
1078
+ """
1079
+
1080
+ for key_to_remove in [
1081
+ "_is_quantized",
1082
+ "_auto_class",
1083
+ "_commit_hash",
1084
+ "_attn_implementation_internal",
1085
+ "_experts_implementation_internal",
1086
+ "ignore_keys_at_rope_validation",
1087
+ "base_model_tp_plan",
1088
+ "base_model_pp_plan",
1089
+ ]:
1090
+ d.pop(key_to_remove, None)
1091
+
1092
+ if "_output_attentions" in d:
1093
+ d["output_attentions"] = d.pop("_output_attentions")
1094
+
1095
+ for value in d.values():
1096
+ if isinstance(value, dict):
1097
+ self._remove_keys_not_serialized(value)
1098
+
1099
+ @classmethod
1100
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
1101
+ """
1102
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
1103
+ the library are already mapped with `AutoConfig`.
1104
+
1105
+
1106
+
1107
+ Args:
1108
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1109
+ The auto class to register this new configuration with.
1110
+ """
1111
+ if not isinstance(auto_class, str):
1112
+ auto_class = auto_class.__name__
1113
+
1114
+ import transformers.models.auto as auto_module
1115
+
1116
+ if not hasattr(auto_module, auto_class):
1117
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1118
+
1119
+ cls._auto_class = auto_class
1120
+
1121
+ def _get_generation_parameters(self) -> dict[str, Any]:
1122
+ """
1123
+ Checks if there are generation parameters in `PreTrainedConfig` instance. Note that
1124
+ we should not save generation params in PreTrainedConfig, and we will raise error
1125
+ if there are any.
1126
+ """
1127
+ generation_params = {}
1128
+ default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
1129
+ for key in GenerationConfig._get_default_generation_params().keys():
1130
+ if key == "use_cache":
1131
+ continue # common key for most models
1132
+ if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
1133
+ generation_params[key] = getattr(self, key)
1134
+
1135
+ return generation_params
1136
+
1137
+ def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig":
1138
+ """
1139
+ Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
1140
+ `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
1141
+ which is useful on models that have both text input and output modalities.
1142
+
1143
+ There are three possible outcomes of using this method:
1144
+ 1. On most models, it returns the original config instance itself.
1145
+ 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
1146
+ of valid names.
1147
+ 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
1148
+
1149
+ Args:
1150
+ decoder (`Optional[bool]`, *optional*):
1151
+ If set to `True`, then only search for decoder config names.
1152
+ encoder (`Optional[bool]`, *optional*):
1153
+ If set to `True`, then only search for encoder config names.
1154
+ """
1155
+ return_both = decoder == encoder # both unset or both set -> search all possible names
1156
+
1157
+ decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1158
+ encoder_possible_text_config_names = ("text_encoder",)
1159
+ if return_both:
1160
+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1161
+ elif decoder:
1162
+ possible_text_config_names = decoder_possible_text_config_names
1163
+ else:
1164
+ possible_text_config_names = encoder_possible_text_config_names
1165
+
1166
+ valid_text_config_names = []
1167
+ for text_config_name in possible_text_config_names:
1168
+ if hasattr(self, text_config_name):
1169
+ text_config = getattr(self, text_config_name, None)
1170
+ if text_config is not None:
1171
+ valid_text_config_names += [text_config_name]
1172
+
1173
+ if len(valid_text_config_names) > 1:
1174
+ raise ValueError(
1175
+ f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1176
+ "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
1177
+ "e.g. `text_config = config.sub_config_name`"
1178
+ )
1179
+ elif len(valid_text_config_names) == 1:
1180
+ config_to_return = getattr(self, valid_text_config_names[0])
1181
+ else:
1182
+ config_to_return = self
1183
+
1184
+ # handle legacy models with flat config structure, when we only want one of the configs
1185
+ if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
1186
+ config_to_return = copy.deepcopy(config_to_return)
1187
+ prefix_to_keep = "decoder" if decoder else "encoder"
1188
+ for key in config_to_return.to_dict():
1189
+ # NOTE: We can't discard keys because:
1190
+ # 1) we can't truly delete a cls attribte on a dataclass; 2) we can't set the value to `None` due to
1191
+ # strict validation. So we just keep it as is, since there are only a couple old models falling in this condition
1192
+ if key.startswith(prefix_to_keep):
1193
+ # [encoder/decoder]_layers -> num_hidden_layers
1194
+ if key == prefix_to_keep + "_layers":
1195
+ new_key = "num_hidden_layers"
1196
+ # [encoder/decoder]_attention_heads -> num_attention_heads
1197
+ elif key == prefix_to_keep + "_attention_heads":
1198
+ new_key = "num_attention_heads"
1199
+ # e.g. encoder_hidden_act -> hidden_act
1200
+ else:
1201
+ new_key = key[len(prefix_to_keep) + 1 :]
1202
+
1203
+ # Does the class map the new key into a different attribute name at read time? if so, let's write
1204
+ # into that attribute instead
1205
+ if new_key in config_to_return.attribute_map:
1206
+ new_key = config_to_return.attribute_map[new_key]
1207
+
1208
+ value = getattr(config_to_return, key)
1209
+ delattr(config_to_return, key)
1210
+ setattr(config_to_return, new_key, value)
1211
+
1212
+ return config_to_return
1213
+
1214
+
1215
+ def get_configuration_file(configuration_files: list[str]) -> str:
1216
+ """
1217
+ Get the configuration file to use for this version of transformers.
1218
+
1219
+ Args:
1220
+ configuration_files (`list[str]`): The list of available configuration files.
1221
+
1222
+ Returns:
1223
+ `str`: The configuration file to use.
1224
+ """
1225
+ configuration_files_map = {}
1226
+ for file_name in configuration_files:
1227
+ if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
1228
+ v = file_name.removeprefix("config.").removesuffix(".json")
1229
+ configuration_files_map[v] = file_name
1230
+ available_versions = sorted(configuration_files_map.keys())
1231
+
1232
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1233
+ configuration_file = CONFIG_NAME
1234
+ transformers_version = version.parse(__version__)
1235
+ for v in available_versions:
1236
+ if version.parse(v) <= transformers_version:
1237
+ configuration_file = configuration_files_map[v]
1238
+ else:
1239
+ # No point going further since the versions are sorted.
1240
+ break
1241
+
1242
+ return configuration_file
1243
+
1244
+
1245
+ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1246
+ """
1247
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1248
+ values from `dict_a` that are different from values in `dict_b`.
1249
+
1250
+ dict_b : the default config dictionary. We want to remove values that are in this one
1251
+ """
1252
+ diff = {}
1253
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
1254
+ for key, value in dict_a.items():
1255
+ obj_value = getattr(config_obj, str(key), None)
1256
+ if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1257
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1258
+ diff[key] = diff_value
1259
+ elif key not in dict_b or (value != default[key]):
1260
+ diff[key] = value
1261
+ return diff
1262
+
1263
+
1264
+ PreTrainedConfig.push_to_hub = copy_func(PreTrainedConfig.push_to_hub)
1265
+ if PreTrainedConfig.push_to_hub.__doc__ is not None:
1266
+ PreTrainedConfig.push_to_hub.__doc__ = PreTrainedConfig.push_to_hub.__doc__.format(
1267
+ object="config", object_class="AutoConfig", object_files="configuration file"
1268
+ )
1269
+
1270
+
1271
+ # The alias is only here for BC - we did not have the correct CamelCasing before
1272
+ PretrainedConfig = PreTrainedConfig
1273
+
1274
+
1275
+ def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True):
1276
+ logger.warning(
1277
+ "`layer_type_validation` is deprecated and will be removed in v5.20. "
1278
+ "Use `PreTrainedConfig.validate_layer_type` instead"
1279
+ )
1280
+
1281
+ if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1282
+ raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
1283
+ if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
1284
+ raise ValueError(
1285
+ f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
1286
+ f"({len(layer_types)})"
1287
+ )