Student0809 commited on
Commit
73c08c8
·
verified ·
1 Parent(s): 3b47bbc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dev_scripts/dockerci.sh +88 -0
  2. .gitattributes +1 -0
  3. .github/workflows/citest.yaml +75 -0
  4. docs/resources/grpo_code.png +3 -0
  5. docs/transformers/build/lib/transformers/models/cpm/tokenization_cpm.py +350 -0
  6. docs/transformers/build/lib/transformers/models/cpmant/modeling_cpmant.py +860 -0
  7. docs/transformers/build/lib/transformers/models/cpmant/tokenization_cpmant.py +270 -0
  8. docs/transformers/build/lib/transformers/models/ctrl/configuration_ctrl.py +116 -0
  9. docs/transformers/build/lib/transformers/models/ctrl/modeling_ctrl.py +844 -0
  10. docs/transformers/build/lib/transformers/models/ctrl/modeling_tf_ctrl.py +922 -0
  11. docs/transformers/build/lib/transformers/models/ctrl/tokenization_ctrl.py +251 -0
  12. docs/transformers/build/lib/transformers/models/cvt/__init__.py +28 -0
  13. docs/transformers/build/lib/transformers/models/cvt/configuration_cvt.py +146 -0
  14. docs/transformers/build/lib/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py +362 -0
  15. docs/transformers/build/lib/transformers/models/cvt/modeling_cvt.py +727 -0
  16. docs/transformers/build/lib/transformers/models/cvt/modeling_tf_cvt.py +1096 -0
  17. docs/transformers/build/lib/transformers/models/dab_detr/__init__.py +28 -0
  18. docs/transformers/build/lib/transformers/models/dab_detr/configuration_dab_detr.py +260 -0
  19. docs/transformers/build/lib/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py +233 -0
  20. docs/transformers/build/lib/transformers/models/dab_detr/modeling_dab_detr.py +1716 -0
  21. docs/transformers/build/lib/transformers/models/dac/__init__.py +28 -0
  22. docs/transformers/build/lib/transformers/models/dac/configuration_dac.py +114 -0
  23. docs/transformers/build/lib/transformers/models/dac/convert_dac_checkpoint.py +261 -0
  24. docs/transformers/build/lib/transformers/models/dac/feature_extraction_dac.py +173 -0
  25. docs/transformers/build/lib/transformers/models/dac/modeling_dac.py +724 -0
  26. docs/transformers/build/lib/transformers/models/data2vec/__init__.py +32 -0
  27. docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_audio.py +288 -0
  28. docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_text.py +154 -0
  29. docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_vision.py +194 -0
  30. docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py +285 -0
  31. docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py +207 -0
  32. docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py +374 -0
  33. docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_audio.py +1746 -0
  34. docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_text.py +1553 -0
  35. docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_vision.py +1449 -0
  36. docs/transformers/build/lib/transformers/models/data2vec/modeling_tf_data2vec_vision.py +1724 -0
  37. docs/transformers/build/lib/transformers/models/data2vec/modular_data2vec_audio.py +400 -0
  38. docs/transformers/build/lib/transformers/models/dbrx/__init__.py +27 -0
  39. docs/transformers/build/lib/transformers/models/dbrx/configuration_dbrx.py +232 -0
  40. docs/transformers/build/lib/transformers/models/dbrx/modeling_dbrx.py +1392 -0
  41. docs/transformers/build/lib/transformers/models/deberta/__init__.py +30 -0
  42. docs/transformers/build/lib/transformers/models/deberta/configuration_deberta.py +199 -0
  43. docs/transformers/build/lib/transformers/models/deberta/modeling_deberta.py +1352 -0
  44. docs/transformers/build/lib/transformers/models/deberta/modeling_tf_deberta.py +1652 -0
  45. docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta.py +396 -0
  46. docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta_fast.py +239 -0
  47. docs/transformers/build/lib/transformers/models/deberta_v2/__init__.py +30 -0
  48. docs/transformers/build/lib/transformers/models/deberta_v2/configuration_deberta_v2.py +198 -0
  49. docs/transformers/build/lib/transformers/models/deberta_v2/modeling_deberta_v2.py +1523 -0
  50. docs/transformers/build/lib/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +1881 -0
.dev_scripts/dockerci.sh ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache
3
+ CODE_DIR=$PWD
4
+ CODE_DIR_IN_CONTAINER=/ms-swift
5
+ echo "$USER"
6
+ gpus='0,1 2,3'
7
+ cpu_sets='0-15 16-31'
8
+ cpu_sets_arr=($cpu_sets)
9
+ is_get_file_lock=false
10
+ CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
11
+ echo "ci command: $CI_COMMAND"
12
+ PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
13
+ echo "PR modified files: $PR_CHANGED_FILES"
14
+ PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
15
+ echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
16
+ idx=0
17
+ for gpu in $gpus
18
+ do
19
+ exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
20
+ flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; idx=$((idx+1)); continue; }
21
+ echo "get gpu lock $gpu"
22
+
23
+ CONTAINER_NAME="swift-ci-$idx"
24
+ let is_get_file_lock=true
25
+
26
+ # pull image if there are update
27
+ docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
28
+ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
29
+ echo 'debugging'
30
+ docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
31
+ --cpuset-cpus=${cpu_sets_arr[$idx]} \
32
+ --gpus='"'"device=$gpu"'"' \
33
+ -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
34
+ -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
35
+ -v $MODELSCOPE_HOME_CACHE/$idx:/root \
36
+ -v /home/admin/pre-commit:/home/admin/pre-commit \
37
+ -e CI_TEST=True \
38
+ -e TEST_LEVEL=$TEST_LEVEL \
39
+ -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
40
+ -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
41
+ -e MODELSCOPE_SDK_DEBUG=True \
42
+ -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
43
+ -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
44
+ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
45
+ -e TEST_LEVEL=$TEST_LEVEL \
46
+ -e MODELSCOPE_ENVIRONMENT='ci' \
47
+ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
48
+ -e MODEL_TAG_URL=$MODEL_TAG_URL \
49
+ -e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
50
+ -e PR_CHANGED_FILES=$PR_CHANGED_FILES \
51
+ --workdir=$CODE_DIR_IN_CONTAINER \
52
+ ${IMAGE_NAME}:${IMAGE_VERSION} \
53
+ $CI_COMMAND
54
+ else
55
+ docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
56
+ --cpuset-cpus=${cpu_sets_arr[$idx]} \
57
+ --gpus='"'"device=$gpu"'"' \
58
+ -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
59
+ -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
60
+ -v $MODELSCOPE_HOME_CACHE/$idx:/root \
61
+ -v /home/admin/pre-commit:/home/admin/pre-commit \
62
+ -e CI_TEST=True \
63
+ -e TEST_LEVEL=$TEST_LEVEL \
64
+ -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
65
+ -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
66
+ -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
67
+ -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
68
+ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
69
+ -e TEST_LEVEL=$TEST_LEVEL \
70
+ -e MODELSCOPE_ENVIRONMENT='ci' \
71
+ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
72
+ -e MODEL_TAG_URL=$MODEL_TAG_URL \
73
+ -e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
74
+ -e PR_CHANGED_FILES=$PR_CHANGED_FILES \
75
+ --workdir=$CODE_DIR_IN_CONTAINER \
76
+ ${IMAGE_NAME}:${IMAGE_VERSION} \
77
+ $CI_COMMAND
78
+ fi
79
+ if [ $? -ne 0 ]; then
80
+ echo "Running test case failed, please check the log!"
81
+ exit -1
82
+ fi
83
+ break
84
+ done
85
+ if [ "$is_get_file_lock" = false ] ; then
86
+ echo 'No free GPU!'
87
+ exit 1
88
+ fi
.gitattributes CHANGED
@@ -55,3 +55,4 @@ docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
55
  docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
56
  docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
57
  docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
 
 
55
  docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
56
  docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
57
  docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
58
+ docs/resources/grpo_code.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/citest.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: citest
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+ - "release/**"
8
+ paths-ignore:
9
+ - "setup.*"
10
+ - "requirements.txt"
11
+ - "requirements/**"
12
+ - "docs/**"
13
+ - "tools/**"
14
+ - ".dev_scripts/**"
15
+ - "README.md"
16
+ - "README_*.md"
17
+ - "NOTICE"
18
+ - ".github/workflows/lint.yaml"
19
+ - ".github/workflows/publish.yaml"
20
+
21
+ pull_request:
22
+ paths-ignore:
23
+ - "setup.*"
24
+ - "requirements.txt"
25
+ - "requirements/**"
26
+ - "docs/**"
27
+ - "tools/**"
28
+ - ".dev_scripts/**"
29
+ - "README.md"
30
+ - "README_*.md"
31
+ - "NOTICE"
32
+ - ".github/workflows/lint.yaml"
33
+ - ".github/workflows/publish.yaml"
34
+
35
+ concurrency:
36
+ group: ${{ github.workflow }}-${{ github.ref }}
37
+ cancel-in-progress: true
38
+
39
+ jobs:
40
+ unittest:
41
+ # The type of runner that the job will run on
42
+ runs-on: [self-hosted]
43
+ timeout-minutes: 240
44
+ steps:
45
+ - name: ResetFileMode
46
+ shell: bash
47
+ run: |
48
+ # reset filemode to allow action runner to delete files
49
+ # generated by root in docker
50
+ set -e
51
+ source ~/.bashrc
52
+ sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
53
+
54
+ - name: Checkout
55
+ uses: actions/checkout@v3
56
+ with:
57
+ lfs: 'true'
58
+ submodules: 'true'
59
+ fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
60
+ - name: Get changed files
61
+ id: changed-files
62
+ run: |
63
+ if ${{ github.event_name == 'pull_request' }}; then
64
+ echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
65
+ else
66
+ echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
67
+ fi
68
+ - name: Checkout LFS objects
69
+ run: git lfs checkout
70
+ - name: Run unittest
71
+ shell: bash
72
+ run: |
73
+ set -e
74
+ source /mnt/modelscope/ci_env.sh
75
+ bash .dev_scripts/dockerci.sh
docs/resources/grpo_code.png ADDED

Git LFS Details

  • SHA256: 5f396d9ce5ce9de323d7a6ffa8d53f783d938a242088b191f46a293268193b64
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB
docs/transformers/build/lib/transformers/models/cpm/tokenization_cpm.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ import os
18
+ import unicodedata
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import SPIECE_UNDERLINE, logging
26
+ from ...utils.import_utils import requires
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
32
+
33
+
34
+ @requires(backends=("sentencepiece",))
35
+ class CpmTokenizer(PreTrainedTokenizer):
36
+ """Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models."""
37
+
38
+ vocab_files_names = VOCAB_FILES_NAMES
39
+
40
+ def __init__(
41
+ self,
42
+ vocab_file,
43
+ do_lower_case=False,
44
+ remove_space=True,
45
+ keep_accents=False,
46
+ bos_token="<s>",
47
+ eos_token="</s>",
48
+ unk_token="<unk>",
49
+ sep_token="<sep>",
50
+ pad_token="<pad>",
51
+ cls_token="<cls>",
52
+ mask_token="<mask>",
53
+ additional_special_tokens=["<eop>", "<eod>"],
54
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
55
+ **kwargs,
56
+ ) -> None:
57
+ """
58
+ Construct a CPM tokenizer. Based on [Jieba](https://pypi.org/project/jieba/) and
59
+ [SentencePiece](https://github.com/google/sentencepiece).
60
+
61
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should
62
+ refer to this superclass for more information regarding those methods.
63
+
64
+ Args:
65
+ vocab_file (`str`):
66
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
67
+ contains the vocabulary necessary to instantiate a tokenizer.
68
+ do_lower_case (`bool`, *optional*, defaults to `True`):
69
+ Whether to lowercase the input when tokenizing.
70
+ remove_space (`bool`, *optional*, defaults to `True`):
71
+ Whether to strip the text when tokenizing (removing excess spaces before and after the string).
72
+ keep_accents (`bool`, *optional*, defaults to `False`):
73
+ Whether to keep accents when tokenizing.
74
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
75
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier
76
+ token.
77
+
78
+ <Tip>
79
+
80
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
81
+ sequence. The token used is the `cls_token`.
82
+
83
+ </Tip>
84
+
85
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
86
+ The end of sequence token.
87
+
88
+ <Tip>
89
+
90
+ When building a sequence using special tokens, this is not the token that is used for the end of
91
+ sequence. The token used is the `sep_token`.
92
+
93
+ </Tip>
94
+
95
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
96
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be
97
+ this token instead.
98
+ sep_token (`str`, *optional*, defaults to `"<sep>"`):
99
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
100
+ for sequence classification or for a text and a question for question answering. It is also used as the
101
+ last token of a sequence built with special tokens.
102
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
103
+ The token used for padding, for example when batching sequences of different lengths.
104
+ cls_token (`str`, *optional*, defaults to `"<cls>"`):
105
+ The classifier token which is used when doing sequence classification (classification of the whole
106
+ sequence instead of per-token classification). It is the first token of the sequence when built with
107
+ special tokens.
108
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
109
+ The token used for masking values. This is the token used when training this model with masked language
110
+ modeling. This is the token which the model will try to predict.
111
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<eop>", "<eod>"]`):
112
+ Additional special tokens used by the tokenizer.
113
+
114
+ Attributes:
115
+ sp_model (`SentencePieceProcessor`):
116
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
117
+ """
118
+ # Mask token behave like a normal word, i.e. include the space before it
119
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
120
+
121
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
122
+
123
+ self.do_lower_case = do_lower_case
124
+ self.remove_space = remove_space
125
+ self.keep_accents = keep_accents
126
+ self.vocab_file = vocab_file
127
+
128
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
129
+ self.sp_model.Load(vocab_file)
130
+
131
+ try:
132
+ import jieba
133
+ except ModuleNotFoundError as error:
134
+ raise error.__class__(
135
+ "You need to install jieba to use CpmTokenizer or CpmTokenizerFast. "
136
+ "See https://pypi.org/project/jieba/ for installation."
137
+ )
138
+ self.jieba = jieba
139
+ self.translator = str.maketrans(" \n", "\u2582\u2583")
140
+
141
+ super().__init__(
142
+ do_lower_case=do_lower_case,
143
+ remove_space=remove_space,
144
+ keep_accents=keep_accents,
145
+ bos_token=bos_token,
146
+ eos_token=eos_token,
147
+ unk_token=unk_token,
148
+ sep_token=sep_token,
149
+ pad_token=pad_token,
150
+ cls_token=cls_token,
151
+ mask_token=mask_token,
152
+ additional_special_tokens=additional_special_tokens,
153
+ sp_model_kwargs=self.sp_model_kwargs,
154
+ **kwargs,
155
+ )
156
+
157
+ self._pad_token_type_id = 3
158
+
159
+ @property
160
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.vocab_size
161
+ def vocab_size(self):
162
+ return len(self.sp_model)
163
+
164
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_vocab
165
+ def get_vocab(self):
166
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
167
+ vocab.update(self.added_tokens_encoder)
168
+ return vocab
169
+
170
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__getstate__
171
+ def __getstate__(self):
172
+ state = self.__dict__.copy()
173
+ state["sp_model"] = None
174
+ return state
175
+
176
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__setstate__
177
+ def __setstate__(self, d):
178
+ self.__dict__ = d
179
+
180
+ # for backward compatibility
181
+ if not hasattr(self, "sp_model_kwargs"):
182
+ self.sp_model_kwargs = {}
183
+
184
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
185
+ self.sp_model.Load(self.vocab_file)
186
+
187
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.preprocess_text
188
+ def preprocess_text(self, inputs):
189
+ if self.remove_space:
190
+ outputs = " ".join(inputs.strip().split())
191
+ else:
192
+ outputs = inputs
193
+ outputs = outputs.replace("``", '"').replace("''", '"')
194
+
195
+ if not self.keep_accents:
196
+ outputs = unicodedata.normalize("NFKD", outputs)
197
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
198
+ if self.do_lower_case:
199
+ outputs = outputs.lower()
200
+
201
+ return outputs
202
+
203
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._tokenize
204
+ def _tokenize(self, text: str) -> List[str]:
205
+ """Tokenize a string."""
206
+ text = self.preprocess_text(text)
207
+ pieces = self.sp_model.encode(text, out_type=str)
208
+ new_pieces = []
209
+ for piece in pieces:
210
+ if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
211
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
212
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
213
+ if len(cur_pieces[0]) == 1:
214
+ cur_pieces = cur_pieces[1:]
215
+ else:
216
+ cur_pieces[0] = cur_pieces[0][1:]
217
+ cur_pieces.append(piece[-1])
218
+ new_pieces.extend(cur_pieces)
219
+ else:
220
+ new_pieces.append(piece)
221
+
222
+ return new_pieces
223
+
224
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_token_to_id
225
+ def _convert_token_to_id(self, token):
226
+ """Converts a token (str) in an id using the vocab."""
227
+ return self.sp_model.PieceToId(token)
228
+
229
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_id_to_token
230
+ def _convert_id_to_token(self, index):
231
+ """Converts an index (integer) in a token (str) using the vocab."""
232
+ return self.sp_model.IdToPiece(index)
233
+
234
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.convert_tokens_to_string
235
+ def convert_tokens_to_string(self, tokens):
236
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
237
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
238
+ return out_string
239
+
240
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.build_inputs_with_special_tokens
241
+ def build_inputs_with_special_tokens(
242
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
243
+ ) -> List[int]:
244
+ """
245
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
246
+ adding special tokens. An XLNet sequence has the following format:
247
+
248
+ - single sequence: `X <sep> <cls>`
249
+ - pair of sequences: `A <sep> B <sep> <cls>`
250
+
251
+ Args:
252
+ token_ids_0 (`List[int]`):
253
+ List of IDs to which the special tokens will be added.
254
+ token_ids_1 (`List[int]`, *optional*):
255
+ Optional second list of IDs for sequence pairs.
256
+
257
+ Returns:
258
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
259
+ """
260
+ sep = [self.sep_token_id]
261
+ cls = [self.cls_token_id]
262
+ if token_ids_1 is None:
263
+ return token_ids_0 + sep + cls
264
+ return token_ids_0 + sep + token_ids_1 + sep + cls
265
+
266
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_special_tokens_mask
267
+ def get_special_tokens_mask(
268
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
269
+ ) -> List[int]:
270
+ """
271
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
272
+ special tokens using the tokenizer `prepare_for_model` method.
273
+
274
+ Args:
275
+ token_ids_0 (`List[int]`):
276
+ List of IDs.
277
+ token_ids_1 (`List[int]`, *optional*):
278
+ Optional second list of IDs for sequence pairs.
279
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
280
+ Whether or not the token list is already formatted with special tokens for the model.
281
+
282
+ Returns:
283
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
284
+ """
285
+
286
+ if already_has_special_tokens:
287
+ return super().get_special_tokens_mask(
288
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
289
+ )
290
+
291
+ if token_ids_1 is not None:
292
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
293
+ return ([0] * len(token_ids_0)) + [1, 1]
294
+
295
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.create_token_type_ids_from_sequences
296
+ def create_token_type_ids_from_sequences(
297
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
298
+ ) -> List[int]:
299
+ """
300
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
301
+ sequence pair mask has the following format:
302
+
303
+ ```
304
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
305
+ | first sequence | second sequence |
306
+ ```
307
+
308
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
309
+
310
+ Args:
311
+ token_ids_0 (`List[int]`):
312
+ List of IDs.
313
+ token_ids_1 (`List[int]`, *optional*):
314
+ Optional second list of IDs for sequence pairs.
315
+
316
+ Returns:
317
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
318
+ """
319
+ sep = [self.sep_token_id]
320
+ cls_segment_id = [2]
321
+
322
+ if token_ids_1 is None:
323
+ return len(token_ids_0 + sep) * [0] + cls_segment_id
324
+ return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
325
+
326
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.save_vocabulary
327
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
328
+ if not os.path.isdir(save_directory):
329
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
330
+ return
331
+ out_vocab_file = os.path.join(
332
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
333
+ )
334
+
335
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
336
+ copyfile(self.vocab_file, out_vocab_file)
337
+ elif not os.path.isfile(self.vocab_file):
338
+ with open(out_vocab_file, "wb") as fi:
339
+ content_spiece_model = self.sp_model.serialized_model_proto()
340
+ fi.write(content_spiece_model)
341
+
342
+ return (out_vocab_file,)
343
+
344
+ def _decode(self, *args, **kwargs):
345
+ text = super()._decode(*args, **kwargs)
346
+ text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n")
347
+ return text
348
+
349
+
350
+ __all__ = ["CpmTokenizer"]
docs/transformers/build/lib/transformers/models/cpmant/modeling_cpmant.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CPMAnt"""
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
31
+ from .configuration_cpmant import CpmAntConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b"
37
+ _CONFIG_FOR_DOC = "CpmAntConfig"
38
+
39
+
40
+ class CpmAntLayerNorm(nn.Module):
41
+ """
42
+ We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
43
+ """
44
+
45
+ def __init__(self, config: CpmAntConfig):
46
+ super().__init__()
47
+
48
+ self.eps = config.eps
49
+ self.dim_norm = config.hidden_size
50
+ self.weight = nn.Parameter(torch.empty(config.hidden_size))
51
+
52
+ def forward(self, hidden_states: torch.Tensor):
53
+ """
54
+ Args:
55
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
56
+ """
57
+ if hidden_states.size(-1) != self.dim_norm:
58
+ raise AssertionError("hidden_states.size(-1) != self.dim_norm")
59
+ old_dtype = hidden_states.dtype
60
+ variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
61
+ hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
62
+ return hidden_states
63
+
64
+
65
+ class CpmAntAttention(nn.Module):
66
+ def __init__(self, config: CpmAntConfig):
67
+ super().__init__()
68
+ self.dim_model = config.hidden_size
69
+ self.num_heads = config.num_attention_heads
70
+ self.dim_head = config.dim_head
71
+
72
+ self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
73
+ self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
74
+ self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
75
+
76
+ self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
77
+
78
+ self.softmax = torch.nn.Softmax(dim=-1)
79
+
80
+ if config.dropout_p is not None:
81
+ self.dropout = torch.nn.Dropout(p=config.dropout_p)
82
+ else:
83
+ self.dropout = None
84
+
85
+ def forward(
86
+ self,
87
+ hidden_q: torch.Tensor,
88
+ hidden_kv: torch.Tensor,
89
+ attention_mask: torch.BoolTensor,
90
+ position_bias: torch.Tensor,
91
+ output_attentions: Optional[bool] = False,
92
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
93
+ use_cache: Optional[bool] = None,
94
+ ):
95
+ """
96
+ Args:
97
+ hidden_q (`torch.Tensor`):
98
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
99
+ hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
100
+ Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
101
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
102
+ Avoid invalid areas to participate in the calculation of self-attention.
103
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
104
+ Provide positional information to self-attention block.
105
+ output_attentions (`bool`, *optional*):
106
+ Whether or not to return the attentions tensors of all attention layers.
107
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
108
+ Cached past key and value projection states.
109
+ use_cache (`bool`, *optional*):
110
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
111
+ (see `past_key_values`).
112
+ """
113
+ batch_size = hidden_q.size(0)
114
+ len_q = hidden_q.size(1)
115
+ len_k = hidden_kv.size(1)
116
+
117
+ query = self.project_q(hidden_q)
118
+ key = self.project_k(hidden_kv)
119
+ value = self.project_v(hidden_kv)
120
+
121
+ query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
122
+ key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
123
+ value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
124
+
125
+ if past_key_values is not None:
126
+ key = torch.cat([past_key_values[0], key], dim=-2)
127
+ value = torch.cat([past_key_values[1], value], dim=-2)
128
+ len_k = key.size(-2)
129
+
130
+ # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
131
+ score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
132
+ score = score + position_bias
133
+
134
+ score = torch.masked_fill(
135
+ score,
136
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
137
+ torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
138
+ )
139
+ score = self.softmax(score)
140
+
141
+ score = torch.masked_fill(
142
+ score,
143
+ attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
144
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
145
+ )
146
+ if output_attentions:
147
+ attn_weights = score
148
+ else:
149
+ attn_weights = None
150
+
151
+ if self.dropout is not None:
152
+ score = self.dropout(score)
153
+
154
+ # (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
155
+ score = torch.matmul(score, value)
156
+
157
+ score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
158
+ score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
159
+
160
+ score = self.attention_out(score)
161
+
162
+ past_key_values = None
163
+ if use_cache:
164
+ past_key_values = (key, value)
165
+
166
+ return score, attn_weights, past_key_values
167
+
168
+
169
+ class CpmAntSelfAttentionBlock(nn.Module):
170
+ def __init__(self, config: CpmAntConfig):
171
+ super().__init__()
172
+ self.layernorm_before_attention = CpmAntLayerNorm(config)
173
+ self.self_attention = CpmAntAttention(config)
174
+ if config.dropout_p:
175
+ self.dropout = torch.nn.Dropout(config.dropout_p)
176
+ else:
177
+ self.dropout = None
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ attention_mask: torch.Tensor,
183
+ position_bias: Optional[torch.Tensor] = None,
184
+ output_attentions: Optional[bool] = False,
185
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ ):
188
+ """
189
+ Args:
190
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
191
+ Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
192
+ attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
193
+ Avoid invalid areas to participate in the calculation of self-attention.
194
+ position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
195
+ Provide positional information to self-attention block.
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers.
198
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
199
+ Cached past key and value projection states.
200
+ use_cache (`bool`, *optional*):
201
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
202
+ (see `past_key_values`).
203
+ """
204
+ outputs = self.layernorm_before_attention(hidden_states)
205
+ outputs = self.self_attention(
206
+ outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
207
+ )
208
+
209
+ outputs, attn_weights, current_key_value = outputs
210
+
211
+ if self.dropout is not None:
212
+ outputs = self.dropout(outputs)
213
+ hidden_states = hidden_states + outputs
214
+
215
+ return hidden_states, attn_weights, current_key_value
216
+
217
+
218
+ class CpmAntDenseGatedACT(nn.Module):
219
+ def __init__(self, config: CpmAntConfig):
220
+ super().__init__()
221
+ self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
222
+ self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
223
+ self.act = torch.nn.GELU()
224
+
225
+ def forward(self, hidden_states: torch.Tensor):
226
+ """Transform an input tensor from one feature space to another via a nonlinear operation
227
+
228
+ Args:
229
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
230
+ """
231
+ gate_score = self.act(self.w_0(hidden_states))
232
+ hidden_states = self.w_1(hidden_states)
233
+
234
+ hidden_states = gate_score * hidden_states
235
+ return hidden_states
236
+
237
+
238
+ class CpmAntFeedForward(nn.Module):
239
+ def __init__(self, config: CpmAntConfig):
240
+ super().__init__()
241
+ self.w_in = CpmAntDenseGatedACT(config)
242
+ if config.dropout_p is not None:
243
+ self.dropout = torch.nn.Dropout(config.dropout_p)
244
+ else:
245
+ self.dropout = None
246
+
247
+ self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
248
+
249
+ def forward(self, hidden_states: torch.Tensor):
250
+ """
251
+ Args:
252
+ hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
253
+ """
254
+ hidden_states = self.w_in(hidden_states)
255
+
256
+ if self.dropout is not None:
257
+ hidden_states = self.dropout(hidden_states)
258
+
259
+ hidden_states = self.w_out(hidden_states)
260
+
261
+ return hidden_states
262
+
263
+
264
+ class CpmAntFFNBlock(nn.Module):
265
+ def __init__(self, config: CpmAntConfig):
266
+ super().__init__()
267
+ self.layernorm_before_ffn = CpmAntLayerNorm(config)
268
+ self.ffn = CpmAntFeedForward(config)
269
+ if config.dropout_p:
270
+ self.dropout = torch.nn.Dropout(config.dropout_p)
271
+ else:
272
+ self.dropout = None
273
+
274
+ def forward(
275
+ self,
276
+ hidden_states: torch.Tensor,
277
+ ):
278
+ """
279
+ Args:
280
+ hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
281
+ Hidden states before feed forward layer.
282
+ """
283
+ ln_outputs = self.layernorm_before_ffn(hidden_states)
284
+ outputs = self.ffn(ln_outputs)
285
+ if self.dropout is not None:
286
+ outputs = self.dropout(outputs)
287
+ hidden_states = hidden_states + outputs
288
+ return hidden_states
289
+
290
+
291
+ class CpmAntTransformerBlock(nn.Module):
292
+ def __init__(self, config: CpmAntConfig):
293
+ super().__init__()
294
+ self.self_att = CpmAntSelfAttentionBlock(config)
295
+ self.ffn = CpmAntFFNBlock(config)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: torch.Tensor,
301
+ position_bias: Optional[torch.Tensor] = None,
302
+ output_attentions: Optional[bool] = False,
303
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
304
+ use_cache: Optional[bool] = None,
305
+ ):
306
+ """
307
+ Args:
308
+ hidden_states (`torch.Tensor`):
309
+ Input to the layer of shape `(batch, seq_len, dim_model)`
310
+ attention_mask (`torch.Tensor`):
311
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
312
+ position_bias (`torch.Tensor`):
313
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
314
+ output_attentions (`bool`, *optional*):
315
+ Whether or not to return the attentions tensors of all attention layers.
316
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
317
+ Cached past key and value projection states
318
+ use_cache (`bool`, *optional*):
319
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
320
+ (see `past_key_values`).
321
+ """
322
+ hidden_states = self.self_att(
323
+ hidden_states,
324
+ attention_mask=attention_mask,
325
+ position_bias=position_bias,
326
+ output_attentions=output_attentions,
327
+ past_key_values=past_key_values,
328
+ use_cache=use_cache,
329
+ )
330
+
331
+ hidden_states, attn_weights, current_key_value = hidden_states
332
+
333
+ hidden_states = self.ffn(hidden_states)
334
+
335
+ return hidden_states, attn_weights, current_key_value
336
+
337
+
338
+ class CpmAntEncoder(nn.Module):
339
+ def __init__(self, config: CpmAntConfig):
340
+ super().__init__()
341
+ self.num_layers = config.num_hidden_layers
342
+ self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)])
343
+
344
+ self.output_layernorm = CpmAntLayerNorm(config)
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states: torch.Tensor,
349
+ attention_mask: torch.Tensor,
350
+ position_bias: torch.Tensor,
351
+ output_attentions: Optional[bool] = None,
352
+ output_hidden_states: Optional[bool] = None,
353
+ past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
354
+ use_cache: Optional[bool] = None,
355
+ ):
356
+ """
357
+ Args:
358
+ hidden_states (`torch.Tensor`):
359
+ Input to the layer of shape `(batch, seq_len, dim_model)`
360
+ attention_mask (`torch.Tensor`):
361
+ Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
362
+ position_bias (`torch.Tensor`):
363
+ Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
364
+ output_attentions (`bool`, *optional*):
365
+ Whether or not to return the attentions tensors of all attention layers.
366
+ output_hidden_states (`bool`, *optional*):
367
+ Whether or not to return the hidden states of all layers.
368
+ past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
369
+ Cached past key and value projection states
370
+ use_cache (`bool`, *optional*):
371
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
372
+ (see `past_key_values`).
373
+ """
374
+ all_hidden_states = () if output_hidden_states else None
375
+ all_self_attns = () if output_attentions else None
376
+ current_key_values = () if use_cache else None
377
+
378
+ for i, layer in enumerate(self.layers):
379
+ if output_hidden_states:
380
+ all_hidden_states += (hidden_states,)
381
+ layer_outputs = layer(
382
+ hidden_states,
383
+ attention_mask,
384
+ position_bias,
385
+ output_attentions=output_attentions,
386
+ past_key_values=past_key_values[i] if past_key_values else None,
387
+ use_cache=use_cache,
388
+ )
389
+ hidden_states, attn_weights, current_key_value = layer_outputs
390
+ if output_attentions:
391
+ all_self_attns += (attn_weights,)
392
+ if current_key_value is not None:
393
+ current_key_values = current_key_values + (current_key_value,)
394
+
395
+ hidden_states = self.output_layernorm(hidden_states)
396
+
397
+ if output_hidden_states:
398
+ all_hidden_states += (hidden_states,)
399
+
400
+ return hidden_states, current_key_values, all_hidden_states, all_self_attns
401
+
402
+
403
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
404
+ class CpmAntIntermediate(nn.Module):
405
+ def __init__(self, config):
406
+ super().__init__()
407
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
408
+ if isinstance(config.hidden_act, str):
409
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
410
+ else:
411
+ self.intermediate_act_fn = config.hidden_act
412
+
413
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
414
+ hidden_states = self.dense(hidden_states)
415
+ hidden_states = self.intermediate_act_fn(hidden_states)
416
+ return hidden_states
417
+
418
+
419
+ class CpmAntSegmentPositionEmbedding(nn.Module):
420
+ def __init__(self, config: CpmAntConfig):
421
+ super().__init__()
422
+
423
+ self.num_heads = config.num_attention_heads
424
+ self.num_buckets = config.position_bias_num_buckets
425
+ self.max_distance = config.position_bias_max_distance
426
+ self.num_segments = config.segment_types
427
+
428
+ self.relative_attention_bias = nn.Parameter(
429
+ torch.empty(
430
+ config.segment_types * config.segment_types + config.position_bias_num_buckets,
431
+ config.num_attention_heads,
432
+ )
433
+ )
434
+
435
+ def forward(
436
+ self,
437
+ key_pos: torch.Tensor,
438
+ query_pos: torch.Tensor,
439
+ key_segment: torch.Tensor,
440
+ query_segment: torch.Tensor,
441
+ ):
442
+ with torch.no_grad():
443
+ batch = key_pos.size(0)
444
+ keylen = key_pos.size(1)
445
+ querylen = query_pos.size(1)
446
+
447
+ if key_pos.size(0) != query_pos.size(0):
448
+ raise AssertionError(
449
+ f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
450
+ )
451
+ if keylen != key_segment.size(1) or querylen != query_segment.size(1):
452
+ raise AssertionError(
453
+ f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
454
+ )
455
+ if querylen != query_segment.size(1):
456
+ raise AssertionError(
457
+ f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!"
458
+ )
459
+
460
+ key_pos = key_pos.view(batch, -1, keylen)
461
+ query_pos = query_pos.view(batch, querylen, -1)
462
+ key_segment = key_segment.view(batch, -1, keylen)
463
+ query_segment = query_segment.view(batch, querylen, -1)
464
+
465
+ relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
466
+ relative_position_bucket = relative_position_bucket + self.num_buckets
467
+
468
+ # (batch, len_q, len_k)
469
+ absolute_position_bucket = self._position_bucket(
470
+ torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
471
+ - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
472
+ num_buckets=self.num_buckets,
473
+ max_distance=self.max_distance,
474
+ )
475
+ relative_position_bucket = torch.where(
476
+ (key_segment == query_segment),
477
+ absolute_position_bucket[None, :, :],
478
+ relative_position_bucket,
479
+ )
480
+
481
+ # (batch, len_q, len_k, num_heads)
482
+ embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
483
+ # (batch, num_heads, len_q, len_k)
484
+ embeds = embeds.permute(0, 3, 1, 2).contiguous()
485
+ return embeds
486
+
487
+ def _segment_relative_position_bucket(self, query_segment, key_segment):
488
+ return query_segment * self.num_segments + key_segment
489
+
490
+ def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
491
+ relative_buckets = 0
492
+ # always bidirectional in CPMAnt
493
+ num_buckets //= 2
494
+ relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
495
+ relative_position = torch.abs(relative_position)
496
+ max_exact = num_buckets // 2
497
+ is_small = relative_position < max_exact
498
+ relative_postion_if_large = max_exact + (
499
+ torch.log(relative_position.float() / max_exact)
500
+ / math.log(max_distance / max_exact)
501
+ * (num_buckets - max_exact)
502
+ ).to(torch.int32)
503
+ relative_postion_if_large = torch.min(
504
+ relative_postion_if_large,
505
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
506
+ )
507
+ relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
508
+ return relative_buckets
509
+
510
+
511
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
512
+ class CpmAntOutput(nn.Module):
513
+ def __init__(self, config):
514
+ super().__init__()
515
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
516
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
517
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
518
+
519
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
520
+ hidden_states = self.dense(hidden_states)
521
+ hidden_states = self.dropout(hidden_states)
522
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
523
+ return hidden_states
524
+
525
+
526
+ class CpmAntPreTrainedModel(PreTrainedModel):
527
+ """
528
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
529
+ models.
530
+ """
531
+
532
+ config_class = CpmAntConfig
533
+ base_model_prefix = "cpmant"
534
+
535
+ def _init_weights(self, module):
536
+ """Initialize the weights"""
537
+ if isinstance(module, nn.Linear):
538
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
539
+ if module.bias is not None:
540
+ module.bias.data.zero_()
541
+ elif isinstance(module, nn.Embedding):
542
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
543
+ if module.padding_idx is not None:
544
+ module.weight.data[module.padding_idx].zero_()
545
+ elif isinstance(module, nn.LayerNorm):
546
+ module.bias.data.zero_()
547
+ module.weight.data.fill_(1.0)
548
+ elif isinstance(module, CpmAntLayerNorm):
549
+ module.weight.data.fill_(1.0)
550
+ elif isinstance(module, CpmAntSegmentPositionEmbedding):
551
+ module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
552
+
553
+
554
+ CPMANT_START_DOCSTRING = r"""
555
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
556
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
557
+ behavior.
558
+
559
+ Parameters
560
+ config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the
561
+ Initializing with a config file does not load the weights associated with the model, only the
562
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
563
+ """
564
+
565
+ CPMANT_INPUTS_DOCSTRING = r"""
566
+ Args:
567
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
568
+ Indices of input sequence tokens in the vocabulary.
569
+
570
+ Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
571
+ [`PreTrainedTokenizer.__call__`] for details.
572
+
573
+ [What are input IDs?](../glossary#input-ids)
574
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
575
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
576
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
577
+ use_cache (`bool`, *optional*):
578
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
579
+ `past_key_values`).
580
+ output_attentions (`bool`, *optional*):
581
+ Whether or not to return the attentions tensors of all attention layers.
582
+ output_hidden_states (`bool`, *optional*):
583
+ Whether or not to return the hidden states of all layers.
584
+ return_dict (`bool`, *optional*):
585
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
586
+ """
587
+
588
+
589
+ @add_start_docstrings(
590
+ "The bare CPMAnt Model outputting raw hidden-states without any specific head on top.",
591
+ CPMANT_START_DOCSTRING,
592
+ )
593
+ class CpmAntModel(CpmAntPreTrainedModel):
594
+ def __init__(self, config: CpmAntConfig):
595
+ super().__init__(config)
596
+ self.encoder = CpmAntEncoder(config)
597
+ self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
598
+ self.input_embedding = nn.Embedding(
599
+ config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
600
+ )
601
+ self.position_bias = CpmAntSegmentPositionEmbedding(config)
602
+ self.prompt_length = config.prompt_length
603
+ self.vocab_size = config.vocab_size
604
+
605
+ self.post_init()
606
+
607
+ def get_input_embeddings(self):
608
+ return self.input_embedding
609
+
610
+ def set_input_embeddings(self, embeddings, **kwargs):
611
+ self.input_embedding = embeddings
612
+
613
+ def _prepare_attention_mask(self, input_ids, span, context, length):
614
+ batch = input_ids.size(0)
615
+ seqlen = input_ids.size(1)
616
+ device = input_ids.device
617
+ directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
618
+ attention_mask = context[:, None, :] | (
619
+ context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
620
+ )
621
+ attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
622
+ # mask for left padding
623
+ mask_1d = (
624
+ torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
625
+ < length[:, None]
626
+ )
627
+ mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
628
+ attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
629
+ return attention_mask
630
+
631
+ @add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
632
+ @add_code_sample_docstrings(
633
+ checkpoint=_CHECKPOINT_FOR_DOC,
634
+ output_type=BaseModelOutputWithPast,
635
+ config_class=_CONFIG_FOR_DOC,
636
+ )
637
+ def forward(
638
+ self,
639
+ input_ids: Optional[torch.Tensor] = None,
640
+ output_attentions: Optional[bool] = None,
641
+ output_hidden_states: Optional[bool] = None,
642
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
643
+ use_cache: Optional[bool] = None,
644
+ return_dict: Optional[bool] = None,
645
+ **kwargs,
646
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
647
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
648
+ output_hidden_states = (
649
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
650
+ )
651
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
652
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
653
+
654
+ # add prompts ahead
655
+ if input_ids.dtype != torch.int32:
656
+ input_ids = input_ids.to(torch.int32)
657
+ dtype, device = input_ids.dtype, input_ids.device
658
+ segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
659
+ length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
660
+ input_ids = torch.cat(
661
+ (
662
+ torch.arange(
663
+ self.prompt_length * 2 + self.vocab_size,
664
+ self.prompt_length * 3 + self.vocab_size,
665
+ dtype=dtype,
666
+ device=device,
667
+ ).repeat(input_ids.size(0), 1),
668
+ input_ids,
669
+ ),
670
+ dim=1,
671
+ )
672
+ batch, seq_length = input_ids.size()
673
+ segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
674
+ context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
675
+ position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
676
+ span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
677
+
678
+ if past_key_values is None:
679
+ past_length = 0
680
+ past_key_values = tuple([None] * self.encoder.num_layers)
681
+ input_ids = input_ids.contiguous()
682
+ hidden_states = self.input_embedding(input_ids)
683
+ segment_states = self.segment_embedding(segment)
684
+ hidden_states = hidden_states + segment_states
685
+ else:
686
+ past_length = past_key_values[0][0].size(-2)
687
+ segment_states = self.segment_embedding(segment)
688
+ hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :]
689
+
690
+ attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
691
+ position_bias = self.position_bias(position, position, segment, segment)
692
+
693
+ attention_mask = attention_mask[:, past_length:, :]
694
+ position_bias = position_bias[:, :, past_length:, :]
695
+ hidden_states = hidden_states[:, past_length:, :]
696
+
697
+ hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
698
+ hidden_states,
699
+ attention_mask,
700
+ position_bias,
701
+ output_attentions,
702
+ output_hidden_states,
703
+ past_key_values,
704
+ use_cache,
705
+ )
706
+
707
+ if past_length == 0:
708
+ hidden_states = hidden_states[:, self.prompt_length :, :]
709
+ # drop the prompt
710
+ if all_attentions is not None:
711
+ new_attentions = ()
712
+ for attention in all_attentions:
713
+ new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
714
+ all_attentions = new_attentions
715
+ if all_hidden_states is not None:
716
+ new_hidden_states = ()
717
+ for hidden_state in all_hidden_states:
718
+ new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
719
+ all_hidden_states = new_hidden_states
720
+
721
+ if not return_dict:
722
+ return tuple(
723
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
724
+ )
725
+
726
+ return BaseModelOutputWithPast(
727
+ last_hidden_state=hidden_states,
728
+ past_key_values=present_key_values,
729
+ hidden_states=all_hidden_states,
730
+ attentions=all_attentions,
731
+ )
732
+
733
+
734
+ @add_start_docstrings(
735
+ """
736
+ The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
737
+ """,
738
+ CPMANT_START_DOCSTRING,
739
+ )
740
+ class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
741
+ _tied_weights_keys = ["lm_head.weight"]
742
+
743
+ def __init__(self, config: CpmAntConfig):
744
+ super().__init__(config)
745
+ self.cpmant = CpmAntModel(config)
746
+
747
+ # lm_head.weight is tied to cpmant.input_embedding.weight
748
+ self.lm_head = nn.Linear(
749
+ config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
750
+ )
751
+ self.post_init()
752
+
753
+ @add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
754
+ @add_code_sample_docstrings(
755
+ checkpoint=_CHECKPOINT_FOR_DOC,
756
+ output_type=CausalLMOutputWithPast,
757
+ config_class=_CONFIG_FOR_DOC,
758
+ )
759
+ def forward(
760
+ self,
761
+ input_ids: Optional[torch.Tensor] = None,
762
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
763
+ use_cache: Optional[bool] = None,
764
+ output_attentions: Optional[bool] = None,
765
+ output_hidden_states: Optional[bool] = None,
766
+ labels: Optional[torch.Tensor] = None,
767
+ return_dict: Optional[bool] = None,
768
+ attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline
769
+ **kwargs,
770
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
771
+ r"""
772
+ Args:
773
+ input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
774
+ Indices of input sequence tokens in the vocabulary.
775
+
776
+ Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
777
+ [`PreTrainedTokenizer.__call__`] for details.
778
+
779
+ [What are input IDs?](../glossary#input-ids)
780
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
781
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
782
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
783
+ use_cache (`bool`, *optional*):
784
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
785
+ (see `past_key_values`).
786
+ output_attentions (`bool`, *optional*):
787
+ Whether or not to return the attentions tensors of all attention layers.
788
+ output_hidden_states (`bool`, *optional*):
789
+ Whether or not to return the hidden states of all layers.
790
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
791
+ Labels for computing the masked language modeling loss.
792
+ return_dict (`bool`, *optional*):
793
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
794
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
795
+ CPMAnt will process attention mask automatically, this parameter is a dummy parameter for
796
+ text-generation pipeline.
797
+
798
+ Example:
799
+
800
+ Text Generation with CpmAntForCausalLM.
801
+ ```python
802
+ >>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
803
+
804
+ >>> texts = "今天天气不错,"
805
+ >>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
806
+ >>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
807
+ >>> input_ids = tokenizer(texts, return_tensors="pt")
808
+ >>> outputs = model.generate(**input_ids)
809
+ >>> output_texts = tokenizer.batch_decode(outputs)
810
+ >>> print(output_texts)
811
+ ['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
812
+ ```
813
+ """
814
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
815
+
816
+ model_output = self.cpmant(
817
+ input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict
818
+ )
819
+ hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
820
+
821
+ logits = self.lm_head(hidden_states)
822
+
823
+ loss = None
824
+ if labels is not None:
825
+ loss_func = CrossEntropyLoss()
826
+ loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
827
+
828
+ if not return_dict:
829
+ output = (logits,) + model_output[1:]
830
+ return ((loss,) + output) if loss is not None else output
831
+
832
+ return CausalLMOutputWithPast(
833
+ loss=loss,
834
+ logits=logits,
835
+ past_key_values=model_output.past_key_values,
836
+ hidden_states=model_output.hidden_states,
837
+ attentions=model_output.attentions,
838
+ )
839
+
840
+ def get_input_embeddings(self):
841
+ return self.cpmant.input_embedding
842
+
843
+ def set_input_embeddings(self, embeddings):
844
+ self.cpmant.input_embedding = embeddings
845
+
846
+ def get_output_embeddings(self):
847
+ return self.lm_head
848
+
849
+ def set_output_embeddings(self, new_embeddings):
850
+ self.lm_head = new_embeddings
851
+
852
+ def _reorder_cache(self, past_key_values, beam_idx):
853
+ past_key_values = [list(each) if each is not None else each for each in past_key_values]
854
+ for key_value_layer in past_key_values:
855
+ key_value_layer[0] = key_value_layer[0][beam_idx]
856
+ key_value_layer[1] = key_value_layer[1][beam_idx]
857
+ return past_key_values
858
+
859
+
860
+ __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]
docs/transformers/build/lib/transformers/models/cpmant/tokenization_cpmant.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for CPMAnt."""
16
+
17
+ import collections
18
+ import os
19
+ from typing import List, Optional, Tuple
20
+
21
+ from transformers.utils import is_jieba_available, requires_backends
22
+
23
+
24
+ if is_jieba_available():
25
+ import jieba
26
+
27
+ from ...tokenization_utils import PreTrainedTokenizer
28
+ from ...utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
34
+
35
+
36
+ def load_vocab(vocab_file):
37
+ """Loads a vocabulary file into a dictionary."""
38
+ vocab = collections.OrderedDict()
39
+ with open(vocab_file, "r", encoding="utf-8") as reader:
40
+ tokens = reader.readlines()
41
+ for index, token in enumerate(tokens):
42
+ token = token.rstrip("\n")
43
+ vocab[token] = index
44
+ return vocab
45
+
46
+
47
+ class WordpieceTokenizer:
48
+ def __init__(self, vocab, unk_token="<unk>", max_input_chars_per_word=200):
49
+ self.vocab = vocab
50
+ self.unk_token = unk_token
51
+ self.max_input_chars_per_word = max_input_chars_per_word
52
+
53
+ def tokenize(self, token):
54
+ chars = list(token)
55
+ if len(chars) > self.max_input_chars_per_word:
56
+ return [self.unk_token]
57
+
58
+ start = 0
59
+ sub_tokens = []
60
+ while start < len(chars):
61
+ end = len(chars)
62
+ cur_substr = None
63
+ while start < end:
64
+ substr = "".join(chars[start:end])
65
+ if substr in self.vocab:
66
+ cur_substr = substr
67
+ break
68
+ end -= 1
69
+ if cur_substr is None:
70
+ sub_tokens.append(self.unk_token)
71
+ start += 1
72
+ else:
73
+ sub_tokens.append(cur_substr)
74
+ start = end
75
+
76
+ return sub_tokens
77
+
78
+
79
+ class CpmAntTokenizer(PreTrainedTokenizer):
80
+ """
81
+ Construct a CPMAnt tokenizer. Based on byte-level Byte-Pair-Encoding.
82
+
83
+ Args:
84
+ vocab_file (`str`):
85
+ Path to the vocabulary file.
86
+ bod_token (`str`, *optional*, defaults to `"<d>"`):
87
+ The beginning of document token.
88
+ eod_token (`str`, *optional*, defaults to `"</d>"`):
89
+ The end of document token.
90
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
91
+ The beginning of sequence token.
92
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
93
+ The end of sequence token.
94
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
95
+ The token used for padding.
96
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
97
+ The unknown token.
98
+ line_token (`str`, *optional*, defaults to `"</n>"`):
99
+ The line token.
100
+ space_token (`str`, *optional*, defaults to `"</_>"`):
101
+ The space token.
102
+ """
103
+
104
+ vocab_files_names = VOCAB_FILES_NAMES
105
+ model_input_names = ["input_ids", "attention_mask"]
106
+ add_prefix_space = False
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_file,
111
+ bod_token="<d>",
112
+ eod_token="</d>",
113
+ bos_token="<s>",
114
+ eos_token="</s>",
115
+ pad_token="<pad>",
116
+ unk_token="<unk>",
117
+ line_token="</n>",
118
+ space_token="</_>",
119
+ padding_side="left",
120
+ **kwargs,
121
+ ):
122
+ requires_backends(self, ["jieba"])
123
+ self.bod_token = bod_token
124
+ self.eod_token = eod_token
125
+ self.encoder = load_vocab(vocab_file)
126
+ self.encoder[" "] = self.encoder[space_token]
127
+ self.encoder["\n"] = self.encoder[line_token]
128
+
129
+ del self.encoder[space_token]
130
+ del self.encoder[line_token]
131
+
132
+ self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
133
+ self.decoder = {v: k for k, v in self.encoder.items()}
134
+
135
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder, unk_token=unk_token)
136
+
137
+ super().__init__(
138
+ bod_token=bod_token,
139
+ eod_token=eod_token,
140
+ bos_token=bos_token,
141
+ eos_token=eos_token,
142
+ pad_token=pad_token,
143
+ unk_token=unk_token,
144
+ line_token=line_token,
145
+ space_token=space_token,
146
+ padding_side=padding_side,
147
+ **kwargs,
148
+ )
149
+
150
+ @property
151
+ def bod_token_id(self):
152
+ return self.encoder[self.bod_token]
153
+
154
+ @property
155
+ def eod_token_id(self):
156
+ return self.encoder[self.eod_token]
157
+
158
+ @property
159
+ def newline_id(self):
160
+ return self.encoder["\n"]
161
+
162
+ @property
163
+ def vocab_size(self) -> int:
164
+ return len(self.encoder)
165
+
166
+ def get_vocab(self):
167
+ return dict(self.encoder, **self.added_tokens_encoder)
168
+
169
+ def _tokenize(self, text):
170
+ """Tokenize a string."""
171
+ output_tokens = []
172
+ for x in jieba.cut(text, cut_all=False):
173
+ output_tokens.extend(self.wordpiece_tokenizer.tokenize(x))
174
+ return output_tokens
175
+
176
+ def _decode(self, token_ids, **kwargs):
177
+ """Decode ids into a string."""
178
+ token_ids = [i for i in token_ids if i >= 0]
179
+ token_ids = [
180
+ x for x in token_ids if x != self.pad_token_id and x != self.eos_token_id and x != self.bos_token_id
181
+ ]
182
+ return super()._decode(token_ids, **kwargs)
183
+
184
+ def check(self, token):
185
+ return token in self.encoder
186
+
187
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
188
+ return "".join(tokens)
189
+
190
+ def _convert_token_to_id(self, token):
191
+ """Converts a token (str) in an id using the vocab."""
192
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
193
+
194
+ def _convert_id_to_token(self, index):
195
+ """Converts an index (integer) in a token (str) using the vocab."""
196
+ return self.decoder.get(index, self.unk_token)
197
+
198
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
199
+ if os.path.isdir(save_directory):
200
+ vocab_file = os.path.join(
201
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
202
+ )
203
+ else:
204
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
205
+ index = 0
206
+ if " " in self.encoder:
207
+ self.encoder["</_>"] = self.encoder[" "]
208
+ del self.encoder[" "]
209
+ if "\n" in self.encoder:
210
+ self.encoder["</n>"] = self.encoder["\n"]
211
+ del self.encoder["\n"]
212
+ self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
213
+ with open(vocab_file, "w", encoding="utf-8") as writer:
214
+ for token, token_index in self.encoder.items():
215
+ if index != token_index:
216
+ logger.warning(
217
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
218
+ " Please check that the vocabulary is not corrupted!"
219
+ )
220
+ index = token_index
221
+ writer.write(token + "\n")
222
+ index += 1
223
+ return (vocab_file,)
224
+
225
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: List[int] = None) -> List[int]:
226
+ """
227
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
228
+ adding special tokens. A CPMAnt sequence has the following format:
229
+
230
+ - single sequence: `[BOS] Sequence`.
231
+
232
+ Args:
233
+ token_ids_0 (`List[int]`): The first tokenized sequence that special tokens will be added.
234
+ token_ids_1 (`List[int]`): The optional second tokenized sequence that special tokens will be added.
235
+
236
+ Returns:
237
+ `List[int]`: The model input with special tokens.
238
+ """
239
+ if token_ids_1 is None:
240
+ return [self.bos_token_id] + token_ids_0
241
+ return [self.bos_token_id] + token_ids_0 + [self.bos_token_id] + token_ids_1
242
+
243
+ def get_special_tokens_mask(
244
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
245
+ ) -> List[int]:
246
+ """
247
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
248
+ special tokens using the tokenizer `prepare_for_model` method.
249
+
250
+ Args:
251
+ token_ids_0 (`List[int]`): List of IDs.
252
+ token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs.
253
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
254
+ Whether or not the token list is already formatted with special tokens for the model.
255
+
256
+ Returns:
257
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
258
+ """
259
+
260
+ if already_has_special_tokens:
261
+ return super().get_special_tokens_mask(
262
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
263
+ )
264
+
265
+ if token_ids_1 is not None:
266
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
267
+ return [1] + ([0] * len(token_ids_0))
268
+
269
+
270
+ __all__ = ["CpmAntTokenizer"]
docs/transformers/build/lib/transformers/models/ctrl/configuration_ctrl.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Salesforce and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Salesforce CTRL configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class CTRLConfig(PretrainedConfig):
25
+ """
26
+ This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
27
+ instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the
29
+ [Salesforce/ctrl](https://huggingface.co/Salesforce/ctrl) architecture from SalesForce.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 246534):
36
+ Vocabulary size of the CTRL model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`CTRLModel`] or [`TFCTRLModel`].
38
+ n_positions (`int`, *optional*, defaults to 256):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ n_embd (`int`, *optional*, defaults to 1280):
42
+ Dimensionality of the embeddings and hidden states.
43
+ dff (`int`, *optional*, defaults to 8192):
44
+ Dimensionality of the inner dimension of the feed forward networks (FFN).
45
+ n_layer (`int`, *optional*, defaults to 48):
46
+ Number of hidden layers in the Transformer encoder.
47
+ n_head (`int`, *optional*, defaults to 16):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the embeddings.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-06):
54
+ The epsilon to use in the layer normalization layers
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ use_cache (`bool`, *optional*, defaults to `True`):
58
+ Whether or not the model should return the last key/values attentions (not used by all models).
59
+
60
+
61
+ Examples:
62
+
63
+ ```python
64
+ >>> from transformers import CTRLConfig, CTRLModel
65
+
66
+ >>> # Initializing a CTRL configuration
67
+ >>> configuration = CTRLConfig()
68
+
69
+ >>> # Initializing a model (with random weights) from the configuration
70
+ >>> model = CTRLModel(configuration)
71
+
72
+ >>> # Accessing the model configuration
73
+ >>> configuration = model.config
74
+ ```"""
75
+
76
+ model_type = "ctrl"
77
+ keys_to_ignore_at_inference = ["past_key_values"]
78
+ attribute_map = {
79
+ "max_position_embeddings": "n_positions",
80
+ "hidden_size": "n_embd",
81
+ "num_attention_heads": "n_head",
82
+ "num_hidden_layers": "n_layer",
83
+ }
84
+
85
+ def __init__(
86
+ self,
87
+ vocab_size=246534,
88
+ n_positions=256,
89
+ n_embd=1280,
90
+ dff=8192,
91
+ n_layer=48,
92
+ n_head=16,
93
+ resid_pdrop=0.1,
94
+ embd_pdrop=0.1,
95
+ layer_norm_epsilon=1e-6,
96
+ initializer_range=0.02,
97
+ use_cache=True,
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.n_positions = n_positions
102
+ self.n_embd = n_embd
103
+ self.n_layer = n_layer
104
+ self.n_head = n_head
105
+ self.dff = dff
106
+ self.resid_pdrop = resid_pdrop
107
+ self.embd_pdrop = embd_pdrop
108
+ self.layer_norm_epsilon = layer_norm_epsilon
109
+ self.initializer_range = initializer_range
110
+
111
+ self.use_cache = use_cache
112
+
113
+ super().__init__(**kwargs)
114
+
115
+
116
+ __all__ = ["CTRLConfig"]
docs/transformers/build/lib/transformers/models/ctrl/modeling_ctrl.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Salesforce and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch CTRL model."""
17
+
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...generation import GenerationMixin
26
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
30
+ from .configuration_ctrl import CTRLConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _CONFIG_FOR_DOC = "CTRLConfig"
36
+
37
+
38
+ def angle_defn(pos, i, d_model_size):
39
+ angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
40
+ return pos * angle_rates
41
+
42
+
43
+ def positional_encoding(position, d_model_size, dtype):
44
+ # create the sinusoidal pattern for the positional encoding
45
+ angle_rads = angle_defn(
46
+ torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
47
+ torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
48
+ d_model_size,
49
+ )
50
+
51
+ sines = torch.sin(angle_rads[:, 0::2])
52
+ cosines = torch.cos(angle_rads[:, 1::2])
53
+
54
+ pos_encoding = torch.cat([sines, cosines], dim=-1)
55
+ return pos_encoding
56
+
57
+
58
+ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
59
+ # calculate attention
60
+ matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
61
+
62
+ dk = k.shape[-1]
63
+ scaled_attention_logits = matmul_qk / np.sqrt(dk)
64
+
65
+ if mask is not None:
66
+ nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
67
+ scaled_attention_logits += mask[ns - nd : ns, :ns] * -1e4
68
+
69
+ if attention_mask is not None:
70
+ # Apply the attention mask
71
+ scaled_attention_logits = scaled_attention_logits + attention_mask
72
+
73
+ attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
74
+
75
+ # Mask heads if we want to
76
+ if head_mask is not None:
77
+ attention_weights = attention_weights * head_mask
78
+
79
+ output = torch.matmul(attention_weights, v)
80
+
81
+ return output, attention_weights
82
+
83
+
84
+ class MultiHeadAttention(nn.Module):
85
+ def __init__(self, d_model_size, num_heads):
86
+ super().__init__()
87
+ self.num_heads = num_heads
88
+ self.d_model_size = d_model_size
89
+
90
+ self.depth = int(d_model_size / self.num_heads)
91
+
92
+ self.Wq = nn.Linear(d_model_size, d_model_size)
93
+ self.Wk = nn.Linear(d_model_size, d_model_size)
94
+ self.Wv = nn.Linear(d_model_size, d_model_size)
95
+
96
+ self.dense = nn.Linear(d_model_size, d_model_size)
97
+ self.pruned_heads = set()
98
+
99
+ def prune_heads(self, heads):
100
+ attention_head_size = self.d_model_size // self.num_heads
101
+ if len(heads) == 0:
102
+ return
103
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
104
+
105
+ # Prune linear layers
106
+ self.Wq = prune_linear_layer(self.Wq, index)
107
+ self.Wk = prune_linear_layer(self.Wk, index)
108
+ self.Wv = prune_linear_layer(self.Wv, index)
109
+ self.dense = prune_linear_layer(self.dense, index, dim=1)
110
+
111
+ # Update hyper params
112
+ self.num_heads = self.num_heads - len(heads)
113
+ self.d_model_size = attention_head_size * self.num_heads
114
+ self.pruned_heads = self.pruned_heads.union(heads)
115
+
116
+ def split_into_heads(self, x, batch_size):
117
+ x = x.reshape(batch_size, -1, self.num_heads, self.depth)
118
+ return x.permute([0, 2, 1, 3])
119
+
120
+ def forward(
121
+ self,
122
+ v,
123
+ k,
124
+ q,
125
+ mask,
126
+ layer_past=None,
127
+ attention_mask=None,
128
+ head_mask=None,
129
+ use_cache=False,
130
+ output_attentions=False,
131
+ ):
132
+ batch_size = q.shape[0]
133
+
134
+ q = self.Wq(q)
135
+ k = self.Wk(k)
136
+ v = self.Wv(v)
137
+
138
+ q = self.split_into_heads(q, batch_size)
139
+ k = self.split_into_heads(k, batch_size)
140
+ v = self.split_into_heads(v, batch_size)
141
+ if layer_past is not None:
142
+ past_key, past_value = layer_past[0], layer_past[1]
143
+ k = torch.cat((past_key, k), dim=-2)
144
+ v = torch.cat((past_value, v), dim=-2)
145
+
146
+ if use_cache is True:
147
+ present = torch.stack((k, v))
148
+ else:
149
+ present = (None,)
150
+
151
+ output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
152
+ scaled_attention = output[0].permute([0, 2, 1, 3])
153
+ attn = output[1]
154
+ original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
155
+ output = self.dense(original_size_attention)
156
+
157
+ outputs = (output, present)
158
+ if output_attentions:
159
+ outputs = outputs + (attn,)
160
+ return outputs
161
+
162
+
163
+ def point_wise_feed_forward_network(d_model_size, dff):
164
+ return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
165
+
166
+
167
+ class EncoderLayer(nn.Module):
168
+ def __init__(self, d_model_size, num_heads, dff, rate=0.1):
169
+ super().__init__()
170
+
171
+ self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
172
+ self.ffn = point_wise_feed_forward_network(d_model_size, dff)
173
+
174
+ self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
175
+ self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
176
+
177
+ self.dropout1 = nn.Dropout(rate)
178
+ self.dropout2 = nn.Dropout(rate)
179
+
180
+ def forward(
181
+ self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
182
+ ):
183
+ normed = self.layernorm1(x)
184
+ attn_outputs = self.multi_head_attention(
185
+ normed,
186
+ normed,
187
+ normed,
188
+ mask,
189
+ layer_past=layer_past,
190
+ attention_mask=attention_mask,
191
+ head_mask=head_mask,
192
+ use_cache=use_cache,
193
+ output_attentions=output_attentions,
194
+ )
195
+ attn_output = attn_outputs[0]
196
+ attn_output = self.dropout1(attn_output)
197
+ out1 = x + attn_output
198
+
199
+ out2 = self.layernorm2(out1)
200
+ ffn_output = self.ffn(out2)
201
+ ffn_output = self.dropout2(ffn_output)
202
+ out2 = out1 + ffn_output
203
+
204
+ outputs = (out2,) + attn_outputs[1:]
205
+ return outputs
206
+
207
+
208
+ class CTRLPreTrainedModel(PreTrainedModel):
209
+ """
210
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
211
+ models.
212
+ """
213
+
214
+ config_class = CTRLConfig
215
+ base_model_prefix = "transformer"
216
+
217
+ def _init_weights(self, module):
218
+ """Initialize the weights."""
219
+ if isinstance(module, (nn.Linear, Conv1D)):
220
+ # Slightly different from the TF version which uses truncated_normal for initialization
221
+ # cf https://github.com/pytorch/pytorch/pull/5617
222
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
223
+ if module.bias is not None:
224
+ module.bias.data.zero_()
225
+ elif isinstance(module, nn.Embedding):
226
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
227
+ if module.padding_idx is not None:
228
+ module.weight.data[module.padding_idx].zero_()
229
+ elif isinstance(module, nn.LayerNorm):
230
+ module.bias.data.zero_()
231
+ module.weight.data.fill_(1.0)
232
+
233
+
234
+ CTRL_START_DOCSTRING = r"""
235
+
236
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
237
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
238
+ etc.)
239
+
240
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
241
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
242
+ and behavior.
243
+
244
+ Parameters:
245
+ config ([`CTRLConfig`]): Model configuration class with all the parameters of the model.
246
+ Initializing with a config file does not load the weights associated with the model, only the
247
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
248
+ """
249
+
250
+ CTRL_INPUTS_DOCSTRING = r"""
251
+ Args:
252
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
253
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
254
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
255
+
256
+ If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
257
+ `input_ids`.
258
+
259
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
260
+ [`PreTrainedTokenizer.encode`] for details.
261
+
262
+ [What are input IDs?](../glossary#input-ids)
263
+ past_key_values (`Tuple[Tuple[torch.FloatTensor]]` of length `config.n_layers`):
264
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
265
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
266
+ their past given to this model should not be passed as input ids as they have already been computed.
267
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
268
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
269
+
270
+ - 1 for tokens that are **not masked**,
271
+ - 0 for tokens that are **masked**.
272
+
273
+ [What are attention masks?](../glossary#attention-mask)
274
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
275
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
276
+ 1]`:
277
+
278
+ - 0 corresponds to a *sentence A* token,
279
+ - 1 corresponds to a *sentence B* token.
280
+
281
+ [What are token type IDs?](../glossary#token-type-ids)
282
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
283
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
284
+ config.max_position_embeddings - 1]`.
285
+
286
+ [What are position IDs?](../glossary#position-ids)
287
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
288
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
289
+
290
+ - 1 indicates the head is **not masked**,
291
+ - 0 indicates the head is **masked**.
292
+
293
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
294
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
295
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
296
+ model's internal embedding lookup matrix.
297
+ use_cache (`bool`, *optional*):
298
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
299
+ `past_key_values`).
300
+ output_attentions (`bool`, *optional*):
301
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
302
+ tensors for more detail.
303
+ output_hidden_states (`bool`, *optional*):
304
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
305
+ more detail.
306
+ return_dict (`bool`, *optional*):
307
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
308
+ """
309
+
310
+
311
+ @add_start_docstrings(
312
+ "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
313
+ CTRL_START_DOCSTRING,
314
+ )
315
+ class CTRLModel(CTRLPreTrainedModel):
316
+ def __init__(self, config):
317
+ super().__init__(config)
318
+
319
+ self.d_model_size = config.n_embd
320
+ self.num_layers = config.n_layer
321
+
322
+ self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
323
+
324
+ self.w = nn.Embedding(config.vocab_size, config.n_embd)
325
+
326
+ self.dropout = nn.Dropout(config.embd_pdrop)
327
+ self.h = nn.ModuleList(
328
+ [EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]
329
+ )
330
+ self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
331
+
332
+ # Initialize weights and apply final processing
333
+ self.post_init()
334
+
335
+ def get_input_embeddings(self):
336
+ return self.w
337
+
338
+ def set_input_embeddings(self, new_embeddings):
339
+ self.w = new_embeddings
340
+
341
+ def _prune_heads(self, heads_to_prune):
342
+ """
343
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
344
+ """
345
+ for layer, heads in heads_to_prune.items():
346
+ self.h[layer].multi_head_attention.prune_heads(heads)
347
+
348
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
349
+ @replace_return_docstrings(output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
350
+ def forward(
351
+ self,
352
+ input_ids: Optional[torch.LongTensor] = None,
353
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
354
+ attention_mask: Optional[torch.FloatTensor] = None,
355
+ token_type_ids: Optional[torch.LongTensor] = None,
356
+ position_ids: Optional[torch.LongTensor] = None,
357
+ head_mask: Optional[torch.FloatTensor] = None,
358
+ inputs_embeds: Optional[torch.FloatTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ **kwargs, # NOOP kwargs, for now
364
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
365
+ r"""
366
+ Returns:
367
+
368
+ Example:
369
+
370
+ ```python
371
+ >>> from transformers import AutoTokenizer, CTRLModel
372
+ >>> import torch
373
+
374
+ >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
375
+ >>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
376
+
377
+ >>> # CTRL was trained with control codes as the first token
378
+ >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
379
+ >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
380
+
381
+ >>> outputs = model(**inputs)
382
+
383
+ >>> last_hidden_states = outputs.last_hidden_state
384
+ >>> list(last_hidden_states.shape)
385
+ [1, 5, 1280]
386
+ ```"""
387
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
388
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
389
+ output_hidden_states = (
390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
391
+ )
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+
394
+ if input_ids is not None and inputs_embeds is not None:
395
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
396
+ elif input_ids is not None:
397
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
398
+ input_shape = input_ids.size()
399
+ input_ids = input_ids.view(-1, input_shape[-1])
400
+ batch_size = input_ids.shape[0]
401
+ elif inputs_embeds is not None:
402
+ input_shape = inputs_embeds.size()[:-1]
403
+ batch_size = inputs_embeds.shape[0]
404
+ else:
405
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
406
+
407
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
408
+
409
+ if past_key_values is None:
410
+ past_length = 0
411
+ past_key_values = tuple([None] * len(self.h))
412
+ else:
413
+ past_length = past_key_values[0][0].size(-2)
414
+ if position_ids is None:
415
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
416
+ position_ids = position_ids.unsqueeze(0)
417
+
418
+ # Attention mask.
419
+ if attention_mask is not None:
420
+ if batch_size <= 0:
421
+ raise ValueError("batch_size has to be defined and > 0")
422
+ attention_mask = attention_mask.view(batch_size, -1)
423
+ # We create a 3D attention mask from a 2D tensor mask.
424
+ # Sizes are [batch_size, 1, 1, to_seq_length]
425
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
426
+ # this attention mask is more simple than the triangular masking of causal attention
427
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
428
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
429
+
430
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
431
+ # masked positions, this operation will create a tensor which is 0.0 for
432
+ # positions we want to attend and the dtype's smallest value for masked positions.
433
+ # Since we are adding it to the raw scores before the softmax, this is
434
+ # effectively the same as removing these entirely.
435
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
436
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
437
+
438
+ # Prepare head mask if needed
439
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
440
+
441
+ if token_type_ids is not None:
442
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
443
+ token_type_embeds = self.w(token_type_ids)
444
+ token_type_embeds *= np.sqrt(self.d_model_size)
445
+ else:
446
+ token_type_embeds = 0
447
+
448
+ if inputs_embeds is None:
449
+ inputs_embeds = self.w(input_ids)
450
+ # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
451
+ seq_len = input_shape[-1]
452
+ mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
453
+
454
+ inputs_embeds *= np.sqrt(self.d_model_size)
455
+
456
+ # `self.pos_encoding` won't be sent to the correct device along the model, so we do it manually.
457
+ self.pos_encoding = self.pos_encoding.to(device)
458
+ pos_embeds = self.pos_encoding[position_ids, :]
459
+
460
+ hidden_states = inputs_embeds + pos_embeds + token_type_embeds
461
+
462
+ hidden_states = self.dropout(hidden_states)
463
+
464
+ presents = () if use_cache else None
465
+ all_hidden_states = () if output_hidden_states else None
466
+ all_attentions = () if output_attentions else None
467
+ for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
468
+ if output_hidden_states:
469
+ all_hidden_states = all_hidden_states + (hidden_states,)
470
+ outputs = h(
471
+ hidden_states,
472
+ mask,
473
+ layer_past=layer_past,
474
+ attention_mask=attention_mask,
475
+ head_mask=head_mask[i],
476
+ use_cache=use_cache,
477
+ output_attentions=output_attentions,
478
+ )
479
+ hidden_states, present = outputs[:2]
480
+ if use_cache is True:
481
+ presents = presents + (present,)
482
+
483
+ if output_attentions:
484
+ all_attentions += (outputs[2],)
485
+
486
+ hidden_states = self.layernorm(hidden_states)
487
+ if output_hidden_states:
488
+ all_hidden_states = all_hidden_states + (hidden_states,)
489
+
490
+ if not return_dict:
491
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
492
+
493
+ return BaseModelOutputWithPast(
494
+ last_hidden_state=hidden_states,
495
+ past_key_values=presents,
496
+ hidden_states=all_hidden_states,
497
+ attentions=all_attentions,
498
+ )
499
+
500
+
501
+ @add_start_docstrings(
502
+ """
503
+ The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
504
+ embeddings).
505
+ """,
506
+ CTRL_START_DOCSTRING,
507
+ )
508
+ class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
509
+ _tied_weights_keys = ["lm_head.weight"]
510
+
511
+ def __init__(self, config):
512
+ super().__init__(config)
513
+ self.transformer = CTRLModel(config)
514
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
515
+
516
+ # Initialize weights and apply final processing
517
+ self.post_init()
518
+
519
+ def get_output_embeddings(self):
520
+ return self.lm_head
521
+
522
+ def set_output_embeddings(self, new_embeddings):
523
+ self.lm_head = new_embeddings
524
+
525
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
526
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
527
+ def forward(
528
+ self,
529
+ input_ids: Optional[torch.LongTensor] = None,
530
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
531
+ attention_mask: Optional[torch.FloatTensor] = None,
532
+ token_type_ids: Optional[torch.LongTensor] = None,
533
+ position_ids: Optional[torch.LongTensor] = None,
534
+ head_mask: Optional[torch.FloatTensor] = None,
535
+ inputs_embeds: Optional[torch.FloatTensor] = None,
536
+ labels: Optional[torch.LongTensor] = None,
537
+ use_cache: Optional[bool] = None,
538
+ output_attentions: Optional[bool] = None,
539
+ output_hidden_states: Optional[bool] = None,
540
+ return_dict: Optional[bool] = None,
541
+ **kwargs,
542
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
543
+ r"""
544
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
545
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
546
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
547
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
548
+
549
+ Returns:
550
+
551
+ Example:
552
+
553
+ ```python
554
+ >>> import torch
555
+ >>> from transformers import AutoTokenizer, CTRLLMHeadModel
556
+
557
+ >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
558
+ >>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
559
+
560
+ >>> # CTRL was trained with control codes as the first token
561
+ >>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
562
+ >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
563
+
564
+ >>> sequence_ids = model.generate(inputs["input_ids"])
565
+ >>> sequences = tokenizer.batch_decode(sequence_ids)
566
+ >>> sequences
567
+ ['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']
568
+
569
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
570
+ >>> round(outputs.loss.item(), 2)
571
+ 9.21
572
+
573
+ >>> list(outputs.logits.shape)
574
+ [1, 5, 246534]
575
+ ```"""
576
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
577
+
578
+ transformer_outputs = self.transformer(
579
+ input_ids,
580
+ past_key_values=past_key_values,
581
+ attention_mask=attention_mask,
582
+ token_type_ids=token_type_ids,
583
+ position_ids=position_ids,
584
+ head_mask=head_mask,
585
+ inputs_embeds=inputs_embeds,
586
+ use_cache=use_cache,
587
+ output_attentions=output_attentions,
588
+ output_hidden_states=output_hidden_states,
589
+ return_dict=return_dict,
590
+ )
591
+
592
+ hidden_states = transformer_outputs[0]
593
+
594
+ lm_logits = self.lm_head(hidden_states)
595
+
596
+ loss = None
597
+ if labels is not None:
598
+ loss = self.loss_function(
599
+ lm_logits,
600
+ labels,
601
+ vocab_size=self.config.vocab_size,
602
+ **kwargs,
603
+ )
604
+
605
+ if not return_dict:
606
+ output = (lm_logits,) + transformer_outputs[1:]
607
+ return ((loss,) + output) if loss is not None else output
608
+
609
+ return CausalLMOutputWithPast(
610
+ loss=loss,
611
+ logits=lm_logits,
612
+ past_key_values=transformer_outputs.past_key_values,
613
+ hidden_states=transformer_outputs.hidden_states,
614
+ attentions=transformer_outputs.attentions,
615
+ )
616
+
617
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
618
+ # Overwritten -- inputs_embeds not working properly
619
+
620
+ # only last tokens for inputs_ids if past is defined in kwargs
621
+ if past_key_values is not None:
622
+ past_length = past_key_values[0][0].shape[2]
623
+
624
+ # Some generation methods already pass only the last input ID
625
+ if input_ids.shape[1] > past_length:
626
+ remove_prefix_length = past_length
627
+ else:
628
+ # Default to old behavior: keep only final ID
629
+ remove_prefix_length = input_ids.shape[1] - 1
630
+
631
+ input_ids = input_ids[:, remove_prefix_length:]
632
+
633
+ return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
634
+
635
+ @staticmethod
636
+ def _reorder_cache(
637
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
638
+ ) -> Tuple[Tuple[torch.Tensor]]:
639
+ """
640
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
641
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
642
+ beam_idx at every generation step.
643
+ """
644
+ return tuple(
645
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
646
+ for layer_past in past_key_values
647
+ )
648
+
649
+
650
+ @add_start_docstrings(
651
+ """
652
+ The CTRL Model transformer with a sequence classification head on top (linear layer).
653
+ [`CTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
654
+ (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last
655
+ token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in
656
+ each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
657
+ guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last
658
+ value in each row of the batch).
659
+ """,
660
+ CTRL_START_DOCSTRING,
661
+ )
662
+ class CTRLForSequenceClassification(CTRLPreTrainedModel):
663
+ def __init__(self, config):
664
+ super().__init__(config)
665
+ self.num_labels = config.num_labels
666
+ self.transformer = CTRLModel(config)
667
+ self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
668
+
669
+ # Initialize weights and apply final processing
670
+ self.post_init()
671
+
672
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
673
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
674
+ def forward(
675
+ self,
676
+ input_ids: Optional[torch.LongTensor] = None,
677
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
678
+ attention_mask: Optional[torch.FloatTensor] = None,
679
+ token_type_ids: Optional[torch.LongTensor] = None,
680
+ position_ids: Optional[torch.LongTensor] = None,
681
+ head_mask: Optional[torch.FloatTensor] = None,
682
+ inputs_embeds: Optional[torch.FloatTensor] = None,
683
+ labels: Optional[torch.LongTensor] = None,
684
+ use_cache: Optional[bool] = None,
685
+ output_attentions: Optional[bool] = None,
686
+ output_hidden_states: Optional[bool] = None,
687
+ return_dict: Optional[bool] = None,
688
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
689
+ r"""
690
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
691
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
692
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
693
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
694
+
695
+ Returns:
696
+
697
+ Example of single-label classification:
698
+
699
+ ```python
700
+ >>> import torch
701
+ >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
702
+
703
+ >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
704
+ >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
705
+
706
+ >>> # CTRL was trained with control codes as the first token
707
+ >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
708
+ >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
709
+
710
+ >>> with torch.no_grad():
711
+ ... logits = model(**inputs).logits
712
+
713
+ >>> predicted_class_id = logits.argmax().item()
714
+ >>> model.config.id2label[predicted_class_id]
715
+ 'LABEL_0'
716
+ ```
717
+
718
+ ```python
719
+ >>> import torch
720
+
721
+ >>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
722
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
723
+ >>> num_labels = len(model.config.id2label)
724
+ >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
725
+
726
+ >>> labels = torch.tensor(1)
727
+ >>> loss = model(**inputs, labels=labels).loss
728
+ >>> round(loss.item(), 2)
729
+ 0.93
730
+ ```
731
+
732
+ Example of multi-label classification:
733
+
734
+ ```python
735
+ >>> import torch
736
+ >>> from transformers import AutoTokenizer, CTRLForSequenceClassification
737
+
738
+ >>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
739
+ >>> model = CTRLForSequenceClassification.from_pretrained(
740
+ ... "Salesforce/ctrl", problem_type="multi_label_classification"
741
+ ... )
742
+
743
+ >>> # CTRL was trained with control codes as the first token
744
+ >>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
745
+ >>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
746
+
747
+ >>> with torch.no_grad():
748
+ ... logits = model(**inputs).logits
749
+
750
+ >>> predicted_class_id = logits.argmax().item()
751
+ >>> model.config.id2label[predicted_class_id]
752
+ 'LABEL_0'
753
+ ```
754
+
755
+ ```python
756
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
757
+ >>> num_labels = len(model.config.id2label)
758
+ >>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
759
+
760
+ >>> num_labels = len(model.config.id2label)
761
+ >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
762
+ ... torch.float
763
+ ... )
764
+ >>> loss = model(**inputs, labels=labels).loss
765
+ >>> loss.backward() # doctest: +IGNORE_RESULT
766
+ ```"""
767
+
768
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
769
+
770
+ transformer_outputs = self.transformer(
771
+ input_ids,
772
+ past_key_values=past_key_values,
773
+ attention_mask=attention_mask,
774
+ token_type_ids=token_type_ids,
775
+ position_ids=position_ids,
776
+ head_mask=head_mask,
777
+ inputs_embeds=inputs_embeds,
778
+ use_cache=use_cache,
779
+ output_attentions=output_attentions,
780
+ output_hidden_states=output_hidden_states,
781
+ return_dict=return_dict,
782
+ )
783
+
784
+ hidden_states = transformer_outputs[0]
785
+ logits = self.classifier(hidden_states)
786
+
787
+ if input_ids is not None:
788
+ batch_size, sequence_length = input_ids.shape[:2]
789
+ else:
790
+ batch_size, sequence_length = inputs_embeds.shape[:2]
791
+
792
+ if self.config.pad_token_id is None and batch_size != 1:
793
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
794
+ if self.config.pad_token_id is None:
795
+ last_non_pad_token = -1
796
+ elif input_ids is not None:
797
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
798
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
799
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
800
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
801
+ else:
802
+ last_non_pad_token = -1
803
+ logger.warning_once(
804
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
805
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
806
+ )
807
+
808
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
809
+
810
+ loss = None
811
+ if labels is not None:
812
+ if self.config.problem_type is None:
813
+ if self.num_labels == 1:
814
+ self.config.problem_type = "regression"
815
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
816
+ self.config.problem_type = "single_label_classification"
817
+ else:
818
+ self.config.problem_type = "multi_label_classification"
819
+
820
+ if self.config.problem_type == "regression":
821
+ loss_fct = MSELoss()
822
+ if self.num_labels == 1:
823
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
824
+ else:
825
+ loss = loss_fct(pooled_logits, labels)
826
+ elif self.config.problem_type == "single_label_classification":
827
+ loss_fct = CrossEntropyLoss()
828
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
829
+ elif self.config.problem_type == "multi_label_classification":
830
+ loss_fct = BCEWithLogitsLoss()
831
+ loss = loss_fct(pooled_logits, labels)
832
+ if not return_dict:
833
+ output = (pooled_logits,) + transformer_outputs[2:]
834
+ return ((loss,) + output) if loss is not None else output
835
+
836
+ return SequenceClassifierOutput(
837
+ loss=loss,
838
+ logits=pooled_logits,
839
+ hidden_states=transformer_outputs.hidden_states,
840
+ attentions=transformer_outputs.attentions,
841
+ )
842
+
843
+
844
+ __all__ = ["CTRLForSequenceClassification", "CTRLLMHeadModel", "CTRLModel", "CTRLPreTrainedModel"]
docs/transformers/build/lib/transformers/models/ctrl/modeling_tf_ctrl.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Salesforce and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """TF 2.0 CTRL model."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput
26
+ from ...modeling_tf_utils import (
27
+ TFCausalLanguageModelingLoss,
28
+ TFModelInputType,
29
+ TFPreTrainedModel,
30
+ TFSequenceClassificationLoss,
31
+ get_initializer,
32
+ keras,
33
+ keras_serializable,
34
+ unpack_inputs,
35
+ )
36
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
37
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from .configuration_ctrl import CTRLConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
44
+ _CONFIG_FOR_DOC = "CTRLConfig"
45
+
46
+
47
+ def angle_defn(pos, i, d_model_size):
48
+ angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
49
+ return pos * angle_rates
50
+
51
+
52
+ def positional_encoding(position, d_model_size):
53
+ # create the sinusoidal pattern for the positional encoding
54
+ angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size)
55
+
56
+ sines = np.sin(angle_rads[:, 0::2])
57
+ cosines = np.cos(angle_rads[:, 1::2])
58
+ pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
59
+
60
+ return pos_encoding
61
+
62
+
63
+ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
64
+ # calculate attention
65
+ matmul_qk = tf.matmul(q, k, transpose_b=True)
66
+
67
+ dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
68
+ scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
69
+
70
+ if mask is not None:
71
+ scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
72
+
73
+ if attention_mask is not None:
74
+ # Apply the attention mask
75
+ attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
76
+ scaled_attention_logits = scaled_attention_logits + attention_mask
77
+
78
+ attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
79
+
80
+ # Mask heads if we want to
81
+ if head_mask is not None:
82
+ attention_weights = attention_weights * head_mask
83
+
84
+ output = tf.matmul(attention_weights, v)
85
+
86
+ return output, attention_weights
87
+
88
+
89
+ class TFMultiHeadAttention(keras.layers.Layer):
90
+ def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
91
+ super().__init__(**kwargs)
92
+ self.num_heads = num_heads
93
+ self.d_model_size = d_model_size
94
+ self.output_attentions = output_attentions
95
+
96
+ self.depth = int(d_model_size / self.num_heads)
97
+
98
+ self.Wq = keras.layers.Dense(d_model_size, name="Wq")
99
+ self.Wk = keras.layers.Dense(d_model_size, name="Wk")
100
+ self.Wv = keras.layers.Dense(d_model_size, name="Wv")
101
+
102
+ self.dense = keras.layers.Dense(d_model_size, name="dense")
103
+
104
+ def split_into_heads(self, x, batch_size):
105
+ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
106
+ return tf.transpose(x, perm=[0, 2, 1, 3])
107
+
108
+ def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
109
+ batch_size = shape_list(q)[0]
110
+
111
+ q = self.Wq(q)
112
+ k = self.Wk(k)
113
+ v = self.Wv(v)
114
+
115
+ q = self.split_into_heads(q, batch_size)
116
+ k = self.split_into_heads(k, batch_size)
117
+ v = self.split_into_heads(v, batch_size)
118
+
119
+ if layer_past is not None:
120
+ past_key, past_value = tf.unstack(layer_past, axis=0)
121
+ k = tf.concat((past_key, k), axis=-2)
122
+ v = tf.concat((past_value, v), axis=-2)
123
+
124
+ if use_cache:
125
+ present = tf.stack((k, v), axis=0)
126
+ else:
127
+ present = (None,)
128
+
129
+ output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
130
+ scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
131
+ attn = output[1]
132
+ original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
133
+ output = self.dense(original_size_attention)
134
+ outputs = (output, present)
135
+
136
+ if output_attentions:
137
+ outputs = outputs + (attn,)
138
+
139
+ return outputs
140
+
141
+ def build(self, input_shape=None):
142
+ if self.built:
143
+ return
144
+ self.built = True
145
+ if getattr(self, "Wq", None) is not None:
146
+ with tf.name_scope(self.Wq.name):
147
+ self.Wq.build([None, None, self.d_model_size])
148
+ if getattr(self, "Wk", None) is not None:
149
+ with tf.name_scope(self.Wk.name):
150
+ self.Wk.build([None, None, self.d_model_size])
151
+ if getattr(self, "Wv", None) is not None:
152
+ with tf.name_scope(self.Wv.name):
153
+ self.Wv.build([None, None, self.d_model_size])
154
+ if getattr(self, "dense", None) is not None:
155
+ with tf.name_scope(self.dense.name):
156
+ self.dense.build([None, None, self.d_model_size])
157
+
158
+
159
+ class TFPointWiseFeedForwardLayer(keras.layers.Layer):
160
+ def __init__(self, d_model_size, dff, **kwargs):
161
+ super().__init__(**kwargs)
162
+
163
+ self.dense_0 = keras.layers.Dense(dff, activation="relu", name="0")
164
+ self.dense_2 = keras.layers.Dense(d_model_size, name="2")
165
+ self.d_model_size = d_model_size
166
+ self.dff = dff
167
+
168
+ def call(self, inputs, trainable=False):
169
+ dense_0_output = self.dense_0(inputs)
170
+ dense_2_output = self.dense_2(dense_0_output)
171
+
172
+ return dense_2_output
173
+
174
+ def build(self, input_shape=None):
175
+ if self.built:
176
+ return
177
+ self.built = True
178
+ if getattr(self, "dense_0", None) is not None:
179
+ with tf.name_scope(self.dense_0.name):
180
+ self.dense_0.build([None, None, self.d_model_size])
181
+ if getattr(self, "dense_2", None) is not None:
182
+ with tf.name_scope(self.dense_2.name):
183
+ self.dense_2.build([None, None, self.dff])
184
+
185
+
186
+ class TFEncoderLayer(keras.layers.Layer):
187
+ def __init__(
188
+ self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
189
+ ):
190
+ super().__init__(**kwargs)
191
+
192
+ self.output_attentions = output_attentions
193
+
194
+ self.multi_head_attention = TFMultiHeadAttention(
195
+ d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
196
+ )
197
+ self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
198
+
199
+ self.layernorm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
200
+ self.layernorm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
201
+
202
+ self.dropout1 = keras.layers.Dropout(rate)
203
+ self.dropout2 = keras.layers.Dropout(rate)
204
+ self.d_model_size = d_model_size
205
+
206
+ def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
207
+ normed = self.layernorm1(x)
208
+ attn_outputs = self.multi_head_attention(
209
+ normed,
210
+ normed,
211
+ normed,
212
+ mask,
213
+ layer_past,
214
+ attention_mask,
215
+ head_mask,
216
+ use_cache,
217
+ output_attentions,
218
+ training=training,
219
+ )
220
+ attn_output = attn_outputs[0]
221
+ attn_output = self.dropout1(attn_output, training=training)
222
+ out1 = x + attn_output
223
+
224
+ out2 = self.layernorm2(out1)
225
+ ffn_output = self.ffn(out2)
226
+ ffn_output = self.dropout2(ffn_output, training=training)
227
+ out2 = out1 + ffn_output
228
+
229
+ outputs = (out2,) + attn_outputs[1:]
230
+ return outputs
231
+
232
+ def build(self, input_shape=None):
233
+ if self.built:
234
+ return
235
+ self.built = True
236
+ if getattr(self, "multi_head_attention", None) is not None:
237
+ with tf.name_scope(self.multi_head_attention.name):
238
+ self.multi_head_attention.build(None)
239
+ if getattr(self, "ffn", None) is not None:
240
+ with tf.name_scope(self.ffn.name):
241
+ self.ffn.build(None)
242
+ if getattr(self, "layernorm1", None) is not None:
243
+ with tf.name_scope(self.layernorm1.name):
244
+ self.layernorm1.build([None, None, self.d_model_size])
245
+ if getattr(self, "layernorm2", None) is not None:
246
+ with tf.name_scope(self.layernorm2.name):
247
+ self.layernorm2.build([None, None, self.d_model_size])
248
+
249
+
250
+ @keras_serializable
251
+ class TFCTRLMainLayer(keras.layers.Layer):
252
+ config_class = CTRLConfig
253
+
254
+ def __init__(self, config, **kwargs):
255
+ super().__init__(**kwargs)
256
+
257
+ self.config = config
258
+ self.output_hidden_states = config.output_hidden_states
259
+ self.output_attentions = config.output_attentions
260
+ self.use_cache = config.use_cache
261
+ self.return_dict = config.use_return_dict
262
+
263
+ self.d_model_size = config.n_embd
264
+ self.num_layers = config.n_layer
265
+
266
+ self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
267
+
268
+ self.w = keras.layers.Embedding(
269
+ input_dim=config.vocab_size,
270
+ output_dim=config.n_embd,
271
+ embeddings_initializer=get_initializer(config.initializer_range),
272
+ name="w",
273
+ )
274
+
275
+ self.dropout = keras.layers.Dropout(config.embd_pdrop)
276
+ self.h = [
277
+ TFEncoderLayer(
278
+ config.n_embd,
279
+ config.n_head,
280
+ config.dff,
281
+ config.resid_pdrop,
282
+ config.layer_norm_epsilon,
283
+ self.output_attentions,
284
+ name=f"h_._{i}",
285
+ )
286
+ for i in range(config.n_layer)
287
+ ]
288
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
289
+
290
+ def get_input_embeddings(self):
291
+ return self.w
292
+
293
+ def set_input_embeddings(self, new_embeddings):
294
+ self.w = new_embeddings
295
+
296
+ def _prune_heads(self, heads_to_prune):
297
+ """
298
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
299
+ """
300
+ raise NotImplementedError
301
+
302
+ @unpack_inputs
303
+ def call(
304
+ self,
305
+ input_ids: TFModelInputType | None = None,
306
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
307
+ attention_mask: np.ndarray | tf.Tensor | None = None,
308
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
309
+ position_ids: np.ndarray | tf.Tensor | None = None,
310
+ head_mask: np.ndarray | tf.Tensor | None = None,
311
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
312
+ use_cache: Optional[bool] = None,
313
+ output_attentions: Optional[bool] = None,
314
+ output_hidden_states: Optional[bool] = None,
315
+ return_dict: Optional[bool] = None,
316
+ training: Optional[bool] = False,
317
+ ) -> Union[Tuple, TFBaseModelOutputWithPast]:
318
+ # If using past key value states, only the last tokens
319
+ # should be given as an input
320
+ if past_key_values is not None:
321
+ if input_ids is not None:
322
+ input_ids = input_ids[:, -1:]
323
+ if inputs_embeds is not None:
324
+ inputs_embeds = inputs_embeds[:, -1:]
325
+ if token_type_ids is not None:
326
+ token_type_ids = token_type_ids[:, -1:]
327
+
328
+ if input_ids is not None and inputs_embeds is not None:
329
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
330
+ elif input_ids is not None:
331
+ input_shape = shape_list(input_ids)
332
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
333
+ elif inputs_embeds is not None:
334
+ input_shape = shape_list(inputs_embeds)[:-1]
335
+ else:
336
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
337
+
338
+ if past_key_values is None:
339
+ past_length = 0
340
+ past_key_values = [None] * len(self.h)
341
+ else:
342
+ past_length = shape_list(past_key_values[0][0])[-2]
343
+ if position_ids is None:
344
+ position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0)
345
+ position_ids = tf.tile(position_ids, [input_shape[0], 1])
346
+
347
+ # Attention mask.
348
+ if attention_mask is not None:
349
+ # We create a 3D attention mask from a 2D tensor mask.
350
+ # Sizes are [batch_size, 1, 1, to_seq_length]
351
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
352
+ # this attention mask is more simple than the triangular masking of causal attention
353
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
354
+ attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
355
+
356
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
357
+ # masked positions, this operation will create a tensor which is 0.0 for
358
+ # positions we want to attend and -10000.0 for masked positions.
359
+ # Since we are adding it to the raw scores before the softmax, this is
360
+ # effectively the same as removing these entirely.
361
+
362
+ one_cst = tf.constant(1.0)
363
+ ten_thousand_cst = tf.constant(-10000.0)
364
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
365
+ attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), ten_thousand_cst)
366
+
367
+ # Prepare head mask if needed
368
+ # 1.0 in head_mask indicate we keep the head
369
+ # attention_probs has shape bsz x n_heads x N x N
370
+ # head_mask has shape n_layer x batch x n_heads x N x N
371
+ if head_mask is not None:
372
+ raise NotImplementedError
373
+ else:
374
+ head_mask = [None] * self.num_layers
375
+
376
+ if token_type_ids is not None:
377
+ token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
378
+ token_type_embeds = self.w(token_type_ids)
379
+ token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
380
+ else:
381
+ token_type_embeds = tf.constant(0.0)
382
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
383
+
384
+ if inputs_embeds is None:
385
+ check_embeddings_within_bounds(input_ids, self.w.input_dim)
386
+ inputs_embeds = self.w(input_ids)
387
+ seq_len = input_shape[-1]
388
+ mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
389
+
390
+ inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, inputs_embeds.dtype))
391
+
392
+ pos_embeds = tf.gather(self.pos_encoding, position_ids)
393
+ pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
394
+ hidden_states = inputs_embeds + pos_embeds + token_type_embeds
395
+
396
+ hidden_states = self.dropout(hidden_states, training=training)
397
+
398
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
399
+ presents = () if use_cache else None
400
+ all_hidden_states = () if output_hidden_states else None
401
+ all_attentions = () if output_attentions else None
402
+ for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
403
+ if output_hidden_states:
404
+ all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
405
+ outputs = h(
406
+ hidden_states,
407
+ mask,
408
+ layer_past,
409
+ attention_mask,
410
+ head_mask[i],
411
+ use_cache,
412
+ output_attentions,
413
+ training=training,
414
+ )
415
+ hidden_states, present = outputs[:2]
416
+
417
+ if use_cache:
418
+ presents = presents + (present,)
419
+
420
+ if output_attentions:
421
+ all_attentions = all_attentions + (outputs[2],)
422
+
423
+ hidden_states = self.layernorm(hidden_states)
424
+ hidden_states = tf.reshape(hidden_states, output_shape)
425
+ if output_hidden_states:
426
+ all_hidden_states = all_hidden_states + (hidden_states,)
427
+
428
+ if output_attentions:
429
+ # let the number of heads free (-1) so we can extract attention even after head pruning
430
+ attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
431
+ all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
432
+
433
+ if not return_dict:
434
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
435
+
436
+ return TFBaseModelOutputWithPast(
437
+ last_hidden_state=hidden_states,
438
+ past_key_values=presents,
439
+ hidden_states=all_hidden_states,
440
+ attentions=all_attentions,
441
+ )
442
+
443
+ def build(self, input_shape=None):
444
+ if self.built:
445
+ return
446
+ self.built = True
447
+ if getattr(self, "w", None) is not None:
448
+ with tf.name_scope(self.w.name):
449
+ self.w.build(None)
450
+ if getattr(self, "layernorm", None) is not None:
451
+ with tf.name_scope(self.layernorm.name):
452
+ self.layernorm.build([None, None, self.config.n_embd])
453
+ if getattr(self, "h", None) is not None:
454
+ for layer in self.h:
455
+ with tf.name_scope(layer.name):
456
+ layer.build(None)
457
+
458
+
459
+ class TFCTRLPreTrainedModel(TFPreTrainedModel):
460
+ """
461
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
462
+ models.
463
+ """
464
+
465
+ config_class = CTRLConfig
466
+ base_model_prefix = "transformer"
467
+
468
+
469
+ CTRL_START_DOCSTRING = r"""
470
+
471
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
472
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
473
+ etc.)
474
+
475
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
476
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
477
+ behavior.
478
+
479
+ <Tip>
480
+
481
+ TensorFlow models and layers in `transformers` accept two formats as input:
482
+
483
+ - having all inputs as keyword arguments (like PyTorch models), or
484
+ - having all inputs as a list, tuple or dict in the first positional argument.
485
+
486
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
487
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
488
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
489
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
490
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
491
+ positional argument:
492
+
493
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
494
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
495
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
496
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
497
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
498
+
499
+ Note that when creating models and layers with
500
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
501
+ about any of this, as you can just pass inputs like you would to any other Python function!
502
+
503
+ </Tip>
504
+
505
+ Parameters:
506
+ config ([`CTRLConfig`]): Model configuration class with all the parameters of the model.
507
+ Initializing with a config file does not load the weights associated with the model, only the
508
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
509
+ """
510
+
511
+ CTRL_INPUTS_DOCSTRING = r"""
512
+ Args:
513
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
514
+ `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of
515
+ input past key value states).
516
+
517
+ Indices of input sequence tokens in the vocabulary.
518
+
519
+ If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`.
520
+
521
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
522
+ [`PreTrainedTokenizer.encode`] for details.
523
+
524
+ [What are input IDs?](../glossary#input-ids)
525
+ past (`List[tf.Tensor]` of length `config.n_layers`):
526
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
527
+ `past` output below). Can be used to speed up sequential decoding. The token ids which have their past
528
+ given to this model should not be passed as input ids as they have already been computed.
529
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
530
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
531
+
532
+ - 1 for tokens that are **not masked**,
533
+ - 0 for tokens that are **masked**.
534
+
535
+ [What are attention masks?](../glossary#attention-mask)
536
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
537
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
538
+ 1]`:
539
+
540
+ - 0 corresponds to a *sentence A* token,
541
+ - 1 corresponds to a *sentence B* token.
542
+
543
+ [What are token type IDs?](../glossary#token-type-ids)
544
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
545
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
546
+ config.max_position_embeddings - 1]`.
547
+
548
+ [What are position IDs?](../glossary#position-ids)
549
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
550
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
551
+
552
+ - 1 indicates the head is **not masked**,
553
+ - 0 indicates the head is **masked**.
554
+
555
+ inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
556
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
557
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
558
+ model's internal embedding lookup matrix.
559
+ use_cache (`bool`, *optional*):
560
+ If set to `True`, `past` key value states are returned and can be used to speed up decoding (see `past`).
561
+ output_attentions (`bool`, *optional*):
562
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
563
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
564
+ config will be used instead.
565
+ output_hidden_states (`bool`, *optional*):
566
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
567
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
568
+ used instead.
569
+ return_dict (`bool`, *optional*):
570
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
571
+ eager mode, in graph mode the value will always be set to True.
572
+ training (`bool`, *optional*, defaults to `False`):
573
+ Whether or not to use the model in training mode (some modules like dropout modules have different
574
+ behaviors between training and evaluation).
575
+ """
576
+
577
+
578
+ @add_start_docstrings(
579
+ "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
580
+ CTRL_START_DOCSTRING,
581
+ )
582
+ class TFCTRLModel(TFCTRLPreTrainedModel):
583
+ def __init__(self, config, *inputs, **kwargs):
584
+ super().__init__(config, *inputs, **kwargs)
585
+ self.transformer = TFCTRLMainLayer(config, name="transformer")
586
+
587
+ @unpack_inputs
588
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
589
+ @add_code_sample_docstrings(
590
+ checkpoint=_CHECKPOINT_FOR_DOC,
591
+ output_type=TFBaseModelOutputWithPast,
592
+ config_class=_CONFIG_FOR_DOC,
593
+ )
594
+ def call(
595
+ self,
596
+ input_ids: TFModelInputType | None = None,
597
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
598
+ attention_mask: np.ndarray | tf.Tensor | None = None,
599
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
600
+ position_ids: np.ndarray | tf.Tensor | None = None,
601
+ head_mask: np.ndarray | tf.Tensor | None = None,
602
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
603
+ use_cache: Optional[bool] = None,
604
+ output_attentions: Optional[bool] = None,
605
+ output_hidden_states: Optional[bool] = None,
606
+ return_dict: Optional[bool] = None,
607
+ training: Optional[bool] = False,
608
+ ) -> Union[Tuple, TFBaseModelOutputWithPast]:
609
+ outputs = self.transformer(
610
+ input_ids=input_ids,
611
+ past_key_values=past_key_values,
612
+ attention_mask=attention_mask,
613
+ token_type_ids=token_type_ids,
614
+ position_ids=position_ids,
615
+ head_mask=head_mask,
616
+ inputs_embeds=inputs_embeds,
617
+ use_cache=use_cache,
618
+ output_attentions=output_attentions,
619
+ output_hidden_states=output_hidden_states,
620
+ return_dict=return_dict,
621
+ training=training,
622
+ )
623
+ return outputs
624
+
625
+ def build(self, input_shape=None):
626
+ if self.built:
627
+ return
628
+ self.built = True
629
+ if getattr(self, "transformer", None) is not None:
630
+ with tf.name_scope(self.transformer.name):
631
+ self.transformer.build(None)
632
+
633
+
634
+ class TFCTRLBiasLayer(keras.layers.Layer):
635
+ """
636
+ Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
637
+ so all weights have to be registered in a layer.
638
+ """
639
+
640
+ def __init__(self, shape, initializer, trainable, name, **kwargs):
641
+ super().__init__(name=name, **kwargs)
642
+ self.shape = shape
643
+ self.initializer = initializer
644
+ self.trainable = trainable
645
+
646
+ def build(self, input_shape):
647
+ self.bias = self.add_weight(
648
+ name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
649
+ )
650
+ super().build(input_shape)
651
+
652
+ def call(self, x):
653
+ return x + self.bias
654
+
655
+
656
+ @add_start_docstrings(
657
+ """
658
+ The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
659
+ embeddings).
660
+ """,
661
+ CTRL_START_DOCSTRING,
662
+ )
663
+ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
664
+ def __init__(self, config, *inputs, **kwargs):
665
+ super().__init__(config, *inputs, **kwargs)
666
+ self.transformer = TFCTRLMainLayer(config, name="transformer")
667
+ self.bias_layer = TFCTRLBiasLayer(
668
+ name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
669
+ )
670
+
671
+ def get_output_embeddings(self):
672
+ return self.get_input_embeddings()
673
+
674
+ def set_output_embeddings(self, value):
675
+ self.set_input_embeddings(value)
676
+
677
+ def get_bias(self):
678
+ return {"lm_head.bias": self.bias_layer.bias}
679
+
680
+ def set_bias(self, value):
681
+ # Replaces the existing layers containing bias for correct (de)serialization.
682
+ vocab_size = value["lm_head.bias"].shape[-1]
683
+ self.bias_layer = TFCTRLBiasLayer(
684
+ name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
685
+ )
686
+ self.bias_layer.build(None)
687
+ self.bias_layer.bias.assign(value["lm_head.bias"])
688
+
689
+ # Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
690
+ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
691
+ token_type_ids = kwargs.get("token_type_ids", None)
692
+ # only last token for inputs_ids if past is defined in kwargs
693
+ if past_key_values:
694
+ inputs = tf.expand_dims(inputs[:, -1], -1)
695
+ if token_type_ids is not None:
696
+ token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
697
+
698
+ position_ids = kwargs.get("position_ids", None)
699
+ attention_mask = kwargs.get("attention_mask", None)
700
+
701
+ if attention_mask is not None and position_ids is None:
702
+ position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
703
+ if past_key_values:
704
+ position_ids = tf.expand_dims(position_ids[:, -1], -1)
705
+
706
+ return {
707
+ "input_ids": inputs,
708
+ "attention_mask": attention_mask,
709
+ "position_ids": position_ids,
710
+ "past_key_values": past_key_values,
711
+ "use_cache": use_cache,
712
+ "token_type_ids": token_type_ids,
713
+ }
714
+
715
+ @unpack_inputs
716
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
717
+ @add_code_sample_docstrings(
718
+ checkpoint=_CHECKPOINT_FOR_DOC,
719
+ output_type=TFCausalLMOutputWithPast,
720
+ config_class=_CONFIG_FOR_DOC,
721
+ )
722
+ def call(
723
+ self,
724
+ input_ids: TFModelInputType | None = None,
725
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
726
+ attention_mask: np.ndarray | tf.Tensor | None = None,
727
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
728
+ position_ids: np.ndarray | tf.Tensor | None = None,
729
+ head_mask: np.ndarray | tf.Tensor | None = None,
730
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
731
+ use_cache: Optional[bool] = None,
732
+ output_attentions: Optional[bool] = None,
733
+ output_hidden_states: Optional[bool] = None,
734
+ return_dict: Optional[bool] = None,
735
+ labels: np.ndarray | tf.Tensor | None = None,
736
+ training: Optional[bool] = False,
737
+ ) -> Union[Tuple, TFCausalLMOutputWithPast]:
738
+ r"""
739
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
741
+ config.vocab_size - 1]`.
742
+ """
743
+ transformer_outputs = self.transformer(
744
+ input_ids=input_ids,
745
+ past_key_values=past_key_values,
746
+ attention_mask=attention_mask,
747
+ token_type_ids=token_type_ids,
748
+ position_ids=position_ids,
749
+ head_mask=head_mask,
750
+ inputs_embeds=inputs_embeds,
751
+ use_cache=use_cache,
752
+ output_attentions=output_attentions,
753
+ output_hidden_states=output_hidden_states,
754
+ return_dict=return_dict,
755
+ training=training,
756
+ )
757
+ hidden_states = transformer_outputs[0]
758
+ logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
759
+ logits = self.bias_layer(logits)
760
+
761
+ loss = None
762
+ if labels is not None:
763
+ # shift labels to the left and cut last logit token
764
+ shifted_logits = logits[:, :-1]
765
+ labels = labels[:, 1:]
766
+ loss = self.hf_compute_loss(labels, shifted_logits)
767
+
768
+ if not return_dict:
769
+ output = (logits,) + transformer_outputs[1:]
770
+ return ((loss,) + output) if loss is not None else output
771
+
772
+ return TFCausalLMOutputWithPast(
773
+ loss=loss,
774
+ logits=logits,
775
+ past_key_values=transformer_outputs.past_key_values,
776
+ hidden_states=transformer_outputs.hidden_states,
777
+ attentions=transformer_outputs.attentions,
778
+ )
779
+
780
+ def build(self, input_shape=None):
781
+ if self.built:
782
+ return
783
+ self.built = True
784
+ if getattr(self, "transformer", None) is not None:
785
+ with tf.name_scope(self.transformer.name):
786
+ self.transformer.build(None)
787
+ if getattr(self, "bias_layer", None) is not None:
788
+ with tf.name_scope(self.bias_layer.name):
789
+ self.bias_layer.build(None)
790
+
791
+
792
+ @add_start_docstrings(
793
+ """
794
+ The CTRL Model transformer with a sequence classification head on top (linear layer).
795
+
796
+ [`TFCTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
797
+ (e.g. GPT-1, GPT-2) do.
798
+
799
+ Since it does classification on the last token, it requires to know the position of the last token. If a
800
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
801
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
802
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
803
+ each row of the batch).
804
+ """,
805
+ CTRL_START_DOCSTRING,
806
+ )
807
+ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassificationLoss):
808
+ def __init__(self, config, *inputs, **kwargs):
809
+ super().__init__(config, *inputs, **kwargs)
810
+ self.num_labels = config.num_labels
811
+ self.classifier = keras.layers.Dense(
812
+ config.num_labels,
813
+ kernel_initializer=get_initializer(config.initializer_range),
814
+ name="classifier",
815
+ use_bias=False,
816
+ )
817
+ self.transformer = TFCTRLMainLayer(config, name="transformer")
818
+ self.config = config
819
+
820
+ def get_output_embeddings(self):
821
+ # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too.
822
+ logger.warning(
823
+ "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed "
824
+ "in transformers v4.32."
825
+ )
826
+ return self.transformer.w
827
+
828
+ @unpack_inputs
829
+ @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
830
+ @add_code_sample_docstrings(
831
+ checkpoint=_CHECKPOINT_FOR_DOC,
832
+ output_type=TFSequenceClassifierOutput,
833
+ config_class=_CONFIG_FOR_DOC,
834
+ )
835
+ def call(
836
+ self,
837
+ input_ids: TFModelInputType | None = None,
838
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
839
+ attention_mask: np.ndarray | tf.Tensor | None = None,
840
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
841
+ position_ids: np.ndarray | tf.Tensor | None = None,
842
+ head_mask: np.ndarray | tf.Tensor | None = None,
843
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
844
+ use_cache: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ return_dict: Optional[bool] = None,
848
+ labels: np.ndarray | tf.Tensor | None = None,
849
+ training: Optional[bool] = False,
850
+ ) -> Union[Tuple, TFSequenceClassifierOutput]:
851
+ r"""
852
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
853
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
854
+ config.vocab_size - 1]`.
855
+ """
856
+
857
+ transformer_outputs = self.transformer(
858
+ input_ids=input_ids,
859
+ past_key_values=past_key_values,
860
+ attention_mask=attention_mask,
861
+ token_type_ids=token_type_ids,
862
+ position_ids=position_ids,
863
+ head_mask=head_mask,
864
+ inputs_embeds=inputs_embeds,
865
+ use_cache=use_cache,
866
+ output_attentions=output_attentions,
867
+ output_hidden_states=output_hidden_states,
868
+ return_dict=return_dict,
869
+ training=training,
870
+ )
871
+ hidden_states = transformer_outputs[0]
872
+ logits = self.classifier(hidden_states)
873
+ logits_shape = shape_list(logits)
874
+ batch_size = logits_shape[0]
875
+
876
+ if self.config.pad_token_id is None:
877
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
878
+ else:
879
+ if input_ids is not None:
880
+ token_indices = tf.range(shape_list(input_ids)[-1])
881
+ non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
882
+ last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
883
+ else:
884
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
885
+ logger.warning_once(
886
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
887
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
888
+ )
889
+ loss = None
890
+
891
+ pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
892
+
893
+ if labels is not None:
894
+ if self.config.pad_token_id is None and logits_shape[0] != 1:
895
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
896
+
897
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
898
+
899
+ if not return_dict:
900
+ output = (pooled_logits,) + transformer_outputs[1:]
901
+ return ((loss,) + output) if loss is not None else output
902
+
903
+ return TFSequenceClassifierOutput(
904
+ loss=loss,
905
+ logits=pooled_logits,
906
+ hidden_states=transformer_outputs.hidden_states,
907
+ attentions=transformer_outputs.attentions,
908
+ )
909
+
910
+ def build(self, input_shape=None):
911
+ if self.built:
912
+ return
913
+ self.built = True
914
+ if getattr(self, "classifier", None) is not None:
915
+ with tf.name_scope(self.classifier.name):
916
+ self.classifier.build([None, None, self.config.n_embd])
917
+ if getattr(self, "transformer", None) is not None:
918
+ with tf.name_scope(self.transformer.name):
919
+ self.transformer.build(None)
920
+
921
+
922
+ __all__ = ["TFCTRLForSequenceClassification", "TFCTRLLMHeadModel", "TFCTRLModel", "TFCTRLPreTrainedModel"]
docs/transformers/build/lib/transformers/models/ctrl/tokenization_ctrl.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Salesforce and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Salesforce CTRL."""
16
+
17
+ import json
18
+ import os
19
+ from typing import Optional, Tuple
20
+
21
+ import regex as re
22
+
23
+ from ...tokenization_utils import PreTrainedTokenizer
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {
30
+ "vocab_file": "vocab.json",
31
+ "merges_file": "merges.txt",
32
+ }
33
+
34
+
35
+ CONTROL_CODES = {
36
+ "Pregnancy": 168629,
37
+ "Christianity": 7675,
38
+ "Explain": 106423,
39
+ "Fitness": 63440,
40
+ "Saving": 63163,
41
+ "Ask": 27171,
42
+ "Ass": 95985,
43
+ "Joke": 163509,
44
+ "Questions": 45622,
45
+ "Thoughts": 49605,
46
+ "Retail": 52342,
47
+ "Feminism": 164338,
48
+ "Writing": 11992,
49
+ "Atheism": 192263,
50
+ "Netflix": 48616,
51
+ "Computing": 39639,
52
+ "Opinion": 43213,
53
+ "Alone": 44967,
54
+ "Funny": 58917,
55
+ "Gaming": 40358,
56
+ "Human": 4088,
57
+ "India": 1331,
58
+ "Joker": 77138,
59
+ "Diet": 36206,
60
+ "Legal": 11859,
61
+ "Norman": 4939,
62
+ "Tip": 72689,
63
+ "Weight": 52343,
64
+ "Movies": 46273,
65
+ "Running": 23425,
66
+ "Science": 2090,
67
+ "Horror": 37793,
68
+ "Confession": 60572,
69
+ "Finance": 12250,
70
+ "Politics": 16360,
71
+ "Scary": 191985,
72
+ "Support": 12654,
73
+ "Technologies": 32516,
74
+ "Teenage": 66160,
75
+ "Event": 32769,
76
+ "Learned": 67460,
77
+ "Notion": 182770,
78
+ "Wikipedia": 37583,
79
+ "Books": 6665,
80
+ "Extract": 76050,
81
+ "Confessions": 102701,
82
+ "Conspiracy": 75932,
83
+ "Links": 63674,
84
+ "Narcissus": 150425,
85
+ "Relationship": 54766,
86
+ "Relationships": 134796,
87
+ "Reviews": 41671,
88
+ "News": 4256,
89
+ "Translation": 26820,
90
+ "multilingual": 128406,
91
+ }
92
+
93
+
94
+ def get_pairs(word):
95
+ """
96
+ Return set of symbol pairs in a word.
97
+
98
+ Word is represented as tuple of symbols (symbols being variable-length strings).
99
+ """
100
+ pairs = set()
101
+ prev_char = word[0]
102
+ for char in word[1:]:
103
+ pairs.add((prev_char, char))
104
+ prev_char = char
105
+
106
+ pairs = set(pairs)
107
+ return pairs
108
+
109
+
110
+ class CTRLTokenizer(PreTrainedTokenizer):
111
+ """
112
+ Construct a CTRL tokenizer. Based on Byte-Pair-Encoding.
113
+
114
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
115
+ this superclass for more information regarding those methods.
116
+
117
+ Args:
118
+ vocab_file (`str`):
119
+ Path to the vocabulary file.
120
+ merges_file (`str`):
121
+ Path to the merges file.
122
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
123
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
124
+ token instead.
125
+ """
126
+
127
+ vocab_files_names = VOCAB_FILES_NAMES
128
+ control_codes = CONTROL_CODES
129
+
130
+ def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
131
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
132
+ self.encoder = json.load(vocab_handle)
133
+ self.decoder = {v: k for k, v in self.encoder.items()}
134
+ with open(merges_file, encoding="utf-8") as merges_handle:
135
+ merges = merges_handle.read().split("\n")[1:-1]
136
+ merges = [tuple(merge.split()) for merge in merges]
137
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
138
+ self.cache = {}
139
+ super().__init__(unk_token=unk_token, **kwargs)
140
+
141
+ @property
142
+ def vocab_size(self):
143
+ return len(self.encoder)
144
+
145
+ def get_vocab(self):
146
+ return dict(self.encoder, **self.added_tokens_encoder)
147
+
148
+ def bpe(self, token):
149
+ if token in self.cache:
150
+ return self.cache[token]
151
+ word = tuple(token)
152
+ word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
153
+ pairs = get_pairs(word)
154
+
155
+ if not pairs:
156
+ return token
157
+
158
+ while True:
159
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
160
+ if bigram not in self.bpe_ranks:
161
+ break
162
+ first, second = bigram
163
+ new_word = []
164
+ i = 0
165
+ while i < len(word):
166
+ try:
167
+ j = word.index(first, i)
168
+ except ValueError:
169
+ new_word.extend(word[i:])
170
+ break
171
+ else:
172
+ new_word.extend(word[i:j])
173
+ i = j
174
+
175
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
176
+ new_word.append(first + second)
177
+ i += 2
178
+ else:
179
+ new_word.append(word[i])
180
+ i += 1
181
+ new_word = tuple(new_word)
182
+ word = new_word
183
+ if len(word) == 1:
184
+ break
185
+ else:
186
+ pairs = get_pairs(word)
187
+ word = "@@ ".join(word)
188
+ word = word[:-4]
189
+ self.cache[token] = word
190
+ return word
191
+
192
+ def _tokenize(self, text):
193
+ """Tokenize a string."""
194
+ split_tokens = []
195
+
196
+ words = re.findall(r"\S+\n?", text)
197
+
198
+ for token in words:
199
+ split_tokens.extend(list(self.bpe(token).split(" ")))
200
+ return split_tokens
201
+
202
+ def _convert_token_to_id(self, token):
203
+ """Converts a token (str) in an id using the vocab."""
204
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
205
+
206
+ def _convert_id_to_token(self, index):
207
+ """Converts an index (integer) in a token (str) using the vocab."""
208
+ return self.decoder.get(index, self.unk_token)
209
+
210
+ def convert_tokens_to_string(self, tokens):
211
+ """Converts a sequence of tokens (string) in a single string."""
212
+ out_string = " ".join(tokens).replace("@@ ", "").strip()
213
+ return out_string
214
+
215
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
216
+ if not os.path.isdir(save_directory):
217
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
218
+ return
219
+ vocab_file = os.path.join(
220
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
221
+ )
222
+ merge_file = os.path.join(
223
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
224
+ )
225
+
226
+ with open(vocab_file, "w", encoding="utf-8") as f:
227
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
228
+
229
+ index = 0
230
+ with open(merge_file, "w", encoding="utf-8") as writer:
231
+ writer.write("#version: 0.2\n")
232
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
233
+ if index != token_index:
234
+ logger.warning(
235
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
236
+ " Please check that the tokenizer is not corrupted!"
237
+ )
238
+ index = token_index
239
+ writer.write(" ".join(bpe_tokens) + "\n")
240
+ index += 1
241
+
242
+ return vocab_file, merge_file
243
+
244
+ # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
245
+ # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
246
+ # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
247
+ # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
248
+ # return ''.join(tokens_generated_so_far)
249
+
250
+
251
+ __all__ = ["CTRLTokenizer"]
docs/transformers/build/lib/transformers/models/cvt/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_cvt import *
22
+ from .modeling_cvt import *
23
+ from .modeling_tf_cvt import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/cvt/configuration_cvt.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CvT model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class CvtConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model
27
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the CvT
29
+ [microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ num_channels (`int`, *optional*, defaults to 3):
36
+ The number of input channels.
37
+ patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`):
38
+ The kernel size of each encoder's patch embedding.
39
+ patch_stride (`List[int]`, *optional*, defaults to `[4, 2, 2]`):
40
+ The stride size of each encoder's patch embedding.
41
+ patch_padding (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
42
+ The padding size of each encoder's patch embedding.
43
+ embed_dim (`List[int]`, *optional*, defaults to `[64, 192, 384]`):
44
+ Dimension of each of the encoder blocks.
45
+ num_heads (`List[int]`, *optional*, defaults to `[1, 3, 6]`):
46
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
47
+ depth (`List[int]`, *optional*, defaults to `[1, 2, 10]`):
48
+ The number of layers in each encoder block.
49
+ mlp_ratios (`List[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`):
50
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
51
+ encoder blocks.
52
+ attention_drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
53
+ The dropout ratio for the attention probabilities.
54
+ drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
55
+ The dropout ratio for the patch embeddings probabilities.
56
+ drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`):
57
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
58
+ qkv_bias (`List[bool]`, *optional*, defaults to `[True, True, True]`):
59
+ The bias bool for query, key and value in attentions
60
+ cls_token (`List[bool]`, *optional*, defaults to `[False, False, True]`):
61
+ Whether or not to add a classification token to the output of each of the last 3 stages.
62
+ qkv_projection_method (`List[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`):
63
+ The projection method for query, key and value Default is depth-wise convolutions with batch norm. For
64
+ Linear projection use "avg".
65
+ kernel_qkv (`List[int]`, *optional*, defaults to `[3, 3, 3]`):
66
+ The kernel size for query, key and value in attention layer
67
+ padding_kv (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
68
+ The padding size for key and value in attention layer
69
+ stride_kv (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
70
+ The stride size for key and value in attention layer
71
+ padding_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
72
+ The padding size for query in attention layer
73
+ stride_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
74
+ The stride size for query in attention layer
75
+ initializer_range (`float`, *optional*, defaults to 0.02):
76
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
77
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
78
+ The epsilon used by the layer normalization layers.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import CvtConfig, CvtModel
84
+
85
+ >>> # Initializing a Cvt msft/cvt style configuration
86
+ >>> configuration = CvtConfig()
87
+
88
+ >>> # Initializing a model (with random weights) from the msft/cvt style configuration
89
+ >>> model = CvtModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "cvt"
96
+
97
+ def __init__(
98
+ self,
99
+ num_channels=3,
100
+ patch_sizes=[7, 3, 3],
101
+ patch_stride=[4, 2, 2],
102
+ patch_padding=[2, 1, 1],
103
+ embed_dim=[64, 192, 384],
104
+ num_heads=[1, 3, 6],
105
+ depth=[1, 2, 10],
106
+ mlp_ratio=[4.0, 4.0, 4.0],
107
+ attention_drop_rate=[0.0, 0.0, 0.0],
108
+ drop_rate=[0.0, 0.0, 0.0],
109
+ drop_path_rate=[0.0, 0.0, 0.1],
110
+ qkv_bias=[True, True, True],
111
+ cls_token=[False, False, True],
112
+ qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"],
113
+ kernel_qkv=[3, 3, 3],
114
+ padding_kv=[1, 1, 1],
115
+ stride_kv=[2, 2, 2],
116
+ padding_q=[1, 1, 1],
117
+ stride_q=[1, 1, 1],
118
+ initializer_range=0.02,
119
+ layer_norm_eps=1e-12,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(**kwargs)
123
+ self.num_channels = num_channels
124
+ self.patch_sizes = patch_sizes
125
+ self.patch_stride = patch_stride
126
+ self.patch_padding = patch_padding
127
+ self.embed_dim = embed_dim
128
+ self.num_heads = num_heads
129
+ self.depth = depth
130
+ self.mlp_ratio = mlp_ratio
131
+ self.attention_drop_rate = attention_drop_rate
132
+ self.drop_rate = drop_rate
133
+ self.drop_path_rate = drop_path_rate
134
+ self.qkv_bias = qkv_bias
135
+ self.cls_token = cls_token
136
+ self.qkv_projection_method = qkv_projection_method
137
+ self.kernel_qkv = kernel_qkv
138
+ self.padding_kv = padding_kv
139
+ self.stride_kv = stride_kv
140
+ self.padding_q = padding_q
141
+ self.stride_q = stride_q
142
+ self.initializer_range = initializer_range
143
+ self.layer_norm_eps = layer_norm_eps
144
+
145
+
146
+ __all__ = ["CvtConfig"]
docs/transformers/build/lib/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert CvT checkpoints from the original repository.
16
+
17
+ URL: https://github.com/microsoft/CvT"""
18
+
19
+ import argparse
20
+ import json
21
+ from collections import OrderedDict
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
28
+
29
+
30
+ def embeddings(idx):
31
+ """
32
+ The function helps in renaming embedding layer weights.
33
+
34
+ Args:
35
+ idx: stage number in original model
36
+ """
37
+ embed = []
38
+ embed.append(
39
+ (
40
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
41
+ f"stage{idx}.patch_embed.proj.weight",
42
+ )
43
+ )
44
+ embed.append(
45
+ (
46
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
47
+ f"stage{idx}.patch_embed.proj.bias",
48
+ )
49
+ )
50
+ embed.append(
51
+ (
52
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
53
+ f"stage{idx}.patch_embed.norm.weight",
54
+ )
55
+ )
56
+ embed.append(
57
+ (
58
+ f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
59
+ f"stage{idx}.patch_embed.norm.bias",
60
+ )
61
+ )
62
+ return embed
63
+
64
+
65
+ def attention(idx, cnt):
66
+ """
67
+ The function helps in renaming attention block layers weights.
68
+
69
+ Args:
70
+ idx: stage number in original model
71
+ cnt: count of blocks in each stage
72
+ """
73
+ attention_weights = []
74
+ attention_weights.append(
75
+ (
76
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
77
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
78
+ )
79
+ )
80
+ attention_weights.append(
81
+ (
82
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
83
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
84
+ )
85
+ )
86
+ attention_weights.append(
87
+ (
88
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
89
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
90
+ )
91
+ )
92
+ attention_weights.append(
93
+ (
94
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
95
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
96
+ )
97
+ )
98
+ attention_weights.append(
99
+ (
100
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
101
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
102
+ )
103
+ )
104
+ attention_weights.append(
105
+ (
106
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
107
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
108
+ )
109
+ )
110
+ attention_weights.append(
111
+ (
112
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
113
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
114
+ )
115
+ )
116
+ attention_weights.append(
117
+ (
118
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
119
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
120
+ )
121
+ )
122
+ attention_weights.append(
123
+ (
124
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
125
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
126
+ )
127
+ )
128
+ attention_weights.append(
129
+ (
130
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
131
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
132
+ )
133
+ )
134
+ attention_weights.append(
135
+ (
136
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
137
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
138
+ )
139
+ )
140
+ attention_weights.append(
141
+ (
142
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
143
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
144
+ )
145
+ )
146
+ attention_weights.append(
147
+ (
148
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
149
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
150
+ )
151
+ )
152
+ attention_weights.append(
153
+ (
154
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
155
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
156
+ )
157
+ )
158
+ attention_weights.append(
159
+ (
160
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
161
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
162
+ )
163
+ )
164
+ attention_weights.append(
165
+ (
166
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
167
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
168
+ )
169
+ )
170
+ attention_weights.append(
171
+ (
172
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
173
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
174
+ )
175
+ )
176
+ attention_weights.append(
177
+ (
178
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
179
+ f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
180
+ )
181
+ )
182
+ attention_weights.append(
183
+ (
184
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
185
+ f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
186
+ )
187
+ )
188
+ attention_weights.append(
189
+ (
190
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
191
+ f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
192
+ )
193
+ )
194
+ attention_weights.append(
195
+ (
196
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
197
+ f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
198
+ )
199
+ )
200
+ attention_weights.append(
201
+ (
202
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
203
+ f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
204
+ )
205
+ )
206
+ attention_weights.append(
207
+ (
208
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
209
+ f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
210
+ )
211
+ )
212
+ attention_weights.append(
213
+ (
214
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
215
+ f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
216
+ )
217
+ )
218
+ attention_weights.append(
219
+ (
220
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
221
+ f"stage{idx}.blocks.{cnt}.attn.proj.weight",
222
+ )
223
+ )
224
+ attention_weights.append(
225
+ (
226
+ f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
227
+ f"stage{idx}.blocks.{cnt}.attn.proj.bias",
228
+ )
229
+ )
230
+ attention_weights.append(
231
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
232
+ )
233
+ attention_weights.append(
234
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
235
+ )
236
+ attention_weights.append(
237
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
238
+ )
239
+ attention_weights.append(
240
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
241
+ )
242
+ attention_weights.append(
243
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
244
+ )
245
+ attention_weights.append(
246
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
247
+ )
248
+ attention_weights.append(
249
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
250
+ )
251
+ attention_weights.append(
252
+ (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
253
+ )
254
+ return attention_weights
255
+
256
+
257
+ def cls_token(idx):
258
+ """
259
+ Function helps in renaming cls_token weights
260
+ """
261
+ token = []
262
+ token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
263
+ return token
264
+
265
+
266
+ def final():
267
+ """
268
+ Function helps in renaming final classification layer
269
+ """
270
+ head = []
271
+ head.append(("layernorm.weight", "norm.weight"))
272
+ head.append(("layernorm.bias", "norm.bias"))
273
+ head.append(("classifier.weight", "head.weight"))
274
+ head.append(("classifier.bias", "head.bias"))
275
+ return head
276
+
277
+
278
+ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
279
+ """
280
+ Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
281
+ """
282
+ img_labels_file = "imagenet-1k-id2label.json"
283
+ num_labels = 1000
284
+
285
+ repo_id = "huggingface/label-files"
286
+ num_labels = num_labels
287
+ id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text())
288
+ id2label = {int(k): v for k, v in id2label.items()}
289
+
290
+ id2label = id2label
291
+ label2id = {v: k for k, v in id2label.items()}
292
+
293
+ config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
294
+
295
+ # For depth size 13 (13 = 1+2+10)
296
+ if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
297
+ config.depth = [1, 2, 10]
298
+
299
+ # For depth size 21 (21 = 1+4+16)
300
+ elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
301
+ config.depth = [1, 4, 16]
302
+
303
+ # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
304
+ else:
305
+ config.depth = [2, 2, 20]
306
+ config.num_heads = [3, 12, 16]
307
+ config.embed_dim = [192, 768, 1024]
308
+
309
+ model = CvtForImageClassification(config)
310
+ image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
311
+ image_processor.size["shortest_edge"] = image_size
312
+ original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"), weights_only=True)
313
+
314
+ huggingface_weights = OrderedDict()
315
+ list_of_state_dict = []
316
+
317
+ for idx in range(len(config.depth)):
318
+ if config.cls_token[idx]:
319
+ list_of_state_dict = list_of_state_dict + cls_token(idx)
320
+ list_of_state_dict = list_of_state_dict + embeddings(idx)
321
+ for cnt in range(config.depth[idx]):
322
+ list_of_state_dict = list_of_state_dict + attention(idx, cnt)
323
+
324
+ list_of_state_dict = list_of_state_dict + final()
325
+ for gg in list_of_state_dict:
326
+ print(gg)
327
+ for i in range(len(list_of_state_dict)):
328
+ huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
329
+
330
+ model.load_state_dict(huggingface_weights)
331
+ model.save_pretrained(pytorch_dump_folder)
332
+ image_processor.save_pretrained(pytorch_dump_folder)
333
+
334
+
335
+ # Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
336
+
337
+ if __name__ == "__main__":
338
+ parser = argparse.ArgumentParser()
339
+ parser.add_argument(
340
+ "--cvt_model",
341
+ default="cvt-w24",
342
+ type=str,
343
+ help="Name of the cvt model you'd like to convert.",
344
+ )
345
+ parser.add_argument(
346
+ "--image_size",
347
+ default=384,
348
+ type=int,
349
+ help="Input Image Size",
350
+ )
351
+ parser.add_argument(
352
+ "--cvt_file_name",
353
+ default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
354
+ type=str,
355
+ help="Input Image Size",
356
+ )
357
+ parser.add_argument(
358
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
359
+ )
360
+
361
+ args = parser.parse_args()
362
+ convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/cvt/modeling_cvt.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CvT model."""
16
+
17
+ import collections.abc
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
28
+ from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import logging
30
+ from .configuration_cvt import CvtConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ # General docstring
36
+ _CONFIG_FOR_DOC = "CvtConfig"
37
+
38
+ # Base docstring
39
+ _CHECKPOINT_FOR_DOC = "microsoft/cvt-13"
40
+ _EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14]
41
+
42
+ # Image classification docstring
43
+ _IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13"
44
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
45
+
46
+
47
+ @dataclass
48
+ class BaseModelOutputWithCLSToken(ModelOutput):
49
+ """
50
+ Base class for model's outputs, with potential hidden states and attentions.
51
+
52
+ Args:
53
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
54
+ Sequence of hidden-states at the output of the last layer of the model.
55
+ cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
56
+ Classification token at the output of the last layer of the model.
57
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
58
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
59
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
60
+ plus the initial embedding outputs.
61
+ """
62
+
63
+ last_hidden_state: Optional[torch.FloatTensor] = None
64
+ cls_token_value: Optional[torch.FloatTensor] = None
65
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
66
+
67
+
68
+ # Copied from transformers.models.beit.modeling_beit.drop_path
69
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
70
+ """
71
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
72
+
73
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
74
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
75
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
76
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
77
+ argument.
78
+ """
79
+ if drop_prob == 0.0 or not training:
80
+ return input
81
+ keep_prob = 1 - drop_prob
82
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
83
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
84
+ random_tensor.floor_() # binarize
85
+ output = input.div(keep_prob) * random_tensor
86
+ return output
87
+
88
+
89
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
90
+ class CvtDropPath(nn.Module):
91
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
92
+
93
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
94
+ super().__init__()
95
+ self.drop_prob = drop_prob
96
+
97
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98
+ return drop_path(hidden_states, self.drop_prob, self.training)
99
+
100
+ def extra_repr(self) -> str:
101
+ return "p={}".format(self.drop_prob)
102
+
103
+
104
+ class CvtEmbeddings(nn.Module):
105
+ """
106
+ Construct the CvT embeddings.
107
+ """
108
+
109
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
110
+ super().__init__()
111
+ self.convolution_embeddings = CvtConvEmbeddings(
112
+ patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
113
+ )
114
+ self.dropout = nn.Dropout(dropout_rate)
115
+
116
+ def forward(self, pixel_values):
117
+ hidden_state = self.convolution_embeddings(pixel_values)
118
+ hidden_state = self.dropout(hidden_state)
119
+ return hidden_state
120
+
121
+
122
+ class CvtConvEmbeddings(nn.Module):
123
+ """
124
+ Image to Conv Embedding.
125
+ """
126
+
127
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
128
+ super().__init__()
129
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
130
+ self.patch_size = patch_size
131
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
132
+ self.normalization = nn.LayerNorm(embed_dim)
133
+
134
+ def forward(self, pixel_values):
135
+ pixel_values = self.projection(pixel_values)
136
+ batch_size, num_channels, height, width = pixel_values.shape
137
+ hidden_size = height * width
138
+ # rearrange "b c h w -> b (h w) c"
139
+ pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
140
+ if self.normalization:
141
+ pixel_values = self.normalization(pixel_values)
142
+ # rearrange "b (h w) c" -> b c h w"
143
+ pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width)
144
+ return pixel_values
145
+
146
+
147
+ class CvtSelfAttentionConvProjection(nn.Module):
148
+ def __init__(self, embed_dim, kernel_size, padding, stride):
149
+ super().__init__()
150
+ self.convolution = nn.Conv2d(
151
+ embed_dim,
152
+ embed_dim,
153
+ kernel_size=kernel_size,
154
+ padding=padding,
155
+ stride=stride,
156
+ bias=False,
157
+ groups=embed_dim,
158
+ )
159
+ self.normalization = nn.BatchNorm2d(embed_dim)
160
+
161
+ def forward(self, hidden_state):
162
+ hidden_state = self.convolution(hidden_state)
163
+ hidden_state = self.normalization(hidden_state)
164
+ return hidden_state
165
+
166
+
167
+ class CvtSelfAttentionLinearProjection(nn.Module):
168
+ def forward(self, hidden_state):
169
+ batch_size, num_channels, height, width = hidden_state.shape
170
+ hidden_size = height * width
171
+ # rearrange " b c h w -> b (h w) c"
172
+ hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
173
+ return hidden_state
174
+
175
+
176
+ class CvtSelfAttentionProjection(nn.Module):
177
+ def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"):
178
+ super().__init__()
179
+ if projection_method == "dw_bn":
180
+ self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride)
181
+ self.linear_projection = CvtSelfAttentionLinearProjection()
182
+
183
+ def forward(self, hidden_state):
184
+ hidden_state = self.convolution_projection(hidden_state)
185
+ hidden_state = self.linear_projection(hidden_state)
186
+ return hidden_state
187
+
188
+
189
+ class CvtSelfAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ num_heads,
193
+ embed_dim,
194
+ kernel_size,
195
+ padding_q,
196
+ padding_kv,
197
+ stride_q,
198
+ stride_kv,
199
+ qkv_projection_method,
200
+ qkv_bias,
201
+ attention_drop_rate,
202
+ with_cls_token=True,
203
+ **kwargs,
204
+ ):
205
+ super().__init__()
206
+ self.scale = embed_dim**-0.5
207
+ self.with_cls_token = with_cls_token
208
+ self.embed_dim = embed_dim
209
+ self.num_heads = num_heads
210
+
211
+ self.convolution_projection_query = CvtSelfAttentionProjection(
212
+ embed_dim,
213
+ kernel_size,
214
+ padding_q,
215
+ stride_q,
216
+ projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
217
+ )
218
+ self.convolution_projection_key = CvtSelfAttentionProjection(
219
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
220
+ )
221
+ self.convolution_projection_value = CvtSelfAttentionProjection(
222
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
223
+ )
224
+
225
+ self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
226
+ self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
227
+ self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
228
+
229
+ self.dropout = nn.Dropout(attention_drop_rate)
230
+
231
+ def rearrange_for_multi_head_attention(self, hidden_state):
232
+ batch_size, hidden_size, _ = hidden_state.shape
233
+ head_dim = self.embed_dim // self.num_heads
234
+ # rearrange 'b t (h d) -> b h t d'
235
+ return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
236
+
237
+ def forward(self, hidden_state, height, width):
238
+ if self.with_cls_token:
239
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
240
+ batch_size, hidden_size, num_channels = hidden_state.shape
241
+ # rearrange "b (h w) c -> b c h w"
242
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
243
+
244
+ key = self.convolution_projection_key(hidden_state)
245
+ query = self.convolution_projection_query(hidden_state)
246
+ value = self.convolution_projection_value(hidden_state)
247
+
248
+ if self.with_cls_token:
249
+ query = torch.cat((cls_token, query), dim=1)
250
+ key = torch.cat((cls_token, key), dim=1)
251
+ value = torch.cat((cls_token, value), dim=1)
252
+
253
+ head_dim = self.embed_dim // self.num_heads
254
+
255
+ query = self.rearrange_for_multi_head_attention(self.projection_query(query))
256
+ key = self.rearrange_for_multi_head_attention(self.projection_key(key))
257
+ value = self.rearrange_for_multi_head_attention(self.projection_value(value))
258
+
259
+ attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
260
+ attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
261
+ attention_probs = self.dropout(attention_probs)
262
+
263
+ context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
264
+ # rearrange"b h t d -> b t (h d)"
265
+ _, _, hidden_size, _ = context.shape
266
+ context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim)
267
+ return context
268
+
269
+
270
+ class CvtSelfOutput(nn.Module):
271
+ """
272
+ The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
273
+ layernorm applied before each block.
274
+ """
275
+
276
+ def __init__(self, embed_dim, drop_rate):
277
+ super().__init__()
278
+ self.dense = nn.Linear(embed_dim, embed_dim)
279
+ self.dropout = nn.Dropout(drop_rate)
280
+
281
+ def forward(self, hidden_state, input_tensor):
282
+ hidden_state = self.dense(hidden_state)
283
+ hidden_state = self.dropout(hidden_state)
284
+ return hidden_state
285
+
286
+
287
+ class CvtAttention(nn.Module):
288
+ def __init__(
289
+ self,
290
+ num_heads,
291
+ embed_dim,
292
+ kernel_size,
293
+ padding_q,
294
+ padding_kv,
295
+ stride_q,
296
+ stride_kv,
297
+ qkv_projection_method,
298
+ qkv_bias,
299
+ attention_drop_rate,
300
+ drop_rate,
301
+ with_cls_token=True,
302
+ ):
303
+ super().__init__()
304
+ self.attention = CvtSelfAttention(
305
+ num_heads,
306
+ embed_dim,
307
+ kernel_size,
308
+ padding_q,
309
+ padding_kv,
310
+ stride_q,
311
+ stride_kv,
312
+ qkv_projection_method,
313
+ qkv_bias,
314
+ attention_drop_rate,
315
+ with_cls_token,
316
+ )
317
+ self.output = CvtSelfOutput(embed_dim, drop_rate)
318
+ self.pruned_heads = set()
319
+
320
+ def prune_heads(self, heads):
321
+ if len(heads) == 0:
322
+ return
323
+ heads, index = find_pruneable_heads_and_indices(
324
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
325
+ )
326
+
327
+ # Prune linear layers
328
+ self.attention.query = prune_linear_layer(self.attention.query, index)
329
+ self.attention.key = prune_linear_layer(self.attention.key, index)
330
+ self.attention.value = prune_linear_layer(self.attention.value, index)
331
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
332
+
333
+ # Update hyper params and store pruned heads
334
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
335
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
336
+ self.pruned_heads = self.pruned_heads.union(heads)
337
+
338
+ def forward(self, hidden_state, height, width):
339
+ self_output = self.attention(hidden_state, height, width)
340
+ attention_output = self.output(self_output, hidden_state)
341
+ return attention_output
342
+
343
+
344
+ class CvtIntermediate(nn.Module):
345
+ def __init__(self, embed_dim, mlp_ratio):
346
+ super().__init__()
347
+ self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
348
+ self.activation = nn.GELU()
349
+
350
+ def forward(self, hidden_state):
351
+ hidden_state = self.dense(hidden_state)
352
+ hidden_state = self.activation(hidden_state)
353
+ return hidden_state
354
+
355
+
356
+ class CvtOutput(nn.Module):
357
+ def __init__(self, embed_dim, mlp_ratio, drop_rate):
358
+ super().__init__()
359
+ self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
360
+ self.dropout = nn.Dropout(drop_rate)
361
+
362
+ def forward(self, hidden_state, input_tensor):
363
+ hidden_state = self.dense(hidden_state)
364
+ hidden_state = self.dropout(hidden_state)
365
+ hidden_state = hidden_state + input_tensor
366
+ return hidden_state
367
+
368
+
369
+ class CvtLayer(nn.Module):
370
+ """
371
+ CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ num_heads,
377
+ embed_dim,
378
+ kernel_size,
379
+ padding_q,
380
+ padding_kv,
381
+ stride_q,
382
+ stride_kv,
383
+ qkv_projection_method,
384
+ qkv_bias,
385
+ attention_drop_rate,
386
+ drop_rate,
387
+ mlp_ratio,
388
+ drop_path_rate,
389
+ with_cls_token=True,
390
+ ):
391
+ super().__init__()
392
+ self.attention = CvtAttention(
393
+ num_heads,
394
+ embed_dim,
395
+ kernel_size,
396
+ padding_q,
397
+ padding_kv,
398
+ stride_q,
399
+ stride_kv,
400
+ qkv_projection_method,
401
+ qkv_bias,
402
+ attention_drop_rate,
403
+ drop_rate,
404
+ with_cls_token,
405
+ )
406
+
407
+ self.intermediate = CvtIntermediate(embed_dim, mlp_ratio)
408
+ self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate)
409
+ self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
410
+ self.layernorm_before = nn.LayerNorm(embed_dim)
411
+ self.layernorm_after = nn.LayerNorm(embed_dim)
412
+
413
+ def forward(self, hidden_state, height, width):
414
+ self_attention_output = self.attention(
415
+ self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention
416
+ height,
417
+ width,
418
+ )
419
+ attention_output = self_attention_output
420
+ attention_output = self.drop_path(attention_output)
421
+
422
+ # first residual connection
423
+ hidden_state = attention_output + hidden_state
424
+
425
+ # in Cvt, layernorm is also applied after self-attention
426
+ layer_output = self.layernorm_after(hidden_state)
427
+ layer_output = self.intermediate(layer_output)
428
+
429
+ # second residual connection is done here
430
+ layer_output = self.output(layer_output, hidden_state)
431
+ layer_output = self.drop_path(layer_output)
432
+ return layer_output
433
+
434
+
435
+ class CvtStage(nn.Module):
436
+ def __init__(self, config, stage):
437
+ super().__init__()
438
+ self.config = config
439
+ self.stage = stage
440
+ if self.config.cls_token[self.stage]:
441
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1]))
442
+
443
+ self.embedding = CvtEmbeddings(
444
+ patch_size=config.patch_sizes[self.stage],
445
+ stride=config.patch_stride[self.stage],
446
+ num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
447
+ embed_dim=config.embed_dim[self.stage],
448
+ padding=config.patch_padding[self.stage],
449
+ dropout_rate=config.drop_rate[self.stage],
450
+ )
451
+
452
+ drop_path_rates = [
453
+ x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu")
454
+ ]
455
+
456
+ self.layers = nn.Sequential(
457
+ *[
458
+ CvtLayer(
459
+ num_heads=config.num_heads[self.stage],
460
+ embed_dim=config.embed_dim[self.stage],
461
+ kernel_size=config.kernel_qkv[self.stage],
462
+ padding_q=config.padding_q[self.stage],
463
+ padding_kv=config.padding_kv[self.stage],
464
+ stride_kv=config.stride_kv[self.stage],
465
+ stride_q=config.stride_q[self.stage],
466
+ qkv_projection_method=config.qkv_projection_method[self.stage],
467
+ qkv_bias=config.qkv_bias[self.stage],
468
+ attention_drop_rate=config.attention_drop_rate[self.stage],
469
+ drop_rate=config.drop_rate[self.stage],
470
+ drop_path_rate=drop_path_rates[self.stage],
471
+ mlp_ratio=config.mlp_ratio[self.stage],
472
+ with_cls_token=config.cls_token[self.stage],
473
+ )
474
+ for _ in range(config.depth[self.stage])
475
+ ]
476
+ )
477
+
478
+ def forward(self, hidden_state):
479
+ cls_token = None
480
+ hidden_state = self.embedding(hidden_state)
481
+ batch_size, num_channels, height, width = hidden_state.shape
482
+ # rearrange b c h w -> b (h w) c"
483
+ hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1)
484
+ if self.config.cls_token[self.stage]:
485
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
486
+ hidden_state = torch.cat((cls_token, hidden_state), dim=1)
487
+
488
+ for layer in self.layers:
489
+ layer_outputs = layer(hidden_state, height, width)
490
+ hidden_state = layer_outputs
491
+
492
+ if self.config.cls_token[self.stage]:
493
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
494
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
495
+ return hidden_state, cls_token
496
+
497
+
498
+ class CvtEncoder(nn.Module):
499
+ def __init__(self, config):
500
+ super().__init__()
501
+ self.config = config
502
+ self.stages = nn.ModuleList([])
503
+ for stage_idx in range(len(config.depth)):
504
+ self.stages.append(CvtStage(config, stage_idx))
505
+
506
+ def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
507
+ all_hidden_states = () if output_hidden_states else None
508
+ hidden_state = pixel_values
509
+
510
+ cls_token = None
511
+ for _, (stage_module) in enumerate(self.stages):
512
+ hidden_state, cls_token = stage_module(hidden_state)
513
+ if output_hidden_states:
514
+ all_hidden_states = all_hidden_states + (hidden_state,)
515
+
516
+ if not return_dict:
517
+ return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
518
+
519
+ return BaseModelOutputWithCLSToken(
520
+ last_hidden_state=hidden_state,
521
+ cls_token_value=cls_token,
522
+ hidden_states=all_hidden_states,
523
+ )
524
+
525
+
526
+ class CvtPreTrainedModel(PreTrainedModel):
527
+ """
528
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
529
+ models.
530
+ """
531
+
532
+ config_class = CvtConfig
533
+ base_model_prefix = "cvt"
534
+ main_input_name = "pixel_values"
535
+ _no_split_modules = ["CvtLayer"]
536
+
537
+ def _init_weights(self, module):
538
+ """Initialize the weights"""
539
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
540
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
541
+ if module.bias is not None:
542
+ module.bias.data.zero_()
543
+ elif isinstance(module, nn.LayerNorm):
544
+ module.bias.data.zero_()
545
+ module.weight.data.fill_(1.0)
546
+ elif isinstance(module, CvtStage):
547
+ if self.config.cls_token[module.stage]:
548
+ module.cls_token.data = nn.init.trunc_normal_(
549
+ module.cls_token.data, mean=0.0, std=self.config.initializer_range
550
+ )
551
+
552
+
553
+ CVT_START_DOCSTRING = r"""
554
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
555
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
556
+ behavior.
557
+
558
+ Parameters:
559
+ config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
560
+ Initializing with a config file does not load the weights associated with the model, only the
561
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
562
+ """
563
+
564
+ CVT_INPUTS_DOCSTRING = r"""
565
+ Args:
566
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
567
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
568
+ for details.
569
+ output_hidden_states (`bool`, *optional*):
570
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
571
+ more detail.
572
+ return_dict (`bool`, *optional*):
573
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
574
+ """
575
+
576
+
577
+ @add_start_docstrings(
578
+ "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
579
+ CVT_START_DOCSTRING,
580
+ )
581
+ class CvtModel(CvtPreTrainedModel):
582
+ def __init__(self, config, add_pooling_layer=True):
583
+ super().__init__(config)
584
+ self.config = config
585
+ self.encoder = CvtEncoder(config)
586
+ self.post_init()
587
+
588
+ def _prune_heads(self, heads_to_prune):
589
+ """
590
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
591
+ class PreTrainedModel
592
+ """
593
+ for layer, heads in heads_to_prune.items():
594
+ self.encoder.layer[layer].attention.prune_heads(heads)
595
+
596
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
597
+ @add_code_sample_docstrings(
598
+ checkpoint=_CHECKPOINT_FOR_DOC,
599
+ output_type=BaseModelOutputWithCLSToken,
600
+ config_class=_CONFIG_FOR_DOC,
601
+ modality="vision",
602
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
603
+ )
604
+ def forward(
605
+ self,
606
+ pixel_values: Optional[torch.Tensor] = None,
607
+ output_hidden_states: Optional[bool] = None,
608
+ return_dict: Optional[bool] = None,
609
+ ) -> Union[Tuple, BaseModelOutputWithCLSToken]:
610
+ output_hidden_states = (
611
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
612
+ )
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ if pixel_values is None:
616
+ raise ValueError("You have to specify pixel_values")
617
+
618
+ encoder_outputs = self.encoder(
619
+ pixel_values,
620
+ output_hidden_states=output_hidden_states,
621
+ return_dict=return_dict,
622
+ )
623
+ sequence_output = encoder_outputs[0]
624
+
625
+ if not return_dict:
626
+ return (sequence_output,) + encoder_outputs[1:]
627
+
628
+ return BaseModelOutputWithCLSToken(
629
+ last_hidden_state=sequence_output,
630
+ cls_token_value=encoder_outputs.cls_token_value,
631
+ hidden_states=encoder_outputs.hidden_states,
632
+ )
633
+
634
+
635
+ @add_start_docstrings(
636
+ """
637
+ Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
638
+ the [CLS] token) e.g. for ImageNet.
639
+ """,
640
+ CVT_START_DOCSTRING,
641
+ )
642
+ class CvtForImageClassification(CvtPreTrainedModel):
643
+ def __init__(self, config):
644
+ super().__init__(config)
645
+
646
+ self.num_labels = config.num_labels
647
+ self.cvt = CvtModel(config, add_pooling_layer=False)
648
+ self.layernorm = nn.LayerNorm(config.embed_dim[-1])
649
+ # Classifier head
650
+ self.classifier = (
651
+ nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
652
+ )
653
+
654
+ # Initialize weights and apply final processing
655
+ self.post_init()
656
+
657
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
658
+ @add_code_sample_docstrings(
659
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
660
+ output_type=ImageClassifierOutputWithNoAttention,
661
+ config_class=_CONFIG_FOR_DOC,
662
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
663
+ )
664
+ def forward(
665
+ self,
666
+ pixel_values: Optional[torch.Tensor] = None,
667
+ labels: Optional[torch.Tensor] = None,
668
+ output_hidden_states: Optional[bool] = None,
669
+ return_dict: Optional[bool] = None,
670
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
671
+ r"""
672
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
673
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
674
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
675
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
676
+ """
677
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
678
+ outputs = self.cvt(
679
+ pixel_values,
680
+ output_hidden_states=output_hidden_states,
681
+ return_dict=return_dict,
682
+ )
683
+
684
+ sequence_output = outputs[0]
685
+ cls_token = outputs[1]
686
+ if self.config.cls_token[-1]:
687
+ sequence_output = self.layernorm(cls_token)
688
+ else:
689
+ batch_size, num_channels, height, width = sequence_output.shape
690
+ # rearrange "b c h w -> b (h w) c"
691
+ sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
692
+ sequence_output = self.layernorm(sequence_output)
693
+
694
+ sequence_output_mean = sequence_output.mean(dim=1)
695
+ logits = self.classifier(sequence_output_mean)
696
+
697
+ loss = None
698
+ if labels is not None:
699
+ if self.config.problem_type is None:
700
+ if self.config.num_labels == 1:
701
+ self.config.problem_type = "regression"
702
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
703
+ self.config.problem_type = "single_label_classification"
704
+ else:
705
+ self.config.problem_type = "multi_label_classification"
706
+
707
+ if self.config.problem_type == "regression":
708
+ loss_fct = MSELoss()
709
+ if self.config.num_labels == 1:
710
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
711
+ else:
712
+ loss = loss_fct(logits, labels)
713
+ elif self.config.problem_type == "single_label_classification":
714
+ loss_fct = CrossEntropyLoss()
715
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
716
+ elif self.config.problem_type == "multi_label_classification":
717
+ loss_fct = BCEWithLogitsLoss()
718
+ loss = loss_fct(logits, labels)
719
+
720
+ if not return_dict:
721
+ output = (logits,) + outputs[2:]
722
+ return ((loss,) + output) if loss is not None else output
723
+
724
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
725
+
726
+
727
+ __all__ = ["CvtForImageClassification", "CvtModel", "CvtPreTrainedModel"]
docs/transformers/build/lib/transformers/models/cvt/modeling_tf_cvt.py ADDED
@@ -0,0 +1,1096 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 Cvt model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import collections.abc
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import tensorflow as tf
24
+
25
+ from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention
26
+ from ...modeling_tf_utils import (
27
+ TFModelInputType,
28
+ TFPreTrainedModel,
29
+ TFSequenceClassificationLoss,
30
+ get_initializer,
31
+ keras,
32
+ keras_serializable,
33
+ unpack_inputs,
34
+ )
35
+ from ...tf_utils import shape_list, stable_softmax
36
+ from ...utils import (
37
+ ModelOutput,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from .configuration_cvt import CvtConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ # General docstring
49
+ _CONFIG_FOR_DOC = "CvtConfig"
50
+
51
+
52
+ @dataclass
53
+ class TFBaseModelOutputWithCLSToken(ModelOutput):
54
+ """
55
+ Base class for model's outputs.
56
+
57
+ Args:
58
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
59
+ Sequence of hidden-states at the output of the last layer of the model.
60
+ cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`):
61
+ Classification token at the output of the last layer of the model.
62
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
63
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
64
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
65
+ the initial embedding outputs.
66
+ """
67
+
68
+ last_hidden_state: Optional[tf.Tensor] = None
69
+ cls_token_value: Optional[tf.Tensor] = None
70
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
71
+
72
+
73
+ class TFCvtDropPath(keras.layers.Layer):
74
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
75
+ References:
76
+ (1) github.com:rwightman/pytorch-image-models
77
+ """
78
+
79
+ def __init__(self, drop_prob: float, **kwargs):
80
+ super().__init__(**kwargs)
81
+ self.drop_prob = drop_prob
82
+
83
+ def call(self, x: tf.Tensor, training=None):
84
+ if self.drop_prob == 0.0 or not training:
85
+ return x
86
+ keep_prob = 1 - self.drop_prob
87
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
88
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
89
+ random_tensor = tf.floor(random_tensor)
90
+ return (x / keep_prob) * random_tensor
91
+
92
+
93
+ class TFCvtEmbeddings(keras.layers.Layer):
94
+ """Construct the Convolutional Token Embeddings."""
95
+
96
+ def __init__(
97
+ self,
98
+ config: CvtConfig,
99
+ patch_size: int,
100
+ num_channels: int,
101
+ embed_dim: int,
102
+ stride: int,
103
+ padding: int,
104
+ dropout_rate: float,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+ self.convolution_embeddings = TFCvtConvEmbeddings(
109
+ config,
110
+ patch_size=patch_size,
111
+ num_channels=num_channels,
112
+ embed_dim=embed_dim,
113
+ stride=stride,
114
+ padding=padding,
115
+ name="convolution_embeddings",
116
+ )
117
+ self.dropout = keras.layers.Dropout(dropout_rate)
118
+
119
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
120
+ hidden_state = self.convolution_embeddings(pixel_values)
121
+ hidden_state = self.dropout(hidden_state, training=training)
122
+ return hidden_state
123
+
124
+ def build(self, input_shape=None):
125
+ if self.built:
126
+ return
127
+ self.built = True
128
+ if getattr(self, "convolution_embeddings", None) is not None:
129
+ with tf.name_scope(self.convolution_embeddings.name):
130
+ self.convolution_embeddings.build(None)
131
+
132
+
133
+ class TFCvtConvEmbeddings(keras.layers.Layer):
134
+ """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts."""
135
+
136
+ def __init__(
137
+ self,
138
+ config: CvtConfig,
139
+ patch_size: int,
140
+ num_channels: int,
141
+ embed_dim: int,
142
+ stride: int,
143
+ padding: int,
144
+ **kwargs,
145
+ ):
146
+ super().__init__(**kwargs)
147
+ self.padding = keras.layers.ZeroPadding2D(padding=padding)
148
+ self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
149
+ self.projection = keras.layers.Conv2D(
150
+ filters=embed_dim,
151
+ kernel_size=patch_size,
152
+ strides=stride,
153
+ padding="valid",
154
+ data_format="channels_last",
155
+ kernel_initializer=get_initializer(config.initializer_range),
156
+ name="projection",
157
+ )
158
+ # Using the same default epsilon as PyTorch
159
+ self.normalization = keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")
160
+ self.num_channels = num_channels
161
+ self.embed_dim = embed_dim
162
+
163
+ def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
164
+ if isinstance(pixel_values, dict):
165
+ pixel_values = pixel_values["pixel_values"]
166
+
167
+ pixel_values = self.projection(self.padding(pixel_values))
168
+
169
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
170
+ batch_size, height, width, num_channels = shape_list(pixel_values)
171
+ hidden_size = height * width
172
+ pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels))
173
+ pixel_values = self.normalization(pixel_values)
174
+
175
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
176
+ pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
177
+ return pixel_values
178
+
179
+ def build(self, input_shape=None):
180
+ if self.built:
181
+ return
182
+ self.built = True
183
+ if getattr(self, "projection", None) is not None:
184
+ with tf.name_scope(self.projection.name):
185
+ self.projection.build([None, None, None, self.num_channels])
186
+ if getattr(self, "normalization", None) is not None:
187
+ with tf.name_scope(self.normalization.name):
188
+ self.normalization.build([None, None, self.embed_dim])
189
+
190
+
191
+ class TFCvtSelfAttentionConvProjection(keras.layers.Layer):
192
+ """Convolutional projection layer."""
193
+
194
+ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs):
195
+ super().__init__(**kwargs)
196
+ self.padding = keras.layers.ZeroPadding2D(padding=padding)
197
+ self.convolution = keras.layers.Conv2D(
198
+ filters=embed_dim,
199
+ kernel_size=kernel_size,
200
+ kernel_initializer=get_initializer(config.initializer_range),
201
+ padding="valid",
202
+ strides=stride,
203
+ use_bias=False,
204
+ name="convolution",
205
+ groups=embed_dim,
206
+ )
207
+ # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
208
+ self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
209
+ self.embed_dim = embed_dim
210
+
211
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
212
+ hidden_state = self.convolution(self.padding(hidden_state))
213
+ hidden_state = self.normalization(hidden_state, training=training)
214
+ return hidden_state
215
+
216
+ def build(self, input_shape=None):
217
+ if self.built:
218
+ return
219
+ self.built = True
220
+ if getattr(self, "convolution", None) is not None:
221
+ with tf.name_scope(self.convolution.name):
222
+ self.convolution.build([None, None, None, self.embed_dim])
223
+ if getattr(self, "normalization", None) is not None:
224
+ with tf.name_scope(self.normalization.name):
225
+ self.normalization.build([None, None, None, self.embed_dim])
226
+
227
+
228
+ class TFCvtSelfAttentionLinearProjection(keras.layers.Layer):
229
+ """Linear projection layer used to flatten tokens into 1D."""
230
+
231
+ def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
232
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
233
+ batch_size, height, width, num_channels = shape_list(hidden_state)
234
+ hidden_size = height * width
235
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
236
+ return hidden_state
237
+
238
+
239
+ class TFCvtSelfAttentionProjection(keras.layers.Layer):
240
+ """Convolutional Projection for Attention."""
241
+
242
+ def __init__(
243
+ self,
244
+ config: CvtConfig,
245
+ embed_dim: int,
246
+ kernel_size: int,
247
+ stride: int,
248
+ padding: int,
249
+ projection_method: str = "dw_bn",
250
+ **kwargs,
251
+ ):
252
+ super().__init__(**kwargs)
253
+ if projection_method == "dw_bn":
254
+ self.convolution_projection = TFCvtSelfAttentionConvProjection(
255
+ config, embed_dim, kernel_size, stride, padding, name="convolution_projection"
256
+ )
257
+ self.linear_projection = TFCvtSelfAttentionLinearProjection()
258
+
259
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
260
+ hidden_state = self.convolution_projection(hidden_state, training=training)
261
+ hidden_state = self.linear_projection(hidden_state)
262
+ return hidden_state
263
+
264
+ def build(self, input_shape=None):
265
+ if self.built:
266
+ return
267
+ self.built = True
268
+ if getattr(self, "convolution_projection", None) is not None:
269
+ with tf.name_scope(self.convolution_projection.name):
270
+ self.convolution_projection.build(None)
271
+
272
+
273
+ class TFCvtSelfAttention(keras.layers.Layer):
274
+ """
275
+ Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for
276
+ query, key, and value embeddings.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ config: CvtConfig,
282
+ num_heads: int,
283
+ embed_dim: int,
284
+ kernel_size: int,
285
+ stride_q: int,
286
+ stride_kv: int,
287
+ padding_q: int,
288
+ padding_kv: int,
289
+ qkv_projection_method: str,
290
+ qkv_bias: bool,
291
+ attention_drop_rate: float,
292
+ with_cls_token: bool = True,
293
+ **kwargs,
294
+ ):
295
+ super().__init__(**kwargs)
296
+ self.scale = embed_dim**-0.5
297
+ self.with_cls_token = with_cls_token
298
+ self.embed_dim = embed_dim
299
+ self.num_heads = num_heads
300
+
301
+ self.convolution_projection_query = TFCvtSelfAttentionProjection(
302
+ config,
303
+ embed_dim,
304
+ kernel_size,
305
+ stride_q,
306
+ padding_q,
307
+ projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
308
+ name="convolution_projection_query",
309
+ )
310
+ self.convolution_projection_key = TFCvtSelfAttentionProjection(
311
+ config,
312
+ embed_dim,
313
+ kernel_size,
314
+ stride_kv,
315
+ padding_kv,
316
+ projection_method=qkv_projection_method,
317
+ name="convolution_projection_key",
318
+ )
319
+ self.convolution_projection_value = TFCvtSelfAttentionProjection(
320
+ config,
321
+ embed_dim,
322
+ kernel_size,
323
+ stride_kv,
324
+ padding_kv,
325
+ projection_method=qkv_projection_method,
326
+ name="convolution_projection_value",
327
+ )
328
+
329
+ self.projection_query = keras.layers.Dense(
330
+ units=embed_dim,
331
+ kernel_initializer=get_initializer(config.initializer_range),
332
+ use_bias=qkv_bias,
333
+ bias_initializer="zeros",
334
+ name="projection_query",
335
+ )
336
+ self.projection_key = keras.layers.Dense(
337
+ units=embed_dim,
338
+ kernel_initializer=get_initializer(config.initializer_range),
339
+ use_bias=qkv_bias,
340
+ bias_initializer="zeros",
341
+ name="projection_key",
342
+ )
343
+ self.projection_value = keras.layers.Dense(
344
+ units=embed_dim,
345
+ kernel_initializer=get_initializer(config.initializer_range),
346
+ use_bias=qkv_bias,
347
+ bias_initializer="zeros",
348
+ name="projection_value",
349
+ )
350
+ self.dropout = keras.layers.Dropout(attention_drop_rate)
351
+
352
+ def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor:
353
+ batch_size, hidden_size, _ = shape_list(hidden_state)
354
+ head_dim = self.embed_dim // self.num_heads
355
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim))
356
+ hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3))
357
+ return hidden_state
358
+
359
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
360
+ if self.with_cls_token:
361
+ cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
362
+
363
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
364
+ batch_size, hidden_size, num_channels = shape_list(hidden_state)
365
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
366
+
367
+ key = self.convolution_projection_key(hidden_state, training=training)
368
+ query = self.convolution_projection_query(hidden_state, training=training)
369
+ value = self.convolution_projection_value(hidden_state, training=training)
370
+
371
+ if self.with_cls_token:
372
+ query = tf.concat((cls_token, query), axis=1)
373
+ key = tf.concat((cls_token, key), axis=1)
374
+ value = tf.concat((cls_token, value), axis=1)
375
+
376
+ head_dim = self.embed_dim // self.num_heads
377
+
378
+ query = self.rearrange_for_multi_head_attention(self.projection_query(query))
379
+ key = self.rearrange_for_multi_head_attention(self.projection_key(key))
380
+ value = self.rearrange_for_multi_head_attention(self.projection_value(value))
381
+
382
+ attention_score = tf.matmul(query, key, transpose_b=True) * self.scale
383
+ attention_probs = stable_softmax(logits=attention_score, axis=-1)
384
+ attention_probs = self.dropout(attention_probs, training=training)
385
+
386
+ context = tf.matmul(attention_probs, value)
387
+ # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)"
388
+ _, _, hidden_size, _ = shape_list(context)
389
+ context = tf.transpose(context, perm=(0, 2, 1, 3))
390
+ context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
391
+ return context
392
+
393
+ def build(self, input_shape=None):
394
+ if self.built:
395
+ return
396
+ self.built = True
397
+ if getattr(self, "convolution_projection_query", None) is not None:
398
+ with tf.name_scope(self.convolution_projection_query.name):
399
+ self.convolution_projection_query.build(None)
400
+ if getattr(self, "convolution_projection_key", None) is not None:
401
+ with tf.name_scope(self.convolution_projection_key.name):
402
+ self.convolution_projection_key.build(None)
403
+ if getattr(self, "convolution_projection_value", None) is not None:
404
+ with tf.name_scope(self.convolution_projection_value.name):
405
+ self.convolution_projection_value.build(None)
406
+ if getattr(self, "projection_query", None) is not None:
407
+ with tf.name_scope(self.projection_query.name):
408
+ self.projection_query.build([None, None, self.embed_dim])
409
+ if getattr(self, "projection_key", None) is not None:
410
+ with tf.name_scope(self.projection_key.name):
411
+ self.projection_key.build([None, None, self.embed_dim])
412
+ if getattr(self, "projection_value", None) is not None:
413
+ with tf.name_scope(self.projection_value.name):
414
+ self.projection_value.build([None, None, self.embed_dim])
415
+
416
+
417
+ class TFCvtSelfOutput(keras.layers.Layer):
418
+ """Output of the Attention layer ."""
419
+
420
+ def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs):
421
+ super().__init__(**kwargs)
422
+ self.dense = keras.layers.Dense(
423
+ units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
424
+ )
425
+ self.dropout = keras.layers.Dropout(drop_rate)
426
+ self.embed_dim = embed_dim
427
+
428
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
429
+ hidden_state = self.dense(inputs=hidden_state)
430
+ hidden_state = self.dropout(inputs=hidden_state, training=training)
431
+ return hidden_state
432
+
433
+ def build(self, input_shape=None):
434
+ if self.built:
435
+ return
436
+ self.built = True
437
+ if getattr(self, "dense", None) is not None:
438
+ with tf.name_scope(self.dense.name):
439
+ self.dense.build([None, None, self.embed_dim])
440
+
441
+
442
+ class TFCvtAttention(keras.layers.Layer):
443
+ """Attention layer. First chunk of the convolutional transformer block."""
444
+
445
+ def __init__(
446
+ self,
447
+ config: CvtConfig,
448
+ num_heads: int,
449
+ embed_dim: int,
450
+ kernel_size: int,
451
+ stride_q: int,
452
+ stride_kv: int,
453
+ padding_q: int,
454
+ padding_kv: int,
455
+ qkv_projection_method: str,
456
+ qkv_bias: bool,
457
+ attention_drop_rate: float,
458
+ drop_rate: float,
459
+ with_cls_token: bool = True,
460
+ **kwargs,
461
+ ):
462
+ super().__init__(**kwargs)
463
+ self.attention = TFCvtSelfAttention(
464
+ config,
465
+ num_heads,
466
+ embed_dim,
467
+ kernel_size,
468
+ stride_q,
469
+ stride_kv,
470
+ padding_q,
471
+ padding_kv,
472
+ qkv_projection_method,
473
+ qkv_bias,
474
+ attention_drop_rate,
475
+ with_cls_token,
476
+ name="attention",
477
+ )
478
+ self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output")
479
+
480
+ def prune_heads(self, heads):
481
+ raise NotImplementedError
482
+
483
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False):
484
+ self_output = self.attention(hidden_state, height, width, training=training)
485
+ attention_output = self.dense_output(self_output, training=training)
486
+ return attention_output
487
+
488
+ def build(self, input_shape=None):
489
+ if self.built:
490
+ return
491
+ self.built = True
492
+ if getattr(self, "attention", None) is not None:
493
+ with tf.name_scope(self.attention.name):
494
+ self.attention.build(None)
495
+ if getattr(self, "dense_output", None) is not None:
496
+ with tf.name_scope(self.dense_output.name):
497
+ self.dense_output.build(None)
498
+
499
+
500
+ class TFCvtIntermediate(keras.layers.Layer):
501
+ """Intermediate dense layer. Second chunk of the convolutional transformer block."""
502
+
503
+ def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs):
504
+ super().__init__(**kwargs)
505
+ self.dense = keras.layers.Dense(
506
+ units=int(embed_dim * mlp_ratio),
507
+ kernel_initializer=get_initializer(config.initializer_range),
508
+ activation="gelu",
509
+ name="dense",
510
+ )
511
+ self.embed_dim = embed_dim
512
+
513
+ def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
514
+ hidden_state = self.dense(hidden_state)
515
+ return hidden_state
516
+
517
+ def build(self, input_shape=None):
518
+ if self.built:
519
+ return
520
+ self.built = True
521
+ if getattr(self, "dense", None) is not None:
522
+ with tf.name_scope(self.dense.name):
523
+ self.dense.build([None, None, self.embed_dim])
524
+
525
+
526
+ class TFCvtOutput(keras.layers.Layer):
527
+ """
528
+ Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection.
529
+ """
530
+
531
+ def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs):
532
+ super().__init__(**kwargs)
533
+ self.dense = keras.layers.Dense(
534
+ units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
535
+ )
536
+ self.dropout = keras.layers.Dropout(drop_rate)
537
+ self.embed_dim = embed_dim
538
+ self.mlp_ratio = mlp_ratio
539
+
540
+ def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
541
+ hidden_state = self.dense(inputs=hidden_state)
542
+ hidden_state = self.dropout(inputs=hidden_state, training=training)
543
+ hidden_state = hidden_state + input_tensor
544
+ return hidden_state
545
+
546
+ def build(self, input_shape=None):
547
+ if self.built:
548
+ return
549
+ self.built = True
550
+ if getattr(self, "dense", None) is not None:
551
+ with tf.name_scope(self.dense.name):
552
+ self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)])
553
+
554
+
555
+ class TFCvtLayer(keras.layers.Layer):
556
+ """
557
+ Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It
558
+ consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the
559
+ `Block` class in the original implementation.
560
+ """
561
+
562
+ def __init__(
563
+ self,
564
+ config: CvtConfig,
565
+ num_heads: int,
566
+ embed_dim: int,
567
+ kernel_size: int,
568
+ stride_q: int,
569
+ stride_kv: int,
570
+ padding_q: int,
571
+ padding_kv: int,
572
+ qkv_projection_method: str,
573
+ qkv_bias: bool,
574
+ attention_drop_rate: float,
575
+ drop_rate: float,
576
+ mlp_ratio: float,
577
+ drop_path_rate: float,
578
+ with_cls_token: bool = True,
579
+ **kwargs,
580
+ ):
581
+ super().__init__(**kwargs)
582
+ self.attention = TFCvtAttention(
583
+ config,
584
+ num_heads,
585
+ embed_dim,
586
+ kernel_size,
587
+ stride_q,
588
+ stride_kv,
589
+ padding_q,
590
+ padding_kv,
591
+ qkv_projection_method,
592
+ qkv_bias,
593
+ attention_drop_rate,
594
+ drop_rate,
595
+ with_cls_token,
596
+ name="attention",
597
+ )
598
+ self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate")
599
+ self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output")
600
+ # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour.
601
+ self.drop_path = (
602
+ TFCvtDropPath(drop_path_rate, name="drop_path")
603
+ if drop_path_rate > 0.0
604
+ else keras.layers.Activation("linear", name="drop_path")
605
+ )
606
+ # Using the same default epsilon as PyTorch
607
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before")
608
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after")
609
+ self.embed_dim = embed_dim
610
+
611
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
612
+ # in Cvt, layernorm is applied before self-attention
613
+ attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training)
614
+ attention_output = self.drop_path(attention_output, training=training)
615
+
616
+ # first residual connection
617
+ hidden_state = attention_output + hidden_state
618
+
619
+ # in Cvt, layernorm is also applied after self-attention
620
+ layer_output = self.layernorm_after(hidden_state)
621
+ layer_output = self.intermediate(layer_output)
622
+
623
+ # second residual connection is done here
624
+ layer_output = self.dense_output(layer_output, hidden_state)
625
+ layer_output = self.drop_path(layer_output, training=training)
626
+ return layer_output
627
+
628
+ def build(self, input_shape=None):
629
+ if self.built:
630
+ return
631
+ self.built = True
632
+ if getattr(self, "attention", None) is not None:
633
+ with tf.name_scope(self.attention.name):
634
+ self.attention.build(None)
635
+ if getattr(self, "intermediate", None) is not None:
636
+ with tf.name_scope(self.intermediate.name):
637
+ self.intermediate.build(None)
638
+ if getattr(self, "dense_output", None) is not None:
639
+ with tf.name_scope(self.dense_output.name):
640
+ self.dense_output.build(None)
641
+ if getattr(self, "drop_path", None) is not None:
642
+ with tf.name_scope(self.drop_path.name):
643
+ self.drop_path.build(None)
644
+ if getattr(self, "layernorm_before", None) is not None:
645
+ with tf.name_scope(self.layernorm_before.name):
646
+ self.layernorm_before.build([None, None, self.embed_dim])
647
+ if getattr(self, "layernorm_after", None) is not None:
648
+ with tf.name_scope(self.layernorm_after.name):
649
+ self.layernorm_after.build([None, None, self.embed_dim])
650
+
651
+
652
+ class TFCvtStage(keras.layers.Layer):
653
+ """
654
+ Cvt stage (encoder block). Each stage has 2 parts :
655
+ - (1) A Convolutional Token Embedding layer
656
+ - (2) A Convolutional Transformer Block (layer).
657
+ The classification token is added only in the last stage.
658
+
659
+ Args:
660
+ config ([`CvtConfig`]): Model configuration class.
661
+ stage (`int`): Stage number.
662
+ """
663
+
664
+ def __init__(self, config: CvtConfig, stage: int, **kwargs):
665
+ super().__init__(**kwargs)
666
+ self.config = config
667
+ self.stage = stage
668
+ if self.config.cls_token[self.stage]:
669
+ self.cls_token = self.add_weight(
670
+ shape=(1, 1, self.config.embed_dim[-1]),
671
+ initializer=get_initializer(self.config.initializer_range),
672
+ trainable=True,
673
+ name="cvt.encoder.stages.2.cls_token",
674
+ )
675
+
676
+ self.embedding = TFCvtEmbeddings(
677
+ self.config,
678
+ patch_size=config.patch_sizes[self.stage],
679
+ num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
680
+ stride=config.patch_stride[self.stage],
681
+ embed_dim=config.embed_dim[self.stage],
682
+ padding=config.patch_padding[self.stage],
683
+ dropout_rate=config.drop_rate[self.stage],
684
+ name="embedding",
685
+ )
686
+
687
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage])
688
+ drop_path_rates = [x.numpy().item() for x in drop_path_rates]
689
+ self.layers = [
690
+ TFCvtLayer(
691
+ config,
692
+ num_heads=config.num_heads[self.stage],
693
+ embed_dim=config.embed_dim[self.stage],
694
+ kernel_size=config.kernel_qkv[self.stage],
695
+ stride_q=config.stride_q[self.stage],
696
+ stride_kv=config.stride_kv[self.stage],
697
+ padding_q=config.padding_q[self.stage],
698
+ padding_kv=config.padding_kv[self.stage],
699
+ qkv_projection_method=config.qkv_projection_method[self.stage],
700
+ qkv_bias=config.qkv_bias[self.stage],
701
+ attention_drop_rate=config.attention_drop_rate[self.stage],
702
+ drop_rate=config.drop_rate[self.stage],
703
+ mlp_ratio=config.mlp_ratio[self.stage],
704
+ drop_path_rate=drop_path_rates[self.stage],
705
+ with_cls_token=config.cls_token[self.stage],
706
+ name=f"layers.{j}",
707
+ )
708
+ for j in range(config.depth[self.stage])
709
+ ]
710
+
711
+ def call(self, hidden_state: tf.Tensor, training: bool = False):
712
+ cls_token = None
713
+ hidden_state = self.embedding(hidden_state, training)
714
+
715
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
716
+ batch_size, height, width, num_channels = shape_list(hidden_state)
717
+ hidden_size = height * width
718
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
719
+
720
+ if self.config.cls_token[self.stage]:
721
+ cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
722
+ hidden_state = tf.concat((cls_token, hidden_state), axis=1)
723
+
724
+ for layer in self.layers:
725
+ layer_outputs = layer(hidden_state, height, width, training=training)
726
+ hidden_state = layer_outputs
727
+
728
+ if self.config.cls_token[self.stage]:
729
+ cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
730
+
731
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
732
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
733
+ return hidden_state, cls_token
734
+
735
+ def build(self, input_shape=None):
736
+ if self.built:
737
+ return
738
+ self.built = True
739
+ if getattr(self, "embedding", None) is not None:
740
+ with tf.name_scope(self.embedding.name):
741
+ self.embedding.build(None)
742
+ if getattr(self, "layers", None) is not None:
743
+ for layer in self.layers:
744
+ with tf.name_scope(layer.name):
745
+ layer.build(None)
746
+
747
+
748
+ class TFCvtEncoder(keras.layers.Layer):
749
+ """
750
+ Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers
751
+ (depth) being 1, 2 and 10.
752
+
753
+ Args:
754
+ config ([`CvtConfig`]): Model configuration class.
755
+ """
756
+
757
+ config_class = CvtConfig
758
+
759
+ def __init__(self, config: CvtConfig, **kwargs):
760
+ super().__init__(**kwargs)
761
+ self.config = config
762
+ self.stages = [
763
+ TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth))
764
+ ]
765
+
766
+ def call(
767
+ self,
768
+ pixel_values: TFModelInputType,
769
+ output_hidden_states: Optional[bool] = False,
770
+ return_dict: Optional[bool] = True,
771
+ training: Optional[bool] = False,
772
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
773
+ all_hidden_states = () if output_hidden_states else None
774
+ hidden_state = pixel_values
775
+ # When running on CPU, `keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
776
+ # as input format. So change the input format to (batch_size, height, width, num_channels).
777
+ hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))
778
+
779
+ cls_token = None
780
+ for _, (stage_module) in enumerate(self.stages):
781
+ hidden_state, cls_token = stage_module(hidden_state, training=training)
782
+ if output_hidden_states:
783
+ all_hidden_states = all_hidden_states + (hidden_state,)
784
+
785
+ # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules
786
+ hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2))
787
+ if output_hidden_states:
788
+ all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states])
789
+
790
+ if not return_dict:
791
+ return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
792
+
793
+ return TFBaseModelOutputWithCLSToken(
794
+ last_hidden_state=hidden_state,
795
+ cls_token_value=cls_token,
796
+ hidden_states=all_hidden_states,
797
+ )
798
+
799
+ def build(self, input_shape=None):
800
+ if self.built:
801
+ return
802
+ self.built = True
803
+ if getattr(self, "stages", None) is not None:
804
+ for layer in self.stages:
805
+ with tf.name_scope(layer.name):
806
+ layer.build(None)
807
+
808
+
809
+ @keras_serializable
810
+ class TFCvtMainLayer(keras.layers.Layer):
811
+ """Construct the Cvt model."""
812
+
813
+ config_class = CvtConfig
814
+
815
+ def __init__(self, config: CvtConfig, **kwargs):
816
+ super().__init__(**kwargs)
817
+ self.config = config
818
+ self.encoder = TFCvtEncoder(config, name="encoder")
819
+
820
+ @unpack_inputs
821
+ def call(
822
+ self,
823
+ pixel_values: TFModelInputType | None = None,
824
+ output_hidden_states: Optional[bool] = None,
825
+ return_dict: Optional[bool] = None,
826
+ training: Optional[bool] = False,
827
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
828
+ if pixel_values is None:
829
+ raise ValueError("You have to specify pixel_values")
830
+
831
+ encoder_outputs = self.encoder(
832
+ pixel_values,
833
+ output_hidden_states=output_hidden_states,
834
+ return_dict=return_dict,
835
+ training=training,
836
+ )
837
+
838
+ sequence_output = encoder_outputs[0]
839
+
840
+ if not return_dict:
841
+ return (sequence_output,) + encoder_outputs[1:]
842
+
843
+ return TFBaseModelOutputWithCLSToken(
844
+ last_hidden_state=sequence_output,
845
+ cls_token_value=encoder_outputs.cls_token_value,
846
+ hidden_states=encoder_outputs.hidden_states,
847
+ )
848
+
849
+ def build(self, input_shape=None):
850
+ if self.built:
851
+ return
852
+ self.built = True
853
+ if getattr(self, "encoder", None) is not None:
854
+ with tf.name_scope(self.encoder.name):
855
+ self.encoder.build(None)
856
+
857
+
858
+ class TFCvtPreTrainedModel(TFPreTrainedModel):
859
+ """
860
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
861
+ models.
862
+ """
863
+
864
+ config_class = CvtConfig
865
+ base_model_prefix = "cvt"
866
+ main_input_name = "pixel_values"
867
+
868
+
869
+ TFCVT_START_DOCSTRING = r"""
870
+
871
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
872
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
873
+ etc.)
874
+
875
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
876
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
877
+ behavior.
878
+
879
+ <Tip>
880
+
881
+ TF 2.0 models accepts two formats as inputs:
882
+
883
+ - having all inputs as keyword arguments (like PyTorch models), or
884
+ - having all inputs as a list, tuple or dict in the first positional arguments.
885
+
886
+ This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
887
+ tensors in the first argument of the model call function: `model(inputs)`.
888
+
889
+ </Tip>
890
+
891
+ Args:
892
+ config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
893
+ Initializing with a config file does not load the weights associated with the model, only the
894
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
895
+ """
896
+
897
+ TFCVT_INPUTS_DOCSTRING = r"""
898
+ Args:
899
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
900
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
901
+ for details.
902
+
903
+ output_hidden_states (`bool`, *optional*):
904
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
905
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
906
+ used instead.
907
+ return_dict (`bool`, *optional*):
908
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
909
+ eager mode, in graph mode the value will always be set to True.
910
+ training (`bool`, *optional*, defaults to `False``):
911
+ Whether or not to use the model in training mode (some modules like dropout modules have different
912
+ behaviors between training and evaluation).
913
+ """
914
+
915
+
916
+ @add_start_docstrings(
917
+ "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
918
+ TFCVT_START_DOCSTRING,
919
+ )
920
+ class TFCvtModel(TFCvtPreTrainedModel):
921
+ def __init__(self, config: CvtConfig, *inputs, **kwargs):
922
+ super().__init__(config, *inputs, **kwargs)
923
+
924
+ self.cvt = TFCvtMainLayer(config, name="cvt")
925
+
926
+ @unpack_inputs
927
+ @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
928
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC)
929
+ def call(
930
+ self,
931
+ pixel_values: tf.Tensor | None = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ training: Optional[bool] = False,
935
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
936
+ r"""
937
+ Returns:
938
+
939
+ Examples:
940
+
941
+ ```python
942
+ >>> from transformers import AutoImageProcessor, TFCvtModel
943
+ >>> from PIL import Image
944
+ >>> import requests
945
+
946
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
947
+ >>> image = Image.open(requests.get(url, stream=True).raw)
948
+
949
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
950
+ >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13")
951
+
952
+ >>> inputs = image_processor(images=image, return_tensors="tf")
953
+ >>> outputs = model(**inputs)
954
+ >>> last_hidden_states = outputs.last_hidden_state
955
+ ```"""
956
+
957
+ if pixel_values is None:
958
+ raise ValueError("You have to specify pixel_values")
959
+
960
+ outputs = self.cvt(
961
+ pixel_values=pixel_values,
962
+ output_hidden_states=output_hidden_states,
963
+ return_dict=return_dict,
964
+ training=training,
965
+ )
966
+
967
+ if not return_dict:
968
+ return (outputs[0],) + outputs[1:]
969
+
970
+ return TFBaseModelOutputWithCLSToken(
971
+ last_hidden_state=outputs.last_hidden_state,
972
+ cls_token_value=outputs.cls_token_value,
973
+ hidden_states=outputs.hidden_states,
974
+ )
975
+
976
+ def build(self, input_shape=None):
977
+ if self.built:
978
+ return
979
+ self.built = True
980
+ if getattr(self, "cvt", None) is not None:
981
+ with tf.name_scope(self.cvt.name):
982
+ self.cvt.build(None)
983
+
984
+
985
+ @add_start_docstrings(
986
+ """
987
+ Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
988
+ the [CLS] token) e.g. for ImageNet.
989
+ """,
990
+ TFCVT_START_DOCSTRING,
991
+ )
992
+ class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss):
993
+ def __init__(self, config: CvtConfig, *inputs, **kwargs):
994
+ super().__init__(config, *inputs, **kwargs)
995
+
996
+ self.num_labels = config.num_labels
997
+ self.cvt = TFCvtMainLayer(config, name="cvt")
998
+ # Using same default epsilon as in the original implementation.
999
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")
1000
+
1001
+ # Classifier head
1002
+ self.classifier = keras.layers.Dense(
1003
+ units=config.num_labels,
1004
+ kernel_initializer=get_initializer(config.initializer_range),
1005
+ use_bias=True,
1006
+ bias_initializer="zeros",
1007
+ name="classifier",
1008
+ )
1009
+ self.config = config
1010
+
1011
+ @unpack_inputs
1012
+ @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
1013
+ @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
1014
+ def call(
1015
+ self,
1016
+ pixel_values: tf.Tensor | None = None,
1017
+ labels: tf.Tensor | None = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ training: Optional[bool] = False,
1021
+ ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]:
1022
+ r"""
1023
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1024
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1025
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1026
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1027
+
1028
+ Returns:
1029
+
1030
+ Examples:
1031
+
1032
+ ```python
1033
+ >>> from transformers import AutoImageProcessor, TFCvtForImageClassification
1034
+ >>> import tensorflow as tf
1035
+ >>> from PIL import Image
1036
+ >>> import requests
1037
+
1038
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1039
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1040
+
1041
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
1042
+ >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13")
1043
+
1044
+ >>> inputs = image_processor(images=image, return_tensors="tf")
1045
+ >>> outputs = model(**inputs)
1046
+ >>> logits = outputs.logits
1047
+ >>> # model predicts one of the 1000 ImageNet classes
1048
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
1049
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
1050
+ ```"""
1051
+
1052
+ outputs = self.cvt(
1053
+ pixel_values,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ training=training,
1057
+ )
1058
+
1059
+ sequence_output = outputs[0]
1060
+ cls_token = outputs[1]
1061
+ if self.config.cls_token[-1]:
1062
+ sequence_output = self.layernorm(cls_token)
1063
+ else:
1064
+ # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels"
1065
+ batch_size, num_channels, height, width = shape_list(sequence_output)
1066
+ sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width))
1067
+ sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1))
1068
+ sequence_output = self.layernorm(sequence_output)
1069
+
1070
+ sequence_output_mean = tf.reduce_mean(sequence_output, axis=1)
1071
+ logits = self.classifier(sequence_output_mean)
1072
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1073
+
1074
+ if not return_dict:
1075
+ output = (logits,) + outputs[2:]
1076
+ return ((loss,) + output) if loss is not None else output
1077
+
1078
+ return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
1079
+
1080
+ def build(self, input_shape=None):
1081
+ if self.built:
1082
+ return
1083
+ self.built = True
1084
+ if getattr(self, "cvt", None) is not None:
1085
+ with tf.name_scope(self.cvt.name):
1086
+ self.cvt.build(None)
1087
+ if getattr(self, "layernorm", None) is not None:
1088
+ with tf.name_scope(self.layernorm.name):
1089
+ self.layernorm.build([None, None, self.config.embed_dim[-1]])
1090
+ if getattr(self, "classifier", None) is not None:
1091
+ if hasattr(self.classifier, "name"):
1092
+ with tf.name_scope(self.classifier.name):
1093
+ self.classifier.build([None, None, self.config.embed_dim[-1]])
1094
+
1095
+
1096
+ __all__ = ["TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel"]
docs/transformers/build/lib/transformers/models/dab_detr/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ...utils import _LazyModule
18
+ from ...utils.import_utils import define_import_structure
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from .configuration_dab_detr import *
23
+ from .modeling_dab_detr import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/dab_detr/configuration_dab_detr.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DAB-DETR model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import verify_backbone_config_arguments
20
+ from ..auto import CONFIG_MAPPING
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DabDetrConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate
29
+ a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a
30
+ configuration with the defaults will yield a similar configuration to that of the DAB-DETR
31
+ [IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
+ API.
40
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
41
+ The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
+ case it will default to `ResNetConfig()`.
43
+ backbone (`str`, *optional*, defaults to `"resnet50"`):
44
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
45
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
46
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
47
+ use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
48
+ Whether to use pretrained weights for the backbone.
49
+ backbone_kwargs (`dict`, *optional*):
50
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
51
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
52
+ num_queries (`int`, *optional*, defaults to 300):
53
+ Number of object queries, i.e. detection slots. This is the maximal number of objects
54
+ [`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
55
+ encoder_layers (`int`, *optional*, defaults to 6):
56
+ Number of encoder layers.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 2048):
58
+ Dimension of the "intermediate" (often named feed-forward) layer in encoder.
59
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ decoder_layers (`int`, *optional*, defaults to 6):
62
+ Number of decoder layers.
63
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
64
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
65
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
66
+ Number of attention heads for each attention layer in the Transformer decoder.
67
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
68
+ Indicates whether the transformer model architecture is an encoder-decoder or not.
69
+ activation_function (`str` or `function`, *optional*, defaults to `"prelu"`):
70
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
71
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
72
+ hidden_size (`int`, *optional*, defaults to 256):
73
+ This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
74
+ dropout (`float`, *optional*, defaults to 0.1):
75
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
76
+ attention_dropout (`float`, *optional*, defaults to 0.0):
77
+ The dropout ratio for the attention probabilities.
78
+ activation_dropout (`float`, *optional*, defaults to 0.0):
79
+ The dropout ratio for activations inside the fully connected layer.
80
+ init_std (`float`, *optional*, defaults to 0.02):
81
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
82
+ init_xavier_std (`float`, *optional*, defaults to 1.0):
83
+ The scaling factor used for the Xavier initialization gain in the HM Attention map module.
84
+ auxiliary_loss (`bool`, *optional*, defaults to `False`):
85
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
86
+ dilation (`bool`, *optional*, defaults to `False`):
87
+ Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`.
88
+ class_cost (`float`, *optional*, defaults to 2):
89
+ Relative weight of the classification error in the Hungarian matching cost.
90
+ bbox_cost (`float`, *optional*, defaults to 5):
91
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
92
+ giou_cost (`float`, *optional*, defaults to 2):
93
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
94
+ cls_loss_coefficient (`float`, *optional*, defaults to 2):
95
+ Relative weight of the classification loss in the object detection loss function.
96
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
97
+ Relative weight of the L1 bounding box loss in the object detection loss.
98
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
99
+ Relative weight of the generalized IoU loss in the object detection loss.
100
+ focal_alpha (`float`, *optional*, defaults to 0.25):
101
+ Alpha parameter in the focal loss.
102
+ temperature_height (`int`, *optional*, defaults to 20):
103
+ Temperature parameter to tune the flatness of positional attention (HEIGHT)
104
+ temperature_width (`int`, *optional*, defaults to 20):
105
+ Temperature parameter to tune the flatness of positional attention (WIDTH)
106
+ query_dim (`int`, *optional*, defaults to 4):
107
+ Query dimension parameter represents the size of the output vector.
108
+ random_refpoints_xy (`bool`, *optional*, defaults to `False`):
109
+ Whether to fix the x and y coordinates of the anchor boxes with random initialization.
110
+ keep_query_pos (`bool`, *optional*, defaults to `False`):
111
+ Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer.
112
+ num_patterns (`int`, *optional*, defaults to 0):
113
+ Number of pattern embeddings.
114
+ normalize_before (`bool`, *optional*, defaults to `False`):
115
+ Whether we use a normalization layer in the Encoder or not.
116
+ sine_position_embedding_scale (`float`, *optional*, defaults to 'None'):
117
+ Scaling factor applied to the normalized positional encodings.
118
+ initializer_bias_prior_prob (`float`, *optional*):
119
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
120
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
121
+
122
+
123
+ Examples:
124
+
125
+ ```python
126
+ >>> from transformers import DabDetrConfig, DabDetrModel
127
+
128
+ >>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration
129
+ >>> configuration = DabDetrConfig()
130
+
131
+ >>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration
132
+ >>> model = DabDetrModel(configuration)
133
+
134
+ >>> # Accessing the model configuration
135
+ >>> configuration = model.config
136
+ ```"""
137
+
138
+ model_type = "dab-detr"
139
+ keys_to_ignore_at_inference = ["past_key_values"]
140
+ attribute_map = {
141
+ "num_attention_heads": "encoder_attention_heads",
142
+ }
143
+
144
+ def __init__(
145
+ self,
146
+ use_timm_backbone=True,
147
+ backbone_config=None,
148
+ backbone="resnet50",
149
+ use_pretrained_backbone=True,
150
+ backbone_kwargs=None,
151
+ num_queries=300,
152
+ encoder_layers=6,
153
+ encoder_ffn_dim=2048,
154
+ encoder_attention_heads=8,
155
+ decoder_layers=6,
156
+ decoder_ffn_dim=2048,
157
+ decoder_attention_heads=8,
158
+ is_encoder_decoder=True,
159
+ activation_function="prelu",
160
+ hidden_size=256,
161
+ dropout=0.1,
162
+ attention_dropout=0.0,
163
+ activation_dropout=0.0,
164
+ init_std=0.02,
165
+ init_xavier_std=1.0,
166
+ auxiliary_loss=False,
167
+ dilation=False,
168
+ class_cost=2,
169
+ bbox_cost=5,
170
+ giou_cost=2,
171
+ cls_loss_coefficient=2,
172
+ bbox_loss_coefficient=5,
173
+ giou_loss_coefficient=2,
174
+ focal_alpha=0.25,
175
+ temperature_height=20,
176
+ temperature_width=20,
177
+ query_dim=4,
178
+ random_refpoints_xy=False,
179
+ keep_query_pos=False,
180
+ num_patterns=0,
181
+ normalize_before=False,
182
+ sine_position_embedding_scale=None,
183
+ initializer_bias_prior_prob=None,
184
+ **kwargs,
185
+ ):
186
+ if query_dim != 4:
187
+ raise ValueError("The query dimensions has to be 4.")
188
+
189
+ # We default to values which were previously hard-coded in the model. This enables configurability of the config
190
+ # while keeping the default behavior the same.
191
+ if use_timm_backbone and backbone_kwargs is None:
192
+ backbone_kwargs = {}
193
+ if dilation:
194
+ backbone_kwargs["output_stride"] = 16
195
+ backbone_kwargs["out_indices"] = [1, 2, 3, 4]
196
+ backbone_kwargs["in_chans"] = 3 # num_channels
197
+ # Backwards compatibility
198
+ elif not use_timm_backbone and backbone in (None, "resnet50"):
199
+ if backbone_config is None:
200
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
201
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
202
+ elif isinstance(backbone_config, dict):
203
+ backbone_model_type = backbone_config.get("model_type")
204
+ config_class = CONFIG_MAPPING[backbone_model_type]
205
+ backbone_config = config_class.from_dict(backbone_config)
206
+ backbone = None
207
+ # set timm attributes to None
208
+ dilation = None
209
+
210
+ verify_backbone_config_arguments(
211
+ use_timm_backbone=use_timm_backbone,
212
+ use_pretrained_backbone=use_pretrained_backbone,
213
+ backbone=backbone,
214
+ backbone_config=backbone_config,
215
+ backbone_kwargs=backbone_kwargs,
216
+ )
217
+
218
+ self.use_timm_backbone = use_timm_backbone
219
+ self.backbone_config = backbone_config
220
+ self.num_queries = num_queries
221
+ self.hidden_size = hidden_size
222
+ self.encoder_ffn_dim = encoder_ffn_dim
223
+ self.encoder_layers = encoder_layers
224
+ self.encoder_attention_heads = encoder_attention_heads
225
+ self.decoder_ffn_dim = decoder_ffn_dim
226
+ self.decoder_layers = decoder_layers
227
+ self.decoder_attention_heads = decoder_attention_heads
228
+ self.dropout = dropout
229
+ self.attention_dropout = attention_dropout
230
+ self.activation_dropout = activation_dropout
231
+ self.activation_function = activation_function
232
+ self.init_std = init_std
233
+ self.init_xavier_std = init_xavier_std
234
+ self.num_hidden_layers = encoder_layers
235
+ self.auxiliary_loss = auxiliary_loss
236
+ self.backbone = backbone
237
+ self.use_pretrained_backbone = use_pretrained_backbone
238
+ self.backbone_kwargs = backbone_kwargs
239
+ # Hungarian matcher
240
+ self.class_cost = class_cost
241
+ self.bbox_cost = bbox_cost
242
+ self.giou_cost = giou_cost
243
+ # Loss coefficients
244
+ self.cls_loss_coefficient = cls_loss_coefficient
245
+ self.bbox_loss_coefficient = bbox_loss_coefficient
246
+ self.giou_loss_coefficient = giou_loss_coefficient
247
+ self.focal_alpha = focal_alpha
248
+ self.query_dim = query_dim
249
+ self.random_refpoints_xy = random_refpoints_xy
250
+ self.keep_query_pos = keep_query_pos
251
+ self.num_patterns = num_patterns
252
+ self.normalize_before = normalize_before
253
+ self.temperature_width = temperature_width
254
+ self.temperature_height = temperature_height
255
+ self.sine_position_embedding_scale = sine_position_embedding_scale
256
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
257
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
258
+
259
+
260
+ __all__ = ["DabDetrConfig"]
docs/transformers/build/lib/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert DAB-DETR checkpoints."""
16
+
17
+ import argparse
18
+ import gc
19
+ import json
20
+ import re
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
27
+ from transformers.utils import logging
28
+
29
+
30
+ logging.set_verbosity_info()
31
+ logger = logging.get_logger(__name__)
32
+
33
+ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
34
+ # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
35
+ # for dab-DETR, also convert reference point head and query scale MLP
36
+ r"input_proj\.(bias|weight)": r"input_projection.\1",
37
+ r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
38
+ r"class_embed\.(bias|weight)": r"class_embed.\1",
39
+ # negative lookbehind because of the overlap
40
+ r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
41
+ r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
42
+ r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
43
+ r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
44
+ r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
45
+ r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
46
+ r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
47
+ r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
48
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
49
+ # output projection
50
+ r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
51
+ # FFN layers
52
+ r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
53
+ # normalization layers
54
+ # nm1
55
+ r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
56
+ # nm2
57
+ r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
58
+ # activation function weight
59
+ r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
60
+ #########################################################################################################################################
61
+ # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
62
+ r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
63
+ r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
64
+ # FFNs
65
+ r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
66
+ # nm1
67
+ r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
68
+ # nm2
69
+ r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
70
+ # nm3
71
+ r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
72
+ # activation function weight
73
+ r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
74
+ # q, k, v projections and biases in self-attention in decoder
75
+ r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
76
+ r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
77
+ r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
78
+ r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
79
+ r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
80
+ # q, k, v projections in cross-attention in decoder
81
+ r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
82
+ r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
83
+ r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
84
+ r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
85
+ r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
86
+ }
87
+
88
+
89
+ # Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
90
+ def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
91
+ """
92
+ This function should be applied only once, on the concatenated keys to efficiently rename using
93
+ the key mappings.
94
+ """
95
+ output_dict = {}
96
+ if state_dict_keys is not None:
97
+ old_text = "\n".join(state_dict_keys)
98
+ new_text = old_text
99
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
100
+ if replacement is None:
101
+ new_text = re.sub(pattern, "", new_text) # an empty line
102
+ continue
103
+ new_text = re.sub(pattern, replacement, new_text)
104
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
105
+ return output_dict
106
+
107
+
108
+ def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
109
+ logger.info("Converting image processor...")
110
+ format = "coco_detection"
111
+ image_processor = ConditionalDetrImageProcessor(format=format)
112
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
113
+ image_processor.save_pretrained(pytorch_dump_folder_path)
114
+
115
+ if push_to_hub:
116
+ image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
117
+
118
+
119
+ @torch.no_grad()
120
+ def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
121
+ # load modified config. Why? After loading the default config, the backbone kwargs are already set.
122
+ if "dc5" in model_name:
123
+ config = DabDetrConfig(dilation=True)
124
+ else:
125
+ # load default config
126
+ config = DabDetrConfig()
127
+ # set other attributes
128
+ if "dab-detr-resnet-50-dc5" == model_name:
129
+ config.temperature_height = 10
130
+ config.temperature_width = 10
131
+ if "fixxy" in model_name:
132
+ config.random_refpoints_xy = True
133
+ if "pat3" in model_name:
134
+ config.num_patterns = 3
135
+ # only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
136
+ ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
137
+
138
+ config.num_labels = 91
139
+ repo_id = "huggingface/label-files"
140
+ filename = "coco-detection-id2label.json"
141
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
142
+ id2label = {int(k): v for k, v in id2label.items()}
143
+ config.id2label = id2label
144
+ config.label2id = {v: k for k, v in id2label.items()}
145
+ # load original model from local path
146
+ loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"), weights_only=True)["model"]
147
+ # Renaming the original model state dictionary to HF compatibile
148
+ all_keys = list(loaded.keys())
149
+ new_keys = convert_old_keys_to_new_keys(all_keys)
150
+ state_dict = {}
151
+ for key in all_keys:
152
+ if "backbone.0.body" in key:
153
+ new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
154
+ state_dict[new_key] = loaded[key]
155
+ # Q, K, V encoder values mapping
156
+ elif re.search("self_attn.in_proj_(weight|bias)", key):
157
+ # Dynamically find the layer number
158
+ pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
159
+ match = re.search(pattern, key)
160
+ if match:
161
+ layer_num = match.group(1)
162
+ else:
163
+ raise ValueError(f"Pattern not found in key: {key}")
164
+
165
+ in_proj_value = loaded.pop(key)
166
+ if "weight" in key:
167
+ state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
168
+ state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
169
+ state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
170
+ elif "bias" in key:
171
+ state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
172
+ state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
173
+ state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
174
+ else:
175
+ new_key = new_keys[key]
176
+ state_dict[new_key] = loaded[key]
177
+
178
+ del loaded
179
+ gc.collect()
180
+ # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
181
+ prefix = "model."
182
+ for key in state_dict.copy().keys():
183
+ if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
184
+ val = state_dict.pop(key)
185
+ state_dict[prefix + key] = val
186
+ # finally, create HuggingFace model and load state dict
187
+ model = DabDetrForObjectDetection(config)
188
+ model.load_state_dict(state_dict)
189
+ model.eval()
190
+ logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
191
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
192
+ model.save_pretrained(pytorch_dump_folder_path)
193
+
194
+ if push_to_hub:
195
+ model.push_to_hub(repo_id=model_name, commit_message="Add new model")
196
+
197
+
198
+ def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
199
+ logger.info("Converting image processor...")
200
+ write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
201
+
202
+ logger.info(f"Converting model {model_name}...")
203
+ write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
204
+
205
+
206
+ if __name__ == "__main__":
207
+ parser = argparse.ArgumentParser()
208
+
209
+ parser.add_argument(
210
+ "--model_name",
211
+ default="dab-detr-resnet-50",
212
+ type=str,
213
+ help="Name of the DAB_DETR model you'd like to convert.",
214
+ )
215
+ parser.add_argument(
216
+ "--pretrained_model_weights_path",
217
+ default="modelzoo/R50/checkpoint.pth",
218
+ type=str,
219
+ help="The path of the original model weights like: modelzoo/checkpoint.pth",
220
+ )
221
+ parser.add_argument(
222
+ "--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
223
+ )
224
+ parser.add_argument(
225
+ "--push_to_hub",
226
+ default=True,
227
+ type=bool,
228
+ help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
229
+ )
230
+ args = parser.parse_args()
231
+ convert_dab_detr_checkpoint(
232
+ args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
233
+ )
docs/transformers/build/lib/transformers/models/dab_detr/modeling_dab_detr.py ADDED
@@ -0,0 +1,1716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 IDEA Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DAB-DETR model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, nn
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from ...utils.backbone_utils import load_backbone
36
+ from .configuration_dab_detr import DabDetrConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CONFIG_FOR_DOC = "DabDetrConfig"
42
+ _CHECKPOINT_FOR_DOC = "IDEA-Research/dab_detr-base"
43
+
44
+
45
+ @dataclass
46
+ # Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points)
47
+ class DabDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
48
+ """
49
+ Base class for outputs of the Conditional DETR decoder. This class adds one attribute to
50
+ BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output
51
+ of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary
52
+ decoding losses.
53
+
54
+ Args:
55
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
56
+ Sequence of hidden-states at the output of the last layer of the model.
57
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
58
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
59
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
60
+ plus the initial embedding outputs.
61
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
62
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
63
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
64
+ the self-attention heads.
65
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
66
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
67
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
68
+ used to compute the weighted average in the cross-attention heads.
69
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
70
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
71
+ layernorm.
72
+ reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
73
+ Reference points (reference points of each layer of the decoder).
74
+ """
75
+
76
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
77
+ reference_points: Optional[Tuple[torch.FloatTensor]] = None
78
+
79
+
80
+ @dataclass
81
+ # Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrModelOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points)
82
+ class DabDetrModelOutput(Seq2SeqModelOutput):
83
+ """
84
+ Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to
85
+ Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder
86
+ layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding
87
+ losses.
88
+
89
+ Args:
90
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
91
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
92
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
93
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
94
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
95
+ layer plus the initial embedding outputs.
96
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
97
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
98
+ sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
99
+ weighted average in the self-attention heads.
100
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
101
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
102
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
103
+ used to compute the weighted average in the cross-attention heads.
104
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
105
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
106
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
108
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
109
+ layer plus the initial embedding outputs.
110
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
111
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
112
+ sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
113
+ weighted average in the self-attention heads.
114
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
115
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
116
+ layernorm.
117
+ reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
118
+ Reference points (reference points of each layer of the decoder).
119
+ """
120
+
121
+ intermediate_hidden_states: Optional[torch.FloatTensor] = None
122
+ reference_points: Optional[Tuple[torch.FloatTensor]] = None
123
+
124
+
125
+ @dataclass
126
+ # Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->DabDetr
127
+ class DabDetrObjectDetectionOutput(ModelOutput):
128
+ """
129
+ Output type of [`DabDetrForObjectDetection`].
130
+
131
+ Args:
132
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
133
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
134
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
135
+ scale-invariant IoU loss.
136
+ loss_dict (`Dict`, *optional*):
137
+ A dictionary containing the individual losses. Useful for logging.
138
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
139
+ Classification logits (including no-object) for all queries.
140
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
141
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
142
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
143
+ possible padding). You can use [`~DabDetrImageProcessor.post_process_object_detection`] to retrieve the
144
+ unnormalized bounding boxes.
145
+ auxiliary_outputs (`list[Dict]`, *optional*):
146
+ Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
147
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
148
+ `pred_boxes`) for each decoder layer.
149
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
150
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
151
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
152
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
153
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
154
+ layer plus the initial embedding outputs.
155
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
156
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
157
+ sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
158
+ weighted average in the self-attention heads.
159
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
160
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
161
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
162
+ used to compute the weighted average in the cross-attention heads.
163
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
164
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
165
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
166
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
167
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
168
+ layer plus the initial embedding outputs.
169
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
170
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
171
+ sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
172
+ weighted average in the self-attention heads.
173
+ """
174
+
175
+ loss: Optional[torch.FloatTensor] = None
176
+ loss_dict: Optional[Dict] = None
177
+ logits: Optional[torch.FloatTensor] = None
178
+ pred_boxes: Optional[torch.FloatTensor] = None
179
+ auxiliary_outputs: Optional[List[Dict]] = None
180
+ last_hidden_state: Optional[torch.FloatTensor] = None
181
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
182
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
183
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
184
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
185
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
186
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
187
+
188
+
189
+ # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DabDetr
190
+ class DabDetrFrozenBatchNorm2d(nn.Module):
191
+ """
192
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
193
+
194
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
195
+ torchvision.models.resnet[18,34,50,101] produce nans.
196
+ """
197
+
198
+ def __init__(self, n):
199
+ super().__init__()
200
+ self.register_buffer("weight", torch.ones(n))
201
+ self.register_buffer("bias", torch.zeros(n))
202
+ self.register_buffer("running_mean", torch.zeros(n))
203
+ self.register_buffer("running_var", torch.ones(n))
204
+
205
+ def _load_from_state_dict(
206
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
207
+ ):
208
+ num_batches_tracked_key = prefix + "num_batches_tracked"
209
+ if num_batches_tracked_key in state_dict:
210
+ del state_dict[num_batches_tracked_key]
211
+
212
+ super()._load_from_state_dict(
213
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
214
+ )
215
+
216
+ def forward(self, x):
217
+ # move reshapes to the beginning
218
+ # to make it user-friendly
219
+ weight = self.weight.reshape(1, -1, 1, 1)
220
+ bias = self.bias.reshape(1, -1, 1, 1)
221
+ running_var = self.running_var.reshape(1, -1, 1, 1)
222
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
223
+ epsilon = 1e-5
224
+ scale = weight * (running_var + epsilon).rsqrt()
225
+ bias = bias - running_mean * scale
226
+ return x * scale + bias
227
+
228
+
229
+ # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DabDetr
230
+ def replace_batch_norm(model):
231
+ r"""
232
+ Recursively replace all `torch.nn.BatchNorm2d` with `DabDetrFrozenBatchNorm2d`.
233
+
234
+ Args:
235
+ model (torch.nn.Module):
236
+ input model
237
+ """
238
+ for name, module in model.named_children():
239
+ if isinstance(module, nn.BatchNorm2d):
240
+ new_module = DabDetrFrozenBatchNorm2d(module.num_features)
241
+
242
+ if not module.weight.device == torch.device("meta"):
243
+ new_module.weight.data.copy_(module.weight)
244
+ new_module.bias.data.copy_(module.bias)
245
+ new_module.running_mean.data.copy_(module.running_mean)
246
+ new_module.running_var.data.copy_(module.running_var)
247
+
248
+ model._modules[name] = new_module
249
+
250
+ if len(list(module.children())) > 0:
251
+ replace_batch_norm(module)
252
+
253
+
254
+ # Modified from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->DabDetr
255
+ class DabDetrConvEncoder(nn.Module):
256
+ """
257
+ Convolutional backbone, using either the AutoBackbone API or one from the timm library.
258
+
259
+ nn.BatchNorm2d layers are replaced by DabDetrFrozenBatchNorm2d as defined above.
260
+
261
+ """
262
+
263
+ def __init__(self, config: DabDetrConfig):
264
+ super().__init__()
265
+
266
+ self.config = config
267
+ backbone = load_backbone(config)
268
+
269
+ # replace batch norm by frozen batch norm
270
+ with torch.no_grad():
271
+ replace_batch_norm(backbone)
272
+ self.model = backbone
273
+ self.intermediate_channel_sizes = self.model.channels
274
+
275
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
276
+ # send pixel_values through the model to get list of feature maps
277
+ features = self.model(pixel_values).feature_maps
278
+
279
+ out = []
280
+ for feature_map in features:
281
+ # downsample pixel_mask to match shape of corresponding feature_map
282
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
283
+ out.append((feature_map, mask))
284
+ return out
285
+
286
+
287
+ # Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DabDetr
288
+ class DabDetrConvModel(nn.Module):
289
+ """
290
+ This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
291
+ """
292
+
293
+ def __init__(self, conv_encoder, position_embedding):
294
+ super().__init__()
295
+ self.conv_encoder = conv_encoder
296
+ self.position_embedding = position_embedding
297
+
298
+ def forward(self, pixel_values, pixel_mask):
299
+ # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
300
+ out = self.conv_encoder(pixel_values, pixel_mask)
301
+ pos = []
302
+ for feature_map, mask in out:
303
+ # position encoding
304
+ pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
305
+
306
+ return out, pos
307
+
308
+
309
+ # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrSinePositionEmbedding with ConditionalDetr->DabDetr
310
+ class DabDetrSinePositionEmbedding(nn.Module):
311
+ """
312
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
313
+ need paper, generalized to work on images.
314
+ """
315
+
316
+ def __init__(self, config: DabDetrConfig):
317
+ super().__init__()
318
+ self.config = config
319
+ self.embedding_dim = config.hidden_size / 2
320
+ self.temperature_height = config.temperature_height
321
+ self.temperature_width = config.temperature_width
322
+ scale = config.sine_position_embedding_scale
323
+ if scale is None:
324
+ scale = 2 * math.pi
325
+ self.scale = scale
326
+
327
+ def forward(self, pixel_values, pixel_mask):
328
+ if pixel_mask is None:
329
+ raise ValueError("No pixel mask provided")
330
+ y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
331
+ x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
332
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
333
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
334
+
335
+ # We use float32 to ensure reproducibility of the original implementation
336
+ dim_tx = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
337
+ # Modifying dim_tx in place to avoid extra memory allocation -> dim_tx = self.temperature_width ** (2 * (dim_tx // 2) / self.embedding_dim)
338
+ dim_tx //= 2
339
+ dim_tx.mul_(2 / self.embedding_dim)
340
+ dim_tx.copy_(self.temperature_width**dim_tx)
341
+ pos_x = x_embed[:, :, :, None] / dim_tx
342
+
343
+ # We use float32 to ensure reproducibility of the original implementation
344
+ dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
345
+ # Modifying dim_ty in place to avoid extra memory allocation -> dim_ty = self.temperature_height ** (2 * (dim_ty // 2) / self.embedding_dim)
346
+ dim_ty //= 2
347
+ dim_ty.mul_(2 / self.embedding_dim)
348
+ dim_ty.copy_(self.temperature_height**dim_ty)
349
+ pos_y = y_embed[:, :, :, None] / dim_ty
350
+
351
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
352
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
353
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
354
+ return pos
355
+
356
+
357
+ # function to generate sine positional embedding for 4d coordinates
358
+ def gen_sine_position_embeddings(pos_tensor, hidden_size=256):
359
+ """
360
+ This function computes position embeddings using sine and cosine functions from the input positional tensor,
361
+ which has a shape of (batch_size, num_queries, 4).
362
+ The last dimension of `pos_tensor` represents the following coordinates:
363
+ - 0: x-coord
364
+ - 1: y-coord
365
+ - 2: width
366
+ - 3: height
367
+
368
+ The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension
369
+ achieved by concatenating the sine and cosine values for each coordinate.
370
+ """
371
+ scale = 2 * math.pi
372
+ dim = hidden_size // 2
373
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
374
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
375
+ x_embed = pos_tensor[:, :, 0] * scale
376
+ y_embed = pos_tensor[:, :, 1] * scale
377
+ pos_x = x_embed[:, :, None] / dim_t
378
+ pos_y = y_embed[:, :, None] / dim_t
379
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
380
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
381
+ if pos_tensor.size(-1) == 4:
382
+ w_embed = pos_tensor[:, :, 2] * scale
383
+ pos_w = w_embed[:, :, None] / dim_t
384
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
385
+
386
+ h_embed = pos_tensor[:, :, 3] * scale
387
+ pos_h = h_embed[:, :, None] / dim_t
388
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
389
+
390
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
391
+ else:
392
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
393
+ return pos
394
+
395
+
396
+ def inverse_sigmoid(x, eps=1e-5):
397
+ x = x.clamp(min=0, max=1)
398
+ x1 = x.clamp(min=eps)
399
+ x2 = (1 - x).clamp(min=eps)
400
+ return torch.log(x1 / x2)
401
+
402
+
403
+ # Modified from transformers.models.detr.modeling_detr.DetrAttention
404
+ class DetrAttention(nn.Module):
405
+ """
406
+ Multi-headed attention from 'Attention Is All You Need' paper.
407
+
408
+ Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ config: DabDetrConfig,
414
+ bias: bool = True,
415
+ ):
416
+ super().__init__()
417
+ self.config = config
418
+ self.hidden_size = config.hidden_size
419
+ self.num_heads = config.encoder_attention_heads
420
+ self.attention_dropout = config.attention_dropout
421
+ self.head_dim = self.hidden_size // self.num_heads
422
+ if self.head_dim * self.num_heads != self.hidden_size:
423
+ raise ValueError(
424
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
425
+ f" {self.num_heads})."
426
+ )
427
+ self.scaling = self.head_dim**-0.5
428
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
429
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
430
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
431
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.Tensor,
436
+ attention_mask: Optional[torch.Tensor] = None,
437
+ object_queries: Optional[torch.Tensor] = None,
438
+ key_value_states: Optional[torch.Tensor] = None,
439
+ output_attentions: bool = False,
440
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
441
+ """Input shape: Batch x Time x Channel"""
442
+ batch_size, q_len, embed_dim = hidden_states.size()
443
+ # add position embeddings to the hidden states before projecting to queries and keys
444
+ if object_queries is not None:
445
+ hidden_states_original = hidden_states
446
+ hidden_states = hidden_states + object_queries
447
+
448
+ query_states = self.q_proj(hidden_states) * self.scaling
449
+ key_states = self.k_proj(hidden_states)
450
+ value_states = self.v_proj(hidden_states_original)
451
+
452
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
453
+ key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
454
+ value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
455
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
456
+
457
+ if attention_mask is not None:
458
+ attn_weights = attn_weights + attention_mask
459
+
460
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
461
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
462
+ attn_output = torch.matmul(attn_weights, value_states)
463
+
464
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
465
+ raise ValueError(
466
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
467
+ f" {attn_output.size()}"
468
+ )
469
+
470
+ attn_output = attn_output.transpose(1, 2).contiguous()
471
+
472
+ attn_output = attn_output.reshape(batch_size, q_len, embed_dim)
473
+ attn_output = self.out_proj(attn_output)
474
+
475
+ if not output_attentions:
476
+ attn_weights = None
477
+
478
+ return attn_output, attn_weights
479
+
480
+
481
+ # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrAttention with ConditionalDetr->DABDETR,Conditional DETR->DabDetr
482
+ class DabDetrAttention(nn.Module):
483
+ """
484
+ Cross-Attention used in DAB-DETR 'DAB-DETR for Fast Training Convergence' paper.
485
+
486
+ The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be
487
+ different to v.
488
+ """
489
+
490
+ def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = False):
491
+ super().__init__()
492
+ self.config = config
493
+ self.embed_dim = config.hidden_size * 2 if is_cross else config.hidden_size
494
+ self.output_dim = config.hidden_size
495
+ self.attention_heads = config.decoder_attention_heads
496
+ self.attention_dropout = config.attention_dropout
497
+ self.attention_head_dim = self.embed_dim // self.attention_heads
498
+ if self.attention_head_dim * self.attention_heads != self.embed_dim:
499
+ raise ValueError(
500
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `attention_heads`:"
501
+ f" {self.attention_heads})."
502
+ )
503
+ # head dimension of values
504
+ self.values_head_dim = self.output_dim // self.attention_heads
505
+ if self.values_head_dim * self.attention_heads != self.output_dim:
506
+ raise ValueError(
507
+ f"output_dim must be divisible by attention_heads (got `output_dim`: {self.output_dim} and `attention_heads`: {self.attention_heads})."
508
+ )
509
+ self.scaling = self.attention_head_dim**-0.5
510
+ self.output_proj = nn.Linear(self.output_dim, self.output_dim, bias=bias)
511
+
512
+ def forward(
513
+ self,
514
+ hidden_states: torch.Tensor,
515
+ attention_mask: Optional[torch.Tensor] = None,
516
+ key_states: Optional[torch.Tensor] = None,
517
+ value_states: Optional[torch.Tensor] = None,
518
+ output_attentions: Optional[bool] = None,
519
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
520
+ """Input shape: Batch x Time x Channel"""
521
+
522
+ batch_size, q_len, _ = hidden_states.size()
523
+
524
+ # scaling query and refactor key-, value states
525
+ query_states = hidden_states * self.scaling
526
+ query_states = query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2)
527
+ key_states = key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2)
528
+ value_states = value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2)
529
+
530
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
531
+
532
+ if attention_mask is not None:
533
+ attn_weights = attn_weights + attention_mask
534
+
535
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
536
+ attn_probs = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
537
+ attn_output = torch.matmul(attn_probs, value_states)
538
+
539
+ if attn_output.size() != (batch_size, self.attention_heads, q_len, self.values_head_dim):
540
+ raise ValueError(
541
+ f"`attn_output` should be of size {(batch_size, self.attention_heads, q_len, self.values_head_dim)}, but is"
542
+ f" {attn_output.size()}"
543
+ )
544
+
545
+ attn_output = attn_output.transpose(1, 2).contiguous()
546
+
547
+ attn_output = attn_output.reshape(batch_size, q_len, self.output_dim)
548
+ attn_output = self.output_proj(attn_output)
549
+
550
+ if not output_attentions:
551
+ attn_weights = None
552
+
553
+ return attn_output, attn_weights
554
+
555
+
556
+ class DabDetrDecoderLayerSelfAttention(nn.Module):
557
+ def __init__(self, config: DabDetrConfig):
558
+ super().__init__()
559
+ self.dropout = config.dropout
560
+ self.self_attn_query_content_proj = nn.Linear(config.hidden_size, config.hidden_size)
561
+ self.self_attn_query_pos_proj = nn.Linear(config.hidden_size, config.hidden_size)
562
+ self.self_attn_key_content_proj = nn.Linear(config.hidden_size, config.hidden_size)
563
+ self.self_attn_key_pos_proj = nn.Linear(config.hidden_size, config.hidden_size)
564
+ self.self_attn_value_proj = nn.Linear(config.hidden_size, config.hidden_size)
565
+ self.self_attn = DabDetrAttention(config)
566
+ self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
567
+
568
+ def forward(
569
+ self,
570
+ hidden_states: torch.Tensor,
571
+ query_position_embeddings: Optional[torch.Tensor] = None,
572
+ attention_mask: Optional[torch.Tensor] = None,
573
+ output_attentions: Optional[bool] = None,
574
+ ):
575
+ residual = hidden_states
576
+ query_content = self.self_attn_query_content_proj(hidden_states)
577
+ query_pos = self.self_attn_query_pos_proj(query_position_embeddings)
578
+ key_content = self.self_attn_key_content_proj(hidden_states)
579
+ key_pos = self.self_attn_key_pos_proj(query_position_embeddings)
580
+ value = self.self_attn_value_proj(hidden_states)
581
+
582
+ query = query_content + query_pos
583
+ key = key_content + key_pos
584
+
585
+ hidden_states, attn_weights = self.self_attn(
586
+ hidden_states=query,
587
+ attention_mask=attention_mask,
588
+ key_states=key,
589
+ value_states=value,
590
+ output_attentions=True,
591
+ )
592
+
593
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
594
+ hidden_states = residual + hidden_states
595
+ hidden_states = self.self_attn_layer_norm(hidden_states)
596
+
597
+ return hidden_states, attn_weights
598
+
599
+
600
+ class DabDetrDecoderLayerCrossAttention(nn.Module):
601
+ def __init__(self, config: DabDetrConfig, is_first: bool = False):
602
+ super().__init__()
603
+ hidden_size = config.hidden_size
604
+ self.cross_attn_query_content_proj = nn.Linear(hidden_size, hidden_size)
605
+ self.cross_attn_query_pos_proj = nn.Linear(hidden_size, hidden_size)
606
+ self.cross_attn_key_content_proj = nn.Linear(hidden_size, hidden_size)
607
+ self.cross_attn_key_pos_proj = nn.Linear(hidden_size, hidden_size)
608
+ self.cross_attn_value_proj = nn.Linear(hidden_size, hidden_size)
609
+ self.cross_attn_query_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
610
+ self.decoder_attention_heads = config.decoder_attention_heads
611
+ self.cross_attn_layer_norm = nn.LayerNorm(hidden_size)
612
+ self.cross_attn = DabDetrAttention(config, is_cross=True)
613
+
614
+ self.keep_query_pos = config.keep_query_pos
615
+
616
+ if not self.keep_query_pos and not is_first:
617
+ self.cross_attn_query_pos_proj = None
618
+
619
+ self.is_first = is_first
620
+ self.dropout = config.dropout
621
+
622
+ def forward(
623
+ self,
624
+ hidden_states: torch.Tensor,
625
+ encoder_hidden_states: Optional[torch.Tensor] = None,
626
+ query_position_embeddings: Optional[torch.Tensor] = None,
627
+ object_queries: Optional[torch.Tensor] = None,
628
+ encoder_attention_mask: Optional[torch.Tensor] = None,
629
+ query_sine_embed: Optional[torch.Tensor] = None,
630
+ output_attentions: Optional[bool] = None,
631
+ ):
632
+ query_content = self.cross_attn_query_content_proj(hidden_states)
633
+ key_content = self.cross_attn_key_content_proj(encoder_hidden_states)
634
+ value = self.cross_attn_value_proj(encoder_hidden_states)
635
+
636
+ batch_size, num_queries, n_model = query_content.shape
637
+ _, height_width, _ = key_content.shape
638
+
639
+ key_pos = self.cross_attn_key_pos_proj(object_queries)
640
+
641
+ # For the first decoder layer, we add the positional embedding predicted from
642
+ # the object query (the positional embedding) into the original query (key) in DETR.
643
+ if self.is_first or self.keep_query_pos:
644
+ query_pos = self.cross_attn_query_pos_proj(query_position_embeddings)
645
+ query = query_content + query_pos
646
+ key = key_content + key_pos
647
+ else:
648
+ query = query_content
649
+ key = key_content
650
+
651
+ query = query.view(
652
+ batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads
653
+ )
654
+ query_sine_embed = self.cross_attn_query_pos_sine_proj(query_sine_embed)
655
+ query_sine_embed = query_sine_embed.view(
656
+ batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads
657
+ )
658
+ query = torch.cat([query, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
659
+ key = key.view(batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads)
660
+ key_pos = key_pos.view(
661
+ batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads
662
+ )
663
+ key = torch.cat([key, key_pos], dim=3).view(batch_size, height_width, n_model * 2)
664
+
665
+ # Cross-Attention Block
666
+ cross_attn_weights = None
667
+ if encoder_hidden_states is not None:
668
+ residual = hidden_states
669
+
670
+ hidden_states, cross_attn_weights = self.cross_attn(
671
+ hidden_states=query,
672
+ attention_mask=encoder_attention_mask,
673
+ key_states=key,
674
+ value_states=value,
675
+ output_attentions=output_attentions,
676
+ )
677
+
678
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
679
+ hidden_states = residual + hidden_states
680
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
681
+
682
+ return hidden_states, cross_attn_weights
683
+
684
+
685
+ class DabDetrDecoderLayerFFN(nn.Module):
686
+ def __init__(self, config: DabDetrConfig):
687
+ super().__init__()
688
+ hidden_size = config.hidden_size
689
+ self.final_layer_norm = nn.LayerNorm(hidden_size)
690
+ self.fc1 = nn.Linear(hidden_size, config.decoder_ffn_dim)
691
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, hidden_size)
692
+ self.activation_fn = ACT2FN[config.activation_function]
693
+ self.dropout = config.dropout
694
+ self.activation_dropout = config.activation_dropout
695
+ self.keep_query_pos = config.keep_query_pos
696
+
697
+ def forward(self, hidden_states: torch.Tensor):
698
+ residual = hidden_states
699
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
700
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
701
+ hidden_states = self.fc2(hidden_states)
702
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
703
+ hidden_states = residual + hidden_states
704
+ hidden_states = self.final_layer_norm(hidden_states)
705
+
706
+ return hidden_states
707
+
708
+
709
+ # Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig
710
+ class DabDetrEncoderLayer(nn.Module):
711
+ def __init__(self, config: DabDetrConfig):
712
+ super().__init__()
713
+ self.hidden_size = config.hidden_size
714
+ self.self_attn = DetrAttention(config)
715
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
716
+ self.dropout = config.dropout
717
+ self.activation_fn = ACT2FN[config.activation_function]
718
+ self.fc1 = nn.Linear(self.hidden_size, config.encoder_ffn_dim)
719
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.hidden_size)
720
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
721
+
722
+ def forward(
723
+ self,
724
+ hidden_states: torch.Tensor,
725
+ attention_mask: torch.Tensor,
726
+ object_queries: torch.Tensor,
727
+ output_attentions: Optional[bool] = None,
728
+ ):
729
+ """
730
+ Args:
731
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
732
+ attention_mask (`torch.FloatTensor`): attention mask of size
733
+ `(batch, source_len)` where padding elements are indicated by very large negative
734
+ values.
735
+ object_queries (`torch.FloatTensor`, *optional*):
736
+ Object queries (also called content embeddings), to be added to the hidden states.
737
+ output_attentions (`bool`, *optional*):
738
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
739
+ returned tensors for more detail.
740
+ """
741
+ residual = hidden_states
742
+ hidden_states, attn_weights = self.self_attn(
743
+ hidden_states=hidden_states,
744
+ attention_mask=attention_mask,
745
+ object_queries=object_queries,
746
+ output_attentions=output_attentions,
747
+ )
748
+
749
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
750
+ hidden_states = residual + hidden_states
751
+ hidden_states = self.self_attn_layer_norm(hidden_states)
752
+
753
+ residual = hidden_states
754
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
755
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
756
+
757
+ hidden_states = self.fc2(hidden_states)
758
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
759
+
760
+ hidden_states = residual + hidden_states
761
+ hidden_states = self.final_layer_norm(hidden_states)
762
+
763
+ outputs = (hidden_states,)
764
+
765
+ if output_attentions:
766
+ outputs += (attn_weights,)
767
+
768
+ return outputs
769
+
770
+
771
+ # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr
772
+ class DabDetrDecoderLayer(nn.Module):
773
+ def __init__(self, config: DabDetrConfig, is_first: bool = False):
774
+ super().__init__()
775
+ self.self_attn = DabDetrDecoderLayerSelfAttention(config)
776
+ self.cross_attn = DabDetrDecoderLayerCrossAttention(config, is_first)
777
+ self.mlp = DabDetrDecoderLayerFFN(config)
778
+
779
+ def forward(
780
+ self,
781
+ hidden_states: torch.Tensor,
782
+ attention_mask: Optional[torch.Tensor] = None,
783
+ object_queries: Optional[torch.Tensor] = None,
784
+ query_position_embeddings: Optional[torch.Tensor] = None,
785
+ query_sine_embed: Optional[torch.Tensor] = None,
786
+ encoder_hidden_states: Optional[torch.Tensor] = None,
787
+ encoder_attention_mask: Optional[torch.Tensor] = None,
788
+ output_attentions: Optional[bool] = None,
789
+ ):
790
+ """
791
+ Args:
792
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
793
+ attention_mask (`torch.FloatTensor`): attention mask of size
794
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
795
+ values.
796
+ object_queries (`torch.FloatTensor`, *optional*):
797
+ object_queries that are added to the queries and keys
798
+ in the cross-attention layer.
799
+ query_position_embeddings (`torch.FloatTensor`, *optional*):
800
+ object_queries that are added to the queries and keys
801
+ in the self-attention layer.
802
+ encoder_hidden_states (`torch.FloatTensor`):
803
+ cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
804
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
805
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
806
+ values.
807
+ output_attentions (`bool`, *optional*):
808
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
809
+ returned tensors for more detail.
810
+
811
+ """
812
+ hidden_states, self_attn_weights = self.self_attn(
813
+ hidden_states=hidden_states,
814
+ query_position_embeddings=query_position_embeddings,
815
+ attention_mask=attention_mask,
816
+ output_attentions=output_attentions,
817
+ )
818
+
819
+ hidden_states, cross_attn_weights = self.cross_attn(
820
+ hidden_states=hidden_states,
821
+ encoder_hidden_states=encoder_hidden_states,
822
+ query_position_embeddings=query_position_embeddings,
823
+ object_queries=object_queries,
824
+ encoder_attention_mask=encoder_attention_mask,
825
+ query_sine_embed=query_sine_embed,
826
+ output_attentions=output_attentions,
827
+ )
828
+
829
+ hidden_states = self.mlp(hidden_states=hidden_states)
830
+
831
+ outputs = (hidden_states,)
832
+
833
+ if output_attentions:
834
+ outputs += (self_attn_weights, cross_attn_weights)
835
+
836
+ return outputs
837
+
838
+
839
+ # Modified from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->DabDetrMLP
840
+ class DabDetrMLP(nn.Module):
841
+ """
842
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
843
+ height and width of a bounding box w.r.t. an image.
844
+
845
+ Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
846
+
847
+ """
848
+
849
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
850
+ super().__init__()
851
+ self.num_layers = num_layers
852
+ h = [hidden_dim] * (num_layers - 1)
853
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
854
+
855
+ def forward(self, input_tensor):
856
+ for i, layer in enumerate(self.layers):
857
+ input_tensor = nn.functional.relu(layer(input_tensor)) if i < self.num_layers - 1 else layer(input_tensor)
858
+ return input_tensor
859
+
860
+
861
+ # Modified from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->DabDetr
862
+ class DabDetrPreTrainedModel(PreTrainedModel):
863
+ config_class = DabDetrConfig
864
+ base_model_prefix = "model"
865
+ main_input_name = "pixel_values"
866
+ _no_split_modules = [r"DabDetrConvEncoder", r"DabDetrEncoderLayer", r"DabDetrDecoderLayer"]
867
+
868
+ def _init_weights(self, module):
869
+ std = self.config.init_std
870
+ xavier_std = self.config.init_xavier_std
871
+
872
+ if isinstance(module, DabDetrMHAttentionMap):
873
+ nn.init.zeros_(module.k_linear.bias)
874
+ nn.init.zeros_(module.q_linear.bias)
875
+ nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
876
+ nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
877
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
878
+ # Slightly different from the TF version which uses truncated_normal for initialization
879
+ # cf https://github.com/pytorch/pytorch/pull/5617
880
+ module.weight.data.normal_(mean=0.0, std=std)
881
+ if module.bias is not None:
882
+ module.bias.data.zero_()
883
+ elif isinstance(module, nn.Embedding):
884
+ module.weight.data.normal_(mean=0.0, std=std)
885
+ if module.padding_idx is not None:
886
+ module.weight.data[module.padding_idx].zero_()
887
+ elif isinstance(module, DabDetrForObjectDetection):
888
+ nn.init.constant_(module.bbox_predictor.layers[-1].weight.data, 0)
889
+ nn.init.constant_(module.bbox_predictor.layers[-1].bias.data, 0)
890
+
891
+ # init prior_prob setting for focal loss
892
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
893
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
894
+ module.class_embed.bias.data.fill_(bias_value)
895
+
896
+
897
+ DAB_DETR_START_DOCSTRING = r"""
898
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
899
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
900
+ etc.)
901
+
902
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
903
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
904
+ and behavior.
905
+
906
+ Parameters:
907
+ config ([`DabDetrConfig`]):
908
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
909
+ load the weights associated with the model, only the configuration. Check out the
910
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
911
+ """
912
+
913
+ DAB_DETR_INPUTS_DOCSTRING = r"""
914
+ Args:
915
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
916
+ Pixel values. Padding will be ignored by default should you provide it.
917
+
918
+ Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`]
919
+ for details.
920
+
921
+ pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
922
+ Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
923
+
924
+ - 1 for pixels that are real (i.e. **not masked**),
925
+ - 0 for pixels that are padding (i.e. **masked**).
926
+
927
+ [What are attention masks?](../glossary#attention-mask)
928
+
929
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
930
+ Not used by default. Can be used to mask object queries.
931
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
932
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
933
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
934
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
935
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
936
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
937
+ can choose to directly pass a flattened representation of an image.
938
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
939
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
940
+ embedded representation.
941
+ output_attentions (`bool`, *optional*):
942
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
943
+ tensors for more detail.
944
+ output_hidden_states (`bool`, *optional*):
945
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
946
+ more detail.
947
+ return_dict (`bool`, *optional*):
948
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
949
+ """
950
+
951
+
952
+ # Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR
953
+ class DabDetrEncoder(DabDetrPreTrainedModel):
954
+ """
955
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
956
+ [`DabDetrEncoderLayer`].
957
+
958
+ The encoder updates the flattened feature map through multiple self-attention layers.
959
+
960
+ Small tweak for DAB-DETR:
961
+
962
+ - object_queries are added to the forward pass.
963
+
964
+ Args:
965
+ config: DabDetrConfig
966
+ """
967
+
968
+ def __init__(self, config: DabDetrConfig):
969
+ super().__init__(config)
970
+
971
+ self.dropout = config.dropout
972
+ self.query_scale = DabDetrMLP(config.hidden_size, config.hidden_size, config.hidden_size, 2)
973
+ self.layers = nn.ModuleList([DabDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
974
+ self.norm = nn.LayerNorm(config.hidden_size) if config.normalize_before else None
975
+ self.gradient_checkpointing = False
976
+
977
+ # Initialize weights and apply final processing
978
+ self.post_init()
979
+
980
+ def forward(
981
+ self,
982
+ inputs_embeds,
983
+ attention_mask,
984
+ object_queries,
985
+ output_attentions: Optional[bool] = None,
986
+ output_hidden_states: Optional[bool] = None,
987
+ return_dict: Optional[bool] = None,
988
+ ):
989
+ r"""
990
+ Args:
991
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
992
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
993
+
994
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
995
+ Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
996
+
997
+ - 1 for pixel features that are real (i.e. **not masked**),
998
+ - 0 for pixel features that are padding (i.e. **masked**).
999
+
1000
+ [What are attention masks?](../glossary#attention-mask)
1001
+
1002
+ object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
1003
+ Object queries that are added to the queries in each self-attention layer.
1004
+
1005
+ output_attentions (`bool`, *optional*):
1006
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1007
+ returned tensors for more detail.
1008
+ output_hidden_states (`bool`, *optional*):
1009
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1010
+ for more detail.
1011
+ return_dict (`bool`, *optional*):
1012
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1013
+ """
1014
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1015
+ output_hidden_states = (
1016
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1017
+ )
1018
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1019
+
1020
+ hidden_states = inputs_embeds
1021
+
1022
+ # expand attention_mask
1023
+ if attention_mask is not None:
1024
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1025
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1026
+
1027
+ encoder_states = () if output_hidden_states else None
1028
+ all_attentions = () if output_attentions else None
1029
+
1030
+ for encoder_layer in self.layers:
1031
+ if output_hidden_states:
1032
+ encoder_states = encoder_states + (hidden_states,)
1033
+ # pos scaler
1034
+ pos_scales = self.query_scale(hidden_states)
1035
+ # we add object_queries * pos_scaler as extra input to the encoder_layer
1036
+ scaled_object_queries = object_queries * pos_scales
1037
+
1038
+ if self.gradient_checkpointing and self.training:
1039
+ layer_outputs = self._gradient_checkpointing_func(
1040
+ encoder_layer.__call__,
1041
+ hidden_states,
1042
+ attention_mask,
1043
+ scaled_object_queries,
1044
+ output_attentions,
1045
+ )
1046
+ else:
1047
+ layer_outputs = encoder_layer(
1048
+ hidden_states,
1049
+ attention_mask=attention_mask,
1050
+ object_queries=scaled_object_queries,
1051
+ output_attentions=output_attentions,
1052
+ )
1053
+
1054
+ hidden_states = layer_outputs[0]
1055
+
1056
+ if output_attentions:
1057
+ all_attentions = all_attentions + (layer_outputs[1],)
1058
+
1059
+ if self.norm:
1060
+ hidden_states = self.norm(hidden_states)
1061
+
1062
+ if output_hidden_states:
1063
+ encoder_states = encoder_states + (hidden_states,)
1064
+
1065
+ if not return_dict:
1066
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1067
+ return BaseModelOutput(
1068
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1069
+ )
1070
+
1071
+
1072
+ # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoder with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR
1073
+ class DabDetrDecoder(DabDetrPreTrainedModel):
1074
+ """
1075
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DabDetrDecoderLayer`].
1076
+
1077
+ The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
1078
+
1079
+ Some small tweaks for DAB-DETR:
1080
+
1081
+ - object_queries and query_position_embeddings are added to the forward pass.
1082
+ - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
1083
+
1084
+ Args:
1085
+ config: DabDetrConfig
1086
+ """
1087
+
1088
+ def __init__(self, config: DabDetrConfig):
1089
+ super().__init__(config)
1090
+ self.config = config
1091
+ self.dropout = config.dropout
1092
+ self.num_layers = config.decoder_layers
1093
+ self.gradient_checkpointing = False
1094
+
1095
+ self.layers = nn.ModuleList(
1096
+ [DabDetrDecoderLayer(config, is_first=(layer_id == 0)) for layer_id in range(config.decoder_layers)]
1097
+ )
1098
+ # in DAB-DETR, the decoder uses layernorm after the last decoder layer output
1099
+ self.hidden_size = config.hidden_size
1100
+ self.layernorm = nn.LayerNorm(self.hidden_size)
1101
+
1102
+ # Default cond-elewise
1103
+ self.query_scale = DabDetrMLP(self.hidden_size, self.hidden_size, self.hidden_size, 2)
1104
+
1105
+ self.ref_point_head = DabDetrMLP(
1106
+ config.query_dim // 2 * self.hidden_size, self.hidden_size, self.hidden_size, 2
1107
+ )
1108
+
1109
+ self.bbox_embed = None
1110
+
1111
+ # Default decoder_modulate_hw_attn is True
1112
+ self.ref_anchor_head = DabDetrMLP(self.hidden_size, self.hidden_size, 2, 2)
1113
+
1114
+ # Initialize weights and apply final processing
1115
+ self.post_init()
1116
+
1117
+ def forward(
1118
+ self,
1119
+ inputs_embeds,
1120
+ encoder_hidden_states,
1121
+ memory_key_padding_mask,
1122
+ object_queries,
1123
+ query_position_embeddings,
1124
+ output_attentions: Optional[bool] = None,
1125
+ output_hidden_states: Optional[bool] = None,
1126
+ return_dict: Optional[bool] = None,
1127
+ ):
1128
+ r"""
1129
+ Args:
1130
+ inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
1131
+ The query embeddings that are passed into the decoder.
1132
+ encoder_hidden_states (`torch.FloatTensor` of shape `(encoder_sequence_length, batch_size, hidden_size)`, *optional*):
1133
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1134
+ of the decoder.
1135
+ memory_key_padding_mask (`torch.Tensor.bool` of shape `(batch_size, sequence_length)`):
1136
+ The memory_key_padding_mask indicates which positions in the memory (encoder outputs) should be ignored during the attention computation,
1137
+ ensuring padding tokens do not influence the attention mechanism.
1138
+ object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`, *optional*):
1139
+ Position embeddings that are added to the queries and keys in each cross-attention layer.
1140
+ query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, number_of_anchor_points)`):
1141
+ Position embeddings that are added to the queries and keys in each self-attention layer.
1142
+ output_attentions (`bool`, *optional*):
1143
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1144
+ returned tensors for more detail.
1145
+ output_hidden_states (`bool`, *optional*):
1146
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1147
+ for more detail.
1148
+ return_dict (`bool`, *optional*):
1149
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1150
+ """
1151
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1152
+ output_hidden_states = (
1153
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1154
+ )
1155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1156
+
1157
+ if inputs_embeds is not None:
1158
+ hidden_states = inputs_embeds
1159
+ input_shape = inputs_embeds.size()[:-1]
1160
+
1161
+ # decoder layers
1162
+ all_hidden_states = () if output_hidden_states else None
1163
+ all_self_attns = () if output_attentions else None
1164
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1165
+
1166
+ intermediate = []
1167
+ reference_points = query_position_embeddings.sigmoid()
1168
+ ref_points = [reference_points]
1169
+
1170
+ # expand encoder attention mask
1171
+ if encoder_hidden_states is not None and memory_key_padding_mask is not None:
1172
+ # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1173
+ memory_key_padding_mask = _prepare_4d_attention_mask(
1174
+ memory_key_padding_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1175
+ )
1176
+
1177
+ for layer_id, decoder_layer in enumerate(self.layers):
1178
+ if output_hidden_states:
1179
+ all_hidden_states += (hidden_states,)
1180
+
1181
+ obj_center = reference_points[..., : self.config.query_dim]
1182
+ query_sine_embed = gen_sine_position_embeddings(obj_center, self.hidden_size)
1183
+ query_pos = self.ref_point_head(query_sine_embed)
1184
+
1185
+ # For the first decoder layer, we do not apply transformation over p_s
1186
+ pos_transformation = 1 if layer_id == 0 else self.query_scale(hidden_states)
1187
+
1188
+ # apply transformation
1189
+ query_sine_embed = query_sine_embed[..., : self.hidden_size] * pos_transformation
1190
+
1191
+ # modulated Height Width attentions
1192
+ reference_anchor_size = self.ref_anchor_head(hidden_states).sigmoid() # nq, bs, 2
1193
+ query_sine_embed[..., self.hidden_size // 2 :] *= (
1194
+ reference_anchor_size[..., 0] / obj_center[..., 2]
1195
+ ).unsqueeze(-1)
1196
+ query_sine_embed[..., : self.hidden_size // 2] *= (
1197
+ reference_anchor_size[..., 1] / obj_center[..., 3]
1198
+ ).unsqueeze(-1)
1199
+
1200
+ if self.gradient_checkpointing and self.training:
1201
+ layer_outputs = self._gradient_checkpointing_func(
1202
+ decoder_layer.__call__,
1203
+ hidden_states,
1204
+ None,
1205
+ object_queries,
1206
+ query_pos,
1207
+ query_sine_embed,
1208
+ encoder_hidden_states,
1209
+ memory_key_padding_mask,
1210
+ output_attentions,
1211
+ )
1212
+ else:
1213
+ layer_outputs = decoder_layer(
1214
+ hidden_states,
1215
+ attention_mask=None,
1216
+ object_queries=object_queries,
1217
+ query_position_embeddings=query_pos,
1218
+ query_sine_embed=query_sine_embed,
1219
+ encoder_hidden_states=encoder_hidden_states,
1220
+ encoder_attention_mask=memory_key_padding_mask,
1221
+ output_attentions=output_attentions,
1222
+ )
1223
+
1224
+ # iter update
1225
+ hidden_states = layer_outputs[0]
1226
+
1227
+ if self.bbox_embed is not None:
1228
+ new_reference_points = self.bbox_embed(hidden_states)
1229
+
1230
+ new_reference_points[..., : self.config.query_dim] += inverse_sigmoid(reference_points)
1231
+ new_reference_points = new_reference_points[..., : self.config.query_dim].sigmoid()
1232
+ if layer_id != self.num_layers - 1:
1233
+ ref_points.append(new_reference_points)
1234
+ reference_points = new_reference_points.detach()
1235
+
1236
+ intermediate.append(self.layernorm(hidden_states))
1237
+
1238
+ if output_attentions:
1239
+ all_self_attns += (layer_outputs[1],)
1240
+
1241
+ if encoder_hidden_states is not None:
1242
+ all_cross_attentions += (layer_outputs[2],)
1243
+
1244
+ # Layer normalization on hidden states
1245
+ hidden_states = self.layernorm(hidden_states)
1246
+
1247
+ if output_hidden_states:
1248
+ all_hidden_states += (hidden_states,)
1249
+
1250
+ output_intermediate_hidden_states = torch.stack(intermediate)
1251
+ output_reference_points = torch.stack(ref_points)
1252
+
1253
+ if not return_dict:
1254
+ return tuple(
1255
+ v
1256
+ for v in [
1257
+ hidden_states,
1258
+ all_hidden_states,
1259
+ all_self_attns,
1260
+ all_cross_attentions,
1261
+ output_intermediate_hidden_states,
1262
+ output_reference_points,
1263
+ ]
1264
+ if v is not None
1265
+ )
1266
+ return DabDetrDecoderOutput(
1267
+ last_hidden_state=hidden_states,
1268
+ hidden_states=all_hidden_states,
1269
+ attentions=all_self_attns,
1270
+ cross_attentions=all_cross_attentions,
1271
+ intermediate_hidden_states=output_intermediate_hidden_states,
1272
+ reference_points=output_reference_points,
1273
+ )
1274
+
1275
+
1276
+ @add_start_docstrings(
1277
+ """
1278
+ The bare DAB-DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
1279
+ hidden-states, intermediate hidden states, reference points, output coordinates without any specific head on top.
1280
+ """,
1281
+ DAB_DETR_START_DOCSTRING,
1282
+ )
1283
+ class DabDetrModel(DabDetrPreTrainedModel):
1284
+ def __init__(self, config: DabDetrConfig):
1285
+ super().__init__(config)
1286
+
1287
+ self.auxiliary_loss = config.auxiliary_loss
1288
+
1289
+ # Create backbone + positional encoding
1290
+ self.backbone = DabDetrConvEncoder(config)
1291
+ object_queries = DabDetrSinePositionEmbedding(config)
1292
+
1293
+ self.query_refpoint_embeddings = nn.Embedding(config.num_queries, config.query_dim)
1294
+ self.random_refpoints_xy = config.random_refpoints_xy
1295
+ if self.random_refpoints_xy:
1296
+ self.query_refpoint_embeddings.weight.data[:, :2].uniform_(0, 1)
1297
+ self.query_refpoint_embeddings.weight.data[:, :2] = inverse_sigmoid(
1298
+ self.query_refpoint_embeddings.weight.data[:, :2]
1299
+ )
1300
+ self.query_refpoint_embeddings.weight.data[:, :2].requires_grad = False
1301
+
1302
+ # Create projection layer
1303
+ self.input_projection = nn.Conv2d(
1304
+ self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1
1305
+ )
1306
+ self.backbone = DabDetrConvModel(self.backbone, object_queries)
1307
+
1308
+ self.encoder = DabDetrEncoder(config)
1309
+ self.decoder = DabDetrDecoder(config)
1310
+
1311
+ # decoder related variables
1312
+ self.hidden_size = config.hidden_size
1313
+ self.num_queries = config.num_queries
1314
+
1315
+ self.num_patterns = config.num_patterns
1316
+ if not isinstance(self.num_patterns, int):
1317
+ logger.warning("num_patterns should be int but {}".format(type(self.num_patterns)))
1318
+ self.num_patterns = 0
1319
+ if self.num_patterns > 0:
1320
+ self.patterns = nn.Embedding(self.num_patterns, self.hidden_size)
1321
+
1322
+ self.aux_loss = config.auxiliary_loss
1323
+
1324
+ # Initialize weights and apply final processing
1325
+ self.post_init()
1326
+
1327
+ def get_encoder(self):
1328
+ return self.encoder
1329
+
1330
+ def get_decoder(self):
1331
+ return self.decoder
1332
+
1333
+ def freeze_backbone(self):
1334
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1335
+ param.requires_grad_(False)
1336
+
1337
+ def unfreeze_backbone(self):
1338
+ for name, param in self.backbone.conv_encoder.model.named_parameters():
1339
+ param.requires_grad_(True)
1340
+
1341
+ @add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING)
1342
+ @replace_return_docstrings(output_type=DabDetrModelOutput, config_class=_CONFIG_FOR_DOC)
1343
+ def forward(
1344
+ self,
1345
+ pixel_values: torch.FloatTensor,
1346
+ pixel_mask: Optional[torch.LongTensor] = None,
1347
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1348
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1349
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1350
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1351
+ output_attentions: Optional[bool] = None,
1352
+ output_hidden_states: Optional[bool] = None,
1353
+ return_dict: Optional[bool] = None,
1354
+ ) -> Union[Tuple[torch.FloatTensor], DabDetrModelOutput]:
1355
+ r"""
1356
+ Returns:
1357
+
1358
+ Examples:
1359
+
1360
+ ```python
1361
+ >>> from transformers import AutoImageProcessor, AutoModel
1362
+ >>> from PIL import Image
1363
+ >>> import requests
1364
+
1365
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1366
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1367
+
1368
+ >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab_detr-base")
1369
+ >>> model = AutoModel.from_pretrained("IDEA-Research/dab_detr-base")
1370
+
1371
+ >>> # prepare image for the model
1372
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1373
+
1374
+ >>> # forward pass
1375
+ >>> outputs = model(**inputs)
1376
+
1377
+ >>> # the last hidden states are the final query embeddings of the Transformer decoder
1378
+ >>> # these are of shape (batch_size, num_queries, hidden_size)
1379
+ >>> last_hidden_states = outputs.last_hidden_state
1380
+ >>> list(last_hidden_states.shape)
1381
+ [1, 300, 256]
1382
+ ```"""
1383
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1384
+ output_hidden_states = (
1385
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1386
+ )
1387
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1388
+
1389
+ batch_size, _, height, width = pixel_values.shape
1390
+ device = pixel_values.device
1391
+
1392
+ if pixel_mask is None:
1393
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1394
+
1395
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1396
+ # pixel_values should be of shape (batch_size, num_channels, height, width)
1397
+ # pixel_mask should be of shape (batch_size, height, width)
1398
+ features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1399
+
1400
+ # get final feature map and downsampled mask
1401
+ feature_map, mask = features[-1]
1402
+
1403
+ if mask is None:
1404
+ raise ValueError("Backbone does not return downsampled pixel mask")
1405
+
1406
+ flattened_mask = mask.flatten(1)
1407
+
1408
+ # Second, apply 1x1 convolution to reduce the channel dimension to hidden_size (256 by default)
1409
+ projected_feature_map = self.input_projection(feature_map)
1410
+
1411
+ # Third, flatten the feature map + object_queries of shape NxCxHxW to HWxNxC, and permute it to NxHWxC
1412
+ # In other words, turn their shape into ( sequence_length, batch_size, hidden_size)
1413
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1414
+ object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1415
+ reference_position_embeddings = self.query_refpoint_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1416
+
1417
+ # Fourth, sent flattened_features + flattened_mask + object_queries through encoder
1418
+ # flattened_features is a Tensor of shape (heigth*width, batch_size, hidden_size)
1419
+ # flattened_mask is a Tensor of shape (batch_size, heigth*width)
1420
+ if encoder_outputs is None:
1421
+ encoder_outputs = self.encoder(
1422
+ inputs_embeds=flattened_features,
1423
+ attention_mask=flattened_mask,
1424
+ object_queries=object_queries,
1425
+ output_attentions=output_attentions,
1426
+ output_hidden_states=output_hidden_states,
1427
+ return_dict=return_dict,
1428
+ )
1429
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1430
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1431
+ encoder_outputs = BaseModelOutput(
1432
+ last_hidden_state=encoder_outputs[0],
1433
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1434
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1435
+ )
1436
+
1437
+ # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1438
+ num_queries = reference_position_embeddings.shape[1]
1439
+ if self.num_patterns == 0:
1440
+ queries = torch.zeros(batch_size, num_queries, self.hidden_size, device=device)
1441
+ else:
1442
+ queries = (
1443
+ self.patterns.weight[:, None, None, :]
1444
+ .repeat(1, self.num_queries, batch_size, 1)
1445
+ .flatten(0, 1)
1446
+ .permute(1, 0, 2)
1447
+ ) # bs, n_q*n_pat, hidden_size
1448
+ reference_position_embeddings = reference_position_embeddings.repeat(
1449
+ 1, self.num_patterns, 1
1450
+ ) # bs, n_q*n_pat, hidden_size
1451
+
1452
+ # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1453
+ decoder_outputs = self.decoder(
1454
+ inputs_embeds=queries,
1455
+ query_position_embeddings=reference_position_embeddings,
1456
+ object_queries=object_queries,
1457
+ encoder_hidden_states=encoder_outputs[0],
1458
+ memory_key_padding_mask=flattened_mask,
1459
+ output_attentions=output_attentions,
1460
+ output_hidden_states=output_hidden_states,
1461
+ return_dict=return_dict,
1462
+ )
1463
+
1464
+ if not return_dict:
1465
+ # last_hidden_state
1466
+ output = (decoder_outputs[0],)
1467
+ reference_points = decoder_outputs[-1]
1468
+ intermediate_hidden_states = decoder_outputs[-2]
1469
+
1470
+ # it has to follow the order of DABDETRModelOutput that is based on ModelOutput
1471
+ # If we only use one of the variables then the indexing will change.
1472
+ # E.g: if we return everything then 'decoder_attentions' is decoder_outputs[2], if we only use output_attentions then its decoder_outputs[1]
1473
+ if output_hidden_states and output_attentions:
1474
+ output += (
1475
+ decoder_outputs[1],
1476
+ decoder_outputs[2],
1477
+ decoder_outputs[3],
1478
+ encoder_outputs[0],
1479
+ encoder_outputs[1],
1480
+ encoder_outputs[2],
1481
+ )
1482
+ elif output_hidden_states:
1483
+ # decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states
1484
+ output += (
1485
+ decoder_outputs[1],
1486
+ encoder_outputs[0],
1487
+ encoder_outputs[1],
1488
+ )
1489
+ elif output_attentions:
1490
+ # decoder_self_attention, decoder_cross_attention, encoder_attentions
1491
+ output += (
1492
+ decoder_outputs[1],
1493
+ decoder_outputs[2],
1494
+ encoder_outputs[1],
1495
+ )
1496
+
1497
+ output += (intermediate_hidden_states, reference_points)
1498
+
1499
+ return output
1500
+
1501
+ reference_points = decoder_outputs.reference_points
1502
+ intermediate_hidden_states = decoder_outputs.intermediate_hidden_states
1503
+
1504
+ return DabDetrModelOutput(
1505
+ last_hidden_state=decoder_outputs.last_hidden_state,
1506
+ decoder_hidden_states=decoder_outputs.hidden_states if output_hidden_states else None,
1507
+ decoder_attentions=decoder_outputs.attentions if output_attentions else None,
1508
+ cross_attentions=decoder_outputs.cross_attentions if output_attentions else None,
1509
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state if output_hidden_states else None,
1510
+ encoder_hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
1511
+ encoder_attentions=encoder_outputs.attentions if output_attentions else None,
1512
+ intermediate_hidden_states=intermediate_hidden_states,
1513
+ reference_points=reference_points,
1514
+ )
1515
+
1516
+
1517
+ # Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->DabDetr
1518
+ class DabDetrMHAttentionMap(nn.Module):
1519
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1520
+
1521
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1522
+ super().__init__()
1523
+ self.num_heads = num_heads
1524
+ self.hidden_dim = hidden_dim
1525
+ self.dropout = nn.Dropout(dropout)
1526
+
1527
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1528
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1529
+
1530
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1531
+
1532
+ def forward(self, q, k, mask: Optional[Tensor] = None):
1533
+ q = self.q_linear(q)
1534
+ k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1535
+ queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1536
+ keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1537
+ weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1538
+
1539
+ if mask is not None:
1540
+ weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1541
+ weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1542
+ weights = self.dropout(weights)
1543
+ return weights
1544
+
1545
+
1546
+ @add_start_docstrings(
1547
+ """
1548
+ DAB_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
1549
+ top, for tasks such as COCO detection.
1550
+ """,
1551
+ DAB_DETR_START_DOCSTRING,
1552
+ )
1553
+ class DabDetrForObjectDetection(DabDetrPreTrainedModel):
1554
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1555
+ _tied_weights_keys = [
1556
+ r"bbox_predictor\.layers\.\d+\.(weight|bias)",
1557
+ r"model\.decoder\.bbox_embed\.layers\.\d+\.(weight|bias)",
1558
+ ]
1559
+
1560
+ def __init__(self, config: DabDetrConfig):
1561
+ super().__init__(config)
1562
+
1563
+ self.config = config
1564
+ self.auxiliary_loss = config.auxiliary_loss
1565
+ self.query_dim = config.query_dim
1566
+ # DAB-DETR encoder-decoder model
1567
+ self.model = DabDetrModel(config)
1568
+
1569
+ _bbox_embed = DabDetrMLP(config.hidden_size, config.hidden_size, 4, 3)
1570
+ # Object detection heads
1571
+ self.class_embed = nn.Linear(config.hidden_size, config.num_labels)
1572
+
1573
+ # Default bbox_embed_diff_each_layer is False
1574
+ self.bbox_predictor = _bbox_embed
1575
+
1576
+ # Default iter_update is True
1577
+ self.model.decoder.bbox_embed = self.bbox_predictor
1578
+
1579
+ # Initialize weights and apply final processing
1580
+ self.post_init()
1581
+
1582
+ # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/dab_detr.py
1583
+ @torch.jit.unused
1584
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1585
+ # this is a workaround to make torchscript happy, as torchscript
1586
+ # doesn't support dictionary with non-homogeneous values, such
1587
+ # as a dict having both a Tensor and a list.
1588
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
1589
+
1590
+ @add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING)
1591
+ @replace_return_docstrings(output_type=DabDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
1592
+ def forward(
1593
+ self,
1594
+ pixel_values: torch.FloatTensor,
1595
+ pixel_mask: Optional[torch.LongTensor] = None,
1596
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1597
+ encoder_outputs: Optional[torch.FloatTensor] = None,
1598
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1599
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1600
+ labels: Optional[List[dict]] = None,
1601
+ output_attentions: Optional[bool] = None,
1602
+ output_hidden_states: Optional[bool] = None,
1603
+ return_dict: Optional[bool] = None,
1604
+ ) -> Union[Tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]:
1605
+ r"""
1606
+ labels (`List[Dict]` of len `(batch_size,)`, *optional*):
1607
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1608
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1609
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1610
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1611
+
1612
+ Returns:
1613
+
1614
+ Examples:
1615
+
1616
+ ```python
1617
+ >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
1618
+ >>> from PIL import Image
1619
+ >>> import requests
1620
+
1621
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1622
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1623
+
1624
+ >>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
1625
+ >>> model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
1626
+
1627
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1628
+
1629
+ >>> with torch.no_grad():
1630
+ >>> outputs = model(**inputs)
1631
+
1632
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1633
+ >>> target_sizes = torch.tensor([(image.height, image.width)])
1634
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
1635
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1636
+ ... box = [round(i, 2) for i in box.tolist()]
1637
+ ... print(
1638
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1639
+ ... f"{round(score.item(), 3)} at location {box}"
1640
+ ... )
1641
+ Detected remote with confidence 0.833 at location [38.31, 72.1, 177.63, 118.45]
1642
+ Detected cat with confidence 0.831 at location [9.2, 51.38, 321.13, 469.0]
1643
+ Detected cat with confidence 0.804 at location [340.3, 16.85, 642.93, 370.95]
1644
+ Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
1645
+ Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
1646
+ ```"""
1647
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1648
+ output_hidden_states = (
1649
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1650
+ )
1651
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1652
+
1653
+ # First, sent images through DAB_DETR base model to obtain encoder + decoder outputs
1654
+ model_outputs = self.model(
1655
+ pixel_values,
1656
+ pixel_mask=pixel_mask,
1657
+ decoder_attention_mask=decoder_attention_mask,
1658
+ encoder_outputs=encoder_outputs,
1659
+ inputs_embeds=inputs_embeds,
1660
+ decoder_inputs_embeds=decoder_inputs_embeds,
1661
+ output_attentions=output_attentions,
1662
+ output_hidden_states=output_hidden_states,
1663
+ return_dict=return_dict,
1664
+ )
1665
+
1666
+ reference_points = model_outputs.reference_points if return_dict else model_outputs[-1]
1667
+ intermediate_hidden_states = model_outputs.intermediate_hidden_states if return_dict else model_outputs[-2]
1668
+
1669
+ # class logits + predicted bounding boxes
1670
+ logits = self.class_embed(intermediate_hidden_states[-1])
1671
+
1672
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
1673
+ bbox_with_refinement = self.bbox_predictor(intermediate_hidden_states)
1674
+ bbox_with_refinement[..., : self.query_dim] += reference_before_sigmoid
1675
+ outputs_coord = bbox_with_refinement.sigmoid()
1676
+
1677
+ pred_boxes = outputs_coord[-1]
1678
+
1679
+ loss, loss_dict, auxiliary_outputs = None, None, None
1680
+ if labels is not None:
1681
+ outputs_class = None
1682
+ if self.config.auxiliary_loss:
1683
+ outputs_class = self.class_embed(intermediate_hidden_states)
1684
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1685
+ logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1686
+ )
1687
+
1688
+ if not return_dict:
1689
+ if auxiliary_outputs is not None:
1690
+ output = (logits, pred_boxes) + auxiliary_outputs + model_outputs
1691
+ else:
1692
+ output = (logits, pred_boxes) + model_outputs
1693
+ # Since DabDetrObjectDetectionOutput doesn't have reference points + intermedieate_hidden_states we cut down.
1694
+ return ((loss, loss_dict) + output) if loss is not None else output[:-2]
1695
+
1696
+ return DabDetrObjectDetectionOutput(
1697
+ loss=loss,
1698
+ loss_dict=loss_dict,
1699
+ logits=logits,
1700
+ pred_boxes=pred_boxes,
1701
+ auxiliary_outputs=auxiliary_outputs,
1702
+ last_hidden_state=model_outputs.last_hidden_state,
1703
+ decoder_hidden_states=model_outputs.decoder_hidden_states if output_hidden_states else None,
1704
+ decoder_attentions=model_outputs.decoder_attentions if output_attentions else None,
1705
+ cross_attentions=model_outputs.cross_attentions if output_attentions else None,
1706
+ encoder_last_hidden_state=model_outputs.encoder_last_hidden_state if output_hidden_states else None,
1707
+ encoder_hidden_states=model_outputs.encoder_hidden_states if output_hidden_states else None,
1708
+ encoder_attentions=model_outputs.encoder_attentions if output_attentions else None,
1709
+ )
1710
+
1711
+
1712
+ __all__ = [
1713
+ "DabDetrForObjectDetection",
1714
+ "DabDetrModel",
1715
+ "DabDetrPreTrainedModel",
1716
+ ]
docs/transformers/build/lib/transformers/models/dac/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_dac import *
22
+ from .feature_extraction_dac import *
23
+ from .modeling_dac import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/dac/configuration_dac.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Dac model configuration"""
16
+
17
+ import math
18
+
19
+ import numpy as np
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class DacConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of an [`DacModel`]. It is used to instantiate a
31
+ Dac model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the
33
+ [descript/dac_16khz](https://huggingface.co/descript/dac_16khz) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ encoder_hidden_size (`int`, *optional*, defaults to 64):
40
+ Intermediate representation dimension for the encoder.
41
+ downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 8, 8]`):
42
+ Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
43
+ decoder_hidden_size (`int`, *optional*, defaults to 1536):
44
+ Intermediate representation dimension for the decoder.
45
+ n_codebooks (`int`, *optional*, defaults to 9):
46
+ Number of codebooks in the VQVAE.
47
+ codebook_size (`int`, *optional*, defaults to 1024):
48
+ Number of discrete codes in each codebook.
49
+ codebook_dim (`int`, *optional*, defaults to 8):
50
+ Dimension of the codebook vectors. If not defined, uses `encoder_hidden_size`.
51
+ quantizer_dropout (`bool`, *optional*, defaults to 0):
52
+ Whether to apply dropout to the quantizer.
53
+ commitment_loss_weight (float, *optional*, defaults to 0.25):
54
+ Weight of the commitment loss term in the VQVAE loss function.
55
+ codebook_loss_weight (float, *optional*, defaults to 1.0):
56
+ Weight of the codebook loss term in the VQVAE loss function.
57
+ sampling_rate (`int`, *optional*, defaults to 16000):
58
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
59
+ Example:
60
+
61
+ ```python
62
+ >>> from transformers import DacModel, DacConfig
63
+
64
+ >>> # Initializing a "descript/dac_16khz" style configuration
65
+ >>> configuration = DacConfig()
66
+
67
+ >>> # Initializing a model (with random weights) from the "descript/dac_16khz" style configuration
68
+ >>> model = DacModel(configuration)
69
+
70
+ >>> # Accessing the model configuration
71
+ >>> configuration = model.config
72
+ ```"""
73
+
74
+ model_type = "dac"
75
+
76
+ def __init__(
77
+ self,
78
+ encoder_hidden_size=64,
79
+ downsampling_ratios=[2, 4, 8, 8],
80
+ decoder_hidden_size=1536,
81
+ n_codebooks=9,
82
+ codebook_size=1024,
83
+ codebook_dim=8,
84
+ quantizer_dropout=0,
85
+ commitment_loss_weight=0.25,
86
+ codebook_loss_weight=1.0,
87
+ sampling_rate=16000,
88
+ **kwargs,
89
+ ):
90
+ self.encoder_hidden_size = encoder_hidden_size
91
+ self.downsampling_ratios = downsampling_ratios
92
+ self.decoder_hidden_size = decoder_hidden_size
93
+ self.upsampling_ratios = downsampling_ratios[::-1]
94
+ self.n_codebooks = n_codebooks
95
+ self.codebook_size = codebook_size
96
+ self.codebook_dim = codebook_dim
97
+ self.quantizer_dropout = quantizer_dropout
98
+ self.sampling_rate = sampling_rate
99
+
100
+ self.hidden_size = encoder_hidden_size * (2 ** len(downsampling_ratios))
101
+
102
+ self.hop_length = int(np.prod(downsampling_ratios))
103
+ self.commitment_loss_weight = commitment_loss_weight
104
+ self.codebook_loss_weight = codebook_loss_weight
105
+
106
+ super().__init__(**kwargs)
107
+
108
+ @property
109
+ def frame_rate(self) -> int:
110
+ hop_length = np.prod(self.upsampling_ratios)
111
+ return math.ceil(self.sampling_rate / hop_length)
112
+
113
+
114
+ __all__ = ["DacConfig"]
docs/transformers/build/lib/transformers/models/dac/convert_dac_checkpoint.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import fnmatch
17
+ import re
18
+
19
+ import torch
20
+
21
+ from transformers import (
22
+ DacConfig,
23
+ DacFeatureExtractor,
24
+ DacModel,
25
+ logging,
26
+ )
27
+
28
+
29
+ # checkpoints downloaded using:
30
+ # pip install descript-audio-codec
31
+ # python3 -m dac download # downloads the default 44kHz variant
32
+ # python3 -m dac download --model_type 44khz # downloads the 44kHz variant
33
+ # python3 -m dac download --model_type 24khz # downloads the 24kHz variant
34
+ # python3 -m dac download --model_type 16khz # downloads the 16kHz variant
35
+ # More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
36
+
37
+ logging.set_verbosity_info()
38
+ logger = logging.get_logger("transformers.models.dac")
39
+
40
+
41
+ def match_pattern(string, pattern):
42
+ # Split the pattern into parts
43
+ pattern_parts = pattern.split(".")
44
+ string_parts = string.split(".")
45
+
46
+ pattern_block_count = string_block_count = 0
47
+
48
+ for part in pattern_parts:
49
+ if part.startswith("block"):
50
+ pattern_block_count += 1
51
+
52
+ for part in string_parts:
53
+ if part.startswith("block"):
54
+ string_block_count += 1
55
+
56
+ return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
57
+
58
+
59
+ TOP_LEVEL_KEYS = []
60
+ IGNORE_KEYS = []
61
+
62
+
63
+ MAPPING_ENCODER = {
64
+ "encoder.block.0": ["encoder.conv1"],
65
+ "encoder.block.5": ["encoder.snake1"],
66
+ "encoder.block.6": ["encoder.conv2"],
67
+ "encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
68
+ "encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
69
+ "encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
70
+ "encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
71
+ "encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
72
+ "encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
73
+ }
74
+
75
+ MAPPING_QUANTIZER = {
76
+ "quantizer.quantizers.*": ["quantizer.quantizers.*"],
77
+ }
78
+
79
+ MAPPING_DECODER = {
80
+ "decoder.model.0": ["decoder.conv1"],
81
+ "decoder.model.5": ["decoder.snake1"],
82
+ "decoder.model.6": ["decoder.conv2"],
83
+ "decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
84
+ "decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
85
+ "decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
86
+ "decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
87
+ "decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
88
+ "decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
89
+ }
90
+
91
+
92
+ MAPPING = {
93
+ **MAPPING_ENCODER,
94
+ **MAPPING_QUANTIZER,
95
+ **MAPPING_DECODER,
96
+ }
97
+
98
+
99
+ def set_recursively(hf_pointer, key, value, full_name, weight_type):
100
+ for attribute in key.split("."):
101
+ hf_pointer = getattr(hf_pointer, attribute)
102
+
103
+ if weight_type is not None:
104
+ hf_shape = getattr(hf_pointer, weight_type).shape
105
+ else:
106
+ hf_shape = hf_pointer.shape
107
+
108
+ if hf_shape != value.shape:
109
+ raise ValueError(
110
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
111
+ f" {value.shape} for {full_name}"
112
+ )
113
+
114
+ if weight_type == "weight":
115
+ hf_pointer.weight.data = value
116
+ elif weight_type == "weight_g":
117
+ hf_pointer.weight_g.data = value
118
+ elif weight_type == "weight_v":
119
+ hf_pointer.weight_v.data = value
120
+ elif weight_type == "bias":
121
+ hf_pointer.bias.data = value
122
+ elif weight_type == "alpha":
123
+ hf_pointer.alpha.data = value
124
+ logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
125
+
126
+
127
+ def should_ignore(name, ignore_keys):
128
+ for key in ignore_keys:
129
+ if key.endswith(".*"):
130
+ if name.startswith(key[:-1]):
131
+ return True
132
+ elif ".*." in key:
133
+ prefix, suffix = key.split(".*.")
134
+ if prefix in name and suffix in name:
135
+ return True
136
+ elif key in name:
137
+ return True
138
+ return False
139
+
140
+
141
+ def recursively_load_weights(orig_dict, hf_model, model_name):
142
+ unused_weights = []
143
+
144
+ if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
145
+ raise ValueError(f"Unsupported model: {model_name}")
146
+
147
+ for name, value in orig_dict.items():
148
+ is_used = False
149
+ for key, mapped_key in MAPPING.items():
150
+ regex = re.compile(key)
151
+ if regex.search(name):
152
+ if len(mapped_key) == 1:
153
+ if mapped_key[0][0] == "q":
154
+ mapped_key = ".".join(name.split(".")[:-1])
155
+ else:
156
+ mapped_key = mapped_key[0]
157
+ elif len(mapped_key) == 3:
158
+ integers = re.findall(r"\b\d+\b", name)
159
+ if mapped_key[0][0] == "d":
160
+ mapped_key = "{}.{}.{}{}.{}".format(
161
+ mapped_key[0],
162
+ str(int(integers[0]) - 1),
163
+ mapped_key[1],
164
+ str(int(integers[1]) - 1),
165
+ mapped_key[2],
166
+ )
167
+ else:
168
+ mapped_key = "{}.{}.{}{}.{}".format(
169
+ mapped_key[0],
170
+ str(int(integers[0]) - 1),
171
+ mapped_key[1],
172
+ str(int(integers[1]) + 1),
173
+ mapped_key[2],
174
+ )
175
+ elif len(mapped_key) == 2:
176
+ integers = re.findall(r"\b\d+\b", name)
177
+ mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
178
+
179
+ is_used = True
180
+ if "weight_g" in name:
181
+ weight_type = "weight_g"
182
+ elif "weight_v" in name:
183
+ weight_type = "weight_v"
184
+ elif "bias" in name:
185
+ weight_type = "bias"
186
+ elif "alpha" in name:
187
+ weight_type = "alpha"
188
+ elif "weight" in name:
189
+ weight_type = "weight"
190
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
191
+
192
+ if not is_used:
193
+ unused_weights.append(name)
194
+
195
+ print(list(set(unused_weights)))
196
+
197
+ logger.warning(f"Unused weights: {unused_weights}")
198
+
199
+
200
+ @torch.no_grad()
201
+ def convert_checkpoint(
202
+ model_name,
203
+ checkpoint_path,
204
+ pytorch_dump_folder_path,
205
+ sample_rate=16000,
206
+ repo_id=None,
207
+ ):
208
+ model_dict = torch.load(checkpoint_path, "cpu", weights_only=True)
209
+
210
+ config = DacConfig()
211
+
212
+ metadata = model_dict["metadata"]["kwargs"]
213
+ config.encoder_hidden_size = metadata["encoder_dim"]
214
+ config.downsampling_ratios = metadata["encoder_rates"]
215
+ config.codebook_size = metadata["codebook_size"]
216
+ config.n_codebooks = metadata["n_codebooks"]
217
+ config.codebook_dim = metadata["codebook_dim"]
218
+ config.decoder_hidden_size = metadata["decoder_dim"]
219
+ config.upsampling_ratios = metadata["decoder_rates"]
220
+ config.quantizer_dropout = float(metadata["quantizer_dropout"])
221
+ config.sampling_rate = sample_rate
222
+
223
+ model = DacModel(config)
224
+ feature_extractor = DacFeatureExtractor()
225
+ feature_extractor.sampling_rate = sample_rate
226
+
227
+ original_checkpoint = model_dict["state_dict"]
228
+
229
+ model.apply_weight_norm()
230
+ recursively_load_weights(original_checkpoint, model, model_name)
231
+ model.remove_weight_norm()
232
+
233
+ model.save_pretrained(pytorch_dump_folder_path)
234
+
235
+ if repo_id:
236
+ print("Pushing to the hub...")
237
+ feature_extractor.push_to_hub(repo_id)
238
+ model.push_to_hub(repo_id)
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = argparse.ArgumentParser()
243
+ parser.add_argument(
244
+ "--model",
245
+ default="dac_44khz",
246
+ type=str,
247
+ help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
248
+ )
249
+ parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
250
+ parser.add_argument(
251
+ "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
252
+ )
253
+ parser.add_argument(
254
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
255
+ )
256
+ parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
257
+ args = parser.parse_args()
258
+
259
+ convert_checkpoint(
260
+ args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
261
+ )
docs/transformers/build/lib/transformers/models/dac/feature_extraction_dac.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for DAC"""
16
+
17
+ from typing import List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
22
+ from ...feature_extraction_utils import BatchFeature
23
+ from ...utils import PaddingStrategy, TensorType, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class DacFeatureExtractor(SequenceFeatureExtractor):
30
+ r"""
31
+ Constructs an Dac feature extractor.
32
+
33
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
34
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
35
+
36
+ Args:
37
+ feature_size (`int`, *optional*, defaults to 1):
38
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
39
+ sampling_rate (`int`, *optional*, defaults to 16000):
40
+ The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
41
+ padding_value (`float`, *optional*, defaults to 0.0):
42
+ The value that is used for padding.
43
+ hop_length (`int`, *optional*, defaults to 512):
44
+ Overlap length between successive windows.
45
+ """
46
+
47
+ model_input_names = ["input_values", "n_quantizers"]
48
+
49
+ def __init__(
50
+ self,
51
+ feature_size: int = 1,
52
+ sampling_rate: int = 16000,
53
+ padding_value: float = 0.0,
54
+ hop_length: int = 512,
55
+ **kwargs,
56
+ ):
57
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
58
+ self.hop_length = hop_length
59
+
60
+ def __call__(
61
+ self,
62
+ raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
63
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
64
+ truncation: Optional[bool] = False,
65
+ max_length: Optional[int] = None,
66
+ return_tensors: Optional[Union[str, TensorType]] = None,
67
+ sampling_rate: Optional[int] = None,
68
+ ) -> BatchFeature:
69
+ """
70
+ Main method to featurize and prepare for the model one or several sequence(s).
71
+
72
+ Args:
73
+ raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
74
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
75
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
76
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
77
+ (`feature_size = 2`).
78
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
79
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
80
+ index) among:
81
+
82
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
83
+ sequence if provided).
84
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
85
+ acceptable input length for the model if that argument is not provided.
86
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
87
+ lengths).
88
+ truncation (`bool`, *optional*, defaults to `False`):
89
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
90
+ max_length (`int`, *optional*):
91
+ Maximum length of the returned list and optionally padding length (see above).
92
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
93
+ If set, will return tensors instead of list of python integers. Acceptable values are:
94
+
95
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
96
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
97
+ - `'np'`: Return Numpy `np.ndarray` objects.
98
+ sampling_rate (`int`, *optional*):
99
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
100
+ `sampling_rate` at the forward call to prevent silent errors.
101
+ """
102
+ if sampling_rate is not None:
103
+ if sampling_rate != self.sampling_rate:
104
+ raise ValueError(
105
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
106
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
107
+ f" {self.sampling_rate} and not {sampling_rate}."
108
+ )
109
+ else:
110
+ logger.warning(
111
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
112
+ "Failing to do so can result in silent errors that might be hard to debug."
113
+ )
114
+
115
+ if padding and truncation:
116
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
117
+ elif padding is None:
118
+ # by default let's pad the inputs
119
+ padding = True
120
+
121
+ is_batched = bool(
122
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
123
+ )
124
+
125
+ if is_batched:
126
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
127
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
128
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
129
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
130
+ raw_audio = raw_audio.astype(np.float32)
131
+
132
+ # always return batch
133
+ if not is_batched:
134
+ raw_audio = [np.asarray(raw_audio).T]
135
+
136
+ # verify inputs are valid
137
+ for idx, example in enumerate(raw_audio):
138
+ if example.ndim > 2:
139
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
140
+ if self.feature_size == 1 and example.ndim != 1:
141
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
142
+ if self.feature_size == 2:
143
+ raise ValueError("Stereo audio isn't supported for now")
144
+
145
+ input_values = BatchFeature({"input_values": raw_audio})
146
+
147
+ # normal padding on batch
148
+ padded_inputs = self.pad(
149
+ input_values,
150
+ max_length=max_length,
151
+ truncation=truncation,
152
+ padding=padding,
153
+ return_attention_mask=False,
154
+ pad_to_multiple_of=self.hop_length,
155
+ )
156
+
157
+ if padding:
158
+ padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
159
+
160
+ input_values = []
161
+ for example in padded_inputs.pop("input_values"):
162
+ if self.feature_size == 1:
163
+ example = example[..., None]
164
+ input_values.append(example.T)
165
+
166
+ padded_inputs["input_values"] = input_values
167
+ if return_tensors is not None:
168
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
169
+
170
+ return padded_inputs
171
+
172
+
173
+ __all__ = ["DacFeatureExtractor"]
docs/transformers/build/lib/transformers/models/dac/modeling_dac.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Transformers DAC model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from ...modeling_utils import PreTrainedModel
27
+ from ...utils import (
28
+ ModelOutput,
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ replace_return_docstrings,
32
+ )
33
+ from .configuration_dac import DacConfig
34
+
35
+
36
+ # General docstring
37
+ _CONFIG_FOR_DOC = "DacConfig"
38
+
39
+
40
+ @dataclass
41
+ class DacOutput(ModelOutput):
42
+ """
43
+ Args:
44
+ loss (`torch.Tensor`):
45
+ Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
46
+ audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
47
+ Reconstructed audio data.
48
+ quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
49
+ Quantized continuous representation of input.
50
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
51
+ Codebook indices for each codebook (quantized discrete representation of input).
52
+ projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
53
+ Projected latents (continuous representation of input before quantization).
54
+ """
55
+
56
+ loss: Optional[torch.FloatTensor] = None
57
+ audio_values: Optional[torch.FloatTensor] = None
58
+ quantized_representation: Optional[torch.FloatTensor] = None
59
+ audio_codes: Optional[torch.LongTensor] = None
60
+ projected_latents: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class DacEncoderOutput(ModelOutput):
65
+ """
66
+ Args:
67
+ loss (`torch.Tensor`):
68
+ Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
69
+ quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
70
+ Quantized continuous representation of input.
71
+ audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
72
+ Codebook indices for each codebook (quantized discrete representation of input).
73
+ projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
74
+ Projected latents (continuous representation of input before quantization).
75
+ """
76
+
77
+ loss: Optional[torch.FloatTensor] = None
78
+ quantized_representation: Optional[torch.FloatTensor] = None
79
+ audio_codes: Optional[torch.FloatTensor] = None
80
+ projected_latents: Optional[torch.FloatTensor] = None
81
+
82
+
83
+ @dataclass
84
+ # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
85
+ class DacDecoderOutput(ModelOutput):
86
+ """
87
+ Args:
88
+ audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
89
+ Decoded audio values, obtained using the decoder part of Dac.
90
+ """
91
+
92
+ audio_values: Optional[torch.FloatTensor] = None
93
+
94
+
95
+ class Snake1d(nn.Module):
96
+ """
97
+ A 1-dimensional Snake activation function module.
98
+ """
99
+
100
+ def __init__(self, hidden_dim):
101
+ super().__init__()
102
+ self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
103
+
104
+ def forward(self, hidden_states):
105
+ shape = hidden_states.shape
106
+ hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
107
+ hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
108
+ hidden_states = hidden_states.reshape(shape)
109
+ return hidden_states
110
+
111
+
112
+ class DacVectorQuantize(nn.Module):
113
+ """
114
+ Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
115
+
116
+ Additionally uses following tricks from improved VQGAN
117
+ (https://arxiv.org/pdf/2110.04627.pdf):
118
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
119
+ for improved codebook usage
120
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
121
+ improves training stability
122
+ """
123
+
124
+ def __init__(self, config: DacConfig):
125
+ super().__init__()
126
+
127
+ self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
128
+ self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
129
+ self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
130
+
131
+ def forward(self, hidden_state):
132
+ """
133
+ Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
134
+
135
+ Args:
136
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
137
+ Input tensor.
138
+
139
+ Returns:
140
+ quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
141
+ Quantized continuous representation of input.
142
+ commitment_loss (`torch.FloatTensor`of shape `(1)`):
143
+ Commitment loss to train encoder to predict vectors closer to codebook entries.
144
+ codebook_loss (`torch.FloatTensor`of shape `(1)`):
145
+ Codebook loss to update the codebook.
146
+ audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
147
+ Codebook indices for each codebook, quantized discrete representation of input.
148
+ projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
149
+ Projected latents (continuous representation of input before quantization).
150
+ """
151
+
152
+ projected_latents = self.in_proj(hidden_state)
153
+ quantized_representation, audio_codes = self.decode_latents(projected_latents)
154
+
155
+ commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
156
+ codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
157
+ # noop in forward pass, straight-through gradient estimator in backward pass
158
+ quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
159
+ quantized_representation = self.out_proj(quantized_representation)
160
+
161
+ return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
162
+
163
+ def decode_latents(self, hidden_states):
164
+ batch_size, hidden_dim, sequence_length = hidden_states.shape
165
+ encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
166
+ codebook = self.codebook.weight # codebook: (N x D)
167
+
168
+ # L2 normalize encodings and codebook (ViT-VQGAN)
169
+ encodings = F.normalize(encodings)
170
+ codebook = F.normalize(codebook)
171
+
172
+ # Compute euclidean distance with codebook
173
+ l2_norm = encodings.pow(2).sum(1, keepdim=True)
174
+ dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
175
+
176
+ indices = dist.max(1)[1]
177
+ indices = indices.reshape(hidden_states.size(0), -1)
178
+ quantized_representation = self.codebook(indices).transpose(1, 2)
179
+ return quantized_representation, indices
180
+
181
+
182
+ class DacResidualUnit(nn.Module):
183
+ """
184
+ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
185
+ """
186
+
187
+ def __init__(self, dimension: int = 16, dilation: int = 1):
188
+ super().__init__()
189
+ pad = ((7 - 1) * dilation) // 2
190
+
191
+ self.snake1 = Snake1d(dimension)
192
+ self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
193
+ self.snake2 = Snake1d(dimension)
194
+ self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
195
+
196
+ def forward(self, hidden_state):
197
+ """
198
+ Forward pass through the residual unit.
199
+
200
+ Args:
201
+ hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
202
+ Input tensor .
203
+
204
+ Returns:
205
+ output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
206
+ Input tensor after passing through the residual unit.
207
+ """
208
+ output_tensor = hidden_state
209
+ output_tensor = self.conv1(self.snake1(output_tensor))
210
+ output_tensor = self.conv2(self.snake2(output_tensor))
211
+
212
+ padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
213
+ if padding > 0:
214
+ hidden_state = hidden_state[..., padding:-padding]
215
+ output_tensor = hidden_state + output_tensor
216
+ return output_tensor
217
+
218
+
219
+ class DacEncoderBlock(nn.Module):
220
+ """Encoder block used in DAC encoder."""
221
+
222
+ def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
223
+ super().__init__()
224
+
225
+ dimension = config.encoder_hidden_size * 2**stride_index
226
+ self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
227
+ self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
228
+ self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
229
+ self.snake1 = Snake1d(dimension // 2)
230
+ self.conv1 = nn.Conv1d(
231
+ dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
232
+ )
233
+
234
+ def forward(self, hidden_state):
235
+ hidden_state = self.res_unit1(hidden_state)
236
+ hidden_state = self.res_unit2(hidden_state)
237
+ hidden_state = self.snake1(self.res_unit3(hidden_state))
238
+ hidden_state = self.conv1(hidden_state)
239
+
240
+ return hidden_state
241
+
242
+
243
+ class DacDecoderBlock(nn.Module):
244
+ """Decoder block used in DAC decoder."""
245
+
246
+ def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
247
+ super().__init__()
248
+
249
+ input_dim = config.decoder_hidden_size // 2**stride_index
250
+ output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
251
+ self.snake1 = Snake1d(input_dim)
252
+ self.conv_t1 = nn.ConvTranspose1d(
253
+ input_dim,
254
+ output_dim,
255
+ kernel_size=2 * stride,
256
+ stride=stride,
257
+ padding=math.ceil(stride / 2),
258
+ )
259
+
260
+ self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
261
+ self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
262
+ self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
263
+
264
+ def forward(self, hidden_state):
265
+ hidden_state = self.snake1(hidden_state)
266
+ hidden_state = self.conv_t1(hidden_state)
267
+ hidden_state = self.res_unit1(hidden_state)
268
+ hidden_state = self.res_unit2(hidden_state)
269
+ hidden_state = self.res_unit3(hidden_state)
270
+
271
+ return hidden_state
272
+
273
+
274
+ class DacResidualVectorQuantize(nn.Module):
275
+ """
276
+ ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
277
+ """
278
+
279
+ def __init__(self, config: DacConfig):
280
+ super().__init__()
281
+
282
+ n_codebooks = config.n_codebooks
283
+ quantizer_dropout = config.quantizer_dropout
284
+
285
+ self.n_codebooks = n_codebooks
286
+
287
+ self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
288
+ self.quantizer_dropout = quantizer_dropout
289
+
290
+ def forward(self, hidden_state, n_quantizers: Optional[int] = None):
291
+ """
292
+ Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
293
+ Args:
294
+ hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
295
+ Input tensor to be quantized.
296
+ n_quantizers (`int`, *optional*):
297
+ Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
298
+ this argument is ignored during training, and a random number of quantizers is used.
299
+
300
+ Returns:
301
+ quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
302
+ Quantized continuous representation of input.
303
+ audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
304
+ Codebook indices for each codebook (quantized discrete representation of input).
305
+ projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
306
+ Projected latents (continuous representation of input before quantization).
307
+ commitment_loss (`torch.Tensor` of shape `(1)`):
308
+ Commitment loss to train the encoder to predict vectors closer to codebook entries.
309
+ codebook_loss (`torch.Tensor` of shape `(1)`):
310
+ Codebook loss to update the codebook.
311
+ """
312
+
313
+ quantized_representation = 0
314
+ residual = hidden_state
315
+ commitment_loss = 0
316
+ codebook_loss = 0
317
+
318
+ audio_codes = []
319
+ projected_latents = []
320
+
321
+ n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
322
+ if self.training:
323
+ n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
324
+ dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
325
+ n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
326
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
327
+ n_quantizers = n_quantizers.to(hidden_state.device)
328
+
329
+ for i, quantizer in enumerate(self.quantizers):
330
+ if self.training is False and i >= n_quantizers:
331
+ break
332
+
333
+ quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
334
+ residual
335
+ )
336
+
337
+ # Create mask to apply quantizer dropout
338
+ mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
339
+ quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
340
+ residual = residual - quantized_representation_i
341
+
342
+ # Sum losses
343
+ commitment_loss += commitment_loss_i * mask
344
+ codebook_loss += codebook_loss_i * mask
345
+
346
+ audio_codes.append(indices_i)
347
+ projected_latents.append(projected_latents_i)
348
+
349
+ audio_codes = torch.stack(audio_codes, dim=1)
350
+ projected_latents = torch.cat(projected_latents, dim=1)
351
+
352
+ return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
353
+
354
+ def from_codes(self, audio_codes: torch.Tensor):
355
+ """
356
+ Reconstructs the continuous representation from quantized codes.
357
+
358
+ Args:
359
+ audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
360
+ Quantized discrete representation of input.
361
+
362
+ Returns:
363
+ quantized_representation (`torch.Tensor`):
364
+ Quantized continuous representation of input.
365
+ projected_latents (`torch.Tensor`):
366
+ List of projected latents (continuous representations of input before quantization)
367
+ for each codebook.
368
+ audio_codes (`torch.Tensor`):
369
+ Codebook indices for each codebook.
370
+ """
371
+ quantized_representation = 0.0
372
+ projected_latents = []
373
+ n_codebooks = audio_codes.shape[1]
374
+ for i in range(n_codebooks):
375
+ projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
376
+ projected_latents.append(projected_latents_i)
377
+ quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
378
+ return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
379
+
380
+ def from_latents(self, latents: torch.Tensor):
381
+ """Reconstructs the quantized representation from unquantized latents.
382
+
383
+ Args:
384
+ latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
385
+ Continuous representation of input after projection.
386
+
387
+ Returns:
388
+ quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
389
+ Quantized representation of the full-projected space.
390
+ quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
391
+ Quantized representation of the latent space (continuous representation before quantization).
392
+ """
393
+ quantized_representation = 0
394
+ quantized_latents = []
395
+ codes = []
396
+ codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
397
+ dims = torch.cumsum(codebook_dims_tensor, dim=0)
398
+
399
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
400
+ for i in range(n_codebooks):
401
+ hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
402
+ quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
403
+ quantized_latents.append(quantized_latents_i)
404
+ codes.append(codes_i)
405
+
406
+ quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
407
+ quantized_representation = quantized_representation + quantized_representation_i
408
+
409
+ return quantized_representation, torch.cat(quantized_latents, dim=1)
410
+
411
+
412
+ class DacDecoder(nn.Module):
413
+ """DAC Decoder"""
414
+
415
+ def __init__(self, config: DacConfig):
416
+ super().__init__()
417
+
418
+ input_channel = config.hidden_size
419
+ channels = config.decoder_hidden_size
420
+ strides = config.upsampling_ratios
421
+
422
+ # Add first conv layer
423
+ self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
424
+
425
+ # Add upsampling + MRF blocks
426
+ block = []
427
+ for stride_index, stride in enumerate(strides):
428
+ block += [DacDecoderBlock(config, stride, stride_index)]
429
+
430
+ self.block = nn.ModuleList(block)
431
+ output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
432
+ self.snake1 = Snake1d(output_dim)
433
+ self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
434
+ self.tanh = nn.Tanh()
435
+
436
+ def forward(self, hidden_state):
437
+ hidden_state = self.conv1(hidden_state)
438
+
439
+ for layer in self.block:
440
+ hidden_state = layer(hidden_state)
441
+
442
+ hidden_state = self.snake1(hidden_state)
443
+ hidden_state = self.conv2(hidden_state)
444
+ hidden_state = self.tanh(hidden_state)
445
+
446
+ return hidden_state
447
+
448
+
449
+ class DacEncoder(nn.Module):
450
+ """DAC Encoder"""
451
+
452
+ def __init__(self, config: DacConfig):
453
+ super().__init__()
454
+
455
+ strides = config.downsampling_ratios
456
+ # Create first convolution
457
+ self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
458
+
459
+ self.block = []
460
+ # Create EncoderBlocks that double channels as they downsample by `stride`
461
+ for stride_index, stride in enumerate(strides):
462
+ stride_index = stride_index + 1
463
+ self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
464
+
465
+ self.block = nn.ModuleList(self.block)
466
+ d_model = config.encoder_hidden_size * 2**stride_index
467
+ self.snake1 = Snake1d(d_model)
468
+ self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
469
+
470
+ def forward(self, hidden_state):
471
+ hidden_state = self.conv1(hidden_state)
472
+
473
+ for module in self.block:
474
+ hidden_state = module(hidden_state)
475
+
476
+ hidden_state = self.snake1(hidden_state)
477
+ hidden_state = self.conv2(hidden_state)
478
+
479
+ return hidden_state
480
+
481
+
482
+ class DacPreTrainedModel(PreTrainedModel):
483
+ """
484
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
485
+ """
486
+
487
+ config_class = DacConfig
488
+ base_model_prefix = "dac"
489
+ main_input_name = "input_values"
490
+
491
+ def _init_weights(self, module):
492
+ if isinstance(module, nn.Conv1d):
493
+ nn.init.trunc_normal_(module.weight, std=0.02)
494
+ nn.init.constant_(module.bias, 0)
495
+
496
+ def apply_weight_norm(self):
497
+ weight_norm = nn.utils.weight_norm
498
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
499
+ weight_norm = nn.utils.parametrizations.weight_norm
500
+
501
+ for layer in self.quantizer.quantizers:
502
+ weight_norm(layer.in_proj)
503
+ weight_norm(layer.out_proj)
504
+
505
+ weight_norm(self.encoder.conv1)
506
+ weight_norm(self.encoder.conv2)
507
+
508
+ for layer in self.encoder.block:
509
+ weight_norm(layer.conv1)
510
+ weight_norm(layer.res_unit1.conv1)
511
+ weight_norm(layer.res_unit1.conv2)
512
+ weight_norm(layer.res_unit2.conv1)
513
+ weight_norm(layer.res_unit2.conv2)
514
+ weight_norm(layer.res_unit3.conv1)
515
+ weight_norm(layer.res_unit3.conv2)
516
+
517
+ weight_norm(self.decoder.conv1)
518
+ weight_norm(self.decoder.conv2)
519
+
520
+ for layer in self.decoder.block:
521
+ weight_norm(layer.conv_t1)
522
+ weight_norm(layer.res_unit1.conv1)
523
+ weight_norm(layer.res_unit1.conv2)
524
+ weight_norm(layer.res_unit2.conv1)
525
+ weight_norm(layer.res_unit2.conv2)
526
+ weight_norm(layer.res_unit3.conv1)
527
+ weight_norm(layer.res_unit3.conv2)
528
+
529
+ def remove_weight_norm(self):
530
+ for layer in self.quantizer.quantizers:
531
+ nn.utils.remove_weight_norm(layer.in_proj)
532
+ nn.utils.remove_weight_norm(layer.out_proj)
533
+
534
+ nn.utils.remove_weight_norm(self.encoder.conv1)
535
+ nn.utils.remove_weight_norm(self.encoder.conv2)
536
+
537
+ for layer in self.encoder.block:
538
+ nn.utils.remove_weight_norm(layer.conv1)
539
+ nn.utils.remove_weight_norm(layer.res_unit1.conv1)
540
+ nn.utils.remove_weight_norm(layer.res_unit1.conv2)
541
+ nn.utils.remove_weight_norm(layer.res_unit2.conv1)
542
+ nn.utils.remove_weight_norm(layer.res_unit2.conv2)
543
+ nn.utils.remove_weight_norm(layer.res_unit3.conv1)
544
+ nn.utils.remove_weight_norm(layer.res_unit3.conv2)
545
+
546
+ nn.utils.remove_weight_norm(self.decoder.conv1)
547
+ nn.utils.remove_weight_norm(self.decoder.conv2)
548
+
549
+ for layer in self.decoder.block:
550
+ nn.utils.remove_weight_norm(layer.conv_t1)
551
+ nn.utils.remove_weight_norm(layer.res_unit1.conv1)
552
+ nn.utils.remove_weight_norm(layer.res_unit1.conv2)
553
+ nn.utils.remove_weight_norm(layer.res_unit2.conv1)
554
+ nn.utils.remove_weight_norm(layer.res_unit2.conv2)
555
+ nn.utils.remove_weight_norm(layer.res_unit3.conv1)
556
+ nn.utils.remove_weight_norm(layer.res_unit3.conv2)
557
+
558
+
559
+ DAC_START_DOCSTRING = r"""
560
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
561
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
562
+ etc.)
563
+
564
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
565
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
566
+ and behavior.
567
+
568
+ Parameters:
569
+ config ([`DacConfig`]):
570
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
571
+ load the weights associated with the model, only the configuration. Check out the
572
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
573
+ """
574
+
575
+ DAC_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
578
+ Audio data to encode,
579
+ n_quantizers (`int`, *optional*):
580
+ Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
581
+ return_dict (`bool`, *optional*):
582
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
583
+ """
584
+
585
+
586
+ @add_start_docstrings(
587
+ "The DAC (Descript Audio Codec) model.",
588
+ DAC_START_DOCSTRING,
589
+ )
590
+ class DacModel(DacPreTrainedModel):
591
+ def __init__(self, config: DacConfig):
592
+ super().__init__(config)
593
+ self.config = config
594
+
595
+ self.encoder = DacEncoder(config)
596
+ self.decoder = DacDecoder(config)
597
+
598
+ self.quantizer = DacResidualVectorQuantize(config)
599
+
600
+ self.bits_per_codebook = int(math.log2(self.config.codebook_size))
601
+ if 2**self.bits_per_codebook != self.config.codebook_size:
602
+ raise ValueError("The codebook_size must be a power of 2.")
603
+
604
+ # Initialize weights and apply final processing
605
+ self.post_init()
606
+
607
+ @replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
608
+ def encode(
609
+ self,
610
+ input_values: torch.Tensor,
611
+ n_quantizers: Optional[int] = None,
612
+ return_dict: Optional[bool] = None,
613
+ ):
614
+ """
615
+ Encode given audio data and return quantized latent codes
616
+
617
+ Args:
618
+ input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
619
+ Input audio data to encode,
620
+ n_quantizers (int, *optional*):
621
+ Number of quantizers to use. If None, all quantizers are used. Default is None.
622
+ return_dict (`bool`, *optional*):
623
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
624
+ Returns:
625
+
626
+ """
627
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
628
+
629
+ quantized_representation = self.encoder(input_values)
630
+ quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
631
+ quantized_representation, n_quantizers
632
+ )
633
+
634
+ loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
635
+
636
+ if not return_dict:
637
+ return (loss, quantized_representation, audio_codes, projected_latents)
638
+
639
+ return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
640
+
641
+ @replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
642
+ def decode(
643
+ self,
644
+ quantized_representation: Optional[torch.Tensor] = None,
645
+ audio_codes: Optional[torch.Tensor] = None,
646
+ return_dict: Optional[bool] = None,
647
+ ):
648
+ """Decode given latent codes and return audio data
649
+
650
+ Args:
651
+ quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
652
+ Quantized continuous representation of input.
653
+ audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
654
+ The codebook indices for each codebook, representing the quantized discrete
655
+ representation of the input. This parameter should be provided if you want
656
+ to decode directly from the audio codes (it will overwrite quantized_representation).
657
+ return_dict (`bool`, *optional*):
658
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
659
+
660
+ Returns:
661
+
662
+ """
663
+
664
+ if quantized_representation is None and audio_codes is None:
665
+ raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
666
+
667
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
668
+
669
+ if audio_codes is not None:
670
+ quantized_representation = self.quantizer.from_codes(audio_codes)[0]
671
+
672
+ audio_values = self.decoder(quantized_representation).squeeze(1)
673
+
674
+ if not return_dict:
675
+ return (audio_values,)
676
+
677
+ return DacDecoderOutput(audio_values)
678
+
679
+ @add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
680
+ @replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
681
+ def forward(
682
+ self,
683
+ input_values: torch.Tensor,
684
+ n_quantizers: Optional[int] = None,
685
+ return_dict: Optional[bool] = None,
686
+ ):
687
+ """
688
+ Returns:
689
+ Examples:
690
+
691
+ ```python
692
+ >>> from datasets import load_dataset, Audio
693
+ >>> from transformers import DacModel, AutoProcessor
694
+ >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
695
+
696
+ >>> model = DacModel.from_pretrained("descript/dac_16khz")
697
+ >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
698
+ >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
699
+ >>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
700
+ >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
701
+
702
+ >>> encoder_outputs = model.encode(inputs["input_values"])
703
+ >>> # Get the intermediate audio codes
704
+ >>> audio_codes = encoder_outputs.audio_codes
705
+ >>> # Reconstruct the audio from its quantized representation
706
+ >>> audio_values = model.decode(encoder_outputs.quantized_representation)
707
+ >>> # or the equivalent with a forward pass
708
+ >>> audio_values = model(inputs["input_values"]).audio_values
709
+ ```"""
710
+
711
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
712
+ length = input_values.shape[-1]
713
+ loss, quantized_representation, audio_codes, projected_latents = self.encode(
714
+ input_values, n_quantizers, return_dict=False
715
+ )
716
+ audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
717
+
718
+ if not return_dict:
719
+ return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
720
+
721
+ return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
722
+
723
+
724
+ __all__ = ["DacModel", "DacPreTrainedModel"]
docs/transformers/build/lib/transformers/models/data2vec/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_data2vec_audio import *
22
+ from .configuration_data2vec_text import *
23
+ from .configuration_data2vec_vision import *
24
+ from .modeling_data2vec_audio import *
25
+ from .modeling_data2vec_text import *
26
+ from .modeling_data2vec_vision import *
27
+ from .modeling_tf_data2vec_vision import *
28
+ else:
29
+ import sys
30
+
31
+ _file = globals()["__file__"]
32
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_audio.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Data2VecText configuration"""
16
+
17
+ import math
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Data2VecAudioConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Data2VecAudioModel`]. It is used to instantiate
29
+ an Data2VecAudio model according to the specified arguments, defining the model architecture. Instantiating a
30
+ configuration with the defaults will yield a similar configuration to that of the Data2VecAudio
31
+ [facebook/data2vec-audio-base-960h](https://huggingface.co/facebook/data2vec-audio-base-960h) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32):
39
+ Vocabulary size of the Data2VecAudio model. Defines the number of different tokens that can be represented
40
+ by the `inputs_ids` passed when calling [`Data2VecAudioModel`] or [`TFData2VecAudioModel`]. Vocabulary size
41
+ of the model. Defines the different tokens that can be represented by the *inputs_ids* passed to the
42
+ forward method of [`Data2VecAudioModel`].
43
+ hidden_size (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ num_hidden_layers (`int`, *optional*, defaults to 12):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 12):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
51
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
52
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
53
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
54
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ activation_dropout (`float`, *optional*, defaults to 0.1):
57
+ The dropout ratio for activations inside the fully connected layer.
58
+ attention_dropout (`float`, *optional*, defaults to 0.1):
59
+ The dropout ratio for the attention probabilities.
60
+ final_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for the final projection layer of [`Data2VecAudioForCTC`].
62
+ layerdrop (`float`, *optional*, defaults to 0.1):
63
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
64
+ details.
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
70
+ The dropout probability for output of the feature encoder.
71
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
72
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
73
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
74
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
75
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
76
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
77
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
78
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
79
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
80
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
81
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
82
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
83
+ *conv_dim*.
84
+ conv_bias (`bool`, *optional*, defaults to `False`):
85
+ Whether the 1D convolutional layers have a bias.
86
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
87
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
88
+ embeddings layer.
89
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
90
+ Number of groups of 1D convolutional positional embeddings layer.
91
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
92
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
93
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
94
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
95
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
96
+ mask_time_length (`int`, *optional*, defaults to 10):
97
+ Length of vector span along the time axis.
98
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
99
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
100
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
101
+ mask_time_min_masks''
102
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
103
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
104
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
105
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
106
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
107
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
108
+ True`.
109
+ mask_feature_length (`int`, *optional*, defaults to 10):
110
+ Length of vector span along the feature axis.
111
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
112
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
113
+ step, irrespectively of `mask_feature_prob`. Only relevant if
114
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
115
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
116
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
117
+ instance of [`Data2VecAudioForCTC`].
118
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
119
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
120
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
121
+ of [`Data2VecAudioForCTC`].
122
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
123
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
124
+ instance of [`Data2VecAudioForSequenceClassification`].
125
+ classifier_proj_size (`int`, *optional*, defaults to 256):
126
+ Dimensionality of the projection before token mean-pooling for classification.
127
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
128
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
129
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
130
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
131
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
132
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
133
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
134
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
135
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
136
+ xvector_output_dim (`int`, *optional*, defaults to 512):
137
+ Dimensionality of the *XVector* embedding vectors.
138
+ add_adapter (`bool`, *optional*, defaults to `False`):
139
+ Whether a convolutional network should be stacked on top of the Data2VecAudio Encoder. Can be very useful
140
+ for warm-starting Data2VecAudio for SpeechEncoderDecoder models.
141
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
142
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
143
+ adapter_stride (`int`, *optional*, defaults to 2):
144
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
145
+ num_adapter_layers (`int`, *optional*, defaults to 3):
146
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
147
+ True`.
148
+ output_hidden_size (`int`, *optional*):
149
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
150
+ if `add_adapter is True`.
151
+
152
+ Example:
153
+
154
+ ```python
155
+ >>> from transformers import Data2VecAudioConfig, Data2VecAudioModel
156
+
157
+ >>> # Initializing a Data2VecAudio facebook/data2vec-audio-base-960h style configuration
158
+ >>> configuration = Data2VecAudioConfig()
159
+
160
+ >>> # Initializing a model (with random weights) from the facebook/data2vec-audio-base-960h style configuration
161
+ >>> model = Data2VecAudioModel(configuration)
162
+
163
+ >>> # Accessing the model configuration
164
+ >>> configuration = model.config
165
+ ```"""
166
+
167
+ model_type = "data2vec-audio"
168
+
169
+ def __init__(
170
+ self,
171
+ vocab_size=32,
172
+ hidden_size=768,
173
+ num_hidden_layers=12,
174
+ num_attention_heads=12,
175
+ intermediate_size=3072,
176
+ hidden_act="gelu",
177
+ hidden_dropout=0.1,
178
+ activation_dropout=0.1,
179
+ attention_dropout=0.1,
180
+ feat_proj_dropout=0.0,
181
+ final_dropout=0.1,
182
+ layerdrop=0.1,
183
+ initializer_range=0.02,
184
+ layer_norm_eps=1e-5,
185
+ feat_extract_activation="gelu",
186
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
187
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
188
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
189
+ conv_bias=False,
190
+ num_conv_pos_embedding_groups=16,
191
+ conv_pos_kernel_size=19,
192
+ num_conv_pos_embeddings=5,
193
+ mask_time_prob=0.05,
194
+ mask_time_length=10,
195
+ mask_time_min_masks=2,
196
+ mask_feature_prob=0.0,
197
+ mask_feature_length=10,
198
+ mask_feature_min_masks=0,
199
+ ctc_loss_reduction="sum",
200
+ ctc_zero_infinity=False,
201
+ use_weighted_layer_sum=False,
202
+ classifier_proj_size=256,
203
+ tdnn_dim=(512, 512, 512, 512, 1500),
204
+ tdnn_kernel=(5, 3, 3, 1, 1),
205
+ tdnn_dilation=(1, 2, 3, 1, 1),
206
+ xvector_output_dim=512,
207
+ pad_token_id=0,
208
+ bos_token_id=1,
209
+ eos_token_id=2,
210
+ add_adapter=False,
211
+ adapter_kernel_size=3,
212
+ adapter_stride=2,
213
+ num_adapter_layers=3,
214
+ output_hidden_size=None,
215
+ **kwargs,
216
+ ):
217
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
218
+ self.hidden_size = hidden_size
219
+ self.feat_extract_activation = feat_extract_activation
220
+ self.conv_dim = list(conv_dim)
221
+ self.conv_stride = list(conv_stride)
222
+ self.conv_kernel = list(conv_kernel)
223
+ self.conv_bias = conv_bias
224
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
225
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
226
+ self.conv_pos_kernel_size = conv_pos_kernel_size
227
+ self.num_feat_extract_layers = len(self.conv_dim)
228
+ self.num_hidden_layers = num_hidden_layers
229
+ self.intermediate_size = intermediate_size
230
+ self.hidden_act = hidden_act
231
+ self.num_attention_heads = num_attention_heads
232
+ self.hidden_dropout = hidden_dropout
233
+ self.attention_dropout = attention_dropout
234
+ self.activation_dropout = activation_dropout
235
+ self.feat_proj_dropout = feat_proj_dropout
236
+ self.final_dropout = final_dropout
237
+ self.layerdrop = layerdrop
238
+ self.layer_norm_eps = layer_norm_eps
239
+ self.initializer_range = initializer_range
240
+ self.vocab_size = vocab_size
241
+ self.use_weighted_layer_sum = use_weighted_layer_sum
242
+
243
+ if (
244
+ (len(self.conv_stride) != self.num_feat_extract_layers)
245
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
246
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
247
+ ):
248
+ raise ValueError(
249
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
250
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
251
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
252
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
253
+ )
254
+
255
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
256
+ self.mask_time_prob = mask_time_prob
257
+ self.mask_time_length = mask_time_length
258
+ self.mask_time_min_masks = mask_time_min_masks
259
+ self.mask_feature_prob = mask_feature_prob
260
+ self.mask_feature_length = mask_feature_length
261
+ self.mask_feature_min_masks = mask_feature_min_masks
262
+
263
+ # ctc loss
264
+ self.ctc_loss_reduction = ctc_loss_reduction
265
+ self.ctc_zero_infinity = ctc_zero_infinity
266
+
267
+ # adapter
268
+ self.add_adapter = add_adapter
269
+ self.adapter_kernel_size = adapter_kernel_size
270
+ self.adapter_stride = adapter_stride
271
+ self.num_adapter_layers = num_adapter_layers
272
+ self.output_hidden_size = output_hidden_size or hidden_size
273
+
274
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
275
+ self.classifier_proj_size = classifier_proj_size
276
+
277
+ # XVector-specific parameters. Feel free to ignore for other classes.
278
+ self.tdnn_dim = list(tdnn_dim)
279
+ self.tdnn_kernel = list(tdnn_kernel)
280
+ self.tdnn_dilation = list(tdnn_dilation)
281
+ self.xvector_output_dim = xvector_output_dim
282
+
283
+ @property
284
+ def inputs_to_logits_ratio(self):
285
+ return math.prod(self.conv_stride)
286
+
287
+
288
+ __all__ = ["Data2VecAudioConfig"]
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_text.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Data2VecText configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class Data2VecTextConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`Data2VecTextModel`] and [`Data2VecTextModel`]. It
31
+ is used to instantiate a Data2VecText model according to the specified arguments, defining the model architecture.
32
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Data2VecText
33
+ [facebook/data2vec-text-base](https://huggingface.co/facebook/data2vec-text-base) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 30522):
41
+ Vocabulary size of the DATA2VEC model. Defines the number of different tokens that can be represented by
42
+ the `inputs_ids` passed when calling [`Data2VecModel`].
43
+ hidden_size (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ num_hidden_layers (`int`, *optional*, defaults to 12):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 12):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
51
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
52
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
53
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
54
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
57
+ The dropout ratio for the attention probabilities.
58
+ max_position_embeddings (`int`, *optional*, defaults to 512):
59
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
60
+ just in case (e.g., 512 or 1024 or 2048).
61
+ type_vocab_size (`int`, *optional*, defaults to 2):
62
+ The vocabulary size of the `token_type_ids` passed when calling [`Data2VecModel`].
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the layer normalization layers.
67
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
68
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
69
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
70
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
71
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
72
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
73
+ is_decoder (`bool`, *optional*, defaults to `False`):
74
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
75
+ use_cache (`bool`, *optional*, defaults to `True`):
76
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
77
+ relevant if `config.is_decoder=True`.
78
+ classifier_dropout (`float`, *optional*):
79
+ The dropout ratio for the classification head.
80
+
81
+ Examples:
82
+
83
+ ```python
84
+ >>> from transformers import Data2VecTextConfig, Data2VecTextModel
85
+
86
+ >>> # Initializing a Data2VecText facebook/data2vec-text-base style configuration
87
+ >>> configuration = Data2VecTextConfig()
88
+
89
+ >>> # Initializing a model (with random weights) from the facebook/data2vec-text-base style configuration
90
+ >>> model = Data2VecTextModel(configuration)
91
+
92
+ >>> # Accessing the model configuration
93
+ >>> configuration = model.config
94
+ ```"""
95
+
96
+ model_type = "data2vec-text"
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size=30522,
101
+ hidden_size=768,
102
+ num_hidden_layers=12,
103
+ num_attention_heads=12,
104
+ intermediate_size=3072,
105
+ hidden_act="gelu",
106
+ hidden_dropout_prob=0.1,
107
+ attention_probs_dropout_prob=0.1,
108
+ max_position_embeddings=512,
109
+ type_vocab_size=2,
110
+ initializer_range=0.02,
111
+ layer_norm_eps=1e-12,
112
+ pad_token_id=1,
113
+ bos_token_id=0,
114
+ eos_token_id=2,
115
+ position_embedding_type="absolute",
116
+ use_cache=True,
117
+ classifier_dropout=None,
118
+ **kwargs,
119
+ ):
120
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
121
+
122
+ self.vocab_size = vocab_size
123
+ self.hidden_size = hidden_size
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.num_attention_heads = num_attention_heads
126
+ self.hidden_act = hidden_act
127
+ self.intermediate_size = intermediate_size
128
+ self.hidden_dropout_prob = hidden_dropout_prob
129
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
130
+ self.max_position_embeddings = max_position_embeddings
131
+ self.type_vocab_size = type_vocab_size
132
+ self.initializer_range = initializer_range
133
+ self.layer_norm_eps = layer_norm_eps
134
+ self.position_embedding_type = position_embedding_type
135
+ self.use_cache = use_cache
136
+ self.classifier_dropout = classifier_dropout
137
+
138
+
139
+ class Data2VecTextOnnxConfig(OnnxConfig):
140
+ @property
141
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
142
+ if self.task == "multiple-choice":
143
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
144
+ else:
145
+ dynamic_axis = {0: "batch", 1: "sequence"}
146
+ return OrderedDict(
147
+ [
148
+ ("input_ids", dynamic_axis),
149
+ ("attention_mask", dynamic_axis),
150
+ ]
151
+ )
152
+
153
+
154
+ __all__ = ["Data2VecTextConfig", "Data2VecTextOnnxConfig"]
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_vision.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Data2VecVision model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from packaging import version
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...onnx import OnnxConfig
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class Data2VecVisionConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate
33
+ an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a
34
+ configuration with the defaults will yield a similar configuration to that of the Data2VecVision
35
+ [facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture.
36
+
37
+ Args:
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
52
+ The dropout ratio for the attention probabilities.
53
+ initializer_range (`float`, *optional*, defaults to 0.02):
54
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
56
+ The epsilon used by the layer normalization layers.
57
+ image_size (`int`, *optional*, defaults to 224):
58
+ The size (resolution) of each image.
59
+ patch_size (`int`, *optional*, defaults to 16):
60
+ The size (resolution) of each patch.
61
+ num_channels (`int`, *optional*, defaults to 3):
62
+ The number of input channels.
63
+ use_mask_token (`bool`, *optional*, defaults to `False`):
64
+ Whether to use a mask token for masked image modeling.
65
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
66
+ Whether to use BERT-style absolute position embeddings.
67
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
68
+ Whether to use T5-style relative position embeddings in the self-attention layers.
69
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
70
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
71
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
72
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
73
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
74
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
75
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
76
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
77
+ CLS token, before applying the classification head.
78
+ out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
79
+ Indices of the feature maps to use for semantic segmentation.
80
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
81
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
82
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
83
+ Whether to use an auxiliary head during training.
84
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
85
+ Weight of the cross-entropy loss of the auxiliary head.
86
+ auxiliary_channels (`int`, *optional*, defaults to 256):
87
+ Number of channels to use in the auxiliary head.
88
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
89
+ Number of convolutional layers to use in the auxiliary head.
90
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
91
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
92
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
93
+ The index that is ignored by the loss function of the semantic segmentation model.
94
+
95
+ Example:
96
+
97
+ ```python
98
+ >>> from transformers import Data2VecVisionConfig, Data2VecVisionModel
99
+
100
+ >>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration
101
+ >>> configuration = Data2VecVisionConfig()
102
+
103
+ >>> # Initializing a model (with random weights) from the data2vec_vision-base-patch16-224-in22k style configuration
104
+ >>> model = Data2VecVisionModel(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ ```"""
109
+
110
+ model_type = "data2vec-vision"
111
+
112
+ def __init__(
113
+ self,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.0,
120
+ attention_probs_dropout_prob=0.0,
121
+ initializer_range=0.02,
122
+ layer_norm_eps=1e-12,
123
+ image_size=224,
124
+ patch_size=16,
125
+ num_channels=3,
126
+ use_mask_token=False,
127
+ use_absolute_position_embeddings=False,
128
+ use_relative_position_bias=False,
129
+ use_shared_relative_position_bias=False,
130
+ layer_scale_init_value=0.1,
131
+ drop_path_rate=0.1,
132
+ use_mean_pooling=True,
133
+ out_indices=[3, 5, 7, 11],
134
+ pool_scales=[1, 2, 3, 6],
135
+ use_auxiliary_head=True,
136
+ auxiliary_loss_weight=0.4,
137
+ auxiliary_channels=256,
138
+ auxiliary_num_convs=1,
139
+ auxiliary_concat_input=False,
140
+ semantic_loss_ignore_index=255,
141
+ **kwargs,
142
+ ):
143
+ super().__init__(**kwargs)
144
+
145
+ self.hidden_size = hidden_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+ self.intermediate_size = intermediate_size
149
+ self.hidden_act = hidden_act
150
+ self.hidden_dropout_prob = hidden_dropout_prob
151
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
152
+ self.initializer_range = initializer_range
153
+ self.layer_norm_eps = layer_norm_eps
154
+
155
+ self.image_size = image_size
156
+ self.patch_size = patch_size
157
+ self.num_channels = num_channels
158
+ self.use_mask_token = use_mask_token
159
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
160
+ self.use_relative_position_bias = use_relative_position_bias
161
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
162
+ self.layer_scale_init_value = layer_scale_init_value
163
+ self.drop_path_rate = drop_path_rate
164
+ self.use_mean_pooling = use_mean_pooling
165
+ # decode head attributes (semantic segmentation)
166
+ self.out_indices = out_indices
167
+ self.pool_scales = pool_scales
168
+ # auxiliary head attributes (semantic segmentation)
169
+ self.use_auxiliary_head = use_auxiliary_head
170
+ self.auxiliary_loss_weight = auxiliary_loss_weight
171
+ self.auxiliary_channels = auxiliary_channels
172
+ self.auxiliary_num_convs = auxiliary_num_convs
173
+ self.auxiliary_concat_input = auxiliary_concat_input
174
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
175
+
176
+
177
+ # Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
178
+ class Data2VecVisionOnnxConfig(OnnxConfig):
179
+ torch_onnx_minimum_version = version.parse("1.11")
180
+
181
+ @property
182
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
183
+ return OrderedDict(
184
+ [
185
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
186
+ ]
187
+ )
188
+
189
+ @property
190
+ def atol_for_validation(self) -> float:
191
+ return 1e-4
192
+
193
+
194
+ __all__ = ["Data2VecVisionConfig", "Data2VecVisionOnnxConfig"]
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Wav2Vec2 checkpoint."""
16
+
17
+ import argparse
18
+ import os
19
+ from functools import reduce
20
+
21
+ import fairseq
22
+ import torch
23
+ from datasets import load_dataset
24
+
25
+ from transformers import Wav2Vec2Processor, logging
26
+ from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig
27
+
28
+ # Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
29
+ from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy # noqa: F401
30
+ from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel
31
+
32
+
33
+ logging.set_verbosity_info()
34
+ logger = logging.get_logger(__name__)
35
+
36
+ MAPPING = {
37
+ "post_extract_proj": "feature_projection.projection",
38
+ "models.0.layer_norm": "feature_projection.layer_norm",
39
+ "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
40
+ "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
41
+ "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
42
+ "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
43
+ "self_attn_layer_norm": "encoder.layers.*.layer_norm",
44
+ "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
45
+ "fc2": "encoder.layers.*.feed_forward.output_dense",
46
+ "final_layer_norm": "encoder.layers.*.final_layer_norm",
47
+ "encoder.layer_norm": "encoder.layer_norm",
48
+ "w2v_model.layer_norm": "feature_projection.layer_norm",
49
+ "w2v_encoder.proj": "lm_head",
50
+ "mask_emb": "masked_spec_embed",
51
+ }
52
+ TOP_LEVEL_KEYS = [
53
+ "lm_head",
54
+ ]
55
+
56
+
57
+ def set_recursively(hf_pointer, key, value, full_name, weight_type):
58
+ for attribute in key.split("."):
59
+ hf_pointer = getattr(hf_pointer, attribute)
60
+
61
+ if weight_type is not None:
62
+ hf_shape = getattr(hf_pointer, weight_type).shape
63
+ else:
64
+ hf_shape = hf_pointer.shape
65
+
66
+ if hf_shape != value.shape:
67
+ raise ValueError(
68
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
69
+ f" {value.shape} for {full_name}"
70
+ )
71
+
72
+ if weight_type == "weight":
73
+ hf_pointer.weight.data = value
74
+ elif weight_type == "weight_g":
75
+ hf_pointer.weight_g.data = value
76
+ elif weight_type == "weight_v":
77
+ hf_pointer.weight_v.data = value
78
+ elif weight_type == "bias":
79
+ hf_pointer.bias.data = value
80
+ else:
81
+ hf_pointer.data = value
82
+
83
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
84
+
85
+
86
+ def recursively_load_weights(fairseq_model, hf_model, is_headless):
87
+ unused_weights = []
88
+ fairseq_dict = fairseq_model.state_dict()
89
+
90
+ if not is_headless:
91
+ feature_extractor = hf_model.data2vec_audio.feature_extractor
92
+ pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed
93
+
94
+ else:
95
+ feature_extractor = hf_model.feature_extractor
96
+ pos_conv_embedding = hf_model.encoder.pos_conv_embed
97
+
98
+ for name, value in fairseq_dict.items():
99
+ is_used = False
100
+ if "conv_layers" in name:
101
+ load_conv_layer(
102
+ name,
103
+ value,
104
+ feature_extractor,
105
+ unused_weights,
106
+ )
107
+ is_used = True
108
+ elif "pos_conv" in name:
109
+ load_pos_conv_layer(
110
+ name,
111
+ value,
112
+ pos_conv_embedding,
113
+ unused_weights,
114
+ )
115
+ is_used = True
116
+ else:
117
+ for key, mapped_key in MAPPING.items():
118
+ if not is_headless:
119
+ mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
120
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
121
+ is_used = True
122
+ if "*" in mapped_key:
123
+ layer_index = name.split(key)[0].split(".")[-2]
124
+ mapped_key = mapped_key.replace("*", layer_index)
125
+ if "weight_g" in name:
126
+ weight_type = "weight_g"
127
+ elif "weight_v" in name:
128
+ weight_type = "weight_v"
129
+ elif "bias" in name:
130
+ weight_type = "bias"
131
+ elif "weight" in name:
132
+ # TODO: don't match quantizer.weight_proj
133
+ weight_type = "weight"
134
+ else:
135
+ weight_type = None
136
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
137
+ continue
138
+ if not is_used:
139
+ unused_weights.append(name)
140
+
141
+ logger.warning(f"Unused weights: {unused_weights}")
142
+
143
+
144
+ def access_by_string(module, path):
145
+ names = path.split(".")
146
+ return reduce(getattr, names, module)
147
+
148
+
149
+ def set_weights(full_name, module, fsq_value, hf_weight_path):
150
+ hf_weight = access_by_string(module, hf_weight_path)
151
+ hf_value = hf_weight.data
152
+
153
+ if fsq_value.shape != hf_value.shape:
154
+ raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.")
155
+ hf_weight.data = fsq_value
156
+ logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.")
157
+
158
+
159
+ def load_conv_layer(full_name, value, feature_extractor, unused_weights):
160
+ name = full_name.split("conv_layers.")[-1]
161
+ items = name.split(".")
162
+ layer_id = int(items[0])
163
+ type_id = int(items[1])
164
+
165
+ weight_type = name.split(".")[-1]
166
+ if type_id == 0:
167
+ layer_type = "conv"
168
+ elif type_id == 2:
169
+ layer_type = "layer_norm"
170
+ else:
171
+ unused_weights.append(full_name)
172
+ return
173
+
174
+ set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}")
175
+
176
+
177
+ def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights):
178
+ name = full_name.split("pos_conv.")[-1]
179
+ items = name.split(".")
180
+ layer_id = int(items[0])
181
+ type_id = int(items[1])
182
+
183
+ weight_type = name.split(".")[-1]
184
+ if type_id != 0:
185
+ unused_weights.append(full_name)
186
+ return
187
+ else:
188
+ layer_type = "conv"
189
+
190
+ set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}")
191
+
192
+
193
+ @torch.no_grad()
194
+ def convert_wav2vec2_checkpoint(
195
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
196
+ ):
197
+ """
198
+ Copy/paste/tweak model's weights to transformers design.
199
+ """
200
+ if config_path is not None:
201
+ config = Data2VecAudioConfig.from_pretrained(config_path)
202
+ else:
203
+ config = Data2VecAudioConfig()
204
+
205
+ if not is_finetuned:
206
+ # Modify final_proj layer name
207
+ hf_wav2vec = Data2VecAudioModel(config)
208
+ data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
209
+
210
+ state_dict = torch.load(checkpoint_path, weights_only=True)
211
+ state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
212
+ state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
213
+ converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
214
+ torch.save(state_dict, converted_ckpt)
215
+ else:
216
+ hf_wav2vec = Data2VecAudioForCTC(config)
217
+ converted_ckpt = checkpoint_path
218
+
219
+ def load_data2vec(path):
220
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
221
+ return model[0].eval()
222
+
223
+ model = load_data2vec(converted_ckpt)
224
+
225
+ recursively_load_weights(model, hf_wav2vec, not is_finetuned)
226
+
227
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
228
+
229
+ ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
230
+ input_audio = [x["array"] for x in ds[:4]["audio"]]
231
+
232
+ inputs = processor(input_audio, return_tensors="pt", padding=True)
233
+
234
+ input_values = inputs.input_values
235
+ attention_mask = inputs.attention_mask
236
+ # input_values = inputs.input_values[:, :-1]
237
+ # attention_mask = inputs.attention_mask[:, :-1]
238
+
239
+ hf_wav2vec.eval()
240
+ model.eval()
241
+ if is_finetuned:
242
+ their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
243
+ "encoder_out"
244
+ ].transpose(0, 1)
245
+ our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"]
246
+
247
+ pred_ids = torch.argmax(our_output, dim=-1)
248
+ output_string = processor.batch_decode(pred_ids)
249
+
250
+ print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}")
251
+ else:
252
+ their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
253
+ "layer_results"
254
+ ][-1][0].transpose(0, 1)
255
+ our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"]
256
+
257
+ print(our_output.shape, their_output.shape)
258
+ max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
259
+ print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
260
+ success = torch.allclose(our_output, their_output, atol=1e-3)
261
+ print("Do both models output the same tensors?", "🔥" if success else "💩")
262
+ if not success:
263
+ raise Exception("Something went wRoNg")
264
+
265
+ hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
266
+
267
+ if is_finetuned:
268
+ processor.save_pretrained(pytorch_dump_folder_path)
269
+ else:
270
+ processor.feature_extractor.save_pretrained(pytorch_dump_folder_path)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
276
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
277
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
278
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
279
+ parser.add_argument(
280
+ "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
281
+ )
282
+ args = parser.parse_args()
283
+ convert_wav2vec2_checkpoint(
284
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
285
+ )
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert data2vec checkpoint."""
16
+
17
+ import argparse
18
+ import os
19
+ import pathlib
20
+
21
+ import fairseq
22
+ import torch
23
+ from fairseq.modules import TransformerSentenceEncoderLayer
24
+ from packaging import version
25
+
26
+ from transformers import (
27
+ Data2VecTextConfig,
28
+ Data2VecTextForMaskedLM,
29
+ Data2VecTextForSequenceClassification,
30
+ Data2VecTextModel,
31
+ )
32
+ from transformers.models.bert.modeling_bert import (
33
+ BertIntermediate,
34
+ BertLayer,
35
+ BertOutput,
36
+ BertSelfAttention,
37
+ BertSelfOutput,
38
+ )
39
+
40
+ # IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
41
+ # File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
42
+ from transformers.utils import logging
43
+
44
+
45
+ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
46
+ raise Exception("requires fairseq >= 0.9.0")
47
+
48
+
49
+ logging.set_verbosity_info()
50
+ logger = logging.get_logger(__name__)
51
+
52
+ SAMPLE_TEXT = "Hello world! cécé herlolip"
53
+
54
+
55
+ def convert_data2vec_checkpoint_to_pytorch(
56
+ data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
57
+ ):
58
+ """
59
+ Copy/paste/tweak data2vec's weights to our BERT structure.
60
+ """
61
+ data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
62
+ data2vec = Data2VecTextModel.from_pretrained(
63
+ data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
64
+ )
65
+ data2vec.eval() # disable dropout
66
+ data2vec_model = data2vec.models[0]
67
+ data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
68
+ config = Data2VecTextConfig(
69
+ vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
70
+ hidden_size=data2vec_model.args.encoder_embed_dim,
71
+ num_hidden_layers=data2vec_model.args.encoder_layers,
72
+ num_attention_heads=data2vec_model.args.encoder_attention_heads,
73
+ intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
74
+ max_position_embeddings=514,
75
+ type_vocab_size=1,
76
+ layer_norm_eps=1e-5, # PyTorch default used in fairseq
77
+ )
78
+ if classification_head:
79
+ config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
80
+ print("Our BERT config:", config)
81
+
82
+ model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
83
+ model.eval()
84
+
85
+ # Now let's copy all the weights.
86
+ # Embeddings
87
+ model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
88
+ model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
89
+ model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
90
+ model.data2vec_text.embeddings.token_type_embeddings.weight
91
+ ) # just zero them out b/c data2vec doesn't use them.
92
+ model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
93
+ model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias
94
+
95
+ for i in range(config.num_hidden_layers):
96
+ # Encoder: start of layer
97
+ layer: BertLayer = model.data2vec_text.encoder.layer[i]
98
+ data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]
99
+
100
+ # self attention
101
+ self_attn: BertSelfAttention = layer.attention.self
102
+ assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
103
+ (config.hidden_size, config.hidden_size)
104
+ ), (
105
+ "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
106
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
107
+ )
108
+ assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
109
+ (config.hidden_size, config.hidden_size)
110
+ ), (
111
+ "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
112
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
113
+ )
114
+ assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
115
+ (config.hidden_size, config.hidden_size)
116
+ ), (
117
+ "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
118
+ f" {torch.Size((config.hidden_size, config.hidden_size))}"
119
+ )
120
+
121
+ self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
122
+ self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
123
+ self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
124
+ self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
125
+ self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
126
+ self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias
127
+
128
+ # self-attention output
129
+ self_output: BertSelfOutput = layer.attention.output
130
+ assert self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape, (
131
+ f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
132
+ )
133
+ self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
134
+ self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
135
+ self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
136
+ self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
137
+
138
+ # intermediate
139
+ intermediate: BertIntermediate = layer.intermediate
140
+ assert intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape, (
141
+ f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
142
+ )
143
+ intermediate.dense.weight = data2vec_layer.fc1.weight
144
+ intermediate.dense.bias = data2vec_layer.fc1.bias
145
+
146
+ # output
147
+ bert_output: BertOutput = layer.output
148
+ assert bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape, (
149
+ f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
150
+ )
151
+ bert_output.dense.weight = data2vec_layer.fc2.weight
152
+ bert_output.dense.bias = data2vec_layer.fc2.bias
153
+ bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
154
+ bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
155
+ # end of layer
156
+
157
+ if classification_head:
158
+ model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
159
+ model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
160
+ model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
161
+ model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
162
+ else:
163
+ # LM Head
164
+ model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
165
+ model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
166
+ model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
167
+ model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
168
+ model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
169
+ model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias
170
+
171
+ # Let's check that we get the same results.
172
+ input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
173
+
174
+ our_output = model(input_ids)[0]
175
+ if classification_head:
176
+ their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
177
+ else:
178
+ their_output = data2vec_model(input_ids)[0]
179
+ print(our_output.shape, their_output.shape)
180
+ max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
181
+ print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
182
+ success = torch.allclose(our_output, their_output, atol=1e-3)
183
+ print("Do both models output the same tensors?", "🔥" if success else "💩")
184
+ if not success:
185
+ raise Exception("Something went wRoNg")
186
+
187
+ pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
188
+ print(f"Saving model to {pytorch_dump_folder_path}")
189
+ model.save_pretrained(pytorch_dump_folder_path)
190
+
191
+
192
+ if __name__ == "__main__":
193
+ parser = argparse.ArgumentParser()
194
+ # Required parameters
195
+ parser.add_argument(
196
+ "--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
197
+ )
198
+ parser.add_argument(
199
+ "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
200
+ )
201
+ parser.add_argument(
202
+ "--classification_head", action="store_true", help="Whether to convert a final classification head."
203
+ )
204
+ args = parser.parse_args()
205
+ convert_data2vec_checkpoint_to_pytorch(
206
+ args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
207
+ )
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+ from timm.models import create_model
9
+
10
+ from transformers import (
11
+ BeitImageProcessor,
12
+ Data2VecVisionConfig,
13
+ Data2VecVisionForImageClassification,
14
+ Data2VecVisionModel,
15
+ )
16
+
17
+
18
+ def create_rename_keys(config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec."):
19
+ prefix = "backbone." if is_semantic else ""
20
+
21
+ rename_keys = []
22
+ for i in range(config.num_hidden_layers):
23
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
24
+ rename_keys.append(
25
+ (f"{prefix}blocks.{i}.norm1.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_before.weight")
26
+ )
27
+ rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_before.bias"))
28
+ rename_keys.append(
29
+ (f"{prefix}blocks.{i}.attn.proj.weight", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.weight")
30
+ )
31
+ rename_keys.append(
32
+ (f"{prefix}blocks.{i}.attn.proj.bias", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.bias")
33
+ )
34
+ rename_keys.append(
35
+ (f"{prefix}blocks.{i}.norm2.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_after.weight")
36
+ )
37
+ rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_after.bias"))
38
+ rename_keys.append(
39
+ (f"{prefix}blocks.{i}.mlp.fc1.weight", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.weight")
40
+ )
41
+ rename_keys.append(
42
+ (f"{prefix}blocks.{i}.mlp.fc1.bias", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.bias")
43
+ )
44
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"{hf_prefix}encoder.layer.{i}.output.dense.weight"))
45
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"{hf_prefix}encoder.layer.{i}.output.dense.bias"))
46
+
47
+ # projection layer + position embeddings
48
+ rename_keys.extend(
49
+ [
50
+ (f"{prefix}cls_token", f"{hf_prefix}embeddings.cls_token"),
51
+ (f"{prefix}patch_embed.proj.weight", f"{hf_prefix}embeddings.patch_embeddings.projection.weight"),
52
+ (f"{prefix}patch_embed.proj.bias", f"{hf_prefix}embeddings.patch_embeddings.projection.bias"),
53
+ ]
54
+ )
55
+
56
+ if has_lm_head:
57
+ # mask token + shared relative position bias + layernorm
58
+ rename_keys.extend(
59
+ [
60
+ ("mask_token", f"{hf_prefix}embeddings.mask_token"),
61
+ (
62
+ "rel_pos_bias.relative_position_bias_table",
63
+ f"{hf_prefix}encoder.relative_position_bias.relative_position_bias_table",
64
+ ),
65
+ (
66
+ "rel_pos_bias.relative_position_index",
67
+ f"{hf_prefix}encoder.relative_position_bias.relative_position_index",
68
+ ),
69
+ ("norm.weight", "layernorm.weight"),
70
+ ("norm.bias", "layernorm.bias"),
71
+ ]
72
+ )
73
+ elif is_semantic:
74
+ # semantic segmentation classification heads
75
+ rename_keys.extend(
76
+ [
77
+ ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
78
+ ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
79
+ ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
80
+ ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
81
+ ]
82
+ )
83
+ else:
84
+ # layernorm + classification head
85
+ rename_keys.extend(
86
+ [
87
+ ("fc_norm.weight", f"{hf_prefix}pooler.layernorm.weight"),
88
+ ("fc_norm.bias", f"{hf_prefix}pooler.layernorm.bias"),
89
+ ("head.weight", "classifier.weight"),
90
+ ("head.bias", "classifier.bias"),
91
+ ]
92
+ )
93
+
94
+ return rename_keys
95
+
96
+
97
+ def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec_vision."):
98
+ for i in range(config.num_hidden_layers):
99
+ prefix = "backbone." if is_semantic else ""
100
+ # queries, keys and values
101
+ in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
102
+ q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
103
+ v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
104
+
105
+ state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
106
+ : config.hidden_size, :
107
+ ]
108
+ state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
109
+ state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
110
+ config.hidden_size : config.hidden_size * 2, :
111
+ ]
112
+ state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
113
+ -config.hidden_size :, :
114
+ ]
115
+ state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
116
+
117
+ # gamma_1 and gamma_2
118
+ # we call them lambda because otherwise they are renamed when using .from_pretrained
119
+ gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
120
+ gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
121
+
122
+ state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_1"] = gamma_1
123
+ state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_2"] = gamma_2
124
+
125
+ # relative_position bias table + index
126
+ if not has_lm_head:
127
+ # each layer has its own relative position bias
128
+ table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
129
+ index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
130
+
131
+ state_dict[
132
+ f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
133
+ ] = table
134
+ state_dict[
135
+ f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
136
+ ] = index
137
+
138
+
139
+ def get_args():
140
+ parser = argparse.ArgumentParser(
141
+ "Convert Data2VecVision to HF for image classification and pretraining", add_help=False
142
+ )
143
+ parser.add_argument("--hf_checkpoint_name", type=str)
144
+ parser.add_argument("--input_size", default=224, type=int, help="images input size")
145
+ parser.add_argument("--beit_checkpoint", default="", help="beit checkpoint")
146
+
147
+ return parser.parse_args()
148
+
149
+
150
+ def load_beit_model(args, is_finetuned, is_large):
151
+ def load_state_dict(model, state_dict, prefix="", ignore_missing="relative_position_index"):
152
+ missing_keys = []
153
+ unexpected_keys = []
154
+ error_msgs = []
155
+ # copy state_dict so _load_from_state_dict can modify it
156
+ metadata = getattr(state_dict, "_metadata", None)
157
+ state_dict = state_dict.copy()
158
+ if metadata is not None:
159
+ state_dict._metadata = metadata
160
+
161
+ def load(module, prefix=""):
162
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
163
+ module._load_from_state_dict(
164
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
165
+ )
166
+ for name, child in module._modules.items():
167
+ if child is not None:
168
+ load(child, prefix + name + ".")
169
+
170
+ load(model, prefix=prefix)
171
+
172
+ warn_missing_keys = []
173
+ ignore_missing_keys = []
174
+ for key in missing_keys:
175
+ keep_flag = True
176
+ for ignore_key in ignore_missing.split("|"):
177
+ if ignore_key in key:
178
+ keep_flag = False
179
+ break
180
+ if keep_flag:
181
+ warn_missing_keys.append(key)
182
+ else:
183
+ ignore_missing_keys.append(key)
184
+
185
+ missing_keys = warn_missing_keys
186
+
187
+ if len(missing_keys) > 0:
188
+ print(
189
+ "Weights of {} not initialized from pretrained model: {}".format(
190
+ model.__class__.__name__, missing_keys
191
+ )
192
+ )
193
+ if len(unexpected_keys) > 0:
194
+ print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys))
195
+ if len(ignore_missing_keys) > 0:
196
+ print(
197
+ "Ignored weights of {} not initialized from pretrained model: {}".format(
198
+ model.__class__.__name__, ignore_missing_keys
199
+ )
200
+ )
201
+ if len(error_msgs) > 0:
202
+ print("\n".join(error_msgs))
203
+
204
+ model_kwargs = {
205
+ "pretrained": False,
206
+ "use_shared_rel_pos_bias": True,
207
+ "use_abs_pos_emb": False,
208
+ "init_values": 0.1,
209
+ }
210
+
211
+ if is_finetuned:
212
+ model_kwargs.update(
213
+ {
214
+ "num_classes": 1000,
215
+ "use_mean_pooling": True,
216
+ "init_scale": 0.001,
217
+ "use_rel_pos_bias": True,
218
+ }
219
+ )
220
+
221
+ model = create_model(
222
+ "beit_large_patch16_224" if is_large else "beit_base_patch16_224",
223
+ **model_kwargs,
224
+ )
225
+ patch_size = model.patch_embed.patch_size
226
+ args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
227
+ checkpoint = torch.load(args.beit_checkpoint, map_location="cpu", weights_only=True)
228
+
229
+ print(f"Load ckpt from {args.beit_checkpoint}")
230
+ checkpoint_model = None
231
+ for model_key in ("model", "module"):
232
+ if model_key in checkpoint:
233
+ checkpoint_model = checkpoint[model_key]
234
+ print(f"Load state_dict by model_key = {model_key}")
235
+ break
236
+
237
+ all_keys = list(checkpoint_model.keys())
238
+ for key in all_keys:
239
+ if "relative_position_index" in key:
240
+ checkpoint_model.pop(key)
241
+
242
+ if "relative_position_bias_table" in key:
243
+ rel_pos_bias = checkpoint_model[key]
244
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
245
+ dst_num_pos, _ = model.state_dict()[key].size()
246
+ dst_patch_shape = model.patch_embed.patch_shape
247
+ if dst_patch_shape[0] != dst_patch_shape[1]:
248
+ raise NotImplementedError()
249
+
250
+ load_state_dict(model, checkpoint_model, prefix="")
251
+
252
+ return model
253
+
254
+
255
+ def main():
256
+ args = get_args()
257
+
258
+ is_finetuned = "ft1k" in args.hf_checkpoint_name
259
+ is_large = "large" in args.hf_checkpoint_name
260
+
261
+ if is_finetuned:
262
+ # To convert Beit's data2vec_vision to HF you need to copy
263
+ # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py
264
+ # into this folder.
265
+ import modeling_finetune # noqa: F401
266
+ else:
267
+ # To convert Beit's data2vec_vision to HF you need to copy
268
+ # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py
269
+ # into this folder
270
+ # IMPORTANT: Note that for now we've only converted the down-stream
271
+ # model and not the full pretrained model. This means for the integration
272
+ # test you need to add a `return x` after the following line:
273
+ # https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197
274
+ # to make the integration test pass.
275
+ import modeling_cyclical # noqa: F401
276
+
277
+ # 1. Create model config
278
+ config = Data2VecVisionConfig()
279
+ if is_finetuned:
280
+ config.use_relative_position_bias = True
281
+ config.use_shared_relative_position_bias = False
282
+ config.use_mean_pooling = True
283
+ config.num_labels = 1000
284
+
285
+ repo_id = "huggingface/label-files"
286
+ filename = "imagenet-1k-id2label.json"
287
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
288
+ id2label = {int(k): v for k, v in id2label.items()}
289
+ config.id2label = id2label
290
+ config.label2id = {v: k for k, v in id2label.items()}
291
+ else:
292
+ config.use_relative_position_bias = False
293
+ config.use_shared_relative_position_bias = True
294
+ config.use_mean_pooling = False
295
+
296
+ if is_large:
297
+ config.hidden_size = 1024
298
+ config.intermediate_size = 4096
299
+ config.num_hidden_layers = 24
300
+ config.num_attention_heads = 16
301
+
302
+ # 2. Load Beit model
303
+ orig_model = load_beit_model(args, is_finetuned, is_large)
304
+ orig_model.eval()
305
+
306
+ # 3. Forward Beit model
307
+ image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
308
+ image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png")
309
+ encoding = image_processor(images=image, return_tensors="pt")
310
+ pixel_values = encoding["pixel_values"]
311
+
312
+ orig_args = (pixel_values,) if is_finetuned else (pixel_values, None)
313
+ with torch.no_grad():
314
+ orig_model_output = orig_model(*orig_args)
315
+
316
+ # 4. Load HF Data2VecVision model
317
+ if is_finetuned:
318
+ hf_model = Data2VecVisionForImageClassification(config)
319
+ hf_model.eval()
320
+ has_lm_head = False
321
+ hf_prefix = "data2vec_vision."
322
+ else:
323
+ hf_model = Data2VecVisionModel(config)
324
+ hf_model.eval()
325
+ has_lm_head = True
326
+ hf_prefix = ""
327
+
328
+ rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
329
+ state_dict = orig_model.state_dict()
330
+ for src, dest in rename_keys:
331
+ val = state_dict.pop(src)
332
+ state_dict[dest] = val
333
+
334
+ read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
335
+ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
336
+ print("HF missing", missing_keys)
337
+ print("HF unexpected_keys", unexpected_keys)
338
+
339
+ # 5. Forward HF Data2VecVision model
340
+ with torch.no_grad():
341
+ hf_model_output = hf_model(pixel_values)
342
+
343
+ hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state
344
+
345
+ # 6. Compare
346
+ max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item()
347
+
348
+ print(f"max_absolute_diff = {max_absolute_diff}")
349
+ success = torch.allclose(hf_output, orig_model_output, atol=1e-3)
350
+ print("Do both models output the same tensors?", "🔥" if success else "💩")
351
+ if not success:
352
+ raise Exception("Something went wRoNg")
353
+
354
+ # 7. Save
355
+ print(f"Saving to {args.hf_checkpoint_name}")
356
+ hf_model.save_pretrained(args.hf_checkpoint_name)
357
+ image_processor.save_pretrained(args.hf_checkpoint_name)
358
+
359
+
360
+ if __name__ == "__main__":
361
+ main()
362
+ # Run the following to convert checkpoints
363
+ # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
364
+ # --beit_checkpoint ./pretrained_base.pt \
365
+ # --hf_checkpoint_name "./data2vec-vision-base"
366
+ # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
367
+ # --beit_checkpoint ./finetuned_base.pt \
368
+ # --hf_checkpoint_name "./data2vec-vision-base-ft1k"
369
+ # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
370
+ # --beit_checkpoint ./pretrained_large.pt \
371
+ # --hf_checkpoint_name "./data2vec-vision-large"
372
+ # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
373
+ # --beit_checkpoint ./finetuned_large.pt \
374
+ # --hf_checkpoint_name "./data2vec-vision-large-ft1k"
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_audio.py ADDED
@@ -0,0 +1,1746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_data2vec_audio.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import math
8
+ import warnings
9
+ from typing import Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from ...activations import ACT2FN
17
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
18
+ from ...integrations.fsdp import is_fsdp_managed_module
19
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
20
+ from ...modeling_outputs import (
21
+ BaseModelOutput,
22
+ CausalLMOutput,
23
+ SequenceClassifierOutput,
24
+ TokenClassifierOutput,
25
+ Wav2Vec2BaseModelOutput,
26
+ XVectorOutput,
27
+ )
28
+ from ...modeling_utils import PreTrainedModel
29
+ from ...utils import (
30
+ add_code_sample_docstrings,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ is_peft_available,
34
+ logging,
35
+ )
36
+ from .configuration_data2vec_audio import Data2VecAudioConfig
37
+
38
+
39
+ if is_flash_attn_available():
40
+ from ...modeling_flash_attention_utils import _flash_attention_forward
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ # Base docstring
46
+ _CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
47
+
48
+ # General docstring
49
+ _CONFIG_FOR_DOC = "Data2VecAudioConfig"
50
+
51
+
52
+ class Data2VecAudioConvLayer(nn.Module):
53
+ def __init__(self, config, layer_id=0):
54
+ super().__init__()
55
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
56
+ self.out_conv_dim = config.conv_dim[layer_id]
57
+
58
+ self.conv = nn.Conv1d(
59
+ self.in_conv_dim,
60
+ self.out_conv_dim,
61
+ kernel_size=config.conv_kernel[layer_id],
62
+ stride=config.conv_stride[layer_id],
63
+ bias=config.conv_bias,
64
+ )
65
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
66
+ self.activation = ACT2FN[config.feat_extract_activation]
67
+
68
+ def forward(self, hidden_states):
69
+ hidden_states = self.conv(hidden_states)
70
+
71
+ hidden_states = hidden_states.transpose(-2, -1)
72
+ hidden_states = self.layer_norm(hidden_states)
73
+ hidden_states = hidden_states.transpose(-2, -1)
74
+
75
+ hidden_states = self.activation(hidden_states)
76
+ return hidden_states
77
+
78
+
79
+ class Data2VecAudioPadLayer(nn.Module):
80
+ def __init__(self, num_conv_pos_embeddings):
81
+ super().__init__()
82
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
83
+
84
+ def forward(self, hidden_states):
85
+ if self.num_pad_remove > 0:
86
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
87
+ return hidden_states
88
+
89
+
90
+ class Data2VecAudioPositionalConvLayer(nn.Module):
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.conv = nn.Conv1d(
94
+ config.hidden_size,
95
+ config.hidden_size,
96
+ kernel_size=config.conv_pos_kernel_size,
97
+ padding=config.conv_pos_kernel_size // 2,
98
+ groups=config.num_conv_pos_embedding_groups,
99
+ )
100
+
101
+ self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
102
+ self.activation = ACT2FN[config.feat_extract_activation]
103
+ # no learnable parameters
104
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
105
+
106
+ def forward(self, hidden_states):
107
+ hidden_states = self.conv(hidden_states)
108
+ hidden_states = self.padding(hidden_states)
109
+
110
+ hidden_states = hidden_states.transpose(1, 2)
111
+ hidden_states = self.layer_norm(hidden_states)
112
+ hidden_states = hidden_states.transpose(1, 2)
113
+ hidden_states = self.activation(hidden_states)
114
+ return hidden_states
115
+
116
+
117
+ class Data2VecAudioPositionalConvEmbedding(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ self.layers = nn.ModuleList(
121
+ [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
122
+ )
123
+
124
+ def forward(self, hidden_states):
125
+ hidden_states = hidden_states.transpose(1, 2)
126
+ for layer in self.layers:
127
+ hidden_states = layer(hidden_states)
128
+ hidden_states = hidden_states.transpose(1, 2)
129
+ return hidden_states
130
+
131
+
132
+ class Data2VecAudioFeatureEncoder(nn.Module):
133
+ """Construct the features from raw audio waveform"""
134
+
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.conv_layers = nn.ModuleList(
138
+ [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
139
+ )
140
+ self.gradient_checkpointing = False
141
+ self._requires_grad = True
142
+
143
+ def _freeze_parameters(self):
144
+ for param in self.parameters():
145
+ param.requires_grad = False
146
+ self._requires_grad = False
147
+
148
+ def forward(self, input_values):
149
+ hidden_states = input_values[:, None]
150
+
151
+ # make sure hidden_states require grad for gradient_checkpointing
152
+ if self._requires_grad and self.training:
153
+ hidden_states.requires_grad = True
154
+
155
+ for conv_layer in self.conv_layers:
156
+ if self._requires_grad and self.gradient_checkpointing and self.training:
157
+ hidden_states = self._gradient_checkpointing_func(
158
+ conv_layer.__call__,
159
+ hidden_states,
160
+ )
161
+ else:
162
+ hidden_states = conv_layer(hidden_states)
163
+
164
+ return hidden_states
165
+
166
+
167
+ class Data2VecAudioFeatureProjection(nn.Module):
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
171
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
172
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
173
+
174
+ def forward(self, hidden_states):
175
+ # non-projected hidden states are needed for quantization
176
+ norm_hidden_states = self.layer_norm(hidden_states)
177
+ hidden_states = self.projection(norm_hidden_states)
178
+ hidden_states = self.dropout(hidden_states)
179
+ return hidden_states, norm_hidden_states
180
+
181
+
182
+ class Data2VecAudioAttention(nn.Module):
183
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
184
+
185
+ def __init__(
186
+ self,
187
+ embed_dim: int,
188
+ num_heads: int,
189
+ dropout: float = 0.0,
190
+ is_decoder: bool = False,
191
+ bias: bool = True,
192
+ is_causal: bool = False,
193
+ config: Optional[Data2VecAudioConfig] = None,
194
+ ):
195
+ super().__init__()
196
+ self.embed_dim = embed_dim
197
+ self.num_heads = num_heads
198
+ self.dropout = dropout
199
+ self.head_dim = embed_dim // num_heads
200
+ self.config = config
201
+
202
+ if (self.head_dim * num_heads) != self.embed_dim:
203
+ raise ValueError(
204
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
205
+ f" and `num_heads`: {num_heads})."
206
+ )
207
+ self.scaling = self.head_dim**-0.5
208
+ self.is_decoder = is_decoder
209
+ self.is_causal = is_causal
210
+
211
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
212
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
213
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
214
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
215
+
216
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
217
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ key_value_states: Optional[torch.Tensor] = None,
223
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ layer_head_mask: Optional[torch.Tensor] = None,
226
+ output_attentions: bool = False,
227
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
228
+ """Input shape: Batch x Time x Channel"""
229
+
230
+ # if key_value_states are provided this layer is used as a cross-attention layer
231
+ # for the decoder
232
+ is_cross_attention = key_value_states is not None
233
+
234
+ bsz, tgt_len, _ = hidden_states.size()
235
+
236
+ # get query proj
237
+ query_states = self.q_proj(hidden_states) * self.scaling
238
+ # get key, value proj
239
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
240
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
241
+ # the provided `key_value_states` to support prefix tuning
242
+ if (
243
+ is_cross_attention
244
+ and past_key_value is not None
245
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
246
+ ):
247
+ # reuse k,v, cross_attentions
248
+ key_states = past_key_value[0]
249
+ value_states = past_key_value[1]
250
+ elif is_cross_attention:
251
+ # cross_attentions
252
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
253
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
254
+ elif past_key_value is not None:
255
+ # reuse k, v, self_attention
256
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
257
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
258
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
259
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
260
+ else:
261
+ # self_attention
262
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
263
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
264
+
265
+ if self.is_decoder:
266
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
267
+ # Further calls to cross_attention layer can then reuse all cross-attention
268
+ # key/value_states (first "if" case)
269
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
270
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
271
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
272
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
273
+ past_key_value = (key_states, value_states)
274
+
275
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
276
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
277
+ key_states = key_states.reshape(*proj_shape)
278
+ value_states = value_states.reshape(*proj_shape)
279
+
280
+ src_len = key_states.size(1)
281
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
282
+
283
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
284
+ raise ValueError(
285
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
286
+ f" {attn_weights.size()}"
287
+ )
288
+
289
+ if attention_mask is not None:
290
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
291
+ raise ValueError(
292
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
293
+ )
294
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
295
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
296
+
297
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
298
+
299
+ if layer_head_mask is not None:
300
+ if layer_head_mask.size() != (self.num_heads,):
301
+ raise ValueError(
302
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
303
+ f" {layer_head_mask.size()}"
304
+ )
305
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
306
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
307
+
308
+ if output_attentions:
309
+ # this operation is a bit awkward, but it's required to
310
+ # make sure that attn_weights keeps its gradient.
311
+ # In order to do so, attn_weights have to be reshaped
312
+ # twice and have to be reused in the following
313
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
314
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
315
+ else:
316
+ attn_weights_reshaped = None
317
+
318
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
319
+
320
+ attn_output = torch.bmm(attn_probs, value_states)
321
+
322
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
323
+ raise ValueError(
324
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
325
+ f" {attn_output.size()}"
326
+ )
327
+
328
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
329
+ attn_output = attn_output.transpose(1, 2)
330
+
331
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
332
+ # partitioned across GPUs when using tensor-parallelism.
333
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
334
+
335
+ attn_output = self.out_proj(attn_output)
336
+
337
+ return attn_output, attn_weights_reshaped, past_key_value
338
+
339
+
340
+ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
341
+ """
342
+ Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
343
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
344
+ flash attention and deal with padding tokens in case the input contains any of them.
345
+ """
346
+
347
+ def __init__(self, *args, **kwargs):
348
+ super().__init__(*args, **kwargs)
349
+
350
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
351
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
352
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
353
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
354
+
355
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
356
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ key_value_states: Optional[torch.Tensor] = None,
362
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ layer_head_mask: Optional[torch.Tensor] = None,
365
+ output_attentions: bool = False,
366
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
367
+ # Data2VecAudioFlashAttention2 attention does not support output_attentions
368
+ if output_attentions:
369
+ raise ValueError("Data2VecAudioFlashAttention2 attention does not support output_attentions")
370
+
371
+ # if key_value_states are provided this layer is used as a cross-attention layer
372
+ # for the decoder
373
+ is_cross_attention = key_value_states is not None
374
+
375
+ bsz, q_len, _ = hidden_states.size()
376
+
377
+ # get query proj
378
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
379
+ # get key, value proj
380
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
381
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
382
+ # the provided `key_value_states` to support prefix tuning
383
+ if (
384
+ is_cross_attention
385
+ and past_key_value is not None
386
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
387
+ ):
388
+ # reuse k,v, cross_attentions
389
+ key_states = past_key_value[0].transpose(1, 2)
390
+ value_states = past_key_value[1].transpose(1, 2)
391
+ elif is_cross_attention:
392
+ # cross_attentions
393
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
394
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
395
+ elif past_key_value is not None:
396
+ # reuse k, v, self_attention
397
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
398
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
399
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
400
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
401
+ else:
402
+ # self_attention
403
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
404
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
405
+
406
+ if self.is_decoder:
407
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
408
+ # Further calls to cross_attention layer can then reuse all cross-attention
409
+ # key/value_states (first "if" case)
410
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
411
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
412
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
413
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
414
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
415
+
416
+ kv_seq_len = key_states.shape[-2]
417
+ if past_key_value is not None:
418
+ kv_seq_len += past_key_value[0].shape[-2]
419
+
420
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
421
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
422
+ # cast them back in the correct dtype just to be sure everything works as expected.
423
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
424
+ # in fp32. (LlamaRMSNorm handles it correctly)
425
+
426
+ input_dtype = query_states.dtype
427
+ if input_dtype == torch.float32:
428
+ if torch.is_autocast_enabled():
429
+ target_dtype = torch.get_autocast_gpu_dtype()
430
+ # Handle the case where the model is quantized
431
+ elif hasattr(self.config, "_pre_quantization_dtype"):
432
+ target_dtype = self.config._pre_quantization_dtype
433
+ else:
434
+ target_dtype = self.q_proj.weight.dtype
435
+
436
+ logger.warning_once(
437
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
438
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
439
+ f" {target_dtype}."
440
+ )
441
+
442
+ query_states = query_states.to(target_dtype)
443
+ key_states = key_states.to(target_dtype)
444
+ value_states = value_states.to(target_dtype)
445
+
446
+ attn_output = _flash_attention_forward(
447
+ query_states,
448
+ key_states,
449
+ value_states,
450
+ attention_mask,
451
+ q_len,
452
+ dropout=self.dropout if self.training else 0.0,
453
+ is_causal=self.is_causal,
454
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
455
+ )
456
+
457
+ attn_output = attn_output.reshape(bsz, q_len, -1)
458
+ attn_output = self.out_proj(attn_output)
459
+
460
+ if not output_attentions:
461
+ attn_weights = None
462
+
463
+ return attn_output, attn_weights, past_key_value
464
+
465
+
466
+ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
467
+ def forward(
468
+ self,
469
+ hidden_states: torch.Tensor,
470
+ key_value_states: Optional[torch.Tensor] = None,
471
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ layer_head_mask: Optional[torch.Tensor] = None,
474
+ output_attentions: bool = False,
475
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
476
+ """Input shape: Batch x Time x Channel"""
477
+ if output_attentions or layer_head_mask is not None:
478
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
479
+ logger.warning_once(
480
+ "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
481
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
482
+ )
483
+ return super().forward(
484
+ hidden_states,
485
+ key_value_states=key_value_states,
486
+ past_key_value=past_key_value,
487
+ attention_mask=attention_mask,
488
+ layer_head_mask=layer_head_mask,
489
+ output_attentions=output_attentions,
490
+ )
491
+
492
+ # if key_value_states are provided this layer is used as a cross-attention layer
493
+ # for the decoder
494
+ is_cross_attention = key_value_states is not None
495
+
496
+ bsz, tgt_len, _ = hidden_states.size()
497
+
498
+ # get query proj
499
+ query_states = self.q_proj(hidden_states)
500
+ # get key, value proj
501
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
502
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
503
+ # the provided `key_value_states` to support prefix tuning
504
+ if (
505
+ is_cross_attention
506
+ and past_key_value is not None
507
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
508
+ ):
509
+ # reuse k,v, cross_attentions
510
+ key_states = past_key_value[0]
511
+ value_states = past_key_value[1]
512
+ elif is_cross_attention:
513
+ # cross_attentions
514
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
515
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
516
+ elif past_key_value is not None:
517
+ # reuse k, v, self_attention
518
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
519
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
520
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
521
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
522
+ else:
523
+ # self_attention
524
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
525
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
526
+
527
+ if self.is_decoder:
528
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
529
+ # Further calls to cross_attention layer can then reuse all cross-attention
530
+ # key/value_states (first "if" case)
531
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
532
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
533
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
534
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
535
+ past_key_value = (key_states, value_states)
536
+
537
+ query_states = self._shape(query_states, tgt_len, bsz)
538
+
539
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
540
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
541
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
542
+ is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
543
+
544
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
545
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
546
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
547
+ query_states,
548
+ key_states,
549
+ value_states,
550
+ attn_mask=attention_mask,
551
+ dropout_p=self.dropout if self.training else 0.0,
552
+ is_causal=is_causal,
553
+ )
554
+
555
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
556
+ raise ValueError(
557
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
558
+ f" {attn_output.size()}"
559
+ )
560
+
561
+ attn_output = attn_output.transpose(1, 2)
562
+
563
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
564
+ # partitioned across GPUs when using tensor-parallelism.
565
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
566
+
567
+ attn_output = self.out_proj(attn_output)
568
+
569
+ return attn_output, None, past_key_value
570
+
571
+
572
+ class Data2VecAudioFeedForward(nn.Module):
573
+ def __init__(self, config):
574
+ super().__init__()
575
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
576
+
577
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
578
+ if isinstance(config.hidden_act, str):
579
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
580
+ else:
581
+ self.intermediate_act_fn = config.hidden_act
582
+
583
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
584
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
585
+
586
+ def forward(self, hidden_states):
587
+ hidden_states = self.intermediate_dense(hidden_states)
588
+ hidden_states = self.intermediate_act_fn(hidden_states)
589
+ hidden_states = self.intermediate_dropout(hidden_states)
590
+
591
+ hidden_states = self.output_dense(hidden_states)
592
+ hidden_states = self.output_dropout(hidden_states)
593
+ return hidden_states
594
+
595
+
596
+ DATA2VEC_AUDIO_ATTENTION_CLASSES = {
597
+ "eager": Data2VecAudioAttention,
598
+ "sdpa": Data2VecAudioSdpaAttention,
599
+ "flash_attention_2": Data2VecAudioFlashAttention2,
600
+ }
601
+
602
+
603
+ class Data2VecAudioEncoderLayer(nn.Module):
604
+ def __init__(self, config):
605
+ super().__init__()
606
+ self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
607
+ embed_dim=config.hidden_size,
608
+ num_heads=config.num_attention_heads,
609
+ dropout=config.attention_dropout,
610
+ is_decoder=False,
611
+ )
612
+
613
+ self.dropout = nn.Dropout(config.hidden_dropout)
614
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
615
+ self.feed_forward = Data2VecAudioFeedForward(config)
616
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+
618
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
619
+ attn_residual = hidden_states
620
+ hidden_states, attn_weights, _ = self.attention(
621
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
622
+ )
623
+ hidden_states = self.dropout(hidden_states)
624
+ hidden_states = attn_residual + hidden_states
625
+
626
+ hidden_states = self.layer_norm(hidden_states)
627
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
628
+ hidden_states = self.final_layer_norm(hidden_states)
629
+
630
+ outputs = (hidden_states,)
631
+
632
+ if output_attentions:
633
+ outputs += (attn_weights,)
634
+
635
+ return outputs
636
+
637
+
638
+ class Data2VecAudioEncoder(nn.Module):
639
+ def __init__(self, config):
640
+ super().__init__()
641
+ self.config = config
642
+ self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config)
643
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
644
+ self.dropout = nn.Dropout(config.hidden_dropout)
645
+ self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
646
+ self.gradient_checkpointing = False
647
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
648
+
649
+ def forward(
650
+ self,
651
+ hidden_states: torch.tensor,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ output_attentions: bool = False,
654
+ output_hidden_states: bool = False,
655
+ return_dict: bool = True,
656
+ ):
657
+ all_hidden_states = () if output_hidden_states else None
658
+ all_self_attentions = () if output_attentions else None
659
+
660
+ if attention_mask is not None:
661
+ # make sure padded tokens output 0
662
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
663
+ hidden_states[~expand_attention_mask] = 0
664
+ if self._use_flash_attention_2:
665
+ # 2d mask is passed through the layers
666
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
667
+ else:
668
+ # extend attention_mask
669
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
670
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
671
+ attention_mask = attention_mask.expand(
672
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
673
+ )
674
+
675
+ position_embeddings = self.pos_conv_embed(hidden_states)
676
+ hidden_states = hidden_states + position_embeddings
677
+ hidden_states = self.layer_norm(hidden_states)
678
+ hidden_states = self.dropout(hidden_states)
679
+
680
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
681
+
682
+ for layer in self.layers:
683
+ if output_hidden_states:
684
+ all_hidden_states = all_hidden_states + (hidden_states,)
685
+
686
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
687
+ dropout_probability = torch.rand([])
688
+
689
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
690
+ if not skip_the_layer or synced_gpus:
691
+ # under fsdp or deepspeed zero3 all gpus must run in sync
692
+ if self.gradient_checkpointing and self.training:
693
+ layer_outputs = self._gradient_checkpointing_func(
694
+ layer.__call__,
695
+ hidden_states,
696
+ attention_mask,
697
+ output_attentions,
698
+ )
699
+ else:
700
+ layer_outputs = layer(
701
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
702
+ )
703
+ hidden_states = layer_outputs[0]
704
+
705
+ if skip_the_layer:
706
+ layer_outputs = (None, None)
707
+
708
+ if output_attentions:
709
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
710
+
711
+ if output_hidden_states:
712
+ all_hidden_states = all_hidden_states + (hidden_states,)
713
+
714
+ if not return_dict:
715
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
716
+ return BaseModelOutput(
717
+ last_hidden_state=hidden_states,
718
+ hidden_states=all_hidden_states,
719
+ attentions=all_self_attentions,
720
+ )
721
+
722
+
723
+ class Data2VecAudioAdapterLayer(nn.Module):
724
+ def __init__(self, config):
725
+ super().__init__()
726
+ self.conv = nn.Conv1d(
727
+ config.output_hidden_size,
728
+ 2 * config.output_hidden_size,
729
+ config.adapter_kernel_size,
730
+ stride=config.adapter_stride,
731
+ padding=1,
732
+ )
733
+
734
+ def forward(self, hidden_states):
735
+ hidden_states = self.conv(hidden_states)
736
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
737
+
738
+ return hidden_states
739
+
740
+
741
+ class Data2VecAudioAdapter(nn.Module):
742
+ def __init__(self, config):
743
+ super().__init__()
744
+
745
+ # feature dim might need to be down-projected
746
+ if config.output_hidden_size != config.hidden_size:
747
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
748
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
749
+ else:
750
+ self.proj = self.proj_layer_norm = None
751
+
752
+ self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers))
753
+ self.layerdrop = config.layerdrop
754
+
755
+ def forward(self, hidden_states):
756
+ # down project hidden_states if necessary
757
+ if self.proj is not None and self.proj_layer_norm is not None:
758
+ hidden_states = self.proj(hidden_states)
759
+ hidden_states = self.proj_layer_norm(hidden_states)
760
+
761
+ hidden_states = hidden_states.transpose(1, 2)
762
+
763
+ for layer in self.layers:
764
+ layerdrop_prob = np.random.random()
765
+ if not self.training or (layerdrop_prob > self.layerdrop):
766
+ hidden_states = layer(hidden_states)
767
+
768
+ hidden_states = hidden_states.transpose(1, 2)
769
+ return hidden_states
770
+
771
+
772
+ class Data2VecAudioPreTrainedModel(PreTrainedModel):
773
+ """
774
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
775
+ models.
776
+ """
777
+
778
+ config_class = Data2VecAudioConfig
779
+ base_model_prefix = "data2vec_audio"
780
+ main_input_name = "input_values"
781
+ supports_gradient_checkpointing = True
782
+ _supports_flash_attn_2 = True
783
+ _supports_sdpa = True
784
+
785
+ def _init_weights(self, module):
786
+ """Initialize the weights"""
787
+ if isinstance(module, Data2VecAudioFeatureProjection):
788
+ k = math.sqrt(1 / module.projection.in_features)
789
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
790
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
791
+ elif isinstance(module, Data2VecAudioPositionalConvLayer):
792
+ nn.init.constant_(module.conv.bias, 0)
793
+ elif isinstance(module, nn.Linear):
794
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
795
+
796
+ if module.bias is not None:
797
+ module.bias.data.zero_()
798
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
799
+ if module.bias is not None:
800
+ module.bias.data.zero_()
801
+ if module.weight is not None:
802
+ module.weight.data.fill_(1.0)
803
+ elif isinstance(module, nn.Conv1d):
804
+ nn.init.kaiming_normal_(module.weight)
805
+
806
+ if module.bias is not None:
807
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
808
+ nn.init.uniform_(module.bias, a=-k, b=k)
809
+
810
+ def _get_feat_extract_output_lengths(
811
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
812
+ ):
813
+ """
814
+ Computes the output length of the convolutional layers
815
+ """
816
+
817
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
818
+
819
+ def _conv_out_length(input_length, kernel_size, stride):
820
+ # 1D convolutional layer output length formula taken
821
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
822
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
823
+
824
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
825
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
826
+
827
+ if add_adapter:
828
+ for _ in range(self.config.num_adapter_layers):
829
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
830
+
831
+ return input_lengths
832
+
833
+ def _get_feature_vector_attention_mask(
834
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
835
+ ):
836
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
837
+ # on inference mode.
838
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
839
+
840
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
841
+ output_lengths = output_lengths.to(torch.long)
842
+
843
+ batch_size = attention_mask.shape[0]
844
+
845
+ attention_mask = torch.zeros(
846
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
847
+ )
848
+ # these two operations makes sure that all values before the output lengths idxs are attended to
849
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
850
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
851
+ return attention_mask
852
+
853
+
854
+ def _compute_mask_indices(
855
+ shape: Tuple[int, int],
856
+ mask_prob: float,
857
+ mask_length: int,
858
+ attention_mask: Optional[torch.LongTensor] = None,
859
+ min_masks: int = 0,
860
+ ) -> np.ndarray:
861
+ """
862
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
863
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
864
+ CPU as part of the preprocessing during training.
865
+
866
+ Args:
867
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
868
+ the first element is the batch size and the second element is the length of the axis to span.
869
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
870
+ independently generated mask spans of length `mask_length` is computed by
871
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
872
+ actual percentage will be smaller.
873
+ mask_length: size of the mask
874
+ min_masks: minimum number of masked spans
875
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
876
+ each batch dimension.
877
+ """
878
+ batch_size, sequence_length = shape
879
+
880
+ if mask_length < 1:
881
+ raise ValueError("`mask_length` has to be bigger than 0.")
882
+
883
+ if mask_length > sequence_length:
884
+ raise ValueError(
885
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
886
+ f" and `sequence_length`: {sequence_length}`"
887
+ )
888
+
889
+ # epsilon is used for probabilistic rounding
890
+ epsilon = np.random.rand(1).item()
891
+
892
+ def compute_num_masked_span(input_length):
893
+ """Given input length, compute how many spans should be masked"""
894
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
895
+ num_masked_span = max(num_masked_span, min_masks)
896
+
897
+ # make sure num masked span <= sequence_length
898
+ if num_masked_span * mask_length > sequence_length:
899
+ num_masked_span = sequence_length // mask_length
900
+
901
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
902
+ if input_length - (mask_length - 1) < num_masked_span:
903
+ num_masked_span = max(input_length - (mask_length - 1), 0)
904
+
905
+ return num_masked_span
906
+
907
+ # compute number of masked spans in batch
908
+ input_lengths = (
909
+ attention_mask.detach().sum(-1).tolist()
910
+ if attention_mask is not None
911
+ else [sequence_length for _ in range(batch_size)]
912
+ )
913
+
914
+ # SpecAugment mask to fill
915
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
916
+ spec_aug_mask_idxs = []
917
+
918
+ max_num_masked_span = compute_num_masked_span(sequence_length)
919
+
920
+ if max_num_masked_span == 0:
921
+ return spec_aug_mask
922
+
923
+ for input_length in input_lengths:
924
+ # compute num of masked spans for this input
925
+ num_masked_span = compute_num_masked_span(input_length)
926
+
927
+ # get random indices to mask
928
+ spec_aug_mask_idx = np.random.choice(
929
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
930
+ )
931
+
932
+ # pick first sampled index that will serve as a dummy index to pad vector
933
+ # to ensure same dimension for all batches due to probabilistic rounding
934
+ # Picking first sample just pads those vectors twice.
935
+ if len(spec_aug_mask_idx) == 0:
936
+ # this case can only happen if `input_length` is strictly smaller then
937
+ # `sequence_length` in which case the last token has to be a padding
938
+ # token which we can use as a dummy mask id
939
+ dummy_mask_idx = sequence_length - 1
940
+ else:
941
+ dummy_mask_idx = spec_aug_mask_idx[0]
942
+
943
+ spec_aug_mask_idx = np.concatenate(
944
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
945
+ )
946
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
947
+
948
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
949
+
950
+ # expand masked indices to masked spans
951
+ spec_aug_mask_idxs = np.broadcast_to(
952
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
953
+ )
954
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
955
+
956
+ # add offset to the starting indexes so that indexes now create a span
957
+ offsets = np.arange(mask_length)[None, None, :]
958
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
959
+ batch_size, max_num_masked_span * mask_length
960
+ )
961
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
962
+
963
+ # ensure that we cannot have indices larger than sequence_length
964
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
965
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
966
+
967
+ # scatter indices to mask
968
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
969
+
970
+ return spec_aug_mask
971
+
972
+
973
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
974
+
975
+
976
+ DATA2VEC_AUDIO_START_DOCSTRING = r"""
977
+ Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
978
+ Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
979
+ Michael Auli.
980
+
981
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
982
+ library implements for all its model (such as downloading or saving etc.).
983
+
984
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
985
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
986
+ behavior.
987
+
988
+ Parameters:
989
+ config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
990
+ Initializing with a config file does not load the weights associated with the model, only the
991
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
992
+ """
993
+
994
+ DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
995
+ Args:
996
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
997
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
998
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
999
+ soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
1000
+ conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
1001
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1002
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1003
+ 1]`:
1004
+
1005
+ - 1 for tokens that are **not masked**,
1006
+ - 0 for tokens that are **masked**.
1007
+
1008
+ [What are attention masks?](../glossary#attention-mask)
1009
+
1010
+ <Tip warning={true}>
1011
+
1012
+ `attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
1013
+ True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
1014
+ `attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
1015
+ because there are more than one convolutional layer in the positional encodings. For a more detailed
1016
+ explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
1017
+
1018
+ </Tip>
1019
+
1020
+ output_attentions (`bool`, *optional*):
1021
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1022
+ tensors for more detail.
1023
+ output_hidden_states (`bool`, *optional*):
1024
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1025
+ more detail.
1026
+ return_dict (`bool`, *optional*):
1027
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1028
+ """
1029
+
1030
+ Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
1031
+
1032
+
1033
+ @add_start_docstrings(
1034
+ "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
1035
+ DATA2VEC_AUDIO_START_DOCSTRING,
1036
+ )
1037
+ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
1038
+ def __init__(self, config: Data2VecAudioConfig):
1039
+ super().__init__(config)
1040
+ self.config = config
1041
+ self.feature_extractor = Data2VecAudioFeatureEncoder(config)
1042
+ self.feature_projection = Data2VecAudioFeatureProjection(config)
1043
+
1044
+ # model only needs masking vector if mask prob is > 0.0
1045
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1046
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
1047
+
1048
+ self.encoder = Data2VecAudioEncoder(config)
1049
+
1050
+ self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
1051
+
1052
+ # Initialize weights and apply final processing
1053
+ self.post_init()
1054
+
1055
+ def freeze_feature_encoder(self):
1056
+ """
1057
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1058
+ not be updated during training.
1059
+ """
1060
+ self.feature_extractor._freeze_parameters()
1061
+
1062
+ def _mask_hidden_states(
1063
+ self,
1064
+ hidden_states: torch.FloatTensor,
1065
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1066
+ attention_mask: Optional[torch.LongTensor] = None,
1067
+ ):
1068
+ """
1069
+ Masks extracted features along time axis and/or along feature axis according to
1070
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1071
+ """
1072
+
1073
+ # `config.apply_spec_augment` can set masking to False
1074
+ if not getattr(self.config, "apply_spec_augment", True):
1075
+ return hidden_states
1076
+
1077
+ # generate indices & apply SpecAugment along time axis
1078
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1079
+
1080
+ if mask_time_indices is not None:
1081
+ # apply SpecAugment along time axis with given mask_time_indices
1082
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1083
+ elif self.config.mask_time_prob > 0 and self.training:
1084
+ mask_time_indices = _compute_mask_indices(
1085
+ (batch_size, sequence_length),
1086
+ mask_prob=self.config.mask_time_prob,
1087
+ mask_length=self.config.mask_time_length,
1088
+ attention_mask=attention_mask,
1089
+ min_masks=self.config.mask_time_min_masks,
1090
+ )
1091
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1092
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1093
+
1094
+ if self.config.mask_feature_prob > 0 and self.training:
1095
+ # generate indices & apply SpecAugment along feature axis
1096
+ mask_feature_indices = _compute_mask_indices(
1097
+ (batch_size, hidden_size),
1098
+ mask_prob=self.config.mask_feature_prob,
1099
+ mask_length=self.config.mask_feature_length,
1100
+ min_masks=self.config.mask_feature_min_masks,
1101
+ )
1102
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1103
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1104
+ hidden_states[mask_feature_indices] = 0
1105
+
1106
+ return hidden_states
1107
+
1108
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
1109
+ @add_code_sample_docstrings(
1110
+ checkpoint=_CHECKPOINT_FOR_DOC,
1111
+ output_type=Data2VecAudioBaseModelOutput,
1112
+ config_class=_CONFIG_FOR_DOC,
1113
+ modality="audio",
1114
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1115
+ )
1116
+ def forward(
1117
+ self,
1118
+ input_values: Optional[torch.Tensor],
1119
+ attention_mask: Optional[torch.Tensor] = None,
1120
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1121
+ output_attentions: Optional[bool] = None,
1122
+ output_hidden_states: Optional[bool] = None,
1123
+ return_dict: Optional[bool] = None,
1124
+ ) -> Union[Tuple, Data2VecAudioBaseModelOutput]:
1125
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1126
+ output_hidden_states = (
1127
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1128
+ )
1129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1130
+
1131
+ extract_features = self.feature_extractor(input_values)
1132
+ extract_features = extract_features.transpose(1, 2)
1133
+
1134
+ if attention_mask is not None:
1135
+ # compute reduced attention_mask corresponding to feature vectors
1136
+ attention_mask = self._get_feature_vector_attention_mask(
1137
+ extract_features.shape[1], attention_mask, add_adapter=False
1138
+ )
1139
+
1140
+ hidden_states, extract_features = self.feature_projection(extract_features)
1141
+ hidden_states = self._mask_hidden_states(
1142
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1143
+ )
1144
+
1145
+ encoder_outputs = self.encoder(
1146
+ hidden_states,
1147
+ attention_mask=attention_mask,
1148
+ output_attentions=output_attentions,
1149
+ output_hidden_states=output_hidden_states,
1150
+ return_dict=return_dict,
1151
+ )
1152
+
1153
+ hidden_states = encoder_outputs[0]
1154
+
1155
+ if self.adapter is not None:
1156
+ hidden_states = self.adapter(hidden_states)
1157
+
1158
+ if not return_dict:
1159
+ return (hidden_states, extract_features) + encoder_outputs[1:]
1160
+
1161
+ return Data2VecAudioBaseModelOutput(
1162
+ last_hidden_state=hidden_states,
1163
+ extract_features=extract_features,
1164
+ hidden_states=encoder_outputs.hidden_states,
1165
+ attentions=encoder_outputs.attentions,
1166
+ )
1167
+
1168
+
1169
+ _HIDDEN_STATES_START_POSITION = 2
1170
+
1171
+ # CTC docstring
1172
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
1173
+ _CTC_EXPECTED_LOSS = 66.95
1174
+
1175
+
1176
+ @add_start_docstrings(
1177
+ """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1178
+ DATA2VEC_AUDIO_START_DOCSTRING,
1179
+ )
1180
+ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
1181
+ def __init__(self, config):
1182
+ super().__init__(config)
1183
+
1184
+ self.data2vec_audio = Data2VecAudioModel(config)
1185
+ self.dropout = nn.Dropout(config.final_dropout)
1186
+
1187
+ if config.vocab_size is None:
1188
+ raise ValueError(
1189
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1190
+ "does not define the vocabulary size of the language model head. Please "
1191
+ "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1192
+ "or define `vocab_size` of your model's configuration."
1193
+ )
1194
+ output_hidden_size = (
1195
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1196
+ )
1197
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1198
+
1199
+ # Initialize weights and apply final processing
1200
+ self.post_init()
1201
+
1202
+ def freeze_feature_extractor(self):
1203
+ """
1204
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1205
+ not be updated during training.
1206
+ """
1207
+ warnings.warn(
1208
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1209
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1210
+ FutureWarning,
1211
+ )
1212
+ self.freeze_feature_encoder()
1213
+
1214
+ def freeze_feature_encoder(self):
1215
+ """
1216
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1217
+ not be updated during training.
1218
+ """
1219
+ self.data2vec_audio.feature_extractor._freeze_parameters()
1220
+
1221
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
1222
+ @add_code_sample_docstrings(
1223
+ checkpoint=_CHECKPOINT_FOR_DOC,
1224
+ output_type=CausalLMOutput,
1225
+ config_class=_CONFIG_FOR_DOC,
1226
+ expected_output=_CTC_EXPECTED_OUTPUT,
1227
+ expected_loss=_CTC_EXPECTED_LOSS,
1228
+ )
1229
+ def forward(
1230
+ self,
1231
+ input_values: Optional[torch.Tensor],
1232
+ attention_mask: Optional[torch.Tensor] = None,
1233
+ output_attentions: Optional[bool] = None,
1234
+ output_hidden_states: Optional[bool] = None,
1235
+ return_dict: Optional[bool] = None,
1236
+ labels: Optional[torch.Tensor] = None,
1237
+ ) -> Union[Tuple, CausalLMOutput]:
1238
+ r"""
1239
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1240
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1241
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1242
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1243
+ config.vocab_size - 1]`.
1244
+ """
1245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
+
1247
+ if labels is not None and labels.max() >= self.config.vocab_size:
1248
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1249
+
1250
+ outputs = self.data2vec_audio(
1251
+ input_values,
1252
+ attention_mask=attention_mask,
1253
+ output_attentions=output_attentions,
1254
+ output_hidden_states=output_hidden_states,
1255
+ return_dict=return_dict,
1256
+ )
1257
+
1258
+ hidden_states = outputs[0]
1259
+ hidden_states = self.dropout(hidden_states)
1260
+
1261
+ logits = self.lm_head(hidden_states)
1262
+
1263
+ loss = None
1264
+ if labels is not None:
1265
+ # retrieve loss input_lengths from attention_mask
1266
+ attention_mask = (
1267
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1268
+ )
1269
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1270
+
1271
+ # assuming that padded tokens are filled with -100
1272
+ # when not being attended to
1273
+ labels_mask = labels >= 0
1274
+ target_lengths = labels_mask.sum(-1)
1275
+ flattened_targets = labels.masked_select(labels_mask)
1276
+
1277
+ # ctc_loss doesn't support fp16
1278
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1279
+
1280
+ with torch.backends.cudnn.flags(enabled=False):
1281
+ loss = nn.functional.ctc_loss(
1282
+ log_probs,
1283
+ flattened_targets,
1284
+ input_lengths,
1285
+ target_lengths,
1286
+ blank=self.config.pad_token_id,
1287
+ reduction=self.config.ctc_loss_reduction,
1288
+ zero_infinity=self.config.ctc_zero_infinity,
1289
+ )
1290
+
1291
+ if not return_dict:
1292
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1293
+ return ((loss,) + output) if loss is not None else output
1294
+
1295
+ return CausalLMOutput(
1296
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1297
+ )
1298
+
1299
+
1300
+ @add_start_docstrings(
1301
+ """
1302
+ Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
1303
+ like SUPERB Keyword Spotting.
1304
+ """,
1305
+ DATA2VEC_AUDIO_START_DOCSTRING,
1306
+ )
1307
+ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
1308
+ def __init__(self, config):
1309
+ super().__init__(config)
1310
+
1311
+ if hasattr(config, "add_adapter") and config.add_adapter:
1312
+ raise ValueError(
1313
+ "Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
1314
+ )
1315
+ self.data2vec_audio = Data2VecAudioModel(config)
1316
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1317
+ if config.use_weighted_layer_sum:
1318
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1319
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1320
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1321
+
1322
+ # Initialize weights and apply final processing
1323
+ self.post_init()
1324
+
1325
+ def freeze_feature_extractor(self):
1326
+ """
1327
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
1328
+ not be updated during training.
1329
+ """
1330
+ warnings.warn(
1331
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1332
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1333
+ FutureWarning,
1334
+ )
1335
+ self.freeze_feature_encoder()
1336
+
1337
+ def freeze_feature_encoder(self):
1338
+ """
1339
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1340
+ not be updated during training.
1341
+ """
1342
+ self.data2vec_audio.feature_extractor._freeze_parameters()
1343
+
1344
+ def freeze_base_model(self):
1345
+ """
1346
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1347
+ be updated during training. Only the classification head will be updated.
1348
+ """
1349
+ for param in self.data2vec_audio.parameters():
1350
+ param.requires_grad = False
1351
+
1352
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
1353
+ @add_code_sample_docstrings(
1354
+ checkpoint=_CHECKPOINT_FOR_DOC,
1355
+ output_type=SequenceClassifierOutput,
1356
+ config_class=_CONFIG_FOR_DOC,
1357
+ modality="audio",
1358
+ )
1359
+ def forward(
1360
+ self,
1361
+ input_values: Optional[torch.Tensor],
1362
+ attention_mask: Optional[torch.Tensor] = None,
1363
+ output_attentions: Optional[bool] = None,
1364
+ output_hidden_states: Optional[bool] = None,
1365
+ return_dict: Optional[bool] = None,
1366
+ labels: Optional[torch.Tensor] = None,
1367
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1368
+ r"""
1369
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1370
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1371
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1372
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1373
+ """
1374
+
1375
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1376
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1377
+
1378
+ outputs = self.data2vec_audio(
1379
+ input_values,
1380
+ attention_mask=attention_mask,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+
1386
+ if self.config.use_weighted_layer_sum:
1387
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1388
+ hidden_states = torch.stack(hidden_states, dim=1)
1389
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1390
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1391
+ else:
1392
+ hidden_states = outputs[0]
1393
+
1394
+ hidden_states = self.projector(hidden_states)
1395
+ if attention_mask is None:
1396
+ pooled_output = hidden_states.mean(dim=1)
1397
+ else:
1398
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1399
+ expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
1400
+ hidden_states[~expand_padding_mask] = 0.0
1401
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1402
+
1403
+ logits = self.classifier(pooled_output)
1404
+
1405
+ loss = None
1406
+ if labels is not None:
1407
+ loss_fct = CrossEntropyLoss()
1408
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1409
+
1410
+ if not return_dict:
1411
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1412
+ return ((loss,) + output) if loss is not None else output
1413
+
1414
+ return SequenceClassifierOutput(
1415
+ loss=loss,
1416
+ logits=logits,
1417
+ hidden_states=outputs.hidden_states,
1418
+ attentions=outputs.attentions,
1419
+ )
1420
+
1421
+
1422
+ @add_start_docstrings(
1423
+ """
1424
+ Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
1425
+ """,
1426
+ DATA2VEC_AUDIO_START_DOCSTRING,
1427
+ )
1428
+ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
1429
+ def __init__(self, config):
1430
+ super().__init__(config)
1431
+
1432
+ if hasattr(config, "add_adapter") and config.add_adapter:
1433
+ raise ValueError(
1434
+ "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
1435
+ )
1436
+ self.data2vec_audio = Data2VecAudioModel(config)
1437
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1438
+ if config.use_weighted_layer_sum:
1439
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1440
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1441
+ self.num_labels = config.num_labels
1442
+
1443
+ self.init_weights()
1444
+
1445
+ def freeze_feature_extractor(self):
1446
+ """
1447
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1448
+ not be updated during training.
1449
+ """
1450
+ warnings.warn(
1451
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1452
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1453
+ FutureWarning,
1454
+ )
1455
+ self.freeze_feature_encoder()
1456
+
1457
+ def freeze_feature_encoder(self):
1458
+ """
1459
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1460
+ not be updated during training.
1461
+ """
1462
+ self.data2vec_audio.feature_extractor._freeze_parameters()
1463
+
1464
+ def freeze_base_model(self):
1465
+ """
1466
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1467
+ be updated during training. Only the classification head will be updated.
1468
+ """
1469
+ for param in self.data2vec_audio.parameters():
1470
+ param.requires_grad = False
1471
+
1472
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
1473
+ @add_code_sample_docstrings(
1474
+ checkpoint=_CHECKPOINT_FOR_DOC,
1475
+ output_type=TokenClassifierOutput,
1476
+ config_class=_CONFIG_FOR_DOC,
1477
+ modality="audio",
1478
+ )
1479
+ def forward(
1480
+ self,
1481
+ input_values: Optional[torch.Tensor],
1482
+ attention_mask: Optional[torch.Tensor] = None,
1483
+ labels: Optional[torch.Tensor] = None,
1484
+ output_attentions: Optional[bool] = None,
1485
+ output_hidden_states: Optional[bool] = None,
1486
+ return_dict: Optional[bool] = None,
1487
+ ) -> Union[Tuple, TokenClassifierOutput]:
1488
+ r"""
1489
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1490
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1491
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1492
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1493
+ """
1494
+
1495
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1496
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1497
+
1498
+ outputs = self.data2vec_audio(
1499
+ input_values,
1500
+ attention_mask=attention_mask,
1501
+ output_attentions=output_attentions,
1502
+ output_hidden_states=output_hidden_states,
1503
+ return_dict=return_dict,
1504
+ )
1505
+
1506
+ if self.config.use_weighted_layer_sum:
1507
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1508
+ hidden_states = torch.stack(hidden_states, dim=1)
1509
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1510
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1511
+ else:
1512
+ hidden_states = outputs[0]
1513
+
1514
+ logits = self.classifier(hidden_states)
1515
+
1516
+ loss = None
1517
+ if labels is not None:
1518
+ loss_fct = CrossEntropyLoss()
1519
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1520
+
1521
+ if not return_dict:
1522
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1523
+ return output
1524
+
1525
+ return TokenClassifierOutput(
1526
+ loss=loss,
1527
+ logits=logits,
1528
+ hidden_states=outputs.hidden_states,
1529
+ attentions=outputs.attentions,
1530
+ )
1531
+
1532
+
1533
+ class AMSoftmaxLoss(nn.Module):
1534
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1535
+ super(AMSoftmaxLoss, self).__init__()
1536
+ self.scale = scale
1537
+ self.margin = margin
1538
+ self.num_labels = num_labels
1539
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1540
+ self.loss = nn.CrossEntropyLoss()
1541
+
1542
+ def forward(self, hidden_states, labels):
1543
+ labels = labels.flatten()
1544
+ weight = nn.functional.normalize(self.weight, dim=0)
1545
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
1546
+ cos_theta = torch.mm(hidden_states, weight)
1547
+ psi = cos_theta - self.margin
1548
+
1549
+ onehot = nn.functional.one_hot(labels, self.num_labels)
1550
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1551
+ loss = self.loss(logits, labels)
1552
+
1553
+ return loss
1554
+
1555
+
1556
+ class TDNNLayer(nn.Module):
1557
+ def __init__(self, config, layer_id=0):
1558
+ super().__init__()
1559
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1560
+ self.out_conv_dim = config.tdnn_dim[layer_id]
1561
+ self.kernel_size = config.tdnn_kernel[layer_id]
1562
+ self.dilation = config.tdnn_dilation[layer_id]
1563
+
1564
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1565
+ self.activation = nn.ReLU()
1566
+
1567
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1568
+ if is_peft_available():
1569
+ from peft.tuners.lora import LoraLayer
1570
+
1571
+ if is_peft_available():
1572
+ if isinstance(self.kernel, LoraLayer):
1573
+ warnings.warn(
1574
+ "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
1575
+ "You should exclude TDNNLayer from LoRA's target modules.",
1576
+ )
1577
+
1578
+ # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
1579
+ hidden_states = hidden_states.transpose(1, 2)
1580
+ weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
1581
+ hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
1582
+ hidden_states = hidden_states.transpose(1, 2)
1583
+
1584
+ hidden_states = self.activation(hidden_states)
1585
+ return hidden_states
1586
+
1587
+
1588
+ @add_start_docstrings(
1589
+ """
1590
+ Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1591
+ """,
1592
+ DATA2VEC_AUDIO_START_DOCSTRING,
1593
+ )
1594
+ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
1595
+ def __init__(self, config):
1596
+ super().__init__(config)
1597
+
1598
+ self.data2vec_audio = Data2VecAudioModel(config)
1599
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1600
+ if config.use_weighted_layer_sum:
1601
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1602
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1603
+
1604
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1605
+ self.tdnn = nn.ModuleList(tdnn_layers)
1606
+
1607
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1608
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1609
+
1610
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1611
+
1612
+ self.init_weights()
1613
+
1614
+ def freeze_feature_extractor(self):
1615
+ """
1616
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1617
+ not be updated during training.
1618
+ """
1619
+ warnings.warn(
1620
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
1621
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
1622
+ FutureWarning,
1623
+ )
1624
+ self.freeze_feature_encoder()
1625
+
1626
+ def freeze_feature_encoder(self):
1627
+ """
1628
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1629
+ not be updated during training.
1630
+ """
1631
+ self.data2vec_audio.feature_extractor._freeze_parameters()
1632
+
1633
+ def freeze_base_model(self):
1634
+ """
1635
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1636
+ be updated during training. Only the classification head will be updated.
1637
+ """
1638
+ for param in self.data2vec_audio.parameters():
1639
+ param.requires_grad = False
1640
+
1641
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
1642
+ """
1643
+ Computes the output length of the TDNN layers
1644
+ """
1645
+
1646
+ def _conv_out_length(input_length, kernel_size, stride):
1647
+ # 1D convolutional layer output length formula taken
1648
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1649
+ return (input_length - kernel_size) // stride + 1
1650
+
1651
+ for kernel_size in self.config.tdnn_kernel:
1652
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
1653
+
1654
+ return input_lengths
1655
+
1656
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
1657
+ @add_code_sample_docstrings(
1658
+ checkpoint=_CHECKPOINT_FOR_DOC,
1659
+ output_type=XVectorOutput,
1660
+ config_class=_CONFIG_FOR_DOC,
1661
+ modality="audio",
1662
+ )
1663
+ def forward(
1664
+ self,
1665
+ input_values: Optional[torch.Tensor],
1666
+ attention_mask: Optional[torch.Tensor] = None,
1667
+ output_attentions: Optional[bool] = None,
1668
+ output_hidden_states: Optional[bool] = None,
1669
+ return_dict: Optional[bool] = None,
1670
+ labels: Optional[torch.Tensor] = None,
1671
+ ) -> Union[Tuple, XVectorOutput]:
1672
+ r"""
1673
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1674
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1675
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1676
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1677
+ """
1678
+
1679
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1680
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1681
+
1682
+ outputs = self.data2vec_audio(
1683
+ input_values,
1684
+ attention_mask=attention_mask,
1685
+ output_attentions=output_attentions,
1686
+ output_hidden_states=output_hidden_states,
1687
+ return_dict=return_dict,
1688
+ )
1689
+
1690
+ if self.config.use_weighted_layer_sum:
1691
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1692
+ hidden_states = torch.stack(hidden_states, dim=1)
1693
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1694
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1695
+ else:
1696
+ hidden_states = outputs[0]
1697
+
1698
+ hidden_states = self.projector(hidden_states)
1699
+
1700
+ for tdnn_layer in self.tdnn:
1701
+ hidden_states = tdnn_layer(hidden_states)
1702
+
1703
+ # Statistic Pooling
1704
+ if attention_mask is None:
1705
+ mean_features = hidden_states.mean(dim=1)
1706
+ std_features = hidden_states.std(dim=1)
1707
+ else:
1708
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
1709
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
1710
+ mean_features = []
1711
+ std_features = []
1712
+ for i, length in enumerate(tdnn_output_lengths):
1713
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
1714
+ std_features.append(hidden_states[i, :length].std(dim=0))
1715
+ mean_features = torch.stack(mean_features)
1716
+ std_features = torch.stack(std_features)
1717
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
1718
+
1719
+ output_embeddings = self.feature_extractor(statistic_pooling)
1720
+ logits = self.classifier(output_embeddings)
1721
+
1722
+ loss = None
1723
+ if labels is not None:
1724
+ loss = self.objective(logits, labels)
1725
+
1726
+ if not return_dict:
1727
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
1728
+ return ((loss,) + output) if loss is not None else output
1729
+
1730
+ return XVectorOutput(
1731
+ loss=loss,
1732
+ logits=logits,
1733
+ embeddings=output_embeddings,
1734
+ hidden_states=outputs.hidden_states,
1735
+ attentions=outputs.attentions,
1736
+ )
1737
+
1738
+
1739
+ __all__ = [
1740
+ "Data2VecAudioForAudioFrameClassification",
1741
+ "Data2VecAudioForCTC",
1742
+ "Data2VecAudioForSequenceClassification",
1743
+ "Data2VecAudioForXVector",
1744
+ "Data2VecAudioModel",
1745
+ "Data2VecAudioPreTrainedModel",
1746
+ ]
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_text.py ADDED
@@ -0,0 +1,1553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Data2VecText model."""
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...activations import ACT2FN, gelu
26
+ from ...generation import GenerationMixin
27
+ from ...modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
39
+ from ...utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ replace_return_docstrings,
45
+ )
46
+ from .configuration_data2vec_text import Data2VecTextConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ _HIDDEN_STATES_START_POSITION = 2
53
+
54
+ # General docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/data2vec-text-base"
56
+ _CONFIG_FOR_DOC = "Data2VecTextConfig"
57
+
58
+
59
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
60
+ class Data2VecTextForTextEmbeddings(nn.Module):
61
+ """
62
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
63
+ """
64
+
65
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
69
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
70
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
71
+
72
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
73
+ # any TensorFlow checkpoint file
74
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
75
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
76
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
77
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
78
+ self.register_buffer(
79
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
80
+ )
81
+ self.register_buffer(
82
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
83
+ )
84
+
85
+ # End copy
86
+ self.padding_idx = config.pad_token_id
87
+ self.position_embeddings = nn.Embedding(
88
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
89
+ )
90
+
91
+ def forward(
92
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
93
+ ):
94
+ if position_ids is None:
95
+ if input_ids is not None:
96
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
97
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
98
+ else:
99
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
100
+
101
+ if input_ids is not None:
102
+ input_shape = input_ids.size()
103
+ else:
104
+ input_shape = inputs_embeds.size()[:-1]
105
+
106
+ seq_length = input_shape[1]
107
+
108
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
109
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
110
+ # issue #5664
111
+ if token_type_ids is None:
112
+ if hasattr(self, "token_type_ids"):
113
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
114
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
115
+ token_type_ids = buffered_token_type_ids_expanded
116
+ else:
117
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
118
+
119
+ if inputs_embeds is None:
120
+ inputs_embeds = self.word_embeddings(input_ids)
121
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
122
+
123
+ embeddings = inputs_embeds + token_type_embeddings
124
+ if self.position_embedding_type == "absolute":
125
+ position_embeddings = self.position_embeddings(position_ids)
126
+ embeddings += position_embeddings
127
+ embeddings = self.LayerNorm(embeddings)
128
+ embeddings = self.dropout(embeddings)
129
+ return embeddings
130
+
131
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
132
+ """
133
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
134
+
135
+ Args:
136
+ inputs_embeds: torch.Tensor
137
+
138
+ Returns: torch.Tensor
139
+ """
140
+ input_shape = inputs_embeds.size()[:-1]
141
+ sequence_length = input_shape[1]
142
+
143
+ position_ids = torch.arange(
144
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
145
+ )
146
+ return position_ids.unsqueeze(0).expand(input_shape)
147
+
148
+
149
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
150
+ class Data2VecTextSelfAttention(nn.Module):
151
+ def __init__(self, config, position_embedding_type=None):
152
+ super().__init__()
153
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
154
+ raise ValueError(
155
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
156
+ f"heads ({config.num_attention_heads})"
157
+ )
158
+
159
+ self.num_attention_heads = config.num_attention_heads
160
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
161
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
162
+
163
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
164
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
165
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
166
+
167
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
168
+ self.position_embedding_type = position_embedding_type or getattr(
169
+ config, "position_embedding_type", "absolute"
170
+ )
171
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
172
+ self.max_position_embeddings = config.max_position_embeddings
173
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
174
+
175
+ self.is_decoder = config.is_decoder
176
+
177
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
178
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
179
+ x = x.view(new_x_shape)
180
+ return x.permute(0, 2, 1, 3)
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ attention_mask: Optional[torch.FloatTensor] = None,
186
+ head_mask: Optional[torch.FloatTensor] = None,
187
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
188
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
189
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
190
+ output_attentions: Optional[bool] = False,
191
+ ) -> Tuple[torch.Tensor]:
192
+ mixed_query_layer = self.query(hidden_states)
193
+
194
+ # If this is instantiated as a cross-attention module, the keys
195
+ # and values come from an encoder; the attention mask needs to be
196
+ # such that the encoder's padding tokens are not attended to.
197
+ is_cross_attention = encoder_hidden_states is not None
198
+
199
+ if is_cross_attention and past_key_value is not None:
200
+ # reuse k,v, cross_attentions
201
+ key_layer = past_key_value[0]
202
+ value_layer = past_key_value[1]
203
+ attention_mask = encoder_attention_mask
204
+ elif is_cross_attention:
205
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
206
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
207
+ attention_mask = encoder_attention_mask
208
+ elif past_key_value is not None:
209
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
210
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
211
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
212
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
213
+ else:
214
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
215
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
216
+
217
+ query_layer = self.transpose_for_scores(mixed_query_layer)
218
+
219
+ use_cache = past_key_value is not None
220
+ if self.is_decoder:
221
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
222
+ # Further calls to cross_attention layer can then reuse all cross-attention
223
+ # key/value_states (first "if" case)
224
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
225
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
226
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
227
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
228
+ past_key_value = (key_layer, value_layer)
229
+
230
+ # Take the dot product between "query" and "key" to get the raw attention scores.
231
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
232
+
233
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
234
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
235
+ if use_cache:
236
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
237
+ -1, 1
238
+ )
239
+ else:
240
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
241
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
242
+ distance = position_ids_l - position_ids_r
243
+
244
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
245
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
246
+
247
+ if self.position_embedding_type == "relative_key":
248
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
249
+ attention_scores = attention_scores + relative_position_scores
250
+ elif self.position_embedding_type == "relative_key_query":
251
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
252
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
253
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
254
+
255
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
256
+ if attention_mask is not None:
257
+ # Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
258
+ attention_scores = attention_scores + attention_mask
259
+
260
+ # Normalize the attention scores to probabilities.
261
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
262
+
263
+ # This is actually dropping out entire tokens to attend to, which might
264
+ # seem a bit unusual, but is taken from the original Transformer paper.
265
+ attention_probs = self.dropout(attention_probs)
266
+
267
+ # Mask heads if we want to
268
+ if head_mask is not None:
269
+ attention_probs = attention_probs * head_mask
270
+
271
+ context_layer = torch.matmul(attention_probs, value_layer)
272
+
273
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
274
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
275
+ context_layer = context_layer.view(new_context_layer_shape)
276
+
277
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
278
+
279
+ if self.is_decoder:
280
+ outputs = outputs + (past_key_value,)
281
+ return outputs
282
+
283
+
284
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
285
+ class Data2VecTextSelfOutput(nn.Module):
286
+ def __init__(self, config):
287
+ super().__init__()
288
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
289
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
296
+ return hidden_states
297
+
298
+
299
+ DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
300
+ "eager": Data2VecTextSelfAttention,
301
+ }
302
+
303
+
304
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
305
+ class Data2VecTextAttention(nn.Module):
306
+ def __init__(self, config, position_embedding_type=None):
307
+ super().__init__()
308
+ self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
309
+ config, position_embedding_type=position_embedding_type
310
+ )
311
+ self.output = Data2VecTextSelfOutput(config)
312
+ self.pruned_heads = set()
313
+
314
+ def prune_heads(self, heads):
315
+ if len(heads) == 0:
316
+ return
317
+ heads, index = find_pruneable_heads_and_indices(
318
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
319
+ )
320
+
321
+ # Prune linear layers
322
+ self.self.query = prune_linear_layer(self.self.query, index)
323
+ self.self.key = prune_linear_layer(self.self.key, index)
324
+ self.self.value = prune_linear_layer(self.self.value, index)
325
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
326
+
327
+ # Update hyper params and store pruned heads
328
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
329
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
330
+ self.pruned_heads = self.pruned_heads.union(heads)
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ attention_mask: Optional[torch.FloatTensor] = None,
336
+ head_mask: Optional[torch.FloatTensor] = None,
337
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
338
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
339
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
340
+ output_attentions: Optional[bool] = False,
341
+ ) -> Tuple[torch.Tensor]:
342
+ self_outputs = self.self(
343
+ hidden_states,
344
+ attention_mask,
345
+ head_mask,
346
+ encoder_hidden_states,
347
+ encoder_attention_mask,
348
+ past_key_value,
349
+ output_attentions,
350
+ )
351
+ attention_output = self.output(self_outputs[0], hidden_states)
352
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
353
+ return outputs
354
+
355
+
356
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
357
+ class Data2VecTextIntermediate(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
361
+ if isinstance(config.hidden_act, str):
362
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
363
+ else:
364
+ self.intermediate_act_fn = config.hidden_act
365
+
366
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.intermediate_act_fn(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
373
+ class Data2VecTextOutput(nn.Module):
374
+ def __init__(self, config):
375
+ super().__init__()
376
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
377
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
378
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
379
+
380
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
381
+ hidden_states = self.dense(hidden_states)
382
+ hidden_states = self.dropout(hidden_states)
383
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
388
+ class Data2VecTextLayer(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
392
+ self.seq_len_dim = 1
393
+ self.attention = Data2VecTextAttention(config)
394
+ self.is_decoder = config.is_decoder
395
+ self.add_cross_attention = config.add_cross_attention
396
+ if self.add_cross_attention:
397
+ if not self.is_decoder:
398
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
399
+ self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute")
400
+ self.intermediate = Data2VecTextIntermediate(config)
401
+ self.output = Data2VecTextOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ attention_mask: Optional[torch.FloatTensor] = None,
407
+ head_mask: Optional[torch.FloatTensor] = None,
408
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
409
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
410
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
411
+ output_attentions: Optional[bool] = False,
412
+ ) -> Tuple[torch.Tensor]:
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
415
+ self_attention_outputs = self.attention(
416
+ hidden_states,
417
+ attention_mask,
418
+ head_mask,
419
+ output_attentions=output_attentions,
420
+ past_key_value=self_attn_past_key_value,
421
+ )
422
+ attention_output = self_attention_outputs[0]
423
+
424
+ # if decoder, the last output is tuple of self-attn cache
425
+ if self.is_decoder:
426
+ outputs = self_attention_outputs[1:-1]
427
+ present_key_value = self_attention_outputs[-1]
428
+ else:
429
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
430
+
431
+ cross_attn_present_key_value = None
432
+ if self.is_decoder and encoder_hidden_states is not None:
433
+ if not hasattr(self, "crossattention"):
434
+ raise ValueError(
435
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
436
+ " by setting `config.add_cross_attention=True`"
437
+ )
438
+
439
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
440
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
441
+ cross_attention_outputs = self.crossattention(
442
+ attention_output,
443
+ attention_mask,
444
+ head_mask,
445
+ encoder_hidden_states,
446
+ encoder_attention_mask,
447
+ cross_attn_past_key_value,
448
+ output_attentions,
449
+ )
450
+ attention_output = cross_attention_outputs[0]
451
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
452
+
453
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
454
+ cross_attn_present_key_value = cross_attention_outputs[-1]
455
+ present_key_value = present_key_value + cross_attn_present_key_value
456
+
457
+ layer_output = apply_chunking_to_forward(
458
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
459
+ )
460
+ outputs = (layer_output,) + outputs
461
+
462
+ # if decoder, return the attn key/values as the last output
463
+ if self.is_decoder:
464
+ outputs = outputs + (present_key_value,)
465
+
466
+ return outputs
467
+
468
+ def feed_forward_chunk(self, attention_output):
469
+ intermediate_output = self.intermediate(attention_output)
470
+ layer_output = self.output(intermediate_output, attention_output)
471
+ return layer_output
472
+
473
+
474
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText
475
+ class Data2VecTextEncoder(nn.Module):
476
+ def __init__(self, config):
477
+ super().__init__()
478
+ self.config = config
479
+ self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)])
480
+ self.gradient_checkpointing = False
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states: torch.Tensor,
485
+ attention_mask: Optional[torch.FloatTensor] = None,
486
+ head_mask: Optional[torch.FloatTensor] = None,
487
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
488
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
489
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
490
+ use_cache: Optional[bool] = None,
491
+ output_attentions: Optional[bool] = False,
492
+ output_hidden_states: Optional[bool] = False,
493
+ return_dict: Optional[bool] = True,
494
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
495
+ all_hidden_states = () if output_hidden_states else None
496
+ all_self_attentions = () if output_attentions else None
497
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
498
+
499
+ if self.gradient_checkpointing and self.training:
500
+ if use_cache:
501
+ logger.warning_once(
502
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
503
+ )
504
+ use_cache = False
505
+
506
+ next_decoder_cache = () if use_cache else None
507
+ for i, layer_module in enumerate(self.layer):
508
+ if output_hidden_states:
509
+ all_hidden_states = all_hidden_states + (hidden_states,)
510
+
511
+ layer_head_mask = head_mask[i] if head_mask is not None else None
512
+ past_key_value = past_key_values[i] if past_key_values is not None else None
513
+
514
+ if self.gradient_checkpointing and self.training:
515
+ layer_outputs = self._gradient_checkpointing_func(
516
+ layer_module.__call__,
517
+ hidden_states,
518
+ attention_mask,
519
+ layer_head_mask,
520
+ encoder_hidden_states,
521
+ encoder_attention_mask,
522
+ past_key_value,
523
+ output_attentions,
524
+ )
525
+ else:
526
+ layer_outputs = layer_module(
527
+ hidden_states,
528
+ attention_mask,
529
+ layer_head_mask,
530
+ encoder_hidden_states,
531
+ encoder_attention_mask,
532
+ past_key_value,
533
+ output_attentions,
534
+ )
535
+
536
+ hidden_states = layer_outputs[0]
537
+ if use_cache:
538
+ next_decoder_cache += (layer_outputs[-1],)
539
+ if output_attentions:
540
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
541
+ if self.config.add_cross_attention:
542
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
543
+
544
+ if output_hidden_states:
545
+ all_hidden_states = all_hidden_states + (hidden_states,)
546
+
547
+ if not return_dict:
548
+ return tuple(
549
+ v
550
+ for v in [
551
+ hidden_states,
552
+ next_decoder_cache,
553
+ all_hidden_states,
554
+ all_self_attentions,
555
+ all_cross_attentions,
556
+ ]
557
+ if v is not None
558
+ )
559
+ return BaseModelOutputWithPastAndCrossAttentions(
560
+ last_hidden_state=hidden_states,
561
+ past_key_values=next_decoder_cache,
562
+ hidden_states=all_hidden_states,
563
+ attentions=all_self_attentions,
564
+ cross_attentions=all_cross_attentions,
565
+ )
566
+
567
+
568
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
569
+ class Data2VecTextPooler(nn.Module):
570
+ def __init__(self, config):
571
+ super().__init__()
572
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
573
+ self.activation = nn.Tanh()
574
+
575
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
576
+ # We "pool" the model by simply taking the hidden state corresponding
577
+ # to the first token.
578
+ first_token_tensor = hidden_states[:, 0]
579
+ pooled_output = self.dense(first_token_tensor)
580
+ pooled_output = self.activation(pooled_output)
581
+ return pooled_output
582
+
583
+
584
+ class Data2VecTextPreTrainedModel(PreTrainedModel):
585
+ """
586
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
587
+ models.
588
+ """
589
+
590
+ config_class = Data2VecTextConfig
591
+ base_model_prefix = "data2vec_text"
592
+ supports_gradient_checkpointing = True
593
+ _no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
594
+
595
+ def _init_weights(self, module):
596
+ """Initialize the weights"""
597
+ if isinstance(module, nn.Linear):
598
+ # Slightly different from the TF version which uses truncated_normal for initialization
599
+ # cf https://github.com/pytorch/pytorch/pull/5617
600
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
601
+ if module.bias is not None:
602
+ module.bias.data.zero_()
603
+ elif isinstance(module, nn.Embedding):
604
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
605
+ if module.padding_idx is not None:
606
+ module.weight.data[module.padding_idx].zero_()
607
+ elif isinstance(module, nn.LayerNorm):
608
+ if hasattr(module, "bias") and module.bias is not None:
609
+ module.bias.data.zero_()
610
+ if hasattr(module, "weight") and module.weight is not None:
611
+ module.weight.data.fill_(1.0)
612
+
613
+
614
+ DATA2VECTEXT_START_DOCSTRING = r"""
615
+ Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
616
+ Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
617
+ Michael Auli.
618
+
619
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
620
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
621
+ etc.)
622
+
623
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
624
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
625
+ and behavior.
626
+
627
+ Parameters:
628
+ config ([`Data2VecTextConfig`]): Model configuration class with all the parameters of the
629
+ model. Initializing with a config file does not load the weights associated with the model, only the
630
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
631
+ """
632
+
633
+ DATA2VECTEXT_INPUTS_DOCSTRING = r"""
634
+ Args:
635
+ input_ids (`torch.LongTensor` of shape `({0})`):
636
+ Indices of input sequence tokens in the vocabulary.
637
+
638
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
639
+ [`PreTrainedTokenizer.__call__`] for details.
640
+
641
+ [What are input IDs?](../glossary#input-ids)
642
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
643
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
644
+
645
+ - 1 for tokens that are **not masked**,
646
+ - 0 for tokens that are **masked**.
647
+
648
+ [What are attention masks?](../glossary#attention-mask)
649
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
650
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
651
+ 1]`:
652
+
653
+ - 0 corresponds to a *sentence A* token,
654
+ - 1 corresponds to a *sentence B* token.
655
+
656
+ [What are token type IDs?](../glossary#token-type-ids)
657
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
658
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
659
+ config.max_position_embeddings - 1]`.
660
+
661
+ [What are position IDs?](../glossary#position-ids)
662
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
663
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
664
+
665
+ - 1 indicates the head is **not masked**,
666
+ - 0 indicates the head is **masked**.
667
+
668
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
669
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
670
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
671
+ model's internal embedding lookup matrix.
672
+ output_attentions (`bool`, *optional*):
673
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
674
+ tensors for more detail.
675
+ output_hidden_states (`bool`, *optional*):
676
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
677
+ more detail.
678
+ return_dict (`bool`, *optional*):
679
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
680
+ """
681
+
682
+
683
+ @add_start_docstrings(
684
+ "The bare Data2VecText Model for text transformer outputting raw hidden-states without any specific head on top.",
685
+ DATA2VECTEXT_START_DOCSTRING,
686
+ )
687
+ class Data2VecTextModel(Data2VecTextPreTrainedModel):
688
+ """
689
+
690
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
691
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
692
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
693
+ Kaiser and Illia Polosukhin.
694
+
695
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
696
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
697
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
698
+
699
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
700
+
701
+ """
702
+
703
+ def __init__(self, config, add_pooling_layer=True):
704
+ super().__init__(config)
705
+ self.config = config
706
+
707
+ self.embeddings = Data2VecTextForTextEmbeddings(config)
708
+ self.encoder = Data2VecTextEncoder(config)
709
+
710
+ self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
711
+
712
+ # Initialize weights and apply final processing
713
+ self.post_init()
714
+
715
+ def get_input_embeddings(self):
716
+ return self.embeddings.word_embeddings
717
+
718
+ def set_input_embeddings(self, value):
719
+ self.embeddings.word_embeddings = value
720
+
721
+ def _prune_heads(self, heads_to_prune):
722
+ """
723
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
724
+ class PreTrainedModel
725
+ """
726
+ for layer, heads in heads_to_prune.items():
727
+ self.encoder.layer[layer].attention.prune_heads(heads)
728
+
729
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
730
+ @add_code_sample_docstrings(
731
+ checkpoint=_CHECKPOINT_FOR_DOC,
732
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
733
+ config_class=_CONFIG_FOR_DOC,
734
+ )
735
+ # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
736
+ def forward(
737
+ self,
738
+ input_ids: Optional[torch.Tensor] = None,
739
+ attention_mask: Optional[torch.Tensor] = None,
740
+ token_type_ids: Optional[torch.Tensor] = None,
741
+ position_ids: Optional[torch.Tensor] = None,
742
+ head_mask: Optional[torch.Tensor] = None,
743
+ inputs_embeds: Optional[torch.Tensor] = None,
744
+ encoder_hidden_states: Optional[torch.Tensor] = None,
745
+ encoder_attention_mask: Optional[torch.Tensor] = None,
746
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
747
+ use_cache: Optional[bool] = None,
748
+ output_attentions: Optional[bool] = None,
749
+ output_hidden_states: Optional[bool] = None,
750
+ return_dict: Optional[bool] = None,
751
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
752
+ r"""
753
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
754
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
755
+ the model is configured as a decoder.
756
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
757
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
758
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
759
+
760
+ - 1 for tokens that are **not masked**,
761
+ - 0 for tokens that are **masked**.
762
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
763
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
764
+
765
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
766
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
767
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
768
+ use_cache (`bool`, *optional*):
769
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
770
+ `past_key_values`).
771
+ """
772
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
773
+ output_hidden_states = (
774
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
775
+ )
776
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
777
+
778
+ if self.config.is_decoder:
779
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
780
+ else:
781
+ use_cache = False
782
+
783
+ if input_ids is not None and inputs_embeds is not None:
784
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
785
+ elif input_ids is not None:
786
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
787
+ input_shape = input_ids.size()
788
+ elif inputs_embeds is not None:
789
+ input_shape = inputs_embeds.size()[:-1]
790
+ else:
791
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
792
+
793
+ batch_size, seq_length = input_shape
794
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
795
+
796
+ # past_key_values_length
797
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
798
+
799
+ if attention_mask is None:
800
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
801
+
802
+ if token_type_ids is None:
803
+ if hasattr(self.embeddings, "token_type_ids"):
804
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
805
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
806
+ token_type_ids = buffered_token_type_ids_expanded
807
+ else:
808
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
809
+
810
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
811
+ # ourselves in which case we just need to make it broadcastable to all heads.
812
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
813
+
814
+ # If a 2D or 3D attention mask is provided for the cross-attention
815
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
816
+ if self.config.is_decoder and encoder_hidden_states is not None:
817
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
818
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
819
+ if encoder_attention_mask is None:
820
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
821
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
822
+ else:
823
+ encoder_extended_attention_mask = None
824
+
825
+ # Prepare head mask if needed
826
+ # 1.0 in head_mask indicate we keep the head
827
+ # attention_probs has shape bsz x n_heads x N x N
828
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
829
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
830
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
831
+
832
+ embedding_output = self.embeddings(
833
+ input_ids=input_ids,
834
+ position_ids=position_ids,
835
+ token_type_ids=token_type_ids,
836
+ inputs_embeds=inputs_embeds,
837
+ past_key_values_length=past_key_values_length,
838
+ )
839
+ encoder_outputs = self.encoder(
840
+ embedding_output,
841
+ attention_mask=extended_attention_mask,
842
+ head_mask=head_mask,
843
+ encoder_hidden_states=encoder_hidden_states,
844
+ encoder_attention_mask=encoder_extended_attention_mask,
845
+ past_key_values=past_key_values,
846
+ use_cache=use_cache,
847
+ output_attentions=output_attentions,
848
+ output_hidden_states=output_hidden_states,
849
+ return_dict=return_dict,
850
+ )
851
+ sequence_output = encoder_outputs[0]
852
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
853
+
854
+ if not return_dict:
855
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
856
+
857
+ return BaseModelOutputWithPoolingAndCrossAttentions(
858
+ last_hidden_state=sequence_output,
859
+ pooler_output=pooled_output,
860
+ past_key_values=encoder_outputs.past_key_values,
861
+ hidden_states=encoder_outputs.hidden_states,
862
+ attentions=encoder_outputs.attentions,
863
+ cross_attentions=encoder_outputs.cross_attentions,
864
+ )
865
+
866
+
867
+ @add_start_docstrings(
868
+ """Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING
869
+ )
870
+ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
871
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
872
+
873
+ def __init__(self, config):
874
+ super().__init__(config)
875
+
876
+ if not config.is_decoder:
877
+ logger.warning("If you want to use `Data2VecTextLMHeadModel` as a standalone, add `is_decoder=True.`")
878
+
879
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
880
+ self.lm_head = Data2VecTextLMHead(config)
881
+
882
+ # Initialize weights and apply final processing
883
+ self.post_init()
884
+
885
+ def get_output_embeddings(self):
886
+ return self.lm_head.decoder
887
+
888
+ def set_output_embeddings(self, new_embeddings):
889
+ self.lm_head.decoder = new_embeddings
890
+
891
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
892
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
893
+ def forward(
894
+ self,
895
+ input_ids: Optional[torch.LongTensor] = None,
896
+ attention_mask: Optional[torch.FloatTensor] = None,
897
+ token_type_ids: Optional[torch.LongTensor] = None,
898
+ position_ids: Optional[torch.LongTensor] = None,
899
+ head_mask: Optional[torch.FloatTensor] = None,
900
+ inputs_embeds: Optional[torch.FloatTensor] = None,
901
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
902
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
903
+ labels: Optional[torch.LongTensor] = None,
904
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
905
+ use_cache: Optional[bool] = None,
906
+ output_attentions: Optional[bool] = None,
907
+ output_hidden_states: Optional[bool] = None,
908
+ return_dict: Optional[bool] = None,
909
+ **kwargs,
910
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
911
+ r"""
912
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
913
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
914
+ the model is configured as a decoder.
915
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
916
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
917
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
918
+
919
+ - 1 for tokens that are **not masked**,
920
+ - 0 for tokens that are **masked**.
921
+
922
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
923
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
924
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
925
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
926
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
927
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
928
+
929
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
930
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
931
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
932
+ use_cache (`bool`, *optional*):
933
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
934
+ `past_key_values`).
935
+
936
+ Returns:
937
+
938
+ Example:
939
+
940
+ ```python
941
+ >>> from transformers import AutoTokenizer, Data2VecTextForCausalLM, Data2VecTextConfig
942
+ >>> import torch
943
+
944
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base")
945
+ >>> config = Data2VecTextConfig.from_pretrained("facebook/data2vec-text-base")
946
+ >>> config.is_decoder = True
947
+ >>> model = Data2VecTextForCausalLM.from_pretrained("facebook/data2vec-text-base", config=config)
948
+
949
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
950
+ >>> outputs = model(**inputs)
951
+
952
+ >>> prediction_logits = outputs.logits
953
+ ```"""
954
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
955
+ if labels is not None:
956
+ use_cache = False
957
+
958
+ outputs = self.data2vec_text(
959
+ input_ids,
960
+ attention_mask=attention_mask,
961
+ token_type_ids=token_type_ids,
962
+ position_ids=position_ids,
963
+ head_mask=head_mask,
964
+ inputs_embeds=inputs_embeds,
965
+ encoder_hidden_states=encoder_hidden_states,
966
+ encoder_attention_mask=encoder_attention_mask,
967
+ past_key_values=past_key_values,
968
+ use_cache=use_cache,
969
+ output_attentions=output_attentions,
970
+ output_hidden_states=output_hidden_states,
971
+ return_dict=return_dict,
972
+ )
973
+
974
+ sequence_output = outputs[0]
975
+ prediction_scores = self.lm_head(sequence_output)
976
+
977
+ lm_loss = None
978
+ if labels is not None:
979
+ lm_loss = self.loss_function(
980
+ prediction_scores,
981
+ labels,
982
+ vocab_size=self.config.vocab_size,
983
+ **kwargs,
984
+ )
985
+
986
+ if not return_dict:
987
+ output = (prediction_scores,) + outputs[2:]
988
+ return ((lm_loss,) + output) if lm_loss is not None else output
989
+
990
+ return CausalLMOutputWithCrossAttentions(
991
+ loss=lm_loss,
992
+ logits=prediction_scores,
993
+ past_key_values=outputs.past_key_values,
994
+ hidden_states=outputs.hidden_states,
995
+ attentions=outputs.attentions,
996
+ cross_attentions=outputs.cross_attentions,
997
+ )
998
+
999
+ def _reorder_cache(self, past_key_values, beam_idx):
1000
+ reordered_past = ()
1001
+ for layer_past in past_key_values:
1002
+ reordered_past += (
1003
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1004
+ )
1005
+ return reordered_past
1006
+
1007
+
1008
+ @add_start_docstrings("""data2vec Model with a `language modeling` head on top.""", DATA2VECTEXT_START_DOCSTRING)
1009
+ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
1010
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1011
+
1012
+ def __init__(self, config):
1013
+ super().__init__(config)
1014
+
1015
+ if config.is_decoder:
1016
+ logger.warning(
1017
+ "If you want to use `Data2VecTextForMaskedLM` make sure `config.is_decoder=False` for "
1018
+ "bi-directional self-attention."
1019
+ )
1020
+
1021
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
1022
+ self.lm_head = Data2VecTextLMHead(config)
1023
+
1024
+ # Initialize weights and apply final processing
1025
+ self.post_init()
1026
+
1027
+ def get_output_embeddings(self):
1028
+ return self.lm_head.decoder
1029
+
1030
+ def set_output_embeddings(self, new_embeddings):
1031
+ self.lm_head.decoder = new_embeddings
1032
+
1033
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1034
+ @add_code_sample_docstrings(
1035
+ checkpoint=_CHECKPOINT_FOR_DOC,
1036
+ output_type=MaskedLMOutput,
1037
+ config_class=_CONFIG_FOR_DOC,
1038
+ mask="<mask>",
1039
+ )
1040
+ def forward(
1041
+ self,
1042
+ input_ids: Optional[torch.LongTensor] = None,
1043
+ attention_mask: Optional[torch.FloatTensor] = None,
1044
+ token_type_ids: Optional[torch.LongTensor] = None,
1045
+ position_ids: Optional[torch.LongTensor] = None,
1046
+ head_mask: Optional[torch.FloatTensor] = None,
1047
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1048
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1049
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1050
+ labels: Optional[torch.LongTensor] = None,
1051
+ output_attentions: Optional[bool] = None,
1052
+ output_hidden_states: Optional[bool] = None,
1053
+ return_dict: Optional[bool] = None,
1054
+ ) -> Union[Tuple, MaskedLMOutput]:
1055
+ r"""
1056
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1057
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1058
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1059
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1060
+ kwargs (`Dict[str, any]`, *optional*, defaults to *{}*):
1061
+ Used to hide legacy arguments that have been deprecated.
1062
+ """
1063
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1064
+
1065
+ outputs = self.data2vec_text(
1066
+ input_ids,
1067
+ attention_mask=attention_mask,
1068
+ token_type_ids=token_type_ids,
1069
+ position_ids=position_ids,
1070
+ head_mask=head_mask,
1071
+ inputs_embeds=inputs_embeds,
1072
+ encoder_hidden_states=encoder_hidden_states,
1073
+ encoder_attention_mask=encoder_attention_mask,
1074
+ output_attentions=output_attentions,
1075
+ output_hidden_states=output_hidden_states,
1076
+ return_dict=return_dict,
1077
+ )
1078
+ sequence_output = outputs[0]
1079
+ prediction_scores = self.lm_head(sequence_output)
1080
+
1081
+ masked_lm_loss = None
1082
+ if labels is not None:
1083
+ loss_fct = CrossEntropyLoss()
1084
+
1085
+ labels = labels.to(prediction_scores.device)
1086
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1087
+
1088
+ if not return_dict:
1089
+ output = (prediction_scores,) + outputs[2:]
1090
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1091
+
1092
+ return MaskedLMOutput(
1093
+ loss=masked_lm_loss,
1094
+ logits=prediction_scores,
1095
+ hidden_states=outputs.hidden_states,
1096
+ attentions=outputs.attentions,
1097
+ )
1098
+
1099
+
1100
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Data2VecText
1101
+ class Data2VecTextLMHead(nn.Module):
1102
+ """Data2VecText Head for masked language modeling."""
1103
+
1104
+ def __init__(self, config):
1105
+ super().__init__()
1106
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1107
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1108
+
1109
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1110
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1111
+ self.decoder.bias = self.bias
1112
+
1113
+ def forward(self, features, **kwargs):
1114
+ x = self.dense(features)
1115
+ x = gelu(x)
1116
+ x = self.layer_norm(x)
1117
+
1118
+ # project back to size of vocabulary with bias
1119
+ x = self.decoder(x)
1120
+
1121
+ return x
1122
+
1123
+ def _tie_weights(self):
1124
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1125
+ # For accelerate compatibility and to not break backward compatibility
1126
+ if self.decoder.bias.device.type == "meta":
1127
+ self.decoder.bias = self.bias
1128
+ else:
1129
+ self.bias = self.decoder.bias
1130
+
1131
+
1132
+ @add_start_docstrings(
1133
+ """
1134
+ Data2VecText Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1135
+ pooled output) e.g. for GLUE tasks.
1136
+ """,
1137
+ DATA2VECTEXT_START_DOCSTRING,
1138
+ )
1139
+ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
1140
+ def __init__(self, config):
1141
+ super().__init__(config)
1142
+ self.num_labels = config.num_labels
1143
+ self.config = config
1144
+
1145
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
1146
+ self.classifier = Data2VecTextClassificationHead(config)
1147
+
1148
+ # Initialize weights and apply final processing
1149
+ self.post_init()
1150
+
1151
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1152
+ @add_code_sample_docstrings(
1153
+ checkpoint=_CHECKPOINT_FOR_DOC,
1154
+ output_type=SequenceClassifierOutput,
1155
+ config_class=_CONFIG_FOR_DOC,
1156
+ )
1157
+ def forward(
1158
+ self,
1159
+ input_ids: Optional[torch.LongTensor] = None,
1160
+ attention_mask: Optional[torch.FloatTensor] = None,
1161
+ token_type_ids: Optional[torch.LongTensor] = None,
1162
+ position_ids: Optional[torch.LongTensor] = None,
1163
+ head_mask: Optional[torch.FloatTensor] = None,
1164
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1165
+ labels: Optional[torch.LongTensor] = None,
1166
+ output_attentions: Optional[bool] = None,
1167
+ output_hidden_states: Optional[bool] = None,
1168
+ return_dict: Optional[bool] = None,
1169
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1170
+ r"""
1171
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1172
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1173
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1174
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1175
+ """
1176
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1177
+
1178
+ outputs = self.data2vec_text(
1179
+ input_ids,
1180
+ attention_mask=attention_mask,
1181
+ token_type_ids=token_type_ids,
1182
+ position_ids=position_ids,
1183
+ head_mask=head_mask,
1184
+ inputs_embeds=inputs_embeds,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ )
1189
+ sequence_output = outputs[0]
1190
+ logits = self.classifier(sequence_output)
1191
+
1192
+ loss = None
1193
+ if labels is not None:
1194
+ labels = labels.to(logits.device)
1195
+
1196
+ if self.config.problem_type is None:
1197
+ if self.num_labels == 1:
1198
+ self.config.problem_type = "regression"
1199
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1200
+ self.config.problem_type = "single_label_classification"
1201
+ else:
1202
+ self.config.problem_type = "multi_label_classification"
1203
+
1204
+ if self.config.problem_type == "regression":
1205
+ loss_fct = MSELoss()
1206
+ if self.num_labels == 1:
1207
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1208
+ else:
1209
+ loss = loss_fct(logits, labels)
1210
+ elif self.config.problem_type == "single_label_classification":
1211
+ loss_fct = CrossEntropyLoss()
1212
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1213
+ elif self.config.problem_type == "multi_label_classification":
1214
+ loss_fct = BCEWithLogitsLoss()
1215
+ loss = loss_fct(logits, labels)
1216
+
1217
+ if not return_dict:
1218
+ output = (logits,) + outputs[2:]
1219
+ return ((loss,) + output) if loss is not None else output
1220
+
1221
+ return SequenceClassifierOutput(
1222
+ loss=loss,
1223
+ logits=logits,
1224
+ hidden_states=outputs.hidden_states,
1225
+ attentions=outputs.attentions,
1226
+ )
1227
+
1228
+
1229
+ @add_start_docstrings(
1230
+ """
1231
+ Data2VecText Model with a multiple choice classification head on top (a linear layer on top of the pooled output
1232
+ and a softmax) e.g. for RocStories/SWAG tasks.
1233
+ """,
1234
+ DATA2VECTEXT_START_DOCSTRING,
1235
+ )
1236
+ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
1237
+ def __init__(self, config):
1238
+ super().__init__(config)
1239
+
1240
+ self.data2vec_text = Data2VecTextModel(config)
1241
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1242
+ self.classifier = nn.Linear(config.hidden_size, 1)
1243
+
1244
+ # Initialize weights and apply final processing
1245
+ self.post_init()
1246
+
1247
+ @add_start_docstrings_to_model_forward(
1248
+ DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1249
+ )
1250
+ @add_code_sample_docstrings(
1251
+ checkpoint=_CHECKPOINT_FOR_DOC,
1252
+ output_type=MultipleChoiceModelOutput,
1253
+ config_class=_CONFIG_FOR_DOC,
1254
+ )
1255
+ def forward(
1256
+ self,
1257
+ input_ids: Optional[torch.LongTensor] = None,
1258
+ token_type_ids: Optional[torch.LongTensor] = None,
1259
+ attention_mask: Optional[torch.FloatTensor] = None,
1260
+ labels: Optional[torch.LongTensor] = None,
1261
+ position_ids: Optional[torch.LongTensor] = None,
1262
+ head_mask: Optional[torch.FloatTensor] = None,
1263
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1264
+ output_attentions: Optional[bool] = None,
1265
+ output_hidden_states: Optional[bool] = None,
1266
+ return_dict: Optional[bool] = None,
1267
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1268
+ r"""
1269
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1270
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1271
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1272
+ `input_ids` above)
1273
+ """
1274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1275
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1276
+
1277
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1278
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1279
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1280
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1281
+ flat_inputs_embeds = (
1282
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1283
+ if inputs_embeds is not None
1284
+ else None
1285
+ )
1286
+
1287
+ outputs = self.data2vec_text(
1288
+ flat_input_ids,
1289
+ position_ids=flat_position_ids,
1290
+ token_type_ids=flat_token_type_ids,
1291
+ attention_mask=flat_attention_mask,
1292
+ head_mask=head_mask,
1293
+ inputs_embeds=flat_inputs_embeds,
1294
+ output_attentions=output_attentions,
1295
+ output_hidden_states=output_hidden_states,
1296
+ return_dict=return_dict,
1297
+ )
1298
+ pooled_output = outputs[1]
1299
+
1300
+ pooled_output = self.dropout(pooled_output)
1301
+ logits = self.classifier(pooled_output)
1302
+ reshaped_logits = logits.view(-1, num_choices)
1303
+
1304
+ loss = None
1305
+ if labels is not None:
1306
+ loss_fct = CrossEntropyLoss()
1307
+
1308
+ labels = labels.to(reshaped_logits.device)
1309
+ loss = loss_fct(reshaped_logits, labels)
1310
+
1311
+ if not return_dict:
1312
+ output = (reshaped_logits,) + outputs[2:]
1313
+ return ((loss,) + output) if loss is not None else output
1314
+
1315
+ return MultipleChoiceModelOutput(
1316
+ loss=loss,
1317
+ logits=reshaped_logits,
1318
+ hidden_states=outputs.hidden_states,
1319
+ attentions=outputs.attentions,
1320
+ )
1321
+
1322
+
1323
+ @add_start_docstrings(
1324
+ """
1325
+ Data2VecText Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
1326
+ for Named-Entity-Recognition (NER) tasks.
1327
+ """,
1328
+ DATA2VECTEXT_START_DOCSTRING,
1329
+ )
1330
+ class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
1331
+ def __init__(self, config):
1332
+ super().__init__(config)
1333
+ self.num_labels = config.num_labels
1334
+
1335
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
1336
+ classifier_dropout = (
1337
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1338
+ )
1339
+ self.dropout = nn.Dropout(classifier_dropout)
1340
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1341
+
1342
+ # Initialize weights and apply final processing
1343
+ self.post_init()
1344
+
1345
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1346
+ @add_code_sample_docstrings(
1347
+ checkpoint=_CHECKPOINT_FOR_DOC,
1348
+ output_type=TokenClassifierOutput,
1349
+ config_class=_CONFIG_FOR_DOC,
1350
+ )
1351
+ def forward(
1352
+ self,
1353
+ input_ids: Optional[torch.LongTensor] = None,
1354
+ attention_mask: Optional[torch.FloatTensor] = None,
1355
+ token_type_ids: Optional[torch.LongTensor] = None,
1356
+ position_ids: Optional[torch.LongTensor] = None,
1357
+ head_mask: Optional[torch.FloatTensor] = None,
1358
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1359
+ labels: Optional[torch.LongTensor] = None,
1360
+ output_attentions: Optional[bool] = None,
1361
+ output_hidden_states: Optional[bool] = None,
1362
+ return_dict: Optional[bool] = None,
1363
+ ) -> Union[Tuple, TokenClassifierOutput]:
1364
+ r"""
1365
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1366
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1367
+ """
1368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1369
+
1370
+ outputs = self.data2vec_text(
1371
+ input_ids,
1372
+ attention_mask=attention_mask,
1373
+ token_type_ids=token_type_ids,
1374
+ position_ids=position_ids,
1375
+ head_mask=head_mask,
1376
+ inputs_embeds=inputs_embeds,
1377
+ output_attentions=output_attentions,
1378
+ output_hidden_states=output_hidden_states,
1379
+ return_dict=return_dict,
1380
+ )
1381
+
1382
+ sequence_output = outputs[0]
1383
+
1384
+ sequence_output = self.dropout(sequence_output)
1385
+ logits = self.classifier(sequence_output)
1386
+
1387
+ loss = None
1388
+ if labels is not None:
1389
+ loss_fct = CrossEntropyLoss()
1390
+
1391
+ labels = labels.to(logits.device)
1392
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1393
+
1394
+ if not return_dict:
1395
+ output = (logits,) + outputs[2:]
1396
+ return ((loss,) + output) if loss is not None else output
1397
+
1398
+ return TokenClassifierOutput(
1399
+ loss=loss,
1400
+ logits=logits,
1401
+ hidden_states=outputs.hidden_states,
1402
+ attentions=outputs.attentions,
1403
+ )
1404
+
1405
+
1406
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Data2VecText
1407
+ class Data2VecTextClassificationHead(nn.Module):
1408
+ """Head for sentence-level classification tasks."""
1409
+
1410
+ def __init__(self, config):
1411
+ super().__init__()
1412
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1413
+ classifier_dropout = (
1414
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1415
+ )
1416
+ self.dropout = nn.Dropout(classifier_dropout)
1417
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1418
+
1419
+ def forward(self, features, **kwargs):
1420
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1421
+ x = self.dropout(x)
1422
+ x = self.dense(x)
1423
+ x = torch.tanh(x)
1424
+ x = self.dropout(x)
1425
+ x = self.out_proj(x)
1426
+ return x
1427
+
1428
+
1429
+ @add_start_docstrings(
1430
+ """
1431
+ Data2VecText Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
1432
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1433
+ """,
1434
+ DATA2VECTEXT_START_DOCSTRING,
1435
+ )
1436
+ class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
1437
+ def __init__(self, config):
1438
+ super().__init__(config)
1439
+ self.num_labels = config.num_labels
1440
+
1441
+ self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
1442
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1443
+
1444
+ # Initialize weights and apply final processing
1445
+ self.post_init()
1446
+
1447
+ @add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1448
+ @add_code_sample_docstrings(
1449
+ checkpoint=_CHECKPOINT_FOR_DOC,
1450
+ output_type=QuestionAnsweringModelOutput,
1451
+ config_class=_CONFIG_FOR_DOC,
1452
+ )
1453
+ def forward(
1454
+ self,
1455
+ input_ids: Optional[torch.LongTensor] = None,
1456
+ attention_mask: Optional[torch.FloatTensor] = None,
1457
+ token_type_ids: Optional[torch.LongTensor] = None,
1458
+ position_ids: Optional[torch.LongTensor] = None,
1459
+ head_mask: Optional[torch.FloatTensor] = None,
1460
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1461
+ start_positions: Optional[torch.LongTensor] = None,
1462
+ end_positions: Optional[torch.LongTensor] = None,
1463
+ output_attentions: Optional[bool] = None,
1464
+ output_hidden_states: Optional[bool] = None,
1465
+ return_dict: Optional[bool] = None,
1466
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1467
+ r"""
1468
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1469
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1470
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1471
+ are not taken into account for computing the loss.
1472
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1473
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1474
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1475
+ are not taken into account for computing the loss.
1476
+ """
1477
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1478
+
1479
+ outputs = self.data2vec_text(
1480
+ input_ids,
1481
+ attention_mask=attention_mask,
1482
+ token_type_ids=token_type_ids,
1483
+ position_ids=position_ids,
1484
+ head_mask=head_mask,
1485
+ inputs_embeds=inputs_embeds,
1486
+ output_attentions=output_attentions,
1487
+ output_hidden_states=output_hidden_states,
1488
+ return_dict=return_dict,
1489
+ )
1490
+
1491
+ sequence_output = outputs[0]
1492
+
1493
+ logits = self.qa_outputs(sequence_output)
1494
+ start_logits, end_logits = logits.split(1, dim=-1)
1495
+ start_logits = start_logits.squeeze(-1).contiguous()
1496
+ end_logits = end_logits.squeeze(-1).contiguous()
1497
+
1498
+ total_loss = None
1499
+ if start_positions is not None and end_positions is not None:
1500
+ # If we are on multi-GPU, split add a dimension
1501
+ if len(start_positions.size()) > 1:
1502
+ start_positions = start_positions.squeeze(-1)
1503
+ if len(end_positions.size()) > 1:
1504
+ end_positions = end_positions.squeeze(-1)
1505
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1506
+ ignored_index = start_logits.size(1)
1507
+ start_positions = start_positions.clamp(0, ignored_index)
1508
+ end_positions = end_positions.clamp(0, ignored_index)
1509
+
1510
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1511
+ start_loss = loss_fct(start_logits, start_positions)
1512
+ end_loss = loss_fct(end_logits, end_positions)
1513
+ total_loss = (start_loss + end_loss) / 2
1514
+
1515
+ if not return_dict:
1516
+ output = (start_logits, end_logits) + outputs[2:]
1517
+ return ((total_loss,) + output) if total_loss is not None else output
1518
+
1519
+ return QuestionAnsweringModelOutput(
1520
+ loss=total_loss,
1521
+ start_logits=start_logits,
1522
+ end_logits=end_logits,
1523
+ hidden_states=outputs.hidden_states,
1524
+ attentions=outputs.attentions,
1525
+ )
1526
+
1527
+
1528
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1529
+ """
1530
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1531
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1532
+
1533
+ Args:
1534
+ x: torch.Tensor x:
1535
+
1536
+ Returns: torch.Tensor
1537
+ """
1538
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1539
+ mask = input_ids.ne(padding_idx).int()
1540
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1541
+ return incremental_indices.long() + padding_idx
1542
+
1543
+
1544
+ __all__ = [
1545
+ "Data2VecTextForCausalLM",
1546
+ "Data2VecTextForMaskedLM",
1547
+ "Data2VecTextForMultipleChoice",
1548
+ "Data2VecTextForQuestionAnswering",
1549
+ "Data2VecTextForSequenceClassification",
1550
+ "Data2VecTextForTokenClassification",
1551
+ "Data2VecTextModel",
1552
+ "Data2VecTextPreTrainedModel",
1553
+ ]
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_vision.py ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Data2VecVision model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ...activations import ACT2FN
29
+ from ...modeling_outputs import (
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ SemanticSegmenterOutput,
34
+ )
35
+ from ...modeling_utils import PreTrainedModel
36
+ from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
37
+ from ...utils import (
38
+ add_code_sample_docstrings,
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ logging,
42
+ replace_return_docstrings,
43
+ torch_int,
44
+ )
45
+ from .configuration_data2vec_vision import Data2VecVisionConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ # General docstring
51
+ _CONFIG_FOR_DOC = "Data2VecVisionConfig"
52
+
53
+ # Base docstring
54
+ _CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
55
+ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
56
+
57
+ # Image classification docstring
58
+ _IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
59
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
60
+
61
+
62
+ @dataclass
63
+ # Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision
64
+ class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
65
+ """
66
+ Class for outputs of [`Data2VecVisionModel`].
67
+
68
+ Args:
69
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
70
+ Sequence of hidden-states at the output of the last layer of the model.
71
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
72
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
73
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
74
+ will be returned.
75
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
76
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
77
+ shape `(batch_size, sequence_length, hidden_size)`.
78
+
79
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
80
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
81
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
82
+ sequence_length)`.
83
+
84
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
85
+ heads.
86
+ """
87
+
88
+
89
+ # Copied from transformers.models.beit.modeling_beit.drop_path
90
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
91
+ """
92
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
93
+
94
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
95
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
96
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
97
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
98
+ argument.
99
+ """
100
+ if drop_prob == 0.0 or not training:
101
+ return input
102
+ keep_prob = 1 - drop_prob
103
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
104
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
105
+ random_tensor.floor_() # binarize
106
+ output = input.div(keep_prob) * random_tensor
107
+ return output
108
+
109
+
110
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
111
+ class Data2VecVisionDropPath(nn.Module):
112
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
113
+
114
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
115
+ super().__init__()
116
+ self.drop_prob = drop_prob
117
+
118
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
119
+ return drop_path(hidden_states, self.drop_prob, self.training)
120
+
121
+ def extra_repr(self) -> str:
122
+ return "p={}".format(self.drop_prob)
123
+
124
+
125
+ # Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
126
+ class Data2VecVisionEmbeddings(nn.Module):
127
+ """
128
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
129
+
130
+ """
131
+
132
+ def __init__(self, config: Data2VecVisionConfig) -> None:
133
+ super().__init__()
134
+
135
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
136
+ if config.use_mask_token:
137
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
138
+ else:
139
+ self.mask_token = None
140
+ self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
141
+ self.patch_size = config.patch_size
142
+ self.image_size = (
143
+ config.image_size
144
+ if isinstance(config.image_size, collections.abc.Iterable)
145
+ else (config.image_size, config.image_size)
146
+ )
147
+ num_patches = self.patch_embeddings.num_patches
148
+ if config.use_absolute_position_embeddings:
149
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
150
+ else:
151
+ self.position_embeddings = None
152
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
153
+
154
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
155
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
156
+ """
157
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
158
+ images. This method is also adapted to support torch.jit tracing.
159
+
160
+ Adapted from:
161
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
162
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
163
+ """
164
+
165
+ num_patches = embeddings.shape[1] - 1
166
+ num_positions = self.position_embeddings.shape[1] - 1
167
+
168
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
169
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
170
+ return self.position_embeddings
171
+
172
+ class_pos_embed = self.position_embeddings[:, :1]
173
+ patch_pos_embed = self.position_embeddings[:, 1:]
174
+
175
+ dim = embeddings.shape[-1]
176
+
177
+ new_height = height // self.patch_size
178
+ new_width = width // self.patch_size
179
+
180
+ sqrt_num_positions = torch_int(num_positions**0.5)
181
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
182
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
183
+
184
+ patch_pos_embed = nn.functional.interpolate(
185
+ patch_pos_embed,
186
+ size=(new_height, new_width),
187
+ mode="bicubic",
188
+ align_corners=False,
189
+ )
190
+
191
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
192
+
193
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
194
+
195
+ def forward(
196
+ self,
197
+ pixel_values: torch.Tensor,
198
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
199
+ interpolate_pos_encoding: Optional[bool] = None,
200
+ ) -> torch.Tensor:
201
+ if self.position_embeddings is not None and interpolate_pos_encoding is not None:
202
+ warnings.warn(
203
+ "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
204
+ "interpolated to the input image size. The argument will be removed in transformers v4.51.0."
205
+ )
206
+
207
+ _, _, height, width = pixel_values.shape
208
+ embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
209
+ batch_size, seq_len, _ = embeddings.size()
210
+
211
+ if bool_masked_pos is not None:
212
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
213
+ # replace the masked visual tokens by mask_tokens
214
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
215
+ embeddings = embeddings * (1 - w) + mask_tokens * w
216
+
217
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
218
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
219
+
220
+ if self.position_embeddings is not None:
221
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
222
+
223
+ embeddings = self.dropout(embeddings)
224
+
225
+ return embeddings, (patch_height, patch_width)
226
+
227
+
228
+ # Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
229
+ class Data2VecVisionPatchEmbeddings(nn.Module):
230
+ """
231
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
232
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
233
+ Transformer.
234
+ """
235
+
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ image_size, patch_size = config.image_size, config.patch_size
239
+ num_channels, hidden_size = config.num_channels, config.hidden_size
240
+
241
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
242
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
243
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
244
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
245
+ self.image_size = image_size
246
+ self.patch_size = patch_size
247
+ self.num_channels = num_channels
248
+ self.num_patches = num_patches
249
+ self.patch_shape = patch_shape
250
+
251
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
252
+
253
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
254
+ batch_size, num_channels, height, width = pixel_values.shape
255
+ if num_channels != self.num_channels:
256
+ raise ValueError(
257
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
258
+ )
259
+
260
+ embeddings = self.projection(pixel_values)
261
+ patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
262
+ embeddings = embeddings.flatten(2).transpose(1, 2)
263
+
264
+ return embeddings, (patch_height, patch_width)
265
+
266
+
267
+ # Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
268
+ class Data2VecVisionSelfAttention(nn.Module):
269
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
270
+ super().__init__()
271
+ self.config = config
272
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
273
+ raise ValueError(
274
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
275
+ f"heads {config.num_attention_heads}."
276
+ )
277
+
278
+ self.num_attention_heads = config.num_attention_heads
279
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
280
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
281
+
282
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
283
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
284
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
285
+
286
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
287
+
288
+ self.has_relative_position_bias = bool(window_size)
289
+ if self.has_relative_position_bias:
290
+ self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
291
+
292
+ def transpose_for_scores(self, x):
293
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
294
+ x = x.view(*new_x_shape)
295
+ return x.permute(0, 2, 1, 3)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ head_mask: Optional[torch.Tensor] = None,
301
+ output_attentions: bool = False,
302
+ relative_position_bias: Optional[torch.Tensor] = None,
303
+ interpolate_pos_encoding: bool = False,
304
+ resolution: Optional[Tuple[int]] = None,
305
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
306
+ mixed_query_layer = self.query(hidden_states)
307
+
308
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
310
+ query_layer = self.transpose_for_scores(mixed_query_layer)
311
+
312
+ # Take the dot product between "query" and "key" to get the raw attention scores.
313
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
314
+
315
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
316
+
317
+ # Add relative position bias if present.
318
+ if self.has_relative_position_bias:
319
+ height, width = resolution
320
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
321
+ attention_scores = attention_scores + self.relative_position_bias(
322
+ window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
323
+ )
324
+
325
+ # Add shared relative position bias if provided.
326
+ if relative_position_bias is not None:
327
+ attention_scores = attention_scores + relative_position_bias
328
+
329
+ # Normalize the attention scores to probabilities.
330
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
331
+
332
+ # This is actually dropping out entire tokens to attend to, which might
333
+ # seem a bit unusual, but is taken from the original Transformer paper.
334
+ attention_probs = self.dropout(attention_probs)
335
+
336
+ # Mask heads if we want to
337
+ if head_mask is not None:
338
+ attention_probs = attention_probs * head_mask
339
+
340
+ context_layer = torch.matmul(attention_probs, value_layer)
341
+
342
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
343
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
344
+ context_layer = context_layer.view(*new_context_layer_shape)
345
+
346
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
347
+
348
+ return outputs
349
+
350
+
351
+ # Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
352
+ class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ head_mask: Optional[torch.Tensor] = None,
357
+ output_attentions: bool = False,
358
+ relative_position_bias: Optional[torch.Tensor] = None,
359
+ interpolate_pos_encoding: bool = False,
360
+ resolution: Optional[Tuple[int]] = None,
361
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
362
+ if output_attentions or head_mask is not None:
363
+ logger.warning_once(
364
+ "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
365
+ "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
366
+ "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
367
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
368
+ )
369
+ return super().forward(
370
+ hidden_states=hidden_states,
371
+ head_mask=head_mask,
372
+ output_attentions=output_attentions,
373
+ relative_position_bias=relative_position_bias,
374
+ interpolate_pos_encoding=interpolate_pos_encoding,
375
+ resolution=resolution,
376
+ )
377
+
378
+ mixed_query_layer = self.query(hidden_states)
379
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
380
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
381
+ query_layer = self.transpose_for_scores(mixed_query_layer)
382
+
383
+ attn_bias = None
384
+ if self.has_relative_position_bias:
385
+ height, width = resolution
386
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
387
+ attn_bias = self.relative_position_bias(
388
+ window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
389
+ )
390
+
391
+ # Add shared relative position bias if provided.
392
+ if relative_position_bias is not None:
393
+ if attn_bias is None:
394
+ attn_bias = relative_position_bias
395
+ else:
396
+ attn_bias += relative_position_bias
397
+
398
+ scaling = 1 / math.sqrt(self.attention_head_size)
399
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
400
+ query_layer,
401
+ key_layer,
402
+ value_layer,
403
+ attn_mask=attn_bias,
404
+ dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
405
+ is_causal=False,
406
+ scale=scaling,
407
+ )
408
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
409
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
410
+ context_layer = context_layer.view(*new_context_layer_shape)
411
+ return context_layer, None
412
+
413
+
414
+ # Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
415
+ class Data2VecVisionSelfOutput(nn.Module):
416
+ """
417
+ The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
418
+ layernorm applied before each block.
419
+ """
420
+
421
+ def __init__(self, config: Data2VecVisionConfig) -> None:
422
+ super().__init__()
423
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
424
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
425
+
426
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
427
+ hidden_states = self.dense(hidden_states)
428
+ hidden_states = self.dropout(hidden_states)
429
+
430
+ return hidden_states
431
+
432
+
433
+ DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
434
+ "eager": Data2VecVisionSelfAttention,
435
+ "sdpa": Data2VecVisionSdpaSelfAttention,
436
+ }
437
+
438
+
439
+ # Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
440
+ class Data2VecVisionAttention(nn.Module):
441
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
442
+ super().__init__()
443
+ self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
444
+ config, window_size=window_size
445
+ )
446
+ self.output = Data2VecVisionSelfOutput(config)
447
+ self.pruned_heads = set()
448
+
449
+ def prune_heads(self, heads):
450
+ if len(heads) == 0:
451
+ return
452
+ heads, index = find_pruneable_heads_and_indices(
453
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
454
+ )
455
+
456
+ # Prune linear layers
457
+ self.attention.query = prune_linear_layer(self.attention.query, index)
458
+ self.attention.key = prune_linear_layer(self.attention.key, index)
459
+ self.attention.value = prune_linear_layer(self.attention.value, index)
460
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
461
+
462
+ # Update hyper params and store pruned heads
463
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
464
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
465
+ self.pruned_heads = self.pruned_heads.union(heads)
466
+
467
+ def forward(
468
+ self,
469
+ hidden_states: torch.Tensor,
470
+ head_mask: Optional[torch.Tensor] = None,
471
+ output_attentions: bool = False,
472
+ relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
473
+ interpolate_pos_encoding: bool = False,
474
+ resolution: Optional[Tuple[int]] = None,
475
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
476
+ self_outputs = self.attention(
477
+ hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
478
+ )
479
+
480
+ attention_output = self.output(self_outputs[0], hidden_states)
481
+
482
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
483
+ return outputs
484
+
485
+
486
+ # Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision
487
+ class Data2VecVisionIntermediate(nn.Module):
488
+ def __init__(self, config: Data2VecVisionConfig) -> None:
489
+ super().__init__()
490
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
491
+ if isinstance(config.hidden_act, str):
492
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
493
+ else:
494
+ self.intermediate_act_fn = config.hidden_act
495
+
496
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
497
+ hidden_states = self.dense(hidden_states)
498
+ hidden_states = self.intermediate_act_fn(hidden_states)
499
+
500
+ return hidden_states
501
+
502
+
503
+ # Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision
504
+ class Data2VecVisionOutput(nn.Module):
505
+ def __init__(self, config: Data2VecVisionConfig) -> None:
506
+ super().__init__()
507
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
508
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
509
+
510
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
511
+ hidden_states = self.dense(hidden_states)
512
+ hidden_states = self.dropout(hidden_states)
513
+
514
+ return hidden_states
515
+
516
+
517
+ # Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
518
+ class Data2VecVisionLayer(nn.Module):
519
+ """This corresponds to the Block class in the timm implementation."""
520
+
521
+ def __init__(
522
+ self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0
523
+ ) -> None:
524
+ super().__init__()
525
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
526
+ self.seq_len_dim = 1
527
+ self.attention = Data2VecVisionAttention(config, window_size=window_size)
528
+ self.intermediate = Data2VecVisionIntermediate(config)
529
+ self.output = Data2VecVisionOutput(config)
530
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
531
+ self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
532
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
533
+
534
+ init_values = config.layer_scale_init_value
535
+ if init_values > 0:
536
+ self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
537
+ self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
538
+ else:
539
+ self.lambda_1, self.lambda_2 = None, None
540
+
541
+ def forward(
542
+ self,
543
+ hidden_states: torch.Tensor,
544
+ head_mask: Optional[torch.Tensor] = None,
545
+ output_attentions: bool = False,
546
+ relative_position_bias: Optional[torch.Tensor] = None,
547
+ interpolate_pos_encoding: bool = False,
548
+ resolution: Optional[Tuple[int]] = None,
549
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
550
+ self_attention_outputs = self.attention(
551
+ self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
552
+ head_mask,
553
+ output_attentions=output_attentions,
554
+ relative_position_bias=relative_position_bias,
555
+ interpolate_pos_encoding=interpolate_pos_encoding,
556
+ resolution=resolution,
557
+ )
558
+ attention_output = self_attention_outputs[0]
559
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
560
+
561
+ # apply lambda_1 if present
562
+ if self.lambda_1 is not None:
563
+ attention_output = self.lambda_1 * attention_output
564
+
565
+ # first residual connection
566
+ hidden_states = self.drop_path(attention_output) + hidden_states
567
+
568
+ # in Data2VecVision, layernorm is also applied after self-attention
569
+ layer_output = self.layernorm_after(hidden_states)
570
+
571
+ layer_output = self.intermediate(layer_output)
572
+ layer_output = self.output(layer_output)
573
+
574
+ if self.lambda_2 is not None:
575
+ layer_output = self.lambda_2 * layer_output
576
+
577
+ # second residual connection
578
+ layer_output = self.drop_path(layer_output) + hidden_states
579
+
580
+ outputs = (layer_output,) + outputs
581
+
582
+ return outputs
583
+
584
+
585
+ # Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision
586
+ class Data2VecVisionRelativePositionBias(nn.Module):
587
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None:
588
+ super().__init__()
589
+ self.window_size = window_size
590
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
591
+ self.relative_position_bias_table = nn.Parameter(
592
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
593
+ ) # 2*Wh-1 * 2*Ww-1, nH
594
+ # cls to token & token 2 cls & cls to cls
595
+
596
+ @compile_compatible_method_lru_cache(maxsize=10)
597
+ def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
598
+ """
599
+ This method creates the relative position index, modified to support arbitrary window sizes,
600
+ as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
601
+ """
602
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
603
+ # cls to token & token 2 cls & cls to cls
604
+ # get pair-wise relative position index for each token inside the window
605
+ window_area = window_size[0] * window_size[1]
606
+ grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
607
+ coords = torch.stack(grid) # 2, Wh, Ww
608
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
609
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
610
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
611
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
612
+ relative_coords[:, :, 1] += window_size[1] - 1
613
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
614
+ relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
615
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
616
+ relative_position_index[0, 0:] = num_relative_distance - 3
617
+ relative_position_index[0:, 0] = num_relative_distance - 2
618
+ relative_position_index[0, 0] = num_relative_distance - 1
619
+ return relative_position_index
620
+
621
+ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
622
+ """
623
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
624
+ """
625
+ old_height = 2 * self.window_size[0] - 1
626
+ old_width = 2 * self.window_size[1] - 1
627
+
628
+ new_height = 2 * window_size[0] - 1
629
+ new_width = 2 * window_size[1] - 1
630
+
631
+ old_relative_position_bias_table = self.relative_position_bias_table
632
+
633
+ old_num_relative_distance = self.num_relative_distance
634
+ new_num_relative_distance = new_height * new_width + 3
635
+
636
+ old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
637
+
638
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
639
+ new_sub_table = nn.functional.interpolate(
640
+ old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
641
+ )
642
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
643
+
644
+ new_relative_position_bias_table = torch.cat(
645
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
646
+ )
647
+
648
+ relative_position_index = self.generate_relative_position_index(window_size)
649
+ relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
650
+
651
+ # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
652
+ relative_position_bias = relative_position_bias.view(
653
+ window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
654
+ )
655
+ # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
656
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
657
+
658
+ if interpolate_pos_encoding:
659
+ relative_position_bias = nn.functional.interpolate(
660
+ relative_position_bias.unsqueeze(1),
661
+ size=(dim_size, dim_size),
662
+ mode="bilinear",
663
+ align_corners=False,
664
+ ).squeeze(1)
665
+
666
+ return relative_position_bias.unsqueeze(0)
667
+
668
+
669
+ # Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
670
+ class Data2VecVisionEncoder(nn.Module):
671
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
672
+ super().__init__()
673
+ self.config = config
674
+ self.has_relative_position_bias = config.use_shared_relative_position_bias
675
+ if self.has_relative_position_bias:
676
+ self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
677
+
678
+ # stochastic depth decay rule
679
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
680
+ self.layer = nn.ModuleList(
681
+ [
682
+ Data2VecVisionLayer(
683
+ config,
684
+ window_size=window_size if config.use_relative_position_bias else None,
685
+ drop_path_rate=dpr[i],
686
+ )
687
+ for i in range(config.num_hidden_layers)
688
+ ]
689
+ )
690
+ self.gradient_checkpointing = False
691
+
692
+ def forward(
693
+ self,
694
+ hidden_states: torch.Tensor,
695
+ head_mask: Optional[torch.Tensor] = None,
696
+ output_attentions: bool = False,
697
+ output_hidden_states: bool = False,
698
+ interpolate_pos_encoding: bool = False,
699
+ resolution: Optional[Tuple[int, int]] = None,
700
+ return_dict: bool = True,
701
+ ) -> Union[tuple, BaseModelOutput]:
702
+ all_hidden_states = () if output_hidden_states else None
703
+ all_self_attentions = () if output_attentions else None
704
+
705
+ for i, layer_module in enumerate(self.layer):
706
+ if output_hidden_states:
707
+ all_hidden_states = all_hidden_states + (hidden_states,)
708
+
709
+ if self.has_relative_position_bias:
710
+ height, width = resolution
711
+ window_size = (height // self.config.patch_size, width // self.config.patch_size)
712
+ relative_position_bias = self.relative_position_bias(
713
+ window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
714
+ )
715
+ else:
716
+ relative_position_bias = None
717
+
718
+ layer_head_mask = head_mask[i] if head_mask is not None else None
719
+
720
+ if self.gradient_checkpointing and self.training:
721
+ layer_outputs = self._gradient_checkpointing_func(
722
+ layer_module.__call__,
723
+ hidden_states,
724
+ layer_head_mask,
725
+ output_attentions,
726
+ relative_position_bias,
727
+ interpolate_pos_encoding,
728
+ resolution,
729
+ )
730
+ else:
731
+ layer_outputs = layer_module(
732
+ hidden_states,
733
+ layer_head_mask,
734
+ output_attentions,
735
+ relative_position_bias,
736
+ interpolate_pos_encoding,
737
+ resolution,
738
+ )
739
+
740
+ hidden_states = layer_outputs[0]
741
+
742
+ if output_attentions:
743
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
744
+
745
+ if output_hidden_states:
746
+ all_hidden_states = all_hidden_states + (hidden_states,)
747
+
748
+ if not return_dict:
749
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
750
+ return BaseModelOutput(
751
+ last_hidden_state=hidden_states,
752
+ hidden_states=all_hidden_states,
753
+ attentions=all_self_attentions,
754
+ )
755
+
756
+
757
+ # Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision
758
+ class Data2VecVisionPreTrainedModel(PreTrainedModel):
759
+ """
760
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
761
+ models.
762
+ """
763
+
764
+ config_class = Data2VecVisionConfig
765
+ base_model_prefix = "data2vec_vision"
766
+ main_input_name = "pixel_values"
767
+ supports_gradient_checkpointing = True
768
+ _no_split_modules = ["Data2VecVisionLayer"]
769
+ _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
770
+ _supports_sdpa = True
771
+
772
+ def _init_weights(self, module):
773
+ """Initialize the weights"""
774
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
775
+ # Slightly different from the TF version which uses truncated_normal for initialization
776
+ # cf https://github.com/pytorch/pytorch/pull/5617
777
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
778
+ if module.bias is not None:
779
+ module.bias.data.zero_()
780
+ elif isinstance(module, nn.Embedding):
781
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
782
+ if module.padding_idx is not None:
783
+ module.weight.data[module.padding_idx].zero_()
784
+ elif isinstance(module, nn.LayerNorm):
785
+ module.bias.data.zero_()
786
+ module.weight.data.fill_(1.0)
787
+ elif isinstance(module, Data2VecVisionEmbeddings):
788
+ module.cls_token.data.zero_()
789
+ if module.mask_token is not None:
790
+ module.mask_token.data.zero_()
791
+ if module.position_embeddings is not None:
792
+ module.position_embeddings.data.zero_()
793
+ elif isinstance(module, Data2VecVisionRelativePositionBias):
794
+ module.relative_position_bias_table.data.zero_()
795
+ elif isinstance(module, Data2VecVisionLayer):
796
+ if module.lambda_1 is not None:
797
+ module.lambda_1.data.fill_(self.config.layer_scale_init_value)
798
+ module.lambda_2.data.fill_(self.config.layer_scale_init_value)
799
+
800
+
801
+ DATA2VEC_VISION_START_DOCSTRING = r"""
802
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
803
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
804
+ behavior.
805
+
806
+ Parameters:
807
+ config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
808
+ Initializing with a config file does not load the weights associated with the model, only the
809
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
810
+ """
811
+
812
+ DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
813
+ Args:
814
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
815
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
816
+ [`BeitImageProcessor.__call__`] for details.
817
+
818
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
819
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
820
+
821
+ - 1 indicates the head is **not masked**,
822
+ - 0 indicates the head is **masked**.
823
+
824
+ output_attentions (`bool`, *optional*):
825
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
826
+ tensors for more detail.
827
+ output_hidden_states (`bool`, *optional*):
828
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
829
+ more detail.
830
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
831
+ Whether to interpolate the pre-trained position encodings.
832
+ return_dict (`bool`, *optional*):
833
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
834
+ """
835
+
836
+
837
+ @add_start_docstrings(
838
+ "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
839
+ DATA2VEC_VISION_START_DOCSTRING,
840
+ )
841
+ # Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False
842
+ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
843
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None:
844
+ super().__init__(config)
845
+ self.config = config
846
+
847
+ self.embeddings = Data2VecVisionEmbeddings(config)
848
+ self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
849
+
850
+ self.layernorm = (
851
+ nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
852
+ )
853
+ self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None
854
+
855
+ # Initialize weights and apply final processing
856
+ self.post_init()
857
+
858
+ def get_input_embeddings(self):
859
+ return self.embeddings.patch_embeddings
860
+
861
+ def _prune_heads(self, heads_to_prune):
862
+ """
863
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
864
+ class PreTrainedModel
865
+ """
866
+ for layer, heads in heads_to_prune.items():
867
+ self.encoder.layer[layer].attention.prune_heads(heads)
868
+
869
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_CHECKPOINT_FOR_DOC,
872
+ output_type=Data2VecVisionModelOutputWithPooling,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ modality="vision",
875
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
876
+ )
877
+ def forward(
878
+ self,
879
+ pixel_values: torch.Tensor,
880
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
881
+ head_mask: Optional[torch.Tensor] = None,
882
+ output_attentions: Optional[bool] = None,
883
+ output_hidden_states: Optional[bool] = None,
884
+ interpolate_pos_encoding: bool = False,
885
+ return_dict: Optional[bool] = None,
886
+ ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
887
+ r"""
888
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
889
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
890
+ """
891
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
+ output_hidden_states = (
893
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
+ )
895
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
896
+
897
+ # Prepare head mask if needed
898
+ # 1.0 in head_mask indicate we keep the head
899
+ # attention_probs has shape bsz x n_heads x N x N
900
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
901
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
902
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
903
+
904
+ embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
905
+ resolution = pixel_values.shape[2:]
906
+
907
+ encoder_outputs = self.encoder(
908
+ embedding_output,
909
+ head_mask=head_mask,
910
+ output_attentions=output_attentions,
911
+ output_hidden_states=output_hidden_states,
912
+ resolution=resolution,
913
+ return_dict=return_dict,
914
+ interpolate_pos_encoding=interpolate_pos_encoding,
915
+ )
916
+ sequence_output = encoder_outputs[0]
917
+ sequence_output = self.layernorm(sequence_output)
918
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
919
+
920
+ if not return_dict:
921
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
922
+ return head_outputs + encoder_outputs[1:]
923
+
924
+ return Data2VecVisionModelOutputWithPooling(
925
+ last_hidden_state=sequence_output,
926
+ pooler_output=pooled_output,
927
+ hidden_states=encoder_outputs.hidden_states,
928
+ attentions=encoder_outputs.attentions,
929
+ )
930
+
931
+
932
+ # Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision
933
+ class Data2VecVisionPooler(nn.Module):
934
+ def __init__(self, config: Data2VecVisionConfig) -> None:
935
+ super().__init__()
936
+ self.layernorm = (
937
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
938
+ )
939
+
940
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
941
+ if self.layernorm is not None:
942
+ # Mean pool the final hidden states of the patch tokens
943
+ patch_tokens = hidden_states[:, 1:, :]
944
+ pooled_output = self.layernorm(patch_tokens.mean(1))
945
+ else:
946
+ # Pool by simply taking the final hidden state of the [CLS] token
947
+ pooled_output = hidden_states[:, 0]
948
+
949
+ return pooled_output
950
+
951
+
952
+ @add_start_docstrings(
953
+ """
954
+ Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
955
+ the final hidden states of the patch tokens) e.g. for ImageNet.
956
+ """,
957
+ DATA2VEC_VISION_START_DOCSTRING,
958
+ )
959
+ # Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision
960
+ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
961
+ def __init__(self, config: Data2VecVisionConfig) -> None:
962
+ super().__init__(config)
963
+
964
+ self.num_labels = config.num_labels
965
+ self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True)
966
+
967
+ # Classifier head
968
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
969
+
970
+ # Initialize weights and apply final processing
971
+ self.post_init()
972
+
973
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
974
+ @add_code_sample_docstrings(
975
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
976
+ output_type=ImageClassifierOutput,
977
+ config_class=_CONFIG_FOR_DOC,
978
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
979
+ )
980
+ def forward(
981
+ self,
982
+ pixel_values: Optional[torch.Tensor] = None,
983
+ head_mask: Optional[torch.Tensor] = None,
984
+ labels: Optional[torch.Tensor] = None,
985
+ output_attentions: Optional[bool] = None,
986
+ output_hidden_states: Optional[bool] = None,
987
+ interpolate_pos_encoding: bool = False,
988
+ return_dict: Optional[bool] = None,
989
+ ) -> Union[tuple, ImageClassifierOutput]:
990
+ r"""
991
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
992
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
993
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
994
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
995
+ """
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+ outputs = self.data2vec_vision(
998
+ pixel_values,
999
+ head_mask=head_mask,
1000
+ output_attentions=output_attentions,
1001
+ output_hidden_states=output_hidden_states,
1002
+ interpolate_pos_encoding=interpolate_pos_encoding,
1003
+ return_dict=return_dict,
1004
+ )
1005
+
1006
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
1007
+
1008
+ logits = self.classifier(pooled_output)
1009
+
1010
+ loss = None
1011
+ if labels is not None:
1012
+ if self.config.problem_type is None:
1013
+ if self.num_labels == 1:
1014
+ self.config.problem_type = "regression"
1015
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1016
+ self.config.problem_type = "single_label_classification"
1017
+ else:
1018
+ self.config.problem_type = "multi_label_classification"
1019
+
1020
+ if self.config.problem_type == "regression":
1021
+ loss_fct = MSELoss()
1022
+ if self.num_labels == 1:
1023
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1024
+ else:
1025
+ loss = loss_fct(logits, labels)
1026
+ elif self.config.problem_type == "single_label_classification":
1027
+ loss_fct = CrossEntropyLoss()
1028
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1029
+ elif self.config.problem_type == "multi_label_classification":
1030
+ loss_fct = BCEWithLogitsLoss()
1031
+ loss = loss_fct(logits, labels)
1032
+ if not return_dict:
1033
+ output = (logits,) + outputs[2:]
1034
+ return ((loss,) + output) if loss is not None else output
1035
+
1036
+ return ImageClassifierOutput(
1037
+ loss=loss,
1038
+ logits=logits,
1039
+ hidden_states=outputs.hidden_states,
1040
+ attentions=outputs.attentions,
1041
+ )
1042
+
1043
+
1044
+ # Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision
1045
+ class Data2VecVisionConvModule(nn.Module):
1046
+ """
1047
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
1048
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
1049
+
1050
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1051
+ """
1052
+
1053
+ def __init__(
1054
+ self,
1055
+ in_channels: int,
1056
+ out_channels: int,
1057
+ kernel_size: Union[int, Tuple[int, int]],
1058
+ padding: Union[int, Tuple[int, int], str] = 0,
1059
+ bias: bool = False,
1060
+ dilation: Union[int, Tuple[int, int]] = 1,
1061
+ ) -> None:
1062
+ super().__init__()
1063
+ self.conv = nn.Conv2d(
1064
+ in_channels=in_channels,
1065
+ out_channels=out_channels,
1066
+ kernel_size=kernel_size,
1067
+ padding=padding,
1068
+ bias=bias,
1069
+ dilation=dilation,
1070
+ )
1071
+ self.bn = nn.BatchNorm2d(out_channels)
1072
+ self.activation = nn.ReLU()
1073
+
1074
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
1075
+ output = self.conv(input)
1076
+ output = self.bn(output)
1077
+ output = self.activation(output)
1078
+
1079
+ return output
1080
+
1081
+
1082
+ # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
1083
+ class Data2VecVisionPyramidPoolingBlock(nn.Module):
1084
+ def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
1085
+ super().__init__()
1086
+ self.layers = [
1087
+ nn.AdaptiveAvgPool2d(pool_scale),
1088
+ Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
1089
+ ]
1090
+ for i, layer in enumerate(self.layers):
1091
+ self.add_module(str(i), layer)
1092
+
1093
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
1094
+ hidden_state = input
1095
+ for layer in self.layers:
1096
+ hidden_state = layer(hidden_state)
1097
+ return hidden_state
1098
+
1099
+
1100
+ # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
1101
+ class Data2VecVisionPyramidPoolingModule(nn.Module):
1102
+ """
1103
+ Pyramid Pooling Module (PPM) used in PSPNet.
1104
+
1105
+ Args:
1106
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
1107
+ Module.
1108
+ in_channels (int): Input channels.
1109
+ channels (int): Channels after modules, before conv_seg.
1110
+ align_corners (bool): align_corners argument of F.interpolate.
1111
+
1112
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1113
+ """
1114
+
1115
+ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
1116
+ super().__init__()
1117
+ self.pool_scales = pool_scales
1118
+ self.align_corners = align_corners
1119
+ self.in_channels = in_channels
1120
+ self.channels = channels
1121
+ self.blocks = []
1122
+ for i, pool_scale in enumerate(pool_scales):
1123
+ block = Data2VecVisionPyramidPoolingBlock(
1124
+ pool_scale=pool_scale, in_channels=in_channels, channels=channels
1125
+ )
1126
+ self.blocks.append(block)
1127
+ self.add_module(str(i), block)
1128
+
1129
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
1130
+ ppm_outs = []
1131
+ for ppm in self.blocks:
1132
+ ppm_out = ppm(x)
1133
+ upsampled_ppm_out = nn.functional.interpolate(
1134
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
1135
+ )
1136
+ ppm_outs.append(upsampled_ppm_out)
1137
+ return ppm_outs
1138
+
1139
+
1140
+ # Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision
1141
+ class Data2VecVisionUperHead(nn.Module):
1142
+ """
1143
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
1144
+ [UPerNet](https://arxiv.org/abs/1807.10221).
1145
+
1146
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1147
+ """
1148
+
1149
+ def __init__(self, config: Data2VecVisionConfig) -> None:
1150
+ super().__init__()
1151
+
1152
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
1153
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
1154
+ self.channels = config.hidden_size
1155
+ self.align_corners = False
1156
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1157
+
1158
+ # PSP Module
1159
+ self.psp_modules = Data2VecVisionPyramidPoolingModule(
1160
+ self.pool_scales,
1161
+ self.in_channels[-1],
1162
+ self.channels,
1163
+ align_corners=self.align_corners,
1164
+ )
1165
+ self.bottleneck = Data2VecVisionConvModule(
1166
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
1167
+ self.channels,
1168
+ kernel_size=3,
1169
+ padding=1,
1170
+ )
1171
+ # FPN Module
1172
+ self.lateral_convs = nn.ModuleList()
1173
+ self.fpn_convs = nn.ModuleList()
1174
+ for in_channels in self.in_channels[:-1]: # skip the top layer
1175
+ l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1)
1176
+ fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1)
1177
+ self.lateral_convs.append(l_conv)
1178
+ self.fpn_convs.append(fpn_conv)
1179
+
1180
+ self.fpn_bottleneck = Data2VecVisionConvModule(
1181
+ len(self.in_channels) * self.channels,
1182
+ self.channels,
1183
+ kernel_size=3,
1184
+ padding=1,
1185
+ )
1186
+
1187
+ def psp_forward(self, inputs):
1188
+ x = inputs[-1]
1189
+ psp_outs = [x]
1190
+ psp_outs.extend(self.psp_modules(x))
1191
+ psp_outs = torch.cat(psp_outs, dim=1)
1192
+ output = self.bottleneck(psp_outs)
1193
+
1194
+ return output
1195
+
1196
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
1197
+ # build laterals
1198
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
1199
+
1200
+ laterals.append(self.psp_forward(encoder_hidden_states))
1201
+
1202
+ # build top-down path
1203
+ used_backbone_levels = len(laterals)
1204
+ for i in range(used_backbone_levels - 1, 0, -1):
1205
+ prev_shape = laterals[i - 1].shape[2:]
1206
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
1207
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
1208
+ )
1209
+
1210
+ # build outputs
1211
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
1212
+ # append psp feature
1213
+ fpn_outs.append(laterals[-1])
1214
+
1215
+ for i in range(used_backbone_levels - 1, 0, -1):
1216
+ fpn_outs[i] = nn.functional.interpolate(
1217
+ fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
1218
+ )
1219
+ fpn_outs = torch.cat(fpn_outs, dim=1)
1220
+ output = self.fpn_bottleneck(fpn_outs)
1221
+ output = self.classifier(output)
1222
+
1223
+ return output
1224
+
1225
+
1226
+ # Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision
1227
+ class Data2VecVisionFCNHead(nn.Module):
1228
+ """
1229
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented of
1230
+ [FCNNet](https://arxiv.org/abs/1411.4038>).
1231
+
1232
+ Args:
1233
+ config (Data2VecVisionConfig): Configuration.
1234
+ in_channels
1235
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
1236
+ dilation (int): The dilation rate for convs in the head. Default: 1.
1237
+
1238
+
1239
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1240
+ """
1241
+
1242
+ def __init__(
1243
+ self,
1244
+ config: Data2VecVisionConfig,
1245
+ in_index: int = 2,
1246
+ kernel_size: int = 3,
1247
+ dilation: Union[int, Tuple[int, int]] = 1,
1248
+ ) -> None:
1249
+ super().__init__()
1250
+ self.in_channels = config.hidden_size
1251
+ self.channels = config.auxiliary_channels
1252
+ self.num_convs = config.auxiliary_num_convs
1253
+ self.concat_input = config.auxiliary_concat_input
1254
+ self.in_index = in_index
1255
+
1256
+ conv_padding = (kernel_size // 2) * dilation
1257
+ convs = []
1258
+ convs.append(
1259
+ Data2VecVisionConvModule(
1260
+ self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1261
+ )
1262
+ )
1263
+ for i in range(self.num_convs - 1):
1264
+ convs.append(
1265
+ Data2VecVisionConvModule(
1266
+ self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1267
+ )
1268
+ )
1269
+ if self.num_convs == 0:
1270
+ self.convs = nn.Identity()
1271
+ else:
1272
+ self.convs = nn.Sequential(*convs)
1273
+ if self.concat_input:
1274
+ self.conv_cat = Data2VecVisionConvModule(
1275
+ self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
1276
+ )
1277
+
1278
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1279
+
1280
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
1281
+ # just take the relevant feature maps
1282
+ hidden_states = encoder_hidden_states[self.in_index]
1283
+ output = self.convs(hidden_states)
1284
+ if self.concat_input:
1285
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
1286
+ output = self.classifier(output)
1287
+ return output
1288
+
1289
+
1290
+ @add_start_docstrings(
1291
+ """
1292
+ Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
1293
+ """,
1294
+ DATA2VEC_VISION_START_DOCSTRING,
1295
+ )
1296
+ # Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision
1297
+ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
1298
+ def __init__(self, config: Data2VecVisionConfig) -> None:
1299
+ super().__init__(config)
1300
+
1301
+ self.num_labels = config.num_labels
1302
+ self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
1303
+
1304
+ # FPNs
1305
+ if len(self.config.out_indices) != 4:
1306
+ raise ValueError(
1307
+ "Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
1308
+ "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
1309
+ "a base-sized architecture."
1310
+ )
1311
+ self.fpn1 = nn.Sequential(
1312
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1313
+ nn.BatchNorm2d(config.hidden_size),
1314
+ nn.GELU(),
1315
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1316
+ )
1317
+ self.fpn2 = nn.Sequential(
1318
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1319
+ )
1320
+ self.fpn3 = nn.Identity()
1321
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
1322
+
1323
+ # Semantic segmentation head(s)
1324
+ self.decode_head = Data2VecVisionUperHead(config)
1325
+ self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None
1326
+
1327
+ # Initialize weights and apply final processing
1328
+ self.post_init()
1329
+
1330
+ def compute_loss(self, logits, auxiliary_logits, labels):
1331
+ # upsample logits to the images' original size
1332
+ upsampled_logits = nn.functional.interpolate(
1333
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1334
+ )
1335
+ if auxiliary_logits is not None:
1336
+ upsampled_auxiliary_logits = nn.functional.interpolate(
1337
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1338
+ )
1339
+ # compute weighted loss
1340
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
1341
+ main_loss = loss_fct(upsampled_logits, labels)
1342
+ loss = main_loss
1343
+ if auxiliary_logits is not None:
1344
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
1345
+ loss += self.config.auxiliary_loss_weight * auxiliary_loss
1346
+
1347
+ return loss
1348
+
1349
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
1350
+ @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
1351
+ def forward(
1352
+ self,
1353
+ pixel_values: Optional[torch.Tensor] = None,
1354
+ head_mask: Optional[torch.Tensor] = None,
1355
+ labels: Optional[torch.Tensor] = None,
1356
+ output_attentions: Optional[bool] = None,
1357
+ output_hidden_states: Optional[bool] = None,
1358
+ interpolate_pos_encoding: bool = False,
1359
+ return_dict: Optional[bool] = None,
1360
+ ) -> Union[tuple, SemanticSegmenterOutput]:
1361
+ r"""
1362
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
1363
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
1364
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
1365
+
1366
+ Returns:
1367
+
1368
+ Examples:
1369
+
1370
+ ```python
1371
+ >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
1372
+ >>> from PIL import Image
1373
+ >>> import requests
1374
+
1375
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1376
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1377
+
1378
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
1379
+ >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
1380
+
1381
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1382
+ >>> outputs = model(**inputs)
1383
+ >>> # logits are of shape (batch_size, num_labels, height, width)
1384
+ >>> logits = outputs.logits
1385
+ ```"""
1386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1387
+ output_hidden_states = (
1388
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1389
+ )
1390
+
1391
+ if labels is not None and self.config.num_labels == 1:
1392
+ raise ValueError("The number of labels should be greater than one")
1393
+
1394
+ outputs = self.data2vec_vision(
1395
+ pixel_values,
1396
+ head_mask=head_mask,
1397
+ output_attentions=output_attentions,
1398
+ output_hidden_states=True, # we need the intermediate hidden states
1399
+ interpolate_pos_encoding=interpolate_pos_encoding,
1400
+ return_dict=return_dict,
1401
+ )
1402
+
1403
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
1404
+
1405
+ # only keep certain features, and reshape
1406
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
1407
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
1408
+ batch_size = pixel_values.shape[0]
1409
+ patch_resolution = self.config.image_size // self.config.patch_size
1410
+ features = [
1411
+ x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
1412
+ ]
1413
+
1414
+ # apply FPNs
1415
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
1416
+ for i in range(len(features)):
1417
+ features[i] = ops[i](features[i])
1418
+
1419
+ logits = self.decode_head(features)
1420
+
1421
+ auxiliary_logits = None
1422
+ if self.auxiliary_head is not None:
1423
+ auxiliary_logits = self.auxiliary_head(features)
1424
+
1425
+ loss = None
1426
+ if labels is not None:
1427
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
1428
+
1429
+ if not return_dict:
1430
+ if output_hidden_states:
1431
+ output = (logits,) + outputs[1:]
1432
+ else:
1433
+ output = (logits,) + outputs[2:]
1434
+ return ((loss,) + output) if loss is not None else output
1435
+
1436
+ return SemanticSegmenterOutput(
1437
+ loss=loss,
1438
+ logits=logits,
1439
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1440
+ attentions=outputs.attentions,
1441
+ )
1442
+
1443
+
1444
+ __all__ = [
1445
+ "Data2VecVisionForImageClassification",
1446
+ "Data2VecVisionForSemanticSegmentation",
1447
+ "Data2VecVisionModel",
1448
+ "Data2VecVisionPreTrainedModel",
1449
+ ]
docs/transformers/build/lib/transformers/models/data2vec/modeling_tf_data2vec_vision.py ADDED
@@ -0,0 +1,1724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 Data2Vec Vision model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import collections.abc
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import tensorflow as tf
26
+
27
+ from ...activations_tf import get_tf_activation
28
+ from ...modeling_tf_outputs import (
29
+ TFBaseModelOutput,
30
+ TFBaseModelOutputWithPooling,
31
+ TFSemanticSegmenterOutput,
32
+ TFSequenceClassifierOutput,
33
+ )
34
+ from ...modeling_tf_utils import (
35
+ TFModelInputType,
36
+ TFPreTrainedModel,
37
+ TFSequenceClassificationLoss,
38
+ get_initializer,
39
+ keras,
40
+ keras_serializable,
41
+ unpack_inputs,
42
+ )
43
+ from ...tf_utils import shape_list, stable_softmax
44
+ from ...utils import (
45
+ add_code_sample_docstrings,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from .configuration_data2vec_vision import Data2VecVisionConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ # General docstring
57
+ _CONFIG_FOR_DOC = "Data2VecVisionConfig"
58
+
59
+ # Base docstring
60
+ _CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
61
+ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
62
+
63
+ # Image classification docstring
64
+ _IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
65
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
66
+
67
+
68
+ @dataclass
69
+ class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
70
+ """
71
+ Class for outputs of [`TFData2VecVisionModel`].
72
+
73
+ Args:
74
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
75
+ Sequence of hidden-states at the output of the last layer of the model.
76
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
77
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
78
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
79
+ will be returned.
80
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
81
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
82
+ `(batch_size, sequence_length, hidden_size)`.
83
+
84
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
85
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
86
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
87
+ sequence_length)`.
88
+
89
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
90
+ heads.
91
+ """
92
+
93
+ last_hidden_state: Optional[tf.Tensor] = None
94
+ pooler_output: Optional[tf.Tensor] = None
95
+ hidden_states: Tuple[tf.Tensor] | None = None
96
+ attentions: Tuple[tf.Tensor] | None = None
97
+
98
+
99
+ class TFData2VecVisionDropPath(keras.layers.Layer):
100
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
101
+ References:
102
+ (1) github.com:rwightman/pytorch-image-models
103
+ """
104
+
105
+ def __init__(self, drop_path, **kwargs):
106
+ super().__init__(**kwargs)
107
+ self.drop_path = drop_path
108
+
109
+ def call(self, x, training=None):
110
+ if training:
111
+ keep_prob = 1 - self.drop_path
112
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
113
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
114
+ random_tensor = tf.floor(random_tensor)
115
+ return (x / keep_prob) * random_tensor
116
+ return x
117
+
118
+
119
+ class TFData2VecVisionEmbeddings(keras.layers.Layer):
120
+ """
121
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
122
+
123
+ """
124
+
125
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
126
+ super().__init__(**kwargs)
127
+ self.config = config
128
+
129
+ self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
130
+ self.num_patches = self.patch_embeddings.num_patches
131
+ self.config = config
132
+
133
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
134
+
135
+ def build(self, input_shape=None):
136
+ self.cls_token = self.add_weight(
137
+ shape=(1, 1, self.config.hidden_size),
138
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
139
+ trainable=True,
140
+ name="cls_token",
141
+ )
142
+ if self.config.use_mask_token:
143
+ self.mask_token = self.add_weight(
144
+ shape=(1, 1, self.config.hidden_size),
145
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
146
+ trainable=True,
147
+ name="mask_token",
148
+ )
149
+ else:
150
+ self.mask_token = None
151
+
152
+ if self.config.use_absolute_position_embeddings:
153
+ self.position_embeddings = self.add_weight(
154
+ shape=(1, self.num_patches + 1, self.config.hidden_size),
155
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
156
+ trainable=True,
157
+ name="position_embeddings",
158
+ )
159
+ else:
160
+ self.position_embeddings = None
161
+
162
+ if self.built:
163
+ return
164
+ self.built = True
165
+ if getattr(self, "patch_embeddings", None) is not None:
166
+ with tf.name_scope(self.patch_embeddings.name):
167
+ self.patch_embeddings.build(None)
168
+
169
+ def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor:
170
+ embeddings = self.patch_embeddings(pixel_values)
171
+ batch_size, seq_len, projection_dim = shape_list(embeddings)
172
+
173
+ cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1))
174
+
175
+ if bool_masked_pos is not None:
176
+ mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim))
177
+ # replace the masked visual tokens by mask_tokens
178
+ w = bool_masked_pos[..., None]
179
+ w = tf.cast(w, mask_tokens.dtype)
180
+ # since TF doesn't support eager tensor assignment
181
+ embeddings = embeddings * (1 - w) + mask_tokens * w
182
+
183
+ embeddings = tf.concat([cls_tokens, embeddings], axis=1)
184
+ if self.position_embeddings is not None:
185
+ embeddings = embeddings + self.position_embeddings
186
+ embeddings = self.dropout(embeddings)
187
+
188
+ return embeddings
189
+
190
+
191
+ class TFData2VecVisionPatchEmbeddings(keras.layers.Layer):
192
+ """
193
+ Image to Patch Embedding.
194
+ """
195
+
196
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
197
+ super().__init__(**kwargs)
198
+ self.config = config
199
+
200
+ image_size, patch_size = config.image_size, config.patch_size
201
+ num_channels, hidden_size = config.num_channels, config.hidden_size
202
+
203
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
204
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
205
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
206
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
207
+ self.image_size = image_size
208
+ self.patch_size = patch_size
209
+ self.num_patches = num_patches
210
+ self.patch_shape = patch_shape
211
+ self.num_channels = num_channels
212
+
213
+ self.projection = keras.layers.Conv2D(
214
+ filters=hidden_size,
215
+ kernel_size=patch_size,
216
+ strides=patch_size,
217
+ padding="valid",
218
+ data_format="channels_last",
219
+ kernel_initializer="glorot_uniform", # following torch.nn.Linear
220
+ bias_initializer="zeros",
221
+ name="projection",
222
+ )
223
+
224
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
225
+ batch_size, num_channels, height, width = shape_list(pixel_values)
226
+ if tf.executing_eagerly():
227
+ if num_channels != self.num_channels:
228
+ raise ValueError(
229
+ "Make sure that the channel dimension of the pixel values match with the one set in the"
230
+ " configuration."
231
+ )
232
+ if height != self.image_size[0] or width != self.image_size[1]:
233
+ raise ValueError(
234
+ f"Input image size ({height}*{width}) doesn't match model"
235
+ f" ({self.image_size[0]}*{self.image_size[1]})."
236
+ )
237
+
238
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
239
+ # So change the input format from `NCHW` to `NHWC`.
240
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
241
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
242
+
243
+ projection = self.projection(pixel_values)
244
+
245
+ # Change the 2D spatial dimensions to a single temporal dimension.
246
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
247
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
248
+
249
+ return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
250
+
251
+ def build(self, input_shape=None):
252
+ if self.built:
253
+ return
254
+ self.built = True
255
+ if getattr(self, "projection", None) is not None:
256
+ with tf.name_scope(self.projection.name):
257
+ self.projection.build([None, None, None, self.num_channels])
258
+
259
+
260
+ class TFData2VecVisionSelfAttention(keras.layers.Layer):
261
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
262
+ super().__init__(**kwargs)
263
+
264
+ if config.hidden_size % config.num_attention_heads != 0:
265
+ raise ValueError(
266
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
267
+ f"of attention heads ({config.num_attention_heads})"
268
+ )
269
+
270
+ self.num_attention_heads = config.num_attention_heads
271
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
272
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
273
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
274
+
275
+ self.query = keras.layers.Dense(
276
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
277
+ )
278
+ self.key = keras.layers.Dense(
279
+ units=self.all_head_size,
280
+ kernel_initializer=get_initializer(config.initializer_range),
281
+ name="key",
282
+ use_bias=False,
283
+ )
284
+ self.value = keras.layers.Dense(
285
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
286
+ )
287
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
288
+
289
+ if window_size:
290
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
291
+ config, window_size=window_size, name="relative_position_bias"
292
+ )
293
+ else:
294
+ self.relative_position_bias = None
295
+ self.config = config
296
+
297
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
298
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
299
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
300
+
301
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
302
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
303
+
304
+ def call(
305
+ self,
306
+ hidden_states: tf.Tensor,
307
+ head_mask: tf.Tensor,
308
+ output_attentions: bool,
309
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
310
+ training: bool = False,
311
+ ) -> Tuple[tf.Tensor]:
312
+ batch_size = shape_list(hidden_states)[0]
313
+ mixed_query_layer = self.query(inputs=hidden_states)
314
+ mixed_key_layer = self.key(inputs=hidden_states)
315
+ mixed_value_layer = self.value(inputs=hidden_states)
316
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
317
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
318
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
319
+
320
+ # Take the dot product between "query" and "key" to get the raw attention scores.
321
+ # (batch size, num_heads, seq_len_q, seq_len_k)
322
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
323
+ attention_scores = attention_scores / self.sqrt_att_head_size
324
+
325
+ # Add relative position bias if present.
326
+ if self.relative_position_bias is not None:
327
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
328
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
329
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
330
+ attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...]
331
+
332
+ # Add shared relative position bias if provided.
333
+ if relative_position_bias is not None:
334
+ attention_scores = attention_scores + relative_position_bias
335
+
336
+ # Normalize the attention scores to probabilities.
337
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
338
+
339
+ # This is actually dropping out entire tokens to attend to, which might
340
+ # seem a bit unusual, but is taken from the original Transformer paper.
341
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
342
+
343
+ # Mask heads if we want to
344
+ if head_mask is not None:
345
+ attention_probs = tf.multiply(attention_probs, head_mask)
346
+
347
+ attention_output = tf.matmul(attention_probs, value_layer)
348
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
349
+
350
+ # (batch_size, seq_len_q, all_head_size)
351
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
352
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
353
+
354
+ return outputs
355
+
356
+ def build(self, input_shape=None):
357
+ if self.built:
358
+ return
359
+ self.built = True
360
+ if getattr(self, "query", None) is not None:
361
+ with tf.name_scope(self.query.name):
362
+ self.query.build([None, None, self.config.hidden_size])
363
+ if getattr(self, "key", None) is not None:
364
+ with tf.name_scope(self.key.name):
365
+ self.key.build([None, None, self.config.hidden_size])
366
+ if getattr(self, "value", None) is not None:
367
+ with tf.name_scope(self.value.name):
368
+ self.value.build([None, None, self.config.hidden_size])
369
+ if getattr(self, "relative_position_bias", None) is not None:
370
+ with tf.name_scope(self.relative_position_bias.name):
371
+ self.relative_position_bias.build(None)
372
+
373
+
374
+ class TFData2VecVisionSelfOutput(keras.layers.Layer):
375
+ """
376
+ The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due
377
+ to the layernorm applied before each block.
378
+ """
379
+
380
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
381
+ super().__init__(**kwargs)
382
+
383
+ self.dense = keras.layers.Dense(
384
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
385
+ )
386
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
387
+ self.config = config
388
+
389
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
390
+ hidden_states = self.dense(inputs=hidden_states)
391
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
392
+
393
+ return hidden_states
394
+
395
+ def build(self, input_shape=None):
396
+ if self.built:
397
+ return
398
+ self.built = True
399
+ if getattr(self, "dense", None) is not None:
400
+ with tf.name_scope(self.dense.name):
401
+ self.dense.build([None, None, self.config.hidden_size])
402
+
403
+
404
+ class TFData2VecVisionAttention(keras.layers.Layer):
405
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
406
+ super().__init__(**kwargs)
407
+
408
+ self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention")
409
+ self.dense_output = TFData2VecVisionSelfOutput(config, name="output")
410
+
411
+ def prune_heads(self, heads):
412
+ raise NotImplementedError
413
+
414
+ def call(
415
+ self,
416
+ input_tensor: tf.Tensor,
417
+ head_mask: tf.Tensor,
418
+ output_attentions: bool,
419
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
420
+ training: bool = False,
421
+ ) -> Tuple[tf.Tensor]:
422
+ self_outputs = self.attention(
423
+ hidden_states=input_tensor,
424
+ head_mask=head_mask,
425
+ output_attentions=output_attentions,
426
+ relative_position_bias=relative_position_bias,
427
+ training=training,
428
+ )
429
+ attention_output = self.dense_output(
430
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
431
+ )
432
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
433
+
434
+ return outputs
435
+
436
+ def build(self, input_shape=None):
437
+ if self.built:
438
+ return
439
+ self.built = True
440
+ if getattr(self, "attention", None) is not None:
441
+ with tf.name_scope(self.attention.name):
442
+ self.attention.build(None)
443
+ if getattr(self, "dense_output", None) is not None:
444
+ with tf.name_scope(self.dense_output.name):
445
+ self.dense_output.build(None)
446
+
447
+
448
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
449
+ class TFData2VecVisionIntermediate(keras.layers.Layer):
450
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
451
+ super().__init__(**kwargs)
452
+
453
+ self.dense = keras.layers.Dense(
454
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
455
+ )
456
+
457
+ if isinstance(config.hidden_act, str):
458
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
459
+ else:
460
+ self.intermediate_act_fn = config.hidden_act
461
+ self.config = config
462
+
463
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
464
+ hidden_states = self.dense(inputs=hidden_states)
465
+ hidden_states = self.intermediate_act_fn(hidden_states)
466
+
467
+ return hidden_states
468
+
469
+ def build(self, input_shape=None):
470
+ if self.built:
471
+ return
472
+ self.built = True
473
+ if getattr(self, "dense", None) is not None:
474
+ with tf.name_scope(self.dense.name):
475
+ self.dense.build([None, None, self.config.hidden_size])
476
+
477
+
478
+ class TFData2VecVisionOutput(keras.layers.Layer):
479
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
480
+ super().__init__(**kwargs)
481
+
482
+ self.dense = keras.layers.Dense(
483
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
484
+ )
485
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
486
+ self.config = config
487
+
488
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
489
+ hidden_states = self.dense(inputs=hidden_states)
490
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
491
+
492
+ return hidden_states
493
+
494
+ def build(self, input_shape=None):
495
+ if self.built:
496
+ return
497
+ self.built = True
498
+ if getattr(self, "dense", None) is not None:
499
+ with tf.name_scope(self.dense.name):
500
+ self.dense.build([None, None, self.config.intermediate_size])
501
+
502
+
503
+ class TFData2VecVisionLayer(keras.layers.Layer):
504
+ """This corresponds to the Block class in the timm implementation."""
505
+
506
+ def __init__(
507
+ self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0, **kwargs
508
+ ):
509
+ super().__init__(**kwargs)
510
+ self.config = config
511
+
512
+ self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention")
513
+ self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate")
514
+ self.data2vec_output = TFData2VecVisionOutput(config, name="output")
515
+
516
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
517
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
518
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
519
+ # behaviour.
520
+ self.drop_path = (
521
+ TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
522
+ if drop_path_rate > 0.0
523
+ else keras.layers.Activation("linear", name="drop_path")
524
+ )
525
+ self.init_values = config.layer_scale_init_value
526
+
527
+ def build(self, input_shape: tf.TensorShape = None):
528
+ if self.init_values > 0:
529
+ self.lambda_1 = self.add_weight(
530
+ shape=(self.config.hidden_size),
531
+ initializer="ones",
532
+ trainable=True,
533
+ name="lambda_1",
534
+ )
535
+ self.lambda_2 = self.add_weight(
536
+ shape=(self.config.hidden_size),
537
+ initializer="ones",
538
+ trainable=True,
539
+ name="lambda_2",
540
+ )
541
+ self.lambda_1.assign(self.init_values * tf.ones((self.config.hidden_size)))
542
+ self.lambda_2.assign(self.init_values * tf.ones((self.config.hidden_size)))
543
+ else:
544
+ self.lambda_1, self.lambda_2 = None, None
545
+
546
+ if self.built:
547
+ return
548
+ self.built = True
549
+ if getattr(self, "attention", None) is not None:
550
+ with tf.name_scope(self.attention.name):
551
+ self.attention.build(None)
552
+ if getattr(self, "intermediate", None) is not None:
553
+ with tf.name_scope(self.intermediate.name):
554
+ self.intermediate.build(None)
555
+ if getattr(self, "data2vec_output", None) is not None:
556
+ with tf.name_scope(self.data2vec_output.name):
557
+ self.data2vec_output.build(None)
558
+ if getattr(self, "layernorm_before", None) is not None:
559
+ with tf.name_scope(self.layernorm_before.name):
560
+ self.layernorm_before.build([None, None, self.config.hidden_size])
561
+ if getattr(self, "layernorm_after", None) is not None:
562
+ with tf.name_scope(self.layernorm_after.name):
563
+ self.layernorm_after.build([None, None, self.config.hidden_size])
564
+ if getattr(self, "drop_path", None) is not None:
565
+ with tf.name_scope(self.drop_path.name):
566
+ self.drop_path.build(None)
567
+
568
+ def call(
569
+ self,
570
+ hidden_states: tf.Tensor,
571
+ head_mask: tf.Tensor,
572
+ output_attentions: bool,
573
+ relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
574
+ training: bool = False,
575
+ ) -> Tuple[tf.Tensor]:
576
+ self_attention_outputs = self.attention(
577
+ # in Data2VecVision, layernorm is applied before self-attention
578
+ input_tensor=self.layernorm_before(inputs=hidden_states),
579
+ head_mask=head_mask,
580
+ output_attentions=output_attentions,
581
+ relative_position_bias=relative_position_bias,
582
+ training=training,
583
+ )
584
+ attention_output = self_attention_outputs[0]
585
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
586
+
587
+ # apply lambda_1 if present
588
+ if self.lambda_1 is not None:
589
+ attention_output = self.lambda_1 * attention_output
590
+
591
+ # first residual connection
592
+ hidden_states = self.drop_path(attention_output) + hidden_states
593
+
594
+ # in Data2VecVision, layernorm is also applied after self-attention
595
+ layer_output = self.layernorm_after(hidden_states)
596
+
597
+ layer_output = self.intermediate(layer_output)
598
+ layer_output = self.data2vec_output(layer_output)
599
+
600
+ if self.lambda_2 is not None:
601
+ layer_output = self.lambda_2 * layer_output
602
+
603
+ # second residual connection
604
+ layer_output = self.drop_path(layer_output) + hidden_states
605
+
606
+ outputs = (layer_output,) + outputs
607
+
608
+ return outputs
609
+
610
+
611
+ # Taken and modified from here:
612
+ # https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28
613
+ class TFData2VecVisionRelativePositionBias(keras.layers.Layer):
614
+ def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None:
615
+ super().__init__(**kwargs)
616
+ self.config = config
617
+
618
+ self.window_size = window_size
619
+ # +3 for cls_token_pos_len
620
+ # window_size can be something like (14, 14)
621
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
622
+
623
+ self.relative_position_index = self.get_position_index()
624
+
625
+ def build(self, input_shape):
626
+ self.relative_position_bias_table = self.add_weight(
627
+ shape=(self.num_relative_distance, self.config.num_attention_heads),
628
+ initializer="zeros",
629
+ trainable=True,
630
+ name="relative_position_bias_table",
631
+ ) # [2*Wh-1 * 2*Ww-1, nH]
632
+ # cls to token & token 2 cls & cls to cls
633
+
634
+ super().build(input_shape)
635
+
636
+ def get_position_index(self):
637
+ # get pair-wise relative position index for each token inside the window
638
+ xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1]))
639
+ coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww]
640
+ coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww]
641
+
642
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww]
643
+ relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2]
644
+
645
+ xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
646
+ yy = relative_coords[:, :, 1] + self.window_size[1] - 1
647
+ relative_coords = tf.stack([xx, yy], axis=-1)
648
+
649
+ relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww]
650
+
651
+ top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * (
652
+ self.num_relative_distance - 3
653
+ )
654
+ left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * (
655
+ self.num_relative_distance - 2
656
+ )
657
+ corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1)
658
+
659
+ left_corner = tf.concat([corner, left], axis=0)
660
+ relative_position_index = tf.concat([top, relative_position_index], axis=0)
661
+ relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1]
662
+ return relative_position_index
663
+
664
+ def call(self, inputs=None) -> tf.Tensor:
665
+ relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0)
666
+ return tf.transpose(relative_position_bias, [2, 0, 1])
667
+
668
+
669
+ class TFData2VecVisionEncoder(keras.layers.Layer):
670
+ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
671
+ super().__init__(**kwargs)
672
+ self.config = config
673
+ if config.use_shared_relative_position_bias:
674
+ self.relative_position_bias = TFData2VecVisionRelativePositionBias(
675
+ config, window_size=window_size, name="relative_position_bias"
676
+ )
677
+ else:
678
+ self.relative_position_bias = None
679
+
680
+ # stochastic depth decay rule
681
+ dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers))
682
+ self.layer = [
683
+ TFData2VecVisionLayer(
684
+ config,
685
+ window_size=window_size if config.use_relative_position_bias else None,
686
+ drop_path_rate=dpr[i],
687
+ name=f"layer_._{i}",
688
+ )
689
+ for i in range(config.num_hidden_layers)
690
+ ]
691
+
692
+ def call(
693
+ self,
694
+ hidden_states: tf.Tensor,
695
+ head_mask: tf.Tensor | None = None,
696
+ output_attentions: bool = False,
697
+ output_hidden_states: bool = False,
698
+ return_dict: bool = True,
699
+ ) -> Union[tuple, TFBaseModelOutput]:
700
+ all_hidden_states = () if output_hidden_states else None
701
+ all_self_attentions = () if output_attentions else None
702
+
703
+ for i, layer_module in enumerate(self.layer):
704
+ if output_hidden_states:
705
+ all_hidden_states = all_hidden_states + (hidden_states,)
706
+
707
+ layer_head_mask = head_mask[i] if head_mask is not None else None
708
+ # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
709
+ # might complain about `Layer.call()` not being invoked properly. In this case this input
710
+ # i.e., 0.0 is not going to be used in any calculations so we're safe.
711
+ relative_position_bias = (
712
+ self.relative_position_bias(0.0) if self.relative_position_bias is not None else None
713
+ )
714
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
715
+
716
+ hidden_states = layer_outputs[0]
717
+
718
+ if output_attentions:
719
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
720
+
721
+ if output_hidden_states:
722
+ all_hidden_states = all_hidden_states + (hidden_states,)
723
+
724
+ if not return_dict:
725
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
726
+
727
+ return TFBaseModelOutput(
728
+ last_hidden_state=hidden_states,
729
+ hidden_states=all_hidden_states,
730
+ attentions=all_self_attentions,
731
+ )
732
+
733
+ def build(self, input_shape=None):
734
+ if self.built:
735
+ return
736
+ self.built = True
737
+ if getattr(self, "relative_position_bias", None) is not None:
738
+ with tf.name_scope(self.relative_position_bias.name):
739
+ self.relative_position_bias.build(None)
740
+ if getattr(self, "layer", None) is not None:
741
+ for layer in self.layer:
742
+ with tf.name_scope(layer.name):
743
+ layer.build(None)
744
+
745
+
746
+ @keras_serializable
747
+ class TFData2VecVisionMainLayer(keras.layers.Layer):
748
+ config_class = Data2VecVisionConfig
749
+
750
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs):
751
+ super().__init__(**kwargs)
752
+
753
+ self.config = config
754
+ self.add_pooling_layer = add_pooling_layer
755
+
756
+ self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings")
757
+ self.encoder = TFData2VecVisionEncoder(
758
+ config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder"
759
+ )
760
+ self.layernorm = (
761
+ tf.identity
762
+ if config.use_mean_pooling
763
+ else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
764
+ )
765
+
766
+ # We are setting the `data_format` like so because from here on we will revert to the
767
+ # NCHW output format
768
+ self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None
769
+
770
+ def get_input_embeddings(self) -> keras.layers.Layer:
771
+ return self.embeddings.patch_embeddings
772
+
773
+ def _prune_heads(self, heads_to_prune):
774
+ """
775
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
776
+ class PreTrainedModel
777
+ """
778
+ raise NotImplementedError
779
+
780
+ @unpack_inputs
781
+ def call(
782
+ self,
783
+ pixel_values: tf.Tensor | None = None,
784
+ bool_masked_pos: tf.Tensor | None = None,
785
+ head_mask: tf.Tensor | None = None,
786
+ output_attentions: Optional[bool] = None,
787
+ output_hidden_states: Optional[bool] = None,
788
+ return_dict: Optional[bool] = None,
789
+ training: bool = False,
790
+ ) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
791
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
792
+ output_hidden_states = (
793
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
794
+ )
795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
796
+
797
+ if pixel_values is None:
798
+ raise ValueError("You have to specify pixel_values")
799
+
800
+ # Prepare head mask if needed
801
+ # 1.0 in head_mask indicate we keep the head
802
+ # attention_probs has shape bsz x n_heads x N x N
803
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
804
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
805
+ if head_mask is not None:
806
+ raise NotImplementedError
807
+ else:
808
+ head_mask = [None] * self.config.num_hidden_layers
809
+
810
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training)
811
+
812
+ encoder_outputs = self.encoder(
813
+ embedding_output,
814
+ head_mask=head_mask,
815
+ output_attentions=output_attentions,
816
+ output_hidden_states=output_hidden_states,
817
+ return_dict=return_dict,
818
+ training=training,
819
+ )
820
+
821
+ sequence_output = encoder_outputs[0]
822
+ sequence_output = self.layernorm(sequence_output)
823
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
824
+
825
+ if not return_dict:
826
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
827
+ return head_outputs + encoder_outputs[1:]
828
+
829
+ return TFData2VecVisionModelOutputWithPooling(
830
+ last_hidden_state=sequence_output,
831
+ pooler_output=pooled_output,
832
+ hidden_states=encoder_outputs.hidden_states,
833
+ attentions=encoder_outputs.attentions,
834
+ )
835
+
836
+ def build(self, input_shape=None):
837
+ if self.built:
838
+ return
839
+ self.built = True
840
+ if getattr(self, "embeddings", None) is not None:
841
+ with tf.name_scope(self.embeddings.name):
842
+ self.embeddings.build(None)
843
+ if getattr(self, "encoder", None) is not None:
844
+ with tf.name_scope(self.encoder.name):
845
+ self.encoder.build(None)
846
+ if getattr(self, "layernorm", None) is not None:
847
+ if hasattr(self.layernorm, "name"):
848
+ with tf.name_scope(self.layernorm.name):
849
+ self.layernorm.build((None, self.config.hidden_size))
850
+ if getattr(self, "pooler", None) is not None:
851
+ with tf.name_scope(self.pooler.name):
852
+ self.pooler.build(None)
853
+
854
+
855
+ class TFData2VecVisionPooler(keras.layers.Layer):
856
+ def __init__(self, config: Data2VecVisionConfig, **kwargs):
857
+ super().__init__(**kwargs)
858
+ self.layernorm = (
859
+ keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
860
+ if config.use_mean_pooling
861
+ else None
862
+ )
863
+ self.config = config
864
+
865
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
866
+ if self.layernorm is not None:
867
+ # Mean pool the final hidden states of the patch tokens
868
+ patch_tokens = hidden_states[:, 1:, :]
869
+ pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1))
870
+ else:
871
+ # Pool by simply taking the final hidden state of the [CLS] token
872
+ pooled_output = hidden_states[:, 0]
873
+
874
+ return pooled_output
875
+
876
+ def build(self, input_shape=None):
877
+ if self.built:
878
+ return
879
+ self.built = True
880
+ if getattr(self, "layernorm", None) is not None:
881
+ if hasattr(self.layernorm, "name"):
882
+ with tf.name_scope(self.layernorm.name):
883
+ self.layernorm.build((None, self.config.hidden_size))
884
+
885
+
886
+ class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
887
+ """
888
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
889
+ models.
890
+ """
891
+
892
+ config_class = Data2VecVisionConfig
893
+ base_model_prefix = "data2vec_vision"
894
+ main_input_name = "pixel_values"
895
+ _keys_to_ignore_on_load_unexpected = [r"relative_position_index"]
896
+
897
+
898
+ DATA2VEC_VISION_START_DOCSTRING = r"""
899
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
900
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
901
+ etc.).
902
+
903
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
904
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
905
+ behavior.
906
+
907
+ <Tip>
908
+
909
+ TensorFlow models and layers in `transformers` accept two formats as input:
910
+
911
+ - having all inputs as keyword arguments (like PyTorch models), or
912
+ - having all inputs as a list, tuple or dict in the first positional argument.
913
+
914
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
915
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
916
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
917
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
918
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
919
+ positional argument:
920
+
921
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
922
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
923
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
924
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
925
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
926
+
927
+ Note that when creating models and layers with
928
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
929
+ about any of this, as you can just pass inputs like you would to any other Python function!
930
+
931
+ </Tip>
932
+
933
+ Args:
934
+ config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
935
+ Initializing with a config file does not load the weights associated with the model, only the
936
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
937
+ """
938
+
939
+ DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
940
+ Args:
941
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
942
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
943
+ [`BeitImageProcessor.__call__`] for details.
944
+
945
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
946
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
947
+ - 1 indicates the head is **not masked**,
948
+ - 0 indicates the head is **masked**.
949
+
950
+ output_attentions (`bool`, *optional*):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
952
+ tensors for more detail.
953
+
954
+ output_hidden_states (`bool`, *optional*):
955
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
956
+ more detail.
957
+
958
+ return_dict (`bool`, *optional*):
959
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
960
+ in eager mode, in graph mode the value will always be set to True.
961
+
962
+ training (`bool`, *optional*, defaults to `False``):
963
+ Whether or not to use the model in training mode (some modules like dropout modules have different
964
+ behaviors between training and evaluation).
965
+ """
966
+
967
+
968
+ @add_start_docstrings(
969
+ "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
970
+ DATA2VEC_VISION_START_DOCSTRING,
971
+ )
972
+ class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
973
+ def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs):
974
+ super().__init__(config, *inputs, **kwargs)
975
+ self.config = config
976
+
977
+ self.data2vec_vision = TFData2VecVisionMainLayer(
978
+ config, add_pooling_layer=add_pooling_layer, name="data2vec_vision"
979
+ )
980
+
981
+ def get_input_embeddings(self):
982
+ return self.data2vec_vision.get_input_embeddings()
983
+
984
+ @unpack_inputs
985
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
986
+ @add_code_sample_docstrings(
987
+ checkpoint=_CHECKPOINT_FOR_DOC,
988
+ output_type=TFData2VecVisionModelOutputWithPooling,
989
+ config_class=_CONFIG_FOR_DOC,
990
+ modality="vision",
991
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
992
+ )
993
+ def call(
994
+ self,
995
+ pixel_values: TFModelInputType | None = None,
996
+ bool_masked_pos: tf.Tensor | None = None,
997
+ head_mask: np.ndarray | tf.Tensor | None = None,
998
+ output_attentions: Optional[bool] = None,
999
+ output_hidden_states: Optional[bool] = None,
1000
+ return_dict: Optional[bool] = None,
1001
+ training: bool = False,
1002
+ ) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
1003
+ r"""
1004
+ bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*):
1005
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
1006
+ """
1007
+ outputs = self.data2vec_vision(
1008
+ pixel_values=pixel_values,
1009
+ bool_masked_pos=bool_masked_pos,
1010
+ head_mask=head_mask,
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ training=training,
1015
+ )
1016
+
1017
+ return outputs
1018
+
1019
+ def build(self, input_shape=None):
1020
+ if self.built:
1021
+ return
1022
+ self.built = True
1023
+ if getattr(self, "data2vec_vision", None) is not None:
1024
+ with tf.name_scope(self.data2vec_vision.name):
1025
+ self.data2vec_vision.build(None)
1026
+
1027
+
1028
+ @add_start_docstrings(
1029
+ """
1030
+ Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
1031
+ the final hidden states of the patch tokens) e.g. for ImageNet.
1032
+ """,
1033
+ DATA2VEC_VISION_START_DOCSTRING,
1034
+ )
1035
+ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss):
1036
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs):
1037
+ super().__init__(config, *inputs, **kwargs)
1038
+
1039
+ self.num_labels = config.num_labels
1040
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision")
1041
+
1042
+ # Classifier head
1043
+ self.classifier = keras.layers.Dense(
1044
+ units=config.num_labels,
1045
+ kernel_initializer=get_initializer(config.initializer_range),
1046
+ name="classifier",
1047
+ )
1048
+ self.config = config
1049
+
1050
+ @unpack_inputs
1051
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
1052
+ @add_code_sample_docstrings(
1053
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1054
+ output_type=TFSequenceClassifierOutput,
1055
+ config_class=_CONFIG_FOR_DOC,
1056
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1057
+ )
1058
+ def call(
1059
+ self,
1060
+ pixel_values: TFModelInputType | None = None,
1061
+ head_mask: np.ndarray | tf.Tensor | None = None,
1062
+ output_attentions: Optional[bool] = None,
1063
+ output_hidden_states: Optional[bool] = None,
1064
+ return_dict: Optional[bool] = None,
1065
+ labels: np.ndarray | tf.Tensor | None = None,
1066
+ training: Optional[bool] = False,
1067
+ ) -> Union[TFSequenceClassifierOutput, tuple]:
1068
+ r"""
1069
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1070
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1071
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1072
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1073
+ """
1074
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1075
+
1076
+ outputs = self.data2vec_vision(
1077
+ pixel_values=pixel_values,
1078
+ head_mask=head_mask,
1079
+ output_attentions=output_attentions,
1080
+ output_hidden_states=output_hidden_states,
1081
+ return_dict=return_dict,
1082
+ training=training,
1083
+ )
1084
+
1085
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
1086
+ logits = self.classifier(pooled_output)
1087
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1088
+
1089
+ if not return_dict:
1090
+ output = (logits,) + outputs[2:]
1091
+ return ((loss,) + output) if loss is not None else output
1092
+
1093
+ return TFSequenceClassifierOutput(
1094
+ loss=loss,
1095
+ logits=logits,
1096
+ hidden_states=outputs.hidden_states,
1097
+ attentions=outputs.attentions,
1098
+ )
1099
+
1100
+ def build(self, input_shape=None):
1101
+ if self.built:
1102
+ return
1103
+ self.built = True
1104
+ if getattr(self, "data2vec_vision", None) is not None:
1105
+ with tf.name_scope(self.data2vec_vision.name):
1106
+ self.data2vec_vision.build(None)
1107
+ if getattr(self, "classifier", None) is not None:
1108
+ with tf.name_scope(self.classifier.name):
1109
+ self.classifier.build([None, None, self.config.hidden_size])
1110
+
1111
+
1112
+ class TFData2VecVisionConvModule(keras.layers.Layer):
1113
+ """
1114
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
1115
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
1116
+
1117
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1118
+ """
1119
+
1120
+ def __init__(
1121
+ self,
1122
+ in_channels: int,
1123
+ out_channels: int,
1124
+ kernel_size: Union[int, Tuple[int, int]],
1125
+ padding: str = "valid",
1126
+ bias: bool = False,
1127
+ dilation: Union[int, Tuple[int, int]] = 1,
1128
+ **kwargs,
1129
+ ) -> None:
1130
+ super().__init__(**kwargs)
1131
+ self.conv = keras.layers.Conv2D(
1132
+ filters=out_channels,
1133
+ kernel_size=kernel_size,
1134
+ padding=padding,
1135
+ use_bias=bias,
1136
+ dilation_rate=dilation,
1137
+ name="conv",
1138
+ )
1139
+ self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
1140
+ self.activation = tf.nn.relu
1141
+ self.in_channels = in_channels
1142
+ self.out_channels = out_channels
1143
+
1144
+ def call(self, input: tf.Tensor) -> tf.Tensor:
1145
+ output = self.conv(input)
1146
+ output = self.bn(output)
1147
+ output = self.activation(output)
1148
+ return output
1149
+
1150
+ def build(self, input_shape=None):
1151
+ if self.built:
1152
+ return
1153
+ self.built = True
1154
+ if getattr(self, "conv", None) is not None:
1155
+ with tf.name_scope(self.conv.name):
1156
+ self.conv.build([None, None, None, self.in_channels])
1157
+ if getattr(self, "bn", None) is not None:
1158
+ with tf.name_scope(self.bn.name):
1159
+ self.bn.build((None, None, None, self.out_channels))
1160
+
1161
+
1162
+ class TFAdaptiveAvgPool2D(keras.layers.Layer):
1163
+ def __init__(self, output_dims: Tuple[int, int], input_ordering: str = "NHWC", **kwargs):
1164
+ super().__init__(**kwargs)
1165
+ self.output_dims = output_dims
1166
+ self.input_ordering = input_ordering
1167
+ if input_ordering not in ("NCHW", "NHWC"):
1168
+ raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!")
1169
+ self.h_axis = input_ordering.index("H")
1170
+ self.w_axis = input_ordering.index("W")
1171
+
1172
+ def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
1173
+ # Figure out which axis we're pooling on
1174
+ if h_pooling:
1175
+ axis = self.h_axis
1176
+ output_dim = self.output_dims[0]
1177
+ else:
1178
+ axis = self.w_axis
1179
+ output_dim = self.output_dims[1]
1180
+ input_dim = inputs.shape[axis]
1181
+
1182
+ # Figure out the potential pooling windows
1183
+ # This is the key idea - the torch op always uses only two
1184
+ # consecutive pooling window sizes, like 3 and 4. Therefore,
1185
+ # if we pool with both possible sizes, we simply need to gather
1186
+ # the 'correct' pool at each position to reimplement the torch op.
1187
+ small_window = math.ceil(input_dim / output_dim)
1188
+ big_window = small_window + 1
1189
+ if h_pooling:
1190
+ output_dim = self.output_dims[0]
1191
+ small_window_shape = (small_window, 1)
1192
+ big_window_shape = (big_window, 1)
1193
+ else:
1194
+ output_dim = self.output_dims[1]
1195
+ small_window_shape = (1, small_window)
1196
+ big_window_shape = (1, big_window)
1197
+
1198
+ # For resizes to 1, or integer resizes, we can take quick shortcuts
1199
+ if output_dim == input_dim:
1200
+ return inputs
1201
+ elif output_dim == 1:
1202
+ return tf.reduce_mean(inputs, axis=axis, keepdims=True)
1203
+ elif input_dim % output_dim == 0:
1204
+ return tf.nn.avg_pool2d(
1205
+ inputs,
1206
+ ksize=small_window_shape,
1207
+ strides=small_window_shape,
1208
+ padding="VALID",
1209
+ data_format=self.input_ordering,
1210
+ )
1211
+ # When upscaling by an integer factor we can also take a quick shortcut
1212
+ elif output_dim > input_dim and output_dim % input_dim == 0:
1213
+ return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis)
1214
+
1215
+ # For non-integer resizes, we pool with both possible window sizes and concatenate them
1216
+ if output_dim < input_dim:
1217
+ small_pool = tf.nn.avg_pool2d(
1218
+ inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
1219
+ )
1220
+ big_pool = tf.nn.avg_pool2d(
1221
+ inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
1222
+ )
1223
+ both_pool = tf.concat([small_pool, big_pool], axis=axis)
1224
+ else:
1225
+ # When we're actually upscaling instead, then we build the pools a bit differently
1226
+ small_pool = inputs
1227
+ big_pool = tf.nn.avg_pool2d(
1228
+ inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
1229
+ )
1230
+ both_pool = tf.concat([small_pool, big_pool], axis=axis)
1231
+
1232
+ # We compute vectors of the start and end positions for each pooling window
1233
+ # Each (start, end) pair here corresponds to a single output position
1234
+ window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim)
1235
+ window_starts = tf.cast(window_starts, tf.int64)
1236
+ window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim)
1237
+ window_ends = tf.cast(window_ends, tf.int64)
1238
+
1239
+ # pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position
1240
+ # has a big receptive field and 0 indicates that that output position has a small receptive field
1241
+ pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool)
1242
+
1243
+ # Since we concatenated the small and big pools, we need to do a bit of
1244
+ # pointer arithmetic to get the indices of the big pools
1245
+ small_indices = window_starts
1246
+ big_indices = window_starts + small_pool.shape[axis]
1247
+
1248
+ # Finally, we use the pool_selector to generate a list of indices, one per output position
1249
+ gather_indices = tf.where(pool_selector, big_indices, small_indices)
1250
+
1251
+ # Gathering from those indices yields the final, correct pooling
1252
+ return tf.gather(both_pool, gather_indices, axis=axis)
1253
+
1254
+ def call(self, inputs: tf.Tensor):
1255
+ if self.input_ordering == "NHWC":
1256
+ input_shape = inputs.shape[1:3]
1257
+ else:
1258
+ input_shape = inputs.shape[2:]
1259
+
1260
+ # We break the task down into each possible case
1261
+ # Firstly, if we're resizing down to 1, it's just tf.reduce_mean
1262
+ if self.output_dims[0] == self.output_dims[1] == 1:
1263
+ if self.input_ordering == "NHWC":
1264
+ reduce_dims = [1, 2]
1265
+ else:
1266
+ reduce_dims = [2, 3]
1267
+ return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True)
1268
+ # Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut
1269
+ elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0:
1270
+ h_resize = int(input_shape[0] // self.output_dims[0])
1271
+ w_resize = int(input_shape[1] // self.output_dims[1])
1272
+ return tf.nn.avg_pool2d(
1273
+ inputs,
1274
+ ksize=(h_resize, w_resize),
1275
+ strides=(h_resize, w_resize),
1276
+ padding="VALID",
1277
+ data_format=self.input_ordering,
1278
+ )
1279
+ else:
1280
+ # Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut
1281
+ # for dimensions where an integer resize is possible. It can also handle upscaling.
1282
+ h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
1283
+ return self.pseudo_1d_pool(h_pooled, h_pooling=False)
1284
+
1285
+
1286
+ class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer):
1287
+ """
1288
+ Pyramid Pooling Module (PPM) used in PSPNet.
1289
+
1290
+ Args:
1291
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
1292
+ Module.
1293
+ channels (int): Channels after modules, before conv_seg.
1294
+
1295
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1296
+ """
1297
+
1298
+ def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None:
1299
+ super().__init__(**kwargs)
1300
+ self.pool_scales = pool_scales
1301
+ self.in_channels = in_channels
1302
+ self.out_channels = out_channels
1303
+
1304
+ self.layer_list = []
1305
+ for idx, pool_scale in enumerate(pool_scales):
1306
+ pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
1307
+ self.layer_list.append(
1308
+ [
1309
+ TFAdaptiveAvgPool2D(output_dims=pool_scale),
1310
+ TFData2VecVisionConvModule(
1311
+ in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1"
1312
+ ),
1313
+ ]
1314
+ )
1315
+
1316
+ def call(self, x: tf.Tensor) -> List[tf.Tensor]:
1317
+ ppm_outs = []
1318
+ inputs = x
1319
+
1320
+ for ppm in self.layer_list:
1321
+ for layer_module in ppm:
1322
+ ppm_out = layer_module(x)
1323
+ x = ppm_out
1324
+
1325
+ upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
1326
+ ppm_outs.append(upsampled_ppm_out)
1327
+ return ppm_outs
1328
+
1329
+ def build(self, input_shape=None):
1330
+ for layer in self.layer_list:
1331
+ for layer_module in layer:
1332
+ with tf.name_scope(layer_module.name):
1333
+ layer_module.build(None)
1334
+
1335
+
1336
+ class TFData2VecVisionUperHead(keras.layers.Layer):
1337
+ """
1338
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
1339
+ [UPerNet](https://arxiv.org/abs/1807.10221).
1340
+
1341
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1342
+ """
1343
+
1344
+ def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
1345
+ super().__init__(**kwargs)
1346
+
1347
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
1348
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
1349
+ self.channels = config.hidden_size
1350
+ self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
1351
+
1352
+ # PSP Module
1353
+ self.psp_modules = TFData2VecVisionPyramidPoolingModule(
1354
+ self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules"
1355
+ )
1356
+ self.bottleneck = TFData2VecVisionConvModule(
1357
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
1358
+ self.channels,
1359
+ kernel_size=3,
1360
+ padding="same",
1361
+ name="bottleneck",
1362
+ )
1363
+ # FPN Module
1364
+ self.lateral_convs = []
1365
+ self.fpn_convs = []
1366
+ for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer
1367
+ l_conv = TFData2VecVisionConvModule(
1368
+ in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}"
1369
+ )
1370
+ fpn_conv = TFData2VecVisionConvModule(
1371
+ in_channels=self.channels,
1372
+ out_channels=self.channels,
1373
+ kernel_size=3,
1374
+ padding="same",
1375
+ name=f"fpn_convs.{idx}",
1376
+ )
1377
+ self.lateral_convs.append(l_conv)
1378
+ self.fpn_convs.append(fpn_conv)
1379
+
1380
+ self.fpn_bottleneck = TFData2VecVisionConvModule(
1381
+ in_channels=len(self.in_channels) * self.channels,
1382
+ out_channels=self.channels,
1383
+ kernel_size=3,
1384
+ padding="same",
1385
+ name="fpn_bottleneck",
1386
+ )
1387
+
1388
+ def psp_forward(self, inputs):
1389
+ x = inputs[-1]
1390
+ psp_outs = [x]
1391
+ psp_outs.extend(self.psp_modules(x))
1392
+ psp_outs = tf.concat(psp_outs, axis=-1)
1393
+ output = self.bottleneck(psp_outs)
1394
+
1395
+ return output
1396
+
1397
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
1398
+ # build laterals
1399
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
1400
+
1401
+ laterals.append(self.psp_forward(encoder_hidden_states))
1402
+
1403
+ # build top-down path
1404
+ used_backbone_levels = len(laterals)
1405
+ for i in range(used_backbone_levels - 1, 0, -1):
1406
+ prev_shape = shape_list(laterals[i - 1])[1:-1]
1407
+ laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
1408
+
1409
+ # build outputs
1410
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
1411
+ # append psp feature
1412
+ fpn_outs.append(laterals[-1])
1413
+
1414
+ for i in range(used_backbone_levels - 1, 0, -1):
1415
+ fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
1416
+ fpn_outs = tf.concat(fpn_outs, axis=-1)
1417
+ output = self.fpn_bottleneck(fpn_outs)
1418
+ output = self.classifier(output)
1419
+
1420
+ return output
1421
+
1422
+ def build(self, input_shape=None):
1423
+ if self.built:
1424
+ return
1425
+ self.built = True
1426
+ if getattr(self, "classifier", None) is not None:
1427
+ with tf.name_scope(self.classifier.name):
1428
+ self.classifier.build([None, None, None, self.channels])
1429
+ if getattr(self, "psp_modules", None) is not None:
1430
+ with tf.name_scope(self.psp_modules.name):
1431
+ self.psp_modules.build(None)
1432
+ if getattr(self, "bottleneck", None) is not None:
1433
+ with tf.name_scope(self.bottleneck.name):
1434
+ self.bottleneck.build(None)
1435
+ if getattr(self, "fpn_bottleneck", None) is not None:
1436
+ with tf.name_scope(self.fpn_bottleneck.name):
1437
+ self.fpn_bottleneck.build(None)
1438
+ for layer in self.lateral_convs:
1439
+ with tf.name_scope(layer.name):
1440
+ layer.build(None)
1441
+ for layer in self.fpn_convs:
1442
+ with tf.name_scope(layer.name):
1443
+ layer.build(None)
1444
+
1445
+
1446
+ class TFData2VecVisionFCNHead(keras.layers.Layer):
1447
+ """
1448
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented from
1449
+ [FCNNet](https://arxiv.org/abs/1411.4038).
1450
+
1451
+ Args:
1452
+ config (Data2VecVisionConfig): Configuration.
1453
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
1454
+ dilation (int): The dilation rate for convs in the head. Default: 1.
1455
+
1456
+
1457
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1458
+ """
1459
+
1460
+ def __init__(
1461
+ self,
1462
+ config: Data2VecVisionConfig,
1463
+ in_index: int = 2,
1464
+ kernel_size: int = 3,
1465
+ dilation: Union[int, Tuple[int, int]] = 1,
1466
+ **kwargs,
1467
+ ) -> None:
1468
+ super().__init__(**kwargs)
1469
+ self.in_channels = config.hidden_size
1470
+ self.channels = config.auxiliary_channels
1471
+ self.num_convs = config.auxiliary_num_convs
1472
+ self.concat_input = config.auxiliary_concat_input
1473
+ self.in_index = in_index
1474
+
1475
+ convs = []
1476
+ convs.append(
1477
+ TFData2VecVisionConvModule(
1478
+ in_channels=self.in_channels,
1479
+ out_channels=self.channels,
1480
+ kernel_size=kernel_size,
1481
+ padding="same",
1482
+ dilation=dilation,
1483
+ name="convs.0",
1484
+ )
1485
+ )
1486
+ for i in range(self.num_convs - 1):
1487
+ convs.append(
1488
+ TFData2VecVisionConvModule(
1489
+ in_channels=self.channels,
1490
+ out_channels=self.channels,
1491
+ kernel_size=kernel_size,
1492
+ padding="same",
1493
+ dilation=dilation,
1494
+ name=f"conv_module_{i + 2}",
1495
+ )
1496
+ )
1497
+ if self.num_convs == 0:
1498
+ self.convs = [tf.identity]
1499
+ else:
1500
+ self.convs = convs
1501
+ if self.concat_input:
1502
+ self.conv_cat = TFData2VecVisionConvModule(
1503
+ self.in_channels + self.channels,
1504
+ out_channels=self.channels,
1505
+ kernel_size=kernel_size,
1506
+ padding="same",
1507
+ name="conv_cat",
1508
+ )
1509
+
1510
+ self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
1511
+
1512
+ def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
1513
+ # just take the relevant feature maps
1514
+ hidden_states = encoder_hidden_states[self.in_index]
1515
+ output = hidden_states
1516
+ for layer_module in self.convs:
1517
+ output = layer_module(output)
1518
+ if self.concat_input:
1519
+ output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
1520
+ output = self.classifier(output)
1521
+ return output
1522
+
1523
+ def build(self, input_shape=None):
1524
+ if self.built:
1525
+ return
1526
+ self.built = True
1527
+ if getattr(self, "classifier", None) is not None:
1528
+ with tf.name_scope(self.classifier.name):
1529
+ self.classifier.build([None, None, None, self.channels])
1530
+ if getattr(self, "conv_cat", None) is not None:
1531
+ with tf.name_scope(self.conv_cat.name):
1532
+ self.conv_cat.build(None)
1533
+
1534
+
1535
+ @add_start_docstrings(
1536
+ """
1537
+ Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
1538
+ """,
1539
+ DATA2VEC_VISION_START_DOCSTRING,
1540
+ )
1541
+ class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
1542
+ def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
1543
+ super().__init__(config, *inputs, **kwargs)
1544
+ self.num_labels = config.num_labels
1545
+ self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
1546
+
1547
+ # FPNs
1548
+ self.fpn1 = [
1549
+ keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
1550
+ keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5),
1551
+ keras.layers.Activation("gelu"),
1552
+ keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
1553
+ ]
1554
+ self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
1555
+
1556
+ self.fpn3 = tf.identity
1557
+ self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2)
1558
+
1559
+ # Semantic segmentation head(s)
1560
+ self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
1561
+ self.auxiliary_head = (
1562
+ TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
1563
+ )
1564
+
1565
+ def compute_loss(self, logits, auxiliary_logits, labels):
1566
+ # upsample logits to the images' original size
1567
+ if len(shape_list(labels)) > 3:
1568
+ label_interp_shape = shape_list(labels)[1:-1]
1569
+ else:
1570
+ label_interp_shape = shape_list(labels)[-2:]
1571
+
1572
+ upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
1573
+ if auxiliary_logits is not None:
1574
+ upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
1575
+ # compute weighted loss
1576
+ loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
1577
+
1578
+ # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
1579
+ # Utility to mask the index to ignore during computing the loss.
1580
+ def masked_loss(real, pred):
1581
+ mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
1582
+ loss_ = loss_fct(real, pred)
1583
+ mask = tf.cast(mask, dtype=loss_.dtype)
1584
+ loss_ *= mask
1585
+ reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
1586
+ return tf.reshape(reduced_masked_loss, (1,))
1587
+
1588
+ main_loss = masked_loss(labels, upsampled_logits)
1589
+ auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
1590
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
1591
+
1592
+ return loss
1593
+
1594
+ @unpack_inputs
1595
+ @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
1596
+ @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
1597
+ def call(
1598
+ self,
1599
+ pixel_values: tf.Tensor | None = None,
1600
+ head_mask: tf.Tensor | None = None,
1601
+ labels: tf.Tensor | None = None,
1602
+ output_attentions: Optional[bool] = None,
1603
+ output_hidden_states: Optional[bool] = None,
1604
+ return_dict: Optional[bool] = None,
1605
+ ) -> Union[tuple, TFSemanticSegmenterOutput]:
1606
+ r"""
1607
+ labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
1608
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
1609
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
1610
+
1611
+ Returns:
1612
+
1613
+ Examples:
1614
+
1615
+ ```python
1616
+ >>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation
1617
+ >>> from PIL import Image
1618
+ >>> import requests
1619
+
1620
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1621
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1622
+
1623
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
1624
+ >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
1625
+
1626
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1627
+ >>> outputs = model(**inputs)
1628
+ >>> # logits are of shape (batch_size, num_labels, height, width)
1629
+ >>> logits = outputs.logits
1630
+ ```"""
1631
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1632
+ output_hidden_states = (
1633
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1634
+ )
1635
+
1636
+ if labels is not None and self.config.num_labels == 1:
1637
+ raise ValueError("The number of labels should be greater than one")
1638
+
1639
+ outputs = self.data2vec_vision(
1640
+ pixel_values,
1641
+ head_mask=head_mask,
1642
+ output_attentions=output_attentions,
1643
+ output_hidden_states=True, # we need the intermediate hidden states
1644
+ return_dict=return_dict,
1645
+ )
1646
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
1647
+
1648
+ # only keep certain features, and reshape
1649
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
1650
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
1651
+ patch_resolution = self.config.image_size // self.config.patch_size
1652
+
1653
+ def reshape_features(x):
1654
+ # We do it this way so TF can always infer the non-batch dims at compile time
1655
+ x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size))
1656
+ return x
1657
+
1658
+ features = [reshape_features(x[:, 1:, :]) for x in features]
1659
+
1660
+ # apply FPNs
1661
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
1662
+ for module in ops[0]:
1663
+ features[0] = module(features[0])
1664
+ features[1] = ops[1][0](features[1])
1665
+ for i in range(len(features[2:])):
1666
+ features[i + 2] = ops[i + 2](features[i + 2])
1667
+
1668
+ logits = self.decode_head(features)
1669
+ # Tranpose the logits to maintain consistency in the output formats.
1670
+ transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
1671
+
1672
+ auxiliary_logits = None
1673
+ if self.auxiliary_head is not None:
1674
+ auxiliary_logits = self.auxiliary_head(features)
1675
+
1676
+ loss = None
1677
+ if labels is not None:
1678
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
1679
+
1680
+ if not return_dict:
1681
+ if output_hidden_states:
1682
+ output = (logits,) + outputs[1:]
1683
+ else:
1684
+ output = (logits,) + outputs[2:]
1685
+ return ((loss,) + output) if loss is not None else output
1686
+
1687
+ return TFSemanticSegmenterOutput(
1688
+ loss=loss,
1689
+ logits=transposed_logits,
1690
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1691
+ attentions=outputs.attentions,
1692
+ )
1693
+
1694
+ def build(self, input_shape=None):
1695
+ if self.built:
1696
+ return
1697
+ self.built = True
1698
+ if getattr(self, "data2vec_vision", None) is not None:
1699
+ with tf.name_scope(self.data2vec_vision.name):
1700
+ self.data2vec_vision.build(None)
1701
+ if getattr(self, "decode_head", None) is not None:
1702
+ with tf.name_scope(self.decode_head.name):
1703
+ self.decode_head.build(None)
1704
+ if getattr(self, "auxiliary_head", None) is not None:
1705
+ with tf.name_scope(self.auxiliary_head.name):
1706
+ self.auxiliary_head.build(None)
1707
+ if getattr(self, "fpn1", None) is not None:
1708
+ with tf.name_scope(self.fpn1[0].name):
1709
+ self.fpn1[0].build([None, None, None, self.config.hidden_size])
1710
+ with tf.name_scope(self.fpn1[1].name):
1711
+ self.fpn1[1].build((None, None, None, self.config.hidden_size))
1712
+ with tf.name_scope(self.fpn1[3].name):
1713
+ self.fpn1[3].build([None, None, None, self.config.hidden_size])
1714
+ if getattr(self, "fpn2", None) is not None:
1715
+ with tf.name_scope(self.fpn2[0].name):
1716
+ self.fpn2[0].build([None, None, None, self.config.hidden_size])
1717
+
1718
+
1719
+ __all__ = [
1720
+ "TFData2VecVisionForImageClassification",
1721
+ "TFData2VecVisionForSemanticSegmentation",
1722
+ "TFData2VecVisionModel",
1723
+ "TFData2VecVisionPreTrainedModel",
1724
+ ]
docs/transformers/build/lib/transformers/models/data2vec/modular_data2vec_audio.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from ...activations import ACT2FN
7
+ from ...modeling_outputs import (
8
+ CausalLMOutput,
9
+ SequenceClassifierOutput,
10
+ TokenClassifierOutput,
11
+ Wav2Vec2BaseModelOutput,
12
+ XVectorOutput,
13
+ )
14
+ from ...modeling_utils import PreTrainedModel
15
+ from ...utils import (
16
+ add_code_sample_docstrings,
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ )
20
+ from ..wav2vec2.modeling_wav2vec2 import (
21
+ Wav2Vec2Adapter,
22
+ Wav2Vec2Encoder,
23
+ Wav2Vec2FeatureEncoder,
24
+ Wav2Vec2FeatureProjection,
25
+ Wav2Vec2ForAudioFrameClassification,
26
+ Wav2Vec2ForCTC,
27
+ Wav2Vec2ForSequenceClassification,
28
+ Wav2Vec2ForXVector,
29
+ Wav2Vec2Model,
30
+ Wav2Vec2PreTrainedModel,
31
+ Wav2Vec2SamePadLayer,
32
+ )
33
+ from .configuration_data2vec_audio import Data2VecAudioConfig
34
+
35
+
36
+ _HIDDEN_STATES_START_POSITION = 2
37
+
38
+ # General docstring
39
+ _CONFIG_FOR_DOC = "Data2VecAudioConfig"
40
+
41
+ # Base docstring
42
+ _CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
43
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
44
+
45
+ # CTC docstring
46
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
47
+ _CTC_EXPECTED_LOSS = 66.95
48
+
49
+
50
+ class Data2VecAudioConvLayer(nn.Module):
51
+ def __init__(self, config, layer_id=0):
52
+ super().__init__()
53
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
54
+ self.out_conv_dim = config.conv_dim[layer_id]
55
+
56
+ self.conv = nn.Conv1d(
57
+ self.in_conv_dim,
58
+ self.out_conv_dim,
59
+ kernel_size=config.conv_kernel[layer_id],
60
+ stride=config.conv_stride[layer_id],
61
+ bias=config.conv_bias,
62
+ )
63
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
64
+ self.activation = ACT2FN[config.feat_extract_activation]
65
+
66
+ def forward(self, hidden_states):
67
+ hidden_states = self.conv(hidden_states)
68
+
69
+ hidden_states = hidden_states.transpose(-2, -1)
70
+ hidden_states = self.layer_norm(hidden_states)
71
+ hidden_states = hidden_states.transpose(-2, -1)
72
+
73
+ hidden_states = self.activation(hidden_states)
74
+ return hidden_states
75
+
76
+
77
+ class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
78
+ pass
79
+
80
+
81
+ class Data2VecAudioPositionalConvLayer(nn.Module):
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ self.conv = nn.Conv1d(
85
+ config.hidden_size,
86
+ config.hidden_size,
87
+ kernel_size=config.conv_pos_kernel_size,
88
+ padding=config.conv_pos_kernel_size // 2,
89
+ groups=config.num_conv_pos_embedding_groups,
90
+ )
91
+
92
+ self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
93
+ self.activation = ACT2FN[config.feat_extract_activation]
94
+ # no learnable parameters
95
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
96
+
97
+ def forward(self, hidden_states):
98
+ hidden_states = self.conv(hidden_states)
99
+ hidden_states = self.padding(hidden_states)
100
+
101
+ hidden_states = hidden_states.transpose(1, 2)
102
+ hidden_states = self.layer_norm(hidden_states)
103
+ hidden_states = hidden_states.transpose(1, 2)
104
+ hidden_states = self.activation(hidden_states)
105
+ return hidden_states
106
+
107
+
108
+ class Data2VecAudioPositionalConvEmbedding(nn.Module):
109
+ def __init__(self, config):
110
+ super().__init__()
111
+ self.layers = nn.ModuleList(
112
+ [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
113
+ )
114
+
115
+ def forward(self, hidden_states):
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ for layer in self.layers:
118
+ hidden_states = layer(hidden_states)
119
+ hidden_states = hidden_states.transpose(1, 2)
120
+ return hidden_states
121
+
122
+
123
+ class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module):
124
+ def __init__(self, config):
125
+ nn.Module.__init__()
126
+ self.conv_layers = nn.ModuleList(
127
+ [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
128
+ )
129
+ self.gradient_checkpointing = False
130
+ self._requires_grad = True
131
+
132
+
133
+ class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
134
+ pass
135
+
136
+
137
+ class Data2VecAudioEncoder(Wav2Vec2Encoder):
138
+ pass
139
+
140
+
141
+ class Data2VecAudioAdapter(Wav2Vec2Adapter):
142
+ pass
143
+
144
+
145
+ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
146
+ """
147
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
148
+ models.
149
+ """
150
+
151
+ config_class = Data2VecAudioConfig
152
+ base_model_prefix = "data2vec_audio"
153
+ main_input_name = "input_values"
154
+ supports_gradient_checkpointing = True
155
+ _supports_flash_attn_2 = True
156
+ _supports_sdpa = True
157
+
158
+ def _init_weights(self, module):
159
+ """Initialize the weights"""
160
+ if isinstance(module, Data2VecAudioFeatureProjection):
161
+ k = math.sqrt(1 / module.projection.in_features)
162
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
163
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
164
+ elif isinstance(module, Data2VecAudioPositionalConvLayer):
165
+ nn.init.constant_(module.conv.bias, 0)
166
+ elif isinstance(module, nn.Linear):
167
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
168
+
169
+ if module.bias is not None:
170
+ module.bias.data.zero_()
171
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
172
+ if module.bias is not None:
173
+ module.bias.data.zero_()
174
+ if module.weight is not None:
175
+ module.weight.data.fill_(1.0)
176
+ elif isinstance(module, nn.Conv1d):
177
+ nn.init.kaiming_normal_(module.weight)
178
+
179
+ if module.bias is not None:
180
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
181
+ nn.init.uniform_(module.bias, a=-k, b=k)
182
+
183
+ def _get_adapters(self):
184
+ raise AttributeError("Not needed for Data2VecAudio")
185
+
186
+ def init_adapter_layers(self):
187
+ raise AttributeError("Not needed for Data2VecAudio")
188
+
189
+ def load_adapter(self):
190
+ raise AttributeError("Not needed for Data2VecAudio")
191
+
192
+
193
+ DATA2VEC_AUDIO_START_DOCSTRING = r"""
194
+ Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
195
+ Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
196
+ Michael Auli.
197
+
198
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
199
+ library implements for all its model (such as downloading or saving etc.).
200
+
201
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
202
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
203
+ behavior.
204
+
205
+ Parameters:
206
+ config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
207
+ Initializing with a config file does not load the weights associated with the model, only the
208
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
209
+ """
210
+
211
+ DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
212
+ Args:
213
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
214
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
215
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
216
+ soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
217
+ conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
218
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
219
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
220
+ 1]`:
221
+
222
+ - 1 for tokens that are **not masked**,
223
+ - 0 for tokens that are **masked**.
224
+
225
+ [What are attention masks?](../glossary#attention-mask)
226
+
227
+ <Tip warning={true}>
228
+
229
+ `attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
230
+ True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
231
+ `attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
232
+ because there are more than one convolutional layer in the positional encodings. For a more detailed
233
+ explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
234
+
235
+ </Tip>
236
+
237
+ output_attentions (`bool`, *optional*):
238
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
239
+ tensors for more detail.
240
+ output_hidden_states (`bool`, *optional*):
241
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
242
+ more detail.
243
+ return_dict (`bool`, *optional*):
244
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
245
+ """
246
+
247
+ Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
248
+
249
+
250
+ @add_start_docstrings(
251
+ "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
252
+ DATA2VEC_AUDIO_START_DOCSTRING,
253
+ )
254
+ class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
255
+ def __init__(self, config: Data2VecAudioConfig):
256
+ Data2VecAudioPreTrainedModel.__init__(config)
257
+ self.config = config
258
+ self.feature_extractor = Data2VecAudioFeatureEncoder(config)
259
+ self.feature_projection = Data2VecAudioFeatureProjection(config)
260
+
261
+ # model only needs masking vector if mask prob is > 0.0
262
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
263
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
264
+
265
+ self.encoder = Data2VecAudioEncoder(config)
266
+
267
+ self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
268
+
269
+ # Initialize weights and apply final processing
270
+ self.post_init()
271
+
272
+ def freeze_feature_extractor(self):
273
+ raise AttributeError("Not needed for Data2VecAudio")
274
+
275
+ def freeze_feature_encoder(self):
276
+ """
277
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
278
+ not be updated during training.
279
+ """
280
+ self.feature_extractor._freeze_parameters()
281
+
282
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
283
+ @add_code_sample_docstrings(
284
+ checkpoint=_CHECKPOINT_FOR_DOC,
285
+ output_type=Data2VecAudioBaseModelOutput,
286
+ config_class=_CONFIG_FOR_DOC,
287
+ modality="audio",
288
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
289
+ )
290
+ def forward(self, **super_kwargs):
291
+ return super().forward(**super_kwargs)
292
+
293
+
294
+ @add_start_docstrings(
295
+ """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
296
+ DATA2VEC_AUDIO_START_DOCSTRING,
297
+ )
298
+ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
299
+ def __init__(self, config):
300
+ Data2VecAudioPreTrainedModel.__init__(config)
301
+
302
+ self.data2vec_audio = Data2VecAudioModel(config)
303
+ self.dropout = nn.Dropout(config.final_dropout)
304
+
305
+ if config.vocab_size is None:
306
+ raise ValueError(
307
+ f"You are trying to instantiate {self.__class__} with a configuration that "
308
+ "does not define the vocabulary size of the language model head. Please "
309
+ "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
310
+ "or define `vocab_size` of your model's configuration."
311
+ )
312
+ output_hidden_size = (
313
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
314
+ )
315
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
316
+
317
+ # Initialize weights and apply final processing
318
+ self.post_init()
319
+
320
+ def freeze_base_model(self):
321
+ raise AttributeError("Not needed for Data2VecAudio")
322
+
323
+ def tie_weights(self):
324
+ raise AttributeError("Not needed for Data2VecAudio")
325
+
326
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
327
+ @add_code_sample_docstrings(
328
+ checkpoint=_CHECKPOINT_FOR_DOC,
329
+ output_type=CausalLMOutput,
330
+ config_class=_CONFIG_FOR_DOC,
331
+ expected_output=_CTC_EXPECTED_OUTPUT,
332
+ expected_loss=_CTC_EXPECTED_LOSS,
333
+ )
334
+ def forward(self, **super_kwargs):
335
+ return super().forward(**super_kwargs)
336
+
337
+
338
+ @add_start_docstrings(
339
+ """
340
+ Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
341
+ like SUPERB Keyword Spotting.
342
+ """,
343
+ DATA2VEC_AUDIO_START_DOCSTRING,
344
+ )
345
+ class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
346
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
347
+ @add_code_sample_docstrings(
348
+ checkpoint=_CHECKPOINT_FOR_DOC,
349
+ output_type=SequenceClassifierOutput,
350
+ config_class=_CONFIG_FOR_DOC,
351
+ modality="audio",
352
+ )
353
+ def forward(self, **super_kwargs):
354
+ return super().forward(**super_kwargs)
355
+
356
+
357
+ @add_start_docstrings(
358
+ """
359
+ Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
360
+ """,
361
+ DATA2VEC_AUDIO_START_DOCSTRING,
362
+ )
363
+ class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
364
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
365
+ @add_code_sample_docstrings(
366
+ checkpoint=_CHECKPOINT_FOR_DOC,
367
+ output_type=TokenClassifierOutput,
368
+ config_class=_CONFIG_FOR_DOC,
369
+ modality="audio",
370
+ )
371
+ def forward(self, **super_kwargs):
372
+ return super().forward(**super_kwargs)
373
+
374
+
375
+ @add_start_docstrings(
376
+ """
377
+ Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
378
+ """,
379
+ DATA2VEC_AUDIO_START_DOCSTRING,
380
+ )
381
+ class Data2VecAudioForXVector(Wav2Vec2ForXVector):
382
+ @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
383
+ @add_code_sample_docstrings(
384
+ checkpoint=_CHECKPOINT_FOR_DOC,
385
+ output_type=XVectorOutput,
386
+ config_class=_CONFIG_FOR_DOC,
387
+ modality="audio",
388
+ )
389
+ def forward(self, **super_kwargs):
390
+ return super().forward(**super_kwargs)
391
+
392
+
393
+ __all__ = [
394
+ "Data2VecAudioForAudioFrameClassification",
395
+ "Data2VecAudioForCTC",
396
+ "Data2VecAudioForSequenceClassification",
397
+ "Data2VecAudioForXVector",
398
+ "Data2VecAudioModel",
399
+ "Data2VecAudioPreTrainedModel",
400
+ ]
docs/transformers/build/lib/transformers/models/dbrx/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_dbrx import *
22
+ from .modeling_dbrx import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/dbrx/configuration_dbrx.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DBRX model configuration"""
16
+
17
+ from typing import Any, Optional
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DbrxAttentionConfig(PretrainedConfig):
27
+ """Configuration class for Dbrx Attention.
28
+
29
+ [`DbrxAttention`] class. It is used to instantiate attention layers
30
+ according to the specified arguments, defining the layers architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
37
+ The dropout probability for the attention layers.
38
+ clip_qkv (`float`, *optional*):
39
+ If set, clip the queries, keys, and values in the attention layer to this value.
40
+ kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
41
+ rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope.
42
+ """
43
+
44
+ base_config_key = "attn_config"
45
+
46
+ def __init__(
47
+ self,
48
+ attn_pdrop: float = 0.0,
49
+ clip_qkv: Optional[float] = None,
50
+ kv_n_heads: int = 1,
51
+ rope_theta: float = 10000.0,
52
+ **kwargs: Any,
53
+ ):
54
+ super().__init__(**kwargs)
55
+ self.attn_pdrop = attn_pdrop
56
+ self.clip_qkv = clip_qkv
57
+ self.kv_n_heads = kv_n_heads
58
+ self.rope_theta = rope_theta
59
+
60
+ for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
61
+ if k in kwargs:
62
+ kwargs.pop(k)
63
+ if len(kwargs) != 0:
64
+ raise ValueError(f"Found unknown {kwargs=}")
65
+
66
+
67
+ class DbrxFFNConfig(PretrainedConfig):
68
+ """Configuration class for Dbrx FFN.
69
+
70
+ [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
71
+ the specified arguments, defining the layers architecture.
72
+
73
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
74
+ documentation from [`PretrainedConfig`] for more information.
75
+
76
+ Args:
77
+ ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
78
+ The dict should have a key 'name' with the value being the name of the activation function along with
79
+ any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
80
+ ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network.
81
+ moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer.
82
+ moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer.
83
+ moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
84
+ moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer.
85
+ moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
86
+ """
87
+
88
+ base_config_key = "ffn_config"
89
+
90
+ def __init__(
91
+ self,
92
+ ffn_act_fn: dict = None,
93
+ ffn_hidden_size: int = 3584,
94
+ moe_num_experts: int = 4,
95
+ moe_top_k: int = 1,
96
+ moe_jitter_eps: Optional[float] = None,
97
+ moe_loss_weight: float = 0.01,
98
+ moe_normalize_expert_weights: Optional[float] = 1.0,
99
+ **kwargs: Any,
100
+ ):
101
+ super().__init__()
102
+ if ffn_act_fn is None:
103
+ ffn_act_fn = {"name": "silu"}
104
+ self.ffn_act_fn = ffn_act_fn
105
+ self.ffn_hidden_size = ffn_hidden_size
106
+ self.moe_num_experts = moe_num_experts
107
+ self.moe_top_k = moe_top_k
108
+ self.moe_jitter_eps = moe_jitter_eps
109
+ self.moe_loss_weight = moe_loss_weight
110
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
111
+
112
+ for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
113
+ if k in kwargs:
114
+ kwargs.pop(k)
115
+ if len(kwargs) != 0:
116
+ raise ValueError(f"Found unknown {kwargs=}")
117
+
118
+
119
+ class DbrxConfig(PretrainedConfig):
120
+ r"""
121
+
122
+ This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
123
+ specified arguments, defining the model architecture. Instantiating a configuration with the
124
+ defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
125
+
126
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
127
+ documentation from [`PretrainedConfig`] for more information.
128
+
129
+
130
+ Args:
131
+ d_model (`int`, *optional*, defaults to 2048):
132
+ Dimensionality of the embeddings and hidden states.
133
+ n_heads (`int`, *optional*, defaults to 16):
134
+ Number of attention heads for each attention layer in the Transformer encoder.
135
+ n_layers (`int`, *optional*, defaults to 24):
136
+ Number of hidden layers in the Transformer encoder.
137
+ max_seq_len (`int`, *optional*, defaults to 2048):
138
+ The maximum sequence length of the model.
139
+ vocab_size (`int`, *optional*, defaults to 32000):
140
+ Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
141
+ the `inputs_ids` passed when calling [`DbrxModel`].
142
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
143
+ The dropout probability applied to the attention output before combining with residual.
144
+ emb_pdrop (`float`, *optional*, defaults to 0.0):
145
+ The dropout probability for the embedding layer.
146
+ attn_config (`dict`, *optional*):
147
+ A dictionary used to configure the model's attention module.
148
+ ffn_config (`dict`, *optional*):
149
+ A dictionary used to configure the model's FFN module.
150
+ use_cache (`bool`, *optional*, defaults to `True`):
151
+ Whether or not the model should return the last key/values attentions (not used by all models).
152
+ initializer_range (`float`, *optional*, defaults to 0.02):
153
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
154
+ output_router_logits (`bool`, *optional*, defaults to `False`):
155
+ Whether or not the router logits should be returned by the model. Enabling this will also
156
+ allow the model to output the auxiliary loss. See [here]() for more details.
157
+
158
+
159
+ Example:
160
+ ```python
161
+ >>> from transformers import DbrxConfig, DbrxModel
162
+
163
+ >>> # Initializing a Dbrx configuration
164
+ >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
165
+
166
+ >>> # Initializing a model (with random weights) from the configuration
167
+ >>> model = DbrxModel(configuration)
168
+
169
+ >>> # Accessing the model configuration
170
+ >>> configuration = model.config
171
+ ```
172
+ """
173
+
174
+ model_type = "dbrx"
175
+ sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
176
+ attribute_map = {
177
+ "num_attention_heads": "n_heads",
178
+ "hidden_size": "d_model",
179
+ "num_hidden_layers": "n_layers",
180
+ "max_position_embeddings": "max_seq_len",
181
+ }
182
+
183
+ def __init__(
184
+ self,
185
+ d_model: int = 2048,
186
+ n_heads: int = 16,
187
+ n_layers: int = 24,
188
+ max_seq_len: int = 2048,
189
+ vocab_size: int = 32000,
190
+ resid_pdrop: float = 0.0,
191
+ emb_pdrop: float = 0.0,
192
+ attn_config: Optional[DbrxAttentionConfig] = None,
193
+ ffn_config: Optional[DbrxFFNConfig] = None,
194
+ use_cache: bool = True,
195
+ initializer_range: float = 0.02,
196
+ output_router_logits: bool = False,
197
+ **kwargs: Any,
198
+ ):
199
+ if attn_config is None:
200
+ self.attn_config = DbrxAttentionConfig()
201
+ elif isinstance(attn_config, dict):
202
+ self.attn_config = DbrxAttentionConfig(**attn_config)
203
+ else:
204
+ self.attn_config = attn_config
205
+
206
+ if ffn_config is None:
207
+ self.ffn_config = DbrxFFNConfig()
208
+ elif isinstance(ffn_config, dict):
209
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
210
+ else:
211
+ self.ffn_config = ffn_config
212
+
213
+ self.d_model = d_model
214
+ self.n_heads = n_heads
215
+ self.n_layers = n_layers
216
+ self.max_seq_len = max_seq_len
217
+ self.vocab_size = vocab_size
218
+ self.resid_pdrop = resid_pdrop
219
+ self.emb_pdrop = emb_pdrop
220
+ self.use_cache = use_cache
221
+ self.initializer_range = initializer_range
222
+ self.output_router_logits = output_router_logits
223
+ self.num_key_value_heads = self.attn_config.kv_n_heads
224
+
225
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
226
+ if tie_word_embeddings:
227
+ raise ValueError("tie_word_embeddings is not supported for DBRX models.")
228
+
229
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
230
+
231
+
232
+ __all__ = ["DbrxConfig"]
docs/transformers/build/lib/transformers/models/dbrx/modeling_dbrx.py ADDED
@@ -0,0 +1,1392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DBRX model."""
16
+
17
+ import math
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from ...activations import ACT2FN
25
+ from ...cache_utils import Cache, DynamicCache, StaticCache
26
+ from ...generation import GenerationMixin
27
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
28
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
29
+ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
30
+ from ...modeling_utils import PreTrainedModel
31
+ from ...utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ is_torch_flex_attn_available,
35
+ logging,
36
+ replace_return_docstrings,
37
+ )
38
+ from .configuration_dbrx import DbrxConfig
39
+
40
+
41
+ if is_torch_flex_attn_available():
42
+ from torch.nn.attention.flex_attention import BlockMask
43
+
44
+ from ...integrations.flex_attention import make_flex_block_causal_mask
45
+
46
+
47
+ if is_flash_attn_available():
48
+ from ...modeling_flash_attention_utils import _flash_attention_forward
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "DbrxConfig"
53
+
54
+
55
+ class DbrxRotaryEmbedding(nn.Module):
56
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
57
+ super().__init__()
58
+
59
+ self.dim = dim
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.base = base
62
+
63
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
64
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
65
+
66
+ @torch.no_grad()
67
+ def forward(self, x, position_ids, seq_len=None):
68
+ # x: [bs, num_attention_heads, seq_len, head_size]
69
+ self.inv_freq.to(x.device)
70
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
71
+ position_ids_expanded = position_ids[:, None, :].float()
72
+ # Force float32 since bfloat16 loses precision on long contexts
73
+ # See https://github.com/huggingface/transformers/pull/29285
74
+ device_type = x.device.type
75
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
76
+ with torch.autocast(device_type=device_type, enabled=False):
77
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
78
+ emb = torch.cat((freqs, freqs), dim=-1)
79
+ cos = emb.cos()
80
+ sin = emb.sin()
81
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
82
+
83
+
84
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
85
+ def rotate_half(x):
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
93
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
94
+ """Applies Rotary Position Embedding to the query and key tensors.
95
+
96
+ Args:
97
+ q (`torch.Tensor`): The query tensor.
98
+ k (`torch.Tensor`): The key tensor.
99
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
100
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
101
+ position_ids (`torch.Tensor`, *optional*):
102
+ Deprecated and unused.
103
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
104
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
105
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
106
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
107
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
108
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
109
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
110
+ Returns:
111
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
112
+ """
113
+ cos = cos.unsqueeze(unsqueeze_dim)
114
+ sin = sin.unsqueeze(unsqueeze_dim)
115
+ q_embed = (q * cos) + (rotate_half(q) * sin)
116
+ k_embed = (k * cos) + (rotate_half(k) * sin)
117
+ return q_embed, k_embed
118
+
119
+
120
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
121
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
122
+ """
123
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
124
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
125
+ """
126
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
127
+ if n_rep == 1:
128
+ return hidden_states
129
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
130
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
131
+
132
+
133
+ def load_balancing_loss_func(
134
+ gate_logits: torch.Tensor,
135
+ num_experts: int,
136
+ top_k: int,
137
+ attention_mask: Optional[torch.Tensor],
138
+ ) -> torch.Tensor:
139
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
140
+
141
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
142
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
143
+ experts is too unbalanced.
144
+
145
+ Args:
146
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
147
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
148
+ shape [batch_size X sequence_length, num_experts].
149
+ num_experts (`int`):
150
+ Number of experts.
151
+ top_k (`int`):
152
+ The number of experts each token is routed to.
153
+ attention_mask (`torch.Tensor`, *optional*):
154
+ The attention_mask used in forward function
155
+ shape [batch_size X sequence_length] if not None.
156
+
157
+ Returns:
158
+ The auxiliary loss.
159
+ """
160
+ if gate_logits is None or not isinstance(gate_logits, tuple):
161
+ return torch.tensor(0.0)
162
+
163
+ if isinstance(gate_logits, tuple):
164
+ compute_device = gate_logits[0].device
165
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
166
+
167
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
168
+
169
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
170
+
171
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
172
+
173
+ if attention_mask is None:
174
+ # Compute the percentage of tokens routed to each experts
175
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
176
+
177
+ # Compute the average probability of routing to these experts
178
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
179
+ else:
180
+ batch_size, sequence_length = attention_mask.shape
181
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
182
+
183
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
184
+ expert_attention_mask = (
185
+ attention_mask[None, :, :, None, None]
186
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
187
+ .reshape(-1, top_k, num_experts)
188
+ .to(compute_device)
189
+ )
190
+
191
+ # Compute the percentage of tokens routed to each experts
192
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
193
+ expert_attention_mask, dim=0
194
+ )
195
+
196
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
197
+ router_per_expert_attention_mask = (
198
+ attention_mask[None, :, :, None]
199
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
200
+ .reshape(-1, num_experts)
201
+ .to(compute_device)
202
+ )
203
+
204
+ # Compute the average probability of routing to these experts
205
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
206
+ router_per_expert_attention_mask, dim=0
207
+ )
208
+
209
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
210
+ return overall_loss * num_experts
211
+
212
+
213
+ class DbrxAttention(nn.Module):
214
+ """Multi-head self attention."""
215
+
216
+ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
217
+ super().__init__()
218
+ self.config = config
219
+ self.hidden_size = config.d_model
220
+ self.num_heads = config.n_heads
221
+ self.head_dim = self.hidden_size // self.num_heads
222
+ self.max_position_embeddings = config.max_seq_len
223
+ self.block_idx = block_idx
224
+ if block_idx is None:
225
+ logger.warning_once(
226
+ f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will "
227
+ + "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` "
228
+ + "when creating this class."
229
+ )
230
+
231
+ attn_config = config.attn_config
232
+ self.attn_pdrop = attn_config.attn_pdrop
233
+ self.clip_qkv = attn_config.clip_qkv
234
+ self.num_key_value_heads = attn_config.kv_n_heads
235
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
236
+ self.rope_theta = attn_config.rope_theta
237
+ self.is_causal = True
238
+
239
+ self.Wqkv = nn.Linear(
240
+ self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
241
+ )
242
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
243
+ self.rotary_emb = DbrxRotaryEmbedding(
244
+ self.head_dim,
245
+ max_position_embeddings=self.max_position_embeddings,
246
+ base=self.rope_theta,
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ position_ids: torch.LongTensor,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ past_key_value: Optional[Cache] = None,
255
+ output_attentions: bool = False,
256
+ use_cache: bool = False,
257
+ cache_position: Optional[torch.LongTensor] = None,
258
+ **kwargs: Any,
259
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
260
+ bsz, q_len, _ = hidden_states.size()
261
+
262
+ qkv_states = self.Wqkv(hidden_states)
263
+ min_val = -self.clip_qkv if self.clip_qkv is not None else None
264
+ max_val = self.clip_qkv
265
+ qkv_states = qkv_states.clamp(min=min_val, max=max_val)
266
+
267
+ query_states, key_states, value_states = qkv_states.split(
268
+ [
269
+ self.hidden_size,
270
+ self.num_key_value_heads * self.head_dim,
271
+ self.num_key_value_heads * self.head_dim,
272
+ ],
273
+ dim=2,
274
+ )
275
+
276
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
277
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
278
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
279
+
280
+ cos, sin = self.rotary_emb(value_states, position_ids)
281
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
282
+
283
+ if past_key_value is not None:
284
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
285
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
286
+ key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
287
+
288
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
289
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
290
+
291
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
292
+
293
+ if attention_mask is not None: # no matter the length, we just slice it
294
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
295
+ attn_weights = attn_weights + causal_mask
296
+
297
+ # upcast attention to fp32
298
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
299
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
300
+ attn_output = torch.matmul(attn_weights, value_states)
301
+
302
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
303
+ raise ValueError(
304
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
305
+ + f" {attn_output.size()}"
306
+ )
307
+
308
+ attn_output = attn_output.transpose(1, 2).contiguous()
309
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
310
+ attn_output = self.out_proj(attn_output)
311
+
312
+ if not output_attentions:
313
+ attn_weights = None
314
+
315
+ return attn_output, attn_weights, past_key_value
316
+
317
+
318
+ class DbrxFlashAttention2(DbrxAttention):
319
+ """Dbrx flash attention module.
320
+
321
+ This module inherits from `DbrxAttention` as the weights of the module stays
322
+ untouched. The only required change would be on the forward pass where it
323
+ calls the public API of flash attention.
324
+ """
325
+
326
+ def __init__(self, *args, **kwargs):
327
+ super().__init__(*args, **kwargs)
328
+
329
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
330
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
331
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
332
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states: torch.Tensor,
337
+ attention_mask: Optional[torch.LongTensor] = None,
338
+ position_ids: Optional[torch.LongTensor] = None,
339
+ past_key_value: Optional[Cache] = None,
340
+ output_attentions: bool = False,
341
+ use_cache: bool = False,
342
+ cache_position: Optional[torch.LongTensor] = None,
343
+ **kwargs: Any,
344
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
345
+ if isinstance(past_key_value, StaticCache):
346
+ raise ValueError(
347
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
348
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
349
+ )
350
+ logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
351
+ output_attentions = False
352
+
353
+ bsz, q_len, _ = hidden_states.size()
354
+
355
+ qkv_states = self.Wqkv(hidden_states)
356
+ if self.clip_qkv is not None:
357
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
358
+
359
+ query_states, key_states, value_states = qkv_states.split(
360
+ [
361
+ self.hidden_size,
362
+ self.num_key_value_heads * self.head_dim,
363
+ self.num_key_value_heads * self.head_dim,
364
+ ],
365
+ dim=2,
366
+ )
367
+
368
+ # Flash attention requires the input to have the shape
369
+ # batch_size x seq_length x head_dim x hidden_dim
370
+ # therefore we just need to keep the original shape
371
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
372
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
373
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
374
+
375
+ cos, sin = self.rotary_emb(value_states, position_ids)
376
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
377
+
378
+ if past_key_value is not None:
379
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
380
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
381
+ key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
382
+
383
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
384
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
385
+ # to be able to avoid many of these transpose/reshape/view.
386
+ query_states = query_states.transpose(1, 2)
387
+ key_states = key_states.transpose(1, 2)
388
+ value_states = value_states.transpose(1, 2)
389
+
390
+ dropout_rate = self.attn_pdrop if self.training else 0.0
391
+
392
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
393
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
394
+ # cast them back in the correct dtype just to be sure everything works as expected.
395
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
396
+ # in fp32. (LlamaRMSNorm handles it correctly)
397
+ input_dtype = query_states.dtype
398
+ if input_dtype == torch.float32:
399
+ if torch.is_autocast_enabled():
400
+ target_dtype = torch.get_autocast_gpu_dtype()
401
+ # Handle the case where the model is quantized
402
+ elif hasattr(self.config, "_pre_quantization_dtype"):
403
+ target_dtype = self.config._pre_quantization_dtype
404
+ else:
405
+ target_dtype = query_states.dtype
406
+
407
+ logger.warning_once(
408
+ "The input hidden states seems to be silently casted in float32, this might be "
409
+ + "related to the fact you have upcasted embedding or layer norm layers in "
410
+ + f"float32. We will cast back the input in {target_dtype}."
411
+ )
412
+
413
+ query_states = query_states.to(target_dtype)
414
+ key_states = key_states.to(target_dtype)
415
+ value_states = value_states.to(target_dtype)
416
+
417
+ attn_output = _flash_attention_forward(
418
+ query_states,
419
+ key_states,
420
+ value_states,
421
+ attention_mask,
422
+ q_len,
423
+ position_ids=position_ids,
424
+ dropout=dropout_rate,
425
+ is_causal=self.is_causal,
426
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
427
+ )
428
+
429
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
430
+ attn_output = self.out_proj(attn_output)
431
+
432
+ if not output_attentions:
433
+ attn_weights = None
434
+
435
+ return attn_output, attn_weights, past_key_value
436
+
437
+
438
+ class DbrxSdpaAttention(DbrxAttention):
439
+ """
440
+ Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
441
+ `DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
442
+ SDPA API.
443
+ """
444
+
445
+ def forward(
446
+ self,
447
+ hidden_states: torch.Tensor,
448
+ attention_mask: Optional[torch.Tensor] = None,
449
+ position_ids: Optional[torch.LongTensor] = None,
450
+ past_key_value: Optional[Cache] = None,
451
+ output_attentions: bool = False,
452
+ use_cache: bool = False,
453
+ cache_position: Optional[torch.LongTensor] = None,
454
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
455
+ if output_attentions:
456
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
457
+ logger.warning_once(
458
+ "DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
459
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
460
+ )
461
+ return super().forward(
462
+ hidden_states=hidden_states,
463
+ attention_mask=attention_mask,
464
+ position_ids=position_ids,
465
+ past_key_value=past_key_value,
466
+ output_attentions=output_attentions,
467
+ use_cache=use_cache,
468
+ cache_position=cache_position,
469
+ )
470
+
471
+ bsz, q_len, _ = hidden_states.size()
472
+
473
+ qkv_states = self.Wqkv(hidden_states)
474
+ if self.clip_qkv is not None:
475
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
476
+
477
+ query_states, key_states, value_states = qkv_states.split(
478
+ [
479
+ self.hidden_size,
480
+ self.num_key_value_heads * self.head_dim,
481
+ self.num_key_value_heads * self.head_dim,
482
+ ],
483
+ dim=2,
484
+ )
485
+
486
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
487
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
488
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
489
+
490
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
491
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
492
+
493
+ if past_key_value is not None:
494
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
495
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
496
+ key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
497
+
498
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
499
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
500
+
501
+ causal_mask = attention_mask
502
+ if attention_mask is not None:
503
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
504
+
505
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
506
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
507
+ if query_states.device.type == "cuda" and causal_mask is not None:
508
+ query_states = query_states.contiguous()
509
+ key_states = key_states.contiguous()
510
+ value_states = value_states.contiguous()
511
+
512
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
513
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
514
+ is_causal = True if causal_mask is None and q_len > 1 else False
515
+
516
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attn_mask=causal_mask,
521
+ dropout_p=self.attn_pdrop if self.training else 0.0,
522
+ is_causal=is_causal,
523
+ )
524
+
525
+ attn_output = attn_output.transpose(1, 2).contiguous()
526
+ attn_output = attn_output.view(bsz, q_len, -1)
527
+
528
+ attn_output = self.out_proj(attn_output)
529
+
530
+ return attn_output, None, past_key_value
531
+
532
+
533
+ DBRX_ATTENTION_CLASSES = {
534
+ "eager": DbrxAttention,
535
+ "flash_attention_2": DbrxFlashAttention2,
536
+ "sdpa": DbrxSdpaAttention,
537
+ }
538
+
539
+
540
+ class DbrxNormAttentionNorm(nn.Module):
541
+ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
542
+ super().__init__()
543
+ self.block_idx = block_idx
544
+ self.resid_pdrop = config.resid_pdrop
545
+ self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
546
+ self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
547
+ config=config,
548
+ block_idx=block_idx,
549
+ )
550
+ self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
551
+
552
+ def forward(
553
+ self,
554
+ hidden_states: torch.Tensor,
555
+ position_ids: torch.LongTensor,
556
+ attention_mask: Optional[torch.Tensor] = None,
557
+ past_key_value: Optional[Cache] = None,
558
+ output_attentions: bool = False,
559
+ use_cache: bool = False,
560
+ cache_position: Optional[torch.LongTensor] = None,
561
+ **kwargs: Any,
562
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
563
+ residual_states = hidden_states
564
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
565
+
566
+ hidden_states, attn_weights, past_key_value = self.attn(
567
+ hidden_states=hidden_states,
568
+ attention_mask=attention_mask,
569
+ position_ids=position_ids,
570
+ past_key_value=past_key_value,
571
+ output_attentions=output_attentions,
572
+ use_cache=use_cache,
573
+ cache_position=cache_position,
574
+ **kwargs,
575
+ )
576
+
577
+ hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
578
+ hidden_states = hidden_states + residual_states
579
+
580
+ residual_states = hidden_states
581
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
582
+
583
+ return residual_states, hidden_states, attn_weights, past_key_value
584
+
585
+
586
+ class DbrxRouter(nn.Module):
587
+ def __init__(
588
+ self,
589
+ hidden_size: int,
590
+ moe_num_experts: int,
591
+ moe_top_k: int,
592
+ moe_jitter_eps: Optional[float],
593
+ moe_normalize_expert_weights: Optional[float],
594
+ ):
595
+ super().__init__()
596
+ self.hidden_size = hidden_size
597
+ self.moe_num_experts = moe_num_experts
598
+ self.moe_top_k = moe_top_k
599
+ self.moe_jitter_eps = moe_jitter_eps
600
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
601
+
602
+ self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
603
+
604
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
605
+ if self.training and self.moe_jitter_eps is not None:
606
+ hidden_states *= torch.empty_like(hidden_states).uniform_(
607
+ 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
608
+ )
609
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
610
+ weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
611
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
612
+
613
+ top_weights_scale = (
614
+ torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
615
+ if self.moe_normalize_expert_weights is not None
616
+ else 1.0
617
+ )
618
+ top_weights = top_weights / top_weights_scale
619
+
620
+ weights = weights.to(hidden_states.dtype)
621
+ top_weights = top_weights.to(hidden_states.dtype)
622
+ return weights, top_weights, top_experts
623
+
624
+
625
+ class DbrxExpertGLU(nn.Module):
626
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
627
+ super().__init__()
628
+ self.hidden_size = hidden_size
629
+ self.ffn_hidden_size = ffn_hidden_size
630
+ self.moe_num_experts = moe_num_experts
631
+
632
+ self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
633
+ self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
634
+ self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
635
+
636
+ act_fn_name = ffn_act_fn.get("name", "silu")
637
+ self.activation_fn = ACT2FN[act_fn_name]
638
+
639
+ def forward(
640
+ self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
641
+ ) -> torch.Tensor:
642
+ gate_proj = x.matmul(expert_w1.t())
643
+ up_proj = x.matmul(expert_v1.t())
644
+ gate_proj = self.activation_fn(gate_proj)
645
+ intermediate_states = gate_proj * up_proj
646
+ down_proj = intermediate_states.matmul(expert_w2)
647
+ return down_proj
648
+
649
+
650
+ class DbrxExperts(nn.Module):
651
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
652
+ super().__init__()
653
+ self.moe_num_experts = moe_num_experts
654
+ self.mlp = DbrxExpertGLU(
655
+ hidden_size=hidden_size,
656
+ ffn_hidden_size=ffn_hidden_size,
657
+ moe_num_experts=moe_num_experts,
658
+ ffn_act_fn=ffn_act_fn,
659
+ )
660
+
661
+ def forward(
662
+ self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
663
+ ) -> torch.Tensor:
664
+ bsz, q_len, hidden_size = x.shape
665
+ x = x.view(-1, hidden_size)
666
+ out = torch.zeros_like(x)
667
+
668
+ expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
669
+ # Chunk experts at once to avoid storing full parameter multiple times in autograd
670
+ w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
671
+ self.moe_num_experts, dim=0
672
+ )
673
+ v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
674
+ self.moe_num_experts, dim=0
675
+ )
676
+ w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
677
+ self.moe_num_experts, dim=0
678
+ )
679
+ w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
680
+ v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
681
+ w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
682
+ for expert_idx in range(0, self.moe_num_experts):
683
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
684
+ # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
685
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
686
+ if token_idx.shape[0] == 0:
687
+ continue
688
+
689
+ token_list = token_idx
690
+ topk_list = topk_idx
691
+
692
+ expert_tokens = x[None, token_list].reshape(-1, hidden_size)
693
+ expert_out = (
694
+ self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
695
+ * top_weights[token_list, topk_list, None]
696
+ )
697
+
698
+ out.index_add_(0, token_idx, expert_out)
699
+
700
+ out = out.reshape(bsz, q_len, hidden_size)
701
+ return out
702
+
703
+
704
+ class DbrxFFN(nn.Module):
705
+ def __init__(self, config: DbrxConfig):
706
+ super().__init__()
707
+
708
+ ffn_config = config.ffn_config
709
+ self.router = DbrxRouter(
710
+ hidden_size=config.d_model,
711
+ moe_num_experts=ffn_config.moe_num_experts,
712
+ moe_top_k=ffn_config.moe_top_k,
713
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
714
+ moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights,
715
+ )
716
+
717
+ self.experts = DbrxExperts(
718
+ hidden_size=config.d_model,
719
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
720
+ moe_num_experts=ffn_config.moe_num_experts,
721
+ ffn_act_fn=ffn_config.ffn_act_fn,
722
+ )
723
+
724
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
725
+ weights, top_weights, top_experts = self.router(x)
726
+ out = self.experts(x, weights, top_weights, top_experts)
727
+ return out, weights
728
+
729
+
730
+ class DbrxBlock(nn.Module):
731
+ def __init__(self, config: DbrxConfig, block_idx: int):
732
+ super().__init__()
733
+ self.hidden_size = config.d_model
734
+ self.resid_pdrop = config.resid_pdrop
735
+ self.block_idx = block_idx
736
+ self.norm_attn_norm = DbrxNormAttentionNorm(
737
+ config=config,
738
+ block_idx=block_idx,
739
+ )
740
+ self.ffn = DbrxFFN(config=config)
741
+
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ attention_mask: Optional[torch.Tensor] = None,
746
+ position_ids: Optional[torch.LongTensor] = None,
747
+ past_key_value: Optional[Cache] = None,
748
+ output_attentions: Optional[bool] = False,
749
+ output_router_logits: Optional[bool] = False,
750
+ use_cache: Optional[bool] = False,
751
+ cache_position: Optional[torch.LongTensor] = None,
752
+ **kwargs: Any,
753
+ ) -> Union[
754
+ Tuple[torch.Tensor],
755
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
756
+ Tuple[torch.Tensor, Optional[Cache]],
757
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
758
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
759
+ Tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
760
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]],
761
+ ]:
762
+ """Forward function for DbrxBlock.
763
+
764
+ Args:
765
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
766
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
767
+ attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length)
768
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
769
+ if default attention is used.
770
+ past_key_value (`Tuple(torch.Tensor)`, *optional*): cached past key and value projection states
771
+ output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all
772
+ attention layers. See `attentions` under returned tensors for more detail.
773
+ output_router_logits (`bool`, *optional*): Whether or not to return the router logits.
774
+ use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are
775
+ returned and can be used to speed up decoding (see `past_key_values`).
776
+ cache_position (`torch.LongTensor`, *optional*): position ids of the cache
777
+ """
778
+
779
+ # Norm + Attention + Norm
780
+ resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
781
+ hidden_states=hidden_states,
782
+ attention_mask=attention_mask,
783
+ position_ids=position_ids,
784
+ past_key_value=past_key_value,
785
+ output_attentions=output_attentions,
786
+ use_cache=use_cache,
787
+ cache_position=cache_position,
788
+ **kwargs,
789
+ )
790
+
791
+ # Fully Connected
792
+ hidden_states, router_logits = self.ffn(hidden_states)
793
+ hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
794
+ hidden_states = resid_states + hidden_states
795
+
796
+ outputs = (hidden_states,)
797
+
798
+ if output_attentions:
799
+ outputs += (self_attn_weights,)
800
+
801
+ if use_cache:
802
+ outputs += (present_key_value,)
803
+
804
+ if output_router_logits:
805
+ outputs += (router_logits,)
806
+
807
+ return outputs
808
+
809
+
810
+ DBRX_START_DOCSTRING = r"""
811
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
812
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
813
+ etc.)
814
+
815
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
816
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
817
+ and behavior.
818
+
819
+ Parameters:
820
+ config ([`DbrxConfig`]):
821
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
822
+ load the weights associated with the model, only the configuration. Check out the
823
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
824
+ """
825
+
826
+
827
+ @add_start_docstrings(
828
+ "The bare DBRX Model outputting raw hidden-states without any specific head on top.",
829
+ DBRX_START_DOCSTRING,
830
+ )
831
+ class DbrxPreTrainedModel(PreTrainedModel):
832
+ config_class = DbrxConfig
833
+ base_model_prefix = "transformer"
834
+ supports_gradient_checkpointing = True
835
+ _no_split_modules = ["DbrxBlock"]
836
+ _skip_keys_device_placement = ["past_key_values"]
837
+ _supports_flash_attn_2 = True
838
+ _supports_sdpa = True
839
+ _supports_cache_class = True
840
+ _supports_quantized_cache = True
841
+ _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
842
+
843
+ def _init_weights(self, module: nn.Module):
844
+ std = self.config.initializer_range
845
+ if isinstance(module, nn.Linear):
846
+ module.weight.data.normal_(mean=0.0, std=std)
847
+ if module.bias is not None:
848
+ module.bias.data.zero_()
849
+ elif isinstance(module, nn.Embedding):
850
+ module.weight.data.normal_(mean=0.0, std=std)
851
+ if module.padding_idx is not None:
852
+ module.weight.data[module.padding_idx].zero_()
853
+ elif isinstance(module, nn.LayerNorm):
854
+ module.weight.data.fill_(1.0)
855
+ if module.bias is not None:
856
+ module.bias.data.zero_()
857
+ elif isinstance(module, DbrxExpertGLU):
858
+ module.w1.data.normal_(mean=0.0, std=std)
859
+ module.v1.data.normal_(mean=0.0, std=std)
860
+ module.w2.data.normal_(mean=0.0, std=std)
861
+
862
+
863
+ DBRX_INPUTS_DOCSTRING = r"""
864
+ Args:
865
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
866
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
867
+ it.
868
+
869
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
870
+ [`PreTrainedTokenizer.__call__`] for details.
871
+
872
+ [What are input IDs?](../glossary#input-ids)
873
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
874
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
875
+
876
+ - 1 for tokens that are **not masked**,
877
+ - 0 for tokens that are **masked**.
878
+
879
+ [What are attention masks?](../glossary#attention-mask)
880
+
881
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
882
+ [`PreTrainedTokenizer.__call__`] for details.
883
+
884
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
885
+ `past_key_values`).
886
+
887
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
888
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
889
+ information on the default strategy.
890
+
891
+ - 1 indicates the head is **not masked**,
892
+ - 0 indicates the head is **masked**.
893
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
894
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
895
+ config.n_positions - 1]`.
896
+
897
+ [What are position IDs?](../glossary#position-ids)
898
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
899
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
900
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
901
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
902
+
903
+ Two formats are allowed:
904
+ - a [`~cache_utils.Cache`] instance, see our
905
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
906
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
907
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
908
+ cache format.
909
+
910
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
911
+ legacy cache format will be returned.
912
+
913
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
914
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
915
+ of shape `(batch_size, sequence_length)`.
916
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
917
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
918
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
919
+ model's internal embedding lookup matrix.
920
+ use_cache (`bool`, *optional*):
921
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
922
+ `past_key_values`).
923
+ output_attentions (`bool`, *optional*):
924
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
925
+ tensors for more detail.
926
+ output_hidden_states (`bool`, *optional*):
927
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
928
+ more detail.
929
+ output_router_logits (`bool`, *optional*):
930
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
931
+ should not be returned during inference.
932
+ return_dict (`bool`, *optional*):
933
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
934
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
935
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
936
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
937
+ the complete sequence length.
938
+ """
939
+
940
+
941
+ @add_start_docstrings(
942
+ "The bare DBRX Model outputting raw hidden-states without any specific head on top.",
943
+ DBRX_START_DOCSTRING,
944
+ )
945
+ class DbrxModel(DbrxPreTrainedModel):
946
+ """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
947
+
948
+ Args:
949
+ config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
950
+ Initializing with a config file does not load the weights associated with the model, only the
951
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
952
+ """
953
+
954
+ def __init__(self, config: DbrxConfig):
955
+ super().__init__(config)
956
+ self.padding_idx = config.pad_token_id
957
+ self.vocab_size = config.vocab_size
958
+ self.emb_pdrop = config.emb_pdrop
959
+
960
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
961
+ self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
962
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
963
+ self.gradient_checkpointing = False
964
+
965
+ # Initialize weights and apply final processing
966
+ self.post_init()
967
+
968
+ def get_input_embeddings(self) -> nn.Embedding:
969
+ return self.wte
970
+
971
+ def set_input_embeddings(self, value: nn.Embedding):
972
+ self.wte = value
973
+
974
+ @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
975
+ def forward(
976
+ self,
977
+ input_ids: Optional[torch.LongTensor] = None,
978
+ attention_mask: Optional[torch.Tensor] = None,
979
+ position_ids: Optional[torch.LongTensor] = None,
980
+ past_key_values: Optional[Cache] = None,
981
+ inputs_embeds: Optional[torch.Tensor] = None,
982
+ use_cache: Optional[bool] = None,
983
+ output_attentions: Optional[bool] = None,
984
+ output_hidden_states: Optional[bool] = None,
985
+ output_router_logits: Optional[bool] = None,
986
+ return_dict: Optional[bool] = None,
987
+ cache_position: Optional[torch.LongTensor] = None,
988
+ **kwargs, # NOOP kwargs, for now
989
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
990
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
991
+ output_hidden_states = (
992
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
993
+ )
994
+ output_router_logits = (
995
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
996
+ )
997
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
998
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
999
+
1000
+ if (input_ids is None) ^ (inputs_embeds is not None):
1001
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1002
+
1003
+ if self.gradient_checkpointing and self.training and use_cache:
1004
+ logger.warning_once(
1005
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1006
+ )
1007
+ use_cache = False
1008
+
1009
+ if inputs_embeds is None:
1010
+ inputs_embeds = self.wte(input_ids)
1011
+
1012
+ inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
1013
+
1014
+ # kept for BC (non `Cache` `past_key_values` inputs)
1015
+ return_legacy_cache = False
1016
+ if use_cache and not isinstance(past_key_values, Cache):
1017
+ return_legacy_cache = True
1018
+ if past_key_values is None:
1019
+ past_key_values = DynamicCache()
1020
+ else:
1021
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1022
+ logger.warning_once(
1023
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
1024
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
1025
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
1026
+ )
1027
+
1028
+ if cache_position is None:
1029
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1030
+ cache_position = torch.arange(
1031
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1032
+ )
1033
+
1034
+ if position_ids is None:
1035
+ position_ids = cache_position.unsqueeze(0)
1036
+
1037
+ causal_mask = self._update_causal_mask(
1038
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1039
+ )
1040
+
1041
+ # embed positions
1042
+ hidden_states = inputs_embeds
1043
+
1044
+ # decoder layers
1045
+ all_hidden_states = () if output_hidden_states else None
1046
+ all_self_attns = () if output_attentions else None
1047
+ all_router_logits = () if output_router_logits else None
1048
+ next_decoder_cache = None
1049
+
1050
+ for block in self.blocks:
1051
+ if output_hidden_states:
1052
+ all_hidden_states += (hidden_states,)
1053
+
1054
+ if self.gradient_checkpointing and self.training:
1055
+ block_outputs = self._gradient_checkpointing_func(
1056
+ block.__call__,
1057
+ hidden_states,
1058
+ causal_mask,
1059
+ position_ids,
1060
+ past_key_values,
1061
+ output_attentions,
1062
+ output_router_logits,
1063
+ use_cache,
1064
+ cache_position,
1065
+ )
1066
+ else:
1067
+ block_outputs = block(
1068
+ hidden_states,
1069
+ attention_mask=causal_mask,
1070
+ position_ids=position_ids,
1071
+ past_key_value=past_key_values,
1072
+ output_attentions=output_attentions,
1073
+ output_router_logits=output_router_logits,
1074
+ use_cache=use_cache,
1075
+ cache_position=cache_position,
1076
+ )
1077
+
1078
+ hidden_states = block_outputs[0]
1079
+
1080
+ if use_cache:
1081
+ next_decoder_cache = block_outputs[2 if output_attentions else 1]
1082
+
1083
+ if output_attentions:
1084
+ all_self_attns += (block_outputs[1],)
1085
+
1086
+ if output_router_logits:
1087
+ all_router_logits += (block_outputs[-1],)
1088
+
1089
+ hidden_states = self.norm_f(hidden_states)
1090
+
1091
+ # add hidden states from the last decoder layer
1092
+ if output_hidden_states:
1093
+ all_hidden_states += (hidden_states,)
1094
+
1095
+ next_cache = next_decoder_cache if use_cache else None
1096
+ if return_legacy_cache:
1097
+ next_cache = next_cache.to_legacy_cache()
1098
+
1099
+ if not return_dict:
1100
+ return tuple(
1101
+ v
1102
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1103
+ if v is not None
1104
+ )
1105
+ return MoeModelOutputWithPast(
1106
+ last_hidden_state=hidden_states,
1107
+ past_key_values=next_cache,
1108
+ hidden_states=all_hidden_states,
1109
+ attentions=all_self_attns,
1110
+ router_logits=all_router_logits,
1111
+ )
1112
+
1113
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1114
+ def _update_causal_mask(
1115
+ self,
1116
+ attention_mask: Union[torch.Tensor, "BlockMask"],
1117
+ input_tensor: torch.Tensor,
1118
+ cache_position: torch.Tensor,
1119
+ past_key_values: Cache,
1120
+ output_attentions: bool = False,
1121
+ ):
1122
+ if self.config._attn_implementation == "flash_attention_2":
1123
+ if attention_mask is not None and (attention_mask == 0.0).any():
1124
+ return attention_mask
1125
+ return None
1126
+ if self.config._attn_implementation == "flex_attention":
1127
+ if isinstance(attention_mask, torch.Tensor):
1128
+ attention_mask = make_flex_block_causal_mask(attention_mask)
1129
+ return attention_mask
1130
+
1131
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1132
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1133
+ # to infer the attention mask.
1134
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1135
+ using_static_cache = isinstance(past_key_values, StaticCache)
1136
+
1137
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1138
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1139
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1140
+ attention_mask,
1141
+ inputs_embeds=input_tensor,
1142
+ past_key_values_length=past_seen_tokens,
1143
+ is_training=self.training,
1144
+ ):
1145
+ return None
1146
+
1147
+ dtype, device = input_tensor.dtype, input_tensor.device
1148
+ sequence_length = input_tensor.shape[1]
1149
+ if using_static_cache:
1150
+ target_length = past_key_values.get_max_cache_shape()
1151
+ else:
1152
+ target_length = (
1153
+ attention_mask.shape[-1]
1154
+ if isinstance(attention_mask, torch.Tensor)
1155
+ else past_seen_tokens + sequence_length + 1
1156
+ )
1157
+
1158
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1159
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1160
+ attention_mask,
1161
+ sequence_length=sequence_length,
1162
+ target_length=target_length,
1163
+ dtype=dtype,
1164
+ device=device,
1165
+ cache_position=cache_position,
1166
+ batch_size=input_tensor.shape[0],
1167
+ )
1168
+
1169
+ if (
1170
+ self.config._attn_implementation == "sdpa"
1171
+ and attention_mask is not None
1172
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1173
+ and not output_attentions
1174
+ ):
1175
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1176
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1177
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1178
+ min_dtype = torch.finfo(dtype).min
1179
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1180
+
1181
+ return causal_mask
1182
+
1183
+ @staticmethod
1184
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
1185
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1186
+ attention_mask: torch.Tensor,
1187
+ sequence_length: int,
1188
+ target_length: int,
1189
+ dtype: torch.dtype,
1190
+ device: torch.device,
1191
+ cache_position: torch.Tensor,
1192
+ batch_size: int,
1193
+ **kwargs,
1194
+ ):
1195
+ """
1196
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1197
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1198
+
1199
+ Args:
1200
+ attention_mask (`torch.Tensor`):
1201
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1202
+ `(batch_size, 1, query_length, key_value_length)`.
1203
+ sequence_length (`int`):
1204
+ The sequence length being processed.
1205
+ target_length (`int`):
1206
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1207
+ to account for the 0 padding, the part of the cache that is not filled yet.
1208
+ dtype (`torch.dtype`):
1209
+ The dtype to use for the 4D attention mask.
1210
+ device (`torch.device`):
1211
+ The device to place the 4D attention mask on.
1212
+ cache_position (`torch.Tensor`):
1213
+ Indices depicting the position of the input sequence tokens in the sequence.
1214
+ batch_size (`torch.Tensor`):
1215
+ Batch size.
1216
+ """
1217
+ if attention_mask is not None and attention_mask.dim() == 4:
1218
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1219
+ causal_mask = attention_mask
1220
+ else:
1221
+ min_dtype = torch.finfo(dtype).min
1222
+ causal_mask = torch.full(
1223
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1224
+ )
1225
+ if sequence_length != 1:
1226
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1227
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1228
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1229
+ if attention_mask is not None:
1230
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1231
+ mask_length = attention_mask.shape[-1]
1232
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1233
+ causal_mask.device
1234
+ )
1235
+ padding_mask = padding_mask == 0
1236
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1237
+ padding_mask, min_dtype
1238
+ )
1239
+
1240
+ return causal_mask
1241
+
1242
+
1243
+ @add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING)
1244
+ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
1245
+ def __init__(self, config: DbrxConfig):
1246
+ super().__init__(config)
1247
+ self.transformer = DbrxModel(config)
1248
+ self.vocab_size = config.vocab_size
1249
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1250
+ self.moe_loss_weight = config.ffn_config.moe_loss_weight
1251
+ self.num_experts = config.ffn_config.moe_num_experts
1252
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
1253
+
1254
+ # Initialize weights and apply final processing
1255
+ self.post_init()
1256
+
1257
+ def get_input_embeddings(self) -> nn.Embedding:
1258
+ return self.transformer.get_input_embeddings()
1259
+
1260
+ def set_input_embeddings(self, value: nn.Embedding):
1261
+ self.transformer.set_input_embeddings(value)
1262
+
1263
+ def get_output_embeddings(self) -> nn.Linear:
1264
+ return self.lm_head
1265
+
1266
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1267
+ self.lm_head = new_embeddings
1268
+
1269
+ def set_decoder(self, decoder: DbrxModel):
1270
+ self.transformer = decoder
1271
+
1272
+ def get_decoder(self) -> DbrxModel:
1273
+ return self.transformer
1274
+
1275
+ @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
1276
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1277
+ def forward(
1278
+ self,
1279
+ input_ids: Optional[torch.LongTensor] = None,
1280
+ attention_mask: Optional[torch.Tensor] = None,
1281
+ position_ids: Optional[torch.LongTensor] = None,
1282
+ past_key_values: Optional[Cache] = None,
1283
+ inputs_embeds: Optional[torch.Tensor] = None,
1284
+ labels: Optional[torch.LongTensor] = None,
1285
+ use_cache: Optional[bool] = None,
1286
+ output_attentions: Optional[bool] = None,
1287
+ output_hidden_states: Optional[bool] = None,
1288
+ output_router_logits: Optional[bool] = None,
1289
+ return_dict: Optional[bool] = None,
1290
+ cache_position: Optional[torch.LongTensor] = None,
1291
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1292
+ **kwargs,
1293
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1294
+ r"""
1295
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1296
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1297
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1298
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1299
+
1300
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
1301
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1302
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1303
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1304
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1305
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
1306
+
1307
+ Returns:
1308
+
1309
+ Example:
1310
+
1311
+ ```python
1312
+ >> from transformers import AutoTokenizer, DbrxForCausalLM
1313
+
1314
+ >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
1315
+ >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
1316
+
1317
+ >> prompt = "Hey, are you conscious? Can you talk to me?"
1318
+ >> inputs = tokenizer(prompt, return_tensors="pt")
1319
+
1320
+ >> # Generate
1321
+ >> generate_ids = model.generate(inputs.input_ids, max_length=30)
1322
+ >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1323
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1324
+ ```
1325
+ """
1326
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1327
+ output_hidden_states = (
1328
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1329
+ )
1330
+ output_router_logits = (
1331
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1332
+ )
1333
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1334
+
1335
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1336
+ outputs = self.transformer(
1337
+ input_ids=input_ids,
1338
+ attention_mask=attention_mask,
1339
+ position_ids=position_ids,
1340
+ past_key_values=past_key_values,
1341
+ inputs_embeds=inputs_embeds,
1342
+ use_cache=use_cache,
1343
+ output_attentions=output_attentions,
1344
+ output_hidden_states=output_hidden_states,
1345
+ output_router_logits=output_router_logits,
1346
+ return_dict=return_dict,
1347
+ cache_position=cache_position,
1348
+ )
1349
+
1350
+ hidden_states = outputs[0]
1351
+ # No upscaling to float was ever done for Dbrx
1352
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1353
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ loss = self.loss_function(
1358
+ logits,
1359
+ labels,
1360
+ vocab_size=self.config.vocab_size,
1361
+ **kwargs,
1362
+ )
1363
+
1364
+ aux_loss = None
1365
+ if output_router_logits:
1366
+ aux_loss = load_balancing_loss_func(
1367
+ outputs.router_logits if return_dict else outputs[-1],
1368
+ self.num_experts,
1369
+ self.num_experts_per_tok,
1370
+ attention_mask,
1371
+ )
1372
+ if labels is not None and loss is not None:
1373
+ loss += self.moe_loss_weight * aux_loss.to(loss.device) # make sure to reside in the same device
1374
+
1375
+ if not return_dict:
1376
+ output = (logits,) + outputs[1:]
1377
+ if output_router_logits:
1378
+ output = (aux_loss,) + output
1379
+ return (loss,) + output if loss is not None else output
1380
+
1381
+ return MoeCausalLMOutputWithPast(
1382
+ loss=loss,
1383
+ aux_loss=aux_loss,
1384
+ logits=logits,
1385
+ past_key_values=outputs.past_key_values,
1386
+ hidden_states=outputs.hidden_states,
1387
+ attentions=outputs.attentions,
1388
+ router_logits=outputs.router_logits,
1389
+ )
1390
+
1391
+
1392
+ __all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]
docs/transformers/build/lib/transformers/models/deberta/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_deberta import *
22
+ from .modeling_deberta import *
23
+ from .modeling_tf_deberta import *
24
+ from .tokenization_deberta import *
25
+ from .tokenization_deberta_fast import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deberta/configuration_deberta.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020, Microsoft and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DeBERTa model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class DebertaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`DebertaModel`] or a [`TFDebertaModel`]. It is
35
+ used to instantiate a DeBERTa model according to the specified arguments, defining the model architecture.
36
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa
37
+ [microsoft/deberta-base](https://huggingface.co/microsoft/deberta-base) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Arguments:
43
+ vocab_size (`int`, *optional*, defaults to 50265):
44
+ Vocabulary size of the DeBERTa model. Defines the number of different tokens that can be represented by the
45
+ `inputs_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
46
+ hidden_size (`int`, *optional*, defaults to 768):
47
+ Dimensionality of the encoder layers and the pooler layer.
48
+ num_hidden_layers (`int`, *optional*, defaults to 12):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 12):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ intermediate_size (`int`, *optional*, defaults to 3072):
53
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
54
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
55
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"`
57
+ are supported.
58
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
61
+ The dropout ratio for the attention probabilities.
62
+ max_position_embeddings (`int`, *optional*, defaults to 512):
63
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
64
+ just in case (e.g., 512 or 1024 or 2048).
65
+ type_vocab_size (`int`, *optional*, defaults to 0):
66
+ The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
70
+ The epsilon used by the layer normalization layers.
71
+ relative_attention (`bool`, *optional*, defaults to `False`):
72
+ Whether use relative position encoding.
73
+ max_relative_positions (`int`, *optional*, defaults to 1):
74
+ The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value
75
+ as `max_position_embeddings`.
76
+ pad_token_id (`int`, *optional*, defaults to 0):
77
+ The value used to pad input_ids.
78
+ position_biased_input (`bool`, *optional*, defaults to `True`):
79
+ Whether add absolute position embedding to content embedding.
80
+ pos_att_type (`List[str]`, *optional*):
81
+ The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
82
+ `["p2c", "c2p"]`.
83
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
84
+ The epsilon used by the layer normalization layers.
85
+ legacy (`bool`, *optional*, defaults to `True`):
86
+ Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly
87
+ for mask infilling tasks.
88
+
89
+ Example:
90
+
91
+ ```python
92
+ >>> from transformers import DebertaConfig, DebertaModel
93
+
94
+ >>> # Initializing a DeBERTa microsoft/deberta-base style configuration
95
+ >>> configuration = DebertaConfig()
96
+
97
+ >>> # Initializing a model (with random weights) from the microsoft/deberta-base style configuration
98
+ >>> model = DebertaModel(configuration)
99
+
100
+ >>> # Accessing the model configuration
101
+ >>> configuration = model.config
102
+ ```"""
103
+
104
+ model_type = "deberta"
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=50265,
109
+ hidden_size=768,
110
+ num_hidden_layers=12,
111
+ num_attention_heads=12,
112
+ intermediate_size=3072,
113
+ hidden_act="gelu",
114
+ hidden_dropout_prob=0.1,
115
+ attention_probs_dropout_prob=0.1,
116
+ max_position_embeddings=512,
117
+ type_vocab_size=0,
118
+ initializer_range=0.02,
119
+ layer_norm_eps=1e-7,
120
+ relative_attention=False,
121
+ max_relative_positions=-1,
122
+ pad_token_id=0,
123
+ position_biased_input=True,
124
+ pos_att_type=None,
125
+ pooler_dropout=0,
126
+ pooler_hidden_act="gelu",
127
+ legacy=True,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(**kwargs)
131
+
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.intermediate_size = intermediate_size
136
+ self.hidden_act = hidden_act
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.relative_attention = relative_attention
143
+ self.max_relative_positions = max_relative_positions
144
+ self.pad_token_id = pad_token_id
145
+ self.position_biased_input = position_biased_input
146
+
147
+ # Backwards compatibility
148
+ if isinstance(pos_att_type, str):
149
+ pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
150
+
151
+ self.pos_att_type = pos_att_type
152
+ self.vocab_size = vocab_size
153
+ self.layer_norm_eps = layer_norm_eps
154
+
155
+ self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
156
+ self.pooler_dropout = pooler_dropout
157
+ self.pooler_hidden_act = pooler_hidden_act
158
+ self.legacy = legacy
159
+
160
+
161
+ # Copied from transformers.models.deberta_v2.configuration_deberta_v2.DebertaV2OnnxConfig
162
+ class DebertaOnnxConfig(OnnxConfig):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ if self.task == "multiple-choice":
166
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
167
+ else:
168
+ dynamic_axis = {0: "batch", 1: "sequence"}
169
+ if self._config.type_vocab_size > 0:
170
+ return OrderedDict(
171
+ [("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
172
+ )
173
+ else:
174
+ return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
175
+
176
+ @property
177
+ def default_onnx_opset(self) -> int:
178
+ return 12
179
+
180
+ def generate_dummy_inputs(
181
+ self,
182
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
183
+ batch_size: int = -1,
184
+ seq_length: int = -1,
185
+ num_choices: int = -1,
186
+ is_pair: bool = False,
187
+ framework: Optional["TensorType"] = None,
188
+ num_channels: int = 3,
189
+ image_width: int = 40,
190
+ image_height: int = 40,
191
+ tokenizer: "PreTrainedTokenizerBase" = None,
192
+ ) -> Mapping[str, Any]:
193
+ dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
194
+ if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
195
+ del dummy_inputs["token_type_ids"]
196
+ return dummy_inputs
197
+
198
+
199
+ __all__ = ["DebertaConfig", "DebertaOnnxConfig"]
docs/transformers/build/lib/transformers/models/deberta/modeling_deberta.py ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DeBERTa model."""
16
+
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import (
26
+ BaseModelOutput,
27
+ MaskedLMOutput,
28
+ QuestionAnsweringModelOutput,
29
+ SequenceClassifierOutput,
30
+ TokenClassifierOutput,
31
+ )
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
34
+ from .configuration_deberta import DebertaConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+ _CONFIG_FOR_DOC = "DebertaConfig"
39
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
40
+
41
+ # Masked LM docstring
42
+ _CHECKPOINT_FOR_MASKED_LM = "lsanochkin/deberta-large-feedback"
43
+ _MASKED_LM_EXPECTED_OUTPUT = "' Paris'"
44
+ _MASKED_LM_EXPECTED_LOSS = "0.54"
45
+
46
+ # QuestionAnswering docstring
47
+ _CHECKPOINT_FOR_QA = "Palak/microsoft_deberta-large_squad"
48
+ _QA_EXPECTED_OUTPUT = "' a nice puppet'"
49
+ _QA_EXPECTED_LOSS = 0.14
50
+ _QA_TARGET_START_INDEX = 12
51
+ _QA_TARGET_END_INDEX = 14
52
+
53
+
54
+ class DebertaLayerNorm(nn.Module):
55
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
56
+
57
+ def __init__(self, size, eps=1e-12):
58
+ super().__init__()
59
+ self.weight = nn.Parameter(torch.ones(size))
60
+ self.bias = nn.Parameter(torch.zeros(size))
61
+ self.variance_epsilon = eps
62
+
63
+ def forward(self, hidden_states):
64
+ input_type = hidden_states.dtype
65
+ hidden_states = hidden_states.float()
66
+ mean = hidden_states.mean(-1, keepdim=True)
67
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
68
+ hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
69
+ hidden_states = hidden_states.to(input_type)
70
+ y = self.weight * hidden_states + self.bias
71
+ return y
72
+
73
+
74
+ class DebertaSelfOutput(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
78
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
79
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
80
+
81
+ def forward(self, hidden_states, input_tensor):
82
+ hidden_states = self.dense(hidden_states)
83
+ hidden_states = self.dropout(hidden_states)
84
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
85
+ return hidden_states
86
+
87
+
88
+ @torch.jit.script
89
+ def build_relative_position(query_layer, key_layer):
90
+ """
91
+ Build relative position according to the query and key
92
+
93
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
94
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
95
+ P_k\\)
96
+
97
+ Args:
98
+ query_size (int): the length of query
99
+ key_size (int): the length of key
100
+
101
+ Return:
102
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
103
+
104
+ """
105
+
106
+ query_size = query_layer.size(-2)
107
+ key_size = key_layer.size(-2)
108
+
109
+ q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
110
+ k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
111
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
112
+ rel_pos_ids = rel_pos_ids[:query_size, :]
113
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
114
+ return rel_pos_ids
115
+
116
+
117
+ @torch.jit.script
118
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
119
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
120
+
121
+
122
+ @torch.jit.script
123
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
124
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
125
+
126
+
127
+ @torch.jit.script
128
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
129
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
130
+
131
+
132
+ ###### To support a general trace, we have to define these operation as they use python objects (sizes) ##################
133
+ # which are not supported by torch.jit.trace.
134
+ # Full credits to @Szustarol
135
+ @torch.jit.script
136
+ def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
137
+ return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
138
+
139
+
140
+ @torch.jit.script
141
+ def build_rpos(query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
142
+ if query_layer.size(-2) != key_layer.size(-2):
143
+ return build_relative_position(query_layer, key_layer)
144
+ else:
145
+ return relative_pos
146
+
147
+
148
+ @torch.jit.script
149
+ def compute_attention_span(query_layer: torch.Tensor, key_layer: torch.Tensor, max_relative_positions: int):
150
+ return torch.tensor(min(max(query_layer.size(-2), key_layer.size(-2)), max_relative_positions))
151
+
152
+
153
+ @torch.jit.script
154
+ def uneven_size_corrected(p2c_att, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
155
+ if query_layer.size(-2) != key_layer.size(-2):
156
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
157
+ return torch.gather(p2c_att, dim=2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
158
+ else:
159
+ return p2c_att
160
+
161
+
162
+ ########################################################################################################################
163
+
164
+
165
+ class DisentangledSelfAttention(nn.Module):
166
+ """
167
+ Disentangled self-attention module
168
+
169
+ Parameters:
170
+ config (`str`):
171
+ A model config class instance with the configuration to build a new model. The schema is similar to
172
+ *BertConfig*, for more details, please refer [`DebertaConfig`]
173
+
174
+ """
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ if config.hidden_size % config.num_attention_heads != 0:
179
+ raise ValueError(
180
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
181
+ f"heads ({config.num_attention_heads})"
182
+ )
183
+ self.num_attention_heads = config.num_attention_heads
184
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
185
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
186
+ self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
187
+ self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
188
+ self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
189
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
190
+
191
+ self.relative_attention = getattr(config, "relative_attention", False)
192
+ self.talking_head = getattr(config, "talking_head", False)
193
+
194
+ if self.talking_head:
195
+ self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
196
+ self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
197
+ else:
198
+ self.head_logits_proj = None
199
+ self.head_weights_proj = None
200
+
201
+ if self.relative_attention:
202
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
203
+ if self.max_relative_positions < 1:
204
+ self.max_relative_positions = config.max_position_embeddings
205
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
206
+
207
+ if "c2p" in self.pos_att_type:
208
+ self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
209
+ if "p2c" in self.pos_att_type:
210
+ self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
211
+
212
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
213
+
214
+ def transpose_for_scores(self, x):
215
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
216
+ x = x.view(new_x_shape)
217
+ return x.permute(0, 2, 1, 3)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ attention_mask: torch.Tensor,
223
+ output_attentions: bool = False,
224
+ query_states: Optional[torch.Tensor] = None,
225
+ relative_pos: Optional[torch.Tensor] = None,
226
+ rel_embeddings: Optional[torch.Tensor] = None,
227
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
228
+ """
229
+ Call the module
230
+
231
+ Args:
232
+ hidden_states (`torch.FloatTensor`):
233
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
234
+ *Attention(Q,K,V)*
235
+
236
+ attention_mask (`torch.BoolTensor`):
237
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
238
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
239
+ th token.
240
+
241
+ output_attentions (`bool`, *optional*):
242
+ Whether return the attention matrix.
243
+
244
+ query_states (`torch.FloatTensor`, *optional*):
245
+ The *Q* state in *Attention(Q,K,V)*.
246
+
247
+ relative_pos (`torch.LongTensor`):
248
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
249
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
250
+
251
+ rel_embeddings (`torch.FloatTensor`):
252
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
253
+ \\text{max_relative_positions}\\), *hidden_size*].
254
+
255
+
256
+ """
257
+ if query_states is None:
258
+ qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
259
+ query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
260
+ else:
261
+ ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
262
+ qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
263
+ q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype))
264
+ k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype))
265
+ v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype))
266
+ query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
267
+
268
+ query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
269
+ value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
270
+
271
+ rel_att: int = 0
272
+ # Take the dot product between "query" and "key" to get the raw attention scores.
273
+ scale_factor = 1 + len(self.pos_att_type)
274
+ scale = scaled_size_sqrt(query_layer, scale_factor)
275
+ query_layer = query_layer / scale.to(dtype=query_layer.dtype)
276
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
277
+
278
+ if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
279
+ rel_embeddings = self.pos_dropout(rel_embeddings)
280
+ rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
281
+
282
+ if rel_att is not None:
283
+ attention_scores = attention_scores + rel_att
284
+
285
+ # bxhxlxd
286
+ if self.head_logits_proj is not None:
287
+ attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
288
+
289
+ attention_mask = attention_mask.bool()
290
+ attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
291
+ # bsz x height x length x dimension
292
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
293
+
294
+ attention_probs = self.dropout(attention_probs)
295
+ if self.head_weights_proj is not None:
296
+ attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
297
+
298
+ context_layer = torch.matmul(attention_probs, value_layer)
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+ if not output_attentions:
303
+ return (context_layer, None)
304
+ return (context_layer, attention_probs)
305
+
306
+ def disentangled_att_bias(
307
+ self,
308
+ query_layer: torch.Tensor,
309
+ key_layer: torch.Tensor,
310
+ relative_pos: torch.Tensor,
311
+ rel_embeddings: torch.Tensor,
312
+ scale_factor: int,
313
+ ):
314
+ if relative_pos is None:
315
+ relative_pos = build_relative_position(query_layer, key_layer, query_layer.device)
316
+ if relative_pos.dim() == 2:
317
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
318
+ elif relative_pos.dim() == 3:
319
+ relative_pos = relative_pos.unsqueeze(1)
320
+ # bxhxqxk
321
+ elif relative_pos.dim() != 4:
322
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
323
+
324
+ att_span = compute_attention_span(query_layer, key_layer, self.max_relative_positions)
325
+ relative_pos = relative_pos.long()
326
+ rel_embeddings = rel_embeddings[
327
+ self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
328
+ ].unsqueeze(0)
329
+
330
+ score = 0
331
+
332
+ # content->position
333
+ if "c2p" in self.pos_att_type:
334
+ pos_key_layer = self.pos_proj(rel_embeddings)
335
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
336
+ c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
337
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
338
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
339
+ score += c2p_att
340
+
341
+ # position->content
342
+ if "p2c" in self.pos_att_type:
343
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
344
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
345
+ pos_query_layer /= scaled_size_sqrt(pos_query_layer, scale_factor)
346
+ r_pos = build_rpos(
347
+ query_layer,
348
+ key_layer,
349
+ relative_pos,
350
+ )
351
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
352
+ p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype))
353
+ p2c_att = torch.gather(
354
+ p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
355
+ ).transpose(-1, -2)
356
+
357
+ p2c_att = uneven_size_corrected(p2c_att, query_layer, key_layer, relative_pos)
358
+ score += p2c_att
359
+
360
+ return score
361
+
362
+
363
+ class DebertaEmbeddings(nn.Module):
364
+ """Construct the embeddings from word, position and token_type embeddings."""
365
+
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ pad_token_id = getattr(config, "pad_token_id", 0)
369
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
370
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
371
+
372
+ self.position_biased_input = getattr(config, "position_biased_input", True)
373
+ if not self.position_biased_input:
374
+ self.position_embeddings = None
375
+ else:
376
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
377
+
378
+ if config.type_vocab_size > 0:
379
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
380
+ else:
381
+ self.token_type_embeddings = None
382
+
383
+ if self.embedding_size != config.hidden_size:
384
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
385
+ else:
386
+ self.embed_proj = None
387
+
388
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
389
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
390
+ self.config = config
391
+
392
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
393
+ self.register_buffer(
394
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
395
+ )
396
+
397
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
398
+ if input_ids is not None:
399
+ input_shape = input_ids.size()
400
+ else:
401
+ input_shape = inputs_embeds.size()[:-1]
402
+
403
+ seq_length = input_shape[1]
404
+
405
+ if position_ids is None:
406
+ position_ids = self.position_ids[:, :seq_length]
407
+
408
+ if token_type_ids is None:
409
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
410
+
411
+ if inputs_embeds is None:
412
+ inputs_embeds = self.word_embeddings(input_ids)
413
+
414
+ if self.position_embeddings is not None:
415
+ position_embeddings = self.position_embeddings(position_ids.long())
416
+ else:
417
+ position_embeddings = torch.zeros_like(inputs_embeds)
418
+
419
+ embeddings = inputs_embeds
420
+ if self.position_biased_input:
421
+ embeddings = embeddings + position_embeddings
422
+ if self.token_type_embeddings is not None:
423
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
424
+ embeddings = embeddings + token_type_embeddings
425
+
426
+ if self.embed_proj is not None:
427
+ embeddings = self.embed_proj(embeddings)
428
+
429
+ embeddings = self.LayerNorm(embeddings)
430
+
431
+ if mask is not None:
432
+ if mask.dim() != embeddings.dim():
433
+ if mask.dim() == 4:
434
+ mask = mask.squeeze(1).squeeze(1)
435
+ mask = mask.unsqueeze(2)
436
+ mask = mask.to(embeddings.dtype)
437
+
438
+ embeddings = embeddings * mask
439
+
440
+ embeddings = self.dropout(embeddings)
441
+ return embeddings
442
+
443
+
444
+ class DebertaAttention(nn.Module):
445
+ def __init__(self, config):
446
+ super().__init__()
447
+ self.self = DisentangledSelfAttention(config)
448
+ self.output = DebertaSelfOutput(config)
449
+ self.config = config
450
+
451
+ def forward(
452
+ self,
453
+ hidden_states,
454
+ attention_mask,
455
+ output_attentions: bool = False,
456
+ query_states=None,
457
+ relative_pos=None,
458
+ rel_embeddings=None,
459
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
460
+ self_output, att_matrix = self.self(
461
+ hidden_states,
462
+ attention_mask,
463
+ output_attentions,
464
+ query_states=query_states,
465
+ relative_pos=relative_pos,
466
+ rel_embeddings=rel_embeddings,
467
+ )
468
+ if query_states is None:
469
+ query_states = hidden_states
470
+ attention_output = self.output(self_output, query_states)
471
+
472
+ if output_attentions:
473
+ return (attention_output, att_matrix)
474
+ else:
475
+ return (attention_output, None)
476
+
477
+
478
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
479
+ class DebertaIntermediate(nn.Module):
480
+ def __init__(self, config):
481
+ super().__init__()
482
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
483
+ if isinstance(config.hidden_act, str):
484
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
485
+ else:
486
+ self.intermediate_act_fn = config.hidden_act
487
+
488
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
489
+ hidden_states = self.dense(hidden_states)
490
+ hidden_states = self.intermediate_act_fn(hidden_states)
491
+ return hidden_states
492
+
493
+
494
+ class DebertaOutput(nn.Module):
495
+ def __init__(self, config):
496
+ super().__init__()
497
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
498
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
499
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
500
+ self.config = config
501
+
502
+ def forward(self, hidden_states, input_tensor):
503
+ hidden_states = self.dense(hidden_states)
504
+ hidden_states = self.dropout(hidden_states)
505
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
506
+ return hidden_states
507
+
508
+
509
+ class DebertaLayer(nn.Module):
510
+ def __init__(self, config):
511
+ super().__init__()
512
+ self.attention = DebertaAttention(config)
513
+ self.intermediate = DebertaIntermediate(config)
514
+ self.output = DebertaOutput(config)
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states,
519
+ attention_mask,
520
+ query_states=None,
521
+ relative_pos=None,
522
+ rel_embeddings=None,
523
+ output_attentions: bool = False,
524
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
525
+ attention_output, att_matrix = self.attention(
526
+ hidden_states,
527
+ attention_mask,
528
+ output_attentions=output_attentions,
529
+ query_states=query_states,
530
+ relative_pos=relative_pos,
531
+ rel_embeddings=rel_embeddings,
532
+ )
533
+ intermediate_output = self.intermediate(attention_output)
534
+ layer_output = self.output(intermediate_output, attention_output)
535
+
536
+ if output_attentions:
537
+ return (layer_output, att_matrix)
538
+ else:
539
+ return (layer_output, None)
540
+
541
+
542
+ class DebertaEncoder(nn.Module):
543
+ """Modified BertEncoder with relative position bias support"""
544
+
545
+ def __init__(self, config):
546
+ super().__init__()
547
+ self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
548
+ self.relative_attention = getattr(config, "relative_attention", False)
549
+ if self.relative_attention:
550
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
551
+ if self.max_relative_positions < 1:
552
+ self.max_relative_positions = config.max_position_embeddings
553
+ self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
554
+ self.gradient_checkpointing = False
555
+
556
+ def get_rel_embedding(self):
557
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
558
+ return rel_embeddings
559
+
560
+ def get_attention_mask(self, attention_mask):
561
+ if attention_mask.dim() <= 2:
562
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
563
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
564
+ elif attention_mask.dim() == 3:
565
+ attention_mask = attention_mask.unsqueeze(1)
566
+
567
+ return attention_mask
568
+
569
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
570
+ if self.relative_attention and relative_pos is None:
571
+ if query_states is not None:
572
+ relative_pos = build_relative_position(query_states, hidden_states)
573
+ else:
574
+ relative_pos = build_relative_position(hidden_states, hidden_states)
575
+ return relative_pos
576
+
577
+ def forward(
578
+ self,
579
+ hidden_states: torch.Tensor,
580
+ attention_mask: torch.Tensor,
581
+ output_hidden_states: bool = True,
582
+ output_attentions: bool = False,
583
+ query_states=None,
584
+ relative_pos=None,
585
+ return_dict: bool = True,
586
+ ):
587
+ attention_mask = self.get_attention_mask(attention_mask)
588
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
589
+
590
+ all_hidden_states: Optional[Tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None
591
+ all_attentions = () if output_attentions else None
592
+
593
+ next_kv = hidden_states
594
+
595
+ rel_embeddings = self.get_rel_embedding()
596
+ for i, layer_module in enumerate(self.layer):
597
+ if self.gradient_checkpointing and self.training:
598
+ hidden_states, att_m = self._gradient_checkpointing_func(
599
+ layer_module.__call__,
600
+ next_kv,
601
+ attention_mask,
602
+ query_states,
603
+ relative_pos,
604
+ rel_embeddings,
605
+ output_attentions,
606
+ )
607
+ else:
608
+ hidden_states, att_m = layer_module(
609
+ next_kv,
610
+ attention_mask,
611
+ query_states=query_states,
612
+ relative_pos=relative_pos,
613
+ rel_embeddings=rel_embeddings,
614
+ output_attentions=output_attentions,
615
+ )
616
+
617
+ if output_hidden_states:
618
+ all_hidden_states = all_hidden_states + (hidden_states,)
619
+
620
+ if query_states is not None:
621
+ query_states = hidden_states
622
+ else:
623
+ next_kv = hidden_states
624
+
625
+ if output_attentions:
626
+ all_attentions = all_attentions + (att_m,)
627
+
628
+ if not return_dict:
629
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
630
+ return BaseModelOutput(
631
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
632
+ )
633
+
634
+
635
+ class DebertaPreTrainedModel(PreTrainedModel):
636
+ """
637
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
638
+ models.
639
+ """
640
+
641
+ config_class = DebertaConfig
642
+ base_model_prefix = "deberta"
643
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
644
+ supports_gradient_checkpointing = True
645
+
646
+ def _init_weights(self, module):
647
+ """Initialize the weights."""
648
+ if isinstance(module, nn.Linear):
649
+ # Slightly different from the TF version which uses truncated_normal for initialization
650
+ # cf https://github.com/pytorch/pytorch/pull/5617
651
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
652
+ if module.bias is not None:
653
+ module.bias.data.zero_()
654
+ elif isinstance(module, nn.Embedding):
655
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
656
+ if module.padding_idx is not None:
657
+ module.weight.data[module.padding_idx].zero_()
658
+ elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)):
659
+ module.weight.data.fill_(1.0)
660
+ module.bias.data.zero_()
661
+ elif isinstance(module, DisentangledSelfAttention):
662
+ module.q_bias.data.zero_()
663
+ module.v_bias.data.zero_()
664
+ elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
665
+ module.bias.data.zero_()
666
+
667
+
668
+ DEBERTA_START_DOCSTRING = r"""
669
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
670
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
671
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
672
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
673
+
674
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
675
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
676
+ and behavior.
677
+
678
+
679
+ Parameters:
680
+ config ([`DebertaConfig`]): Model configuration class with all the parameters of the model.
681
+ Initializing with a config file does not load the weights associated with the model, only the
682
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
683
+ """
684
+
685
+ DEBERTA_INPUTS_DOCSTRING = r"""
686
+ Args:
687
+ input_ids (`torch.LongTensor` of shape `({0})`):
688
+ Indices of input sequence tokens in the vocabulary.
689
+
690
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
691
+ [`PreTrainedTokenizer.__call__`] for details.
692
+
693
+ [What are input IDs?](../glossary#input-ids)
694
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
695
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
696
+
697
+ - 1 for tokens that are **not masked**,
698
+ - 0 for tokens that are **masked**.
699
+
700
+ [What are attention masks?](../glossary#attention-mask)
701
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
702
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
703
+ 1]`:
704
+
705
+ - 0 corresponds to a *sentence A* token,
706
+ - 1 corresponds to a *sentence B* token.
707
+
708
+ [What are token type IDs?](../glossary#token-type-ids)
709
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
710
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
711
+ config.max_position_embeddings - 1]`.
712
+
713
+ [What are position IDs?](../glossary#position-ids)
714
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
715
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
716
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
717
+ model's internal embedding lookup matrix.
718
+ output_attentions (`bool`, *optional*):
719
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
720
+ tensors for more detail.
721
+ output_hidden_states (`bool`, *optional*):
722
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
723
+ more detail.
724
+ return_dict (`bool`, *optional*):
725
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
726
+ """
727
+
728
+
729
+ @add_start_docstrings(
730
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
731
+ DEBERTA_START_DOCSTRING,
732
+ )
733
+ class DebertaModel(DebertaPreTrainedModel):
734
+ def __init__(self, config):
735
+ super().__init__(config)
736
+
737
+ self.embeddings = DebertaEmbeddings(config)
738
+ self.encoder = DebertaEncoder(config)
739
+ self.z_steps = 0
740
+ self.config = config
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self):
745
+ return self.embeddings.word_embeddings
746
+
747
+ def set_input_embeddings(self, new_embeddings):
748
+ self.embeddings.word_embeddings = new_embeddings
749
+
750
+ def _prune_heads(self, heads_to_prune):
751
+ """
752
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
753
+ class PreTrainedModel
754
+ """
755
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
756
+
757
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
758
+ @add_code_sample_docstrings(
759
+ checkpoint=_CHECKPOINT_FOR_DOC,
760
+ output_type=BaseModelOutput,
761
+ config_class=_CONFIG_FOR_DOC,
762
+ )
763
+ def forward(
764
+ self,
765
+ input_ids: Optional[torch.Tensor] = None,
766
+ attention_mask: Optional[torch.Tensor] = None,
767
+ token_type_ids: Optional[torch.Tensor] = None,
768
+ position_ids: Optional[torch.Tensor] = None,
769
+ inputs_embeds: Optional[torch.Tensor] = None,
770
+ output_attentions: Optional[bool] = None,
771
+ output_hidden_states: Optional[bool] = None,
772
+ return_dict: Optional[bool] = None,
773
+ ) -> Union[Tuple, BaseModelOutput]:
774
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
775
+ output_hidden_states = (
776
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
777
+ )
778
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
779
+
780
+ if input_ids is not None and inputs_embeds is not None:
781
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
782
+ elif input_ids is not None:
783
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
784
+ input_shape = input_ids.size()
785
+ elif inputs_embeds is not None:
786
+ input_shape = inputs_embeds.size()[:-1]
787
+ else:
788
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
789
+
790
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
791
+
792
+ if attention_mask is None:
793
+ attention_mask = torch.ones(input_shape, device=device)
794
+ if token_type_ids is None:
795
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
796
+
797
+ embedding_output = self.embeddings(
798
+ input_ids=input_ids,
799
+ token_type_ids=token_type_ids,
800
+ position_ids=position_ids,
801
+ mask=attention_mask,
802
+ inputs_embeds=inputs_embeds,
803
+ )
804
+
805
+ encoder_outputs = self.encoder(
806
+ embedding_output,
807
+ attention_mask,
808
+ output_hidden_states=True,
809
+ output_attentions=output_attentions,
810
+ return_dict=return_dict,
811
+ )
812
+ encoded_layers = encoder_outputs[1]
813
+
814
+ if self.z_steps > 1:
815
+ hidden_states = encoded_layers[-2]
816
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
817
+ query_states = encoded_layers[-1]
818
+ rel_embeddings = self.encoder.get_rel_embedding()
819
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
820
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
821
+ for layer in layers[1:]:
822
+ query_states = layer(
823
+ hidden_states,
824
+ attention_mask,
825
+ output_attentions=False,
826
+ query_states=query_states,
827
+ relative_pos=rel_pos,
828
+ rel_embeddings=rel_embeddings,
829
+ )
830
+ encoded_layers.append(query_states)
831
+
832
+ sequence_output = encoded_layers[-1]
833
+
834
+ if not return_dict:
835
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
836
+
837
+ return BaseModelOutput(
838
+ last_hidden_state=sequence_output,
839
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
840
+ attentions=encoder_outputs.attentions,
841
+ )
842
+
843
+
844
+ class LegacyDebertaPredictionHeadTransform(nn.Module):
845
+ def __init__(self, config):
846
+ super().__init__()
847
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
848
+
849
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
850
+ if isinstance(config.hidden_act, str):
851
+ self.transform_act_fn = ACT2FN[config.hidden_act]
852
+ else:
853
+ self.transform_act_fn = config.hidden_act
854
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
855
+
856
+ def forward(self, hidden_states):
857
+ hidden_states = self.dense(hidden_states)
858
+ hidden_states = self.transform_act_fn(hidden_states)
859
+ hidden_states = self.LayerNorm(hidden_states)
860
+ return hidden_states
861
+
862
+
863
+ class LegacyDebertaLMPredictionHead(nn.Module):
864
+ def __init__(self, config):
865
+ super().__init__()
866
+ self.transform = LegacyDebertaPredictionHeadTransform(config)
867
+
868
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
869
+ # The output weights are the same as the input embeddings, but there is
870
+ # an output-only bias for each token.
871
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
872
+
873
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
874
+
875
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
876
+ self.decoder.bias = self.bias
877
+
878
+ def _tie_weights(self):
879
+ self.decoder.bias = self.bias
880
+
881
+ def forward(self, hidden_states):
882
+ hidden_states = self.transform(hidden_states)
883
+ hidden_states = self.decoder(hidden_states)
884
+ return hidden_states
885
+
886
+
887
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LegacyDeberta
888
+ class LegacyDebertaOnlyMLMHead(nn.Module):
889
+ def __init__(self, config):
890
+ super().__init__()
891
+ self.predictions = LegacyDebertaLMPredictionHead(config)
892
+
893
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
894
+ prediction_scores = self.predictions(sequence_output)
895
+ return prediction_scores
896
+
897
+
898
+ class DebertaLMPredictionHead(nn.Module):
899
+ """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
900
+
901
+ def __init__(self, config):
902
+ super().__init__()
903
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
904
+
905
+ if isinstance(config.hidden_act, str):
906
+ self.transform_act_fn = ACT2FN[config.hidden_act]
907
+ else:
908
+ self.transform_act_fn = config.hidden_act
909
+
910
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
911
+
912
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
913
+
914
+ # note that the input embeddings must be passed as an argument
915
+ def forward(self, hidden_states, word_embeddings):
916
+ hidden_states = self.dense(hidden_states)
917
+ hidden_states = self.transform_act_fn(hidden_states)
918
+ hidden_states = self.LayerNorm(
919
+ hidden_states
920
+ ) # original used MaskedLayerNorm, but passed no mask. This is equivalent.
921
+ hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
922
+ return hidden_states
923
+
924
+
925
+ class DebertaOnlyMLMHead(nn.Module):
926
+ def __init__(self, config):
927
+ super().__init__()
928
+ self.lm_head = DebertaLMPredictionHead(config)
929
+
930
+ # note that the input embeddings must be passed as an argument
931
+ def forward(self, sequence_output, word_embeddings):
932
+ prediction_scores = self.lm_head(sequence_output, word_embeddings)
933
+ return prediction_scores
934
+
935
+
936
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
937
+ class DebertaForMaskedLM(DebertaPreTrainedModel):
938
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
939
+
940
+ def __init__(self, config):
941
+ super().__init__(config)
942
+ self.legacy = config.legacy
943
+ self.deberta = DebertaModel(config)
944
+ if self.legacy:
945
+ self.cls = LegacyDebertaOnlyMLMHead(config)
946
+ else:
947
+ self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"]
948
+ self.lm_predictions = DebertaOnlyMLMHead(config)
949
+
950
+ # Initialize weights and apply final processing
951
+ self.post_init()
952
+
953
+ def get_output_embeddings(self):
954
+ if self.legacy:
955
+ return self.cls.predictions.decoder
956
+ else:
957
+ return self.lm_predictions.lm_head.dense
958
+
959
+ def set_output_embeddings(self, new_embeddings):
960
+ if self.legacy:
961
+ self.cls.predictions.decoder = new_embeddings
962
+ self.cls.predictions.bias = new_embeddings.bias
963
+ else:
964
+ self.lm_predictions.lm_head.dense = new_embeddings
965
+ self.lm_predictions.lm_head.bias = new_embeddings.bias
966
+
967
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
968
+ @add_code_sample_docstrings(
969
+ checkpoint=_CHECKPOINT_FOR_MASKED_LM,
970
+ output_type=MaskedLMOutput,
971
+ config_class=_CONFIG_FOR_DOC,
972
+ mask="[MASK]",
973
+ expected_output=_MASKED_LM_EXPECTED_OUTPUT,
974
+ expected_loss=_MASKED_LM_EXPECTED_LOSS,
975
+ )
976
+ def forward(
977
+ self,
978
+ input_ids: Optional[torch.Tensor] = None,
979
+ attention_mask: Optional[torch.Tensor] = None,
980
+ token_type_ids: Optional[torch.Tensor] = None,
981
+ position_ids: Optional[torch.Tensor] = None,
982
+ inputs_embeds: Optional[torch.Tensor] = None,
983
+ labels: Optional[torch.Tensor] = None,
984
+ output_attentions: Optional[bool] = None,
985
+ output_hidden_states: Optional[bool] = None,
986
+ return_dict: Optional[bool] = None,
987
+ ) -> Union[Tuple, MaskedLMOutput]:
988
+ r"""
989
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
990
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
991
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
992
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
993
+ """
994
+
995
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
996
+
997
+ outputs = self.deberta(
998
+ input_ids,
999
+ attention_mask=attention_mask,
1000
+ token_type_ids=token_type_ids,
1001
+ position_ids=position_ids,
1002
+ inputs_embeds=inputs_embeds,
1003
+ output_attentions=output_attentions,
1004
+ output_hidden_states=output_hidden_states,
1005
+ return_dict=return_dict,
1006
+ )
1007
+
1008
+ sequence_output = outputs[0]
1009
+ if self.legacy:
1010
+ prediction_scores = self.cls(sequence_output)
1011
+ else:
1012
+ prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
1013
+
1014
+ masked_lm_loss = None
1015
+ if labels is not None:
1016
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1017
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1018
+
1019
+ if not return_dict:
1020
+ output = (prediction_scores,) + outputs[1:]
1021
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1022
+
1023
+ return MaskedLMOutput(
1024
+ loss=masked_lm_loss,
1025
+ logits=prediction_scores,
1026
+ hidden_states=outputs.hidden_states,
1027
+ attentions=outputs.attentions,
1028
+ )
1029
+
1030
+
1031
+ class ContextPooler(nn.Module):
1032
+ def __init__(self, config):
1033
+ super().__init__()
1034
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
1035
+ self.dropout = nn.Dropout(config.pooler_dropout)
1036
+ self.config = config
1037
+
1038
+ def forward(self, hidden_states):
1039
+ # We "pool" the model by simply taking the hidden state corresponding
1040
+ # to the first token.
1041
+
1042
+ context_token = hidden_states[:, 0]
1043
+ context_token = self.dropout(context_token)
1044
+ pooled_output = self.dense(context_token)
1045
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
1046
+ return pooled_output
1047
+
1048
+ @property
1049
+ def output_dim(self):
1050
+ return self.config.hidden_size
1051
+
1052
+
1053
+ @add_start_docstrings(
1054
+ """
1055
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1056
+ pooled output) e.g. for GLUE tasks.
1057
+ """,
1058
+ DEBERTA_START_DOCSTRING,
1059
+ )
1060
+ class DebertaForSequenceClassification(DebertaPreTrainedModel):
1061
+ def __init__(self, config):
1062
+ super().__init__(config)
1063
+
1064
+ num_labels = getattr(config, "num_labels", 2)
1065
+ self.num_labels = num_labels
1066
+
1067
+ self.deberta = DebertaModel(config)
1068
+ self.pooler = ContextPooler(config)
1069
+ output_dim = self.pooler.output_dim
1070
+
1071
+ self.classifier = nn.Linear(output_dim, num_labels)
1072
+ drop_out = getattr(config, "cls_dropout", None)
1073
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1074
+ self.dropout = nn.Dropout(drop_out)
1075
+
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ def get_input_embeddings(self):
1080
+ return self.deberta.get_input_embeddings()
1081
+
1082
+ def set_input_embeddings(self, new_embeddings):
1083
+ self.deberta.set_input_embeddings(new_embeddings)
1084
+
1085
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1086
+ @add_code_sample_docstrings(
1087
+ checkpoint=_CHECKPOINT_FOR_DOC,
1088
+ output_type=SequenceClassifierOutput,
1089
+ config_class=_CONFIG_FOR_DOC,
1090
+ )
1091
+ def forward(
1092
+ self,
1093
+ input_ids: Optional[torch.Tensor] = None,
1094
+ attention_mask: Optional[torch.Tensor] = None,
1095
+ token_type_ids: Optional[torch.Tensor] = None,
1096
+ position_ids: Optional[torch.Tensor] = None,
1097
+ inputs_embeds: Optional[torch.Tensor] = None,
1098
+ labels: Optional[torch.Tensor] = None,
1099
+ output_attentions: Optional[bool] = None,
1100
+ output_hidden_states: Optional[bool] = None,
1101
+ return_dict: Optional[bool] = None,
1102
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1103
+ r"""
1104
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1105
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1106
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1107
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1108
+ """
1109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1110
+
1111
+ outputs = self.deberta(
1112
+ input_ids,
1113
+ token_type_ids=token_type_ids,
1114
+ attention_mask=attention_mask,
1115
+ position_ids=position_ids,
1116
+ inputs_embeds=inputs_embeds,
1117
+ output_attentions=output_attentions,
1118
+ output_hidden_states=output_hidden_states,
1119
+ return_dict=return_dict,
1120
+ )
1121
+
1122
+ encoder_layer = outputs[0]
1123
+ pooled_output = self.pooler(encoder_layer)
1124
+ pooled_output = self.dropout(pooled_output)
1125
+ logits = self.classifier(pooled_output)
1126
+
1127
+ loss = None
1128
+ if labels is not None:
1129
+ if self.config.problem_type is None:
1130
+ if self.num_labels == 1:
1131
+ # regression task
1132
+ loss_fn = nn.MSELoss()
1133
+ logits = logits.view(-1).to(labels.dtype)
1134
+ loss = loss_fn(logits, labels.view(-1))
1135
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1136
+ label_index = (labels >= 0).nonzero()
1137
+ labels = labels.long()
1138
+ if label_index.size(0) > 0:
1139
+ labeled_logits = torch.gather(
1140
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1141
+ )
1142
+ labels = torch.gather(labels, 0, label_index.view(-1))
1143
+ loss_fct = CrossEntropyLoss()
1144
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1145
+ else:
1146
+ loss = torch.tensor(0).to(logits)
1147
+ else:
1148
+ log_softmax = nn.LogSoftmax(-1)
1149
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1150
+ elif self.config.problem_type == "regression":
1151
+ loss_fct = MSELoss()
1152
+ if self.num_labels == 1:
1153
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1154
+ else:
1155
+ loss = loss_fct(logits, labels)
1156
+ elif self.config.problem_type == "single_label_classification":
1157
+ loss_fct = CrossEntropyLoss()
1158
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1159
+ elif self.config.problem_type == "multi_label_classification":
1160
+ loss_fct = BCEWithLogitsLoss()
1161
+ loss = loss_fct(logits, labels)
1162
+ if not return_dict:
1163
+ output = (logits,) + outputs[1:]
1164
+ return ((loss,) + output) if loss is not None else output
1165
+
1166
+ return SequenceClassifierOutput(
1167
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1168
+ )
1169
+
1170
+
1171
+ @add_start_docstrings(
1172
+ """
1173
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1174
+ Named-Entity-Recognition (NER) tasks.
1175
+ """,
1176
+ DEBERTA_START_DOCSTRING,
1177
+ )
1178
+ class DebertaForTokenClassification(DebertaPreTrainedModel):
1179
+ def __init__(self, config):
1180
+ super().__init__(config)
1181
+ self.num_labels = config.num_labels
1182
+
1183
+ self.deberta = DebertaModel(config)
1184
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1185
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1186
+
1187
+ # Initialize weights and apply final processing
1188
+ self.post_init()
1189
+
1190
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1191
+ @add_code_sample_docstrings(
1192
+ checkpoint=_CHECKPOINT_FOR_DOC,
1193
+ output_type=TokenClassifierOutput,
1194
+ config_class=_CONFIG_FOR_DOC,
1195
+ )
1196
+ def forward(
1197
+ self,
1198
+ input_ids: Optional[torch.Tensor] = None,
1199
+ attention_mask: Optional[torch.Tensor] = None,
1200
+ token_type_ids: Optional[torch.Tensor] = None,
1201
+ position_ids: Optional[torch.Tensor] = None,
1202
+ inputs_embeds: Optional[torch.Tensor] = None,
1203
+ labels: Optional[torch.Tensor] = None,
1204
+ output_attentions: Optional[bool] = None,
1205
+ output_hidden_states: Optional[bool] = None,
1206
+ return_dict: Optional[bool] = None,
1207
+ ) -> Union[Tuple, TokenClassifierOutput]:
1208
+ r"""
1209
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1210
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1211
+ """
1212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1213
+
1214
+ outputs = self.deberta(
1215
+ input_ids,
1216
+ attention_mask=attention_mask,
1217
+ token_type_ids=token_type_ids,
1218
+ position_ids=position_ids,
1219
+ inputs_embeds=inputs_embeds,
1220
+ output_attentions=output_attentions,
1221
+ output_hidden_states=output_hidden_states,
1222
+ return_dict=return_dict,
1223
+ )
1224
+
1225
+ sequence_output = outputs[0]
1226
+
1227
+ sequence_output = self.dropout(sequence_output)
1228
+ logits = self.classifier(sequence_output)
1229
+
1230
+ loss = None
1231
+ if labels is not None:
1232
+ loss_fct = CrossEntropyLoss()
1233
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1234
+
1235
+ if not return_dict:
1236
+ output = (logits,) + outputs[1:]
1237
+ return ((loss,) + output) if loss is not None else output
1238
+
1239
+ return TokenClassifierOutput(
1240
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1241
+ )
1242
+
1243
+
1244
+ @add_start_docstrings(
1245
+ """
1246
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1247
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1248
+ """,
1249
+ DEBERTA_START_DOCSTRING,
1250
+ )
1251
+ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
1252
+ def __init__(self, config):
1253
+ super().__init__(config)
1254
+ self.num_labels = config.num_labels
1255
+
1256
+ self.deberta = DebertaModel(config)
1257
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1258
+
1259
+ # Initialize weights and apply final processing
1260
+ self.post_init()
1261
+
1262
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1263
+ @add_code_sample_docstrings(
1264
+ checkpoint=_CHECKPOINT_FOR_QA,
1265
+ output_type=QuestionAnsweringModelOutput,
1266
+ config_class=_CONFIG_FOR_DOC,
1267
+ expected_output=_QA_EXPECTED_OUTPUT,
1268
+ expected_loss=_QA_EXPECTED_LOSS,
1269
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1270
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1271
+ )
1272
+ def forward(
1273
+ self,
1274
+ input_ids: Optional[torch.Tensor] = None,
1275
+ attention_mask: Optional[torch.Tensor] = None,
1276
+ token_type_ids: Optional[torch.Tensor] = None,
1277
+ position_ids: Optional[torch.Tensor] = None,
1278
+ inputs_embeds: Optional[torch.Tensor] = None,
1279
+ start_positions: Optional[torch.Tensor] = None,
1280
+ end_positions: Optional[torch.Tensor] = None,
1281
+ output_attentions: Optional[bool] = None,
1282
+ output_hidden_states: Optional[bool] = None,
1283
+ return_dict: Optional[bool] = None,
1284
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1285
+ r"""
1286
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1287
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1288
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1289
+ are not taken into account for computing the loss.
1290
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1291
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1292
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1293
+ are not taken into account for computing the loss.
1294
+ """
1295
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1296
+
1297
+ outputs = self.deberta(
1298
+ input_ids,
1299
+ attention_mask=attention_mask,
1300
+ token_type_ids=token_type_ids,
1301
+ position_ids=position_ids,
1302
+ inputs_embeds=inputs_embeds,
1303
+ output_attentions=output_attentions,
1304
+ output_hidden_states=output_hidden_states,
1305
+ return_dict=return_dict,
1306
+ )
1307
+
1308
+ sequence_output = outputs[0]
1309
+
1310
+ logits = self.qa_outputs(sequence_output)
1311
+ start_logits, end_logits = logits.split(1, dim=-1)
1312
+ start_logits = start_logits.squeeze(-1).contiguous()
1313
+ end_logits = end_logits.squeeze(-1).contiguous()
1314
+
1315
+ total_loss = None
1316
+ if start_positions is not None and end_positions is not None:
1317
+ # If we are on multi-GPU, split add a dimension
1318
+ if len(start_positions.size()) > 1:
1319
+ start_positions = start_positions.squeeze(-1)
1320
+ if len(end_positions.size()) > 1:
1321
+ end_positions = end_positions.squeeze(-1)
1322
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1323
+ ignored_index = start_logits.size(1)
1324
+ start_positions = start_positions.clamp(0, ignored_index)
1325
+ end_positions = end_positions.clamp(0, ignored_index)
1326
+
1327
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1328
+ start_loss = loss_fct(start_logits, start_positions)
1329
+ end_loss = loss_fct(end_logits, end_positions)
1330
+ total_loss = (start_loss + end_loss) / 2
1331
+
1332
+ if not return_dict:
1333
+ output = (start_logits, end_logits) + outputs[1:]
1334
+ return ((total_loss,) + output) if total_loss is not None else output
1335
+
1336
+ return QuestionAnsweringModelOutput(
1337
+ loss=total_loss,
1338
+ start_logits=start_logits,
1339
+ end_logits=end_logits,
1340
+ hidden_states=outputs.hidden_states,
1341
+ attentions=outputs.attentions,
1342
+ )
1343
+
1344
+
1345
+ __all__ = [
1346
+ "DebertaForMaskedLM",
1347
+ "DebertaForQuestionAnswering",
1348
+ "DebertaForSequenceClassification",
1349
+ "DebertaForTokenClassification",
1350
+ "DebertaModel",
1351
+ "DebertaPreTrainedModel",
1352
+ ]
docs/transformers/build/lib/transformers/models/deberta/modeling_tf_deberta.py ADDED
@@ -0,0 +1,1652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 DeBERTa model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from typing import Dict, Optional, Sequence, Tuple, Union
21
+
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ from ...activations_tf import get_tf_activation
26
+ from ...modeling_tf_outputs import (
27
+ TFBaseModelOutput,
28
+ TFMaskedLMOutput,
29
+ TFQuestionAnsweringModelOutput,
30
+ TFSequenceClassifierOutput,
31
+ TFTokenClassifierOutput,
32
+ )
33
+ from ...modeling_tf_utils import (
34
+ TFMaskedLanguageModelingLoss,
35
+ TFModelInputType,
36
+ TFPreTrainedModel,
37
+ TFQuestionAnsweringLoss,
38
+ TFSequenceClassificationLoss,
39
+ TFTokenClassificationLoss,
40
+ get_initializer,
41
+ keras,
42
+ unpack_inputs,
43
+ )
44
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
45
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
46
+ from .configuration_deberta import DebertaConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ _CONFIG_FOR_DOC = "DebertaConfig"
53
+ _CHECKPOINT_FOR_DOC = "kamalkraj/deberta-base"
54
+
55
+
56
+ class TFDebertaContextPooler(keras.layers.Layer):
57
+ def __init__(self, config: DebertaConfig, **kwargs):
58
+ super().__init__(**kwargs)
59
+ self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense")
60
+ self.dropout = TFDebertaStableDropout(config.pooler_dropout, name="dropout")
61
+ self.config = config
62
+
63
+ def call(self, hidden_states, training: bool = False):
64
+ # We "pool" the model by simply taking the hidden state corresponding
65
+ # to the first token.
66
+ context_token = hidden_states[:, 0]
67
+ context_token = self.dropout(context_token, training=training)
68
+ pooled_output = self.dense(context_token)
69
+ pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output)
70
+ return pooled_output
71
+
72
+ @property
73
+ def output_dim(self) -> int:
74
+ return self.config.hidden_size
75
+
76
+ def build(self, input_shape=None):
77
+ if self.built:
78
+ return
79
+ self.built = True
80
+ if getattr(self, "dense", None) is not None:
81
+ with tf.name_scope(self.dense.name):
82
+ self.dense.build([None, None, self.config.pooler_hidden_size])
83
+ if getattr(self, "dropout", None) is not None:
84
+ with tf.name_scope(self.dropout.name):
85
+ self.dropout.build(None)
86
+
87
+
88
+ class TFDebertaXSoftmax(keras.layers.Layer):
89
+ """
90
+ Masked Softmax which is optimized for saving memory
91
+
92
+ Args:
93
+ input (`tf.Tensor`): The input tensor that will apply softmax.
94
+ mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
95
+ dim (int): The dimension that will apply softmax
96
+ """
97
+
98
+ def __init__(self, axis=-1, **kwargs):
99
+ super().__init__(**kwargs)
100
+ self.axis = axis
101
+
102
+ def call(self, inputs: tf.Tensor, mask: tf.Tensor):
103
+ rmask = tf.logical_not(tf.cast(mask, tf.bool))
104
+ output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)
105
+ output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)
106
+ output = tf.where(rmask, 0.0, output)
107
+ return output
108
+
109
+
110
+ class TFDebertaStableDropout(keras.layers.Layer):
111
+ """
112
+ Optimized dropout module for stabilizing the training
113
+
114
+ Args:
115
+ drop_prob (float): the dropout probabilities
116
+ """
117
+
118
+ def __init__(self, drop_prob, **kwargs):
119
+ super().__init__(**kwargs)
120
+ self.drop_prob = drop_prob
121
+
122
+ @tf.custom_gradient
123
+ def xdropout(self, inputs):
124
+ """
125
+ Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
126
+ """
127
+ mask = tf.cast(
128
+ 1
129
+ - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
130
+ tf.bool,
131
+ )
132
+ scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)
133
+ if self.drop_prob > 0:
134
+ inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale
135
+
136
+ def grad(upstream):
137
+ if self.drop_prob > 0:
138
+ return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale
139
+ else:
140
+ return upstream
141
+
142
+ return inputs, grad
143
+
144
+ def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
145
+ if training:
146
+ return self.xdropout(inputs)
147
+ return inputs
148
+
149
+
150
+ class TFDebertaLayerNorm(keras.layers.Layer):
151
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
152
+
153
+ def __init__(self, size, eps=1e-12, **kwargs):
154
+ super().__init__(**kwargs)
155
+ self.size = size
156
+ self.eps = eps
157
+
158
+ def build(self, input_shape):
159
+ self.gamma = self.add_weight(shape=[self.size], initializer=tf.ones_initializer(), name="weight")
160
+ self.beta = self.add_weight(shape=[self.size], initializer=tf.zeros_initializer(), name="bias")
161
+ return super().build(input_shape)
162
+
163
+ def call(self, x: tf.Tensor) -> tf.Tensor:
164
+ mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
165
+ variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
166
+ std = tf.math.sqrt(variance + self.eps)
167
+ return self.gamma * (x - mean) / std + self.beta
168
+
169
+
170
+ class TFDebertaSelfOutput(keras.layers.Layer):
171
+ def __init__(self, config: DebertaConfig, **kwargs):
172
+ super().__init__(**kwargs)
173
+ self.dense = keras.layers.Dense(config.hidden_size, name="dense")
174
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
175
+ self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
176
+ self.config = config
177
+
178
+ def call(self, hidden_states, input_tensor, training: bool = False):
179
+ hidden_states = self.dense(hidden_states)
180
+ hidden_states = self.dropout(hidden_states, training=training)
181
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
182
+ return hidden_states
183
+
184
+ def build(self, input_shape=None):
185
+ if self.built:
186
+ return
187
+ self.built = True
188
+ if getattr(self, "dense", None) is not None:
189
+ with tf.name_scope(self.dense.name):
190
+ self.dense.build([None, None, self.config.hidden_size])
191
+ if getattr(self, "LayerNorm", None) is not None:
192
+ with tf.name_scope(self.LayerNorm.name):
193
+ self.LayerNorm.build([None, None, self.config.hidden_size])
194
+ if getattr(self, "dropout", None) is not None:
195
+ with tf.name_scope(self.dropout.name):
196
+ self.dropout.build(None)
197
+
198
+
199
+ class TFDebertaAttention(keras.layers.Layer):
200
+ def __init__(self, config: DebertaConfig, **kwargs):
201
+ super().__init__(**kwargs)
202
+ self.self = TFDebertaDisentangledSelfAttention(config, name="self")
203
+ self.dense_output = TFDebertaSelfOutput(config, name="output")
204
+ self.config = config
205
+
206
+ def call(
207
+ self,
208
+ input_tensor: tf.Tensor,
209
+ attention_mask: tf.Tensor,
210
+ query_states: Optional[tf.Tensor] = None,
211
+ relative_pos: Optional[tf.Tensor] = None,
212
+ rel_embeddings: Optional[tf.Tensor] = None,
213
+ output_attentions: bool = False,
214
+ training: bool = False,
215
+ ) -> Tuple[tf.Tensor]:
216
+ self_outputs = self.self(
217
+ hidden_states=input_tensor,
218
+ attention_mask=attention_mask,
219
+ query_states=query_states,
220
+ relative_pos=relative_pos,
221
+ rel_embeddings=rel_embeddings,
222
+ output_attentions=output_attentions,
223
+ training=training,
224
+ )
225
+ if query_states is None:
226
+ query_states = input_tensor
227
+ attention_output = self.dense_output(
228
+ hidden_states=self_outputs[0], input_tensor=query_states, training=training
229
+ )
230
+
231
+ output = (attention_output,) + self_outputs[1:]
232
+
233
+ return output
234
+
235
+ def build(self, input_shape=None):
236
+ if self.built:
237
+ return
238
+ self.built = True
239
+ if getattr(self, "self", None) is not None:
240
+ with tf.name_scope(self.self.name):
241
+ self.self.build(None)
242
+ if getattr(self, "dense_output", None) is not None:
243
+ with tf.name_scope(self.dense_output.name):
244
+ self.dense_output.build(None)
245
+
246
+
247
+ class TFDebertaIntermediate(keras.layers.Layer):
248
+ def __init__(self, config: DebertaConfig, **kwargs):
249
+ super().__init__(**kwargs)
250
+
251
+ self.dense = keras.layers.Dense(
252
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
253
+ )
254
+
255
+ if isinstance(config.hidden_act, str):
256
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
257
+ else:
258
+ self.intermediate_act_fn = config.hidden_act
259
+ self.config = config
260
+
261
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
262
+ hidden_states = self.dense(inputs=hidden_states)
263
+ hidden_states = self.intermediate_act_fn(hidden_states)
264
+
265
+ return hidden_states
266
+
267
+ def build(self, input_shape=None):
268
+ if self.built:
269
+ return
270
+ self.built = True
271
+ if getattr(self, "dense", None) is not None:
272
+ with tf.name_scope(self.dense.name):
273
+ self.dense.build([None, None, self.config.hidden_size])
274
+
275
+
276
+ class TFDebertaOutput(keras.layers.Layer):
277
+ def __init__(self, config: DebertaConfig, **kwargs):
278
+ super().__init__(**kwargs)
279
+
280
+ self.dense = keras.layers.Dense(
281
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
282
+ )
283
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
284
+ self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
285
+ self.config = config
286
+
287
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
288
+ hidden_states = self.dense(inputs=hidden_states)
289
+ hidden_states = self.dropout(hidden_states, training=training)
290
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
291
+
292
+ return hidden_states
293
+
294
+ def build(self, input_shape=None):
295
+ if self.built:
296
+ return
297
+ self.built = True
298
+ if getattr(self, "dense", None) is not None:
299
+ with tf.name_scope(self.dense.name):
300
+ self.dense.build([None, None, self.config.intermediate_size])
301
+ if getattr(self, "LayerNorm", None) is not None:
302
+ with tf.name_scope(self.LayerNorm.name):
303
+ self.LayerNorm.build([None, None, self.config.hidden_size])
304
+ if getattr(self, "dropout", None) is not None:
305
+ with tf.name_scope(self.dropout.name):
306
+ self.dropout.build(None)
307
+
308
+
309
+ class TFDebertaLayer(keras.layers.Layer):
310
+ def __init__(self, config: DebertaConfig, **kwargs):
311
+ super().__init__(**kwargs)
312
+
313
+ self.attention = TFDebertaAttention(config, name="attention")
314
+ self.intermediate = TFDebertaIntermediate(config, name="intermediate")
315
+ self.bert_output = TFDebertaOutput(config, name="output")
316
+
317
+ def call(
318
+ self,
319
+ hidden_states: tf.Tensor,
320
+ attention_mask: tf.Tensor,
321
+ query_states: Optional[tf.Tensor] = None,
322
+ relative_pos: Optional[tf.Tensor] = None,
323
+ rel_embeddings: Optional[tf.Tensor] = None,
324
+ output_attentions: bool = False,
325
+ training: bool = False,
326
+ ) -> Tuple[tf.Tensor]:
327
+ attention_outputs = self.attention(
328
+ input_tensor=hidden_states,
329
+ attention_mask=attention_mask,
330
+ query_states=query_states,
331
+ relative_pos=relative_pos,
332
+ rel_embeddings=rel_embeddings,
333
+ output_attentions=output_attentions,
334
+ training=training,
335
+ )
336
+ attention_output = attention_outputs[0]
337
+ intermediate_output = self.intermediate(hidden_states=attention_output)
338
+ layer_output = self.bert_output(
339
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
340
+ )
341
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
342
+
343
+ return outputs
344
+
345
+ def build(self, input_shape=None):
346
+ if self.built:
347
+ return
348
+ self.built = True
349
+ if getattr(self, "attention", None) is not None:
350
+ with tf.name_scope(self.attention.name):
351
+ self.attention.build(None)
352
+ if getattr(self, "intermediate", None) is not None:
353
+ with tf.name_scope(self.intermediate.name):
354
+ self.intermediate.build(None)
355
+ if getattr(self, "bert_output", None) is not None:
356
+ with tf.name_scope(self.bert_output.name):
357
+ self.bert_output.build(None)
358
+
359
+
360
+ class TFDebertaEncoder(keras.layers.Layer):
361
+ def __init__(self, config: DebertaConfig, **kwargs):
362
+ super().__init__(**kwargs)
363
+
364
+ self.layer = [TFDebertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
365
+ self.relative_attention = getattr(config, "relative_attention", False)
366
+ self.config = config
367
+ if self.relative_attention:
368
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
369
+ if self.max_relative_positions < 1:
370
+ self.max_relative_positions = config.max_position_embeddings
371
+
372
+ def build(self, input_shape=None):
373
+ if self.built:
374
+ return
375
+ self.built = True
376
+ if self.relative_attention:
377
+ self.rel_embeddings = self.add_weight(
378
+ name="rel_embeddings.weight",
379
+ shape=[self.max_relative_positions * 2, self.config.hidden_size],
380
+ initializer=get_initializer(self.config.initializer_range),
381
+ )
382
+ if getattr(self, "layer", None) is not None:
383
+ for layer in self.layer:
384
+ with tf.name_scope(layer.name):
385
+ layer.build(None)
386
+
387
+ def get_rel_embedding(self):
388
+ rel_embeddings = self.rel_embeddings if self.relative_attention else None
389
+ return rel_embeddings
390
+
391
+ def get_attention_mask(self, attention_mask):
392
+ if len(shape_list(attention_mask)) <= 2:
393
+ extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2)
394
+ attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1)
395
+ attention_mask = tf.cast(attention_mask, tf.uint8)
396
+ elif len(shape_list(attention_mask)) == 3:
397
+ attention_mask = tf.expand_dims(attention_mask, 1)
398
+
399
+ return attention_mask
400
+
401
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
402
+ if self.relative_attention and relative_pos is None:
403
+ q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2]
404
+ relative_pos = build_relative_position(q, shape_list(hidden_states)[-2])
405
+ return relative_pos
406
+
407
+ def call(
408
+ self,
409
+ hidden_states: tf.Tensor,
410
+ attention_mask: tf.Tensor,
411
+ query_states: Optional[tf.Tensor] = None,
412
+ relative_pos: Optional[tf.Tensor] = None,
413
+ output_attentions: bool = False,
414
+ output_hidden_states: bool = False,
415
+ return_dict: bool = True,
416
+ training: bool = False,
417
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
418
+ all_hidden_states = () if output_hidden_states else None
419
+ all_attentions = () if output_attentions else None
420
+
421
+ attention_mask = self.get_attention_mask(attention_mask)
422
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
423
+
424
+ if isinstance(hidden_states, Sequence):
425
+ next_kv = hidden_states[0]
426
+ else:
427
+ next_kv = hidden_states
428
+
429
+ rel_embeddings = self.get_rel_embedding()
430
+
431
+ for i, layer_module in enumerate(self.layer):
432
+ if output_hidden_states:
433
+ all_hidden_states = all_hidden_states + (hidden_states,)
434
+
435
+ layer_outputs = layer_module(
436
+ hidden_states=next_kv,
437
+ attention_mask=attention_mask,
438
+ query_states=query_states,
439
+ relative_pos=relative_pos,
440
+ rel_embeddings=rel_embeddings,
441
+ output_attentions=output_attentions,
442
+ training=training,
443
+ )
444
+ hidden_states = layer_outputs[0]
445
+
446
+ if query_states is not None:
447
+ query_states = hidden_states
448
+ if isinstance(hidden_states, Sequence):
449
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
450
+ else:
451
+ next_kv = hidden_states
452
+
453
+ if output_attentions:
454
+ all_attentions = all_attentions + (layer_outputs[1],)
455
+
456
+ # Add last layer
457
+ if output_hidden_states:
458
+ all_hidden_states = all_hidden_states + (hidden_states,)
459
+
460
+ if not return_dict:
461
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
462
+
463
+ return TFBaseModelOutput(
464
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
465
+ )
466
+
467
+
468
+ def build_relative_position(query_size, key_size):
469
+ """
470
+ Build relative position according to the query and key
471
+
472
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
473
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
474
+ P_k\\)
475
+
476
+ Args:
477
+ query_size (int): the length of query
478
+ key_size (int): the length of key
479
+
480
+ Return:
481
+ `tf.Tensor`: A tensor with shape [1, query_size, key_size]
482
+
483
+ """
484
+ q_ids = tf.range(query_size, dtype=tf.int32)
485
+ k_ids = tf.range(key_size, dtype=tf.int32)
486
+ rel_pos_ids = q_ids[:, None] - tf.tile(tf.reshape(k_ids, [1, -1]), [query_size, 1])
487
+ rel_pos_ids = rel_pos_ids[:query_size, :]
488
+ rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0)
489
+ return tf.cast(rel_pos_ids, tf.int64)
490
+
491
+
492
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
493
+ shapes = [
494
+ shape_list(query_layer)[0],
495
+ shape_list(query_layer)[1],
496
+ shape_list(query_layer)[2],
497
+ shape_list(relative_pos)[-1],
498
+ ]
499
+ return tf.broadcast_to(c2p_pos, shapes)
500
+
501
+
502
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
503
+ shapes = [
504
+ shape_list(query_layer)[0],
505
+ shape_list(query_layer)[1],
506
+ shape_list(key_layer)[-2],
507
+ shape_list(key_layer)[-2],
508
+ ]
509
+ return tf.broadcast_to(c2p_pos, shapes)
510
+
511
+
512
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
513
+ shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]]
514
+ return tf.broadcast_to(pos_index, shapes)
515
+
516
+
517
+ def torch_gather(x, indices, gather_axis):
518
+ if gather_axis < 0:
519
+ gather_axis = tf.rank(x) + gather_axis
520
+
521
+ if gather_axis != tf.rank(x) - 1:
522
+ pre_roll = tf.rank(x) - 1 - gather_axis
523
+ permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
524
+ x = tf.transpose(x, perm=permutation)
525
+ indices = tf.transpose(indices, perm=permutation)
526
+ else:
527
+ pre_roll = 0
528
+
529
+ flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
530
+ flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
531
+ gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
532
+ gathered = tf.reshape(gathered, tf.shape(indices))
533
+
534
+ if pre_roll != 0:
535
+ permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
536
+ gathered = tf.transpose(gathered, perm=permutation)
537
+
538
+ return gathered
539
+
540
+
541
+ class TFDebertaDisentangledSelfAttention(keras.layers.Layer):
542
+ """
543
+ Disentangled self-attention module
544
+
545
+ Parameters:
546
+ config (`str`):
547
+ A model config class instance with the configuration to build a new model. The schema is similar to
548
+ *BertConfig*, for more details, please refer [`DebertaConfig`]
549
+
550
+ """
551
+
552
+ def __init__(self, config: DebertaConfig, **kwargs):
553
+ super().__init__(**kwargs)
554
+ if config.hidden_size % config.num_attention_heads != 0:
555
+ raise ValueError(
556
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
557
+ f"heads ({config.num_attention_heads})"
558
+ )
559
+ self.num_attention_heads = config.num_attention_heads
560
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
561
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
562
+ self.in_proj = keras.layers.Dense(
563
+ self.all_head_size * 3,
564
+ kernel_initializer=get_initializer(config.initializer_range),
565
+ name="in_proj",
566
+ use_bias=False,
567
+ )
568
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
569
+
570
+ self.relative_attention = getattr(config, "relative_attention", False)
571
+ self.talking_head = getattr(config, "talking_head", False)
572
+
573
+ if self.talking_head:
574
+ self.head_logits_proj = keras.layers.Dense(
575
+ self.num_attention_heads,
576
+ kernel_initializer=get_initializer(config.initializer_range),
577
+ name="head_logits_proj",
578
+ use_bias=False,
579
+ )
580
+ self.head_weights_proj = keras.layers.Dense(
581
+ self.num_attention_heads,
582
+ kernel_initializer=get_initializer(config.initializer_range),
583
+ name="head_weights_proj",
584
+ use_bias=False,
585
+ )
586
+
587
+ self.softmax = TFDebertaXSoftmax(axis=-1)
588
+
589
+ if self.relative_attention:
590
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
591
+ if self.max_relative_positions < 1:
592
+ self.max_relative_positions = config.max_position_embeddings
593
+ self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout")
594
+ if "c2p" in self.pos_att_type:
595
+ self.pos_proj = keras.layers.Dense(
596
+ self.all_head_size,
597
+ kernel_initializer=get_initializer(config.initializer_range),
598
+ name="pos_proj",
599
+ use_bias=False,
600
+ )
601
+ if "p2c" in self.pos_att_type:
602
+ self.pos_q_proj = keras.layers.Dense(
603
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj"
604
+ )
605
+
606
+ self.dropout = TFDebertaStableDropout(config.attention_probs_dropout_prob, name="dropout")
607
+ self.config = config
608
+
609
+ def build(self, input_shape=None):
610
+ if self.built:
611
+ return
612
+ self.built = True
613
+ self.q_bias = self.add_weight(
614
+ name="q_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros()
615
+ )
616
+ self.v_bias = self.add_weight(
617
+ name="v_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros()
618
+ )
619
+ if getattr(self, "in_proj", None) is not None:
620
+ with tf.name_scope(self.in_proj.name):
621
+ self.in_proj.build([None, None, self.config.hidden_size])
622
+ if getattr(self, "dropout", None) is not None:
623
+ with tf.name_scope(self.dropout.name):
624
+ self.dropout.build(None)
625
+ if getattr(self, "head_logits_proj", None) is not None:
626
+ with tf.name_scope(self.head_logits_proj.name):
627
+ self.head_logits_proj.build(None)
628
+ if getattr(self, "head_weights_proj", None) is not None:
629
+ with tf.name_scope(self.head_weights_proj.name):
630
+ self.head_weights_proj.build(None)
631
+ if getattr(self, "pos_dropout", None) is not None:
632
+ with tf.name_scope(self.pos_dropout.name):
633
+ self.pos_dropout.build(None)
634
+ if getattr(self, "pos_proj", None) is not None:
635
+ with tf.name_scope(self.pos_proj.name):
636
+ self.pos_proj.build([self.config.hidden_size])
637
+ if getattr(self, "pos_q_proj", None) is not None:
638
+ with tf.name_scope(self.pos_q_proj.name):
639
+ self.pos_q_proj.build([self.config.hidden_size])
640
+
641
+ def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
642
+ shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1]
643
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
644
+ tensor = tf.reshape(tensor=tensor, shape=shape)
645
+
646
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
647
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
648
+
649
+ def call(
650
+ self,
651
+ hidden_states: tf.Tensor,
652
+ attention_mask: tf.Tensor,
653
+ query_states: Optional[tf.Tensor] = None,
654
+ relative_pos: Optional[tf.Tensor] = None,
655
+ rel_embeddings: Optional[tf.Tensor] = None,
656
+ output_attentions: bool = False,
657
+ training: bool = False,
658
+ ) -> Tuple[tf.Tensor]:
659
+ """
660
+ Call the module
661
+
662
+ Args:
663
+ hidden_states (`tf.Tensor`):
664
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
665
+ *Attention(Q,K,V)*
666
+
667
+ attention_mask (`tf.Tensor`):
668
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
669
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
670
+ th token.
671
+
672
+ return_att (`bool`, *optional*):
673
+ Whether return the attention matrix.
674
+
675
+ query_states (`tf.Tensor`, *optional*):
676
+ The *Q* state in *Attention(Q,K,V)*.
677
+
678
+ relative_pos (`tf.Tensor`):
679
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
680
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
681
+
682
+ rel_embeddings (`tf.Tensor`):
683
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
684
+ \\text{max_relative_positions}\\), *hidden_size*].
685
+
686
+
687
+ """
688
+ if query_states is None:
689
+ qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
690
+ query_layer, key_layer, value_layer = tf.split(
691
+ self.transpose_for_scores(qp), num_or_size_splits=3, axis=-1
692
+ )
693
+ else:
694
+
695
+ def linear(w, b, x):
696
+ out = tf.matmul(x, w, transpose_b=True)
697
+ if b is not None:
698
+ out += tf.transpose(b)
699
+ return out
700
+
701
+ ws = tf.split(
702
+ tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
703
+ )
704
+ qkvw = tf.TensorArray(dtype=self.dtype, size=3)
705
+ for k in tf.range(3):
706
+ qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads)
707
+ for i in tf.range(self.num_attention_heads):
708
+ qkvw_inside = qkvw_inside.write(i, ws[i * 3 + k])
709
+ qkvw = qkvw.write(k, qkvw_inside.concat())
710
+ qkvb = [None] * 3
711
+
712
+ q = linear(qkvw[0], qkvb[0], query_states)
713
+ k = linear(qkvw[1], qkvb[1], hidden_states)
714
+ v = linear(qkvw[2], qkvb[2], hidden_states)
715
+ query_layer = self.transpose_for_scores(q)
716
+ key_layer = self.transpose_for_scores(k)
717
+ value_layer = self.transpose_for_scores(v)
718
+
719
+ query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
720
+ value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
721
+
722
+ rel_att = None
723
+ # Take the dot product between "query" and "key" to get the raw attention scores.
724
+ scale_factor = 1 + len(self.pos_att_type)
725
+ scale = math.sqrt(shape_list(query_layer)[-1] * scale_factor)
726
+ query_layer = query_layer / scale
727
+
728
+ attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 1, 3, 2]))
729
+ if self.relative_attention:
730
+ rel_embeddings = self.pos_dropout(rel_embeddings, training=training)
731
+ rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
732
+
733
+ if rel_att is not None:
734
+ attention_scores = attention_scores + rel_att
735
+
736
+ if self.talking_head:
737
+ attention_scores = tf.transpose(
738
+ self.head_logits_proj(tf.transpose(attention_scores, [0, 2, 3, 1])), [0, 3, 1, 2]
739
+ )
740
+
741
+ attention_probs = self.softmax(attention_scores, attention_mask)
742
+ attention_probs = self.dropout(attention_probs, training=training)
743
+ if self.talking_head:
744
+ attention_probs = tf.transpose(
745
+ self.head_weights_proj(tf.transpose(attention_probs, [0, 2, 3, 1])), [0, 3, 1, 2]
746
+ )
747
+
748
+ context_layer = tf.matmul(attention_probs, value_layer)
749
+ context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
750
+ context_layer_shape = shape_list(context_layer)
751
+ # Set the final dimension here explicitly.
752
+ # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
753
+ # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
754
+ # requires final input dimension to be defined
755
+ new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
756
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
757
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
758
+ return outputs
759
+
760
+ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
761
+ if relative_pos is None:
762
+ q = shape_list(query_layer)[-2]
763
+ relative_pos = build_relative_position(q, shape_list(key_layer)[-2])
764
+ shape_list_pos = shape_list(relative_pos)
765
+ if len(shape_list_pos) == 2:
766
+ relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
767
+ elif len(shape_list_pos) == 3:
768
+ relative_pos = tf.expand_dims(relative_pos, 1)
769
+ # bxhxqxk
770
+ elif len(shape_list_pos) != 4:
771
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
772
+
773
+ att_span = tf.cast(
774
+ tf.minimum(
775
+ tf.maximum(shape_list(query_layer)[-2], shape_list(key_layer)[-2]), self.max_relative_positions
776
+ ),
777
+ tf.int64,
778
+ )
779
+ rel_embeddings = tf.expand_dims(
780
+ rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0
781
+ )
782
+
783
+ score = 0
784
+
785
+ # content->position
786
+ if "c2p" in self.pos_att_type:
787
+ pos_key_layer = self.pos_proj(rel_embeddings)
788
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
789
+ c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2]))
790
+ c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
791
+ c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1)
792
+ score += c2p_att
793
+
794
+ # position->content
795
+ if "p2c" in self.pos_att_type:
796
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
797
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
798
+ pos_query_layer /= tf.math.sqrt(
799
+ tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype)
800
+ )
801
+ if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]:
802
+ r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2])
803
+ else:
804
+ r_pos = relative_pos
805
+ p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
806
+ p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2]))
807
+ p2c_att = tf.transpose(
808
+ torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2]
809
+ )
810
+ if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]:
811
+ pos_index = tf.expand_dims(relative_pos[:, :, :, 0], -1)
812
+ p2c_att = torch_gather(p2c_att, pos_dynamic_expand(pos_index, p2c_att, key_layer), -2)
813
+ score += p2c_att
814
+
815
+ return score
816
+
817
+
818
+ class TFDebertaEmbeddings(keras.layers.Layer):
819
+ """Construct the embeddings from word, position and token_type embeddings."""
820
+
821
+ def __init__(self, config, **kwargs):
822
+ super().__init__(**kwargs)
823
+
824
+ self.config = config
825
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
826
+ self.hidden_size = config.hidden_size
827
+ self.max_position_embeddings = config.max_position_embeddings
828
+ self.position_biased_input = getattr(config, "position_biased_input", True)
829
+ self.initializer_range = config.initializer_range
830
+ if self.embedding_size != config.hidden_size:
831
+ self.embed_proj = keras.layers.Dense(
832
+ config.hidden_size,
833
+ kernel_initializer=get_initializer(config.initializer_range),
834
+ name="embed_proj",
835
+ use_bias=False,
836
+ )
837
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
838
+ self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
839
+
840
+ def build(self, input_shape=None):
841
+ with tf.name_scope("word_embeddings"):
842
+ self.weight = self.add_weight(
843
+ name="weight",
844
+ shape=[self.config.vocab_size, self.embedding_size],
845
+ initializer=get_initializer(self.initializer_range),
846
+ )
847
+
848
+ with tf.name_scope("token_type_embeddings"):
849
+ if self.config.type_vocab_size > 0:
850
+ self.token_type_embeddings = self.add_weight(
851
+ name="embeddings",
852
+ shape=[self.config.type_vocab_size, self.embedding_size],
853
+ initializer=get_initializer(self.initializer_range),
854
+ )
855
+ else:
856
+ self.token_type_embeddings = None
857
+
858
+ with tf.name_scope("position_embeddings"):
859
+ if self.position_biased_input:
860
+ self.position_embeddings = self.add_weight(
861
+ name="embeddings",
862
+ shape=[self.max_position_embeddings, self.hidden_size],
863
+ initializer=get_initializer(self.initializer_range),
864
+ )
865
+ else:
866
+ self.position_embeddings = None
867
+
868
+ if self.built:
869
+ return
870
+ self.built = True
871
+ if getattr(self, "LayerNorm", None) is not None:
872
+ with tf.name_scope(self.LayerNorm.name):
873
+ self.LayerNorm.build([None, None, self.config.hidden_size])
874
+ if getattr(self, "dropout", None) is not None:
875
+ with tf.name_scope(self.dropout.name):
876
+ self.dropout.build(None)
877
+ if getattr(self, "embed_proj", None) is not None:
878
+ with tf.name_scope(self.embed_proj.name):
879
+ self.embed_proj.build([None, None, self.embedding_size])
880
+
881
+ def call(
882
+ self,
883
+ input_ids: Optional[tf.Tensor] = None,
884
+ position_ids: Optional[tf.Tensor] = None,
885
+ token_type_ids: Optional[tf.Tensor] = None,
886
+ inputs_embeds: Optional[tf.Tensor] = None,
887
+ mask: Optional[tf.Tensor] = None,
888
+ training: bool = False,
889
+ ) -> tf.Tensor:
890
+ """
891
+ Applies embedding based on inputs tensor.
892
+
893
+ Returns:
894
+ final_embeddings (`tf.Tensor`): output embedding tensor.
895
+ """
896
+ if input_ids is None and inputs_embeds is None:
897
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
898
+
899
+ if input_ids is not None:
900
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
901
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
902
+
903
+ input_shape = shape_list(inputs_embeds)[:-1]
904
+
905
+ if token_type_ids is None:
906
+ token_type_ids = tf.fill(dims=input_shape, value=0)
907
+
908
+ if position_ids is None:
909
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
910
+
911
+ final_embeddings = inputs_embeds
912
+ if self.position_biased_input:
913
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
914
+ final_embeddings += position_embeds
915
+ if self.config.type_vocab_size > 0:
916
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
917
+ final_embeddings += token_type_embeds
918
+
919
+ if self.embedding_size != self.hidden_size:
920
+ final_embeddings = self.embed_proj(final_embeddings)
921
+
922
+ final_embeddings = self.LayerNorm(final_embeddings)
923
+
924
+ if mask is not None:
925
+ if len(shape_list(mask)) != len(shape_list(final_embeddings)):
926
+ if len(shape_list(mask)) == 4:
927
+ mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1)
928
+ mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)
929
+
930
+ final_embeddings = final_embeddings * mask
931
+
932
+ final_embeddings = self.dropout(final_embeddings, training=training)
933
+
934
+ return final_embeddings
935
+
936
+
937
+ class TFDebertaPredictionHeadTransform(keras.layers.Layer):
938
+ def __init__(self, config: DebertaConfig, **kwargs):
939
+ super().__init__(**kwargs)
940
+
941
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
942
+
943
+ self.dense = keras.layers.Dense(
944
+ units=self.embedding_size,
945
+ kernel_initializer=get_initializer(config.initializer_range),
946
+ name="dense",
947
+ )
948
+
949
+ if isinstance(config.hidden_act, str):
950
+ self.transform_act_fn = get_tf_activation(config.hidden_act)
951
+ else:
952
+ self.transform_act_fn = config.hidden_act
953
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
954
+ self.config = config
955
+
956
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
957
+ hidden_states = self.dense(inputs=hidden_states)
958
+ hidden_states = self.transform_act_fn(hidden_states)
959
+ hidden_states = self.LayerNorm(hidden_states)
960
+
961
+ return hidden_states
962
+
963
+ def build(self, input_shape=None):
964
+ if self.built:
965
+ return
966
+ self.built = True
967
+ if getattr(self, "dense", None) is not None:
968
+ with tf.name_scope(self.dense.name):
969
+ self.dense.build([None, None, self.config.hidden_size])
970
+ if getattr(self, "LayerNorm", None) is not None:
971
+ with tf.name_scope(self.LayerNorm.name):
972
+ self.LayerNorm.build([None, None, self.embedding_size])
973
+
974
+
975
+ class TFDebertaLMPredictionHead(keras.layers.Layer):
976
+ def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs):
977
+ super().__init__(**kwargs)
978
+
979
+ self.config = config
980
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
981
+
982
+ self.transform = TFDebertaPredictionHeadTransform(config, name="transform")
983
+
984
+ # The output weights are the same as the input embeddings, but there is
985
+ # an output-only bias for each token.
986
+ self.input_embeddings = input_embeddings
987
+
988
+ def build(self, input_shape=None):
989
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
990
+
991
+ if self.built:
992
+ return
993
+ self.built = True
994
+ if getattr(self, "transform", None) is not None:
995
+ with tf.name_scope(self.transform.name):
996
+ self.transform.build(None)
997
+
998
+ def get_output_embeddings(self) -> keras.layers.Layer:
999
+ return self.input_embeddings
1000
+
1001
+ def set_output_embeddings(self, value: tf.Variable):
1002
+ self.input_embeddings.weight = value
1003
+ self.input_embeddings.vocab_size = shape_list(value)[0]
1004
+
1005
+ def get_bias(self) -> Dict[str, tf.Variable]:
1006
+ return {"bias": self.bias}
1007
+
1008
+ def set_bias(self, value: tf.Variable):
1009
+ self.bias = value["bias"]
1010
+ self.config.vocab_size = shape_list(value["bias"])[0]
1011
+
1012
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
1013
+ hidden_states = self.transform(hidden_states=hidden_states)
1014
+ seq_length = shape_list(hidden_states)[1]
1015
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
1016
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
1017
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
1018
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
1019
+
1020
+ return hidden_states
1021
+
1022
+
1023
+ class TFDebertaOnlyMLMHead(keras.layers.Layer):
1024
+ def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs):
1025
+ super().__init__(**kwargs)
1026
+ self.predictions = TFDebertaLMPredictionHead(config, input_embeddings, name="predictions")
1027
+
1028
+ def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
1029
+ prediction_scores = self.predictions(hidden_states=sequence_output)
1030
+
1031
+ return prediction_scores
1032
+
1033
+ def build(self, input_shape=None):
1034
+ if self.built:
1035
+ return
1036
+ self.built = True
1037
+ if getattr(self, "predictions", None) is not None:
1038
+ with tf.name_scope(self.predictions.name):
1039
+ self.predictions.build(None)
1040
+
1041
+
1042
+ # @keras_serializable
1043
+ class TFDebertaMainLayer(keras.layers.Layer):
1044
+ config_class = DebertaConfig
1045
+
1046
+ def __init__(self, config: DebertaConfig, **kwargs):
1047
+ super().__init__(**kwargs)
1048
+
1049
+ self.config = config
1050
+
1051
+ self.embeddings = TFDebertaEmbeddings(config, name="embeddings")
1052
+ self.encoder = TFDebertaEncoder(config, name="encoder")
1053
+
1054
+ def get_input_embeddings(self) -> keras.layers.Layer:
1055
+ return self.embeddings
1056
+
1057
+ def set_input_embeddings(self, value: tf.Variable):
1058
+ self.embeddings.weight = value
1059
+ self.embeddings.vocab_size = shape_list(value)[0]
1060
+
1061
+ def _prune_heads(self, heads_to_prune):
1062
+ """
1063
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1064
+ class PreTrainedModel
1065
+ """
1066
+ raise NotImplementedError
1067
+
1068
+ @unpack_inputs
1069
+ def call(
1070
+ self,
1071
+ input_ids: TFModelInputType | None = None,
1072
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1073
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1074
+ position_ids: np.ndarray | tf.Tensor | None = None,
1075
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1076
+ output_attentions: Optional[bool] = None,
1077
+ output_hidden_states: Optional[bool] = None,
1078
+ return_dict: Optional[bool] = None,
1079
+ training: bool = False,
1080
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
1081
+ if input_ids is not None and inputs_embeds is not None:
1082
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1083
+ elif input_ids is not None:
1084
+ input_shape = shape_list(input_ids)
1085
+ elif inputs_embeds is not None:
1086
+ input_shape = shape_list(inputs_embeds)[:-1]
1087
+ else:
1088
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1089
+
1090
+ if attention_mask is None:
1091
+ attention_mask = tf.fill(dims=input_shape, value=1)
1092
+
1093
+ if token_type_ids is None:
1094
+ token_type_ids = tf.fill(dims=input_shape, value=0)
1095
+
1096
+ embedding_output = self.embeddings(
1097
+ input_ids=input_ids,
1098
+ position_ids=position_ids,
1099
+ token_type_ids=token_type_ids,
1100
+ inputs_embeds=inputs_embeds,
1101
+ mask=attention_mask,
1102
+ training=training,
1103
+ )
1104
+
1105
+ encoder_outputs = self.encoder(
1106
+ hidden_states=embedding_output,
1107
+ attention_mask=attention_mask,
1108
+ output_attentions=output_attentions,
1109
+ output_hidden_states=output_hidden_states,
1110
+ return_dict=return_dict,
1111
+ training=training,
1112
+ )
1113
+
1114
+ sequence_output = encoder_outputs[0]
1115
+
1116
+ if not return_dict:
1117
+ return (sequence_output,) + encoder_outputs[1:]
1118
+
1119
+ return TFBaseModelOutput(
1120
+ last_hidden_state=sequence_output,
1121
+ hidden_states=encoder_outputs.hidden_states,
1122
+ attentions=encoder_outputs.attentions,
1123
+ )
1124
+
1125
+ def build(self, input_shape=None):
1126
+ if self.built:
1127
+ return
1128
+ self.built = True
1129
+ if getattr(self, "embeddings", None) is not None:
1130
+ with tf.name_scope(self.embeddings.name):
1131
+ self.embeddings.build(None)
1132
+ if getattr(self, "encoder", None) is not None:
1133
+ with tf.name_scope(self.encoder.name):
1134
+ self.encoder.build(None)
1135
+
1136
+
1137
+ class TFDebertaPreTrainedModel(TFPreTrainedModel):
1138
+ """
1139
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1140
+ models.
1141
+ """
1142
+
1143
+ config_class = DebertaConfig
1144
+ base_model_prefix = "deberta"
1145
+
1146
+
1147
+ DEBERTA_START_DOCSTRING = r"""
1148
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
1149
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
1150
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
1151
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
1152
+
1153
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
1154
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
1155
+ behavior.
1156
+
1157
+ <Tip>
1158
+
1159
+ TensorFlow models and layers in `transformers` accept two formats as input:
1160
+
1161
+ - having all inputs as keyword arguments (like PyTorch models), or
1162
+ - having all inputs as a list, tuple or dict in the first positional argument.
1163
+
1164
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
1165
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
1166
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
1167
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
1168
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
1169
+ positional argument:
1170
+
1171
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
1172
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
1173
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
1174
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
1175
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
1176
+
1177
+ Note that when creating models and layers with
1178
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
1179
+ about any of this, as you can just pass inputs like you would to any other Python function!
1180
+
1181
+ </Tip>
1182
+
1183
+ Parameters:
1184
+ config ([`DebertaConfig`]): Model configuration class with all the parameters of the model.
1185
+ Initializing with a config file does not load the weights associated with the model, only the
1186
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1187
+ """
1188
+
1189
+ DEBERTA_INPUTS_DOCSTRING = r"""
1190
+ Args:
1191
+ input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
1192
+ Indices of input sequence tokens in the vocabulary.
1193
+
1194
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1195
+ [`PreTrainedTokenizer.__call__`] for details.
1196
+
1197
+ [What are input IDs?](../glossary#input-ids)
1198
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1199
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1200
+
1201
+ - 1 for tokens that are **not masked**,
1202
+ - 0 for tokens that are **masked**.
1203
+
1204
+ [What are attention masks?](../glossary#attention-mask)
1205
+ token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1206
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1207
+ 1]`:
1208
+
1209
+ - 0 corresponds to a *sentence A* token,
1210
+ - 1 corresponds to a *sentence B* token.
1211
+
1212
+ [What are token type IDs?](../glossary#token-type-ids)
1213
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1214
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1215
+ config.max_position_embeddings - 1]`.
1216
+
1217
+ [What are position IDs?](../glossary#position-ids)
1218
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
1219
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1220
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
1221
+ model's internal embedding lookup matrix.
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
1235
+ DEBERTA_START_DOCSTRING,
1236
+ )
1237
+ class TFDebertaModel(TFDebertaPreTrainedModel):
1238
+ def __init__(self, config: DebertaConfig, *inputs, **kwargs):
1239
+ super().__init__(config, *inputs, **kwargs)
1240
+
1241
+ self.deberta = TFDebertaMainLayer(config, name="deberta")
1242
+
1243
+ @unpack_inputs
1244
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1245
+ @add_code_sample_docstrings(
1246
+ checkpoint=_CHECKPOINT_FOR_DOC,
1247
+ output_type=TFBaseModelOutput,
1248
+ config_class=_CONFIG_FOR_DOC,
1249
+ )
1250
+ def call(
1251
+ self,
1252
+ input_ids: TFModelInputType | None = None,
1253
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1254
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1255
+ position_ids: np.ndarray | tf.Tensor | None = None,
1256
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1257
+ output_attentions: Optional[bool] = None,
1258
+ output_hidden_states: Optional[bool] = None,
1259
+ return_dict: Optional[bool] = None,
1260
+ training: Optional[bool] = False,
1261
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
1262
+ outputs = self.deberta(
1263
+ input_ids=input_ids,
1264
+ attention_mask=attention_mask,
1265
+ token_type_ids=token_type_ids,
1266
+ position_ids=position_ids,
1267
+ inputs_embeds=inputs_embeds,
1268
+ output_attentions=output_attentions,
1269
+ output_hidden_states=output_hidden_states,
1270
+ return_dict=return_dict,
1271
+ training=training,
1272
+ )
1273
+
1274
+ return outputs
1275
+
1276
+ def build(self, input_shape=None):
1277
+ if self.built:
1278
+ return
1279
+ self.built = True
1280
+ if getattr(self, "deberta", None) is not None:
1281
+ with tf.name_scope(self.deberta.name):
1282
+ self.deberta.build(None)
1283
+
1284
+
1285
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1286
+ class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLoss):
1287
+ def __init__(self, config: DebertaConfig, *inputs, **kwargs):
1288
+ super().__init__(config, *inputs, **kwargs)
1289
+
1290
+ if config.is_decoder:
1291
+ logger.warning(
1292
+ "If you want to use `TFDebertaForMaskedLM` make sure `config.is_decoder=False` for "
1293
+ "bi-directional self-attention."
1294
+ )
1295
+
1296
+ self.deberta = TFDebertaMainLayer(config, name="deberta")
1297
+ self.mlm = TFDebertaOnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls")
1298
+
1299
+ def get_lm_head(self) -> keras.layers.Layer:
1300
+ return self.mlm.predictions
1301
+
1302
+ @unpack_inputs
1303
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1304
+ @add_code_sample_docstrings(
1305
+ checkpoint=_CHECKPOINT_FOR_DOC,
1306
+ output_type=TFMaskedLMOutput,
1307
+ config_class=_CONFIG_FOR_DOC,
1308
+ )
1309
+ def call(
1310
+ self,
1311
+ input_ids: TFModelInputType | None = None,
1312
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1313
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1314
+ position_ids: np.ndarray | tf.Tensor | None = None,
1315
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1316
+ output_attentions: Optional[bool] = None,
1317
+ output_hidden_states: Optional[bool] = None,
1318
+ return_dict: Optional[bool] = None,
1319
+ labels: np.ndarray | tf.Tensor | None = None,
1320
+ training: Optional[bool] = False,
1321
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
1322
+ r"""
1323
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1324
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1325
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1326
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1327
+ """
1328
+ outputs = self.deberta(
1329
+ input_ids=input_ids,
1330
+ attention_mask=attention_mask,
1331
+ token_type_ids=token_type_ids,
1332
+ position_ids=position_ids,
1333
+ inputs_embeds=inputs_embeds,
1334
+ output_attentions=output_attentions,
1335
+ output_hidden_states=output_hidden_states,
1336
+ return_dict=return_dict,
1337
+ training=training,
1338
+ )
1339
+ sequence_output = outputs[0]
1340
+ prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
1341
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
1342
+
1343
+ if not return_dict:
1344
+ output = (prediction_scores,) + outputs[2:]
1345
+ return ((loss,) + output) if loss is not None else output
1346
+
1347
+ return TFMaskedLMOutput(
1348
+ loss=loss,
1349
+ logits=prediction_scores,
1350
+ hidden_states=outputs.hidden_states,
1351
+ attentions=outputs.attentions,
1352
+ )
1353
+
1354
+ def build(self, input_shape=None):
1355
+ if self.built:
1356
+ return
1357
+ self.built = True
1358
+ if getattr(self, "deberta", None) is not None:
1359
+ with tf.name_scope(self.deberta.name):
1360
+ self.deberta.build(None)
1361
+ if getattr(self, "mlm", None) is not None:
1362
+ with tf.name_scope(self.mlm.name):
1363
+ self.mlm.build(None)
1364
+
1365
+
1366
+ @add_start_docstrings(
1367
+ """
1368
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1369
+ pooled output) e.g. for GLUE tasks.
1370
+ """,
1371
+ DEBERTA_START_DOCSTRING,
1372
+ )
1373
+ class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceClassificationLoss):
1374
+ def __init__(self, config: DebertaConfig, *inputs, **kwargs):
1375
+ super().__init__(config, *inputs, **kwargs)
1376
+
1377
+ self.num_labels = config.num_labels
1378
+
1379
+ self.deberta = TFDebertaMainLayer(config, name="deberta")
1380
+ self.pooler = TFDebertaContextPooler(config, name="pooler")
1381
+
1382
+ drop_out = getattr(config, "cls_dropout", None)
1383
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1384
+ self.dropout = TFDebertaStableDropout(drop_out, name="cls_dropout")
1385
+ self.classifier = keras.layers.Dense(
1386
+ units=config.num_labels,
1387
+ kernel_initializer=get_initializer(config.initializer_range),
1388
+ name="classifier",
1389
+ )
1390
+ self.output_dim = self.pooler.output_dim
1391
+
1392
+ @unpack_inputs
1393
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1394
+ @add_code_sample_docstrings(
1395
+ checkpoint=_CHECKPOINT_FOR_DOC,
1396
+ output_type=TFSequenceClassifierOutput,
1397
+ config_class=_CONFIG_FOR_DOC,
1398
+ )
1399
+ def call(
1400
+ self,
1401
+ input_ids: TFModelInputType | None = None,
1402
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1403
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1404
+ position_ids: np.ndarray | tf.Tensor | None = None,
1405
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1406
+ output_attentions: Optional[bool] = None,
1407
+ output_hidden_states: Optional[bool] = None,
1408
+ return_dict: Optional[bool] = None,
1409
+ labels: np.ndarray | tf.Tensor | None = None,
1410
+ training: Optional[bool] = False,
1411
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
1412
+ r"""
1413
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1414
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1415
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1416
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1417
+ """
1418
+ outputs = self.deberta(
1419
+ input_ids=input_ids,
1420
+ attention_mask=attention_mask,
1421
+ token_type_ids=token_type_ids,
1422
+ position_ids=position_ids,
1423
+ inputs_embeds=inputs_embeds,
1424
+ output_attentions=output_attentions,
1425
+ output_hidden_states=output_hidden_states,
1426
+ return_dict=return_dict,
1427
+ training=training,
1428
+ )
1429
+ sequence_output = outputs[0]
1430
+ pooled_output = self.pooler(sequence_output, training=training)
1431
+ pooled_output = self.dropout(pooled_output, training=training)
1432
+ logits = self.classifier(pooled_output)
1433
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1434
+
1435
+ if not return_dict:
1436
+ output = (logits,) + outputs[1:]
1437
+
1438
+ return ((loss,) + output) if loss is not None else output
1439
+
1440
+ return TFSequenceClassifierOutput(
1441
+ loss=loss,
1442
+ logits=logits,
1443
+ hidden_states=outputs.hidden_states,
1444
+ attentions=outputs.attentions,
1445
+ )
1446
+
1447
+ def build(self, input_shape=None):
1448
+ if self.built:
1449
+ return
1450
+ self.built = True
1451
+ if getattr(self, "deberta", None) is not None:
1452
+ with tf.name_scope(self.deberta.name):
1453
+ self.deberta.build(None)
1454
+ if getattr(self, "pooler", None) is not None:
1455
+ with tf.name_scope(self.pooler.name):
1456
+ self.pooler.build(None)
1457
+ if getattr(self, "dropout", None) is not None:
1458
+ with tf.name_scope(self.dropout.name):
1459
+ self.dropout.build(None)
1460
+ if getattr(self, "classifier", None) is not None:
1461
+ with tf.name_scope(self.classifier.name):
1462
+ self.classifier.build([None, None, self.output_dim])
1463
+
1464
+
1465
+ @add_start_docstrings(
1466
+ """
1467
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1468
+ Named-Entity-Recognition (NER) tasks.
1469
+ """,
1470
+ DEBERTA_START_DOCSTRING,
1471
+ )
1472
+ class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassificationLoss):
1473
+ def __init__(self, config: DebertaConfig, *inputs, **kwargs):
1474
+ super().__init__(config, *inputs, **kwargs)
1475
+
1476
+ self.num_labels = config.num_labels
1477
+
1478
+ self.deberta = TFDebertaMainLayer(config, name="deberta")
1479
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
1480
+ self.classifier = keras.layers.Dense(
1481
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1482
+ )
1483
+ self.config = config
1484
+
1485
+ @unpack_inputs
1486
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1487
+ @add_code_sample_docstrings(
1488
+ checkpoint=_CHECKPOINT_FOR_DOC,
1489
+ output_type=TFTokenClassifierOutput,
1490
+ config_class=_CONFIG_FOR_DOC,
1491
+ )
1492
+ def call(
1493
+ self,
1494
+ input_ids: TFModelInputType | None = None,
1495
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1496
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1497
+ position_ids: np.ndarray | tf.Tensor | None = None,
1498
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1499
+ output_attentions: Optional[bool] = None,
1500
+ output_hidden_states: Optional[bool] = None,
1501
+ return_dict: Optional[bool] = None,
1502
+ labels: np.ndarray | tf.Tensor | None = None,
1503
+ training: Optional[bool] = False,
1504
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
1505
+ r"""
1506
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1507
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1508
+ """
1509
+ outputs = self.deberta(
1510
+ input_ids=input_ids,
1511
+ attention_mask=attention_mask,
1512
+ token_type_ids=token_type_ids,
1513
+ position_ids=position_ids,
1514
+ inputs_embeds=inputs_embeds,
1515
+ output_attentions=output_attentions,
1516
+ output_hidden_states=output_hidden_states,
1517
+ return_dict=return_dict,
1518
+ training=training,
1519
+ )
1520
+ sequence_output = outputs[0]
1521
+ sequence_output = self.dropout(sequence_output, training=training)
1522
+ logits = self.classifier(inputs=sequence_output)
1523
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1524
+
1525
+ if not return_dict:
1526
+ output = (logits,) + outputs[1:]
1527
+ return ((loss,) + output) if loss is not None else output
1528
+
1529
+ return TFTokenClassifierOutput(
1530
+ loss=loss,
1531
+ logits=logits,
1532
+ hidden_states=outputs.hidden_states,
1533
+ attentions=outputs.attentions,
1534
+ )
1535
+
1536
+ def build(self, input_shape=None):
1537
+ if self.built:
1538
+ return
1539
+ self.built = True
1540
+ if getattr(self, "deberta", None) is not None:
1541
+ with tf.name_scope(self.deberta.name):
1542
+ self.deberta.build(None)
1543
+ if getattr(self, "classifier", None) is not None:
1544
+ with tf.name_scope(self.classifier.name):
1545
+ self.classifier.build([None, None, self.config.hidden_size])
1546
+
1547
+
1548
+ @add_start_docstrings(
1549
+ """
1550
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1551
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1552
+ """,
1553
+ DEBERTA_START_DOCSTRING,
1554
+ )
1555
+ class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnsweringLoss):
1556
+ def __init__(self, config: DebertaConfig, *inputs, **kwargs):
1557
+ super().__init__(config, *inputs, **kwargs)
1558
+
1559
+ self.num_labels = config.num_labels
1560
+
1561
+ self.deberta = TFDebertaMainLayer(config, name="deberta")
1562
+ self.qa_outputs = keras.layers.Dense(
1563
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1564
+ )
1565
+ self.config = config
1566
+
1567
+ @unpack_inputs
1568
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1569
+ @add_code_sample_docstrings(
1570
+ checkpoint=_CHECKPOINT_FOR_DOC,
1571
+ output_type=TFQuestionAnsweringModelOutput,
1572
+ config_class=_CONFIG_FOR_DOC,
1573
+ )
1574
+ def call(
1575
+ self,
1576
+ input_ids: TFModelInputType | None = None,
1577
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1578
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1579
+ position_ids: np.ndarray | tf.Tensor | None = None,
1580
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1581
+ output_attentions: Optional[bool] = None,
1582
+ output_hidden_states: Optional[bool] = None,
1583
+ return_dict: Optional[bool] = None,
1584
+ start_positions: np.ndarray | tf.Tensor | None = None,
1585
+ end_positions: np.ndarray | tf.Tensor | None = None,
1586
+ training: Optional[bool] = False,
1587
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
1588
+ r"""
1589
+ start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1590
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1591
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1592
+ are not taken into account for computing the loss.
1593
+ end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1594
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1595
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1596
+ are not taken into account for computing the loss.
1597
+ """
1598
+ outputs = self.deberta(
1599
+ input_ids=input_ids,
1600
+ attention_mask=attention_mask,
1601
+ token_type_ids=token_type_ids,
1602
+ position_ids=position_ids,
1603
+ inputs_embeds=inputs_embeds,
1604
+ output_attentions=output_attentions,
1605
+ output_hidden_states=output_hidden_states,
1606
+ return_dict=return_dict,
1607
+ training=training,
1608
+ )
1609
+ sequence_output = outputs[0]
1610
+ logits = self.qa_outputs(inputs=sequence_output)
1611
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
1612
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
1613
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
1614
+ loss = None
1615
+
1616
+ if start_positions is not None and end_positions is not None:
1617
+ labels = {"start_position": start_positions}
1618
+ labels["end_position"] = end_positions
1619
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
1620
+
1621
+ if not return_dict:
1622
+ output = (start_logits, end_logits) + outputs[2:]
1623
+ return ((loss,) + output) if loss is not None else output
1624
+
1625
+ return TFQuestionAnsweringModelOutput(
1626
+ loss=loss,
1627
+ start_logits=start_logits,
1628
+ end_logits=end_logits,
1629
+ hidden_states=outputs.hidden_states,
1630
+ attentions=outputs.attentions,
1631
+ )
1632
+
1633
+ def build(self, input_shape=None):
1634
+ if self.built:
1635
+ return
1636
+ self.built = True
1637
+ if getattr(self, "deberta", None) is not None:
1638
+ with tf.name_scope(self.deberta.name):
1639
+ self.deberta.build(None)
1640
+ if getattr(self, "qa_outputs", None) is not None:
1641
+ with tf.name_scope(self.qa_outputs.name):
1642
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1643
+
1644
+
1645
+ __all__ = [
1646
+ "TFDebertaForMaskedLM",
1647
+ "TFDebertaForQuestionAnswering",
1648
+ "TFDebertaForSequenceClassification",
1649
+ "TFDebertaForTokenClassification",
1650
+ "TFDebertaModel",
1651
+ "TFDebertaPreTrainedModel",
1652
+ ]
docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization class for model DeBERTa."""
16
+
17
+ import json
18
+ import os
19
+ from typing import List, Optional, Tuple
20
+
21
+ import regex as re
22
+
23
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
30
+
31
+
32
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
33
+ def bytes_to_unicode():
34
+ """
35
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
36
+ characters the bpe code barfs on.
37
+
38
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
39
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
40
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
41
+ tables between utf-8 bytes and unicode strings.
42
+ """
43
+ bs = (
44
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
45
+ )
46
+ cs = bs[:]
47
+ n = 0
48
+ for b in range(2**8):
49
+ if b not in bs:
50
+ bs.append(b)
51
+ cs.append(2**8 + n)
52
+ n += 1
53
+ cs = [chr(n) for n in cs]
54
+ return dict(zip(bs, cs))
55
+
56
+
57
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
58
+ def get_pairs(word):
59
+ """
60
+ Return set of symbol pairs in a word.
61
+
62
+ Word is represented as tuple of symbols (symbols being variable-length strings).
63
+ """
64
+ pairs = set()
65
+ prev_char = word[0]
66
+ for char in word[1:]:
67
+ pairs.add((prev_char, char))
68
+ prev_char = char
69
+ return pairs
70
+
71
+
72
+ class DebertaTokenizer(PreTrainedTokenizer):
73
+ """
74
+ Construct a DeBERTa tokenizer. Based on byte-level Byte-Pair-Encoding.
75
+
76
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
77
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
78
+
79
+ ```python
80
+ >>> from transformers import DebertaTokenizer
81
+
82
+ >>> tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
83
+ >>> tokenizer("Hello world")["input_ids"]
84
+ [1, 31414, 232, 2]
85
+
86
+ >>> tokenizer(" Hello world")["input_ids"]
87
+ [1, 20920, 232, 2]
88
+ ```
89
+
90
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
91
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
92
+
93
+ <Tip>
94
+
95
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
96
+
97
+ </Tip>
98
+
99
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
100
+ this superclass for more information regarding those methods.
101
+
102
+ Args:
103
+ vocab_file (`str`):
104
+ Path to the vocabulary file.
105
+ merges_file (`str`):
106
+ Path to the merges file.
107
+ errors (`str`, *optional*, defaults to `"replace"`):
108
+ Paradigm to follow when decoding bytes to UTF-8. See
109
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
110
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
111
+ The beginning of sequence token.
112
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
113
+ The end of sequence token.
114
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
115
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
116
+ sequence classification or for a text and a question for question answering. It is also used as the last
117
+ token of a sequence built with special tokens.
118
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
119
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
120
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
121
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
122
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
123
+ token instead.
124
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
125
+ The token used for padding, for example when batching sequences of different lengths.
126
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
127
+ The token used for masking values. This is the token used when training this model with masked language
128
+ modeling. This is the token which the model will try to predict.
129
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
130
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
131
+ other word. (Deberta tokenizer detect beginning of words by the preceding space).
132
+ add_bos_token (`bool`, *optional*, defaults to `False`):
133
+ Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
134
+ any other word.
135
+ """
136
+
137
+ vocab_files_names = VOCAB_FILES_NAMES
138
+ model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
139
+
140
+ def __init__(
141
+ self,
142
+ vocab_file,
143
+ merges_file,
144
+ errors="replace",
145
+ bos_token="[CLS]",
146
+ eos_token="[SEP]",
147
+ sep_token="[SEP]",
148
+ cls_token="[CLS]",
149
+ unk_token="[UNK]",
150
+ pad_token="[PAD]",
151
+ mask_token="[MASK]",
152
+ add_prefix_space=False,
153
+ add_bos_token=False,
154
+ **kwargs,
155
+ ):
156
+ bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
157
+ eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
158
+ sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token
159
+ cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token
160
+ unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
161
+ pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
162
+
163
+ # Mask token behave like a normal word, i.e. include the space before it
164
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
165
+ self.add_bos_token = add_bos_token
166
+
167
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
168
+ self.encoder = json.load(vocab_handle)
169
+ self.decoder = {v: k for k, v in self.encoder.items()}
170
+ self.errors = errors # how to handle errors in decoding
171
+ self.byte_encoder = bytes_to_unicode()
172
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
173
+ with open(merges_file, encoding="utf-8") as merges_handle:
174
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
175
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
176
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
177
+ self.cache = {}
178
+ self.add_prefix_space = add_prefix_space
179
+
180
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
181
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
182
+
183
+ super().__init__(
184
+ errors=errors,
185
+ bos_token=bos_token,
186
+ eos_token=eos_token,
187
+ unk_token=unk_token,
188
+ sep_token=sep_token,
189
+ cls_token=cls_token,
190
+ pad_token=pad_token,
191
+ mask_token=mask_token,
192
+ add_prefix_space=add_prefix_space,
193
+ add_bos_token=add_bos_token,
194
+ **kwargs,
195
+ )
196
+
197
+ @property
198
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.vocab_size
199
+ def vocab_size(self):
200
+ return len(self.encoder)
201
+
202
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
203
+ def get_vocab(self):
204
+ return dict(self.encoder, **self.added_tokens_encoder)
205
+
206
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
207
+ def bpe(self, token):
208
+ if token in self.cache:
209
+ return self.cache[token]
210
+ word = tuple(token)
211
+ pairs = get_pairs(word)
212
+
213
+ if not pairs:
214
+ return token
215
+
216
+ while True:
217
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
218
+ if bigram not in self.bpe_ranks:
219
+ break
220
+ first, second = bigram
221
+ new_word = []
222
+ i = 0
223
+ while i < len(word):
224
+ try:
225
+ j = word.index(first, i)
226
+ except ValueError:
227
+ new_word.extend(word[i:])
228
+ break
229
+ else:
230
+ new_word.extend(word[i:j])
231
+ i = j
232
+
233
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
234
+ new_word.append(first + second)
235
+ i += 2
236
+ else:
237
+ new_word.append(word[i])
238
+ i += 1
239
+ new_word = tuple(new_word)
240
+ word = new_word
241
+ if len(word) == 1:
242
+ break
243
+ else:
244
+ pairs = get_pairs(word)
245
+ word = " ".join(word)
246
+ self.cache[token] = word
247
+ return word
248
+
249
+ def build_inputs_with_special_tokens(
250
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
251
+ ) -> List[int]:
252
+ """
253
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
254
+ adding special tokens. A DeBERTa sequence has the following format:
255
+
256
+ - single sequence: [CLS] X [SEP]
257
+ - pair of sequences: [CLS] A [SEP] B [SEP]
258
+
259
+ Args:
260
+ token_ids_0 (`List[int]`):
261
+ List of IDs to which the special tokens will be added.
262
+ token_ids_1 (`List[int]`, *optional*):
263
+ Optional second list of IDs for sequence pairs.
264
+
265
+ Returns:
266
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
267
+ """
268
+ if token_ids_1 is None:
269
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
270
+ cls = [self.cls_token_id]
271
+ sep = [self.sep_token_id]
272
+ return cls + token_ids_0 + sep + token_ids_1 + sep
273
+
274
+ def get_special_tokens_mask(
275
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
276
+ ) -> List[int]:
277
+ """
278
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
279
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
280
+
281
+ Args:
282
+ token_ids_0 (`List[int]`):
283
+ List of IDs.
284
+ token_ids_1 (`List[int]`, *optional*):
285
+ Optional second list of IDs for sequence pairs.
286
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
287
+ Whether or not the token list is already formatted with special tokens for the model.
288
+
289
+ Returns:
290
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
291
+ """
292
+ if already_has_special_tokens:
293
+ return super().get_special_tokens_mask(
294
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
295
+ )
296
+
297
+ if token_ids_1 is None:
298
+ return [1] + ([0] * len(token_ids_0)) + [1]
299
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
300
+
301
+ def create_token_type_ids_from_sequences(
302
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
303
+ ) -> List[int]:
304
+ """
305
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
306
+ sequence pair mask has the following format:
307
+
308
+ ```
309
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
310
+ | first sequence | second sequence |
311
+ ```
312
+
313
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
314
+
315
+ Args:
316
+ token_ids_0 (`List[int]`):
317
+ List of IDs.
318
+ token_ids_1 (`List[int]`, *optional*):
319
+ Optional second list of IDs for sequence pairs.
320
+
321
+ Returns:
322
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
323
+ """
324
+ sep = [self.sep_token_id]
325
+ cls = [self.cls_token_id]
326
+
327
+ if token_ids_1 is None:
328
+ return len(cls + token_ids_0 + sep) * [0]
329
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
330
+
331
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
332
+ def _tokenize(self, text):
333
+ """Tokenize a string."""
334
+ bpe_tokens = []
335
+ for token in re.findall(self.pat, text):
336
+ token = "".join(
337
+ self.byte_encoder[b] for b in token.encode("utf-8")
338
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
339
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
340
+ return bpe_tokens
341
+
342
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
343
+ def _convert_token_to_id(self, token):
344
+ """Converts a token (str) in an id using the vocab."""
345
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
346
+
347
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
348
+ def _convert_id_to_token(self, index):
349
+ """Converts an index (integer) in a token (str) using the vocab."""
350
+ return self.decoder.get(index)
351
+
352
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
353
+ def convert_tokens_to_string(self, tokens):
354
+ """Converts a sequence of tokens (string) in a single string."""
355
+ text = "".join(tokens)
356
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
357
+ return text
358
+
359
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
360
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
361
+ if not os.path.isdir(save_directory):
362
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
363
+ return
364
+ vocab_file = os.path.join(
365
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
366
+ )
367
+ merge_file = os.path.join(
368
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
369
+ )
370
+
371
+ with open(vocab_file, "w", encoding="utf-8") as f:
372
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
373
+
374
+ index = 0
375
+ with open(merge_file, "w", encoding="utf-8") as writer:
376
+ writer.write("#version: 0.2\n")
377
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
378
+ if index != token_index:
379
+ logger.warning(
380
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
381
+ " Please check that the tokenizer is not corrupted!"
382
+ )
383
+ index = token_index
384
+ writer.write(" ".join(bpe_tokens) + "\n")
385
+ index += 1
386
+
387
+ return vocab_file, merge_file
388
+
389
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
390
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
391
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
392
+ text = " " + text
393
+ return (text, kwargs)
394
+
395
+
396
+ __all__ = ["DebertaTokenizer"]
docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta_fast.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Tokenization class for model DeBERTa."""
16
+
17
+ from typing import List, Optional, Tuple
18
+
19
+ from ...tokenization_utils_base import AddedToken, BatchEncoding
20
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from ...utils import logging
22
+ from .tokenization_deberta import DebertaTokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
28
+
29
+
30
+ class DebertaTokenizerFast(PreTrainedTokenizerFast):
31
+ """
32
+ Construct a "fast" DeBERTa tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
33
+ Byte-Pair-Encoding.
34
+
35
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
36
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
37
+
38
+ ```python
39
+ >>> from transformers import DebertaTokenizerFast
40
+
41
+ >>> tokenizer = DebertaTokenizerFast.from_pretrained("microsoft/deberta-base")
42
+ >>> tokenizer("Hello world")["input_ids"]
43
+ [1, 31414, 232, 2]
44
+
45
+ >>> tokenizer(" Hello world")["input_ids"]
46
+ [1, 20920, 232, 2]
47
+ ```
48
+
49
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
50
+ the model was not pretrained this way, it might yield a decrease in performance.
51
+
52
+ <Tip>
53
+
54
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
55
+
56
+ </Tip>
57
+
58
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
59
+ refer to this superclass for more information regarding those methods.
60
+
61
+ Args:
62
+ vocab_file (`str`, *optional*):
63
+ Path to the vocabulary file.
64
+ merges_file (`str`, *optional*):
65
+ Path to the merges file.
66
+ tokenizer_file (`str`, *optional*):
67
+ The path to a tokenizer file to use instead of the vocab file.
68
+ errors (`str`, *optional*, defaults to `"replace"`):
69
+ Paradigm to follow when decoding bytes to UTF-8. See
70
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
71
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
72
+ The beginning of sequence token.
73
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
74
+ The end of sequence token.
75
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
76
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
77
+ sequence classification or for a text and a question for question answering. It is also used as the last
78
+ token of a sequence built with special tokens.
79
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
80
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
81
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
82
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
83
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84
+ token instead.
85
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
86
+ The token used for padding, for example when batching sequences of different lengths.
87
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
88
+ The token used for masking values. This is the token used when training this model with masked language
89
+ modeling. This is the token which the model will try to predict.
90
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
92
+ other word. (Deberta tokenizer detect beginning of words by the preceding space).
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
97
+ slow_tokenizer_class = DebertaTokenizer
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_file=None,
102
+ merges_file=None,
103
+ tokenizer_file=None,
104
+ errors="replace",
105
+ bos_token="[CLS]",
106
+ eos_token="[SEP]",
107
+ sep_token="[SEP]",
108
+ cls_token="[CLS]",
109
+ unk_token="[UNK]",
110
+ pad_token="[PAD]",
111
+ mask_token="[MASK]",
112
+ add_prefix_space=False,
113
+ **kwargs,
114
+ ):
115
+ super().__init__(
116
+ vocab_file,
117
+ merges_file,
118
+ tokenizer_file=tokenizer_file,
119
+ errors=errors,
120
+ bos_token=bos_token,
121
+ eos_token=eos_token,
122
+ unk_token=unk_token,
123
+ sep_token=sep_token,
124
+ cls_token=cls_token,
125
+ pad_token=pad_token,
126
+ mask_token=mask_token,
127
+ add_prefix_space=add_prefix_space,
128
+ **kwargs,
129
+ )
130
+ self.add_bos_token = kwargs.pop("add_bos_token", False)
131
+
132
+ @property
133
+ def mask_token(self) -> str:
134
+ """
135
+ `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
136
+ having been set.
137
+
138
+ Deberta tokenizer has a special mask token to be used in the fill-mask pipeline. The mask token will greedily
139
+ comprise the space before the *[MASK]*.
140
+ """
141
+ if self._mask_token is None:
142
+ if self.verbose:
143
+ logger.error("Using mask_token, but it is not set yet.")
144
+ return None
145
+ return str(self._mask_token)
146
+
147
+ @mask_token.setter
148
+ def mask_token(self, value):
149
+ """
150
+ Overriding the default behavior of the mask token to have it eat the space before it.
151
+ """
152
+ # Mask token behave like a normal word, i.e. include the space before it
153
+ # So we set lstrip to True
154
+ value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
155
+ self._mask_token = value
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
162
+ adding special tokens. A DeBERTa sequence has the following format:
163
+
164
+ - single sequence: [CLS] X [SEP]
165
+ - pair of sequences: [CLS] A [SEP] B [SEP]
166
+
167
+ Args:
168
+ token_ids_0 (`List[int]`):
169
+ List of IDs to which the special tokens will be added.
170
+ token_ids_1 (`List[int]`, *optional*):
171
+ Optional second list of IDs for sequence pairs.
172
+
173
+ Returns:
174
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
175
+ """
176
+ if token_ids_1 is None:
177
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
178
+ cls = [self.cls_token_id]
179
+ sep = [self.sep_token_id]
180
+ return cls + token_ids_0 + sep + token_ids_1 + sep
181
+
182
+ def create_token_type_ids_from_sequences(
183
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
184
+ ) -> List[int]:
185
+ """
186
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
187
+ sequence pair mask has the following format:
188
+
189
+ ```
190
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
191
+ | first sequence | second sequence |
192
+ ```
193
+
194
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+
202
+ Returns:
203
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
204
+ """
205
+ sep = [self.sep_token_id]
206
+ cls = [self.cls_token_id]
207
+
208
+ if token_ids_1 is None:
209
+ return len(cls + token_ids_0 + sep) * [0]
210
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
211
+
212
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._batch_encode_plus
213
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
214
+ is_split_into_words = kwargs.get("is_split_into_words", False)
215
+ assert self.add_prefix_space or not is_split_into_words, (
216
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
217
+ "to use it with pretokenized inputs."
218
+ )
219
+
220
+ return super()._batch_encode_plus(*args, **kwargs)
221
+
222
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._encode_plus
223
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
224
+ is_split_into_words = kwargs.get("is_split_into_words", False)
225
+
226
+ assert self.add_prefix_space or not is_split_into_words, (
227
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
228
+ "to use it with pretokenized inputs."
229
+ )
230
+
231
+ return super()._encode_plus(*args, **kwargs)
232
+
233
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
234
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
235
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
236
+ return tuple(files)
237
+
238
+
239
+ __all__ = ["DebertaTokenizerFast"]
docs/transformers/build/lib/transformers/models/deberta_v2/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_deberta_v2 import *
22
+ from .modeling_deberta_v2 import *
23
+ from .modeling_tf_deberta_v2 import *
24
+ from .tokenization_deberta_v2 import *
25
+ from .tokenization_deberta_v2_fast import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/deberta_v2/configuration_deberta_v2.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020, Microsoft and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """DeBERTa-v2 model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class DebertaV2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a
35
+ DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a
36
+ configuration with the defaults will yield a similar configuration to that of the DeBERTa
37
+ [microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Arguments:
43
+ vocab_size (`int`, *optional*, defaults to 128100):
44
+ Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by
45
+ the `inputs_ids` passed when calling [`DebertaV2Model`].
46
+ hidden_size (`int`, *optional*, defaults to 1536):
47
+ Dimensionality of the encoder layers and the pooler layer.
48
+ num_hidden_layers (`int`, *optional*, defaults to 24):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 24):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ intermediate_size (`int`, *optional*, defaults to 6144):
53
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
54
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
55
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"`
57
+ are supported.
58
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
61
+ The dropout ratio for the attention probabilities.
62
+ max_position_embeddings (`int`, *optional*, defaults to 512):
63
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
64
+ just in case (e.g., 512 or 1024 or 2048).
65
+ type_vocab_size (`int`, *optional*, defaults to 0):
66
+ The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ layer_norm_eps (`float`, *optional*, defaults to 1e-7):
70
+ The epsilon used by the layer normalization layers.
71
+ relative_attention (`bool`, *optional*, defaults to `True`):
72
+ Whether use relative position encoding.
73
+ max_relative_positions (`int`, *optional*, defaults to -1):
74
+ The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value
75
+ as `max_position_embeddings`.
76
+ pad_token_id (`int`, *optional*, defaults to 0):
77
+ The value used to pad input_ids.
78
+ position_biased_input (`bool`, *optional*, defaults to `True`):
79
+ Whether add absolute position embedding to content embedding.
80
+ pos_att_type (`List[str]`, *optional*):
81
+ The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
82
+ `["p2c", "c2p"]`, `["p2c", "c2p"]`.
83
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
84
+ The epsilon used by the layer normalization layers.
85
+ legacy (`bool`, *optional*, defaults to `True`):
86
+ Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly
87
+ for mask infilling tasks.
88
+
89
+ Example:
90
+
91
+ ```python
92
+ >>> from transformers import DebertaV2Config, DebertaV2Model
93
+
94
+ >>> # Initializing a DeBERTa-v2 microsoft/deberta-v2-xlarge style configuration
95
+ >>> configuration = DebertaV2Config()
96
+
97
+ >>> # Initializing a model (with random weights) from the microsoft/deberta-v2-xlarge style configuration
98
+ >>> model = DebertaV2Model(configuration)
99
+
100
+ >>> # Accessing the model configuration
101
+ >>> configuration = model.config
102
+ ```"""
103
+
104
+ model_type = "deberta-v2"
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=128100,
109
+ hidden_size=1536,
110
+ num_hidden_layers=24,
111
+ num_attention_heads=24,
112
+ intermediate_size=6144,
113
+ hidden_act="gelu",
114
+ hidden_dropout_prob=0.1,
115
+ attention_probs_dropout_prob=0.1,
116
+ max_position_embeddings=512,
117
+ type_vocab_size=0,
118
+ initializer_range=0.02,
119
+ layer_norm_eps=1e-7,
120
+ relative_attention=False,
121
+ max_relative_positions=-1,
122
+ pad_token_id=0,
123
+ position_biased_input=True,
124
+ pos_att_type=None,
125
+ pooler_dropout=0,
126
+ pooler_hidden_act="gelu",
127
+ legacy=True,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(**kwargs)
131
+
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.intermediate_size = intermediate_size
136
+ self.hidden_act = hidden_act
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.relative_attention = relative_attention
143
+ self.max_relative_positions = max_relative_positions
144
+ self.pad_token_id = pad_token_id
145
+ self.position_biased_input = position_biased_input
146
+
147
+ # Backwards compatibility
148
+ if isinstance(pos_att_type, str):
149
+ pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
150
+
151
+ self.pos_att_type = pos_att_type
152
+ self.vocab_size = vocab_size
153
+ self.layer_norm_eps = layer_norm_eps
154
+
155
+ self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
156
+ self.pooler_dropout = pooler_dropout
157
+ self.pooler_hidden_act = pooler_hidden_act
158
+ self.legacy = legacy
159
+
160
+
161
+ class DebertaV2OnnxConfig(OnnxConfig):
162
+ @property
163
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
164
+ if self.task == "multiple-choice":
165
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
166
+ else:
167
+ dynamic_axis = {0: "batch", 1: "sequence"}
168
+ if self._config.type_vocab_size > 0:
169
+ return OrderedDict(
170
+ [("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
171
+ )
172
+ else:
173
+ return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
174
+
175
+ @property
176
+ def default_onnx_opset(self) -> int:
177
+ return 12
178
+
179
+ def generate_dummy_inputs(
180
+ self,
181
+ preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
182
+ batch_size: int = -1,
183
+ seq_length: int = -1,
184
+ num_choices: int = -1,
185
+ is_pair: bool = False,
186
+ framework: Optional["TensorType"] = None,
187
+ num_channels: int = 3,
188
+ image_width: int = 40,
189
+ image_height: int = 40,
190
+ tokenizer: "PreTrainedTokenizerBase" = None,
191
+ ) -> Mapping[str, Any]:
192
+ dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
193
+ if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
194
+ del dummy_inputs["token_type_ids"]
195
+ return dummy_inputs
196
+
197
+
198
+ __all__ = ["DebertaV2Config", "DebertaV2OnnxConfig"]
docs/transformers/build/lib/transformers/models/deberta_v2/modeling_deberta_v2.py ADDED
@@ -0,0 +1,1523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DeBERTa-v2 model."""
16
+
17
+ from collections.abc import Sequence
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import (
27
+ BaseModelOutput,
28
+ MaskedLMOutput,
29
+ MultipleChoiceModelOutput,
30
+ QuestionAnsweringModelOutput,
31
+ SequenceClassifierOutput,
32
+ TokenClassifierOutput,
33
+ )
34
+ from ...modeling_utils import PreTrainedModel
35
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
36
+ from .configuration_deberta_v2 import DebertaV2Config
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CONFIG_FOR_DOC = "DebertaV2Config"
42
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
43
+ _QA_TARGET_START_INDEX = 2
44
+ _QA_TARGET_END_INDEX = 9
45
+
46
+
47
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
48
+ class DebertaV2SelfOutput(nn.Module):
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
52
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
53
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
54
+
55
+ def forward(self, hidden_states, input_tensor):
56
+ hidden_states = self.dense(hidden_states)
57
+ hidden_states = self.dropout(hidden_states)
58
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
59
+ return hidden_states
60
+
61
+
62
+ @torch.jit.script
63
+ def make_log_bucket_position(relative_pos, bucket_size: int, max_position: int):
64
+ sign = torch.sign(relative_pos)
65
+ mid = bucket_size // 2
66
+ abs_pos = torch.where(
67
+ (relative_pos < mid) & (relative_pos > -mid),
68
+ torch.tensor(mid - 1).type_as(relative_pos),
69
+ torch.abs(relative_pos),
70
+ )
71
+ log_pos = (
72
+ torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
73
+ )
74
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
75
+ return bucket_pos
76
+
77
+
78
+ def build_relative_position(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1):
79
+ """
80
+ Build relative position according to the query and key
81
+
82
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
83
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
84
+ P_k\\)
85
+
86
+ Args:
87
+ query_size (int): the length of query
88
+ key_size (int): the length of key
89
+ bucket_size (int): the size of position bucket
90
+ max_position (int): the maximum allowed absolute position
91
+ device (`torch.device`): the device on which tensors will be created.
92
+
93
+ Return:
94
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
95
+ """
96
+ query_size = query_layer.size(-2)
97
+ key_size = key_layer.size(-2)
98
+
99
+ q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
100
+ k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
101
+ rel_pos_ids = q_ids[:, None] - k_ids[None, :]
102
+ if bucket_size > 0 and max_position > 0:
103
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
104
+ rel_pos_ids = rel_pos_ids.to(torch.long)
105
+ rel_pos_ids = rel_pos_ids[:query_size, :]
106
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
107
+ return rel_pos_ids
108
+
109
+
110
+ @torch.jit.script
111
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
112
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
113
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
114
+
115
+
116
+ @torch.jit.script
117
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
118
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
119
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
120
+
121
+
122
+ @torch.jit.script
123
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
124
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
125
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
126
+
127
+
128
+ @torch.jit.script
129
+ def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
130
+ return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
131
+
132
+
133
+ @torch.jit.script
134
+ def build_rpos(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int):
135
+ if key_layer.size(-2) != query_layer.size(-2):
136
+ return build_relative_position(
137
+ key_layer,
138
+ key_layer,
139
+ bucket_size=position_buckets,
140
+ max_position=max_relative_positions,
141
+ )
142
+ else:
143
+ return relative_pos
144
+
145
+
146
+ class DisentangledSelfAttention(nn.Module):
147
+ """
148
+ Disentangled self-attention module
149
+
150
+ Parameters:
151
+ config (`DebertaV2Config`):
152
+ A model config class instance with the configuration to build a new model. The schema is similar to
153
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
154
+
155
+ """
156
+
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ if config.hidden_size % config.num_attention_heads != 0:
160
+ raise ValueError(
161
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
162
+ f"heads ({config.num_attention_heads})"
163
+ )
164
+ self.num_attention_heads = config.num_attention_heads
165
+ _attention_head_size = config.hidden_size // config.num_attention_heads
166
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
167
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
168
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
169
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
170
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
171
+
172
+ self.share_att_key = getattr(config, "share_att_key", False)
173
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
174
+ self.relative_attention = getattr(config, "relative_attention", False)
175
+
176
+ if self.relative_attention:
177
+ self.position_buckets = getattr(config, "position_buckets", -1)
178
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
179
+ if self.max_relative_positions < 1:
180
+ self.max_relative_positions = config.max_position_embeddings
181
+ self.pos_ebd_size = self.max_relative_positions
182
+ if self.position_buckets > 0:
183
+ self.pos_ebd_size = self.position_buckets
184
+
185
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
186
+
187
+ if not self.share_att_key:
188
+ if "c2p" in self.pos_att_type:
189
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
190
+ if "p2c" in self.pos_att_type:
191
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
192
+
193
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
194
+
195
+ def transpose_for_scores(self, x, attention_heads) -> torch.Tensor:
196
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
197
+ x = x.view(new_x_shape)
198
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
199
+
200
+ def forward(
201
+ self,
202
+ hidden_states,
203
+ attention_mask,
204
+ output_attentions=False,
205
+ query_states=None,
206
+ relative_pos=None,
207
+ rel_embeddings=None,
208
+ ):
209
+ """
210
+ Call the module
211
+
212
+ Args:
213
+ hidden_states (`torch.FloatTensor`):
214
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
215
+ *Attention(Q,K,V)*
216
+
217
+ attention_mask (`torch.BoolTensor`):
218
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
219
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
220
+ th token.
221
+
222
+ output_attentions (`bool`, *optional*):
223
+ Whether return the attention matrix.
224
+
225
+ query_states (`torch.FloatTensor`, *optional*):
226
+ The *Q* state in *Attention(Q,K,V)*.
227
+
228
+ relative_pos (`torch.LongTensor`):
229
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
230
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
231
+
232
+ rel_embeddings (`torch.FloatTensor`):
233
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
234
+ \\text{max_relative_positions}\\), *hidden_size*].
235
+
236
+
237
+ """
238
+ if query_states is None:
239
+ query_states = hidden_states
240
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
241
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
242
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
243
+
244
+ rel_att = None
245
+ # Take the dot product between "query" and "key" to get the raw attention scores.
246
+ scale_factor = 1
247
+ if "c2p" in self.pos_att_type:
248
+ scale_factor += 1
249
+ if "p2c" in self.pos_att_type:
250
+ scale_factor += 1
251
+ scale = scaled_size_sqrt(query_layer, scale_factor)
252
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
253
+ if self.relative_attention:
254
+ rel_embeddings = self.pos_dropout(rel_embeddings)
255
+ rel_att = self.disentangled_attention_bias(
256
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
257
+ )
258
+
259
+ if rel_att is not None:
260
+ attention_scores = attention_scores + rel_att
261
+ attention_scores = attention_scores
262
+ attention_scores = attention_scores.view(
263
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
264
+ )
265
+
266
+ attention_mask = attention_mask.bool()
267
+ attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
268
+ # bsz x height x length x dimension
269
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
270
+
271
+ attention_probs = self.dropout(attention_probs)
272
+ context_layer = torch.bmm(
273
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
274
+ )
275
+ context_layer = (
276
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
277
+ .permute(0, 2, 1, 3)
278
+ .contiguous()
279
+ )
280
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
281
+ context_layer = context_layer.view(new_context_layer_shape)
282
+ if not output_attentions:
283
+ return (context_layer, None)
284
+ return (context_layer, attention_probs)
285
+
286
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
287
+ if relative_pos is None:
288
+ relative_pos = build_relative_position(
289
+ query_layer,
290
+ key_layer,
291
+ bucket_size=self.position_buckets,
292
+ max_position=self.max_relative_positions,
293
+ )
294
+ if relative_pos.dim() == 2:
295
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
296
+ elif relative_pos.dim() == 3:
297
+ relative_pos = relative_pos.unsqueeze(1)
298
+ # bsz x height x query x key
299
+ elif relative_pos.dim() != 4:
300
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
301
+
302
+ att_span = self.pos_ebd_size
303
+ relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
304
+
305
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
306
+ if self.share_att_key:
307
+ pos_query_layer = self.transpose_for_scores(
308
+ self.query_proj(rel_embeddings), self.num_attention_heads
309
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
310
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
311
+ query_layer.size(0) // self.num_attention_heads, 1, 1
312
+ )
313
+ else:
314
+ if "c2p" in self.pos_att_type:
315
+ pos_key_layer = self.transpose_for_scores(
316
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
317
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
318
+ if "p2c" in self.pos_att_type:
319
+ pos_query_layer = self.transpose_for_scores(
320
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
321
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
322
+
323
+ score = 0
324
+ # content->position
325
+ if "c2p" in self.pos_att_type:
326
+ scale = scaled_size_sqrt(pos_key_layer, scale_factor)
327
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
328
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
329
+ c2p_att = torch.gather(
330
+ c2p_att,
331
+ dim=-1,
332
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
333
+ )
334
+ score += c2p_att / scale.to(dtype=c2p_att.dtype)
335
+
336
+ # position->content
337
+ if "p2c" in self.pos_att_type:
338
+ scale = scaled_size_sqrt(pos_query_layer, scale_factor)
339
+ r_pos = build_rpos(
340
+ query_layer,
341
+ key_layer,
342
+ relative_pos,
343
+ self.max_relative_positions,
344
+ self.position_buckets,
345
+ )
346
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
347
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
348
+ p2c_att = torch.gather(
349
+ p2c_att,
350
+ dim=-1,
351
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
352
+ ).transpose(-1, -2)
353
+ score += p2c_att / scale.to(dtype=p2c_att.dtype)
354
+
355
+ return score
356
+
357
+
358
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
359
+ class DebertaV2Attention(nn.Module):
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.self = DisentangledSelfAttention(config)
363
+ self.output = DebertaV2SelfOutput(config)
364
+ self.config = config
365
+
366
+ def forward(
367
+ self,
368
+ hidden_states,
369
+ attention_mask,
370
+ output_attentions: bool = False,
371
+ query_states=None,
372
+ relative_pos=None,
373
+ rel_embeddings=None,
374
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
375
+ self_output, att_matrix = self.self(
376
+ hidden_states,
377
+ attention_mask,
378
+ output_attentions,
379
+ query_states=query_states,
380
+ relative_pos=relative_pos,
381
+ rel_embeddings=rel_embeddings,
382
+ )
383
+ if query_states is None:
384
+ query_states = hidden_states
385
+ attention_output = self.output(self_output, query_states)
386
+
387
+ if output_attentions:
388
+ return (attention_output, att_matrix)
389
+ else:
390
+ return (attention_output, None)
391
+
392
+
393
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
394
+ class DebertaV2Intermediate(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
398
+ if isinstance(config.hidden_act, str):
399
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
400
+ else:
401
+ self.intermediate_act_fn = config.hidden_act
402
+
403
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
404
+ hidden_states = self.dense(hidden_states)
405
+ hidden_states = self.intermediate_act_fn(hidden_states)
406
+ return hidden_states
407
+
408
+
409
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
410
+ class DebertaV2Output(nn.Module):
411
+ def __init__(self, config):
412
+ super().__init__()
413
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
414
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
415
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
416
+ self.config = config
417
+
418
+ def forward(self, hidden_states, input_tensor):
419
+ hidden_states = self.dense(hidden_states)
420
+ hidden_states = self.dropout(hidden_states)
421
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
422
+ return hidden_states
423
+
424
+
425
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
426
+ class DebertaV2Layer(nn.Module):
427
+ def __init__(self, config):
428
+ super().__init__()
429
+ self.attention = DebertaV2Attention(config)
430
+ self.intermediate = DebertaV2Intermediate(config)
431
+ self.output = DebertaV2Output(config)
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states,
436
+ attention_mask,
437
+ query_states=None,
438
+ relative_pos=None,
439
+ rel_embeddings=None,
440
+ output_attentions: bool = False,
441
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
442
+ attention_output, att_matrix = self.attention(
443
+ hidden_states,
444
+ attention_mask,
445
+ output_attentions=output_attentions,
446
+ query_states=query_states,
447
+ relative_pos=relative_pos,
448
+ rel_embeddings=rel_embeddings,
449
+ )
450
+ intermediate_output = self.intermediate(attention_output)
451
+ layer_output = self.output(intermediate_output, attention_output)
452
+
453
+ if output_attentions:
454
+ return (layer_output, att_matrix)
455
+ else:
456
+ return (layer_output, None)
457
+
458
+
459
+ class ConvLayer(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ kernel_size = getattr(config, "conv_kernel_size", 3)
463
+ groups = getattr(config, "conv_groups", 1)
464
+ self.conv_act = getattr(config, "conv_act", "tanh")
465
+ self.conv = nn.Conv1d(
466
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
467
+ )
468
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
469
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
470
+ self.config = config
471
+
472
+ def forward(self, hidden_states, residual_states, input_mask):
473
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
474
+ rmask = (1 - input_mask).bool()
475
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
476
+ out = ACT2FN[self.conv_act](self.dropout(out))
477
+
478
+ layer_norm_input = residual_states + out
479
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
480
+
481
+ if input_mask is None:
482
+ output_states = output
483
+ else:
484
+ if input_mask.dim() != layer_norm_input.dim():
485
+ if input_mask.dim() == 4:
486
+ input_mask = input_mask.squeeze(1).squeeze(1)
487
+ input_mask = input_mask.unsqueeze(2)
488
+
489
+ input_mask = input_mask.to(output.dtype)
490
+ output_states = output * input_mask
491
+
492
+ return output_states
493
+
494
+
495
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2
496
+ class DebertaV2Embeddings(nn.Module):
497
+ """Construct the embeddings from word, position and token_type embeddings."""
498
+
499
+ def __init__(self, config):
500
+ super().__init__()
501
+ pad_token_id = getattr(config, "pad_token_id", 0)
502
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
503
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
504
+
505
+ self.position_biased_input = getattr(config, "position_biased_input", True)
506
+ if not self.position_biased_input:
507
+ self.position_embeddings = None
508
+ else:
509
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
510
+
511
+ if config.type_vocab_size > 0:
512
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
513
+ else:
514
+ self.token_type_embeddings = None
515
+
516
+ if self.embedding_size != config.hidden_size:
517
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
518
+ else:
519
+ self.embed_proj = None
520
+
521
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
522
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
523
+ self.config = config
524
+
525
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
526
+ self.register_buffer(
527
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
528
+ )
529
+
530
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
531
+ if input_ids is not None:
532
+ input_shape = input_ids.size()
533
+ else:
534
+ input_shape = inputs_embeds.size()[:-1]
535
+
536
+ seq_length = input_shape[1]
537
+
538
+ if position_ids is None:
539
+ position_ids = self.position_ids[:, :seq_length]
540
+
541
+ if token_type_ids is None:
542
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
543
+
544
+ if inputs_embeds is None:
545
+ inputs_embeds = self.word_embeddings(input_ids)
546
+
547
+ if self.position_embeddings is not None:
548
+ position_embeddings = self.position_embeddings(position_ids.long())
549
+ else:
550
+ position_embeddings = torch.zeros_like(inputs_embeds)
551
+
552
+ embeddings = inputs_embeds
553
+ if self.position_biased_input:
554
+ embeddings = embeddings + position_embeddings
555
+ if self.token_type_embeddings is not None:
556
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
557
+ embeddings = embeddings + token_type_embeddings
558
+
559
+ if self.embed_proj is not None:
560
+ embeddings = self.embed_proj(embeddings)
561
+
562
+ embeddings = self.LayerNorm(embeddings)
563
+
564
+ if mask is not None:
565
+ if mask.dim() != embeddings.dim():
566
+ if mask.dim() == 4:
567
+ mask = mask.squeeze(1).squeeze(1)
568
+ mask = mask.unsqueeze(2)
569
+ mask = mask.to(embeddings.dtype)
570
+
571
+ embeddings = embeddings * mask
572
+
573
+ embeddings = self.dropout(embeddings)
574
+ return embeddings
575
+
576
+
577
+ class DebertaV2Encoder(nn.Module):
578
+ """Modified BertEncoder with relative position bias support"""
579
+
580
+ def __init__(self, config):
581
+ super().__init__()
582
+
583
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
584
+ self.relative_attention = getattr(config, "relative_attention", False)
585
+
586
+ if self.relative_attention:
587
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
588
+ if self.max_relative_positions < 1:
589
+ self.max_relative_positions = config.max_position_embeddings
590
+
591
+ self.position_buckets = getattr(config, "position_buckets", -1)
592
+ pos_ebd_size = self.max_relative_positions * 2
593
+
594
+ if self.position_buckets > 0:
595
+ pos_ebd_size = self.position_buckets * 2
596
+
597
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
598
+
599
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
600
+
601
+ if "layer_norm" in self.norm_rel_ebd:
602
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
603
+
604
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
605
+ self.gradient_checkpointing = False
606
+
607
+ def get_rel_embedding(self):
608
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
609
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
610
+ rel_embeddings = self.LayerNorm(rel_embeddings)
611
+ return rel_embeddings
612
+
613
+ def get_attention_mask(self, attention_mask):
614
+ if attention_mask.dim() <= 2:
615
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
616
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
617
+ elif attention_mask.dim() == 3:
618
+ attention_mask = attention_mask.unsqueeze(1)
619
+
620
+ return attention_mask
621
+
622
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
623
+ if self.relative_attention and relative_pos is None:
624
+ if query_states is not None:
625
+ relative_pos = build_relative_position(
626
+ query_states,
627
+ hidden_states,
628
+ bucket_size=self.position_buckets,
629
+ max_position=self.max_relative_positions,
630
+ )
631
+ else:
632
+ relative_pos = build_relative_position(
633
+ hidden_states,
634
+ hidden_states,
635
+ bucket_size=self.position_buckets,
636
+ max_position=self.max_relative_positions,
637
+ )
638
+ return relative_pos
639
+
640
+ def forward(
641
+ self,
642
+ hidden_states,
643
+ attention_mask,
644
+ output_hidden_states=True,
645
+ output_attentions=False,
646
+ query_states=None,
647
+ relative_pos=None,
648
+ return_dict=True,
649
+ ):
650
+ if attention_mask.dim() <= 2:
651
+ input_mask = attention_mask
652
+ else:
653
+ input_mask = attention_mask.sum(-2) > 0
654
+ attention_mask = self.get_attention_mask(attention_mask)
655
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
656
+
657
+ all_hidden_states: Optional[Tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None
658
+ all_attentions = () if output_attentions else None
659
+
660
+ next_kv = hidden_states
661
+ rel_embeddings = self.get_rel_embedding()
662
+ for i, layer_module in enumerate(self.layer):
663
+ if self.gradient_checkpointing and self.training:
664
+ output_states, attn_weights = self._gradient_checkpointing_func(
665
+ layer_module.__call__,
666
+ next_kv,
667
+ attention_mask,
668
+ query_states,
669
+ relative_pos,
670
+ rel_embeddings,
671
+ output_attentions,
672
+ )
673
+ else:
674
+ output_states, attn_weights = layer_module(
675
+ next_kv,
676
+ attention_mask,
677
+ query_states=query_states,
678
+ relative_pos=relative_pos,
679
+ rel_embeddings=rel_embeddings,
680
+ output_attentions=output_attentions,
681
+ )
682
+
683
+ if output_attentions:
684
+ all_attentions = all_attentions + (attn_weights,)
685
+
686
+ if i == 0 and self.conv is not None:
687
+ output_states = self.conv(hidden_states, output_states, input_mask)
688
+
689
+ if output_hidden_states:
690
+ all_hidden_states = all_hidden_states + (output_states,)
691
+
692
+ if query_states is not None:
693
+ query_states = output_states
694
+ if isinstance(hidden_states, Sequence):
695
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
696
+ else:
697
+ next_kv = output_states
698
+
699
+ if not return_dict:
700
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
701
+ return BaseModelOutput(
702
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
703
+ )
704
+
705
+
706
+ class DebertaV2PreTrainedModel(PreTrainedModel):
707
+ """
708
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
709
+ models.
710
+ """
711
+
712
+ config_class = DebertaV2Config
713
+ base_model_prefix = "deberta"
714
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
715
+ supports_gradient_checkpointing = True
716
+
717
+ def _init_weights(self, module):
718
+ """Initialize the weights."""
719
+ if isinstance(module, nn.Linear):
720
+ # Slightly different from the TF version which uses truncated_normal for initialization
721
+ # cf https://github.com/pytorch/pytorch/pull/5617
722
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
723
+ if module.bias is not None:
724
+ module.bias.data.zero_()
725
+ elif isinstance(module, nn.Embedding):
726
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
727
+ if module.padding_idx is not None:
728
+ module.weight.data[module.padding_idx].zero_()
729
+ elif isinstance(module, nn.LayerNorm):
730
+ module.weight.data.fill_(1.0)
731
+ module.bias.data.zero_()
732
+ elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)):
733
+ module.bias.data.zero_()
734
+
735
+
736
+ DEBERTA_START_DOCSTRING = r"""
737
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
738
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
739
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
740
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
741
+
742
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
743
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
744
+ and behavior.
745
+
746
+
747
+ Parameters:
748
+ config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
749
+ Initializing with a config file does not load the weights associated with the model, only the
750
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
751
+ """
752
+
753
+ DEBERTA_INPUTS_DOCSTRING = r"""
754
+ Args:
755
+ input_ids (`torch.LongTensor` of shape `({0})`):
756
+ Indices of input sequence tokens in the vocabulary.
757
+
758
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
759
+ [`PreTrainedTokenizer.__call__`] for details.
760
+
761
+ [What are input IDs?](../glossary#input-ids)
762
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
763
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
764
+
765
+ - 1 for tokens that are **not masked**,
766
+ - 0 for tokens that are **masked**.
767
+
768
+ [What are attention masks?](../glossary#attention-mask)
769
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
770
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
771
+ 1]`:
772
+
773
+ - 0 corresponds to a *sentence A* token,
774
+ - 1 corresponds to a *sentence B* token.
775
+
776
+ [What are token type IDs?](../glossary#token-type-ids)
777
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
778
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
779
+ config.max_position_embeddings - 1]`.
780
+
781
+ [What are position IDs?](../glossary#position-ids)
782
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
783
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
784
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
785
+ model's internal embedding lookup matrix.
786
+ output_attentions (`bool`, *optional*):
787
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
788
+ tensors for more detail.
789
+ output_hidden_states (`bool`, *optional*):
790
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
791
+ more detail.
792
+ return_dict (`bool`, *optional*):
793
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
794
+ """
795
+
796
+
797
+ @add_start_docstrings(
798
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
799
+ DEBERTA_START_DOCSTRING,
800
+ )
801
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
802
+ class DebertaV2Model(DebertaV2PreTrainedModel):
803
+ def __init__(self, config):
804
+ super().__init__(config)
805
+
806
+ self.embeddings = DebertaV2Embeddings(config)
807
+ self.encoder = DebertaV2Encoder(config)
808
+ self.z_steps = 0
809
+ self.config = config
810
+ # Initialize weights and apply final processing
811
+ self.post_init()
812
+
813
+ def get_input_embeddings(self):
814
+ return self.embeddings.word_embeddings
815
+
816
+ def set_input_embeddings(self, new_embeddings):
817
+ self.embeddings.word_embeddings = new_embeddings
818
+
819
+ def _prune_heads(self, heads_to_prune):
820
+ """
821
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
822
+ class PreTrainedModel
823
+ """
824
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
825
+
826
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
827
+ @add_code_sample_docstrings(
828
+ checkpoint=_CHECKPOINT_FOR_DOC,
829
+ output_type=BaseModelOutput,
830
+ config_class=_CONFIG_FOR_DOC,
831
+ )
832
+ def forward(
833
+ self,
834
+ input_ids: Optional[torch.Tensor] = None,
835
+ attention_mask: Optional[torch.Tensor] = None,
836
+ token_type_ids: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.Tensor] = None,
838
+ inputs_embeds: Optional[torch.Tensor] = None,
839
+ output_attentions: Optional[bool] = None,
840
+ output_hidden_states: Optional[bool] = None,
841
+ return_dict: Optional[bool] = None,
842
+ ) -> Union[Tuple, BaseModelOutput]:
843
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
844
+ output_hidden_states = (
845
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
846
+ )
847
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
848
+
849
+ if input_ids is not None and inputs_embeds is not None:
850
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
851
+ elif input_ids is not None:
852
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
853
+ input_shape = input_ids.size()
854
+ elif inputs_embeds is not None:
855
+ input_shape = inputs_embeds.size()[:-1]
856
+ else:
857
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
858
+
859
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
860
+
861
+ if attention_mask is None:
862
+ attention_mask = torch.ones(input_shape, device=device)
863
+ if token_type_ids is None:
864
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
865
+
866
+ embedding_output = self.embeddings(
867
+ input_ids=input_ids,
868
+ token_type_ids=token_type_ids,
869
+ position_ids=position_ids,
870
+ mask=attention_mask,
871
+ inputs_embeds=inputs_embeds,
872
+ )
873
+
874
+ encoder_outputs = self.encoder(
875
+ embedding_output,
876
+ attention_mask,
877
+ output_hidden_states=True,
878
+ output_attentions=output_attentions,
879
+ return_dict=return_dict,
880
+ )
881
+ encoded_layers = encoder_outputs[1]
882
+
883
+ if self.z_steps > 1:
884
+ hidden_states = encoded_layers[-2]
885
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
886
+ query_states = encoded_layers[-1]
887
+ rel_embeddings = self.encoder.get_rel_embedding()
888
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
889
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
890
+ for layer in layers[1:]:
891
+ query_states = layer(
892
+ hidden_states,
893
+ attention_mask,
894
+ output_attentions=False,
895
+ query_states=query_states,
896
+ relative_pos=rel_pos,
897
+ rel_embeddings=rel_embeddings,
898
+ )
899
+ encoded_layers.append(query_states)
900
+
901
+ sequence_output = encoded_layers[-1]
902
+
903
+ if not return_dict:
904
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
905
+
906
+ return BaseModelOutput(
907
+ last_hidden_state=sequence_output,
908
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
909
+ attentions=encoder_outputs.attentions,
910
+ )
911
+
912
+
913
+ # Copied from transformers.models.deberta.modeling_deberta.LegacyDebertaPredictionHeadTransform with Deberta->DebertaV2
914
+ class LegacyDebertaV2PredictionHeadTransform(nn.Module):
915
+ def __init__(self, config):
916
+ super().__init__()
917
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
918
+
919
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
920
+ if isinstance(config.hidden_act, str):
921
+ self.transform_act_fn = ACT2FN[config.hidden_act]
922
+ else:
923
+ self.transform_act_fn = config.hidden_act
924
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
925
+
926
+ def forward(self, hidden_states):
927
+ hidden_states = self.dense(hidden_states)
928
+ hidden_states = self.transform_act_fn(hidden_states)
929
+ hidden_states = self.LayerNorm(hidden_states)
930
+ return hidden_states
931
+
932
+
933
+ class LegacyDebertaV2LMPredictionHead(nn.Module):
934
+ def __init__(self, config):
935
+ super().__init__()
936
+ self.transform = LegacyDebertaV2PredictionHeadTransform(config)
937
+
938
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
939
+ # The output weights are the same as the input embeddings, but there is
940
+ # an output-only bias for each token.
941
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
942
+
943
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
944
+
945
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
946
+ self.decoder.bias = self.bias
947
+
948
+ def _tie_weights(self):
949
+ self.decoder.bias = self.bias
950
+
951
+ def forward(self, hidden_states):
952
+ hidden_states = self.transform(hidden_states)
953
+ hidden_states = self.decoder(hidden_states)
954
+ return hidden_states
955
+
956
+
957
+ class LegacyDebertaV2OnlyMLMHead(nn.Module):
958
+ def __init__(self, config):
959
+ super().__init__()
960
+ self.predictions = LegacyDebertaV2LMPredictionHead(config)
961
+
962
+ def forward(self, sequence_output):
963
+ prediction_scores = self.predictions(sequence_output)
964
+ return prediction_scores
965
+
966
+
967
+ class DebertaV2LMPredictionHead(nn.Module):
968
+ """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
969
+
970
+ def __init__(self, config):
971
+ super().__init__()
972
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
973
+
974
+ if isinstance(config.hidden_act, str):
975
+ self.transform_act_fn = ACT2FN[config.hidden_act]
976
+ else:
977
+ self.transform_act_fn = config.hidden_act
978
+
979
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
980
+
981
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
982
+
983
+ # note that the input embeddings must be passed as an argument
984
+ def forward(self, hidden_states, word_embeddings):
985
+ hidden_states = self.dense(hidden_states)
986
+ hidden_states = self.transform_act_fn(hidden_states)
987
+ hidden_states = self.LayerNorm(hidden_states)
988
+ hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
989
+ return hidden_states
990
+
991
+
992
+ class DebertaV2OnlyMLMHead(nn.Module):
993
+ def __init__(self, config):
994
+ super().__init__()
995
+ self.lm_head = DebertaV2LMPredictionHead(config)
996
+
997
+ # note that the input embeddings must be passed as an argument
998
+ def forward(self, sequence_output, word_embeddings):
999
+ prediction_scores = self.lm_head(sequence_output, word_embeddings)
1000
+ return prediction_scores
1001
+
1002
+
1003
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1004
+ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1005
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
1006
+ _keys_to_ignore_on_load_unexpected = r"mask_predictions.*"
1007
+
1008
+ def __init__(self, config):
1009
+ super().__init__(config)
1010
+ self.legacy = config.legacy
1011
+ self.deberta = DebertaV2Model(config)
1012
+ if self.legacy:
1013
+ self.cls = LegacyDebertaV2OnlyMLMHead(config)
1014
+ else:
1015
+ self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"]
1016
+ self.lm_predictions = DebertaV2OnlyMLMHead(config)
1017
+ # Initialize weights and apply final processing
1018
+ self.post_init()
1019
+
1020
+ def get_output_embeddings(self):
1021
+ if self.legacy:
1022
+ return self.cls.predictions.decoder
1023
+ else:
1024
+ return self.lm_predictions.lm_head.dense
1025
+
1026
+ def set_output_embeddings(self, new_embeddings):
1027
+ if self.legacy:
1028
+ self.cls.predictions.decoder = new_embeddings
1029
+ self.cls.predictions.bias = new_embeddings.bias
1030
+ else:
1031
+ self.lm_predictions.lm_head.dense = new_embeddings
1032
+ self.lm_predictions.lm_head.bias = new_embeddings.bias
1033
+
1034
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1035
+ @add_code_sample_docstrings(
1036
+ checkpoint=_CHECKPOINT_FOR_DOC,
1037
+ output_type=MaskedLMOutput,
1038
+ config_class=_CONFIG_FOR_DOC,
1039
+ mask="[MASK]",
1040
+ )
1041
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2
1042
+ def forward(
1043
+ self,
1044
+ input_ids: Optional[torch.Tensor] = None,
1045
+ attention_mask: Optional[torch.Tensor] = None,
1046
+ token_type_ids: Optional[torch.Tensor] = None,
1047
+ position_ids: Optional[torch.Tensor] = None,
1048
+ inputs_embeds: Optional[torch.Tensor] = None,
1049
+ labels: Optional[torch.Tensor] = None,
1050
+ output_attentions: Optional[bool] = None,
1051
+ output_hidden_states: Optional[bool] = None,
1052
+ return_dict: Optional[bool] = None,
1053
+ ) -> Union[Tuple, MaskedLMOutput]:
1054
+ r"""
1055
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1056
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1057
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1058
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1059
+ """
1060
+
1061
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1062
+
1063
+ outputs = self.deberta(
1064
+ input_ids,
1065
+ attention_mask=attention_mask,
1066
+ token_type_ids=token_type_ids,
1067
+ position_ids=position_ids,
1068
+ inputs_embeds=inputs_embeds,
1069
+ output_attentions=output_attentions,
1070
+ output_hidden_states=output_hidden_states,
1071
+ return_dict=return_dict,
1072
+ )
1073
+
1074
+ sequence_output = outputs[0]
1075
+ if self.legacy:
1076
+ prediction_scores = self.cls(sequence_output)
1077
+ else:
1078
+ prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
1079
+
1080
+ masked_lm_loss = None
1081
+ if labels is not None:
1082
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1083
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1084
+
1085
+ if not return_dict:
1086
+ output = (prediction_scores,) + outputs[1:]
1087
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1088
+
1089
+ return MaskedLMOutput(
1090
+ loss=masked_lm_loss,
1091
+ logits=prediction_scores,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ )
1095
+
1096
+
1097
+ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
1098
+ class ContextPooler(nn.Module):
1099
+ def __init__(self, config):
1100
+ super().__init__()
1101
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
1102
+ self.dropout = nn.Dropout(config.pooler_dropout)
1103
+ self.config = config
1104
+
1105
+ def forward(self, hidden_states):
1106
+ # We "pool" the model by simply taking the hidden state corresponding
1107
+ # to the first token.
1108
+
1109
+ context_token = hidden_states[:, 0]
1110
+ context_token = self.dropout(context_token)
1111
+ pooled_output = self.dense(context_token)
1112
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
1113
+ return pooled_output
1114
+
1115
+ @property
1116
+ def output_dim(self):
1117
+ return self.config.hidden_size
1118
+
1119
+
1120
+ @add_start_docstrings(
1121
+ """
1122
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1123
+ pooled output) e.g. for GLUE tasks.
1124
+ """,
1125
+ DEBERTA_START_DOCSTRING,
1126
+ )
1127
+ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
1128
+ def __init__(self, config):
1129
+ super().__init__(config)
1130
+
1131
+ num_labels = getattr(config, "num_labels", 2)
1132
+ self.num_labels = num_labels
1133
+
1134
+ self.deberta = DebertaV2Model(config)
1135
+ self.pooler = ContextPooler(config)
1136
+ output_dim = self.pooler.output_dim
1137
+
1138
+ self.classifier = nn.Linear(output_dim, num_labels)
1139
+ drop_out = getattr(config, "cls_dropout", None)
1140
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1141
+ self.dropout = nn.Dropout(drop_out)
1142
+
1143
+ # Initialize weights and apply final processing
1144
+ self.post_init()
1145
+
1146
+ def get_input_embeddings(self):
1147
+ return self.deberta.get_input_embeddings()
1148
+
1149
+ def set_input_embeddings(self, new_embeddings):
1150
+ self.deberta.set_input_embeddings(new_embeddings)
1151
+
1152
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1153
+ @add_code_sample_docstrings(
1154
+ checkpoint=_CHECKPOINT_FOR_DOC,
1155
+ output_type=SequenceClassifierOutput,
1156
+ config_class=_CONFIG_FOR_DOC,
1157
+ )
1158
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2
1159
+ def forward(
1160
+ self,
1161
+ input_ids: Optional[torch.Tensor] = None,
1162
+ attention_mask: Optional[torch.Tensor] = None,
1163
+ token_type_ids: Optional[torch.Tensor] = None,
1164
+ position_ids: Optional[torch.Tensor] = None,
1165
+ inputs_embeds: Optional[torch.Tensor] = None,
1166
+ labels: Optional[torch.Tensor] = None,
1167
+ output_attentions: Optional[bool] = None,
1168
+ output_hidden_states: Optional[bool] = None,
1169
+ return_dict: Optional[bool] = None,
1170
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1171
+ r"""
1172
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1173
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1174
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1175
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1176
+ """
1177
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1178
+
1179
+ outputs = self.deberta(
1180
+ input_ids,
1181
+ token_type_ids=token_type_ids,
1182
+ attention_mask=attention_mask,
1183
+ position_ids=position_ids,
1184
+ inputs_embeds=inputs_embeds,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ )
1189
+
1190
+ encoder_layer = outputs[0]
1191
+ pooled_output = self.pooler(encoder_layer)
1192
+ pooled_output = self.dropout(pooled_output)
1193
+ logits = self.classifier(pooled_output)
1194
+
1195
+ loss = None
1196
+ if labels is not None:
1197
+ if self.config.problem_type is None:
1198
+ if self.num_labels == 1:
1199
+ # regression task
1200
+ loss_fn = nn.MSELoss()
1201
+ logits = logits.view(-1).to(labels.dtype)
1202
+ loss = loss_fn(logits, labels.view(-1))
1203
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1204
+ label_index = (labels >= 0).nonzero()
1205
+ labels = labels.long()
1206
+ if label_index.size(0) > 0:
1207
+ labeled_logits = torch.gather(
1208
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1209
+ )
1210
+ labels = torch.gather(labels, 0, label_index.view(-1))
1211
+ loss_fct = CrossEntropyLoss()
1212
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1213
+ else:
1214
+ loss = torch.tensor(0).to(logits)
1215
+ else:
1216
+ log_softmax = nn.LogSoftmax(-1)
1217
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1218
+ elif self.config.problem_type == "regression":
1219
+ loss_fct = MSELoss()
1220
+ if self.num_labels == 1:
1221
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1222
+ else:
1223
+ loss = loss_fct(logits, labels)
1224
+ elif self.config.problem_type == "single_label_classification":
1225
+ loss_fct = CrossEntropyLoss()
1226
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1227
+ elif self.config.problem_type == "multi_label_classification":
1228
+ loss_fct = BCEWithLogitsLoss()
1229
+ loss = loss_fct(logits, labels)
1230
+ if not return_dict:
1231
+ output = (logits,) + outputs[1:]
1232
+ return ((loss,) + output) if loss is not None else output
1233
+
1234
+ return SequenceClassifierOutput(
1235
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1236
+ )
1237
+
1238
+
1239
+ @add_start_docstrings(
1240
+ """
1241
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1242
+ Named-Entity-Recognition (NER) tasks.
1243
+ """,
1244
+ DEBERTA_START_DOCSTRING,
1245
+ )
1246
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
1247
+ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
1248
+ def __init__(self, config):
1249
+ super().__init__(config)
1250
+ self.num_labels = config.num_labels
1251
+
1252
+ self.deberta = DebertaV2Model(config)
1253
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1254
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1255
+
1256
+ # Initialize weights and apply final processing
1257
+ self.post_init()
1258
+
1259
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1260
+ @add_code_sample_docstrings(
1261
+ checkpoint=_CHECKPOINT_FOR_DOC,
1262
+ output_type=TokenClassifierOutput,
1263
+ config_class=_CONFIG_FOR_DOC,
1264
+ )
1265
+ def forward(
1266
+ self,
1267
+ input_ids: Optional[torch.Tensor] = None,
1268
+ attention_mask: Optional[torch.Tensor] = None,
1269
+ token_type_ids: Optional[torch.Tensor] = None,
1270
+ position_ids: Optional[torch.Tensor] = None,
1271
+ inputs_embeds: Optional[torch.Tensor] = None,
1272
+ labels: Optional[torch.Tensor] = None,
1273
+ output_attentions: Optional[bool] = None,
1274
+ output_hidden_states: Optional[bool] = None,
1275
+ return_dict: Optional[bool] = None,
1276
+ ) -> Union[Tuple, TokenClassifierOutput]:
1277
+ r"""
1278
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1279
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1280
+ """
1281
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1282
+
1283
+ outputs = self.deberta(
1284
+ input_ids,
1285
+ attention_mask=attention_mask,
1286
+ token_type_ids=token_type_ids,
1287
+ position_ids=position_ids,
1288
+ inputs_embeds=inputs_embeds,
1289
+ output_attentions=output_attentions,
1290
+ output_hidden_states=output_hidden_states,
1291
+ return_dict=return_dict,
1292
+ )
1293
+
1294
+ sequence_output = outputs[0]
1295
+
1296
+ sequence_output = self.dropout(sequence_output)
1297
+ logits = self.classifier(sequence_output)
1298
+
1299
+ loss = None
1300
+ if labels is not None:
1301
+ loss_fct = CrossEntropyLoss()
1302
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1303
+
1304
+ if not return_dict:
1305
+ output = (logits,) + outputs[1:]
1306
+ return ((loss,) + output) if loss is not None else output
1307
+
1308
+ return TokenClassifierOutput(
1309
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1310
+ )
1311
+
1312
+
1313
+ @add_start_docstrings(
1314
+ """
1315
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1316
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1317
+ """,
1318
+ DEBERTA_START_DOCSTRING,
1319
+ )
1320
+ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
1321
+ def __init__(self, config):
1322
+ super().__init__(config)
1323
+ self.num_labels = config.num_labels
1324
+
1325
+ self.deberta = DebertaV2Model(config)
1326
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1327
+
1328
+ # Initialize weights and apply final processing
1329
+ self.post_init()
1330
+
1331
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1332
+ @add_code_sample_docstrings(
1333
+ checkpoint=_CHECKPOINT_FOR_DOC,
1334
+ output_type=QuestionAnsweringModelOutput,
1335
+ config_class=_CONFIG_FOR_DOC,
1336
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1337
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1338
+ )
1339
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2
1340
+ def forward(
1341
+ self,
1342
+ input_ids: Optional[torch.Tensor] = None,
1343
+ attention_mask: Optional[torch.Tensor] = None,
1344
+ token_type_ids: Optional[torch.Tensor] = None,
1345
+ position_ids: Optional[torch.Tensor] = None,
1346
+ inputs_embeds: Optional[torch.Tensor] = None,
1347
+ start_positions: Optional[torch.Tensor] = None,
1348
+ end_positions: Optional[torch.Tensor] = None,
1349
+ output_attentions: Optional[bool] = None,
1350
+ output_hidden_states: Optional[bool] = None,
1351
+ return_dict: Optional[bool] = None,
1352
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1353
+ r"""
1354
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1356
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1357
+ are not taken into account for computing the loss.
1358
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1359
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1360
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1361
+ are not taken into account for computing the loss.
1362
+ """
1363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1364
+
1365
+ outputs = self.deberta(
1366
+ input_ids,
1367
+ attention_mask=attention_mask,
1368
+ token_type_ids=token_type_ids,
1369
+ position_ids=position_ids,
1370
+ inputs_embeds=inputs_embeds,
1371
+ output_attentions=output_attentions,
1372
+ output_hidden_states=output_hidden_states,
1373
+ return_dict=return_dict,
1374
+ )
1375
+
1376
+ sequence_output = outputs[0]
1377
+
1378
+ logits = self.qa_outputs(sequence_output)
1379
+ start_logits, end_logits = logits.split(1, dim=-1)
1380
+ start_logits = start_logits.squeeze(-1).contiguous()
1381
+ end_logits = end_logits.squeeze(-1).contiguous()
1382
+
1383
+ total_loss = None
1384
+ if start_positions is not None and end_positions is not None:
1385
+ # If we are on multi-GPU, split add a dimension
1386
+ if len(start_positions.size()) > 1:
1387
+ start_positions = start_positions.squeeze(-1)
1388
+ if len(end_positions.size()) > 1:
1389
+ end_positions = end_positions.squeeze(-1)
1390
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1391
+ ignored_index = start_logits.size(1)
1392
+ start_positions = start_positions.clamp(0, ignored_index)
1393
+ end_positions = end_positions.clamp(0, ignored_index)
1394
+
1395
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1396
+ start_loss = loss_fct(start_logits, start_positions)
1397
+ end_loss = loss_fct(end_logits, end_positions)
1398
+ total_loss = (start_loss + end_loss) / 2
1399
+
1400
+ if not return_dict:
1401
+ output = (start_logits, end_logits) + outputs[1:]
1402
+ return ((total_loss,) + output) if total_loss is not None else output
1403
+
1404
+ return QuestionAnsweringModelOutput(
1405
+ loss=total_loss,
1406
+ start_logits=start_logits,
1407
+ end_logits=end_logits,
1408
+ hidden_states=outputs.hidden_states,
1409
+ attentions=outputs.attentions,
1410
+ )
1411
+
1412
+
1413
+ @add_start_docstrings(
1414
+ """
1415
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1416
+ softmax) e.g. for RocStories/SWAG tasks.
1417
+ """,
1418
+ DEBERTA_START_DOCSTRING,
1419
+ )
1420
+ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
1421
+ def __init__(self, config):
1422
+ super().__init__(config)
1423
+
1424
+ num_labels = getattr(config, "num_labels", 2)
1425
+ self.num_labels = num_labels
1426
+
1427
+ self.deberta = DebertaV2Model(config)
1428
+ self.pooler = ContextPooler(config)
1429
+ output_dim = self.pooler.output_dim
1430
+
1431
+ self.classifier = nn.Linear(output_dim, 1)
1432
+ drop_out = getattr(config, "cls_dropout", None)
1433
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1434
+ self.dropout = nn.Dropout(drop_out)
1435
+
1436
+ self.init_weights()
1437
+
1438
+ def get_input_embeddings(self):
1439
+ return self.deberta.get_input_embeddings()
1440
+
1441
+ def set_input_embeddings(self, new_embeddings):
1442
+ self.deberta.set_input_embeddings(new_embeddings)
1443
+
1444
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1445
+ @add_code_sample_docstrings(
1446
+ checkpoint=_CHECKPOINT_FOR_DOC,
1447
+ output_type=MultipleChoiceModelOutput,
1448
+ config_class=_CONFIG_FOR_DOC,
1449
+ )
1450
+ def forward(
1451
+ self,
1452
+ input_ids: Optional[torch.Tensor] = None,
1453
+ attention_mask: Optional[torch.Tensor] = None,
1454
+ token_type_ids: Optional[torch.Tensor] = None,
1455
+ position_ids: Optional[torch.Tensor] = None,
1456
+ inputs_embeds: Optional[torch.Tensor] = None,
1457
+ labels: Optional[torch.Tensor] = None,
1458
+ output_attentions: Optional[bool] = None,
1459
+ output_hidden_states: Optional[bool] = None,
1460
+ return_dict: Optional[bool] = None,
1461
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1462
+ r"""
1463
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1464
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1465
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1466
+ `input_ids` above)
1467
+ """
1468
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1469
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1470
+
1471
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1472
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1473
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1474
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1475
+ flat_inputs_embeds = (
1476
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1477
+ if inputs_embeds is not None
1478
+ else None
1479
+ )
1480
+
1481
+ outputs = self.deberta(
1482
+ flat_input_ids,
1483
+ position_ids=flat_position_ids,
1484
+ token_type_ids=flat_token_type_ids,
1485
+ attention_mask=flat_attention_mask,
1486
+ inputs_embeds=flat_inputs_embeds,
1487
+ output_attentions=output_attentions,
1488
+ output_hidden_states=output_hidden_states,
1489
+ return_dict=return_dict,
1490
+ )
1491
+
1492
+ encoder_layer = outputs[0]
1493
+ pooled_output = self.pooler(encoder_layer)
1494
+ pooled_output = self.dropout(pooled_output)
1495
+ logits = self.classifier(pooled_output)
1496
+ reshaped_logits = logits.view(-1, num_choices)
1497
+
1498
+ loss = None
1499
+ if labels is not None:
1500
+ loss_fct = CrossEntropyLoss()
1501
+ loss = loss_fct(reshaped_logits, labels)
1502
+
1503
+ if not return_dict:
1504
+ output = (reshaped_logits,) + outputs[1:]
1505
+ return ((loss,) + output) if loss is not None else output
1506
+
1507
+ return MultipleChoiceModelOutput(
1508
+ loss=loss,
1509
+ logits=reshaped_logits,
1510
+ hidden_states=outputs.hidden_states,
1511
+ attentions=outputs.attentions,
1512
+ )
1513
+
1514
+
1515
+ __all__ = [
1516
+ "DebertaV2ForMaskedLM",
1517
+ "DebertaV2ForMultipleChoice",
1518
+ "DebertaV2ForQuestionAnswering",
1519
+ "DebertaV2ForSequenceClassification",
1520
+ "DebertaV2ForTokenClassification",
1521
+ "DebertaV2Model",
1522
+ "DebertaV2PreTrainedModel",
1523
+ ]
docs/transformers/build/lib/transformers/models/deberta_v2/modeling_tf_deberta_v2.py ADDED
@@ -0,0 +1,1881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 DeBERTa-v2 model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Dict, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import tensorflow as tf
23
+
24
+ from ...activations_tf import get_tf_activation
25
+ from ...modeling_tf_outputs import (
26
+ TFBaseModelOutput,
27
+ TFMaskedLMOutput,
28
+ TFMultipleChoiceModelOutput,
29
+ TFQuestionAnsweringModelOutput,
30
+ TFSequenceClassifierOutput,
31
+ TFTokenClassifierOutput,
32
+ )
33
+ from ...modeling_tf_utils import (
34
+ TFMaskedLanguageModelingLoss,
35
+ TFModelInputType,
36
+ TFMultipleChoiceLoss,
37
+ TFPreTrainedModel,
38
+ TFQuestionAnsweringLoss,
39
+ TFSequenceClassificationLoss,
40
+ TFTokenClassificationLoss,
41
+ get_initializer,
42
+ keras,
43
+ unpack_inputs,
44
+ )
45
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
46
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
47
+ from .configuration_deberta_v2 import DebertaV2Config
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "DebertaV2Config"
53
+ _CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge"
54
+
55
+
56
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaContextPooler with Deberta->DebertaV2
57
+ class TFDebertaV2ContextPooler(keras.layers.Layer):
58
+ def __init__(self, config: DebertaV2Config, **kwargs):
59
+ super().__init__(**kwargs)
60
+ self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense")
61
+ self.dropout = TFDebertaV2StableDropout(config.pooler_dropout, name="dropout")
62
+ self.config = config
63
+
64
+ def call(self, hidden_states, training: bool = False):
65
+ # We "pool" the model by simply taking the hidden state corresponding
66
+ # to the first token.
67
+ context_token = hidden_states[:, 0]
68
+ context_token = self.dropout(context_token, training=training)
69
+ pooled_output = self.dense(context_token)
70
+ pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output)
71
+ return pooled_output
72
+
73
+ @property
74
+ def output_dim(self) -> int:
75
+ return self.config.hidden_size
76
+
77
+ def build(self, input_shape=None):
78
+ if self.built:
79
+ return
80
+ self.built = True
81
+ if getattr(self, "dense", None) is not None:
82
+ with tf.name_scope(self.dense.name):
83
+ self.dense.build([None, None, self.config.pooler_hidden_size])
84
+ if getattr(self, "dropout", None) is not None:
85
+ with tf.name_scope(self.dropout.name):
86
+ self.dropout.build(None)
87
+
88
+
89
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2
90
+ class TFDebertaV2XSoftmax(keras.layers.Layer):
91
+ """
92
+ Masked Softmax which is optimized for saving memory
93
+
94
+ Args:
95
+ input (`tf.Tensor`): The input tensor that will apply softmax.
96
+ mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
97
+ dim (int): The dimension that will apply softmax
98
+ """
99
+
100
+ def __init__(self, axis=-1, **kwargs):
101
+ super().__init__(**kwargs)
102
+ self.axis = axis
103
+
104
+ def call(self, inputs: tf.Tensor, mask: tf.Tensor):
105
+ rmask = tf.logical_not(tf.cast(mask, tf.bool))
106
+ output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)
107
+ output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)
108
+ output = tf.where(rmask, 0.0, output)
109
+ return output
110
+
111
+
112
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
113
+ class TFDebertaV2StableDropout(keras.layers.Layer):
114
+ """
115
+ Optimized dropout module for stabilizing the training
116
+
117
+ Args:
118
+ drop_prob (float): the dropout probabilities
119
+ """
120
+
121
+ def __init__(self, drop_prob, **kwargs):
122
+ super().__init__(**kwargs)
123
+ self.drop_prob = drop_prob
124
+
125
+ @tf.custom_gradient
126
+ def xdropout(self, inputs):
127
+ """
128
+ Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
129
+ """
130
+ mask = tf.cast(
131
+ 1
132
+ - tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
133
+ tf.bool,
134
+ )
135
+ scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)
136
+ if self.drop_prob > 0:
137
+ inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale
138
+
139
+ def grad(upstream):
140
+ if self.drop_prob > 0:
141
+ return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale
142
+ else:
143
+ return upstream
144
+
145
+ return inputs, grad
146
+
147
+ def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
148
+ if training:
149
+ return self.xdropout(inputs)
150
+ return inputs
151
+
152
+
153
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaSelfOutput with Deberta->DebertaV2
154
+ class TFDebertaV2SelfOutput(keras.layers.Layer):
155
+ def __init__(self, config: DebertaV2Config, **kwargs):
156
+ super().__init__(**kwargs)
157
+ self.dense = keras.layers.Dense(config.hidden_size, name="dense")
158
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
159
+ self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
160
+ self.config = config
161
+
162
+ def call(self, hidden_states, input_tensor, training: bool = False):
163
+ hidden_states = self.dense(hidden_states)
164
+ hidden_states = self.dropout(hidden_states, training=training)
165
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
166
+ return hidden_states
167
+
168
+ def build(self, input_shape=None):
169
+ if self.built:
170
+ return
171
+ self.built = True
172
+ if getattr(self, "dense", None) is not None:
173
+ with tf.name_scope(self.dense.name):
174
+ self.dense.build([None, None, self.config.hidden_size])
175
+ if getattr(self, "LayerNorm", None) is not None:
176
+ with tf.name_scope(self.LayerNorm.name):
177
+ self.LayerNorm.build([None, None, self.config.hidden_size])
178
+ if getattr(self, "dropout", None) is not None:
179
+ with tf.name_scope(self.dropout.name):
180
+ self.dropout.build(None)
181
+
182
+
183
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2
184
+ class TFDebertaV2Attention(keras.layers.Layer):
185
+ def __init__(self, config: DebertaV2Config, **kwargs):
186
+ super().__init__(**kwargs)
187
+ self.self = TFDebertaV2DisentangledSelfAttention(config, name="self")
188
+ self.dense_output = TFDebertaV2SelfOutput(config, name="output")
189
+ self.config = config
190
+
191
+ def call(
192
+ self,
193
+ input_tensor: tf.Tensor,
194
+ attention_mask: tf.Tensor,
195
+ query_states: Optional[tf.Tensor] = None,
196
+ relative_pos: Optional[tf.Tensor] = None,
197
+ rel_embeddings: Optional[tf.Tensor] = None,
198
+ output_attentions: bool = False,
199
+ training: bool = False,
200
+ ) -> Tuple[tf.Tensor]:
201
+ self_outputs = self.self(
202
+ hidden_states=input_tensor,
203
+ attention_mask=attention_mask,
204
+ query_states=query_states,
205
+ relative_pos=relative_pos,
206
+ rel_embeddings=rel_embeddings,
207
+ output_attentions=output_attentions,
208
+ training=training,
209
+ )
210
+ if query_states is None:
211
+ query_states = input_tensor
212
+ attention_output = self.dense_output(
213
+ hidden_states=self_outputs[0], input_tensor=query_states, training=training
214
+ )
215
+
216
+ output = (attention_output,) + self_outputs[1:]
217
+
218
+ return output
219
+
220
+ def build(self, input_shape=None):
221
+ if self.built:
222
+ return
223
+ self.built = True
224
+ if getattr(self, "self", None) is not None:
225
+ with tf.name_scope(self.self.name):
226
+ self.self.build(None)
227
+ if getattr(self, "dense_output", None) is not None:
228
+ with tf.name_scope(self.dense_output.name):
229
+ self.dense_output.build(None)
230
+
231
+
232
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2
233
+ class TFDebertaV2Intermediate(keras.layers.Layer):
234
+ def __init__(self, config: DebertaV2Config, **kwargs):
235
+ super().__init__(**kwargs)
236
+
237
+ self.dense = keras.layers.Dense(
238
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
239
+ )
240
+
241
+ if isinstance(config.hidden_act, str):
242
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
243
+ else:
244
+ self.intermediate_act_fn = config.hidden_act
245
+ self.config = config
246
+
247
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
248
+ hidden_states = self.dense(inputs=hidden_states)
249
+ hidden_states = self.intermediate_act_fn(hidden_states)
250
+
251
+ return hidden_states
252
+
253
+ def build(self, input_shape=None):
254
+ if self.built:
255
+ return
256
+ self.built = True
257
+ if getattr(self, "dense", None) is not None:
258
+ with tf.name_scope(self.dense.name):
259
+ self.dense.build([None, None, self.config.hidden_size])
260
+
261
+
262
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2
263
+ class TFDebertaV2Output(keras.layers.Layer):
264
+ def __init__(self, config: DebertaV2Config, **kwargs):
265
+ super().__init__(**kwargs)
266
+
267
+ self.dense = keras.layers.Dense(
268
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
269
+ )
270
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
271
+ self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
272
+ self.config = config
273
+
274
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
275
+ hidden_states = self.dense(inputs=hidden_states)
276
+ hidden_states = self.dropout(hidden_states, training=training)
277
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
278
+
279
+ return hidden_states
280
+
281
+ def build(self, input_shape=None):
282
+ if self.built:
283
+ return
284
+ self.built = True
285
+ if getattr(self, "dense", None) is not None:
286
+ with tf.name_scope(self.dense.name):
287
+ self.dense.build([None, None, self.config.intermediate_size])
288
+ if getattr(self, "LayerNorm", None) is not None:
289
+ with tf.name_scope(self.LayerNorm.name):
290
+ self.LayerNorm.build([None, None, self.config.hidden_size])
291
+ if getattr(self, "dropout", None) is not None:
292
+ with tf.name_scope(self.dropout.name):
293
+ self.dropout.build(None)
294
+
295
+
296
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2
297
+ class TFDebertaV2Layer(keras.layers.Layer):
298
+ def __init__(self, config: DebertaV2Config, **kwargs):
299
+ super().__init__(**kwargs)
300
+
301
+ self.attention = TFDebertaV2Attention(config, name="attention")
302
+ self.intermediate = TFDebertaV2Intermediate(config, name="intermediate")
303
+ self.bert_output = TFDebertaV2Output(config, name="output")
304
+
305
+ def call(
306
+ self,
307
+ hidden_states: tf.Tensor,
308
+ attention_mask: tf.Tensor,
309
+ query_states: Optional[tf.Tensor] = None,
310
+ relative_pos: Optional[tf.Tensor] = None,
311
+ rel_embeddings: Optional[tf.Tensor] = None,
312
+ output_attentions: bool = False,
313
+ training: bool = False,
314
+ ) -> Tuple[tf.Tensor]:
315
+ attention_outputs = self.attention(
316
+ input_tensor=hidden_states,
317
+ attention_mask=attention_mask,
318
+ query_states=query_states,
319
+ relative_pos=relative_pos,
320
+ rel_embeddings=rel_embeddings,
321
+ output_attentions=output_attentions,
322
+ training=training,
323
+ )
324
+ attention_output = attention_outputs[0]
325
+ intermediate_output = self.intermediate(hidden_states=attention_output)
326
+ layer_output = self.bert_output(
327
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
328
+ )
329
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
330
+
331
+ return outputs
332
+
333
+ def build(self, input_shape=None):
334
+ if self.built:
335
+ return
336
+ self.built = True
337
+ if getattr(self, "attention", None) is not None:
338
+ with tf.name_scope(self.attention.name):
339
+ self.attention.build(None)
340
+ if getattr(self, "intermediate", None) is not None:
341
+ with tf.name_scope(self.intermediate.name):
342
+ self.intermediate.build(None)
343
+ if getattr(self, "bert_output", None) is not None:
344
+ with tf.name_scope(self.bert_output.name):
345
+ self.bert_output.build(None)
346
+
347
+
348
+ class TFDebertaV2ConvLayer(keras.layers.Layer):
349
+ def __init__(self, config: DebertaV2Config, **kwargs):
350
+ super().__init__(**kwargs)
351
+
352
+ self.kernel_size = getattr(config, "conv_kernel_size", 3)
353
+ # groups = getattr(config, "conv_groups", 1)
354
+ self.conv_act = get_tf_activation(getattr(config, "conv_act", "tanh"))
355
+ self.padding = (self.kernel_size - 1) // 2
356
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
357
+ self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
358
+ self.config = config
359
+
360
+ def build(self, input_shape=None):
361
+ if self.built:
362
+ return
363
+ self.built = True
364
+ with tf.name_scope("conv"):
365
+ self.conv_kernel = self.add_weight(
366
+ name="kernel",
367
+ shape=[self.kernel_size, self.config.hidden_size, self.config.hidden_size],
368
+ initializer=get_initializer(self.config.initializer_range),
369
+ )
370
+ self.conv_bias = self.add_weight(
371
+ name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer()
372
+ )
373
+ if getattr(self, "LayerNorm", None) is not None:
374
+ with tf.name_scope(self.LayerNorm.name):
375
+ self.LayerNorm.build([None, None, self.config.hidden_size])
376
+ if getattr(self, "dropout", None) is not None:
377
+ with tf.name_scope(self.dropout.name):
378
+ self.dropout.build(None)
379
+
380
+ def call(
381
+ self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False
382
+ ) -> tf.Tensor:
383
+ out = tf.nn.conv2d(
384
+ tf.expand_dims(hidden_states, 1),
385
+ tf.expand_dims(self.conv_kernel, 0),
386
+ strides=1,
387
+ padding=[[0, 0], [0, 0], [self.padding, self.padding], [0, 0]],
388
+ )
389
+ out = tf.squeeze(tf.nn.bias_add(out, self.conv_bias), 1)
390
+ rmask = tf.cast(1 - input_mask, tf.bool)
391
+ out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
392
+ out = self.dropout(out, training=training)
393
+ out = self.conv_act(out)
394
+
395
+ layer_norm_input = residual_states + out
396
+ output = self.LayerNorm(layer_norm_input)
397
+
398
+ if input_mask is None:
399
+ output_states = output
400
+ else:
401
+ if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)):
402
+ if len(shape_list(input_mask)) == 4:
403
+ input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
404
+ input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype)
405
+
406
+ output_states = output * input_mask
407
+
408
+ return output_states
409
+
410
+
411
+ class TFDebertaV2Encoder(keras.layers.Layer):
412
+ def __init__(self, config: DebertaV2Config, **kwargs):
413
+ super().__init__(**kwargs)
414
+
415
+ self.layer = [TFDebertaV2Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
416
+ self.relative_attention = getattr(config, "relative_attention", False)
417
+ self.config = config
418
+ if self.relative_attention:
419
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
420
+ if self.max_relative_positions < 1:
421
+ self.max_relative_positions = config.max_position_embeddings
422
+
423
+ self.position_buckets = getattr(config, "position_buckets", -1)
424
+ self.pos_ebd_size = self.max_relative_positions * 2
425
+
426
+ if self.position_buckets > 0:
427
+ self.pos_ebd_size = self.position_buckets * 2
428
+
429
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
430
+
431
+ if "layer_norm" in self.norm_rel_ebd:
432
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
433
+
434
+ self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None
435
+
436
+ def build(self, input_shape=None):
437
+ if self.built:
438
+ return
439
+ self.built = True
440
+ if self.relative_attention:
441
+ self.rel_embeddings = self.add_weight(
442
+ name="rel_embeddings.weight",
443
+ shape=[self.pos_ebd_size, self.config.hidden_size],
444
+ initializer=get_initializer(self.config.initializer_range),
445
+ )
446
+ if getattr(self, "conv", None) is not None:
447
+ with tf.name_scope(self.conv.name):
448
+ self.conv.build(None)
449
+ if getattr(self, "LayerNorm", None) is not None:
450
+ with tf.name_scope(self.LayerNorm.name):
451
+ self.LayerNorm.build([None, self.config.hidden_size])
452
+ if getattr(self, "layer", None) is not None:
453
+ for layer in self.layer:
454
+ with tf.name_scope(layer.name):
455
+ layer.build(None)
456
+
457
+ def get_rel_embedding(self):
458
+ rel_embeddings = self.rel_embeddings if self.relative_attention else None
459
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
460
+ rel_embeddings = self.LayerNorm(rel_embeddings)
461
+ return rel_embeddings
462
+
463
+ def get_attention_mask(self, attention_mask):
464
+ if len(shape_list(attention_mask)) <= 2:
465
+ extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2)
466
+ attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1)
467
+ attention_mask = tf.cast(attention_mask, tf.uint8)
468
+ elif len(shape_list(attention_mask)) == 3:
469
+ attention_mask = tf.expand_dims(attention_mask, 1)
470
+
471
+ return attention_mask
472
+
473
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
474
+ if self.relative_attention and relative_pos is None:
475
+ q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2]
476
+ relative_pos = build_relative_position(
477
+ q,
478
+ shape_list(hidden_states)[-2],
479
+ bucket_size=self.position_buckets,
480
+ max_position=self.max_relative_positions,
481
+ )
482
+ return relative_pos
483
+
484
+ def call(
485
+ self,
486
+ hidden_states: tf.Tensor,
487
+ attention_mask: tf.Tensor,
488
+ query_states: Optional[tf.Tensor] = None,
489
+ relative_pos: Optional[tf.Tensor] = None,
490
+ output_attentions: bool = False,
491
+ output_hidden_states: bool = False,
492
+ return_dict: bool = True,
493
+ training: bool = False,
494
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
495
+ if len(shape_list(attention_mask)) <= 2:
496
+ input_mask = attention_mask
497
+ else:
498
+ input_mask = tf.cast(tf.math.reduce_sum(attention_mask, axis=-2) > 0, dtype=tf.uint8)
499
+
500
+ all_hidden_states = () if output_hidden_states else None
501
+ all_attentions = () if output_attentions else None
502
+
503
+ attention_mask = self.get_attention_mask(attention_mask)
504
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
505
+
506
+ next_kv = hidden_states
507
+
508
+ rel_embeddings = self.get_rel_embedding()
509
+ output_states = next_kv
510
+ for i, layer_module in enumerate(self.layer):
511
+ if output_hidden_states:
512
+ all_hidden_states = all_hidden_states + (output_states,)
513
+
514
+ layer_outputs = layer_module(
515
+ hidden_states=next_kv,
516
+ attention_mask=attention_mask,
517
+ query_states=query_states,
518
+ relative_pos=relative_pos,
519
+ rel_embeddings=rel_embeddings,
520
+ output_attentions=output_attentions,
521
+ training=training,
522
+ )
523
+ output_states = layer_outputs[0]
524
+
525
+ if i == 0 and self.conv is not None:
526
+ output_states = self.conv(hidden_states, output_states, input_mask)
527
+
528
+ next_kv = output_states
529
+
530
+ if output_attentions:
531
+ all_attentions = all_attentions + (layer_outputs[1],)
532
+
533
+ # Add last layer
534
+ if output_hidden_states:
535
+ all_hidden_states = all_hidden_states + (output_states,)
536
+
537
+ if not return_dict:
538
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
539
+
540
+ return TFBaseModelOutput(
541
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
542
+ )
543
+
544
+
545
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
546
+ sign = tf.math.sign(relative_pos)
547
+ mid = bucket_size // 2
548
+ abs_pos = tf.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, tf.math.abs(relative_pos))
549
+ log_pos = tf.math.ceil(
550
+ tf.cast(tf.math.log(abs_pos / mid), tf.float32)
551
+ / tf.cast(tf.math.log((max_position - 1) / mid), tf.float32)
552
+ * tf.cast(mid - 1, tf.float32) # in graph mode
553
+ ) + tf.cast(mid, tf.float32)
554
+ bucket_pos = tf.cast(
555
+ tf.where(abs_pos <= mid, tf.cast(relative_pos, tf.float32), log_pos * tf.cast(sign, tf.float32)), tf.int32
556
+ )
557
+ return bucket_pos
558
+
559
+
560
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
561
+ """
562
+ Build relative position according to the query and key
563
+
564
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
565
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
566
+ P_k\\)
567
+
568
+ Args:
569
+ query_size (int): the length of query
570
+ key_size (int): the length of key
571
+ bucket_size (int): the size of position bucket
572
+ max_position (int): the maximum allowed absolute position
573
+
574
+ Return:
575
+ `tf.Tensor`: A tensor with shape [1, query_size, key_size]
576
+
577
+ """
578
+ q_ids = tf.range(query_size, dtype=tf.int32)
579
+ k_ids = tf.range(key_size, dtype=tf.int32)
580
+ rel_pos_ids = q_ids[:, None] - tf.tile(tf.expand_dims(k_ids, axis=0), [shape_list(q_ids)[0], 1])
581
+ if bucket_size > 0 and max_position > 0:
582
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
583
+ rel_pos_ids = rel_pos_ids[:query_size, :]
584
+ rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0)
585
+ return tf.cast(rel_pos_ids, tf.int64)
586
+
587
+
588
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
589
+ shapes = [
590
+ shape_list(query_layer)[0],
591
+ shape_list(query_layer)[1],
592
+ shape_list(query_layer)[2],
593
+ shape_list(relative_pos)[-1],
594
+ ]
595
+ return tf.broadcast_to(c2p_pos, shapes)
596
+
597
+
598
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
599
+ shapes = [
600
+ shape_list(query_layer)[0],
601
+ shape_list(query_layer)[1],
602
+ shape_list(key_layer)[-2],
603
+ shape_list(key_layer)[-2],
604
+ ]
605
+ return tf.broadcast_to(c2p_pos, shapes)
606
+
607
+
608
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
609
+ shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]]
610
+ return tf.broadcast_to(pos_index, shapes)
611
+
612
+
613
+ def take_along_axis(x, indices):
614
+ # Only a valid port of np.take_along_axis when the gather axis is -1
615
+
616
+ # TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
617
+ if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
618
+ # [B, S, P] -> [B, S, P, D]
619
+ one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
620
+
621
+ # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
622
+ # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
623
+ gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
624
+
625
+ # GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
626
+ else:
627
+ gathered = tf.gather(x, indices, batch_dims=2)
628
+
629
+ return gathered
630
+
631
+
632
+ class TFDebertaV2DisentangledSelfAttention(keras.layers.Layer):
633
+ """
634
+ Disentangled self-attention module
635
+
636
+ Parameters:
637
+ config (`DebertaV2Config`):
638
+ A model config class instance with the configuration to build a new model. The schema is similar to
639
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
640
+
641
+ """
642
+
643
+ def __init__(self, config: DebertaV2Config, **kwargs):
644
+ super().__init__(**kwargs)
645
+ if config.hidden_size % config.num_attention_heads != 0:
646
+ raise ValueError(
647
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
648
+ f"heads ({config.num_attention_heads})"
649
+ )
650
+ self.num_attention_heads = config.num_attention_heads
651
+ _attention_head_size = config.hidden_size // config.num_attention_heads
652
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
653
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
654
+ self.query_proj = keras.layers.Dense(
655
+ self.all_head_size,
656
+ kernel_initializer=get_initializer(config.initializer_range),
657
+ name="query_proj",
658
+ use_bias=True,
659
+ )
660
+ self.key_proj = keras.layers.Dense(
661
+ self.all_head_size,
662
+ kernel_initializer=get_initializer(config.initializer_range),
663
+ name="key_proj",
664
+ use_bias=True,
665
+ )
666
+ self.value_proj = keras.layers.Dense(
667
+ self.all_head_size,
668
+ kernel_initializer=get_initializer(config.initializer_range),
669
+ name="value_proj",
670
+ use_bias=True,
671
+ )
672
+
673
+ self.share_att_key = getattr(config, "share_att_key", False)
674
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
675
+ self.relative_attention = getattr(config, "relative_attention", False)
676
+
677
+ if self.relative_attention:
678
+ self.position_buckets = getattr(config, "position_buckets", -1)
679
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
680
+ if self.max_relative_positions < 1:
681
+ self.max_relative_positions = config.max_position_embeddings
682
+ self.pos_ebd_size = self.max_relative_positions
683
+ if self.position_buckets > 0:
684
+ self.pos_ebd_size = self.position_buckets
685
+
686
+ self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout")
687
+
688
+ if not self.share_att_key:
689
+ if "c2p" in self.pos_att_type:
690
+ self.pos_key_proj = keras.layers.Dense(
691
+ self.all_head_size,
692
+ kernel_initializer=get_initializer(config.initializer_range),
693
+ name="pos_proj",
694
+ use_bias=True,
695
+ )
696
+ if "p2c" in self.pos_att_type:
697
+ self.pos_query_proj = keras.layers.Dense(
698
+ self.all_head_size,
699
+ kernel_initializer=get_initializer(config.initializer_range),
700
+ name="pos_q_proj",
701
+ )
702
+ self.softmax = TFDebertaV2XSoftmax(axis=-1)
703
+ self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
704
+ self.config = config
705
+
706
+ def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
707
+ tensor_shape = shape_list(tensor)
708
+ # In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
709
+ shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
710
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
711
+ tensor = tf.reshape(tensor=tensor, shape=shape)
712
+ tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
713
+ x_shape = shape_list(tensor)
714
+ tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
715
+ return tensor
716
+
717
+ def call(
718
+ self,
719
+ hidden_states: tf.Tensor,
720
+ attention_mask: tf.Tensor,
721
+ query_states: Optional[tf.Tensor] = None,
722
+ relative_pos: Optional[tf.Tensor] = None,
723
+ rel_embeddings: Optional[tf.Tensor] = None,
724
+ output_attentions: bool = False,
725
+ training: bool = False,
726
+ ) -> Tuple[tf.Tensor]:
727
+ """
728
+ Call the module
729
+
730
+ Args:
731
+ hidden_states (`tf.Tensor`):
732
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
733
+ *Attention(Q,K,V)*
734
+
735
+ attention_mask (`tf.Tensor`):
736
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
737
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
738
+ th token.
739
+
740
+ return_att (`bool`, *optional*):
741
+ Whether return the attention matrix.
742
+
743
+ query_states (`tf.Tensor`, *optional*):
744
+ The *Q* state in *Attention(Q,K,V)*.
745
+
746
+ relative_pos (`tf.Tensor`):
747
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
748
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
749
+
750
+ rel_embeddings (`tf.Tensor`):
751
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
752
+ \\text{max_relative_positions}\\), *hidden_size*].
753
+
754
+
755
+ """
756
+ if query_states is None:
757
+ query_states = hidden_states
758
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
759
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
760
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
761
+
762
+ rel_att = None
763
+ # Take the dot product between "query" and "key" to get the raw attention scores.
764
+ scale_factor = 1
765
+ if "c2p" in self.pos_att_type:
766
+ scale_factor += 1
767
+ if "p2c" in self.pos_att_type:
768
+ scale_factor += 1
769
+ scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype))
770
+ attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale)
771
+ if self.relative_attention:
772
+ rel_embeddings = self.pos_dropout(rel_embeddings)
773
+ rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
774
+
775
+ if rel_att is not None:
776
+ attention_scores = attention_scores + rel_att
777
+ attention_scores = tf.reshape(
778
+ attention_scores,
779
+ (-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
780
+ )
781
+
782
+ # bsz x height x length x dimension
783
+ attention_probs = self.softmax(attention_scores, attention_mask)
784
+ attention_probs = self.dropout(attention_probs, training=training)
785
+ context_layer = tf.matmul(
786
+ tf.reshape(attention_probs, [-1, shape_list(attention_probs)[-2], shape_list(attention_probs)[-1]]),
787
+ value_layer,
788
+ )
789
+ context_layer = tf.transpose(
790
+ tf.reshape(
791
+ context_layer,
792
+ [-1, self.num_attention_heads, shape_list(context_layer)[-2], shape_list(context_layer)[-1]],
793
+ ),
794
+ [0, 2, 1, 3],
795
+ )
796
+ # Set the final dimension here explicitly.
797
+ # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
798
+ # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
799
+ # requires final input dimension to be defined
800
+ context_layer_shape = shape_list(context_layer)
801
+ new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
802
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
803
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
804
+ return outputs
805
+
806
+ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
807
+ if relative_pos is None:
808
+ q = shape_list(query_layer)[-2]
809
+ relative_pos = build_relative_position(
810
+ q,
811
+ shape_list(key_layer)[-2],
812
+ bucket_size=self.position_buckets,
813
+ max_position=self.max_relative_positions,
814
+ )
815
+ shape_list_pos = shape_list(relative_pos)
816
+ if len(shape_list_pos) == 2:
817
+ relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
818
+ elif len(shape_list_pos) == 3:
819
+ relative_pos = tf.expand_dims(relative_pos, 1)
820
+ # bsz x height x query x key
821
+ elif len(shape_list_pos) != 4:
822
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
823
+
824
+ att_span = self.pos_ebd_size
825
+ rel_embeddings = tf.expand_dims(
826
+ rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :], 0
827
+ )
828
+ if self.share_att_key:
829
+ pos_query_layer = tf.tile(
830
+ self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads),
831
+ [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
832
+ )
833
+ pos_key_layer = tf.tile(
834
+ self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads),
835
+ [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
836
+ )
837
+ else:
838
+ if "c2p" in self.pos_att_type:
839
+ pos_key_layer = tf.tile(
840
+ self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads),
841
+ [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
842
+ ) # .split(self.all_head_size, dim=-1)
843
+ if "p2c" in self.pos_att_type:
844
+ pos_query_layer = tf.tile(
845
+ self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads),
846
+ [shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
847
+ ) # .split(self.all_head_size, dim=-1)
848
+
849
+ score = 0
850
+ # content->position
851
+ if "c2p" in self.pos_att_type:
852
+ scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype))
853
+ c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 2, 1]))
854
+ c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
855
+ c2p_att = take_along_axis(
856
+ c2p_att,
857
+ tf.broadcast_to(
858
+ tf.squeeze(c2p_pos, 0),
859
+ [shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
860
+ ),
861
+ )
862
+ score += c2p_att / scale
863
+
864
+ # position->content
865
+ if "p2c" in self.pos_att_type:
866
+ scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))
867
+ if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]:
868
+ r_pos = build_relative_position(
869
+ shape_list(key_layer)[-2],
870
+ shape_list(key_layer)[-2],
871
+ bucket_size=self.position_buckets,
872
+ max_position=self.max_relative_positions,
873
+ )
874
+ r_pos = tf.expand_dims(r_pos, 0)
875
+ else:
876
+ r_pos = relative_pos
877
+
878
+ p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
879
+
880
+ p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1]))
881
+ p2c_att = tf.transpose(
882
+ take_along_axis(
883
+ p2c_att,
884
+ tf.broadcast_to(
885
+ tf.squeeze(p2c_pos, 0),
886
+ [shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
887
+ ),
888
+ ),
889
+ [0, 2, 1],
890
+ )
891
+ score += p2c_att / scale
892
+
893
+ return score
894
+
895
+ def build(self, input_shape=None):
896
+ if self.built:
897
+ return
898
+ self.built = True
899
+ if getattr(self, "query_proj", None) is not None:
900
+ with tf.name_scope(self.query_proj.name):
901
+ self.query_proj.build([None, None, self.config.hidden_size])
902
+ if getattr(self, "key_proj", None) is not None:
903
+ with tf.name_scope(self.key_proj.name):
904
+ self.key_proj.build([None, None, self.config.hidden_size])
905
+ if getattr(self, "value_proj", None) is not None:
906
+ with tf.name_scope(self.value_proj.name):
907
+ self.value_proj.build([None, None, self.config.hidden_size])
908
+ if getattr(self, "dropout", None) is not None:
909
+ with tf.name_scope(self.dropout.name):
910
+ self.dropout.build(None)
911
+ if getattr(self, "pos_dropout", None) is not None:
912
+ with tf.name_scope(self.pos_dropout.name):
913
+ self.pos_dropout.build(None)
914
+ if getattr(self, "pos_key_proj", None) is not None:
915
+ with tf.name_scope(self.pos_key_proj.name):
916
+ self.pos_key_proj.build([None, None, self.config.hidden_size])
917
+ if getattr(self, "pos_query_proj", None) is not None:
918
+ with tf.name_scope(self.pos_query_proj.name):
919
+ self.pos_query_proj.build([None, None, self.config.hidden_size])
920
+
921
+
922
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2
923
+ class TFDebertaV2Embeddings(keras.layers.Layer):
924
+ """Construct the embeddings from word, position and token_type embeddings."""
925
+
926
+ def __init__(self, config, **kwargs):
927
+ super().__init__(**kwargs)
928
+
929
+ self.config = config
930
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
931
+ self.hidden_size = config.hidden_size
932
+ self.max_position_embeddings = config.max_position_embeddings
933
+ self.position_biased_input = getattr(config, "position_biased_input", True)
934
+ self.initializer_range = config.initializer_range
935
+ if self.embedding_size != config.hidden_size:
936
+ self.embed_proj = keras.layers.Dense(
937
+ config.hidden_size,
938
+ kernel_initializer=get_initializer(config.initializer_range),
939
+ name="embed_proj",
940
+ use_bias=False,
941
+ )
942
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
943
+ self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
944
+
945
+ def build(self, input_shape=None):
946
+ with tf.name_scope("word_embeddings"):
947
+ self.weight = self.add_weight(
948
+ name="weight",
949
+ shape=[self.config.vocab_size, self.embedding_size],
950
+ initializer=get_initializer(self.initializer_range),
951
+ )
952
+
953
+ with tf.name_scope("token_type_embeddings"):
954
+ if self.config.type_vocab_size > 0:
955
+ self.token_type_embeddings = self.add_weight(
956
+ name="embeddings",
957
+ shape=[self.config.type_vocab_size, self.embedding_size],
958
+ initializer=get_initializer(self.initializer_range),
959
+ )
960
+ else:
961
+ self.token_type_embeddings = None
962
+
963
+ with tf.name_scope("position_embeddings"):
964
+ if self.position_biased_input:
965
+ self.position_embeddings = self.add_weight(
966
+ name="embeddings",
967
+ shape=[self.max_position_embeddings, self.hidden_size],
968
+ initializer=get_initializer(self.initializer_range),
969
+ )
970
+ else:
971
+ self.position_embeddings = None
972
+
973
+ if self.built:
974
+ return
975
+ self.built = True
976
+ if getattr(self, "LayerNorm", None) is not None:
977
+ with tf.name_scope(self.LayerNorm.name):
978
+ self.LayerNorm.build([None, None, self.config.hidden_size])
979
+ if getattr(self, "dropout", None) is not None:
980
+ with tf.name_scope(self.dropout.name):
981
+ self.dropout.build(None)
982
+ if getattr(self, "embed_proj", None) is not None:
983
+ with tf.name_scope(self.embed_proj.name):
984
+ self.embed_proj.build([None, None, self.embedding_size])
985
+
986
+ def call(
987
+ self,
988
+ input_ids: Optional[tf.Tensor] = None,
989
+ position_ids: Optional[tf.Tensor] = None,
990
+ token_type_ids: Optional[tf.Tensor] = None,
991
+ inputs_embeds: Optional[tf.Tensor] = None,
992
+ mask: Optional[tf.Tensor] = None,
993
+ training: bool = False,
994
+ ) -> tf.Tensor:
995
+ """
996
+ Applies embedding based on inputs tensor.
997
+
998
+ Returns:
999
+ final_embeddings (`tf.Tensor`): output embedding tensor.
1000
+ """
1001
+ if input_ids is None and inputs_embeds is None:
1002
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
1003
+
1004
+ if input_ids is not None:
1005
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
1006
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
1007
+
1008
+ input_shape = shape_list(inputs_embeds)[:-1]
1009
+
1010
+ if token_type_ids is None:
1011
+ token_type_ids = tf.fill(dims=input_shape, value=0)
1012
+
1013
+ if position_ids is None:
1014
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
1015
+
1016
+ final_embeddings = inputs_embeds
1017
+ if self.position_biased_input:
1018
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
1019
+ final_embeddings += position_embeds
1020
+ if self.config.type_vocab_size > 0:
1021
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
1022
+ final_embeddings += token_type_embeds
1023
+
1024
+ if self.embedding_size != self.hidden_size:
1025
+ final_embeddings = self.embed_proj(final_embeddings)
1026
+
1027
+ final_embeddings = self.LayerNorm(final_embeddings)
1028
+
1029
+ if mask is not None:
1030
+ if len(shape_list(mask)) != len(shape_list(final_embeddings)):
1031
+ if len(shape_list(mask)) == 4:
1032
+ mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1)
1033
+ mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)
1034
+
1035
+ final_embeddings = final_embeddings * mask
1036
+
1037
+ final_embeddings = self.dropout(final_embeddings, training=training)
1038
+
1039
+ return final_embeddings
1040
+
1041
+
1042
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPredictionHeadTransform with Deberta->DebertaV2
1043
+ class TFDebertaV2PredictionHeadTransform(keras.layers.Layer):
1044
+ def __init__(self, config: DebertaV2Config, **kwargs):
1045
+ super().__init__(**kwargs)
1046
+
1047
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1048
+
1049
+ self.dense = keras.layers.Dense(
1050
+ units=self.embedding_size,
1051
+ kernel_initializer=get_initializer(config.initializer_range),
1052
+ name="dense",
1053
+ )
1054
+
1055
+ if isinstance(config.hidden_act, str):
1056
+ self.transform_act_fn = get_tf_activation(config.hidden_act)
1057
+ else:
1058
+ self.transform_act_fn = config.hidden_act
1059
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
1060
+ self.config = config
1061
+
1062
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
1063
+ hidden_states = self.dense(inputs=hidden_states)
1064
+ hidden_states = self.transform_act_fn(hidden_states)
1065
+ hidden_states = self.LayerNorm(hidden_states)
1066
+
1067
+ return hidden_states
1068
+
1069
+ def build(self, input_shape=None):
1070
+ if self.built:
1071
+ return
1072
+ self.built = True
1073
+ if getattr(self, "dense", None) is not None:
1074
+ with tf.name_scope(self.dense.name):
1075
+ self.dense.build([None, None, self.config.hidden_size])
1076
+ if getattr(self, "LayerNorm", None) is not None:
1077
+ with tf.name_scope(self.LayerNorm.name):
1078
+ self.LayerNorm.build([None, None, self.embedding_size])
1079
+
1080
+
1081
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLMPredictionHead with Deberta->DebertaV2
1082
+ class TFDebertaV2LMPredictionHead(keras.layers.Layer):
1083
+ def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs):
1084
+ super().__init__(**kwargs)
1085
+
1086
+ self.config = config
1087
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1088
+
1089
+ self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform")
1090
+
1091
+ # The output weights are the same as the input embeddings, but there is
1092
+ # an output-only bias for each token.
1093
+ self.input_embeddings = input_embeddings
1094
+
1095
+ def build(self, input_shape=None):
1096
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
1097
+
1098
+ if self.built:
1099
+ return
1100
+ self.built = True
1101
+ if getattr(self, "transform", None) is not None:
1102
+ with tf.name_scope(self.transform.name):
1103
+ self.transform.build(None)
1104
+
1105
+ def get_output_embeddings(self) -> keras.layers.Layer:
1106
+ return self.input_embeddings
1107
+
1108
+ def set_output_embeddings(self, value: tf.Variable):
1109
+ self.input_embeddings.weight = value
1110
+ self.input_embeddings.vocab_size = shape_list(value)[0]
1111
+
1112
+ def get_bias(self) -> Dict[str, tf.Variable]:
1113
+ return {"bias": self.bias}
1114
+
1115
+ def set_bias(self, value: tf.Variable):
1116
+ self.bias = value["bias"]
1117
+ self.config.vocab_size = shape_list(value["bias"])[0]
1118
+
1119
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
1120
+ hidden_states = self.transform(hidden_states=hidden_states)
1121
+ seq_length = shape_list(hidden_states)[1]
1122
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
1123
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
1124
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
1125
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
1126
+
1127
+ return hidden_states
1128
+
1129
+
1130
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOnlyMLMHead with Deberta->DebertaV2
1131
+ class TFDebertaV2OnlyMLMHead(keras.layers.Layer):
1132
+ def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs):
1133
+ super().__init__(**kwargs)
1134
+ self.predictions = TFDebertaV2LMPredictionHead(config, input_embeddings, name="predictions")
1135
+
1136
+ def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
1137
+ prediction_scores = self.predictions(hidden_states=sequence_output)
1138
+
1139
+ return prediction_scores
1140
+
1141
+ def build(self, input_shape=None):
1142
+ if self.built:
1143
+ return
1144
+ self.built = True
1145
+ if getattr(self, "predictions", None) is not None:
1146
+ with tf.name_scope(self.predictions.name):
1147
+ self.predictions.build(None)
1148
+
1149
+
1150
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaMainLayer with Deberta->DebertaV2
1151
+ class TFDebertaV2MainLayer(keras.layers.Layer):
1152
+ config_class = DebertaV2Config
1153
+
1154
+ def __init__(self, config: DebertaV2Config, **kwargs):
1155
+ super().__init__(**kwargs)
1156
+
1157
+ self.config = config
1158
+
1159
+ self.embeddings = TFDebertaV2Embeddings(config, name="embeddings")
1160
+ self.encoder = TFDebertaV2Encoder(config, name="encoder")
1161
+
1162
+ def get_input_embeddings(self) -> keras.layers.Layer:
1163
+ return self.embeddings
1164
+
1165
+ def set_input_embeddings(self, value: tf.Variable):
1166
+ self.embeddings.weight = value
1167
+ self.embeddings.vocab_size = shape_list(value)[0]
1168
+
1169
+ def _prune_heads(self, heads_to_prune):
1170
+ """
1171
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1172
+ class PreTrainedModel
1173
+ """
1174
+ raise NotImplementedError
1175
+
1176
+ @unpack_inputs
1177
+ def call(
1178
+ self,
1179
+ input_ids: TFModelInputType | None = None,
1180
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1181
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1182
+ position_ids: np.ndarray | tf.Tensor | None = None,
1183
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1184
+ output_attentions: Optional[bool] = None,
1185
+ output_hidden_states: Optional[bool] = None,
1186
+ return_dict: Optional[bool] = None,
1187
+ training: bool = False,
1188
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
1189
+ if input_ids is not None and inputs_embeds is not None:
1190
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1191
+ elif input_ids is not None:
1192
+ input_shape = shape_list(input_ids)
1193
+ elif inputs_embeds is not None:
1194
+ input_shape = shape_list(inputs_embeds)[:-1]
1195
+ else:
1196
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1197
+
1198
+ if attention_mask is None:
1199
+ attention_mask = tf.fill(dims=input_shape, value=1)
1200
+
1201
+ if token_type_ids is None:
1202
+ token_type_ids = tf.fill(dims=input_shape, value=0)
1203
+
1204
+ embedding_output = self.embeddings(
1205
+ input_ids=input_ids,
1206
+ position_ids=position_ids,
1207
+ token_type_ids=token_type_ids,
1208
+ inputs_embeds=inputs_embeds,
1209
+ mask=attention_mask,
1210
+ training=training,
1211
+ )
1212
+
1213
+ encoder_outputs = self.encoder(
1214
+ hidden_states=embedding_output,
1215
+ attention_mask=attention_mask,
1216
+ output_attentions=output_attentions,
1217
+ output_hidden_states=output_hidden_states,
1218
+ return_dict=return_dict,
1219
+ training=training,
1220
+ )
1221
+
1222
+ sequence_output = encoder_outputs[0]
1223
+
1224
+ if not return_dict:
1225
+ return (sequence_output,) + encoder_outputs[1:]
1226
+
1227
+ return TFBaseModelOutput(
1228
+ last_hidden_state=sequence_output,
1229
+ hidden_states=encoder_outputs.hidden_states,
1230
+ attentions=encoder_outputs.attentions,
1231
+ )
1232
+
1233
+ def build(self, input_shape=None):
1234
+ if self.built:
1235
+ return
1236
+ self.built = True
1237
+ if getattr(self, "embeddings", None) is not None:
1238
+ with tf.name_scope(self.embeddings.name):
1239
+ self.embeddings.build(None)
1240
+ if getattr(self, "encoder", None) is not None:
1241
+ with tf.name_scope(self.encoder.name):
1242
+ self.encoder.build(None)
1243
+
1244
+
1245
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPreTrainedModel with Deberta->DebertaV2
1246
+ class TFDebertaV2PreTrainedModel(TFPreTrainedModel):
1247
+ """
1248
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1249
+ models.
1250
+ """
1251
+
1252
+ config_class = DebertaV2Config
1253
+ base_model_prefix = "deberta"
1254
+
1255
+
1256
+ DEBERTA_START_DOCSTRING = r"""
1257
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
1258
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
1259
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
1260
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
1261
+
1262
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
1263
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
1264
+ behavior.
1265
+
1266
+ <Tip>
1267
+
1268
+ TensorFlow models and layers in `transformers` accept two formats as input:
1269
+
1270
+ - having all inputs as keyword arguments (like PyTorch models), or
1271
+ - having all inputs as a list, tuple or dict in the first positional argument.
1272
+
1273
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
1274
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
1275
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
1276
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
1277
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
1278
+ positional argument:
1279
+
1280
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
1281
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
1282
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
1283
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
1284
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
1285
+
1286
+ Note that when creating models and layers with
1287
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
1288
+ about any of this, as you can just pass inputs like you would to any other Python function!
1289
+
1290
+ </Tip>
1291
+
1292
+ Parameters:
1293
+ config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
1294
+ Initializing with a config file does not load the weights associated with the model, only the
1295
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1296
+ """
1297
+
1298
+ DEBERTA_INPUTS_DOCSTRING = r"""
1299
+ Args:
1300
+ input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
1301
+ Indices of input sequence tokens in the vocabulary.
1302
+
1303
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1304
+ [`PreTrainedTokenizer.__call__`] for details.
1305
+
1306
+ [What are input IDs?](../glossary#input-ids)
1307
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1308
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1309
+
1310
+ - 1 for tokens that are **not masked**,
1311
+ - 0 for tokens that are **masked**.
1312
+
1313
+ [What are attention masks?](../glossary#attention-mask)
1314
+ token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1315
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1316
+ 1]`:
1317
+
1318
+ - 0 corresponds to a *sentence A* token,
1319
+ - 1 corresponds to a *sentence B* token.
1320
+
1321
+ [What are token type IDs?](../glossary#token-type-ids)
1322
+ position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
1323
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1324
+ config.max_position_embeddings - 1]`.
1325
+
1326
+ [What are position IDs?](../glossary#position-ids)
1327
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
1328
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1329
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
1330
+ model's internal embedding lookup matrix.
1331
+ output_attentions (`bool`, *optional*):
1332
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1333
+ tensors for more detail.
1334
+ output_hidden_states (`bool`, *optional*):
1335
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1336
+ more detail.
1337
+ return_dict (`bool`, *optional*):
1338
+ Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple.
1339
+ """
1340
+
1341
+
1342
+ @add_start_docstrings(
1343
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
1344
+ DEBERTA_START_DOCSTRING,
1345
+ )
1346
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaModel with Deberta->DebertaV2
1347
+ class TFDebertaV2Model(TFDebertaV2PreTrainedModel):
1348
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1349
+ super().__init__(config, *inputs, **kwargs)
1350
+
1351
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1352
+
1353
+ @unpack_inputs
1354
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1355
+ @add_code_sample_docstrings(
1356
+ checkpoint=_CHECKPOINT_FOR_DOC,
1357
+ output_type=TFBaseModelOutput,
1358
+ config_class=_CONFIG_FOR_DOC,
1359
+ )
1360
+ def call(
1361
+ self,
1362
+ input_ids: TFModelInputType | None = None,
1363
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1364
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1365
+ position_ids: np.ndarray | tf.Tensor | None = None,
1366
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1367
+ output_attentions: Optional[bool] = None,
1368
+ output_hidden_states: Optional[bool] = None,
1369
+ return_dict: Optional[bool] = None,
1370
+ training: Optional[bool] = False,
1371
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
1372
+ outputs = self.deberta(
1373
+ input_ids=input_ids,
1374
+ attention_mask=attention_mask,
1375
+ token_type_ids=token_type_ids,
1376
+ position_ids=position_ids,
1377
+ inputs_embeds=inputs_embeds,
1378
+ output_attentions=output_attentions,
1379
+ output_hidden_states=output_hidden_states,
1380
+ return_dict=return_dict,
1381
+ training=training,
1382
+ )
1383
+
1384
+ return outputs
1385
+
1386
+ def build(self, input_shape=None):
1387
+ if self.built:
1388
+ return
1389
+ self.built = True
1390
+ if getattr(self, "deberta", None) is not None:
1391
+ with tf.name_scope(self.deberta.name):
1392
+ self.deberta.build(None)
1393
+
1394
+
1395
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1396
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForMaskedLM with Deberta->DebertaV2
1397
+ class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelingLoss):
1398
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1399
+ super().__init__(config, *inputs, **kwargs)
1400
+
1401
+ if config.is_decoder:
1402
+ logger.warning(
1403
+ "If you want to use `TFDebertaV2ForMaskedLM` make sure `config.is_decoder=False` for "
1404
+ "bi-directional self-attention."
1405
+ )
1406
+
1407
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1408
+ self.mlm = TFDebertaV2OnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls")
1409
+
1410
+ def get_lm_head(self) -> keras.layers.Layer:
1411
+ return self.mlm.predictions
1412
+
1413
+ @unpack_inputs
1414
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1415
+ @add_code_sample_docstrings(
1416
+ checkpoint=_CHECKPOINT_FOR_DOC,
1417
+ output_type=TFMaskedLMOutput,
1418
+ config_class=_CONFIG_FOR_DOC,
1419
+ )
1420
+ def call(
1421
+ self,
1422
+ input_ids: TFModelInputType | None = None,
1423
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1424
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1425
+ position_ids: np.ndarray | tf.Tensor | None = None,
1426
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1427
+ output_attentions: Optional[bool] = None,
1428
+ output_hidden_states: Optional[bool] = None,
1429
+ return_dict: Optional[bool] = None,
1430
+ labels: np.ndarray | tf.Tensor | None = None,
1431
+ training: Optional[bool] = False,
1432
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
1433
+ r"""
1434
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1435
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1436
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1437
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1438
+ """
1439
+ outputs = self.deberta(
1440
+ input_ids=input_ids,
1441
+ attention_mask=attention_mask,
1442
+ token_type_ids=token_type_ids,
1443
+ position_ids=position_ids,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ return_dict=return_dict,
1448
+ training=training,
1449
+ )
1450
+ sequence_output = outputs[0]
1451
+ prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
1452
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
1453
+
1454
+ if not return_dict:
1455
+ output = (prediction_scores,) + outputs[2:]
1456
+ return ((loss,) + output) if loss is not None else output
1457
+
1458
+ return TFMaskedLMOutput(
1459
+ loss=loss,
1460
+ logits=prediction_scores,
1461
+ hidden_states=outputs.hidden_states,
1462
+ attentions=outputs.attentions,
1463
+ )
1464
+
1465
+ def build(self, input_shape=None):
1466
+ if self.built:
1467
+ return
1468
+ self.built = True
1469
+ if getattr(self, "deberta", None) is not None:
1470
+ with tf.name_scope(self.deberta.name):
1471
+ self.deberta.build(None)
1472
+ if getattr(self, "mlm", None) is not None:
1473
+ with tf.name_scope(self.mlm.name):
1474
+ self.mlm.build(None)
1475
+
1476
+
1477
+ @add_start_docstrings(
1478
+ """
1479
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1480
+ pooled output) e.g. for GLUE tasks.
1481
+ """,
1482
+ DEBERTA_START_DOCSTRING,
1483
+ )
1484
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForSequenceClassification with Deberta->DebertaV2
1485
+ class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenceClassificationLoss):
1486
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1487
+ super().__init__(config, *inputs, **kwargs)
1488
+
1489
+ self.num_labels = config.num_labels
1490
+
1491
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1492
+ self.pooler = TFDebertaV2ContextPooler(config, name="pooler")
1493
+
1494
+ drop_out = getattr(config, "cls_dropout", None)
1495
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1496
+ self.dropout = TFDebertaV2StableDropout(drop_out, name="cls_dropout")
1497
+ self.classifier = keras.layers.Dense(
1498
+ units=config.num_labels,
1499
+ kernel_initializer=get_initializer(config.initializer_range),
1500
+ name="classifier",
1501
+ )
1502
+ self.output_dim = self.pooler.output_dim
1503
+
1504
+ @unpack_inputs
1505
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1506
+ @add_code_sample_docstrings(
1507
+ checkpoint=_CHECKPOINT_FOR_DOC,
1508
+ output_type=TFSequenceClassifierOutput,
1509
+ config_class=_CONFIG_FOR_DOC,
1510
+ )
1511
+ def call(
1512
+ self,
1513
+ input_ids: TFModelInputType | None = None,
1514
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1515
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1516
+ position_ids: np.ndarray | tf.Tensor | None = None,
1517
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1518
+ output_attentions: Optional[bool] = None,
1519
+ output_hidden_states: Optional[bool] = None,
1520
+ return_dict: Optional[bool] = None,
1521
+ labels: np.ndarray | tf.Tensor | None = None,
1522
+ training: Optional[bool] = False,
1523
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
1524
+ r"""
1525
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1526
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1527
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1528
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1529
+ """
1530
+ outputs = self.deberta(
1531
+ input_ids=input_ids,
1532
+ attention_mask=attention_mask,
1533
+ token_type_ids=token_type_ids,
1534
+ position_ids=position_ids,
1535
+ inputs_embeds=inputs_embeds,
1536
+ output_attentions=output_attentions,
1537
+ output_hidden_states=output_hidden_states,
1538
+ return_dict=return_dict,
1539
+ training=training,
1540
+ )
1541
+ sequence_output = outputs[0]
1542
+ pooled_output = self.pooler(sequence_output, training=training)
1543
+ pooled_output = self.dropout(pooled_output, training=training)
1544
+ logits = self.classifier(pooled_output)
1545
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1546
+
1547
+ if not return_dict:
1548
+ output = (logits,) + outputs[1:]
1549
+
1550
+ return ((loss,) + output) if loss is not None else output
1551
+
1552
+ return TFSequenceClassifierOutput(
1553
+ loss=loss,
1554
+ logits=logits,
1555
+ hidden_states=outputs.hidden_states,
1556
+ attentions=outputs.attentions,
1557
+ )
1558
+
1559
+ def build(self, input_shape=None):
1560
+ if self.built:
1561
+ return
1562
+ self.built = True
1563
+ if getattr(self, "deberta", None) is not None:
1564
+ with tf.name_scope(self.deberta.name):
1565
+ self.deberta.build(None)
1566
+ if getattr(self, "pooler", None) is not None:
1567
+ with tf.name_scope(self.pooler.name):
1568
+ self.pooler.build(None)
1569
+ if getattr(self, "dropout", None) is not None:
1570
+ with tf.name_scope(self.dropout.name):
1571
+ self.dropout.build(None)
1572
+ if getattr(self, "classifier", None) is not None:
1573
+ with tf.name_scope(self.classifier.name):
1574
+ self.classifier.build([None, None, self.output_dim])
1575
+
1576
+
1577
+ @add_start_docstrings(
1578
+ """
1579
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1580
+ Named-Entity-Recognition (NER) tasks.
1581
+ """,
1582
+ DEBERTA_START_DOCSTRING,
1583
+ )
1584
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForTokenClassification with Deberta->DebertaV2
1585
+ class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClassificationLoss):
1586
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1587
+ super().__init__(config, *inputs, **kwargs)
1588
+
1589
+ self.num_labels = config.num_labels
1590
+
1591
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1592
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
1593
+ self.classifier = keras.layers.Dense(
1594
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1595
+ )
1596
+ self.config = config
1597
+
1598
+ @unpack_inputs
1599
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1600
+ @add_code_sample_docstrings(
1601
+ checkpoint=_CHECKPOINT_FOR_DOC,
1602
+ output_type=TFTokenClassifierOutput,
1603
+ config_class=_CONFIG_FOR_DOC,
1604
+ )
1605
+ def call(
1606
+ self,
1607
+ input_ids: TFModelInputType | None = None,
1608
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1609
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1610
+ position_ids: np.ndarray | tf.Tensor | None = None,
1611
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1612
+ output_attentions: Optional[bool] = None,
1613
+ output_hidden_states: Optional[bool] = None,
1614
+ return_dict: Optional[bool] = None,
1615
+ labels: np.ndarray | tf.Tensor | None = None,
1616
+ training: Optional[bool] = False,
1617
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
1618
+ r"""
1619
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1620
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1621
+ """
1622
+ outputs = self.deberta(
1623
+ input_ids=input_ids,
1624
+ attention_mask=attention_mask,
1625
+ token_type_ids=token_type_ids,
1626
+ position_ids=position_ids,
1627
+ inputs_embeds=inputs_embeds,
1628
+ output_attentions=output_attentions,
1629
+ output_hidden_states=output_hidden_states,
1630
+ return_dict=return_dict,
1631
+ training=training,
1632
+ )
1633
+ sequence_output = outputs[0]
1634
+ sequence_output = self.dropout(sequence_output, training=training)
1635
+ logits = self.classifier(inputs=sequence_output)
1636
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1637
+
1638
+ if not return_dict:
1639
+ output = (logits,) + outputs[1:]
1640
+ return ((loss,) + output) if loss is not None else output
1641
+
1642
+ return TFTokenClassifierOutput(
1643
+ loss=loss,
1644
+ logits=logits,
1645
+ hidden_states=outputs.hidden_states,
1646
+ attentions=outputs.attentions,
1647
+ )
1648
+
1649
+ def build(self, input_shape=None):
1650
+ if self.built:
1651
+ return
1652
+ self.built = True
1653
+ if getattr(self, "deberta", None) is not None:
1654
+ with tf.name_scope(self.deberta.name):
1655
+ self.deberta.build(None)
1656
+ if getattr(self, "classifier", None) is not None:
1657
+ with tf.name_scope(self.classifier.name):
1658
+ self.classifier.build([None, None, self.config.hidden_size])
1659
+
1660
+
1661
+ @add_start_docstrings(
1662
+ """
1663
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1664
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1665
+ """,
1666
+ DEBERTA_START_DOCSTRING,
1667
+ )
1668
+ # Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForQuestionAnswering with Deberta->DebertaV2
1669
+ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsweringLoss):
1670
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1671
+ super().__init__(config, *inputs, **kwargs)
1672
+
1673
+ self.num_labels = config.num_labels
1674
+
1675
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1676
+ self.qa_outputs = keras.layers.Dense(
1677
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1678
+ )
1679
+ self.config = config
1680
+
1681
+ @unpack_inputs
1682
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1683
+ @add_code_sample_docstrings(
1684
+ checkpoint=_CHECKPOINT_FOR_DOC,
1685
+ output_type=TFQuestionAnsweringModelOutput,
1686
+ config_class=_CONFIG_FOR_DOC,
1687
+ )
1688
+ def call(
1689
+ self,
1690
+ input_ids: TFModelInputType | None = None,
1691
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1692
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1693
+ position_ids: np.ndarray | tf.Tensor | None = None,
1694
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1695
+ output_attentions: Optional[bool] = None,
1696
+ output_hidden_states: Optional[bool] = None,
1697
+ return_dict: Optional[bool] = None,
1698
+ start_positions: np.ndarray | tf.Tensor | None = None,
1699
+ end_positions: np.ndarray | tf.Tensor | None = None,
1700
+ training: Optional[bool] = False,
1701
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
1702
+ r"""
1703
+ start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1704
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1705
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1706
+ are not taken into account for computing the loss.
1707
+ end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1708
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1709
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1710
+ are not taken into account for computing the loss.
1711
+ """
1712
+ outputs = self.deberta(
1713
+ input_ids=input_ids,
1714
+ attention_mask=attention_mask,
1715
+ token_type_ids=token_type_ids,
1716
+ position_ids=position_ids,
1717
+ inputs_embeds=inputs_embeds,
1718
+ output_attentions=output_attentions,
1719
+ output_hidden_states=output_hidden_states,
1720
+ return_dict=return_dict,
1721
+ training=training,
1722
+ )
1723
+ sequence_output = outputs[0]
1724
+ logits = self.qa_outputs(inputs=sequence_output)
1725
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
1726
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
1727
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
1728
+ loss = None
1729
+
1730
+ if start_positions is not None and end_positions is not None:
1731
+ labels = {"start_position": start_positions}
1732
+ labels["end_position"] = end_positions
1733
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
1734
+
1735
+ if not return_dict:
1736
+ output = (start_logits, end_logits) + outputs[2:]
1737
+ return ((loss,) + output) if loss is not None else output
1738
+
1739
+ return TFQuestionAnsweringModelOutput(
1740
+ loss=loss,
1741
+ start_logits=start_logits,
1742
+ end_logits=end_logits,
1743
+ hidden_states=outputs.hidden_states,
1744
+ attentions=outputs.attentions,
1745
+ )
1746
+
1747
+ def build(self, input_shape=None):
1748
+ if self.built:
1749
+ return
1750
+ self.built = True
1751
+ if getattr(self, "deberta", None) is not None:
1752
+ with tf.name_scope(self.deberta.name):
1753
+ self.deberta.build(None)
1754
+ if getattr(self, "qa_outputs", None) is not None:
1755
+ with tf.name_scope(self.qa_outputs.name):
1756
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1757
+
1758
+
1759
+ @add_start_docstrings(
1760
+ """
1761
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1762
+ softmax) e.g. for RocStories/SWAG tasks.
1763
+ """,
1764
+ DEBERTA_START_DOCSTRING,
1765
+ )
1766
+ class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss):
1767
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1768
+ # _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
1769
+ # _keys_to_ignore_on_load_missing = [r"dropout"]
1770
+
1771
+ def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
1772
+ super().__init__(config, *inputs, **kwargs)
1773
+
1774
+ self.deberta = TFDebertaV2MainLayer(config, name="deberta")
1775
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
1776
+ self.pooler = TFDebertaV2ContextPooler(config, name="pooler")
1777
+ self.classifier = keras.layers.Dense(
1778
+ units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1779
+ )
1780
+ self.output_dim = self.pooler.output_dim
1781
+
1782
+ @unpack_inputs
1783
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1784
+ @add_code_sample_docstrings(
1785
+ checkpoint=_CHECKPOINT_FOR_DOC,
1786
+ output_type=TFMultipleChoiceModelOutput,
1787
+ config_class=_CONFIG_FOR_DOC,
1788
+ )
1789
+ def call(
1790
+ self,
1791
+ input_ids: TFModelInputType | None = None,
1792
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1793
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1794
+ position_ids: np.ndarray | tf.Tensor | None = None,
1795
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1796
+ output_attentions: Optional[bool] = None,
1797
+ output_hidden_states: Optional[bool] = None,
1798
+ return_dict: Optional[bool] = None,
1799
+ labels: np.ndarray | tf.Tensor | None = None,
1800
+ training: Optional[bool] = False,
1801
+ ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
1802
+ r"""
1803
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1804
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1805
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
1806
+ """
1807
+ if input_ids is not None:
1808
+ num_choices = shape_list(input_ids)[1]
1809
+ seq_length = shape_list(input_ids)[2]
1810
+ else:
1811
+ num_choices = shape_list(inputs_embeds)[1]
1812
+ seq_length = shape_list(inputs_embeds)[2]
1813
+
1814
+ flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
1815
+ flat_attention_mask = (
1816
+ tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
1817
+ )
1818
+ flat_token_type_ids = (
1819
+ tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
1820
+ )
1821
+ flat_position_ids = (
1822
+ tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
1823
+ )
1824
+ flat_inputs_embeds = (
1825
+ tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
1826
+ if inputs_embeds is not None
1827
+ else None
1828
+ )
1829
+ outputs = self.deberta(
1830
+ input_ids=flat_input_ids,
1831
+ attention_mask=flat_attention_mask,
1832
+ token_type_ids=flat_token_type_ids,
1833
+ position_ids=flat_position_ids,
1834
+ inputs_embeds=flat_inputs_embeds,
1835
+ output_attentions=output_attentions,
1836
+ output_hidden_states=output_hidden_states,
1837
+ return_dict=return_dict,
1838
+ training=training,
1839
+ )
1840
+ sequence_output = outputs[0]
1841
+ pooled_output = self.pooler(sequence_output, training=training)
1842
+ pooled_output = self.dropout(pooled_output, training=training)
1843
+ logits = self.classifier(pooled_output)
1844
+ reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
1845
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
1846
+
1847
+ if not return_dict:
1848
+ output = (reshaped_logits,) + outputs[2:]
1849
+ return ((loss,) + output) if loss is not None else output
1850
+
1851
+ return TFMultipleChoiceModelOutput(
1852
+ loss=loss,
1853
+ logits=reshaped_logits,
1854
+ hidden_states=outputs.hidden_states,
1855
+ attentions=outputs.attentions,
1856
+ )
1857
+
1858
+ def build(self, input_shape=None):
1859
+ if self.built:
1860
+ return
1861
+ self.built = True
1862
+ if getattr(self, "deberta", None) is not None:
1863
+ with tf.name_scope(self.deberta.name):
1864
+ self.deberta.build(None)
1865
+ if getattr(self, "pooler", None) is not None:
1866
+ with tf.name_scope(self.pooler.name):
1867
+ self.pooler.build(None)
1868
+ if getattr(self, "classifier", None) is not None:
1869
+ with tf.name_scope(self.classifier.name):
1870
+ self.classifier.build([None, None, self.output_dim])
1871
+
1872
+
1873
+ __all__ = [
1874
+ "TFDebertaV2ForMaskedLM",
1875
+ "TFDebertaV2ForQuestionAnswering",
1876
+ "TFDebertaV2ForMultipleChoice",
1877
+ "TFDebertaV2ForSequenceClassification",
1878
+ "TFDebertaV2ForTokenClassification",
1879
+ "TFDebertaV2Model",
1880
+ "TFDebertaV2PreTrainedModel",
1881
+ ]