HemanthSai7 commited on
Commit
d1e491b
·
verified ·
1 Parent(s): 4ca2f01

Delete configuration_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_utils.py +0 -1287
configuration_utils.py DELETED
@@ -1,1287 +0,0 @@
1
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Configuration base class and utilities."""
16
-
17
- import copy
18
- import json
19
- import math
20
- import os
21
- from collections.abc import Sequence
22
- from dataclasses import dataclass
23
- from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union
24
-
25
- from huggingface_hub import create_repo
26
- from huggingface_hub.dataclasses import strict
27
- from packaging import version
28
-
29
- from . import __version__
30
- from .dynamic_module_utils import custom_object_save
31
- from .generation.configuration_utils import GenerationConfig
32
- from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
33
- from .modeling_rope_utils import RotaryEmbeddingConfigMixin
34
- from .tokenization_utils_base import PreTrainedTokenizerBase
35
- from .utils import (
36
- CONFIG_NAME,
37
- PushToHubMixin,
38
- cached_file,
39
- copy_func,
40
- extract_commit_hash,
41
- is_torch_available,
42
- logging,
43
- )
44
- from .utils.generic import is_timm_config_dict
45
-
46
-
47
- if TYPE_CHECKING:
48
- import torch
49
-
50
-
51
- logger = logging.get_logger(__name__)
52
-
53
-
54
- # type hinting: specifying the type of config class that inherits from PreTrainedConfig
55
- SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
56
-
57
- _FLOAT_TAG_KEY = "__float__"
58
- _FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
59
-
60
-
61
- ALLOWED_LAYER_TYPES = (
62
- "full_attention",
63
- "sliding_attention",
64
- "chunked_attention",
65
- "linear_attention", # used in minimax
66
- "conv", # used in LFMv2
67
- "mamba",
68
- "attention",
69
- "sparse",
70
- "dense",
71
- )
72
-
73
-
74
- @strict(accept_kwargs=True)
75
- @dataclass(repr=False)
76
- class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
77
- # no-format
78
- r"""
79
- Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
80
- methods for loading/downloading/saving configurations.
81
-
82
- <Tip>
83
-
84
- A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
85
- initialize a model does **not** load the model weights. It only affects the model's configuration.
86
-
87
- </Tip>
88
-
89
- Class attributes (overridden by derived classes):
90
-
91
- - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
92
- the correct object in [`~transformers.AutoConfig`].
93
- - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
94
- Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
95
- (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
96
- two or more configs of type [`~transformers.PreTrainedConfig`].
97
- - **keys_to_ignore_at_inference** (`list[str]`) -- A list of keys to ignore by default when looking at dictionary
98
- outputs of the model during inference.
99
- - **attribute_map** (`dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
100
- naming of attributes.
101
- - **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
102
- parallel plan applied to the sub-module when `model.tensor_parallel` is called.
103
- - **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
104
- pipeline parallel plan that enables users to place the child-module on the appropriate device.
105
-
106
- Common attributes (present in all subclasses):
107
-
108
- - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
109
- embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
110
- - **hidden_size** (`int`) -- The hidden size of the model.
111
- - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
112
- model.
113
- - **num_hidden_layers** (`int`) -- The number of blocks in the model.
114
-
115
- <Tip warning={true}>
116
-
117
- Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
118
- some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
119
- them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
120
- information about the individual parameters.
121
-
122
- </Tip>
123
-
124
- Arg:
125
- name_or_path (`str`, *optional*, defaults to `""`):
126
- Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path`
127
- if the configuration was created with such a method.
128
- output_hidden_states (`bool`, *optional*, defaults to `False`):
129
- Whether or not the model should return all hidden-states.
130
- output_attentions (`bool`, *optional*, defaults to `False`):
131
- Whether or not the model should returns all attentions.
132
- return_dict (`bool`, *optional*, defaults to `True`):
133
- Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
134
- is_encoder_decoder (`bool`, *optional*, defaults to `False`):
135
- Whether the model is used as an encoder/decoder or not.
136
- chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
137
- The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
138
- the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
139
- sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
140
- Forward Chunking work?](../glossary.html#feed-forward-chunking).
141
-
142
- > Parameters for fine-tuning tasks
143
-
144
- architectures (`list[str]`, *optional*):
145
- Model architectures that can be used with the model pretrained weights.
146
- id2label (`dict[int, str]`, *optional*):
147
- A map from index (for instance prediction index, or target index) to label.
148
- label2id (`dict[str, int]`, *optional*):
149
- A map from label to index for the model.
150
- num_labels (`int`, *optional*):
151
- Number of labels to use in the last layer added to the model, typically for a classification task.
152
- problem_type (`str`, *optional*):
153
- Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
154
- `"single_label_classification"` or `"multi_label_classification"`.
155
-
156
- > PyTorch specific parameters
157
-
158
- dtype (`str`, *optional*):
159
- The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
160
- (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
161
- model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
162
- `float16` weights.
163
- """
164
-
165
- # Class attributes that we don't want to save or have in `self.__dict__`
166
- # They are not supposed to be set/changed by users. Each field is set when
167
- # creating a model class
168
- base_config_key: ClassVar[str] = ""
169
- sub_configs: ClassVar[dict[str, type["PreTrainedConfig"]]] = {}
170
- has_no_defaults_at_init: ClassVar[bool] = False
171
- keys_to_ignore_at_inference: ClassVar[list[str]] = []
172
- attribute_map: ClassVar[dict[str, str]] = {}
173
- base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
174
- base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
175
- base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
176
- _auto_class: ClassVar[str | None] = None
177
-
178
- # Attributes set internally when saving and used to infer model
179
- # class for `Auto` mapping
180
- model_type: ClassVar[str] = ""
181
- transformers_version: str | None = None
182
- architectures: list[str] | None = None
183
-
184
- # Common attributes for all models
185
- output_hidden_states: bool | None = False
186
- return_dict: bool | None = True
187
- dtype: Union[str, "torch.dtype"] | None = None
188
- chunk_size_feed_forward: int = 0
189
- is_encoder_decoder: bool = False
190
-
191
- # Fine-tuning task arguments
192
- id2label: dict[int, str] | dict[str, str] | None = None
193
- label2id: dict[str, int] | dict[str, str] | None = None
194
- problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None
195
-
196
- # Tokenizer kwargs
197
- tokenizer_class: str | PreTrainedTokenizerBase | None = None
198
-
199
- def __post_init__(self, **kwargs):
200
- # BC for the `torch_dtype` argument instead of the simpler `dtype`
201
- # Do not warn, as it would otherwise always be triggered since most configs on the hub have `torch_dtype`
202
- if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
203
- # If both are provided, keep `dtype`
204
- self.dtype = self.dtype if self.dtype is not None else torch_dtype
205
- if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available():
206
- # we will start using self.dtype in v5, but to be consistent with
207
- # from_pretrained's dtype arg convert it to an actual torch.dtype object
208
- import torch
209
-
210
- self.dtype = getattr(torch, self.dtype)
211
-
212
- # Keep the default value of `num_labels=2` in case users have saved a classfier with 2 labels
213
- # Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other
214
- # cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise
215
- if self.id2label is None:
216
- self.num_labels = kwargs.get("num_labels", 2)
217
- else:
218
- if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"):
219
- logger.warning(
220
- f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to "
221
- f"the `id2label` map of length `{len(self.id2label)}`."
222
- )
223
- # Keys are always strings in JSON so convert ids to int
224
- self.id2label = {int(key): value for key, value in self.id2label.items()}
225
-
226
- # BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
227
- if hasattr(self, "rope_parameters"):
228
- kwargs = self.convert_rope_params_to_dict(**kwargs)
229
-
230
- # Parameters for sequence generation saved in the config are popped instead of loading them.
231
- for parameter_name in GenerationConfig._get_default_generation_params().keys():
232
- kwargs.pop(parameter_name, None)
233
-
234
- # Name or path to the pretrained checkpoint
235
- self._name_or_path = str(kwargs.pop("name_or_path", ""))
236
- self._commit_hash = kwargs.pop("_commit_hash", None)
237
-
238
- # Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs)
239
- self._output_attentions: bool | None = kwargs.pop("output_attentions", False)
240
- self._attn_implementation: str | None = kwargs.pop("attn_implementation", None)
241
- self._experts_implementation: str | None = kwargs.pop("experts_implementation", None)
242
-
243
- # Additional attributes without default values
244
- for key, value in kwargs.items():
245
- # Check this to avoid deserializing problematic fields from hub configs - they should use the public field
246
- if key not in ("_attn_implementation_internal", "_experts_implementation_internal"):
247
- try:
248
- setattr(self, key, value)
249
- except AttributeError as err:
250
- logger.error(f"Can't set {key} with value {value} for {self}")
251
- raise err
252
-
253
- def __init_subclass__(cls, *args, **kwargs):
254
- super().__init_subclass__(*args, **kwargs)
255
- cls = dataclass(cls, repr=False)
256
-
257
- @property
258
- def name_or_path(self) -> str | None:
259
- return getattr(self, "_name_or_path", None)
260
-
261
- @name_or_path.setter
262
- def name_or_path(self, value):
263
- self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
264
-
265
- @property
266
- def num_labels(self) -> int:
267
- """
268
- `int`: The number of labels for classification models.
269
- """
270
- return len(self.id2label) if self.id2label is not None else None
271
-
272
- @num_labels.setter
273
- def num_labels(self, num_labels: int):
274
- # we do not store `num_labels` attribute in config, but instead
275
- # compute it based on the length of the `id2label` map
276
- if self.id2label is None or self.num_labels != num_labels:
277
- self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
278
- self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
279
-
280
- @property
281
- def output_attentions(self):
282
- """
283
- `bool`: Whether or not the model should returns all attentions.
284
- """
285
- return self._output_attentions
286
-
287
- @output_attentions.setter
288
- def output_attentions(self, value: bool):
289
- # If we set `output_attentions` explicitly before the attn implementation, dispatch eager
290
- if value and self._attn_implementation is None:
291
- self._attn_implementation = "eager"
292
- if value and self._attn_implementation != "eager":
293
- raise ValueError(
294
- "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
295
- f"{self._attn_implementation}. Please set it to 'eager' instead."
296
- )
297
- self._output_attentions = value
298
-
299
- @property
300
- def _attn_implementation(self):
301
- return self._attn_implementation_internal
302
-
303
- @_attn_implementation.setter
304
- def _attn_implementation(self, value: str | dict | None):
305
- """We set it recursively on the sub-configs as well"""
306
- # Set if for current config
307
- current_attn = getattr(self, "_attn_implementation", None)
308
- attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
309
- self._attn_implementation_internal = attn_implementation
310
-
311
- # Set it recursively on the subconfigs
312
- for subconfig_key in self.sub_configs:
313
- subconfig = getattr(self, subconfig_key, None)
314
- if subconfig is not None:
315
- current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
316
- sub_implementation = (
317
- value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
318
- )
319
- subconfig._attn_implementation = sub_implementation
320
-
321
- @property
322
- def _experts_implementation(self):
323
- return self._experts_implementation_internal
324
-
325
- @_experts_implementation.setter
326
- def _experts_implementation(self, value: str | dict | None):
327
- """We set it recursively on the sub-configs as well"""
328
- # Set if for current config
329
- current_moe = getattr(self, "_experts_implementation", None)
330
- experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe)
331
- self._experts_implementation_internal = experts_implementation
332
-
333
- # Set it recursively on the subconfigs
334
- for subconfig_key in self.sub_configs:
335
- subconfig = getattr(self, subconfig_key, None)
336
- if subconfig is not None:
337
- current_subconfig_moe = getattr(subconfig, "_experts_implementation", None)
338
- sub_implementation = (
339
- value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe)
340
- )
341
- subconfig._experts_implementation = sub_implementation
342
-
343
- @property
344
- def torch_dtype(self):
345
- logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
346
- return self.dtype
347
-
348
- @property
349
- def use_return_dict(self):
350
- logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!")
351
- return self.return_dict
352
-
353
- @torch_dtype.setter
354
- def torch_dtype(self, value):
355
- logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
356
- self.dtype = value
357
-
358
- def __setattr__(self, key, value):
359
- if key in super().__getattribute__("attribute_map"):
360
- key = super().__getattribute__("attribute_map")[key]
361
- super().__setattr__(key, value)
362
-
363
- def __getattribute__(self, key):
364
- if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
365
- key = super().__getattribute__("attribute_map")[key]
366
- return super().__getattribute__(key)
367
-
368
- def validate_output_attentions(self):
369
- if self.output_attentions and self._attn_implementation not in ["eager", None]:
370
- raise ValueError(
371
- "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
372
- f"{self._attn_implementation}. Please set it to 'eager' instead."
373
- )
374
-
375
- def validate_architecture(self):
376
- """Part of `@strict`-powered validation. Validates the architecture of the config."""
377
- if (
378
- hasattr(self, "head_dim")
379
- and hasattr(self, "num_heads")
380
- and hasattr(self, "embed_dim")
381
- and self.head_dim * self.num_heads != self.embed_dim
382
- ):
383
- raise ValueError(
384
- f"The embed_dim ({self.embed_dim}) is not a multiple of the number of attention "
385
- f"heads ({self.num_heads})."
386
- )
387
-
388
- def validate_token_ids(self):
389
- """Part of `@strict`-powered validation. Validates the contents of the special tokens."""
390
- text_config = self.get_text_config(decoder=True)
391
- vocab_size = getattr(text_config, "vocab_size", None)
392
- if vocab_size is not None:
393
- # Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id
394
- for value in text_config:
395
- if value.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size:
396
- # Can't be an exception until we can load configs that fail validation: several configs on the Hub
397
- # store invalid special tokens, e.g. `pad_token_id=-1`
398
- logger.warning_once(
399
- f"Model config: {value} must be `None` or an integer within the vocabulary (between 0 "
400
- f"and {vocab_size - 1}), got {value}. This may result in unexpected behavior."
401
- )
402
-
403
- def validate_layer_type(self):
404
- """Check that `layer_types` is correctly defined."""
405
- if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")):
406
- return
407
- elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types):
408
- raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}")
409
- elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types):
410
- raise ValueError(
411
- f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
412
- f"({len(self.layer_types)})"
413
- )
414
-
415
- @property
416
- def rope_scaling(self):
417
- return self.rope_parameters
418
-
419
- @rope_scaling.setter
420
- def rope_scaling(self, value):
421
- self.rope_parameters = value
422
-
423
- def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
424
- """
425
- Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
426
- [`~PreTrainedConfig.from_pretrained`] class method.
427
-
428
- Args:
429
- save_directory (`str` or `os.PathLike`):
430
- Directory where the configuration JSON file will be saved (will be created if it does not exist).
431
- push_to_hub (`bool`, *optional*, defaults to `False`):
432
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
433
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
434
- namespace).
435
- kwargs (`dict[str, Any]`, *optional*):
436
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
437
- """
438
- if os.path.isfile(save_directory):
439
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
440
-
441
- generation_parameters = self._get_generation_parameters()
442
- if len(generation_parameters) > 0:
443
- raise ValueError(
444
- "Some generation parameters are set in the model config. These should go into `model.generation_config`"
445
- f"as opposed to `model.config`. \nGeneration parameters found: {str(generation_parameters)}",
446
- )
447
-
448
- os.makedirs(save_directory, exist_ok=True)
449
-
450
- if push_to_hub:
451
- commit_message = kwargs.pop("commit_message", None)
452
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
453
- repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
454
- files_timestamps = self._get_files_timestamps(save_directory)
455
-
456
- # This attribute is important to know on load, but should not be serialized on save.
457
- if "transformers_weights" in self:
458
- delattr(self, "transformers_weights")
459
-
460
- # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
461
- # loaded from the Hub.
462
- if self._auto_class is not None:
463
- custom_object_save(self, save_directory, config=self)
464
-
465
- # If we save using the predefined names, we can load using `from_pretrained`
466
- output_config_file = os.path.join(save_directory, CONFIG_NAME)
467
-
468
- # Strict validation at save-time: prevent bad patterns from propagating
469
- # Using `strict` decorator guarantees that `self.validate` exists , but not all
470
- # model config might have the decorator added
471
- if hasattr(self, "validate"):
472
- self.validate()
473
- self.to_json_file(output_config_file, use_diff=True)
474
- logger.info(f"Configuration saved in {output_config_file}")
475
-
476
- if push_to_hub:
477
- self._upload_modified_files(
478
- save_directory,
479
- repo_id,
480
- files_timestamps,
481
- commit_message=commit_message,
482
- token=kwargs.get("token"),
483
- )
484
-
485
- @classmethod
486
- def from_pretrained(
487
- cls: type[SpecificPreTrainedConfigType],
488
- pretrained_model_name_or_path: str | os.PathLike,
489
- cache_dir: str | os.PathLike | None = None,
490
- force_download: bool = False,
491
- local_files_only: bool = False,
492
- token: str | bool | None = None,
493
- revision: str = "main",
494
- **kwargs,
495
- ) -> SpecificPreTrainedConfigType:
496
- r"""
497
- Instantiate a [`PreTrainedConfig`] (or a derived class) from a pretrained model configuration.
498
-
499
- Args:
500
- pretrained_model_name_or_path (`str` or `os.PathLike`):
501
- This can be either:
502
-
503
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
504
- huggingface.co.
505
- - a path to a *directory* containing a configuration file saved using the
506
- [`~PreTrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
507
- - a path to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
508
- cache_dir (`str` or `os.PathLike`, *optional*):
509
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
510
- standard cache should not be used.
511
- force_download (`bool`, *optional*, defaults to `False`):
512
- Whether or not to force to (re-)download the configuration files and override the cached versions if
513
- they exist.
514
- proxies (`dict[str, str]`, *optional*):
515
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
516
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
517
- token (`str` or `bool`, *optional*):
518
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
519
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
520
- revision (`str`, *optional*, defaults to `"main"`):
521
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
522
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
523
- identifier allowed by git.
524
-
525
- <Tip>
526
-
527
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
528
-
529
- </Tip>
530
-
531
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
532
- If `False`, then this function returns just the final configuration object.
533
-
534
- If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
535
- dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
536
- part of `kwargs` which has not been used to update `config` and is otherwise ignored.
537
- subfolder (`str`, *optional*, defaults to `""`):
538
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
539
- specify the folder name here.
540
- kwargs (`dict[str, Any]`, *optional*):
541
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded
542
- values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
543
- by the `return_unused_kwargs` keyword parameter.
544
-
545
- Returns:
546
- [`PreTrainedConfig`]: The configuration object instantiated from this pretrained model.
547
-
548
- Examples:
549
-
550
- ```python
551
- # We can't instantiate directly the base class *PreTrainedConfig* so let's show the examples on a
552
- # derived class: BertConfig
553
- config = BertConfig.from_pretrained(
554
- "google-bert/bert-base-uncased"
555
- ) # Download configuration from huggingface.co and cache.
556
- config = BertConfig.from_pretrained(
557
- "./test/saved_model/"
558
- ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
559
- config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
560
- config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
561
- assert config.output_attentions == True
562
- config, unused_kwargs = BertConfig.from_pretrained(
563
- "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
564
- )
565
- assert config.output_attentions == True
566
- assert unused_kwargs == {"foo": False}
567
- ```"""
568
- kwargs["cache_dir"] = cache_dir
569
- kwargs["force_download"] = force_download
570
- kwargs["local_files_only"] = local_files_only
571
- kwargs["revision"] = revision
572
-
573
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
574
- if cls.base_config_key and cls.base_config_key in config_dict:
575
- config_dict = config_dict[cls.base_config_key]
576
-
577
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
578
- # sometimes the config has no `base_config_key` if the config is used in several composite models
579
- # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
580
- for v in config_dict.values():
581
- if isinstance(v, dict) and v.get("model_type") == cls.model_type:
582
- config_dict = v
583
-
584
- # raise warning only if we still can't see a match in `model_type`
585
- if config_dict["model_type"] != cls.model_type:
586
- logger.warning(
587
- f"You are using a model of type `{config_dict['model_type']}` to instantiate a model of type "
588
- f"`{cls.model_type}`. This may be expected if you are loading a checkpoint that shares a subset "
589
- f"of the architecture (e.g., loading a `sam2_video` checkpoint into `Sam2Model`), but is otherwise "
590
- f"not supported and can yield errors. Please verify that the checkpoint is compatible with the "
591
- f"model you are instantiating."
592
- )
593
-
594
- return cls.from_dict(config_dict, **kwargs)
595
-
596
- @classmethod
597
- def get_config_dict(
598
- cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
599
- ) -> tuple[dict[str, Any], dict[str, Any]]:
600
- """
601
- From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
602
- [`PreTrainedConfig`] using `from_dict`.
603
-
604
- Parameters:
605
- pretrained_model_name_or_path (`str` or `os.PathLike`):
606
- The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
607
-
608
- Returns:
609
- `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
610
-
611
- """
612
- original_kwargs = copy.deepcopy(kwargs)
613
- # Get config dict associated with the base config file
614
- config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
615
- if config_dict is None:
616
- return {}, kwargs
617
- if "_commit_hash" in config_dict:
618
- original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
619
-
620
- # That config file may point us toward another config file to use.
621
- if "configuration_files" in config_dict:
622
- configuration_file = get_configuration_file(config_dict["configuration_files"])
623
- config_dict, kwargs = cls._get_config_dict(
624
- pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
625
- )
626
-
627
- return config_dict, kwargs
628
-
629
- @classmethod
630
- def _get_config_dict(
631
- cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
632
- ) -> tuple[dict[str, Any], dict[str, Any]]:
633
- cache_dir = kwargs.pop("cache_dir", None)
634
- force_download = kwargs.pop("force_download", False)
635
- proxies = kwargs.pop("proxies", None)
636
- token = kwargs.pop("token", None)
637
- local_files_only = kwargs.pop("local_files_only", False)
638
- revision = kwargs.pop("revision", None)
639
- trust_remote_code = kwargs.pop("trust_remote_code", None)
640
- subfolder = kwargs.pop("subfolder", "")
641
- from_pipeline = kwargs.pop("_from_pipeline", None)
642
- from_auto_class = kwargs.pop("_from_auto", False)
643
- commit_hash = kwargs.pop("_commit_hash", None)
644
-
645
- gguf_file = kwargs.get("gguf_file")
646
-
647
- if trust_remote_code is True:
648
- logger.warning(
649
- "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
650
- " ignored."
651
- )
652
-
653
- user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
654
- if from_pipeline is not None:
655
- user_agent["using_pipeline"] = from_pipeline
656
-
657
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
658
-
659
- is_local = os.path.isdir(pretrained_model_name_or_path)
660
- if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
661
- # Special case when pretrained_model_name_or_path is a local file
662
- resolved_config_file = pretrained_model_name_or_path
663
- is_local = True
664
- else:
665
- configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
666
-
667
- try:
668
- # Load from local folder or from cache or download from model Hub and cache
669
- resolved_config_file = cached_file(
670
- pretrained_model_name_or_path,
671
- configuration_file,
672
- cache_dir=cache_dir,
673
- force_download=force_download,
674
- proxies=proxies,
675
- local_files_only=local_files_only,
676
- token=token,
677
- user_agent=user_agent,
678
- revision=revision,
679
- subfolder=subfolder,
680
- _commit_hash=commit_hash,
681
- )
682
- if resolved_config_file is None:
683
- return None, kwargs
684
- commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
685
- except OSError:
686
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
687
- # the original exception.
688
- raise
689
- except Exception:
690
- # For any other exception, we throw a generic error.
691
- raise OSError(
692
- f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
693
- " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
694
- f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
695
- f" containing a {configuration_file} file"
696
- )
697
-
698
- try:
699
- if gguf_file:
700
- config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
701
- else:
702
- # Load config dict
703
- config_dict = cls._dict_from_json_file(resolved_config_file)
704
-
705
- config_dict["_commit_hash"] = commit_hash
706
- except (json.JSONDecodeError, UnicodeDecodeError):
707
- raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
708
-
709
- if is_local:
710
- logger.info(f"loading configuration file {resolved_config_file}")
711
- else:
712
- logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
713
-
714
- # timm models are not saved with the model_type in the config file
715
- if "model_type" not in config_dict and is_timm_config_dict(config_dict):
716
- config_dict["model_type"] = "timm_wrapper"
717
-
718
- return config_dict, kwargs
719
-
720
- @classmethod
721
- def from_dict(
722
- cls: type[SpecificPreTrainedConfigType], config_dict: dict[str, Any], **kwargs
723
- ) -> SpecificPreTrainedConfigType:
724
- """
725
- Instantiates a [`PreTrainedConfig`] from a Python dictionary of parameters.
726
-
727
- Args:
728
- config_dict (`dict[str, Any]`):
729
- Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
730
- retrieved from a pretrained checkpoint by leveraging the [`~PreTrainedConfig.get_config_dict`] method.
731
- kwargs (`dict[str, Any]`):
732
- Additional parameters from which to initialize the configuration object.
733
-
734
- Returns:
735
- [`PreTrainedConfig`]: The configuration object instantiated from those parameters.
736
- """
737
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
738
-
739
- # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
740
- if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
741
- kwargs.setdefault("_commit_hash", config_dict["_commit_hash"])
742
-
743
- # To remove arg here are those passed along for our internal telemetry but we still need to remove them
744
- to_remove = ["_from_auto", "_from_pipeline"]
745
- valid_fields = [
746
- "num_labels",
747
- "attn_implementation",
748
- "experts_implementation",
749
- "output_attentions",
750
- "torch_dtype",
751
- "dtype",
752
- "name_or_path",
753
- ]
754
- for key, value in kwargs.items():
755
- if key in valid_fields:
756
- if key not in ["torch_dtype", "dtype"]:
757
- config_dict[key] = value
758
- to_remove.append(key)
759
- elif value != "auto":
760
- config_dict[key] = value
761
-
762
- config = cls(**config_dict)
763
-
764
- for key, value in kwargs.items():
765
- if hasattr(config, key):
766
- current_attr = getattr(config, key)
767
- # To authorize passing a custom subconfig as kwarg in models that have nested configs.
768
- # We need to update only custom kwarg values instead and keep other attr in subconfig.
769
- if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict):
770
- current_attr_updated = current_attr.to_dict()
771
- current_attr_updated.update(value)
772
- value = current_attr.__class__(**current_attr_updated)
773
- setattr(config, key, value)
774
- to_remove.append(key)
775
-
776
- for key in to_remove:
777
- kwargs.pop(key, None)
778
-
779
- logger.info(f"Model config {config}")
780
- if return_unused_kwargs:
781
- return config, kwargs
782
- else:
783
- return config
784
-
785
- @classmethod
786
- def from_json_file(
787
- cls: type[SpecificPreTrainedConfigType], json_file: str | os.PathLike
788
- ) -> SpecificPreTrainedConfigType:
789
- """
790
- Instantiates a [`PreTrainedConfig`] from the path to a JSON file of parameters.
791
-
792
- Args:
793
- json_file (`str` or `os.PathLike`):
794
- Path to the JSON file containing the parameters.
795
-
796
- Returns:
797
- [`PreTrainedConfig`]: The configuration object instantiated from that JSON file.
798
-
799
- """
800
- config_dict = cls._dict_from_json_file(json_file)
801
- return cls(**config_dict)
802
-
803
- @classmethod
804
- def _dict_from_json_file(cls, json_file: str | os.PathLike):
805
- with open(json_file, encoding="utf-8") as reader:
806
- text = reader.read()
807
- config_dict = json.loads(text)
808
-
809
- return cls._decode_special_floats(config_dict)
810
-
811
- @classmethod
812
- def _encode_special_floats(cls, obj: Any) -> Any:
813
- """
814
- Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
815
- engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
816
-
817
- It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
818
- """
819
- if isinstance(obj, float):
820
- if math.isnan(obj):
821
- return {_FLOAT_TAG_KEY: "NaN"}
822
- if obj == float("inf"):
823
- return {_FLOAT_TAG_KEY: "Infinity"}
824
- if obj == float("-inf"):
825
- return {_FLOAT_TAG_KEY: "-Infinity"}
826
- return obj
827
-
828
- if isinstance(obj, dict):
829
- return {k: cls._encode_special_floats(v) for k, v in obj.items()}
830
-
831
- if isinstance(obj, (list, tuple)):
832
- return [cls._encode_special_floats(v) for v in obj]
833
-
834
- return obj
835
-
836
- @classmethod
837
- def _decode_special_floats(cls, obj: Any) -> Any:
838
- """
839
- Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
840
- engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
841
-
842
- This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
843
- """
844
- if isinstance(obj, dict):
845
- if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
846
- tag = obj[_FLOAT_TAG_KEY]
847
- if tag in _FLOAT_TAG_VALUES:
848
- return _FLOAT_TAG_VALUES[tag]
849
- return obj
850
-
851
- return {k: cls._decode_special_floats(v) for k, v in obj.items()}
852
-
853
- if isinstance(obj, list):
854
- return [cls._decode_special_floats(v) for v in obj]
855
-
856
- return obj
857
-
858
- def __eq__(self, other):
859
- return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
860
-
861
- def __repr__(self):
862
- return f"{self.__class__.__name__} {self.to_json_string()}"
863
-
864
- def __iter__(self):
865
- yield from self.__dict__
866
-
867
- def to_diff_dict(self) -> dict[str, Any]:
868
- """
869
- Removes all attributes from the configuration that correspond to the default config attributes for
870
- better readability, while always retaining the `config` attribute from the class. Serializes to a
871
- Python dictionary.
872
-
873
- Returns:
874
- dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
875
- """
876
- config_dict = self.to_dict()
877
-
878
- # Get the default config dict (from a fresh PreTrainedConfig instance)
879
- default_config_dict = PreTrainedConfig().to_dict()
880
-
881
- # get class specific config dict
882
- class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
883
-
884
- serializable_config_dict = {}
885
-
886
- # Only serialize values that differ from the default config,
887
- # except always keep the 'config' attribute.
888
- for key, value in config_dict.items():
889
- if (
890
- isinstance(getattr(self, key, None), PreTrainedConfig)
891
- and key in class_config_dict
892
- and isinstance(class_config_dict[key], dict)
893
- ):
894
- # For nested configs we need to clean the diff recursively
895
- diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
896
- if "model_type" in value:
897
- # Needs to be set even if it's not in the diff
898
- diff["model_type"] = value["model_type"]
899
-
900
- serializable_config_dict[key] = diff
901
- elif (
902
- key not in default_config_dict
903
- or key == "transformers_version"
904
- or key == "vocab_file"
905
- or value != default_config_dict[key]
906
- or (key in default_config_dict and value != class_config_dict.get(key, value))
907
- ):
908
- serializable_config_dict[key] = value
909
-
910
- self._remove_keys_not_serialized(serializable_config_dict)
911
-
912
- # Key removed only in diff dict
913
- if "_name_or_path" in serializable_config_dict:
914
- del serializable_config_dict["_name_or_path"]
915
-
916
- if hasattr(self, "quantization_config"):
917
- serializable_config_dict["quantization_config"] = (
918
- self.quantization_config.to_dict()
919
- if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
920
- else self.quantization_config
921
- )
922
- self.dict_dtype_to_str(serializable_config_dict)
923
-
924
- return serializable_config_dict
925
-
926
- def to_dict(self) -> dict[str, Any]:
927
- """
928
- Serializes this instance to a Python dictionary.
929
-
930
- Returns:
931
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
932
- """
933
- output = copy.deepcopy(self.__dict__)
934
- if hasattr(self.__class__, "model_type"):
935
- output["model_type"] = self.__class__.model_type
936
-
937
- # Transformers version when serializing the model
938
- output["transformers_version"] = __version__
939
-
940
- # Pop "kwargs" since they are unpacked and set in the post init
941
- output.pop("kwargs", None)
942
-
943
- def to_list(value):
944
- if isinstance(value, tuple):
945
- value = [to_list(item) for item in value]
946
- return value
947
-
948
- for key, value in output.items():
949
- # Deal with nested configs like CLIP
950
- if isinstance(value, PreTrainedConfig):
951
- value = value.to_dict()
952
- del value["transformers_version"]
953
-
954
- # Some models have defaults as tuples because dataclass
955
- # doesn't allow mutables. Let's convert back to `list``
956
- elif isinstance(value, tuple):
957
- value = to_list(value)
958
-
959
- output[key] = value
960
-
961
- self._remove_keys_not_serialized(output)
962
-
963
- if hasattr(self, "quantization_config"):
964
- output["quantization_config"] = (
965
- self.quantization_config.to_dict()
966
- if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
967
- else self.quantization_config
968
- )
969
- self.dict_dtype_to_str(output)
970
-
971
- return output
972
-
973
- def to_json_string(self, use_diff: bool = True) -> str:
974
- """
975
- Serializes this instance to a JSON string.
976
-
977
- Args:
978
- use_diff (`bool`, *optional*, defaults to `True`):
979
- If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
980
- is serialized to JSON string.
981
-
982
- Returns:
983
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
984
- """
985
- if use_diff is True:
986
- config_dict = self.to_diff_dict()
987
- else:
988
- config_dict = self.to_dict()
989
-
990
- # Handle +/-Infinity and NaNs
991
- config_dict = self._encode_special_floats(config_dict)
992
-
993
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
994
-
995
- def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
996
- """
997
- Save this instance to a JSON file.
998
-
999
- Args:
1000
- json_file_path (`str` or `os.PathLike`):
1001
- Path to the JSON file in which this configuration instance's parameters will be saved.
1002
- use_diff (`bool`, *optional*, defaults to `True`):
1003
- If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
1004
- is serialized to JSON file.
1005
- """
1006
- with open(json_file_path, "w", encoding="utf-8") as writer:
1007
- writer.write(self.to_json_string(use_diff=use_diff))
1008
-
1009
- def update(self, config_dict: dict[str, Any]):
1010
- """
1011
- Updates attributes of this class with attributes from `config_dict`.
1012
-
1013
- Args:
1014
- config_dict (`dict[str, Any]`): Dictionary of attributes that should be updated for this class.
1015
- """
1016
- for key, value in config_dict.items():
1017
- setattr(self, key, value)
1018
-
1019
- def update_from_string(self, update_str: str):
1020
- """
1021
- Updates attributes of this class with attributes from `update_str`.
1022
-
1023
- The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
1024
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
1025
-
1026
- The keys to change have to already exist in the config object.
1027
-
1028
- Args:
1029
- update_str (`str`): String with attributes that should be updated for this class.
1030
-
1031
- """
1032
-
1033
- d = dict(x.split("=") for x in update_str.split(","))
1034
- for k, v in d.items():
1035
- if not hasattr(self, k):
1036
- raise ValueError(f"key {k} isn't in the original config dict")
1037
-
1038
- old_v = getattr(self, k)
1039
- if isinstance(old_v, bool):
1040
- if v.lower() in ["true", "1", "y", "yes"]:
1041
- v = True
1042
- elif v.lower() in ["false", "0", "n", "no"]:
1043
- v = False
1044
- else:
1045
- raise ValueError(f"can't derive true or false from {v} (key {k})")
1046
- elif isinstance(old_v, int):
1047
- v = int(v)
1048
- elif isinstance(old_v, float):
1049
- v = float(v)
1050
- elif not isinstance(old_v, str):
1051
- raise TypeError(
1052
- f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
1053
- )
1054
-
1055
- setattr(self, k, v)
1056
-
1057
- def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
1058
- """
1059
- Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
1060
- converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
1061
- string, which can then be stored in the json format.
1062
- """
1063
- if d.get("dtype") is not None:
1064
- if isinstance(d["dtype"], dict):
1065
- d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
1066
- # models like Emu3 can have "dtype" as token in config's vocabulary map,
1067
- # so we also exclude int type here to avoid error in this special case.
1068
- elif not isinstance(d["dtype"], (str, int)):
1069
- d["dtype"] = str(d["dtype"]).split(".")[1]
1070
- for value in d.values():
1071
- if isinstance(value, dict):
1072
- self.dict_dtype_to_str(value)
1073
-
1074
- def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
1075
- """
1076
- Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
1077
- Runs recursive check on the dict, to remove from all sub configs.
1078
- """
1079
-
1080
- for key_to_remove in [
1081
- "_is_quantized",
1082
- "_auto_class",
1083
- "_commit_hash",
1084
- "_attn_implementation_internal",
1085
- "_experts_implementation_internal",
1086
- "ignore_keys_at_rope_validation",
1087
- "base_model_tp_plan",
1088
- "base_model_pp_plan",
1089
- ]:
1090
- d.pop(key_to_remove, None)
1091
-
1092
- if "_output_attentions" in d:
1093
- d["output_attentions"] = d.pop("_output_attentions")
1094
-
1095
- for value in d.values():
1096
- if isinstance(value, dict):
1097
- self._remove_keys_not_serialized(value)
1098
-
1099
- @classmethod
1100
- def register_for_auto_class(cls, auto_class="AutoConfig"):
1101
- """
1102
- Register this class with a given auto class. This should only be used for custom configurations as the ones in
1103
- the library are already mapped with `AutoConfig`.
1104
-
1105
-
1106
-
1107
- Args:
1108
- auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1109
- The auto class to register this new configuration with.
1110
- """
1111
- if not isinstance(auto_class, str):
1112
- auto_class = auto_class.__name__
1113
-
1114
- import transformers.models.auto as auto_module
1115
-
1116
- if not hasattr(auto_module, auto_class):
1117
- raise ValueError(f"{auto_class} is not a valid auto class.")
1118
-
1119
- cls._auto_class = auto_class
1120
-
1121
- def _get_generation_parameters(self) -> dict[str, Any]:
1122
- """
1123
- Checks if there are generation parameters in `PreTrainedConfig` instance. Note that
1124
- we should not save generation params in PreTrainedConfig, and we will raise error
1125
- if there are any.
1126
- """
1127
- generation_params = {}
1128
- default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
1129
- for key in GenerationConfig._get_default_generation_params().keys():
1130
- if key == "use_cache":
1131
- continue # common key for most models
1132
- if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
1133
- generation_params[key] = getattr(self, key)
1134
-
1135
- return generation_params
1136
-
1137
- def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig":
1138
- """
1139
- Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
1140
- `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
1141
- which is useful on models that have both text input and output modalities.
1142
-
1143
- There are three possible outcomes of using this method:
1144
- 1. On most models, it returns the original config instance itself.
1145
- 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
1146
- of valid names.
1147
- 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
1148
-
1149
- Args:
1150
- decoder (`Optional[bool]`, *optional*):
1151
- If set to `True`, then only search for decoder config names.
1152
- encoder (`Optional[bool]`, *optional*):
1153
- If set to `True`, then only search for encoder config names.
1154
- """
1155
- return_both = decoder == encoder # both unset or both set -> search all possible names
1156
-
1157
- decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1158
- encoder_possible_text_config_names = ("text_encoder",)
1159
- if return_both:
1160
- possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1161
- elif decoder:
1162
- possible_text_config_names = decoder_possible_text_config_names
1163
- else:
1164
- possible_text_config_names = encoder_possible_text_config_names
1165
-
1166
- valid_text_config_names = []
1167
- for text_config_name in possible_text_config_names:
1168
- if hasattr(self, text_config_name):
1169
- text_config = getattr(self, text_config_name, None)
1170
- if text_config is not None:
1171
- valid_text_config_names += [text_config_name]
1172
-
1173
- if len(valid_text_config_names) > 1:
1174
- raise ValueError(
1175
- f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1176
- "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
1177
- "e.g. `text_config = config.sub_config_name`"
1178
- )
1179
- elif len(valid_text_config_names) == 1:
1180
- config_to_return = getattr(self, valid_text_config_names[0])
1181
- else:
1182
- config_to_return = self
1183
-
1184
- # handle legacy models with flat config structure, when we only want one of the configs
1185
- if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
1186
- config_to_return = copy.deepcopy(config_to_return)
1187
- prefix_to_keep = "decoder" if decoder else "encoder"
1188
- for key in config_to_return.to_dict():
1189
- # NOTE: We can't discard keys because:
1190
- # 1) we can't truly delete a cls attribte on a dataclass; 2) we can't set the value to `None` due to
1191
- # strict validation. So we just keep it as is, since there are only a couple old models falling in this condition
1192
- if key.startswith(prefix_to_keep):
1193
- # [encoder/decoder]_layers -> num_hidden_layers
1194
- if key == prefix_to_keep + "_layers":
1195
- new_key = "num_hidden_layers"
1196
- # [encoder/decoder]_attention_heads -> num_attention_heads
1197
- elif key == prefix_to_keep + "_attention_heads":
1198
- new_key = "num_attention_heads"
1199
- # e.g. encoder_hidden_act -> hidden_act
1200
- else:
1201
- new_key = key[len(prefix_to_keep) + 1 :]
1202
-
1203
- # Does the class map the new key into a different attribute name at read time? if so, let's write
1204
- # into that attribute instead
1205
- if new_key in config_to_return.attribute_map:
1206
- new_key = config_to_return.attribute_map[new_key]
1207
-
1208
- value = getattr(config_to_return, key)
1209
- delattr(config_to_return, key)
1210
- setattr(config_to_return, new_key, value)
1211
-
1212
- return config_to_return
1213
-
1214
-
1215
- def get_configuration_file(configuration_files: list[str]) -> str:
1216
- """
1217
- Get the configuration file to use for this version of transformers.
1218
-
1219
- Args:
1220
- configuration_files (`list[str]`): The list of available configuration files.
1221
-
1222
- Returns:
1223
- `str`: The configuration file to use.
1224
- """
1225
- configuration_files_map = {}
1226
- for file_name in configuration_files:
1227
- if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
1228
- v = file_name.removeprefix("config.").removesuffix(".json")
1229
- configuration_files_map[v] = file_name
1230
- available_versions = sorted(configuration_files_map.keys())
1231
-
1232
- # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1233
- configuration_file = CONFIG_NAME
1234
- transformers_version = version.parse(__version__)
1235
- for v in available_versions:
1236
- if version.parse(v) <= transformers_version:
1237
- configuration_file = configuration_files_map[v]
1238
- else:
1239
- # No point going further since the versions are sorted.
1240
- break
1241
-
1242
- return configuration_file
1243
-
1244
-
1245
- def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1246
- """
1247
- Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1248
- values from `dict_a` that are different from values in `dict_b`.
1249
-
1250
- dict_b : the default config dictionary. We want to remove values that are in this one
1251
- """
1252
- diff = {}
1253
- default = config_obj.__class__().to_dict() if config_obj is not None else {}
1254
- for key, value in dict_a.items():
1255
- obj_value = getattr(config_obj, str(key), None)
1256
- if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1257
- diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1258
- diff[key] = diff_value
1259
- elif key not in dict_b or (value != default[key]):
1260
- diff[key] = value
1261
- return diff
1262
-
1263
-
1264
- PreTrainedConfig.push_to_hub = copy_func(PreTrainedConfig.push_to_hub)
1265
- if PreTrainedConfig.push_to_hub.__doc__ is not None:
1266
- PreTrainedConfig.push_to_hub.__doc__ = PreTrainedConfig.push_to_hub.__doc__.format(
1267
- object="config", object_class="AutoConfig", object_files="configuration file"
1268
- )
1269
-
1270
-
1271
- # The alias is only here for BC - we did not have the correct CamelCasing before
1272
- PretrainedConfig = PreTrainedConfig
1273
-
1274
-
1275
- def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True):
1276
- logger.warning(
1277
- "`layer_type_validation` is deprecated and will be removed in v5.20. "
1278
- "Use `PreTrainedConfig.validate_layer_type` instead"
1279
- )
1280
-
1281
- if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1282
- raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
1283
- if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
1284
- raise ValueError(
1285
- f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
1286
- f"({len(layer_types)})"
1287
- )