Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dev_scripts/dockerci.sh +88 -0
- .gitattributes +1 -0
- .github/workflows/citest.yaml +75 -0
- docs/resources/grpo_code.png +3 -0
- docs/transformers/build/lib/transformers/models/cpm/tokenization_cpm.py +350 -0
- docs/transformers/build/lib/transformers/models/cpmant/modeling_cpmant.py +860 -0
- docs/transformers/build/lib/transformers/models/cpmant/tokenization_cpmant.py +270 -0
- docs/transformers/build/lib/transformers/models/ctrl/configuration_ctrl.py +116 -0
- docs/transformers/build/lib/transformers/models/ctrl/modeling_ctrl.py +844 -0
- docs/transformers/build/lib/transformers/models/ctrl/modeling_tf_ctrl.py +922 -0
- docs/transformers/build/lib/transformers/models/ctrl/tokenization_ctrl.py +251 -0
- docs/transformers/build/lib/transformers/models/cvt/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/cvt/configuration_cvt.py +146 -0
- docs/transformers/build/lib/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py +362 -0
- docs/transformers/build/lib/transformers/models/cvt/modeling_cvt.py +727 -0
- docs/transformers/build/lib/transformers/models/cvt/modeling_tf_cvt.py +1096 -0
- docs/transformers/build/lib/transformers/models/dab_detr/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/dab_detr/configuration_dab_detr.py +260 -0
- docs/transformers/build/lib/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py +233 -0
- docs/transformers/build/lib/transformers/models/dab_detr/modeling_dab_detr.py +1716 -0
- docs/transformers/build/lib/transformers/models/dac/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/dac/configuration_dac.py +114 -0
- docs/transformers/build/lib/transformers/models/dac/convert_dac_checkpoint.py +261 -0
- docs/transformers/build/lib/transformers/models/dac/feature_extraction_dac.py +173 -0
- docs/transformers/build/lib/transformers/models/dac/modeling_dac.py +724 -0
- docs/transformers/build/lib/transformers/models/data2vec/__init__.py +32 -0
- docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_audio.py +288 -0
- docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_text.py +154 -0
- docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_vision.py +194 -0
- docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py +285 -0
- docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py +207 -0
- docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py +374 -0
- docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_audio.py +1746 -0
- docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_text.py +1553 -0
- docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_vision.py +1449 -0
- docs/transformers/build/lib/transformers/models/data2vec/modeling_tf_data2vec_vision.py +1724 -0
- docs/transformers/build/lib/transformers/models/data2vec/modular_data2vec_audio.py +400 -0
- docs/transformers/build/lib/transformers/models/dbrx/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/dbrx/configuration_dbrx.py +232 -0
- docs/transformers/build/lib/transformers/models/dbrx/modeling_dbrx.py +1392 -0
- docs/transformers/build/lib/transformers/models/deberta/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/deberta/configuration_deberta.py +199 -0
- docs/transformers/build/lib/transformers/models/deberta/modeling_deberta.py +1352 -0
- docs/transformers/build/lib/transformers/models/deberta/modeling_tf_deberta.py +1652 -0
- docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta.py +396 -0
- docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta_fast.py +239 -0
- docs/transformers/build/lib/transformers/models/deberta_v2/__init__.py +30 -0
- docs/transformers/build/lib/transformers/models/deberta_v2/configuration_deberta_v2.py +198 -0
- docs/transformers/build/lib/transformers/models/deberta_v2/modeling_deberta_v2.py +1523 -0
- docs/transformers/build/lib/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +1881 -0
.dev_scripts/dockerci.sh
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache
|
| 3 |
+
CODE_DIR=$PWD
|
| 4 |
+
CODE_DIR_IN_CONTAINER=/ms-swift
|
| 5 |
+
echo "$USER"
|
| 6 |
+
gpus='0,1 2,3'
|
| 7 |
+
cpu_sets='0-15 16-31'
|
| 8 |
+
cpu_sets_arr=($cpu_sets)
|
| 9 |
+
is_get_file_lock=false
|
| 10 |
+
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
|
| 11 |
+
echo "ci command: $CI_COMMAND"
|
| 12 |
+
PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
|
| 13 |
+
echo "PR modified files: $PR_CHANGED_FILES"
|
| 14 |
+
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
|
| 15 |
+
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
|
| 16 |
+
idx=0
|
| 17 |
+
for gpu in $gpus
|
| 18 |
+
do
|
| 19 |
+
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
|
| 20 |
+
flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; idx=$((idx+1)); continue; }
|
| 21 |
+
echo "get gpu lock $gpu"
|
| 22 |
+
|
| 23 |
+
CONTAINER_NAME="swift-ci-$idx"
|
| 24 |
+
let is_get_file_lock=true
|
| 25 |
+
|
| 26 |
+
# pull image if there are update
|
| 27 |
+
docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
|
| 28 |
+
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
| 29 |
+
echo 'debugging'
|
| 30 |
+
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
|
| 31 |
+
--cpuset-cpus=${cpu_sets_arr[$idx]} \
|
| 32 |
+
--gpus='"'"device=$gpu"'"' \
|
| 33 |
+
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
|
| 34 |
+
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 35 |
+
-v $MODELSCOPE_HOME_CACHE/$idx:/root \
|
| 36 |
+
-v /home/admin/pre-commit:/home/admin/pre-commit \
|
| 37 |
+
-e CI_TEST=True \
|
| 38 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 39 |
+
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 40 |
+
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
|
| 41 |
+
-e MODELSCOPE_SDK_DEBUG=True \
|
| 42 |
+
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
|
| 43 |
+
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
|
| 44 |
+
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
|
| 45 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 46 |
+
-e MODELSCOPE_ENVIRONMENT='ci' \
|
| 47 |
+
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
| 48 |
+
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
| 49 |
+
-e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
|
| 50 |
+
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
| 51 |
+
--workdir=$CODE_DIR_IN_CONTAINER \
|
| 52 |
+
${IMAGE_NAME}:${IMAGE_VERSION} \
|
| 53 |
+
$CI_COMMAND
|
| 54 |
+
else
|
| 55 |
+
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
|
| 56 |
+
--cpuset-cpus=${cpu_sets_arr[$idx]} \
|
| 57 |
+
--gpus='"'"device=$gpu"'"' \
|
| 58 |
+
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
|
| 59 |
+
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 60 |
+
-v $MODELSCOPE_HOME_CACHE/$idx:/root \
|
| 61 |
+
-v /home/admin/pre-commit:/home/admin/pre-commit \
|
| 62 |
+
-e CI_TEST=True \
|
| 63 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 64 |
+
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 65 |
+
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
|
| 66 |
+
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
|
| 67 |
+
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
|
| 68 |
+
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
|
| 69 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 70 |
+
-e MODELSCOPE_ENVIRONMENT='ci' \
|
| 71 |
+
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
| 72 |
+
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
| 73 |
+
-e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
|
| 74 |
+
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
| 75 |
+
--workdir=$CODE_DIR_IN_CONTAINER \
|
| 76 |
+
${IMAGE_NAME}:${IMAGE_VERSION} \
|
| 77 |
+
$CI_COMMAND
|
| 78 |
+
fi
|
| 79 |
+
if [ $? -ne 0 ]; then
|
| 80 |
+
echo "Running test case failed, please check the log!"
|
| 81 |
+
exit -1
|
| 82 |
+
fi
|
| 83 |
+
break
|
| 84 |
+
done
|
| 85 |
+
if [ "$is_get_file_lock" = false ] ; then
|
| 86 |
+
echo 'No free GPU!'
|
| 87 |
+
exit 1
|
| 88 |
+
fi
|
.gitattributes
CHANGED
|
@@ -55,3 +55,4 @@ docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
|
|
| 55 |
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 55 |
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
docs/resources/grpo_code.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/citest.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: citest
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- master
|
| 7 |
+
- "release/**"
|
| 8 |
+
paths-ignore:
|
| 9 |
+
- "setup.*"
|
| 10 |
+
- "requirements.txt"
|
| 11 |
+
- "requirements/**"
|
| 12 |
+
- "docs/**"
|
| 13 |
+
- "tools/**"
|
| 14 |
+
- ".dev_scripts/**"
|
| 15 |
+
- "README.md"
|
| 16 |
+
- "README_*.md"
|
| 17 |
+
- "NOTICE"
|
| 18 |
+
- ".github/workflows/lint.yaml"
|
| 19 |
+
- ".github/workflows/publish.yaml"
|
| 20 |
+
|
| 21 |
+
pull_request:
|
| 22 |
+
paths-ignore:
|
| 23 |
+
- "setup.*"
|
| 24 |
+
- "requirements.txt"
|
| 25 |
+
- "requirements/**"
|
| 26 |
+
- "docs/**"
|
| 27 |
+
- "tools/**"
|
| 28 |
+
- ".dev_scripts/**"
|
| 29 |
+
- "README.md"
|
| 30 |
+
- "README_*.md"
|
| 31 |
+
- "NOTICE"
|
| 32 |
+
- ".github/workflows/lint.yaml"
|
| 33 |
+
- ".github/workflows/publish.yaml"
|
| 34 |
+
|
| 35 |
+
concurrency:
|
| 36 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 37 |
+
cancel-in-progress: true
|
| 38 |
+
|
| 39 |
+
jobs:
|
| 40 |
+
unittest:
|
| 41 |
+
# The type of runner that the job will run on
|
| 42 |
+
runs-on: [self-hosted]
|
| 43 |
+
timeout-minutes: 240
|
| 44 |
+
steps:
|
| 45 |
+
- name: ResetFileMode
|
| 46 |
+
shell: bash
|
| 47 |
+
run: |
|
| 48 |
+
# reset filemode to allow action runner to delete files
|
| 49 |
+
# generated by root in docker
|
| 50 |
+
set -e
|
| 51 |
+
source ~/.bashrc
|
| 52 |
+
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
| 53 |
+
|
| 54 |
+
- name: Checkout
|
| 55 |
+
uses: actions/checkout@v3
|
| 56 |
+
with:
|
| 57 |
+
lfs: 'true'
|
| 58 |
+
submodules: 'true'
|
| 59 |
+
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
|
| 60 |
+
- name: Get changed files
|
| 61 |
+
id: changed-files
|
| 62 |
+
run: |
|
| 63 |
+
if ${{ github.event_name == 'pull_request' }}; then
|
| 64 |
+
echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
|
| 65 |
+
else
|
| 66 |
+
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
|
| 67 |
+
fi
|
| 68 |
+
- name: Checkout LFS objects
|
| 69 |
+
run: git lfs checkout
|
| 70 |
+
- name: Run unittest
|
| 71 |
+
shell: bash
|
| 72 |
+
run: |
|
| 73 |
+
set -e
|
| 74 |
+
source /mnt/modelscope/ci_env.sh
|
| 75 |
+
bash .dev_scripts/dockerci.sh
|
docs/resources/grpo_code.png
ADDED
|
Git LFS Details
|
docs/transformers/build/lib/transformers/models/cpm/tokenization_cpm.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import unicodedata
|
| 19 |
+
from shutil import copyfile
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import sentencepiece as spm
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 25 |
+
from ...utils import SPIECE_UNDERLINE, logging
|
| 26 |
+
from ...utils.import_utils import requires
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@requires(backends=("sentencepiece",))
|
| 35 |
+
class CpmTokenizer(PreTrainedTokenizer):
|
| 36 |
+
"""Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models."""
|
| 37 |
+
|
| 38 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
vocab_file,
|
| 43 |
+
do_lower_case=False,
|
| 44 |
+
remove_space=True,
|
| 45 |
+
keep_accents=False,
|
| 46 |
+
bos_token="<s>",
|
| 47 |
+
eos_token="</s>",
|
| 48 |
+
unk_token="<unk>",
|
| 49 |
+
sep_token="<sep>",
|
| 50 |
+
pad_token="<pad>",
|
| 51 |
+
cls_token="<cls>",
|
| 52 |
+
mask_token="<mask>",
|
| 53 |
+
additional_special_tokens=["<eop>", "<eod>"],
|
| 54 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 55 |
+
**kwargs,
|
| 56 |
+
) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Construct a CPM tokenizer. Based on [Jieba](https://pypi.org/project/jieba/) and
|
| 59 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 60 |
+
|
| 61 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should
|
| 62 |
+
refer to this superclass for more information regarding those methods.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
vocab_file (`str`):
|
| 66 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
|
| 67 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 68 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 69 |
+
Whether to lowercase the input when tokenizing.
|
| 70 |
+
remove_space (`bool`, *optional*, defaults to `True`):
|
| 71 |
+
Whether to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 72 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 73 |
+
Whether to keep accents when tokenizing.
|
| 74 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 75 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier
|
| 76 |
+
token.
|
| 77 |
+
|
| 78 |
+
<Tip>
|
| 79 |
+
|
| 80 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 81 |
+
sequence. The token used is the `cls_token`.
|
| 82 |
+
|
| 83 |
+
</Tip>
|
| 84 |
+
|
| 85 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 86 |
+
The end of sequence token.
|
| 87 |
+
|
| 88 |
+
<Tip>
|
| 89 |
+
|
| 90 |
+
When building a sequence using special tokens, this is not the token that is used for the end of
|
| 91 |
+
sequence. The token used is the `sep_token`.
|
| 92 |
+
|
| 93 |
+
</Tip>
|
| 94 |
+
|
| 95 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 96 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be
|
| 97 |
+
this token instead.
|
| 98 |
+
sep_token (`str`, *optional*, defaults to `"<sep>"`):
|
| 99 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 100 |
+
for sequence classification or for a text and a question for question answering. It is also used as the
|
| 101 |
+
last token of a sequence built with special tokens.
|
| 102 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 103 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 104 |
+
cls_token (`str`, *optional*, defaults to `"<cls>"`):
|
| 105 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 106 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 107 |
+
special tokens.
|
| 108 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 109 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 110 |
+
modeling. This is the token which the model will try to predict.
|
| 111 |
+
additional_special_tokens (`List[str]`, *optional*, defaults to `["<eop>", "<eod>"]`):
|
| 112 |
+
Additional special tokens used by the tokenizer.
|
| 113 |
+
|
| 114 |
+
Attributes:
|
| 115 |
+
sp_model (`SentencePieceProcessor`):
|
| 116 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 117 |
+
"""
|
| 118 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 119 |
+
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
| 120 |
+
|
| 121 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 122 |
+
|
| 123 |
+
self.do_lower_case = do_lower_case
|
| 124 |
+
self.remove_space = remove_space
|
| 125 |
+
self.keep_accents = keep_accents
|
| 126 |
+
self.vocab_file = vocab_file
|
| 127 |
+
|
| 128 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 129 |
+
self.sp_model.Load(vocab_file)
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
import jieba
|
| 133 |
+
except ModuleNotFoundError as error:
|
| 134 |
+
raise error.__class__(
|
| 135 |
+
"You need to install jieba to use CpmTokenizer or CpmTokenizerFast. "
|
| 136 |
+
"See https://pypi.org/project/jieba/ for installation."
|
| 137 |
+
)
|
| 138 |
+
self.jieba = jieba
|
| 139 |
+
self.translator = str.maketrans(" \n", "\u2582\u2583")
|
| 140 |
+
|
| 141 |
+
super().__init__(
|
| 142 |
+
do_lower_case=do_lower_case,
|
| 143 |
+
remove_space=remove_space,
|
| 144 |
+
keep_accents=keep_accents,
|
| 145 |
+
bos_token=bos_token,
|
| 146 |
+
eos_token=eos_token,
|
| 147 |
+
unk_token=unk_token,
|
| 148 |
+
sep_token=sep_token,
|
| 149 |
+
pad_token=pad_token,
|
| 150 |
+
cls_token=cls_token,
|
| 151 |
+
mask_token=mask_token,
|
| 152 |
+
additional_special_tokens=additional_special_tokens,
|
| 153 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 154 |
+
**kwargs,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self._pad_token_type_id = 3
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.vocab_size
|
| 161 |
+
def vocab_size(self):
|
| 162 |
+
return len(self.sp_model)
|
| 163 |
+
|
| 164 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_vocab
|
| 165 |
+
def get_vocab(self):
|
| 166 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 167 |
+
vocab.update(self.added_tokens_encoder)
|
| 168 |
+
return vocab
|
| 169 |
+
|
| 170 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__getstate__
|
| 171 |
+
def __getstate__(self):
|
| 172 |
+
state = self.__dict__.copy()
|
| 173 |
+
state["sp_model"] = None
|
| 174 |
+
return state
|
| 175 |
+
|
| 176 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__setstate__
|
| 177 |
+
def __setstate__(self, d):
|
| 178 |
+
self.__dict__ = d
|
| 179 |
+
|
| 180 |
+
# for backward compatibility
|
| 181 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 182 |
+
self.sp_model_kwargs = {}
|
| 183 |
+
|
| 184 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 185 |
+
self.sp_model.Load(self.vocab_file)
|
| 186 |
+
|
| 187 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.preprocess_text
|
| 188 |
+
def preprocess_text(self, inputs):
|
| 189 |
+
if self.remove_space:
|
| 190 |
+
outputs = " ".join(inputs.strip().split())
|
| 191 |
+
else:
|
| 192 |
+
outputs = inputs
|
| 193 |
+
outputs = outputs.replace("``", '"').replace("''", '"')
|
| 194 |
+
|
| 195 |
+
if not self.keep_accents:
|
| 196 |
+
outputs = unicodedata.normalize("NFKD", outputs)
|
| 197 |
+
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
| 198 |
+
if self.do_lower_case:
|
| 199 |
+
outputs = outputs.lower()
|
| 200 |
+
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._tokenize
|
| 204 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 205 |
+
"""Tokenize a string."""
|
| 206 |
+
text = self.preprocess_text(text)
|
| 207 |
+
pieces = self.sp_model.encode(text, out_type=str)
|
| 208 |
+
new_pieces = []
|
| 209 |
+
for piece in pieces:
|
| 210 |
+
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
|
| 211 |
+
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
| 212 |
+
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
| 213 |
+
if len(cur_pieces[0]) == 1:
|
| 214 |
+
cur_pieces = cur_pieces[1:]
|
| 215 |
+
else:
|
| 216 |
+
cur_pieces[0] = cur_pieces[0][1:]
|
| 217 |
+
cur_pieces.append(piece[-1])
|
| 218 |
+
new_pieces.extend(cur_pieces)
|
| 219 |
+
else:
|
| 220 |
+
new_pieces.append(piece)
|
| 221 |
+
|
| 222 |
+
return new_pieces
|
| 223 |
+
|
| 224 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_token_to_id
|
| 225 |
+
def _convert_token_to_id(self, token):
|
| 226 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 227 |
+
return self.sp_model.PieceToId(token)
|
| 228 |
+
|
| 229 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_id_to_token
|
| 230 |
+
def _convert_id_to_token(self, index):
|
| 231 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 232 |
+
return self.sp_model.IdToPiece(index)
|
| 233 |
+
|
| 234 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.convert_tokens_to_string
|
| 235 |
+
def convert_tokens_to_string(self, tokens):
|
| 236 |
+
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
| 237 |
+
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
| 238 |
+
return out_string
|
| 239 |
+
|
| 240 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.build_inputs_with_special_tokens
|
| 241 |
+
def build_inputs_with_special_tokens(
|
| 242 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 243 |
+
) -> List[int]:
|
| 244 |
+
"""
|
| 245 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 246 |
+
adding special tokens. An XLNet sequence has the following format:
|
| 247 |
+
|
| 248 |
+
- single sequence: `X <sep> <cls>`
|
| 249 |
+
- pair of sequences: `A <sep> B <sep> <cls>`
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
token_ids_0 (`List[int]`):
|
| 253 |
+
List of IDs to which the special tokens will be added.
|
| 254 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 255 |
+
Optional second list of IDs for sequence pairs.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 259 |
+
"""
|
| 260 |
+
sep = [self.sep_token_id]
|
| 261 |
+
cls = [self.cls_token_id]
|
| 262 |
+
if token_ids_1 is None:
|
| 263 |
+
return token_ids_0 + sep + cls
|
| 264 |
+
return token_ids_0 + sep + token_ids_1 + sep + cls
|
| 265 |
+
|
| 266 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_special_tokens_mask
|
| 267 |
+
def get_special_tokens_mask(
|
| 268 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 269 |
+
) -> List[int]:
|
| 270 |
+
"""
|
| 271 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 272 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
token_ids_0 (`List[int]`):
|
| 276 |
+
List of IDs.
|
| 277 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 278 |
+
Optional second list of IDs for sequence pairs.
|
| 279 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 280 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
if already_has_special_tokens:
|
| 287 |
+
return super().get_special_tokens_mask(
|
| 288 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if token_ids_1 is not None:
|
| 292 |
+
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
|
| 293 |
+
return ([0] * len(token_ids_0)) + [1, 1]
|
| 294 |
+
|
| 295 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.create_token_type_ids_from_sequences
|
| 296 |
+
def create_token_type_ids_from_sequences(
|
| 297 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 298 |
+
) -> List[int]:
|
| 299 |
+
"""
|
| 300 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
|
| 301 |
+
sequence pair mask has the following format:
|
| 302 |
+
|
| 303 |
+
```
|
| 304 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 305 |
+
| first sequence | second sequence |
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
token_ids_0 (`List[int]`):
|
| 312 |
+
List of IDs.
|
| 313 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 314 |
+
Optional second list of IDs for sequence pairs.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 318 |
+
"""
|
| 319 |
+
sep = [self.sep_token_id]
|
| 320 |
+
cls_segment_id = [2]
|
| 321 |
+
|
| 322 |
+
if token_ids_1 is None:
|
| 323 |
+
return len(token_ids_0 + sep) * [0] + cls_segment_id
|
| 324 |
+
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
|
| 325 |
+
|
| 326 |
+
# Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.save_vocabulary
|
| 327 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 328 |
+
if not os.path.isdir(save_directory):
|
| 329 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 330 |
+
return
|
| 331 |
+
out_vocab_file = os.path.join(
|
| 332 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 336 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 337 |
+
elif not os.path.isfile(self.vocab_file):
|
| 338 |
+
with open(out_vocab_file, "wb") as fi:
|
| 339 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 340 |
+
fi.write(content_spiece_model)
|
| 341 |
+
|
| 342 |
+
return (out_vocab_file,)
|
| 343 |
+
|
| 344 |
+
def _decode(self, *args, **kwargs):
|
| 345 |
+
text = super()._decode(*args, **kwargs)
|
| 346 |
+
text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n")
|
| 347 |
+
return text
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
__all__ = ["CpmTokenizer"]
|
docs/transformers/build/lib/transformers/models/cpmant/modeling_cpmant.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch CPMAnt"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...generation import GenerationMixin
|
| 28 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 29 |
+
from ...modeling_utils import PreTrainedModel
|
| 30 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 31 |
+
from .configuration_cpmant import CpmAntConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
_CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b"
|
| 37 |
+
_CONFIG_FOR_DOC = "CpmAntConfig"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CpmAntLayerNorm(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config: CpmAntConfig):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
self.eps = config.eps
|
| 49 |
+
self.dim_norm = config.hidden_size
|
| 50 |
+
self.weight = nn.Parameter(torch.empty(config.hidden_size))
|
| 51 |
+
|
| 52 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
| 56 |
+
"""
|
| 57 |
+
if hidden_states.size(-1) != self.dim_norm:
|
| 58 |
+
raise AssertionError("hidden_states.size(-1) != self.dim_norm")
|
| 59 |
+
old_dtype = hidden_states.dtype
|
| 60 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
| 61 |
+
hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
|
| 62 |
+
return hidden_states
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class CpmAntAttention(nn.Module):
|
| 66 |
+
def __init__(self, config: CpmAntConfig):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.dim_model = config.hidden_size
|
| 69 |
+
self.num_heads = config.num_attention_heads
|
| 70 |
+
self.dim_head = config.dim_head
|
| 71 |
+
|
| 72 |
+
self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
| 73 |
+
self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
| 74 |
+
self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
| 75 |
+
|
| 76 |
+
self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
|
| 77 |
+
|
| 78 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
| 79 |
+
|
| 80 |
+
if config.dropout_p is not None:
|
| 81 |
+
self.dropout = torch.nn.Dropout(p=config.dropout_p)
|
| 82 |
+
else:
|
| 83 |
+
self.dropout = None
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
hidden_q: torch.Tensor,
|
| 88 |
+
hidden_kv: torch.Tensor,
|
| 89 |
+
attention_mask: torch.BoolTensor,
|
| 90 |
+
position_bias: torch.Tensor,
|
| 91 |
+
output_attentions: Optional[bool] = False,
|
| 92 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 93 |
+
use_cache: Optional[bool] = None,
|
| 94 |
+
):
|
| 95 |
+
"""
|
| 96 |
+
Args:
|
| 97 |
+
hidden_q (`torch.Tensor`):
|
| 98 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
| 99 |
+
hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
|
| 100 |
+
Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
|
| 101 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
| 102 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
| 103 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
| 104 |
+
Provide positional information to self-attention block.
|
| 105 |
+
output_attentions (`bool`, *optional*):
|
| 106 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 107 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
|
| 108 |
+
Cached past key and value projection states.
|
| 109 |
+
use_cache (`bool`, *optional*):
|
| 110 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 111 |
+
(see `past_key_values`).
|
| 112 |
+
"""
|
| 113 |
+
batch_size = hidden_q.size(0)
|
| 114 |
+
len_q = hidden_q.size(1)
|
| 115 |
+
len_k = hidden_kv.size(1)
|
| 116 |
+
|
| 117 |
+
query = self.project_q(hidden_q)
|
| 118 |
+
key = self.project_k(hidden_kv)
|
| 119 |
+
value = self.project_v(hidden_kv)
|
| 120 |
+
|
| 121 |
+
query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
| 122 |
+
key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
| 123 |
+
value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
| 124 |
+
|
| 125 |
+
if past_key_values is not None:
|
| 126 |
+
key = torch.cat([past_key_values[0], key], dim=-2)
|
| 127 |
+
value = torch.cat([past_key_values[1], value], dim=-2)
|
| 128 |
+
len_k = key.size(-2)
|
| 129 |
+
|
| 130 |
+
# (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
|
| 131 |
+
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
|
| 132 |
+
score = score + position_bias
|
| 133 |
+
|
| 134 |
+
score = torch.masked_fill(
|
| 135 |
+
score,
|
| 136 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
| 137 |
+
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
| 138 |
+
)
|
| 139 |
+
score = self.softmax(score)
|
| 140 |
+
|
| 141 |
+
score = torch.masked_fill(
|
| 142 |
+
score,
|
| 143 |
+
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
| 144 |
+
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
| 145 |
+
)
|
| 146 |
+
if output_attentions:
|
| 147 |
+
attn_weights = score
|
| 148 |
+
else:
|
| 149 |
+
attn_weights = None
|
| 150 |
+
|
| 151 |
+
if self.dropout is not None:
|
| 152 |
+
score = self.dropout(score)
|
| 153 |
+
|
| 154 |
+
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
|
| 155 |
+
score = torch.matmul(score, value)
|
| 156 |
+
|
| 157 |
+
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
| 158 |
+
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
| 159 |
+
|
| 160 |
+
score = self.attention_out(score)
|
| 161 |
+
|
| 162 |
+
past_key_values = None
|
| 163 |
+
if use_cache:
|
| 164 |
+
past_key_values = (key, value)
|
| 165 |
+
|
| 166 |
+
return score, attn_weights, past_key_values
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class CpmAntSelfAttentionBlock(nn.Module):
|
| 170 |
+
def __init__(self, config: CpmAntConfig):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.layernorm_before_attention = CpmAntLayerNorm(config)
|
| 173 |
+
self.self_attention = CpmAntAttention(config)
|
| 174 |
+
if config.dropout_p:
|
| 175 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
| 176 |
+
else:
|
| 177 |
+
self.dropout = None
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
hidden_states: torch.Tensor,
|
| 182 |
+
attention_mask: torch.Tensor,
|
| 183 |
+
position_bias: Optional[torch.Tensor] = None,
|
| 184 |
+
output_attentions: Optional[bool] = False,
|
| 185 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 186 |
+
use_cache: Optional[bool] = None,
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Args:
|
| 190 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
| 191 |
+
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
| 192 |
+
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
| 193 |
+
Avoid invalid areas to participate in the calculation of self-attention.
|
| 194 |
+
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
| 195 |
+
Provide positional information to self-attention block.
|
| 196 |
+
output_attentions (`bool`, *optional*):
|
| 197 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 198 |
+
past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
|
| 199 |
+
Cached past key and value projection states.
|
| 200 |
+
use_cache (`bool`, *optional*):
|
| 201 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 202 |
+
(see `past_key_values`).
|
| 203 |
+
"""
|
| 204 |
+
outputs = self.layernorm_before_attention(hidden_states)
|
| 205 |
+
outputs = self.self_attention(
|
| 206 |
+
outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
outputs, attn_weights, current_key_value = outputs
|
| 210 |
+
|
| 211 |
+
if self.dropout is not None:
|
| 212 |
+
outputs = self.dropout(outputs)
|
| 213 |
+
hidden_states = hidden_states + outputs
|
| 214 |
+
|
| 215 |
+
return hidden_states, attn_weights, current_key_value
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class CpmAntDenseGatedACT(nn.Module):
|
| 219 |
+
def __init__(self, config: CpmAntConfig):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
|
| 222 |
+
self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
|
| 223 |
+
self.act = torch.nn.GELU()
|
| 224 |
+
|
| 225 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 226 |
+
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
| 230 |
+
"""
|
| 231 |
+
gate_score = self.act(self.w_0(hidden_states))
|
| 232 |
+
hidden_states = self.w_1(hidden_states)
|
| 233 |
+
|
| 234 |
+
hidden_states = gate_score * hidden_states
|
| 235 |
+
return hidden_states
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class CpmAntFeedForward(nn.Module):
|
| 239 |
+
def __init__(self, config: CpmAntConfig):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.w_in = CpmAntDenseGatedACT(config)
|
| 242 |
+
if config.dropout_p is not None:
|
| 243 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
| 244 |
+
else:
|
| 245 |
+
self.dropout = None
|
| 246 |
+
|
| 247 |
+
self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
|
| 248 |
+
|
| 249 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 250 |
+
"""
|
| 251 |
+
Args:
|
| 252 |
+
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
| 253 |
+
"""
|
| 254 |
+
hidden_states = self.w_in(hidden_states)
|
| 255 |
+
|
| 256 |
+
if self.dropout is not None:
|
| 257 |
+
hidden_states = self.dropout(hidden_states)
|
| 258 |
+
|
| 259 |
+
hidden_states = self.w_out(hidden_states)
|
| 260 |
+
|
| 261 |
+
return hidden_states
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class CpmAntFFNBlock(nn.Module):
|
| 265 |
+
def __init__(self, config: CpmAntConfig):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.layernorm_before_ffn = CpmAntLayerNorm(config)
|
| 268 |
+
self.ffn = CpmAntFeedForward(config)
|
| 269 |
+
if config.dropout_p:
|
| 270 |
+
self.dropout = torch.nn.Dropout(config.dropout_p)
|
| 271 |
+
else:
|
| 272 |
+
self.dropout = None
|
| 273 |
+
|
| 274 |
+
def forward(
|
| 275 |
+
self,
|
| 276 |
+
hidden_states: torch.Tensor,
|
| 277 |
+
):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
| 281 |
+
Hidden states before feed forward layer.
|
| 282 |
+
"""
|
| 283 |
+
ln_outputs = self.layernorm_before_ffn(hidden_states)
|
| 284 |
+
outputs = self.ffn(ln_outputs)
|
| 285 |
+
if self.dropout is not None:
|
| 286 |
+
outputs = self.dropout(outputs)
|
| 287 |
+
hidden_states = hidden_states + outputs
|
| 288 |
+
return hidden_states
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class CpmAntTransformerBlock(nn.Module):
|
| 292 |
+
def __init__(self, config: CpmAntConfig):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.self_att = CpmAntSelfAttentionBlock(config)
|
| 295 |
+
self.ffn = CpmAntFFNBlock(config)
|
| 296 |
+
|
| 297 |
+
def forward(
|
| 298 |
+
self,
|
| 299 |
+
hidden_states: torch.Tensor,
|
| 300 |
+
attention_mask: torch.Tensor,
|
| 301 |
+
position_bias: Optional[torch.Tensor] = None,
|
| 302 |
+
output_attentions: Optional[bool] = False,
|
| 303 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 304 |
+
use_cache: Optional[bool] = None,
|
| 305 |
+
):
|
| 306 |
+
"""
|
| 307 |
+
Args:
|
| 308 |
+
hidden_states (`torch.Tensor`):
|
| 309 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
| 310 |
+
attention_mask (`torch.Tensor`):
|
| 311 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
| 312 |
+
position_bias (`torch.Tensor`):
|
| 313 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
| 314 |
+
output_attentions (`bool`, *optional*):
|
| 315 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 316 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
| 317 |
+
Cached past key and value projection states
|
| 318 |
+
use_cache (`bool`, *optional*):
|
| 319 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 320 |
+
(see `past_key_values`).
|
| 321 |
+
"""
|
| 322 |
+
hidden_states = self.self_att(
|
| 323 |
+
hidden_states,
|
| 324 |
+
attention_mask=attention_mask,
|
| 325 |
+
position_bias=position_bias,
|
| 326 |
+
output_attentions=output_attentions,
|
| 327 |
+
past_key_values=past_key_values,
|
| 328 |
+
use_cache=use_cache,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
hidden_states, attn_weights, current_key_value = hidden_states
|
| 332 |
+
|
| 333 |
+
hidden_states = self.ffn(hidden_states)
|
| 334 |
+
|
| 335 |
+
return hidden_states, attn_weights, current_key_value
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class CpmAntEncoder(nn.Module):
|
| 339 |
+
def __init__(self, config: CpmAntConfig):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.num_layers = config.num_hidden_layers
|
| 342 |
+
self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)])
|
| 343 |
+
|
| 344 |
+
self.output_layernorm = CpmAntLayerNorm(config)
|
| 345 |
+
|
| 346 |
+
def forward(
|
| 347 |
+
self,
|
| 348 |
+
hidden_states: torch.Tensor,
|
| 349 |
+
attention_mask: torch.Tensor,
|
| 350 |
+
position_bias: torch.Tensor,
|
| 351 |
+
output_attentions: Optional[bool] = None,
|
| 352 |
+
output_hidden_states: Optional[bool] = None,
|
| 353 |
+
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 354 |
+
use_cache: Optional[bool] = None,
|
| 355 |
+
):
|
| 356 |
+
"""
|
| 357 |
+
Args:
|
| 358 |
+
hidden_states (`torch.Tensor`):
|
| 359 |
+
Input to the layer of shape `(batch, seq_len, dim_model)`
|
| 360 |
+
attention_mask (`torch.Tensor`):
|
| 361 |
+
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
| 362 |
+
position_bias (`torch.Tensor`):
|
| 363 |
+
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
| 364 |
+
output_attentions (`bool`, *optional*):
|
| 365 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 366 |
+
output_hidden_states (`bool`, *optional*):
|
| 367 |
+
Whether or not to return the hidden states of all layers.
|
| 368 |
+
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
| 369 |
+
Cached past key and value projection states
|
| 370 |
+
use_cache (`bool`, *optional*):
|
| 371 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 372 |
+
(see `past_key_values`).
|
| 373 |
+
"""
|
| 374 |
+
all_hidden_states = () if output_hidden_states else None
|
| 375 |
+
all_self_attns = () if output_attentions else None
|
| 376 |
+
current_key_values = () if use_cache else None
|
| 377 |
+
|
| 378 |
+
for i, layer in enumerate(self.layers):
|
| 379 |
+
if output_hidden_states:
|
| 380 |
+
all_hidden_states += (hidden_states,)
|
| 381 |
+
layer_outputs = layer(
|
| 382 |
+
hidden_states,
|
| 383 |
+
attention_mask,
|
| 384 |
+
position_bias,
|
| 385 |
+
output_attentions=output_attentions,
|
| 386 |
+
past_key_values=past_key_values[i] if past_key_values else None,
|
| 387 |
+
use_cache=use_cache,
|
| 388 |
+
)
|
| 389 |
+
hidden_states, attn_weights, current_key_value = layer_outputs
|
| 390 |
+
if output_attentions:
|
| 391 |
+
all_self_attns += (attn_weights,)
|
| 392 |
+
if current_key_value is not None:
|
| 393 |
+
current_key_values = current_key_values + (current_key_value,)
|
| 394 |
+
|
| 395 |
+
hidden_states = self.output_layernorm(hidden_states)
|
| 396 |
+
|
| 397 |
+
if output_hidden_states:
|
| 398 |
+
all_hidden_states += (hidden_states,)
|
| 399 |
+
|
| 400 |
+
return hidden_states, current_key_values, all_hidden_states, all_self_attns
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
|
| 404 |
+
class CpmAntIntermediate(nn.Module):
|
| 405 |
+
def __init__(self, config):
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 408 |
+
if isinstance(config.hidden_act, str):
|
| 409 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 410 |
+
else:
|
| 411 |
+
self.intermediate_act_fn = config.hidden_act
|
| 412 |
+
|
| 413 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 414 |
+
hidden_states = self.dense(hidden_states)
|
| 415 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 416 |
+
return hidden_states
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class CpmAntSegmentPositionEmbedding(nn.Module):
|
| 420 |
+
def __init__(self, config: CpmAntConfig):
|
| 421 |
+
super().__init__()
|
| 422 |
+
|
| 423 |
+
self.num_heads = config.num_attention_heads
|
| 424 |
+
self.num_buckets = config.position_bias_num_buckets
|
| 425 |
+
self.max_distance = config.position_bias_max_distance
|
| 426 |
+
self.num_segments = config.segment_types
|
| 427 |
+
|
| 428 |
+
self.relative_attention_bias = nn.Parameter(
|
| 429 |
+
torch.empty(
|
| 430 |
+
config.segment_types * config.segment_types + config.position_bias_num_buckets,
|
| 431 |
+
config.num_attention_heads,
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self,
|
| 437 |
+
key_pos: torch.Tensor,
|
| 438 |
+
query_pos: torch.Tensor,
|
| 439 |
+
key_segment: torch.Tensor,
|
| 440 |
+
query_segment: torch.Tensor,
|
| 441 |
+
):
|
| 442 |
+
with torch.no_grad():
|
| 443 |
+
batch = key_pos.size(0)
|
| 444 |
+
keylen = key_pos.size(1)
|
| 445 |
+
querylen = query_pos.size(1)
|
| 446 |
+
|
| 447 |
+
if key_pos.size(0) != query_pos.size(0):
|
| 448 |
+
raise AssertionError(
|
| 449 |
+
f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
|
| 450 |
+
)
|
| 451 |
+
if keylen != key_segment.size(1) or querylen != query_segment.size(1):
|
| 452 |
+
raise AssertionError(
|
| 453 |
+
f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
|
| 454 |
+
)
|
| 455 |
+
if querylen != query_segment.size(1):
|
| 456 |
+
raise AssertionError(
|
| 457 |
+
f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
key_pos = key_pos.view(batch, -1, keylen)
|
| 461 |
+
query_pos = query_pos.view(batch, querylen, -1)
|
| 462 |
+
key_segment = key_segment.view(batch, -1, keylen)
|
| 463 |
+
query_segment = query_segment.view(batch, querylen, -1)
|
| 464 |
+
|
| 465 |
+
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
| 466 |
+
relative_position_bucket = relative_position_bucket + self.num_buckets
|
| 467 |
+
|
| 468 |
+
# (batch, len_q, len_k)
|
| 469 |
+
absolute_position_bucket = self._position_bucket(
|
| 470 |
+
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
| 471 |
+
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
| 472 |
+
num_buckets=self.num_buckets,
|
| 473 |
+
max_distance=self.max_distance,
|
| 474 |
+
)
|
| 475 |
+
relative_position_bucket = torch.where(
|
| 476 |
+
(key_segment == query_segment),
|
| 477 |
+
absolute_position_bucket[None, :, :],
|
| 478 |
+
relative_position_bucket,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# (batch, len_q, len_k, num_heads)
|
| 482 |
+
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
| 483 |
+
# (batch, num_heads, len_q, len_k)
|
| 484 |
+
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
| 485 |
+
return embeds
|
| 486 |
+
|
| 487 |
+
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
| 488 |
+
return query_segment * self.num_segments + key_segment
|
| 489 |
+
|
| 490 |
+
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
| 491 |
+
relative_buckets = 0
|
| 492 |
+
# always bidirectional in CPMAnt
|
| 493 |
+
num_buckets //= 2
|
| 494 |
+
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
| 495 |
+
relative_position = torch.abs(relative_position)
|
| 496 |
+
max_exact = num_buckets // 2
|
| 497 |
+
is_small = relative_position < max_exact
|
| 498 |
+
relative_postion_if_large = max_exact + (
|
| 499 |
+
torch.log(relative_position.float() / max_exact)
|
| 500 |
+
/ math.log(max_distance / max_exact)
|
| 501 |
+
* (num_buckets - max_exact)
|
| 502 |
+
).to(torch.int32)
|
| 503 |
+
relative_postion_if_large = torch.min(
|
| 504 |
+
relative_postion_if_large,
|
| 505 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
| 506 |
+
)
|
| 507 |
+
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
| 508 |
+
return relative_buckets
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
|
| 512 |
+
class CpmAntOutput(nn.Module):
|
| 513 |
+
def __init__(self, config):
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 516 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 517 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 518 |
+
|
| 519 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 520 |
+
hidden_states = self.dense(hidden_states)
|
| 521 |
+
hidden_states = self.dropout(hidden_states)
|
| 522 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 523 |
+
return hidden_states
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class CpmAntPreTrainedModel(PreTrainedModel):
|
| 527 |
+
"""
|
| 528 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 529 |
+
models.
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
config_class = CpmAntConfig
|
| 533 |
+
base_model_prefix = "cpmant"
|
| 534 |
+
|
| 535 |
+
def _init_weights(self, module):
|
| 536 |
+
"""Initialize the weights"""
|
| 537 |
+
if isinstance(module, nn.Linear):
|
| 538 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
| 539 |
+
if module.bias is not None:
|
| 540 |
+
module.bias.data.zero_()
|
| 541 |
+
elif isinstance(module, nn.Embedding):
|
| 542 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
| 543 |
+
if module.padding_idx is not None:
|
| 544 |
+
module.weight.data[module.padding_idx].zero_()
|
| 545 |
+
elif isinstance(module, nn.LayerNorm):
|
| 546 |
+
module.bias.data.zero_()
|
| 547 |
+
module.weight.data.fill_(1.0)
|
| 548 |
+
elif isinstance(module, CpmAntLayerNorm):
|
| 549 |
+
module.weight.data.fill_(1.0)
|
| 550 |
+
elif isinstance(module, CpmAntSegmentPositionEmbedding):
|
| 551 |
+
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
CPMANT_START_DOCSTRING = r"""
|
| 555 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 556 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 557 |
+
behavior.
|
| 558 |
+
|
| 559 |
+
Parameters
|
| 560 |
+
config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the
|
| 561 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 562 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
CPMANT_INPUTS_DOCSTRING = r"""
|
| 566 |
+
Args:
|
| 567 |
+
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
| 568 |
+
Indices of input sequence tokens in the vocabulary.
|
| 569 |
+
|
| 570 |
+
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 571 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 572 |
+
|
| 573 |
+
[What are input IDs?](../glossary#input-ids)
|
| 574 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 575 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 576 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 577 |
+
use_cache (`bool`, *optional*):
|
| 578 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 579 |
+
`past_key_values`).
|
| 580 |
+
output_attentions (`bool`, *optional*):
|
| 581 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 582 |
+
output_hidden_states (`bool`, *optional*):
|
| 583 |
+
Whether or not to return the hidden states of all layers.
|
| 584 |
+
return_dict (`bool`, *optional*):
|
| 585 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@add_start_docstrings(
|
| 590 |
+
"The bare CPMAnt Model outputting raw hidden-states without any specific head on top.",
|
| 591 |
+
CPMANT_START_DOCSTRING,
|
| 592 |
+
)
|
| 593 |
+
class CpmAntModel(CpmAntPreTrainedModel):
|
| 594 |
+
def __init__(self, config: CpmAntConfig):
|
| 595 |
+
super().__init__(config)
|
| 596 |
+
self.encoder = CpmAntEncoder(config)
|
| 597 |
+
self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
|
| 598 |
+
self.input_embedding = nn.Embedding(
|
| 599 |
+
config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
|
| 600 |
+
)
|
| 601 |
+
self.position_bias = CpmAntSegmentPositionEmbedding(config)
|
| 602 |
+
self.prompt_length = config.prompt_length
|
| 603 |
+
self.vocab_size = config.vocab_size
|
| 604 |
+
|
| 605 |
+
self.post_init()
|
| 606 |
+
|
| 607 |
+
def get_input_embeddings(self):
|
| 608 |
+
return self.input_embedding
|
| 609 |
+
|
| 610 |
+
def set_input_embeddings(self, embeddings, **kwargs):
|
| 611 |
+
self.input_embedding = embeddings
|
| 612 |
+
|
| 613 |
+
def _prepare_attention_mask(self, input_ids, span, context, length):
|
| 614 |
+
batch = input_ids.size(0)
|
| 615 |
+
seqlen = input_ids.size(1)
|
| 616 |
+
device = input_ids.device
|
| 617 |
+
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
| 618 |
+
attention_mask = context[:, None, :] | (
|
| 619 |
+
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
| 620 |
+
)
|
| 621 |
+
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
| 622 |
+
# mask for left padding
|
| 623 |
+
mask_1d = (
|
| 624 |
+
torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
|
| 625 |
+
< length[:, None]
|
| 626 |
+
)
|
| 627 |
+
mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
|
| 628 |
+
attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
|
| 629 |
+
return attention_mask
|
| 630 |
+
|
| 631 |
+
@add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
|
| 632 |
+
@add_code_sample_docstrings(
|
| 633 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 634 |
+
output_type=BaseModelOutputWithPast,
|
| 635 |
+
config_class=_CONFIG_FOR_DOC,
|
| 636 |
+
)
|
| 637 |
+
def forward(
|
| 638 |
+
self,
|
| 639 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 640 |
+
output_attentions: Optional[bool] = None,
|
| 641 |
+
output_hidden_states: Optional[bool] = None,
|
| 642 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 643 |
+
use_cache: Optional[bool] = None,
|
| 644 |
+
return_dict: Optional[bool] = None,
|
| 645 |
+
**kwargs,
|
| 646 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
| 647 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 648 |
+
output_hidden_states = (
|
| 649 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 650 |
+
)
|
| 651 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 652 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 653 |
+
|
| 654 |
+
# add prompts ahead
|
| 655 |
+
if input_ids.dtype != torch.int32:
|
| 656 |
+
input_ids = input_ids.to(torch.int32)
|
| 657 |
+
dtype, device = input_ids.dtype, input_ids.device
|
| 658 |
+
segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
|
| 659 |
+
length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
|
| 660 |
+
input_ids = torch.cat(
|
| 661 |
+
(
|
| 662 |
+
torch.arange(
|
| 663 |
+
self.prompt_length * 2 + self.vocab_size,
|
| 664 |
+
self.prompt_length * 3 + self.vocab_size,
|
| 665 |
+
dtype=dtype,
|
| 666 |
+
device=device,
|
| 667 |
+
).repeat(input_ids.size(0), 1),
|
| 668 |
+
input_ids,
|
| 669 |
+
),
|
| 670 |
+
dim=1,
|
| 671 |
+
)
|
| 672 |
+
batch, seq_length = input_ids.size()
|
| 673 |
+
segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
|
| 674 |
+
context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
|
| 675 |
+
position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
|
| 676 |
+
span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
| 677 |
+
|
| 678 |
+
if past_key_values is None:
|
| 679 |
+
past_length = 0
|
| 680 |
+
past_key_values = tuple([None] * self.encoder.num_layers)
|
| 681 |
+
input_ids = input_ids.contiguous()
|
| 682 |
+
hidden_states = self.input_embedding(input_ids)
|
| 683 |
+
segment_states = self.segment_embedding(segment)
|
| 684 |
+
hidden_states = hidden_states + segment_states
|
| 685 |
+
else:
|
| 686 |
+
past_length = past_key_values[0][0].size(-2)
|
| 687 |
+
segment_states = self.segment_embedding(segment)
|
| 688 |
+
hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :]
|
| 689 |
+
|
| 690 |
+
attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
|
| 691 |
+
position_bias = self.position_bias(position, position, segment, segment)
|
| 692 |
+
|
| 693 |
+
attention_mask = attention_mask[:, past_length:, :]
|
| 694 |
+
position_bias = position_bias[:, :, past_length:, :]
|
| 695 |
+
hidden_states = hidden_states[:, past_length:, :]
|
| 696 |
+
|
| 697 |
+
hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
|
| 698 |
+
hidden_states,
|
| 699 |
+
attention_mask,
|
| 700 |
+
position_bias,
|
| 701 |
+
output_attentions,
|
| 702 |
+
output_hidden_states,
|
| 703 |
+
past_key_values,
|
| 704 |
+
use_cache,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
if past_length == 0:
|
| 708 |
+
hidden_states = hidden_states[:, self.prompt_length :, :]
|
| 709 |
+
# drop the prompt
|
| 710 |
+
if all_attentions is not None:
|
| 711 |
+
new_attentions = ()
|
| 712 |
+
for attention in all_attentions:
|
| 713 |
+
new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
|
| 714 |
+
all_attentions = new_attentions
|
| 715 |
+
if all_hidden_states is not None:
|
| 716 |
+
new_hidden_states = ()
|
| 717 |
+
for hidden_state in all_hidden_states:
|
| 718 |
+
new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
|
| 719 |
+
all_hidden_states = new_hidden_states
|
| 720 |
+
|
| 721 |
+
if not return_dict:
|
| 722 |
+
return tuple(
|
| 723 |
+
v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
return BaseModelOutputWithPast(
|
| 727 |
+
last_hidden_state=hidden_states,
|
| 728 |
+
past_key_values=present_key_values,
|
| 729 |
+
hidden_states=all_hidden_states,
|
| 730 |
+
attentions=all_attentions,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
@add_start_docstrings(
|
| 735 |
+
"""
|
| 736 |
+
The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
| 737 |
+
""",
|
| 738 |
+
CPMANT_START_DOCSTRING,
|
| 739 |
+
)
|
| 740 |
+
class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
|
| 741 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 742 |
+
|
| 743 |
+
def __init__(self, config: CpmAntConfig):
|
| 744 |
+
super().__init__(config)
|
| 745 |
+
self.cpmant = CpmAntModel(config)
|
| 746 |
+
|
| 747 |
+
# lm_head.weight is tied to cpmant.input_embedding.weight
|
| 748 |
+
self.lm_head = nn.Linear(
|
| 749 |
+
config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
|
| 750 |
+
)
|
| 751 |
+
self.post_init()
|
| 752 |
+
|
| 753 |
+
@add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
|
| 754 |
+
@add_code_sample_docstrings(
|
| 755 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 756 |
+
output_type=CausalLMOutputWithPast,
|
| 757 |
+
config_class=_CONFIG_FOR_DOC,
|
| 758 |
+
)
|
| 759 |
+
def forward(
|
| 760 |
+
self,
|
| 761 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 762 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 763 |
+
use_cache: Optional[bool] = None,
|
| 764 |
+
output_attentions: Optional[bool] = None,
|
| 765 |
+
output_hidden_states: Optional[bool] = None,
|
| 766 |
+
labels: Optional[torch.Tensor] = None,
|
| 767 |
+
return_dict: Optional[bool] = None,
|
| 768 |
+
attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline
|
| 769 |
+
**kwargs,
|
| 770 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 771 |
+
r"""
|
| 772 |
+
Args:
|
| 773 |
+
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
| 774 |
+
Indices of input sequence tokens in the vocabulary.
|
| 775 |
+
|
| 776 |
+
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 777 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 778 |
+
|
| 779 |
+
[What are input IDs?](../glossary#input-ids)
|
| 780 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 781 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
| 782 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 783 |
+
use_cache (`bool`, *optional*):
|
| 784 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 785 |
+
(see `past_key_values`).
|
| 786 |
+
output_attentions (`bool`, *optional*):
|
| 787 |
+
Whether or not to return the attentions tensors of all attention layers.
|
| 788 |
+
output_hidden_states (`bool`, *optional*):
|
| 789 |
+
Whether or not to return the hidden states of all layers.
|
| 790 |
+
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 791 |
+
Labels for computing the masked language modeling loss.
|
| 792 |
+
return_dict (`bool`, *optional*):
|
| 793 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 794 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 795 |
+
CPMAnt will process attention mask automatically, this parameter is a dummy parameter for
|
| 796 |
+
text-generation pipeline.
|
| 797 |
+
|
| 798 |
+
Example:
|
| 799 |
+
|
| 800 |
+
Text Generation with CpmAntForCausalLM.
|
| 801 |
+
```python
|
| 802 |
+
>>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
|
| 803 |
+
|
| 804 |
+
>>> texts = "今天天气不错,"
|
| 805 |
+
>>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
|
| 806 |
+
>>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
|
| 807 |
+
>>> input_ids = tokenizer(texts, return_tensors="pt")
|
| 808 |
+
>>> outputs = model.generate(**input_ids)
|
| 809 |
+
>>> output_texts = tokenizer.batch_decode(outputs)
|
| 810 |
+
>>> print(output_texts)
|
| 811 |
+
['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
|
| 812 |
+
```
|
| 813 |
+
"""
|
| 814 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 815 |
+
|
| 816 |
+
model_output = self.cpmant(
|
| 817 |
+
input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict
|
| 818 |
+
)
|
| 819 |
+
hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
|
| 820 |
+
|
| 821 |
+
logits = self.lm_head(hidden_states)
|
| 822 |
+
|
| 823 |
+
loss = None
|
| 824 |
+
if labels is not None:
|
| 825 |
+
loss_func = CrossEntropyLoss()
|
| 826 |
+
loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 827 |
+
|
| 828 |
+
if not return_dict:
|
| 829 |
+
output = (logits,) + model_output[1:]
|
| 830 |
+
return ((loss,) + output) if loss is not None else output
|
| 831 |
+
|
| 832 |
+
return CausalLMOutputWithPast(
|
| 833 |
+
loss=loss,
|
| 834 |
+
logits=logits,
|
| 835 |
+
past_key_values=model_output.past_key_values,
|
| 836 |
+
hidden_states=model_output.hidden_states,
|
| 837 |
+
attentions=model_output.attentions,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
def get_input_embeddings(self):
|
| 841 |
+
return self.cpmant.input_embedding
|
| 842 |
+
|
| 843 |
+
def set_input_embeddings(self, embeddings):
|
| 844 |
+
self.cpmant.input_embedding = embeddings
|
| 845 |
+
|
| 846 |
+
def get_output_embeddings(self):
|
| 847 |
+
return self.lm_head
|
| 848 |
+
|
| 849 |
+
def set_output_embeddings(self, new_embeddings):
|
| 850 |
+
self.lm_head = new_embeddings
|
| 851 |
+
|
| 852 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 853 |
+
past_key_values = [list(each) if each is not None else each for each in past_key_values]
|
| 854 |
+
for key_value_layer in past_key_values:
|
| 855 |
+
key_value_layer[0] = key_value_layer[0][beam_idx]
|
| 856 |
+
key_value_layer[1] = key_value_layer[1][beam_idx]
|
| 857 |
+
return past_key_values
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
__all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/cpmant/tokenization_cpmant.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for CPMAnt."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import os
|
| 19 |
+
from typing import List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from transformers.utils import is_jieba_available, requires_backends
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_jieba_available():
|
| 25 |
+
import jieba
|
| 26 |
+
|
| 27 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 28 |
+
from ...utils import logging
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_vocab(vocab_file):
|
| 37 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 38 |
+
vocab = collections.OrderedDict()
|
| 39 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 40 |
+
tokens = reader.readlines()
|
| 41 |
+
for index, token in enumerate(tokens):
|
| 42 |
+
token = token.rstrip("\n")
|
| 43 |
+
vocab[token] = index
|
| 44 |
+
return vocab
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class WordpieceTokenizer:
|
| 48 |
+
def __init__(self, vocab, unk_token="<unk>", max_input_chars_per_word=200):
|
| 49 |
+
self.vocab = vocab
|
| 50 |
+
self.unk_token = unk_token
|
| 51 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 52 |
+
|
| 53 |
+
def tokenize(self, token):
|
| 54 |
+
chars = list(token)
|
| 55 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 56 |
+
return [self.unk_token]
|
| 57 |
+
|
| 58 |
+
start = 0
|
| 59 |
+
sub_tokens = []
|
| 60 |
+
while start < len(chars):
|
| 61 |
+
end = len(chars)
|
| 62 |
+
cur_substr = None
|
| 63 |
+
while start < end:
|
| 64 |
+
substr = "".join(chars[start:end])
|
| 65 |
+
if substr in self.vocab:
|
| 66 |
+
cur_substr = substr
|
| 67 |
+
break
|
| 68 |
+
end -= 1
|
| 69 |
+
if cur_substr is None:
|
| 70 |
+
sub_tokens.append(self.unk_token)
|
| 71 |
+
start += 1
|
| 72 |
+
else:
|
| 73 |
+
sub_tokens.append(cur_substr)
|
| 74 |
+
start = end
|
| 75 |
+
|
| 76 |
+
return sub_tokens
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CpmAntTokenizer(PreTrainedTokenizer):
|
| 80 |
+
"""
|
| 81 |
+
Construct a CPMAnt tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
vocab_file (`str`):
|
| 85 |
+
Path to the vocabulary file.
|
| 86 |
+
bod_token (`str`, *optional*, defaults to `"<d>"`):
|
| 87 |
+
The beginning of document token.
|
| 88 |
+
eod_token (`str`, *optional*, defaults to `"</d>"`):
|
| 89 |
+
The end of document token.
|
| 90 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 91 |
+
The beginning of sequence token.
|
| 92 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 93 |
+
The end of sequence token.
|
| 94 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 95 |
+
The token used for padding.
|
| 96 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 97 |
+
The unknown token.
|
| 98 |
+
line_token (`str`, *optional*, defaults to `"</n>"`):
|
| 99 |
+
The line token.
|
| 100 |
+
space_token (`str`, *optional*, defaults to `"</_>"`):
|
| 101 |
+
The space token.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 105 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 106 |
+
add_prefix_space = False
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_file,
|
| 111 |
+
bod_token="<d>",
|
| 112 |
+
eod_token="</d>",
|
| 113 |
+
bos_token="<s>",
|
| 114 |
+
eos_token="</s>",
|
| 115 |
+
pad_token="<pad>",
|
| 116 |
+
unk_token="<unk>",
|
| 117 |
+
line_token="</n>",
|
| 118 |
+
space_token="</_>",
|
| 119 |
+
padding_side="left",
|
| 120 |
+
**kwargs,
|
| 121 |
+
):
|
| 122 |
+
requires_backends(self, ["jieba"])
|
| 123 |
+
self.bod_token = bod_token
|
| 124 |
+
self.eod_token = eod_token
|
| 125 |
+
self.encoder = load_vocab(vocab_file)
|
| 126 |
+
self.encoder[" "] = self.encoder[space_token]
|
| 127 |
+
self.encoder["\n"] = self.encoder[line_token]
|
| 128 |
+
|
| 129 |
+
del self.encoder[space_token]
|
| 130 |
+
del self.encoder[line_token]
|
| 131 |
+
|
| 132 |
+
self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
|
| 133 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 134 |
+
|
| 135 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder, unk_token=unk_token)
|
| 136 |
+
|
| 137 |
+
super().__init__(
|
| 138 |
+
bod_token=bod_token,
|
| 139 |
+
eod_token=eod_token,
|
| 140 |
+
bos_token=bos_token,
|
| 141 |
+
eos_token=eos_token,
|
| 142 |
+
pad_token=pad_token,
|
| 143 |
+
unk_token=unk_token,
|
| 144 |
+
line_token=line_token,
|
| 145 |
+
space_token=space_token,
|
| 146 |
+
padding_side=padding_side,
|
| 147 |
+
**kwargs,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def bod_token_id(self):
|
| 152 |
+
return self.encoder[self.bod_token]
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def eod_token_id(self):
|
| 156 |
+
return self.encoder[self.eod_token]
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def newline_id(self):
|
| 160 |
+
return self.encoder["\n"]
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def vocab_size(self) -> int:
|
| 164 |
+
return len(self.encoder)
|
| 165 |
+
|
| 166 |
+
def get_vocab(self):
|
| 167 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 168 |
+
|
| 169 |
+
def _tokenize(self, text):
|
| 170 |
+
"""Tokenize a string."""
|
| 171 |
+
output_tokens = []
|
| 172 |
+
for x in jieba.cut(text, cut_all=False):
|
| 173 |
+
output_tokens.extend(self.wordpiece_tokenizer.tokenize(x))
|
| 174 |
+
return output_tokens
|
| 175 |
+
|
| 176 |
+
def _decode(self, token_ids, **kwargs):
|
| 177 |
+
"""Decode ids into a string."""
|
| 178 |
+
token_ids = [i for i in token_ids if i >= 0]
|
| 179 |
+
token_ids = [
|
| 180 |
+
x for x in token_ids if x != self.pad_token_id and x != self.eos_token_id and x != self.bos_token_id
|
| 181 |
+
]
|
| 182 |
+
return super()._decode(token_ids, **kwargs)
|
| 183 |
+
|
| 184 |
+
def check(self, token):
|
| 185 |
+
return token in self.encoder
|
| 186 |
+
|
| 187 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 188 |
+
return "".join(tokens)
|
| 189 |
+
|
| 190 |
+
def _convert_token_to_id(self, token):
|
| 191 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 192 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 193 |
+
|
| 194 |
+
def _convert_id_to_token(self, index):
|
| 195 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 196 |
+
return self.decoder.get(index, self.unk_token)
|
| 197 |
+
|
| 198 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 199 |
+
if os.path.isdir(save_directory):
|
| 200 |
+
vocab_file = os.path.join(
|
| 201 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 202 |
+
)
|
| 203 |
+
else:
|
| 204 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
| 205 |
+
index = 0
|
| 206 |
+
if " " in self.encoder:
|
| 207 |
+
self.encoder["</_>"] = self.encoder[" "]
|
| 208 |
+
del self.encoder[" "]
|
| 209 |
+
if "\n" in self.encoder:
|
| 210 |
+
self.encoder["</n>"] = self.encoder["\n"]
|
| 211 |
+
del self.encoder["\n"]
|
| 212 |
+
self.encoder = collections.OrderedDict(sorted(self.encoder.items(), key=lambda x: x[1]))
|
| 213 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 214 |
+
for token, token_index in self.encoder.items():
|
| 215 |
+
if index != token_index:
|
| 216 |
+
logger.warning(
|
| 217 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 218 |
+
" Please check that the vocabulary is not corrupted!"
|
| 219 |
+
)
|
| 220 |
+
index = token_index
|
| 221 |
+
writer.write(token + "\n")
|
| 222 |
+
index += 1
|
| 223 |
+
return (vocab_file,)
|
| 224 |
+
|
| 225 |
+
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: List[int] = None) -> List[int]:
|
| 226 |
+
"""
|
| 227 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 228 |
+
adding special tokens. A CPMAnt sequence has the following format:
|
| 229 |
+
|
| 230 |
+
- single sequence: `[BOS] Sequence`.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
token_ids_0 (`List[int]`): The first tokenized sequence that special tokens will be added.
|
| 234 |
+
token_ids_1 (`List[int]`): The optional second tokenized sequence that special tokens will be added.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
`List[int]`: The model input with special tokens.
|
| 238 |
+
"""
|
| 239 |
+
if token_ids_1 is None:
|
| 240 |
+
return [self.bos_token_id] + token_ids_0
|
| 241 |
+
return [self.bos_token_id] + token_ids_0 + [self.bos_token_id] + token_ids_1
|
| 242 |
+
|
| 243 |
+
def get_special_tokens_mask(
|
| 244 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 245 |
+
) -> List[int]:
|
| 246 |
+
"""
|
| 247 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 248 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
token_ids_0 (`List[int]`): List of IDs.
|
| 252 |
+
token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs.
|
| 253 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 254 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
if already_has_special_tokens:
|
| 261 |
+
return super().get_special_tokens_mask(
|
| 262 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if token_ids_1 is not None:
|
| 266 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
| 267 |
+
return [1] + ([0] * len(token_ids_0))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
__all__ = ["CpmAntTokenizer"]
|
docs/transformers/build/lib/transformers/models/ctrl/configuration_ctrl.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Salesforce and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Salesforce CTRL configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CTRLConfig(PretrainedConfig):
|
| 25 |
+
"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`CTRLModel`] or a [`TFCTRLModel`]. It is used to
|
| 27 |
+
instantiate a CTRL model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the
|
| 29 |
+
[Salesforce/ctrl](https://huggingface.co/Salesforce/ctrl) architecture from SalesForce.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 246534):
|
| 36 |
+
Vocabulary size of the CTRL model. Defines the number of different tokens that can be represented by the
|
| 37 |
+
`inputs_ids` passed when calling [`CTRLModel`] or [`TFCTRLModel`].
|
| 38 |
+
n_positions (`int`, *optional*, defaults to 256):
|
| 39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 41 |
+
n_embd (`int`, *optional*, defaults to 1280):
|
| 42 |
+
Dimensionality of the embeddings and hidden states.
|
| 43 |
+
dff (`int`, *optional*, defaults to 8192):
|
| 44 |
+
Dimensionality of the inner dimension of the feed forward networks (FFN).
|
| 45 |
+
n_layer (`int`, *optional*, defaults to 48):
|
| 46 |
+
Number of hidden layers in the Transformer encoder.
|
| 47 |
+
n_head (`int`, *optional*, defaults to 16):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
| 52 |
+
The dropout ratio for the embeddings.
|
| 53 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-06):
|
| 54 |
+
The epsilon to use in the layer normalization layers
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 58 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Examples:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
>>> from transformers import CTRLConfig, CTRLModel
|
| 65 |
+
|
| 66 |
+
>>> # Initializing a CTRL configuration
|
| 67 |
+
>>> configuration = CTRLConfig()
|
| 68 |
+
|
| 69 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 70 |
+
>>> model = CTRLModel(configuration)
|
| 71 |
+
|
| 72 |
+
>>> # Accessing the model configuration
|
| 73 |
+
>>> configuration = model.config
|
| 74 |
+
```"""
|
| 75 |
+
|
| 76 |
+
model_type = "ctrl"
|
| 77 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 78 |
+
attribute_map = {
|
| 79 |
+
"max_position_embeddings": "n_positions",
|
| 80 |
+
"hidden_size": "n_embd",
|
| 81 |
+
"num_attention_heads": "n_head",
|
| 82 |
+
"num_hidden_layers": "n_layer",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
vocab_size=246534,
|
| 88 |
+
n_positions=256,
|
| 89 |
+
n_embd=1280,
|
| 90 |
+
dff=8192,
|
| 91 |
+
n_layer=48,
|
| 92 |
+
n_head=16,
|
| 93 |
+
resid_pdrop=0.1,
|
| 94 |
+
embd_pdrop=0.1,
|
| 95 |
+
layer_norm_epsilon=1e-6,
|
| 96 |
+
initializer_range=0.02,
|
| 97 |
+
use_cache=True,
|
| 98 |
+
**kwargs,
|
| 99 |
+
):
|
| 100 |
+
self.vocab_size = vocab_size
|
| 101 |
+
self.n_positions = n_positions
|
| 102 |
+
self.n_embd = n_embd
|
| 103 |
+
self.n_layer = n_layer
|
| 104 |
+
self.n_head = n_head
|
| 105 |
+
self.dff = dff
|
| 106 |
+
self.resid_pdrop = resid_pdrop
|
| 107 |
+
self.embd_pdrop = embd_pdrop
|
| 108 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 109 |
+
self.initializer_range = initializer_range
|
| 110 |
+
|
| 111 |
+
self.use_cache = use_cache
|
| 112 |
+
|
| 113 |
+
super().__init__(**kwargs)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
__all__ = ["CTRLConfig"]
|
docs/transformers/build/lib/transformers/models/ctrl/modeling_ctrl.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Salesforce and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch CTRL model."""
|
| 17 |
+
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...generation import GenerationMixin
|
| 26 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
|
| 29 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 30 |
+
from .configuration_ctrl import CTRLConfig
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
_CONFIG_FOR_DOC = "CTRLConfig"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def angle_defn(pos, i, d_model_size):
|
| 39 |
+
angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
|
| 40 |
+
return pos * angle_rates
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def positional_encoding(position, d_model_size, dtype):
|
| 44 |
+
# create the sinusoidal pattern for the positional encoding
|
| 45 |
+
angle_rads = angle_defn(
|
| 46 |
+
torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
|
| 47 |
+
torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
|
| 48 |
+
d_model_size,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
sines = torch.sin(angle_rads[:, 0::2])
|
| 52 |
+
cosines = torch.cos(angle_rads[:, 1::2])
|
| 53 |
+
|
| 54 |
+
pos_encoding = torch.cat([sines, cosines], dim=-1)
|
| 55 |
+
return pos_encoding
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
|
| 59 |
+
# calculate attention
|
| 60 |
+
matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
|
| 61 |
+
|
| 62 |
+
dk = k.shape[-1]
|
| 63 |
+
scaled_attention_logits = matmul_qk / np.sqrt(dk)
|
| 64 |
+
|
| 65 |
+
if mask is not None:
|
| 66 |
+
nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
|
| 67 |
+
scaled_attention_logits += mask[ns - nd : ns, :ns] * -1e4
|
| 68 |
+
|
| 69 |
+
if attention_mask is not None:
|
| 70 |
+
# Apply the attention mask
|
| 71 |
+
scaled_attention_logits = scaled_attention_logits + attention_mask
|
| 72 |
+
|
| 73 |
+
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
|
| 74 |
+
|
| 75 |
+
# Mask heads if we want to
|
| 76 |
+
if head_mask is not None:
|
| 77 |
+
attention_weights = attention_weights * head_mask
|
| 78 |
+
|
| 79 |
+
output = torch.matmul(attention_weights, v)
|
| 80 |
+
|
| 81 |
+
return output, attention_weights
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiHeadAttention(nn.Module):
|
| 85 |
+
def __init__(self, d_model_size, num_heads):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.num_heads = num_heads
|
| 88 |
+
self.d_model_size = d_model_size
|
| 89 |
+
|
| 90 |
+
self.depth = int(d_model_size / self.num_heads)
|
| 91 |
+
|
| 92 |
+
self.Wq = nn.Linear(d_model_size, d_model_size)
|
| 93 |
+
self.Wk = nn.Linear(d_model_size, d_model_size)
|
| 94 |
+
self.Wv = nn.Linear(d_model_size, d_model_size)
|
| 95 |
+
|
| 96 |
+
self.dense = nn.Linear(d_model_size, d_model_size)
|
| 97 |
+
self.pruned_heads = set()
|
| 98 |
+
|
| 99 |
+
def prune_heads(self, heads):
|
| 100 |
+
attention_head_size = self.d_model_size // self.num_heads
|
| 101 |
+
if len(heads) == 0:
|
| 102 |
+
return
|
| 103 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, attention_head_size, self.pruned_heads)
|
| 104 |
+
|
| 105 |
+
# Prune linear layers
|
| 106 |
+
self.Wq = prune_linear_layer(self.Wq, index)
|
| 107 |
+
self.Wk = prune_linear_layer(self.Wk, index)
|
| 108 |
+
self.Wv = prune_linear_layer(self.Wv, index)
|
| 109 |
+
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
| 110 |
+
|
| 111 |
+
# Update hyper params
|
| 112 |
+
self.num_heads = self.num_heads - len(heads)
|
| 113 |
+
self.d_model_size = attention_head_size * self.num_heads
|
| 114 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 115 |
+
|
| 116 |
+
def split_into_heads(self, x, batch_size):
|
| 117 |
+
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
| 118 |
+
return x.permute([0, 2, 1, 3])
|
| 119 |
+
|
| 120 |
+
def forward(
|
| 121 |
+
self,
|
| 122 |
+
v,
|
| 123 |
+
k,
|
| 124 |
+
q,
|
| 125 |
+
mask,
|
| 126 |
+
layer_past=None,
|
| 127 |
+
attention_mask=None,
|
| 128 |
+
head_mask=None,
|
| 129 |
+
use_cache=False,
|
| 130 |
+
output_attentions=False,
|
| 131 |
+
):
|
| 132 |
+
batch_size = q.shape[0]
|
| 133 |
+
|
| 134 |
+
q = self.Wq(q)
|
| 135 |
+
k = self.Wk(k)
|
| 136 |
+
v = self.Wv(v)
|
| 137 |
+
|
| 138 |
+
q = self.split_into_heads(q, batch_size)
|
| 139 |
+
k = self.split_into_heads(k, batch_size)
|
| 140 |
+
v = self.split_into_heads(v, batch_size)
|
| 141 |
+
if layer_past is not None:
|
| 142 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
| 143 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 144 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 145 |
+
|
| 146 |
+
if use_cache is True:
|
| 147 |
+
present = torch.stack((k, v))
|
| 148 |
+
else:
|
| 149 |
+
present = (None,)
|
| 150 |
+
|
| 151 |
+
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
| 152 |
+
scaled_attention = output[0].permute([0, 2, 1, 3])
|
| 153 |
+
attn = output[1]
|
| 154 |
+
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
|
| 155 |
+
output = self.dense(original_size_attention)
|
| 156 |
+
|
| 157 |
+
outputs = (output, present)
|
| 158 |
+
if output_attentions:
|
| 159 |
+
outputs = outputs + (attn,)
|
| 160 |
+
return outputs
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def point_wise_feed_forward_network(d_model_size, dff):
|
| 164 |
+
return nn.Sequential(nn.Linear(d_model_size, dff), nn.ReLU(), nn.Linear(dff, d_model_size))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class EncoderLayer(nn.Module):
|
| 168 |
+
def __init__(self, d_model_size, num_heads, dff, rate=0.1):
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
|
| 172 |
+
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
|
| 173 |
+
|
| 174 |
+
self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6)
|
| 175 |
+
self.layernorm2 = nn.LayerNorm(d_model_size, eps=1e-6)
|
| 176 |
+
|
| 177 |
+
self.dropout1 = nn.Dropout(rate)
|
| 178 |
+
self.dropout2 = nn.Dropout(rate)
|
| 179 |
+
|
| 180 |
+
def forward(
|
| 181 |
+
self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
|
| 182 |
+
):
|
| 183 |
+
normed = self.layernorm1(x)
|
| 184 |
+
attn_outputs = self.multi_head_attention(
|
| 185 |
+
normed,
|
| 186 |
+
normed,
|
| 187 |
+
normed,
|
| 188 |
+
mask,
|
| 189 |
+
layer_past=layer_past,
|
| 190 |
+
attention_mask=attention_mask,
|
| 191 |
+
head_mask=head_mask,
|
| 192 |
+
use_cache=use_cache,
|
| 193 |
+
output_attentions=output_attentions,
|
| 194 |
+
)
|
| 195 |
+
attn_output = attn_outputs[0]
|
| 196 |
+
attn_output = self.dropout1(attn_output)
|
| 197 |
+
out1 = x + attn_output
|
| 198 |
+
|
| 199 |
+
out2 = self.layernorm2(out1)
|
| 200 |
+
ffn_output = self.ffn(out2)
|
| 201 |
+
ffn_output = self.dropout2(ffn_output)
|
| 202 |
+
out2 = out1 + ffn_output
|
| 203 |
+
|
| 204 |
+
outputs = (out2,) + attn_outputs[1:]
|
| 205 |
+
return outputs
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class CTRLPreTrainedModel(PreTrainedModel):
|
| 209 |
+
"""
|
| 210 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 211 |
+
models.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
config_class = CTRLConfig
|
| 215 |
+
base_model_prefix = "transformer"
|
| 216 |
+
|
| 217 |
+
def _init_weights(self, module):
|
| 218 |
+
"""Initialize the weights."""
|
| 219 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
| 220 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 221 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 222 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 223 |
+
if module.bias is not None:
|
| 224 |
+
module.bias.data.zero_()
|
| 225 |
+
elif isinstance(module, nn.Embedding):
|
| 226 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 227 |
+
if module.padding_idx is not None:
|
| 228 |
+
module.weight.data[module.padding_idx].zero_()
|
| 229 |
+
elif isinstance(module, nn.LayerNorm):
|
| 230 |
+
module.bias.data.zero_()
|
| 231 |
+
module.weight.data.fill_(1.0)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
CTRL_START_DOCSTRING = r"""
|
| 235 |
+
|
| 236 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 237 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 238 |
+
etc.)
|
| 239 |
+
|
| 240 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 241 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 242 |
+
and behavior.
|
| 243 |
+
|
| 244 |
+
Parameters:
|
| 245 |
+
config ([`CTRLConfig`]): Model configuration class with all the parameters of the model.
|
| 246 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 247 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
CTRL_INPUTS_DOCSTRING = r"""
|
| 251 |
+
Args:
|
| 252 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 253 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
|
| 254 |
+
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
|
| 255 |
+
|
| 256 |
+
If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
|
| 257 |
+
`input_ids`.
|
| 258 |
+
|
| 259 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 260 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 261 |
+
|
| 262 |
+
[What are input IDs?](../glossary#input-ids)
|
| 263 |
+
past_key_values (`Tuple[Tuple[torch.FloatTensor]]` of length `config.n_layers`):
|
| 264 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
| 265 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
| 266 |
+
their past given to this model should not be passed as input ids as they have already been computed.
|
| 267 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 268 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 269 |
+
|
| 270 |
+
- 1 for tokens that are **not masked**,
|
| 271 |
+
- 0 for tokens that are **masked**.
|
| 272 |
+
|
| 273 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 274 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 275 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 276 |
+
1]`:
|
| 277 |
+
|
| 278 |
+
- 0 corresponds to a *sentence A* token,
|
| 279 |
+
- 1 corresponds to a *sentence B* token.
|
| 280 |
+
|
| 281 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 282 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 283 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 284 |
+
config.max_position_embeddings - 1]`.
|
| 285 |
+
|
| 286 |
+
[What are position IDs?](../glossary#position-ids)
|
| 287 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 288 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 289 |
+
|
| 290 |
+
- 1 indicates the head is **not masked**,
|
| 291 |
+
- 0 indicates the head is **masked**.
|
| 292 |
+
|
| 293 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 294 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 295 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 296 |
+
model's internal embedding lookup matrix.
|
| 297 |
+
use_cache (`bool`, *optional*):
|
| 298 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 299 |
+
`past_key_values`).
|
| 300 |
+
output_attentions (`bool`, *optional*):
|
| 301 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 302 |
+
tensors for more detail.
|
| 303 |
+
output_hidden_states (`bool`, *optional*):
|
| 304 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 305 |
+
more detail.
|
| 306 |
+
return_dict (`bool`, *optional*):
|
| 307 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@add_start_docstrings(
|
| 312 |
+
"The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
| 313 |
+
CTRL_START_DOCSTRING,
|
| 314 |
+
)
|
| 315 |
+
class CTRLModel(CTRLPreTrainedModel):
|
| 316 |
+
def __init__(self, config):
|
| 317 |
+
super().__init__(config)
|
| 318 |
+
|
| 319 |
+
self.d_model_size = config.n_embd
|
| 320 |
+
self.num_layers = config.n_layer
|
| 321 |
+
|
| 322 |
+
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
|
| 323 |
+
|
| 324 |
+
self.w = nn.Embedding(config.vocab_size, config.n_embd)
|
| 325 |
+
|
| 326 |
+
self.dropout = nn.Dropout(config.embd_pdrop)
|
| 327 |
+
self.h = nn.ModuleList(
|
| 328 |
+
[EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]
|
| 329 |
+
)
|
| 330 |
+
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 331 |
+
|
| 332 |
+
# Initialize weights and apply final processing
|
| 333 |
+
self.post_init()
|
| 334 |
+
|
| 335 |
+
def get_input_embeddings(self):
|
| 336 |
+
return self.w
|
| 337 |
+
|
| 338 |
+
def set_input_embeddings(self, new_embeddings):
|
| 339 |
+
self.w = new_embeddings
|
| 340 |
+
|
| 341 |
+
def _prune_heads(self, heads_to_prune):
|
| 342 |
+
"""
|
| 343 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 344 |
+
"""
|
| 345 |
+
for layer, heads in heads_to_prune.items():
|
| 346 |
+
self.h[layer].multi_head_attention.prune_heads(heads)
|
| 347 |
+
|
| 348 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 349 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 350 |
+
def forward(
|
| 351 |
+
self,
|
| 352 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 353 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 354 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 355 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 356 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 357 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 358 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 359 |
+
use_cache: Optional[bool] = None,
|
| 360 |
+
output_attentions: Optional[bool] = None,
|
| 361 |
+
output_hidden_states: Optional[bool] = None,
|
| 362 |
+
return_dict: Optional[bool] = None,
|
| 363 |
+
**kwargs, # NOOP kwargs, for now
|
| 364 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
| 365 |
+
r"""
|
| 366 |
+
Returns:
|
| 367 |
+
|
| 368 |
+
Example:
|
| 369 |
+
|
| 370 |
+
```python
|
| 371 |
+
>>> from transformers import AutoTokenizer, CTRLModel
|
| 372 |
+
>>> import torch
|
| 373 |
+
|
| 374 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
| 375 |
+
>>> model = CTRLModel.from_pretrained("Salesforce/ctrl")
|
| 376 |
+
|
| 377 |
+
>>> # CTRL was trained with control codes as the first token
|
| 378 |
+
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
| 379 |
+
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
| 380 |
+
|
| 381 |
+
>>> outputs = model(**inputs)
|
| 382 |
+
|
| 383 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 384 |
+
>>> list(last_hidden_states.shape)
|
| 385 |
+
[1, 5, 1280]
|
| 386 |
+
```"""
|
| 387 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 388 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 389 |
+
output_hidden_states = (
|
| 390 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 391 |
+
)
|
| 392 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 393 |
+
|
| 394 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 395 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 396 |
+
elif input_ids is not None:
|
| 397 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 398 |
+
input_shape = input_ids.size()
|
| 399 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 400 |
+
batch_size = input_ids.shape[0]
|
| 401 |
+
elif inputs_embeds is not None:
|
| 402 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 403 |
+
batch_size = inputs_embeds.shape[0]
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 406 |
+
|
| 407 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 408 |
+
|
| 409 |
+
if past_key_values is None:
|
| 410 |
+
past_length = 0
|
| 411 |
+
past_key_values = tuple([None] * len(self.h))
|
| 412 |
+
else:
|
| 413 |
+
past_length = past_key_values[0][0].size(-2)
|
| 414 |
+
if position_ids is None:
|
| 415 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 416 |
+
position_ids = position_ids.unsqueeze(0)
|
| 417 |
+
|
| 418 |
+
# Attention mask.
|
| 419 |
+
if attention_mask is not None:
|
| 420 |
+
if batch_size <= 0:
|
| 421 |
+
raise ValueError("batch_size has to be defined and > 0")
|
| 422 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
| 423 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 424 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 425 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 426 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 427 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 428 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 429 |
+
|
| 430 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 431 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 432 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
| 433 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 434 |
+
# effectively the same as removing these entirely.
|
| 435 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 436 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
| 437 |
+
|
| 438 |
+
# Prepare head mask if needed
|
| 439 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 440 |
+
|
| 441 |
+
if token_type_ids is not None:
|
| 442 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
| 443 |
+
token_type_embeds = self.w(token_type_ids)
|
| 444 |
+
token_type_embeds *= np.sqrt(self.d_model_size)
|
| 445 |
+
else:
|
| 446 |
+
token_type_embeds = 0
|
| 447 |
+
|
| 448 |
+
if inputs_embeds is None:
|
| 449 |
+
inputs_embeds = self.w(input_ids)
|
| 450 |
+
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
| 451 |
+
seq_len = input_shape[-1]
|
| 452 |
+
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(device)
|
| 453 |
+
|
| 454 |
+
inputs_embeds *= np.sqrt(self.d_model_size)
|
| 455 |
+
|
| 456 |
+
# `self.pos_encoding` won't be sent to the correct device along the model, so we do it manually.
|
| 457 |
+
self.pos_encoding = self.pos_encoding.to(device)
|
| 458 |
+
pos_embeds = self.pos_encoding[position_ids, :]
|
| 459 |
+
|
| 460 |
+
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
| 461 |
+
|
| 462 |
+
hidden_states = self.dropout(hidden_states)
|
| 463 |
+
|
| 464 |
+
presents = () if use_cache else None
|
| 465 |
+
all_hidden_states = () if output_hidden_states else None
|
| 466 |
+
all_attentions = () if output_attentions else None
|
| 467 |
+
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 468 |
+
if output_hidden_states:
|
| 469 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 470 |
+
outputs = h(
|
| 471 |
+
hidden_states,
|
| 472 |
+
mask,
|
| 473 |
+
layer_past=layer_past,
|
| 474 |
+
attention_mask=attention_mask,
|
| 475 |
+
head_mask=head_mask[i],
|
| 476 |
+
use_cache=use_cache,
|
| 477 |
+
output_attentions=output_attentions,
|
| 478 |
+
)
|
| 479 |
+
hidden_states, present = outputs[:2]
|
| 480 |
+
if use_cache is True:
|
| 481 |
+
presents = presents + (present,)
|
| 482 |
+
|
| 483 |
+
if output_attentions:
|
| 484 |
+
all_attentions += (outputs[2],)
|
| 485 |
+
|
| 486 |
+
hidden_states = self.layernorm(hidden_states)
|
| 487 |
+
if output_hidden_states:
|
| 488 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 489 |
+
|
| 490 |
+
if not return_dict:
|
| 491 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
| 492 |
+
|
| 493 |
+
return BaseModelOutputWithPast(
|
| 494 |
+
last_hidden_state=hidden_states,
|
| 495 |
+
past_key_values=presents,
|
| 496 |
+
hidden_states=all_hidden_states,
|
| 497 |
+
attentions=all_attentions,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@add_start_docstrings(
|
| 502 |
+
"""
|
| 503 |
+
The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 504 |
+
embeddings).
|
| 505 |
+
""",
|
| 506 |
+
CTRL_START_DOCSTRING,
|
| 507 |
+
)
|
| 508 |
+
class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
|
| 509 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 510 |
+
|
| 511 |
+
def __init__(self, config):
|
| 512 |
+
super().__init__(config)
|
| 513 |
+
self.transformer = CTRLModel(config)
|
| 514 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
|
| 515 |
+
|
| 516 |
+
# Initialize weights and apply final processing
|
| 517 |
+
self.post_init()
|
| 518 |
+
|
| 519 |
+
def get_output_embeddings(self):
|
| 520 |
+
return self.lm_head
|
| 521 |
+
|
| 522 |
+
def set_output_embeddings(self, new_embeddings):
|
| 523 |
+
self.lm_head = new_embeddings
|
| 524 |
+
|
| 525 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 526 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 527 |
+
def forward(
|
| 528 |
+
self,
|
| 529 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 530 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 531 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 532 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 533 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 534 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 535 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 536 |
+
labels: Optional[torch.LongTensor] = None,
|
| 537 |
+
use_cache: Optional[bool] = None,
|
| 538 |
+
output_attentions: Optional[bool] = None,
|
| 539 |
+
output_hidden_states: Optional[bool] = None,
|
| 540 |
+
return_dict: Optional[bool] = None,
|
| 541 |
+
**kwargs,
|
| 542 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
|
| 543 |
+
r"""
|
| 544 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 545 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 546 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 547 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
|
| 551 |
+
Example:
|
| 552 |
+
|
| 553 |
+
```python
|
| 554 |
+
>>> import torch
|
| 555 |
+
>>> from transformers import AutoTokenizer, CTRLLMHeadModel
|
| 556 |
+
|
| 557 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
| 558 |
+
>>> model = CTRLLMHeadModel.from_pretrained("Salesforce/ctrl")
|
| 559 |
+
|
| 560 |
+
>>> # CTRL was trained with control codes as the first token
|
| 561 |
+
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
|
| 562 |
+
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
| 563 |
+
|
| 564 |
+
>>> sequence_ids = model.generate(inputs["input_ids"])
|
| 565 |
+
>>> sequences = tokenizer.batch_decode(sequence_ids)
|
| 566 |
+
>>> sequences
|
| 567 |
+
['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']
|
| 568 |
+
|
| 569 |
+
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
| 570 |
+
>>> round(outputs.loss.item(), 2)
|
| 571 |
+
9.21
|
| 572 |
+
|
| 573 |
+
>>> list(outputs.logits.shape)
|
| 574 |
+
[1, 5, 246534]
|
| 575 |
+
```"""
|
| 576 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 577 |
+
|
| 578 |
+
transformer_outputs = self.transformer(
|
| 579 |
+
input_ids,
|
| 580 |
+
past_key_values=past_key_values,
|
| 581 |
+
attention_mask=attention_mask,
|
| 582 |
+
token_type_ids=token_type_ids,
|
| 583 |
+
position_ids=position_ids,
|
| 584 |
+
head_mask=head_mask,
|
| 585 |
+
inputs_embeds=inputs_embeds,
|
| 586 |
+
use_cache=use_cache,
|
| 587 |
+
output_attentions=output_attentions,
|
| 588 |
+
output_hidden_states=output_hidden_states,
|
| 589 |
+
return_dict=return_dict,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
hidden_states = transformer_outputs[0]
|
| 593 |
+
|
| 594 |
+
lm_logits = self.lm_head(hidden_states)
|
| 595 |
+
|
| 596 |
+
loss = None
|
| 597 |
+
if labels is not None:
|
| 598 |
+
loss = self.loss_function(
|
| 599 |
+
lm_logits,
|
| 600 |
+
labels,
|
| 601 |
+
vocab_size=self.config.vocab_size,
|
| 602 |
+
**kwargs,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if not return_dict:
|
| 606 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 607 |
+
return ((loss,) + output) if loss is not None else output
|
| 608 |
+
|
| 609 |
+
return CausalLMOutputWithPast(
|
| 610 |
+
loss=loss,
|
| 611 |
+
logits=lm_logits,
|
| 612 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 613 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 614 |
+
attentions=transformer_outputs.attentions,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
|
| 618 |
+
# Overwritten -- inputs_embeds not working properly
|
| 619 |
+
|
| 620 |
+
# only last tokens for inputs_ids if past is defined in kwargs
|
| 621 |
+
if past_key_values is not None:
|
| 622 |
+
past_length = past_key_values[0][0].shape[2]
|
| 623 |
+
|
| 624 |
+
# Some generation methods already pass only the last input ID
|
| 625 |
+
if input_ids.shape[1] > past_length:
|
| 626 |
+
remove_prefix_length = past_length
|
| 627 |
+
else:
|
| 628 |
+
# Default to old behavior: keep only final ID
|
| 629 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
| 630 |
+
|
| 631 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
| 632 |
+
|
| 633 |
+
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
|
| 634 |
+
|
| 635 |
+
@staticmethod
|
| 636 |
+
def _reorder_cache(
|
| 637 |
+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
| 638 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
| 639 |
+
"""
|
| 640 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 641 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 642 |
+
beam_idx at every generation step.
|
| 643 |
+
"""
|
| 644 |
+
return tuple(
|
| 645 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
| 646 |
+
for layer_past in past_key_values
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
@add_start_docstrings(
|
| 651 |
+
"""
|
| 652 |
+
The CTRL Model transformer with a sequence classification head on top (linear layer).
|
| 653 |
+
[`CTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 654 |
+
(e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last
|
| 655 |
+
token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in
|
| 656 |
+
each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
| 657 |
+
guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last
|
| 658 |
+
value in each row of the batch).
|
| 659 |
+
""",
|
| 660 |
+
CTRL_START_DOCSTRING,
|
| 661 |
+
)
|
| 662 |
+
class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
| 663 |
+
def __init__(self, config):
|
| 664 |
+
super().__init__(config)
|
| 665 |
+
self.num_labels = config.num_labels
|
| 666 |
+
self.transformer = CTRLModel(config)
|
| 667 |
+
self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
| 668 |
+
|
| 669 |
+
# Initialize weights and apply final processing
|
| 670 |
+
self.post_init()
|
| 671 |
+
|
| 672 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 673 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 674 |
+
def forward(
|
| 675 |
+
self,
|
| 676 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 677 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 678 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 679 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 680 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 681 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 682 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 683 |
+
labels: Optional[torch.LongTensor] = None,
|
| 684 |
+
use_cache: Optional[bool] = None,
|
| 685 |
+
output_attentions: Optional[bool] = None,
|
| 686 |
+
output_hidden_states: Optional[bool] = None,
|
| 687 |
+
return_dict: Optional[bool] = None,
|
| 688 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 689 |
+
r"""
|
| 690 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 691 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 692 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 693 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 694 |
+
|
| 695 |
+
Returns:
|
| 696 |
+
|
| 697 |
+
Example of single-label classification:
|
| 698 |
+
|
| 699 |
+
```python
|
| 700 |
+
>>> import torch
|
| 701 |
+
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
| 702 |
+
|
| 703 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
| 704 |
+
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl")
|
| 705 |
+
|
| 706 |
+
>>> # CTRL was trained with control codes as the first token
|
| 707 |
+
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
| 708 |
+
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
| 709 |
+
|
| 710 |
+
>>> with torch.no_grad():
|
| 711 |
+
... logits = model(**inputs).logits
|
| 712 |
+
|
| 713 |
+
>>> predicted_class_id = logits.argmax().item()
|
| 714 |
+
>>> model.config.id2label[predicted_class_id]
|
| 715 |
+
'LABEL_0'
|
| 716 |
+
```
|
| 717 |
+
|
| 718 |
+
```python
|
| 719 |
+
>>> import torch
|
| 720 |
+
|
| 721 |
+
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
|
| 722 |
+
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
| 723 |
+
>>> num_labels = len(model.config.id2label)
|
| 724 |
+
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
| 725 |
+
|
| 726 |
+
>>> labels = torch.tensor(1)
|
| 727 |
+
>>> loss = model(**inputs, labels=labels).loss
|
| 728 |
+
>>> round(loss.item(), 2)
|
| 729 |
+
0.93
|
| 730 |
+
```
|
| 731 |
+
|
| 732 |
+
Example of multi-label classification:
|
| 733 |
+
|
| 734 |
+
```python
|
| 735 |
+
>>> import torch
|
| 736 |
+
>>> from transformers import AutoTokenizer, CTRLForSequenceClassification
|
| 737 |
+
|
| 738 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
|
| 739 |
+
>>> model = CTRLForSequenceClassification.from_pretrained(
|
| 740 |
+
... "Salesforce/ctrl", problem_type="multi_label_classification"
|
| 741 |
+
... )
|
| 742 |
+
|
| 743 |
+
>>> # CTRL was trained with control codes as the first token
|
| 744 |
+
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
| 745 |
+
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
| 746 |
+
|
| 747 |
+
>>> with torch.no_grad():
|
| 748 |
+
... logits = model(**inputs).logits
|
| 749 |
+
|
| 750 |
+
>>> predicted_class_id = logits.argmax().item()
|
| 751 |
+
>>> model.config.id2label[predicted_class_id]
|
| 752 |
+
'LABEL_0'
|
| 753 |
+
```
|
| 754 |
+
|
| 755 |
+
```python
|
| 756 |
+
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
| 757 |
+
>>> num_labels = len(model.config.id2label)
|
| 758 |
+
>>> model = CTRLForSequenceClassification.from_pretrained("Salesforce/ctrl", num_labels=num_labels)
|
| 759 |
+
|
| 760 |
+
>>> num_labels = len(model.config.id2label)
|
| 761 |
+
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
| 762 |
+
... torch.float
|
| 763 |
+
... )
|
| 764 |
+
>>> loss = model(**inputs, labels=labels).loss
|
| 765 |
+
>>> loss.backward() # doctest: +IGNORE_RESULT
|
| 766 |
+
```"""
|
| 767 |
+
|
| 768 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 769 |
+
|
| 770 |
+
transformer_outputs = self.transformer(
|
| 771 |
+
input_ids,
|
| 772 |
+
past_key_values=past_key_values,
|
| 773 |
+
attention_mask=attention_mask,
|
| 774 |
+
token_type_ids=token_type_ids,
|
| 775 |
+
position_ids=position_ids,
|
| 776 |
+
head_mask=head_mask,
|
| 777 |
+
inputs_embeds=inputs_embeds,
|
| 778 |
+
use_cache=use_cache,
|
| 779 |
+
output_attentions=output_attentions,
|
| 780 |
+
output_hidden_states=output_hidden_states,
|
| 781 |
+
return_dict=return_dict,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
hidden_states = transformer_outputs[0]
|
| 785 |
+
logits = self.classifier(hidden_states)
|
| 786 |
+
|
| 787 |
+
if input_ids is not None:
|
| 788 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
| 789 |
+
else:
|
| 790 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
| 791 |
+
|
| 792 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 793 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 794 |
+
if self.config.pad_token_id is None:
|
| 795 |
+
last_non_pad_token = -1
|
| 796 |
+
elif input_ids is not None:
|
| 797 |
+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
| 798 |
+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
| 799 |
+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
| 800 |
+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
| 801 |
+
else:
|
| 802 |
+
last_non_pad_token = -1
|
| 803 |
+
logger.warning_once(
|
| 804 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 805 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
| 809 |
+
|
| 810 |
+
loss = None
|
| 811 |
+
if labels is not None:
|
| 812 |
+
if self.config.problem_type is None:
|
| 813 |
+
if self.num_labels == 1:
|
| 814 |
+
self.config.problem_type = "regression"
|
| 815 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 816 |
+
self.config.problem_type = "single_label_classification"
|
| 817 |
+
else:
|
| 818 |
+
self.config.problem_type = "multi_label_classification"
|
| 819 |
+
|
| 820 |
+
if self.config.problem_type == "regression":
|
| 821 |
+
loss_fct = MSELoss()
|
| 822 |
+
if self.num_labels == 1:
|
| 823 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 824 |
+
else:
|
| 825 |
+
loss = loss_fct(pooled_logits, labels)
|
| 826 |
+
elif self.config.problem_type == "single_label_classification":
|
| 827 |
+
loss_fct = CrossEntropyLoss()
|
| 828 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 829 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 830 |
+
loss_fct = BCEWithLogitsLoss()
|
| 831 |
+
loss = loss_fct(pooled_logits, labels)
|
| 832 |
+
if not return_dict:
|
| 833 |
+
output = (pooled_logits,) + transformer_outputs[2:]
|
| 834 |
+
return ((loss,) + output) if loss is not None else output
|
| 835 |
+
|
| 836 |
+
return SequenceClassifierOutput(
|
| 837 |
+
loss=loss,
|
| 838 |
+
logits=pooled_logits,
|
| 839 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 840 |
+
attentions=transformer_outputs.attentions,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
__all__ = ["CTRLForSequenceClassification", "CTRLLMHeadModel", "CTRLModel", "CTRLPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/ctrl/modeling_tf_ctrl.py
ADDED
|
@@ -0,0 +1,922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Salesforce and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""TF 2.0 CTRL model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput
|
| 26 |
+
from ...modeling_tf_utils import (
|
| 27 |
+
TFCausalLanguageModelingLoss,
|
| 28 |
+
TFModelInputType,
|
| 29 |
+
TFPreTrainedModel,
|
| 30 |
+
TFSequenceClassificationLoss,
|
| 31 |
+
get_initializer,
|
| 32 |
+
keras,
|
| 33 |
+
keras_serializable,
|
| 34 |
+
unpack_inputs,
|
| 35 |
+
)
|
| 36 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 37 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 38 |
+
from .configuration_ctrl import CTRLConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
_CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
|
| 44 |
+
_CONFIG_FOR_DOC = "CTRLConfig"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def angle_defn(pos, i, d_model_size):
|
| 48 |
+
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
|
| 49 |
+
return pos * angle_rates
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def positional_encoding(position, d_model_size):
|
| 53 |
+
# create the sinusoidal pattern for the positional encoding
|
| 54 |
+
angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size)
|
| 55 |
+
|
| 56 |
+
sines = np.sin(angle_rads[:, 0::2])
|
| 57 |
+
cosines = np.cos(angle_rads[:, 1::2])
|
| 58 |
+
pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
|
| 59 |
+
|
| 60 |
+
return pos_encoding
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
|
| 64 |
+
# calculate attention
|
| 65 |
+
matmul_qk = tf.matmul(q, k, transpose_b=True)
|
| 66 |
+
|
| 67 |
+
dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
|
| 68 |
+
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
|
| 69 |
+
|
| 70 |
+
if mask is not None:
|
| 71 |
+
scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
|
| 72 |
+
|
| 73 |
+
if attention_mask is not None:
|
| 74 |
+
# Apply the attention mask
|
| 75 |
+
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
|
| 76 |
+
scaled_attention_logits = scaled_attention_logits + attention_mask
|
| 77 |
+
|
| 78 |
+
attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
|
| 79 |
+
|
| 80 |
+
# Mask heads if we want to
|
| 81 |
+
if head_mask is not None:
|
| 82 |
+
attention_weights = attention_weights * head_mask
|
| 83 |
+
|
| 84 |
+
output = tf.matmul(attention_weights, v)
|
| 85 |
+
|
| 86 |
+
return output, attention_weights
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TFMultiHeadAttention(keras.layers.Layer):
|
| 90 |
+
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
|
| 91 |
+
super().__init__(**kwargs)
|
| 92 |
+
self.num_heads = num_heads
|
| 93 |
+
self.d_model_size = d_model_size
|
| 94 |
+
self.output_attentions = output_attentions
|
| 95 |
+
|
| 96 |
+
self.depth = int(d_model_size / self.num_heads)
|
| 97 |
+
|
| 98 |
+
self.Wq = keras.layers.Dense(d_model_size, name="Wq")
|
| 99 |
+
self.Wk = keras.layers.Dense(d_model_size, name="Wk")
|
| 100 |
+
self.Wv = keras.layers.Dense(d_model_size, name="Wv")
|
| 101 |
+
|
| 102 |
+
self.dense = keras.layers.Dense(d_model_size, name="dense")
|
| 103 |
+
|
| 104 |
+
def split_into_heads(self, x, batch_size):
|
| 105 |
+
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
|
| 106 |
+
return tf.transpose(x, perm=[0, 2, 1, 3])
|
| 107 |
+
|
| 108 |
+
def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
| 109 |
+
batch_size = shape_list(q)[0]
|
| 110 |
+
|
| 111 |
+
q = self.Wq(q)
|
| 112 |
+
k = self.Wk(k)
|
| 113 |
+
v = self.Wv(v)
|
| 114 |
+
|
| 115 |
+
q = self.split_into_heads(q, batch_size)
|
| 116 |
+
k = self.split_into_heads(k, batch_size)
|
| 117 |
+
v = self.split_into_heads(v, batch_size)
|
| 118 |
+
|
| 119 |
+
if layer_past is not None:
|
| 120 |
+
past_key, past_value = tf.unstack(layer_past, axis=0)
|
| 121 |
+
k = tf.concat((past_key, k), axis=-2)
|
| 122 |
+
v = tf.concat((past_value, v), axis=-2)
|
| 123 |
+
|
| 124 |
+
if use_cache:
|
| 125 |
+
present = tf.stack((k, v), axis=0)
|
| 126 |
+
else:
|
| 127 |
+
present = (None,)
|
| 128 |
+
|
| 129 |
+
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
| 130 |
+
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
|
| 131 |
+
attn = output[1]
|
| 132 |
+
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
|
| 133 |
+
output = self.dense(original_size_attention)
|
| 134 |
+
outputs = (output, present)
|
| 135 |
+
|
| 136 |
+
if output_attentions:
|
| 137 |
+
outputs = outputs + (attn,)
|
| 138 |
+
|
| 139 |
+
return outputs
|
| 140 |
+
|
| 141 |
+
def build(self, input_shape=None):
|
| 142 |
+
if self.built:
|
| 143 |
+
return
|
| 144 |
+
self.built = True
|
| 145 |
+
if getattr(self, "Wq", None) is not None:
|
| 146 |
+
with tf.name_scope(self.Wq.name):
|
| 147 |
+
self.Wq.build([None, None, self.d_model_size])
|
| 148 |
+
if getattr(self, "Wk", None) is not None:
|
| 149 |
+
with tf.name_scope(self.Wk.name):
|
| 150 |
+
self.Wk.build([None, None, self.d_model_size])
|
| 151 |
+
if getattr(self, "Wv", None) is not None:
|
| 152 |
+
with tf.name_scope(self.Wv.name):
|
| 153 |
+
self.Wv.build([None, None, self.d_model_size])
|
| 154 |
+
if getattr(self, "dense", None) is not None:
|
| 155 |
+
with tf.name_scope(self.dense.name):
|
| 156 |
+
self.dense.build([None, None, self.d_model_size])
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class TFPointWiseFeedForwardLayer(keras.layers.Layer):
|
| 160 |
+
def __init__(self, d_model_size, dff, **kwargs):
|
| 161 |
+
super().__init__(**kwargs)
|
| 162 |
+
|
| 163 |
+
self.dense_0 = keras.layers.Dense(dff, activation="relu", name="0")
|
| 164 |
+
self.dense_2 = keras.layers.Dense(d_model_size, name="2")
|
| 165 |
+
self.d_model_size = d_model_size
|
| 166 |
+
self.dff = dff
|
| 167 |
+
|
| 168 |
+
def call(self, inputs, trainable=False):
|
| 169 |
+
dense_0_output = self.dense_0(inputs)
|
| 170 |
+
dense_2_output = self.dense_2(dense_0_output)
|
| 171 |
+
|
| 172 |
+
return dense_2_output
|
| 173 |
+
|
| 174 |
+
def build(self, input_shape=None):
|
| 175 |
+
if self.built:
|
| 176 |
+
return
|
| 177 |
+
self.built = True
|
| 178 |
+
if getattr(self, "dense_0", None) is not None:
|
| 179 |
+
with tf.name_scope(self.dense_0.name):
|
| 180 |
+
self.dense_0.build([None, None, self.d_model_size])
|
| 181 |
+
if getattr(self, "dense_2", None) is not None:
|
| 182 |
+
with tf.name_scope(self.dense_2.name):
|
| 183 |
+
self.dense_2.build([None, None, self.dff])
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TFEncoderLayer(keras.layers.Layer):
|
| 187 |
+
def __init__(
|
| 188 |
+
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
|
| 189 |
+
):
|
| 190 |
+
super().__init__(**kwargs)
|
| 191 |
+
|
| 192 |
+
self.output_attentions = output_attentions
|
| 193 |
+
|
| 194 |
+
self.multi_head_attention = TFMultiHeadAttention(
|
| 195 |
+
d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
|
| 196 |
+
)
|
| 197 |
+
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
|
| 198 |
+
|
| 199 |
+
self.layernorm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
| 200 |
+
self.layernorm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
|
| 201 |
+
|
| 202 |
+
self.dropout1 = keras.layers.Dropout(rate)
|
| 203 |
+
self.dropout2 = keras.layers.Dropout(rate)
|
| 204 |
+
self.d_model_size = d_model_size
|
| 205 |
+
|
| 206 |
+
def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
| 207 |
+
normed = self.layernorm1(x)
|
| 208 |
+
attn_outputs = self.multi_head_attention(
|
| 209 |
+
normed,
|
| 210 |
+
normed,
|
| 211 |
+
normed,
|
| 212 |
+
mask,
|
| 213 |
+
layer_past,
|
| 214 |
+
attention_mask,
|
| 215 |
+
head_mask,
|
| 216 |
+
use_cache,
|
| 217 |
+
output_attentions,
|
| 218 |
+
training=training,
|
| 219 |
+
)
|
| 220 |
+
attn_output = attn_outputs[0]
|
| 221 |
+
attn_output = self.dropout1(attn_output, training=training)
|
| 222 |
+
out1 = x + attn_output
|
| 223 |
+
|
| 224 |
+
out2 = self.layernorm2(out1)
|
| 225 |
+
ffn_output = self.ffn(out2)
|
| 226 |
+
ffn_output = self.dropout2(ffn_output, training=training)
|
| 227 |
+
out2 = out1 + ffn_output
|
| 228 |
+
|
| 229 |
+
outputs = (out2,) + attn_outputs[1:]
|
| 230 |
+
return outputs
|
| 231 |
+
|
| 232 |
+
def build(self, input_shape=None):
|
| 233 |
+
if self.built:
|
| 234 |
+
return
|
| 235 |
+
self.built = True
|
| 236 |
+
if getattr(self, "multi_head_attention", None) is not None:
|
| 237 |
+
with tf.name_scope(self.multi_head_attention.name):
|
| 238 |
+
self.multi_head_attention.build(None)
|
| 239 |
+
if getattr(self, "ffn", None) is not None:
|
| 240 |
+
with tf.name_scope(self.ffn.name):
|
| 241 |
+
self.ffn.build(None)
|
| 242 |
+
if getattr(self, "layernorm1", None) is not None:
|
| 243 |
+
with tf.name_scope(self.layernorm1.name):
|
| 244 |
+
self.layernorm1.build([None, None, self.d_model_size])
|
| 245 |
+
if getattr(self, "layernorm2", None) is not None:
|
| 246 |
+
with tf.name_scope(self.layernorm2.name):
|
| 247 |
+
self.layernorm2.build([None, None, self.d_model_size])
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@keras_serializable
|
| 251 |
+
class TFCTRLMainLayer(keras.layers.Layer):
|
| 252 |
+
config_class = CTRLConfig
|
| 253 |
+
|
| 254 |
+
def __init__(self, config, **kwargs):
|
| 255 |
+
super().__init__(**kwargs)
|
| 256 |
+
|
| 257 |
+
self.config = config
|
| 258 |
+
self.output_hidden_states = config.output_hidden_states
|
| 259 |
+
self.output_attentions = config.output_attentions
|
| 260 |
+
self.use_cache = config.use_cache
|
| 261 |
+
self.return_dict = config.use_return_dict
|
| 262 |
+
|
| 263 |
+
self.d_model_size = config.n_embd
|
| 264 |
+
self.num_layers = config.n_layer
|
| 265 |
+
|
| 266 |
+
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
|
| 267 |
+
|
| 268 |
+
self.w = keras.layers.Embedding(
|
| 269 |
+
input_dim=config.vocab_size,
|
| 270 |
+
output_dim=config.n_embd,
|
| 271 |
+
embeddings_initializer=get_initializer(config.initializer_range),
|
| 272 |
+
name="w",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.dropout = keras.layers.Dropout(config.embd_pdrop)
|
| 276 |
+
self.h = [
|
| 277 |
+
TFEncoderLayer(
|
| 278 |
+
config.n_embd,
|
| 279 |
+
config.n_head,
|
| 280 |
+
config.dff,
|
| 281 |
+
config.resid_pdrop,
|
| 282 |
+
config.layer_norm_epsilon,
|
| 283 |
+
self.output_attentions,
|
| 284 |
+
name=f"h_._{i}",
|
| 285 |
+
)
|
| 286 |
+
for i in range(config.n_layer)
|
| 287 |
+
]
|
| 288 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
|
| 289 |
+
|
| 290 |
+
def get_input_embeddings(self):
|
| 291 |
+
return self.w
|
| 292 |
+
|
| 293 |
+
def set_input_embeddings(self, new_embeddings):
|
| 294 |
+
self.w = new_embeddings
|
| 295 |
+
|
| 296 |
+
def _prune_heads(self, heads_to_prune):
|
| 297 |
+
"""
|
| 298 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 299 |
+
"""
|
| 300 |
+
raise NotImplementedError
|
| 301 |
+
|
| 302 |
+
@unpack_inputs
|
| 303 |
+
def call(
|
| 304 |
+
self,
|
| 305 |
+
input_ids: TFModelInputType | None = None,
|
| 306 |
+
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
| 307 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 308 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 309 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 310 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 311 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 312 |
+
use_cache: Optional[bool] = None,
|
| 313 |
+
output_attentions: Optional[bool] = None,
|
| 314 |
+
output_hidden_states: Optional[bool] = None,
|
| 315 |
+
return_dict: Optional[bool] = None,
|
| 316 |
+
training: Optional[bool] = False,
|
| 317 |
+
) -> Union[Tuple, TFBaseModelOutputWithPast]:
|
| 318 |
+
# If using past key value states, only the last tokens
|
| 319 |
+
# should be given as an input
|
| 320 |
+
if past_key_values is not None:
|
| 321 |
+
if input_ids is not None:
|
| 322 |
+
input_ids = input_ids[:, -1:]
|
| 323 |
+
if inputs_embeds is not None:
|
| 324 |
+
inputs_embeds = inputs_embeds[:, -1:]
|
| 325 |
+
if token_type_ids is not None:
|
| 326 |
+
token_type_ids = token_type_ids[:, -1:]
|
| 327 |
+
|
| 328 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 329 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 330 |
+
elif input_ids is not None:
|
| 331 |
+
input_shape = shape_list(input_ids)
|
| 332 |
+
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
| 333 |
+
elif inputs_embeds is not None:
|
| 334 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 337 |
+
|
| 338 |
+
if past_key_values is None:
|
| 339 |
+
past_length = 0
|
| 340 |
+
past_key_values = [None] * len(self.h)
|
| 341 |
+
else:
|
| 342 |
+
past_length = shape_list(past_key_values[0][0])[-2]
|
| 343 |
+
if position_ids is None:
|
| 344 |
+
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0)
|
| 345 |
+
position_ids = tf.tile(position_ids, [input_shape[0], 1])
|
| 346 |
+
|
| 347 |
+
# Attention mask.
|
| 348 |
+
if attention_mask is not None:
|
| 349 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 350 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 351 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 352 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 353 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 354 |
+
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
|
| 355 |
+
|
| 356 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 357 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 358 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 359 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 360 |
+
# effectively the same as removing these entirely.
|
| 361 |
+
|
| 362 |
+
one_cst = tf.constant(1.0)
|
| 363 |
+
ten_thousand_cst = tf.constant(-10000.0)
|
| 364 |
+
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
|
| 365 |
+
attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), ten_thousand_cst)
|
| 366 |
+
|
| 367 |
+
# Prepare head mask if needed
|
| 368 |
+
# 1.0 in head_mask indicate we keep the head
|
| 369 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 370 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 371 |
+
if head_mask is not None:
|
| 372 |
+
raise NotImplementedError
|
| 373 |
+
else:
|
| 374 |
+
head_mask = [None] * self.num_layers
|
| 375 |
+
|
| 376 |
+
if token_type_ids is not None:
|
| 377 |
+
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
| 378 |
+
token_type_embeds = self.w(token_type_ids)
|
| 379 |
+
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
| 380 |
+
else:
|
| 381 |
+
token_type_embeds = tf.constant(0.0)
|
| 382 |
+
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
| 383 |
+
|
| 384 |
+
if inputs_embeds is None:
|
| 385 |
+
check_embeddings_within_bounds(input_ids, self.w.input_dim)
|
| 386 |
+
inputs_embeds = self.w(input_ids)
|
| 387 |
+
seq_len = input_shape[-1]
|
| 388 |
+
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
| 389 |
+
|
| 390 |
+
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, inputs_embeds.dtype))
|
| 391 |
+
|
| 392 |
+
pos_embeds = tf.gather(self.pos_encoding, position_ids)
|
| 393 |
+
pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
|
| 394 |
+
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
| 395 |
+
|
| 396 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 397 |
+
|
| 398 |
+
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
| 399 |
+
presents = () if use_cache else None
|
| 400 |
+
all_hidden_states = () if output_hidden_states else None
|
| 401 |
+
all_attentions = () if output_attentions else None
|
| 402 |
+
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 403 |
+
if output_hidden_states:
|
| 404 |
+
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
| 405 |
+
outputs = h(
|
| 406 |
+
hidden_states,
|
| 407 |
+
mask,
|
| 408 |
+
layer_past,
|
| 409 |
+
attention_mask,
|
| 410 |
+
head_mask[i],
|
| 411 |
+
use_cache,
|
| 412 |
+
output_attentions,
|
| 413 |
+
training=training,
|
| 414 |
+
)
|
| 415 |
+
hidden_states, present = outputs[:2]
|
| 416 |
+
|
| 417 |
+
if use_cache:
|
| 418 |
+
presents = presents + (present,)
|
| 419 |
+
|
| 420 |
+
if output_attentions:
|
| 421 |
+
all_attentions = all_attentions + (outputs[2],)
|
| 422 |
+
|
| 423 |
+
hidden_states = self.layernorm(hidden_states)
|
| 424 |
+
hidden_states = tf.reshape(hidden_states, output_shape)
|
| 425 |
+
if output_hidden_states:
|
| 426 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 427 |
+
|
| 428 |
+
if output_attentions:
|
| 429 |
+
# let the number of heads free (-1) so we can extract attention even after head pruning
|
| 430 |
+
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
| 431 |
+
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
| 432 |
+
|
| 433 |
+
if not return_dict:
|
| 434 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
| 435 |
+
|
| 436 |
+
return TFBaseModelOutputWithPast(
|
| 437 |
+
last_hidden_state=hidden_states,
|
| 438 |
+
past_key_values=presents,
|
| 439 |
+
hidden_states=all_hidden_states,
|
| 440 |
+
attentions=all_attentions,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def build(self, input_shape=None):
|
| 444 |
+
if self.built:
|
| 445 |
+
return
|
| 446 |
+
self.built = True
|
| 447 |
+
if getattr(self, "w", None) is not None:
|
| 448 |
+
with tf.name_scope(self.w.name):
|
| 449 |
+
self.w.build(None)
|
| 450 |
+
if getattr(self, "layernorm", None) is not None:
|
| 451 |
+
with tf.name_scope(self.layernorm.name):
|
| 452 |
+
self.layernorm.build([None, None, self.config.n_embd])
|
| 453 |
+
if getattr(self, "h", None) is not None:
|
| 454 |
+
for layer in self.h:
|
| 455 |
+
with tf.name_scope(layer.name):
|
| 456 |
+
layer.build(None)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class TFCTRLPreTrainedModel(TFPreTrainedModel):
|
| 460 |
+
"""
|
| 461 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 462 |
+
models.
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
config_class = CTRLConfig
|
| 466 |
+
base_model_prefix = "transformer"
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
CTRL_START_DOCSTRING = r"""
|
| 470 |
+
|
| 471 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 472 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 473 |
+
etc.)
|
| 474 |
+
|
| 475 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 476 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 477 |
+
behavior.
|
| 478 |
+
|
| 479 |
+
<Tip>
|
| 480 |
+
|
| 481 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 482 |
+
|
| 483 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 484 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 485 |
+
|
| 486 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 487 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 488 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 489 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 490 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 491 |
+
positional argument:
|
| 492 |
+
|
| 493 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 494 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 495 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 496 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 497 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 498 |
+
|
| 499 |
+
Note that when creating models and layers with
|
| 500 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 501 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 502 |
+
|
| 503 |
+
</Tip>
|
| 504 |
+
|
| 505 |
+
Parameters:
|
| 506 |
+
config ([`CTRLConfig`]): Model configuration class with all the parameters of the model.
|
| 507 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 508 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
CTRL_INPUTS_DOCSTRING = r"""
|
| 512 |
+
Args:
|
| 513 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
|
| 514 |
+
`input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of
|
| 515 |
+
input past key value states).
|
| 516 |
+
|
| 517 |
+
Indices of input sequence tokens in the vocabulary.
|
| 518 |
+
|
| 519 |
+
If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`.
|
| 520 |
+
|
| 521 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 522 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 523 |
+
|
| 524 |
+
[What are input IDs?](../glossary#input-ids)
|
| 525 |
+
past (`List[tf.Tensor]` of length `config.n_layers`):
|
| 526 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
| 527 |
+
`past` output below). Can be used to speed up sequential decoding. The token ids which have their past
|
| 528 |
+
given to this model should not be passed as input ids as they have already been computed.
|
| 529 |
+
attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 530 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 531 |
+
|
| 532 |
+
- 1 for tokens that are **not masked**,
|
| 533 |
+
- 0 for tokens that are **masked**.
|
| 534 |
+
|
| 535 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 536 |
+
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 537 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 538 |
+
1]`:
|
| 539 |
+
|
| 540 |
+
- 0 corresponds to a *sentence A* token,
|
| 541 |
+
- 1 corresponds to a *sentence B* token.
|
| 542 |
+
|
| 543 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 544 |
+
position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 545 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 546 |
+
config.max_position_embeddings - 1]`.
|
| 547 |
+
|
| 548 |
+
[What are position IDs?](../glossary#position-ids)
|
| 549 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 550 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 551 |
+
|
| 552 |
+
- 1 indicates the head is **not masked**,
|
| 553 |
+
- 0 indicates the head is **masked**.
|
| 554 |
+
|
| 555 |
+
inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 556 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 557 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 558 |
+
model's internal embedding lookup matrix.
|
| 559 |
+
use_cache (`bool`, *optional*):
|
| 560 |
+
If set to `True`, `past` key value states are returned and can be used to speed up decoding (see `past`).
|
| 561 |
+
output_attentions (`bool`, *optional*):
|
| 562 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 563 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 564 |
+
config will be used instead.
|
| 565 |
+
output_hidden_states (`bool`, *optional*):
|
| 566 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 567 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 568 |
+
used instead.
|
| 569 |
+
return_dict (`bool`, *optional*):
|
| 570 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 571 |
+
eager mode, in graph mode the value will always be set to True.
|
| 572 |
+
training (`bool`, *optional*, defaults to `False`):
|
| 573 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 574 |
+
behaviors between training and evaluation).
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
@add_start_docstrings(
|
| 579 |
+
"The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
|
| 580 |
+
CTRL_START_DOCSTRING,
|
| 581 |
+
)
|
| 582 |
+
class TFCTRLModel(TFCTRLPreTrainedModel):
|
| 583 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 584 |
+
super().__init__(config, *inputs, **kwargs)
|
| 585 |
+
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
| 586 |
+
|
| 587 |
+
@unpack_inputs
|
| 588 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 589 |
+
@add_code_sample_docstrings(
|
| 590 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 591 |
+
output_type=TFBaseModelOutputWithPast,
|
| 592 |
+
config_class=_CONFIG_FOR_DOC,
|
| 593 |
+
)
|
| 594 |
+
def call(
|
| 595 |
+
self,
|
| 596 |
+
input_ids: TFModelInputType | None = None,
|
| 597 |
+
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
| 598 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 599 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 600 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 601 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 602 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 603 |
+
use_cache: Optional[bool] = None,
|
| 604 |
+
output_attentions: Optional[bool] = None,
|
| 605 |
+
output_hidden_states: Optional[bool] = None,
|
| 606 |
+
return_dict: Optional[bool] = None,
|
| 607 |
+
training: Optional[bool] = False,
|
| 608 |
+
) -> Union[Tuple, TFBaseModelOutputWithPast]:
|
| 609 |
+
outputs = self.transformer(
|
| 610 |
+
input_ids=input_ids,
|
| 611 |
+
past_key_values=past_key_values,
|
| 612 |
+
attention_mask=attention_mask,
|
| 613 |
+
token_type_ids=token_type_ids,
|
| 614 |
+
position_ids=position_ids,
|
| 615 |
+
head_mask=head_mask,
|
| 616 |
+
inputs_embeds=inputs_embeds,
|
| 617 |
+
use_cache=use_cache,
|
| 618 |
+
output_attentions=output_attentions,
|
| 619 |
+
output_hidden_states=output_hidden_states,
|
| 620 |
+
return_dict=return_dict,
|
| 621 |
+
training=training,
|
| 622 |
+
)
|
| 623 |
+
return outputs
|
| 624 |
+
|
| 625 |
+
def build(self, input_shape=None):
|
| 626 |
+
if self.built:
|
| 627 |
+
return
|
| 628 |
+
self.built = True
|
| 629 |
+
if getattr(self, "transformer", None) is not None:
|
| 630 |
+
with tf.name_scope(self.transformer.name):
|
| 631 |
+
self.transformer.build(None)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class TFCTRLBiasLayer(keras.layers.Layer):
|
| 635 |
+
"""
|
| 636 |
+
Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
|
| 637 |
+
so all weights have to be registered in a layer.
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
def __init__(self, shape, initializer, trainable, name, **kwargs):
|
| 641 |
+
super().__init__(name=name, **kwargs)
|
| 642 |
+
self.shape = shape
|
| 643 |
+
self.initializer = initializer
|
| 644 |
+
self.trainable = trainable
|
| 645 |
+
|
| 646 |
+
def build(self, input_shape):
|
| 647 |
+
self.bias = self.add_weight(
|
| 648 |
+
name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
|
| 649 |
+
)
|
| 650 |
+
super().build(input_shape)
|
| 651 |
+
|
| 652 |
+
def call(self, x):
|
| 653 |
+
return x + self.bias
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
@add_start_docstrings(
|
| 657 |
+
"""
|
| 658 |
+
The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 659 |
+
embeddings).
|
| 660 |
+
""",
|
| 661 |
+
CTRL_START_DOCSTRING,
|
| 662 |
+
)
|
| 663 |
+
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
| 664 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 665 |
+
super().__init__(config, *inputs, **kwargs)
|
| 666 |
+
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
| 667 |
+
self.bias_layer = TFCTRLBiasLayer(
|
| 668 |
+
name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
def get_output_embeddings(self):
|
| 672 |
+
return self.get_input_embeddings()
|
| 673 |
+
|
| 674 |
+
def set_output_embeddings(self, value):
|
| 675 |
+
self.set_input_embeddings(value)
|
| 676 |
+
|
| 677 |
+
def get_bias(self):
|
| 678 |
+
return {"lm_head.bias": self.bias_layer.bias}
|
| 679 |
+
|
| 680 |
+
def set_bias(self, value):
|
| 681 |
+
# Replaces the existing layers containing bias for correct (de)serialization.
|
| 682 |
+
vocab_size = value["lm_head.bias"].shape[-1]
|
| 683 |
+
self.bias_layer = TFCTRLBiasLayer(
|
| 684 |
+
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
|
| 685 |
+
)
|
| 686 |
+
self.bias_layer.build(None)
|
| 687 |
+
self.bias_layer.bias.assign(value["lm_head.bias"])
|
| 688 |
+
|
| 689 |
+
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
|
| 690 |
+
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
| 691 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
| 692 |
+
# only last token for inputs_ids if past is defined in kwargs
|
| 693 |
+
if past_key_values:
|
| 694 |
+
inputs = tf.expand_dims(inputs[:, -1], -1)
|
| 695 |
+
if token_type_ids is not None:
|
| 696 |
+
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
| 697 |
+
|
| 698 |
+
position_ids = kwargs.get("position_ids", None)
|
| 699 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 700 |
+
|
| 701 |
+
if attention_mask is not None and position_ids is None:
|
| 702 |
+
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
| 703 |
+
if past_key_values:
|
| 704 |
+
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
| 705 |
+
|
| 706 |
+
return {
|
| 707 |
+
"input_ids": inputs,
|
| 708 |
+
"attention_mask": attention_mask,
|
| 709 |
+
"position_ids": position_ids,
|
| 710 |
+
"past_key_values": past_key_values,
|
| 711 |
+
"use_cache": use_cache,
|
| 712 |
+
"token_type_ids": token_type_ids,
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
@unpack_inputs
|
| 716 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 717 |
+
@add_code_sample_docstrings(
|
| 718 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 719 |
+
output_type=TFCausalLMOutputWithPast,
|
| 720 |
+
config_class=_CONFIG_FOR_DOC,
|
| 721 |
+
)
|
| 722 |
+
def call(
|
| 723 |
+
self,
|
| 724 |
+
input_ids: TFModelInputType | None = None,
|
| 725 |
+
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
| 726 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 727 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 728 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 729 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 730 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 731 |
+
use_cache: Optional[bool] = None,
|
| 732 |
+
output_attentions: Optional[bool] = None,
|
| 733 |
+
output_hidden_states: Optional[bool] = None,
|
| 734 |
+
return_dict: Optional[bool] = None,
|
| 735 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 736 |
+
training: Optional[bool] = False,
|
| 737 |
+
) -> Union[Tuple, TFCausalLMOutputWithPast]:
|
| 738 |
+
r"""
|
| 739 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 740 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
| 741 |
+
config.vocab_size - 1]`.
|
| 742 |
+
"""
|
| 743 |
+
transformer_outputs = self.transformer(
|
| 744 |
+
input_ids=input_ids,
|
| 745 |
+
past_key_values=past_key_values,
|
| 746 |
+
attention_mask=attention_mask,
|
| 747 |
+
token_type_ids=token_type_ids,
|
| 748 |
+
position_ids=position_ids,
|
| 749 |
+
head_mask=head_mask,
|
| 750 |
+
inputs_embeds=inputs_embeds,
|
| 751 |
+
use_cache=use_cache,
|
| 752 |
+
output_attentions=output_attentions,
|
| 753 |
+
output_hidden_states=output_hidden_states,
|
| 754 |
+
return_dict=return_dict,
|
| 755 |
+
training=training,
|
| 756 |
+
)
|
| 757 |
+
hidden_states = transformer_outputs[0]
|
| 758 |
+
logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
|
| 759 |
+
logits = self.bias_layer(logits)
|
| 760 |
+
|
| 761 |
+
loss = None
|
| 762 |
+
if labels is not None:
|
| 763 |
+
# shift labels to the left and cut last logit token
|
| 764 |
+
shifted_logits = logits[:, :-1]
|
| 765 |
+
labels = labels[:, 1:]
|
| 766 |
+
loss = self.hf_compute_loss(labels, shifted_logits)
|
| 767 |
+
|
| 768 |
+
if not return_dict:
|
| 769 |
+
output = (logits,) + transformer_outputs[1:]
|
| 770 |
+
return ((loss,) + output) if loss is not None else output
|
| 771 |
+
|
| 772 |
+
return TFCausalLMOutputWithPast(
|
| 773 |
+
loss=loss,
|
| 774 |
+
logits=logits,
|
| 775 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 776 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 777 |
+
attentions=transformer_outputs.attentions,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
def build(self, input_shape=None):
|
| 781 |
+
if self.built:
|
| 782 |
+
return
|
| 783 |
+
self.built = True
|
| 784 |
+
if getattr(self, "transformer", None) is not None:
|
| 785 |
+
with tf.name_scope(self.transformer.name):
|
| 786 |
+
self.transformer.build(None)
|
| 787 |
+
if getattr(self, "bias_layer", None) is not None:
|
| 788 |
+
with tf.name_scope(self.bias_layer.name):
|
| 789 |
+
self.bias_layer.build(None)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
@add_start_docstrings(
|
| 793 |
+
"""
|
| 794 |
+
The CTRL Model transformer with a sequence classification head on top (linear layer).
|
| 795 |
+
|
| 796 |
+
[`TFCTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 797 |
+
(e.g. GPT-1, GPT-2) do.
|
| 798 |
+
|
| 799 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 800 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 801 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 802 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 803 |
+
each row of the batch).
|
| 804 |
+
""",
|
| 805 |
+
CTRL_START_DOCSTRING,
|
| 806 |
+
)
|
| 807 |
+
class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassificationLoss):
|
| 808 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 809 |
+
super().__init__(config, *inputs, **kwargs)
|
| 810 |
+
self.num_labels = config.num_labels
|
| 811 |
+
self.classifier = keras.layers.Dense(
|
| 812 |
+
config.num_labels,
|
| 813 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 814 |
+
name="classifier",
|
| 815 |
+
use_bias=False,
|
| 816 |
+
)
|
| 817 |
+
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
| 818 |
+
self.config = config
|
| 819 |
+
|
| 820 |
+
def get_output_embeddings(self):
|
| 821 |
+
# Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too.
|
| 822 |
+
logger.warning(
|
| 823 |
+
"Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed "
|
| 824 |
+
"in transformers v4.32."
|
| 825 |
+
)
|
| 826 |
+
return self.transformer.w
|
| 827 |
+
|
| 828 |
+
@unpack_inputs
|
| 829 |
+
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
| 830 |
+
@add_code_sample_docstrings(
|
| 831 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 832 |
+
output_type=TFSequenceClassifierOutput,
|
| 833 |
+
config_class=_CONFIG_FOR_DOC,
|
| 834 |
+
)
|
| 835 |
+
def call(
|
| 836 |
+
self,
|
| 837 |
+
input_ids: TFModelInputType | None = None,
|
| 838 |
+
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
| 839 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 840 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 841 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 842 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 843 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 844 |
+
use_cache: Optional[bool] = None,
|
| 845 |
+
output_attentions: Optional[bool] = None,
|
| 846 |
+
output_hidden_states: Optional[bool] = None,
|
| 847 |
+
return_dict: Optional[bool] = None,
|
| 848 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 849 |
+
training: Optional[bool] = False,
|
| 850 |
+
) -> Union[Tuple, TFSequenceClassifierOutput]:
|
| 851 |
+
r"""
|
| 852 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 853 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
| 854 |
+
config.vocab_size - 1]`.
|
| 855 |
+
"""
|
| 856 |
+
|
| 857 |
+
transformer_outputs = self.transformer(
|
| 858 |
+
input_ids=input_ids,
|
| 859 |
+
past_key_values=past_key_values,
|
| 860 |
+
attention_mask=attention_mask,
|
| 861 |
+
token_type_ids=token_type_ids,
|
| 862 |
+
position_ids=position_ids,
|
| 863 |
+
head_mask=head_mask,
|
| 864 |
+
inputs_embeds=inputs_embeds,
|
| 865 |
+
use_cache=use_cache,
|
| 866 |
+
output_attentions=output_attentions,
|
| 867 |
+
output_hidden_states=output_hidden_states,
|
| 868 |
+
return_dict=return_dict,
|
| 869 |
+
training=training,
|
| 870 |
+
)
|
| 871 |
+
hidden_states = transformer_outputs[0]
|
| 872 |
+
logits = self.classifier(hidden_states)
|
| 873 |
+
logits_shape = shape_list(logits)
|
| 874 |
+
batch_size = logits_shape[0]
|
| 875 |
+
|
| 876 |
+
if self.config.pad_token_id is None:
|
| 877 |
+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
| 878 |
+
else:
|
| 879 |
+
if input_ids is not None:
|
| 880 |
+
token_indices = tf.range(shape_list(input_ids)[-1])
|
| 881 |
+
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
| 882 |
+
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
| 883 |
+
else:
|
| 884 |
+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
| 885 |
+
logger.warning_once(
|
| 886 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 887 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 888 |
+
)
|
| 889 |
+
loss = None
|
| 890 |
+
|
| 891 |
+
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
| 892 |
+
|
| 893 |
+
if labels is not None:
|
| 894 |
+
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
| 895 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 896 |
+
|
| 897 |
+
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
| 898 |
+
|
| 899 |
+
if not return_dict:
|
| 900 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 901 |
+
return ((loss,) + output) if loss is not None else output
|
| 902 |
+
|
| 903 |
+
return TFSequenceClassifierOutput(
|
| 904 |
+
loss=loss,
|
| 905 |
+
logits=pooled_logits,
|
| 906 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 907 |
+
attentions=transformer_outputs.attentions,
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
def build(self, input_shape=None):
|
| 911 |
+
if self.built:
|
| 912 |
+
return
|
| 913 |
+
self.built = True
|
| 914 |
+
if getattr(self, "classifier", None) is not None:
|
| 915 |
+
with tf.name_scope(self.classifier.name):
|
| 916 |
+
self.classifier.build([None, None, self.config.n_embd])
|
| 917 |
+
if getattr(self, "transformer", None) is not None:
|
| 918 |
+
with tf.name_scope(self.transformer.name):
|
| 919 |
+
self.transformer.build(None)
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
__all__ = ["TFCTRLForSequenceClassification", "TFCTRLLMHeadModel", "TFCTRLModel", "TFCTRLPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/ctrl/tokenization_ctrl.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Salesforce and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for Salesforce CTRL."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import regex as re
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {
|
| 30 |
+
"vocab_file": "vocab.json",
|
| 31 |
+
"merges_file": "merges.txt",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
CONTROL_CODES = {
|
| 36 |
+
"Pregnancy": 168629,
|
| 37 |
+
"Christianity": 7675,
|
| 38 |
+
"Explain": 106423,
|
| 39 |
+
"Fitness": 63440,
|
| 40 |
+
"Saving": 63163,
|
| 41 |
+
"Ask": 27171,
|
| 42 |
+
"Ass": 95985,
|
| 43 |
+
"Joke": 163509,
|
| 44 |
+
"Questions": 45622,
|
| 45 |
+
"Thoughts": 49605,
|
| 46 |
+
"Retail": 52342,
|
| 47 |
+
"Feminism": 164338,
|
| 48 |
+
"Writing": 11992,
|
| 49 |
+
"Atheism": 192263,
|
| 50 |
+
"Netflix": 48616,
|
| 51 |
+
"Computing": 39639,
|
| 52 |
+
"Opinion": 43213,
|
| 53 |
+
"Alone": 44967,
|
| 54 |
+
"Funny": 58917,
|
| 55 |
+
"Gaming": 40358,
|
| 56 |
+
"Human": 4088,
|
| 57 |
+
"India": 1331,
|
| 58 |
+
"Joker": 77138,
|
| 59 |
+
"Diet": 36206,
|
| 60 |
+
"Legal": 11859,
|
| 61 |
+
"Norman": 4939,
|
| 62 |
+
"Tip": 72689,
|
| 63 |
+
"Weight": 52343,
|
| 64 |
+
"Movies": 46273,
|
| 65 |
+
"Running": 23425,
|
| 66 |
+
"Science": 2090,
|
| 67 |
+
"Horror": 37793,
|
| 68 |
+
"Confession": 60572,
|
| 69 |
+
"Finance": 12250,
|
| 70 |
+
"Politics": 16360,
|
| 71 |
+
"Scary": 191985,
|
| 72 |
+
"Support": 12654,
|
| 73 |
+
"Technologies": 32516,
|
| 74 |
+
"Teenage": 66160,
|
| 75 |
+
"Event": 32769,
|
| 76 |
+
"Learned": 67460,
|
| 77 |
+
"Notion": 182770,
|
| 78 |
+
"Wikipedia": 37583,
|
| 79 |
+
"Books": 6665,
|
| 80 |
+
"Extract": 76050,
|
| 81 |
+
"Confessions": 102701,
|
| 82 |
+
"Conspiracy": 75932,
|
| 83 |
+
"Links": 63674,
|
| 84 |
+
"Narcissus": 150425,
|
| 85 |
+
"Relationship": 54766,
|
| 86 |
+
"Relationships": 134796,
|
| 87 |
+
"Reviews": 41671,
|
| 88 |
+
"News": 4256,
|
| 89 |
+
"Translation": 26820,
|
| 90 |
+
"multilingual": 128406,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_pairs(word):
|
| 95 |
+
"""
|
| 96 |
+
Return set of symbol pairs in a word.
|
| 97 |
+
|
| 98 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 99 |
+
"""
|
| 100 |
+
pairs = set()
|
| 101 |
+
prev_char = word[0]
|
| 102 |
+
for char in word[1:]:
|
| 103 |
+
pairs.add((prev_char, char))
|
| 104 |
+
prev_char = char
|
| 105 |
+
|
| 106 |
+
pairs = set(pairs)
|
| 107 |
+
return pairs
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class CTRLTokenizer(PreTrainedTokenizer):
|
| 111 |
+
"""
|
| 112 |
+
Construct a CTRL tokenizer. Based on Byte-Pair-Encoding.
|
| 113 |
+
|
| 114 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 115 |
+
this superclass for more information regarding those methods.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
vocab_file (`str`):
|
| 119 |
+
Path to the vocabulary file.
|
| 120 |
+
merges_file (`str`):
|
| 121 |
+
Path to the merges file.
|
| 122 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 123 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 124 |
+
token instead.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 128 |
+
control_codes = CONTROL_CODES
|
| 129 |
+
|
| 130 |
+
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
| 131 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 132 |
+
self.encoder = json.load(vocab_handle)
|
| 133 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 134 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 135 |
+
merges = merges_handle.read().split("\n")[1:-1]
|
| 136 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 137 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 138 |
+
self.cache = {}
|
| 139 |
+
super().__init__(unk_token=unk_token, **kwargs)
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def vocab_size(self):
|
| 143 |
+
return len(self.encoder)
|
| 144 |
+
|
| 145 |
+
def get_vocab(self):
|
| 146 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 147 |
+
|
| 148 |
+
def bpe(self, token):
|
| 149 |
+
if token in self.cache:
|
| 150 |
+
return self.cache[token]
|
| 151 |
+
word = tuple(token)
|
| 152 |
+
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
|
| 153 |
+
pairs = get_pairs(word)
|
| 154 |
+
|
| 155 |
+
if not pairs:
|
| 156 |
+
return token
|
| 157 |
+
|
| 158 |
+
while True:
|
| 159 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 160 |
+
if bigram not in self.bpe_ranks:
|
| 161 |
+
break
|
| 162 |
+
first, second = bigram
|
| 163 |
+
new_word = []
|
| 164 |
+
i = 0
|
| 165 |
+
while i < len(word):
|
| 166 |
+
try:
|
| 167 |
+
j = word.index(first, i)
|
| 168 |
+
except ValueError:
|
| 169 |
+
new_word.extend(word[i:])
|
| 170 |
+
break
|
| 171 |
+
else:
|
| 172 |
+
new_word.extend(word[i:j])
|
| 173 |
+
i = j
|
| 174 |
+
|
| 175 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 176 |
+
new_word.append(first + second)
|
| 177 |
+
i += 2
|
| 178 |
+
else:
|
| 179 |
+
new_word.append(word[i])
|
| 180 |
+
i += 1
|
| 181 |
+
new_word = tuple(new_word)
|
| 182 |
+
word = new_word
|
| 183 |
+
if len(word) == 1:
|
| 184 |
+
break
|
| 185 |
+
else:
|
| 186 |
+
pairs = get_pairs(word)
|
| 187 |
+
word = "@@ ".join(word)
|
| 188 |
+
word = word[:-4]
|
| 189 |
+
self.cache[token] = word
|
| 190 |
+
return word
|
| 191 |
+
|
| 192 |
+
def _tokenize(self, text):
|
| 193 |
+
"""Tokenize a string."""
|
| 194 |
+
split_tokens = []
|
| 195 |
+
|
| 196 |
+
words = re.findall(r"\S+\n?", text)
|
| 197 |
+
|
| 198 |
+
for token in words:
|
| 199 |
+
split_tokens.extend(list(self.bpe(token).split(" ")))
|
| 200 |
+
return split_tokens
|
| 201 |
+
|
| 202 |
+
def _convert_token_to_id(self, token):
|
| 203 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 204 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 205 |
+
|
| 206 |
+
def _convert_id_to_token(self, index):
|
| 207 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 208 |
+
return self.decoder.get(index, self.unk_token)
|
| 209 |
+
|
| 210 |
+
def convert_tokens_to_string(self, tokens):
|
| 211 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 212 |
+
out_string = " ".join(tokens).replace("@@ ", "").strip()
|
| 213 |
+
return out_string
|
| 214 |
+
|
| 215 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 216 |
+
if not os.path.isdir(save_directory):
|
| 217 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 218 |
+
return
|
| 219 |
+
vocab_file = os.path.join(
|
| 220 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 221 |
+
)
|
| 222 |
+
merge_file = os.path.join(
|
| 223 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 227 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 228 |
+
|
| 229 |
+
index = 0
|
| 230 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 231 |
+
writer.write("#version: 0.2\n")
|
| 232 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 233 |
+
if index != token_index:
|
| 234 |
+
logger.warning(
|
| 235 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 236 |
+
" Please check that the tokenizer is not corrupted!"
|
| 237 |
+
)
|
| 238 |
+
index = token_index
|
| 239 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 240 |
+
index += 1
|
| 241 |
+
|
| 242 |
+
return vocab_file, merge_file
|
| 243 |
+
|
| 244 |
+
# def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
| 245 |
+
# filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
|
| 246 |
+
# tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
|
| 247 |
+
# tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
|
| 248 |
+
# return ''.join(tokens_generated_so_far)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
__all__ = ["CTRLTokenizer"]
|
docs/transformers/build/lib/transformers/models/cvt/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_cvt import *
|
| 22 |
+
from .modeling_cvt import *
|
| 23 |
+
from .modeling_tf_cvt import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/cvt/configuration_cvt.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""CvT model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CvtConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model
|
| 27 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 28 |
+
defaults will yield a similar configuration to that of the CvT
|
| 29 |
+
[microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 36 |
+
The number of input channels.
|
| 37 |
+
patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`):
|
| 38 |
+
The kernel size of each encoder's patch embedding.
|
| 39 |
+
patch_stride (`List[int]`, *optional*, defaults to `[4, 2, 2]`):
|
| 40 |
+
The stride size of each encoder's patch embedding.
|
| 41 |
+
patch_padding (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
|
| 42 |
+
The padding size of each encoder's patch embedding.
|
| 43 |
+
embed_dim (`List[int]`, *optional*, defaults to `[64, 192, 384]`):
|
| 44 |
+
Dimension of each of the encoder blocks.
|
| 45 |
+
num_heads (`List[int]`, *optional*, defaults to `[1, 3, 6]`):
|
| 46 |
+
Number of attention heads for each attention layer in each block of the Transformer encoder.
|
| 47 |
+
depth (`List[int]`, *optional*, defaults to `[1, 2, 10]`):
|
| 48 |
+
The number of layers in each encoder block.
|
| 49 |
+
mlp_ratios (`List[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`):
|
| 50 |
+
Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
|
| 51 |
+
encoder blocks.
|
| 52 |
+
attention_drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
|
| 53 |
+
The dropout ratio for the attention probabilities.
|
| 54 |
+
drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
|
| 55 |
+
The dropout ratio for the patch embeddings probabilities.
|
| 56 |
+
drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`):
|
| 57 |
+
The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
|
| 58 |
+
qkv_bias (`List[bool]`, *optional*, defaults to `[True, True, True]`):
|
| 59 |
+
The bias bool for query, key and value in attentions
|
| 60 |
+
cls_token (`List[bool]`, *optional*, defaults to `[False, False, True]`):
|
| 61 |
+
Whether or not to add a classification token to the output of each of the last 3 stages.
|
| 62 |
+
qkv_projection_method (`List[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`):
|
| 63 |
+
The projection method for query, key and value Default is depth-wise convolutions with batch norm. For
|
| 64 |
+
Linear projection use "avg".
|
| 65 |
+
kernel_qkv (`List[int]`, *optional*, defaults to `[3, 3, 3]`):
|
| 66 |
+
The kernel size for query, key and value in attention layer
|
| 67 |
+
padding_kv (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
|
| 68 |
+
The padding size for key and value in attention layer
|
| 69 |
+
stride_kv (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
|
| 70 |
+
The stride size for key and value in attention layer
|
| 71 |
+
padding_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
|
| 72 |
+
The padding size for query in attention layer
|
| 73 |
+
stride_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
|
| 74 |
+
The stride size for query in attention layer
|
| 75 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 76 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 77 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 78 |
+
The epsilon used by the layer normalization layers.
|
| 79 |
+
|
| 80 |
+
Example:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
>>> from transformers import CvtConfig, CvtModel
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a Cvt msft/cvt style configuration
|
| 86 |
+
>>> configuration = CvtConfig()
|
| 87 |
+
|
| 88 |
+
>>> # Initializing a model (with random weights) from the msft/cvt style configuration
|
| 89 |
+
>>> model = CvtModel(configuration)
|
| 90 |
+
|
| 91 |
+
>>> # Accessing the model configuration
|
| 92 |
+
>>> configuration = model.config
|
| 93 |
+
```"""
|
| 94 |
+
|
| 95 |
+
model_type = "cvt"
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
num_channels=3,
|
| 100 |
+
patch_sizes=[7, 3, 3],
|
| 101 |
+
patch_stride=[4, 2, 2],
|
| 102 |
+
patch_padding=[2, 1, 1],
|
| 103 |
+
embed_dim=[64, 192, 384],
|
| 104 |
+
num_heads=[1, 3, 6],
|
| 105 |
+
depth=[1, 2, 10],
|
| 106 |
+
mlp_ratio=[4.0, 4.0, 4.0],
|
| 107 |
+
attention_drop_rate=[0.0, 0.0, 0.0],
|
| 108 |
+
drop_rate=[0.0, 0.0, 0.0],
|
| 109 |
+
drop_path_rate=[0.0, 0.0, 0.1],
|
| 110 |
+
qkv_bias=[True, True, True],
|
| 111 |
+
cls_token=[False, False, True],
|
| 112 |
+
qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"],
|
| 113 |
+
kernel_qkv=[3, 3, 3],
|
| 114 |
+
padding_kv=[1, 1, 1],
|
| 115 |
+
stride_kv=[2, 2, 2],
|
| 116 |
+
padding_q=[1, 1, 1],
|
| 117 |
+
stride_q=[1, 1, 1],
|
| 118 |
+
initializer_range=0.02,
|
| 119 |
+
layer_norm_eps=1e-12,
|
| 120 |
+
**kwargs,
|
| 121 |
+
):
|
| 122 |
+
super().__init__(**kwargs)
|
| 123 |
+
self.num_channels = num_channels
|
| 124 |
+
self.patch_sizes = patch_sizes
|
| 125 |
+
self.patch_stride = patch_stride
|
| 126 |
+
self.patch_padding = patch_padding
|
| 127 |
+
self.embed_dim = embed_dim
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.depth = depth
|
| 130 |
+
self.mlp_ratio = mlp_ratio
|
| 131 |
+
self.attention_drop_rate = attention_drop_rate
|
| 132 |
+
self.drop_rate = drop_rate
|
| 133 |
+
self.drop_path_rate = drop_path_rate
|
| 134 |
+
self.qkv_bias = qkv_bias
|
| 135 |
+
self.cls_token = cls_token
|
| 136 |
+
self.qkv_projection_method = qkv_projection_method
|
| 137 |
+
self.kernel_qkv = kernel_qkv
|
| 138 |
+
self.padding_kv = padding_kv
|
| 139 |
+
self.stride_kv = stride_kv
|
| 140 |
+
self.padding_q = padding_q
|
| 141 |
+
self.stride_q = stride_q
|
| 142 |
+
self.initializer_range = initializer_range
|
| 143 |
+
self.layer_norm_eps = layer_norm_eps
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
__all__ = ["CvtConfig"]
|
docs/transformers/build/lib/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert CvT checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/microsoft/CvT"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from huggingface_hub import hf_hub_download
|
| 26 |
+
|
| 27 |
+
from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def embeddings(idx):
|
| 31 |
+
"""
|
| 32 |
+
The function helps in renaming embedding layer weights.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
idx: stage number in original model
|
| 36 |
+
"""
|
| 37 |
+
embed = []
|
| 38 |
+
embed.append(
|
| 39 |
+
(
|
| 40 |
+
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
|
| 41 |
+
f"stage{idx}.patch_embed.proj.weight",
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
embed.append(
|
| 45 |
+
(
|
| 46 |
+
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
|
| 47 |
+
f"stage{idx}.patch_embed.proj.bias",
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
embed.append(
|
| 51 |
+
(
|
| 52 |
+
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
|
| 53 |
+
f"stage{idx}.patch_embed.norm.weight",
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
embed.append(
|
| 57 |
+
(
|
| 58 |
+
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
|
| 59 |
+
f"stage{idx}.patch_embed.norm.bias",
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
return embed
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def attention(idx, cnt):
|
| 66 |
+
"""
|
| 67 |
+
The function helps in renaming attention block layers weights.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
idx: stage number in original model
|
| 71 |
+
cnt: count of blocks in each stage
|
| 72 |
+
"""
|
| 73 |
+
attention_weights = []
|
| 74 |
+
attention_weights.append(
|
| 75 |
+
(
|
| 76 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
|
| 77 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
attention_weights.append(
|
| 81 |
+
(
|
| 82 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
|
| 83 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
attention_weights.append(
|
| 87 |
+
(
|
| 88 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
|
| 89 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
attention_weights.append(
|
| 93 |
+
(
|
| 94 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
|
| 95 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
attention_weights.append(
|
| 99 |
+
(
|
| 100 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
|
| 101 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
attention_weights.append(
|
| 105 |
+
(
|
| 106 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
|
| 107 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
attention_weights.append(
|
| 111 |
+
(
|
| 112 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
|
| 113 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
attention_weights.append(
|
| 117 |
+
(
|
| 118 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
|
| 119 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
attention_weights.append(
|
| 123 |
+
(
|
| 124 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
|
| 125 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
attention_weights.append(
|
| 129 |
+
(
|
| 130 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
|
| 131 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
attention_weights.append(
|
| 135 |
+
(
|
| 136 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
|
| 137 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
attention_weights.append(
|
| 141 |
+
(
|
| 142 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
|
| 143 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
attention_weights.append(
|
| 147 |
+
(
|
| 148 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
|
| 149 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
attention_weights.append(
|
| 153 |
+
(
|
| 154 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
|
| 155 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
attention_weights.append(
|
| 159 |
+
(
|
| 160 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
|
| 161 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
attention_weights.append(
|
| 165 |
+
(
|
| 166 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
|
| 167 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
attention_weights.append(
|
| 171 |
+
(
|
| 172 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
|
| 173 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
|
| 174 |
+
)
|
| 175 |
+
)
|
| 176 |
+
attention_weights.append(
|
| 177 |
+
(
|
| 178 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
|
| 179 |
+
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
attention_weights.append(
|
| 183 |
+
(
|
| 184 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
|
| 185 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
attention_weights.append(
|
| 189 |
+
(
|
| 190 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
|
| 191 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
attention_weights.append(
|
| 195 |
+
(
|
| 196 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
|
| 197 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
attention_weights.append(
|
| 201 |
+
(
|
| 202 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
|
| 203 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
attention_weights.append(
|
| 207 |
+
(
|
| 208 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
|
| 209 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
attention_weights.append(
|
| 213 |
+
(
|
| 214 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
|
| 215 |
+
f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
attention_weights.append(
|
| 219 |
+
(
|
| 220 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
|
| 221 |
+
f"stage{idx}.blocks.{cnt}.attn.proj.weight",
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
attention_weights.append(
|
| 225 |
+
(
|
| 226 |
+
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
|
| 227 |
+
f"stage{idx}.blocks.{cnt}.attn.proj.bias",
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
attention_weights.append(
|
| 231 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
|
| 232 |
+
)
|
| 233 |
+
attention_weights.append(
|
| 234 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
|
| 235 |
+
)
|
| 236 |
+
attention_weights.append(
|
| 237 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
|
| 238 |
+
)
|
| 239 |
+
attention_weights.append(
|
| 240 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
|
| 241 |
+
)
|
| 242 |
+
attention_weights.append(
|
| 243 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
|
| 244 |
+
)
|
| 245 |
+
attention_weights.append(
|
| 246 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
|
| 247 |
+
)
|
| 248 |
+
attention_weights.append(
|
| 249 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
|
| 250 |
+
)
|
| 251 |
+
attention_weights.append(
|
| 252 |
+
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
|
| 253 |
+
)
|
| 254 |
+
return attention_weights
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def cls_token(idx):
|
| 258 |
+
"""
|
| 259 |
+
Function helps in renaming cls_token weights
|
| 260 |
+
"""
|
| 261 |
+
token = []
|
| 262 |
+
token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
|
| 263 |
+
return token
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def final():
|
| 267 |
+
"""
|
| 268 |
+
Function helps in renaming final classification layer
|
| 269 |
+
"""
|
| 270 |
+
head = []
|
| 271 |
+
head.append(("layernorm.weight", "norm.weight"))
|
| 272 |
+
head.append(("layernorm.bias", "norm.bias"))
|
| 273 |
+
head.append(("classifier.weight", "head.weight"))
|
| 274 |
+
head.append(("classifier.bias", "head.bias"))
|
| 275 |
+
return head
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
|
| 279 |
+
"""
|
| 280 |
+
Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
|
| 281 |
+
"""
|
| 282 |
+
img_labels_file = "imagenet-1k-id2label.json"
|
| 283 |
+
num_labels = 1000
|
| 284 |
+
|
| 285 |
+
repo_id = "huggingface/label-files"
|
| 286 |
+
num_labels = num_labels
|
| 287 |
+
id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text())
|
| 288 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 289 |
+
|
| 290 |
+
id2label = id2label
|
| 291 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 292 |
+
|
| 293 |
+
config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
|
| 294 |
+
|
| 295 |
+
# For depth size 13 (13 = 1+2+10)
|
| 296 |
+
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
|
| 297 |
+
config.depth = [1, 2, 10]
|
| 298 |
+
|
| 299 |
+
# For depth size 21 (21 = 1+4+16)
|
| 300 |
+
elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
|
| 301 |
+
config.depth = [1, 4, 16]
|
| 302 |
+
|
| 303 |
+
# For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
|
| 304 |
+
else:
|
| 305 |
+
config.depth = [2, 2, 20]
|
| 306 |
+
config.num_heads = [3, 12, 16]
|
| 307 |
+
config.embed_dim = [192, 768, 1024]
|
| 308 |
+
|
| 309 |
+
model = CvtForImageClassification(config)
|
| 310 |
+
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
|
| 311 |
+
image_processor.size["shortest_edge"] = image_size
|
| 312 |
+
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"), weights_only=True)
|
| 313 |
+
|
| 314 |
+
huggingface_weights = OrderedDict()
|
| 315 |
+
list_of_state_dict = []
|
| 316 |
+
|
| 317 |
+
for idx in range(len(config.depth)):
|
| 318 |
+
if config.cls_token[idx]:
|
| 319 |
+
list_of_state_dict = list_of_state_dict + cls_token(idx)
|
| 320 |
+
list_of_state_dict = list_of_state_dict + embeddings(idx)
|
| 321 |
+
for cnt in range(config.depth[idx]):
|
| 322 |
+
list_of_state_dict = list_of_state_dict + attention(idx, cnt)
|
| 323 |
+
|
| 324 |
+
list_of_state_dict = list_of_state_dict + final()
|
| 325 |
+
for gg in list_of_state_dict:
|
| 326 |
+
print(gg)
|
| 327 |
+
for i in range(len(list_of_state_dict)):
|
| 328 |
+
huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
|
| 329 |
+
|
| 330 |
+
model.load_state_dict(huggingface_weights)
|
| 331 |
+
model.save_pretrained(pytorch_dump_folder)
|
| 332 |
+
image_processor.save_pretrained(pytorch_dump_folder)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
parser = argparse.ArgumentParser()
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--cvt_model",
|
| 341 |
+
default="cvt-w24",
|
| 342 |
+
type=str,
|
| 343 |
+
help="Name of the cvt model you'd like to convert.",
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--image_size",
|
| 347 |
+
default=384,
|
| 348 |
+
type=int,
|
| 349 |
+
help="Input Image Size",
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--cvt_file_name",
|
| 353 |
+
default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
|
| 354 |
+
type=str,
|
| 355 |
+
help="Input Image Size",
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
args = parser.parse_args()
|
| 362 |
+
convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/cvt/modeling_cvt.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch CvT model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
| 27 |
+
from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
|
| 28 |
+
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
| 29 |
+
from ...utils import logging
|
| 30 |
+
from .configuration_cvt import CvtConfig
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
# General docstring
|
| 36 |
+
_CONFIG_FOR_DOC = "CvtConfig"
|
| 37 |
+
|
| 38 |
+
# Base docstring
|
| 39 |
+
_CHECKPOINT_FOR_DOC = "microsoft/cvt-13"
|
| 40 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14]
|
| 41 |
+
|
| 42 |
+
# Image classification docstring
|
| 43 |
+
_IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13"
|
| 44 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class BaseModelOutputWithCLSToken(ModelOutput):
|
| 49 |
+
"""
|
| 50 |
+
Base class for model's outputs, with potential hidden states and attentions.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 54 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 55 |
+
cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
|
| 56 |
+
Classification token at the output of the last layer of the model.
|
| 57 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 58 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 59 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 60 |
+
plus the initial embedding outputs.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 64 |
+
cls_token_value: Optional[torch.FloatTensor] = None
|
| 65 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 69 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 72 |
+
|
| 73 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 74 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 75 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 76 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 77 |
+
argument.
|
| 78 |
+
"""
|
| 79 |
+
if drop_prob == 0.0 or not training:
|
| 80 |
+
return input
|
| 81 |
+
keep_prob = 1 - drop_prob
|
| 82 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 83 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 84 |
+
random_tensor.floor_() # binarize
|
| 85 |
+
output = input.div(keep_prob) * random_tensor
|
| 86 |
+
return output
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
| 90 |
+
class CvtDropPath(nn.Module):
|
| 91 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 92 |
+
|
| 93 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.drop_prob = drop_prob
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 99 |
+
|
| 100 |
+
def extra_repr(self) -> str:
|
| 101 |
+
return "p={}".format(self.drop_prob)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class CvtEmbeddings(nn.Module):
|
| 105 |
+
"""
|
| 106 |
+
Construct the CvT embeddings.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.convolution_embeddings = CvtConvEmbeddings(
|
| 112 |
+
patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
|
| 113 |
+
)
|
| 114 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 115 |
+
|
| 116 |
+
def forward(self, pixel_values):
|
| 117 |
+
hidden_state = self.convolution_embeddings(pixel_values)
|
| 118 |
+
hidden_state = self.dropout(hidden_state)
|
| 119 |
+
return hidden_state
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class CvtConvEmbeddings(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
Image to Conv Embedding.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
|
| 128 |
+
super().__init__()
|
| 129 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 130 |
+
self.patch_size = patch_size
|
| 131 |
+
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
|
| 132 |
+
self.normalization = nn.LayerNorm(embed_dim)
|
| 133 |
+
|
| 134 |
+
def forward(self, pixel_values):
|
| 135 |
+
pixel_values = self.projection(pixel_values)
|
| 136 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 137 |
+
hidden_size = height * width
|
| 138 |
+
# rearrange "b c h w -> b (h w) c"
|
| 139 |
+
pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
|
| 140 |
+
if self.normalization:
|
| 141 |
+
pixel_values = self.normalization(pixel_values)
|
| 142 |
+
# rearrange "b (h w) c" -> b c h w"
|
| 143 |
+
pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width)
|
| 144 |
+
return pixel_values
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CvtSelfAttentionConvProjection(nn.Module):
|
| 148 |
+
def __init__(self, embed_dim, kernel_size, padding, stride):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.convolution = nn.Conv2d(
|
| 151 |
+
embed_dim,
|
| 152 |
+
embed_dim,
|
| 153 |
+
kernel_size=kernel_size,
|
| 154 |
+
padding=padding,
|
| 155 |
+
stride=stride,
|
| 156 |
+
bias=False,
|
| 157 |
+
groups=embed_dim,
|
| 158 |
+
)
|
| 159 |
+
self.normalization = nn.BatchNorm2d(embed_dim)
|
| 160 |
+
|
| 161 |
+
def forward(self, hidden_state):
|
| 162 |
+
hidden_state = self.convolution(hidden_state)
|
| 163 |
+
hidden_state = self.normalization(hidden_state)
|
| 164 |
+
return hidden_state
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class CvtSelfAttentionLinearProjection(nn.Module):
|
| 168 |
+
def forward(self, hidden_state):
|
| 169 |
+
batch_size, num_channels, height, width = hidden_state.shape
|
| 170 |
+
hidden_size = height * width
|
| 171 |
+
# rearrange " b c h w -> b (h w) c"
|
| 172 |
+
hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
|
| 173 |
+
return hidden_state
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class CvtSelfAttentionProjection(nn.Module):
|
| 177 |
+
def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"):
|
| 178 |
+
super().__init__()
|
| 179 |
+
if projection_method == "dw_bn":
|
| 180 |
+
self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride)
|
| 181 |
+
self.linear_projection = CvtSelfAttentionLinearProjection()
|
| 182 |
+
|
| 183 |
+
def forward(self, hidden_state):
|
| 184 |
+
hidden_state = self.convolution_projection(hidden_state)
|
| 185 |
+
hidden_state = self.linear_projection(hidden_state)
|
| 186 |
+
return hidden_state
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class CvtSelfAttention(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
num_heads,
|
| 193 |
+
embed_dim,
|
| 194 |
+
kernel_size,
|
| 195 |
+
padding_q,
|
| 196 |
+
padding_kv,
|
| 197 |
+
stride_q,
|
| 198 |
+
stride_kv,
|
| 199 |
+
qkv_projection_method,
|
| 200 |
+
qkv_bias,
|
| 201 |
+
attention_drop_rate,
|
| 202 |
+
with_cls_token=True,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.scale = embed_dim**-0.5
|
| 207 |
+
self.with_cls_token = with_cls_token
|
| 208 |
+
self.embed_dim = embed_dim
|
| 209 |
+
self.num_heads = num_heads
|
| 210 |
+
|
| 211 |
+
self.convolution_projection_query = CvtSelfAttentionProjection(
|
| 212 |
+
embed_dim,
|
| 213 |
+
kernel_size,
|
| 214 |
+
padding_q,
|
| 215 |
+
stride_q,
|
| 216 |
+
projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
|
| 217 |
+
)
|
| 218 |
+
self.convolution_projection_key = CvtSelfAttentionProjection(
|
| 219 |
+
embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
|
| 220 |
+
)
|
| 221 |
+
self.convolution_projection_value = CvtSelfAttentionProjection(
|
| 222 |
+
embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
| 226 |
+
self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
| 227 |
+
self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
| 228 |
+
|
| 229 |
+
self.dropout = nn.Dropout(attention_drop_rate)
|
| 230 |
+
|
| 231 |
+
def rearrange_for_multi_head_attention(self, hidden_state):
|
| 232 |
+
batch_size, hidden_size, _ = hidden_state.shape
|
| 233 |
+
head_dim = self.embed_dim // self.num_heads
|
| 234 |
+
# rearrange 'b t (h d) -> b h t d'
|
| 235 |
+
return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
|
| 236 |
+
|
| 237 |
+
def forward(self, hidden_state, height, width):
|
| 238 |
+
if self.with_cls_token:
|
| 239 |
+
cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
|
| 240 |
+
batch_size, hidden_size, num_channels = hidden_state.shape
|
| 241 |
+
# rearrange "b (h w) c -> b c h w"
|
| 242 |
+
hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
|
| 243 |
+
|
| 244 |
+
key = self.convolution_projection_key(hidden_state)
|
| 245 |
+
query = self.convolution_projection_query(hidden_state)
|
| 246 |
+
value = self.convolution_projection_value(hidden_state)
|
| 247 |
+
|
| 248 |
+
if self.with_cls_token:
|
| 249 |
+
query = torch.cat((cls_token, query), dim=1)
|
| 250 |
+
key = torch.cat((cls_token, key), dim=1)
|
| 251 |
+
value = torch.cat((cls_token, value), dim=1)
|
| 252 |
+
|
| 253 |
+
head_dim = self.embed_dim // self.num_heads
|
| 254 |
+
|
| 255 |
+
query = self.rearrange_for_multi_head_attention(self.projection_query(query))
|
| 256 |
+
key = self.rearrange_for_multi_head_attention(self.projection_key(key))
|
| 257 |
+
value = self.rearrange_for_multi_head_attention(self.projection_value(value))
|
| 258 |
+
|
| 259 |
+
attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
|
| 260 |
+
attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
|
| 261 |
+
attention_probs = self.dropout(attention_probs)
|
| 262 |
+
|
| 263 |
+
context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
|
| 264 |
+
# rearrange"b h t d -> b t (h d)"
|
| 265 |
+
_, _, hidden_size, _ = context.shape
|
| 266 |
+
context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim)
|
| 267 |
+
return context
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class CvtSelfOutput(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
|
| 273 |
+
layernorm applied before each block.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(self, embed_dim, drop_rate):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
| 279 |
+
self.dropout = nn.Dropout(drop_rate)
|
| 280 |
+
|
| 281 |
+
def forward(self, hidden_state, input_tensor):
|
| 282 |
+
hidden_state = self.dense(hidden_state)
|
| 283 |
+
hidden_state = self.dropout(hidden_state)
|
| 284 |
+
return hidden_state
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CvtAttention(nn.Module):
|
| 288 |
+
def __init__(
|
| 289 |
+
self,
|
| 290 |
+
num_heads,
|
| 291 |
+
embed_dim,
|
| 292 |
+
kernel_size,
|
| 293 |
+
padding_q,
|
| 294 |
+
padding_kv,
|
| 295 |
+
stride_q,
|
| 296 |
+
stride_kv,
|
| 297 |
+
qkv_projection_method,
|
| 298 |
+
qkv_bias,
|
| 299 |
+
attention_drop_rate,
|
| 300 |
+
drop_rate,
|
| 301 |
+
with_cls_token=True,
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.attention = CvtSelfAttention(
|
| 305 |
+
num_heads,
|
| 306 |
+
embed_dim,
|
| 307 |
+
kernel_size,
|
| 308 |
+
padding_q,
|
| 309 |
+
padding_kv,
|
| 310 |
+
stride_q,
|
| 311 |
+
stride_kv,
|
| 312 |
+
qkv_projection_method,
|
| 313 |
+
qkv_bias,
|
| 314 |
+
attention_drop_rate,
|
| 315 |
+
with_cls_token,
|
| 316 |
+
)
|
| 317 |
+
self.output = CvtSelfOutput(embed_dim, drop_rate)
|
| 318 |
+
self.pruned_heads = set()
|
| 319 |
+
|
| 320 |
+
def prune_heads(self, heads):
|
| 321 |
+
if len(heads) == 0:
|
| 322 |
+
return
|
| 323 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 324 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Prune linear layers
|
| 328 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 329 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 330 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 331 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 332 |
+
|
| 333 |
+
# Update hyper params and store pruned heads
|
| 334 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 335 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 336 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 337 |
+
|
| 338 |
+
def forward(self, hidden_state, height, width):
|
| 339 |
+
self_output = self.attention(hidden_state, height, width)
|
| 340 |
+
attention_output = self.output(self_output, hidden_state)
|
| 341 |
+
return attention_output
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class CvtIntermediate(nn.Module):
|
| 345 |
+
def __init__(self, embed_dim, mlp_ratio):
|
| 346 |
+
super().__init__()
|
| 347 |
+
self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
|
| 348 |
+
self.activation = nn.GELU()
|
| 349 |
+
|
| 350 |
+
def forward(self, hidden_state):
|
| 351 |
+
hidden_state = self.dense(hidden_state)
|
| 352 |
+
hidden_state = self.activation(hidden_state)
|
| 353 |
+
return hidden_state
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class CvtOutput(nn.Module):
|
| 357 |
+
def __init__(self, embed_dim, mlp_ratio, drop_rate):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
|
| 360 |
+
self.dropout = nn.Dropout(drop_rate)
|
| 361 |
+
|
| 362 |
+
def forward(self, hidden_state, input_tensor):
|
| 363 |
+
hidden_state = self.dense(hidden_state)
|
| 364 |
+
hidden_state = self.dropout(hidden_state)
|
| 365 |
+
hidden_state = hidden_state + input_tensor
|
| 366 |
+
return hidden_state
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class CvtLayer(nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
num_heads,
|
| 377 |
+
embed_dim,
|
| 378 |
+
kernel_size,
|
| 379 |
+
padding_q,
|
| 380 |
+
padding_kv,
|
| 381 |
+
stride_q,
|
| 382 |
+
stride_kv,
|
| 383 |
+
qkv_projection_method,
|
| 384 |
+
qkv_bias,
|
| 385 |
+
attention_drop_rate,
|
| 386 |
+
drop_rate,
|
| 387 |
+
mlp_ratio,
|
| 388 |
+
drop_path_rate,
|
| 389 |
+
with_cls_token=True,
|
| 390 |
+
):
|
| 391 |
+
super().__init__()
|
| 392 |
+
self.attention = CvtAttention(
|
| 393 |
+
num_heads,
|
| 394 |
+
embed_dim,
|
| 395 |
+
kernel_size,
|
| 396 |
+
padding_q,
|
| 397 |
+
padding_kv,
|
| 398 |
+
stride_q,
|
| 399 |
+
stride_kv,
|
| 400 |
+
qkv_projection_method,
|
| 401 |
+
qkv_bias,
|
| 402 |
+
attention_drop_rate,
|
| 403 |
+
drop_rate,
|
| 404 |
+
with_cls_token,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
self.intermediate = CvtIntermediate(embed_dim, mlp_ratio)
|
| 408 |
+
self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate)
|
| 409 |
+
self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 410 |
+
self.layernorm_before = nn.LayerNorm(embed_dim)
|
| 411 |
+
self.layernorm_after = nn.LayerNorm(embed_dim)
|
| 412 |
+
|
| 413 |
+
def forward(self, hidden_state, height, width):
|
| 414 |
+
self_attention_output = self.attention(
|
| 415 |
+
self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention
|
| 416 |
+
height,
|
| 417 |
+
width,
|
| 418 |
+
)
|
| 419 |
+
attention_output = self_attention_output
|
| 420 |
+
attention_output = self.drop_path(attention_output)
|
| 421 |
+
|
| 422 |
+
# first residual connection
|
| 423 |
+
hidden_state = attention_output + hidden_state
|
| 424 |
+
|
| 425 |
+
# in Cvt, layernorm is also applied after self-attention
|
| 426 |
+
layer_output = self.layernorm_after(hidden_state)
|
| 427 |
+
layer_output = self.intermediate(layer_output)
|
| 428 |
+
|
| 429 |
+
# second residual connection is done here
|
| 430 |
+
layer_output = self.output(layer_output, hidden_state)
|
| 431 |
+
layer_output = self.drop_path(layer_output)
|
| 432 |
+
return layer_output
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class CvtStage(nn.Module):
|
| 436 |
+
def __init__(self, config, stage):
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.config = config
|
| 439 |
+
self.stage = stage
|
| 440 |
+
if self.config.cls_token[self.stage]:
|
| 441 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1]))
|
| 442 |
+
|
| 443 |
+
self.embedding = CvtEmbeddings(
|
| 444 |
+
patch_size=config.patch_sizes[self.stage],
|
| 445 |
+
stride=config.patch_stride[self.stage],
|
| 446 |
+
num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
|
| 447 |
+
embed_dim=config.embed_dim[self.stage],
|
| 448 |
+
padding=config.patch_padding[self.stage],
|
| 449 |
+
dropout_rate=config.drop_rate[self.stage],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
drop_path_rates = [
|
| 453 |
+
x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu")
|
| 454 |
+
]
|
| 455 |
+
|
| 456 |
+
self.layers = nn.Sequential(
|
| 457 |
+
*[
|
| 458 |
+
CvtLayer(
|
| 459 |
+
num_heads=config.num_heads[self.stage],
|
| 460 |
+
embed_dim=config.embed_dim[self.stage],
|
| 461 |
+
kernel_size=config.kernel_qkv[self.stage],
|
| 462 |
+
padding_q=config.padding_q[self.stage],
|
| 463 |
+
padding_kv=config.padding_kv[self.stage],
|
| 464 |
+
stride_kv=config.stride_kv[self.stage],
|
| 465 |
+
stride_q=config.stride_q[self.stage],
|
| 466 |
+
qkv_projection_method=config.qkv_projection_method[self.stage],
|
| 467 |
+
qkv_bias=config.qkv_bias[self.stage],
|
| 468 |
+
attention_drop_rate=config.attention_drop_rate[self.stage],
|
| 469 |
+
drop_rate=config.drop_rate[self.stage],
|
| 470 |
+
drop_path_rate=drop_path_rates[self.stage],
|
| 471 |
+
mlp_ratio=config.mlp_ratio[self.stage],
|
| 472 |
+
with_cls_token=config.cls_token[self.stage],
|
| 473 |
+
)
|
| 474 |
+
for _ in range(config.depth[self.stage])
|
| 475 |
+
]
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
def forward(self, hidden_state):
|
| 479 |
+
cls_token = None
|
| 480 |
+
hidden_state = self.embedding(hidden_state)
|
| 481 |
+
batch_size, num_channels, height, width = hidden_state.shape
|
| 482 |
+
# rearrange b c h w -> b (h w) c"
|
| 483 |
+
hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
| 484 |
+
if self.config.cls_token[self.stage]:
|
| 485 |
+
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
| 486 |
+
hidden_state = torch.cat((cls_token, hidden_state), dim=1)
|
| 487 |
+
|
| 488 |
+
for layer in self.layers:
|
| 489 |
+
layer_outputs = layer(hidden_state, height, width)
|
| 490 |
+
hidden_state = layer_outputs
|
| 491 |
+
|
| 492 |
+
if self.config.cls_token[self.stage]:
|
| 493 |
+
cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
|
| 494 |
+
hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
|
| 495 |
+
return hidden_state, cls_token
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class CvtEncoder(nn.Module):
|
| 499 |
+
def __init__(self, config):
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.config = config
|
| 502 |
+
self.stages = nn.ModuleList([])
|
| 503 |
+
for stage_idx in range(len(config.depth)):
|
| 504 |
+
self.stages.append(CvtStage(config, stage_idx))
|
| 505 |
+
|
| 506 |
+
def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
|
| 507 |
+
all_hidden_states = () if output_hidden_states else None
|
| 508 |
+
hidden_state = pixel_values
|
| 509 |
+
|
| 510 |
+
cls_token = None
|
| 511 |
+
for _, (stage_module) in enumerate(self.stages):
|
| 512 |
+
hidden_state, cls_token = stage_module(hidden_state)
|
| 513 |
+
if output_hidden_states:
|
| 514 |
+
all_hidden_states = all_hidden_states + (hidden_state,)
|
| 515 |
+
|
| 516 |
+
if not return_dict:
|
| 517 |
+
return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
|
| 518 |
+
|
| 519 |
+
return BaseModelOutputWithCLSToken(
|
| 520 |
+
last_hidden_state=hidden_state,
|
| 521 |
+
cls_token_value=cls_token,
|
| 522 |
+
hidden_states=all_hidden_states,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class CvtPreTrainedModel(PreTrainedModel):
|
| 527 |
+
"""
|
| 528 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 529 |
+
models.
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
config_class = CvtConfig
|
| 533 |
+
base_model_prefix = "cvt"
|
| 534 |
+
main_input_name = "pixel_values"
|
| 535 |
+
_no_split_modules = ["CvtLayer"]
|
| 536 |
+
|
| 537 |
+
def _init_weights(self, module):
|
| 538 |
+
"""Initialize the weights"""
|
| 539 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 540 |
+
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
| 541 |
+
if module.bias is not None:
|
| 542 |
+
module.bias.data.zero_()
|
| 543 |
+
elif isinstance(module, nn.LayerNorm):
|
| 544 |
+
module.bias.data.zero_()
|
| 545 |
+
module.weight.data.fill_(1.0)
|
| 546 |
+
elif isinstance(module, CvtStage):
|
| 547 |
+
if self.config.cls_token[module.stage]:
|
| 548 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
| 549 |
+
module.cls_token.data, mean=0.0, std=self.config.initializer_range
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
CVT_START_DOCSTRING = r"""
|
| 554 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 555 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 556 |
+
behavior.
|
| 557 |
+
|
| 558 |
+
Parameters:
|
| 559 |
+
config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
|
| 560 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 561 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 562 |
+
"""
|
| 563 |
+
|
| 564 |
+
CVT_INPUTS_DOCSTRING = r"""
|
| 565 |
+
Args:
|
| 566 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 567 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
|
| 568 |
+
for details.
|
| 569 |
+
output_hidden_states (`bool`, *optional*):
|
| 570 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 571 |
+
more detail.
|
| 572 |
+
return_dict (`bool`, *optional*):
|
| 573 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
@add_start_docstrings(
|
| 578 |
+
"The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
|
| 579 |
+
CVT_START_DOCSTRING,
|
| 580 |
+
)
|
| 581 |
+
class CvtModel(CvtPreTrainedModel):
|
| 582 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 583 |
+
super().__init__(config)
|
| 584 |
+
self.config = config
|
| 585 |
+
self.encoder = CvtEncoder(config)
|
| 586 |
+
self.post_init()
|
| 587 |
+
|
| 588 |
+
def _prune_heads(self, heads_to_prune):
|
| 589 |
+
"""
|
| 590 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 591 |
+
class PreTrainedModel
|
| 592 |
+
"""
|
| 593 |
+
for layer, heads in heads_to_prune.items():
|
| 594 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 595 |
+
|
| 596 |
+
@add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
|
| 597 |
+
@add_code_sample_docstrings(
|
| 598 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 599 |
+
output_type=BaseModelOutputWithCLSToken,
|
| 600 |
+
config_class=_CONFIG_FOR_DOC,
|
| 601 |
+
modality="vision",
|
| 602 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 603 |
+
)
|
| 604 |
+
def forward(
|
| 605 |
+
self,
|
| 606 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 607 |
+
output_hidden_states: Optional[bool] = None,
|
| 608 |
+
return_dict: Optional[bool] = None,
|
| 609 |
+
) -> Union[Tuple, BaseModelOutputWithCLSToken]:
|
| 610 |
+
output_hidden_states = (
|
| 611 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 612 |
+
)
|
| 613 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 614 |
+
|
| 615 |
+
if pixel_values is None:
|
| 616 |
+
raise ValueError("You have to specify pixel_values")
|
| 617 |
+
|
| 618 |
+
encoder_outputs = self.encoder(
|
| 619 |
+
pixel_values,
|
| 620 |
+
output_hidden_states=output_hidden_states,
|
| 621 |
+
return_dict=return_dict,
|
| 622 |
+
)
|
| 623 |
+
sequence_output = encoder_outputs[0]
|
| 624 |
+
|
| 625 |
+
if not return_dict:
|
| 626 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 627 |
+
|
| 628 |
+
return BaseModelOutputWithCLSToken(
|
| 629 |
+
last_hidden_state=sequence_output,
|
| 630 |
+
cls_token_value=encoder_outputs.cls_token_value,
|
| 631 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@add_start_docstrings(
|
| 636 |
+
"""
|
| 637 |
+
Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 638 |
+
the [CLS] token) e.g. for ImageNet.
|
| 639 |
+
""",
|
| 640 |
+
CVT_START_DOCSTRING,
|
| 641 |
+
)
|
| 642 |
+
class CvtForImageClassification(CvtPreTrainedModel):
|
| 643 |
+
def __init__(self, config):
|
| 644 |
+
super().__init__(config)
|
| 645 |
+
|
| 646 |
+
self.num_labels = config.num_labels
|
| 647 |
+
self.cvt = CvtModel(config, add_pooling_layer=False)
|
| 648 |
+
self.layernorm = nn.LayerNorm(config.embed_dim[-1])
|
| 649 |
+
# Classifier head
|
| 650 |
+
self.classifier = (
|
| 651 |
+
nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Initialize weights and apply final processing
|
| 655 |
+
self.post_init()
|
| 656 |
+
|
| 657 |
+
@add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
|
| 658 |
+
@add_code_sample_docstrings(
|
| 659 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 660 |
+
output_type=ImageClassifierOutputWithNoAttention,
|
| 661 |
+
config_class=_CONFIG_FOR_DOC,
|
| 662 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 663 |
+
)
|
| 664 |
+
def forward(
|
| 665 |
+
self,
|
| 666 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 667 |
+
labels: Optional[torch.Tensor] = None,
|
| 668 |
+
output_hidden_states: Optional[bool] = None,
|
| 669 |
+
return_dict: Optional[bool] = None,
|
| 670 |
+
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
| 671 |
+
r"""
|
| 672 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 673 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 674 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 675 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 676 |
+
"""
|
| 677 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 678 |
+
outputs = self.cvt(
|
| 679 |
+
pixel_values,
|
| 680 |
+
output_hidden_states=output_hidden_states,
|
| 681 |
+
return_dict=return_dict,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
sequence_output = outputs[0]
|
| 685 |
+
cls_token = outputs[1]
|
| 686 |
+
if self.config.cls_token[-1]:
|
| 687 |
+
sequence_output = self.layernorm(cls_token)
|
| 688 |
+
else:
|
| 689 |
+
batch_size, num_channels, height, width = sequence_output.shape
|
| 690 |
+
# rearrange "b c h w -> b (h w) c"
|
| 691 |
+
sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
| 692 |
+
sequence_output = self.layernorm(sequence_output)
|
| 693 |
+
|
| 694 |
+
sequence_output_mean = sequence_output.mean(dim=1)
|
| 695 |
+
logits = self.classifier(sequence_output_mean)
|
| 696 |
+
|
| 697 |
+
loss = None
|
| 698 |
+
if labels is not None:
|
| 699 |
+
if self.config.problem_type is None:
|
| 700 |
+
if self.config.num_labels == 1:
|
| 701 |
+
self.config.problem_type = "regression"
|
| 702 |
+
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 703 |
+
self.config.problem_type = "single_label_classification"
|
| 704 |
+
else:
|
| 705 |
+
self.config.problem_type = "multi_label_classification"
|
| 706 |
+
|
| 707 |
+
if self.config.problem_type == "regression":
|
| 708 |
+
loss_fct = MSELoss()
|
| 709 |
+
if self.config.num_labels == 1:
|
| 710 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 711 |
+
else:
|
| 712 |
+
loss = loss_fct(logits, labels)
|
| 713 |
+
elif self.config.problem_type == "single_label_classification":
|
| 714 |
+
loss_fct = CrossEntropyLoss()
|
| 715 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 716 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 717 |
+
loss_fct = BCEWithLogitsLoss()
|
| 718 |
+
loss = loss_fct(logits, labels)
|
| 719 |
+
|
| 720 |
+
if not return_dict:
|
| 721 |
+
output = (logits,) + outputs[2:]
|
| 722 |
+
return ((loss,) + output) if loss is not None else output
|
| 723 |
+
|
| 724 |
+
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
__all__ = ["CvtForImageClassification", "CvtModel", "CvtPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/cvt/modeling_tf_cvt.py
ADDED
|
@@ -0,0 +1,1096 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 Cvt model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import collections.abc
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention
|
| 26 |
+
from ...modeling_tf_utils import (
|
| 27 |
+
TFModelInputType,
|
| 28 |
+
TFPreTrainedModel,
|
| 29 |
+
TFSequenceClassificationLoss,
|
| 30 |
+
get_initializer,
|
| 31 |
+
keras,
|
| 32 |
+
keras_serializable,
|
| 33 |
+
unpack_inputs,
|
| 34 |
+
)
|
| 35 |
+
from ...tf_utils import shape_list, stable_softmax
|
| 36 |
+
from ...utils import (
|
| 37 |
+
ModelOutput,
|
| 38 |
+
add_start_docstrings,
|
| 39 |
+
add_start_docstrings_to_model_forward,
|
| 40 |
+
logging,
|
| 41 |
+
replace_return_docstrings,
|
| 42 |
+
)
|
| 43 |
+
from .configuration_cvt import CvtConfig
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
# General docstring
|
| 49 |
+
_CONFIG_FOR_DOC = "CvtConfig"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class TFBaseModelOutputWithCLSToken(ModelOutput):
|
| 54 |
+
"""
|
| 55 |
+
Base class for model's outputs.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 59 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 60 |
+
cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`):
|
| 61 |
+
Classification token at the output of the last layer of the model.
|
| 62 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 63 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 64 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 65 |
+
the initial embedding outputs.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
last_hidden_state: Optional[tf.Tensor] = None
|
| 69 |
+
cls_token_value: Optional[tf.Tensor] = None
|
| 70 |
+
hidden_states: Tuple[tf.Tensor, ...] | None = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TFCvtDropPath(keras.layers.Layer):
|
| 74 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 75 |
+
References:
|
| 76 |
+
(1) github.com:rwightman/pytorch-image-models
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, drop_prob: float, **kwargs):
|
| 80 |
+
super().__init__(**kwargs)
|
| 81 |
+
self.drop_prob = drop_prob
|
| 82 |
+
|
| 83 |
+
def call(self, x: tf.Tensor, training=None):
|
| 84 |
+
if self.drop_prob == 0.0 or not training:
|
| 85 |
+
return x
|
| 86 |
+
keep_prob = 1 - self.drop_prob
|
| 87 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
| 88 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
|
| 89 |
+
random_tensor = tf.floor(random_tensor)
|
| 90 |
+
return (x / keep_prob) * random_tensor
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TFCvtEmbeddings(keras.layers.Layer):
|
| 94 |
+
"""Construct the Convolutional Token Embeddings."""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
config: CvtConfig,
|
| 99 |
+
patch_size: int,
|
| 100 |
+
num_channels: int,
|
| 101 |
+
embed_dim: int,
|
| 102 |
+
stride: int,
|
| 103 |
+
padding: int,
|
| 104 |
+
dropout_rate: float,
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
super().__init__(**kwargs)
|
| 108 |
+
self.convolution_embeddings = TFCvtConvEmbeddings(
|
| 109 |
+
config,
|
| 110 |
+
patch_size=patch_size,
|
| 111 |
+
num_channels=num_channels,
|
| 112 |
+
embed_dim=embed_dim,
|
| 113 |
+
stride=stride,
|
| 114 |
+
padding=padding,
|
| 115 |
+
name="convolution_embeddings",
|
| 116 |
+
)
|
| 117 |
+
self.dropout = keras.layers.Dropout(dropout_rate)
|
| 118 |
+
|
| 119 |
+
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 120 |
+
hidden_state = self.convolution_embeddings(pixel_values)
|
| 121 |
+
hidden_state = self.dropout(hidden_state, training=training)
|
| 122 |
+
return hidden_state
|
| 123 |
+
|
| 124 |
+
def build(self, input_shape=None):
|
| 125 |
+
if self.built:
|
| 126 |
+
return
|
| 127 |
+
self.built = True
|
| 128 |
+
if getattr(self, "convolution_embeddings", None) is not None:
|
| 129 |
+
with tf.name_scope(self.convolution_embeddings.name):
|
| 130 |
+
self.convolution_embeddings.build(None)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TFCvtConvEmbeddings(keras.layers.Layer):
|
| 134 |
+
"""Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts."""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
config: CvtConfig,
|
| 139 |
+
patch_size: int,
|
| 140 |
+
num_channels: int,
|
| 141 |
+
embed_dim: int,
|
| 142 |
+
stride: int,
|
| 143 |
+
padding: int,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
super().__init__(**kwargs)
|
| 147 |
+
self.padding = keras.layers.ZeroPadding2D(padding=padding)
|
| 148 |
+
self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 149 |
+
self.projection = keras.layers.Conv2D(
|
| 150 |
+
filters=embed_dim,
|
| 151 |
+
kernel_size=patch_size,
|
| 152 |
+
strides=stride,
|
| 153 |
+
padding="valid",
|
| 154 |
+
data_format="channels_last",
|
| 155 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 156 |
+
name="projection",
|
| 157 |
+
)
|
| 158 |
+
# Using the same default epsilon as PyTorch
|
| 159 |
+
self.normalization = keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")
|
| 160 |
+
self.num_channels = num_channels
|
| 161 |
+
self.embed_dim = embed_dim
|
| 162 |
+
|
| 163 |
+
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
|
| 164 |
+
if isinstance(pixel_values, dict):
|
| 165 |
+
pixel_values = pixel_values["pixel_values"]
|
| 166 |
+
|
| 167 |
+
pixel_values = self.projection(self.padding(pixel_values))
|
| 168 |
+
|
| 169 |
+
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
|
| 170 |
+
batch_size, height, width, num_channels = shape_list(pixel_values)
|
| 171 |
+
hidden_size = height * width
|
| 172 |
+
pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels))
|
| 173 |
+
pixel_values = self.normalization(pixel_values)
|
| 174 |
+
|
| 175 |
+
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
|
| 176 |
+
pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
|
| 177 |
+
return pixel_values
|
| 178 |
+
|
| 179 |
+
def build(self, input_shape=None):
|
| 180 |
+
if self.built:
|
| 181 |
+
return
|
| 182 |
+
self.built = True
|
| 183 |
+
if getattr(self, "projection", None) is not None:
|
| 184 |
+
with tf.name_scope(self.projection.name):
|
| 185 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 186 |
+
if getattr(self, "normalization", None) is not None:
|
| 187 |
+
with tf.name_scope(self.normalization.name):
|
| 188 |
+
self.normalization.build([None, None, self.embed_dim])
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class TFCvtSelfAttentionConvProjection(keras.layers.Layer):
|
| 192 |
+
"""Convolutional projection layer."""
|
| 193 |
+
|
| 194 |
+
def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs):
|
| 195 |
+
super().__init__(**kwargs)
|
| 196 |
+
self.padding = keras.layers.ZeroPadding2D(padding=padding)
|
| 197 |
+
self.convolution = keras.layers.Conv2D(
|
| 198 |
+
filters=embed_dim,
|
| 199 |
+
kernel_size=kernel_size,
|
| 200 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 201 |
+
padding="valid",
|
| 202 |
+
strides=stride,
|
| 203 |
+
use_bias=False,
|
| 204 |
+
name="convolution",
|
| 205 |
+
groups=embed_dim,
|
| 206 |
+
)
|
| 207 |
+
# Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
|
| 208 |
+
self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
|
| 209 |
+
self.embed_dim = embed_dim
|
| 210 |
+
|
| 211 |
+
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 212 |
+
hidden_state = self.convolution(self.padding(hidden_state))
|
| 213 |
+
hidden_state = self.normalization(hidden_state, training=training)
|
| 214 |
+
return hidden_state
|
| 215 |
+
|
| 216 |
+
def build(self, input_shape=None):
|
| 217 |
+
if self.built:
|
| 218 |
+
return
|
| 219 |
+
self.built = True
|
| 220 |
+
if getattr(self, "convolution", None) is not None:
|
| 221 |
+
with tf.name_scope(self.convolution.name):
|
| 222 |
+
self.convolution.build([None, None, None, self.embed_dim])
|
| 223 |
+
if getattr(self, "normalization", None) is not None:
|
| 224 |
+
with tf.name_scope(self.normalization.name):
|
| 225 |
+
self.normalization.build([None, None, None, self.embed_dim])
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TFCvtSelfAttentionLinearProjection(keras.layers.Layer):
|
| 229 |
+
"""Linear projection layer used to flatten tokens into 1D."""
|
| 230 |
+
|
| 231 |
+
def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
|
| 232 |
+
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
|
| 233 |
+
batch_size, height, width, num_channels = shape_list(hidden_state)
|
| 234 |
+
hidden_size = height * width
|
| 235 |
+
hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
|
| 236 |
+
return hidden_state
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class TFCvtSelfAttentionProjection(keras.layers.Layer):
|
| 240 |
+
"""Convolutional Projection for Attention."""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
config: CvtConfig,
|
| 245 |
+
embed_dim: int,
|
| 246 |
+
kernel_size: int,
|
| 247 |
+
stride: int,
|
| 248 |
+
padding: int,
|
| 249 |
+
projection_method: str = "dw_bn",
|
| 250 |
+
**kwargs,
|
| 251 |
+
):
|
| 252 |
+
super().__init__(**kwargs)
|
| 253 |
+
if projection_method == "dw_bn":
|
| 254 |
+
self.convolution_projection = TFCvtSelfAttentionConvProjection(
|
| 255 |
+
config, embed_dim, kernel_size, stride, padding, name="convolution_projection"
|
| 256 |
+
)
|
| 257 |
+
self.linear_projection = TFCvtSelfAttentionLinearProjection()
|
| 258 |
+
|
| 259 |
+
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 260 |
+
hidden_state = self.convolution_projection(hidden_state, training=training)
|
| 261 |
+
hidden_state = self.linear_projection(hidden_state)
|
| 262 |
+
return hidden_state
|
| 263 |
+
|
| 264 |
+
def build(self, input_shape=None):
|
| 265 |
+
if self.built:
|
| 266 |
+
return
|
| 267 |
+
self.built = True
|
| 268 |
+
if getattr(self, "convolution_projection", None) is not None:
|
| 269 |
+
with tf.name_scope(self.convolution_projection.name):
|
| 270 |
+
self.convolution_projection.build(None)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class TFCvtSelfAttention(keras.layers.Layer):
|
| 274 |
+
"""
|
| 275 |
+
Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for
|
| 276 |
+
query, key, and value embeddings.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
config: CvtConfig,
|
| 282 |
+
num_heads: int,
|
| 283 |
+
embed_dim: int,
|
| 284 |
+
kernel_size: int,
|
| 285 |
+
stride_q: int,
|
| 286 |
+
stride_kv: int,
|
| 287 |
+
padding_q: int,
|
| 288 |
+
padding_kv: int,
|
| 289 |
+
qkv_projection_method: str,
|
| 290 |
+
qkv_bias: bool,
|
| 291 |
+
attention_drop_rate: float,
|
| 292 |
+
with_cls_token: bool = True,
|
| 293 |
+
**kwargs,
|
| 294 |
+
):
|
| 295 |
+
super().__init__(**kwargs)
|
| 296 |
+
self.scale = embed_dim**-0.5
|
| 297 |
+
self.with_cls_token = with_cls_token
|
| 298 |
+
self.embed_dim = embed_dim
|
| 299 |
+
self.num_heads = num_heads
|
| 300 |
+
|
| 301 |
+
self.convolution_projection_query = TFCvtSelfAttentionProjection(
|
| 302 |
+
config,
|
| 303 |
+
embed_dim,
|
| 304 |
+
kernel_size,
|
| 305 |
+
stride_q,
|
| 306 |
+
padding_q,
|
| 307 |
+
projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
|
| 308 |
+
name="convolution_projection_query",
|
| 309 |
+
)
|
| 310 |
+
self.convolution_projection_key = TFCvtSelfAttentionProjection(
|
| 311 |
+
config,
|
| 312 |
+
embed_dim,
|
| 313 |
+
kernel_size,
|
| 314 |
+
stride_kv,
|
| 315 |
+
padding_kv,
|
| 316 |
+
projection_method=qkv_projection_method,
|
| 317 |
+
name="convolution_projection_key",
|
| 318 |
+
)
|
| 319 |
+
self.convolution_projection_value = TFCvtSelfAttentionProjection(
|
| 320 |
+
config,
|
| 321 |
+
embed_dim,
|
| 322 |
+
kernel_size,
|
| 323 |
+
stride_kv,
|
| 324 |
+
padding_kv,
|
| 325 |
+
projection_method=qkv_projection_method,
|
| 326 |
+
name="convolution_projection_value",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
self.projection_query = keras.layers.Dense(
|
| 330 |
+
units=embed_dim,
|
| 331 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 332 |
+
use_bias=qkv_bias,
|
| 333 |
+
bias_initializer="zeros",
|
| 334 |
+
name="projection_query",
|
| 335 |
+
)
|
| 336 |
+
self.projection_key = keras.layers.Dense(
|
| 337 |
+
units=embed_dim,
|
| 338 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 339 |
+
use_bias=qkv_bias,
|
| 340 |
+
bias_initializer="zeros",
|
| 341 |
+
name="projection_key",
|
| 342 |
+
)
|
| 343 |
+
self.projection_value = keras.layers.Dense(
|
| 344 |
+
units=embed_dim,
|
| 345 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 346 |
+
use_bias=qkv_bias,
|
| 347 |
+
bias_initializer="zeros",
|
| 348 |
+
name="projection_value",
|
| 349 |
+
)
|
| 350 |
+
self.dropout = keras.layers.Dropout(attention_drop_rate)
|
| 351 |
+
|
| 352 |
+
def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor:
|
| 353 |
+
batch_size, hidden_size, _ = shape_list(hidden_state)
|
| 354 |
+
head_dim = self.embed_dim // self.num_heads
|
| 355 |
+
hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim))
|
| 356 |
+
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3))
|
| 357 |
+
return hidden_state
|
| 358 |
+
|
| 359 |
+
def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
|
| 360 |
+
if self.with_cls_token:
|
| 361 |
+
cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
|
| 362 |
+
|
| 363 |
+
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
|
| 364 |
+
batch_size, hidden_size, num_channels = shape_list(hidden_state)
|
| 365 |
+
hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
|
| 366 |
+
|
| 367 |
+
key = self.convolution_projection_key(hidden_state, training=training)
|
| 368 |
+
query = self.convolution_projection_query(hidden_state, training=training)
|
| 369 |
+
value = self.convolution_projection_value(hidden_state, training=training)
|
| 370 |
+
|
| 371 |
+
if self.with_cls_token:
|
| 372 |
+
query = tf.concat((cls_token, query), axis=1)
|
| 373 |
+
key = tf.concat((cls_token, key), axis=1)
|
| 374 |
+
value = tf.concat((cls_token, value), axis=1)
|
| 375 |
+
|
| 376 |
+
head_dim = self.embed_dim // self.num_heads
|
| 377 |
+
|
| 378 |
+
query = self.rearrange_for_multi_head_attention(self.projection_query(query))
|
| 379 |
+
key = self.rearrange_for_multi_head_attention(self.projection_key(key))
|
| 380 |
+
value = self.rearrange_for_multi_head_attention(self.projection_value(value))
|
| 381 |
+
|
| 382 |
+
attention_score = tf.matmul(query, key, transpose_b=True) * self.scale
|
| 383 |
+
attention_probs = stable_softmax(logits=attention_score, axis=-1)
|
| 384 |
+
attention_probs = self.dropout(attention_probs, training=training)
|
| 385 |
+
|
| 386 |
+
context = tf.matmul(attention_probs, value)
|
| 387 |
+
# "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)"
|
| 388 |
+
_, _, hidden_size, _ = shape_list(context)
|
| 389 |
+
context = tf.transpose(context, perm=(0, 2, 1, 3))
|
| 390 |
+
context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
|
| 391 |
+
return context
|
| 392 |
+
|
| 393 |
+
def build(self, input_shape=None):
|
| 394 |
+
if self.built:
|
| 395 |
+
return
|
| 396 |
+
self.built = True
|
| 397 |
+
if getattr(self, "convolution_projection_query", None) is not None:
|
| 398 |
+
with tf.name_scope(self.convolution_projection_query.name):
|
| 399 |
+
self.convolution_projection_query.build(None)
|
| 400 |
+
if getattr(self, "convolution_projection_key", None) is not None:
|
| 401 |
+
with tf.name_scope(self.convolution_projection_key.name):
|
| 402 |
+
self.convolution_projection_key.build(None)
|
| 403 |
+
if getattr(self, "convolution_projection_value", None) is not None:
|
| 404 |
+
with tf.name_scope(self.convolution_projection_value.name):
|
| 405 |
+
self.convolution_projection_value.build(None)
|
| 406 |
+
if getattr(self, "projection_query", None) is not None:
|
| 407 |
+
with tf.name_scope(self.projection_query.name):
|
| 408 |
+
self.projection_query.build([None, None, self.embed_dim])
|
| 409 |
+
if getattr(self, "projection_key", None) is not None:
|
| 410 |
+
with tf.name_scope(self.projection_key.name):
|
| 411 |
+
self.projection_key.build([None, None, self.embed_dim])
|
| 412 |
+
if getattr(self, "projection_value", None) is not None:
|
| 413 |
+
with tf.name_scope(self.projection_value.name):
|
| 414 |
+
self.projection_value.build([None, None, self.embed_dim])
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class TFCvtSelfOutput(keras.layers.Layer):
|
| 418 |
+
"""Output of the Attention layer ."""
|
| 419 |
+
|
| 420 |
+
def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs):
|
| 421 |
+
super().__init__(**kwargs)
|
| 422 |
+
self.dense = keras.layers.Dense(
|
| 423 |
+
units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 424 |
+
)
|
| 425 |
+
self.dropout = keras.layers.Dropout(drop_rate)
|
| 426 |
+
self.embed_dim = embed_dim
|
| 427 |
+
|
| 428 |
+
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 429 |
+
hidden_state = self.dense(inputs=hidden_state)
|
| 430 |
+
hidden_state = self.dropout(inputs=hidden_state, training=training)
|
| 431 |
+
return hidden_state
|
| 432 |
+
|
| 433 |
+
def build(self, input_shape=None):
|
| 434 |
+
if self.built:
|
| 435 |
+
return
|
| 436 |
+
self.built = True
|
| 437 |
+
if getattr(self, "dense", None) is not None:
|
| 438 |
+
with tf.name_scope(self.dense.name):
|
| 439 |
+
self.dense.build([None, None, self.embed_dim])
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class TFCvtAttention(keras.layers.Layer):
|
| 443 |
+
"""Attention layer. First chunk of the convolutional transformer block."""
|
| 444 |
+
|
| 445 |
+
def __init__(
|
| 446 |
+
self,
|
| 447 |
+
config: CvtConfig,
|
| 448 |
+
num_heads: int,
|
| 449 |
+
embed_dim: int,
|
| 450 |
+
kernel_size: int,
|
| 451 |
+
stride_q: int,
|
| 452 |
+
stride_kv: int,
|
| 453 |
+
padding_q: int,
|
| 454 |
+
padding_kv: int,
|
| 455 |
+
qkv_projection_method: str,
|
| 456 |
+
qkv_bias: bool,
|
| 457 |
+
attention_drop_rate: float,
|
| 458 |
+
drop_rate: float,
|
| 459 |
+
with_cls_token: bool = True,
|
| 460 |
+
**kwargs,
|
| 461 |
+
):
|
| 462 |
+
super().__init__(**kwargs)
|
| 463 |
+
self.attention = TFCvtSelfAttention(
|
| 464 |
+
config,
|
| 465 |
+
num_heads,
|
| 466 |
+
embed_dim,
|
| 467 |
+
kernel_size,
|
| 468 |
+
stride_q,
|
| 469 |
+
stride_kv,
|
| 470 |
+
padding_q,
|
| 471 |
+
padding_kv,
|
| 472 |
+
qkv_projection_method,
|
| 473 |
+
qkv_bias,
|
| 474 |
+
attention_drop_rate,
|
| 475 |
+
with_cls_token,
|
| 476 |
+
name="attention",
|
| 477 |
+
)
|
| 478 |
+
self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output")
|
| 479 |
+
|
| 480 |
+
def prune_heads(self, heads):
|
| 481 |
+
raise NotImplementedError
|
| 482 |
+
|
| 483 |
+
def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False):
|
| 484 |
+
self_output = self.attention(hidden_state, height, width, training=training)
|
| 485 |
+
attention_output = self.dense_output(self_output, training=training)
|
| 486 |
+
return attention_output
|
| 487 |
+
|
| 488 |
+
def build(self, input_shape=None):
|
| 489 |
+
if self.built:
|
| 490 |
+
return
|
| 491 |
+
self.built = True
|
| 492 |
+
if getattr(self, "attention", None) is not None:
|
| 493 |
+
with tf.name_scope(self.attention.name):
|
| 494 |
+
self.attention.build(None)
|
| 495 |
+
if getattr(self, "dense_output", None) is not None:
|
| 496 |
+
with tf.name_scope(self.dense_output.name):
|
| 497 |
+
self.dense_output.build(None)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class TFCvtIntermediate(keras.layers.Layer):
|
| 501 |
+
"""Intermediate dense layer. Second chunk of the convolutional transformer block."""
|
| 502 |
+
|
| 503 |
+
def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs):
|
| 504 |
+
super().__init__(**kwargs)
|
| 505 |
+
self.dense = keras.layers.Dense(
|
| 506 |
+
units=int(embed_dim * mlp_ratio),
|
| 507 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 508 |
+
activation="gelu",
|
| 509 |
+
name="dense",
|
| 510 |
+
)
|
| 511 |
+
self.embed_dim = embed_dim
|
| 512 |
+
|
| 513 |
+
def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
|
| 514 |
+
hidden_state = self.dense(hidden_state)
|
| 515 |
+
return hidden_state
|
| 516 |
+
|
| 517 |
+
def build(self, input_shape=None):
|
| 518 |
+
if self.built:
|
| 519 |
+
return
|
| 520 |
+
self.built = True
|
| 521 |
+
if getattr(self, "dense", None) is not None:
|
| 522 |
+
with tf.name_scope(self.dense.name):
|
| 523 |
+
self.dense.build([None, None, self.embed_dim])
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class TFCvtOutput(keras.layers.Layer):
|
| 527 |
+
"""
|
| 528 |
+
Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection.
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs):
|
| 532 |
+
super().__init__(**kwargs)
|
| 533 |
+
self.dense = keras.layers.Dense(
|
| 534 |
+
units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 535 |
+
)
|
| 536 |
+
self.dropout = keras.layers.Dropout(drop_rate)
|
| 537 |
+
self.embed_dim = embed_dim
|
| 538 |
+
self.mlp_ratio = mlp_ratio
|
| 539 |
+
|
| 540 |
+
def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 541 |
+
hidden_state = self.dense(inputs=hidden_state)
|
| 542 |
+
hidden_state = self.dropout(inputs=hidden_state, training=training)
|
| 543 |
+
hidden_state = hidden_state + input_tensor
|
| 544 |
+
return hidden_state
|
| 545 |
+
|
| 546 |
+
def build(self, input_shape=None):
|
| 547 |
+
if self.built:
|
| 548 |
+
return
|
| 549 |
+
self.built = True
|
| 550 |
+
if getattr(self, "dense", None) is not None:
|
| 551 |
+
with tf.name_scope(self.dense.name):
|
| 552 |
+
self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)])
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class TFCvtLayer(keras.layers.Layer):
|
| 556 |
+
"""
|
| 557 |
+
Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It
|
| 558 |
+
consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the
|
| 559 |
+
`Block` class in the original implementation.
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
def __init__(
|
| 563 |
+
self,
|
| 564 |
+
config: CvtConfig,
|
| 565 |
+
num_heads: int,
|
| 566 |
+
embed_dim: int,
|
| 567 |
+
kernel_size: int,
|
| 568 |
+
stride_q: int,
|
| 569 |
+
stride_kv: int,
|
| 570 |
+
padding_q: int,
|
| 571 |
+
padding_kv: int,
|
| 572 |
+
qkv_projection_method: str,
|
| 573 |
+
qkv_bias: bool,
|
| 574 |
+
attention_drop_rate: float,
|
| 575 |
+
drop_rate: float,
|
| 576 |
+
mlp_ratio: float,
|
| 577 |
+
drop_path_rate: float,
|
| 578 |
+
with_cls_token: bool = True,
|
| 579 |
+
**kwargs,
|
| 580 |
+
):
|
| 581 |
+
super().__init__(**kwargs)
|
| 582 |
+
self.attention = TFCvtAttention(
|
| 583 |
+
config,
|
| 584 |
+
num_heads,
|
| 585 |
+
embed_dim,
|
| 586 |
+
kernel_size,
|
| 587 |
+
stride_q,
|
| 588 |
+
stride_kv,
|
| 589 |
+
padding_q,
|
| 590 |
+
padding_kv,
|
| 591 |
+
qkv_projection_method,
|
| 592 |
+
qkv_bias,
|
| 593 |
+
attention_drop_rate,
|
| 594 |
+
drop_rate,
|
| 595 |
+
with_cls_token,
|
| 596 |
+
name="attention",
|
| 597 |
+
)
|
| 598 |
+
self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate")
|
| 599 |
+
self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output")
|
| 600 |
+
# Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour.
|
| 601 |
+
self.drop_path = (
|
| 602 |
+
TFCvtDropPath(drop_path_rate, name="drop_path")
|
| 603 |
+
if drop_path_rate > 0.0
|
| 604 |
+
else keras.layers.Activation("linear", name="drop_path")
|
| 605 |
+
)
|
| 606 |
+
# Using the same default epsilon as PyTorch
|
| 607 |
+
self.layernorm_before = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before")
|
| 608 |
+
self.layernorm_after = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after")
|
| 609 |
+
self.embed_dim = embed_dim
|
| 610 |
+
|
| 611 |
+
def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
|
| 612 |
+
# in Cvt, layernorm is applied before self-attention
|
| 613 |
+
attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training)
|
| 614 |
+
attention_output = self.drop_path(attention_output, training=training)
|
| 615 |
+
|
| 616 |
+
# first residual connection
|
| 617 |
+
hidden_state = attention_output + hidden_state
|
| 618 |
+
|
| 619 |
+
# in Cvt, layernorm is also applied after self-attention
|
| 620 |
+
layer_output = self.layernorm_after(hidden_state)
|
| 621 |
+
layer_output = self.intermediate(layer_output)
|
| 622 |
+
|
| 623 |
+
# second residual connection is done here
|
| 624 |
+
layer_output = self.dense_output(layer_output, hidden_state)
|
| 625 |
+
layer_output = self.drop_path(layer_output, training=training)
|
| 626 |
+
return layer_output
|
| 627 |
+
|
| 628 |
+
def build(self, input_shape=None):
|
| 629 |
+
if self.built:
|
| 630 |
+
return
|
| 631 |
+
self.built = True
|
| 632 |
+
if getattr(self, "attention", None) is not None:
|
| 633 |
+
with tf.name_scope(self.attention.name):
|
| 634 |
+
self.attention.build(None)
|
| 635 |
+
if getattr(self, "intermediate", None) is not None:
|
| 636 |
+
with tf.name_scope(self.intermediate.name):
|
| 637 |
+
self.intermediate.build(None)
|
| 638 |
+
if getattr(self, "dense_output", None) is not None:
|
| 639 |
+
with tf.name_scope(self.dense_output.name):
|
| 640 |
+
self.dense_output.build(None)
|
| 641 |
+
if getattr(self, "drop_path", None) is not None:
|
| 642 |
+
with tf.name_scope(self.drop_path.name):
|
| 643 |
+
self.drop_path.build(None)
|
| 644 |
+
if getattr(self, "layernorm_before", None) is not None:
|
| 645 |
+
with tf.name_scope(self.layernorm_before.name):
|
| 646 |
+
self.layernorm_before.build([None, None, self.embed_dim])
|
| 647 |
+
if getattr(self, "layernorm_after", None) is not None:
|
| 648 |
+
with tf.name_scope(self.layernorm_after.name):
|
| 649 |
+
self.layernorm_after.build([None, None, self.embed_dim])
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class TFCvtStage(keras.layers.Layer):
|
| 653 |
+
"""
|
| 654 |
+
Cvt stage (encoder block). Each stage has 2 parts :
|
| 655 |
+
- (1) A Convolutional Token Embedding layer
|
| 656 |
+
- (2) A Convolutional Transformer Block (layer).
|
| 657 |
+
The classification token is added only in the last stage.
|
| 658 |
+
|
| 659 |
+
Args:
|
| 660 |
+
config ([`CvtConfig`]): Model configuration class.
|
| 661 |
+
stage (`int`): Stage number.
|
| 662 |
+
"""
|
| 663 |
+
|
| 664 |
+
def __init__(self, config: CvtConfig, stage: int, **kwargs):
|
| 665 |
+
super().__init__(**kwargs)
|
| 666 |
+
self.config = config
|
| 667 |
+
self.stage = stage
|
| 668 |
+
if self.config.cls_token[self.stage]:
|
| 669 |
+
self.cls_token = self.add_weight(
|
| 670 |
+
shape=(1, 1, self.config.embed_dim[-1]),
|
| 671 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 672 |
+
trainable=True,
|
| 673 |
+
name="cvt.encoder.stages.2.cls_token",
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
self.embedding = TFCvtEmbeddings(
|
| 677 |
+
self.config,
|
| 678 |
+
patch_size=config.patch_sizes[self.stage],
|
| 679 |
+
num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
|
| 680 |
+
stride=config.patch_stride[self.stage],
|
| 681 |
+
embed_dim=config.embed_dim[self.stage],
|
| 682 |
+
padding=config.patch_padding[self.stage],
|
| 683 |
+
dropout_rate=config.drop_rate[self.stage],
|
| 684 |
+
name="embedding",
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage])
|
| 688 |
+
drop_path_rates = [x.numpy().item() for x in drop_path_rates]
|
| 689 |
+
self.layers = [
|
| 690 |
+
TFCvtLayer(
|
| 691 |
+
config,
|
| 692 |
+
num_heads=config.num_heads[self.stage],
|
| 693 |
+
embed_dim=config.embed_dim[self.stage],
|
| 694 |
+
kernel_size=config.kernel_qkv[self.stage],
|
| 695 |
+
stride_q=config.stride_q[self.stage],
|
| 696 |
+
stride_kv=config.stride_kv[self.stage],
|
| 697 |
+
padding_q=config.padding_q[self.stage],
|
| 698 |
+
padding_kv=config.padding_kv[self.stage],
|
| 699 |
+
qkv_projection_method=config.qkv_projection_method[self.stage],
|
| 700 |
+
qkv_bias=config.qkv_bias[self.stage],
|
| 701 |
+
attention_drop_rate=config.attention_drop_rate[self.stage],
|
| 702 |
+
drop_rate=config.drop_rate[self.stage],
|
| 703 |
+
mlp_ratio=config.mlp_ratio[self.stage],
|
| 704 |
+
drop_path_rate=drop_path_rates[self.stage],
|
| 705 |
+
with_cls_token=config.cls_token[self.stage],
|
| 706 |
+
name=f"layers.{j}",
|
| 707 |
+
)
|
| 708 |
+
for j in range(config.depth[self.stage])
|
| 709 |
+
]
|
| 710 |
+
|
| 711 |
+
def call(self, hidden_state: tf.Tensor, training: bool = False):
|
| 712 |
+
cls_token = None
|
| 713 |
+
hidden_state = self.embedding(hidden_state, training)
|
| 714 |
+
|
| 715 |
+
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
|
| 716 |
+
batch_size, height, width, num_channels = shape_list(hidden_state)
|
| 717 |
+
hidden_size = height * width
|
| 718 |
+
hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
|
| 719 |
+
|
| 720 |
+
if self.config.cls_token[self.stage]:
|
| 721 |
+
cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
|
| 722 |
+
hidden_state = tf.concat((cls_token, hidden_state), axis=1)
|
| 723 |
+
|
| 724 |
+
for layer in self.layers:
|
| 725 |
+
layer_outputs = layer(hidden_state, height, width, training=training)
|
| 726 |
+
hidden_state = layer_outputs
|
| 727 |
+
|
| 728 |
+
if self.config.cls_token[self.stage]:
|
| 729 |
+
cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
|
| 730 |
+
|
| 731 |
+
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
|
| 732 |
+
hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
|
| 733 |
+
return hidden_state, cls_token
|
| 734 |
+
|
| 735 |
+
def build(self, input_shape=None):
|
| 736 |
+
if self.built:
|
| 737 |
+
return
|
| 738 |
+
self.built = True
|
| 739 |
+
if getattr(self, "embedding", None) is not None:
|
| 740 |
+
with tf.name_scope(self.embedding.name):
|
| 741 |
+
self.embedding.build(None)
|
| 742 |
+
if getattr(self, "layers", None) is not None:
|
| 743 |
+
for layer in self.layers:
|
| 744 |
+
with tf.name_scope(layer.name):
|
| 745 |
+
layer.build(None)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
class TFCvtEncoder(keras.layers.Layer):
|
| 749 |
+
"""
|
| 750 |
+
Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers
|
| 751 |
+
(depth) being 1, 2 and 10.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
config ([`CvtConfig`]): Model configuration class.
|
| 755 |
+
"""
|
| 756 |
+
|
| 757 |
+
config_class = CvtConfig
|
| 758 |
+
|
| 759 |
+
def __init__(self, config: CvtConfig, **kwargs):
|
| 760 |
+
super().__init__(**kwargs)
|
| 761 |
+
self.config = config
|
| 762 |
+
self.stages = [
|
| 763 |
+
TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth))
|
| 764 |
+
]
|
| 765 |
+
|
| 766 |
+
def call(
|
| 767 |
+
self,
|
| 768 |
+
pixel_values: TFModelInputType,
|
| 769 |
+
output_hidden_states: Optional[bool] = False,
|
| 770 |
+
return_dict: Optional[bool] = True,
|
| 771 |
+
training: Optional[bool] = False,
|
| 772 |
+
) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
|
| 773 |
+
all_hidden_states = () if output_hidden_states else None
|
| 774 |
+
hidden_state = pixel_values
|
| 775 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
|
| 776 |
+
# as input format. So change the input format to (batch_size, height, width, num_channels).
|
| 777 |
+
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))
|
| 778 |
+
|
| 779 |
+
cls_token = None
|
| 780 |
+
for _, (stage_module) in enumerate(self.stages):
|
| 781 |
+
hidden_state, cls_token = stage_module(hidden_state, training=training)
|
| 782 |
+
if output_hidden_states:
|
| 783 |
+
all_hidden_states = all_hidden_states + (hidden_state,)
|
| 784 |
+
|
| 785 |
+
# Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules
|
| 786 |
+
hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2))
|
| 787 |
+
if output_hidden_states:
|
| 788 |
+
all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states])
|
| 789 |
+
|
| 790 |
+
if not return_dict:
|
| 791 |
+
return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
|
| 792 |
+
|
| 793 |
+
return TFBaseModelOutputWithCLSToken(
|
| 794 |
+
last_hidden_state=hidden_state,
|
| 795 |
+
cls_token_value=cls_token,
|
| 796 |
+
hidden_states=all_hidden_states,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
def build(self, input_shape=None):
|
| 800 |
+
if self.built:
|
| 801 |
+
return
|
| 802 |
+
self.built = True
|
| 803 |
+
if getattr(self, "stages", None) is not None:
|
| 804 |
+
for layer in self.stages:
|
| 805 |
+
with tf.name_scope(layer.name):
|
| 806 |
+
layer.build(None)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
@keras_serializable
|
| 810 |
+
class TFCvtMainLayer(keras.layers.Layer):
|
| 811 |
+
"""Construct the Cvt model."""
|
| 812 |
+
|
| 813 |
+
config_class = CvtConfig
|
| 814 |
+
|
| 815 |
+
def __init__(self, config: CvtConfig, **kwargs):
|
| 816 |
+
super().__init__(**kwargs)
|
| 817 |
+
self.config = config
|
| 818 |
+
self.encoder = TFCvtEncoder(config, name="encoder")
|
| 819 |
+
|
| 820 |
+
@unpack_inputs
|
| 821 |
+
def call(
|
| 822 |
+
self,
|
| 823 |
+
pixel_values: TFModelInputType | None = None,
|
| 824 |
+
output_hidden_states: Optional[bool] = None,
|
| 825 |
+
return_dict: Optional[bool] = None,
|
| 826 |
+
training: Optional[bool] = False,
|
| 827 |
+
) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
|
| 828 |
+
if pixel_values is None:
|
| 829 |
+
raise ValueError("You have to specify pixel_values")
|
| 830 |
+
|
| 831 |
+
encoder_outputs = self.encoder(
|
| 832 |
+
pixel_values,
|
| 833 |
+
output_hidden_states=output_hidden_states,
|
| 834 |
+
return_dict=return_dict,
|
| 835 |
+
training=training,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
sequence_output = encoder_outputs[0]
|
| 839 |
+
|
| 840 |
+
if not return_dict:
|
| 841 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 842 |
+
|
| 843 |
+
return TFBaseModelOutputWithCLSToken(
|
| 844 |
+
last_hidden_state=sequence_output,
|
| 845 |
+
cls_token_value=encoder_outputs.cls_token_value,
|
| 846 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
def build(self, input_shape=None):
|
| 850 |
+
if self.built:
|
| 851 |
+
return
|
| 852 |
+
self.built = True
|
| 853 |
+
if getattr(self, "encoder", None) is not None:
|
| 854 |
+
with tf.name_scope(self.encoder.name):
|
| 855 |
+
self.encoder.build(None)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
class TFCvtPreTrainedModel(TFPreTrainedModel):
|
| 859 |
+
"""
|
| 860 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 861 |
+
models.
|
| 862 |
+
"""
|
| 863 |
+
|
| 864 |
+
config_class = CvtConfig
|
| 865 |
+
base_model_prefix = "cvt"
|
| 866 |
+
main_input_name = "pixel_values"
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
TFCVT_START_DOCSTRING = r"""
|
| 870 |
+
|
| 871 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 872 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 873 |
+
etc.)
|
| 874 |
+
|
| 875 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 876 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 877 |
+
behavior.
|
| 878 |
+
|
| 879 |
+
<Tip>
|
| 880 |
+
|
| 881 |
+
TF 2.0 models accepts two formats as inputs:
|
| 882 |
+
|
| 883 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 884 |
+
- having all inputs as a list, tuple or dict in the first positional arguments.
|
| 885 |
+
|
| 886 |
+
This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
|
| 887 |
+
tensors in the first argument of the model call function: `model(inputs)`.
|
| 888 |
+
|
| 889 |
+
</Tip>
|
| 890 |
+
|
| 891 |
+
Args:
|
| 892 |
+
config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
|
| 893 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 894 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 895 |
+
"""
|
| 896 |
+
|
| 897 |
+
TFCVT_INPUTS_DOCSTRING = r"""
|
| 898 |
+
Args:
|
| 899 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 900 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
|
| 901 |
+
for details.
|
| 902 |
+
|
| 903 |
+
output_hidden_states (`bool`, *optional*):
|
| 904 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 905 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 906 |
+
used instead.
|
| 907 |
+
return_dict (`bool`, *optional*):
|
| 908 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 909 |
+
eager mode, in graph mode the value will always be set to True.
|
| 910 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 911 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 912 |
+
behaviors between training and evaluation).
|
| 913 |
+
"""
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
@add_start_docstrings(
|
| 917 |
+
"The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
|
| 918 |
+
TFCVT_START_DOCSTRING,
|
| 919 |
+
)
|
| 920 |
+
class TFCvtModel(TFCvtPreTrainedModel):
|
| 921 |
+
def __init__(self, config: CvtConfig, *inputs, **kwargs):
|
| 922 |
+
super().__init__(config, *inputs, **kwargs)
|
| 923 |
+
|
| 924 |
+
self.cvt = TFCvtMainLayer(config, name="cvt")
|
| 925 |
+
|
| 926 |
+
@unpack_inputs
|
| 927 |
+
@add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
|
| 928 |
+
@replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC)
|
| 929 |
+
def call(
|
| 930 |
+
self,
|
| 931 |
+
pixel_values: tf.Tensor | None = None,
|
| 932 |
+
output_hidden_states: Optional[bool] = None,
|
| 933 |
+
return_dict: Optional[bool] = None,
|
| 934 |
+
training: Optional[bool] = False,
|
| 935 |
+
) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
|
| 936 |
+
r"""
|
| 937 |
+
Returns:
|
| 938 |
+
|
| 939 |
+
Examples:
|
| 940 |
+
|
| 941 |
+
```python
|
| 942 |
+
>>> from transformers import AutoImageProcessor, TFCvtModel
|
| 943 |
+
>>> from PIL import Image
|
| 944 |
+
>>> import requests
|
| 945 |
+
|
| 946 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 947 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 948 |
+
|
| 949 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
|
| 950 |
+
>>> model = TFCvtModel.from_pretrained("microsoft/cvt-13")
|
| 951 |
+
|
| 952 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 953 |
+
>>> outputs = model(**inputs)
|
| 954 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 955 |
+
```"""
|
| 956 |
+
|
| 957 |
+
if pixel_values is None:
|
| 958 |
+
raise ValueError("You have to specify pixel_values")
|
| 959 |
+
|
| 960 |
+
outputs = self.cvt(
|
| 961 |
+
pixel_values=pixel_values,
|
| 962 |
+
output_hidden_states=output_hidden_states,
|
| 963 |
+
return_dict=return_dict,
|
| 964 |
+
training=training,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
if not return_dict:
|
| 968 |
+
return (outputs[0],) + outputs[1:]
|
| 969 |
+
|
| 970 |
+
return TFBaseModelOutputWithCLSToken(
|
| 971 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 972 |
+
cls_token_value=outputs.cls_token_value,
|
| 973 |
+
hidden_states=outputs.hidden_states,
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
def build(self, input_shape=None):
|
| 977 |
+
if self.built:
|
| 978 |
+
return
|
| 979 |
+
self.built = True
|
| 980 |
+
if getattr(self, "cvt", None) is not None:
|
| 981 |
+
with tf.name_scope(self.cvt.name):
|
| 982 |
+
self.cvt.build(None)
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
@add_start_docstrings(
|
| 986 |
+
"""
|
| 987 |
+
Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 988 |
+
the [CLS] token) e.g. for ImageNet.
|
| 989 |
+
""",
|
| 990 |
+
TFCVT_START_DOCSTRING,
|
| 991 |
+
)
|
| 992 |
+
class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss):
|
| 993 |
+
def __init__(self, config: CvtConfig, *inputs, **kwargs):
|
| 994 |
+
super().__init__(config, *inputs, **kwargs)
|
| 995 |
+
|
| 996 |
+
self.num_labels = config.num_labels
|
| 997 |
+
self.cvt = TFCvtMainLayer(config, name="cvt")
|
| 998 |
+
# Using same default epsilon as in the original implementation.
|
| 999 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")
|
| 1000 |
+
|
| 1001 |
+
# Classifier head
|
| 1002 |
+
self.classifier = keras.layers.Dense(
|
| 1003 |
+
units=config.num_labels,
|
| 1004 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1005 |
+
use_bias=True,
|
| 1006 |
+
bias_initializer="zeros",
|
| 1007 |
+
name="classifier",
|
| 1008 |
+
)
|
| 1009 |
+
self.config = config
|
| 1010 |
+
|
| 1011 |
+
@unpack_inputs
|
| 1012 |
+
@add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
|
| 1013 |
+
@replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
|
| 1014 |
+
def call(
|
| 1015 |
+
self,
|
| 1016 |
+
pixel_values: tf.Tensor | None = None,
|
| 1017 |
+
labels: tf.Tensor | None = None,
|
| 1018 |
+
output_hidden_states: Optional[bool] = None,
|
| 1019 |
+
return_dict: Optional[bool] = None,
|
| 1020 |
+
training: Optional[bool] = False,
|
| 1021 |
+
) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]:
|
| 1022 |
+
r"""
|
| 1023 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1024 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 1025 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1026 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1027 |
+
|
| 1028 |
+
Returns:
|
| 1029 |
+
|
| 1030 |
+
Examples:
|
| 1031 |
+
|
| 1032 |
+
```python
|
| 1033 |
+
>>> from transformers import AutoImageProcessor, TFCvtForImageClassification
|
| 1034 |
+
>>> import tensorflow as tf
|
| 1035 |
+
>>> from PIL import Image
|
| 1036 |
+
>>> import requests
|
| 1037 |
+
|
| 1038 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1039 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1040 |
+
|
| 1041 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
|
| 1042 |
+
>>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13")
|
| 1043 |
+
|
| 1044 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 1045 |
+
>>> outputs = model(**inputs)
|
| 1046 |
+
>>> logits = outputs.logits
|
| 1047 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
| 1048 |
+
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
| 1049 |
+
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
| 1050 |
+
```"""
|
| 1051 |
+
|
| 1052 |
+
outputs = self.cvt(
|
| 1053 |
+
pixel_values,
|
| 1054 |
+
output_hidden_states=output_hidden_states,
|
| 1055 |
+
return_dict=return_dict,
|
| 1056 |
+
training=training,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
sequence_output = outputs[0]
|
| 1060 |
+
cls_token = outputs[1]
|
| 1061 |
+
if self.config.cls_token[-1]:
|
| 1062 |
+
sequence_output = self.layernorm(cls_token)
|
| 1063 |
+
else:
|
| 1064 |
+
# rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels"
|
| 1065 |
+
batch_size, num_channels, height, width = shape_list(sequence_output)
|
| 1066 |
+
sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width))
|
| 1067 |
+
sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1))
|
| 1068 |
+
sequence_output = self.layernorm(sequence_output)
|
| 1069 |
+
|
| 1070 |
+
sequence_output_mean = tf.reduce_mean(sequence_output, axis=1)
|
| 1071 |
+
logits = self.classifier(sequence_output_mean)
|
| 1072 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1073 |
+
|
| 1074 |
+
if not return_dict:
|
| 1075 |
+
output = (logits,) + outputs[2:]
|
| 1076 |
+
return ((loss,) + output) if loss is not None else output
|
| 1077 |
+
|
| 1078 |
+
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
| 1079 |
+
|
| 1080 |
+
def build(self, input_shape=None):
|
| 1081 |
+
if self.built:
|
| 1082 |
+
return
|
| 1083 |
+
self.built = True
|
| 1084 |
+
if getattr(self, "cvt", None) is not None:
|
| 1085 |
+
with tf.name_scope(self.cvt.name):
|
| 1086 |
+
self.cvt.build(None)
|
| 1087 |
+
if getattr(self, "layernorm", None) is not None:
|
| 1088 |
+
with tf.name_scope(self.layernorm.name):
|
| 1089 |
+
self.layernorm.build([None, None, self.config.embed_dim[-1]])
|
| 1090 |
+
if getattr(self, "classifier", None) is not None:
|
| 1091 |
+
if hasattr(self.classifier, "name"):
|
| 1092 |
+
with tf.name_scope(self.classifier.name):
|
| 1093 |
+
self.classifier.build([None, None, self.config.embed_dim[-1]])
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
__all__ = ["TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/dab_detr/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import _LazyModule
|
| 18 |
+
from ...utils.import_utils import define_import_structure
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from .configuration_dab_detr import *
|
| 23 |
+
from .modeling_dab_detr import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/dab_detr/configuration_dab_detr.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DAB-DETR model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 20 |
+
from ..auto import CONFIG_MAPPING
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DabDetrConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate
|
| 29 |
+
a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a
|
| 30 |
+
configuration with the defaults will yield a similar configuration to that of the DAB-DETR
|
| 31 |
+
[IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
| 38 |
+
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
| 39 |
+
API.
|
| 40 |
+
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
| 41 |
+
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
| 42 |
+
case it will default to `ResNetConfig()`.
|
| 43 |
+
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
| 44 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 45 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 46 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 47 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether to use pretrained weights for the backbone.
|
| 49 |
+
backbone_kwargs (`dict`, *optional*):
|
| 50 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 51 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 52 |
+
num_queries (`int`, *optional*, defaults to 300):
|
| 53 |
+
Number of object queries, i.e. detection slots. This is the maximal number of objects
|
| 54 |
+
[`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
|
| 55 |
+
encoder_layers (`int`, *optional*, defaults to 6):
|
| 56 |
+
Number of encoder layers.
|
| 57 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 58 |
+
Dimension of the "intermediate" (often named feed-forward) layer in encoder.
|
| 59 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 60 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 61 |
+
decoder_layers (`int`, *optional*, defaults to 6):
|
| 62 |
+
Number of decoder layers.
|
| 63 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
| 64 |
+
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
| 65 |
+
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 66 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 67 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Indicates whether the transformer model architecture is an encoder-decoder or not.
|
| 69 |
+
activation_function (`str` or `function`, *optional*, defaults to `"prelu"`):
|
| 70 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 71 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 72 |
+
hidden_size (`int`, *optional*, defaults to 256):
|
| 73 |
+
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
| 74 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 75 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 76 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 77 |
+
The dropout ratio for the attention probabilities.
|
| 78 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 79 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 80 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 81 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 82 |
+
init_xavier_std (`float`, *optional*, defaults to 1.0):
|
| 83 |
+
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
|
| 84 |
+
auxiliary_loss (`bool`, *optional*, defaults to `False`):
|
| 85 |
+
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 86 |
+
dilation (`bool`, *optional*, defaults to `False`):
|
| 87 |
+
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`.
|
| 88 |
+
class_cost (`float`, *optional*, defaults to 2):
|
| 89 |
+
Relative weight of the classification error in the Hungarian matching cost.
|
| 90 |
+
bbox_cost (`float`, *optional*, defaults to 5):
|
| 91 |
+
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
|
| 92 |
+
giou_cost (`float`, *optional*, defaults to 2):
|
| 93 |
+
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
|
| 94 |
+
cls_loss_coefficient (`float`, *optional*, defaults to 2):
|
| 95 |
+
Relative weight of the classification loss in the object detection loss function.
|
| 96 |
+
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
|
| 97 |
+
Relative weight of the L1 bounding box loss in the object detection loss.
|
| 98 |
+
giou_loss_coefficient (`float`, *optional*, defaults to 2):
|
| 99 |
+
Relative weight of the generalized IoU loss in the object detection loss.
|
| 100 |
+
focal_alpha (`float`, *optional*, defaults to 0.25):
|
| 101 |
+
Alpha parameter in the focal loss.
|
| 102 |
+
temperature_height (`int`, *optional*, defaults to 20):
|
| 103 |
+
Temperature parameter to tune the flatness of positional attention (HEIGHT)
|
| 104 |
+
temperature_width (`int`, *optional*, defaults to 20):
|
| 105 |
+
Temperature parameter to tune the flatness of positional attention (WIDTH)
|
| 106 |
+
query_dim (`int`, *optional*, defaults to 4):
|
| 107 |
+
Query dimension parameter represents the size of the output vector.
|
| 108 |
+
random_refpoints_xy (`bool`, *optional*, defaults to `False`):
|
| 109 |
+
Whether to fix the x and y coordinates of the anchor boxes with random initialization.
|
| 110 |
+
keep_query_pos (`bool`, *optional*, defaults to `False`):
|
| 111 |
+
Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer.
|
| 112 |
+
num_patterns (`int`, *optional*, defaults to 0):
|
| 113 |
+
Number of pattern embeddings.
|
| 114 |
+
normalize_before (`bool`, *optional*, defaults to `False`):
|
| 115 |
+
Whether we use a normalization layer in the Encoder or not.
|
| 116 |
+
sine_position_embedding_scale (`float`, *optional*, defaults to 'None'):
|
| 117 |
+
Scaling factor applied to the normalized positional encodings.
|
| 118 |
+
initializer_bias_prior_prob (`float`, *optional*):
|
| 119 |
+
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
| 120 |
+
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
Examples:
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
>>> from transformers import DabDetrConfig, DabDetrModel
|
| 127 |
+
|
| 128 |
+
>>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration
|
| 129 |
+
>>> configuration = DabDetrConfig()
|
| 130 |
+
|
| 131 |
+
>>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration
|
| 132 |
+
>>> model = DabDetrModel(configuration)
|
| 133 |
+
|
| 134 |
+
>>> # Accessing the model configuration
|
| 135 |
+
>>> configuration = model.config
|
| 136 |
+
```"""
|
| 137 |
+
|
| 138 |
+
model_type = "dab-detr"
|
| 139 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 140 |
+
attribute_map = {
|
| 141 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
use_timm_backbone=True,
|
| 147 |
+
backbone_config=None,
|
| 148 |
+
backbone="resnet50",
|
| 149 |
+
use_pretrained_backbone=True,
|
| 150 |
+
backbone_kwargs=None,
|
| 151 |
+
num_queries=300,
|
| 152 |
+
encoder_layers=6,
|
| 153 |
+
encoder_ffn_dim=2048,
|
| 154 |
+
encoder_attention_heads=8,
|
| 155 |
+
decoder_layers=6,
|
| 156 |
+
decoder_ffn_dim=2048,
|
| 157 |
+
decoder_attention_heads=8,
|
| 158 |
+
is_encoder_decoder=True,
|
| 159 |
+
activation_function="prelu",
|
| 160 |
+
hidden_size=256,
|
| 161 |
+
dropout=0.1,
|
| 162 |
+
attention_dropout=0.0,
|
| 163 |
+
activation_dropout=0.0,
|
| 164 |
+
init_std=0.02,
|
| 165 |
+
init_xavier_std=1.0,
|
| 166 |
+
auxiliary_loss=False,
|
| 167 |
+
dilation=False,
|
| 168 |
+
class_cost=2,
|
| 169 |
+
bbox_cost=5,
|
| 170 |
+
giou_cost=2,
|
| 171 |
+
cls_loss_coefficient=2,
|
| 172 |
+
bbox_loss_coefficient=5,
|
| 173 |
+
giou_loss_coefficient=2,
|
| 174 |
+
focal_alpha=0.25,
|
| 175 |
+
temperature_height=20,
|
| 176 |
+
temperature_width=20,
|
| 177 |
+
query_dim=4,
|
| 178 |
+
random_refpoints_xy=False,
|
| 179 |
+
keep_query_pos=False,
|
| 180 |
+
num_patterns=0,
|
| 181 |
+
normalize_before=False,
|
| 182 |
+
sine_position_embedding_scale=None,
|
| 183 |
+
initializer_bias_prior_prob=None,
|
| 184 |
+
**kwargs,
|
| 185 |
+
):
|
| 186 |
+
if query_dim != 4:
|
| 187 |
+
raise ValueError("The query dimensions has to be 4.")
|
| 188 |
+
|
| 189 |
+
# We default to values which were previously hard-coded in the model. This enables configurability of the config
|
| 190 |
+
# while keeping the default behavior the same.
|
| 191 |
+
if use_timm_backbone and backbone_kwargs is None:
|
| 192 |
+
backbone_kwargs = {}
|
| 193 |
+
if dilation:
|
| 194 |
+
backbone_kwargs["output_stride"] = 16
|
| 195 |
+
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
|
| 196 |
+
backbone_kwargs["in_chans"] = 3 # num_channels
|
| 197 |
+
# Backwards compatibility
|
| 198 |
+
elif not use_timm_backbone and backbone in (None, "resnet50"):
|
| 199 |
+
if backbone_config is None:
|
| 200 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
| 201 |
+
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
| 202 |
+
elif isinstance(backbone_config, dict):
|
| 203 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 204 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 205 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 206 |
+
backbone = None
|
| 207 |
+
# set timm attributes to None
|
| 208 |
+
dilation = None
|
| 209 |
+
|
| 210 |
+
verify_backbone_config_arguments(
|
| 211 |
+
use_timm_backbone=use_timm_backbone,
|
| 212 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 213 |
+
backbone=backbone,
|
| 214 |
+
backbone_config=backbone_config,
|
| 215 |
+
backbone_kwargs=backbone_kwargs,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.use_timm_backbone = use_timm_backbone
|
| 219 |
+
self.backbone_config = backbone_config
|
| 220 |
+
self.num_queries = num_queries
|
| 221 |
+
self.hidden_size = hidden_size
|
| 222 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 223 |
+
self.encoder_layers = encoder_layers
|
| 224 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 225 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 226 |
+
self.decoder_layers = decoder_layers
|
| 227 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 228 |
+
self.dropout = dropout
|
| 229 |
+
self.attention_dropout = attention_dropout
|
| 230 |
+
self.activation_dropout = activation_dropout
|
| 231 |
+
self.activation_function = activation_function
|
| 232 |
+
self.init_std = init_std
|
| 233 |
+
self.init_xavier_std = init_xavier_std
|
| 234 |
+
self.num_hidden_layers = encoder_layers
|
| 235 |
+
self.auxiliary_loss = auxiliary_loss
|
| 236 |
+
self.backbone = backbone
|
| 237 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 238 |
+
self.backbone_kwargs = backbone_kwargs
|
| 239 |
+
# Hungarian matcher
|
| 240 |
+
self.class_cost = class_cost
|
| 241 |
+
self.bbox_cost = bbox_cost
|
| 242 |
+
self.giou_cost = giou_cost
|
| 243 |
+
# Loss coefficients
|
| 244 |
+
self.cls_loss_coefficient = cls_loss_coefficient
|
| 245 |
+
self.bbox_loss_coefficient = bbox_loss_coefficient
|
| 246 |
+
self.giou_loss_coefficient = giou_loss_coefficient
|
| 247 |
+
self.focal_alpha = focal_alpha
|
| 248 |
+
self.query_dim = query_dim
|
| 249 |
+
self.random_refpoints_xy = random_refpoints_xy
|
| 250 |
+
self.keep_query_pos = keep_query_pos
|
| 251 |
+
self.num_patterns = num_patterns
|
| 252 |
+
self.normalize_before = normalize_before
|
| 253 |
+
self.temperature_width = temperature_width
|
| 254 |
+
self.temperature_height = temperature_height
|
| 255 |
+
self.sine_position_embedding_scale = sine_position_embedding_scale
|
| 256 |
+
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
| 257 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
__all__ = ["DabDetrConfig"]
|
docs/transformers/build/lib/transformers/models/dab_detr/convert_dab_detr_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert DAB-DETR checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import gc
|
| 19 |
+
import json
|
| 20 |
+
import re
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logging.set_verbosity_info()
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
| 34 |
+
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
| 35 |
+
# for dab-DETR, also convert reference point head and query scale MLP
|
| 36 |
+
r"input_proj\.(bias|weight)": r"input_projection.\1",
|
| 37 |
+
r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
|
| 38 |
+
r"class_embed\.(bias|weight)": r"class_embed.\1",
|
| 39 |
+
# negative lookbehind because of the overlap
|
| 40 |
+
r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
|
| 41 |
+
r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
|
| 42 |
+
r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
|
| 43 |
+
r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
|
| 44 |
+
r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
|
| 45 |
+
r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
|
| 46 |
+
r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
|
| 47 |
+
r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
|
| 48 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
|
| 49 |
+
# output projection
|
| 50 |
+
r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
|
| 51 |
+
# FFN layers
|
| 52 |
+
r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
|
| 53 |
+
# normalization layers
|
| 54 |
+
# nm1
|
| 55 |
+
r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
|
| 56 |
+
# nm2
|
| 57 |
+
r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
|
| 58 |
+
# activation function weight
|
| 59 |
+
r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
|
| 60 |
+
#########################################################################################################################################
|
| 61 |
+
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
|
| 62 |
+
r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
|
| 63 |
+
r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
|
| 64 |
+
# FFNs
|
| 65 |
+
r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
|
| 66 |
+
# nm1
|
| 67 |
+
r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
|
| 68 |
+
# nm2
|
| 69 |
+
r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
|
| 70 |
+
# nm3
|
| 71 |
+
r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
|
| 72 |
+
# activation function weight
|
| 73 |
+
r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
|
| 74 |
+
# q, k, v projections and biases in self-attention in decoder
|
| 75 |
+
r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
|
| 76 |
+
r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
|
| 77 |
+
r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
|
| 78 |
+
r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
|
| 79 |
+
r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
|
| 80 |
+
# q, k, v projections in cross-attention in decoder
|
| 81 |
+
r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
|
| 82 |
+
r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
|
| 83 |
+
r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
|
| 84 |
+
r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
|
| 85 |
+
r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
|
| 90 |
+
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
| 91 |
+
"""
|
| 92 |
+
This function should be applied only once, on the concatenated keys to efficiently rename using
|
| 93 |
+
the key mappings.
|
| 94 |
+
"""
|
| 95 |
+
output_dict = {}
|
| 96 |
+
if state_dict_keys is not None:
|
| 97 |
+
old_text = "\n".join(state_dict_keys)
|
| 98 |
+
new_text = old_text
|
| 99 |
+
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
| 100 |
+
if replacement is None:
|
| 101 |
+
new_text = re.sub(pattern, "", new_text) # an empty line
|
| 102 |
+
continue
|
| 103 |
+
new_text = re.sub(pattern, replacement, new_text)
|
| 104 |
+
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
| 105 |
+
return output_dict
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
|
| 109 |
+
logger.info("Converting image processor...")
|
| 110 |
+
format = "coco_detection"
|
| 111 |
+
image_processor = ConditionalDetrImageProcessor(format=format)
|
| 112 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 113 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 114 |
+
|
| 115 |
+
if push_to_hub:
|
| 116 |
+
image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
| 121 |
+
# load modified config. Why? After loading the default config, the backbone kwargs are already set.
|
| 122 |
+
if "dc5" in model_name:
|
| 123 |
+
config = DabDetrConfig(dilation=True)
|
| 124 |
+
else:
|
| 125 |
+
# load default config
|
| 126 |
+
config = DabDetrConfig()
|
| 127 |
+
# set other attributes
|
| 128 |
+
if "dab-detr-resnet-50-dc5" == model_name:
|
| 129 |
+
config.temperature_height = 10
|
| 130 |
+
config.temperature_width = 10
|
| 131 |
+
if "fixxy" in model_name:
|
| 132 |
+
config.random_refpoints_xy = True
|
| 133 |
+
if "pat3" in model_name:
|
| 134 |
+
config.num_patterns = 3
|
| 135 |
+
# only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
|
| 136 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
|
| 137 |
+
|
| 138 |
+
config.num_labels = 91
|
| 139 |
+
repo_id = "huggingface/label-files"
|
| 140 |
+
filename = "coco-detection-id2label.json"
|
| 141 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 142 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 143 |
+
config.id2label = id2label
|
| 144 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 145 |
+
# load original model from local path
|
| 146 |
+
loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"), weights_only=True)["model"]
|
| 147 |
+
# Renaming the original model state dictionary to HF compatibile
|
| 148 |
+
all_keys = list(loaded.keys())
|
| 149 |
+
new_keys = convert_old_keys_to_new_keys(all_keys)
|
| 150 |
+
state_dict = {}
|
| 151 |
+
for key in all_keys:
|
| 152 |
+
if "backbone.0.body" in key:
|
| 153 |
+
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
|
| 154 |
+
state_dict[new_key] = loaded[key]
|
| 155 |
+
# Q, K, V encoder values mapping
|
| 156 |
+
elif re.search("self_attn.in_proj_(weight|bias)", key):
|
| 157 |
+
# Dynamically find the layer number
|
| 158 |
+
pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
|
| 159 |
+
match = re.search(pattern, key)
|
| 160 |
+
if match:
|
| 161 |
+
layer_num = match.group(1)
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f"Pattern not found in key: {key}")
|
| 164 |
+
|
| 165 |
+
in_proj_value = loaded.pop(key)
|
| 166 |
+
if "weight" in key:
|
| 167 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
|
| 168 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
|
| 169 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
|
| 170 |
+
elif "bias" in key:
|
| 171 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
|
| 172 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
|
| 173 |
+
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
|
| 174 |
+
else:
|
| 175 |
+
new_key = new_keys[key]
|
| 176 |
+
state_dict[new_key] = loaded[key]
|
| 177 |
+
|
| 178 |
+
del loaded
|
| 179 |
+
gc.collect()
|
| 180 |
+
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
| 181 |
+
prefix = "model."
|
| 182 |
+
for key in state_dict.copy().keys():
|
| 183 |
+
if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
|
| 184 |
+
val = state_dict.pop(key)
|
| 185 |
+
state_dict[prefix + key] = val
|
| 186 |
+
# finally, create HuggingFace model and load state dict
|
| 187 |
+
model = DabDetrForObjectDetection(config)
|
| 188 |
+
model.load_state_dict(state_dict)
|
| 189 |
+
model.eval()
|
| 190 |
+
logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
|
| 191 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 192 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 193 |
+
|
| 194 |
+
if push_to_hub:
|
| 195 |
+
model.push_to_hub(repo_id=model_name, commit_message="Add new model")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
| 199 |
+
logger.info("Converting image processor...")
|
| 200 |
+
write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
|
| 201 |
+
|
| 202 |
+
logger.info(f"Converting model {model_name}...")
|
| 203 |
+
write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
parser = argparse.ArgumentParser()
|
| 208 |
+
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--model_name",
|
| 211 |
+
default="dab-detr-resnet-50",
|
| 212 |
+
type=str,
|
| 213 |
+
help="Name of the DAB_DETR model you'd like to convert.",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--pretrained_model_weights_path",
|
| 217 |
+
default="modelzoo/R50/checkpoint.pth",
|
| 218 |
+
type=str,
|
| 219 |
+
help="The path of the original model weights like: modelzoo/checkpoint.pth",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--push_to_hub",
|
| 226 |
+
default=True,
|
| 227 |
+
type=bool,
|
| 228 |
+
help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
|
| 229 |
+
)
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
convert_dab_detr_checkpoint(
|
| 232 |
+
args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
|
| 233 |
+
)
|
docs/transformers/build/lib/transformers/models/dab_detr/modeling_dab_detr.py
ADDED
|
@@ -0,0 +1,1716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 IDEA Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DAB-DETR model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor, nn
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 26 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...utils import (
|
| 29 |
+
ModelOutput,
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
logging,
|
| 33 |
+
replace_return_docstrings,
|
| 34 |
+
)
|
| 35 |
+
from ...utils.backbone_utils import load_backbone
|
| 36 |
+
from .configuration_dab_detr import DabDetrConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
_CONFIG_FOR_DOC = "DabDetrConfig"
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "IDEA-Research/dab_detr-base"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points)
|
| 47 |
+
class DabDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
| 48 |
+
"""
|
| 49 |
+
Base class for outputs of the Conditional DETR decoder. This class adds one attribute to
|
| 50 |
+
BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output
|
| 51 |
+
of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary
|
| 52 |
+
decoding losses.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 56 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 57 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 58 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 59 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 60 |
+
plus the initial embedding outputs.
|
| 61 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 62 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 63 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 64 |
+
the self-attention heads.
|
| 65 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
| 66 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 67 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 68 |
+
used to compute the weighted average in the cross-attention heads.
|
| 69 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 70 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 71 |
+
layernorm.
|
| 72 |
+
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
|
| 73 |
+
Reference points (reference points of each layer of the decoder).
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 77 |
+
reference_points: Optional[Tuple[torch.FloatTensor]] = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrModelOutput with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR,2 (anchor points)->4 (anchor points)
|
| 82 |
+
class DabDetrModelOutput(Seq2SeqModelOutput):
|
| 83 |
+
"""
|
| 84 |
+
Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to
|
| 85 |
+
Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder
|
| 86 |
+
layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding
|
| 87 |
+
losses.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 91 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 92 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 93 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 94 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
|
| 95 |
+
layer plus the initial embedding outputs.
|
| 96 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 97 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 98 |
+
sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
|
| 99 |
+
weighted average in the self-attention heads.
|
| 100 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 101 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 102 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 103 |
+
used to compute the weighted average in the cross-attention heads.
|
| 104 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 105 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 106 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 107 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 108 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
| 109 |
+
layer plus the initial embedding outputs.
|
| 110 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 111 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 112 |
+
sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
|
| 113 |
+
weighted average in the self-attention heads.
|
| 114 |
+
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
| 115 |
+
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
| 116 |
+
layernorm.
|
| 117 |
+
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
|
| 118 |
+
Reference points (reference points of each layer of the decoder).
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
| 122 |
+
reference_points: Optional[Tuple[torch.FloatTensor]] = None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
# Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->DabDetr
|
| 127 |
+
class DabDetrObjectDetectionOutput(ModelOutput):
|
| 128 |
+
"""
|
| 129 |
+
Output type of [`DabDetrForObjectDetection`].
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
| 133 |
+
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
| 134 |
+
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
| 135 |
+
scale-invariant IoU loss.
|
| 136 |
+
loss_dict (`Dict`, *optional*):
|
| 137 |
+
A dictionary containing the individual losses. Useful for logging.
|
| 138 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
|
| 139 |
+
Classification logits (including no-object) for all queries.
|
| 140 |
+
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 141 |
+
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
| 142 |
+
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
| 143 |
+
possible padding). You can use [`~DabDetrImageProcessor.post_process_object_detection`] to retrieve the
|
| 144 |
+
unnormalized bounding boxes.
|
| 145 |
+
auxiliary_outputs (`list[Dict]`, *optional*):
|
| 146 |
+
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
|
| 147 |
+
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
|
| 148 |
+
`pred_boxes`) for each decoder layer.
|
| 149 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 150 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
| 151 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 152 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 153 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each
|
| 154 |
+
layer plus the initial embedding outputs.
|
| 155 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 156 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 157 |
+
sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the
|
| 158 |
+
weighted average in the self-attention heads.
|
| 159 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 160 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 161 |
+
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
| 162 |
+
used to compute the weighted average in the cross-attention heads.
|
| 163 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 164 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
| 165 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 166 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 167 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
| 168 |
+
layer plus the initial embedding outputs.
|
| 169 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 170 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 171 |
+
sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the
|
| 172 |
+
weighted average in the self-attention heads.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
loss: Optional[torch.FloatTensor] = None
|
| 176 |
+
loss_dict: Optional[Dict] = None
|
| 177 |
+
logits: Optional[torch.FloatTensor] = None
|
| 178 |
+
pred_boxes: Optional[torch.FloatTensor] = None
|
| 179 |
+
auxiliary_outputs: Optional[List[Dict]] = None
|
| 180 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 181 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 182 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 183 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 184 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 185 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 186 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DabDetr
|
| 190 |
+
class DabDetrFrozenBatchNorm2d(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 193 |
+
|
| 194 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
| 195 |
+
torchvision.models.resnet[18,34,50,101] produce nans.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, n):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.register_buffer("weight", torch.ones(n))
|
| 201 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 202 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 203 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 204 |
+
|
| 205 |
+
def _load_from_state_dict(
|
| 206 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 207 |
+
):
|
| 208 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 209 |
+
if num_batches_tracked_key in state_dict:
|
| 210 |
+
del state_dict[num_batches_tracked_key]
|
| 211 |
+
|
| 212 |
+
super()._load_from_state_dict(
|
| 213 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(self, x):
|
| 217 |
+
# move reshapes to the beginning
|
| 218 |
+
# to make it user-friendly
|
| 219 |
+
weight = self.weight.reshape(1, -1, 1, 1)
|
| 220 |
+
bias = self.bias.reshape(1, -1, 1, 1)
|
| 221 |
+
running_var = self.running_var.reshape(1, -1, 1, 1)
|
| 222 |
+
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
| 223 |
+
epsilon = 1e-5
|
| 224 |
+
scale = weight * (running_var + epsilon).rsqrt()
|
| 225 |
+
bias = bias - running_mean * scale
|
| 226 |
+
return x * scale + bias
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DabDetr
|
| 230 |
+
def replace_batch_norm(model):
|
| 231 |
+
r"""
|
| 232 |
+
Recursively replace all `torch.nn.BatchNorm2d` with `DabDetrFrozenBatchNorm2d`.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
model (torch.nn.Module):
|
| 236 |
+
input model
|
| 237 |
+
"""
|
| 238 |
+
for name, module in model.named_children():
|
| 239 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 240 |
+
new_module = DabDetrFrozenBatchNorm2d(module.num_features)
|
| 241 |
+
|
| 242 |
+
if not module.weight.device == torch.device("meta"):
|
| 243 |
+
new_module.weight.data.copy_(module.weight)
|
| 244 |
+
new_module.bias.data.copy_(module.bias)
|
| 245 |
+
new_module.running_mean.data.copy_(module.running_mean)
|
| 246 |
+
new_module.running_var.data.copy_(module.running_var)
|
| 247 |
+
|
| 248 |
+
model._modules[name] = new_module
|
| 249 |
+
|
| 250 |
+
if len(list(module.children())) > 0:
|
| 251 |
+
replace_batch_norm(module)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Modified from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->DabDetr
|
| 255 |
+
class DabDetrConvEncoder(nn.Module):
|
| 256 |
+
"""
|
| 257 |
+
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
| 258 |
+
|
| 259 |
+
nn.BatchNorm2d layers are replaced by DabDetrFrozenBatchNorm2d as defined above.
|
| 260 |
+
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, config: DabDetrConfig):
|
| 264 |
+
super().__init__()
|
| 265 |
+
|
| 266 |
+
self.config = config
|
| 267 |
+
backbone = load_backbone(config)
|
| 268 |
+
|
| 269 |
+
# replace batch norm by frozen batch norm
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
replace_batch_norm(backbone)
|
| 272 |
+
self.model = backbone
|
| 273 |
+
self.intermediate_channel_sizes = self.model.channels
|
| 274 |
+
|
| 275 |
+
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
| 276 |
+
# send pixel_values through the model to get list of feature maps
|
| 277 |
+
features = self.model(pixel_values).feature_maps
|
| 278 |
+
|
| 279 |
+
out = []
|
| 280 |
+
for feature_map in features:
|
| 281 |
+
# downsample pixel_mask to match shape of corresponding feature_map
|
| 282 |
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
| 283 |
+
out.append((feature_map, mask))
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DabDetr
|
| 288 |
+
class DabDetrConvModel(nn.Module):
|
| 289 |
+
"""
|
| 290 |
+
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, conv_encoder, position_embedding):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.conv_encoder = conv_encoder
|
| 296 |
+
self.position_embedding = position_embedding
|
| 297 |
+
|
| 298 |
+
def forward(self, pixel_values, pixel_mask):
|
| 299 |
+
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
| 300 |
+
out = self.conv_encoder(pixel_values, pixel_mask)
|
| 301 |
+
pos = []
|
| 302 |
+
for feature_map, mask in out:
|
| 303 |
+
# position encoding
|
| 304 |
+
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
| 305 |
+
|
| 306 |
+
return out, pos
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrSinePositionEmbedding with ConditionalDetr->DabDetr
|
| 310 |
+
class DabDetrSinePositionEmbedding(nn.Module):
|
| 311 |
+
"""
|
| 312 |
+
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
| 313 |
+
need paper, generalized to work on images.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, config: DabDetrConfig):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.config = config
|
| 319 |
+
self.embedding_dim = config.hidden_size / 2
|
| 320 |
+
self.temperature_height = config.temperature_height
|
| 321 |
+
self.temperature_width = config.temperature_width
|
| 322 |
+
scale = config.sine_position_embedding_scale
|
| 323 |
+
if scale is None:
|
| 324 |
+
scale = 2 * math.pi
|
| 325 |
+
self.scale = scale
|
| 326 |
+
|
| 327 |
+
def forward(self, pixel_values, pixel_mask):
|
| 328 |
+
if pixel_mask is None:
|
| 329 |
+
raise ValueError("No pixel mask provided")
|
| 330 |
+
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
| 331 |
+
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
| 332 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
|
| 333 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
|
| 334 |
+
|
| 335 |
+
# We use float32 to ensure reproducibility of the original implementation
|
| 336 |
+
dim_tx = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
| 337 |
+
# Modifying dim_tx in place to avoid extra memory allocation -> dim_tx = self.temperature_width ** (2 * (dim_tx // 2) / self.embedding_dim)
|
| 338 |
+
dim_tx //= 2
|
| 339 |
+
dim_tx.mul_(2 / self.embedding_dim)
|
| 340 |
+
dim_tx.copy_(self.temperature_width**dim_tx)
|
| 341 |
+
pos_x = x_embed[:, :, :, None] / dim_tx
|
| 342 |
+
|
| 343 |
+
# We use float32 to ensure reproducibility of the original implementation
|
| 344 |
+
dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
| 345 |
+
# Modifying dim_ty in place to avoid extra memory allocation -> dim_ty = self.temperature_height ** (2 * (dim_ty // 2) / self.embedding_dim)
|
| 346 |
+
dim_ty //= 2
|
| 347 |
+
dim_ty.mul_(2 / self.embedding_dim)
|
| 348 |
+
dim_ty.copy_(self.temperature_height**dim_ty)
|
| 349 |
+
pos_y = y_embed[:, :, :, None] / dim_ty
|
| 350 |
+
|
| 351 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 352 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 353 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 354 |
+
return pos
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# function to generate sine positional embedding for 4d coordinates
|
| 358 |
+
def gen_sine_position_embeddings(pos_tensor, hidden_size=256):
|
| 359 |
+
"""
|
| 360 |
+
This function computes position embeddings using sine and cosine functions from the input positional tensor,
|
| 361 |
+
which has a shape of (batch_size, num_queries, 4).
|
| 362 |
+
The last dimension of `pos_tensor` represents the following coordinates:
|
| 363 |
+
- 0: x-coord
|
| 364 |
+
- 1: y-coord
|
| 365 |
+
- 2: width
|
| 366 |
+
- 3: height
|
| 367 |
+
|
| 368 |
+
The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension
|
| 369 |
+
achieved by concatenating the sine and cosine values for each coordinate.
|
| 370 |
+
"""
|
| 371 |
+
scale = 2 * math.pi
|
| 372 |
+
dim = hidden_size // 2
|
| 373 |
+
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
|
| 374 |
+
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
|
| 375 |
+
x_embed = pos_tensor[:, :, 0] * scale
|
| 376 |
+
y_embed = pos_tensor[:, :, 1] * scale
|
| 377 |
+
pos_x = x_embed[:, :, None] / dim_t
|
| 378 |
+
pos_y = y_embed[:, :, None] / dim_t
|
| 379 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
| 380 |
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
| 381 |
+
if pos_tensor.size(-1) == 4:
|
| 382 |
+
w_embed = pos_tensor[:, :, 2] * scale
|
| 383 |
+
pos_w = w_embed[:, :, None] / dim_t
|
| 384 |
+
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
| 385 |
+
|
| 386 |
+
h_embed = pos_tensor[:, :, 3] * scale
|
| 387 |
+
pos_h = h_embed[:, :, None] / dim_t
|
| 388 |
+
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
| 389 |
+
|
| 390 |
+
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
| 391 |
+
else:
|
| 392 |
+
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
|
| 393 |
+
return pos
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def inverse_sigmoid(x, eps=1e-5):
|
| 397 |
+
x = x.clamp(min=0, max=1)
|
| 398 |
+
x1 = x.clamp(min=eps)
|
| 399 |
+
x2 = (1 - x).clamp(min=eps)
|
| 400 |
+
return torch.log(x1 / x2)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# Modified from transformers.models.detr.modeling_detr.DetrAttention
|
| 404 |
+
class DetrAttention(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 407 |
+
|
| 408 |
+
Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
def __init__(
|
| 412 |
+
self,
|
| 413 |
+
config: DabDetrConfig,
|
| 414 |
+
bias: bool = True,
|
| 415 |
+
):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.config = config
|
| 418 |
+
self.hidden_size = config.hidden_size
|
| 419 |
+
self.num_heads = config.encoder_attention_heads
|
| 420 |
+
self.attention_dropout = config.attention_dropout
|
| 421 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 422 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
| 425 |
+
f" {self.num_heads})."
|
| 426 |
+
)
|
| 427 |
+
self.scaling = self.head_dim**-0.5
|
| 428 |
+
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
|
| 429 |
+
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
|
| 430 |
+
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
|
| 431 |
+
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias)
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
hidden_states: torch.Tensor,
|
| 436 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 437 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 438 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 439 |
+
output_attentions: bool = False,
|
| 440 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 441 |
+
"""Input shape: Batch x Time x Channel"""
|
| 442 |
+
batch_size, q_len, embed_dim = hidden_states.size()
|
| 443 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 444 |
+
if object_queries is not None:
|
| 445 |
+
hidden_states_original = hidden_states
|
| 446 |
+
hidden_states = hidden_states + object_queries
|
| 447 |
+
|
| 448 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 449 |
+
key_states = self.k_proj(hidden_states)
|
| 450 |
+
value_states = self.v_proj(hidden_states_original)
|
| 451 |
+
|
| 452 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 453 |
+
key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 454 |
+
value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 455 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 456 |
+
|
| 457 |
+
if attention_mask is not None:
|
| 458 |
+
attn_weights = attn_weights + attention_mask
|
| 459 |
+
|
| 460 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 461 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 462 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 463 |
+
|
| 464 |
+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
| 467 |
+
f" {attn_output.size()}"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 471 |
+
|
| 472 |
+
attn_output = attn_output.reshape(batch_size, q_len, embed_dim)
|
| 473 |
+
attn_output = self.out_proj(attn_output)
|
| 474 |
+
|
| 475 |
+
if not output_attentions:
|
| 476 |
+
attn_weights = None
|
| 477 |
+
|
| 478 |
+
return attn_output, attn_weights
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrAttention with ConditionalDetr->DABDETR,Conditional DETR->DabDetr
|
| 482 |
+
class DabDetrAttention(nn.Module):
|
| 483 |
+
"""
|
| 484 |
+
Cross-Attention used in DAB-DETR 'DAB-DETR for Fast Training Convergence' paper.
|
| 485 |
+
|
| 486 |
+
The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be
|
| 487 |
+
different to v.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = False):
|
| 491 |
+
super().__init__()
|
| 492 |
+
self.config = config
|
| 493 |
+
self.embed_dim = config.hidden_size * 2 if is_cross else config.hidden_size
|
| 494 |
+
self.output_dim = config.hidden_size
|
| 495 |
+
self.attention_heads = config.decoder_attention_heads
|
| 496 |
+
self.attention_dropout = config.attention_dropout
|
| 497 |
+
self.attention_head_dim = self.embed_dim // self.attention_heads
|
| 498 |
+
if self.attention_head_dim * self.attention_heads != self.embed_dim:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `attention_heads`:"
|
| 501 |
+
f" {self.attention_heads})."
|
| 502 |
+
)
|
| 503 |
+
# head dimension of values
|
| 504 |
+
self.values_head_dim = self.output_dim // self.attention_heads
|
| 505 |
+
if self.values_head_dim * self.attention_heads != self.output_dim:
|
| 506 |
+
raise ValueError(
|
| 507 |
+
f"output_dim must be divisible by attention_heads (got `output_dim`: {self.output_dim} and `attention_heads`: {self.attention_heads})."
|
| 508 |
+
)
|
| 509 |
+
self.scaling = self.attention_head_dim**-0.5
|
| 510 |
+
self.output_proj = nn.Linear(self.output_dim, self.output_dim, bias=bias)
|
| 511 |
+
|
| 512 |
+
def forward(
|
| 513 |
+
self,
|
| 514 |
+
hidden_states: torch.Tensor,
|
| 515 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 516 |
+
key_states: Optional[torch.Tensor] = None,
|
| 517 |
+
value_states: Optional[torch.Tensor] = None,
|
| 518 |
+
output_attentions: Optional[bool] = None,
|
| 519 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 520 |
+
"""Input shape: Batch x Time x Channel"""
|
| 521 |
+
|
| 522 |
+
batch_size, q_len, _ = hidden_states.size()
|
| 523 |
+
|
| 524 |
+
# scaling query and refactor key-, value states
|
| 525 |
+
query_states = hidden_states * self.scaling
|
| 526 |
+
query_states = query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2)
|
| 527 |
+
key_states = key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2)
|
| 528 |
+
value_states = value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2)
|
| 529 |
+
|
| 530 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 531 |
+
|
| 532 |
+
if attention_mask is not None:
|
| 533 |
+
attn_weights = attn_weights + attention_mask
|
| 534 |
+
|
| 535 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 536 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 537 |
+
attn_output = torch.matmul(attn_probs, value_states)
|
| 538 |
+
|
| 539 |
+
if attn_output.size() != (batch_size, self.attention_heads, q_len, self.values_head_dim):
|
| 540 |
+
raise ValueError(
|
| 541 |
+
f"`attn_output` should be of size {(batch_size, self.attention_heads, q_len, self.values_head_dim)}, but is"
|
| 542 |
+
f" {attn_output.size()}"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 546 |
+
|
| 547 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.output_dim)
|
| 548 |
+
attn_output = self.output_proj(attn_output)
|
| 549 |
+
|
| 550 |
+
if not output_attentions:
|
| 551 |
+
attn_weights = None
|
| 552 |
+
|
| 553 |
+
return attn_output, attn_weights
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class DabDetrDecoderLayerSelfAttention(nn.Module):
|
| 557 |
+
def __init__(self, config: DabDetrConfig):
|
| 558 |
+
super().__init__()
|
| 559 |
+
self.dropout = config.dropout
|
| 560 |
+
self.self_attn_query_content_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 561 |
+
self.self_attn_query_pos_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 562 |
+
self.self_attn_key_content_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 563 |
+
self.self_attn_key_pos_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 564 |
+
self.self_attn_value_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
| 565 |
+
self.self_attn = DabDetrAttention(config)
|
| 566 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 567 |
+
|
| 568 |
+
def forward(
|
| 569 |
+
self,
|
| 570 |
+
hidden_states: torch.Tensor,
|
| 571 |
+
query_position_embeddings: Optional[torch.Tensor] = None,
|
| 572 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 573 |
+
output_attentions: Optional[bool] = None,
|
| 574 |
+
):
|
| 575 |
+
residual = hidden_states
|
| 576 |
+
query_content = self.self_attn_query_content_proj(hidden_states)
|
| 577 |
+
query_pos = self.self_attn_query_pos_proj(query_position_embeddings)
|
| 578 |
+
key_content = self.self_attn_key_content_proj(hidden_states)
|
| 579 |
+
key_pos = self.self_attn_key_pos_proj(query_position_embeddings)
|
| 580 |
+
value = self.self_attn_value_proj(hidden_states)
|
| 581 |
+
|
| 582 |
+
query = query_content + query_pos
|
| 583 |
+
key = key_content + key_pos
|
| 584 |
+
|
| 585 |
+
hidden_states, attn_weights = self.self_attn(
|
| 586 |
+
hidden_states=query,
|
| 587 |
+
attention_mask=attention_mask,
|
| 588 |
+
key_states=key,
|
| 589 |
+
value_states=value,
|
| 590 |
+
output_attentions=True,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 594 |
+
hidden_states = residual + hidden_states
|
| 595 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 596 |
+
|
| 597 |
+
return hidden_states, attn_weights
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class DabDetrDecoderLayerCrossAttention(nn.Module):
|
| 601 |
+
def __init__(self, config: DabDetrConfig, is_first: bool = False):
|
| 602 |
+
super().__init__()
|
| 603 |
+
hidden_size = config.hidden_size
|
| 604 |
+
self.cross_attn_query_content_proj = nn.Linear(hidden_size, hidden_size)
|
| 605 |
+
self.cross_attn_query_pos_proj = nn.Linear(hidden_size, hidden_size)
|
| 606 |
+
self.cross_attn_key_content_proj = nn.Linear(hidden_size, hidden_size)
|
| 607 |
+
self.cross_attn_key_pos_proj = nn.Linear(hidden_size, hidden_size)
|
| 608 |
+
self.cross_attn_value_proj = nn.Linear(hidden_size, hidden_size)
|
| 609 |
+
self.cross_attn_query_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
|
| 610 |
+
self.decoder_attention_heads = config.decoder_attention_heads
|
| 611 |
+
self.cross_attn_layer_norm = nn.LayerNorm(hidden_size)
|
| 612 |
+
self.cross_attn = DabDetrAttention(config, is_cross=True)
|
| 613 |
+
|
| 614 |
+
self.keep_query_pos = config.keep_query_pos
|
| 615 |
+
|
| 616 |
+
if not self.keep_query_pos and not is_first:
|
| 617 |
+
self.cross_attn_query_pos_proj = None
|
| 618 |
+
|
| 619 |
+
self.is_first = is_first
|
| 620 |
+
self.dropout = config.dropout
|
| 621 |
+
|
| 622 |
+
def forward(
|
| 623 |
+
self,
|
| 624 |
+
hidden_states: torch.Tensor,
|
| 625 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 626 |
+
query_position_embeddings: Optional[torch.Tensor] = None,
|
| 627 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 628 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 629 |
+
query_sine_embed: Optional[torch.Tensor] = None,
|
| 630 |
+
output_attentions: Optional[bool] = None,
|
| 631 |
+
):
|
| 632 |
+
query_content = self.cross_attn_query_content_proj(hidden_states)
|
| 633 |
+
key_content = self.cross_attn_key_content_proj(encoder_hidden_states)
|
| 634 |
+
value = self.cross_attn_value_proj(encoder_hidden_states)
|
| 635 |
+
|
| 636 |
+
batch_size, num_queries, n_model = query_content.shape
|
| 637 |
+
_, height_width, _ = key_content.shape
|
| 638 |
+
|
| 639 |
+
key_pos = self.cross_attn_key_pos_proj(object_queries)
|
| 640 |
+
|
| 641 |
+
# For the first decoder layer, we add the positional embedding predicted from
|
| 642 |
+
# the object query (the positional embedding) into the original query (key) in DETR.
|
| 643 |
+
if self.is_first or self.keep_query_pos:
|
| 644 |
+
query_pos = self.cross_attn_query_pos_proj(query_position_embeddings)
|
| 645 |
+
query = query_content + query_pos
|
| 646 |
+
key = key_content + key_pos
|
| 647 |
+
else:
|
| 648 |
+
query = query_content
|
| 649 |
+
key = key_content
|
| 650 |
+
|
| 651 |
+
query = query.view(
|
| 652 |
+
batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads
|
| 653 |
+
)
|
| 654 |
+
query_sine_embed = self.cross_attn_query_pos_sine_proj(query_sine_embed)
|
| 655 |
+
query_sine_embed = query_sine_embed.view(
|
| 656 |
+
batch_size, num_queries, self.decoder_attention_heads, n_model // self.decoder_attention_heads
|
| 657 |
+
)
|
| 658 |
+
query = torch.cat([query, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
|
| 659 |
+
key = key.view(batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads)
|
| 660 |
+
key_pos = key_pos.view(
|
| 661 |
+
batch_size, height_width, self.decoder_attention_heads, n_model // self.decoder_attention_heads
|
| 662 |
+
)
|
| 663 |
+
key = torch.cat([key, key_pos], dim=3).view(batch_size, height_width, n_model * 2)
|
| 664 |
+
|
| 665 |
+
# Cross-Attention Block
|
| 666 |
+
cross_attn_weights = None
|
| 667 |
+
if encoder_hidden_states is not None:
|
| 668 |
+
residual = hidden_states
|
| 669 |
+
|
| 670 |
+
hidden_states, cross_attn_weights = self.cross_attn(
|
| 671 |
+
hidden_states=query,
|
| 672 |
+
attention_mask=encoder_attention_mask,
|
| 673 |
+
key_states=key,
|
| 674 |
+
value_states=value,
|
| 675 |
+
output_attentions=output_attentions,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 679 |
+
hidden_states = residual + hidden_states
|
| 680 |
+
hidden_states = self.cross_attn_layer_norm(hidden_states)
|
| 681 |
+
|
| 682 |
+
return hidden_states, cross_attn_weights
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class DabDetrDecoderLayerFFN(nn.Module):
|
| 686 |
+
def __init__(self, config: DabDetrConfig):
|
| 687 |
+
super().__init__()
|
| 688 |
+
hidden_size = config.hidden_size
|
| 689 |
+
self.final_layer_norm = nn.LayerNorm(hidden_size)
|
| 690 |
+
self.fc1 = nn.Linear(hidden_size, config.decoder_ffn_dim)
|
| 691 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, hidden_size)
|
| 692 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 693 |
+
self.dropout = config.dropout
|
| 694 |
+
self.activation_dropout = config.activation_dropout
|
| 695 |
+
self.keep_query_pos = config.keep_query_pos
|
| 696 |
+
|
| 697 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 698 |
+
residual = hidden_states
|
| 699 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 700 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 701 |
+
hidden_states = self.fc2(hidden_states)
|
| 702 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 703 |
+
hidden_states = residual + hidden_states
|
| 704 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 705 |
+
|
| 706 |
+
return hidden_states
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
# Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig
|
| 710 |
+
class DabDetrEncoderLayer(nn.Module):
|
| 711 |
+
def __init__(self, config: DabDetrConfig):
|
| 712 |
+
super().__init__()
|
| 713 |
+
self.hidden_size = config.hidden_size
|
| 714 |
+
self.self_attn = DetrAttention(config)
|
| 715 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
|
| 716 |
+
self.dropout = config.dropout
|
| 717 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 718 |
+
self.fc1 = nn.Linear(self.hidden_size, config.encoder_ffn_dim)
|
| 719 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.hidden_size)
|
| 720 |
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
| 721 |
+
|
| 722 |
+
def forward(
|
| 723 |
+
self,
|
| 724 |
+
hidden_states: torch.Tensor,
|
| 725 |
+
attention_mask: torch.Tensor,
|
| 726 |
+
object_queries: torch.Tensor,
|
| 727 |
+
output_attentions: Optional[bool] = None,
|
| 728 |
+
):
|
| 729 |
+
"""
|
| 730 |
+
Args:
|
| 731 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 732 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 733 |
+
`(batch, source_len)` where padding elements are indicated by very large negative
|
| 734 |
+
values.
|
| 735 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 736 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 737 |
+
output_attentions (`bool`, *optional*):
|
| 738 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 739 |
+
returned tensors for more detail.
|
| 740 |
+
"""
|
| 741 |
+
residual = hidden_states
|
| 742 |
+
hidden_states, attn_weights = self.self_attn(
|
| 743 |
+
hidden_states=hidden_states,
|
| 744 |
+
attention_mask=attention_mask,
|
| 745 |
+
object_queries=object_queries,
|
| 746 |
+
output_attentions=output_attentions,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 750 |
+
hidden_states = residual + hidden_states
|
| 751 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 752 |
+
|
| 753 |
+
residual = hidden_states
|
| 754 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 755 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 756 |
+
|
| 757 |
+
hidden_states = self.fc2(hidden_states)
|
| 758 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 759 |
+
|
| 760 |
+
hidden_states = residual + hidden_states
|
| 761 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 762 |
+
|
| 763 |
+
outputs = (hidden_states,)
|
| 764 |
+
|
| 765 |
+
if output_attentions:
|
| 766 |
+
outputs += (attn_weights,)
|
| 767 |
+
|
| 768 |
+
return outputs
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr
|
| 772 |
+
class DabDetrDecoderLayer(nn.Module):
|
| 773 |
+
def __init__(self, config: DabDetrConfig, is_first: bool = False):
|
| 774 |
+
super().__init__()
|
| 775 |
+
self.self_attn = DabDetrDecoderLayerSelfAttention(config)
|
| 776 |
+
self.cross_attn = DabDetrDecoderLayerCrossAttention(config, is_first)
|
| 777 |
+
self.mlp = DabDetrDecoderLayerFFN(config)
|
| 778 |
+
|
| 779 |
+
def forward(
|
| 780 |
+
self,
|
| 781 |
+
hidden_states: torch.Tensor,
|
| 782 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 783 |
+
object_queries: Optional[torch.Tensor] = None,
|
| 784 |
+
query_position_embeddings: Optional[torch.Tensor] = None,
|
| 785 |
+
query_sine_embed: Optional[torch.Tensor] = None,
|
| 786 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 787 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 788 |
+
output_attentions: Optional[bool] = None,
|
| 789 |
+
):
|
| 790 |
+
"""
|
| 791 |
+
Args:
|
| 792 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 793 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 794 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 795 |
+
values.
|
| 796 |
+
object_queries (`torch.FloatTensor`, *optional*):
|
| 797 |
+
object_queries that are added to the queries and keys
|
| 798 |
+
in the cross-attention layer.
|
| 799 |
+
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
| 800 |
+
object_queries that are added to the queries and keys
|
| 801 |
+
in the self-attention layer.
|
| 802 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 803 |
+
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 804 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
| 805 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 806 |
+
values.
|
| 807 |
+
output_attentions (`bool`, *optional*):
|
| 808 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 809 |
+
returned tensors for more detail.
|
| 810 |
+
|
| 811 |
+
"""
|
| 812 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 813 |
+
hidden_states=hidden_states,
|
| 814 |
+
query_position_embeddings=query_position_embeddings,
|
| 815 |
+
attention_mask=attention_mask,
|
| 816 |
+
output_attentions=output_attentions,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
hidden_states, cross_attn_weights = self.cross_attn(
|
| 820 |
+
hidden_states=hidden_states,
|
| 821 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 822 |
+
query_position_embeddings=query_position_embeddings,
|
| 823 |
+
object_queries=object_queries,
|
| 824 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 825 |
+
query_sine_embed=query_sine_embed,
|
| 826 |
+
output_attentions=output_attentions,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
hidden_states = self.mlp(hidden_states=hidden_states)
|
| 830 |
+
|
| 831 |
+
outputs = (hidden_states,)
|
| 832 |
+
|
| 833 |
+
if output_attentions:
|
| 834 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 835 |
+
|
| 836 |
+
return outputs
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
# Modified from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->DabDetrMLP
|
| 840 |
+
class DabDetrMLP(nn.Module):
|
| 841 |
+
"""
|
| 842 |
+
Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
|
| 843 |
+
height and width of a bounding box w.r.t. an image.
|
| 844 |
+
|
| 845 |
+
Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
| 846 |
+
|
| 847 |
+
"""
|
| 848 |
+
|
| 849 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 850 |
+
super().__init__()
|
| 851 |
+
self.num_layers = num_layers
|
| 852 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 853 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 854 |
+
|
| 855 |
+
def forward(self, input_tensor):
|
| 856 |
+
for i, layer in enumerate(self.layers):
|
| 857 |
+
input_tensor = nn.functional.relu(layer(input_tensor)) if i < self.num_layers - 1 else layer(input_tensor)
|
| 858 |
+
return input_tensor
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
# Modified from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->DabDetr
|
| 862 |
+
class DabDetrPreTrainedModel(PreTrainedModel):
|
| 863 |
+
config_class = DabDetrConfig
|
| 864 |
+
base_model_prefix = "model"
|
| 865 |
+
main_input_name = "pixel_values"
|
| 866 |
+
_no_split_modules = [r"DabDetrConvEncoder", r"DabDetrEncoderLayer", r"DabDetrDecoderLayer"]
|
| 867 |
+
|
| 868 |
+
def _init_weights(self, module):
|
| 869 |
+
std = self.config.init_std
|
| 870 |
+
xavier_std = self.config.init_xavier_std
|
| 871 |
+
|
| 872 |
+
if isinstance(module, DabDetrMHAttentionMap):
|
| 873 |
+
nn.init.zeros_(module.k_linear.bias)
|
| 874 |
+
nn.init.zeros_(module.q_linear.bias)
|
| 875 |
+
nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
|
| 876 |
+
nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
|
| 877 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 878 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 879 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 880 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 881 |
+
if module.bias is not None:
|
| 882 |
+
module.bias.data.zero_()
|
| 883 |
+
elif isinstance(module, nn.Embedding):
|
| 884 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 885 |
+
if module.padding_idx is not None:
|
| 886 |
+
module.weight.data[module.padding_idx].zero_()
|
| 887 |
+
elif isinstance(module, DabDetrForObjectDetection):
|
| 888 |
+
nn.init.constant_(module.bbox_predictor.layers[-1].weight.data, 0)
|
| 889 |
+
nn.init.constant_(module.bbox_predictor.layers[-1].bias.data, 0)
|
| 890 |
+
|
| 891 |
+
# init prior_prob setting for focal loss
|
| 892 |
+
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
| 893 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
| 894 |
+
module.class_embed.bias.data.fill_(bias_value)
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
DAB_DETR_START_DOCSTRING = r"""
|
| 898 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 899 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 900 |
+
etc.)
|
| 901 |
+
|
| 902 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 903 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 904 |
+
and behavior.
|
| 905 |
+
|
| 906 |
+
Parameters:
|
| 907 |
+
config ([`DabDetrConfig`]):
|
| 908 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 909 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 910 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 911 |
+
"""
|
| 912 |
+
|
| 913 |
+
DAB_DETR_INPUTS_DOCSTRING = r"""
|
| 914 |
+
Args:
|
| 915 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 916 |
+
Pixel values. Padding will be ignored by default should you provide it.
|
| 917 |
+
|
| 918 |
+
Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`]
|
| 919 |
+
for details.
|
| 920 |
+
|
| 921 |
+
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 922 |
+
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
|
| 923 |
+
|
| 924 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 925 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 926 |
+
|
| 927 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 928 |
+
|
| 929 |
+
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
| 930 |
+
Not used by default. Can be used to mask object queries.
|
| 931 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
| 932 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
| 933 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
| 934 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
| 935 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 936 |
+
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
| 937 |
+
can choose to directly pass a flattened representation of an image.
|
| 938 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
| 939 |
+
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
| 940 |
+
embedded representation.
|
| 941 |
+
output_attentions (`bool`, *optional*):
|
| 942 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 943 |
+
tensors for more detail.
|
| 944 |
+
output_hidden_states (`bool`, *optional*):
|
| 945 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 946 |
+
more detail.
|
| 947 |
+
return_dict (`bool`, *optional*):
|
| 948 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
# Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR
|
| 953 |
+
class DabDetrEncoder(DabDetrPreTrainedModel):
|
| 954 |
+
"""
|
| 955 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 956 |
+
[`DabDetrEncoderLayer`].
|
| 957 |
+
|
| 958 |
+
The encoder updates the flattened feature map through multiple self-attention layers.
|
| 959 |
+
|
| 960 |
+
Small tweak for DAB-DETR:
|
| 961 |
+
|
| 962 |
+
- object_queries are added to the forward pass.
|
| 963 |
+
|
| 964 |
+
Args:
|
| 965 |
+
config: DabDetrConfig
|
| 966 |
+
"""
|
| 967 |
+
|
| 968 |
+
def __init__(self, config: DabDetrConfig):
|
| 969 |
+
super().__init__(config)
|
| 970 |
+
|
| 971 |
+
self.dropout = config.dropout
|
| 972 |
+
self.query_scale = DabDetrMLP(config.hidden_size, config.hidden_size, config.hidden_size, 2)
|
| 973 |
+
self.layers = nn.ModuleList([DabDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 974 |
+
self.norm = nn.LayerNorm(config.hidden_size) if config.normalize_before else None
|
| 975 |
+
self.gradient_checkpointing = False
|
| 976 |
+
|
| 977 |
+
# Initialize weights and apply final processing
|
| 978 |
+
self.post_init()
|
| 979 |
+
|
| 980 |
+
def forward(
|
| 981 |
+
self,
|
| 982 |
+
inputs_embeds,
|
| 983 |
+
attention_mask,
|
| 984 |
+
object_queries,
|
| 985 |
+
output_attentions: Optional[bool] = None,
|
| 986 |
+
output_hidden_states: Optional[bool] = None,
|
| 987 |
+
return_dict: Optional[bool] = None,
|
| 988 |
+
):
|
| 989 |
+
r"""
|
| 990 |
+
Args:
|
| 991 |
+
inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
|
| 992 |
+
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
| 993 |
+
|
| 994 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 995 |
+
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
| 996 |
+
|
| 997 |
+
- 1 for pixel features that are real (i.e. **not masked**),
|
| 998 |
+
- 0 for pixel features that are padding (i.e. **masked**).
|
| 999 |
+
|
| 1000 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1001 |
+
|
| 1002 |
+
object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
|
| 1003 |
+
Object queries that are added to the queries in each self-attention layer.
|
| 1004 |
+
|
| 1005 |
+
output_attentions (`bool`, *optional*):
|
| 1006 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1007 |
+
returned tensors for more detail.
|
| 1008 |
+
output_hidden_states (`bool`, *optional*):
|
| 1009 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1010 |
+
for more detail.
|
| 1011 |
+
return_dict (`bool`, *optional*):
|
| 1012 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1013 |
+
"""
|
| 1014 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1015 |
+
output_hidden_states = (
|
| 1016 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1017 |
+
)
|
| 1018 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1019 |
+
|
| 1020 |
+
hidden_states = inputs_embeds
|
| 1021 |
+
|
| 1022 |
+
# expand attention_mask
|
| 1023 |
+
if attention_mask is not None:
|
| 1024 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 1025 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 1026 |
+
|
| 1027 |
+
encoder_states = () if output_hidden_states else None
|
| 1028 |
+
all_attentions = () if output_attentions else None
|
| 1029 |
+
|
| 1030 |
+
for encoder_layer in self.layers:
|
| 1031 |
+
if output_hidden_states:
|
| 1032 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 1033 |
+
# pos scaler
|
| 1034 |
+
pos_scales = self.query_scale(hidden_states)
|
| 1035 |
+
# we add object_queries * pos_scaler as extra input to the encoder_layer
|
| 1036 |
+
scaled_object_queries = object_queries * pos_scales
|
| 1037 |
+
|
| 1038 |
+
if self.gradient_checkpointing and self.training:
|
| 1039 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1040 |
+
encoder_layer.__call__,
|
| 1041 |
+
hidden_states,
|
| 1042 |
+
attention_mask,
|
| 1043 |
+
scaled_object_queries,
|
| 1044 |
+
output_attentions,
|
| 1045 |
+
)
|
| 1046 |
+
else:
|
| 1047 |
+
layer_outputs = encoder_layer(
|
| 1048 |
+
hidden_states,
|
| 1049 |
+
attention_mask=attention_mask,
|
| 1050 |
+
object_queries=scaled_object_queries,
|
| 1051 |
+
output_attentions=output_attentions,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
hidden_states = layer_outputs[0]
|
| 1055 |
+
|
| 1056 |
+
if output_attentions:
|
| 1057 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1058 |
+
|
| 1059 |
+
if self.norm:
|
| 1060 |
+
hidden_states = self.norm(hidden_states)
|
| 1061 |
+
|
| 1062 |
+
if output_hidden_states:
|
| 1063 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 1064 |
+
|
| 1065 |
+
if not return_dict:
|
| 1066 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 1067 |
+
return BaseModelOutput(
|
| 1068 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 1069 |
+
)
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
# Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoder with ConditionalDetr->DabDetr,Conditional DETR->DAB-DETR
|
| 1073 |
+
class DabDetrDecoder(DabDetrPreTrainedModel):
|
| 1074 |
+
"""
|
| 1075 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DabDetrDecoderLayer`].
|
| 1076 |
+
|
| 1077 |
+
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
|
| 1078 |
+
|
| 1079 |
+
Some small tweaks for DAB-DETR:
|
| 1080 |
+
|
| 1081 |
+
- object_queries and query_position_embeddings are added to the forward pass.
|
| 1082 |
+
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
| 1083 |
+
|
| 1084 |
+
Args:
|
| 1085 |
+
config: DabDetrConfig
|
| 1086 |
+
"""
|
| 1087 |
+
|
| 1088 |
+
def __init__(self, config: DabDetrConfig):
|
| 1089 |
+
super().__init__(config)
|
| 1090 |
+
self.config = config
|
| 1091 |
+
self.dropout = config.dropout
|
| 1092 |
+
self.num_layers = config.decoder_layers
|
| 1093 |
+
self.gradient_checkpointing = False
|
| 1094 |
+
|
| 1095 |
+
self.layers = nn.ModuleList(
|
| 1096 |
+
[DabDetrDecoderLayer(config, is_first=(layer_id == 0)) for layer_id in range(config.decoder_layers)]
|
| 1097 |
+
)
|
| 1098 |
+
# in DAB-DETR, the decoder uses layernorm after the last decoder layer output
|
| 1099 |
+
self.hidden_size = config.hidden_size
|
| 1100 |
+
self.layernorm = nn.LayerNorm(self.hidden_size)
|
| 1101 |
+
|
| 1102 |
+
# Default cond-elewise
|
| 1103 |
+
self.query_scale = DabDetrMLP(self.hidden_size, self.hidden_size, self.hidden_size, 2)
|
| 1104 |
+
|
| 1105 |
+
self.ref_point_head = DabDetrMLP(
|
| 1106 |
+
config.query_dim // 2 * self.hidden_size, self.hidden_size, self.hidden_size, 2
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
self.bbox_embed = None
|
| 1110 |
+
|
| 1111 |
+
# Default decoder_modulate_hw_attn is True
|
| 1112 |
+
self.ref_anchor_head = DabDetrMLP(self.hidden_size, self.hidden_size, 2, 2)
|
| 1113 |
+
|
| 1114 |
+
# Initialize weights and apply final processing
|
| 1115 |
+
self.post_init()
|
| 1116 |
+
|
| 1117 |
+
def forward(
|
| 1118 |
+
self,
|
| 1119 |
+
inputs_embeds,
|
| 1120 |
+
encoder_hidden_states,
|
| 1121 |
+
memory_key_padding_mask,
|
| 1122 |
+
object_queries,
|
| 1123 |
+
query_position_embeddings,
|
| 1124 |
+
output_attentions: Optional[bool] = None,
|
| 1125 |
+
output_hidden_states: Optional[bool] = None,
|
| 1126 |
+
return_dict: Optional[bool] = None,
|
| 1127 |
+
):
|
| 1128 |
+
r"""
|
| 1129 |
+
Args:
|
| 1130 |
+
inputs_embeds (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`):
|
| 1131 |
+
The query embeddings that are passed into the decoder.
|
| 1132 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(encoder_sequence_length, batch_size, hidden_size)`, *optional*):
|
| 1133 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 1134 |
+
of the decoder.
|
| 1135 |
+
memory_key_padding_mask (`torch.Tensor.bool` of shape `(batch_size, sequence_length)`):
|
| 1136 |
+
The memory_key_padding_mask indicates which positions in the memory (encoder outputs) should be ignored during the attention computation,
|
| 1137 |
+
ensuring padding tokens do not influence the attention mechanism.
|
| 1138 |
+
object_queries (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`, *optional*):
|
| 1139 |
+
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
| 1140 |
+
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, number_of_anchor_points)`):
|
| 1141 |
+
Position embeddings that are added to the queries and keys in each self-attention layer.
|
| 1142 |
+
output_attentions (`bool`, *optional*):
|
| 1143 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1144 |
+
returned tensors for more detail.
|
| 1145 |
+
output_hidden_states (`bool`, *optional*):
|
| 1146 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1147 |
+
for more detail.
|
| 1148 |
+
return_dict (`bool`, *optional*):
|
| 1149 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1150 |
+
"""
|
| 1151 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1152 |
+
output_hidden_states = (
|
| 1153 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1154 |
+
)
|
| 1155 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1156 |
+
|
| 1157 |
+
if inputs_embeds is not None:
|
| 1158 |
+
hidden_states = inputs_embeds
|
| 1159 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1160 |
+
|
| 1161 |
+
# decoder layers
|
| 1162 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1163 |
+
all_self_attns = () if output_attentions else None
|
| 1164 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 1165 |
+
|
| 1166 |
+
intermediate = []
|
| 1167 |
+
reference_points = query_position_embeddings.sigmoid()
|
| 1168 |
+
ref_points = [reference_points]
|
| 1169 |
+
|
| 1170 |
+
# expand encoder attention mask
|
| 1171 |
+
if encoder_hidden_states is not None and memory_key_padding_mask is not None:
|
| 1172 |
+
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
| 1173 |
+
memory_key_padding_mask = _prepare_4d_attention_mask(
|
| 1174 |
+
memory_key_padding_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
for layer_id, decoder_layer in enumerate(self.layers):
|
| 1178 |
+
if output_hidden_states:
|
| 1179 |
+
all_hidden_states += (hidden_states,)
|
| 1180 |
+
|
| 1181 |
+
obj_center = reference_points[..., : self.config.query_dim]
|
| 1182 |
+
query_sine_embed = gen_sine_position_embeddings(obj_center, self.hidden_size)
|
| 1183 |
+
query_pos = self.ref_point_head(query_sine_embed)
|
| 1184 |
+
|
| 1185 |
+
# For the first decoder layer, we do not apply transformation over p_s
|
| 1186 |
+
pos_transformation = 1 if layer_id == 0 else self.query_scale(hidden_states)
|
| 1187 |
+
|
| 1188 |
+
# apply transformation
|
| 1189 |
+
query_sine_embed = query_sine_embed[..., : self.hidden_size] * pos_transformation
|
| 1190 |
+
|
| 1191 |
+
# modulated Height Width attentions
|
| 1192 |
+
reference_anchor_size = self.ref_anchor_head(hidden_states).sigmoid() # nq, bs, 2
|
| 1193 |
+
query_sine_embed[..., self.hidden_size // 2 :] *= (
|
| 1194 |
+
reference_anchor_size[..., 0] / obj_center[..., 2]
|
| 1195 |
+
).unsqueeze(-1)
|
| 1196 |
+
query_sine_embed[..., : self.hidden_size // 2] *= (
|
| 1197 |
+
reference_anchor_size[..., 1] / obj_center[..., 3]
|
| 1198 |
+
).unsqueeze(-1)
|
| 1199 |
+
|
| 1200 |
+
if self.gradient_checkpointing and self.training:
|
| 1201 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1202 |
+
decoder_layer.__call__,
|
| 1203 |
+
hidden_states,
|
| 1204 |
+
None,
|
| 1205 |
+
object_queries,
|
| 1206 |
+
query_pos,
|
| 1207 |
+
query_sine_embed,
|
| 1208 |
+
encoder_hidden_states,
|
| 1209 |
+
memory_key_padding_mask,
|
| 1210 |
+
output_attentions,
|
| 1211 |
+
)
|
| 1212 |
+
else:
|
| 1213 |
+
layer_outputs = decoder_layer(
|
| 1214 |
+
hidden_states,
|
| 1215 |
+
attention_mask=None,
|
| 1216 |
+
object_queries=object_queries,
|
| 1217 |
+
query_position_embeddings=query_pos,
|
| 1218 |
+
query_sine_embed=query_sine_embed,
|
| 1219 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1220 |
+
encoder_attention_mask=memory_key_padding_mask,
|
| 1221 |
+
output_attentions=output_attentions,
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
# iter update
|
| 1225 |
+
hidden_states = layer_outputs[0]
|
| 1226 |
+
|
| 1227 |
+
if self.bbox_embed is not None:
|
| 1228 |
+
new_reference_points = self.bbox_embed(hidden_states)
|
| 1229 |
+
|
| 1230 |
+
new_reference_points[..., : self.config.query_dim] += inverse_sigmoid(reference_points)
|
| 1231 |
+
new_reference_points = new_reference_points[..., : self.config.query_dim].sigmoid()
|
| 1232 |
+
if layer_id != self.num_layers - 1:
|
| 1233 |
+
ref_points.append(new_reference_points)
|
| 1234 |
+
reference_points = new_reference_points.detach()
|
| 1235 |
+
|
| 1236 |
+
intermediate.append(self.layernorm(hidden_states))
|
| 1237 |
+
|
| 1238 |
+
if output_attentions:
|
| 1239 |
+
all_self_attns += (layer_outputs[1],)
|
| 1240 |
+
|
| 1241 |
+
if encoder_hidden_states is not None:
|
| 1242 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 1243 |
+
|
| 1244 |
+
# Layer normalization on hidden states
|
| 1245 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1246 |
+
|
| 1247 |
+
if output_hidden_states:
|
| 1248 |
+
all_hidden_states += (hidden_states,)
|
| 1249 |
+
|
| 1250 |
+
output_intermediate_hidden_states = torch.stack(intermediate)
|
| 1251 |
+
output_reference_points = torch.stack(ref_points)
|
| 1252 |
+
|
| 1253 |
+
if not return_dict:
|
| 1254 |
+
return tuple(
|
| 1255 |
+
v
|
| 1256 |
+
for v in [
|
| 1257 |
+
hidden_states,
|
| 1258 |
+
all_hidden_states,
|
| 1259 |
+
all_self_attns,
|
| 1260 |
+
all_cross_attentions,
|
| 1261 |
+
output_intermediate_hidden_states,
|
| 1262 |
+
output_reference_points,
|
| 1263 |
+
]
|
| 1264 |
+
if v is not None
|
| 1265 |
+
)
|
| 1266 |
+
return DabDetrDecoderOutput(
|
| 1267 |
+
last_hidden_state=hidden_states,
|
| 1268 |
+
hidden_states=all_hidden_states,
|
| 1269 |
+
attentions=all_self_attns,
|
| 1270 |
+
cross_attentions=all_cross_attentions,
|
| 1271 |
+
intermediate_hidden_states=output_intermediate_hidden_states,
|
| 1272 |
+
reference_points=output_reference_points,
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
@add_start_docstrings(
|
| 1277 |
+
"""
|
| 1278 |
+
The bare DAB-DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
|
| 1279 |
+
hidden-states, intermediate hidden states, reference points, output coordinates without any specific head on top.
|
| 1280 |
+
""",
|
| 1281 |
+
DAB_DETR_START_DOCSTRING,
|
| 1282 |
+
)
|
| 1283 |
+
class DabDetrModel(DabDetrPreTrainedModel):
|
| 1284 |
+
def __init__(self, config: DabDetrConfig):
|
| 1285 |
+
super().__init__(config)
|
| 1286 |
+
|
| 1287 |
+
self.auxiliary_loss = config.auxiliary_loss
|
| 1288 |
+
|
| 1289 |
+
# Create backbone + positional encoding
|
| 1290 |
+
self.backbone = DabDetrConvEncoder(config)
|
| 1291 |
+
object_queries = DabDetrSinePositionEmbedding(config)
|
| 1292 |
+
|
| 1293 |
+
self.query_refpoint_embeddings = nn.Embedding(config.num_queries, config.query_dim)
|
| 1294 |
+
self.random_refpoints_xy = config.random_refpoints_xy
|
| 1295 |
+
if self.random_refpoints_xy:
|
| 1296 |
+
self.query_refpoint_embeddings.weight.data[:, :2].uniform_(0, 1)
|
| 1297 |
+
self.query_refpoint_embeddings.weight.data[:, :2] = inverse_sigmoid(
|
| 1298 |
+
self.query_refpoint_embeddings.weight.data[:, :2]
|
| 1299 |
+
)
|
| 1300 |
+
self.query_refpoint_embeddings.weight.data[:, :2].requires_grad = False
|
| 1301 |
+
|
| 1302 |
+
# Create projection layer
|
| 1303 |
+
self.input_projection = nn.Conv2d(
|
| 1304 |
+
self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1
|
| 1305 |
+
)
|
| 1306 |
+
self.backbone = DabDetrConvModel(self.backbone, object_queries)
|
| 1307 |
+
|
| 1308 |
+
self.encoder = DabDetrEncoder(config)
|
| 1309 |
+
self.decoder = DabDetrDecoder(config)
|
| 1310 |
+
|
| 1311 |
+
# decoder related variables
|
| 1312 |
+
self.hidden_size = config.hidden_size
|
| 1313 |
+
self.num_queries = config.num_queries
|
| 1314 |
+
|
| 1315 |
+
self.num_patterns = config.num_patterns
|
| 1316 |
+
if not isinstance(self.num_patterns, int):
|
| 1317 |
+
logger.warning("num_patterns should be int but {}".format(type(self.num_patterns)))
|
| 1318 |
+
self.num_patterns = 0
|
| 1319 |
+
if self.num_patterns > 0:
|
| 1320 |
+
self.patterns = nn.Embedding(self.num_patterns, self.hidden_size)
|
| 1321 |
+
|
| 1322 |
+
self.aux_loss = config.auxiliary_loss
|
| 1323 |
+
|
| 1324 |
+
# Initialize weights and apply final processing
|
| 1325 |
+
self.post_init()
|
| 1326 |
+
|
| 1327 |
+
def get_encoder(self):
|
| 1328 |
+
return self.encoder
|
| 1329 |
+
|
| 1330 |
+
def get_decoder(self):
|
| 1331 |
+
return self.decoder
|
| 1332 |
+
|
| 1333 |
+
def freeze_backbone(self):
|
| 1334 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1335 |
+
param.requires_grad_(False)
|
| 1336 |
+
|
| 1337 |
+
def unfreeze_backbone(self):
|
| 1338 |
+
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
| 1339 |
+
param.requires_grad_(True)
|
| 1340 |
+
|
| 1341 |
+
@add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING)
|
| 1342 |
+
@replace_return_docstrings(output_type=DabDetrModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1343 |
+
def forward(
|
| 1344 |
+
self,
|
| 1345 |
+
pixel_values: torch.FloatTensor,
|
| 1346 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1347 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1348 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1349 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1350 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1351 |
+
output_attentions: Optional[bool] = None,
|
| 1352 |
+
output_hidden_states: Optional[bool] = None,
|
| 1353 |
+
return_dict: Optional[bool] = None,
|
| 1354 |
+
) -> Union[Tuple[torch.FloatTensor], DabDetrModelOutput]:
|
| 1355 |
+
r"""
|
| 1356 |
+
Returns:
|
| 1357 |
+
|
| 1358 |
+
Examples:
|
| 1359 |
+
|
| 1360 |
+
```python
|
| 1361 |
+
>>> from transformers import AutoImageProcessor, AutoModel
|
| 1362 |
+
>>> from PIL import Image
|
| 1363 |
+
>>> import requests
|
| 1364 |
+
|
| 1365 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1366 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1367 |
+
|
| 1368 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab_detr-base")
|
| 1369 |
+
>>> model = AutoModel.from_pretrained("IDEA-Research/dab_detr-base")
|
| 1370 |
+
|
| 1371 |
+
>>> # prepare image for the model
|
| 1372 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1373 |
+
|
| 1374 |
+
>>> # forward pass
|
| 1375 |
+
>>> outputs = model(**inputs)
|
| 1376 |
+
|
| 1377 |
+
>>> # the last hidden states are the final query embeddings of the Transformer decoder
|
| 1378 |
+
>>> # these are of shape (batch_size, num_queries, hidden_size)
|
| 1379 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1380 |
+
>>> list(last_hidden_states.shape)
|
| 1381 |
+
[1, 300, 256]
|
| 1382 |
+
```"""
|
| 1383 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1384 |
+
output_hidden_states = (
|
| 1385 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1386 |
+
)
|
| 1387 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1388 |
+
|
| 1389 |
+
batch_size, _, height, width = pixel_values.shape
|
| 1390 |
+
device = pixel_values.device
|
| 1391 |
+
|
| 1392 |
+
if pixel_mask is None:
|
| 1393 |
+
pixel_mask = torch.ones(((batch_size, height, width)), device=device)
|
| 1394 |
+
|
| 1395 |
+
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
| 1396 |
+
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
| 1397 |
+
# pixel_mask should be of shape (batch_size, height, width)
|
| 1398 |
+
features, object_queries_list = self.backbone(pixel_values, pixel_mask)
|
| 1399 |
+
|
| 1400 |
+
# get final feature map and downsampled mask
|
| 1401 |
+
feature_map, mask = features[-1]
|
| 1402 |
+
|
| 1403 |
+
if mask is None:
|
| 1404 |
+
raise ValueError("Backbone does not return downsampled pixel mask")
|
| 1405 |
+
|
| 1406 |
+
flattened_mask = mask.flatten(1)
|
| 1407 |
+
|
| 1408 |
+
# Second, apply 1x1 convolution to reduce the channel dimension to hidden_size (256 by default)
|
| 1409 |
+
projected_feature_map = self.input_projection(feature_map)
|
| 1410 |
+
|
| 1411 |
+
# Third, flatten the feature map + object_queries of shape NxCxHxW to HWxNxC, and permute it to NxHWxC
|
| 1412 |
+
# In other words, turn their shape into ( sequence_length, batch_size, hidden_size)
|
| 1413 |
+
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
| 1414 |
+
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
| 1415 |
+
reference_position_embeddings = self.query_refpoint_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 1416 |
+
|
| 1417 |
+
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
|
| 1418 |
+
# flattened_features is a Tensor of shape (heigth*width, batch_size, hidden_size)
|
| 1419 |
+
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
| 1420 |
+
if encoder_outputs is None:
|
| 1421 |
+
encoder_outputs = self.encoder(
|
| 1422 |
+
inputs_embeds=flattened_features,
|
| 1423 |
+
attention_mask=flattened_mask,
|
| 1424 |
+
object_queries=object_queries,
|
| 1425 |
+
output_attentions=output_attentions,
|
| 1426 |
+
output_hidden_states=output_hidden_states,
|
| 1427 |
+
return_dict=return_dict,
|
| 1428 |
+
)
|
| 1429 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1430 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 1431 |
+
encoder_outputs = BaseModelOutput(
|
| 1432 |
+
last_hidden_state=encoder_outputs[0],
|
| 1433 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1434 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
| 1438 |
+
num_queries = reference_position_embeddings.shape[1]
|
| 1439 |
+
if self.num_patterns == 0:
|
| 1440 |
+
queries = torch.zeros(batch_size, num_queries, self.hidden_size, device=device)
|
| 1441 |
+
else:
|
| 1442 |
+
queries = (
|
| 1443 |
+
self.patterns.weight[:, None, None, :]
|
| 1444 |
+
.repeat(1, self.num_queries, batch_size, 1)
|
| 1445 |
+
.flatten(0, 1)
|
| 1446 |
+
.permute(1, 0, 2)
|
| 1447 |
+
) # bs, n_q*n_pat, hidden_size
|
| 1448 |
+
reference_position_embeddings = reference_position_embeddings.repeat(
|
| 1449 |
+
1, self.num_patterns, 1
|
| 1450 |
+
) # bs, n_q*n_pat, hidden_size
|
| 1451 |
+
|
| 1452 |
+
# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
|
| 1453 |
+
decoder_outputs = self.decoder(
|
| 1454 |
+
inputs_embeds=queries,
|
| 1455 |
+
query_position_embeddings=reference_position_embeddings,
|
| 1456 |
+
object_queries=object_queries,
|
| 1457 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1458 |
+
memory_key_padding_mask=flattened_mask,
|
| 1459 |
+
output_attentions=output_attentions,
|
| 1460 |
+
output_hidden_states=output_hidden_states,
|
| 1461 |
+
return_dict=return_dict,
|
| 1462 |
+
)
|
| 1463 |
+
|
| 1464 |
+
if not return_dict:
|
| 1465 |
+
# last_hidden_state
|
| 1466 |
+
output = (decoder_outputs[0],)
|
| 1467 |
+
reference_points = decoder_outputs[-1]
|
| 1468 |
+
intermediate_hidden_states = decoder_outputs[-2]
|
| 1469 |
+
|
| 1470 |
+
# it has to follow the order of DABDETRModelOutput that is based on ModelOutput
|
| 1471 |
+
# If we only use one of the variables then the indexing will change.
|
| 1472 |
+
# E.g: if we return everything then 'decoder_attentions' is decoder_outputs[2], if we only use output_attentions then its decoder_outputs[1]
|
| 1473 |
+
if output_hidden_states and output_attentions:
|
| 1474 |
+
output += (
|
| 1475 |
+
decoder_outputs[1],
|
| 1476 |
+
decoder_outputs[2],
|
| 1477 |
+
decoder_outputs[3],
|
| 1478 |
+
encoder_outputs[0],
|
| 1479 |
+
encoder_outputs[1],
|
| 1480 |
+
encoder_outputs[2],
|
| 1481 |
+
)
|
| 1482 |
+
elif output_hidden_states:
|
| 1483 |
+
# decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states
|
| 1484 |
+
output += (
|
| 1485 |
+
decoder_outputs[1],
|
| 1486 |
+
encoder_outputs[0],
|
| 1487 |
+
encoder_outputs[1],
|
| 1488 |
+
)
|
| 1489 |
+
elif output_attentions:
|
| 1490 |
+
# decoder_self_attention, decoder_cross_attention, encoder_attentions
|
| 1491 |
+
output += (
|
| 1492 |
+
decoder_outputs[1],
|
| 1493 |
+
decoder_outputs[2],
|
| 1494 |
+
encoder_outputs[1],
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
output += (intermediate_hidden_states, reference_points)
|
| 1498 |
+
|
| 1499 |
+
return output
|
| 1500 |
+
|
| 1501 |
+
reference_points = decoder_outputs.reference_points
|
| 1502 |
+
intermediate_hidden_states = decoder_outputs.intermediate_hidden_states
|
| 1503 |
+
|
| 1504 |
+
return DabDetrModelOutput(
|
| 1505 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1506 |
+
decoder_hidden_states=decoder_outputs.hidden_states if output_hidden_states else None,
|
| 1507 |
+
decoder_attentions=decoder_outputs.attentions if output_attentions else None,
|
| 1508 |
+
cross_attentions=decoder_outputs.cross_attentions if output_attentions else None,
|
| 1509 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state if output_hidden_states else None,
|
| 1510 |
+
encoder_hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
|
| 1511 |
+
encoder_attentions=encoder_outputs.attentions if output_attentions else None,
|
| 1512 |
+
intermediate_hidden_states=intermediate_hidden_states,
|
| 1513 |
+
reference_points=reference_points,
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
|
| 1517 |
+
# Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->DabDetr
|
| 1518 |
+
class DabDetrMHAttentionMap(nn.Module):
|
| 1519 |
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
| 1520 |
+
|
| 1521 |
+
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
|
| 1522 |
+
super().__init__()
|
| 1523 |
+
self.num_heads = num_heads
|
| 1524 |
+
self.hidden_dim = hidden_dim
|
| 1525 |
+
self.dropout = nn.Dropout(dropout)
|
| 1526 |
+
|
| 1527 |
+
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1528 |
+
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 1529 |
+
|
| 1530 |
+
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
| 1531 |
+
|
| 1532 |
+
def forward(self, q, k, mask: Optional[Tensor] = None):
|
| 1533 |
+
q = self.q_linear(q)
|
| 1534 |
+
k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
| 1535 |
+
queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
| 1536 |
+
keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
| 1537 |
+
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
| 1538 |
+
|
| 1539 |
+
if mask is not None:
|
| 1540 |
+
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
| 1541 |
+
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
| 1542 |
+
weights = self.dropout(weights)
|
| 1543 |
+
return weights
|
| 1544 |
+
|
| 1545 |
+
|
| 1546 |
+
@add_start_docstrings(
|
| 1547 |
+
"""
|
| 1548 |
+
DAB_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
|
| 1549 |
+
top, for tasks such as COCO detection.
|
| 1550 |
+
""",
|
| 1551 |
+
DAB_DETR_START_DOCSTRING,
|
| 1552 |
+
)
|
| 1553 |
+
class DabDetrForObjectDetection(DabDetrPreTrainedModel):
|
| 1554 |
+
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
| 1555 |
+
_tied_weights_keys = [
|
| 1556 |
+
r"bbox_predictor\.layers\.\d+\.(weight|bias)",
|
| 1557 |
+
r"model\.decoder\.bbox_embed\.layers\.\d+\.(weight|bias)",
|
| 1558 |
+
]
|
| 1559 |
+
|
| 1560 |
+
def __init__(self, config: DabDetrConfig):
|
| 1561 |
+
super().__init__(config)
|
| 1562 |
+
|
| 1563 |
+
self.config = config
|
| 1564 |
+
self.auxiliary_loss = config.auxiliary_loss
|
| 1565 |
+
self.query_dim = config.query_dim
|
| 1566 |
+
# DAB-DETR encoder-decoder model
|
| 1567 |
+
self.model = DabDetrModel(config)
|
| 1568 |
+
|
| 1569 |
+
_bbox_embed = DabDetrMLP(config.hidden_size, config.hidden_size, 4, 3)
|
| 1570 |
+
# Object detection heads
|
| 1571 |
+
self.class_embed = nn.Linear(config.hidden_size, config.num_labels)
|
| 1572 |
+
|
| 1573 |
+
# Default bbox_embed_diff_each_layer is False
|
| 1574 |
+
self.bbox_predictor = _bbox_embed
|
| 1575 |
+
|
| 1576 |
+
# Default iter_update is True
|
| 1577 |
+
self.model.decoder.bbox_embed = self.bbox_predictor
|
| 1578 |
+
|
| 1579 |
+
# Initialize weights and apply final processing
|
| 1580 |
+
self.post_init()
|
| 1581 |
+
|
| 1582 |
+
# taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/dab_detr.py
|
| 1583 |
+
@torch.jit.unused
|
| 1584 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 1585 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 1586 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 1587 |
+
# as a dict having both a Tensor and a list.
|
| 1588 |
+
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
| 1589 |
+
|
| 1590 |
+
@add_start_docstrings_to_model_forward(DAB_DETR_INPUTS_DOCSTRING)
|
| 1591 |
+
@replace_return_docstrings(output_type=DabDetrObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
| 1592 |
+
def forward(
|
| 1593 |
+
self,
|
| 1594 |
+
pixel_values: torch.FloatTensor,
|
| 1595 |
+
pixel_mask: Optional[torch.LongTensor] = None,
|
| 1596 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1597 |
+
encoder_outputs: Optional[torch.FloatTensor] = None,
|
| 1598 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1599 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1600 |
+
labels: Optional[List[dict]] = None,
|
| 1601 |
+
output_attentions: Optional[bool] = None,
|
| 1602 |
+
output_hidden_states: Optional[bool] = None,
|
| 1603 |
+
return_dict: Optional[bool] = None,
|
| 1604 |
+
) -> Union[Tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]:
|
| 1605 |
+
r"""
|
| 1606 |
+
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
|
| 1607 |
+
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
| 1608 |
+
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
| 1609 |
+
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
| 1610 |
+
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
| 1611 |
+
|
| 1612 |
+
Returns:
|
| 1613 |
+
|
| 1614 |
+
Examples:
|
| 1615 |
+
|
| 1616 |
+
```python
|
| 1617 |
+
>>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
| 1618 |
+
>>> from PIL import Image
|
| 1619 |
+
>>> import requests
|
| 1620 |
+
|
| 1621 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1622 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1623 |
+
|
| 1624 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
| 1625 |
+
>>> model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
| 1626 |
+
|
| 1627 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1628 |
+
|
| 1629 |
+
>>> with torch.no_grad():
|
| 1630 |
+
>>> outputs = model(**inputs)
|
| 1631 |
+
|
| 1632 |
+
>>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
| 1633 |
+
>>> target_sizes = torch.tensor([(image.height, image.width)])
|
| 1634 |
+
>>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
|
| 1635 |
+
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 1636 |
+
... box = [round(i, 2) for i in box.tolist()]
|
| 1637 |
+
... print(
|
| 1638 |
+
... f"Detected {model.config.id2label[label.item()]} with confidence "
|
| 1639 |
+
... f"{round(score.item(), 3)} at location {box}"
|
| 1640 |
+
... )
|
| 1641 |
+
Detected remote with confidence 0.833 at location [38.31, 72.1, 177.63, 118.45]
|
| 1642 |
+
Detected cat with confidence 0.831 at location [9.2, 51.38, 321.13, 469.0]
|
| 1643 |
+
Detected cat with confidence 0.804 at location [340.3, 16.85, 642.93, 370.95]
|
| 1644 |
+
Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
|
| 1645 |
+
Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
|
| 1646 |
+
```"""
|
| 1647 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1648 |
+
output_hidden_states = (
|
| 1649 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1650 |
+
)
|
| 1651 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1652 |
+
|
| 1653 |
+
# First, sent images through DAB_DETR base model to obtain encoder + decoder outputs
|
| 1654 |
+
model_outputs = self.model(
|
| 1655 |
+
pixel_values,
|
| 1656 |
+
pixel_mask=pixel_mask,
|
| 1657 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1658 |
+
encoder_outputs=encoder_outputs,
|
| 1659 |
+
inputs_embeds=inputs_embeds,
|
| 1660 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1661 |
+
output_attentions=output_attentions,
|
| 1662 |
+
output_hidden_states=output_hidden_states,
|
| 1663 |
+
return_dict=return_dict,
|
| 1664 |
+
)
|
| 1665 |
+
|
| 1666 |
+
reference_points = model_outputs.reference_points if return_dict else model_outputs[-1]
|
| 1667 |
+
intermediate_hidden_states = model_outputs.intermediate_hidden_states if return_dict else model_outputs[-2]
|
| 1668 |
+
|
| 1669 |
+
# class logits + predicted bounding boxes
|
| 1670 |
+
logits = self.class_embed(intermediate_hidden_states[-1])
|
| 1671 |
+
|
| 1672 |
+
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
| 1673 |
+
bbox_with_refinement = self.bbox_predictor(intermediate_hidden_states)
|
| 1674 |
+
bbox_with_refinement[..., : self.query_dim] += reference_before_sigmoid
|
| 1675 |
+
outputs_coord = bbox_with_refinement.sigmoid()
|
| 1676 |
+
|
| 1677 |
+
pred_boxes = outputs_coord[-1]
|
| 1678 |
+
|
| 1679 |
+
loss, loss_dict, auxiliary_outputs = None, None, None
|
| 1680 |
+
if labels is not None:
|
| 1681 |
+
outputs_class = None
|
| 1682 |
+
if self.config.auxiliary_loss:
|
| 1683 |
+
outputs_class = self.class_embed(intermediate_hidden_states)
|
| 1684 |
+
loss, loss_dict, auxiliary_outputs = self.loss_function(
|
| 1685 |
+
logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
|
| 1686 |
+
)
|
| 1687 |
+
|
| 1688 |
+
if not return_dict:
|
| 1689 |
+
if auxiliary_outputs is not None:
|
| 1690 |
+
output = (logits, pred_boxes) + auxiliary_outputs + model_outputs
|
| 1691 |
+
else:
|
| 1692 |
+
output = (logits, pred_boxes) + model_outputs
|
| 1693 |
+
# Since DabDetrObjectDetectionOutput doesn't have reference points + intermedieate_hidden_states we cut down.
|
| 1694 |
+
return ((loss, loss_dict) + output) if loss is not None else output[:-2]
|
| 1695 |
+
|
| 1696 |
+
return DabDetrObjectDetectionOutput(
|
| 1697 |
+
loss=loss,
|
| 1698 |
+
loss_dict=loss_dict,
|
| 1699 |
+
logits=logits,
|
| 1700 |
+
pred_boxes=pred_boxes,
|
| 1701 |
+
auxiliary_outputs=auxiliary_outputs,
|
| 1702 |
+
last_hidden_state=model_outputs.last_hidden_state,
|
| 1703 |
+
decoder_hidden_states=model_outputs.decoder_hidden_states if output_hidden_states else None,
|
| 1704 |
+
decoder_attentions=model_outputs.decoder_attentions if output_attentions else None,
|
| 1705 |
+
cross_attentions=model_outputs.cross_attentions if output_attentions else None,
|
| 1706 |
+
encoder_last_hidden_state=model_outputs.encoder_last_hidden_state if output_hidden_states else None,
|
| 1707 |
+
encoder_hidden_states=model_outputs.encoder_hidden_states if output_hidden_states else None,
|
| 1708 |
+
encoder_attentions=model_outputs.encoder_attentions if output_attentions else None,
|
| 1709 |
+
)
|
| 1710 |
+
|
| 1711 |
+
|
| 1712 |
+
__all__ = [
|
| 1713 |
+
"DabDetrForObjectDetection",
|
| 1714 |
+
"DabDetrModel",
|
| 1715 |
+
"DabDetrPreTrainedModel",
|
| 1716 |
+
]
|
docs/transformers/build/lib/transformers/models/dac/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_dac import *
|
| 22 |
+
from .feature_extraction_dac import *
|
| 23 |
+
from .modeling_dac import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/dac/configuration_dac.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Dac model configuration"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DacConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of an [`DacModel`]. It is used to instantiate a
|
| 31 |
+
Dac model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 32 |
+
with the defaults will yield a similar configuration to that of the
|
| 33 |
+
[descript/dac_16khz](https://huggingface.co/descript/dac_16khz) architecture.
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
encoder_hidden_size (`int`, *optional*, defaults to 64):
|
| 40 |
+
Intermediate representation dimension for the encoder.
|
| 41 |
+
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 8, 8]`):
|
| 42 |
+
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
|
| 43 |
+
decoder_hidden_size (`int`, *optional*, defaults to 1536):
|
| 44 |
+
Intermediate representation dimension for the decoder.
|
| 45 |
+
n_codebooks (`int`, *optional*, defaults to 9):
|
| 46 |
+
Number of codebooks in the VQVAE.
|
| 47 |
+
codebook_size (`int`, *optional*, defaults to 1024):
|
| 48 |
+
Number of discrete codes in each codebook.
|
| 49 |
+
codebook_dim (`int`, *optional*, defaults to 8):
|
| 50 |
+
Dimension of the codebook vectors. If not defined, uses `encoder_hidden_size`.
|
| 51 |
+
quantizer_dropout (`bool`, *optional*, defaults to 0):
|
| 52 |
+
Whether to apply dropout to the quantizer.
|
| 53 |
+
commitment_loss_weight (float, *optional*, defaults to 0.25):
|
| 54 |
+
Weight of the commitment loss term in the VQVAE loss function.
|
| 55 |
+
codebook_loss_weight (float, *optional*, defaults to 1.0):
|
| 56 |
+
Weight of the codebook loss term in the VQVAE loss function.
|
| 57 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 58 |
+
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
| 59 |
+
Example:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
>>> from transformers import DacModel, DacConfig
|
| 63 |
+
|
| 64 |
+
>>> # Initializing a "descript/dac_16khz" style configuration
|
| 65 |
+
>>> configuration = DacConfig()
|
| 66 |
+
|
| 67 |
+
>>> # Initializing a model (with random weights) from the "descript/dac_16khz" style configuration
|
| 68 |
+
>>> model = DacModel(configuration)
|
| 69 |
+
|
| 70 |
+
>>> # Accessing the model configuration
|
| 71 |
+
>>> configuration = model.config
|
| 72 |
+
```"""
|
| 73 |
+
|
| 74 |
+
model_type = "dac"
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
encoder_hidden_size=64,
|
| 79 |
+
downsampling_ratios=[2, 4, 8, 8],
|
| 80 |
+
decoder_hidden_size=1536,
|
| 81 |
+
n_codebooks=9,
|
| 82 |
+
codebook_size=1024,
|
| 83 |
+
codebook_dim=8,
|
| 84 |
+
quantizer_dropout=0,
|
| 85 |
+
commitment_loss_weight=0.25,
|
| 86 |
+
codebook_loss_weight=1.0,
|
| 87 |
+
sampling_rate=16000,
|
| 88 |
+
**kwargs,
|
| 89 |
+
):
|
| 90 |
+
self.encoder_hidden_size = encoder_hidden_size
|
| 91 |
+
self.downsampling_ratios = downsampling_ratios
|
| 92 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 93 |
+
self.upsampling_ratios = downsampling_ratios[::-1]
|
| 94 |
+
self.n_codebooks = n_codebooks
|
| 95 |
+
self.codebook_size = codebook_size
|
| 96 |
+
self.codebook_dim = codebook_dim
|
| 97 |
+
self.quantizer_dropout = quantizer_dropout
|
| 98 |
+
self.sampling_rate = sampling_rate
|
| 99 |
+
|
| 100 |
+
self.hidden_size = encoder_hidden_size * (2 ** len(downsampling_ratios))
|
| 101 |
+
|
| 102 |
+
self.hop_length = int(np.prod(downsampling_ratios))
|
| 103 |
+
self.commitment_loss_weight = commitment_loss_weight
|
| 104 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 105 |
+
|
| 106 |
+
super().__init__(**kwargs)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def frame_rate(self) -> int:
|
| 110 |
+
hop_length = np.prod(self.upsampling_ratios)
|
| 111 |
+
return math.ceil(self.sampling_rate / hop_length)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = ["DacConfig"]
|
docs/transformers/build/lib/transformers/models/dac/convert_dac_checkpoint.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import argparse
|
| 16 |
+
import fnmatch
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from transformers import (
|
| 22 |
+
DacConfig,
|
| 23 |
+
DacFeatureExtractor,
|
| 24 |
+
DacModel,
|
| 25 |
+
logging,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# checkpoints downloaded using:
|
| 30 |
+
# pip install descript-audio-codec
|
| 31 |
+
# python3 -m dac download # downloads the default 44kHz variant
|
| 32 |
+
# python3 -m dac download --model_type 44khz # downloads the 44kHz variant
|
| 33 |
+
# python3 -m dac download --model_type 24khz # downloads the 24kHz variant
|
| 34 |
+
# python3 -m dac download --model_type 16khz # downloads the 16kHz variant
|
| 35 |
+
# More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
|
| 36 |
+
|
| 37 |
+
logging.set_verbosity_info()
|
| 38 |
+
logger = logging.get_logger("transformers.models.dac")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def match_pattern(string, pattern):
|
| 42 |
+
# Split the pattern into parts
|
| 43 |
+
pattern_parts = pattern.split(".")
|
| 44 |
+
string_parts = string.split(".")
|
| 45 |
+
|
| 46 |
+
pattern_block_count = string_block_count = 0
|
| 47 |
+
|
| 48 |
+
for part in pattern_parts:
|
| 49 |
+
if part.startswith("block"):
|
| 50 |
+
pattern_block_count += 1
|
| 51 |
+
|
| 52 |
+
for part in string_parts:
|
| 53 |
+
if part.startswith("block"):
|
| 54 |
+
string_block_count += 1
|
| 55 |
+
|
| 56 |
+
return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
TOP_LEVEL_KEYS = []
|
| 60 |
+
IGNORE_KEYS = []
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
MAPPING_ENCODER = {
|
| 64 |
+
"encoder.block.0": ["encoder.conv1"],
|
| 65 |
+
"encoder.block.5": ["encoder.snake1"],
|
| 66 |
+
"encoder.block.6": ["encoder.conv2"],
|
| 67 |
+
"encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
|
| 68 |
+
"encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
|
| 69 |
+
"encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
|
| 70 |
+
"encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
|
| 71 |
+
"encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
|
| 72 |
+
"encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
MAPPING_QUANTIZER = {
|
| 76 |
+
"quantizer.quantizers.*": ["quantizer.quantizers.*"],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
MAPPING_DECODER = {
|
| 80 |
+
"decoder.model.0": ["decoder.conv1"],
|
| 81 |
+
"decoder.model.5": ["decoder.snake1"],
|
| 82 |
+
"decoder.model.6": ["decoder.conv2"],
|
| 83 |
+
"decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
|
| 84 |
+
"decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
|
| 85 |
+
"decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
|
| 86 |
+
"decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
|
| 87 |
+
"decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
|
| 88 |
+
"decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
MAPPING = {
|
| 93 |
+
**MAPPING_ENCODER,
|
| 94 |
+
**MAPPING_QUANTIZER,
|
| 95 |
+
**MAPPING_DECODER,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
| 100 |
+
for attribute in key.split("."):
|
| 101 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 102 |
+
|
| 103 |
+
if weight_type is not None:
|
| 104 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 105 |
+
else:
|
| 106 |
+
hf_shape = hf_pointer.shape
|
| 107 |
+
|
| 108 |
+
if hf_shape != value.shape:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 111 |
+
f" {value.shape} for {full_name}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if weight_type == "weight":
|
| 115 |
+
hf_pointer.weight.data = value
|
| 116 |
+
elif weight_type == "weight_g":
|
| 117 |
+
hf_pointer.weight_g.data = value
|
| 118 |
+
elif weight_type == "weight_v":
|
| 119 |
+
hf_pointer.weight_v.data = value
|
| 120 |
+
elif weight_type == "bias":
|
| 121 |
+
hf_pointer.bias.data = value
|
| 122 |
+
elif weight_type == "alpha":
|
| 123 |
+
hf_pointer.alpha.data = value
|
| 124 |
+
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def should_ignore(name, ignore_keys):
|
| 128 |
+
for key in ignore_keys:
|
| 129 |
+
if key.endswith(".*"):
|
| 130 |
+
if name.startswith(key[:-1]):
|
| 131 |
+
return True
|
| 132 |
+
elif ".*." in key:
|
| 133 |
+
prefix, suffix = key.split(".*.")
|
| 134 |
+
if prefix in name and suffix in name:
|
| 135 |
+
return True
|
| 136 |
+
elif key in name:
|
| 137 |
+
return True
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def recursively_load_weights(orig_dict, hf_model, model_name):
|
| 142 |
+
unused_weights = []
|
| 143 |
+
|
| 144 |
+
if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
|
| 145 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
| 146 |
+
|
| 147 |
+
for name, value in orig_dict.items():
|
| 148 |
+
is_used = False
|
| 149 |
+
for key, mapped_key in MAPPING.items():
|
| 150 |
+
regex = re.compile(key)
|
| 151 |
+
if regex.search(name):
|
| 152 |
+
if len(mapped_key) == 1:
|
| 153 |
+
if mapped_key[0][0] == "q":
|
| 154 |
+
mapped_key = ".".join(name.split(".")[:-1])
|
| 155 |
+
else:
|
| 156 |
+
mapped_key = mapped_key[0]
|
| 157 |
+
elif len(mapped_key) == 3:
|
| 158 |
+
integers = re.findall(r"\b\d+\b", name)
|
| 159 |
+
if mapped_key[0][0] == "d":
|
| 160 |
+
mapped_key = "{}.{}.{}{}.{}".format(
|
| 161 |
+
mapped_key[0],
|
| 162 |
+
str(int(integers[0]) - 1),
|
| 163 |
+
mapped_key[1],
|
| 164 |
+
str(int(integers[1]) - 1),
|
| 165 |
+
mapped_key[2],
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
mapped_key = "{}.{}.{}{}.{}".format(
|
| 169 |
+
mapped_key[0],
|
| 170 |
+
str(int(integers[0]) - 1),
|
| 171 |
+
mapped_key[1],
|
| 172 |
+
str(int(integers[1]) + 1),
|
| 173 |
+
mapped_key[2],
|
| 174 |
+
)
|
| 175 |
+
elif len(mapped_key) == 2:
|
| 176 |
+
integers = re.findall(r"\b\d+\b", name)
|
| 177 |
+
mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
|
| 178 |
+
|
| 179 |
+
is_used = True
|
| 180 |
+
if "weight_g" in name:
|
| 181 |
+
weight_type = "weight_g"
|
| 182 |
+
elif "weight_v" in name:
|
| 183 |
+
weight_type = "weight_v"
|
| 184 |
+
elif "bias" in name:
|
| 185 |
+
weight_type = "bias"
|
| 186 |
+
elif "alpha" in name:
|
| 187 |
+
weight_type = "alpha"
|
| 188 |
+
elif "weight" in name:
|
| 189 |
+
weight_type = "weight"
|
| 190 |
+
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
| 191 |
+
|
| 192 |
+
if not is_used:
|
| 193 |
+
unused_weights.append(name)
|
| 194 |
+
|
| 195 |
+
print(list(set(unused_weights)))
|
| 196 |
+
|
| 197 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def convert_checkpoint(
|
| 202 |
+
model_name,
|
| 203 |
+
checkpoint_path,
|
| 204 |
+
pytorch_dump_folder_path,
|
| 205 |
+
sample_rate=16000,
|
| 206 |
+
repo_id=None,
|
| 207 |
+
):
|
| 208 |
+
model_dict = torch.load(checkpoint_path, "cpu", weights_only=True)
|
| 209 |
+
|
| 210 |
+
config = DacConfig()
|
| 211 |
+
|
| 212 |
+
metadata = model_dict["metadata"]["kwargs"]
|
| 213 |
+
config.encoder_hidden_size = metadata["encoder_dim"]
|
| 214 |
+
config.downsampling_ratios = metadata["encoder_rates"]
|
| 215 |
+
config.codebook_size = metadata["codebook_size"]
|
| 216 |
+
config.n_codebooks = metadata["n_codebooks"]
|
| 217 |
+
config.codebook_dim = metadata["codebook_dim"]
|
| 218 |
+
config.decoder_hidden_size = metadata["decoder_dim"]
|
| 219 |
+
config.upsampling_ratios = metadata["decoder_rates"]
|
| 220 |
+
config.quantizer_dropout = float(metadata["quantizer_dropout"])
|
| 221 |
+
config.sampling_rate = sample_rate
|
| 222 |
+
|
| 223 |
+
model = DacModel(config)
|
| 224 |
+
feature_extractor = DacFeatureExtractor()
|
| 225 |
+
feature_extractor.sampling_rate = sample_rate
|
| 226 |
+
|
| 227 |
+
original_checkpoint = model_dict["state_dict"]
|
| 228 |
+
|
| 229 |
+
model.apply_weight_norm()
|
| 230 |
+
recursively_load_weights(original_checkpoint, model, model_name)
|
| 231 |
+
model.remove_weight_norm()
|
| 232 |
+
|
| 233 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 234 |
+
|
| 235 |
+
if repo_id:
|
| 236 |
+
print("Pushing to the hub...")
|
| 237 |
+
feature_extractor.push_to_hub(repo_id)
|
| 238 |
+
model.push_to_hub(repo_id)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
parser = argparse.ArgumentParser()
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--model",
|
| 245 |
+
default="dac_44khz",
|
| 246 |
+
type=str,
|
| 247 |
+
help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
convert_checkpoint(
|
| 260 |
+
args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
|
| 261 |
+
)
|
docs/transformers/build/lib/transformers/models/dac/feature_extraction_dac.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for DAC"""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 22 |
+
from ...feature_extraction_utils import BatchFeature
|
| 23 |
+
from ...utils import PaddingStrategy, TensorType, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DacFeatureExtractor(SequenceFeatureExtractor):
|
| 30 |
+
r"""
|
| 31 |
+
Constructs an Dac feature extractor.
|
| 32 |
+
|
| 33 |
+
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
| 34 |
+
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
feature_size (`int`, *optional*, defaults to 1):
|
| 38 |
+
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
| 39 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 40 |
+
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
|
| 41 |
+
padding_value (`float`, *optional*, defaults to 0.0):
|
| 42 |
+
The value that is used for padding.
|
| 43 |
+
hop_length (`int`, *optional*, defaults to 512):
|
| 44 |
+
Overlap length between successive windows.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
model_input_names = ["input_values", "n_quantizers"]
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
feature_size: int = 1,
|
| 52 |
+
sampling_rate: int = 16000,
|
| 53 |
+
padding_value: float = 0.0,
|
| 54 |
+
hop_length: int = 512,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
|
| 60 |
+
def __call__(
|
| 61 |
+
self,
|
| 62 |
+
raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
| 63 |
+
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
| 64 |
+
truncation: Optional[bool] = False,
|
| 65 |
+
max_length: Optional[int] = None,
|
| 66 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 67 |
+
sampling_rate: Optional[int] = None,
|
| 68 |
+
) -> BatchFeature:
|
| 69 |
+
"""
|
| 70 |
+
Main method to featurize and prepare for the model one or several sequence(s).
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
| 74 |
+
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
| 75 |
+
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
| 76 |
+
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
| 77 |
+
(`feature_size = 2`).
|
| 78 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 79 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
| 80 |
+
index) among:
|
| 81 |
+
|
| 82 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
| 83 |
+
sequence if provided).
|
| 84 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 85 |
+
acceptable input length for the model if that argument is not provided.
|
| 86 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
| 87 |
+
lengths).
|
| 88 |
+
truncation (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
| 90 |
+
max_length (`int`, *optional*):
|
| 91 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 92 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
|
| 93 |
+
If set, will return tensors instead of list of python integers. Acceptable values are:
|
| 94 |
+
|
| 95 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 96 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 97 |
+
- `'np'`: Return Numpy `np.ndarray` objects.
|
| 98 |
+
sampling_rate (`int`, *optional*):
|
| 99 |
+
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
| 100 |
+
`sampling_rate` at the forward call to prevent silent errors.
|
| 101 |
+
"""
|
| 102 |
+
if sampling_rate is not None:
|
| 103 |
+
if sampling_rate != self.sampling_rate:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 106 |
+
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
| 107 |
+
f" {self.sampling_rate} and not {sampling_rate}."
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
logger.warning(
|
| 111 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 112 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if padding and truncation:
|
| 116 |
+
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
| 117 |
+
elif padding is None:
|
| 118 |
+
# by default let's pad the inputs
|
| 119 |
+
padding = True
|
| 120 |
+
|
| 121 |
+
is_batched = bool(
|
| 122 |
+
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if is_batched:
|
| 126 |
+
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
| 127 |
+
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
| 128 |
+
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
| 129 |
+
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
| 130 |
+
raw_audio = raw_audio.astype(np.float32)
|
| 131 |
+
|
| 132 |
+
# always return batch
|
| 133 |
+
if not is_batched:
|
| 134 |
+
raw_audio = [np.asarray(raw_audio).T]
|
| 135 |
+
|
| 136 |
+
# verify inputs are valid
|
| 137 |
+
for idx, example in enumerate(raw_audio):
|
| 138 |
+
if example.ndim > 2:
|
| 139 |
+
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
| 140 |
+
if self.feature_size == 1 and example.ndim != 1:
|
| 141 |
+
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
| 142 |
+
if self.feature_size == 2:
|
| 143 |
+
raise ValueError("Stereo audio isn't supported for now")
|
| 144 |
+
|
| 145 |
+
input_values = BatchFeature({"input_values": raw_audio})
|
| 146 |
+
|
| 147 |
+
# normal padding on batch
|
| 148 |
+
padded_inputs = self.pad(
|
| 149 |
+
input_values,
|
| 150 |
+
max_length=max_length,
|
| 151 |
+
truncation=truncation,
|
| 152 |
+
padding=padding,
|
| 153 |
+
return_attention_mask=False,
|
| 154 |
+
pad_to_multiple_of=self.hop_length,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if padding:
|
| 158 |
+
padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
|
| 159 |
+
|
| 160 |
+
input_values = []
|
| 161 |
+
for example in padded_inputs.pop("input_values"):
|
| 162 |
+
if self.feature_size == 1:
|
| 163 |
+
example = example[..., None]
|
| 164 |
+
input_values.append(example.T)
|
| 165 |
+
|
| 166 |
+
padded_inputs["input_values"] = input_values
|
| 167 |
+
if return_tensors is not None:
|
| 168 |
+
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
| 169 |
+
|
| 170 |
+
return padded_inputs
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
__all__ = ["DacFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/dac/modeling_dac.py
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Transformers DAC model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
from ...modeling_utils import PreTrainedModel
|
| 27 |
+
from ...utils import (
|
| 28 |
+
ModelOutput,
|
| 29 |
+
add_start_docstrings,
|
| 30 |
+
add_start_docstrings_to_model_forward,
|
| 31 |
+
replace_return_docstrings,
|
| 32 |
+
)
|
| 33 |
+
from .configuration_dac import DacConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# General docstring
|
| 37 |
+
_CONFIG_FOR_DOC = "DacConfig"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class DacOutput(ModelOutput):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
loss (`torch.Tensor`):
|
| 45 |
+
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
| 46 |
+
audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
|
| 47 |
+
Reconstructed audio data.
|
| 48 |
+
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
| 49 |
+
Quantized continuous representation of input.
|
| 50 |
+
audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
| 51 |
+
Codebook indices for each codebook (quantized discrete representation of input).
|
| 52 |
+
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
| 53 |
+
Projected latents (continuous representation of input before quantization).
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
loss: Optional[torch.FloatTensor] = None
|
| 57 |
+
audio_values: Optional[torch.FloatTensor] = None
|
| 58 |
+
quantized_representation: Optional[torch.FloatTensor] = None
|
| 59 |
+
audio_codes: Optional[torch.LongTensor] = None
|
| 60 |
+
projected_latents: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class DacEncoderOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
loss (`torch.Tensor`):
|
| 68 |
+
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
| 69 |
+
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
|
| 70 |
+
Quantized continuous representation of input.
|
| 71 |
+
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
| 72 |
+
Codebook indices for each codebook (quantized discrete representation of input).
|
| 73 |
+
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
|
| 74 |
+
Projected latents (continuous representation of input before quantization).
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
loss: Optional[torch.FloatTensor] = None
|
| 78 |
+
quantized_representation: Optional[torch.FloatTensor] = None
|
| 79 |
+
audio_codes: Optional[torch.FloatTensor] = None
|
| 80 |
+
projected_latents: Optional[torch.FloatTensor] = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
|
| 85 |
+
class DacDecoderOutput(ModelOutput):
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
|
| 89 |
+
Decoded audio values, obtained using the decoder part of Dac.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
audio_values: Optional[torch.FloatTensor] = None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Snake1d(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
A 1-dimensional Snake activation function module.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, hidden_dim):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
|
| 103 |
+
|
| 104 |
+
def forward(self, hidden_states):
|
| 105 |
+
shape = hidden_states.shape
|
| 106 |
+
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
|
| 107 |
+
hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
|
| 108 |
+
hidden_states = hidden_states.reshape(shape)
|
| 109 |
+
return hidden_states
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DacVectorQuantize(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
|
| 115 |
+
|
| 116 |
+
Additionally uses following tricks from improved VQGAN
|
| 117 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 118 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 119 |
+
for improved codebook usage
|
| 120 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 121 |
+
improves training stability
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, config: DacConfig):
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
|
| 128 |
+
self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
|
| 129 |
+
self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
|
| 130 |
+
|
| 131 |
+
def forward(self, hidden_state):
|
| 132 |
+
"""
|
| 133 |
+
Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
|
| 137 |
+
Input tensor.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
|
| 141 |
+
Quantized continuous representation of input.
|
| 142 |
+
commitment_loss (`torch.FloatTensor`of shape `(1)`):
|
| 143 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries.
|
| 144 |
+
codebook_loss (`torch.FloatTensor`of shape `(1)`):
|
| 145 |
+
Codebook loss to update the codebook.
|
| 146 |
+
audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
|
| 147 |
+
Codebook indices for each codebook, quantized discrete representation of input.
|
| 148 |
+
projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
| 149 |
+
Projected latents (continuous representation of input before quantization).
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
projected_latents = self.in_proj(hidden_state)
|
| 153 |
+
quantized_representation, audio_codes = self.decode_latents(projected_latents)
|
| 154 |
+
|
| 155 |
+
commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
|
| 156 |
+
codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
|
| 157 |
+
# noop in forward pass, straight-through gradient estimator in backward pass
|
| 158 |
+
quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
|
| 159 |
+
quantized_representation = self.out_proj(quantized_representation)
|
| 160 |
+
|
| 161 |
+
return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
|
| 162 |
+
|
| 163 |
+
def decode_latents(self, hidden_states):
|
| 164 |
+
batch_size, hidden_dim, sequence_length = hidden_states.shape
|
| 165 |
+
encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
|
| 166 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 167 |
+
|
| 168 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 169 |
+
encodings = F.normalize(encodings)
|
| 170 |
+
codebook = F.normalize(codebook)
|
| 171 |
+
|
| 172 |
+
# Compute euclidean distance with codebook
|
| 173 |
+
l2_norm = encodings.pow(2).sum(1, keepdim=True)
|
| 174 |
+
dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
|
| 175 |
+
|
| 176 |
+
indices = dist.max(1)[1]
|
| 177 |
+
indices = indices.reshape(hidden_states.size(0), -1)
|
| 178 |
+
quantized_representation = self.codebook(indices).transpose(1, 2)
|
| 179 |
+
return quantized_representation, indices
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class DacResidualUnit(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, dimension: int = 16, dilation: int = 1):
|
| 188 |
+
super().__init__()
|
| 189 |
+
pad = ((7 - 1) * dilation) // 2
|
| 190 |
+
|
| 191 |
+
self.snake1 = Snake1d(dimension)
|
| 192 |
+
self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
|
| 193 |
+
self.snake2 = Snake1d(dimension)
|
| 194 |
+
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
|
| 195 |
+
|
| 196 |
+
def forward(self, hidden_state):
|
| 197 |
+
"""
|
| 198 |
+
Forward pass through the residual unit.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
| 202 |
+
Input tensor .
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
| 206 |
+
Input tensor after passing through the residual unit.
|
| 207 |
+
"""
|
| 208 |
+
output_tensor = hidden_state
|
| 209 |
+
output_tensor = self.conv1(self.snake1(output_tensor))
|
| 210 |
+
output_tensor = self.conv2(self.snake2(output_tensor))
|
| 211 |
+
|
| 212 |
+
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
|
| 213 |
+
if padding > 0:
|
| 214 |
+
hidden_state = hidden_state[..., padding:-padding]
|
| 215 |
+
output_tensor = hidden_state + output_tensor
|
| 216 |
+
return output_tensor
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class DacEncoderBlock(nn.Module):
|
| 220 |
+
"""Encoder block used in DAC encoder."""
|
| 221 |
+
|
| 222 |
+
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
| 223 |
+
super().__init__()
|
| 224 |
+
|
| 225 |
+
dimension = config.encoder_hidden_size * 2**stride_index
|
| 226 |
+
self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
|
| 227 |
+
self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
|
| 228 |
+
self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
|
| 229 |
+
self.snake1 = Snake1d(dimension // 2)
|
| 230 |
+
self.conv1 = nn.Conv1d(
|
| 231 |
+
dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def forward(self, hidden_state):
|
| 235 |
+
hidden_state = self.res_unit1(hidden_state)
|
| 236 |
+
hidden_state = self.res_unit2(hidden_state)
|
| 237 |
+
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
| 238 |
+
hidden_state = self.conv1(hidden_state)
|
| 239 |
+
|
| 240 |
+
return hidden_state
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class DacDecoderBlock(nn.Module):
|
| 244 |
+
"""Decoder block used in DAC decoder."""
|
| 245 |
+
|
| 246 |
+
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
input_dim = config.decoder_hidden_size // 2**stride_index
|
| 250 |
+
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
| 251 |
+
self.snake1 = Snake1d(input_dim)
|
| 252 |
+
self.conv_t1 = nn.ConvTranspose1d(
|
| 253 |
+
input_dim,
|
| 254 |
+
output_dim,
|
| 255 |
+
kernel_size=2 * stride,
|
| 256 |
+
stride=stride,
|
| 257 |
+
padding=math.ceil(stride / 2),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
|
| 261 |
+
self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
|
| 262 |
+
self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
|
| 263 |
+
|
| 264 |
+
def forward(self, hidden_state):
|
| 265 |
+
hidden_state = self.snake1(hidden_state)
|
| 266 |
+
hidden_state = self.conv_t1(hidden_state)
|
| 267 |
+
hidden_state = self.res_unit1(hidden_state)
|
| 268 |
+
hidden_state = self.res_unit2(hidden_state)
|
| 269 |
+
hidden_state = self.res_unit3(hidden_state)
|
| 270 |
+
|
| 271 |
+
return hidden_state
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class DacResidualVectorQuantize(nn.Module):
|
| 275 |
+
"""
|
| 276 |
+
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, config: DacConfig):
|
| 280 |
+
super().__init__()
|
| 281 |
+
|
| 282 |
+
n_codebooks = config.n_codebooks
|
| 283 |
+
quantizer_dropout = config.quantizer_dropout
|
| 284 |
+
|
| 285 |
+
self.n_codebooks = n_codebooks
|
| 286 |
+
|
| 287 |
+
self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
|
| 288 |
+
self.quantizer_dropout = quantizer_dropout
|
| 289 |
+
|
| 290 |
+
def forward(self, hidden_state, n_quantizers: Optional[int] = None):
|
| 291 |
+
"""
|
| 292 |
+
Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
|
| 293 |
+
Args:
|
| 294 |
+
hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
| 295 |
+
Input tensor to be quantized.
|
| 296 |
+
n_quantizers (`int`, *optional*):
|
| 297 |
+
Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
|
| 298 |
+
this argument is ignored during training, and a random number of quantizers is used.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
| 302 |
+
Quantized continuous representation of input.
|
| 303 |
+
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
| 304 |
+
Codebook indices for each codebook (quantized discrete representation of input).
|
| 305 |
+
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
| 306 |
+
Projected latents (continuous representation of input before quantization).
|
| 307 |
+
commitment_loss (`torch.Tensor` of shape `(1)`):
|
| 308 |
+
Commitment loss to train the encoder to predict vectors closer to codebook entries.
|
| 309 |
+
codebook_loss (`torch.Tensor` of shape `(1)`):
|
| 310 |
+
Codebook loss to update the codebook.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
quantized_representation = 0
|
| 314 |
+
residual = hidden_state
|
| 315 |
+
commitment_loss = 0
|
| 316 |
+
codebook_loss = 0
|
| 317 |
+
|
| 318 |
+
audio_codes = []
|
| 319 |
+
projected_latents = []
|
| 320 |
+
|
| 321 |
+
n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
|
| 322 |
+
if self.training:
|
| 323 |
+
n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
|
| 324 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
|
| 325 |
+
n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
|
| 326 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 327 |
+
n_quantizers = n_quantizers.to(hidden_state.device)
|
| 328 |
+
|
| 329 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 330 |
+
if self.training is False and i >= n_quantizers:
|
| 331 |
+
break
|
| 332 |
+
|
| 333 |
+
quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
|
| 334 |
+
residual
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Create mask to apply quantizer dropout
|
| 338 |
+
mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
|
| 339 |
+
quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
|
| 340 |
+
residual = residual - quantized_representation_i
|
| 341 |
+
|
| 342 |
+
# Sum losses
|
| 343 |
+
commitment_loss += commitment_loss_i * mask
|
| 344 |
+
codebook_loss += codebook_loss_i * mask
|
| 345 |
+
|
| 346 |
+
audio_codes.append(indices_i)
|
| 347 |
+
projected_latents.append(projected_latents_i)
|
| 348 |
+
|
| 349 |
+
audio_codes = torch.stack(audio_codes, dim=1)
|
| 350 |
+
projected_latents = torch.cat(projected_latents, dim=1)
|
| 351 |
+
|
| 352 |
+
return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
|
| 353 |
+
|
| 354 |
+
def from_codes(self, audio_codes: torch.Tensor):
|
| 355 |
+
"""
|
| 356 |
+
Reconstructs the continuous representation from quantized codes.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
| 360 |
+
Quantized discrete representation of input.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
quantized_representation (`torch.Tensor`):
|
| 364 |
+
Quantized continuous representation of input.
|
| 365 |
+
projected_latents (`torch.Tensor`):
|
| 366 |
+
List of projected latents (continuous representations of input before quantization)
|
| 367 |
+
for each codebook.
|
| 368 |
+
audio_codes (`torch.Tensor`):
|
| 369 |
+
Codebook indices for each codebook.
|
| 370 |
+
"""
|
| 371 |
+
quantized_representation = 0.0
|
| 372 |
+
projected_latents = []
|
| 373 |
+
n_codebooks = audio_codes.shape[1]
|
| 374 |
+
for i in range(n_codebooks):
|
| 375 |
+
projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
|
| 376 |
+
projected_latents.append(projected_latents_i)
|
| 377 |
+
quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
|
| 378 |
+
return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
|
| 379 |
+
|
| 380 |
+
def from_latents(self, latents: torch.Tensor):
|
| 381 |
+
"""Reconstructs the quantized representation from unquantized latents.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
|
| 385 |
+
Continuous representation of input after projection.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
| 389 |
+
Quantized representation of the full-projected space.
|
| 390 |
+
quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
| 391 |
+
Quantized representation of the latent space (continuous representation before quantization).
|
| 392 |
+
"""
|
| 393 |
+
quantized_representation = 0
|
| 394 |
+
quantized_latents = []
|
| 395 |
+
codes = []
|
| 396 |
+
codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
|
| 397 |
+
dims = torch.cumsum(codebook_dims_tensor, dim=0)
|
| 398 |
+
|
| 399 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 400 |
+
for i in range(n_codebooks):
|
| 401 |
+
hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
|
| 402 |
+
quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
|
| 403 |
+
quantized_latents.append(quantized_latents_i)
|
| 404 |
+
codes.append(codes_i)
|
| 405 |
+
|
| 406 |
+
quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
|
| 407 |
+
quantized_representation = quantized_representation + quantized_representation_i
|
| 408 |
+
|
| 409 |
+
return quantized_representation, torch.cat(quantized_latents, dim=1)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class DacDecoder(nn.Module):
|
| 413 |
+
"""DAC Decoder"""
|
| 414 |
+
|
| 415 |
+
def __init__(self, config: DacConfig):
|
| 416 |
+
super().__init__()
|
| 417 |
+
|
| 418 |
+
input_channel = config.hidden_size
|
| 419 |
+
channels = config.decoder_hidden_size
|
| 420 |
+
strides = config.upsampling_ratios
|
| 421 |
+
|
| 422 |
+
# Add first conv layer
|
| 423 |
+
self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
|
| 424 |
+
|
| 425 |
+
# Add upsampling + MRF blocks
|
| 426 |
+
block = []
|
| 427 |
+
for stride_index, stride in enumerate(strides):
|
| 428 |
+
block += [DacDecoderBlock(config, stride, stride_index)]
|
| 429 |
+
|
| 430 |
+
self.block = nn.ModuleList(block)
|
| 431 |
+
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
| 432 |
+
self.snake1 = Snake1d(output_dim)
|
| 433 |
+
self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
|
| 434 |
+
self.tanh = nn.Tanh()
|
| 435 |
+
|
| 436 |
+
def forward(self, hidden_state):
|
| 437 |
+
hidden_state = self.conv1(hidden_state)
|
| 438 |
+
|
| 439 |
+
for layer in self.block:
|
| 440 |
+
hidden_state = layer(hidden_state)
|
| 441 |
+
|
| 442 |
+
hidden_state = self.snake1(hidden_state)
|
| 443 |
+
hidden_state = self.conv2(hidden_state)
|
| 444 |
+
hidden_state = self.tanh(hidden_state)
|
| 445 |
+
|
| 446 |
+
return hidden_state
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class DacEncoder(nn.Module):
|
| 450 |
+
"""DAC Encoder"""
|
| 451 |
+
|
| 452 |
+
def __init__(self, config: DacConfig):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
strides = config.downsampling_ratios
|
| 456 |
+
# Create first convolution
|
| 457 |
+
self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
|
| 458 |
+
|
| 459 |
+
self.block = []
|
| 460 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 461 |
+
for stride_index, stride in enumerate(strides):
|
| 462 |
+
stride_index = stride_index + 1
|
| 463 |
+
self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
|
| 464 |
+
|
| 465 |
+
self.block = nn.ModuleList(self.block)
|
| 466 |
+
d_model = config.encoder_hidden_size * 2**stride_index
|
| 467 |
+
self.snake1 = Snake1d(d_model)
|
| 468 |
+
self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
|
| 469 |
+
|
| 470 |
+
def forward(self, hidden_state):
|
| 471 |
+
hidden_state = self.conv1(hidden_state)
|
| 472 |
+
|
| 473 |
+
for module in self.block:
|
| 474 |
+
hidden_state = module(hidden_state)
|
| 475 |
+
|
| 476 |
+
hidden_state = self.snake1(hidden_state)
|
| 477 |
+
hidden_state = self.conv2(hidden_state)
|
| 478 |
+
|
| 479 |
+
return hidden_state
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class DacPreTrainedModel(PreTrainedModel):
|
| 483 |
+
"""
|
| 484 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
config_class = DacConfig
|
| 488 |
+
base_model_prefix = "dac"
|
| 489 |
+
main_input_name = "input_values"
|
| 490 |
+
|
| 491 |
+
def _init_weights(self, module):
|
| 492 |
+
if isinstance(module, nn.Conv1d):
|
| 493 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
| 494 |
+
nn.init.constant_(module.bias, 0)
|
| 495 |
+
|
| 496 |
+
def apply_weight_norm(self):
|
| 497 |
+
weight_norm = nn.utils.weight_norm
|
| 498 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 499 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 500 |
+
|
| 501 |
+
for layer in self.quantizer.quantizers:
|
| 502 |
+
weight_norm(layer.in_proj)
|
| 503 |
+
weight_norm(layer.out_proj)
|
| 504 |
+
|
| 505 |
+
weight_norm(self.encoder.conv1)
|
| 506 |
+
weight_norm(self.encoder.conv2)
|
| 507 |
+
|
| 508 |
+
for layer in self.encoder.block:
|
| 509 |
+
weight_norm(layer.conv1)
|
| 510 |
+
weight_norm(layer.res_unit1.conv1)
|
| 511 |
+
weight_norm(layer.res_unit1.conv2)
|
| 512 |
+
weight_norm(layer.res_unit2.conv1)
|
| 513 |
+
weight_norm(layer.res_unit2.conv2)
|
| 514 |
+
weight_norm(layer.res_unit3.conv1)
|
| 515 |
+
weight_norm(layer.res_unit3.conv2)
|
| 516 |
+
|
| 517 |
+
weight_norm(self.decoder.conv1)
|
| 518 |
+
weight_norm(self.decoder.conv2)
|
| 519 |
+
|
| 520 |
+
for layer in self.decoder.block:
|
| 521 |
+
weight_norm(layer.conv_t1)
|
| 522 |
+
weight_norm(layer.res_unit1.conv1)
|
| 523 |
+
weight_norm(layer.res_unit1.conv2)
|
| 524 |
+
weight_norm(layer.res_unit2.conv1)
|
| 525 |
+
weight_norm(layer.res_unit2.conv2)
|
| 526 |
+
weight_norm(layer.res_unit3.conv1)
|
| 527 |
+
weight_norm(layer.res_unit3.conv2)
|
| 528 |
+
|
| 529 |
+
def remove_weight_norm(self):
|
| 530 |
+
for layer in self.quantizer.quantizers:
|
| 531 |
+
nn.utils.remove_weight_norm(layer.in_proj)
|
| 532 |
+
nn.utils.remove_weight_norm(layer.out_proj)
|
| 533 |
+
|
| 534 |
+
nn.utils.remove_weight_norm(self.encoder.conv1)
|
| 535 |
+
nn.utils.remove_weight_norm(self.encoder.conv2)
|
| 536 |
+
|
| 537 |
+
for layer in self.encoder.block:
|
| 538 |
+
nn.utils.remove_weight_norm(layer.conv1)
|
| 539 |
+
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
| 540 |
+
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
| 541 |
+
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
| 542 |
+
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
| 543 |
+
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
| 544 |
+
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
| 545 |
+
|
| 546 |
+
nn.utils.remove_weight_norm(self.decoder.conv1)
|
| 547 |
+
nn.utils.remove_weight_norm(self.decoder.conv2)
|
| 548 |
+
|
| 549 |
+
for layer in self.decoder.block:
|
| 550 |
+
nn.utils.remove_weight_norm(layer.conv_t1)
|
| 551 |
+
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
| 552 |
+
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
| 553 |
+
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
| 554 |
+
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
| 555 |
+
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
| 556 |
+
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
DAC_START_DOCSTRING = r"""
|
| 560 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 561 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 562 |
+
etc.)
|
| 563 |
+
|
| 564 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 565 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 566 |
+
and behavior.
|
| 567 |
+
|
| 568 |
+
Parameters:
|
| 569 |
+
config ([`DacConfig`]):
|
| 570 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 571 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 572 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
DAC_INPUTS_DOCSTRING = r"""
|
| 576 |
+
Args:
|
| 577 |
+
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
|
| 578 |
+
Audio data to encode,
|
| 579 |
+
n_quantizers (`int`, *optional*):
|
| 580 |
+
Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
|
| 581 |
+
return_dict (`bool`, *optional*):
|
| 582 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 583 |
+
"""
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@add_start_docstrings(
|
| 587 |
+
"The DAC (Descript Audio Codec) model.",
|
| 588 |
+
DAC_START_DOCSTRING,
|
| 589 |
+
)
|
| 590 |
+
class DacModel(DacPreTrainedModel):
|
| 591 |
+
def __init__(self, config: DacConfig):
|
| 592 |
+
super().__init__(config)
|
| 593 |
+
self.config = config
|
| 594 |
+
|
| 595 |
+
self.encoder = DacEncoder(config)
|
| 596 |
+
self.decoder = DacDecoder(config)
|
| 597 |
+
|
| 598 |
+
self.quantizer = DacResidualVectorQuantize(config)
|
| 599 |
+
|
| 600 |
+
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
|
| 601 |
+
if 2**self.bits_per_codebook != self.config.codebook_size:
|
| 602 |
+
raise ValueError("The codebook_size must be a power of 2.")
|
| 603 |
+
|
| 604 |
+
# Initialize weights and apply final processing
|
| 605 |
+
self.post_init()
|
| 606 |
+
|
| 607 |
+
@replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
| 608 |
+
def encode(
|
| 609 |
+
self,
|
| 610 |
+
input_values: torch.Tensor,
|
| 611 |
+
n_quantizers: Optional[int] = None,
|
| 612 |
+
return_dict: Optional[bool] = None,
|
| 613 |
+
):
|
| 614 |
+
"""
|
| 615 |
+
Encode given audio data and return quantized latent codes
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
|
| 619 |
+
Input audio data to encode,
|
| 620 |
+
n_quantizers (int, *optional*):
|
| 621 |
+
Number of quantizers to use. If None, all quantizers are used. Default is None.
|
| 622 |
+
return_dict (`bool`, *optional*):
|
| 623 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 624 |
+
Returns:
|
| 625 |
+
|
| 626 |
+
"""
|
| 627 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 628 |
+
|
| 629 |
+
quantized_representation = self.encoder(input_values)
|
| 630 |
+
quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
|
| 631 |
+
quantized_representation, n_quantizers
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
|
| 635 |
+
|
| 636 |
+
if not return_dict:
|
| 637 |
+
return (loss, quantized_representation, audio_codes, projected_latents)
|
| 638 |
+
|
| 639 |
+
return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
|
| 640 |
+
|
| 641 |
+
@replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
|
| 642 |
+
def decode(
|
| 643 |
+
self,
|
| 644 |
+
quantized_representation: Optional[torch.Tensor] = None,
|
| 645 |
+
audio_codes: Optional[torch.Tensor] = None,
|
| 646 |
+
return_dict: Optional[bool] = None,
|
| 647 |
+
):
|
| 648 |
+
"""Decode given latent codes and return audio data
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
|
| 652 |
+
Quantized continuous representation of input.
|
| 653 |
+
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
| 654 |
+
The codebook indices for each codebook, representing the quantized discrete
|
| 655 |
+
representation of the input. This parameter should be provided if you want
|
| 656 |
+
to decode directly from the audio codes (it will overwrite quantized_representation).
|
| 657 |
+
return_dict (`bool`, *optional*):
|
| 658 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
|
| 662 |
+
"""
|
| 663 |
+
|
| 664 |
+
if quantized_representation is None and audio_codes is None:
|
| 665 |
+
raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
|
| 666 |
+
|
| 667 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 668 |
+
|
| 669 |
+
if audio_codes is not None:
|
| 670 |
+
quantized_representation = self.quantizer.from_codes(audio_codes)[0]
|
| 671 |
+
|
| 672 |
+
audio_values = self.decoder(quantized_representation).squeeze(1)
|
| 673 |
+
|
| 674 |
+
if not return_dict:
|
| 675 |
+
return (audio_values,)
|
| 676 |
+
|
| 677 |
+
return DacDecoderOutput(audio_values)
|
| 678 |
+
|
| 679 |
+
@add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
|
| 680 |
+
@replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
|
| 681 |
+
def forward(
|
| 682 |
+
self,
|
| 683 |
+
input_values: torch.Tensor,
|
| 684 |
+
n_quantizers: Optional[int] = None,
|
| 685 |
+
return_dict: Optional[bool] = None,
|
| 686 |
+
):
|
| 687 |
+
"""
|
| 688 |
+
Returns:
|
| 689 |
+
Examples:
|
| 690 |
+
|
| 691 |
+
```python
|
| 692 |
+
>>> from datasets import load_dataset, Audio
|
| 693 |
+
>>> from transformers import DacModel, AutoProcessor
|
| 694 |
+
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 695 |
+
|
| 696 |
+
>>> model = DacModel.from_pretrained("descript/dac_16khz")
|
| 697 |
+
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
|
| 698 |
+
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
| 699 |
+
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
| 700 |
+
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
| 701 |
+
|
| 702 |
+
>>> encoder_outputs = model.encode(inputs["input_values"])
|
| 703 |
+
>>> # Get the intermediate audio codes
|
| 704 |
+
>>> audio_codes = encoder_outputs.audio_codes
|
| 705 |
+
>>> # Reconstruct the audio from its quantized representation
|
| 706 |
+
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
|
| 707 |
+
>>> # or the equivalent with a forward pass
|
| 708 |
+
>>> audio_values = model(inputs["input_values"]).audio_values
|
| 709 |
+
```"""
|
| 710 |
+
|
| 711 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 712 |
+
length = input_values.shape[-1]
|
| 713 |
+
loss, quantized_representation, audio_codes, projected_latents = self.encode(
|
| 714 |
+
input_values, n_quantizers, return_dict=False
|
| 715 |
+
)
|
| 716 |
+
audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
|
| 717 |
+
|
| 718 |
+
if not return_dict:
|
| 719 |
+
return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
| 720 |
+
|
| 721 |
+
return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
__all__ = ["DacModel", "DacPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/data2vec/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_data2vec_audio import *
|
| 22 |
+
from .configuration_data2vec_text import *
|
| 23 |
+
from .configuration_data2vec_vision import *
|
| 24 |
+
from .modeling_data2vec_audio import *
|
| 25 |
+
from .modeling_data2vec_text import *
|
| 26 |
+
from .modeling_data2vec_vision import *
|
| 27 |
+
from .modeling_tf_data2vec_vision import *
|
| 28 |
+
else:
|
| 29 |
+
import sys
|
| 30 |
+
|
| 31 |
+
_file = globals()["__file__"]
|
| 32 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_audio.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Data2VecText configuration"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Data2VecAudioConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`Data2VecAudioModel`]. It is used to instantiate
|
| 29 |
+
an Data2VecAudio model according to the specified arguments, defining the model architecture. Instantiating a
|
| 30 |
+
configuration with the defaults will yield a similar configuration to that of the Data2VecAudio
|
| 31 |
+
[facebook/data2vec-audio-base-960h](https://huggingface.co/facebook/data2vec-audio-base-960h) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
| 39 |
+
Vocabulary size of the Data2VecAudio model. Defines the number of different tokens that can be represented
|
| 40 |
+
by the `inputs_ids` passed when calling [`Data2VecAudioModel`] or [`TFData2VecAudioModel`]. Vocabulary size
|
| 41 |
+
of the model. Defines the different tokens that can be represented by the *inputs_ids* passed to the
|
| 42 |
+
forward method of [`Data2VecAudioModel`].
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 46 |
+
Number of hidden layers in the Transformer encoder.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 50 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 51 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 52 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 53 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 54 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 58 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 59 |
+
The dropout ratio for the attention probabilities.
|
| 60 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
| 61 |
+
The dropout probability for the final projection layer of [`Data2VecAudioForCTC`].
|
| 62 |
+
layerdrop (`float`, *optional*, defaults to 0.1):
|
| 63 |
+
The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
|
| 64 |
+
details.
|
| 65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 67 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 68 |
+
The epsilon used by the layer normalization layers.
|
| 69 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
| 70 |
+
The dropout probability for output of the feature encoder.
|
| 71 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
| 72 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
| 73 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 74 |
+
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
| 75 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
| 76 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
| 77 |
+
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
| 78 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
| 79 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
| 80 |
+
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
| 81 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
| 82 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
| 83 |
+
*conv_dim*.
|
| 84 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 85 |
+
Whether the 1D convolutional layers have a bias.
|
| 86 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 87 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 88 |
+
embeddings layer.
|
| 89 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 90 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
| 91 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
| 92 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
| 93 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
| 94 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
| 95 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
| 96 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
| 97 |
+
Length of vector span along the time axis.
|
| 98 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
| 99 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
| 100 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
| 101 |
+
mask_time_min_masks''
|
| 102 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
| 103 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
| 104 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
| 105 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
| 106 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
| 107 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
| 108 |
+
True`.
|
| 109 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
| 110 |
+
Length of vector span along the feature axis.
|
| 111 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
| 112 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
| 113 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
| 114 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
| 115 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
| 116 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
| 117 |
+
instance of [`Data2VecAudioForCTC`].
|
| 118 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
| 119 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
| 120 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
| 121 |
+
of [`Data2VecAudioForCTC`].
|
| 122 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
| 123 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
| 124 |
+
instance of [`Data2VecAudioForSequenceClassification`].
|
| 125 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
| 126 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
| 127 |
+
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
|
| 128 |
+
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
|
| 129 |
+
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
|
| 130 |
+
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
|
| 131 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
|
| 132 |
+
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
|
| 133 |
+
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
|
| 134 |
+
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
|
| 135 |
+
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
|
| 136 |
+
xvector_output_dim (`int`, *optional*, defaults to 512):
|
| 137 |
+
Dimensionality of the *XVector* embedding vectors.
|
| 138 |
+
add_adapter (`bool`, *optional*, defaults to `False`):
|
| 139 |
+
Whether a convolutional network should be stacked on top of the Data2VecAudio Encoder. Can be very useful
|
| 140 |
+
for warm-starting Data2VecAudio for SpeechEncoderDecoder models.
|
| 141 |
+
adapter_kernel_size (`int`, *optional*, defaults to 3):
|
| 142 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 143 |
+
adapter_stride (`int`, *optional*, defaults to 2):
|
| 144 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 145 |
+
num_adapter_layers (`int`, *optional*, defaults to 3):
|
| 146 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
| 147 |
+
True`.
|
| 148 |
+
output_hidden_size (`int`, *optional*):
|
| 149 |
+
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
|
| 150 |
+
if `add_adapter is True`.
|
| 151 |
+
|
| 152 |
+
Example:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
>>> from transformers import Data2VecAudioConfig, Data2VecAudioModel
|
| 156 |
+
|
| 157 |
+
>>> # Initializing a Data2VecAudio facebook/data2vec-audio-base-960h style configuration
|
| 158 |
+
>>> configuration = Data2VecAudioConfig()
|
| 159 |
+
|
| 160 |
+
>>> # Initializing a model (with random weights) from the facebook/data2vec-audio-base-960h style configuration
|
| 161 |
+
>>> model = Data2VecAudioModel(configuration)
|
| 162 |
+
|
| 163 |
+
>>> # Accessing the model configuration
|
| 164 |
+
>>> configuration = model.config
|
| 165 |
+
```"""
|
| 166 |
+
|
| 167 |
+
model_type = "data2vec-audio"
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
vocab_size=32,
|
| 172 |
+
hidden_size=768,
|
| 173 |
+
num_hidden_layers=12,
|
| 174 |
+
num_attention_heads=12,
|
| 175 |
+
intermediate_size=3072,
|
| 176 |
+
hidden_act="gelu",
|
| 177 |
+
hidden_dropout=0.1,
|
| 178 |
+
activation_dropout=0.1,
|
| 179 |
+
attention_dropout=0.1,
|
| 180 |
+
feat_proj_dropout=0.0,
|
| 181 |
+
final_dropout=0.1,
|
| 182 |
+
layerdrop=0.1,
|
| 183 |
+
initializer_range=0.02,
|
| 184 |
+
layer_norm_eps=1e-5,
|
| 185 |
+
feat_extract_activation="gelu",
|
| 186 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
| 187 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
| 188 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
| 189 |
+
conv_bias=False,
|
| 190 |
+
num_conv_pos_embedding_groups=16,
|
| 191 |
+
conv_pos_kernel_size=19,
|
| 192 |
+
num_conv_pos_embeddings=5,
|
| 193 |
+
mask_time_prob=0.05,
|
| 194 |
+
mask_time_length=10,
|
| 195 |
+
mask_time_min_masks=2,
|
| 196 |
+
mask_feature_prob=0.0,
|
| 197 |
+
mask_feature_length=10,
|
| 198 |
+
mask_feature_min_masks=0,
|
| 199 |
+
ctc_loss_reduction="sum",
|
| 200 |
+
ctc_zero_infinity=False,
|
| 201 |
+
use_weighted_layer_sum=False,
|
| 202 |
+
classifier_proj_size=256,
|
| 203 |
+
tdnn_dim=(512, 512, 512, 512, 1500),
|
| 204 |
+
tdnn_kernel=(5, 3, 3, 1, 1),
|
| 205 |
+
tdnn_dilation=(1, 2, 3, 1, 1),
|
| 206 |
+
xvector_output_dim=512,
|
| 207 |
+
pad_token_id=0,
|
| 208 |
+
bos_token_id=1,
|
| 209 |
+
eos_token_id=2,
|
| 210 |
+
add_adapter=False,
|
| 211 |
+
adapter_kernel_size=3,
|
| 212 |
+
adapter_stride=2,
|
| 213 |
+
num_adapter_layers=3,
|
| 214 |
+
output_hidden_size=None,
|
| 215 |
+
**kwargs,
|
| 216 |
+
):
|
| 217 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 218 |
+
self.hidden_size = hidden_size
|
| 219 |
+
self.feat_extract_activation = feat_extract_activation
|
| 220 |
+
self.conv_dim = list(conv_dim)
|
| 221 |
+
self.conv_stride = list(conv_stride)
|
| 222 |
+
self.conv_kernel = list(conv_kernel)
|
| 223 |
+
self.conv_bias = conv_bias
|
| 224 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 225 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 226 |
+
self.conv_pos_kernel_size = conv_pos_kernel_size
|
| 227 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 228 |
+
self.num_hidden_layers = num_hidden_layers
|
| 229 |
+
self.intermediate_size = intermediate_size
|
| 230 |
+
self.hidden_act = hidden_act
|
| 231 |
+
self.num_attention_heads = num_attention_heads
|
| 232 |
+
self.hidden_dropout = hidden_dropout
|
| 233 |
+
self.attention_dropout = attention_dropout
|
| 234 |
+
self.activation_dropout = activation_dropout
|
| 235 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 236 |
+
self.final_dropout = final_dropout
|
| 237 |
+
self.layerdrop = layerdrop
|
| 238 |
+
self.layer_norm_eps = layer_norm_eps
|
| 239 |
+
self.initializer_range = initializer_range
|
| 240 |
+
self.vocab_size = vocab_size
|
| 241 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
| 242 |
+
|
| 243 |
+
if (
|
| 244 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 245 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 246 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 247 |
+
):
|
| 248 |
+
raise ValueError(
|
| 249 |
+
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
|
| 250 |
+
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
|
| 251 |
+
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
|
| 252 |
+
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
| 256 |
+
self.mask_time_prob = mask_time_prob
|
| 257 |
+
self.mask_time_length = mask_time_length
|
| 258 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 259 |
+
self.mask_feature_prob = mask_feature_prob
|
| 260 |
+
self.mask_feature_length = mask_feature_length
|
| 261 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 262 |
+
|
| 263 |
+
# ctc loss
|
| 264 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 265 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 266 |
+
|
| 267 |
+
# adapter
|
| 268 |
+
self.add_adapter = add_adapter
|
| 269 |
+
self.adapter_kernel_size = adapter_kernel_size
|
| 270 |
+
self.adapter_stride = adapter_stride
|
| 271 |
+
self.num_adapter_layers = num_adapter_layers
|
| 272 |
+
self.output_hidden_size = output_hidden_size or hidden_size
|
| 273 |
+
|
| 274 |
+
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
| 275 |
+
self.classifier_proj_size = classifier_proj_size
|
| 276 |
+
|
| 277 |
+
# XVector-specific parameters. Feel free to ignore for other classes.
|
| 278 |
+
self.tdnn_dim = list(tdnn_dim)
|
| 279 |
+
self.tdnn_kernel = list(tdnn_kernel)
|
| 280 |
+
self.tdnn_dilation = list(tdnn_dilation)
|
| 281 |
+
self.xvector_output_dim = xvector_output_dim
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def inputs_to_logits_ratio(self):
|
| 285 |
+
return math.prod(self.conv_stride)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
__all__ = ["Data2VecAudioConfig"]
|
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_text.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Data2VecText configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...onnx import OnnxConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Data2VecTextConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`Data2VecTextModel`] and [`Data2VecTextModel`]. It
|
| 31 |
+
is used to instantiate a Data2VecText model according to the specified arguments, defining the model architecture.
|
| 32 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Data2VecText
|
| 33 |
+
[facebook/data2vec-text-base](https://huggingface.co/facebook/data2vec-text-base) architecture.
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 41 |
+
Vocabulary size of the DATA2VEC model. Defines the number of different tokens that can be represented by
|
| 42 |
+
the `inputs_ids` passed when calling [`Data2VecModel`].
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 46 |
+
Number of hidden layers in the Transformer encoder.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 50 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 51 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 52 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 53 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 54 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout ratio for the attention probabilities.
|
| 58 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 59 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 60 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 61 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 62 |
+
The vocabulary size of the `token_type_ids` passed when calling [`Data2VecModel`].
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 66 |
+
The epsilon used by the layer normalization layers.
|
| 67 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 68 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 69 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 70 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 71 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 72 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 73 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 74 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 75 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 77 |
+
relevant if `config.is_decoder=True`.
|
| 78 |
+
classifier_dropout (`float`, *optional*):
|
| 79 |
+
The dropout ratio for the classification head.
|
| 80 |
+
|
| 81 |
+
Examples:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
>>> from transformers import Data2VecTextConfig, Data2VecTextModel
|
| 85 |
+
|
| 86 |
+
>>> # Initializing a Data2VecText facebook/data2vec-text-base style configuration
|
| 87 |
+
>>> configuration = Data2VecTextConfig()
|
| 88 |
+
|
| 89 |
+
>>> # Initializing a model (with random weights) from the facebook/data2vec-text-base style configuration
|
| 90 |
+
>>> model = Data2VecTextModel(configuration)
|
| 91 |
+
|
| 92 |
+
>>> # Accessing the model configuration
|
| 93 |
+
>>> configuration = model.config
|
| 94 |
+
```"""
|
| 95 |
+
|
| 96 |
+
model_type = "data2vec-text"
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
vocab_size=30522,
|
| 101 |
+
hidden_size=768,
|
| 102 |
+
num_hidden_layers=12,
|
| 103 |
+
num_attention_heads=12,
|
| 104 |
+
intermediate_size=3072,
|
| 105 |
+
hidden_act="gelu",
|
| 106 |
+
hidden_dropout_prob=0.1,
|
| 107 |
+
attention_probs_dropout_prob=0.1,
|
| 108 |
+
max_position_embeddings=512,
|
| 109 |
+
type_vocab_size=2,
|
| 110 |
+
initializer_range=0.02,
|
| 111 |
+
layer_norm_eps=1e-12,
|
| 112 |
+
pad_token_id=1,
|
| 113 |
+
bos_token_id=0,
|
| 114 |
+
eos_token_id=2,
|
| 115 |
+
position_embedding_type="absolute",
|
| 116 |
+
use_cache=True,
|
| 117 |
+
classifier_dropout=None,
|
| 118 |
+
**kwargs,
|
| 119 |
+
):
|
| 120 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 121 |
+
|
| 122 |
+
self.vocab_size = vocab_size
|
| 123 |
+
self.hidden_size = hidden_size
|
| 124 |
+
self.num_hidden_layers = num_hidden_layers
|
| 125 |
+
self.num_attention_heads = num_attention_heads
|
| 126 |
+
self.hidden_act = hidden_act
|
| 127 |
+
self.intermediate_size = intermediate_size
|
| 128 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 129 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 130 |
+
self.max_position_embeddings = max_position_embeddings
|
| 131 |
+
self.type_vocab_size = type_vocab_size
|
| 132 |
+
self.initializer_range = initializer_range
|
| 133 |
+
self.layer_norm_eps = layer_norm_eps
|
| 134 |
+
self.position_embedding_type = position_embedding_type
|
| 135 |
+
self.use_cache = use_cache
|
| 136 |
+
self.classifier_dropout = classifier_dropout
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Data2VecTextOnnxConfig(OnnxConfig):
|
| 140 |
+
@property
|
| 141 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 142 |
+
if self.task == "multiple-choice":
|
| 143 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 144 |
+
else:
|
| 145 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 146 |
+
return OrderedDict(
|
| 147 |
+
[
|
| 148 |
+
("input_ids", dynamic_axis),
|
| 149 |
+
("attention_mask", dynamic_axis),
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
__all__ = ["Data2VecTextConfig", "Data2VecTextOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/data2vec/configuration_data2vec_vision.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Data2VecVision model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...onnx import OnnxConfig
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Data2VecVisionConfig(PretrainedConfig):
|
| 31 |
+
r"""
|
| 32 |
+
This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate
|
| 33 |
+
an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a
|
| 34 |
+
configuration with the defaults will yield a similar configuration to that of the Data2VecVision
|
| 35 |
+
[facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 39 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 43 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 44 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 45 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 46 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 47 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 48 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 49 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 52 |
+
The dropout ratio for the attention probabilities.
|
| 53 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 54 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 55 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 56 |
+
The epsilon used by the layer normalization layers.
|
| 57 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 58 |
+
The size (resolution) of each image.
|
| 59 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 60 |
+
The size (resolution) of each patch.
|
| 61 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 62 |
+
The number of input channels.
|
| 63 |
+
use_mask_token (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether to use a mask token for masked image modeling.
|
| 65 |
+
use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether to use BERT-style absolute position embeddings.
|
| 67 |
+
use_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
| 68 |
+
Whether to use T5-style relative position embeddings in the self-attention layers.
|
| 69 |
+
use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
| 70 |
+
Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
|
| 71 |
+
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
|
| 72 |
+
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
|
| 73 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 74 |
+
Stochastic depth rate per sample (when applied in the main path of residual layers).
|
| 75 |
+
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
| 77 |
+
CLS token, before applying the classification head.
|
| 78 |
+
out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
|
| 79 |
+
Indices of the feature maps to use for semantic segmentation.
|
| 80 |
+
pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
|
| 81 |
+
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
| 82 |
+
use_auxiliary_head (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Whether to use an auxiliary head during training.
|
| 84 |
+
auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
|
| 85 |
+
Weight of the cross-entropy loss of the auxiliary head.
|
| 86 |
+
auxiliary_channels (`int`, *optional*, defaults to 256):
|
| 87 |
+
Number of channels to use in the auxiliary head.
|
| 88 |
+
auxiliary_num_convs (`int`, *optional*, defaults to 1):
|
| 89 |
+
Number of convolutional layers to use in the auxiliary head.
|
| 90 |
+
auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
|
| 91 |
+
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
| 92 |
+
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
| 93 |
+
The index that is ignored by the loss function of the semantic segmentation model.
|
| 94 |
+
|
| 95 |
+
Example:
|
| 96 |
+
|
| 97 |
+
```python
|
| 98 |
+
>>> from transformers import Data2VecVisionConfig, Data2VecVisionModel
|
| 99 |
+
|
| 100 |
+
>>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration
|
| 101 |
+
>>> configuration = Data2VecVisionConfig()
|
| 102 |
+
|
| 103 |
+
>>> # Initializing a model (with random weights) from the data2vec_vision-base-patch16-224-in22k style configuration
|
| 104 |
+
>>> model = Data2VecVisionModel(configuration)
|
| 105 |
+
|
| 106 |
+
>>> # Accessing the model configuration
|
| 107 |
+
>>> configuration = model.config
|
| 108 |
+
```"""
|
| 109 |
+
|
| 110 |
+
model_type = "data2vec-vision"
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
hidden_size=768,
|
| 115 |
+
num_hidden_layers=12,
|
| 116 |
+
num_attention_heads=12,
|
| 117 |
+
intermediate_size=3072,
|
| 118 |
+
hidden_act="gelu",
|
| 119 |
+
hidden_dropout_prob=0.0,
|
| 120 |
+
attention_probs_dropout_prob=0.0,
|
| 121 |
+
initializer_range=0.02,
|
| 122 |
+
layer_norm_eps=1e-12,
|
| 123 |
+
image_size=224,
|
| 124 |
+
patch_size=16,
|
| 125 |
+
num_channels=3,
|
| 126 |
+
use_mask_token=False,
|
| 127 |
+
use_absolute_position_embeddings=False,
|
| 128 |
+
use_relative_position_bias=False,
|
| 129 |
+
use_shared_relative_position_bias=False,
|
| 130 |
+
layer_scale_init_value=0.1,
|
| 131 |
+
drop_path_rate=0.1,
|
| 132 |
+
use_mean_pooling=True,
|
| 133 |
+
out_indices=[3, 5, 7, 11],
|
| 134 |
+
pool_scales=[1, 2, 3, 6],
|
| 135 |
+
use_auxiliary_head=True,
|
| 136 |
+
auxiliary_loss_weight=0.4,
|
| 137 |
+
auxiliary_channels=256,
|
| 138 |
+
auxiliary_num_convs=1,
|
| 139 |
+
auxiliary_concat_input=False,
|
| 140 |
+
semantic_loss_ignore_index=255,
|
| 141 |
+
**kwargs,
|
| 142 |
+
):
|
| 143 |
+
super().__init__(**kwargs)
|
| 144 |
+
|
| 145 |
+
self.hidden_size = hidden_size
|
| 146 |
+
self.num_hidden_layers = num_hidden_layers
|
| 147 |
+
self.num_attention_heads = num_attention_heads
|
| 148 |
+
self.intermediate_size = intermediate_size
|
| 149 |
+
self.hidden_act = hidden_act
|
| 150 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 151 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 152 |
+
self.initializer_range = initializer_range
|
| 153 |
+
self.layer_norm_eps = layer_norm_eps
|
| 154 |
+
|
| 155 |
+
self.image_size = image_size
|
| 156 |
+
self.patch_size = patch_size
|
| 157 |
+
self.num_channels = num_channels
|
| 158 |
+
self.use_mask_token = use_mask_token
|
| 159 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
| 160 |
+
self.use_relative_position_bias = use_relative_position_bias
|
| 161 |
+
self.use_shared_relative_position_bias = use_shared_relative_position_bias
|
| 162 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 163 |
+
self.drop_path_rate = drop_path_rate
|
| 164 |
+
self.use_mean_pooling = use_mean_pooling
|
| 165 |
+
# decode head attributes (semantic segmentation)
|
| 166 |
+
self.out_indices = out_indices
|
| 167 |
+
self.pool_scales = pool_scales
|
| 168 |
+
# auxiliary head attributes (semantic segmentation)
|
| 169 |
+
self.use_auxiliary_head = use_auxiliary_head
|
| 170 |
+
self.auxiliary_loss_weight = auxiliary_loss_weight
|
| 171 |
+
self.auxiliary_channels = auxiliary_channels
|
| 172 |
+
self.auxiliary_num_convs = auxiliary_num_convs
|
| 173 |
+
self.auxiliary_concat_input = auxiliary_concat_input
|
| 174 |
+
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
|
| 178 |
+
class Data2VecVisionOnnxConfig(OnnxConfig):
|
| 179 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 183 |
+
return OrderedDict(
|
| 184 |
+
[
|
| 185 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def atol_for_validation(self) -> float:
|
| 191 |
+
return 1e-4
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
__all__ = ["Data2VecVisionConfig", "Data2VecVisionOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Wav2Vec2 checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
from functools import reduce
|
| 20 |
+
|
| 21 |
+
import fairseq
|
| 22 |
+
import torch
|
| 23 |
+
from datasets import load_dataset
|
| 24 |
+
|
| 25 |
+
from transformers import Wav2Vec2Processor, logging
|
| 26 |
+
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig
|
| 27 |
+
|
| 28 |
+
# Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
|
| 29 |
+
from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy # noqa: F401
|
| 30 |
+
from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logging.set_verbosity_info()
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
MAPPING = {
|
| 37 |
+
"post_extract_proj": "feature_projection.projection",
|
| 38 |
+
"models.0.layer_norm": "feature_projection.layer_norm",
|
| 39 |
+
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
| 40 |
+
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
| 41 |
+
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
| 42 |
+
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
| 43 |
+
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
| 44 |
+
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
| 45 |
+
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
| 46 |
+
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
| 47 |
+
"encoder.layer_norm": "encoder.layer_norm",
|
| 48 |
+
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
| 49 |
+
"w2v_encoder.proj": "lm_head",
|
| 50 |
+
"mask_emb": "masked_spec_embed",
|
| 51 |
+
}
|
| 52 |
+
TOP_LEVEL_KEYS = [
|
| 53 |
+
"lm_head",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
| 58 |
+
for attribute in key.split("."):
|
| 59 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 60 |
+
|
| 61 |
+
if weight_type is not None:
|
| 62 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 63 |
+
else:
|
| 64 |
+
hf_shape = hf_pointer.shape
|
| 65 |
+
|
| 66 |
+
if hf_shape != value.shape:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 69 |
+
f" {value.shape} for {full_name}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if weight_type == "weight":
|
| 73 |
+
hf_pointer.weight.data = value
|
| 74 |
+
elif weight_type == "weight_g":
|
| 75 |
+
hf_pointer.weight_g.data = value
|
| 76 |
+
elif weight_type == "weight_v":
|
| 77 |
+
hf_pointer.weight_v.data = value
|
| 78 |
+
elif weight_type == "bias":
|
| 79 |
+
hf_pointer.bias.data = value
|
| 80 |
+
else:
|
| 81 |
+
hf_pointer.data = value
|
| 82 |
+
|
| 83 |
+
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def recursively_load_weights(fairseq_model, hf_model, is_headless):
|
| 87 |
+
unused_weights = []
|
| 88 |
+
fairseq_dict = fairseq_model.state_dict()
|
| 89 |
+
|
| 90 |
+
if not is_headless:
|
| 91 |
+
feature_extractor = hf_model.data2vec_audio.feature_extractor
|
| 92 |
+
pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
feature_extractor = hf_model.feature_extractor
|
| 96 |
+
pos_conv_embedding = hf_model.encoder.pos_conv_embed
|
| 97 |
+
|
| 98 |
+
for name, value in fairseq_dict.items():
|
| 99 |
+
is_used = False
|
| 100 |
+
if "conv_layers" in name:
|
| 101 |
+
load_conv_layer(
|
| 102 |
+
name,
|
| 103 |
+
value,
|
| 104 |
+
feature_extractor,
|
| 105 |
+
unused_weights,
|
| 106 |
+
)
|
| 107 |
+
is_used = True
|
| 108 |
+
elif "pos_conv" in name:
|
| 109 |
+
load_pos_conv_layer(
|
| 110 |
+
name,
|
| 111 |
+
value,
|
| 112 |
+
pos_conv_embedding,
|
| 113 |
+
unused_weights,
|
| 114 |
+
)
|
| 115 |
+
is_used = True
|
| 116 |
+
else:
|
| 117 |
+
for key, mapped_key in MAPPING.items():
|
| 118 |
+
if not is_headless:
|
| 119 |
+
mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
|
| 120 |
+
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
| 121 |
+
is_used = True
|
| 122 |
+
if "*" in mapped_key:
|
| 123 |
+
layer_index = name.split(key)[0].split(".")[-2]
|
| 124 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 125 |
+
if "weight_g" in name:
|
| 126 |
+
weight_type = "weight_g"
|
| 127 |
+
elif "weight_v" in name:
|
| 128 |
+
weight_type = "weight_v"
|
| 129 |
+
elif "bias" in name:
|
| 130 |
+
weight_type = "bias"
|
| 131 |
+
elif "weight" in name:
|
| 132 |
+
# TODO: don't match quantizer.weight_proj
|
| 133 |
+
weight_type = "weight"
|
| 134 |
+
else:
|
| 135 |
+
weight_type = None
|
| 136 |
+
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
| 137 |
+
continue
|
| 138 |
+
if not is_used:
|
| 139 |
+
unused_weights.append(name)
|
| 140 |
+
|
| 141 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def access_by_string(module, path):
|
| 145 |
+
names = path.split(".")
|
| 146 |
+
return reduce(getattr, names, module)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def set_weights(full_name, module, fsq_value, hf_weight_path):
|
| 150 |
+
hf_weight = access_by_string(module, hf_weight_path)
|
| 151 |
+
hf_value = hf_weight.data
|
| 152 |
+
|
| 153 |
+
if fsq_value.shape != hf_value.shape:
|
| 154 |
+
raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.")
|
| 155 |
+
hf_weight.data = fsq_value
|
| 156 |
+
logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_conv_layer(full_name, value, feature_extractor, unused_weights):
|
| 160 |
+
name = full_name.split("conv_layers.")[-1]
|
| 161 |
+
items = name.split(".")
|
| 162 |
+
layer_id = int(items[0])
|
| 163 |
+
type_id = int(items[1])
|
| 164 |
+
|
| 165 |
+
weight_type = name.split(".")[-1]
|
| 166 |
+
if type_id == 0:
|
| 167 |
+
layer_type = "conv"
|
| 168 |
+
elif type_id == 2:
|
| 169 |
+
layer_type = "layer_norm"
|
| 170 |
+
else:
|
| 171 |
+
unused_weights.append(full_name)
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights):
|
| 178 |
+
name = full_name.split("pos_conv.")[-1]
|
| 179 |
+
items = name.split(".")
|
| 180 |
+
layer_id = int(items[0])
|
| 181 |
+
type_id = int(items[1])
|
| 182 |
+
|
| 183 |
+
weight_type = name.split(".")[-1]
|
| 184 |
+
if type_id != 0:
|
| 185 |
+
unused_weights.append(full_name)
|
| 186 |
+
return
|
| 187 |
+
else:
|
| 188 |
+
layer_type = "conv"
|
| 189 |
+
|
| 190 |
+
set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def convert_wav2vec2_checkpoint(
|
| 195 |
+
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 199 |
+
"""
|
| 200 |
+
if config_path is not None:
|
| 201 |
+
config = Data2VecAudioConfig.from_pretrained(config_path)
|
| 202 |
+
else:
|
| 203 |
+
config = Data2VecAudioConfig()
|
| 204 |
+
|
| 205 |
+
if not is_finetuned:
|
| 206 |
+
# Modify final_proj layer name
|
| 207 |
+
hf_wav2vec = Data2VecAudioModel(config)
|
| 208 |
+
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
|
| 209 |
+
|
| 210 |
+
state_dict = torch.load(checkpoint_path, weights_only=True)
|
| 211 |
+
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
|
| 212 |
+
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
|
| 213 |
+
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
|
| 214 |
+
torch.save(state_dict, converted_ckpt)
|
| 215 |
+
else:
|
| 216 |
+
hf_wav2vec = Data2VecAudioForCTC(config)
|
| 217 |
+
converted_ckpt = checkpoint_path
|
| 218 |
+
|
| 219 |
+
def load_data2vec(path):
|
| 220 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
|
| 221 |
+
return model[0].eval()
|
| 222 |
+
|
| 223 |
+
model = load_data2vec(converted_ckpt)
|
| 224 |
+
|
| 225 |
+
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
|
| 226 |
+
|
| 227 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
|
| 228 |
+
|
| 229 |
+
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
| 230 |
+
input_audio = [x["array"] for x in ds[:4]["audio"]]
|
| 231 |
+
|
| 232 |
+
inputs = processor(input_audio, return_tensors="pt", padding=True)
|
| 233 |
+
|
| 234 |
+
input_values = inputs.input_values
|
| 235 |
+
attention_mask = inputs.attention_mask
|
| 236 |
+
# input_values = inputs.input_values[:, :-1]
|
| 237 |
+
# attention_mask = inputs.attention_mask[:, :-1]
|
| 238 |
+
|
| 239 |
+
hf_wav2vec.eval()
|
| 240 |
+
model.eval()
|
| 241 |
+
if is_finetuned:
|
| 242 |
+
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
|
| 243 |
+
"encoder_out"
|
| 244 |
+
].transpose(0, 1)
|
| 245 |
+
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"]
|
| 246 |
+
|
| 247 |
+
pred_ids = torch.argmax(our_output, dim=-1)
|
| 248 |
+
output_string = processor.batch_decode(pred_ids)
|
| 249 |
+
|
| 250 |
+
print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}")
|
| 251 |
+
else:
|
| 252 |
+
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
|
| 253 |
+
"layer_results"
|
| 254 |
+
][-1][0].transpose(0, 1)
|
| 255 |
+
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"]
|
| 256 |
+
|
| 257 |
+
print(our_output.shape, their_output.shape)
|
| 258 |
+
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
| 259 |
+
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
| 260 |
+
success = torch.allclose(our_output, their_output, atol=1e-3)
|
| 261 |
+
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
| 262 |
+
if not success:
|
| 263 |
+
raise Exception("Something went wRoNg")
|
| 264 |
+
|
| 265 |
+
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
| 266 |
+
|
| 267 |
+
if is_finetuned:
|
| 268 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 269 |
+
else:
|
| 270 |
+
processor.feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
parser = argparse.ArgumentParser()
|
| 275 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 276 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
| 277 |
+
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
| 278 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
| 281 |
+
)
|
| 282 |
+
args = parser.parse_args()
|
| 283 |
+
convert_wav2vec2_checkpoint(
|
| 284 |
+
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
| 285 |
+
)
|
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert data2vec checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import pathlib
|
| 20 |
+
|
| 21 |
+
import fairseq
|
| 22 |
+
import torch
|
| 23 |
+
from fairseq.modules import TransformerSentenceEncoderLayer
|
| 24 |
+
from packaging import version
|
| 25 |
+
|
| 26 |
+
from transformers import (
|
| 27 |
+
Data2VecTextConfig,
|
| 28 |
+
Data2VecTextForMaskedLM,
|
| 29 |
+
Data2VecTextForSequenceClassification,
|
| 30 |
+
Data2VecTextModel,
|
| 31 |
+
)
|
| 32 |
+
from transformers.models.bert.modeling_bert import (
|
| 33 |
+
BertIntermediate,
|
| 34 |
+
BertLayer,
|
| 35 |
+
BertOutput,
|
| 36 |
+
BertSelfAttention,
|
| 37 |
+
BertSelfOutput,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
|
| 41 |
+
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
|
| 42 |
+
from transformers.utils import logging
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
| 46 |
+
raise Exception("requires fairseq >= 0.9.0")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logging.set_verbosity_info()
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def convert_data2vec_checkpoint_to_pytorch(
|
| 56 |
+
data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Copy/paste/tweak data2vec's weights to our BERT structure.
|
| 60 |
+
"""
|
| 61 |
+
data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
|
| 62 |
+
data2vec = Data2VecTextModel.from_pretrained(
|
| 63 |
+
data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
|
| 64 |
+
)
|
| 65 |
+
data2vec.eval() # disable dropout
|
| 66 |
+
data2vec_model = data2vec.models[0]
|
| 67 |
+
data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
|
| 68 |
+
config = Data2VecTextConfig(
|
| 69 |
+
vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
|
| 70 |
+
hidden_size=data2vec_model.args.encoder_embed_dim,
|
| 71 |
+
num_hidden_layers=data2vec_model.args.encoder_layers,
|
| 72 |
+
num_attention_heads=data2vec_model.args.encoder_attention_heads,
|
| 73 |
+
intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
|
| 74 |
+
max_position_embeddings=514,
|
| 75 |
+
type_vocab_size=1,
|
| 76 |
+
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
| 77 |
+
)
|
| 78 |
+
if classification_head:
|
| 79 |
+
config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
|
| 80 |
+
print("Our BERT config:", config)
|
| 81 |
+
|
| 82 |
+
model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
# Now let's copy all the weights.
|
| 86 |
+
# Embeddings
|
| 87 |
+
model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
|
| 88 |
+
model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
|
| 89 |
+
model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
|
| 90 |
+
model.data2vec_text.embeddings.token_type_embeddings.weight
|
| 91 |
+
) # just zero them out b/c data2vec doesn't use them.
|
| 92 |
+
model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
|
| 93 |
+
model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias
|
| 94 |
+
|
| 95 |
+
for i in range(config.num_hidden_layers):
|
| 96 |
+
# Encoder: start of layer
|
| 97 |
+
layer: BertLayer = model.data2vec_text.encoder.layer[i]
|
| 98 |
+
data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]
|
| 99 |
+
|
| 100 |
+
# self attention
|
| 101 |
+
self_attn: BertSelfAttention = layer.attention.self
|
| 102 |
+
assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
|
| 103 |
+
(config.hidden_size, config.hidden_size)
|
| 104 |
+
), (
|
| 105 |
+
"Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
|
| 106 |
+
f" {torch.Size((config.hidden_size, config.hidden_size))}"
|
| 107 |
+
)
|
| 108 |
+
assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
|
| 109 |
+
(config.hidden_size, config.hidden_size)
|
| 110 |
+
), (
|
| 111 |
+
"Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
|
| 112 |
+
f" {torch.Size((config.hidden_size, config.hidden_size))}"
|
| 113 |
+
)
|
| 114 |
+
assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
|
| 115 |
+
(config.hidden_size, config.hidden_size)
|
| 116 |
+
), (
|
| 117 |
+
"Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
|
| 118 |
+
f" {torch.Size((config.hidden_size, config.hidden_size))}"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
|
| 122 |
+
self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
|
| 123 |
+
self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
|
| 124 |
+
self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
|
| 125 |
+
self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
|
| 126 |
+
self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias
|
| 127 |
+
|
| 128 |
+
# self-attention output
|
| 129 |
+
self_output: BertSelfOutput = layer.attention.output
|
| 130 |
+
assert self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape, (
|
| 131 |
+
f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
|
| 132 |
+
)
|
| 133 |
+
self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
|
| 134 |
+
self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
|
| 135 |
+
self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
|
| 136 |
+
self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
|
| 137 |
+
|
| 138 |
+
# intermediate
|
| 139 |
+
intermediate: BertIntermediate = layer.intermediate
|
| 140 |
+
assert intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape, (
|
| 141 |
+
f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
|
| 142 |
+
)
|
| 143 |
+
intermediate.dense.weight = data2vec_layer.fc1.weight
|
| 144 |
+
intermediate.dense.bias = data2vec_layer.fc1.bias
|
| 145 |
+
|
| 146 |
+
# output
|
| 147 |
+
bert_output: BertOutput = layer.output
|
| 148 |
+
assert bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape, (
|
| 149 |
+
f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
|
| 150 |
+
)
|
| 151 |
+
bert_output.dense.weight = data2vec_layer.fc2.weight
|
| 152 |
+
bert_output.dense.bias = data2vec_layer.fc2.bias
|
| 153 |
+
bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
|
| 154 |
+
bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
|
| 155 |
+
# end of layer
|
| 156 |
+
|
| 157 |
+
if classification_head:
|
| 158 |
+
model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
|
| 159 |
+
model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
|
| 160 |
+
model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
|
| 161 |
+
model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
|
| 162 |
+
else:
|
| 163 |
+
# LM Head
|
| 164 |
+
model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
|
| 165 |
+
model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
|
| 166 |
+
model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
|
| 167 |
+
model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
|
| 168 |
+
model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
|
| 169 |
+
model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias
|
| 170 |
+
|
| 171 |
+
# Let's check that we get the same results.
|
| 172 |
+
input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
| 173 |
+
|
| 174 |
+
our_output = model(input_ids)[0]
|
| 175 |
+
if classification_head:
|
| 176 |
+
their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
|
| 177 |
+
else:
|
| 178 |
+
their_output = data2vec_model(input_ids)[0]
|
| 179 |
+
print(our_output.shape, their_output.shape)
|
| 180 |
+
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
| 181 |
+
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
| 182 |
+
success = torch.allclose(our_output, their_output, atol=1e-3)
|
| 183 |
+
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
| 184 |
+
if not success:
|
| 185 |
+
raise Exception("Something went wRoNg")
|
| 186 |
+
|
| 187 |
+
pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
|
| 188 |
+
print(f"Saving model to {pytorch_dump_folder_path}")
|
| 189 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
parser = argparse.ArgumentParser()
|
| 194 |
+
# Required parameters
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--classification_head", action="store_true", help="Whether to convert a final classification head."
|
| 203 |
+
)
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
convert_data2vec_checkpoint_to_pytorch(
|
| 206 |
+
args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
|
| 207 |
+
)
|
docs/transformers/build/lib/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from timm.models import create_model
|
| 9 |
+
|
| 10 |
+
from transformers import (
|
| 11 |
+
BeitImageProcessor,
|
| 12 |
+
Data2VecVisionConfig,
|
| 13 |
+
Data2VecVisionForImageClassification,
|
| 14 |
+
Data2VecVisionModel,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_rename_keys(config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec."):
|
| 19 |
+
prefix = "backbone." if is_semantic else ""
|
| 20 |
+
|
| 21 |
+
rename_keys = []
|
| 22 |
+
for i in range(config.num_hidden_layers):
|
| 23 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 24 |
+
rename_keys.append(
|
| 25 |
+
(f"{prefix}blocks.{i}.norm1.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_before.weight")
|
| 26 |
+
)
|
| 27 |
+
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_before.bias"))
|
| 28 |
+
rename_keys.append(
|
| 29 |
+
(f"{prefix}blocks.{i}.attn.proj.weight", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.weight")
|
| 30 |
+
)
|
| 31 |
+
rename_keys.append(
|
| 32 |
+
(f"{prefix}blocks.{i}.attn.proj.bias", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.bias")
|
| 33 |
+
)
|
| 34 |
+
rename_keys.append(
|
| 35 |
+
(f"{prefix}blocks.{i}.norm2.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_after.weight")
|
| 36 |
+
)
|
| 37 |
+
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_after.bias"))
|
| 38 |
+
rename_keys.append(
|
| 39 |
+
(f"{prefix}blocks.{i}.mlp.fc1.weight", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.weight")
|
| 40 |
+
)
|
| 41 |
+
rename_keys.append(
|
| 42 |
+
(f"{prefix}blocks.{i}.mlp.fc1.bias", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.bias")
|
| 43 |
+
)
|
| 44 |
+
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"{hf_prefix}encoder.layer.{i}.output.dense.weight"))
|
| 45 |
+
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"{hf_prefix}encoder.layer.{i}.output.dense.bias"))
|
| 46 |
+
|
| 47 |
+
# projection layer + position embeddings
|
| 48 |
+
rename_keys.extend(
|
| 49 |
+
[
|
| 50 |
+
(f"{prefix}cls_token", f"{hf_prefix}embeddings.cls_token"),
|
| 51 |
+
(f"{prefix}patch_embed.proj.weight", f"{hf_prefix}embeddings.patch_embeddings.projection.weight"),
|
| 52 |
+
(f"{prefix}patch_embed.proj.bias", f"{hf_prefix}embeddings.patch_embeddings.projection.bias"),
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if has_lm_head:
|
| 57 |
+
# mask token + shared relative position bias + layernorm
|
| 58 |
+
rename_keys.extend(
|
| 59 |
+
[
|
| 60 |
+
("mask_token", f"{hf_prefix}embeddings.mask_token"),
|
| 61 |
+
(
|
| 62 |
+
"rel_pos_bias.relative_position_bias_table",
|
| 63 |
+
f"{hf_prefix}encoder.relative_position_bias.relative_position_bias_table",
|
| 64 |
+
),
|
| 65 |
+
(
|
| 66 |
+
"rel_pos_bias.relative_position_index",
|
| 67 |
+
f"{hf_prefix}encoder.relative_position_bias.relative_position_index",
|
| 68 |
+
),
|
| 69 |
+
("norm.weight", "layernorm.weight"),
|
| 70 |
+
("norm.bias", "layernorm.bias"),
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
elif is_semantic:
|
| 74 |
+
# semantic segmentation classification heads
|
| 75 |
+
rename_keys.extend(
|
| 76 |
+
[
|
| 77 |
+
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
| 78 |
+
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
| 79 |
+
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
| 80 |
+
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
# layernorm + classification head
|
| 85 |
+
rename_keys.extend(
|
| 86 |
+
[
|
| 87 |
+
("fc_norm.weight", f"{hf_prefix}pooler.layernorm.weight"),
|
| 88 |
+
("fc_norm.bias", f"{hf_prefix}pooler.layernorm.bias"),
|
| 89 |
+
("head.weight", "classifier.weight"),
|
| 90 |
+
("head.bias", "classifier.bias"),
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return rename_keys
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec_vision."):
|
| 98 |
+
for i in range(config.num_hidden_layers):
|
| 99 |
+
prefix = "backbone." if is_semantic else ""
|
| 100 |
+
# queries, keys and values
|
| 101 |
+
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
| 102 |
+
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
| 103 |
+
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
| 104 |
+
|
| 105 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
| 106 |
+
: config.hidden_size, :
|
| 107 |
+
]
|
| 108 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
| 109 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
| 110 |
+
config.hidden_size : config.hidden_size * 2, :
|
| 111 |
+
]
|
| 112 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
| 113 |
+
-config.hidden_size :, :
|
| 114 |
+
]
|
| 115 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
| 116 |
+
|
| 117 |
+
# gamma_1 and gamma_2
|
| 118 |
+
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
| 119 |
+
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
| 120 |
+
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
| 121 |
+
|
| 122 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_1"] = gamma_1
|
| 123 |
+
state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_2"] = gamma_2
|
| 124 |
+
|
| 125 |
+
# relative_position bias table + index
|
| 126 |
+
if not has_lm_head:
|
| 127 |
+
# each layer has its own relative position bias
|
| 128 |
+
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
| 129 |
+
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
| 130 |
+
|
| 131 |
+
state_dict[
|
| 132 |
+
f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
| 133 |
+
] = table
|
| 134 |
+
state_dict[
|
| 135 |
+
f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
| 136 |
+
] = index
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_args():
|
| 140 |
+
parser = argparse.ArgumentParser(
|
| 141 |
+
"Convert Data2VecVision to HF for image classification and pretraining", add_help=False
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument("--hf_checkpoint_name", type=str)
|
| 144 |
+
parser.add_argument("--input_size", default=224, type=int, help="images input size")
|
| 145 |
+
parser.add_argument("--beit_checkpoint", default="", help="beit checkpoint")
|
| 146 |
+
|
| 147 |
+
return parser.parse_args()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def load_beit_model(args, is_finetuned, is_large):
|
| 151 |
+
def load_state_dict(model, state_dict, prefix="", ignore_missing="relative_position_index"):
|
| 152 |
+
missing_keys = []
|
| 153 |
+
unexpected_keys = []
|
| 154 |
+
error_msgs = []
|
| 155 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 156 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 157 |
+
state_dict = state_dict.copy()
|
| 158 |
+
if metadata is not None:
|
| 159 |
+
state_dict._metadata = metadata
|
| 160 |
+
|
| 161 |
+
def load(module, prefix=""):
|
| 162 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 163 |
+
module._load_from_state_dict(
|
| 164 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
| 165 |
+
)
|
| 166 |
+
for name, child in module._modules.items():
|
| 167 |
+
if child is not None:
|
| 168 |
+
load(child, prefix + name + ".")
|
| 169 |
+
|
| 170 |
+
load(model, prefix=prefix)
|
| 171 |
+
|
| 172 |
+
warn_missing_keys = []
|
| 173 |
+
ignore_missing_keys = []
|
| 174 |
+
for key in missing_keys:
|
| 175 |
+
keep_flag = True
|
| 176 |
+
for ignore_key in ignore_missing.split("|"):
|
| 177 |
+
if ignore_key in key:
|
| 178 |
+
keep_flag = False
|
| 179 |
+
break
|
| 180 |
+
if keep_flag:
|
| 181 |
+
warn_missing_keys.append(key)
|
| 182 |
+
else:
|
| 183 |
+
ignore_missing_keys.append(key)
|
| 184 |
+
|
| 185 |
+
missing_keys = warn_missing_keys
|
| 186 |
+
|
| 187 |
+
if len(missing_keys) > 0:
|
| 188 |
+
print(
|
| 189 |
+
"Weights of {} not initialized from pretrained model: {}".format(
|
| 190 |
+
model.__class__.__name__, missing_keys
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
if len(unexpected_keys) > 0:
|
| 194 |
+
print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys))
|
| 195 |
+
if len(ignore_missing_keys) > 0:
|
| 196 |
+
print(
|
| 197 |
+
"Ignored weights of {} not initialized from pretrained model: {}".format(
|
| 198 |
+
model.__class__.__name__, ignore_missing_keys
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
if len(error_msgs) > 0:
|
| 202 |
+
print("\n".join(error_msgs))
|
| 203 |
+
|
| 204 |
+
model_kwargs = {
|
| 205 |
+
"pretrained": False,
|
| 206 |
+
"use_shared_rel_pos_bias": True,
|
| 207 |
+
"use_abs_pos_emb": False,
|
| 208 |
+
"init_values": 0.1,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if is_finetuned:
|
| 212 |
+
model_kwargs.update(
|
| 213 |
+
{
|
| 214 |
+
"num_classes": 1000,
|
| 215 |
+
"use_mean_pooling": True,
|
| 216 |
+
"init_scale": 0.001,
|
| 217 |
+
"use_rel_pos_bias": True,
|
| 218 |
+
}
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
model = create_model(
|
| 222 |
+
"beit_large_patch16_224" if is_large else "beit_base_patch16_224",
|
| 223 |
+
**model_kwargs,
|
| 224 |
+
)
|
| 225 |
+
patch_size = model.patch_embed.patch_size
|
| 226 |
+
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
|
| 227 |
+
checkpoint = torch.load(args.beit_checkpoint, map_location="cpu", weights_only=True)
|
| 228 |
+
|
| 229 |
+
print(f"Load ckpt from {args.beit_checkpoint}")
|
| 230 |
+
checkpoint_model = None
|
| 231 |
+
for model_key in ("model", "module"):
|
| 232 |
+
if model_key in checkpoint:
|
| 233 |
+
checkpoint_model = checkpoint[model_key]
|
| 234 |
+
print(f"Load state_dict by model_key = {model_key}")
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
all_keys = list(checkpoint_model.keys())
|
| 238 |
+
for key in all_keys:
|
| 239 |
+
if "relative_position_index" in key:
|
| 240 |
+
checkpoint_model.pop(key)
|
| 241 |
+
|
| 242 |
+
if "relative_position_bias_table" in key:
|
| 243 |
+
rel_pos_bias = checkpoint_model[key]
|
| 244 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
| 245 |
+
dst_num_pos, _ = model.state_dict()[key].size()
|
| 246 |
+
dst_patch_shape = model.patch_embed.patch_shape
|
| 247 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
| 248 |
+
raise NotImplementedError()
|
| 249 |
+
|
| 250 |
+
load_state_dict(model, checkpoint_model, prefix="")
|
| 251 |
+
|
| 252 |
+
return model
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def main():
|
| 256 |
+
args = get_args()
|
| 257 |
+
|
| 258 |
+
is_finetuned = "ft1k" in args.hf_checkpoint_name
|
| 259 |
+
is_large = "large" in args.hf_checkpoint_name
|
| 260 |
+
|
| 261 |
+
if is_finetuned:
|
| 262 |
+
# To convert Beit's data2vec_vision to HF you need to copy
|
| 263 |
+
# https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py
|
| 264 |
+
# into this folder.
|
| 265 |
+
import modeling_finetune # noqa: F401
|
| 266 |
+
else:
|
| 267 |
+
# To convert Beit's data2vec_vision to HF you need to copy
|
| 268 |
+
# https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py
|
| 269 |
+
# into this folder
|
| 270 |
+
# IMPORTANT: Note that for now we've only converted the down-stream
|
| 271 |
+
# model and not the full pretrained model. This means for the integration
|
| 272 |
+
# test you need to add a `return x` after the following line:
|
| 273 |
+
# https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197
|
| 274 |
+
# to make the integration test pass.
|
| 275 |
+
import modeling_cyclical # noqa: F401
|
| 276 |
+
|
| 277 |
+
# 1. Create model config
|
| 278 |
+
config = Data2VecVisionConfig()
|
| 279 |
+
if is_finetuned:
|
| 280 |
+
config.use_relative_position_bias = True
|
| 281 |
+
config.use_shared_relative_position_bias = False
|
| 282 |
+
config.use_mean_pooling = True
|
| 283 |
+
config.num_labels = 1000
|
| 284 |
+
|
| 285 |
+
repo_id = "huggingface/label-files"
|
| 286 |
+
filename = "imagenet-1k-id2label.json"
|
| 287 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 288 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 289 |
+
config.id2label = id2label
|
| 290 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 291 |
+
else:
|
| 292 |
+
config.use_relative_position_bias = False
|
| 293 |
+
config.use_shared_relative_position_bias = True
|
| 294 |
+
config.use_mean_pooling = False
|
| 295 |
+
|
| 296 |
+
if is_large:
|
| 297 |
+
config.hidden_size = 1024
|
| 298 |
+
config.intermediate_size = 4096
|
| 299 |
+
config.num_hidden_layers = 24
|
| 300 |
+
config.num_attention_heads = 16
|
| 301 |
+
|
| 302 |
+
# 2. Load Beit model
|
| 303 |
+
orig_model = load_beit_model(args, is_finetuned, is_large)
|
| 304 |
+
orig_model.eval()
|
| 305 |
+
|
| 306 |
+
# 3. Forward Beit model
|
| 307 |
+
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
| 308 |
+
image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png")
|
| 309 |
+
encoding = image_processor(images=image, return_tensors="pt")
|
| 310 |
+
pixel_values = encoding["pixel_values"]
|
| 311 |
+
|
| 312 |
+
orig_args = (pixel_values,) if is_finetuned else (pixel_values, None)
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
orig_model_output = orig_model(*orig_args)
|
| 315 |
+
|
| 316 |
+
# 4. Load HF Data2VecVision model
|
| 317 |
+
if is_finetuned:
|
| 318 |
+
hf_model = Data2VecVisionForImageClassification(config)
|
| 319 |
+
hf_model.eval()
|
| 320 |
+
has_lm_head = False
|
| 321 |
+
hf_prefix = "data2vec_vision."
|
| 322 |
+
else:
|
| 323 |
+
hf_model = Data2VecVisionModel(config)
|
| 324 |
+
hf_model.eval()
|
| 325 |
+
has_lm_head = True
|
| 326 |
+
hf_prefix = ""
|
| 327 |
+
|
| 328 |
+
rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
|
| 329 |
+
state_dict = orig_model.state_dict()
|
| 330 |
+
for src, dest in rename_keys:
|
| 331 |
+
val = state_dict.pop(src)
|
| 332 |
+
state_dict[dest] = val
|
| 333 |
+
|
| 334 |
+
read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
|
| 335 |
+
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
| 336 |
+
print("HF missing", missing_keys)
|
| 337 |
+
print("HF unexpected_keys", unexpected_keys)
|
| 338 |
+
|
| 339 |
+
# 5. Forward HF Data2VecVision model
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
hf_model_output = hf_model(pixel_values)
|
| 342 |
+
|
| 343 |
+
hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state
|
| 344 |
+
|
| 345 |
+
# 6. Compare
|
| 346 |
+
max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item()
|
| 347 |
+
|
| 348 |
+
print(f"max_absolute_diff = {max_absolute_diff}")
|
| 349 |
+
success = torch.allclose(hf_output, orig_model_output, atol=1e-3)
|
| 350 |
+
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
| 351 |
+
if not success:
|
| 352 |
+
raise Exception("Something went wRoNg")
|
| 353 |
+
|
| 354 |
+
# 7. Save
|
| 355 |
+
print(f"Saving to {args.hf_checkpoint_name}")
|
| 356 |
+
hf_model.save_pretrained(args.hf_checkpoint_name)
|
| 357 |
+
image_processor.save_pretrained(args.hf_checkpoint_name)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
main()
|
| 362 |
+
# Run the following to convert checkpoints
|
| 363 |
+
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
|
| 364 |
+
# --beit_checkpoint ./pretrained_base.pt \
|
| 365 |
+
# --hf_checkpoint_name "./data2vec-vision-base"
|
| 366 |
+
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
|
| 367 |
+
# --beit_checkpoint ./finetuned_base.pt \
|
| 368 |
+
# --hf_checkpoint_name "./data2vec-vision-base-ft1k"
|
| 369 |
+
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
|
| 370 |
+
# --beit_checkpoint ./pretrained_large.pt \
|
| 371 |
+
# --hf_checkpoint_name "./data2vec-vision-large"
|
| 372 |
+
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
|
| 373 |
+
# --beit_checkpoint ./finetuned_large.pt \
|
| 374 |
+
# --hf_checkpoint_name "./data2vec-vision-large-ft1k"
|
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_audio.py
ADDED
|
@@ -0,0 +1,1746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_data2vec_audio.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
import math
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn import CrossEntropyLoss
|
| 15 |
+
|
| 16 |
+
from ...activations import ACT2FN
|
| 17 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 18 |
+
from ...integrations.fsdp import is_fsdp_managed_module
|
| 19 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 20 |
+
from ...modeling_outputs import (
|
| 21 |
+
BaseModelOutput,
|
| 22 |
+
CausalLMOutput,
|
| 23 |
+
SequenceClassifierOutput,
|
| 24 |
+
TokenClassifierOutput,
|
| 25 |
+
Wav2Vec2BaseModelOutput,
|
| 26 |
+
XVectorOutput,
|
| 27 |
+
)
|
| 28 |
+
from ...modeling_utils import PreTrainedModel
|
| 29 |
+
from ...utils import (
|
| 30 |
+
add_code_sample_docstrings,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
is_peft_available,
|
| 34 |
+
logging,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_data2vec_audio import Data2VecAudioConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_flash_attn_available():
|
| 40 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
# Base docstring
|
| 46 |
+
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
|
| 47 |
+
|
| 48 |
+
# General docstring
|
| 49 |
+
_CONFIG_FOR_DOC = "Data2VecAudioConfig"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Data2VecAudioConvLayer(nn.Module):
|
| 53 |
+
def __init__(self, config, layer_id=0):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 56 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 57 |
+
|
| 58 |
+
self.conv = nn.Conv1d(
|
| 59 |
+
self.in_conv_dim,
|
| 60 |
+
self.out_conv_dim,
|
| 61 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 62 |
+
stride=config.conv_stride[layer_id],
|
| 63 |
+
bias=config.conv_bias,
|
| 64 |
+
)
|
| 65 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 66 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 67 |
+
|
| 68 |
+
def forward(self, hidden_states):
|
| 69 |
+
hidden_states = self.conv(hidden_states)
|
| 70 |
+
|
| 71 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 72 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 73 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 74 |
+
|
| 75 |
+
hidden_states = self.activation(hidden_states)
|
| 76 |
+
return hidden_states
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Data2VecAudioPadLayer(nn.Module):
|
| 80 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 83 |
+
|
| 84 |
+
def forward(self, hidden_states):
|
| 85 |
+
if self.num_pad_remove > 0:
|
| 86 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 87 |
+
return hidden_states
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Data2VecAudioPositionalConvLayer(nn.Module):
|
| 91 |
+
def __init__(self, config):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.conv = nn.Conv1d(
|
| 94 |
+
config.hidden_size,
|
| 95 |
+
config.hidden_size,
|
| 96 |
+
kernel_size=config.conv_pos_kernel_size,
|
| 97 |
+
padding=config.conv_pos_kernel_size // 2,
|
| 98 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
|
| 102 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 103 |
+
# no learnable parameters
|
| 104 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
|
| 105 |
+
|
| 106 |
+
def forward(self, hidden_states):
|
| 107 |
+
hidden_states = self.conv(hidden_states)
|
| 108 |
+
hidden_states = self.padding(hidden_states)
|
| 109 |
+
|
| 110 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 111 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 112 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 113 |
+
hidden_states = self.activation(hidden_states)
|
| 114 |
+
return hidden_states
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Data2VecAudioPositionalConvEmbedding(nn.Module):
|
| 118 |
+
def __init__(self, config):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.layers = nn.ModuleList(
|
| 121 |
+
[Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def forward(self, hidden_states):
|
| 125 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 126 |
+
for layer in self.layers:
|
| 127 |
+
hidden_states = layer(hidden_states)
|
| 128 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 129 |
+
return hidden_states
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class Data2VecAudioFeatureEncoder(nn.Module):
|
| 133 |
+
"""Construct the features from raw audio waveform"""
|
| 134 |
+
|
| 135 |
+
def __init__(self, config):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.conv_layers = nn.ModuleList(
|
| 138 |
+
[Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
|
| 139 |
+
)
|
| 140 |
+
self.gradient_checkpointing = False
|
| 141 |
+
self._requires_grad = True
|
| 142 |
+
|
| 143 |
+
def _freeze_parameters(self):
|
| 144 |
+
for param in self.parameters():
|
| 145 |
+
param.requires_grad = False
|
| 146 |
+
self._requires_grad = False
|
| 147 |
+
|
| 148 |
+
def forward(self, input_values):
|
| 149 |
+
hidden_states = input_values[:, None]
|
| 150 |
+
|
| 151 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 152 |
+
if self._requires_grad and self.training:
|
| 153 |
+
hidden_states.requires_grad = True
|
| 154 |
+
|
| 155 |
+
for conv_layer in self.conv_layers:
|
| 156 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 157 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 158 |
+
conv_layer.__call__,
|
| 159 |
+
hidden_states,
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
hidden_states = conv_layer(hidden_states)
|
| 163 |
+
|
| 164 |
+
return hidden_states
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Data2VecAudioFeatureProjection(nn.Module):
|
| 168 |
+
def __init__(self, config):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 171 |
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 172 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 173 |
+
|
| 174 |
+
def forward(self, hidden_states):
|
| 175 |
+
# non-projected hidden states are needed for quantization
|
| 176 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
| 177 |
+
hidden_states = self.projection(norm_hidden_states)
|
| 178 |
+
hidden_states = self.dropout(hidden_states)
|
| 179 |
+
return hidden_states, norm_hidden_states
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Data2VecAudioAttention(nn.Module):
|
| 183 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
embed_dim: int,
|
| 188 |
+
num_heads: int,
|
| 189 |
+
dropout: float = 0.0,
|
| 190 |
+
is_decoder: bool = False,
|
| 191 |
+
bias: bool = True,
|
| 192 |
+
is_causal: bool = False,
|
| 193 |
+
config: Optional[Data2VecAudioConfig] = None,
|
| 194 |
+
):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.embed_dim = embed_dim
|
| 197 |
+
self.num_heads = num_heads
|
| 198 |
+
self.dropout = dropout
|
| 199 |
+
self.head_dim = embed_dim // num_heads
|
| 200 |
+
self.config = config
|
| 201 |
+
|
| 202 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 205 |
+
f" and `num_heads`: {num_heads})."
|
| 206 |
+
)
|
| 207 |
+
self.scaling = self.head_dim**-0.5
|
| 208 |
+
self.is_decoder = is_decoder
|
| 209 |
+
self.is_causal = is_causal
|
| 210 |
+
|
| 211 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 212 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 213 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 214 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 215 |
+
|
| 216 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 217 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
hidden_states: torch.Tensor,
|
| 222 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 223 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 225 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 226 |
+
output_attentions: bool = False,
|
| 227 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 228 |
+
"""Input shape: Batch x Time x Channel"""
|
| 229 |
+
|
| 230 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 231 |
+
# for the decoder
|
| 232 |
+
is_cross_attention = key_value_states is not None
|
| 233 |
+
|
| 234 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 235 |
+
|
| 236 |
+
# get query proj
|
| 237 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 238 |
+
# get key, value proj
|
| 239 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 240 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 241 |
+
# the provided `key_value_states` to support prefix tuning
|
| 242 |
+
if (
|
| 243 |
+
is_cross_attention
|
| 244 |
+
and past_key_value is not None
|
| 245 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 246 |
+
):
|
| 247 |
+
# reuse k,v, cross_attentions
|
| 248 |
+
key_states = past_key_value[0]
|
| 249 |
+
value_states = past_key_value[1]
|
| 250 |
+
elif is_cross_attention:
|
| 251 |
+
# cross_attentions
|
| 252 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 253 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 254 |
+
elif past_key_value is not None:
|
| 255 |
+
# reuse k, v, self_attention
|
| 256 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 257 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 258 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 259 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 260 |
+
else:
|
| 261 |
+
# self_attention
|
| 262 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 263 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 264 |
+
|
| 265 |
+
if self.is_decoder:
|
| 266 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 267 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 268 |
+
# key/value_states (first "if" case)
|
| 269 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 270 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 271 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 272 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 273 |
+
past_key_value = (key_states, value_states)
|
| 274 |
+
|
| 275 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 276 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 277 |
+
key_states = key_states.reshape(*proj_shape)
|
| 278 |
+
value_states = value_states.reshape(*proj_shape)
|
| 279 |
+
|
| 280 |
+
src_len = key_states.size(1)
|
| 281 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 282 |
+
|
| 283 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 284 |
+
raise ValueError(
|
| 285 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 286 |
+
f" {attn_weights.size()}"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if attention_mask is not None:
|
| 290 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 293 |
+
)
|
| 294 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 295 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 296 |
+
|
| 297 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 298 |
+
|
| 299 |
+
if layer_head_mask is not None:
|
| 300 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 303 |
+
f" {layer_head_mask.size()}"
|
| 304 |
+
)
|
| 305 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 306 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 307 |
+
|
| 308 |
+
if output_attentions:
|
| 309 |
+
# this operation is a bit awkward, but it's required to
|
| 310 |
+
# make sure that attn_weights keeps its gradient.
|
| 311 |
+
# In order to do so, attn_weights have to be reshaped
|
| 312 |
+
# twice and have to be reused in the following
|
| 313 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 314 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 315 |
+
else:
|
| 316 |
+
attn_weights_reshaped = None
|
| 317 |
+
|
| 318 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 319 |
+
|
| 320 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 321 |
+
|
| 322 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 325 |
+
f" {attn_output.size()}"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 329 |
+
attn_output = attn_output.transpose(1, 2)
|
| 330 |
+
|
| 331 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 332 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 333 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 334 |
+
|
| 335 |
+
attn_output = self.out_proj(attn_output)
|
| 336 |
+
|
| 337 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
|
| 341 |
+
"""
|
| 342 |
+
Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
|
| 343 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 344 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def __init__(self, *args, **kwargs):
|
| 348 |
+
super().__init__(*args, **kwargs)
|
| 349 |
+
|
| 350 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 351 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 352 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 353 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 354 |
+
|
| 355 |
+
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 356 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
| 357 |
+
|
| 358 |
+
def forward(
|
| 359 |
+
self,
|
| 360 |
+
hidden_states: torch.Tensor,
|
| 361 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 362 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 363 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 364 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 365 |
+
output_attentions: bool = False,
|
| 366 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 367 |
+
# Data2VecAudioFlashAttention2 attention does not support output_attentions
|
| 368 |
+
if output_attentions:
|
| 369 |
+
raise ValueError("Data2VecAudioFlashAttention2 attention does not support output_attentions")
|
| 370 |
+
|
| 371 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 372 |
+
# for the decoder
|
| 373 |
+
is_cross_attention = key_value_states is not None
|
| 374 |
+
|
| 375 |
+
bsz, q_len, _ = hidden_states.size()
|
| 376 |
+
|
| 377 |
+
# get query proj
|
| 378 |
+
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
| 379 |
+
# get key, value proj
|
| 380 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 381 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 382 |
+
# the provided `key_value_states` to support prefix tuning
|
| 383 |
+
if (
|
| 384 |
+
is_cross_attention
|
| 385 |
+
and past_key_value is not None
|
| 386 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 387 |
+
):
|
| 388 |
+
# reuse k,v, cross_attentions
|
| 389 |
+
key_states = past_key_value[0].transpose(1, 2)
|
| 390 |
+
value_states = past_key_value[1].transpose(1, 2)
|
| 391 |
+
elif is_cross_attention:
|
| 392 |
+
# cross_attentions
|
| 393 |
+
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 394 |
+
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
| 395 |
+
elif past_key_value is not None:
|
| 396 |
+
# reuse k, v, self_attention
|
| 397 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 398 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 399 |
+
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
| 400 |
+
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
| 401 |
+
else:
|
| 402 |
+
# self_attention
|
| 403 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 404 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 405 |
+
|
| 406 |
+
if self.is_decoder:
|
| 407 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 408 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 409 |
+
# key/value_states (first "if" case)
|
| 410 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 411 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 412 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 413 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 414 |
+
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
| 415 |
+
|
| 416 |
+
kv_seq_len = key_states.shape[-2]
|
| 417 |
+
if past_key_value is not None:
|
| 418 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 419 |
+
|
| 420 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 421 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 422 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 423 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 424 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 425 |
+
|
| 426 |
+
input_dtype = query_states.dtype
|
| 427 |
+
if input_dtype == torch.float32:
|
| 428 |
+
if torch.is_autocast_enabled():
|
| 429 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 430 |
+
# Handle the case where the model is quantized
|
| 431 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 432 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 433 |
+
else:
|
| 434 |
+
target_dtype = self.q_proj.weight.dtype
|
| 435 |
+
|
| 436 |
+
logger.warning_once(
|
| 437 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 438 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 439 |
+
f" {target_dtype}."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
query_states = query_states.to(target_dtype)
|
| 443 |
+
key_states = key_states.to(target_dtype)
|
| 444 |
+
value_states = value_states.to(target_dtype)
|
| 445 |
+
|
| 446 |
+
attn_output = _flash_attention_forward(
|
| 447 |
+
query_states,
|
| 448 |
+
key_states,
|
| 449 |
+
value_states,
|
| 450 |
+
attention_mask,
|
| 451 |
+
q_len,
|
| 452 |
+
dropout=self.dropout if self.training else 0.0,
|
| 453 |
+
is_causal=self.is_causal,
|
| 454 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 458 |
+
attn_output = self.out_proj(attn_output)
|
| 459 |
+
|
| 460 |
+
if not output_attentions:
|
| 461 |
+
attn_weights = None
|
| 462 |
+
|
| 463 |
+
return attn_output, attn_weights, past_key_value
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
|
| 467 |
+
def forward(
|
| 468 |
+
self,
|
| 469 |
+
hidden_states: torch.Tensor,
|
| 470 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 471 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 472 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 473 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 474 |
+
output_attentions: bool = False,
|
| 475 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 476 |
+
"""Input shape: Batch x Time x Channel"""
|
| 477 |
+
if output_attentions or layer_head_mask is not None:
|
| 478 |
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
| 479 |
+
logger.warning_once(
|
| 480 |
+
"Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
| 481 |
+
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 482 |
+
)
|
| 483 |
+
return super().forward(
|
| 484 |
+
hidden_states,
|
| 485 |
+
key_value_states=key_value_states,
|
| 486 |
+
past_key_value=past_key_value,
|
| 487 |
+
attention_mask=attention_mask,
|
| 488 |
+
layer_head_mask=layer_head_mask,
|
| 489 |
+
output_attentions=output_attentions,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 493 |
+
# for the decoder
|
| 494 |
+
is_cross_attention = key_value_states is not None
|
| 495 |
+
|
| 496 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 497 |
+
|
| 498 |
+
# get query proj
|
| 499 |
+
query_states = self.q_proj(hidden_states)
|
| 500 |
+
# get key, value proj
|
| 501 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 502 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 503 |
+
# the provided `key_value_states` to support prefix tuning
|
| 504 |
+
if (
|
| 505 |
+
is_cross_attention
|
| 506 |
+
and past_key_value is not None
|
| 507 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 508 |
+
):
|
| 509 |
+
# reuse k,v, cross_attentions
|
| 510 |
+
key_states = past_key_value[0]
|
| 511 |
+
value_states = past_key_value[1]
|
| 512 |
+
elif is_cross_attention:
|
| 513 |
+
# cross_attentions
|
| 514 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 515 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 516 |
+
elif past_key_value is not None:
|
| 517 |
+
# reuse k, v, self_attention
|
| 518 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 519 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 520 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 521 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 522 |
+
else:
|
| 523 |
+
# self_attention
|
| 524 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 525 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 526 |
+
|
| 527 |
+
if self.is_decoder:
|
| 528 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 529 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 530 |
+
# key/value_states (first "if" case)
|
| 531 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 532 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 533 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 534 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 535 |
+
past_key_value = (key_states, value_states)
|
| 536 |
+
|
| 537 |
+
query_states = self._shape(query_states, tgt_len, bsz)
|
| 538 |
+
|
| 539 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 540 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 541 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
| 542 |
+
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
| 543 |
+
|
| 544 |
+
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
| 545 |
+
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
| 546 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 547 |
+
query_states,
|
| 548 |
+
key_states,
|
| 549 |
+
value_states,
|
| 550 |
+
attn_mask=attention_mask,
|
| 551 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 552 |
+
is_causal=is_causal,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
| 556 |
+
raise ValueError(
|
| 557 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 558 |
+
f" {attn_output.size()}"
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
attn_output = attn_output.transpose(1, 2)
|
| 562 |
+
|
| 563 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 564 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 565 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 566 |
+
|
| 567 |
+
attn_output = self.out_proj(attn_output)
|
| 568 |
+
|
| 569 |
+
return attn_output, None, past_key_value
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class Data2VecAudioFeedForward(nn.Module):
|
| 573 |
+
def __init__(self, config):
|
| 574 |
+
super().__init__()
|
| 575 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 576 |
+
|
| 577 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 578 |
+
if isinstance(config.hidden_act, str):
|
| 579 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 580 |
+
else:
|
| 581 |
+
self.intermediate_act_fn = config.hidden_act
|
| 582 |
+
|
| 583 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 584 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 585 |
+
|
| 586 |
+
def forward(self, hidden_states):
|
| 587 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
| 588 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 589 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
| 590 |
+
|
| 591 |
+
hidden_states = self.output_dense(hidden_states)
|
| 592 |
+
hidden_states = self.output_dropout(hidden_states)
|
| 593 |
+
return hidden_states
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
DATA2VEC_AUDIO_ATTENTION_CLASSES = {
|
| 597 |
+
"eager": Data2VecAudioAttention,
|
| 598 |
+
"sdpa": Data2VecAudioSdpaAttention,
|
| 599 |
+
"flash_attention_2": Data2VecAudioFlashAttention2,
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class Data2VecAudioEncoderLayer(nn.Module):
|
| 604 |
+
def __init__(self, config):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
|
| 607 |
+
embed_dim=config.hidden_size,
|
| 608 |
+
num_heads=config.num_attention_heads,
|
| 609 |
+
dropout=config.attention_dropout,
|
| 610 |
+
is_decoder=False,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 614 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 615 |
+
self.feed_forward = Data2VecAudioFeedForward(config)
|
| 616 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 617 |
+
|
| 618 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 619 |
+
attn_residual = hidden_states
|
| 620 |
+
hidden_states, attn_weights, _ = self.attention(
|
| 621 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
| 622 |
+
)
|
| 623 |
+
hidden_states = self.dropout(hidden_states)
|
| 624 |
+
hidden_states = attn_residual + hidden_states
|
| 625 |
+
|
| 626 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 627 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
| 628 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 629 |
+
|
| 630 |
+
outputs = (hidden_states,)
|
| 631 |
+
|
| 632 |
+
if output_attentions:
|
| 633 |
+
outputs += (attn_weights,)
|
| 634 |
+
|
| 635 |
+
return outputs
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class Data2VecAudioEncoder(nn.Module):
|
| 639 |
+
def __init__(self, config):
|
| 640 |
+
super().__init__()
|
| 641 |
+
self.config = config
|
| 642 |
+
self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config)
|
| 643 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 644 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 645 |
+
self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 646 |
+
self.gradient_checkpointing = False
|
| 647 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 648 |
+
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
hidden_states: torch.tensor,
|
| 652 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 653 |
+
output_attentions: bool = False,
|
| 654 |
+
output_hidden_states: bool = False,
|
| 655 |
+
return_dict: bool = True,
|
| 656 |
+
):
|
| 657 |
+
all_hidden_states = () if output_hidden_states else None
|
| 658 |
+
all_self_attentions = () if output_attentions else None
|
| 659 |
+
|
| 660 |
+
if attention_mask is not None:
|
| 661 |
+
# make sure padded tokens output 0
|
| 662 |
+
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 663 |
+
hidden_states[~expand_attention_mask] = 0
|
| 664 |
+
if self._use_flash_attention_2:
|
| 665 |
+
# 2d mask is passed through the layers
|
| 666 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 667 |
+
else:
|
| 668 |
+
# extend attention_mask
|
| 669 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 670 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 671 |
+
attention_mask = attention_mask.expand(
|
| 672 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
| 676 |
+
hidden_states = hidden_states + position_embeddings
|
| 677 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 678 |
+
hidden_states = self.dropout(hidden_states)
|
| 679 |
+
|
| 680 |
+
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
| 681 |
+
|
| 682 |
+
for layer in self.layers:
|
| 683 |
+
if output_hidden_states:
|
| 684 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 685 |
+
|
| 686 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 687 |
+
dropout_probability = torch.rand([])
|
| 688 |
+
|
| 689 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 690 |
+
if not skip_the_layer or synced_gpus:
|
| 691 |
+
# under fsdp or deepspeed zero3 all gpus must run in sync
|
| 692 |
+
if self.gradient_checkpointing and self.training:
|
| 693 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 694 |
+
layer.__call__,
|
| 695 |
+
hidden_states,
|
| 696 |
+
attention_mask,
|
| 697 |
+
output_attentions,
|
| 698 |
+
)
|
| 699 |
+
else:
|
| 700 |
+
layer_outputs = layer(
|
| 701 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
| 702 |
+
)
|
| 703 |
+
hidden_states = layer_outputs[0]
|
| 704 |
+
|
| 705 |
+
if skip_the_layer:
|
| 706 |
+
layer_outputs = (None, None)
|
| 707 |
+
|
| 708 |
+
if output_attentions:
|
| 709 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 710 |
+
|
| 711 |
+
if output_hidden_states:
|
| 712 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 713 |
+
|
| 714 |
+
if not return_dict:
|
| 715 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 716 |
+
return BaseModelOutput(
|
| 717 |
+
last_hidden_state=hidden_states,
|
| 718 |
+
hidden_states=all_hidden_states,
|
| 719 |
+
attentions=all_self_attentions,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class Data2VecAudioAdapterLayer(nn.Module):
|
| 724 |
+
def __init__(self, config):
|
| 725 |
+
super().__init__()
|
| 726 |
+
self.conv = nn.Conv1d(
|
| 727 |
+
config.output_hidden_size,
|
| 728 |
+
2 * config.output_hidden_size,
|
| 729 |
+
config.adapter_kernel_size,
|
| 730 |
+
stride=config.adapter_stride,
|
| 731 |
+
padding=1,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
def forward(self, hidden_states):
|
| 735 |
+
hidden_states = self.conv(hidden_states)
|
| 736 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 737 |
+
|
| 738 |
+
return hidden_states
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
class Data2VecAudioAdapter(nn.Module):
|
| 742 |
+
def __init__(self, config):
|
| 743 |
+
super().__init__()
|
| 744 |
+
|
| 745 |
+
# feature dim might need to be down-projected
|
| 746 |
+
if config.output_hidden_size != config.hidden_size:
|
| 747 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
| 748 |
+
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
| 749 |
+
else:
|
| 750 |
+
self.proj = self.proj_layer_norm = None
|
| 751 |
+
|
| 752 |
+
self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
| 753 |
+
self.layerdrop = config.layerdrop
|
| 754 |
+
|
| 755 |
+
def forward(self, hidden_states):
|
| 756 |
+
# down project hidden_states if necessary
|
| 757 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
| 758 |
+
hidden_states = self.proj(hidden_states)
|
| 759 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
| 760 |
+
|
| 761 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 762 |
+
|
| 763 |
+
for layer in self.layers:
|
| 764 |
+
layerdrop_prob = np.random.random()
|
| 765 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
| 766 |
+
hidden_states = layer(hidden_states)
|
| 767 |
+
|
| 768 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 769 |
+
return hidden_states
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class Data2VecAudioPreTrainedModel(PreTrainedModel):
|
| 773 |
+
"""
|
| 774 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 775 |
+
models.
|
| 776 |
+
"""
|
| 777 |
+
|
| 778 |
+
config_class = Data2VecAudioConfig
|
| 779 |
+
base_model_prefix = "data2vec_audio"
|
| 780 |
+
main_input_name = "input_values"
|
| 781 |
+
supports_gradient_checkpointing = True
|
| 782 |
+
_supports_flash_attn_2 = True
|
| 783 |
+
_supports_sdpa = True
|
| 784 |
+
|
| 785 |
+
def _init_weights(self, module):
|
| 786 |
+
"""Initialize the weights"""
|
| 787 |
+
if isinstance(module, Data2VecAudioFeatureProjection):
|
| 788 |
+
k = math.sqrt(1 / module.projection.in_features)
|
| 789 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 790 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 791 |
+
elif isinstance(module, Data2VecAudioPositionalConvLayer):
|
| 792 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 793 |
+
elif isinstance(module, nn.Linear):
|
| 794 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 795 |
+
|
| 796 |
+
if module.bias is not None:
|
| 797 |
+
module.bias.data.zero_()
|
| 798 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 799 |
+
if module.bias is not None:
|
| 800 |
+
module.bias.data.zero_()
|
| 801 |
+
if module.weight is not None:
|
| 802 |
+
module.weight.data.fill_(1.0)
|
| 803 |
+
elif isinstance(module, nn.Conv1d):
|
| 804 |
+
nn.init.kaiming_normal_(module.weight)
|
| 805 |
+
|
| 806 |
+
if module.bias is not None:
|
| 807 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 808 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 809 |
+
|
| 810 |
+
def _get_feat_extract_output_lengths(
|
| 811 |
+
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
| 812 |
+
):
|
| 813 |
+
"""
|
| 814 |
+
Computes the output length of the convolutional layers
|
| 815 |
+
"""
|
| 816 |
+
|
| 817 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
| 818 |
+
|
| 819 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 820 |
+
# 1D convolutional layer output length formula taken
|
| 821 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 822 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 823 |
+
|
| 824 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 825 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 826 |
+
|
| 827 |
+
if add_adapter:
|
| 828 |
+
for _ in range(self.config.num_adapter_layers):
|
| 829 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
| 830 |
+
|
| 831 |
+
return input_lengths
|
| 832 |
+
|
| 833 |
+
def _get_feature_vector_attention_mask(
|
| 834 |
+
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
| 835 |
+
):
|
| 836 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
| 837 |
+
# on inference mode.
|
| 838 |
+
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
| 839 |
+
|
| 840 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
| 841 |
+
output_lengths = output_lengths.to(torch.long)
|
| 842 |
+
|
| 843 |
+
batch_size = attention_mask.shape[0]
|
| 844 |
+
|
| 845 |
+
attention_mask = torch.zeros(
|
| 846 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 847 |
+
)
|
| 848 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 849 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 850 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 851 |
+
return attention_mask
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def _compute_mask_indices(
|
| 855 |
+
shape: Tuple[int, int],
|
| 856 |
+
mask_prob: float,
|
| 857 |
+
mask_length: int,
|
| 858 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 859 |
+
min_masks: int = 0,
|
| 860 |
+
) -> np.ndarray:
|
| 861 |
+
"""
|
| 862 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 863 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 864 |
+
CPU as part of the preprocessing during training.
|
| 865 |
+
|
| 866 |
+
Args:
|
| 867 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 868 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 869 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 870 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 871 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 872 |
+
actual percentage will be smaller.
|
| 873 |
+
mask_length: size of the mask
|
| 874 |
+
min_masks: minimum number of masked spans
|
| 875 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 876 |
+
each batch dimension.
|
| 877 |
+
"""
|
| 878 |
+
batch_size, sequence_length = shape
|
| 879 |
+
|
| 880 |
+
if mask_length < 1:
|
| 881 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 882 |
+
|
| 883 |
+
if mask_length > sequence_length:
|
| 884 |
+
raise ValueError(
|
| 885 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 886 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# epsilon is used for probabilistic rounding
|
| 890 |
+
epsilon = np.random.rand(1).item()
|
| 891 |
+
|
| 892 |
+
def compute_num_masked_span(input_length):
|
| 893 |
+
"""Given input length, compute how many spans should be masked"""
|
| 894 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 895 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 896 |
+
|
| 897 |
+
# make sure num masked span <= sequence_length
|
| 898 |
+
if num_masked_span * mask_length > sequence_length:
|
| 899 |
+
num_masked_span = sequence_length // mask_length
|
| 900 |
+
|
| 901 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 902 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 903 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 904 |
+
|
| 905 |
+
return num_masked_span
|
| 906 |
+
|
| 907 |
+
# compute number of masked spans in batch
|
| 908 |
+
input_lengths = (
|
| 909 |
+
attention_mask.detach().sum(-1).tolist()
|
| 910 |
+
if attention_mask is not None
|
| 911 |
+
else [sequence_length for _ in range(batch_size)]
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
# SpecAugment mask to fill
|
| 915 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 916 |
+
spec_aug_mask_idxs = []
|
| 917 |
+
|
| 918 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 919 |
+
|
| 920 |
+
if max_num_masked_span == 0:
|
| 921 |
+
return spec_aug_mask
|
| 922 |
+
|
| 923 |
+
for input_length in input_lengths:
|
| 924 |
+
# compute num of masked spans for this input
|
| 925 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 926 |
+
|
| 927 |
+
# get random indices to mask
|
| 928 |
+
spec_aug_mask_idx = np.random.choice(
|
| 929 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 933 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 934 |
+
# Picking first sample just pads those vectors twice.
|
| 935 |
+
if len(spec_aug_mask_idx) == 0:
|
| 936 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 937 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 938 |
+
# token which we can use as a dummy mask id
|
| 939 |
+
dummy_mask_idx = sequence_length - 1
|
| 940 |
+
else:
|
| 941 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 942 |
+
|
| 943 |
+
spec_aug_mask_idx = np.concatenate(
|
| 944 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 945 |
+
)
|
| 946 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 947 |
+
|
| 948 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 949 |
+
|
| 950 |
+
# expand masked indices to masked spans
|
| 951 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 952 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 953 |
+
)
|
| 954 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 955 |
+
|
| 956 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 957 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 958 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 959 |
+
batch_size, max_num_masked_span * mask_length
|
| 960 |
+
)
|
| 961 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 962 |
+
|
| 963 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 964 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 965 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 966 |
+
|
| 967 |
+
# scatter indices to mask
|
| 968 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 969 |
+
|
| 970 |
+
return spec_aug_mask
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
DATA2VEC_AUDIO_START_DOCSTRING = r"""
|
| 977 |
+
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
|
| 978 |
+
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
|
| 979 |
+
Michael Auli.
|
| 980 |
+
|
| 981 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 982 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 983 |
+
|
| 984 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 985 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 986 |
+
behavior.
|
| 987 |
+
|
| 988 |
+
Parameters:
|
| 989 |
+
config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
|
| 990 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 991 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 992 |
+
"""
|
| 993 |
+
|
| 994 |
+
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
|
| 995 |
+
Args:
|
| 996 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 997 |
+
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
| 998 |
+
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
| 999 |
+
soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
|
| 1000 |
+
conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1001 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1002 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1003 |
+
1]`:
|
| 1004 |
+
|
| 1005 |
+
- 1 for tokens that are **not masked**,
|
| 1006 |
+
- 0 for tokens that are **masked**.
|
| 1007 |
+
|
| 1008 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1009 |
+
|
| 1010 |
+
<Tip warning={true}>
|
| 1011 |
+
|
| 1012 |
+
`attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
|
| 1013 |
+
True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
|
| 1014 |
+
`attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
|
| 1015 |
+
because there are more than one convolutional layer in the positional encodings. For a more detailed
|
| 1016 |
+
explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
|
| 1017 |
+
|
| 1018 |
+
</Tip>
|
| 1019 |
+
|
| 1020 |
+
output_attentions (`bool`, *optional*):
|
| 1021 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1022 |
+
tensors for more detail.
|
| 1023 |
+
output_hidden_states (`bool`, *optional*):
|
| 1024 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1025 |
+
more detail.
|
| 1026 |
+
return_dict (`bool`, *optional*):
|
| 1027 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1028 |
+
"""
|
| 1029 |
+
|
| 1030 |
+
Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
@add_start_docstrings(
|
| 1034 |
+
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1035 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 1036 |
+
)
|
| 1037 |
+
class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
|
| 1038 |
+
def __init__(self, config: Data2VecAudioConfig):
|
| 1039 |
+
super().__init__(config)
|
| 1040 |
+
self.config = config
|
| 1041 |
+
self.feature_extractor = Data2VecAudioFeatureEncoder(config)
|
| 1042 |
+
self.feature_projection = Data2VecAudioFeatureProjection(config)
|
| 1043 |
+
|
| 1044 |
+
# model only needs masking vector if mask prob is > 0.0
|
| 1045 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1046 |
+
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
|
| 1047 |
+
|
| 1048 |
+
self.encoder = Data2VecAudioEncoder(config)
|
| 1049 |
+
|
| 1050 |
+
self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
|
| 1051 |
+
|
| 1052 |
+
# Initialize weights and apply final processing
|
| 1053 |
+
self.post_init()
|
| 1054 |
+
|
| 1055 |
+
def freeze_feature_encoder(self):
|
| 1056 |
+
"""
|
| 1057 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1058 |
+
not be updated during training.
|
| 1059 |
+
"""
|
| 1060 |
+
self.feature_extractor._freeze_parameters()
|
| 1061 |
+
|
| 1062 |
+
def _mask_hidden_states(
|
| 1063 |
+
self,
|
| 1064 |
+
hidden_states: torch.FloatTensor,
|
| 1065 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1066 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1067 |
+
):
|
| 1068 |
+
"""
|
| 1069 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1070 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1071 |
+
"""
|
| 1072 |
+
|
| 1073 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1074 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1075 |
+
return hidden_states
|
| 1076 |
+
|
| 1077 |
+
# generate indices & apply SpecAugment along time axis
|
| 1078 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1079 |
+
|
| 1080 |
+
if mask_time_indices is not None:
|
| 1081 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1082 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1083 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1084 |
+
mask_time_indices = _compute_mask_indices(
|
| 1085 |
+
(batch_size, sequence_length),
|
| 1086 |
+
mask_prob=self.config.mask_time_prob,
|
| 1087 |
+
mask_length=self.config.mask_time_length,
|
| 1088 |
+
attention_mask=attention_mask,
|
| 1089 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1090 |
+
)
|
| 1091 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1092 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1093 |
+
|
| 1094 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1095 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1096 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1097 |
+
(batch_size, hidden_size),
|
| 1098 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1099 |
+
mask_length=self.config.mask_feature_length,
|
| 1100 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1101 |
+
)
|
| 1102 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1103 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1104 |
+
hidden_states[mask_feature_indices] = 0
|
| 1105 |
+
|
| 1106 |
+
return hidden_states
|
| 1107 |
+
|
| 1108 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 1109 |
+
@add_code_sample_docstrings(
|
| 1110 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1111 |
+
output_type=Data2VecAudioBaseModelOutput,
|
| 1112 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1113 |
+
modality="audio",
|
| 1114 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1115 |
+
)
|
| 1116 |
+
def forward(
|
| 1117 |
+
self,
|
| 1118 |
+
input_values: Optional[torch.Tensor],
|
| 1119 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1120 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1121 |
+
output_attentions: Optional[bool] = None,
|
| 1122 |
+
output_hidden_states: Optional[bool] = None,
|
| 1123 |
+
return_dict: Optional[bool] = None,
|
| 1124 |
+
) -> Union[Tuple, Data2VecAudioBaseModelOutput]:
|
| 1125 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1126 |
+
output_hidden_states = (
|
| 1127 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1128 |
+
)
|
| 1129 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1130 |
+
|
| 1131 |
+
extract_features = self.feature_extractor(input_values)
|
| 1132 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1133 |
+
|
| 1134 |
+
if attention_mask is not None:
|
| 1135 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1136 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1137 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 1141 |
+
hidden_states = self._mask_hidden_states(
|
| 1142 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
encoder_outputs = self.encoder(
|
| 1146 |
+
hidden_states,
|
| 1147 |
+
attention_mask=attention_mask,
|
| 1148 |
+
output_attentions=output_attentions,
|
| 1149 |
+
output_hidden_states=output_hidden_states,
|
| 1150 |
+
return_dict=return_dict,
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
hidden_states = encoder_outputs[0]
|
| 1154 |
+
|
| 1155 |
+
if self.adapter is not None:
|
| 1156 |
+
hidden_states = self.adapter(hidden_states)
|
| 1157 |
+
|
| 1158 |
+
if not return_dict:
|
| 1159 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
| 1160 |
+
|
| 1161 |
+
return Data2VecAudioBaseModelOutput(
|
| 1162 |
+
last_hidden_state=hidden_states,
|
| 1163 |
+
extract_features=extract_features,
|
| 1164 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1165 |
+
attentions=encoder_outputs.attentions,
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 1170 |
+
|
| 1171 |
+
# CTC docstring
|
| 1172 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 1173 |
+
_CTC_EXPECTED_LOSS = 66.95
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
@add_start_docstrings(
|
| 1177 |
+
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1178 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 1179 |
+
)
|
| 1180 |
+
class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
|
| 1181 |
+
def __init__(self, config):
|
| 1182 |
+
super().__init__(config)
|
| 1183 |
+
|
| 1184 |
+
self.data2vec_audio = Data2VecAudioModel(config)
|
| 1185 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1186 |
+
|
| 1187 |
+
if config.vocab_size is None:
|
| 1188 |
+
raise ValueError(
|
| 1189 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1190 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1191 |
+
"instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1192 |
+
"or define `vocab_size` of your model's configuration."
|
| 1193 |
+
)
|
| 1194 |
+
output_hidden_size = (
|
| 1195 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1196 |
+
)
|
| 1197 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1198 |
+
|
| 1199 |
+
# Initialize weights and apply final processing
|
| 1200 |
+
self.post_init()
|
| 1201 |
+
|
| 1202 |
+
def freeze_feature_extractor(self):
|
| 1203 |
+
"""
|
| 1204 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1205 |
+
not be updated during training.
|
| 1206 |
+
"""
|
| 1207 |
+
warnings.warn(
|
| 1208 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1209 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1210 |
+
FutureWarning,
|
| 1211 |
+
)
|
| 1212 |
+
self.freeze_feature_encoder()
|
| 1213 |
+
|
| 1214 |
+
def freeze_feature_encoder(self):
|
| 1215 |
+
"""
|
| 1216 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1217 |
+
not be updated during training.
|
| 1218 |
+
"""
|
| 1219 |
+
self.data2vec_audio.feature_extractor._freeze_parameters()
|
| 1220 |
+
|
| 1221 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 1222 |
+
@add_code_sample_docstrings(
|
| 1223 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1224 |
+
output_type=CausalLMOutput,
|
| 1225 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1226 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1227 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1228 |
+
)
|
| 1229 |
+
def forward(
|
| 1230 |
+
self,
|
| 1231 |
+
input_values: Optional[torch.Tensor],
|
| 1232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1233 |
+
output_attentions: Optional[bool] = None,
|
| 1234 |
+
output_hidden_states: Optional[bool] = None,
|
| 1235 |
+
return_dict: Optional[bool] = None,
|
| 1236 |
+
labels: Optional[torch.Tensor] = None,
|
| 1237 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1238 |
+
r"""
|
| 1239 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1240 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1241 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1242 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1243 |
+
config.vocab_size - 1]`.
|
| 1244 |
+
"""
|
| 1245 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1246 |
+
|
| 1247 |
+
if labels is not None and labels.max() >= self.config.vocab_size:
|
| 1248 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1249 |
+
|
| 1250 |
+
outputs = self.data2vec_audio(
|
| 1251 |
+
input_values,
|
| 1252 |
+
attention_mask=attention_mask,
|
| 1253 |
+
output_attentions=output_attentions,
|
| 1254 |
+
output_hidden_states=output_hidden_states,
|
| 1255 |
+
return_dict=return_dict,
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
hidden_states = outputs[0]
|
| 1259 |
+
hidden_states = self.dropout(hidden_states)
|
| 1260 |
+
|
| 1261 |
+
logits = self.lm_head(hidden_states)
|
| 1262 |
+
|
| 1263 |
+
loss = None
|
| 1264 |
+
if labels is not None:
|
| 1265 |
+
# retrieve loss input_lengths from attention_mask
|
| 1266 |
+
attention_mask = (
|
| 1267 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1268 |
+
)
|
| 1269 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1270 |
+
|
| 1271 |
+
# assuming that padded tokens are filled with -100
|
| 1272 |
+
# when not being attended to
|
| 1273 |
+
labels_mask = labels >= 0
|
| 1274 |
+
target_lengths = labels_mask.sum(-1)
|
| 1275 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1276 |
+
|
| 1277 |
+
# ctc_loss doesn't support fp16
|
| 1278 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1279 |
+
|
| 1280 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1281 |
+
loss = nn.functional.ctc_loss(
|
| 1282 |
+
log_probs,
|
| 1283 |
+
flattened_targets,
|
| 1284 |
+
input_lengths,
|
| 1285 |
+
target_lengths,
|
| 1286 |
+
blank=self.config.pad_token_id,
|
| 1287 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1288 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
if not return_dict:
|
| 1292 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1293 |
+
return ((loss,) + output) if loss is not None else output
|
| 1294 |
+
|
| 1295 |
+
return CausalLMOutput(
|
| 1296 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
@add_start_docstrings(
|
| 1301 |
+
"""
|
| 1302 |
+
Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
|
| 1303 |
+
like SUPERB Keyword Spotting.
|
| 1304 |
+
""",
|
| 1305 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 1306 |
+
)
|
| 1307 |
+
class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
|
| 1308 |
+
def __init__(self, config):
|
| 1309 |
+
super().__init__(config)
|
| 1310 |
+
|
| 1311 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1312 |
+
raise ValueError(
|
| 1313 |
+
"Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
|
| 1314 |
+
)
|
| 1315 |
+
self.data2vec_audio = Data2VecAudioModel(config)
|
| 1316 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1317 |
+
if config.use_weighted_layer_sum:
|
| 1318 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1319 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1320 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1321 |
+
|
| 1322 |
+
# Initialize weights and apply final processing
|
| 1323 |
+
self.post_init()
|
| 1324 |
+
|
| 1325 |
+
def freeze_feature_extractor(self):
|
| 1326 |
+
"""
|
| 1327 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
| 1328 |
+
not be updated during training.
|
| 1329 |
+
"""
|
| 1330 |
+
warnings.warn(
|
| 1331 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1332 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1333 |
+
FutureWarning,
|
| 1334 |
+
)
|
| 1335 |
+
self.freeze_feature_encoder()
|
| 1336 |
+
|
| 1337 |
+
def freeze_feature_encoder(self):
|
| 1338 |
+
"""
|
| 1339 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1340 |
+
not be updated during training.
|
| 1341 |
+
"""
|
| 1342 |
+
self.data2vec_audio.feature_extractor._freeze_parameters()
|
| 1343 |
+
|
| 1344 |
+
def freeze_base_model(self):
|
| 1345 |
+
"""
|
| 1346 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1347 |
+
be updated during training. Only the classification head will be updated.
|
| 1348 |
+
"""
|
| 1349 |
+
for param in self.data2vec_audio.parameters():
|
| 1350 |
+
param.requires_grad = False
|
| 1351 |
+
|
| 1352 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 1353 |
+
@add_code_sample_docstrings(
|
| 1354 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1355 |
+
output_type=SequenceClassifierOutput,
|
| 1356 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1357 |
+
modality="audio",
|
| 1358 |
+
)
|
| 1359 |
+
def forward(
|
| 1360 |
+
self,
|
| 1361 |
+
input_values: Optional[torch.Tensor],
|
| 1362 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1363 |
+
output_attentions: Optional[bool] = None,
|
| 1364 |
+
output_hidden_states: Optional[bool] = None,
|
| 1365 |
+
return_dict: Optional[bool] = None,
|
| 1366 |
+
labels: Optional[torch.Tensor] = None,
|
| 1367 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1368 |
+
r"""
|
| 1369 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1370 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1371 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1372 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1373 |
+
"""
|
| 1374 |
+
|
| 1375 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1376 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1377 |
+
|
| 1378 |
+
outputs = self.data2vec_audio(
|
| 1379 |
+
input_values,
|
| 1380 |
+
attention_mask=attention_mask,
|
| 1381 |
+
output_attentions=output_attentions,
|
| 1382 |
+
output_hidden_states=output_hidden_states,
|
| 1383 |
+
return_dict=return_dict,
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
if self.config.use_weighted_layer_sum:
|
| 1387 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1388 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1389 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1390 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1391 |
+
else:
|
| 1392 |
+
hidden_states = outputs[0]
|
| 1393 |
+
|
| 1394 |
+
hidden_states = self.projector(hidden_states)
|
| 1395 |
+
if attention_mask is None:
|
| 1396 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1397 |
+
else:
|
| 1398 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1399 |
+
expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
|
| 1400 |
+
hidden_states[~expand_padding_mask] = 0.0
|
| 1401 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1402 |
+
|
| 1403 |
+
logits = self.classifier(pooled_output)
|
| 1404 |
+
|
| 1405 |
+
loss = None
|
| 1406 |
+
if labels is not None:
|
| 1407 |
+
loss_fct = CrossEntropyLoss()
|
| 1408 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1409 |
+
|
| 1410 |
+
if not return_dict:
|
| 1411 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1412 |
+
return ((loss,) + output) if loss is not None else output
|
| 1413 |
+
|
| 1414 |
+
return SequenceClassifierOutput(
|
| 1415 |
+
loss=loss,
|
| 1416 |
+
logits=logits,
|
| 1417 |
+
hidden_states=outputs.hidden_states,
|
| 1418 |
+
attentions=outputs.attentions,
|
| 1419 |
+
)
|
| 1420 |
+
|
| 1421 |
+
|
| 1422 |
+
@add_start_docstrings(
|
| 1423 |
+
"""
|
| 1424 |
+
Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 1425 |
+
""",
|
| 1426 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 1427 |
+
)
|
| 1428 |
+
class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
|
| 1429 |
+
def __init__(self, config):
|
| 1430 |
+
super().__init__(config)
|
| 1431 |
+
|
| 1432 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1433 |
+
raise ValueError(
|
| 1434 |
+
"Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
|
| 1435 |
+
)
|
| 1436 |
+
self.data2vec_audio = Data2VecAudioModel(config)
|
| 1437 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1438 |
+
if config.use_weighted_layer_sum:
|
| 1439 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1440 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1441 |
+
self.num_labels = config.num_labels
|
| 1442 |
+
|
| 1443 |
+
self.init_weights()
|
| 1444 |
+
|
| 1445 |
+
def freeze_feature_extractor(self):
|
| 1446 |
+
"""
|
| 1447 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1448 |
+
not be updated during training.
|
| 1449 |
+
"""
|
| 1450 |
+
warnings.warn(
|
| 1451 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1452 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1453 |
+
FutureWarning,
|
| 1454 |
+
)
|
| 1455 |
+
self.freeze_feature_encoder()
|
| 1456 |
+
|
| 1457 |
+
def freeze_feature_encoder(self):
|
| 1458 |
+
"""
|
| 1459 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1460 |
+
not be updated during training.
|
| 1461 |
+
"""
|
| 1462 |
+
self.data2vec_audio.feature_extractor._freeze_parameters()
|
| 1463 |
+
|
| 1464 |
+
def freeze_base_model(self):
|
| 1465 |
+
"""
|
| 1466 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1467 |
+
be updated during training. Only the classification head will be updated.
|
| 1468 |
+
"""
|
| 1469 |
+
for param in self.data2vec_audio.parameters():
|
| 1470 |
+
param.requires_grad = False
|
| 1471 |
+
|
| 1472 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 1473 |
+
@add_code_sample_docstrings(
|
| 1474 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1475 |
+
output_type=TokenClassifierOutput,
|
| 1476 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1477 |
+
modality="audio",
|
| 1478 |
+
)
|
| 1479 |
+
def forward(
|
| 1480 |
+
self,
|
| 1481 |
+
input_values: Optional[torch.Tensor],
|
| 1482 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1483 |
+
labels: Optional[torch.Tensor] = None,
|
| 1484 |
+
output_attentions: Optional[bool] = None,
|
| 1485 |
+
output_hidden_states: Optional[bool] = None,
|
| 1486 |
+
return_dict: Optional[bool] = None,
|
| 1487 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1488 |
+
r"""
|
| 1489 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1490 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1491 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1492 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1493 |
+
"""
|
| 1494 |
+
|
| 1495 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1496 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1497 |
+
|
| 1498 |
+
outputs = self.data2vec_audio(
|
| 1499 |
+
input_values,
|
| 1500 |
+
attention_mask=attention_mask,
|
| 1501 |
+
output_attentions=output_attentions,
|
| 1502 |
+
output_hidden_states=output_hidden_states,
|
| 1503 |
+
return_dict=return_dict,
|
| 1504 |
+
)
|
| 1505 |
+
|
| 1506 |
+
if self.config.use_weighted_layer_sum:
|
| 1507 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1508 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1509 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1510 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1511 |
+
else:
|
| 1512 |
+
hidden_states = outputs[0]
|
| 1513 |
+
|
| 1514 |
+
logits = self.classifier(hidden_states)
|
| 1515 |
+
|
| 1516 |
+
loss = None
|
| 1517 |
+
if labels is not None:
|
| 1518 |
+
loss_fct = CrossEntropyLoss()
|
| 1519 |
+
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
| 1520 |
+
|
| 1521 |
+
if not return_dict:
|
| 1522 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1523 |
+
return output
|
| 1524 |
+
|
| 1525 |
+
return TokenClassifierOutput(
|
| 1526 |
+
loss=loss,
|
| 1527 |
+
logits=logits,
|
| 1528 |
+
hidden_states=outputs.hidden_states,
|
| 1529 |
+
attentions=outputs.attentions,
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
|
| 1533 |
+
class AMSoftmaxLoss(nn.Module):
|
| 1534 |
+
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
| 1535 |
+
super(AMSoftmaxLoss, self).__init__()
|
| 1536 |
+
self.scale = scale
|
| 1537 |
+
self.margin = margin
|
| 1538 |
+
self.num_labels = num_labels
|
| 1539 |
+
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
| 1540 |
+
self.loss = nn.CrossEntropyLoss()
|
| 1541 |
+
|
| 1542 |
+
def forward(self, hidden_states, labels):
|
| 1543 |
+
labels = labels.flatten()
|
| 1544 |
+
weight = nn.functional.normalize(self.weight, dim=0)
|
| 1545 |
+
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
| 1546 |
+
cos_theta = torch.mm(hidden_states, weight)
|
| 1547 |
+
psi = cos_theta - self.margin
|
| 1548 |
+
|
| 1549 |
+
onehot = nn.functional.one_hot(labels, self.num_labels)
|
| 1550 |
+
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
| 1551 |
+
loss = self.loss(logits, labels)
|
| 1552 |
+
|
| 1553 |
+
return loss
|
| 1554 |
+
|
| 1555 |
+
|
| 1556 |
+
class TDNNLayer(nn.Module):
|
| 1557 |
+
def __init__(self, config, layer_id=0):
|
| 1558 |
+
super().__init__()
|
| 1559 |
+
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
| 1560 |
+
self.out_conv_dim = config.tdnn_dim[layer_id]
|
| 1561 |
+
self.kernel_size = config.tdnn_kernel[layer_id]
|
| 1562 |
+
self.dilation = config.tdnn_dilation[layer_id]
|
| 1563 |
+
|
| 1564 |
+
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
| 1565 |
+
self.activation = nn.ReLU()
|
| 1566 |
+
|
| 1567 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1568 |
+
if is_peft_available():
|
| 1569 |
+
from peft.tuners.lora import LoraLayer
|
| 1570 |
+
|
| 1571 |
+
if is_peft_available():
|
| 1572 |
+
if isinstance(self.kernel, LoraLayer):
|
| 1573 |
+
warnings.warn(
|
| 1574 |
+
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
|
| 1575 |
+
"You should exclude TDNNLayer from LoRA's target modules.",
|
| 1576 |
+
)
|
| 1577 |
+
|
| 1578 |
+
# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
|
| 1579 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1580 |
+
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
|
| 1581 |
+
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
|
| 1582 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1583 |
+
|
| 1584 |
+
hidden_states = self.activation(hidden_states)
|
| 1585 |
+
return hidden_states
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
@add_start_docstrings(
|
| 1589 |
+
"""
|
| 1590 |
+
Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 1591 |
+
""",
|
| 1592 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 1593 |
+
)
|
| 1594 |
+
class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
|
| 1595 |
+
def __init__(self, config):
|
| 1596 |
+
super().__init__(config)
|
| 1597 |
+
|
| 1598 |
+
self.data2vec_audio = Data2VecAudioModel(config)
|
| 1599 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1600 |
+
if config.use_weighted_layer_sum:
|
| 1601 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1602 |
+
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
| 1603 |
+
|
| 1604 |
+
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
| 1605 |
+
self.tdnn = nn.ModuleList(tdnn_layers)
|
| 1606 |
+
|
| 1607 |
+
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
| 1608 |
+
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
| 1609 |
+
|
| 1610 |
+
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
| 1611 |
+
|
| 1612 |
+
self.init_weights()
|
| 1613 |
+
|
| 1614 |
+
def freeze_feature_extractor(self):
|
| 1615 |
+
"""
|
| 1616 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1617 |
+
not be updated during training.
|
| 1618 |
+
"""
|
| 1619 |
+
warnings.warn(
|
| 1620 |
+
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
| 1621 |
+
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
| 1622 |
+
FutureWarning,
|
| 1623 |
+
)
|
| 1624 |
+
self.freeze_feature_encoder()
|
| 1625 |
+
|
| 1626 |
+
def freeze_feature_encoder(self):
|
| 1627 |
+
"""
|
| 1628 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1629 |
+
not be updated during training.
|
| 1630 |
+
"""
|
| 1631 |
+
self.data2vec_audio.feature_extractor._freeze_parameters()
|
| 1632 |
+
|
| 1633 |
+
def freeze_base_model(self):
|
| 1634 |
+
"""
|
| 1635 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1636 |
+
be updated during training. Only the classification head will be updated.
|
| 1637 |
+
"""
|
| 1638 |
+
for param in self.data2vec_audio.parameters():
|
| 1639 |
+
param.requires_grad = False
|
| 1640 |
+
|
| 1641 |
+
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 1642 |
+
"""
|
| 1643 |
+
Computes the output length of the TDNN layers
|
| 1644 |
+
"""
|
| 1645 |
+
|
| 1646 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1647 |
+
# 1D convolutional layer output length formula taken
|
| 1648 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1649 |
+
return (input_length - kernel_size) // stride + 1
|
| 1650 |
+
|
| 1651 |
+
for kernel_size in self.config.tdnn_kernel:
|
| 1652 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
| 1653 |
+
|
| 1654 |
+
return input_lengths
|
| 1655 |
+
|
| 1656 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 1657 |
+
@add_code_sample_docstrings(
|
| 1658 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1659 |
+
output_type=XVectorOutput,
|
| 1660 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1661 |
+
modality="audio",
|
| 1662 |
+
)
|
| 1663 |
+
def forward(
|
| 1664 |
+
self,
|
| 1665 |
+
input_values: Optional[torch.Tensor],
|
| 1666 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1667 |
+
output_attentions: Optional[bool] = None,
|
| 1668 |
+
output_hidden_states: Optional[bool] = None,
|
| 1669 |
+
return_dict: Optional[bool] = None,
|
| 1670 |
+
labels: Optional[torch.Tensor] = None,
|
| 1671 |
+
) -> Union[Tuple, XVectorOutput]:
|
| 1672 |
+
r"""
|
| 1673 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1674 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1675 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1676 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1677 |
+
"""
|
| 1678 |
+
|
| 1679 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1680 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1681 |
+
|
| 1682 |
+
outputs = self.data2vec_audio(
|
| 1683 |
+
input_values,
|
| 1684 |
+
attention_mask=attention_mask,
|
| 1685 |
+
output_attentions=output_attentions,
|
| 1686 |
+
output_hidden_states=output_hidden_states,
|
| 1687 |
+
return_dict=return_dict,
|
| 1688 |
+
)
|
| 1689 |
+
|
| 1690 |
+
if self.config.use_weighted_layer_sum:
|
| 1691 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1692 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1693 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1694 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1695 |
+
else:
|
| 1696 |
+
hidden_states = outputs[0]
|
| 1697 |
+
|
| 1698 |
+
hidden_states = self.projector(hidden_states)
|
| 1699 |
+
|
| 1700 |
+
for tdnn_layer in self.tdnn:
|
| 1701 |
+
hidden_states = tdnn_layer(hidden_states)
|
| 1702 |
+
|
| 1703 |
+
# Statistic Pooling
|
| 1704 |
+
if attention_mask is None:
|
| 1705 |
+
mean_features = hidden_states.mean(dim=1)
|
| 1706 |
+
std_features = hidden_states.std(dim=1)
|
| 1707 |
+
else:
|
| 1708 |
+
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
| 1709 |
+
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
| 1710 |
+
mean_features = []
|
| 1711 |
+
std_features = []
|
| 1712 |
+
for i, length in enumerate(tdnn_output_lengths):
|
| 1713 |
+
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
| 1714 |
+
std_features.append(hidden_states[i, :length].std(dim=0))
|
| 1715 |
+
mean_features = torch.stack(mean_features)
|
| 1716 |
+
std_features = torch.stack(std_features)
|
| 1717 |
+
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
| 1718 |
+
|
| 1719 |
+
output_embeddings = self.feature_extractor(statistic_pooling)
|
| 1720 |
+
logits = self.classifier(output_embeddings)
|
| 1721 |
+
|
| 1722 |
+
loss = None
|
| 1723 |
+
if labels is not None:
|
| 1724 |
+
loss = self.objective(logits, labels)
|
| 1725 |
+
|
| 1726 |
+
if not return_dict:
|
| 1727 |
+
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1728 |
+
return ((loss,) + output) if loss is not None else output
|
| 1729 |
+
|
| 1730 |
+
return XVectorOutput(
|
| 1731 |
+
loss=loss,
|
| 1732 |
+
logits=logits,
|
| 1733 |
+
embeddings=output_embeddings,
|
| 1734 |
+
hidden_states=outputs.hidden_states,
|
| 1735 |
+
attentions=outputs.attentions,
|
| 1736 |
+
)
|
| 1737 |
+
|
| 1738 |
+
|
| 1739 |
+
__all__ = [
|
| 1740 |
+
"Data2VecAudioForAudioFrameClassification",
|
| 1741 |
+
"Data2VecAudioForCTC",
|
| 1742 |
+
"Data2VecAudioForSequenceClassification",
|
| 1743 |
+
"Data2VecAudioForXVector",
|
| 1744 |
+
"Data2VecAudioModel",
|
| 1745 |
+
"Data2VecAudioPreTrainedModel",
|
| 1746 |
+
]
|
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_text.py
ADDED
|
@@ -0,0 +1,1553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Data2VecText model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN, gelu
|
| 26 |
+
from ...generation import GenerationMixin
|
| 27 |
+
from ...modeling_outputs import (
|
| 28 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 29 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 30 |
+
CausalLMOutputWithCrossAttentions,
|
| 31 |
+
MaskedLMOutput,
|
| 32 |
+
MultipleChoiceModelOutput,
|
| 33 |
+
QuestionAnsweringModelOutput,
|
| 34 |
+
SequenceClassifierOutput,
|
| 35 |
+
TokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 39 |
+
from ...utils import (
|
| 40 |
+
add_code_sample_docstrings,
|
| 41 |
+
add_start_docstrings,
|
| 42 |
+
add_start_docstrings_to_model_forward,
|
| 43 |
+
logging,
|
| 44 |
+
replace_return_docstrings,
|
| 45 |
+
)
|
| 46 |
+
from .configuration_data2vec_text import Data2VecTextConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 53 |
+
|
| 54 |
+
# General docstring
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "facebook/data2vec-text-base"
|
| 56 |
+
_CONFIG_FOR_DOC = "Data2VecTextConfig"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
|
| 60 |
+
class Data2VecTextForTextEmbeddings(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
|
| 66 |
+
def __init__(self, config):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 69 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 70 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 71 |
+
|
| 72 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 73 |
+
# any TensorFlow checkpoint file
|
| 74 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 75 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 76 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 77 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 78 |
+
self.register_buffer(
|
| 79 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 80 |
+
)
|
| 81 |
+
self.register_buffer(
|
| 82 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# End copy
|
| 86 |
+
self.padding_idx = config.pad_token_id
|
| 87 |
+
self.position_embeddings = nn.Embedding(
|
| 88 |
+
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 93 |
+
):
|
| 94 |
+
if position_ids is None:
|
| 95 |
+
if input_ids is not None:
|
| 96 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
| 97 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
| 98 |
+
else:
|
| 99 |
+
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
| 100 |
+
|
| 101 |
+
if input_ids is not None:
|
| 102 |
+
input_shape = input_ids.size()
|
| 103 |
+
else:
|
| 104 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 105 |
+
|
| 106 |
+
seq_length = input_shape[1]
|
| 107 |
+
|
| 108 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 109 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 110 |
+
# issue #5664
|
| 111 |
+
if token_type_ids is None:
|
| 112 |
+
if hasattr(self, "token_type_ids"):
|
| 113 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 114 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 115 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 116 |
+
else:
|
| 117 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 118 |
+
|
| 119 |
+
if inputs_embeds is None:
|
| 120 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 121 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 122 |
+
|
| 123 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 124 |
+
if self.position_embedding_type == "absolute":
|
| 125 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 126 |
+
embeddings += position_embeddings
|
| 127 |
+
embeddings = self.LayerNorm(embeddings)
|
| 128 |
+
embeddings = self.dropout(embeddings)
|
| 129 |
+
return embeddings
|
| 130 |
+
|
| 131 |
+
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
| 132 |
+
"""
|
| 133 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
inputs_embeds: torch.Tensor
|
| 137 |
+
|
| 138 |
+
Returns: torch.Tensor
|
| 139 |
+
"""
|
| 140 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 141 |
+
sequence_length = input_shape[1]
|
| 142 |
+
|
| 143 |
+
position_ids = torch.arange(
|
| 144 |
+
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
| 145 |
+
)
|
| 146 |
+
return position_ids.unsqueeze(0).expand(input_shape)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
|
| 150 |
+
class Data2VecTextSelfAttention(nn.Module):
|
| 151 |
+
def __init__(self, config, position_embedding_type=None):
|
| 152 |
+
super().__init__()
|
| 153 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 154 |
+
raise ValueError(
|
| 155 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 156 |
+
f"heads ({config.num_attention_heads})"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.num_attention_heads = config.num_attention_heads
|
| 160 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 161 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 162 |
+
|
| 163 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 164 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 165 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 166 |
+
|
| 167 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 168 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
| 169 |
+
config, "position_embedding_type", "absolute"
|
| 170 |
+
)
|
| 171 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 172 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 173 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 174 |
+
|
| 175 |
+
self.is_decoder = config.is_decoder
|
| 176 |
+
|
| 177 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 178 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 179 |
+
x = x.view(new_x_shape)
|
| 180 |
+
return x.permute(0, 2, 1, 3)
|
| 181 |
+
|
| 182 |
+
def forward(
|
| 183 |
+
self,
|
| 184 |
+
hidden_states: torch.Tensor,
|
| 185 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 186 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 187 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 188 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 189 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 190 |
+
output_attentions: Optional[bool] = False,
|
| 191 |
+
) -> Tuple[torch.Tensor]:
|
| 192 |
+
mixed_query_layer = self.query(hidden_states)
|
| 193 |
+
|
| 194 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 195 |
+
# and values come from an encoder; the attention mask needs to be
|
| 196 |
+
# such that the encoder's padding tokens are not attended to.
|
| 197 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 198 |
+
|
| 199 |
+
if is_cross_attention and past_key_value is not None:
|
| 200 |
+
# reuse k,v, cross_attentions
|
| 201 |
+
key_layer = past_key_value[0]
|
| 202 |
+
value_layer = past_key_value[1]
|
| 203 |
+
attention_mask = encoder_attention_mask
|
| 204 |
+
elif is_cross_attention:
|
| 205 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 206 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 207 |
+
attention_mask = encoder_attention_mask
|
| 208 |
+
elif past_key_value is not None:
|
| 209 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 210 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 211 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 212 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 213 |
+
else:
|
| 214 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 215 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 216 |
+
|
| 217 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 218 |
+
|
| 219 |
+
use_cache = past_key_value is not None
|
| 220 |
+
if self.is_decoder:
|
| 221 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 222 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 223 |
+
# key/value_states (first "if" case)
|
| 224 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 225 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 226 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 227 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 228 |
+
past_key_value = (key_layer, value_layer)
|
| 229 |
+
|
| 230 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 231 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 232 |
+
|
| 233 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 234 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 235 |
+
if use_cache:
|
| 236 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
| 237 |
+
-1, 1
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 241 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 242 |
+
distance = position_ids_l - position_ids_r
|
| 243 |
+
|
| 244 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 245 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 246 |
+
|
| 247 |
+
if self.position_embedding_type == "relative_key":
|
| 248 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 249 |
+
attention_scores = attention_scores + relative_position_scores
|
| 250 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 251 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 252 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 253 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 254 |
+
|
| 255 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
# Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
|
| 258 |
+
attention_scores = attention_scores + attention_mask
|
| 259 |
+
|
| 260 |
+
# Normalize the attention scores to probabilities.
|
| 261 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 262 |
+
|
| 263 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 264 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 265 |
+
attention_probs = self.dropout(attention_probs)
|
| 266 |
+
|
| 267 |
+
# Mask heads if we want to
|
| 268 |
+
if head_mask is not None:
|
| 269 |
+
attention_probs = attention_probs * head_mask
|
| 270 |
+
|
| 271 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 272 |
+
|
| 273 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 274 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 275 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 276 |
+
|
| 277 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 278 |
+
|
| 279 |
+
if self.is_decoder:
|
| 280 |
+
outputs = outputs + (past_key_value,)
|
| 281 |
+
return outputs
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
| 285 |
+
class Data2VecTextSelfOutput(nn.Module):
|
| 286 |
+
def __init__(self, config):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 289 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 291 |
+
|
| 292 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 293 |
+
hidden_states = self.dense(hidden_states)
|
| 294 |
+
hidden_states = self.dropout(hidden_states)
|
| 295 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 296 |
+
return hidden_states
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
|
| 300 |
+
"eager": Data2VecTextSelfAttention,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
|
| 305 |
+
class Data2VecTextAttention(nn.Module):
|
| 306 |
+
def __init__(self, config, position_embedding_type=None):
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
| 309 |
+
config, position_embedding_type=position_embedding_type
|
| 310 |
+
)
|
| 311 |
+
self.output = Data2VecTextSelfOutput(config)
|
| 312 |
+
self.pruned_heads = set()
|
| 313 |
+
|
| 314 |
+
def prune_heads(self, heads):
|
| 315 |
+
if len(heads) == 0:
|
| 316 |
+
return
|
| 317 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 318 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Prune linear layers
|
| 322 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 323 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 324 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 325 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 326 |
+
|
| 327 |
+
# Update hyper params and store pruned heads
|
| 328 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 329 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 330 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 336 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 337 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 338 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 339 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 340 |
+
output_attentions: Optional[bool] = False,
|
| 341 |
+
) -> Tuple[torch.Tensor]:
|
| 342 |
+
self_outputs = self.self(
|
| 343 |
+
hidden_states,
|
| 344 |
+
attention_mask,
|
| 345 |
+
head_mask,
|
| 346 |
+
encoder_hidden_states,
|
| 347 |
+
encoder_attention_mask,
|
| 348 |
+
past_key_value,
|
| 349 |
+
output_attentions,
|
| 350 |
+
)
|
| 351 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 352 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 353 |
+
return outputs
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
|
| 357 |
+
class Data2VecTextIntermediate(nn.Module):
|
| 358 |
+
def __init__(self, config):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 361 |
+
if isinstance(config.hidden_act, str):
|
| 362 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 363 |
+
else:
|
| 364 |
+
self.intermediate_act_fn = config.hidden_act
|
| 365 |
+
|
| 366 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 367 |
+
hidden_states = self.dense(hidden_states)
|
| 368 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 369 |
+
return hidden_states
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput
|
| 373 |
+
class Data2VecTextOutput(nn.Module):
|
| 374 |
+
def __init__(self, config):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 377 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 378 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 379 |
+
|
| 380 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 381 |
+
hidden_states = self.dense(hidden_states)
|
| 382 |
+
hidden_states = self.dropout(hidden_states)
|
| 383 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
|
| 388 |
+
class Data2VecTextLayer(nn.Module):
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 392 |
+
self.seq_len_dim = 1
|
| 393 |
+
self.attention = Data2VecTextAttention(config)
|
| 394 |
+
self.is_decoder = config.is_decoder
|
| 395 |
+
self.add_cross_attention = config.add_cross_attention
|
| 396 |
+
if self.add_cross_attention:
|
| 397 |
+
if not self.is_decoder:
|
| 398 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
| 399 |
+
self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute")
|
| 400 |
+
self.intermediate = Data2VecTextIntermediate(config)
|
| 401 |
+
self.output = Data2VecTextOutput(config)
|
| 402 |
+
|
| 403 |
+
def forward(
|
| 404 |
+
self,
|
| 405 |
+
hidden_states: torch.Tensor,
|
| 406 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 407 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 408 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 409 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 410 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 411 |
+
output_attentions: Optional[bool] = False,
|
| 412 |
+
) -> Tuple[torch.Tensor]:
|
| 413 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 414 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 415 |
+
self_attention_outputs = self.attention(
|
| 416 |
+
hidden_states,
|
| 417 |
+
attention_mask,
|
| 418 |
+
head_mask,
|
| 419 |
+
output_attentions=output_attentions,
|
| 420 |
+
past_key_value=self_attn_past_key_value,
|
| 421 |
+
)
|
| 422 |
+
attention_output = self_attention_outputs[0]
|
| 423 |
+
|
| 424 |
+
# if decoder, the last output is tuple of self-attn cache
|
| 425 |
+
if self.is_decoder:
|
| 426 |
+
outputs = self_attention_outputs[1:-1]
|
| 427 |
+
present_key_value = self_attention_outputs[-1]
|
| 428 |
+
else:
|
| 429 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 430 |
+
|
| 431 |
+
cross_attn_present_key_value = None
|
| 432 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 433 |
+
if not hasattr(self, "crossattention"):
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
| 436 |
+
" by setting `config.add_cross_attention=True`"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
| 440 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 441 |
+
cross_attention_outputs = self.crossattention(
|
| 442 |
+
attention_output,
|
| 443 |
+
attention_mask,
|
| 444 |
+
head_mask,
|
| 445 |
+
encoder_hidden_states,
|
| 446 |
+
encoder_attention_mask,
|
| 447 |
+
cross_attn_past_key_value,
|
| 448 |
+
output_attentions,
|
| 449 |
+
)
|
| 450 |
+
attention_output = cross_attention_outputs[0]
|
| 451 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 452 |
+
|
| 453 |
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
| 454 |
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
| 455 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
| 456 |
+
|
| 457 |
+
layer_output = apply_chunking_to_forward(
|
| 458 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 459 |
+
)
|
| 460 |
+
outputs = (layer_output,) + outputs
|
| 461 |
+
|
| 462 |
+
# if decoder, return the attn key/values as the last output
|
| 463 |
+
if self.is_decoder:
|
| 464 |
+
outputs = outputs + (present_key_value,)
|
| 465 |
+
|
| 466 |
+
return outputs
|
| 467 |
+
|
| 468 |
+
def feed_forward_chunk(self, attention_output):
|
| 469 |
+
intermediate_output = self.intermediate(attention_output)
|
| 470 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 471 |
+
return layer_output
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText
|
| 475 |
+
class Data2VecTextEncoder(nn.Module):
|
| 476 |
+
def __init__(self, config):
|
| 477 |
+
super().__init__()
|
| 478 |
+
self.config = config
|
| 479 |
+
self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)])
|
| 480 |
+
self.gradient_checkpointing = False
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self,
|
| 484 |
+
hidden_states: torch.Tensor,
|
| 485 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 486 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 487 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 488 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 489 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 490 |
+
use_cache: Optional[bool] = None,
|
| 491 |
+
output_attentions: Optional[bool] = False,
|
| 492 |
+
output_hidden_states: Optional[bool] = False,
|
| 493 |
+
return_dict: Optional[bool] = True,
|
| 494 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 495 |
+
all_hidden_states = () if output_hidden_states else None
|
| 496 |
+
all_self_attentions = () if output_attentions else None
|
| 497 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 498 |
+
|
| 499 |
+
if self.gradient_checkpointing and self.training:
|
| 500 |
+
if use_cache:
|
| 501 |
+
logger.warning_once(
|
| 502 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 503 |
+
)
|
| 504 |
+
use_cache = False
|
| 505 |
+
|
| 506 |
+
next_decoder_cache = () if use_cache else None
|
| 507 |
+
for i, layer_module in enumerate(self.layer):
|
| 508 |
+
if output_hidden_states:
|
| 509 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 510 |
+
|
| 511 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 512 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 513 |
+
|
| 514 |
+
if self.gradient_checkpointing and self.training:
|
| 515 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 516 |
+
layer_module.__call__,
|
| 517 |
+
hidden_states,
|
| 518 |
+
attention_mask,
|
| 519 |
+
layer_head_mask,
|
| 520 |
+
encoder_hidden_states,
|
| 521 |
+
encoder_attention_mask,
|
| 522 |
+
past_key_value,
|
| 523 |
+
output_attentions,
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
layer_outputs = layer_module(
|
| 527 |
+
hidden_states,
|
| 528 |
+
attention_mask,
|
| 529 |
+
layer_head_mask,
|
| 530 |
+
encoder_hidden_states,
|
| 531 |
+
encoder_attention_mask,
|
| 532 |
+
past_key_value,
|
| 533 |
+
output_attentions,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
hidden_states = layer_outputs[0]
|
| 537 |
+
if use_cache:
|
| 538 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 539 |
+
if output_attentions:
|
| 540 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 541 |
+
if self.config.add_cross_attention:
|
| 542 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 543 |
+
|
| 544 |
+
if output_hidden_states:
|
| 545 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 546 |
+
|
| 547 |
+
if not return_dict:
|
| 548 |
+
return tuple(
|
| 549 |
+
v
|
| 550 |
+
for v in [
|
| 551 |
+
hidden_states,
|
| 552 |
+
next_decoder_cache,
|
| 553 |
+
all_hidden_states,
|
| 554 |
+
all_self_attentions,
|
| 555 |
+
all_cross_attentions,
|
| 556 |
+
]
|
| 557 |
+
if v is not None
|
| 558 |
+
)
|
| 559 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 560 |
+
last_hidden_state=hidden_states,
|
| 561 |
+
past_key_values=next_decoder_cache,
|
| 562 |
+
hidden_states=all_hidden_states,
|
| 563 |
+
attentions=all_self_attentions,
|
| 564 |
+
cross_attentions=all_cross_attentions,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# Copied from transformers.models.bert.modeling_bert.BertPooler
|
| 569 |
+
class Data2VecTextPooler(nn.Module):
|
| 570 |
+
def __init__(self, config):
|
| 571 |
+
super().__init__()
|
| 572 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 573 |
+
self.activation = nn.Tanh()
|
| 574 |
+
|
| 575 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 576 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 577 |
+
# to the first token.
|
| 578 |
+
first_token_tensor = hidden_states[:, 0]
|
| 579 |
+
pooled_output = self.dense(first_token_tensor)
|
| 580 |
+
pooled_output = self.activation(pooled_output)
|
| 581 |
+
return pooled_output
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class Data2VecTextPreTrainedModel(PreTrainedModel):
|
| 585 |
+
"""
|
| 586 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 587 |
+
models.
|
| 588 |
+
"""
|
| 589 |
+
|
| 590 |
+
config_class = Data2VecTextConfig
|
| 591 |
+
base_model_prefix = "data2vec_text"
|
| 592 |
+
supports_gradient_checkpointing = True
|
| 593 |
+
_no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
|
| 594 |
+
|
| 595 |
+
def _init_weights(self, module):
|
| 596 |
+
"""Initialize the weights"""
|
| 597 |
+
if isinstance(module, nn.Linear):
|
| 598 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 599 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 600 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 601 |
+
if module.bias is not None:
|
| 602 |
+
module.bias.data.zero_()
|
| 603 |
+
elif isinstance(module, nn.Embedding):
|
| 604 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 605 |
+
if module.padding_idx is not None:
|
| 606 |
+
module.weight.data[module.padding_idx].zero_()
|
| 607 |
+
elif isinstance(module, nn.LayerNorm):
|
| 608 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 609 |
+
module.bias.data.zero_()
|
| 610 |
+
if hasattr(module, "weight") and module.weight is not None:
|
| 611 |
+
module.weight.data.fill_(1.0)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
DATA2VECTEXT_START_DOCSTRING = r"""
|
| 615 |
+
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
|
| 616 |
+
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
|
| 617 |
+
Michael Auli.
|
| 618 |
+
|
| 619 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 620 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 621 |
+
etc.)
|
| 622 |
+
|
| 623 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 624 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 625 |
+
and behavior.
|
| 626 |
+
|
| 627 |
+
Parameters:
|
| 628 |
+
config ([`Data2VecTextConfig`]): Model configuration class with all the parameters of the
|
| 629 |
+
model. Initializing with a config file does not load the weights associated with the model, only the
|
| 630 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
DATA2VECTEXT_INPUTS_DOCSTRING = r"""
|
| 634 |
+
Args:
|
| 635 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 636 |
+
Indices of input sequence tokens in the vocabulary.
|
| 637 |
+
|
| 638 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 639 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 640 |
+
|
| 641 |
+
[What are input IDs?](../glossary#input-ids)
|
| 642 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 643 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 644 |
+
|
| 645 |
+
- 1 for tokens that are **not masked**,
|
| 646 |
+
- 0 for tokens that are **masked**.
|
| 647 |
+
|
| 648 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 649 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 650 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 651 |
+
1]`:
|
| 652 |
+
|
| 653 |
+
- 0 corresponds to a *sentence A* token,
|
| 654 |
+
- 1 corresponds to a *sentence B* token.
|
| 655 |
+
|
| 656 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 657 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 658 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 659 |
+
config.max_position_embeddings - 1]`.
|
| 660 |
+
|
| 661 |
+
[What are position IDs?](../glossary#position-ids)
|
| 662 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 663 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 664 |
+
|
| 665 |
+
- 1 indicates the head is **not masked**,
|
| 666 |
+
- 0 indicates the head is **masked**.
|
| 667 |
+
|
| 668 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 669 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 670 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 671 |
+
model's internal embedding lookup matrix.
|
| 672 |
+
output_attentions (`bool`, *optional*):
|
| 673 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 674 |
+
tensors for more detail.
|
| 675 |
+
output_hidden_states (`bool`, *optional*):
|
| 676 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 677 |
+
more detail.
|
| 678 |
+
return_dict (`bool`, *optional*):
|
| 679 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
@add_start_docstrings(
|
| 684 |
+
"The bare Data2VecText Model for text transformer outputting raw hidden-states without any specific head on top.",
|
| 685 |
+
DATA2VECTEXT_START_DOCSTRING,
|
| 686 |
+
)
|
| 687 |
+
class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
| 688 |
+
"""
|
| 689 |
+
|
| 690 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 691 |
+
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
| 692 |
+
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
| 693 |
+
Kaiser and Illia Polosukhin.
|
| 694 |
+
|
| 695 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 696 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 697 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 698 |
+
|
| 699 |
+
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
| 700 |
+
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 704 |
+
super().__init__(config)
|
| 705 |
+
self.config = config
|
| 706 |
+
|
| 707 |
+
self.embeddings = Data2VecTextForTextEmbeddings(config)
|
| 708 |
+
self.encoder = Data2VecTextEncoder(config)
|
| 709 |
+
|
| 710 |
+
self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
|
| 711 |
+
|
| 712 |
+
# Initialize weights and apply final processing
|
| 713 |
+
self.post_init()
|
| 714 |
+
|
| 715 |
+
def get_input_embeddings(self):
|
| 716 |
+
return self.embeddings.word_embeddings
|
| 717 |
+
|
| 718 |
+
def set_input_embeddings(self, value):
|
| 719 |
+
self.embeddings.word_embeddings = value
|
| 720 |
+
|
| 721 |
+
def _prune_heads(self, heads_to_prune):
|
| 722 |
+
"""
|
| 723 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 724 |
+
class PreTrainedModel
|
| 725 |
+
"""
|
| 726 |
+
for layer, heads in heads_to_prune.items():
|
| 727 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 728 |
+
|
| 729 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 730 |
+
@add_code_sample_docstrings(
|
| 731 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 732 |
+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
| 733 |
+
config_class=_CONFIG_FOR_DOC,
|
| 734 |
+
)
|
| 735 |
+
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
| 736 |
+
def forward(
|
| 737 |
+
self,
|
| 738 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 739 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 740 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 741 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 742 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 743 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 744 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 745 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 746 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 747 |
+
use_cache: Optional[bool] = None,
|
| 748 |
+
output_attentions: Optional[bool] = None,
|
| 749 |
+
output_hidden_states: Optional[bool] = None,
|
| 750 |
+
return_dict: Optional[bool] = None,
|
| 751 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 752 |
+
r"""
|
| 753 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 754 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 755 |
+
the model is configured as a decoder.
|
| 756 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 757 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 758 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 759 |
+
|
| 760 |
+
- 1 for tokens that are **not masked**,
|
| 761 |
+
- 0 for tokens that are **masked**.
|
| 762 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 763 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 764 |
+
|
| 765 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 766 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 767 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 768 |
+
use_cache (`bool`, *optional*):
|
| 769 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 770 |
+
`past_key_values`).
|
| 771 |
+
"""
|
| 772 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 773 |
+
output_hidden_states = (
|
| 774 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 775 |
+
)
|
| 776 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 777 |
+
|
| 778 |
+
if self.config.is_decoder:
|
| 779 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 780 |
+
else:
|
| 781 |
+
use_cache = False
|
| 782 |
+
|
| 783 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 784 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 785 |
+
elif input_ids is not None:
|
| 786 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 787 |
+
input_shape = input_ids.size()
|
| 788 |
+
elif inputs_embeds is not None:
|
| 789 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 790 |
+
else:
|
| 791 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 792 |
+
|
| 793 |
+
batch_size, seq_length = input_shape
|
| 794 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 795 |
+
|
| 796 |
+
# past_key_values_length
|
| 797 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 798 |
+
|
| 799 |
+
if attention_mask is None:
|
| 800 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 801 |
+
|
| 802 |
+
if token_type_ids is None:
|
| 803 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 804 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 805 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 806 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 807 |
+
else:
|
| 808 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 809 |
+
|
| 810 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 811 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 812 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 813 |
+
|
| 814 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 815 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 816 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 817 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 818 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 819 |
+
if encoder_attention_mask is None:
|
| 820 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 821 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 822 |
+
else:
|
| 823 |
+
encoder_extended_attention_mask = None
|
| 824 |
+
|
| 825 |
+
# Prepare head mask if needed
|
| 826 |
+
# 1.0 in head_mask indicate we keep the head
|
| 827 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 828 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 829 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 830 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 831 |
+
|
| 832 |
+
embedding_output = self.embeddings(
|
| 833 |
+
input_ids=input_ids,
|
| 834 |
+
position_ids=position_ids,
|
| 835 |
+
token_type_ids=token_type_ids,
|
| 836 |
+
inputs_embeds=inputs_embeds,
|
| 837 |
+
past_key_values_length=past_key_values_length,
|
| 838 |
+
)
|
| 839 |
+
encoder_outputs = self.encoder(
|
| 840 |
+
embedding_output,
|
| 841 |
+
attention_mask=extended_attention_mask,
|
| 842 |
+
head_mask=head_mask,
|
| 843 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 844 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 845 |
+
past_key_values=past_key_values,
|
| 846 |
+
use_cache=use_cache,
|
| 847 |
+
output_attentions=output_attentions,
|
| 848 |
+
output_hidden_states=output_hidden_states,
|
| 849 |
+
return_dict=return_dict,
|
| 850 |
+
)
|
| 851 |
+
sequence_output = encoder_outputs[0]
|
| 852 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 853 |
+
|
| 854 |
+
if not return_dict:
|
| 855 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 856 |
+
|
| 857 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 858 |
+
last_hidden_state=sequence_output,
|
| 859 |
+
pooler_output=pooled_output,
|
| 860 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 861 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 862 |
+
attentions=encoder_outputs.attentions,
|
| 863 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
@add_start_docstrings(
|
| 868 |
+
"""Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING
|
| 869 |
+
)
|
| 870 |
+
class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
|
| 871 |
+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
| 872 |
+
|
| 873 |
+
def __init__(self, config):
|
| 874 |
+
super().__init__(config)
|
| 875 |
+
|
| 876 |
+
if not config.is_decoder:
|
| 877 |
+
logger.warning("If you want to use `Data2VecTextLMHeadModel` as a standalone, add `is_decoder=True.`")
|
| 878 |
+
|
| 879 |
+
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
|
| 880 |
+
self.lm_head = Data2VecTextLMHead(config)
|
| 881 |
+
|
| 882 |
+
# Initialize weights and apply final processing
|
| 883 |
+
self.post_init()
|
| 884 |
+
|
| 885 |
+
def get_output_embeddings(self):
|
| 886 |
+
return self.lm_head.decoder
|
| 887 |
+
|
| 888 |
+
def set_output_embeddings(self, new_embeddings):
|
| 889 |
+
self.lm_head.decoder = new_embeddings
|
| 890 |
+
|
| 891 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 892 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
| 893 |
+
def forward(
|
| 894 |
+
self,
|
| 895 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 896 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 897 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 898 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 899 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 900 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 901 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 902 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 903 |
+
labels: Optional[torch.LongTensor] = None,
|
| 904 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 905 |
+
use_cache: Optional[bool] = None,
|
| 906 |
+
output_attentions: Optional[bool] = None,
|
| 907 |
+
output_hidden_states: Optional[bool] = None,
|
| 908 |
+
return_dict: Optional[bool] = None,
|
| 909 |
+
**kwargs,
|
| 910 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
| 911 |
+
r"""
|
| 912 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 913 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 914 |
+
the model is configured as a decoder.
|
| 915 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 916 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 917 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 918 |
+
|
| 919 |
+
- 1 for tokens that are **not masked**,
|
| 920 |
+
- 0 for tokens that are **masked**.
|
| 921 |
+
|
| 922 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 923 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 924 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
| 925 |
+
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 926 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 927 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 928 |
+
|
| 929 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 930 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 931 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 932 |
+
use_cache (`bool`, *optional*):
|
| 933 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 934 |
+
`past_key_values`).
|
| 935 |
+
|
| 936 |
+
Returns:
|
| 937 |
+
|
| 938 |
+
Example:
|
| 939 |
+
|
| 940 |
+
```python
|
| 941 |
+
>>> from transformers import AutoTokenizer, Data2VecTextForCausalLM, Data2VecTextConfig
|
| 942 |
+
>>> import torch
|
| 943 |
+
|
| 944 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base")
|
| 945 |
+
>>> config = Data2VecTextConfig.from_pretrained("facebook/data2vec-text-base")
|
| 946 |
+
>>> config.is_decoder = True
|
| 947 |
+
>>> model = Data2VecTextForCausalLM.from_pretrained("facebook/data2vec-text-base", config=config)
|
| 948 |
+
|
| 949 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 950 |
+
>>> outputs = model(**inputs)
|
| 951 |
+
|
| 952 |
+
>>> prediction_logits = outputs.logits
|
| 953 |
+
```"""
|
| 954 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 955 |
+
if labels is not None:
|
| 956 |
+
use_cache = False
|
| 957 |
+
|
| 958 |
+
outputs = self.data2vec_text(
|
| 959 |
+
input_ids,
|
| 960 |
+
attention_mask=attention_mask,
|
| 961 |
+
token_type_ids=token_type_ids,
|
| 962 |
+
position_ids=position_ids,
|
| 963 |
+
head_mask=head_mask,
|
| 964 |
+
inputs_embeds=inputs_embeds,
|
| 965 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 966 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 967 |
+
past_key_values=past_key_values,
|
| 968 |
+
use_cache=use_cache,
|
| 969 |
+
output_attentions=output_attentions,
|
| 970 |
+
output_hidden_states=output_hidden_states,
|
| 971 |
+
return_dict=return_dict,
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
sequence_output = outputs[0]
|
| 975 |
+
prediction_scores = self.lm_head(sequence_output)
|
| 976 |
+
|
| 977 |
+
lm_loss = None
|
| 978 |
+
if labels is not None:
|
| 979 |
+
lm_loss = self.loss_function(
|
| 980 |
+
prediction_scores,
|
| 981 |
+
labels,
|
| 982 |
+
vocab_size=self.config.vocab_size,
|
| 983 |
+
**kwargs,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
if not return_dict:
|
| 987 |
+
output = (prediction_scores,) + outputs[2:]
|
| 988 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 989 |
+
|
| 990 |
+
return CausalLMOutputWithCrossAttentions(
|
| 991 |
+
loss=lm_loss,
|
| 992 |
+
logits=prediction_scores,
|
| 993 |
+
past_key_values=outputs.past_key_values,
|
| 994 |
+
hidden_states=outputs.hidden_states,
|
| 995 |
+
attentions=outputs.attentions,
|
| 996 |
+
cross_attentions=outputs.cross_attentions,
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 1000 |
+
reordered_past = ()
|
| 1001 |
+
for layer_past in past_key_values:
|
| 1002 |
+
reordered_past += (
|
| 1003 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1004 |
+
)
|
| 1005 |
+
return reordered_past
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
@add_start_docstrings("""data2vec Model with a `language modeling` head on top.""", DATA2VECTEXT_START_DOCSTRING)
|
| 1009 |
+
class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
|
| 1010 |
+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
| 1011 |
+
|
| 1012 |
+
def __init__(self, config):
|
| 1013 |
+
super().__init__(config)
|
| 1014 |
+
|
| 1015 |
+
if config.is_decoder:
|
| 1016 |
+
logger.warning(
|
| 1017 |
+
"If you want to use `Data2VecTextForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1018 |
+
"bi-directional self-attention."
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
|
| 1022 |
+
self.lm_head = Data2VecTextLMHead(config)
|
| 1023 |
+
|
| 1024 |
+
# Initialize weights and apply final processing
|
| 1025 |
+
self.post_init()
|
| 1026 |
+
|
| 1027 |
+
def get_output_embeddings(self):
|
| 1028 |
+
return self.lm_head.decoder
|
| 1029 |
+
|
| 1030 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1031 |
+
self.lm_head.decoder = new_embeddings
|
| 1032 |
+
|
| 1033 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1034 |
+
@add_code_sample_docstrings(
|
| 1035 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1036 |
+
output_type=MaskedLMOutput,
|
| 1037 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1038 |
+
mask="<mask>",
|
| 1039 |
+
)
|
| 1040 |
+
def forward(
|
| 1041 |
+
self,
|
| 1042 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1043 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1044 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1045 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1046 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1047 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1048 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1049 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 1050 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1051 |
+
output_attentions: Optional[bool] = None,
|
| 1052 |
+
output_hidden_states: Optional[bool] = None,
|
| 1053 |
+
return_dict: Optional[bool] = None,
|
| 1054 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 1055 |
+
r"""
|
| 1056 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1057 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1058 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1059 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1060 |
+
kwargs (`Dict[str, any]`, *optional*, defaults to *{}*):
|
| 1061 |
+
Used to hide legacy arguments that have been deprecated.
|
| 1062 |
+
"""
|
| 1063 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1064 |
+
|
| 1065 |
+
outputs = self.data2vec_text(
|
| 1066 |
+
input_ids,
|
| 1067 |
+
attention_mask=attention_mask,
|
| 1068 |
+
token_type_ids=token_type_ids,
|
| 1069 |
+
position_ids=position_ids,
|
| 1070 |
+
head_mask=head_mask,
|
| 1071 |
+
inputs_embeds=inputs_embeds,
|
| 1072 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1073 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1074 |
+
output_attentions=output_attentions,
|
| 1075 |
+
output_hidden_states=output_hidden_states,
|
| 1076 |
+
return_dict=return_dict,
|
| 1077 |
+
)
|
| 1078 |
+
sequence_output = outputs[0]
|
| 1079 |
+
prediction_scores = self.lm_head(sequence_output)
|
| 1080 |
+
|
| 1081 |
+
masked_lm_loss = None
|
| 1082 |
+
if labels is not None:
|
| 1083 |
+
loss_fct = CrossEntropyLoss()
|
| 1084 |
+
|
| 1085 |
+
labels = labels.to(prediction_scores.device)
|
| 1086 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1087 |
+
|
| 1088 |
+
if not return_dict:
|
| 1089 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1090 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1091 |
+
|
| 1092 |
+
return MaskedLMOutput(
|
| 1093 |
+
loss=masked_lm_loss,
|
| 1094 |
+
logits=prediction_scores,
|
| 1095 |
+
hidden_states=outputs.hidden_states,
|
| 1096 |
+
attentions=outputs.attentions,
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Data2VecText
|
| 1101 |
+
class Data2VecTextLMHead(nn.Module):
|
| 1102 |
+
"""Data2VecText Head for masked language modeling."""
|
| 1103 |
+
|
| 1104 |
+
def __init__(self, config):
|
| 1105 |
+
super().__init__()
|
| 1106 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 1107 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1108 |
+
|
| 1109 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
| 1110 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 1111 |
+
self.decoder.bias = self.bias
|
| 1112 |
+
|
| 1113 |
+
def forward(self, features, **kwargs):
|
| 1114 |
+
x = self.dense(features)
|
| 1115 |
+
x = gelu(x)
|
| 1116 |
+
x = self.layer_norm(x)
|
| 1117 |
+
|
| 1118 |
+
# project back to size of vocabulary with bias
|
| 1119 |
+
x = self.decoder(x)
|
| 1120 |
+
|
| 1121 |
+
return x
|
| 1122 |
+
|
| 1123 |
+
def _tie_weights(self):
|
| 1124 |
+
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
| 1125 |
+
# For accelerate compatibility and to not break backward compatibility
|
| 1126 |
+
if self.decoder.bias.device.type == "meta":
|
| 1127 |
+
self.decoder.bias = self.bias
|
| 1128 |
+
else:
|
| 1129 |
+
self.bias = self.decoder.bias
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
@add_start_docstrings(
|
| 1133 |
+
"""
|
| 1134 |
+
Data2VecText Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 1135 |
+
pooled output) e.g. for GLUE tasks.
|
| 1136 |
+
""",
|
| 1137 |
+
DATA2VECTEXT_START_DOCSTRING,
|
| 1138 |
+
)
|
| 1139 |
+
class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
|
| 1140 |
+
def __init__(self, config):
|
| 1141 |
+
super().__init__(config)
|
| 1142 |
+
self.num_labels = config.num_labels
|
| 1143 |
+
self.config = config
|
| 1144 |
+
|
| 1145 |
+
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
|
| 1146 |
+
self.classifier = Data2VecTextClassificationHead(config)
|
| 1147 |
+
|
| 1148 |
+
# Initialize weights and apply final processing
|
| 1149 |
+
self.post_init()
|
| 1150 |
+
|
| 1151 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1152 |
+
@add_code_sample_docstrings(
|
| 1153 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1154 |
+
output_type=SequenceClassifierOutput,
|
| 1155 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1156 |
+
)
|
| 1157 |
+
def forward(
|
| 1158 |
+
self,
|
| 1159 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1160 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1161 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1162 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1163 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1164 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1165 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1166 |
+
output_attentions: Optional[bool] = None,
|
| 1167 |
+
output_hidden_states: Optional[bool] = None,
|
| 1168 |
+
return_dict: Optional[bool] = None,
|
| 1169 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1170 |
+
r"""
|
| 1171 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1172 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1173 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1174 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1175 |
+
"""
|
| 1176 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1177 |
+
|
| 1178 |
+
outputs = self.data2vec_text(
|
| 1179 |
+
input_ids,
|
| 1180 |
+
attention_mask=attention_mask,
|
| 1181 |
+
token_type_ids=token_type_ids,
|
| 1182 |
+
position_ids=position_ids,
|
| 1183 |
+
head_mask=head_mask,
|
| 1184 |
+
inputs_embeds=inputs_embeds,
|
| 1185 |
+
output_attentions=output_attentions,
|
| 1186 |
+
output_hidden_states=output_hidden_states,
|
| 1187 |
+
return_dict=return_dict,
|
| 1188 |
+
)
|
| 1189 |
+
sequence_output = outputs[0]
|
| 1190 |
+
logits = self.classifier(sequence_output)
|
| 1191 |
+
|
| 1192 |
+
loss = None
|
| 1193 |
+
if labels is not None:
|
| 1194 |
+
labels = labels.to(logits.device)
|
| 1195 |
+
|
| 1196 |
+
if self.config.problem_type is None:
|
| 1197 |
+
if self.num_labels == 1:
|
| 1198 |
+
self.config.problem_type = "regression"
|
| 1199 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1200 |
+
self.config.problem_type = "single_label_classification"
|
| 1201 |
+
else:
|
| 1202 |
+
self.config.problem_type = "multi_label_classification"
|
| 1203 |
+
|
| 1204 |
+
if self.config.problem_type == "regression":
|
| 1205 |
+
loss_fct = MSELoss()
|
| 1206 |
+
if self.num_labels == 1:
|
| 1207 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1208 |
+
else:
|
| 1209 |
+
loss = loss_fct(logits, labels)
|
| 1210 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1211 |
+
loss_fct = CrossEntropyLoss()
|
| 1212 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1213 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1214 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1215 |
+
loss = loss_fct(logits, labels)
|
| 1216 |
+
|
| 1217 |
+
if not return_dict:
|
| 1218 |
+
output = (logits,) + outputs[2:]
|
| 1219 |
+
return ((loss,) + output) if loss is not None else output
|
| 1220 |
+
|
| 1221 |
+
return SequenceClassifierOutput(
|
| 1222 |
+
loss=loss,
|
| 1223 |
+
logits=logits,
|
| 1224 |
+
hidden_states=outputs.hidden_states,
|
| 1225 |
+
attentions=outputs.attentions,
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
@add_start_docstrings(
|
| 1230 |
+
"""
|
| 1231 |
+
Data2VecText Model with a multiple choice classification head on top (a linear layer on top of the pooled output
|
| 1232 |
+
and a softmax) e.g. for RocStories/SWAG tasks.
|
| 1233 |
+
""",
|
| 1234 |
+
DATA2VECTEXT_START_DOCSTRING,
|
| 1235 |
+
)
|
| 1236 |
+
class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
|
| 1237 |
+
def __init__(self, config):
|
| 1238 |
+
super().__init__(config)
|
| 1239 |
+
|
| 1240 |
+
self.data2vec_text = Data2VecTextModel(config)
|
| 1241 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1242 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1243 |
+
|
| 1244 |
+
# Initialize weights and apply final processing
|
| 1245 |
+
self.post_init()
|
| 1246 |
+
|
| 1247 |
+
@add_start_docstrings_to_model_forward(
|
| 1248 |
+
DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1249 |
+
)
|
| 1250 |
+
@add_code_sample_docstrings(
|
| 1251 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1252 |
+
output_type=MultipleChoiceModelOutput,
|
| 1253 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1254 |
+
)
|
| 1255 |
+
def forward(
|
| 1256 |
+
self,
|
| 1257 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1258 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1259 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1260 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1261 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1262 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1263 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1264 |
+
output_attentions: Optional[bool] = None,
|
| 1265 |
+
output_hidden_states: Optional[bool] = None,
|
| 1266 |
+
return_dict: Optional[bool] = None,
|
| 1267 |
+
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
| 1268 |
+
r"""
|
| 1269 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1270 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1271 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1272 |
+
`input_ids` above)
|
| 1273 |
+
"""
|
| 1274 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1275 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1276 |
+
|
| 1277 |
+
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1278 |
+
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1279 |
+
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1280 |
+
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1281 |
+
flat_inputs_embeds = (
|
| 1282 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1283 |
+
if inputs_embeds is not None
|
| 1284 |
+
else None
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
outputs = self.data2vec_text(
|
| 1288 |
+
flat_input_ids,
|
| 1289 |
+
position_ids=flat_position_ids,
|
| 1290 |
+
token_type_ids=flat_token_type_ids,
|
| 1291 |
+
attention_mask=flat_attention_mask,
|
| 1292 |
+
head_mask=head_mask,
|
| 1293 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1294 |
+
output_attentions=output_attentions,
|
| 1295 |
+
output_hidden_states=output_hidden_states,
|
| 1296 |
+
return_dict=return_dict,
|
| 1297 |
+
)
|
| 1298 |
+
pooled_output = outputs[1]
|
| 1299 |
+
|
| 1300 |
+
pooled_output = self.dropout(pooled_output)
|
| 1301 |
+
logits = self.classifier(pooled_output)
|
| 1302 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1303 |
+
|
| 1304 |
+
loss = None
|
| 1305 |
+
if labels is not None:
|
| 1306 |
+
loss_fct = CrossEntropyLoss()
|
| 1307 |
+
|
| 1308 |
+
labels = labels.to(reshaped_logits.device)
|
| 1309 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1310 |
+
|
| 1311 |
+
if not return_dict:
|
| 1312 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1313 |
+
return ((loss,) + output) if loss is not None else output
|
| 1314 |
+
|
| 1315 |
+
return MultipleChoiceModelOutput(
|
| 1316 |
+
loss=loss,
|
| 1317 |
+
logits=reshaped_logits,
|
| 1318 |
+
hidden_states=outputs.hidden_states,
|
| 1319 |
+
attentions=outputs.attentions,
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
@add_start_docstrings(
|
| 1324 |
+
"""
|
| 1325 |
+
Data2VecText Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
| 1326 |
+
for Named-Entity-Recognition (NER) tasks.
|
| 1327 |
+
""",
|
| 1328 |
+
DATA2VECTEXT_START_DOCSTRING,
|
| 1329 |
+
)
|
| 1330 |
+
class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
|
| 1331 |
+
def __init__(self, config):
|
| 1332 |
+
super().__init__(config)
|
| 1333 |
+
self.num_labels = config.num_labels
|
| 1334 |
+
|
| 1335 |
+
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
|
| 1336 |
+
classifier_dropout = (
|
| 1337 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1338 |
+
)
|
| 1339 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1340 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1341 |
+
|
| 1342 |
+
# Initialize weights and apply final processing
|
| 1343 |
+
self.post_init()
|
| 1344 |
+
|
| 1345 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1346 |
+
@add_code_sample_docstrings(
|
| 1347 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1348 |
+
output_type=TokenClassifierOutput,
|
| 1349 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1350 |
+
)
|
| 1351 |
+
def forward(
|
| 1352 |
+
self,
|
| 1353 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1354 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1355 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1356 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1357 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1358 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1359 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1360 |
+
output_attentions: Optional[bool] = None,
|
| 1361 |
+
output_hidden_states: Optional[bool] = None,
|
| 1362 |
+
return_dict: Optional[bool] = None,
|
| 1363 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1364 |
+
r"""
|
| 1365 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1366 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1367 |
+
"""
|
| 1368 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1369 |
+
|
| 1370 |
+
outputs = self.data2vec_text(
|
| 1371 |
+
input_ids,
|
| 1372 |
+
attention_mask=attention_mask,
|
| 1373 |
+
token_type_ids=token_type_ids,
|
| 1374 |
+
position_ids=position_ids,
|
| 1375 |
+
head_mask=head_mask,
|
| 1376 |
+
inputs_embeds=inputs_embeds,
|
| 1377 |
+
output_attentions=output_attentions,
|
| 1378 |
+
output_hidden_states=output_hidden_states,
|
| 1379 |
+
return_dict=return_dict,
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
sequence_output = outputs[0]
|
| 1383 |
+
|
| 1384 |
+
sequence_output = self.dropout(sequence_output)
|
| 1385 |
+
logits = self.classifier(sequence_output)
|
| 1386 |
+
|
| 1387 |
+
loss = None
|
| 1388 |
+
if labels is not None:
|
| 1389 |
+
loss_fct = CrossEntropyLoss()
|
| 1390 |
+
|
| 1391 |
+
labels = labels.to(logits.device)
|
| 1392 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1393 |
+
|
| 1394 |
+
if not return_dict:
|
| 1395 |
+
output = (logits,) + outputs[2:]
|
| 1396 |
+
return ((loss,) + output) if loss is not None else output
|
| 1397 |
+
|
| 1398 |
+
return TokenClassifierOutput(
|
| 1399 |
+
loss=loss,
|
| 1400 |
+
logits=logits,
|
| 1401 |
+
hidden_states=outputs.hidden_states,
|
| 1402 |
+
attentions=outputs.attentions,
|
| 1403 |
+
)
|
| 1404 |
+
|
| 1405 |
+
|
| 1406 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Data2VecText
|
| 1407 |
+
class Data2VecTextClassificationHead(nn.Module):
|
| 1408 |
+
"""Head for sentence-level classification tasks."""
|
| 1409 |
+
|
| 1410 |
+
def __init__(self, config):
|
| 1411 |
+
super().__init__()
|
| 1412 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 1413 |
+
classifier_dropout = (
|
| 1414 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1415 |
+
)
|
| 1416 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1417 |
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
| 1418 |
+
|
| 1419 |
+
def forward(self, features, **kwargs):
|
| 1420 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 1421 |
+
x = self.dropout(x)
|
| 1422 |
+
x = self.dense(x)
|
| 1423 |
+
x = torch.tanh(x)
|
| 1424 |
+
x = self.dropout(x)
|
| 1425 |
+
x = self.out_proj(x)
|
| 1426 |
+
return x
|
| 1427 |
+
|
| 1428 |
+
|
| 1429 |
+
@add_start_docstrings(
|
| 1430 |
+
"""
|
| 1431 |
+
Data2VecText Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
|
| 1432 |
+
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1433 |
+
""",
|
| 1434 |
+
DATA2VECTEXT_START_DOCSTRING,
|
| 1435 |
+
)
|
| 1436 |
+
class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
|
| 1437 |
+
def __init__(self, config):
|
| 1438 |
+
super().__init__(config)
|
| 1439 |
+
self.num_labels = config.num_labels
|
| 1440 |
+
|
| 1441 |
+
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
|
| 1442 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1443 |
+
|
| 1444 |
+
# Initialize weights and apply final processing
|
| 1445 |
+
self.post_init()
|
| 1446 |
+
|
| 1447 |
+
@add_start_docstrings_to_model_forward(DATA2VECTEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1448 |
+
@add_code_sample_docstrings(
|
| 1449 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1450 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1451 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1452 |
+
)
|
| 1453 |
+
def forward(
|
| 1454 |
+
self,
|
| 1455 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1456 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1457 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1458 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1459 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1460 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1461 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1462 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1463 |
+
output_attentions: Optional[bool] = None,
|
| 1464 |
+
output_hidden_states: Optional[bool] = None,
|
| 1465 |
+
return_dict: Optional[bool] = None,
|
| 1466 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1467 |
+
r"""
|
| 1468 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1469 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1470 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1471 |
+
are not taken into account for computing the loss.
|
| 1472 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1473 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1474 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1475 |
+
are not taken into account for computing the loss.
|
| 1476 |
+
"""
|
| 1477 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1478 |
+
|
| 1479 |
+
outputs = self.data2vec_text(
|
| 1480 |
+
input_ids,
|
| 1481 |
+
attention_mask=attention_mask,
|
| 1482 |
+
token_type_ids=token_type_ids,
|
| 1483 |
+
position_ids=position_ids,
|
| 1484 |
+
head_mask=head_mask,
|
| 1485 |
+
inputs_embeds=inputs_embeds,
|
| 1486 |
+
output_attentions=output_attentions,
|
| 1487 |
+
output_hidden_states=output_hidden_states,
|
| 1488 |
+
return_dict=return_dict,
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
sequence_output = outputs[0]
|
| 1492 |
+
|
| 1493 |
+
logits = self.qa_outputs(sequence_output)
|
| 1494 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1495 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1496 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1497 |
+
|
| 1498 |
+
total_loss = None
|
| 1499 |
+
if start_positions is not None and end_positions is not None:
|
| 1500 |
+
# If we are on multi-GPU, split add a dimension
|
| 1501 |
+
if len(start_positions.size()) > 1:
|
| 1502 |
+
start_positions = start_positions.squeeze(-1)
|
| 1503 |
+
if len(end_positions.size()) > 1:
|
| 1504 |
+
end_positions = end_positions.squeeze(-1)
|
| 1505 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1506 |
+
ignored_index = start_logits.size(1)
|
| 1507 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1508 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1509 |
+
|
| 1510 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1511 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1512 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1513 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1514 |
+
|
| 1515 |
+
if not return_dict:
|
| 1516 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1517 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1518 |
+
|
| 1519 |
+
return QuestionAnsweringModelOutput(
|
| 1520 |
+
loss=total_loss,
|
| 1521 |
+
start_logits=start_logits,
|
| 1522 |
+
end_logits=end_logits,
|
| 1523 |
+
hidden_states=outputs.hidden_states,
|
| 1524 |
+
attentions=outputs.attentions,
|
| 1525 |
+
)
|
| 1526 |
+
|
| 1527 |
+
|
| 1528 |
+
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
| 1529 |
+
"""
|
| 1530 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
| 1531 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
| 1532 |
+
|
| 1533 |
+
Args:
|
| 1534 |
+
x: torch.Tensor x:
|
| 1535 |
+
|
| 1536 |
+
Returns: torch.Tensor
|
| 1537 |
+
"""
|
| 1538 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
| 1539 |
+
mask = input_ids.ne(padding_idx).int()
|
| 1540 |
+
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
| 1541 |
+
return incremental_indices.long() + padding_idx
|
| 1542 |
+
|
| 1543 |
+
|
| 1544 |
+
__all__ = [
|
| 1545 |
+
"Data2VecTextForCausalLM",
|
| 1546 |
+
"Data2VecTextForMaskedLM",
|
| 1547 |
+
"Data2VecTextForMultipleChoice",
|
| 1548 |
+
"Data2VecTextForQuestionAnswering",
|
| 1549 |
+
"Data2VecTextForSequenceClassification",
|
| 1550 |
+
"Data2VecTextForTokenClassification",
|
| 1551 |
+
"Data2VecTextModel",
|
| 1552 |
+
"Data2VecTextPreTrainedModel",
|
| 1553 |
+
]
|
docs/transformers/build/lib/transformers/models/data2vec/modeling_data2vec_vision.py
ADDED
|
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Data2VecVision model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
import math
|
| 19 |
+
import warnings
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN
|
| 29 |
+
from ...modeling_outputs import (
|
| 30 |
+
BaseModelOutput,
|
| 31 |
+
BaseModelOutputWithPooling,
|
| 32 |
+
ImageClassifierOutput,
|
| 33 |
+
SemanticSegmenterOutput,
|
| 34 |
+
)
|
| 35 |
+
from ...modeling_utils import PreTrainedModel
|
| 36 |
+
from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
|
| 37 |
+
from ...utils import (
|
| 38 |
+
add_code_sample_docstrings,
|
| 39 |
+
add_start_docstrings,
|
| 40 |
+
add_start_docstrings_to_model_forward,
|
| 41 |
+
logging,
|
| 42 |
+
replace_return_docstrings,
|
| 43 |
+
torch_int,
|
| 44 |
+
)
|
| 45 |
+
from .configuration_data2vec_vision import Data2VecVisionConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
# General docstring
|
| 51 |
+
_CONFIG_FOR_DOC = "Data2VecVisionConfig"
|
| 52 |
+
|
| 53 |
+
# Base docstring
|
| 54 |
+
_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
|
| 55 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
| 56 |
+
|
| 57 |
+
# Image classification docstring
|
| 58 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
|
| 59 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
# Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision
|
| 64 |
+
class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
|
| 65 |
+
"""
|
| 66 |
+
Class for outputs of [`Data2VecVisionModel`].
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 70 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 71 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
| 72 |
+
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
|
| 73 |
+
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
|
| 74 |
+
will be returned.
|
| 75 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 76 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 77 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 78 |
+
|
| 79 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 80 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 81 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 82 |
+
sequence_length)`.
|
| 83 |
+
|
| 84 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 85 |
+
heads.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 90 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 93 |
+
|
| 94 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 95 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 96 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 97 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 98 |
+
argument.
|
| 99 |
+
"""
|
| 100 |
+
if drop_prob == 0.0 or not training:
|
| 101 |
+
return input
|
| 102 |
+
keep_prob = 1 - drop_prob
|
| 103 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 104 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 105 |
+
random_tensor.floor_() # binarize
|
| 106 |
+
output = input.div(keep_prob) * random_tensor
|
| 107 |
+
return output
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
|
| 111 |
+
class Data2VecVisionDropPath(nn.Module):
|
| 112 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 113 |
+
|
| 114 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.drop_prob = drop_prob
|
| 117 |
+
|
| 118 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 119 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 120 |
+
|
| 121 |
+
def extra_repr(self) -> str:
|
| 122 |
+
return "p={}".format(self.drop_prob)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
|
| 126 |
+
class Data2VecVisionEmbeddings(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 136 |
+
if config.use_mask_token:
|
| 137 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 138 |
+
else:
|
| 139 |
+
self.mask_token = None
|
| 140 |
+
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
|
| 141 |
+
self.patch_size = config.patch_size
|
| 142 |
+
self.image_size = (
|
| 143 |
+
config.image_size
|
| 144 |
+
if isinstance(config.image_size, collections.abc.Iterable)
|
| 145 |
+
else (config.image_size, config.image_size)
|
| 146 |
+
)
|
| 147 |
+
num_patches = self.patch_embeddings.num_patches
|
| 148 |
+
if config.use_absolute_position_embeddings:
|
| 149 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
| 150 |
+
else:
|
| 151 |
+
self.position_embeddings = None
|
| 152 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 153 |
+
|
| 154 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 155 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 156 |
+
"""
|
| 157 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 158 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 159 |
+
|
| 160 |
+
Adapted from:
|
| 161 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 162 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
num_patches = embeddings.shape[1] - 1
|
| 166 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 167 |
+
|
| 168 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 169 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 170 |
+
return self.position_embeddings
|
| 171 |
+
|
| 172 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 173 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 174 |
+
|
| 175 |
+
dim = embeddings.shape[-1]
|
| 176 |
+
|
| 177 |
+
new_height = height // self.patch_size
|
| 178 |
+
new_width = width // self.patch_size
|
| 179 |
+
|
| 180 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 181 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 182 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 183 |
+
|
| 184 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 185 |
+
patch_pos_embed,
|
| 186 |
+
size=(new_height, new_width),
|
| 187 |
+
mode="bicubic",
|
| 188 |
+
align_corners=False,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 192 |
+
|
| 193 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
pixel_values: torch.Tensor,
|
| 198 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 199 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
if self.position_embeddings is not None and interpolate_pos_encoding is not None:
|
| 202 |
+
warnings.warn(
|
| 203 |
+
"`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
|
| 204 |
+
"interpolated to the input image size. The argument will be removed in transformers v4.51.0."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
_, _, height, width = pixel_values.shape
|
| 208 |
+
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
|
| 209 |
+
batch_size, seq_len, _ = embeddings.size()
|
| 210 |
+
|
| 211 |
+
if bool_masked_pos is not None:
|
| 212 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
| 213 |
+
# replace the masked visual tokens by mask_tokens
|
| 214 |
+
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
| 215 |
+
embeddings = embeddings * (1 - w) + mask_tokens * w
|
| 216 |
+
|
| 217 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 218 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 219 |
+
|
| 220 |
+
if self.position_embeddings is not None:
|
| 221 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 222 |
+
|
| 223 |
+
embeddings = self.dropout(embeddings)
|
| 224 |
+
|
| 225 |
+
return embeddings, (patch_height, patch_width)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
|
| 229 |
+
class Data2VecVisionPatchEmbeddings(nn.Module):
|
| 230 |
+
"""
|
| 231 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 232 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 233 |
+
Transformer.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(self, config):
|
| 237 |
+
super().__init__()
|
| 238 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 239 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 240 |
+
|
| 241 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 242 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 243 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 244 |
+
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
| 245 |
+
self.image_size = image_size
|
| 246 |
+
self.patch_size = patch_size
|
| 247 |
+
self.num_channels = num_channels
|
| 248 |
+
self.num_patches = num_patches
|
| 249 |
+
self.patch_shape = patch_shape
|
| 250 |
+
|
| 251 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 252 |
+
|
| 253 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 254 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 255 |
+
if num_channels != self.num_channels:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
embeddings = self.projection(pixel_values)
|
| 261 |
+
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
| 262 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
| 263 |
+
|
| 264 |
+
return embeddings, (patch_height, patch_width)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
|
| 268 |
+
class Data2VecVisionSelfAttention(nn.Module):
|
| 269 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.config = config
|
| 272 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 275 |
+
f"heads {config.num_attention_heads}."
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.num_attention_heads = config.num_attention_heads
|
| 279 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 280 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 281 |
+
|
| 282 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 283 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
| 284 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 285 |
+
|
| 286 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 287 |
+
|
| 288 |
+
self.has_relative_position_bias = bool(window_size)
|
| 289 |
+
if self.has_relative_position_bias:
|
| 290 |
+
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
| 291 |
+
|
| 292 |
+
def transpose_for_scores(self, x):
|
| 293 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 294 |
+
x = x.view(*new_x_shape)
|
| 295 |
+
return x.permute(0, 2, 1, 3)
|
| 296 |
+
|
| 297 |
+
def forward(
|
| 298 |
+
self,
|
| 299 |
+
hidden_states: torch.Tensor,
|
| 300 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 301 |
+
output_attentions: bool = False,
|
| 302 |
+
relative_position_bias: Optional[torch.Tensor] = None,
|
| 303 |
+
interpolate_pos_encoding: bool = False,
|
| 304 |
+
resolution: Optional[Tuple[int]] = None,
|
| 305 |
+
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
| 306 |
+
mixed_query_layer = self.query(hidden_states)
|
| 307 |
+
|
| 308 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 309 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 310 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 311 |
+
|
| 312 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 313 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 314 |
+
|
| 315 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 316 |
+
|
| 317 |
+
# Add relative position bias if present.
|
| 318 |
+
if self.has_relative_position_bias:
|
| 319 |
+
height, width = resolution
|
| 320 |
+
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
| 321 |
+
attention_scores = attention_scores + self.relative_position_bias(
|
| 322 |
+
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Add shared relative position bias if provided.
|
| 326 |
+
if relative_position_bias is not None:
|
| 327 |
+
attention_scores = attention_scores + relative_position_bias
|
| 328 |
+
|
| 329 |
+
# Normalize the attention scores to probabilities.
|
| 330 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 331 |
+
|
| 332 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 333 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 334 |
+
attention_probs = self.dropout(attention_probs)
|
| 335 |
+
|
| 336 |
+
# Mask heads if we want to
|
| 337 |
+
if head_mask is not None:
|
| 338 |
+
attention_probs = attention_probs * head_mask
|
| 339 |
+
|
| 340 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 341 |
+
|
| 342 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 343 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 344 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 345 |
+
|
| 346 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 347 |
+
|
| 348 |
+
return outputs
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
|
| 352 |
+
class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
hidden_states: torch.Tensor,
|
| 356 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 357 |
+
output_attentions: bool = False,
|
| 358 |
+
relative_position_bias: Optional[torch.Tensor] = None,
|
| 359 |
+
interpolate_pos_encoding: bool = False,
|
| 360 |
+
resolution: Optional[Tuple[int]] = None,
|
| 361 |
+
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
| 362 |
+
if output_attentions or head_mask is not None:
|
| 363 |
+
logger.warning_once(
|
| 364 |
+
"`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
|
| 365 |
+
"support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
|
| 366 |
+
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
| 367 |
+
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 368 |
+
)
|
| 369 |
+
return super().forward(
|
| 370 |
+
hidden_states=hidden_states,
|
| 371 |
+
head_mask=head_mask,
|
| 372 |
+
output_attentions=output_attentions,
|
| 373 |
+
relative_position_bias=relative_position_bias,
|
| 374 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 375 |
+
resolution=resolution,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
mixed_query_layer = self.query(hidden_states)
|
| 379 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 380 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 381 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 382 |
+
|
| 383 |
+
attn_bias = None
|
| 384 |
+
if self.has_relative_position_bias:
|
| 385 |
+
height, width = resolution
|
| 386 |
+
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
| 387 |
+
attn_bias = self.relative_position_bias(
|
| 388 |
+
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Add shared relative position bias if provided.
|
| 392 |
+
if relative_position_bias is not None:
|
| 393 |
+
if attn_bias is None:
|
| 394 |
+
attn_bias = relative_position_bias
|
| 395 |
+
else:
|
| 396 |
+
attn_bias += relative_position_bias
|
| 397 |
+
|
| 398 |
+
scaling = 1 / math.sqrt(self.attention_head_size)
|
| 399 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
| 400 |
+
query_layer,
|
| 401 |
+
key_layer,
|
| 402 |
+
value_layer,
|
| 403 |
+
attn_mask=attn_bias,
|
| 404 |
+
dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
|
| 405 |
+
is_causal=False,
|
| 406 |
+
scale=scaling,
|
| 407 |
+
)
|
| 408 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 409 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 410 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 411 |
+
return context_layer, None
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
|
| 415 |
+
class Data2VecVisionSelfOutput(nn.Module):
|
| 416 |
+
"""
|
| 417 |
+
The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
|
| 418 |
+
layernorm applied before each block.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 422 |
+
super().__init__()
|
| 423 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 424 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 425 |
+
|
| 426 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
|
| 427 |
+
hidden_states = self.dense(hidden_states)
|
| 428 |
+
hidden_states = self.dropout(hidden_states)
|
| 429 |
+
|
| 430 |
+
return hidden_states
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
|
| 434 |
+
"eager": Data2VecVisionSelfAttention,
|
| 435 |
+
"sdpa": Data2VecVisionSdpaSelfAttention,
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
|
| 440 |
+
class Data2VecVisionAttention(nn.Module):
|
| 441 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
| 444 |
+
config, window_size=window_size
|
| 445 |
+
)
|
| 446 |
+
self.output = Data2VecVisionSelfOutput(config)
|
| 447 |
+
self.pruned_heads = set()
|
| 448 |
+
|
| 449 |
+
def prune_heads(self, heads):
|
| 450 |
+
if len(heads) == 0:
|
| 451 |
+
return
|
| 452 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 453 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Prune linear layers
|
| 457 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 458 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 459 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 460 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 461 |
+
|
| 462 |
+
# Update hyper params and store pruned heads
|
| 463 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 464 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 465 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 466 |
+
|
| 467 |
+
def forward(
|
| 468 |
+
self,
|
| 469 |
+
hidden_states: torch.Tensor,
|
| 470 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 471 |
+
output_attentions: bool = False,
|
| 472 |
+
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
| 473 |
+
interpolate_pos_encoding: bool = False,
|
| 474 |
+
resolution: Optional[Tuple[int]] = None,
|
| 475 |
+
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
| 476 |
+
self_outputs = self.attention(
|
| 477 |
+
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 481 |
+
|
| 482 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 483 |
+
return outputs
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision
|
| 487 |
+
class Data2VecVisionIntermediate(nn.Module):
|
| 488 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 489 |
+
super().__init__()
|
| 490 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 491 |
+
if isinstance(config.hidden_act, str):
|
| 492 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 493 |
+
else:
|
| 494 |
+
self.intermediate_act_fn = config.hidden_act
|
| 495 |
+
|
| 496 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 497 |
+
hidden_states = self.dense(hidden_states)
|
| 498 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 499 |
+
|
| 500 |
+
return hidden_states
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision
|
| 504 |
+
class Data2VecVisionOutput(nn.Module):
|
| 505 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 506 |
+
super().__init__()
|
| 507 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 508 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 509 |
+
|
| 510 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 511 |
+
hidden_states = self.dense(hidden_states)
|
| 512 |
+
hidden_states = self.dropout(hidden_states)
|
| 513 |
+
|
| 514 |
+
return hidden_states
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
|
| 518 |
+
class Data2VecVisionLayer(nn.Module):
|
| 519 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 520 |
+
|
| 521 |
+
def __init__(
|
| 522 |
+
self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0
|
| 523 |
+
) -> None:
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 526 |
+
self.seq_len_dim = 1
|
| 527 |
+
self.attention = Data2VecVisionAttention(config, window_size=window_size)
|
| 528 |
+
self.intermediate = Data2VecVisionIntermediate(config)
|
| 529 |
+
self.output = Data2VecVisionOutput(config)
|
| 530 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 531 |
+
self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 532 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 533 |
+
|
| 534 |
+
init_values = config.layer_scale_init_value
|
| 535 |
+
if init_values > 0:
|
| 536 |
+
self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
| 537 |
+
self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
| 538 |
+
else:
|
| 539 |
+
self.lambda_1, self.lambda_2 = None, None
|
| 540 |
+
|
| 541 |
+
def forward(
|
| 542 |
+
self,
|
| 543 |
+
hidden_states: torch.Tensor,
|
| 544 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 545 |
+
output_attentions: bool = False,
|
| 546 |
+
relative_position_bias: Optional[torch.Tensor] = None,
|
| 547 |
+
interpolate_pos_encoding: bool = False,
|
| 548 |
+
resolution: Optional[Tuple[int]] = None,
|
| 549 |
+
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
| 550 |
+
self_attention_outputs = self.attention(
|
| 551 |
+
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
| 552 |
+
head_mask,
|
| 553 |
+
output_attentions=output_attentions,
|
| 554 |
+
relative_position_bias=relative_position_bias,
|
| 555 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 556 |
+
resolution=resolution,
|
| 557 |
+
)
|
| 558 |
+
attention_output = self_attention_outputs[0]
|
| 559 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 560 |
+
|
| 561 |
+
# apply lambda_1 if present
|
| 562 |
+
if self.lambda_1 is not None:
|
| 563 |
+
attention_output = self.lambda_1 * attention_output
|
| 564 |
+
|
| 565 |
+
# first residual connection
|
| 566 |
+
hidden_states = self.drop_path(attention_output) + hidden_states
|
| 567 |
+
|
| 568 |
+
# in Data2VecVision, layernorm is also applied after self-attention
|
| 569 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 570 |
+
|
| 571 |
+
layer_output = self.intermediate(layer_output)
|
| 572 |
+
layer_output = self.output(layer_output)
|
| 573 |
+
|
| 574 |
+
if self.lambda_2 is not None:
|
| 575 |
+
layer_output = self.lambda_2 * layer_output
|
| 576 |
+
|
| 577 |
+
# second residual connection
|
| 578 |
+
layer_output = self.drop_path(layer_output) + hidden_states
|
| 579 |
+
|
| 580 |
+
outputs = (layer_output,) + outputs
|
| 581 |
+
|
| 582 |
+
return outputs
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
# Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision
|
| 586 |
+
class Data2VecVisionRelativePositionBias(nn.Module):
|
| 587 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None:
|
| 588 |
+
super().__init__()
|
| 589 |
+
self.window_size = window_size
|
| 590 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 591 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 592 |
+
torch.zeros(self.num_relative_distance, config.num_attention_heads)
|
| 593 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
| 594 |
+
# cls to token & token 2 cls & cls to cls
|
| 595 |
+
|
| 596 |
+
@compile_compatible_method_lru_cache(maxsize=10)
|
| 597 |
+
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
|
| 598 |
+
"""
|
| 599 |
+
This method creates the relative position index, modified to support arbitrary window sizes,
|
| 600 |
+
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
|
| 601 |
+
"""
|
| 602 |
+
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 603 |
+
# cls to token & token 2 cls & cls to cls
|
| 604 |
+
# get pair-wise relative position index for each token inside the window
|
| 605 |
+
window_area = window_size[0] * window_size[1]
|
| 606 |
+
grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
|
| 607 |
+
coords = torch.stack(grid) # 2, Wh, Ww
|
| 608 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 609 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 610 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 611 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 612 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 613 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 614 |
+
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
|
| 615 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 616 |
+
relative_position_index[0, 0:] = num_relative_distance - 3
|
| 617 |
+
relative_position_index[0:, 0] = num_relative_distance - 2
|
| 618 |
+
relative_position_index[0, 0] = num_relative_distance - 1
|
| 619 |
+
return relative_position_index
|
| 620 |
+
|
| 621 |
+
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
|
| 622 |
+
"""
|
| 623 |
+
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
|
| 624 |
+
"""
|
| 625 |
+
old_height = 2 * self.window_size[0] - 1
|
| 626 |
+
old_width = 2 * self.window_size[1] - 1
|
| 627 |
+
|
| 628 |
+
new_height = 2 * window_size[0] - 1
|
| 629 |
+
new_width = 2 * window_size[1] - 1
|
| 630 |
+
|
| 631 |
+
old_relative_position_bias_table = self.relative_position_bias_table
|
| 632 |
+
|
| 633 |
+
old_num_relative_distance = self.num_relative_distance
|
| 634 |
+
new_num_relative_distance = new_height * new_width + 3
|
| 635 |
+
|
| 636 |
+
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
|
| 637 |
+
|
| 638 |
+
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
| 639 |
+
new_sub_table = nn.functional.interpolate(
|
| 640 |
+
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
|
| 641 |
+
)
|
| 642 |
+
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
| 643 |
+
|
| 644 |
+
new_relative_position_bias_table = torch.cat(
|
| 645 |
+
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
relative_position_index = self.generate_relative_position_index(window_size)
|
| 649 |
+
relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
|
| 650 |
+
|
| 651 |
+
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
|
| 652 |
+
relative_position_bias = relative_position_bias.view(
|
| 653 |
+
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
|
| 654 |
+
)
|
| 655 |
+
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
|
| 656 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 657 |
+
|
| 658 |
+
if interpolate_pos_encoding:
|
| 659 |
+
relative_position_bias = nn.functional.interpolate(
|
| 660 |
+
relative_position_bias.unsqueeze(1),
|
| 661 |
+
size=(dim_size, dim_size),
|
| 662 |
+
mode="bilinear",
|
| 663 |
+
align_corners=False,
|
| 664 |
+
).squeeze(1)
|
| 665 |
+
|
| 666 |
+
return relative_position_bias.unsqueeze(0)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
|
| 670 |
+
class Data2VecVisionEncoder(nn.Module):
|
| 671 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
|
| 672 |
+
super().__init__()
|
| 673 |
+
self.config = config
|
| 674 |
+
self.has_relative_position_bias = config.use_shared_relative_position_bias
|
| 675 |
+
if self.has_relative_position_bias:
|
| 676 |
+
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
| 677 |
+
|
| 678 |
+
# stochastic depth decay rule
|
| 679 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
| 680 |
+
self.layer = nn.ModuleList(
|
| 681 |
+
[
|
| 682 |
+
Data2VecVisionLayer(
|
| 683 |
+
config,
|
| 684 |
+
window_size=window_size if config.use_relative_position_bias else None,
|
| 685 |
+
drop_path_rate=dpr[i],
|
| 686 |
+
)
|
| 687 |
+
for i in range(config.num_hidden_layers)
|
| 688 |
+
]
|
| 689 |
+
)
|
| 690 |
+
self.gradient_checkpointing = False
|
| 691 |
+
|
| 692 |
+
def forward(
|
| 693 |
+
self,
|
| 694 |
+
hidden_states: torch.Tensor,
|
| 695 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 696 |
+
output_attentions: bool = False,
|
| 697 |
+
output_hidden_states: bool = False,
|
| 698 |
+
interpolate_pos_encoding: bool = False,
|
| 699 |
+
resolution: Optional[Tuple[int, int]] = None,
|
| 700 |
+
return_dict: bool = True,
|
| 701 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 702 |
+
all_hidden_states = () if output_hidden_states else None
|
| 703 |
+
all_self_attentions = () if output_attentions else None
|
| 704 |
+
|
| 705 |
+
for i, layer_module in enumerate(self.layer):
|
| 706 |
+
if output_hidden_states:
|
| 707 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 708 |
+
|
| 709 |
+
if self.has_relative_position_bias:
|
| 710 |
+
height, width = resolution
|
| 711 |
+
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
| 712 |
+
relative_position_bias = self.relative_position_bias(
|
| 713 |
+
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
| 714 |
+
)
|
| 715 |
+
else:
|
| 716 |
+
relative_position_bias = None
|
| 717 |
+
|
| 718 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 719 |
+
|
| 720 |
+
if self.gradient_checkpointing and self.training:
|
| 721 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 722 |
+
layer_module.__call__,
|
| 723 |
+
hidden_states,
|
| 724 |
+
layer_head_mask,
|
| 725 |
+
output_attentions,
|
| 726 |
+
relative_position_bias,
|
| 727 |
+
interpolate_pos_encoding,
|
| 728 |
+
resolution,
|
| 729 |
+
)
|
| 730 |
+
else:
|
| 731 |
+
layer_outputs = layer_module(
|
| 732 |
+
hidden_states,
|
| 733 |
+
layer_head_mask,
|
| 734 |
+
output_attentions,
|
| 735 |
+
relative_position_bias,
|
| 736 |
+
interpolate_pos_encoding,
|
| 737 |
+
resolution,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
hidden_states = layer_outputs[0]
|
| 741 |
+
|
| 742 |
+
if output_attentions:
|
| 743 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 744 |
+
|
| 745 |
+
if output_hidden_states:
|
| 746 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 747 |
+
|
| 748 |
+
if not return_dict:
|
| 749 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 750 |
+
return BaseModelOutput(
|
| 751 |
+
last_hidden_state=hidden_states,
|
| 752 |
+
hidden_states=all_hidden_states,
|
| 753 |
+
attentions=all_self_attentions,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
# Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision
|
| 758 |
+
class Data2VecVisionPreTrainedModel(PreTrainedModel):
|
| 759 |
+
"""
|
| 760 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 761 |
+
models.
|
| 762 |
+
"""
|
| 763 |
+
|
| 764 |
+
config_class = Data2VecVisionConfig
|
| 765 |
+
base_model_prefix = "data2vec_vision"
|
| 766 |
+
main_input_name = "pixel_values"
|
| 767 |
+
supports_gradient_checkpointing = True
|
| 768 |
+
_no_split_modules = ["Data2VecVisionLayer"]
|
| 769 |
+
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
| 770 |
+
_supports_sdpa = True
|
| 771 |
+
|
| 772 |
+
def _init_weights(self, module):
|
| 773 |
+
"""Initialize the weights"""
|
| 774 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
| 775 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 776 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 777 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 778 |
+
if module.bias is not None:
|
| 779 |
+
module.bias.data.zero_()
|
| 780 |
+
elif isinstance(module, nn.Embedding):
|
| 781 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 782 |
+
if module.padding_idx is not None:
|
| 783 |
+
module.weight.data[module.padding_idx].zero_()
|
| 784 |
+
elif isinstance(module, nn.LayerNorm):
|
| 785 |
+
module.bias.data.zero_()
|
| 786 |
+
module.weight.data.fill_(1.0)
|
| 787 |
+
elif isinstance(module, Data2VecVisionEmbeddings):
|
| 788 |
+
module.cls_token.data.zero_()
|
| 789 |
+
if module.mask_token is not None:
|
| 790 |
+
module.mask_token.data.zero_()
|
| 791 |
+
if module.position_embeddings is not None:
|
| 792 |
+
module.position_embeddings.data.zero_()
|
| 793 |
+
elif isinstance(module, Data2VecVisionRelativePositionBias):
|
| 794 |
+
module.relative_position_bias_table.data.zero_()
|
| 795 |
+
elif isinstance(module, Data2VecVisionLayer):
|
| 796 |
+
if module.lambda_1 is not None:
|
| 797 |
+
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
| 798 |
+
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
DATA2VEC_VISION_START_DOCSTRING = r"""
|
| 802 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 803 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 804 |
+
behavior.
|
| 805 |
+
|
| 806 |
+
Parameters:
|
| 807 |
+
config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
|
| 808 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 809 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 810 |
+
"""
|
| 811 |
+
|
| 812 |
+
DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
|
| 813 |
+
Args:
|
| 814 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 815 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 816 |
+
[`BeitImageProcessor.__call__`] for details.
|
| 817 |
+
|
| 818 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 819 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 820 |
+
|
| 821 |
+
- 1 indicates the head is **not masked**,
|
| 822 |
+
- 0 indicates the head is **masked**.
|
| 823 |
+
|
| 824 |
+
output_attentions (`bool`, *optional*):
|
| 825 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 826 |
+
tensors for more detail.
|
| 827 |
+
output_hidden_states (`bool`, *optional*):
|
| 828 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 829 |
+
more detail.
|
| 830 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
| 831 |
+
Whether to interpolate the pre-trained position encodings.
|
| 832 |
+
return_dict (`bool`, *optional*):
|
| 833 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 834 |
+
"""
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
@add_start_docstrings(
|
| 838 |
+
"The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
|
| 839 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 840 |
+
)
|
| 841 |
+
# Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False
|
| 842 |
+
class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
| 843 |
+
def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None:
|
| 844 |
+
super().__init__(config)
|
| 845 |
+
self.config = config
|
| 846 |
+
|
| 847 |
+
self.embeddings = Data2VecVisionEmbeddings(config)
|
| 848 |
+
self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
|
| 849 |
+
|
| 850 |
+
self.layernorm = (
|
| 851 |
+
nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 852 |
+
)
|
| 853 |
+
self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None
|
| 854 |
+
|
| 855 |
+
# Initialize weights and apply final processing
|
| 856 |
+
self.post_init()
|
| 857 |
+
|
| 858 |
+
def get_input_embeddings(self):
|
| 859 |
+
return self.embeddings.patch_embeddings
|
| 860 |
+
|
| 861 |
+
def _prune_heads(self, heads_to_prune):
|
| 862 |
+
"""
|
| 863 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 864 |
+
class PreTrainedModel
|
| 865 |
+
"""
|
| 866 |
+
for layer, heads in heads_to_prune.items():
|
| 867 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 868 |
+
|
| 869 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 870 |
+
@add_code_sample_docstrings(
|
| 871 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 872 |
+
output_type=Data2VecVisionModelOutputWithPooling,
|
| 873 |
+
config_class=_CONFIG_FOR_DOC,
|
| 874 |
+
modality="vision",
|
| 875 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 876 |
+
)
|
| 877 |
+
def forward(
|
| 878 |
+
self,
|
| 879 |
+
pixel_values: torch.Tensor,
|
| 880 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 881 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 882 |
+
output_attentions: Optional[bool] = None,
|
| 883 |
+
output_hidden_states: Optional[bool] = None,
|
| 884 |
+
interpolate_pos_encoding: bool = False,
|
| 885 |
+
return_dict: Optional[bool] = None,
|
| 886 |
+
) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
|
| 887 |
+
r"""
|
| 888 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
| 889 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 890 |
+
"""
|
| 891 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 892 |
+
output_hidden_states = (
|
| 893 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 894 |
+
)
|
| 895 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 896 |
+
|
| 897 |
+
# Prepare head mask if needed
|
| 898 |
+
# 1.0 in head_mask indicate we keep the head
|
| 899 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 900 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 901 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 902 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 903 |
+
|
| 904 |
+
embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 905 |
+
resolution = pixel_values.shape[2:]
|
| 906 |
+
|
| 907 |
+
encoder_outputs = self.encoder(
|
| 908 |
+
embedding_output,
|
| 909 |
+
head_mask=head_mask,
|
| 910 |
+
output_attentions=output_attentions,
|
| 911 |
+
output_hidden_states=output_hidden_states,
|
| 912 |
+
resolution=resolution,
|
| 913 |
+
return_dict=return_dict,
|
| 914 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 915 |
+
)
|
| 916 |
+
sequence_output = encoder_outputs[0]
|
| 917 |
+
sequence_output = self.layernorm(sequence_output)
|
| 918 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 919 |
+
|
| 920 |
+
if not return_dict:
|
| 921 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
| 922 |
+
return head_outputs + encoder_outputs[1:]
|
| 923 |
+
|
| 924 |
+
return Data2VecVisionModelOutputWithPooling(
|
| 925 |
+
last_hidden_state=sequence_output,
|
| 926 |
+
pooler_output=pooled_output,
|
| 927 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 928 |
+
attentions=encoder_outputs.attentions,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
# Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision
|
| 933 |
+
class Data2VecVisionPooler(nn.Module):
|
| 934 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 935 |
+
super().__init__()
|
| 936 |
+
self.layernorm = (
|
| 937 |
+
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 941 |
+
if self.layernorm is not None:
|
| 942 |
+
# Mean pool the final hidden states of the patch tokens
|
| 943 |
+
patch_tokens = hidden_states[:, 1:, :]
|
| 944 |
+
pooled_output = self.layernorm(patch_tokens.mean(1))
|
| 945 |
+
else:
|
| 946 |
+
# Pool by simply taking the final hidden state of the [CLS] token
|
| 947 |
+
pooled_output = hidden_states[:, 0]
|
| 948 |
+
|
| 949 |
+
return pooled_output
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
@add_start_docstrings(
|
| 953 |
+
"""
|
| 954 |
+
Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
|
| 955 |
+
the final hidden states of the patch tokens) e.g. for ImageNet.
|
| 956 |
+
""",
|
| 957 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 958 |
+
)
|
| 959 |
+
# Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision
|
| 960 |
+
class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
|
| 961 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 962 |
+
super().__init__(config)
|
| 963 |
+
|
| 964 |
+
self.num_labels = config.num_labels
|
| 965 |
+
self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True)
|
| 966 |
+
|
| 967 |
+
# Classifier head
|
| 968 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 969 |
+
|
| 970 |
+
# Initialize weights and apply final processing
|
| 971 |
+
self.post_init()
|
| 972 |
+
|
| 973 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 974 |
+
@add_code_sample_docstrings(
|
| 975 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 976 |
+
output_type=ImageClassifierOutput,
|
| 977 |
+
config_class=_CONFIG_FOR_DOC,
|
| 978 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 979 |
+
)
|
| 980 |
+
def forward(
|
| 981 |
+
self,
|
| 982 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 983 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 984 |
+
labels: Optional[torch.Tensor] = None,
|
| 985 |
+
output_attentions: Optional[bool] = None,
|
| 986 |
+
output_hidden_states: Optional[bool] = None,
|
| 987 |
+
interpolate_pos_encoding: bool = False,
|
| 988 |
+
return_dict: Optional[bool] = None,
|
| 989 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 990 |
+
r"""
|
| 991 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 992 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 993 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 994 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 995 |
+
"""
|
| 996 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 997 |
+
outputs = self.data2vec_vision(
|
| 998 |
+
pixel_values,
|
| 999 |
+
head_mask=head_mask,
|
| 1000 |
+
output_attentions=output_attentions,
|
| 1001 |
+
output_hidden_states=output_hidden_states,
|
| 1002 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1003 |
+
return_dict=return_dict,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
| 1007 |
+
|
| 1008 |
+
logits = self.classifier(pooled_output)
|
| 1009 |
+
|
| 1010 |
+
loss = None
|
| 1011 |
+
if labels is not None:
|
| 1012 |
+
if self.config.problem_type is None:
|
| 1013 |
+
if self.num_labels == 1:
|
| 1014 |
+
self.config.problem_type = "regression"
|
| 1015 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1016 |
+
self.config.problem_type = "single_label_classification"
|
| 1017 |
+
else:
|
| 1018 |
+
self.config.problem_type = "multi_label_classification"
|
| 1019 |
+
|
| 1020 |
+
if self.config.problem_type == "regression":
|
| 1021 |
+
loss_fct = MSELoss()
|
| 1022 |
+
if self.num_labels == 1:
|
| 1023 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1024 |
+
else:
|
| 1025 |
+
loss = loss_fct(logits, labels)
|
| 1026 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1027 |
+
loss_fct = CrossEntropyLoss()
|
| 1028 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1029 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1030 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1031 |
+
loss = loss_fct(logits, labels)
|
| 1032 |
+
if not return_dict:
|
| 1033 |
+
output = (logits,) + outputs[2:]
|
| 1034 |
+
return ((loss,) + output) if loss is not None else output
|
| 1035 |
+
|
| 1036 |
+
return ImageClassifierOutput(
|
| 1037 |
+
loss=loss,
|
| 1038 |
+
logits=logits,
|
| 1039 |
+
hidden_states=outputs.hidden_states,
|
| 1040 |
+
attentions=outputs.attentions,
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
# Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision
|
| 1045 |
+
class Data2VecVisionConvModule(nn.Module):
|
| 1046 |
+
"""
|
| 1047 |
+
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
| 1048 |
+
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
| 1049 |
+
|
| 1050 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1051 |
+
"""
|
| 1052 |
+
|
| 1053 |
+
def __init__(
|
| 1054 |
+
self,
|
| 1055 |
+
in_channels: int,
|
| 1056 |
+
out_channels: int,
|
| 1057 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 1058 |
+
padding: Union[int, Tuple[int, int], str] = 0,
|
| 1059 |
+
bias: bool = False,
|
| 1060 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 1061 |
+
) -> None:
|
| 1062 |
+
super().__init__()
|
| 1063 |
+
self.conv = nn.Conv2d(
|
| 1064 |
+
in_channels=in_channels,
|
| 1065 |
+
out_channels=out_channels,
|
| 1066 |
+
kernel_size=kernel_size,
|
| 1067 |
+
padding=padding,
|
| 1068 |
+
bias=bias,
|
| 1069 |
+
dilation=dilation,
|
| 1070 |
+
)
|
| 1071 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 1072 |
+
self.activation = nn.ReLU()
|
| 1073 |
+
|
| 1074 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 1075 |
+
output = self.conv(input)
|
| 1076 |
+
output = self.bn(output)
|
| 1077 |
+
output = self.activation(output)
|
| 1078 |
+
|
| 1079 |
+
return output
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
|
| 1083 |
+
class Data2VecVisionPyramidPoolingBlock(nn.Module):
|
| 1084 |
+
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
|
| 1085 |
+
super().__init__()
|
| 1086 |
+
self.layers = [
|
| 1087 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
| 1088 |
+
Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
|
| 1089 |
+
]
|
| 1090 |
+
for i, layer in enumerate(self.layers):
|
| 1091 |
+
self.add_module(str(i), layer)
|
| 1092 |
+
|
| 1093 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 1094 |
+
hidden_state = input
|
| 1095 |
+
for layer in self.layers:
|
| 1096 |
+
hidden_state = layer(hidden_state)
|
| 1097 |
+
return hidden_state
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
|
| 1101 |
+
class Data2VecVisionPyramidPoolingModule(nn.Module):
|
| 1102 |
+
"""
|
| 1103 |
+
Pyramid Pooling Module (PPM) used in PSPNet.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 1107 |
+
Module.
|
| 1108 |
+
in_channels (int): Input channels.
|
| 1109 |
+
channels (int): Channels after modules, before conv_seg.
|
| 1110 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 1111 |
+
|
| 1112 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1113 |
+
"""
|
| 1114 |
+
|
| 1115 |
+
def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
|
| 1116 |
+
super().__init__()
|
| 1117 |
+
self.pool_scales = pool_scales
|
| 1118 |
+
self.align_corners = align_corners
|
| 1119 |
+
self.in_channels = in_channels
|
| 1120 |
+
self.channels = channels
|
| 1121 |
+
self.blocks = []
|
| 1122 |
+
for i, pool_scale in enumerate(pool_scales):
|
| 1123 |
+
block = Data2VecVisionPyramidPoolingBlock(
|
| 1124 |
+
pool_scale=pool_scale, in_channels=in_channels, channels=channels
|
| 1125 |
+
)
|
| 1126 |
+
self.blocks.append(block)
|
| 1127 |
+
self.add_module(str(i), block)
|
| 1128 |
+
|
| 1129 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 1130 |
+
ppm_outs = []
|
| 1131 |
+
for ppm in self.blocks:
|
| 1132 |
+
ppm_out = ppm(x)
|
| 1133 |
+
upsampled_ppm_out = nn.functional.interpolate(
|
| 1134 |
+
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
|
| 1135 |
+
)
|
| 1136 |
+
ppm_outs.append(upsampled_ppm_out)
|
| 1137 |
+
return ppm_outs
|
| 1138 |
+
|
| 1139 |
+
|
| 1140 |
+
# Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision
|
| 1141 |
+
class Data2VecVisionUperHead(nn.Module):
|
| 1142 |
+
"""
|
| 1143 |
+
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
|
| 1144 |
+
[UPerNet](https://arxiv.org/abs/1807.10221).
|
| 1145 |
+
|
| 1146 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1147 |
+
"""
|
| 1148 |
+
|
| 1149 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 1150 |
+
super().__init__()
|
| 1151 |
+
|
| 1152 |
+
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
| 1153 |
+
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
| 1154 |
+
self.channels = config.hidden_size
|
| 1155 |
+
self.align_corners = False
|
| 1156 |
+
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
| 1157 |
+
|
| 1158 |
+
# PSP Module
|
| 1159 |
+
self.psp_modules = Data2VecVisionPyramidPoolingModule(
|
| 1160 |
+
self.pool_scales,
|
| 1161 |
+
self.in_channels[-1],
|
| 1162 |
+
self.channels,
|
| 1163 |
+
align_corners=self.align_corners,
|
| 1164 |
+
)
|
| 1165 |
+
self.bottleneck = Data2VecVisionConvModule(
|
| 1166 |
+
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
| 1167 |
+
self.channels,
|
| 1168 |
+
kernel_size=3,
|
| 1169 |
+
padding=1,
|
| 1170 |
+
)
|
| 1171 |
+
# FPN Module
|
| 1172 |
+
self.lateral_convs = nn.ModuleList()
|
| 1173 |
+
self.fpn_convs = nn.ModuleList()
|
| 1174 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
| 1175 |
+
l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1)
|
| 1176 |
+
fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1)
|
| 1177 |
+
self.lateral_convs.append(l_conv)
|
| 1178 |
+
self.fpn_convs.append(fpn_conv)
|
| 1179 |
+
|
| 1180 |
+
self.fpn_bottleneck = Data2VecVisionConvModule(
|
| 1181 |
+
len(self.in_channels) * self.channels,
|
| 1182 |
+
self.channels,
|
| 1183 |
+
kernel_size=3,
|
| 1184 |
+
padding=1,
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
def psp_forward(self, inputs):
|
| 1188 |
+
x = inputs[-1]
|
| 1189 |
+
psp_outs = [x]
|
| 1190 |
+
psp_outs.extend(self.psp_modules(x))
|
| 1191 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
| 1192 |
+
output = self.bottleneck(psp_outs)
|
| 1193 |
+
|
| 1194 |
+
return output
|
| 1195 |
+
|
| 1196 |
+
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1197 |
+
# build laterals
|
| 1198 |
+
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
| 1199 |
+
|
| 1200 |
+
laterals.append(self.psp_forward(encoder_hidden_states))
|
| 1201 |
+
|
| 1202 |
+
# build top-down path
|
| 1203 |
+
used_backbone_levels = len(laterals)
|
| 1204 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 1205 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 1206 |
+
laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
|
| 1207 |
+
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
# build outputs
|
| 1211 |
+
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
| 1212 |
+
# append psp feature
|
| 1213 |
+
fpn_outs.append(laterals[-1])
|
| 1214 |
+
|
| 1215 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 1216 |
+
fpn_outs[i] = nn.functional.interpolate(
|
| 1217 |
+
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
|
| 1218 |
+
)
|
| 1219 |
+
fpn_outs = torch.cat(fpn_outs, dim=1)
|
| 1220 |
+
output = self.fpn_bottleneck(fpn_outs)
|
| 1221 |
+
output = self.classifier(output)
|
| 1222 |
+
|
| 1223 |
+
return output
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
# Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision
|
| 1227 |
+
class Data2VecVisionFCNHead(nn.Module):
|
| 1228 |
+
"""
|
| 1229 |
+
Fully Convolution Networks for Semantic Segmentation. This head is implemented of
|
| 1230 |
+
[FCNNet](https://arxiv.org/abs/1411.4038>).
|
| 1231 |
+
|
| 1232 |
+
Args:
|
| 1233 |
+
config (Data2VecVisionConfig): Configuration.
|
| 1234 |
+
in_channels
|
| 1235 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
| 1236 |
+
dilation (int): The dilation rate for convs in the head. Default: 1.
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1240 |
+
"""
|
| 1241 |
+
|
| 1242 |
+
def __init__(
|
| 1243 |
+
self,
|
| 1244 |
+
config: Data2VecVisionConfig,
|
| 1245 |
+
in_index: int = 2,
|
| 1246 |
+
kernel_size: int = 3,
|
| 1247 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 1248 |
+
) -> None:
|
| 1249 |
+
super().__init__()
|
| 1250 |
+
self.in_channels = config.hidden_size
|
| 1251 |
+
self.channels = config.auxiliary_channels
|
| 1252 |
+
self.num_convs = config.auxiliary_num_convs
|
| 1253 |
+
self.concat_input = config.auxiliary_concat_input
|
| 1254 |
+
self.in_index = in_index
|
| 1255 |
+
|
| 1256 |
+
conv_padding = (kernel_size // 2) * dilation
|
| 1257 |
+
convs = []
|
| 1258 |
+
convs.append(
|
| 1259 |
+
Data2VecVisionConvModule(
|
| 1260 |
+
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
| 1261 |
+
)
|
| 1262 |
+
)
|
| 1263 |
+
for i in range(self.num_convs - 1):
|
| 1264 |
+
convs.append(
|
| 1265 |
+
Data2VecVisionConvModule(
|
| 1266 |
+
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
| 1267 |
+
)
|
| 1268 |
+
)
|
| 1269 |
+
if self.num_convs == 0:
|
| 1270 |
+
self.convs = nn.Identity()
|
| 1271 |
+
else:
|
| 1272 |
+
self.convs = nn.Sequential(*convs)
|
| 1273 |
+
if self.concat_input:
|
| 1274 |
+
self.conv_cat = Data2VecVisionConvModule(
|
| 1275 |
+
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
| 1279 |
+
|
| 1280 |
+
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1281 |
+
# just take the relevant feature maps
|
| 1282 |
+
hidden_states = encoder_hidden_states[self.in_index]
|
| 1283 |
+
output = self.convs(hidden_states)
|
| 1284 |
+
if self.concat_input:
|
| 1285 |
+
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
|
| 1286 |
+
output = self.classifier(output)
|
| 1287 |
+
return output
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
@add_start_docstrings(
|
| 1291 |
+
"""
|
| 1292 |
+
Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
| 1293 |
+
""",
|
| 1294 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 1295 |
+
)
|
| 1296 |
+
# Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision
|
| 1297 |
+
class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
| 1298 |
+
def __init__(self, config: Data2VecVisionConfig) -> None:
|
| 1299 |
+
super().__init__(config)
|
| 1300 |
+
|
| 1301 |
+
self.num_labels = config.num_labels
|
| 1302 |
+
self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
|
| 1303 |
+
|
| 1304 |
+
# FPNs
|
| 1305 |
+
if len(self.config.out_indices) != 4:
|
| 1306 |
+
raise ValueError(
|
| 1307 |
+
"Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
|
| 1308 |
+
"specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
|
| 1309 |
+
"a base-sized architecture."
|
| 1310 |
+
)
|
| 1311 |
+
self.fpn1 = nn.Sequential(
|
| 1312 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
| 1313 |
+
nn.BatchNorm2d(config.hidden_size),
|
| 1314 |
+
nn.GELU(),
|
| 1315 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
| 1316 |
+
)
|
| 1317 |
+
self.fpn2 = nn.Sequential(
|
| 1318 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
| 1319 |
+
)
|
| 1320 |
+
self.fpn3 = nn.Identity()
|
| 1321 |
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 1322 |
+
|
| 1323 |
+
# Semantic segmentation head(s)
|
| 1324 |
+
self.decode_head = Data2VecVisionUperHead(config)
|
| 1325 |
+
self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None
|
| 1326 |
+
|
| 1327 |
+
# Initialize weights and apply final processing
|
| 1328 |
+
self.post_init()
|
| 1329 |
+
|
| 1330 |
+
def compute_loss(self, logits, auxiliary_logits, labels):
|
| 1331 |
+
# upsample logits to the images' original size
|
| 1332 |
+
upsampled_logits = nn.functional.interpolate(
|
| 1333 |
+
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
| 1334 |
+
)
|
| 1335 |
+
if auxiliary_logits is not None:
|
| 1336 |
+
upsampled_auxiliary_logits = nn.functional.interpolate(
|
| 1337 |
+
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
| 1338 |
+
)
|
| 1339 |
+
# compute weighted loss
|
| 1340 |
+
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
| 1341 |
+
main_loss = loss_fct(upsampled_logits, labels)
|
| 1342 |
+
loss = main_loss
|
| 1343 |
+
if auxiliary_logits is not None:
|
| 1344 |
+
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
|
| 1345 |
+
loss += self.config.auxiliary_loss_weight * auxiliary_loss
|
| 1346 |
+
|
| 1347 |
+
return loss
|
| 1348 |
+
|
| 1349 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 1350 |
+
@replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
| 1351 |
+
def forward(
|
| 1352 |
+
self,
|
| 1353 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1354 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1355 |
+
labels: Optional[torch.Tensor] = None,
|
| 1356 |
+
output_attentions: Optional[bool] = None,
|
| 1357 |
+
output_hidden_states: Optional[bool] = None,
|
| 1358 |
+
interpolate_pos_encoding: bool = False,
|
| 1359 |
+
return_dict: Optional[bool] = None,
|
| 1360 |
+
) -> Union[tuple, SemanticSegmenterOutput]:
|
| 1361 |
+
r"""
|
| 1362 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 1363 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
| 1364 |
+
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
| 1365 |
+
|
| 1366 |
+
Returns:
|
| 1367 |
+
|
| 1368 |
+
Examples:
|
| 1369 |
+
|
| 1370 |
+
```python
|
| 1371 |
+
>>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
|
| 1372 |
+
>>> from PIL import Image
|
| 1373 |
+
>>> import requests
|
| 1374 |
+
|
| 1375 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1376 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1377 |
+
|
| 1378 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
|
| 1379 |
+
>>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
|
| 1380 |
+
|
| 1381 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1382 |
+
>>> outputs = model(**inputs)
|
| 1383 |
+
>>> # logits are of shape (batch_size, num_labels, height, width)
|
| 1384 |
+
>>> logits = outputs.logits
|
| 1385 |
+
```"""
|
| 1386 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1387 |
+
output_hidden_states = (
|
| 1388 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1389 |
+
)
|
| 1390 |
+
|
| 1391 |
+
if labels is not None and self.config.num_labels == 1:
|
| 1392 |
+
raise ValueError("The number of labels should be greater than one")
|
| 1393 |
+
|
| 1394 |
+
outputs = self.data2vec_vision(
|
| 1395 |
+
pixel_values,
|
| 1396 |
+
head_mask=head_mask,
|
| 1397 |
+
output_attentions=output_attentions,
|
| 1398 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 1399 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1400 |
+
return_dict=return_dict,
|
| 1401 |
+
)
|
| 1402 |
+
|
| 1403 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 1404 |
+
|
| 1405 |
+
# only keep certain features, and reshape
|
| 1406 |
+
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
| 1407 |
+
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
| 1408 |
+
batch_size = pixel_values.shape[0]
|
| 1409 |
+
patch_resolution = self.config.image_size // self.config.patch_size
|
| 1410 |
+
features = [
|
| 1411 |
+
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
|
| 1412 |
+
]
|
| 1413 |
+
|
| 1414 |
+
# apply FPNs
|
| 1415 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
| 1416 |
+
for i in range(len(features)):
|
| 1417 |
+
features[i] = ops[i](features[i])
|
| 1418 |
+
|
| 1419 |
+
logits = self.decode_head(features)
|
| 1420 |
+
|
| 1421 |
+
auxiliary_logits = None
|
| 1422 |
+
if self.auxiliary_head is not None:
|
| 1423 |
+
auxiliary_logits = self.auxiliary_head(features)
|
| 1424 |
+
|
| 1425 |
+
loss = None
|
| 1426 |
+
if labels is not None:
|
| 1427 |
+
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
| 1428 |
+
|
| 1429 |
+
if not return_dict:
|
| 1430 |
+
if output_hidden_states:
|
| 1431 |
+
output = (logits,) + outputs[1:]
|
| 1432 |
+
else:
|
| 1433 |
+
output = (logits,) + outputs[2:]
|
| 1434 |
+
return ((loss,) + output) if loss is not None else output
|
| 1435 |
+
|
| 1436 |
+
return SemanticSegmenterOutput(
|
| 1437 |
+
loss=loss,
|
| 1438 |
+
logits=logits,
|
| 1439 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 1440 |
+
attentions=outputs.attentions,
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
|
| 1444 |
+
__all__ = [
|
| 1445 |
+
"Data2VecVisionForImageClassification",
|
| 1446 |
+
"Data2VecVisionForSemanticSegmentation",
|
| 1447 |
+
"Data2VecVisionModel",
|
| 1448 |
+
"Data2VecVisionPreTrainedModel",
|
| 1449 |
+
]
|
docs/transformers/build/lib/transformers/models/data2vec/modeling_tf_data2vec_vision.py
ADDED
|
@@ -0,0 +1,1724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 Data2Vec Vision model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import collections.abc
|
| 20 |
+
import math
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
|
| 27 |
+
from ...activations_tf import get_tf_activation
|
| 28 |
+
from ...modeling_tf_outputs import (
|
| 29 |
+
TFBaseModelOutput,
|
| 30 |
+
TFBaseModelOutputWithPooling,
|
| 31 |
+
TFSemanticSegmenterOutput,
|
| 32 |
+
TFSequenceClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from ...modeling_tf_utils import (
|
| 35 |
+
TFModelInputType,
|
| 36 |
+
TFPreTrainedModel,
|
| 37 |
+
TFSequenceClassificationLoss,
|
| 38 |
+
get_initializer,
|
| 39 |
+
keras,
|
| 40 |
+
keras_serializable,
|
| 41 |
+
unpack_inputs,
|
| 42 |
+
)
|
| 43 |
+
from ...tf_utils import shape_list, stable_softmax
|
| 44 |
+
from ...utils import (
|
| 45 |
+
add_code_sample_docstrings,
|
| 46 |
+
add_start_docstrings,
|
| 47 |
+
add_start_docstrings_to_model_forward,
|
| 48 |
+
logging,
|
| 49 |
+
replace_return_docstrings,
|
| 50 |
+
)
|
| 51 |
+
from .configuration_data2vec_vision import Data2VecVisionConfig
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
# General docstring
|
| 57 |
+
_CONFIG_FOR_DOC = "Data2VecVisionConfig"
|
| 58 |
+
|
| 59 |
+
# Base docstring
|
| 60 |
+
_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
|
| 61 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
| 62 |
+
|
| 63 |
+
# Image classification docstring
|
| 64 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
|
| 65 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
|
| 70 |
+
"""
|
| 71 |
+
Class for outputs of [`TFData2VecVisionModel`].
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 75 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 76 |
+
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
|
| 77 |
+
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
|
| 78 |
+
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
|
| 79 |
+
will be returned.
|
| 80 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 81 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 82 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 83 |
+
|
| 84 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 85 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 86 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 87 |
+
sequence_length)`.
|
| 88 |
+
|
| 89 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 90 |
+
heads.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
last_hidden_state: Optional[tf.Tensor] = None
|
| 94 |
+
pooler_output: Optional[tf.Tensor] = None
|
| 95 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 96 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TFData2VecVisionDropPath(keras.layers.Layer):
|
| 100 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 101 |
+
References:
|
| 102 |
+
(1) github.com:rwightman/pytorch-image-models
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, drop_path, **kwargs):
|
| 106 |
+
super().__init__(**kwargs)
|
| 107 |
+
self.drop_path = drop_path
|
| 108 |
+
|
| 109 |
+
def call(self, x, training=None):
|
| 110 |
+
if training:
|
| 111 |
+
keep_prob = 1 - self.drop_path
|
| 112 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
| 113 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
| 114 |
+
random_tensor = tf.floor(random_tensor)
|
| 115 |
+
return (x / keep_prob) * random_tensor
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TFData2VecVisionEmbeddings(keras.layers.Layer):
|
| 120 |
+
"""
|
| 121 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 126 |
+
super().__init__(**kwargs)
|
| 127 |
+
self.config = config
|
| 128 |
+
|
| 129 |
+
self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
|
| 130 |
+
self.num_patches = self.patch_embeddings.num_patches
|
| 131 |
+
self.config = config
|
| 132 |
+
|
| 133 |
+
self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
|
| 134 |
+
|
| 135 |
+
def build(self, input_shape=None):
|
| 136 |
+
self.cls_token = self.add_weight(
|
| 137 |
+
shape=(1, 1, self.config.hidden_size),
|
| 138 |
+
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
|
| 139 |
+
trainable=True,
|
| 140 |
+
name="cls_token",
|
| 141 |
+
)
|
| 142 |
+
if self.config.use_mask_token:
|
| 143 |
+
self.mask_token = self.add_weight(
|
| 144 |
+
shape=(1, 1, self.config.hidden_size),
|
| 145 |
+
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
|
| 146 |
+
trainable=True,
|
| 147 |
+
name="mask_token",
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
self.mask_token = None
|
| 151 |
+
|
| 152 |
+
if self.config.use_absolute_position_embeddings:
|
| 153 |
+
self.position_embeddings = self.add_weight(
|
| 154 |
+
shape=(1, self.num_patches + 1, self.config.hidden_size),
|
| 155 |
+
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
|
| 156 |
+
trainable=True,
|
| 157 |
+
name="position_embeddings",
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
self.position_embeddings = None
|
| 161 |
+
|
| 162 |
+
if self.built:
|
| 163 |
+
return
|
| 164 |
+
self.built = True
|
| 165 |
+
if getattr(self, "patch_embeddings", None) is not None:
|
| 166 |
+
with tf.name_scope(self.patch_embeddings.name):
|
| 167 |
+
self.patch_embeddings.build(None)
|
| 168 |
+
|
| 169 |
+
def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor:
|
| 170 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 171 |
+
batch_size, seq_len, projection_dim = shape_list(embeddings)
|
| 172 |
+
|
| 173 |
+
cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1))
|
| 174 |
+
|
| 175 |
+
if bool_masked_pos is not None:
|
| 176 |
+
mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim))
|
| 177 |
+
# replace the masked visual tokens by mask_tokens
|
| 178 |
+
w = bool_masked_pos[..., None]
|
| 179 |
+
w = tf.cast(w, mask_tokens.dtype)
|
| 180 |
+
# since TF doesn't support eager tensor assignment
|
| 181 |
+
embeddings = embeddings * (1 - w) + mask_tokens * w
|
| 182 |
+
|
| 183 |
+
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
|
| 184 |
+
if self.position_embeddings is not None:
|
| 185 |
+
embeddings = embeddings + self.position_embeddings
|
| 186 |
+
embeddings = self.dropout(embeddings)
|
| 187 |
+
|
| 188 |
+
return embeddings
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class TFData2VecVisionPatchEmbeddings(keras.layers.Layer):
|
| 192 |
+
"""
|
| 193 |
+
Image to Patch Embedding.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 197 |
+
super().__init__(**kwargs)
|
| 198 |
+
self.config = config
|
| 199 |
+
|
| 200 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 201 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 202 |
+
|
| 203 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 204 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 205 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 206 |
+
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
| 207 |
+
self.image_size = image_size
|
| 208 |
+
self.patch_size = patch_size
|
| 209 |
+
self.num_patches = num_patches
|
| 210 |
+
self.patch_shape = patch_shape
|
| 211 |
+
self.num_channels = num_channels
|
| 212 |
+
|
| 213 |
+
self.projection = keras.layers.Conv2D(
|
| 214 |
+
filters=hidden_size,
|
| 215 |
+
kernel_size=patch_size,
|
| 216 |
+
strides=patch_size,
|
| 217 |
+
padding="valid",
|
| 218 |
+
data_format="channels_last",
|
| 219 |
+
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
| 220 |
+
bias_initializer="zeros",
|
| 221 |
+
name="projection",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 225 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 226 |
+
if tf.executing_eagerly():
|
| 227 |
+
if num_channels != self.num_channels:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
| 230 |
+
" configuration."
|
| 231 |
+
)
|
| 232 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 235 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
|
| 239 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 240 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 241 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 242 |
+
|
| 243 |
+
projection = self.projection(pixel_values)
|
| 244 |
+
|
| 245 |
+
# Change the 2D spatial dimensions to a single temporal dimension.
|
| 246 |
+
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
| 247 |
+
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
| 248 |
+
|
| 249 |
+
return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
| 250 |
+
|
| 251 |
+
def build(self, input_shape=None):
|
| 252 |
+
if self.built:
|
| 253 |
+
return
|
| 254 |
+
self.built = True
|
| 255 |
+
if getattr(self, "projection", None) is not None:
|
| 256 |
+
with tf.name_scope(self.projection.name):
|
| 257 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class TFData2VecVisionSelfAttention(keras.layers.Layer):
|
| 261 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
|
| 262 |
+
super().__init__(**kwargs)
|
| 263 |
+
|
| 264 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 267 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.num_attention_heads = config.num_attention_heads
|
| 271 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 272 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 273 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 274 |
+
|
| 275 |
+
self.query = keras.layers.Dense(
|
| 276 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 277 |
+
)
|
| 278 |
+
self.key = keras.layers.Dense(
|
| 279 |
+
units=self.all_head_size,
|
| 280 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 281 |
+
name="key",
|
| 282 |
+
use_bias=False,
|
| 283 |
+
)
|
| 284 |
+
self.value = keras.layers.Dense(
|
| 285 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 286 |
+
)
|
| 287 |
+
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 288 |
+
|
| 289 |
+
if window_size:
|
| 290 |
+
self.relative_position_bias = TFData2VecVisionRelativePositionBias(
|
| 291 |
+
config, window_size=window_size, name="relative_position_bias"
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
self.relative_position_bias = None
|
| 295 |
+
self.config = config
|
| 296 |
+
|
| 297 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 298 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 299 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 300 |
+
|
| 301 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 302 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 303 |
+
|
| 304 |
+
def call(
|
| 305 |
+
self,
|
| 306 |
+
hidden_states: tf.Tensor,
|
| 307 |
+
head_mask: tf.Tensor,
|
| 308 |
+
output_attentions: bool,
|
| 309 |
+
relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
|
| 310 |
+
training: bool = False,
|
| 311 |
+
) -> Tuple[tf.Tensor]:
|
| 312 |
+
batch_size = shape_list(hidden_states)[0]
|
| 313 |
+
mixed_query_layer = self.query(inputs=hidden_states)
|
| 314 |
+
mixed_key_layer = self.key(inputs=hidden_states)
|
| 315 |
+
mixed_value_layer = self.value(inputs=hidden_states)
|
| 316 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 317 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 318 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 319 |
+
|
| 320 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 321 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 322 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 323 |
+
attention_scores = attention_scores / self.sqrt_att_head_size
|
| 324 |
+
|
| 325 |
+
# Add relative position bias if present.
|
| 326 |
+
if self.relative_position_bias is not None:
|
| 327 |
+
# Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
|
| 328 |
+
# might complain about `Layer.call()` not being invoked properly. In this case this input
|
| 329 |
+
# i.e., 0.0 is not going to be used in any calculations so we're safe.
|
| 330 |
+
attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...]
|
| 331 |
+
|
| 332 |
+
# Add shared relative position bias if provided.
|
| 333 |
+
if relative_position_bias is not None:
|
| 334 |
+
attention_scores = attention_scores + relative_position_bias
|
| 335 |
+
|
| 336 |
+
# Normalize the attention scores to probabilities.
|
| 337 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 338 |
+
|
| 339 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 340 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 341 |
+
attention_probs = self.dropout(inputs=attention_probs, training=training)
|
| 342 |
+
|
| 343 |
+
# Mask heads if we want to
|
| 344 |
+
if head_mask is not None:
|
| 345 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 346 |
+
|
| 347 |
+
attention_output = tf.matmul(attention_probs, value_layer)
|
| 348 |
+
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
| 349 |
+
|
| 350 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 351 |
+
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
| 352 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
| 353 |
+
|
| 354 |
+
return outputs
|
| 355 |
+
|
| 356 |
+
def build(self, input_shape=None):
|
| 357 |
+
if self.built:
|
| 358 |
+
return
|
| 359 |
+
self.built = True
|
| 360 |
+
if getattr(self, "query", None) is not None:
|
| 361 |
+
with tf.name_scope(self.query.name):
|
| 362 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 363 |
+
if getattr(self, "key", None) is not None:
|
| 364 |
+
with tf.name_scope(self.key.name):
|
| 365 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 366 |
+
if getattr(self, "value", None) is not None:
|
| 367 |
+
with tf.name_scope(self.value.name):
|
| 368 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 369 |
+
if getattr(self, "relative_position_bias", None) is not None:
|
| 370 |
+
with tf.name_scope(self.relative_position_bias.name):
|
| 371 |
+
self.relative_position_bias.build(None)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class TFData2VecVisionSelfOutput(keras.layers.Layer):
|
| 375 |
+
"""
|
| 376 |
+
The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due
|
| 377 |
+
to the layernorm applied before each block.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 381 |
+
super().__init__(**kwargs)
|
| 382 |
+
|
| 383 |
+
self.dense = keras.layers.Dense(
|
| 384 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 385 |
+
)
|
| 386 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 387 |
+
self.config = config
|
| 388 |
+
|
| 389 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
|
| 390 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 391 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 392 |
+
|
| 393 |
+
return hidden_states
|
| 394 |
+
|
| 395 |
+
def build(self, input_shape=None):
|
| 396 |
+
if self.built:
|
| 397 |
+
return
|
| 398 |
+
self.built = True
|
| 399 |
+
if getattr(self, "dense", None) is not None:
|
| 400 |
+
with tf.name_scope(self.dense.name):
|
| 401 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class TFData2VecVisionAttention(keras.layers.Layer):
|
| 405 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
|
| 406 |
+
super().__init__(**kwargs)
|
| 407 |
+
|
| 408 |
+
self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention")
|
| 409 |
+
self.dense_output = TFData2VecVisionSelfOutput(config, name="output")
|
| 410 |
+
|
| 411 |
+
def prune_heads(self, heads):
|
| 412 |
+
raise NotImplementedError
|
| 413 |
+
|
| 414 |
+
def call(
|
| 415 |
+
self,
|
| 416 |
+
input_tensor: tf.Tensor,
|
| 417 |
+
head_mask: tf.Tensor,
|
| 418 |
+
output_attentions: bool,
|
| 419 |
+
relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
|
| 420 |
+
training: bool = False,
|
| 421 |
+
) -> Tuple[tf.Tensor]:
|
| 422 |
+
self_outputs = self.attention(
|
| 423 |
+
hidden_states=input_tensor,
|
| 424 |
+
head_mask=head_mask,
|
| 425 |
+
output_attentions=output_attentions,
|
| 426 |
+
relative_position_bias=relative_position_bias,
|
| 427 |
+
training=training,
|
| 428 |
+
)
|
| 429 |
+
attention_output = self.dense_output(
|
| 430 |
+
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
| 431 |
+
)
|
| 432 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 433 |
+
|
| 434 |
+
return outputs
|
| 435 |
+
|
| 436 |
+
def build(self, input_shape=None):
|
| 437 |
+
if self.built:
|
| 438 |
+
return
|
| 439 |
+
self.built = True
|
| 440 |
+
if getattr(self, "attention", None) is not None:
|
| 441 |
+
with tf.name_scope(self.attention.name):
|
| 442 |
+
self.attention.build(None)
|
| 443 |
+
if getattr(self, "dense_output", None) is not None:
|
| 444 |
+
with tf.name_scope(self.dense_output.name):
|
| 445 |
+
self.dense_output.build(None)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
|
| 449 |
+
class TFData2VecVisionIntermediate(keras.layers.Layer):
|
| 450 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 451 |
+
super().__init__(**kwargs)
|
| 452 |
+
|
| 453 |
+
self.dense = keras.layers.Dense(
|
| 454 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if isinstance(config.hidden_act, str):
|
| 458 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 459 |
+
else:
|
| 460 |
+
self.intermediate_act_fn = config.hidden_act
|
| 461 |
+
self.config = config
|
| 462 |
+
|
| 463 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 464 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 465 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 466 |
+
|
| 467 |
+
return hidden_states
|
| 468 |
+
|
| 469 |
+
def build(self, input_shape=None):
|
| 470 |
+
if self.built:
|
| 471 |
+
return
|
| 472 |
+
self.built = True
|
| 473 |
+
if getattr(self, "dense", None) is not None:
|
| 474 |
+
with tf.name_scope(self.dense.name):
|
| 475 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class TFData2VecVisionOutput(keras.layers.Layer):
|
| 479 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 480 |
+
super().__init__(**kwargs)
|
| 481 |
+
|
| 482 |
+
self.dense = keras.layers.Dense(
|
| 483 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 484 |
+
)
|
| 485 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 486 |
+
self.config = config
|
| 487 |
+
|
| 488 |
+
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 489 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 490 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 491 |
+
|
| 492 |
+
return hidden_states
|
| 493 |
+
|
| 494 |
+
def build(self, input_shape=None):
|
| 495 |
+
if self.built:
|
| 496 |
+
return
|
| 497 |
+
self.built = True
|
| 498 |
+
if getattr(self, "dense", None) is not None:
|
| 499 |
+
with tf.name_scope(self.dense.name):
|
| 500 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class TFData2VecVisionLayer(keras.layers.Layer):
|
| 504 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 505 |
+
|
| 506 |
+
def __init__(
|
| 507 |
+
self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0, **kwargs
|
| 508 |
+
):
|
| 509 |
+
super().__init__(**kwargs)
|
| 510 |
+
self.config = config
|
| 511 |
+
|
| 512 |
+
self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention")
|
| 513 |
+
self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate")
|
| 514 |
+
self.data2vec_output = TFData2VecVisionOutput(config, name="output")
|
| 515 |
+
|
| 516 |
+
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
|
| 517 |
+
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
|
| 518 |
+
# Using `layers.Activation` instead of `tf.identity` to better control `training`
|
| 519 |
+
# behaviour.
|
| 520 |
+
self.drop_path = (
|
| 521 |
+
TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
|
| 522 |
+
if drop_path_rate > 0.0
|
| 523 |
+
else keras.layers.Activation("linear", name="drop_path")
|
| 524 |
+
)
|
| 525 |
+
self.init_values = config.layer_scale_init_value
|
| 526 |
+
|
| 527 |
+
def build(self, input_shape: tf.TensorShape = None):
|
| 528 |
+
if self.init_values > 0:
|
| 529 |
+
self.lambda_1 = self.add_weight(
|
| 530 |
+
shape=(self.config.hidden_size),
|
| 531 |
+
initializer="ones",
|
| 532 |
+
trainable=True,
|
| 533 |
+
name="lambda_1",
|
| 534 |
+
)
|
| 535 |
+
self.lambda_2 = self.add_weight(
|
| 536 |
+
shape=(self.config.hidden_size),
|
| 537 |
+
initializer="ones",
|
| 538 |
+
trainable=True,
|
| 539 |
+
name="lambda_2",
|
| 540 |
+
)
|
| 541 |
+
self.lambda_1.assign(self.init_values * tf.ones((self.config.hidden_size)))
|
| 542 |
+
self.lambda_2.assign(self.init_values * tf.ones((self.config.hidden_size)))
|
| 543 |
+
else:
|
| 544 |
+
self.lambda_1, self.lambda_2 = None, None
|
| 545 |
+
|
| 546 |
+
if self.built:
|
| 547 |
+
return
|
| 548 |
+
self.built = True
|
| 549 |
+
if getattr(self, "attention", None) is not None:
|
| 550 |
+
with tf.name_scope(self.attention.name):
|
| 551 |
+
self.attention.build(None)
|
| 552 |
+
if getattr(self, "intermediate", None) is not None:
|
| 553 |
+
with tf.name_scope(self.intermediate.name):
|
| 554 |
+
self.intermediate.build(None)
|
| 555 |
+
if getattr(self, "data2vec_output", None) is not None:
|
| 556 |
+
with tf.name_scope(self.data2vec_output.name):
|
| 557 |
+
self.data2vec_output.build(None)
|
| 558 |
+
if getattr(self, "layernorm_before", None) is not None:
|
| 559 |
+
with tf.name_scope(self.layernorm_before.name):
|
| 560 |
+
self.layernorm_before.build([None, None, self.config.hidden_size])
|
| 561 |
+
if getattr(self, "layernorm_after", None) is not None:
|
| 562 |
+
with tf.name_scope(self.layernorm_after.name):
|
| 563 |
+
self.layernorm_after.build([None, None, self.config.hidden_size])
|
| 564 |
+
if getattr(self, "drop_path", None) is not None:
|
| 565 |
+
with tf.name_scope(self.drop_path.name):
|
| 566 |
+
self.drop_path.build(None)
|
| 567 |
+
|
| 568 |
+
def call(
|
| 569 |
+
self,
|
| 570 |
+
hidden_states: tf.Tensor,
|
| 571 |
+
head_mask: tf.Tensor,
|
| 572 |
+
output_attentions: bool,
|
| 573 |
+
relative_position_bias: Optional["TFData2VecVisionRelativePositionBias"] = None,
|
| 574 |
+
training: bool = False,
|
| 575 |
+
) -> Tuple[tf.Tensor]:
|
| 576 |
+
self_attention_outputs = self.attention(
|
| 577 |
+
# in Data2VecVision, layernorm is applied before self-attention
|
| 578 |
+
input_tensor=self.layernorm_before(inputs=hidden_states),
|
| 579 |
+
head_mask=head_mask,
|
| 580 |
+
output_attentions=output_attentions,
|
| 581 |
+
relative_position_bias=relative_position_bias,
|
| 582 |
+
training=training,
|
| 583 |
+
)
|
| 584 |
+
attention_output = self_attention_outputs[0]
|
| 585 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 586 |
+
|
| 587 |
+
# apply lambda_1 if present
|
| 588 |
+
if self.lambda_1 is not None:
|
| 589 |
+
attention_output = self.lambda_1 * attention_output
|
| 590 |
+
|
| 591 |
+
# first residual connection
|
| 592 |
+
hidden_states = self.drop_path(attention_output) + hidden_states
|
| 593 |
+
|
| 594 |
+
# in Data2VecVision, layernorm is also applied after self-attention
|
| 595 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 596 |
+
|
| 597 |
+
layer_output = self.intermediate(layer_output)
|
| 598 |
+
layer_output = self.data2vec_output(layer_output)
|
| 599 |
+
|
| 600 |
+
if self.lambda_2 is not None:
|
| 601 |
+
layer_output = self.lambda_2 * layer_output
|
| 602 |
+
|
| 603 |
+
# second residual connection
|
| 604 |
+
layer_output = self.drop_path(layer_output) + hidden_states
|
| 605 |
+
|
| 606 |
+
outputs = (layer_output,) + outputs
|
| 607 |
+
|
| 608 |
+
return outputs
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Taken and modified from here:
|
| 612 |
+
# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28
|
| 613 |
+
class TFData2VecVisionRelativePositionBias(keras.layers.Layer):
|
| 614 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None:
|
| 615 |
+
super().__init__(**kwargs)
|
| 616 |
+
self.config = config
|
| 617 |
+
|
| 618 |
+
self.window_size = window_size
|
| 619 |
+
# +3 for cls_token_pos_len
|
| 620 |
+
# window_size can be something like (14, 14)
|
| 621 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 622 |
+
|
| 623 |
+
self.relative_position_index = self.get_position_index()
|
| 624 |
+
|
| 625 |
+
def build(self, input_shape):
|
| 626 |
+
self.relative_position_bias_table = self.add_weight(
|
| 627 |
+
shape=(self.num_relative_distance, self.config.num_attention_heads),
|
| 628 |
+
initializer="zeros",
|
| 629 |
+
trainable=True,
|
| 630 |
+
name="relative_position_bias_table",
|
| 631 |
+
) # [2*Wh-1 * 2*Ww-1, nH]
|
| 632 |
+
# cls to token & token 2 cls & cls to cls
|
| 633 |
+
|
| 634 |
+
super().build(input_shape)
|
| 635 |
+
|
| 636 |
+
def get_position_index(self):
|
| 637 |
+
# get pair-wise relative position index for each token inside the window
|
| 638 |
+
xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1]))
|
| 639 |
+
coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww]
|
| 640 |
+
coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww]
|
| 641 |
+
|
| 642 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Wh*Ww, Wh*Ww]
|
| 643 |
+
relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0]) # [Wh*Ww, Wh*Ww, 2]
|
| 644 |
+
|
| 645 |
+
xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
|
| 646 |
+
yy = relative_coords[:, :, 1] + self.window_size[1] - 1
|
| 647 |
+
relative_coords = tf.stack([xx, yy], axis=-1)
|
| 648 |
+
|
| 649 |
+
relative_position_index = tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww]
|
| 650 |
+
|
| 651 |
+
top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * (
|
| 652 |
+
self.num_relative_distance - 3
|
| 653 |
+
)
|
| 654 |
+
left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * (
|
| 655 |
+
self.num_relative_distance - 2
|
| 656 |
+
)
|
| 657 |
+
corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1)
|
| 658 |
+
|
| 659 |
+
left_corner = tf.concat([corner, left], axis=0)
|
| 660 |
+
relative_position_index = tf.concat([top, relative_position_index], axis=0)
|
| 661 |
+
relative_position_index = tf.concat([left_corner, relative_position_index], axis=1) # [Wh*Ww + 1, Wh*Ww + 1]
|
| 662 |
+
return relative_position_index
|
| 663 |
+
|
| 664 |
+
def call(self, inputs=None) -> tf.Tensor:
|
| 665 |
+
relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0)
|
| 666 |
+
return tf.transpose(relative_position_bias, [2, 0, 1])
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class TFData2VecVisionEncoder(keras.layers.Layer):
|
| 670 |
+
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, **kwargs):
|
| 671 |
+
super().__init__(**kwargs)
|
| 672 |
+
self.config = config
|
| 673 |
+
if config.use_shared_relative_position_bias:
|
| 674 |
+
self.relative_position_bias = TFData2VecVisionRelativePositionBias(
|
| 675 |
+
config, window_size=window_size, name="relative_position_bias"
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
self.relative_position_bias = None
|
| 679 |
+
|
| 680 |
+
# stochastic depth decay rule
|
| 681 |
+
dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers))
|
| 682 |
+
self.layer = [
|
| 683 |
+
TFData2VecVisionLayer(
|
| 684 |
+
config,
|
| 685 |
+
window_size=window_size if config.use_relative_position_bias else None,
|
| 686 |
+
drop_path_rate=dpr[i],
|
| 687 |
+
name=f"layer_._{i}",
|
| 688 |
+
)
|
| 689 |
+
for i in range(config.num_hidden_layers)
|
| 690 |
+
]
|
| 691 |
+
|
| 692 |
+
def call(
|
| 693 |
+
self,
|
| 694 |
+
hidden_states: tf.Tensor,
|
| 695 |
+
head_mask: tf.Tensor | None = None,
|
| 696 |
+
output_attentions: bool = False,
|
| 697 |
+
output_hidden_states: bool = False,
|
| 698 |
+
return_dict: bool = True,
|
| 699 |
+
) -> Union[tuple, TFBaseModelOutput]:
|
| 700 |
+
all_hidden_states = () if output_hidden_states else None
|
| 701 |
+
all_self_attentions = () if output_attentions else None
|
| 702 |
+
|
| 703 |
+
for i, layer_module in enumerate(self.layer):
|
| 704 |
+
if output_hidden_states:
|
| 705 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 706 |
+
|
| 707 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 708 |
+
# Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
|
| 709 |
+
# might complain about `Layer.call()` not being invoked properly. In this case this input
|
| 710 |
+
# i.e., 0.0 is not going to be used in any calculations so we're safe.
|
| 711 |
+
relative_position_bias = (
|
| 712 |
+
self.relative_position_bias(0.0) if self.relative_position_bias is not None else None
|
| 713 |
+
)
|
| 714 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
|
| 715 |
+
|
| 716 |
+
hidden_states = layer_outputs[0]
|
| 717 |
+
|
| 718 |
+
if output_attentions:
|
| 719 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 720 |
+
|
| 721 |
+
if output_hidden_states:
|
| 722 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 723 |
+
|
| 724 |
+
if not return_dict:
|
| 725 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 726 |
+
|
| 727 |
+
return TFBaseModelOutput(
|
| 728 |
+
last_hidden_state=hidden_states,
|
| 729 |
+
hidden_states=all_hidden_states,
|
| 730 |
+
attentions=all_self_attentions,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
def build(self, input_shape=None):
|
| 734 |
+
if self.built:
|
| 735 |
+
return
|
| 736 |
+
self.built = True
|
| 737 |
+
if getattr(self, "relative_position_bias", None) is not None:
|
| 738 |
+
with tf.name_scope(self.relative_position_bias.name):
|
| 739 |
+
self.relative_position_bias.build(None)
|
| 740 |
+
if getattr(self, "layer", None) is not None:
|
| 741 |
+
for layer in self.layer:
|
| 742 |
+
with tf.name_scope(layer.name):
|
| 743 |
+
layer.build(None)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
@keras_serializable
|
| 747 |
+
class TFData2VecVisionMainLayer(keras.layers.Layer):
|
| 748 |
+
config_class = Data2VecVisionConfig
|
| 749 |
+
|
| 750 |
+
def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs):
|
| 751 |
+
super().__init__(**kwargs)
|
| 752 |
+
|
| 753 |
+
self.config = config
|
| 754 |
+
self.add_pooling_layer = add_pooling_layer
|
| 755 |
+
|
| 756 |
+
self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings")
|
| 757 |
+
self.encoder = TFData2VecVisionEncoder(
|
| 758 |
+
config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder"
|
| 759 |
+
)
|
| 760 |
+
self.layernorm = (
|
| 761 |
+
tf.identity
|
| 762 |
+
if config.use_mean_pooling
|
| 763 |
+
else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# We are setting the `data_format` like so because from here on we will revert to the
|
| 767 |
+
# NCHW output format
|
| 768 |
+
self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None
|
| 769 |
+
|
| 770 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 771 |
+
return self.embeddings.patch_embeddings
|
| 772 |
+
|
| 773 |
+
def _prune_heads(self, heads_to_prune):
|
| 774 |
+
"""
|
| 775 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 776 |
+
class PreTrainedModel
|
| 777 |
+
"""
|
| 778 |
+
raise NotImplementedError
|
| 779 |
+
|
| 780 |
+
@unpack_inputs
|
| 781 |
+
def call(
|
| 782 |
+
self,
|
| 783 |
+
pixel_values: tf.Tensor | None = None,
|
| 784 |
+
bool_masked_pos: tf.Tensor | None = None,
|
| 785 |
+
head_mask: tf.Tensor | None = None,
|
| 786 |
+
output_attentions: Optional[bool] = None,
|
| 787 |
+
output_hidden_states: Optional[bool] = None,
|
| 788 |
+
return_dict: Optional[bool] = None,
|
| 789 |
+
training: bool = False,
|
| 790 |
+
) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
|
| 791 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 792 |
+
output_hidden_states = (
|
| 793 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 794 |
+
)
|
| 795 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 796 |
+
|
| 797 |
+
if pixel_values is None:
|
| 798 |
+
raise ValueError("You have to specify pixel_values")
|
| 799 |
+
|
| 800 |
+
# Prepare head mask if needed
|
| 801 |
+
# 1.0 in head_mask indicate we keep the head
|
| 802 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 803 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 804 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 805 |
+
if head_mask is not None:
|
| 806 |
+
raise NotImplementedError
|
| 807 |
+
else:
|
| 808 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 809 |
+
|
| 810 |
+
embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training)
|
| 811 |
+
|
| 812 |
+
encoder_outputs = self.encoder(
|
| 813 |
+
embedding_output,
|
| 814 |
+
head_mask=head_mask,
|
| 815 |
+
output_attentions=output_attentions,
|
| 816 |
+
output_hidden_states=output_hidden_states,
|
| 817 |
+
return_dict=return_dict,
|
| 818 |
+
training=training,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
sequence_output = encoder_outputs[0]
|
| 822 |
+
sequence_output = self.layernorm(sequence_output)
|
| 823 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 824 |
+
|
| 825 |
+
if not return_dict:
|
| 826 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
| 827 |
+
return head_outputs + encoder_outputs[1:]
|
| 828 |
+
|
| 829 |
+
return TFData2VecVisionModelOutputWithPooling(
|
| 830 |
+
last_hidden_state=sequence_output,
|
| 831 |
+
pooler_output=pooled_output,
|
| 832 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 833 |
+
attentions=encoder_outputs.attentions,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
def build(self, input_shape=None):
|
| 837 |
+
if self.built:
|
| 838 |
+
return
|
| 839 |
+
self.built = True
|
| 840 |
+
if getattr(self, "embeddings", None) is not None:
|
| 841 |
+
with tf.name_scope(self.embeddings.name):
|
| 842 |
+
self.embeddings.build(None)
|
| 843 |
+
if getattr(self, "encoder", None) is not None:
|
| 844 |
+
with tf.name_scope(self.encoder.name):
|
| 845 |
+
self.encoder.build(None)
|
| 846 |
+
if getattr(self, "layernorm", None) is not None:
|
| 847 |
+
if hasattr(self.layernorm, "name"):
|
| 848 |
+
with tf.name_scope(self.layernorm.name):
|
| 849 |
+
self.layernorm.build((None, self.config.hidden_size))
|
| 850 |
+
if getattr(self, "pooler", None) is not None:
|
| 851 |
+
with tf.name_scope(self.pooler.name):
|
| 852 |
+
self.pooler.build(None)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class TFData2VecVisionPooler(keras.layers.Layer):
|
| 856 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs):
|
| 857 |
+
super().__init__(**kwargs)
|
| 858 |
+
self.layernorm = (
|
| 859 |
+
keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 860 |
+
if config.use_mean_pooling
|
| 861 |
+
else None
|
| 862 |
+
)
|
| 863 |
+
self.config = config
|
| 864 |
+
|
| 865 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 866 |
+
if self.layernorm is not None:
|
| 867 |
+
# Mean pool the final hidden states of the patch tokens
|
| 868 |
+
patch_tokens = hidden_states[:, 1:, :]
|
| 869 |
+
pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1))
|
| 870 |
+
else:
|
| 871 |
+
# Pool by simply taking the final hidden state of the [CLS] token
|
| 872 |
+
pooled_output = hidden_states[:, 0]
|
| 873 |
+
|
| 874 |
+
return pooled_output
|
| 875 |
+
|
| 876 |
+
def build(self, input_shape=None):
|
| 877 |
+
if self.built:
|
| 878 |
+
return
|
| 879 |
+
self.built = True
|
| 880 |
+
if getattr(self, "layernorm", None) is not None:
|
| 881 |
+
if hasattr(self.layernorm, "name"):
|
| 882 |
+
with tf.name_scope(self.layernorm.name):
|
| 883 |
+
self.layernorm.build((None, self.config.hidden_size))
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
|
| 887 |
+
"""
|
| 888 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 889 |
+
models.
|
| 890 |
+
"""
|
| 891 |
+
|
| 892 |
+
config_class = Data2VecVisionConfig
|
| 893 |
+
base_model_prefix = "data2vec_vision"
|
| 894 |
+
main_input_name = "pixel_values"
|
| 895 |
+
_keys_to_ignore_on_load_unexpected = [r"relative_position_index"]
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
DATA2VEC_VISION_START_DOCSTRING = r"""
|
| 899 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 900 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 901 |
+
etc.).
|
| 902 |
+
|
| 903 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 904 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 905 |
+
behavior.
|
| 906 |
+
|
| 907 |
+
<Tip>
|
| 908 |
+
|
| 909 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 910 |
+
|
| 911 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 912 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 913 |
+
|
| 914 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 915 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 916 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 917 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 918 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 919 |
+
positional argument:
|
| 920 |
+
|
| 921 |
+
- a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
|
| 922 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 923 |
+
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
|
| 924 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 925 |
+
`model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
|
| 926 |
+
|
| 927 |
+
Note that when creating models and layers with
|
| 928 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 929 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 930 |
+
|
| 931 |
+
</Tip>
|
| 932 |
+
|
| 933 |
+
Args:
|
| 934 |
+
config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
|
| 935 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 936 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 937 |
+
"""
|
| 938 |
+
|
| 939 |
+
DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
|
| 940 |
+
Args:
|
| 941 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 942 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 943 |
+
[`BeitImageProcessor.__call__`] for details.
|
| 944 |
+
|
| 945 |
+
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 946 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 947 |
+
- 1 indicates the head is **not masked**,
|
| 948 |
+
- 0 indicates the head is **masked**.
|
| 949 |
+
|
| 950 |
+
output_attentions (`bool`, *optional*):
|
| 951 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 952 |
+
tensors for more detail.
|
| 953 |
+
|
| 954 |
+
output_hidden_states (`bool`, *optional*):
|
| 955 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 956 |
+
more detail.
|
| 957 |
+
|
| 958 |
+
return_dict (`bool`, *optional*):
|
| 959 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
|
| 960 |
+
in eager mode, in graph mode the value will always be set to True.
|
| 961 |
+
|
| 962 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 963 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 964 |
+
behaviors between training and evaluation).
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
@add_start_docstrings(
|
| 969 |
+
"The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
|
| 970 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 971 |
+
)
|
| 972 |
+
class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
|
| 973 |
+
def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs):
|
| 974 |
+
super().__init__(config, *inputs, **kwargs)
|
| 975 |
+
self.config = config
|
| 976 |
+
|
| 977 |
+
self.data2vec_vision = TFData2VecVisionMainLayer(
|
| 978 |
+
config, add_pooling_layer=add_pooling_layer, name="data2vec_vision"
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
def get_input_embeddings(self):
|
| 982 |
+
return self.data2vec_vision.get_input_embeddings()
|
| 983 |
+
|
| 984 |
+
@unpack_inputs
|
| 985 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 986 |
+
@add_code_sample_docstrings(
|
| 987 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 988 |
+
output_type=TFData2VecVisionModelOutputWithPooling,
|
| 989 |
+
config_class=_CONFIG_FOR_DOC,
|
| 990 |
+
modality="vision",
|
| 991 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 992 |
+
)
|
| 993 |
+
def call(
|
| 994 |
+
self,
|
| 995 |
+
pixel_values: TFModelInputType | None = None,
|
| 996 |
+
bool_masked_pos: tf.Tensor | None = None,
|
| 997 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 998 |
+
output_attentions: Optional[bool] = None,
|
| 999 |
+
output_hidden_states: Optional[bool] = None,
|
| 1000 |
+
return_dict: Optional[bool] = None,
|
| 1001 |
+
training: bool = False,
|
| 1002 |
+
) -> Union[tuple, TFData2VecVisionModelOutputWithPooling]:
|
| 1003 |
+
r"""
|
| 1004 |
+
bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*):
|
| 1005 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 1006 |
+
"""
|
| 1007 |
+
outputs = self.data2vec_vision(
|
| 1008 |
+
pixel_values=pixel_values,
|
| 1009 |
+
bool_masked_pos=bool_masked_pos,
|
| 1010 |
+
head_mask=head_mask,
|
| 1011 |
+
output_attentions=output_attentions,
|
| 1012 |
+
output_hidden_states=output_hidden_states,
|
| 1013 |
+
return_dict=return_dict,
|
| 1014 |
+
training=training,
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
return outputs
|
| 1018 |
+
|
| 1019 |
+
def build(self, input_shape=None):
|
| 1020 |
+
if self.built:
|
| 1021 |
+
return
|
| 1022 |
+
self.built = True
|
| 1023 |
+
if getattr(self, "data2vec_vision", None) is not None:
|
| 1024 |
+
with tf.name_scope(self.data2vec_vision.name):
|
| 1025 |
+
self.data2vec_vision.build(None)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
@add_start_docstrings(
|
| 1029 |
+
"""
|
| 1030 |
+
Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
|
| 1031 |
+
the final hidden states of the patch tokens) e.g. for ImageNet.
|
| 1032 |
+
""",
|
| 1033 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 1034 |
+
)
|
| 1035 |
+
class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss):
|
| 1036 |
+
def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs):
|
| 1037 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1038 |
+
|
| 1039 |
+
self.num_labels = config.num_labels
|
| 1040 |
+
self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision")
|
| 1041 |
+
|
| 1042 |
+
# Classifier head
|
| 1043 |
+
self.classifier = keras.layers.Dense(
|
| 1044 |
+
units=config.num_labels,
|
| 1045 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1046 |
+
name="classifier",
|
| 1047 |
+
)
|
| 1048 |
+
self.config = config
|
| 1049 |
+
|
| 1050 |
+
@unpack_inputs
|
| 1051 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 1052 |
+
@add_code_sample_docstrings(
|
| 1053 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 1054 |
+
output_type=TFSequenceClassifierOutput,
|
| 1055 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1056 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 1057 |
+
)
|
| 1058 |
+
def call(
|
| 1059 |
+
self,
|
| 1060 |
+
pixel_values: TFModelInputType | None = None,
|
| 1061 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1062 |
+
output_attentions: Optional[bool] = None,
|
| 1063 |
+
output_hidden_states: Optional[bool] = None,
|
| 1064 |
+
return_dict: Optional[bool] = None,
|
| 1065 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1066 |
+
training: Optional[bool] = False,
|
| 1067 |
+
) -> Union[TFSequenceClassifierOutput, tuple]:
|
| 1068 |
+
r"""
|
| 1069 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1070 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 1071 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1072 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1073 |
+
"""
|
| 1074 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1075 |
+
|
| 1076 |
+
outputs = self.data2vec_vision(
|
| 1077 |
+
pixel_values=pixel_values,
|
| 1078 |
+
head_mask=head_mask,
|
| 1079 |
+
output_attentions=output_attentions,
|
| 1080 |
+
output_hidden_states=output_hidden_states,
|
| 1081 |
+
return_dict=return_dict,
|
| 1082 |
+
training=training,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
| 1086 |
+
logits = self.classifier(pooled_output)
|
| 1087 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1088 |
+
|
| 1089 |
+
if not return_dict:
|
| 1090 |
+
output = (logits,) + outputs[2:]
|
| 1091 |
+
return ((loss,) + output) if loss is not None else output
|
| 1092 |
+
|
| 1093 |
+
return TFSequenceClassifierOutput(
|
| 1094 |
+
loss=loss,
|
| 1095 |
+
logits=logits,
|
| 1096 |
+
hidden_states=outputs.hidden_states,
|
| 1097 |
+
attentions=outputs.attentions,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
def build(self, input_shape=None):
|
| 1101 |
+
if self.built:
|
| 1102 |
+
return
|
| 1103 |
+
self.built = True
|
| 1104 |
+
if getattr(self, "data2vec_vision", None) is not None:
|
| 1105 |
+
with tf.name_scope(self.data2vec_vision.name):
|
| 1106 |
+
self.data2vec_vision.build(None)
|
| 1107 |
+
if getattr(self, "classifier", None) is not None:
|
| 1108 |
+
with tf.name_scope(self.classifier.name):
|
| 1109 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
class TFData2VecVisionConvModule(keras.layers.Layer):
|
| 1113 |
+
"""
|
| 1114 |
+
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
| 1115 |
+
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
| 1116 |
+
|
| 1117 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1118 |
+
"""
|
| 1119 |
+
|
| 1120 |
+
def __init__(
|
| 1121 |
+
self,
|
| 1122 |
+
in_channels: int,
|
| 1123 |
+
out_channels: int,
|
| 1124 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 1125 |
+
padding: str = "valid",
|
| 1126 |
+
bias: bool = False,
|
| 1127 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 1128 |
+
**kwargs,
|
| 1129 |
+
) -> None:
|
| 1130 |
+
super().__init__(**kwargs)
|
| 1131 |
+
self.conv = keras.layers.Conv2D(
|
| 1132 |
+
filters=out_channels,
|
| 1133 |
+
kernel_size=kernel_size,
|
| 1134 |
+
padding=padding,
|
| 1135 |
+
use_bias=bias,
|
| 1136 |
+
dilation_rate=dilation,
|
| 1137 |
+
name="conv",
|
| 1138 |
+
)
|
| 1139 |
+
self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
|
| 1140 |
+
self.activation = tf.nn.relu
|
| 1141 |
+
self.in_channels = in_channels
|
| 1142 |
+
self.out_channels = out_channels
|
| 1143 |
+
|
| 1144 |
+
def call(self, input: tf.Tensor) -> tf.Tensor:
|
| 1145 |
+
output = self.conv(input)
|
| 1146 |
+
output = self.bn(output)
|
| 1147 |
+
output = self.activation(output)
|
| 1148 |
+
return output
|
| 1149 |
+
|
| 1150 |
+
def build(self, input_shape=None):
|
| 1151 |
+
if self.built:
|
| 1152 |
+
return
|
| 1153 |
+
self.built = True
|
| 1154 |
+
if getattr(self, "conv", None) is not None:
|
| 1155 |
+
with tf.name_scope(self.conv.name):
|
| 1156 |
+
self.conv.build([None, None, None, self.in_channels])
|
| 1157 |
+
if getattr(self, "bn", None) is not None:
|
| 1158 |
+
with tf.name_scope(self.bn.name):
|
| 1159 |
+
self.bn.build((None, None, None, self.out_channels))
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
class TFAdaptiveAvgPool2D(keras.layers.Layer):
|
| 1163 |
+
def __init__(self, output_dims: Tuple[int, int], input_ordering: str = "NHWC", **kwargs):
|
| 1164 |
+
super().__init__(**kwargs)
|
| 1165 |
+
self.output_dims = output_dims
|
| 1166 |
+
self.input_ordering = input_ordering
|
| 1167 |
+
if input_ordering not in ("NCHW", "NHWC"):
|
| 1168 |
+
raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!")
|
| 1169 |
+
self.h_axis = input_ordering.index("H")
|
| 1170 |
+
self.w_axis = input_ordering.index("W")
|
| 1171 |
+
|
| 1172 |
+
def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
|
| 1173 |
+
# Figure out which axis we're pooling on
|
| 1174 |
+
if h_pooling:
|
| 1175 |
+
axis = self.h_axis
|
| 1176 |
+
output_dim = self.output_dims[0]
|
| 1177 |
+
else:
|
| 1178 |
+
axis = self.w_axis
|
| 1179 |
+
output_dim = self.output_dims[1]
|
| 1180 |
+
input_dim = inputs.shape[axis]
|
| 1181 |
+
|
| 1182 |
+
# Figure out the potential pooling windows
|
| 1183 |
+
# This is the key idea - the torch op always uses only two
|
| 1184 |
+
# consecutive pooling window sizes, like 3 and 4. Therefore,
|
| 1185 |
+
# if we pool with both possible sizes, we simply need to gather
|
| 1186 |
+
# the 'correct' pool at each position to reimplement the torch op.
|
| 1187 |
+
small_window = math.ceil(input_dim / output_dim)
|
| 1188 |
+
big_window = small_window + 1
|
| 1189 |
+
if h_pooling:
|
| 1190 |
+
output_dim = self.output_dims[0]
|
| 1191 |
+
small_window_shape = (small_window, 1)
|
| 1192 |
+
big_window_shape = (big_window, 1)
|
| 1193 |
+
else:
|
| 1194 |
+
output_dim = self.output_dims[1]
|
| 1195 |
+
small_window_shape = (1, small_window)
|
| 1196 |
+
big_window_shape = (1, big_window)
|
| 1197 |
+
|
| 1198 |
+
# For resizes to 1, or integer resizes, we can take quick shortcuts
|
| 1199 |
+
if output_dim == input_dim:
|
| 1200 |
+
return inputs
|
| 1201 |
+
elif output_dim == 1:
|
| 1202 |
+
return tf.reduce_mean(inputs, axis=axis, keepdims=True)
|
| 1203 |
+
elif input_dim % output_dim == 0:
|
| 1204 |
+
return tf.nn.avg_pool2d(
|
| 1205 |
+
inputs,
|
| 1206 |
+
ksize=small_window_shape,
|
| 1207 |
+
strides=small_window_shape,
|
| 1208 |
+
padding="VALID",
|
| 1209 |
+
data_format=self.input_ordering,
|
| 1210 |
+
)
|
| 1211 |
+
# When upscaling by an integer factor we can also take a quick shortcut
|
| 1212 |
+
elif output_dim > input_dim and output_dim % input_dim == 0:
|
| 1213 |
+
return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis)
|
| 1214 |
+
|
| 1215 |
+
# For non-integer resizes, we pool with both possible window sizes and concatenate them
|
| 1216 |
+
if output_dim < input_dim:
|
| 1217 |
+
small_pool = tf.nn.avg_pool2d(
|
| 1218 |
+
inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
|
| 1219 |
+
)
|
| 1220 |
+
big_pool = tf.nn.avg_pool2d(
|
| 1221 |
+
inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
|
| 1222 |
+
)
|
| 1223 |
+
both_pool = tf.concat([small_pool, big_pool], axis=axis)
|
| 1224 |
+
else:
|
| 1225 |
+
# When we're actually upscaling instead, then we build the pools a bit differently
|
| 1226 |
+
small_pool = inputs
|
| 1227 |
+
big_pool = tf.nn.avg_pool2d(
|
| 1228 |
+
inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
|
| 1229 |
+
)
|
| 1230 |
+
both_pool = tf.concat([small_pool, big_pool], axis=axis)
|
| 1231 |
+
|
| 1232 |
+
# We compute vectors of the start and end positions for each pooling window
|
| 1233 |
+
# Each (start, end) pair here corresponds to a single output position
|
| 1234 |
+
window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim)
|
| 1235 |
+
window_starts = tf.cast(window_starts, tf.int64)
|
| 1236 |
+
window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim)
|
| 1237 |
+
window_ends = tf.cast(window_ends, tf.int64)
|
| 1238 |
+
|
| 1239 |
+
# pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position
|
| 1240 |
+
# has a big receptive field and 0 indicates that that output position has a small receptive field
|
| 1241 |
+
pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool)
|
| 1242 |
+
|
| 1243 |
+
# Since we concatenated the small and big pools, we need to do a bit of
|
| 1244 |
+
# pointer arithmetic to get the indices of the big pools
|
| 1245 |
+
small_indices = window_starts
|
| 1246 |
+
big_indices = window_starts + small_pool.shape[axis]
|
| 1247 |
+
|
| 1248 |
+
# Finally, we use the pool_selector to generate a list of indices, one per output position
|
| 1249 |
+
gather_indices = tf.where(pool_selector, big_indices, small_indices)
|
| 1250 |
+
|
| 1251 |
+
# Gathering from those indices yields the final, correct pooling
|
| 1252 |
+
return tf.gather(both_pool, gather_indices, axis=axis)
|
| 1253 |
+
|
| 1254 |
+
def call(self, inputs: tf.Tensor):
|
| 1255 |
+
if self.input_ordering == "NHWC":
|
| 1256 |
+
input_shape = inputs.shape[1:3]
|
| 1257 |
+
else:
|
| 1258 |
+
input_shape = inputs.shape[2:]
|
| 1259 |
+
|
| 1260 |
+
# We break the task down into each possible case
|
| 1261 |
+
# Firstly, if we're resizing down to 1, it's just tf.reduce_mean
|
| 1262 |
+
if self.output_dims[0] == self.output_dims[1] == 1:
|
| 1263 |
+
if self.input_ordering == "NHWC":
|
| 1264 |
+
reduce_dims = [1, 2]
|
| 1265 |
+
else:
|
| 1266 |
+
reduce_dims = [2, 3]
|
| 1267 |
+
return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True)
|
| 1268 |
+
# Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut
|
| 1269 |
+
elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0:
|
| 1270 |
+
h_resize = int(input_shape[0] // self.output_dims[0])
|
| 1271 |
+
w_resize = int(input_shape[1] // self.output_dims[1])
|
| 1272 |
+
return tf.nn.avg_pool2d(
|
| 1273 |
+
inputs,
|
| 1274 |
+
ksize=(h_resize, w_resize),
|
| 1275 |
+
strides=(h_resize, w_resize),
|
| 1276 |
+
padding="VALID",
|
| 1277 |
+
data_format=self.input_ordering,
|
| 1278 |
+
)
|
| 1279 |
+
else:
|
| 1280 |
+
# Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut
|
| 1281 |
+
# for dimensions where an integer resize is possible. It can also handle upscaling.
|
| 1282 |
+
h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
|
| 1283 |
+
return self.pseudo_1d_pool(h_pooled, h_pooling=False)
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer):
|
| 1287 |
+
"""
|
| 1288 |
+
Pyramid Pooling Module (PPM) used in PSPNet.
|
| 1289 |
+
|
| 1290 |
+
Args:
|
| 1291 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 1292 |
+
Module.
|
| 1293 |
+
channels (int): Channels after modules, before conv_seg.
|
| 1294 |
+
|
| 1295 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1296 |
+
"""
|
| 1297 |
+
|
| 1298 |
+
def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None:
|
| 1299 |
+
super().__init__(**kwargs)
|
| 1300 |
+
self.pool_scales = pool_scales
|
| 1301 |
+
self.in_channels = in_channels
|
| 1302 |
+
self.out_channels = out_channels
|
| 1303 |
+
|
| 1304 |
+
self.layer_list = []
|
| 1305 |
+
for idx, pool_scale in enumerate(pool_scales):
|
| 1306 |
+
pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
|
| 1307 |
+
self.layer_list.append(
|
| 1308 |
+
[
|
| 1309 |
+
TFAdaptiveAvgPool2D(output_dims=pool_scale),
|
| 1310 |
+
TFData2VecVisionConvModule(
|
| 1311 |
+
in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1"
|
| 1312 |
+
),
|
| 1313 |
+
]
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
def call(self, x: tf.Tensor) -> List[tf.Tensor]:
|
| 1317 |
+
ppm_outs = []
|
| 1318 |
+
inputs = x
|
| 1319 |
+
|
| 1320 |
+
for ppm in self.layer_list:
|
| 1321 |
+
for layer_module in ppm:
|
| 1322 |
+
ppm_out = layer_module(x)
|
| 1323 |
+
x = ppm_out
|
| 1324 |
+
|
| 1325 |
+
upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
|
| 1326 |
+
ppm_outs.append(upsampled_ppm_out)
|
| 1327 |
+
return ppm_outs
|
| 1328 |
+
|
| 1329 |
+
def build(self, input_shape=None):
|
| 1330 |
+
for layer in self.layer_list:
|
| 1331 |
+
for layer_module in layer:
|
| 1332 |
+
with tf.name_scope(layer_module.name):
|
| 1333 |
+
layer_module.build(None)
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
class TFData2VecVisionUperHead(keras.layers.Layer):
|
| 1337 |
+
"""
|
| 1338 |
+
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
|
| 1339 |
+
[UPerNet](https://arxiv.org/abs/1807.10221).
|
| 1340 |
+
|
| 1341 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1342 |
+
"""
|
| 1343 |
+
|
| 1344 |
+
def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
|
| 1345 |
+
super().__init__(**kwargs)
|
| 1346 |
+
|
| 1347 |
+
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
| 1348 |
+
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
| 1349 |
+
self.channels = config.hidden_size
|
| 1350 |
+
self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
|
| 1351 |
+
|
| 1352 |
+
# PSP Module
|
| 1353 |
+
self.psp_modules = TFData2VecVisionPyramidPoolingModule(
|
| 1354 |
+
self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules"
|
| 1355 |
+
)
|
| 1356 |
+
self.bottleneck = TFData2VecVisionConvModule(
|
| 1357 |
+
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
| 1358 |
+
self.channels,
|
| 1359 |
+
kernel_size=3,
|
| 1360 |
+
padding="same",
|
| 1361 |
+
name="bottleneck",
|
| 1362 |
+
)
|
| 1363 |
+
# FPN Module
|
| 1364 |
+
self.lateral_convs = []
|
| 1365 |
+
self.fpn_convs = []
|
| 1366 |
+
for idx, in_channels in enumerate(self.in_channels[:-1]): # skip the top layer
|
| 1367 |
+
l_conv = TFData2VecVisionConvModule(
|
| 1368 |
+
in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}"
|
| 1369 |
+
)
|
| 1370 |
+
fpn_conv = TFData2VecVisionConvModule(
|
| 1371 |
+
in_channels=self.channels,
|
| 1372 |
+
out_channels=self.channels,
|
| 1373 |
+
kernel_size=3,
|
| 1374 |
+
padding="same",
|
| 1375 |
+
name=f"fpn_convs.{idx}",
|
| 1376 |
+
)
|
| 1377 |
+
self.lateral_convs.append(l_conv)
|
| 1378 |
+
self.fpn_convs.append(fpn_conv)
|
| 1379 |
+
|
| 1380 |
+
self.fpn_bottleneck = TFData2VecVisionConvModule(
|
| 1381 |
+
in_channels=len(self.in_channels) * self.channels,
|
| 1382 |
+
out_channels=self.channels,
|
| 1383 |
+
kernel_size=3,
|
| 1384 |
+
padding="same",
|
| 1385 |
+
name="fpn_bottleneck",
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
def psp_forward(self, inputs):
|
| 1389 |
+
x = inputs[-1]
|
| 1390 |
+
psp_outs = [x]
|
| 1391 |
+
psp_outs.extend(self.psp_modules(x))
|
| 1392 |
+
psp_outs = tf.concat(psp_outs, axis=-1)
|
| 1393 |
+
output = self.bottleneck(psp_outs)
|
| 1394 |
+
|
| 1395 |
+
return output
|
| 1396 |
+
|
| 1397 |
+
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
|
| 1398 |
+
# build laterals
|
| 1399 |
+
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
| 1400 |
+
|
| 1401 |
+
laterals.append(self.psp_forward(encoder_hidden_states))
|
| 1402 |
+
|
| 1403 |
+
# build top-down path
|
| 1404 |
+
used_backbone_levels = len(laterals)
|
| 1405 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 1406 |
+
prev_shape = shape_list(laterals[i - 1])[1:-1]
|
| 1407 |
+
laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
|
| 1408 |
+
|
| 1409 |
+
# build outputs
|
| 1410 |
+
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
| 1411 |
+
# append psp feature
|
| 1412 |
+
fpn_outs.append(laterals[-1])
|
| 1413 |
+
|
| 1414 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 1415 |
+
fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
|
| 1416 |
+
fpn_outs = tf.concat(fpn_outs, axis=-1)
|
| 1417 |
+
output = self.fpn_bottleneck(fpn_outs)
|
| 1418 |
+
output = self.classifier(output)
|
| 1419 |
+
|
| 1420 |
+
return output
|
| 1421 |
+
|
| 1422 |
+
def build(self, input_shape=None):
|
| 1423 |
+
if self.built:
|
| 1424 |
+
return
|
| 1425 |
+
self.built = True
|
| 1426 |
+
if getattr(self, "classifier", None) is not None:
|
| 1427 |
+
with tf.name_scope(self.classifier.name):
|
| 1428 |
+
self.classifier.build([None, None, None, self.channels])
|
| 1429 |
+
if getattr(self, "psp_modules", None) is not None:
|
| 1430 |
+
with tf.name_scope(self.psp_modules.name):
|
| 1431 |
+
self.psp_modules.build(None)
|
| 1432 |
+
if getattr(self, "bottleneck", None) is not None:
|
| 1433 |
+
with tf.name_scope(self.bottleneck.name):
|
| 1434 |
+
self.bottleneck.build(None)
|
| 1435 |
+
if getattr(self, "fpn_bottleneck", None) is not None:
|
| 1436 |
+
with tf.name_scope(self.fpn_bottleneck.name):
|
| 1437 |
+
self.fpn_bottleneck.build(None)
|
| 1438 |
+
for layer in self.lateral_convs:
|
| 1439 |
+
with tf.name_scope(layer.name):
|
| 1440 |
+
layer.build(None)
|
| 1441 |
+
for layer in self.fpn_convs:
|
| 1442 |
+
with tf.name_scope(layer.name):
|
| 1443 |
+
layer.build(None)
|
| 1444 |
+
|
| 1445 |
+
|
| 1446 |
+
class TFData2VecVisionFCNHead(keras.layers.Layer):
|
| 1447 |
+
"""
|
| 1448 |
+
Fully Convolution Networks for Semantic Segmentation. This head is implemented from
|
| 1449 |
+
[FCNNet](https://arxiv.org/abs/1411.4038).
|
| 1450 |
+
|
| 1451 |
+
Args:
|
| 1452 |
+
config (Data2VecVisionConfig): Configuration.
|
| 1453 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
| 1454 |
+
dilation (int): The dilation rate for convs in the head. Default: 1.
|
| 1455 |
+
|
| 1456 |
+
|
| 1457 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
| 1458 |
+
"""
|
| 1459 |
+
|
| 1460 |
+
def __init__(
|
| 1461 |
+
self,
|
| 1462 |
+
config: Data2VecVisionConfig,
|
| 1463 |
+
in_index: int = 2,
|
| 1464 |
+
kernel_size: int = 3,
|
| 1465 |
+
dilation: Union[int, Tuple[int, int]] = 1,
|
| 1466 |
+
**kwargs,
|
| 1467 |
+
) -> None:
|
| 1468 |
+
super().__init__(**kwargs)
|
| 1469 |
+
self.in_channels = config.hidden_size
|
| 1470 |
+
self.channels = config.auxiliary_channels
|
| 1471 |
+
self.num_convs = config.auxiliary_num_convs
|
| 1472 |
+
self.concat_input = config.auxiliary_concat_input
|
| 1473 |
+
self.in_index = in_index
|
| 1474 |
+
|
| 1475 |
+
convs = []
|
| 1476 |
+
convs.append(
|
| 1477 |
+
TFData2VecVisionConvModule(
|
| 1478 |
+
in_channels=self.in_channels,
|
| 1479 |
+
out_channels=self.channels,
|
| 1480 |
+
kernel_size=kernel_size,
|
| 1481 |
+
padding="same",
|
| 1482 |
+
dilation=dilation,
|
| 1483 |
+
name="convs.0",
|
| 1484 |
+
)
|
| 1485 |
+
)
|
| 1486 |
+
for i in range(self.num_convs - 1):
|
| 1487 |
+
convs.append(
|
| 1488 |
+
TFData2VecVisionConvModule(
|
| 1489 |
+
in_channels=self.channels,
|
| 1490 |
+
out_channels=self.channels,
|
| 1491 |
+
kernel_size=kernel_size,
|
| 1492 |
+
padding="same",
|
| 1493 |
+
dilation=dilation,
|
| 1494 |
+
name=f"conv_module_{i + 2}",
|
| 1495 |
+
)
|
| 1496 |
+
)
|
| 1497 |
+
if self.num_convs == 0:
|
| 1498 |
+
self.convs = [tf.identity]
|
| 1499 |
+
else:
|
| 1500 |
+
self.convs = convs
|
| 1501 |
+
if self.concat_input:
|
| 1502 |
+
self.conv_cat = TFData2VecVisionConvModule(
|
| 1503 |
+
self.in_channels + self.channels,
|
| 1504 |
+
out_channels=self.channels,
|
| 1505 |
+
kernel_size=kernel_size,
|
| 1506 |
+
padding="same",
|
| 1507 |
+
name="conv_cat",
|
| 1508 |
+
)
|
| 1509 |
+
|
| 1510 |
+
self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
|
| 1511 |
+
|
| 1512 |
+
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
|
| 1513 |
+
# just take the relevant feature maps
|
| 1514 |
+
hidden_states = encoder_hidden_states[self.in_index]
|
| 1515 |
+
output = hidden_states
|
| 1516 |
+
for layer_module in self.convs:
|
| 1517 |
+
output = layer_module(output)
|
| 1518 |
+
if self.concat_input:
|
| 1519 |
+
output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
|
| 1520 |
+
output = self.classifier(output)
|
| 1521 |
+
return output
|
| 1522 |
+
|
| 1523 |
+
def build(self, input_shape=None):
|
| 1524 |
+
if self.built:
|
| 1525 |
+
return
|
| 1526 |
+
self.built = True
|
| 1527 |
+
if getattr(self, "classifier", None) is not None:
|
| 1528 |
+
with tf.name_scope(self.classifier.name):
|
| 1529 |
+
self.classifier.build([None, None, None, self.channels])
|
| 1530 |
+
if getattr(self, "conv_cat", None) is not None:
|
| 1531 |
+
with tf.name_scope(self.conv_cat.name):
|
| 1532 |
+
self.conv_cat.build(None)
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
@add_start_docstrings(
|
| 1536 |
+
"""
|
| 1537 |
+
Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
| 1538 |
+
""",
|
| 1539 |
+
DATA2VEC_VISION_START_DOCSTRING,
|
| 1540 |
+
)
|
| 1541 |
+
class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
|
| 1542 |
+
def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
|
| 1543 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1544 |
+
self.num_labels = config.num_labels
|
| 1545 |
+
self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
|
| 1546 |
+
|
| 1547 |
+
# FPNs
|
| 1548 |
+
self.fpn1 = [
|
| 1549 |
+
keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
|
| 1550 |
+
keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5),
|
| 1551 |
+
keras.layers.Activation("gelu"),
|
| 1552 |
+
keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
|
| 1553 |
+
]
|
| 1554 |
+
self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
|
| 1555 |
+
|
| 1556 |
+
self.fpn3 = tf.identity
|
| 1557 |
+
self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2)
|
| 1558 |
+
|
| 1559 |
+
# Semantic segmentation head(s)
|
| 1560 |
+
self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
|
| 1561 |
+
self.auxiliary_head = (
|
| 1562 |
+
TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
|
| 1563 |
+
)
|
| 1564 |
+
|
| 1565 |
+
def compute_loss(self, logits, auxiliary_logits, labels):
|
| 1566 |
+
# upsample logits to the images' original size
|
| 1567 |
+
if len(shape_list(labels)) > 3:
|
| 1568 |
+
label_interp_shape = shape_list(labels)[1:-1]
|
| 1569 |
+
else:
|
| 1570 |
+
label_interp_shape = shape_list(labels)[-2:]
|
| 1571 |
+
|
| 1572 |
+
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
|
| 1573 |
+
if auxiliary_logits is not None:
|
| 1574 |
+
upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
|
| 1575 |
+
# compute weighted loss
|
| 1576 |
+
loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
|
| 1577 |
+
|
| 1578 |
+
# Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
|
| 1579 |
+
# Utility to mask the index to ignore during computing the loss.
|
| 1580 |
+
def masked_loss(real, pred):
|
| 1581 |
+
mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
|
| 1582 |
+
loss_ = loss_fct(real, pred)
|
| 1583 |
+
mask = tf.cast(mask, dtype=loss_.dtype)
|
| 1584 |
+
loss_ *= mask
|
| 1585 |
+
reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
|
| 1586 |
+
return tf.reshape(reduced_masked_loss, (1,))
|
| 1587 |
+
|
| 1588 |
+
main_loss = masked_loss(labels, upsampled_logits)
|
| 1589 |
+
auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
|
| 1590 |
+
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
| 1591 |
+
|
| 1592 |
+
return loss
|
| 1593 |
+
|
| 1594 |
+
@unpack_inputs
|
| 1595 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
| 1596 |
+
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
| 1597 |
+
def call(
|
| 1598 |
+
self,
|
| 1599 |
+
pixel_values: tf.Tensor | None = None,
|
| 1600 |
+
head_mask: tf.Tensor | None = None,
|
| 1601 |
+
labels: tf.Tensor | None = None,
|
| 1602 |
+
output_attentions: Optional[bool] = None,
|
| 1603 |
+
output_hidden_states: Optional[bool] = None,
|
| 1604 |
+
return_dict: Optional[bool] = None,
|
| 1605 |
+
) -> Union[tuple, TFSemanticSegmenterOutput]:
|
| 1606 |
+
r"""
|
| 1607 |
+
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
|
| 1608 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
| 1609 |
+
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
| 1610 |
+
|
| 1611 |
+
Returns:
|
| 1612 |
+
|
| 1613 |
+
Examples:
|
| 1614 |
+
|
| 1615 |
+
```python
|
| 1616 |
+
>>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation
|
| 1617 |
+
>>> from PIL import Image
|
| 1618 |
+
>>> import requests
|
| 1619 |
+
|
| 1620 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1621 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1622 |
+
|
| 1623 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
|
| 1624 |
+
>>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
|
| 1625 |
+
|
| 1626 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1627 |
+
>>> outputs = model(**inputs)
|
| 1628 |
+
>>> # logits are of shape (batch_size, num_labels, height, width)
|
| 1629 |
+
>>> logits = outputs.logits
|
| 1630 |
+
```"""
|
| 1631 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1632 |
+
output_hidden_states = (
|
| 1633 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
if labels is not None and self.config.num_labels == 1:
|
| 1637 |
+
raise ValueError("The number of labels should be greater than one")
|
| 1638 |
+
|
| 1639 |
+
outputs = self.data2vec_vision(
|
| 1640 |
+
pixel_values,
|
| 1641 |
+
head_mask=head_mask,
|
| 1642 |
+
output_attentions=output_attentions,
|
| 1643 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
| 1644 |
+
return_dict=return_dict,
|
| 1645 |
+
)
|
| 1646 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 1647 |
+
|
| 1648 |
+
# only keep certain features, and reshape
|
| 1649 |
+
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
| 1650 |
+
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
| 1651 |
+
patch_resolution = self.config.image_size // self.config.patch_size
|
| 1652 |
+
|
| 1653 |
+
def reshape_features(x):
|
| 1654 |
+
# We do it this way so TF can always infer the non-batch dims at compile time
|
| 1655 |
+
x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size))
|
| 1656 |
+
return x
|
| 1657 |
+
|
| 1658 |
+
features = [reshape_features(x[:, 1:, :]) for x in features]
|
| 1659 |
+
|
| 1660 |
+
# apply FPNs
|
| 1661 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
| 1662 |
+
for module in ops[0]:
|
| 1663 |
+
features[0] = module(features[0])
|
| 1664 |
+
features[1] = ops[1][0](features[1])
|
| 1665 |
+
for i in range(len(features[2:])):
|
| 1666 |
+
features[i + 2] = ops[i + 2](features[i + 2])
|
| 1667 |
+
|
| 1668 |
+
logits = self.decode_head(features)
|
| 1669 |
+
# Tranpose the logits to maintain consistency in the output formats.
|
| 1670 |
+
transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
|
| 1671 |
+
|
| 1672 |
+
auxiliary_logits = None
|
| 1673 |
+
if self.auxiliary_head is not None:
|
| 1674 |
+
auxiliary_logits = self.auxiliary_head(features)
|
| 1675 |
+
|
| 1676 |
+
loss = None
|
| 1677 |
+
if labels is not None:
|
| 1678 |
+
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
| 1679 |
+
|
| 1680 |
+
if not return_dict:
|
| 1681 |
+
if output_hidden_states:
|
| 1682 |
+
output = (logits,) + outputs[1:]
|
| 1683 |
+
else:
|
| 1684 |
+
output = (logits,) + outputs[2:]
|
| 1685 |
+
return ((loss,) + output) if loss is not None else output
|
| 1686 |
+
|
| 1687 |
+
return TFSemanticSegmenterOutput(
|
| 1688 |
+
loss=loss,
|
| 1689 |
+
logits=transposed_logits,
|
| 1690 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 1691 |
+
attentions=outputs.attentions,
|
| 1692 |
+
)
|
| 1693 |
+
|
| 1694 |
+
def build(self, input_shape=None):
|
| 1695 |
+
if self.built:
|
| 1696 |
+
return
|
| 1697 |
+
self.built = True
|
| 1698 |
+
if getattr(self, "data2vec_vision", None) is not None:
|
| 1699 |
+
with tf.name_scope(self.data2vec_vision.name):
|
| 1700 |
+
self.data2vec_vision.build(None)
|
| 1701 |
+
if getattr(self, "decode_head", None) is not None:
|
| 1702 |
+
with tf.name_scope(self.decode_head.name):
|
| 1703 |
+
self.decode_head.build(None)
|
| 1704 |
+
if getattr(self, "auxiliary_head", None) is not None:
|
| 1705 |
+
with tf.name_scope(self.auxiliary_head.name):
|
| 1706 |
+
self.auxiliary_head.build(None)
|
| 1707 |
+
if getattr(self, "fpn1", None) is not None:
|
| 1708 |
+
with tf.name_scope(self.fpn1[0].name):
|
| 1709 |
+
self.fpn1[0].build([None, None, None, self.config.hidden_size])
|
| 1710 |
+
with tf.name_scope(self.fpn1[1].name):
|
| 1711 |
+
self.fpn1[1].build((None, None, None, self.config.hidden_size))
|
| 1712 |
+
with tf.name_scope(self.fpn1[3].name):
|
| 1713 |
+
self.fpn1[3].build([None, None, None, self.config.hidden_size])
|
| 1714 |
+
if getattr(self, "fpn2", None) is not None:
|
| 1715 |
+
with tf.name_scope(self.fpn2[0].name):
|
| 1716 |
+
self.fpn2[0].build([None, None, None, self.config.hidden_size])
|
| 1717 |
+
|
| 1718 |
+
|
| 1719 |
+
__all__ = [
|
| 1720 |
+
"TFData2VecVisionForImageClassification",
|
| 1721 |
+
"TFData2VecVisionForSemanticSegmentation",
|
| 1722 |
+
"TFData2VecVisionModel",
|
| 1723 |
+
"TFData2VecVisionPreTrainedModel",
|
| 1724 |
+
]
|
docs/transformers/build/lib/transformers/models/data2vec/modular_data2vec_audio.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from ...activations import ACT2FN
|
| 7 |
+
from ...modeling_outputs import (
|
| 8 |
+
CausalLMOutput,
|
| 9 |
+
SequenceClassifierOutput,
|
| 10 |
+
TokenClassifierOutput,
|
| 11 |
+
Wav2Vec2BaseModelOutput,
|
| 12 |
+
XVectorOutput,
|
| 13 |
+
)
|
| 14 |
+
from ...modeling_utils import PreTrainedModel
|
| 15 |
+
from ...utils import (
|
| 16 |
+
add_code_sample_docstrings,
|
| 17 |
+
add_start_docstrings,
|
| 18 |
+
add_start_docstrings_to_model_forward,
|
| 19 |
+
)
|
| 20 |
+
from ..wav2vec2.modeling_wav2vec2 import (
|
| 21 |
+
Wav2Vec2Adapter,
|
| 22 |
+
Wav2Vec2Encoder,
|
| 23 |
+
Wav2Vec2FeatureEncoder,
|
| 24 |
+
Wav2Vec2FeatureProjection,
|
| 25 |
+
Wav2Vec2ForAudioFrameClassification,
|
| 26 |
+
Wav2Vec2ForCTC,
|
| 27 |
+
Wav2Vec2ForSequenceClassification,
|
| 28 |
+
Wav2Vec2ForXVector,
|
| 29 |
+
Wav2Vec2Model,
|
| 30 |
+
Wav2Vec2PreTrainedModel,
|
| 31 |
+
Wav2Vec2SamePadLayer,
|
| 32 |
+
)
|
| 33 |
+
from .configuration_data2vec_audio import Data2VecAudioConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 37 |
+
|
| 38 |
+
# General docstring
|
| 39 |
+
_CONFIG_FOR_DOC = "Data2VecAudioConfig"
|
| 40 |
+
|
| 41 |
+
# Base docstring
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
|
| 43 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
| 44 |
+
|
| 45 |
+
# CTC docstring
|
| 46 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 47 |
+
_CTC_EXPECTED_LOSS = 66.95
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Data2VecAudioConvLayer(nn.Module):
|
| 51 |
+
def __init__(self, config, layer_id=0):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 54 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 55 |
+
|
| 56 |
+
self.conv = nn.Conv1d(
|
| 57 |
+
self.in_conv_dim,
|
| 58 |
+
self.out_conv_dim,
|
| 59 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 60 |
+
stride=config.conv_stride[layer_id],
|
| 61 |
+
bias=config.conv_bias,
|
| 62 |
+
)
|
| 63 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 64 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 65 |
+
|
| 66 |
+
def forward(self, hidden_states):
|
| 67 |
+
hidden_states = self.conv(hidden_states)
|
| 68 |
+
|
| 69 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 70 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 71 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 72 |
+
|
| 73 |
+
hidden_states = self.activation(hidden_states)
|
| 74 |
+
return hidden_states
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Data2VecAudioPositionalConvLayer(nn.Module):
|
| 82 |
+
def __init__(self, config):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.conv = nn.Conv1d(
|
| 85 |
+
config.hidden_size,
|
| 86 |
+
config.hidden_size,
|
| 87 |
+
kernel_size=config.conv_pos_kernel_size,
|
| 88 |
+
padding=config.conv_pos_kernel_size // 2,
|
| 89 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
|
| 93 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 94 |
+
# no learnable parameters
|
| 95 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_states):
|
| 98 |
+
hidden_states = self.conv(hidden_states)
|
| 99 |
+
hidden_states = self.padding(hidden_states)
|
| 100 |
+
|
| 101 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 102 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 103 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 104 |
+
hidden_states = self.activation(hidden_states)
|
| 105 |
+
return hidden_states
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Data2VecAudioPositionalConvEmbedding(nn.Module):
|
| 109 |
+
def __init__(self, config):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.layers = nn.ModuleList(
|
| 112 |
+
[Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, hidden_states):
|
| 116 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 117 |
+
for layer in self.layers:
|
| 118 |
+
hidden_states = layer(hidden_states)
|
| 119 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 120 |
+
return hidden_states
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder, nn.Module):
|
| 124 |
+
def __init__(self, config):
|
| 125 |
+
nn.Module.__init__()
|
| 126 |
+
self.conv_layers = nn.ModuleList(
|
| 127 |
+
[Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
|
| 128 |
+
)
|
| 129 |
+
self.gradient_checkpointing = False
|
| 130 |
+
self._requires_grad = True
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Data2VecAudioEncoder(Wav2Vec2Encoder):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Data2VecAudioAdapter(Wav2Vec2Adapter):
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
|
| 146 |
+
"""
|
| 147 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 148 |
+
models.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
config_class = Data2VecAudioConfig
|
| 152 |
+
base_model_prefix = "data2vec_audio"
|
| 153 |
+
main_input_name = "input_values"
|
| 154 |
+
supports_gradient_checkpointing = True
|
| 155 |
+
_supports_flash_attn_2 = True
|
| 156 |
+
_supports_sdpa = True
|
| 157 |
+
|
| 158 |
+
def _init_weights(self, module):
|
| 159 |
+
"""Initialize the weights"""
|
| 160 |
+
if isinstance(module, Data2VecAudioFeatureProjection):
|
| 161 |
+
k = math.sqrt(1 / module.projection.in_features)
|
| 162 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 163 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 164 |
+
elif isinstance(module, Data2VecAudioPositionalConvLayer):
|
| 165 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 166 |
+
elif isinstance(module, nn.Linear):
|
| 167 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 168 |
+
|
| 169 |
+
if module.bias is not None:
|
| 170 |
+
module.bias.data.zero_()
|
| 171 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 172 |
+
if module.bias is not None:
|
| 173 |
+
module.bias.data.zero_()
|
| 174 |
+
if module.weight is not None:
|
| 175 |
+
module.weight.data.fill_(1.0)
|
| 176 |
+
elif isinstance(module, nn.Conv1d):
|
| 177 |
+
nn.init.kaiming_normal_(module.weight)
|
| 178 |
+
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 181 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 182 |
+
|
| 183 |
+
def _get_adapters(self):
|
| 184 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 185 |
+
|
| 186 |
+
def init_adapter_layers(self):
|
| 187 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 188 |
+
|
| 189 |
+
def load_adapter(self):
|
| 190 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
DATA2VEC_AUDIO_START_DOCSTRING = r"""
|
| 194 |
+
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
|
| 195 |
+
Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
|
| 196 |
+
Michael Auli.
|
| 197 |
+
|
| 198 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 199 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 200 |
+
|
| 201 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 202 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 203 |
+
behavior.
|
| 204 |
+
|
| 205 |
+
Parameters:
|
| 206 |
+
config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
|
| 207 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 208 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
|
| 212 |
+
Args:
|
| 213 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 214 |
+
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
| 215 |
+
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
| 216 |
+
soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
|
| 217 |
+
conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
|
| 218 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 219 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 220 |
+
1]`:
|
| 221 |
+
|
| 222 |
+
- 1 for tokens that are **not masked**,
|
| 223 |
+
- 0 for tokens that are **masked**.
|
| 224 |
+
|
| 225 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 226 |
+
|
| 227 |
+
<Tip warning={true}>
|
| 228 |
+
|
| 229 |
+
`attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
|
| 230 |
+
True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
|
| 231 |
+
`attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
|
| 232 |
+
because there are more than one convolutional layer in the positional encodings. For a more detailed
|
| 233 |
+
explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
|
| 234 |
+
|
| 235 |
+
</Tip>
|
| 236 |
+
|
| 237 |
+
output_attentions (`bool`, *optional*):
|
| 238 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 239 |
+
tensors for more detail.
|
| 240 |
+
output_hidden_states (`bool`, *optional*):
|
| 241 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 242 |
+
more detail.
|
| 243 |
+
return_dict (`bool`, *optional*):
|
| 244 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@add_start_docstrings(
|
| 251 |
+
"The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
|
| 252 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 253 |
+
)
|
| 254 |
+
class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
|
| 255 |
+
def __init__(self, config: Data2VecAudioConfig):
|
| 256 |
+
Data2VecAudioPreTrainedModel.__init__(config)
|
| 257 |
+
self.config = config
|
| 258 |
+
self.feature_extractor = Data2VecAudioFeatureEncoder(config)
|
| 259 |
+
self.feature_projection = Data2VecAudioFeatureProjection(config)
|
| 260 |
+
|
| 261 |
+
# model only needs masking vector if mask prob is > 0.0
|
| 262 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 263 |
+
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
|
| 264 |
+
|
| 265 |
+
self.encoder = Data2VecAudioEncoder(config)
|
| 266 |
+
|
| 267 |
+
self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
|
| 268 |
+
|
| 269 |
+
# Initialize weights and apply final processing
|
| 270 |
+
self.post_init()
|
| 271 |
+
|
| 272 |
+
def freeze_feature_extractor(self):
|
| 273 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 274 |
+
|
| 275 |
+
def freeze_feature_encoder(self):
|
| 276 |
+
"""
|
| 277 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 278 |
+
not be updated during training.
|
| 279 |
+
"""
|
| 280 |
+
self.feature_extractor._freeze_parameters()
|
| 281 |
+
|
| 282 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 283 |
+
@add_code_sample_docstrings(
|
| 284 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 285 |
+
output_type=Data2VecAudioBaseModelOutput,
|
| 286 |
+
config_class=_CONFIG_FOR_DOC,
|
| 287 |
+
modality="audio",
|
| 288 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 289 |
+
)
|
| 290 |
+
def forward(self, **super_kwargs):
|
| 291 |
+
return super().forward(**super_kwargs)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@add_start_docstrings(
|
| 295 |
+
"""Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 296 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 297 |
+
)
|
| 298 |
+
class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
|
| 299 |
+
def __init__(self, config):
|
| 300 |
+
Data2VecAudioPreTrainedModel.__init__(config)
|
| 301 |
+
|
| 302 |
+
self.data2vec_audio = Data2VecAudioModel(config)
|
| 303 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 304 |
+
|
| 305 |
+
if config.vocab_size is None:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 308 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 309 |
+
"instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 310 |
+
"or define `vocab_size` of your model's configuration."
|
| 311 |
+
)
|
| 312 |
+
output_hidden_size = (
|
| 313 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 314 |
+
)
|
| 315 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 316 |
+
|
| 317 |
+
# Initialize weights and apply final processing
|
| 318 |
+
self.post_init()
|
| 319 |
+
|
| 320 |
+
def freeze_base_model(self):
|
| 321 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 322 |
+
|
| 323 |
+
def tie_weights(self):
|
| 324 |
+
raise AttributeError("Not needed for Data2VecAudio")
|
| 325 |
+
|
| 326 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 327 |
+
@add_code_sample_docstrings(
|
| 328 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 329 |
+
output_type=CausalLMOutput,
|
| 330 |
+
config_class=_CONFIG_FOR_DOC,
|
| 331 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 332 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 333 |
+
)
|
| 334 |
+
def forward(self, **super_kwargs):
|
| 335 |
+
return super().forward(**super_kwargs)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@add_start_docstrings(
|
| 339 |
+
"""
|
| 340 |
+
Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
|
| 341 |
+
like SUPERB Keyword Spotting.
|
| 342 |
+
""",
|
| 343 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 344 |
+
)
|
| 345 |
+
class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
|
| 346 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 347 |
+
@add_code_sample_docstrings(
|
| 348 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 349 |
+
output_type=SequenceClassifierOutput,
|
| 350 |
+
config_class=_CONFIG_FOR_DOC,
|
| 351 |
+
modality="audio",
|
| 352 |
+
)
|
| 353 |
+
def forward(self, **super_kwargs):
|
| 354 |
+
return super().forward(**super_kwargs)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@add_start_docstrings(
|
| 358 |
+
"""
|
| 359 |
+
Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 360 |
+
""",
|
| 361 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 362 |
+
)
|
| 363 |
+
class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
|
| 364 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 365 |
+
@add_code_sample_docstrings(
|
| 366 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 367 |
+
output_type=TokenClassifierOutput,
|
| 368 |
+
config_class=_CONFIG_FOR_DOC,
|
| 369 |
+
modality="audio",
|
| 370 |
+
)
|
| 371 |
+
def forward(self, **super_kwargs):
|
| 372 |
+
return super().forward(**super_kwargs)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@add_start_docstrings(
|
| 376 |
+
"""
|
| 377 |
+
Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 378 |
+
""",
|
| 379 |
+
DATA2VEC_AUDIO_START_DOCSTRING,
|
| 380 |
+
)
|
| 381 |
+
class Data2VecAudioForXVector(Wav2Vec2ForXVector):
|
| 382 |
+
@add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
|
| 383 |
+
@add_code_sample_docstrings(
|
| 384 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 385 |
+
output_type=XVectorOutput,
|
| 386 |
+
config_class=_CONFIG_FOR_DOC,
|
| 387 |
+
modality="audio",
|
| 388 |
+
)
|
| 389 |
+
def forward(self, **super_kwargs):
|
| 390 |
+
return super().forward(**super_kwargs)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
__all__ = [
|
| 394 |
+
"Data2VecAudioForAudioFrameClassification",
|
| 395 |
+
"Data2VecAudioForCTC",
|
| 396 |
+
"Data2VecAudioForSequenceClassification",
|
| 397 |
+
"Data2VecAudioForXVector",
|
| 398 |
+
"Data2VecAudioModel",
|
| 399 |
+
"Data2VecAudioPreTrainedModel",
|
| 400 |
+
]
|
docs/transformers/build/lib/transformers/models/dbrx/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_dbrx import *
|
| 22 |
+
from .modeling_dbrx import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/dbrx/configuration_dbrx.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DBRX model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DbrxAttentionConfig(PretrainedConfig):
|
| 27 |
+
"""Configuration class for Dbrx Attention.
|
| 28 |
+
|
| 29 |
+
[`DbrxAttention`] class. It is used to instantiate attention layers
|
| 30 |
+
according to the specified arguments, defining the layers architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
| 37 |
+
The dropout probability for the attention layers.
|
| 38 |
+
clip_qkv (`float`, *optional*):
|
| 39 |
+
If set, clip the queries, keys, and values in the attention layer to this value.
|
| 40 |
+
kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
|
| 41 |
+
rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
base_config_key = "attn_config"
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
attn_pdrop: float = 0.0,
|
| 49 |
+
clip_qkv: Optional[float] = None,
|
| 50 |
+
kv_n_heads: int = 1,
|
| 51 |
+
rope_theta: float = 10000.0,
|
| 52 |
+
**kwargs: Any,
|
| 53 |
+
):
|
| 54 |
+
super().__init__(**kwargs)
|
| 55 |
+
self.attn_pdrop = attn_pdrop
|
| 56 |
+
self.clip_qkv = clip_qkv
|
| 57 |
+
self.kv_n_heads = kv_n_heads
|
| 58 |
+
self.rope_theta = rope_theta
|
| 59 |
+
|
| 60 |
+
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
| 61 |
+
if k in kwargs:
|
| 62 |
+
kwargs.pop(k)
|
| 63 |
+
if len(kwargs) != 0:
|
| 64 |
+
raise ValueError(f"Found unknown {kwargs=}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DbrxFFNConfig(PretrainedConfig):
|
| 68 |
+
"""Configuration class for Dbrx FFN.
|
| 69 |
+
|
| 70 |
+
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
| 71 |
+
the specified arguments, defining the layers architecture.
|
| 72 |
+
|
| 73 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 74 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
|
| 78 |
+
The dict should have a key 'name' with the value being the name of the activation function along with
|
| 79 |
+
any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
|
| 80 |
+
ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network.
|
| 81 |
+
moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer.
|
| 82 |
+
moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer.
|
| 83 |
+
moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
|
| 84 |
+
moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer.
|
| 85 |
+
moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
base_config_key = "ffn_config"
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
ffn_act_fn: dict = None,
|
| 93 |
+
ffn_hidden_size: int = 3584,
|
| 94 |
+
moe_num_experts: int = 4,
|
| 95 |
+
moe_top_k: int = 1,
|
| 96 |
+
moe_jitter_eps: Optional[float] = None,
|
| 97 |
+
moe_loss_weight: float = 0.01,
|
| 98 |
+
moe_normalize_expert_weights: Optional[float] = 1.0,
|
| 99 |
+
**kwargs: Any,
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
if ffn_act_fn is None:
|
| 103 |
+
ffn_act_fn = {"name": "silu"}
|
| 104 |
+
self.ffn_act_fn = ffn_act_fn
|
| 105 |
+
self.ffn_hidden_size = ffn_hidden_size
|
| 106 |
+
self.moe_num_experts = moe_num_experts
|
| 107 |
+
self.moe_top_k = moe_top_k
|
| 108 |
+
self.moe_jitter_eps = moe_jitter_eps
|
| 109 |
+
self.moe_loss_weight = moe_loss_weight
|
| 110 |
+
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
| 111 |
+
|
| 112 |
+
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
| 113 |
+
if k in kwargs:
|
| 114 |
+
kwargs.pop(k)
|
| 115 |
+
if len(kwargs) != 0:
|
| 116 |
+
raise ValueError(f"Found unknown {kwargs=}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class DbrxConfig(PretrainedConfig):
|
| 120 |
+
r"""
|
| 121 |
+
|
| 122 |
+
This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
| 123 |
+
specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 124 |
+
defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
|
| 125 |
+
|
| 126 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 127 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
d_model (`int`, *optional*, defaults to 2048):
|
| 132 |
+
Dimensionality of the embeddings and hidden states.
|
| 133 |
+
n_heads (`int`, *optional*, defaults to 16):
|
| 134 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 135 |
+
n_layers (`int`, *optional*, defaults to 24):
|
| 136 |
+
Number of hidden layers in the Transformer encoder.
|
| 137 |
+
max_seq_len (`int`, *optional*, defaults to 2048):
|
| 138 |
+
The maximum sequence length of the model.
|
| 139 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 140 |
+
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
| 141 |
+
the `inputs_ids` passed when calling [`DbrxModel`].
|
| 142 |
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
| 143 |
+
The dropout probability applied to the attention output before combining with residual.
|
| 144 |
+
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
| 145 |
+
The dropout probability for the embedding layer.
|
| 146 |
+
attn_config (`dict`, *optional*):
|
| 147 |
+
A dictionary used to configure the model's attention module.
|
| 148 |
+
ffn_config (`dict`, *optional*):
|
| 149 |
+
A dictionary used to configure the model's FFN module.
|
| 150 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 151 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 152 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 153 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 154 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
| 155 |
+
Whether or not the router logits should be returned by the model. Enabling this will also
|
| 156 |
+
allow the model to output the auxiliary loss. See [here]() for more details.
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
Example:
|
| 160 |
+
```python
|
| 161 |
+
>>> from transformers import DbrxConfig, DbrxModel
|
| 162 |
+
|
| 163 |
+
>>> # Initializing a Dbrx configuration
|
| 164 |
+
>>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
|
| 165 |
+
|
| 166 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 167 |
+
>>> model = DbrxModel(configuration)
|
| 168 |
+
|
| 169 |
+
>>> # Accessing the model configuration
|
| 170 |
+
>>> configuration = model.config
|
| 171 |
+
```
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
model_type = "dbrx"
|
| 175 |
+
sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
|
| 176 |
+
attribute_map = {
|
| 177 |
+
"num_attention_heads": "n_heads",
|
| 178 |
+
"hidden_size": "d_model",
|
| 179 |
+
"num_hidden_layers": "n_layers",
|
| 180 |
+
"max_position_embeddings": "max_seq_len",
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
d_model: int = 2048,
|
| 186 |
+
n_heads: int = 16,
|
| 187 |
+
n_layers: int = 24,
|
| 188 |
+
max_seq_len: int = 2048,
|
| 189 |
+
vocab_size: int = 32000,
|
| 190 |
+
resid_pdrop: float = 0.0,
|
| 191 |
+
emb_pdrop: float = 0.0,
|
| 192 |
+
attn_config: Optional[DbrxAttentionConfig] = None,
|
| 193 |
+
ffn_config: Optional[DbrxFFNConfig] = None,
|
| 194 |
+
use_cache: bool = True,
|
| 195 |
+
initializer_range: float = 0.02,
|
| 196 |
+
output_router_logits: bool = False,
|
| 197 |
+
**kwargs: Any,
|
| 198 |
+
):
|
| 199 |
+
if attn_config is None:
|
| 200 |
+
self.attn_config = DbrxAttentionConfig()
|
| 201 |
+
elif isinstance(attn_config, dict):
|
| 202 |
+
self.attn_config = DbrxAttentionConfig(**attn_config)
|
| 203 |
+
else:
|
| 204 |
+
self.attn_config = attn_config
|
| 205 |
+
|
| 206 |
+
if ffn_config is None:
|
| 207 |
+
self.ffn_config = DbrxFFNConfig()
|
| 208 |
+
elif isinstance(ffn_config, dict):
|
| 209 |
+
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
| 210 |
+
else:
|
| 211 |
+
self.ffn_config = ffn_config
|
| 212 |
+
|
| 213 |
+
self.d_model = d_model
|
| 214 |
+
self.n_heads = n_heads
|
| 215 |
+
self.n_layers = n_layers
|
| 216 |
+
self.max_seq_len = max_seq_len
|
| 217 |
+
self.vocab_size = vocab_size
|
| 218 |
+
self.resid_pdrop = resid_pdrop
|
| 219 |
+
self.emb_pdrop = emb_pdrop
|
| 220 |
+
self.use_cache = use_cache
|
| 221 |
+
self.initializer_range = initializer_range
|
| 222 |
+
self.output_router_logits = output_router_logits
|
| 223 |
+
self.num_key_value_heads = self.attn_config.kv_n_heads
|
| 224 |
+
|
| 225 |
+
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
| 226 |
+
if tie_word_embeddings:
|
| 227 |
+
raise ValueError("tie_word_embeddings is not supported for DBRX models.")
|
| 228 |
+
|
| 229 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
__all__ = ["DbrxConfig"]
|
docs/transformers/build/lib/transformers/models/dbrx/modeling_dbrx.py
ADDED
|
@@ -0,0 +1,1392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DBRX model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 26 |
+
from ...generation import GenerationMixin
|
| 27 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 28 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 29 |
+
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
| 30 |
+
from ...modeling_utils import PreTrainedModel
|
| 31 |
+
from ...utils import (
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
is_torch_flex_attn_available,
|
| 35 |
+
logging,
|
| 36 |
+
replace_return_docstrings,
|
| 37 |
+
)
|
| 38 |
+
from .configuration_dbrx import DbrxConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_torch_flex_attn_available():
|
| 42 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 43 |
+
|
| 44 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_flash_attn_available():
|
| 48 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
_CONFIG_FOR_DOC = "DbrxConfig"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DbrxRotaryEmbedding(nn.Module):
|
| 56 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
self.dim = dim
|
| 60 |
+
self.max_position_embeddings = max_position_embeddings
|
| 61 |
+
self.base = base
|
| 62 |
+
|
| 63 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
| 64 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def forward(self, x, position_ids, seq_len=None):
|
| 68 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 69 |
+
self.inv_freq.to(x.device)
|
| 70 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 71 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 72 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 73 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 74 |
+
device_type = x.device.type
|
| 75 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 76 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 77 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 78 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 79 |
+
cos = emb.cos()
|
| 80 |
+
sin = emb.sin()
|
| 81 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 85 |
+
def rotate_half(x):
|
| 86 |
+
"""Rotates half the hidden dims of the input."""
|
| 87 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 88 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 89 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 93 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 94 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
q (`torch.Tensor`): The query tensor.
|
| 98 |
+
k (`torch.Tensor`): The key tensor.
|
| 99 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 100 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 101 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 102 |
+
Deprecated and unused.
|
| 103 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 104 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 105 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 106 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 107 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 108 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 109 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 110 |
+
Returns:
|
| 111 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 112 |
+
"""
|
| 113 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 114 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 115 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 116 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 117 |
+
return q_embed, k_embed
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 121 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 124 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 125 |
+
"""
|
| 126 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 127 |
+
if n_rep == 1:
|
| 128 |
+
return hidden_states
|
| 129 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 130 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_balancing_loss_func(
|
| 134 |
+
gate_logits: torch.Tensor,
|
| 135 |
+
num_experts: int,
|
| 136 |
+
top_k: int,
|
| 137 |
+
attention_mask: Optional[torch.Tensor],
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 140 |
+
|
| 141 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 142 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 143 |
+
experts is too unbalanced.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
| 147 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 148 |
+
shape [batch_size X sequence_length, num_experts].
|
| 149 |
+
num_experts (`int`):
|
| 150 |
+
Number of experts.
|
| 151 |
+
top_k (`int`):
|
| 152 |
+
The number of experts each token is routed to.
|
| 153 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 154 |
+
The attention_mask used in forward function
|
| 155 |
+
shape [batch_size X sequence_length] if not None.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
The auxiliary loss.
|
| 159 |
+
"""
|
| 160 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 161 |
+
return torch.tensor(0.0)
|
| 162 |
+
|
| 163 |
+
if isinstance(gate_logits, tuple):
|
| 164 |
+
compute_device = gate_logits[0].device
|
| 165 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 166 |
+
|
| 167 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 168 |
+
|
| 169 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 170 |
+
|
| 171 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 172 |
+
|
| 173 |
+
if attention_mask is None:
|
| 174 |
+
# Compute the percentage of tokens routed to each experts
|
| 175 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 176 |
+
|
| 177 |
+
# Compute the average probability of routing to these experts
|
| 178 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 179 |
+
else:
|
| 180 |
+
batch_size, sequence_length = attention_mask.shape
|
| 181 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 182 |
+
|
| 183 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 184 |
+
expert_attention_mask = (
|
| 185 |
+
attention_mask[None, :, :, None, None]
|
| 186 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 187 |
+
.reshape(-1, top_k, num_experts)
|
| 188 |
+
.to(compute_device)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Compute the percentage of tokens routed to each experts
|
| 192 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 193 |
+
expert_attention_mask, dim=0
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 197 |
+
router_per_expert_attention_mask = (
|
| 198 |
+
attention_mask[None, :, :, None]
|
| 199 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 200 |
+
.reshape(-1, num_experts)
|
| 201 |
+
.to(compute_device)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Compute the average probability of routing to these experts
|
| 205 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 206 |
+
router_per_expert_attention_mask, dim=0
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 210 |
+
return overall_loss * num_experts
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class DbrxAttention(nn.Module):
|
| 214 |
+
"""Multi-head self attention."""
|
| 215 |
+
|
| 216 |
+
def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
self.hidden_size = config.d_model
|
| 220 |
+
self.num_heads = config.n_heads
|
| 221 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 222 |
+
self.max_position_embeddings = config.max_seq_len
|
| 223 |
+
self.block_idx = block_idx
|
| 224 |
+
if block_idx is None:
|
| 225 |
+
logger.warning_once(
|
| 226 |
+
f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will "
|
| 227 |
+
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` "
|
| 228 |
+
+ "when creating this class."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
attn_config = config.attn_config
|
| 232 |
+
self.attn_pdrop = attn_config.attn_pdrop
|
| 233 |
+
self.clip_qkv = attn_config.clip_qkv
|
| 234 |
+
self.num_key_value_heads = attn_config.kv_n_heads
|
| 235 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 236 |
+
self.rope_theta = attn_config.rope_theta
|
| 237 |
+
self.is_causal = True
|
| 238 |
+
|
| 239 |
+
self.Wqkv = nn.Linear(
|
| 240 |
+
self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
|
| 241 |
+
)
|
| 242 |
+
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 243 |
+
self.rotary_emb = DbrxRotaryEmbedding(
|
| 244 |
+
self.head_dim,
|
| 245 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 246 |
+
base=self.rope_theta,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
hidden_states: torch.Tensor,
|
| 252 |
+
position_ids: torch.LongTensor,
|
| 253 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 254 |
+
past_key_value: Optional[Cache] = None,
|
| 255 |
+
output_attentions: bool = False,
|
| 256 |
+
use_cache: bool = False,
|
| 257 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 258 |
+
**kwargs: Any,
|
| 259 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
| 260 |
+
bsz, q_len, _ = hidden_states.size()
|
| 261 |
+
|
| 262 |
+
qkv_states = self.Wqkv(hidden_states)
|
| 263 |
+
min_val = -self.clip_qkv if self.clip_qkv is not None else None
|
| 264 |
+
max_val = self.clip_qkv
|
| 265 |
+
qkv_states = qkv_states.clamp(min=min_val, max=max_val)
|
| 266 |
+
|
| 267 |
+
query_states, key_states, value_states = qkv_states.split(
|
| 268 |
+
[
|
| 269 |
+
self.hidden_size,
|
| 270 |
+
self.num_key_value_heads * self.head_dim,
|
| 271 |
+
self.num_key_value_heads * self.head_dim,
|
| 272 |
+
],
|
| 273 |
+
dim=2,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 277 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 278 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 279 |
+
|
| 280 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 281 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 282 |
+
|
| 283 |
+
if past_key_value is not None:
|
| 284 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
| 285 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 286 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
|
| 287 |
+
|
| 288 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 289 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 290 |
+
|
| 291 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 292 |
+
|
| 293 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 294 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 295 |
+
attn_weights = attn_weights + causal_mask
|
| 296 |
+
|
| 297 |
+
# upcast attention to fp32
|
| 298 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 299 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
|
| 300 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 301 |
+
|
| 302 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 305 |
+
+ f" {attn_output.size()}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 309 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 310 |
+
attn_output = self.out_proj(attn_output)
|
| 311 |
+
|
| 312 |
+
if not output_attentions:
|
| 313 |
+
attn_weights = None
|
| 314 |
+
|
| 315 |
+
return attn_output, attn_weights, past_key_value
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class DbrxFlashAttention2(DbrxAttention):
|
| 319 |
+
"""Dbrx flash attention module.
|
| 320 |
+
|
| 321 |
+
This module inherits from `DbrxAttention` as the weights of the module stays
|
| 322 |
+
untouched. The only required change would be on the forward pass where it
|
| 323 |
+
calls the public API of flash attention.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def __init__(self, *args, **kwargs):
|
| 327 |
+
super().__init__(*args, **kwargs)
|
| 328 |
+
|
| 329 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 330 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 331 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 332 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 333 |
+
|
| 334 |
+
def forward(
|
| 335 |
+
self,
|
| 336 |
+
hidden_states: torch.Tensor,
|
| 337 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 338 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 339 |
+
past_key_value: Optional[Cache] = None,
|
| 340 |
+
output_attentions: bool = False,
|
| 341 |
+
use_cache: bool = False,
|
| 342 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 343 |
+
**kwargs: Any,
|
| 344 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 345 |
+
if isinstance(past_key_value, StaticCache):
|
| 346 |
+
raise ValueError(
|
| 347 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
| 348 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
| 349 |
+
)
|
| 350 |
+
logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
|
| 351 |
+
output_attentions = False
|
| 352 |
+
|
| 353 |
+
bsz, q_len, _ = hidden_states.size()
|
| 354 |
+
|
| 355 |
+
qkv_states = self.Wqkv(hidden_states)
|
| 356 |
+
if self.clip_qkv is not None:
|
| 357 |
+
qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
| 358 |
+
|
| 359 |
+
query_states, key_states, value_states = qkv_states.split(
|
| 360 |
+
[
|
| 361 |
+
self.hidden_size,
|
| 362 |
+
self.num_key_value_heads * self.head_dim,
|
| 363 |
+
self.num_key_value_heads * self.head_dim,
|
| 364 |
+
],
|
| 365 |
+
dim=2,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Flash attention requires the input to have the shape
|
| 369 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 370 |
+
# therefore we just need to keep the original shape
|
| 371 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 372 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 373 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 374 |
+
|
| 375 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 376 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 377 |
+
|
| 378 |
+
if past_key_value is not None:
|
| 379 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 380 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 381 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
|
| 382 |
+
|
| 383 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
|
| 384 |
+
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 385 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 386 |
+
query_states = query_states.transpose(1, 2)
|
| 387 |
+
key_states = key_states.transpose(1, 2)
|
| 388 |
+
value_states = value_states.transpose(1, 2)
|
| 389 |
+
|
| 390 |
+
dropout_rate = self.attn_pdrop if self.training else 0.0
|
| 391 |
+
|
| 392 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 393 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 394 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 395 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 396 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 397 |
+
input_dtype = query_states.dtype
|
| 398 |
+
if input_dtype == torch.float32:
|
| 399 |
+
if torch.is_autocast_enabled():
|
| 400 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 401 |
+
# Handle the case where the model is quantized
|
| 402 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 403 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 404 |
+
else:
|
| 405 |
+
target_dtype = query_states.dtype
|
| 406 |
+
|
| 407 |
+
logger.warning_once(
|
| 408 |
+
"The input hidden states seems to be silently casted in float32, this might be "
|
| 409 |
+
+ "related to the fact you have upcasted embedding or layer norm layers in "
|
| 410 |
+
+ f"float32. We will cast back the input in {target_dtype}."
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
query_states = query_states.to(target_dtype)
|
| 414 |
+
key_states = key_states.to(target_dtype)
|
| 415 |
+
value_states = value_states.to(target_dtype)
|
| 416 |
+
|
| 417 |
+
attn_output = _flash_attention_forward(
|
| 418 |
+
query_states,
|
| 419 |
+
key_states,
|
| 420 |
+
value_states,
|
| 421 |
+
attention_mask,
|
| 422 |
+
q_len,
|
| 423 |
+
position_ids=position_ids,
|
| 424 |
+
dropout=dropout_rate,
|
| 425 |
+
is_causal=self.is_causal,
|
| 426 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 430 |
+
attn_output = self.out_proj(attn_output)
|
| 431 |
+
|
| 432 |
+
if not output_attentions:
|
| 433 |
+
attn_weights = None
|
| 434 |
+
|
| 435 |
+
return attn_output, attn_weights, past_key_value
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class DbrxSdpaAttention(DbrxAttention):
|
| 439 |
+
"""
|
| 440 |
+
Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 441 |
+
`DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 442 |
+
SDPA API.
|
| 443 |
+
"""
|
| 444 |
+
|
| 445 |
+
def forward(
|
| 446 |
+
self,
|
| 447 |
+
hidden_states: torch.Tensor,
|
| 448 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 449 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 450 |
+
past_key_value: Optional[Cache] = None,
|
| 451 |
+
output_attentions: bool = False,
|
| 452 |
+
use_cache: bool = False,
|
| 453 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 454 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 455 |
+
if output_attentions:
|
| 456 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 457 |
+
logger.warning_once(
|
| 458 |
+
"DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 459 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 460 |
+
)
|
| 461 |
+
return super().forward(
|
| 462 |
+
hidden_states=hidden_states,
|
| 463 |
+
attention_mask=attention_mask,
|
| 464 |
+
position_ids=position_ids,
|
| 465 |
+
past_key_value=past_key_value,
|
| 466 |
+
output_attentions=output_attentions,
|
| 467 |
+
use_cache=use_cache,
|
| 468 |
+
cache_position=cache_position,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
bsz, q_len, _ = hidden_states.size()
|
| 472 |
+
|
| 473 |
+
qkv_states = self.Wqkv(hidden_states)
|
| 474 |
+
if self.clip_qkv is not None:
|
| 475 |
+
qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
| 476 |
+
|
| 477 |
+
query_states, key_states, value_states = qkv_states.split(
|
| 478 |
+
[
|
| 479 |
+
self.hidden_size,
|
| 480 |
+
self.num_key_value_heads * self.head_dim,
|
| 481 |
+
self.num_key_value_heads * self.head_dim,
|
| 482 |
+
],
|
| 483 |
+
dim=2,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 487 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 488 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 489 |
+
|
| 490 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
| 491 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
| 492 |
+
|
| 493 |
+
if past_key_value is not None:
|
| 494 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 495 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 496 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
|
| 497 |
+
|
| 498 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 499 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 500 |
+
|
| 501 |
+
causal_mask = attention_mask
|
| 502 |
+
if attention_mask is not None:
|
| 503 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 504 |
+
|
| 505 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 506 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 507 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 508 |
+
query_states = query_states.contiguous()
|
| 509 |
+
key_states = key_states.contiguous()
|
| 510 |
+
value_states = value_states.contiguous()
|
| 511 |
+
|
| 512 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 513 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 514 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 515 |
+
|
| 516 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 517 |
+
query_states,
|
| 518 |
+
key_states,
|
| 519 |
+
value_states,
|
| 520 |
+
attn_mask=causal_mask,
|
| 521 |
+
dropout_p=self.attn_pdrop if self.training else 0.0,
|
| 522 |
+
is_causal=is_causal,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 526 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 527 |
+
|
| 528 |
+
attn_output = self.out_proj(attn_output)
|
| 529 |
+
|
| 530 |
+
return attn_output, None, past_key_value
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
DBRX_ATTENTION_CLASSES = {
|
| 534 |
+
"eager": DbrxAttention,
|
| 535 |
+
"flash_attention_2": DbrxFlashAttention2,
|
| 536 |
+
"sdpa": DbrxSdpaAttention,
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class DbrxNormAttentionNorm(nn.Module):
|
| 541 |
+
def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
|
| 542 |
+
super().__init__()
|
| 543 |
+
self.block_idx = block_idx
|
| 544 |
+
self.resid_pdrop = config.resid_pdrop
|
| 545 |
+
self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
|
| 546 |
+
self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
|
| 547 |
+
config=config,
|
| 548 |
+
block_idx=block_idx,
|
| 549 |
+
)
|
| 550 |
+
self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
|
| 551 |
+
|
| 552 |
+
def forward(
|
| 553 |
+
self,
|
| 554 |
+
hidden_states: torch.Tensor,
|
| 555 |
+
position_ids: torch.LongTensor,
|
| 556 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 557 |
+
past_key_value: Optional[Cache] = None,
|
| 558 |
+
output_attentions: bool = False,
|
| 559 |
+
use_cache: bool = False,
|
| 560 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 561 |
+
**kwargs: Any,
|
| 562 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
| 563 |
+
residual_states = hidden_states
|
| 564 |
+
hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
|
| 565 |
+
|
| 566 |
+
hidden_states, attn_weights, past_key_value = self.attn(
|
| 567 |
+
hidden_states=hidden_states,
|
| 568 |
+
attention_mask=attention_mask,
|
| 569 |
+
position_ids=position_ids,
|
| 570 |
+
past_key_value=past_key_value,
|
| 571 |
+
output_attentions=output_attentions,
|
| 572 |
+
use_cache=use_cache,
|
| 573 |
+
cache_position=cache_position,
|
| 574 |
+
**kwargs,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
|
| 578 |
+
hidden_states = hidden_states + residual_states
|
| 579 |
+
|
| 580 |
+
residual_states = hidden_states
|
| 581 |
+
hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
|
| 582 |
+
|
| 583 |
+
return residual_states, hidden_states, attn_weights, past_key_value
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class DbrxRouter(nn.Module):
|
| 587 |
+
def __init__(
|
| 588 |
+
self,
|
| 589 |
+
hidden_size: int,
|
| 590 |
+
moe_num_experts: int,
|
| 591 |
+
moe_top_k: int,
|
| 592 |
+
moe_jitter_eps: Optional[float],
|
| 593 |
+
moe_normalize_expert_weights: Optional[float],
|
| 594 |
+
):
|
| 595 |
+
super().__init__()
|
| 596 |
+
self.hidden_size = hidden_size
|
| 597 |
+
self.moe_num_experts = moe_num_experts
|
| 598 |
+
self.moe_top_k = moe_top_k
|
| 599 |
+
self.moe_jitter_eps = moe_jitter_eps
|
| 600 |
+
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
| 601 |
+
|
| 602 |
+
self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
|
| 603 |
+
|
| 604 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
| 605 |
+
if self.training and self.moe_jitter_eps is not None:
|
| 606 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(
|
| 607 |
+
1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
|
| 608 |
+
)
|
| 609 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 610 |
+
weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
|
| 611 |
+
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
| 612 |
+
|
| 613 |
+
top_weights_scale = (
|
| 614 |
+
torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
|
| 615 |
+
if self.moe_normalize_expert_weights is not None
|
| 616 |
+
else 1.0
|
| 617 |
+
)
|
| 618 |
+
top_weights = top_weights / top_weights_scale
|
| 619 |
+
|
| 620 |
+
weights = weights.to(hidden_states.dtype)
|
| 621 |
+
top_weights = top_weights.to(hidden_states.dtype)
|
| 622 |
+
return weights, top_weights, top_experts
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class DbrxExpertGLU(nn.Module):
|
| 626 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
|
| 627 |
+
super().__init__()
|
| 628 |
+
self.hidden_size = hidden_size
|
| 629 |
+
self.ffn_hidden_size = ffn_hidden_size
|
| 630 |
+
self.moe_num_experts = moe_num_experts
|
| 631 |
+
|
| 632 |
+
self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
| 633 |
+
self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
| 634 |
+
self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
| 635 |
+
|
| 636 |
+
act_fn_name = ffn_act_fn.get("name", "silu")
|
| 637 |
+
self.activation_fn = ACT2FN[act_fn_name]
|
| 638 |
+
|
| 639 |
+
def forward(
|
| 640 |
+
self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
|
| 641 |
+
) -> torch.Tensor:
|
| 642 |
+
gate_proj = x.matmul(expert_w1.t())
|
| 643 |
+
up_proj = x.matmul(expert_v1.t())
|
| 644 |
+
gate_proj = self.activation_fn(gate_proj)
|
| 645 |
+
intermediate_states = gate_proj * up_proj
|
| 646 |
+
down_proj = intermediate_states.matmul(expert_w2)
|
| 647 |
+
return down_proj
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
class DbrxExperts(nn.Module):
|
| 651 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
|
| 652 |
+
super().__init__()
|
| 653 |
+
self.moe_num_experts = moe_num_experts
|
| 654 |
+
self.mlp = DbrxExpertGLU(
|
| 655 |
+
hidden_size=hidden_size,
|
| 656 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 657 |
+
moe_num_experts=moe_num_experts,
|
| 658 |
+
ffn_act_fn=ffn_act_fn,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
def forward(
|
| 662 |
+
self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
|
| 663 |
+
) -> torch.Tensor:
|
| 664 |
+
bsz, q_len, hidden_size = x.shape
|
| 665 |
+
x = x.view(-1, hidden_size)
|
| 666 |
+
out = torch.zeros_like(x)
|
| 667 |
+
|
| 668 |
+
expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
| 669 |
+
# Chunk experts at once to avoid storing full parameter multiple times in autograd
|
| 670 |
+
w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
|
| 671 |
+
self.moe_num_experts, dim=0
|
| 672 |
+
)
|
| 673 |
+
v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
|
| 674 |
+
self.moe_num_experts, dim=0
|
| 675 |
+
)
|
| 676 |
+
w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
|
| 677 |
+
self.moe_num_experts, dim=0
|
| 678 |
+
)
|
| 679 |
+
w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
|
| 680 |
+
v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
|
| 681 |
+
w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
|
| 682 |
+
for expert_idx in range(0, self.moe_num_experts):
|
| 683 |
+
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
|
| 684 |
+
# (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
|
| 685 |
+
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
| 686 |
+
if token_idx.shape[0] == 0:
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
token_list = token_idx
|
| 690 |
+
topk_list = topk_idx
|
| 691 |
+
|
| 692 |
+
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
| 693 |
+
expert_out = (
|
| 694 |
+
self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
|
| 695 |
+
* top_weights[token_list, topk_list, None]
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
out.index_add_(0, token_idx, expert_out)
|
| 699 |
+
|
| 700 |
+
out = out.reshape(bsz, q_len, hidden_size)
|
| 701 |
+
return out
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class DbrxFFN(nn.Module):
|
| 705 |
+
def __init__(self, config: DbrxConfig):
|
| 706 |
+
super().__init__()
|
| 707 |
+
|
| 708 |
+
ffn_config = config.ffn_config
|
| 709 |
+
self.router = DbrxRouter(
|
| 710 |
+
hidden_size=config.d_model,
|
| 711 |
+
moe_num_experts=ffn_config.moe_num_experts,
|
| 712 |
+
moe_top_k=ffn_config.moe_top_k,
|
| 713 |
+
moe_jitter_eps=ffn_config.moe_jitter_eps,
|
| 714 |
+
moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
self.experts = DbrxExperts(
|
| 718 |
+
hidden_size=config.d_model,
|
| 719 |
+
ffn_hidden_size=ffn_config.ffn_hidden_size,
|
| 720 |
+
moe_num_experts=ffn_config.moe_num_experts,
|
| 721 |
+
ffn_act_fn=ffn_config.ffn_act_fn,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 725 |
+
weights, top_weights, top_experts = self.router(x)
|
| 726 |
+
out = self.experts(x, weights, top_weights, top_experts)
|
| 727 |
+
return out, weights
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
class DbrxBlock(nn.Module):
|
| 731 |
+
def __init__(self, config: DbrxConfig, block_idx: int):
|
| 732 |
+
super().__init__()
|
| 733 |
+
self.hidden_size = config.d_model
|
| 734 |
+
self.resid_pdrop = config.resid_pdrop
|
| 735 |
+
self.block_idx = block_idx
|
| 736 |
+
self.norm_attn_norm = DbrxNormAttentionNorm(
|
| 737 |
+
config=config,
|
| 738 |
+
block_idx=block_idx,
|
| 739 |
+
)
|
| 740 |
+
self.ffn = DbrxFFN(config=config)
|
| 741 |
+
|
| 742 |
+
def forward(
|
| 743 |
+
self,
|
| 744 |
+
hidden_states: torch.Tensor,
|
| 745 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 746 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 747 |
+
past_key_value: Optional[Cache] = None,
|
| 748 |
+
output_attentions: Optional[bool] = False,
|
| 749 |
+
output_router_logits: Optional[bool] = False,
|
| 750 |
+
use_cache: Optional[bool] = False,
|
| 751 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 752 |
+
**kwargs: Any,
|
| 753 |
+
) -> Union[
|
| 754 |
+
Tuple[torch.Tensor],
|
| 755 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
| 756 |
+
Tuple[torch.Tensor, Optional[Cache]],
|
| 757 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
|
| 758 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
| 759 |
+
Tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
|
| 760 |
+
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]],
|
| 761 |
+
]:
|
| 762 |
+
"""Forward function for DbrxBlock.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 766 |
+
position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
|
| 767 |
+
attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length)
|
| 768 |
+
if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
|
| 769 |
+
if default attention is used.
|
| 770 |
+
past_key_value (`Tuple(torch.Tensor)`, *optional*): cached past key and value projection states
|
| 771 |
+
output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all
|
| 772 |
+
attention layers. See `attentions` under returned tensors for more detail.
|
| 773 |
+
output_router_logits (`bool`, *optional*): Whether or not to return the router logits.
|
| 774 |
+
use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are
|
| 775 |
+
returned and can be used to speed up decoding (see `past_key_values`).
|
| 776 |
+
cache_position (`torch.LongTensor`, *optional*): position ids of the cache
|
| 777 |
+
"""
|
| 778 |
+
|
| 779 |
+
# Norm + Attention + Norm
|
| 780 |
+
resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
|
| 781 |
+
hidden_states=hidden_states,
|
| 782 |
+
attention_mask=attention_mask,
|
| 783 |
+
position_ids=position_ids,
|
| 784 |
+
past_key_value=past_key_value,
|
| 785 |
+
output_attentions=output_attentions,
|
| 786 |
+
use_cache=use_cache,
|
| 787 |
+
cache_position=cache_position,
|
| 788 |
+
**kwargs,
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
# Fully Connected
|
| 792 |
+
hidden_states, router_logits = self.ffn(hidden_states)
|
| 793 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
|
| 794 |
+
hidden_states = resid_states + hidden_states
|
| 795 |
+
|
| 796 |
+
outputs = (hidden_states,)
|
| 797 |
+
|
| 798 |
+
if output_attentions:
|
| 799 |
+
outputs += (self_attn_weights,)
|
| 800 |
+
|
| 801 |
+
if use_cache:
|
| 802 |
+
outputs += (present_key_value,)
|
| 803 |
+
|
| 804 |
+
if output_router_logits:
|
| 805 |
+
outputs += (router_logits,)
|
| 806 |
+
|
| 807 |
+
return outputs
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
DBRX_START_DOCSTRING = r"""
|
| 811 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 812 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 813 |
+
etc.)
|
| 814 |
+
|
| 815 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 816 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 817 |
+
and behavior.
|
| 818 |
+
|
| 819 |
+
Parameters:
|
| 820 |
+
config ([`DbrxConfig`]):
|
| 821 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 822 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 823 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
@add_start_docstrings(
|
| 828 |
+
"The bare DBRX Model outputting raw hidden-states without any specific head on top.",
|
| 829 |
+
DBRX_START_DOCSTRING,
|
| 830 |
+
)
|
| 831 |
+
class DbrxPreTrainedModel(PreTrainedModel):
|
| 832 |
+
config_class = DbrxConfig
|
| 833 |
+
base_model_prefix = "transformer"
|
| 834 |
+
supports_gradient_checkpointing = True
|
| 835 |
+
_no_split_modules = ["DbrxBlock"]
|
| 836 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 837 |
+
_supports_flash_attn_2 = True
|
| 838 |
+
_supports_sdpa = True
|
| 839 |
+
_supports_cache_class = True
|
| 840 |
+
_supports_quantized_cache = True
|
| 841 |
+
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 842 |
+
|
| 843 |
+
def _init_weights(self, module: nn.Module):
|
| 844 |
+
std = self.config.initializer_range
|
| 845 |
+
if isinstance(module, nn.Linear):
|
| 846 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 847 |
+
if module.bias is not None:
|
| 848 |
+
module.bias.data.zero_()
|
| 849 |
+
elif isinstance(module, nn.Embedding):
|
| 850 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 851 |
+
if module.padding_idx is not None:
|
| 852 |
+
module.weight.data[module.padding_idx].zero_()
|
| 853 |
+
elif isinstance(module, nn.LayerNorm):
|
| 854 |
+
module.weight.data.fill_(1.0)
|
| 855 |
+
if module.bias is not None:
|
| 856 |
+
module.bias.data.zero_()
|
| 857 |
+
elif isinstance(module, DbrxExpertGLU):
|
| 858 |
+
module.w1.data.normal_(mean=0.0, std=std)
|
| 859 |
+
module.v1.data.normal_(mean=0.0, std=std)
|
| 860 |
+
module.w2.data.normal_(mean=0.0, std=std)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
DBRX_INPUTS_DOCSTRING = r"""
|
| 864 |
+
Args:
|
| 865 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 866 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 867 |
+
it.
|
| 868 |
+
|
| 869 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 870 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 871 |
+
|
| 872 |
+
[What are input IDs?](../glossary#input-ids)
|
| 873 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 874 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 875 |
+
|
| 876 |
+
- 1 for tokens that are **not masked**,
|
| 877 |
+
- 0 for tokens that are **masked**.
|
| 878 |
+
|
| 879 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 880 |
+
|
| 881 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 882 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 883 |
+
|
| 884 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 885 |
+
`past_key_values`).
|
| 886 |
+
|
| 887 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 888 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 889 |
+
information on the default strategy.
|
| 890 |
+
|
| 891 |
+
- 1 indicates the head is **not masked**,
|
| 892 |
+
- 0 indicates the head is **masked**.
|
| 893 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 894 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 895 |
+
config.n_positions - 1]`.
|
| 896 |
+
|
| 897 |
+
[What are position IDs?](../glossary#position-ids)
|
| 898 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 899 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 900 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 901 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 902 |
+
|
| 903 |
+
Two formats are allowed:
|
| 904 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 905 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 906 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 907 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 908 |
+
cache format.
|
| 909 |
+
|
| 910 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 911 |
+
legacy cache format will be returned.
|
| 912 |
+
|
| 913 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 914 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 915 |
+
of shape `(batch_size, sequence_length)`.
|
| 916 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 917 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 918 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 919 |
+
model's internal embedding lookup matrix.
|
| 920 |
+
use_cache (`bool`, *optional*):
|
| 921 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 922 |
+
`past_key_values`).
|
| 923 |
+
output_attentions (`bool`, *optional*):
|
| 924 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 925 |
+
tensors for more detail.
|
| 926 |
+
output_hidden_states (`bool`, *optional*):
|
| 927 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 928 |
+
more detail.
|
| 929 |
+
output_router_logits (`bool`, *optional*):
|
| 930 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 931 |
+
should not be returned during inference.
|
| 932 |
+
return_dict (`bool`, *optional*):
|
| 933 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 934 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 935 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 936 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 937 |
+
the complete sequence length.
|
| 938 |
+
"""
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
@add_start_docstrings(
|
| 942 |
+
"The bare DBRX Model outputting raw hidden-states without any specific head on top.",
|
| 943 |
+
DBRX_START_DOCSTRING,
|
| 944 |
+
)
|
| 945 |
+
class DbrxModel(DbrxPreTrainedModel):
|
| 946 |
+
"""Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
|
| 947 |
+
|
| 948 |
+
Args:
|
| 949 |
+
config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
|
| 950 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 951 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 952 |
+
"""
|
| 953 |
+
|
| 954 |
+
def __init__(self, config: DbrxConfig):
|
| 955 |
+
super().__init__(config)
|
| 956 |
+
self.padding_idx = config.pad_token_id
|
| 957 |
+
self.vocab_size = config.vocab_size
|
| 958 |
+
self.emb_pdrop = config.emb_pdrop
|
| 959 |
+
|
| 960 |
+
self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
| 961 |
+
self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
|
| 962 |
+
self.norm_f = nn.LayerNorm(config.d_model, bias=False)
|
| 963 |
+
self.gradient_checkpointing = False
|
| 964 |
+
|
| 965 |
+
# Initialize weights and apply final processing
|
| 966 |
+
self.post_init()
|
| 967 |
+
|
| 968 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 969 |
+
return self.wte
|
| 970 |
+
|
| 971 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 972 |
+
self.wte = value
|
| 973 |
+
|
| 974 |
+
@add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
|
| 975 |
+
def forward(
|
| 976 |
+
self,
|
| 977 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 978 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 979 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 980 |
+
past_key_values: Optional[Cache] = None,
|
| 981 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 982 |
+
use_cache: Optional[bool] = None,
|
| 983 |
+
output_attentions: Optional[bool] = None,
|
| 984 |
+
output_hidden_states: Optional[bool] = None,
|
| 985 |
+
output_router_logits: Optional[bool] = None,
|
| 986 |
+
return_dict: Optional[bool] = None,
|
| 987 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 988 |
+
**kwargs, # NOOP kwargs, for now
|
| 989 |
+
) -> Union[Tuple, MoeModelOutputWithPast]:
|
| 990 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 991 |
+
output_hidden_states = (
|
| 992 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 993 |
+
)
|
| 994 |
+
output_router_logits = (
|
| 995 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 996 |
+
)
|
| 997 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 998 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 999 |
+
|
| 1000 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1001 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1002 |
+
|
| 1003 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 1004 |
+
logger.warning_once(
|
| 1005 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 1006 |
+
)
|
| 1007 |
+
use_cache = False
|
| 1008 |
+
|
| 1009 |
+
if inputs_embeds is None:
|
| 1010 |
+
inputs_embeds = self.wte(input_ids)
|
| 1011 |
+
|
| 1012 |
+
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
| 1013 |
+
|
| 1014 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 1015 |
+
return_legacy_cache = False
|
| 1016 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 1017 |
+
return_legacy_cache = True
|
| 1018 |
+
if past_key_values is None:
|
| 1019 |
+
past_key_values = DynamicCache()
|
| 1020 |
+
else:
|
| 1021 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1022 |
+
logger.warning_once(
|
| 1023 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 1024 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 1025 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
if cache_position is None:
|
| 1029 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1030 |
+
cache_position = torch.arange(
|
| 1031 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
if position_ids is None:
|
| 1035 |
+
position_ids = cache_position.unsqueeze(0)
|
| 1036 |
+
|
| 1037 |
+
causal_mask = self._update_causal_mask(
|
| 1038 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# embed positions
|
| 1042 |
+
hidden_states = inputs_embeds
|
| 1043 |
+
|
| 1044 |
+
# decoder layers
|
| 1045 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1046 |
+
all_self_attns = () if output_attentions else None
|
| 1047 |
+
all_router_logits = () if output_router_logits else None
|
| 1048 |
+
next_decoder_cache = None
|
| 1049 |
+
|
| 1050 |
+
for block in self.blocks:
|
| 1051 |
+
if output_hidden_states:
|
| 1052 |
+
all_hidden_states += (hidden_states,)
|
| 1053 |
+
|
| 1054 |
+
if self.gradient_checkpointing and self.training:
|
| 1055 |
+
block_outputs = self._gradient_checkpointing_func(
|
| 1056 |
+
block.__call__,
|
| 1057 |
+
hidden_states,
|
| 1058 |
+
causal_mask,
|
| 1059 |
+
position_ids,
|
| 1060 |
+
past_key_values,
|
| 1061 |
+
output_attentions,
|
| 1062 |
+
output_router_logits,
|
| 1063 |
+
use_cache,
|
| 1064 |
+
cache_position,
|
| 1065 |
+
)
|
| 1066 |
+
else:
|
| 1067 |
+
block_outputs = block(
|
| 1068 |
+
hidden_states,
|
| 1069 |
+
attention_mask=causal_mask,
|
| 1070 |
+
position_ids=position_ids,
|
| 1071 |
+
past_key_value=past_key_values,
|
| 1072 |
+
output_attentions=output_attentions,
|
| 1073 |
+
output_router_logits=output_router_logits,
|
| 1074 |
+
use_cache=use_cache,
|
| 1075 |
+
cache_position=cache_position,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
hidden_states = block_outputs[0]
|
| 1079 |
+
|
| 1080 |
+
if use_cache:
|
| 1081 |
+
next_decoder_cache = block_outputs[2 if output_attentions else 1]
|
| 1082 |
+
|
| 1083 |
+
if output_attentions:
|
| 1084 |
+
all_self_attns += (block_outputs[1],)
|
| 1085 |
+
|
| 1086 |
+
if output_router_logits:
|
| 1087 |
+
all_router_logits += (block_outputs[-1],)
|
| 1088 |
+
|
| 1089 |
+
hidden_states = self.norm_f(hidden_states)
|
| 1090 |
+
|
| 1091 |
+
# add hidden states from the last decoder layer
|
| 1092 |
+
if output_hidden_states:
|
| 1093 |
+
all_hidden_states += (hidden_states,)
|
| 1094 |
+
|
| 1095 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 1096 |
+
if return_legacy_cache:
|
| 1097 |
+
next_cache = next_cache.to_legacy_cache()
|
| 1098 |
+
|
| 1099 |
+
if not return_dict:
|
| 1100 |
+
return tuple(
|
| 1101 |
+
v
|
| 1102 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
| 1103 |
+
if v is not None
|
| 1104 |
+
)
|
| 1105 |
+
return MoeModelOutputWithPast(
|
| 1106 |
+
last_hidden_state=hidden_states,
|
| 1107 |
+
past_key_values=next_cache,
|
| 1108 |
+
hidden_states=all_hidden_states,
|
| 1109 |
+
attentions=all_self_attns,
|
| 1110 |
+
router_logits=all_router_logits,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
| 1114 |
+
def _update_causal_mask(
|
| 1115 |
+
self,
|
| 1116 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 1117 |
+
input_tensor: torch.Tensor,
|
| 1118 |
+
cache_position: torch.Tensor,
|
| 1119 |
+
past_key_values: Cache,
|
| 1120 |
+
output_attentions: bool = False,
|
| 1121 |
+
):
|
| 1122 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1123 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 1124 |
+
return attention_mask
|
| 1125 |
+
return None
|
| 1126 |
+
if self.config._attn_implementation == "flex_attention":
|
| 1127 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 1128 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 1129 |
+
return attention_mask
|
| 1130 |
+
|
| 1131 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1132 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1133 |
+
# to infer the attention mask.
|
| 1134 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1135 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1136 |
+
|
| 1137 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1138 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1139 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1140 |
+
attention_mask,
|
| 1141 |
+
inputs_embeds=input_tensor,
|
| 1142 |
+
past_key_values_length=past_seen_tokens,
|
| 1143 |
+
is_training=self.training,
|
| 1144 |
+
):
|
| 1145 |
+
return None
|
| 1146 |
+
|
| 1147 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1148 |
+
sequence_length = input_tensor.shape[1]
|
| 1149 |
+
if using_static_cache:
|
| 1150 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1151 |
+
else:
|
| 1152 |
+
target_length = (
|
| 1153 |
+
attention_mask.shape[-1]
|
| 1154 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1155 |
+
else past_seen_tokens + sequence_length + 1
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1159 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1160 |
+
attention_mask,
|
| 1161 |
+
sequence_length=sequence_length,
|
| 1162 |
+
target_length=target_length,
|
| 1163 |
+
dtype=dtype,
|
| 1164 |
+
device=device,
|
| 1165 |
+
cache_position=cache_position,
|
| 1166 |
+
batch_size=input_tensor.shape[0],
|
| 1167 |
+
)
|
| 1168 |
+
|
| 1169 |
+
if (
|
| 1170 |
+
self.config._attn_implementation == "sdpa"
|
| 1171 |
+
and attention_mask is not None
|
| 1172 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 1173 |
+
and not output_attentions
|
| 1174 |
+
):
|
| 1175 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1176 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1177 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1178 |
+
min_dtype = torch.finfo(dtype).min
|
| 1179 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1180 |
+
|
| 1181 |
+
return causal_mask
|
| 1182 |
+
|
| 1183 |
+
@staticmethod
|
| 1184 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 1185 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1186 |
+
attention_mask: torch.Tensor,
|
| 1187 |
+
sequence_length: int,
|
| 1188 |
+
target_length: int,
|
| 1189 |
+
dtype: torch.dtype,
|
| 1190 |
+
device: torch.device,
|
| 1191 |
+
cache_position: torch.Tensor,
|
| 1192 |
+
batch_size: int,
|
| 1193 |
+
**kwargs,
|
| 1194 |
+
):
|
| 1195 |
+
"""
|
| 1196 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1197 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1198 |
+
|
| 1199 |
+
Args:
|
| 1200 |
+
attention_mask (`torch.Tensor`):
|
| 1201 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 1202 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 1203 |
+
sequence_length (`int`):
|
| 1204 |
+
The sequence length being processed.
|
| 1205 |
+
target_length (`int`):
|
| 1206 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 1207 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1208 |
+
dtype (`torch.dtype`):
|
| 1209 |
+
The dtype to use for the 4D attention mask.
|
| 1210 |
+
device (`torch.device`):
|
| 1211 |
+
The device to place the 4D attention mask on.
|
| 1212 |
+
cache_position (`torch.Tensor`):
|
| 1213 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1214 |
+
batch_size (`torch.Tensor`):
|
| 1215 |
+
Batch size.
|
| 1216 |
+
"""
|
| 1217 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1218 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1219 |
+
causal_mask = attention_mask
|
| 1220 |
+
else:
|
| 1221 |
+
min_dtype = torch.finfo(dtype).min
|
| 1222 |
+
causal_mask = torch.full(
|
| 1223 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 1224 |
+
)
|
| 1225 |
+
if sequence_length != 1:
|
| 1226 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 1227 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1228 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1229 |
+
if attention_mask is not None:
|
| 1230 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 1231 |
+
mask_length = attention_mask.shape[-1]
|
| 1232 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 1233 |
+
causal_mask.device
|
| 1234 |
+
)
|
| 1235 |
+
padding_mask = padding_mask == 0
|
| 1236 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 1237 |
+
padding_mask, min_dtype
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
return causal_mask
|
| 1241 |
+
|
| 1242 |
+
|
| 1243 |
+
@add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING)
|
| 1244 |
+
class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
|
| 1245 |
+
def __init__(self, config: DbrxConfig):
|
| 1246 |
+
super().__init__(config)
|
| 1247 |
+
self.transformer = DbrxModel(config)
|
| 1248 |
+
self.vocab_size = config.vocab_size
|
| 1249 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1250 |
+
self.moe_loss_weight = config.ffn_config.moe_loss_weight
|
| 1251 |
+
self.num_experts = config.ffn_config.moe_num_experts
|
| 1252 |
+
self.num_experts_per_tok = config.ffn_config.moe_top_k
|
| 1253 |
+
|
| 1254 |
+
# Initialize weights and apply final processing
|
| 1255 |
+
self.post_init()
|
| 1256 |
+
|
| 1257 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 1258 |
+
return self.transformer.get_input_embeddings()
|
| 1259 |
+
|
| 1260 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 1261 |
+
self.transformer.set_input_embeddings(value)
|
| 1262 |
+
|
| 1263 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 1264 |
+
return self.lm_head
|
| 1265 |
+
|
| 1266 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
| 1267 |
+
self.lm_head = new_embeddings
|
| 1268 |
+
|
| 1269 |
+
def set_decoder(self, decoder: DbrxModel):
|
| 1270 |
+
self.transformer = decoder
|
| 1271 |
+
|
| 1272 |
+
def get_decoder(self) -> DbrxModel:
|
| 1273 |
+
return self.transformer
|
| 1274 |
+
|
| 1275 |
+
@add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
|
| 1276 |
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1277 |
+
def forward(
|
| 1278 |
+
self,
|
| 1279 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1280 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1281 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1282 |
+
past_key_values: Optional[Cache] = None,
|
| 1283 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1284 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1285 |
+
use_cache: Optional[bool] = None,
|
| 1286 |
+
output_attentions: Optional[bool] = None,
|
| 1287 |
+
output_hidden_states: Optional[bool] = None,
|
| 1288 |
+
output_router_logits: Optional[bool] = None,
|
| 1289 |
+
return_dict: Optional[bool] = None,
|
| 1290 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1291 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1292 |
+
**kwargs,
|
| 1293 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
| 1294 |
+
r"""
|
| 1295 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1296 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1297 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1298 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1299 |
+
|
| 1300 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 1301 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1302 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1303 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1304 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 1305 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 1306 |
+
|
| 1307 |
+
Returns:
|
| 1308 |
+
|
| 1309 |
+
Example:
|
| 1310 |
+
|
| 1311 |
+
```python
|
| 1312 |
+
>> from transformers import AutoTokenizer, DbrxForCausalLM
|
| 1313 |
+
|
| 1314 |
+
>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
|
| 1315 |
+
>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
|
| 1316 |
+
|
| 1317 |
+
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1318 |
+
>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1319 |
+
|
| 1320 |
+
>> # Generate
|
| 1321 |
+
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1322 |
+
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1323 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1324 |
+
```
|
| 1325 |
+
"""
|
| 1326 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1327 |
+
output_hidden_states = (
|
| 1328 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1329 |
+
)
|
| 1330 |
+
output_router_logits = (
|
| 1331 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1332 |
+
)
|
| 1333 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1334 |
+
|
| 1335 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1336 |
+
outputs = self.transformer(
|
| 1337 |
+
input_ids=input_ids,
|
| 1338 |
+
attention_mask=attention_mask,
|
| 1339 |
+
position_ids=position_ids,
|
| 1340 |
+
past_key_values=past_key_values,
|
| 1341 |
+
inputs_embeds=inputs_embeds,
|
| 1342 |
+
use_cache=use_cache,
|
| 1343 |
+
output_attentions=output_attentions,
|
| 1344 |
+
output_hidden_states=output_hidden_states,
|
| 1345 |
+
output_router_logits=output_router_logits,
|
| 1346 |
+
return_dict=return_dict,
|
| 1347 |
+
cache_position=cache_position,
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
hidden_states = outputs[0]
|
| 1351 |
+
# No upscaling to float was ever done for Dbrx
|
| 1352 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1353 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1354 |
+
|
| 1355 |
+
loss = None
|
| 1356 |
+
if labels is not None:
|
| 1357 |
+
loss = self.loss_function(
|
| 1358 |
+
logits,
|
| 1359 |
+
labels,
|
| 1360 |
+
vocab_size=self.config.vocab_size,
|
| 1361 |
+
**kwargs,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
aux_loss = None
|
| 1365 |
+
if output_router_logits:
|
| 1366 |
+
aux_loss = load_balancing_loss_func(
|
| 1367 |
+
outputs.router_logits if return_dict else outputs[-1],
|
| 1368 |
+
self.num_experts,
|
| 1369 |
+
self.num_experts_per_tok,
|
| 1370 |
+
attention_mask,
|
| 1371 |
+
)
|
| 1372 |
+
if labels is not None and loss is not None:
|
| 1373 |
+
loss += self.moe_loss_weight * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 1374 |
+
|
| 1375 |
+
if not return_dict:
|
| 1376 |
+
output = (logits,) + outputs[1:]
|
| 1377 |
+
if output_router_logits:
|
| 1378 |
+
output = (aux_loss,) + output
|
| 1379 |
+
return (loss,) + output if loss is not None else output
|
| 1380 |
+
|
| 1381 |
+
return MoeCausalLMOutputWithPast(
|
| 1382 |
+
loss=loss,
|
| 1383 |
+
aux_loss=aux_loss,
|
| 1384 |
+
logits=logits,
|
| 1385 |
+
past_key_values=outputs.past_key_values,
|
| 1386 |
+
hidden_states=outputs.hidden_states,
|
| 1387 |
+
attentions=outputs.attentions,
|
| 1388 |
+
router_logits=outputs.router_logits,
|
| 1389 |
+
)
|
| 1390 |
+
|
| 1391 |
+
|
| 1392 |
+
__all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/deberta/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_deberta import *
|
| 22 |
+
from .modeling_deberta import *
|
| 23 |
+
from .modeling_tf_deberta import *
|
| 24 |
+
from .tokenization_deberta import *
|
| 25 |
+
from .tokenization_deberta_fast import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deberta/configuration_deberta.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DeBERTa model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...onnx import OnnxConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DebertaConfig(PretrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`DebertaModel`] or a [`TFDebertaModel`]. It is
|
| 35 |
+
used to instantiate a DeBERTa model according to the specified arguments, defining the model architecture.
|
| 36 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the DeBERTa
|
| 37 |
+
[microsoft/deberta-base](https://huggingface.co/microsoft/deberta-base) architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
| 44 |
+
Vocabulary size of the DeBERTa model. Defines the number of different tokens that can be represented by the
|
| 45 |
+
`inputs_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
|
| 46 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 47 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 48 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 49 |
+
Number of hidden layers in the Transformer encoder.
|
| 50 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 51 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 52 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 53 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 54 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 56 |
+
`"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"`
|
| 57 |
+
are supported.
|
| 58 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 59 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 60 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 61 |
+
The dropout ratio for the attention probabilities.
|
| 62 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 63 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 64 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 65 |
+
type_vocab_size (`int`, *optional*, defaults to 0):
|
| 66 |
+
The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
|
| 67 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 69 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 70 |
+
The epsilon used by the layer normalization layers.
|
| 71 |
+
relative_attention (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether use relative position encoding.
|
| 73 |
+
max_relative_positions (`int`, *optional*, defaults to 1):
|
| 74 |
+
The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value
|
| 75 |
+
as `max_position_embeddings`.
|
| 76 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 77 |
+
The value used to pad input_ids.
|
| 78 |
+
position_biased_input (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether add absolute position embedding to content embedding.
|
| 80 |
+
pos_att_type (`List[str]`, *optional*):
|
| 81 |
+
The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
|
| 82 |
+
`["p2c", "c2p"]`.
|
| 83 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 84 |
+
The epsilon used by the layer normalization layers.
|
| 85 |
+
legacy (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly
|
| 87 |
+
for mask infilling tasks.
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
>>> from transformers import DebertaConfig, DebertaModel
|
| 93 |
+
|
| 94 |
+
>>> # Initializing a DeBERTa microsoft/deberta-base style configuration
|
| 95 |
+
>>> configuration = DebertaConfig()
|
| 96 |
+
|
| 97 |
+
>>> # Initializing a model (with random weights) from the microsoft/deberta-base style configuration
|
| 98 |
+
>>> model = DebertaModel(configuration)
|
| 99 |
+
|
| 100 |
+
>>> # Accessing the model configuration
|
| 101 |
+
>>> configuration = model.config
|
| 102 |
+
```"""
|
| 103 |
+
|
| 104 |
+
model_type = "deberta"
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size=50265,
|
| 109 |
+
hidden_size=768,
|
| 110 |
+
num_hidden_layers=12,
|
| 111 |
+
num_attention_heads=12,
|
| 112 |
+
intermediate_size=3072,
|
| 113 |
+
hidden_act="gelu",
|
| 114 |
+
hidden_dropout_prob=0.1,
|
| 115 |
+
attention_probs_dropout_prob=0.1,
|
| 116 |
+
max_position_embeddings=512,
|
| 117 |
+
type_vocab_size=0,
|
| 118 |
+
initializer_range=0.02,
|
| 119 |
+
layer_norm_eps=1e-7,
|
| 120 |
+
relative_attention=False,
|
| 121 |
+
max_relative_positions=-1,
|
| 122 |
+
pad_token_id=0,
|
| 123 |
+
position_biased_input=True,
|
| 124 |
+
pos_att_type=None,
|
| 125 |
+
pooler_dropout=0,
|
| 126 |
+
pooler_hidden_act="gelu",
|
| 127 |
+
legacy=True,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
super().__init__(**kwargs)
|
| 131 |
+
|
| 132 |
+
self.hidden_size = hidden_size
|
| 133 |
+
self.num_hidden_layers = num_hidden_layers
|
| 134 |
+
self.num_attention_heads = num_attention_heads
|
| 135 |
+
self.intermediate_size = intermediate_size
|
| 136 |
+
self.hidden_act = hidden_act
|
| 137 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 138 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 139 |
+
self.max_position_embeddings = max_position_embeddings
|
| 140 |
+
self.type_vocab_size = type_vocab_size
|
| 141 |
+
self.initializer_range = initializer_range
|
| 142 |
+
self.relative_attention = relative_attention
|
| 143 |
+
self.max_relative_positions = max_relative_positions
|
| 144 |
+
self.pad_token_id = pad_token_id
|
| 145 |
+
self.position_biased_input = position_biased_input
|
| 146 |
+
|
| 147 |
+
# Backwards compatibility
|
| 148 |
+
if isinstance(pos_att_type, str):
|
| 149 |
+
pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
|
| 150 |
+
|
| 151 |
+
self.pos_att_type = pos_att_type
|
| 152 |
+
self.vocab_size = vocab_size
|
| 153 |
+
self.layer_norm_eps = layer_norm_eps
|
| 154 |
+
|
| 155 |
+
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
|
| 156 |
+
self.pooler_dropout = pooler_dropout
|
| 157 |
+
self.pooler_hidden_act = pooler_hidden_act
|
| 158 |
+
self.legacy = legacy
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Copied from transformers.models.deberta_v2.configuration_deberta_v2.DebertaV2OnnxConfig
|
| 162 |
+
class DebertaOnnxConfig(OnnxConfig):
|
| 163 |
+
@property
|
| 164 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 165 |
+
if self.task == "multiple-choice":
|
| 166 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 167 |
+
else:
|
| 168 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 169 |
+
if self._config.type_vocab_size > 0:
|
| 170 |
+
return OrderedDict(
|
| 171 |
+
[("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def default_onnx_opset(self) -> int:
|
| 178 |
+
return 12
|
| 179 |
+
|
| 180 |
+
def generate_dummy_inputs(
|
| 181 |
+
self,
|
| 182 |
+
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
| 183 |
+
batch_size: int = -1,
|
| 184 |
+
seq_length: int = -1,
|
| 185 |
+
num_choices: int = -1,
|
| 186 |
+
is_pair: bool = False,
|
| 187 |
+
framework: Optional["TensorType"] = None,
|
| 188 |
+
num_channels: int = 3,
|
| 189 |
+
image_width: int = 40,
|
| 190 |
+
image_height: int = 40,
|
| 191 |
+
tokenizer: "PreTrainedTokenizerBase" = None,
|
| 192 |
+
) -> Mapping[str, Any]:
|
| 193 |
+
dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
|
| 194 |
+
if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
|
| 195 |
+
del dummy_inputs["token_type_ids"]
|
| 196 |
+
return dummy_inputs
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
__all__ = ["DebertaConfig", "DebertaOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/deberta/modeling_deberta.py
ADDED
|
@@ -0,0 +1,1352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Microsoft and the Hugging Face Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DeBERTa model."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_outputs import (
|
| 26 |
+
BaseModelOutput,
|
| 27 |
+
MaskedLMOutput,
|
| 28 |
+
QuestionAnsweringModelOutput,
|
| 29 |
+
SequenceClassifierOutput,
|
| 30 |
+
TokenClassifierOutput,
|
| 31 |
+
)
|
| 32 |
+
from ...modeling_utils import PreTrainedModel
|
| 33 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 34 |
+
from .configuration_deberta import DebertaConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
_CONFIG_FOR_DOC = "DebertaConfig"
|
| 39 |
+
_CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
|
| 40 |
+
|
| 41 |
+
# Masked LM docstring
|
| 42 |
+
_CHECKPOINT_FOR_MASKED_LM = "lsanochkin/deberta-large-feedback"
|
| 43 |
+
_MASKED_LM_EXPECTED_OUTPUT = "' Paris'"
|
| 44 |
+
_MASKED_LM_EXPECTED_LOSS = "0.54"
|
| 45 |
+
|
| 46 |
+
# QuestionAnswering docstring
|
| 47 |
+
_CHECKPOINT_FOR_QA = "Palak/microsoft_deberta-large_squad"
|
| 48 |
+
_QA_EXPECTED_OUTPUT = "' a nice puppet'"
|
| 49 |
+
_QA_EXPECTED_LOSS = 0.14
|
| 50 |
+
_QA_TARGET_START_INDEX = 12
|
| 51 |
+
_QA_TARGET_END_INDEX = 14
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DebertaLayerNorm(nn.Module):
|
| 55 |
+
"""LayerNorm module in the TF style (epsilon inside the square root)."""
|
| 56 |
+
|
| 57 |
+
def __init__(self, size, eps=1e-12):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.weight = nn.Parameter(torch.ones(size))
|
| 60 |
+
self.bias = nn.Parameter(torch.zeros(size))
|
| 61 |
+
self.variance_epsilon = eps
|
| 62 |
+
|
| 63 |
+
def forward(self, hidden_states):
|
| 64 |
+
input_type = hidden_states.dtype
|
| 65 |
+
hidden_states = hidden_states.float()
|
| 66 |
+
mean = hidden_states.mean(-1, keepdim=True)
|
| 67 |
+
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
| 68 |
+
hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
|
| 69 |
+
hidden_states = hidden_states.to(input_type)
|
| 70 |
+
y = self.weight * hidden_states + self.bias
|
| 71 |
+
return y
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class DebertaSelfOutput(nn.Module):
|
| 75 |
+
def __init__(self, config):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 78 |
+
self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 79 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 80 |
+
|
| 81 |
+
def forward(self, hidden_states, input_tensor):
|
| 82 |
+
hidden_states = self.dense(hidden_states)
|
| 83 |
+
hidden_states = self.dropout(hidden_states)
|
| 84 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 85 |
+
return hidden_states
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.jit.script
|
| 89 |
+
def build_relative_position(query_layer, key_layer):
|
| 90 |
+
"""
|
| 91 |
+
Build relative position according to the query and key
|
| 92 |
+
|
| 93 |
+
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
|
| 94 |
+
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
|
| 95 |
+
P_k\\)
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
query_size (int): the length of query
|
| 99 |
+
key_size (int): the length of key
|
| 100 |
+
|
| 101 |
+
Return:
|
| 102 |
+
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
query_size = query_layer.size(-2)
|
| 107 |
+
key_size = key_layer.size(-2)
|
| 108 |
+
|
| 109 |
+
q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
|
| 110 |
+
k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
|
| 111 |
+
rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
|
| 112 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 113 |
+
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
| 114 |
+
return rel_pos_ids
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@torch.jit.script
|
| 118 |
+
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
| 119 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@torch.jit.script
|
| 123 |
+
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
| 124 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@torch.jit.script
|
| 128 |
+
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
| 129 |
+
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
###### To support a general trace, we have to define these operation as they use python objects (sizes) ##################
|
| 133 |
+
# which are not supported by torch.jit.trace.
|
| 134 |
+
# Full credits to @Szustarol
|
| 135 |
+
@torch.jit.script
|
| 136 |
+
def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
|
| 137 |
+
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@torch.jit.script
|
| 141 |
+
def build_rpos(query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
|
| 142 |
+
if query_layer.size(-2) != key_layer.size(-2):
|
| 143 |
+
return build_relative_position(query_layer, key_layer)
|
| 144 |
+
else:
|
| 145 |
+
return relative_pos
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@torch.jit.script
|
| 149 |
+
def compute_attention_span(query_layer: torch.Tensor, key_layer: torch.Tensor, max_relative_positions: int):
|
| 150 |
+
return torch.tensor(min(max(query_layer.size(-2), key_layer.size(-2)), max_relative_positions))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@torch.jit.script
|
| 154 |
+
def uneven_size_corrected(p2c_att, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
|
| 155 |
+
if query_layer.size(-2) != key_layer.size(-2):
|
| 156 |
+
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
|
| 157 |
+
return torch.gather(p2c_att, dim=2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
|
| 158 |
+
else:
|
| 159 |
+
return p2c_att
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
########################################################################################################################
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DisentangledSelfAttention(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
Disentangled self-attention module
|
| 168 |
+
|
| 169 |
+
Parameters:
|
| 170 |
+
config (`str`):
|
| 171 |
+
A model config class instance with the configuration to build a new model. The schema is similar to
|
| 172 |
+
*BertConfig*, for more details, please refer [`DebertaConfig`]
|
| 173 |
+
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, config):
|
| 177 |
+
super().__init__()
|
| 178 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 181 |
+
f"heads ({config.num_attention_heads})"
|
| 182 |
+
)
|
| 183 |
+
self.num_attention_heads = config.num_attention_heads
|
| 184 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 185 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 186 |
+
self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
|
| 187 |
+
self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
|
| 188 |
+
self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
|
| 189 |
+
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
|
| 190 |
+
|
| 191 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 192 |
+
self.talking_head = getattr(config, "talking_head", False)
|
| 193 |
+
|
| 194 |
+
if self.talking_head:
|
| 195 |
+
self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
|
| 196 |
+
self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
|
| 197 |
+
else:
|
| 198 |
+
self.head_logits_proj = None
|
| 199 |
+
self.head_weights_proj = None
|
| 200 |
+
|
| 201 |
+
if self.relative_attention:
|
| 202 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 203 |
+
if self.max_relative_positions < 1:
|
| 204 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 205 |
+
self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 206 |
+
|
| 207 |
+
if "c2p" in self.pos_att_type:
|
| 208 |
+
self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
| 209 |
+
if "p2c" in self.pos_att_type:
|
| 210 |
+
self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 211 |
+
|
| 212 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 213 |
+
|
| 214 |
+
def transpose_for_scores(self, x):
|
| 215 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
|
| 216 |
+
x = x.view(new_x_shape)
|
| 217 |
+
return x.permute(0, 2, 1, 3)
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
hidden_states: torch.Tensor,
|
| 222 |
+
attention_mask: torch.Tensor,
|
| 223 |
+
output_attentions: bool = False,
|
| 224 |
+
query_states: Optional[torch.Tensor] = None,
|
| 225 |
+
relative_pos: Optional[torch.Tensor] = None,
|
| 226 |
+
rel_embeddings: Optional[torch.Tensor] = None,
|
| 227 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 228 |
+
"""
|
| 229 |
+
Call the module
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
hidden_states (`torch.FloatTensor`):
|
| 233 |
+
Input states to the module usually the output from previous layer, it will be the Q,K and V in
|
| 234 |
+
*Attention(Q,K,V)*
|
| 235 |
+
|
| 236 |
+
attention_mask (`torch.BoolTensor`):
|
| 237 |
+
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
|
| 238 |
+
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
|
| 239 |
+
th token.
|
| 240 |
+
|
| 241 |
+
output_attentions (`bool`, *optional*):
|
| 242 |
+
Whether return the attention matrix.
|
| 243 |
+
|
| 244 |
+
query_states (`torch.FloatTensor`, *optional*):
|
| 245 |
+
The *Q* state in *Attention(Q,K,V)*.
|
| 246 |
+
|
| 247 |
+
relative_pos (`torch.LongTensor`):
|
| 248 |
+
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
|
| 249 |
+
values ranging in [*-max_relative_positions*, *max_relative_positions*].
|
| 250 |
+
|
| 251 |
+
rel_embeddings (`torch.FloatTensor`):
|
| 252 |
+
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
|
| 253 |
+
\\text{max_relative_positions}\\), *hidden_size*].
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
"""
|
| 257 |
+
if query_states is None:
|
| 258 |
+
qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
|
| 259 |
+
query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
|
| 260 |
+
else:
|
| 261 |
+
ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
|
| 262 |
+
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
|
| 263 |
+
q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype))
|
| 264 |
+
k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype))
|
| 265 |
+
v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype))
|
| 266 |
+
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
|
| 267 |
+
|
| 268 |
+
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
|
| 269 |
+
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
|
| 270 |
+
|
| 271 |
+
rel_att: int = 0
|
| 272 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 273 |
+
scale_factor = 1 + len(self.pos_att_type)
|
| 274 |
+
scale = scaled_size_sqrt(query_layer, scale_factor)
|
| 275 |
+
query_layer = query_layer / scale.to(dtype=query_layer.dtype)
|
| 276 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 277 |
+
|
| 278 |
+
if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
|
| 279 |
+
rel_embeddings = self.pos_dropout(rel_embeddings)
|
| 280 |
+
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
|
| 281 |
+
|
| 282 |
+
if rel_att is not None:
|
| 283 |
+
attention_scores = attention_scores + rel_att
|
| 284 |
+
|
| 285 |
+
# bxhxlxd
|
| 286 |
+
if self.head_logits_proj is not None:
|
| 287 |
+
attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 288 |
+
|
| 289 |
+
attention_mask = attention_mask.bool()
|
| 290 |
+
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
|
| 291 |
+
# bsz x height x length x dimension
|
| 292 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 293 |
+
|
| 294 |
+
attention_probs = self.dropout(attention_probs)
|
| 295 |
+
if self.head_weights_proj is not None:
|
| 296 |
+
attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 297 |
+
|
| 298 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 299 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 300 |
+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
| 301 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 302 |
+
if not output_attentions:
|
| 303 |
+
return (context_layer, None)
|
| 304 |
+
return (context_layer, attention_probs)
|
| 305 |
+
|
| 306 |
+
def disentangled_att_bias(
|
| 307 |
+
self,
|
| 308 |
+
query_layer: torch.Tensor,
|
| 309 |
+
key_layer: torch.Tensor,
|
| 310 |
+
relative_pos: torch.Tensor,
|
| 311 |
+
rel_embeddings: torch.Tensor,
|
| 312 |
+
scale_factor: int,
|
| 313 |
+
):
|
| 314 |
+
if relative_pos is None:
|
| 315 |
+
relative_pos = build_relative_position(query_layer, key_layer, query_layer.device)
|
| 316 |
+
if relative_pos.dim() == 2:
|
| 317 |
+
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
| 318 |
+
elif relative_pos.dim() == 3:
|
| 319 |
+
relative_pos = relative_pos.unsqueeze(1)
|
| 320 |
+
# bxhxqxk
|
| 321 |
+
elif relative_pos.dim() != 4:
|
| 322 |
+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
|
| 323 |
+
|
| 324 |
+
att_span = compute_attention_span(query_layer, key_layer, self.max_relative_positions)
|
| 325 |
+
relative_pos = relative_pos.long()
|
| 326 |
+
rel_embeddings = rel_embeddings[
|
| 327 |
+
self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
|
| 328 |
+
].unsqueeze(0)
|
| 329 |
+
|
| 330 |
+
score = 0
|
| 331 |
+
|
| 332 |
+
# content->position
|
| 333 |
+
if "c2p" in self.pos_att_type:
|
| 334 |
+
pos_key_layer = self.pos_proj(rel_embeddings)
|
| 335 |
+
pos_key_layer = self.transpose_for_scores(pos_key_layer)
|
| 336 |
+
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
|
| 337 |
+
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
|
| 338 |
+
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
|
| 339 |
+
score += c2p_att
|
| 340 |
+
|
| 341 |
+
# position->content
|
| 342 |
+
if "p2c" in self.pos_att_type:
|
| 343 |
+
pos_query_layer = self.pos_q_proj(rel_embeddings)
|
| 344 |
+
pos_query_layer = self.transpose_for_scores(pos_query_layer)
|
| 345 |
+
pos_query_layer /= scaled_size_sqrt(pos_query_layer, scale_factor)
|
| 346 |
+
r_pos = build_rpos(
|
| 347 |
+
query_layer,
|
| 348 |
+
key_layer,
|
| 349 |
+
relative_pos,
|
| 350 |
+
)
|
| 351 |
+
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
| 352 |
+
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype))
|
| 353 |
+
p2c_att = torch.gather(
|
| 354 |
+
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
|
| 355 |
+
).transpose(-1, -2)
|
| 356 |
+
|
| 357 |
+
p2c_att = uneven_size_corrected(p2c_att, query_layer, key_layer, relative_pos)
|
| 358 |
+
score += p2c_att
|
| 359 |
+
|
| 360 |
+
return score
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class DebertaEmbeddings(nn.Module):
|
| 364 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 365 |
+
|
| 366 |
+
def __init__(self, config):
|
| 367 |
+
super().__init__()
|
| 368 |
+
pad_token_id = getattr(config, "pad_token_id", 0)
|
| 369 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 370 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
|
| 371 |
+
|
| 372 |
+
self.position_biased_input = getattr(config, "position_biased_input", True)
|
| 373 |
+
if not self.position_biased_input:
|
| 374 |
+
self.position_embeddings = None
|
| 375 |
+
else:
|
| 376 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
|
| 377 |
+
|
| 378 |
+
if config.type_vocab_size > 0:
|
| 379 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
|
| 380 |
+
else:
|
| 381 |
+
self.token_type_embeddings = None
|
| 382 |
+
|
| 383 |
+
if self.embedding_size != config.hidden_size:
|
| 384 |
+
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
|
| 385 |
+
else:
|
| 386 |
+
self.embed_proj = None
|
| 387 |
+
|
| 388 |
+
self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 389 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 390 |
+
self.config = config
|
| 391 |
+
|
| 392 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 393 |
+
self.register_buffer(
|
| 394 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
|
| 398 |
+
if input_ids is not None:
|
| 399 |
+
input_shape = input_ids.size()
|
| 400 |
+
else:
|
| 401 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 402 |
+
|
| 403 |
+
seq_length = input_shape[1]
|
| 404 |
+
|
| 405 |
+
if position_ids is None:
|
| 406 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 407 |
+
|
| 408 |
+
if token_type_ids is None:
|
| 409 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 410 |
+
|
| 411 |
+
if inputs_embeds is None:
|
| 412 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 413 |
+
|
| 414 |
+
if self.position_embeddings is not None:
|
| 415 |
+
position_embeddings = self.position_embeddings(position_ids.long())
|
| 416 |
+
else:
|
| 417 |
+
position_embeddings = torch.zeros_like(inputs_embeds)
|
| 418 |
+
|
| 419 |
+
embeddings = inputs_embeds
|
| 420 |
+
if self.position_biased_input:
|
| 421 |
+
embeddings = embeddings + position_embeddings
|
| 422 |
+
if self.token_type_embeddings is not None:
|
| 423 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 424 |
+
embeddings = embeddings + token_type_embeddings
|
| 425 |
+
|
| 426 |
+
if self.embed_proj is not None:
|
| 427 |
+
embeddings = self.embed_proj(embeddings)
|
| 428 |
+
|
| 429 |
+
embeddings = self.LayerNorm(embeddings)
|
| 430 |
+
|
| 431 |
+
if mask is not None:
|
| 432 |
+
if mask.dim() != embeddings.dim():
|
| 433 |
+
if mask.dim() == 4:
|
| 434 |
+
mask = mask.squeeze(1).squeeze(1)
|
| 435 |
+
mask = mask.unsqueeze(2)
|
| 436 |
+
mask = mask.to(embeddings.dtype)
|
| 437 |
+
|
| 438 |
+
embeddings = embeddings * mask
|
| 439 |
+
|
| 440 |
+
embeddings = self.dropout(embeddings)
|
| 441 |
+
return embeddings
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class DebertaAttention(nn.Module):
|
| 445 |
+
def __init__(self, config):
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.self = DisentangledSelfAttention(config)
|
| 448 |
+
self.output = DebertaSelfOutput(config)
|
| 449 |
+
self.config = config
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
hidden_states,
|
| 454 |
+
attention_mask,
|
| 455 |
+
output_attentions: bool = False,
|
| 456 |
+
query_states=None,
|
| 457 |
+
relative_pos=None,
|
| 458 |
+
rel_embeddings=None,
|
| 459 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 460 |
+
self_output, att_matrix = self.self(
|
| 461 |
+
hidden_states,
|
| 462 |
+
attention_mask,
|
| 463 |
+
output_attentions,
|
| 464 |
+
query_states=query_states,
|
| 465 |
+
relative_pos=relative_pos,
|
| 466 |
+
rel_embeddings=rel_embeddings,
|
| 467 |
+
)
|
| 468 |
+
if query_states is None:
|
| 469 |
+
query_states = hidden_states
|
| 470 |
+
attention_output = self.output(self_output, query_states)
|
| 471 |
+
|
| 472 |
+
if output_attentions:
|
| 473 |
+
return (attention_output, att_matrix)
|
| 474 |
+
else:
|
| 475 |
+
return (attention_output, None)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
|
| 479 |
+
class DebertaIntermediate(nn.Module):
|
| 480 |
+
def __init__(self, config):
|
| 481 |
+
super().__init__()
|
| 482 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 483 |
+
if isinstance(config.hidden_act, str):
|
| 484 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 485 |
+
else:
|
| 486 |
+
self.intermediate_act_fn = config.hidden_act
|
| 487 |
+
|
| 488 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 489 |
+
hidden_states = self.dense(hidden_states)
|
| 490 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 491 |
+
return hidden_states
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class DebertaOutput(nn.Module):
|
| 495 |
+
def __init__(self, config):
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 498 |
+
self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 499 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 500 |
+
self.config = config
|
| 501 |
+
|
| 502 |
+
def forward(self, hidden_states, input_tensor):
|
| 503 |
+
hidden_states = self.dense(hidden_states)
|
| 504 |
+
hidden_states = self.dropout(hidden_states)
|
| 505 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 506 |
+
return hidden_states
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
class DebertaLayer(nn.Module):
|
| 510 |
+
def __init__(self, config):
|
| 511 |
+
super().__init__()
|
| 512 |
+
self.attention = DebertaAttention(config)
|
| 513 |
+
self.intermediate = DebertaIntermediate(config)
|
| 514 |
+
self.output = DebertaOutput(config)
|
| 515 |
+
|
| 516 |
+
def forward(
|
| 517 |
+
self,
|
| 518 |
+
hidden_states,
|
| 519 |
+
attention_mask,
|
| 520 |
+
query_states=None,
|
| 521 |
+
relative_pos=None,
|
| 522 |
+
rel_embeddings=None,
|
| 523 |
+
output_attentions: bool = False,
|
| 524 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 525 |
+
attention_output, att_matrix = self.attention(
|
| 526 |
+
hidden_states,
|
| 527 |
+
attention_mask,
|
| 528 |
+
output_attentions=output_attentions,
|
| 529 |
+
query_states=query_states,
|
| 530 |
+
relative_pos=relative_pos,
|
| 531 |
+
rel_embeddings=rel_embeddings,
|
| 532 |
+
)
|
| 533 |
+
intermediate_output = self.intermediate(attention_output)
|
| 534 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 535 |
+
|
| 536 |
+
if output_attentions:
|
| 537 |
+
return (layer_output, att_matrix)
|
| 538 |
+
else:
|
| 539 |
+
return (layer_output, None)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class DebertaEncoder(nn.Module):
|
| 543 |
+
"""Modified BertEncoder with relative position bias support"""
|
| 544 |
+
|
| 545 |
+
def __init__(self, config):
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
|
| 548 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 549 |
+
if self.relative_attention:
|
| 550 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 551 |
+
if self.max_relative_positions < 1:
|
| 552 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 553 |
+
self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
|
| 554 |
+
self.gradient_checkpointing = False
|
| 555 |
+
|
| 556 |
+
def get_rel_embedding(self):
|
| 557 |
+
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
|
| 558 |
+
return rel_embeddings
|
| 559 |
+
|
| 560 |
+
def get_attention_mask(self, attention_mask):
|
| 561 |
+
if attention_mask.dim() <= 2:
|
| 562 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 563 |
+
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
| 564 |
+
elif attention_mask.dim() == 3:
|
| 565 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 566 |
+
|
| 567 |
+
return attention_mask
|
| 568 |
+
|
| 569 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 570 |
+
if self.relative_attention and relative_pos is None:
|
| 571 |
+
if query_states is not None:
|
| 572 |
+
relative_pos = build_relative_position(query_states, hidden_states)
|
| 573 |
+
else:
|
| 574 |
+
relative_pos = build_relative_position(hidden_states, hidden_states)
|
| 575 |
+
return relative_pos
|
| 576 |
+
|
| 577 |
+
def forward(
|
| 578 |
+
self,
|
| 579 |
+
hidden_states: torch.Tensor,
|
| 580 |
+
attention_mask: torch.Tensor,
|
| 581 |
+
output_hidden_states: bool = True,
|
| 582 |
+
output_attentions: bool = False,
|
| 583 |
+
query_states=None,
|
| 584 |
+
relative_pos=None,
|
| 585 |
+
return_dict: bool = True,
|
| 586 |
+
):
|
| 587 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 588 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 589 |
+
|
| 590 |
+
all_hidden_states: Optional[Tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None
|
| 591 |
+
all_attentions = () if output_attentions else None
|
| 592 |
+
|
| 593 |
+
next_kv = hidden_states
|
| 594 |
+
|
| 595 |
+
rel_embeddings = self.get_rel_embedding()
|
| 596 |
+
for i, layer_module in enumerate(self.layer):
|
| 597 |
+
if self.gradient_checkpointing and self.training:
|
| 598 |
+
hidden_states, att_m = self._gradient_checkpointing_func(
|
| 599 |
+
layer_module.__call__,
|
| 600 |
+
next_kv,
|
| 601 |
+
attention_mask,
|
| 602 |
+
query_states,
|
| 603 |
+
relative_pos,
|
| 604 |
+
rel_embeddings,
|
| 605 |
+
output_attentions,
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
hidden_states, att_m = layer_module(
|
| 609 |
+
next_kv,
|
| 610 |
+
attention_mask,
|
| 611 |
+
query_states=query_states,
|
| 612 |
+
relative_pos=relative_pos,
|
| 613 |
+
rel_embeddings=rel_embeddings,
|
| 614 |
+
output_attentions=output_attentions,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
if output_hidden_states:
|
| 618 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 619 |
+
|
| 620 |
+
if query_states is not None:
|
| 621 |
+
query_states = hidden_states
|
| 622 |
+
else:
|
| 623 |
+
next_kv = hidden_states
|
| 624 |
+
|
| 625 |
+
if output_attentions:
|
| 626 |
+
all_attentions = all_attentions + (att_m,)
|
| 627 |
+
|
| 628 |
+
if not return_dict:
|
| 629 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 630 |
+
return BaseModelOutput(
|
| 631 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class DebertaPreTrainedModel(PreTrainedModel):
|
| 636 |
+
"""
|
| 637 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 638 |
+
models.
|
| 639 |
+
"""
|
| 640 |
+
|
| 641 |
+
config_class = DebertaConfig
|
| 642 |
+
base_model_prefix = "deberta"
|
| 643 |
+
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
| 644 |
+
supports_gradient_checkpointing = True
|
| 645 |
+
|
| 646 |
+
def _init_weights(self, module):
|
| 647 |
+
"""Initialize the weights."""
|
| 648 |
+
if isinstance(module, nn.Linear):
|
| 649 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 650 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 651 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 652 |
+
if module.bias is not None:
|
| 653 |
+
module.bias.data.zero_()
|
| 654 |
+
elif isinstance(module, nn.Embedding):
|
| 655 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 656 |
+
if module.padding_idx is not None:
|
| 657 |
+
module.weight.data[module.padding_idx].zero_()
|
| 658 |
+
elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)):
|
| 659 |
+
module.weight.data.fill_(1.0)
|
| 660 |
+
module.bias.data.zero_()
|
| 661 |
+
elif isinstance(module, DisentangledSelfAttention):
|
| 662 |
+
module.q_bias.data.zero_()
|
| 663 |
+
module.v_bias.data.zero_()
|
| 664 |
+
elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
|
| 665 |
+
module.bias.data.zero_()
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
DEBERTA_START_DOCSTRING = r"""
|
| 669 |
+
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
| 670 |
+
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
|
| 671 |
+
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
|
| 672 |
+
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
|
| 673 |
+
|
| 674 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 675 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 676 |
+
and behavior.
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
Parameters:
|
| 680 |
+
config ([`DebertaConfig`]): Model configuration class with all the parameters of the model.
|
| 681 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 682 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
DEBERTA_INPUTS_DOCSTRING = r"""
|
| 686 |
+
Args:
|
| 687 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 688 |
+
Indices of input sequence tokens in the vocabulary.
|
| 689 |
+
|
| 690 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 691 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 692 |
+
|
| 693 |
+
[What are input IDs?](../glossary#input-ids)
|
| 694 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 695 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 696 |
+
|
| 697 |
+
- 1 for tokens that are **not masked**,
|
| 698 |
+
- 0 for tokens that are **masked**.
|
| 699 |
+
|
| 700 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 701 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 702 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 703 |
+
1]`:
|
| 704 |
+
|
| 705 |
+
- 0 corresponds to a *sentence A* token,
|
| 706 |
+
- 1 corresponds to a *sentence B* token.
|
| 707 |
+
|
| 708 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 709 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 710 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 711 |
+
config.max_position_embeddings - 1]`.
|
| 712 |
+
|
| 713 |
+
[What are position IDs?](../glossary#position-ids)
|
| 714 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 715 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 716 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 717 |
+
model's internal embedding lookup matrix.
|
| 718 |
+
output_attentions (`bool`, *optional*):
|
| 719 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 720 |
+
tensors for more detail.
|
| 721 |
+
output_hidden_states (`bool`, *optional*):
|
| 722 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 723 |
+
more detail.
|
| 724 |
+
return_dict (`bool`, *optional*):
|
| 725 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
@add_start_docstrings(
|
| 730 |
+
"The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
| 731 |
+
DEBERTA_START_DOCSTRING,
|
| 732 |
+
)
|
| 733 |
+
class DebertaModel(DebertaPreTrainedModel):
|
| 734 |
+
def __init__(self, config):
|
| 735 |
+
super().__init__(config)
|
| 736 |
+
|
| 737 |
+
self.embeddings = DebertaEmbeddings(config)
|
| 738 |
+
self.encoder = DebertaEncoder(config)
|
| 739 |
+
self.z_steps = 0
|
| 740 |
+
self.config = config
|
| 741 |
+
# Initialize weights and apply final processing
|
| 742 |
+
self.post_init()
|
| 743 |
+
|
| 744 |
+
def get_input_embeddings(self):
|
| 745 |
+
return self.embeddings.word_embeddings
|
| 746 |
+
|
| 747 |
+
def set_input_embeddings(self, new_embeddings):
|
| 748 |
+
self.embeddings.word_embeddings = new_embeddings
|
| 749 |
+
|
| 750 |
+
def _prune_heads(self, heads_to_prune):
|
| 751 |
+
"""
|
| 752 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 753 |
+
class PreTrainedModel
|
| 754 |
+
"""
|
| 755 |
+
raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
|
| 756 |
+
|
| 757 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 758 |
+
@add_code_sample_docstrings(
|
| 759 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 760 |
+
output_type=BaseModelOutput,
|
| 761 |
+
config_class=_CONFIG_FOR_DOC,
|
| 762 |
+
)
|
| 763 |
+
def forward(
|
| 764 |
+
self,
|
| 765 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 766 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 767 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 768 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 769 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 770 |
+
output_attentions: Optional[bool] = None,
|
| 771 |
+
output_hidden_states: Optional[bool] = None,
|
| 772 |
+
return_dict: Optional[bool] = None,
|
| 773 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 774 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 775 |
+
output_hidden_states = (
|
| 776 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 777 |
+
)
|
| 778 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 779 |
+
|
| 780 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 781 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 782 |
+
elif input_ids is not None:
|
| 783 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 784 |
+
input_shape = input_ids.size()
|
| 785 |
+
elif inputs_embeds is not None:
|
| 786 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 787 |
+
else:
|
| 788 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 789 |
+
|
| 790 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 791 |
+
|
| 792 |
+
if attention_mask is None:
|
| 793 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 794 |
+
if token_type_ids is None:
|
| 795 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 796 |
+
|
| 797 |
+
embedding_output = self.embeddings(
|
| 798 |
+
input_ids=input_ids,
|
| 799 |
+
token_type_ids=token_type_ids,
|
| 800 |
+
position_ids=position_ids,
|
| 801 |
+
mask=attention_mask,
|
| 802 |
+
inputs_embeds=inputs_embeds,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
encoder_outputs = self.encoder(
|
| 806 |
+
embedding_output,
|
| 807 |
+
attention_mask,
|
| 808 |
+
output_hidden_states=True,
|
| 809 |
+
output_attentions=output_attentions,
|
| 810 |
+
return_dict=return_dict,
|
| 811 |
+
)
|
| 812 |
+
encoded_layers = encoder_outputs[1]
|
| 813 |
+
|
| 814 |
+
if self.z_steps > 1:
|
| 815 |
+
hidden_states = encoded_layers[-2]
|
| 816 |
+
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
| 817 |
+
query_states = encoded_layers[-1]
|
| 818 |
+
rel_embeddings = self.encoder.get_rel_embedding()
|
| 819 |
+
attention_mask = self.encoder.get_attention_mask(attention_mask)
|
| 820 |
+
rel_pos = self.encoder.get_rel_pos(embedding_output)
|
| 821 |
+
for layer in layers[1:]:
|
| 822 |
+
query_states = layer(
|
| 823 |
+
hidden_states,
|
| 824 |
+
attention_mask,
|
| 825 |
+
output_attentions=False,
|
| 826 |
+
query_states=query_states,
|
| 827 |
+
relative_pos=rel_pos,
|
| 828 |
+
rel_embeddings=rel_embeddings,
|
| 829 |
+
)
|
| 830 |
+
encoded_layers.append(query_states)
|
| 831 |
+
|
| 832 |
+
sequence_output = encoded_layers[-1]
|
| 833 |
+
|
| 834 |
+
if not return_dict:
|
| 835 |
+
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
|
| 836 |
+
|
| 837 |
+
return BaseModelOutput(
|
| 838 |
+
last_hidden_state=sequence_output,
|
| 839 |
+
hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
|
| 840 |
+
attentions=encoder_outputs.attentions,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
class LegacyDebertaPredictionHeadTransform(nn.Module):
|
| 845 |
+
def __init__(self, config):
|
| 846 |
+
super().__init__()
|
| 847 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 848 |
+
|
| 849 |
+
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
|
| 850 |
+
if isinstance(config.hidden_act, str):
|
| 851 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 852 |
+
else:
|
| 853 |
+
self.transform_act_fn = config.hidden_act
|
| 854 |
+
self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
|
| 855 |
+
|
| 856 |
+
def forward(self, hidden_states):
|
| 857 |
+
hidden_states = self.dense(hidden_states)
|
| 858 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 859 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 860 |
+
return hidden_states
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
class LegacyDebertaLMPredictionHead(nn.Module):
|
| 864 |
+
def __init__(self, config):
|
| 865 |
+
super().__init__()
|
| 866 |
+
self.transform = LegacyDebertaPredictionHeadTransform(config)
|
| 867 |
+
|
| 868 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 869 |
+
# The output weights are the same as the input embeddings, but there is
|
| 870 |
+
# an output-only bias for each token.
|
| 871 |
+
self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
|
| 872 |
+
|
| 873 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 874 |
+
|
| 875 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 876 |
+
self.decoder.bias = self.bias
|
| 877 |
+
|
| 878 |
+
def _tie_weights(self):
|
| 879 |
+
self.decoder.bias = self.bias
|
| 880 |
+
|
| 881 |
+
def forward(self, hidden_states):
|
| 882 |
+
hidden_states = self.transform(hidden_states)
|
| 883 |
+
hidden_states = self.decoder(hidden_states)
|
| 884 |
+
return hidden_states
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LegacyDeberta
|
| 888 |
+
class LegacyDebertaOnlyMLMHead(nn.Module):
|
| 889 |
+
def __init__(self, config):
|
| 890 |
+
super().__init__()
|
| 891 |
+
self.predictions = LegacyDebertaLMPredictionHead(config)
|
| 892 |
+
|
| 893 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 894 |
+
prediction_scores = self.predictions(sequence_output)
|
| 895 |
+
return prediction_scores
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class DebertaLMPredictionHead(nn.Module):
|
| 899 |
+
"""https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
|
| 900 |
+
|
| 901 |
+
def __init__(self, config):
|
| 902 |
+
super().__init__()
|
| 903 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 904 |
+
|
| 905 |
+
if isinstance(config.hidden_act, str):
|
| 906 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 907 |
+
else:
|
| 908 |
+
self.transform_act_fn = config.hidden_act
|
| 909 |
+
|
| 910 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
|
| 911 |
+
|
| 912 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 913 |
+
|
| 914 |
+
# note that the input embeddings must be passed as an argument
|
| 915 |
+
def forward(self, hidden_states, word_embeddings):
|
| 916 |
+
hidden_states = self.dense(hidden_states)
|
| 917 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 918 |
+
hidden_states = self.LayerNorm(
|
| 919 |
+
hidden_states
|
| 920 |
+
) # original used MaskedLayerNorm, but passed no mask. This is equivalent.
|
| 921 |
+
hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
|
| 922 |
+
return hidden_states
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
class DebertaOnlyMLMHead(nn.Module):
|
| 926 |
+
def __init__(self, config):
|
| 927 |
+
super().__init__()
|
| 928 |
+
self.lm_head = DebertaLMPredictionHead(config)
|
| 929 |
+
|
| 930 |
+
# note that the input embeddings must be passed as an argument
|
| 931 |
+
def forward(self, sequence_output, word_embeddings):
|
| 932 |
+
prediction_scores = self.lm_head(sequence_output, word_embeddings)
|
| 933 |
+
return prediction_scores
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
| 937 |
+
class DebertaForMaskedLM(DebertaPreTrainedModel):
|
| 938 |
+
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
| 939 |
+
|
| 940 |
+
def __init__(self, config):
|
| 941 |
+
super().__init__(config)
|
| 942 |
+
self.legacy = config.legacy
|
| 943 |
+
self.deberta = DebertaModel(config)
|
| 944 |
+
if self.legacy:
|
| 945 |
+
self.cls = LegacyDebertaOnlyMLMHead(config)
|
| 946 |
+
else:
|
| 947 |
+
self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"]
|
| 948 |
+
self.lm_predictions = DebertaOnlyMLMHead(config)
|
| 949 |
+
|
| 950 |
+
# Initialize weights and apply final processing
|
| 951 |
+
self.post_init()
|
| 952 |
+
|
| 953 |
+
def get_output_embeddings(self):
|
| 954 |
+
if self.legacy:
|
| 955 |
+
return self.cls.predictions.decoder
|
| 956 |
+
else:
|
| 957 |
+
return self.lm_predictions.lm_head.dense
|
| 958 |
+
|
| 959 |
+
def set_output_embeddings(self, new_embeddings):
|
| 960 |
+
if self.legacy:
|
| 961 |
+
self.cls.predictions.decoder = new_embeddings
|
| 962 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 963 |
+
else:
|
| 964 |
+
self.lm_predictions.lm_head.dense = new_embeddings
|
| 965 |
+
self.lm_predictions.lm_head.bias = new_embeddings.bias
|
| 966 |
+
|
| 967 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 968 |
+
@add_code_sample_docstrings(
|
| 969 |
+
checkpoint=_CHECKPOINT_FOR_MASKED_LM,
|
| 970 |
+
output_type=MaskedLMOutput,
|
| 971 |
+
config_class=_CONFIG_FOR_DOC,
|
| 972 |
+
mask="[MASK]",
|
| 973 |
+
expected_output=_MASKED_LM_EXPECTED_OUTPUT,
|
| 974 |
+
expected_loss=_MASKED_LM_EXPECTED_LOSS,
|
| 975 |
+
)
|
| 976 |
+
def forward(
|
| 977 |
+
self,
|
| 978 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 979 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 980 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 981 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 982 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 983 |
+
labels: Optional[torch.Tensor] = None,
|
| 984 |
+
output_attentions: Optional[bool] = None,
|
| 985 |
+
output_hidden_states: Optional[bool] = None,
|
| 986 |
+
return_dict: Optional[bool] = None,
|
| 987 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 988 |
+
r"""
|
| 989 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 990 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 991 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 992 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 993 |
+
"""
|
| 994 |
+
|
| 995 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 996 |
+
|
| 997 |
+
outputs = self.deberta(
|
| 998 |
+
input_ids,
|
| 999 |
+
attention_mask=attention_mask,
|
| 1000 |
+
token_type_ids=token_type_ids,
|
| 1001 |
+
position_ids=position_ids,
|
| 1002 |
+
inputs_embeds=inputs_embeds,
|
| 1003 |
+
output_attentions=output_attentions,
|
| 1004 |
+
output_hidden_states=output_hidden_states,
|
| 1005 |
+
return_dict=return_dict,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
sequence_output = outputs[0]
|
| 1009 |
+
if self.legacy:
|
| 1010 |
+
prediction_scores = self.cls(sequence_output)
|
| 1011 |
+
else:
|
| 1012 |
+
prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
|
| 1013 |
+
|
| 1014 |
+
masked_lm_loss = None
|
| 1015 |
+
if labels is not None:
|
| 1016 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1017 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1018 |
+
|
| 1019 |
+
if not return_dict:
|
| 1020 |
+
output = (prediction_scores,) + outputs[1:]
|
| 1021 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1022 |
+
|
| 1023 |
+
return MaskedLMOutput(
|
| 1024 |
+
loss=masked_lm_loss,
|
| 1025 |
+
logits=prediction_scores,
|
| 1026 |
+
hidden_states=outputs.hidden_states,
|
| 1027 |
+
attentions=outputs.attentions,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
class ContextPooler(nn.Module):
|
| 1032 |
+
def __init__(self, config):
|
| 1033 |
+
super().__init__()
|
| 1034 |
+
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
| 1035 |
+
self.dropout = nn.Dropout(config.pooler_dropout)
|
| 1036 |
+
self.config = config
|
| 1037 |
+
|
| 1038 |
+
def forward(self, hidden_states):
|
| 1039 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 1040 |
+
# to the first token.
|
| 1041 |
+
|
| 1042 |
+
context_token = hidden_states[:, 0]
|
| 1043 |
+
context_token = self.dropout(context_token)
|
| 1044 |
+
pooled_output = self.dense(context_token)
|
| 1045 |
+
pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
|
| 1046 |
+
return pooled_output
|
| 1047 |
+
|
| 1048 |
+
@property
|
| 1049 |
+
def output_dim(self):
|
| 1050 |
+
return self.config.hidden_size
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
@add_start_docstrings(
|
| 1054 |
+
"""
|
| 1055 |
+
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 1056 |
+
pooled output) e.g. for GLUE tasks.
|
| 1057 |
+
""",
|
| 1058 |
+
DEBERTA_START_DOCSTRING,
|
| 1059 |
+
)
|
| 1060 |
+
class DebertaForSequenceClassification(DebertaPreTrainedModel):
|
| 1061 |
+
def __init__(self, config):
|
| 1062 |
+
super().__init__(config)
|
| 1063 |
+
|
| 1064 |
+
num_labels = getattr(config, "num_labels", 2)
|
| 1065 |
+
self.num_labels = num_labels
|
| 1066 |
+
|
| 1067 |
+
self.deberta = DebertaModel(config)
|
| 1068 |
+
self.pooler = ContextPooler(config)
|
| 1069 |
+
output_dim = self.pooler.output_dim
|
| 1070 |
+
|
| 1071 |
+
self.classifier = nn.Linear(output_dim, num_labels)
|
| 1072 |
+
drop_out = getattr(config, "cls_dropout", None)
|
| 1073 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
| 1074 |
+
self.dropout = nn.Dropout(drop_out)
|
| 1075 |
+
|
| 1076 |
+
# Initialize weights and apply final processing
|
| 1077 |
+
self.post_init()
|
| 1078 |
+
|
| 1079 |
+
def get_input_embeddings(self):
|
| 1080 |
+
return self.deberta.get_input_embeddings()
|
| 1081 |
+
|
| 1082 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1083 |
+
self.deberta.set_input_embeddings(new_embeddings)
|
| 1084 |
+
|
| 1085 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1086 |
+
@add_code_sample_docstrings(
|
| 1087 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1088 |
+
output_type=SequenceClassifierOutput,
|
| 1089 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1090 |
+
)
|
| 1091 |
+
def forward(
|
| 1092 |
+
self,
|
| 1093 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1094 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1095 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1096 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1097 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1098 |
+
labels: Optional[torch.Tensor] = None,
|
| 1099 |
+
output_attentions: Optional[bool] = None,
|
| 1100 |
+
output_hidden_states: Optional[bool] = None,
|
| 1101 |
+
return_dict: Optional[bool] = None,
|
| 1102 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1103 |
+
r"""
|
| 1104 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1105 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1106 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1107 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1108 |
+
"""
|
| 1109 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1110 |
+
|
| 1111 |
+
outputs = self.deberta(
|
| 1112 |
+
input_ids,
|
| 1113 |
+
token_type_ids=token_type_ids,
|
| 1114 |
+
attention_mask=attention_mask,
|
| 1115 |
+
position_ids=position_ids,
|
| 1116 |
+
inputs_embeds=inputs_embeds,
|
| 1117 |
+
output_attentions=output_attentions,
|
| 1118 |
+
output_hidden_states=output_hidden_states,
|
| 1119 |
+
return_dict=return_dict,
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
encoder_layer = outputs[0]
|
| 1123 |
+
pooled_output = self.pooler(encoder_layer)
|
| 1124 |
+
pooled_output = self.dropout(pooled_output)
|
| 1125 |
+
logits = self.classifier(pooled_output)
|
| 1126 |
+
|
| 1127 |
+
loss = None
|
| 1128 |
+
if labels is not None:
|
| 1129 |
+
if self.config.problem_type is None:
|
| 1130 |
+
if self.num_labels == 1:
|
| 1131 |
+
# regression task
|
| 1132 |
+
loss_fn = nn.MSELoss()
|
| 1133 |
+
logits = logits.view(-1).to(labels.dtype)
|
| 1134 |
+
loss = loss_fn(logits, labels.view(-1))
|
| 1135 |
+
elif labels.dim() == 1 or labels.size(-1) == 1:
|
| 1136 |
+
label_index = (labels >= 0).nonzero()
|
| 1137 |
+
labels = labels.long()
|
| 1138 |
+
if label_index.size(0) > 0:
|
| 1139 |
+
labeled_logits = torch.gather(
|
| 1140 |
+
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
|
| 1141 |
+
)
|
| 1142 |
+
labels = torch.gather(labels, 0, label_index.view(-1))
|
| 1143 |
+
loss_fct = CrossEntropyLoss()
|
| 1144 |
+
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
| 1145 |
+
else:
|
| 1146 |
+
loss = torch.tensor(0).to(logits)
|
| 1147 |
+
else:
|
| 1148 |
+
log_softmax = nn.LogSoftmax(-1)
|
| 1149 |
+
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
|
| 1150 |
+
elif self.config.problem_type == "regression":
|
| 1151 |
+
loss_fct = MSELoss()
|
| 1152 |
+
if self.num_labels == 1:
|
| 1153 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1154 |
+
else:
|
| 1155 |
+
loss = loss_fct(logits, labels)
|
| 1156 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1157 |
+
loss_fct = CrossEntropyLoss()
|
| 1158 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1159 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1160 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1161 |
+
loss = loss_fct(logits, labels)
|
| 1162 |
+
if not return_dict:
|
| 1163 |
+
output = (logits,) + outputs[1:]
|
| 1164 |
+
return ((loss,) + output) if loss is not None else output
|
| 1165 |
+
|
| 1166 |
+
return SequenceClassifierOutput(
|
| 1167 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
@add_start_docstrings(
|
| 1172 |
+
"""
|
| 1173 |
+
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1174 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1175 |
+
""",
|
| 1176 |
+
DEBERTA_START_DOCSTRING,
|
| 1177 |
+
)
|
| 1178 |
+
class DebertaForTokenClassification(DebertaPreTrainedModel):
|
| 1179 |
+
def __init__(self, config):
|
| 1180 |
+
super().__init__(config)
|
| 1181 |
+
self.num_labels = config.num_labels
|
| 1182 |
+
|
| 1183 |
+
self.deberta = DebertaModel(config)
|
| 1184 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1185 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1186 |
+
|
| 1187 |
+
# Initialize weights and apply final processing
|
| 1188 |
+
self.post_init()
|
| 1189 |
+
|
| 1190 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1191 |
+
@add_code_sample_docstrings(
|
| 1192 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1193 |
+
output_type=TokenClassifierOutput,
|
| 1194 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1195 |
+
)
|
| 1196 |
+
def forward(
|
| 1197 |
+
self,
|
| 1198 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1199 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1200 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1201 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1202 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1203 |
+
labels: Optional[torch.Tensor] = None,
|
| 1204 |
+
output_attentions: Optional[bool] = None,
|
| 1205 |
+
output_hidden_states: Optional[bool] = None,
|
| 1206 |
+
return_dict: Optional[bool] = None,
|
| 1207 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1208 |
+
r"""
|
| 1209 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1210 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1211 |
+
"""
|
| 1212 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1213 |
+
|
| 1214 |
+
outputs = self.deberta(
|
| 1215 |
+
input_ids,
|
| 1216 |
+
attention_mask=attention_mask,
|
| 1217 |
+
token_type_ids=token_type_ids,
|
| 1218 |
+
position_ids=position_ids,
|
| 1219 |
+
inputs_embeds=inputs_embeds,
|
| 1220 |
+
output_attentions=output_attentions,
|
| 1221 |
+
output_hidden_states=output_hidden_states,
|
| 1222 |
+
return_dict=return_dict,
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
sequence_output = outputs[0]
|
| 1226 |
+
|
| 1227 |
+
sequence_output = self.dropout(sequence_output)
|
| 1228 |
+
logits = self.classifier(sequence_output)
|
| 1229 |
+
|
| 1230 |
+
loss = None
|
| 1231 |
+
if labels is not None:
|
| 1232 |
+
loss_fct = CrossEntropyLoss()
|
| 1233 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1234 |
+
|
| 1235 |
+
if not return_dict:
|
| 1236 |
+
output = (logits,) + outputs[1:]
|
| 1237 |
+
return ((loss,) + output) if loss is not None else output
|
| 1238 |
+
|
| 1239 |
+
return TokenClassifierOutput(
|
| 1240 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
@add_start_docstrings(
|
| 1245 |
+
"""
|
| 1246 |
+
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1247 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1248 |
+
""",
|
| 1249 |
+
DEBERTA_START_DOCSTRING,
|
| 1250 |
+
)
|
| 1251 |
+
class DebertaForQuestionAnswering(DebertaPreTrainedModel):
|
| 1252 |
+
def __init__(self, config):
|
| 1253 |
+
super().__init__(config)
|
| 1254 |
+
self.num_labels = config.num_labels
|
| 1255 |
+
|
| 1256 |
+
self.deberta = DebertaModel(config)
|
| 1257 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1258 |
+
|
| 1259 |
+
# Initialize weights and apply final processing
|
| 1260 |
+
self.post_init()
|
| 1261 |
+
|
| 1262 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1263 |
+
@add_code_sample_docstrings(
|
| 1264 |
+
checkpoint=_CHECKPOINT_FOR_QA,
|
| 1265 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1266 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1267 |
+
expected_output=_QA_EXPECTED_OUTPUT,
|
| 1268 |
+
expected_loss=_QA_EXPECTED_LOSS,
|
| 1269 |
+
qa_target_start_index=_QA_TARGET_START_INDEX,
|
| 1270 |
+
qa_target_end_index=_QA_TARGET_END_INDEX,
|
| 1271 |
+
)
|
| 1272 |
+
def forward(
|
| 1273 |
+
self,
|
| 1274 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1275 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1276 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1277 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1278 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1279 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1280 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1281 |
+
output_attentions: Optional[bool] = None,
|
| 1282 |
+
output_hidden_states: Optional[bool] = None,
|
| 1283 |
+
return_dict: Optional[bool] = None,
|
| 1284 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1285 |
+
r"""
|
| 1286 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1287 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1288 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1289 |
+
are not taken into account for computing the loss.
|
| 1290 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1291 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1292 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1293 |
+
are not taken into account for computing the loss.
|
| 1294 |
+
"""
|
| 1295 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1296 |
+
|
| 1297 |
+
outputs = self.deberta(
|
| 1298 |
+
input_ids,
|
| 1299 |
+
attention_mask=attention_mask,
|
| 1300 |
+
token_type_ids=token_type_ids,
|
| 1301 |
+
position_ids=position_ids,
|
| 1302 |
+
inputs_embeds=inputs_embeds,
|
| 1303 |
+
output_attentions=output_attentions,
|
| 1304 |
+
output_hidden_states=output_hidden_states,
|
| 1305 |
+
return_dict=return_dict,
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
sequence_output = outputs[0]
|
| 1309 |
+
|
| 1310 |
+
logits = self.qa_outputs(sequence_output)
|
| 1311 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1312 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1313 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1314 |
+
|
| 1315 |
+
total_loss = None
|
| 1316 |
+
if start_positions is not None and end_positions is not None:
|
| 1317 |
+
# If we are on multi-GPU, split add a dimension
|
| 1318 |
+
if len(start_positions.size()) > 1:
|
| 1319 |
+
start_positions = start_positions.squeeze(-1)
|
| 1320 |
+
if len(end_positions.size()) > 1:
|
| 1321 |
+
end_positions = end_positions.squeeze(-1)
|
| 1322 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1323 |
+
ignored_index = start_logits.size(1)
|
| 1324 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1325 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1326 |
+
|
| 1327 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1328 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1329 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1330 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1331 |
+
|
| 1332 |
+
if not return_dict:
|
| 1333 |
+
output = (start_logits, end_logits) + outputs[1:]
|
| 1334 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1335 |
+
|
| 1336 |
+
return QuestionAnsweringModelOutput(
|
| 1337 |
+
loss=total_loss,
|
| 1338 |
+
start_logits=start_logits,
|
| 1339 |
+
end_logits=end_logits,
|
| 1340 |
+
hidden_states=outputs.hidden_states,
|
| 1341 |
+
attentions=outputs.attentions,
|
| 1342 |
+
)
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
__all__ = [
|
| 1346 |
+
"DebertaForMaskedLM",
|
| 1347 |
+
"DebertaForQuestionAnswering",
|
| 1348 |
+
"DebertaForSequenceClassification",
|
| 1349 |
+
"DebertaForTokenClassification",
|
| 1350 |
+
"DebertaModel",
|
| 1351 |
+
"DebertaPreTrainedModel",
|
| 1352 |
+
]
|
docs/transformers/build/lib/transformers/models/deberta/modeling_tf_deberta.py
ADDED
|
@@ -0,0 +1,1652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 DeBERTa model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
from ...activations_tf import get_tf_activation
|
| 26 |
+
from ...modeling_tf_outputs import (
|
| 27 |
+
TFBaseModelOutput,
|
| 28 |
+
TFMaskedLMOutput,
|
| 29 |
+
TFQuestionAnsweringModelOutput,
|
| 30 |
+
TFSequenceClassifierOutput,
|
| 31 |
+
TFTokenClassifierOutput,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_tf_utils import (
|
| 34 |
+
TFMaskedLanguageModelingLoss,
|
| 35 |
+
TFModelInputType,
|
| 36 |
+
TFPreTrainedModel,
|
| 37 |
+
TFQuestionAnsweringLoss,
|
| 38 |
+
TFSequenceClassificationLoss,
|
| 39 |
+
TFTokenClassificationLoss,
|
| 40 |
+
get_initializer,
|
| 41 |
+
keras,
|
| 42 |
+
unpack_inputs,
|
| 43 |
+
)
|
| 44 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 45 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 46 |
+
from .configuration_deberta import DebertaConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_CONFIG_FOR_DOC = "DebertaConfig"
|
| 53 |
+
_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-base"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TFDebertaContextPooler(keras.layers.Layer):
|
| 57 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 58 |
+
super().__init__(**kwargs)
|
| 59 |
+
self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense")
|
| 60 |
+
self.dropout = TFDebertaStableDropout(config.pooler_dropout, name="dropout")
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
def call(self, hidden_states, training: bool = False):
|
| 64 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 65 |
+
# to the first token.
|
| 66 |
+
context_token = hidden_states[:, 0]
|
| 67 |
+
context_token = self.dropout(context_token, training=training)
|
| 68 |
+
pooled_output = self.dense(context_token)
|
| 69 |
+
pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output)
|
| 70 |
+
return pooled_output
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def output_dim(self) -> int:
|
| 74 |
+
return self.config.hidden_size
|
| 75 |
+
|
| 76 |
+
def build(self, input_shape=None):
|
| 77 |
+
if self.built:
|
| 78 |
+
return
|
| 79 |
+
self.built = True
|
| 80 |
+
if getattr(self, "dense", None) is not None:
|
| 81 |
+
with tf.name_scope(self.dense.name):
|
| 82 |
+
self.dense.build([None, None, self.config.pooler_hidden_size])
|
| 83 |
+
if getattr(self, "dropout", None) is not None:
|
| 84 |
+
with tf.name_scope(self.dropout.name):
|
| 85 |
+
self.dropout.build(None)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TFDebertaXSoftmax(keras.layers.Layer):
|
| 89 |
+
"""
|
| 90 |
+
Masked Softmax which is optimized for saving memory
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
input (`tf.Tensor`): The input tensor that will apply softmax.
|
| 94 |
+
mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
|
| 95 |
+
dim (int): The dimension that will apply softmax
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, axis=-1, **kwargs):
|
| 99 |
+
super().__init__(**kwargs)
|
| 100 |
+
self.axis = axis
|
| 101 |
+
|
| 102 |
+
def call(self, inputs: tf.Tensor, mask: tf.Tensor):
|
| 103 |
+
rmask = tf.logical_not(tf.cast(mask, tf.bool))
|
| 104 |
+
output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)
|
| 105 |
+
output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)
|
| 106 |
+
output = tf.where(rmask, 0.0, output)
|
| 107 |
+
return output
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TFDebertaStableDropout(keras.layers.Layer):
|
| 111 |
+
"""
|
| 112 |
+
Optimized dropout module for stabilizing the training
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
drop_prob (float): the dropout probabilities
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, drop_prob, **kwargs):
|
| 119 |
+
super().__init__(**kwargs)
|
| 120 |
+
self.drop_prob = drop_prob
|
| 121 |
+
|
| 122 |
+
@tf.custom_gradient
|
| 123 |
+
def xdropout(self, inputs):
|
| 124 |
+
"""
|
| 125 |
+
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
|
| 126 |
+
"""
|
| 127 |
+
mask = tf.cast(
|
| 128 |
+
1
|
| 129 |
+
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
|
| 130 |
+
tf.bool,
|
| 131 |
+
)
|
| 132 |
+
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)
|
| 133 |
+
if self.drop_prob > 0:
|
| 134 |
+
inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale
|
| 135 |
+
|
| 136 |
+
def grad(upstream):
|
| 137 |
+
if self.drop_prob > 0:
|
| 138 |
+
return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale
|
| 139 |
+
else:
|
| 140 |
+
return upstream
|
| 141 |
+
|
| 142 |
+
return inputs, grad
|
| 143 |
+
|
| 144 |
+
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
|
| 145 |
+
if training:
|
| 146 |
+
return self.xdropout(inputs)
|
| 147 |
+
return inputs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class TFDebertaLayerNorm(keras.layers.Layer):
|
| 151 |
+
"""LayerNorm module in the TF style (epsilon inside the square root)."""
|
| 152 |
+
|
| 153 |
+
def __init__(self, size, eps=1e-12, **kwargs):
|
| 154 |
+
super().__init__(**kwargs)
|
| 155 |
+
self.size = size
|
| 156 |
+
self.eps = eps
|
| 157 |
+
|
| 158 |
+
def build(self, input_shape):
|
| 159 |
+
self.gamma = self.add_weight(shape=[self.size], initializer=tf.ones_initializer(), name="weight")
|
| 160 |
+
self.beta = self.add_weight(shape=[self.size], initializer=tf.zeros_initializer(), name="bias")
|
| 161 |
+
return super().build(input_shape)
|
| 162 |
+
|
| 163 |
+
def call(self, x: tf.Tensor) -> tf.Tensor:
|
| 164 |
+
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
|
| 165 |
+
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
|
| 166 |
+
std = tf.math.sqrt(variance + self.eps)
|
| 167 |
+
return self.gamma * (x - mean) / std + self.beta
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class TFDebertaSelfOutput(keras.layers.Layer):
|
| 171 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 172 |
+
super().__init__(**kwargs)
|
| 173 |
+
self.dense = keras.layers.Dense(config.hidden_size, name="dense")
|
| 174 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 175 |
+
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
|
| 176 |
+
self.config = config
|
| 177 |
+
|
| 178 |
+
def call(self, hidden_states, input_tensor, training: bool = False):
|
| 179 |
+
hidden_states = self.dense(hidden_states)
|
| 180 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 181 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 182 |
+
return hidden_states
|
| 183 |
+
|
| 184 |
+
def build(self, input_shape=None):
|
| 185 |
+
if self.built:
|
| 186 |
+
return
|
| 187 |
+
self.built = True
|
| 188 |
+
if getattr(self, "dense", None) is not None:
|
| 189 |
+
with tf.name_scope(self.dense.name):
|
| 190 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 191 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 192 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 193 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 194 |
+
if getattr(self, "dropout", None) is not None:
|
| 195 |
+
with tf.name_scope(self.dropout.name):
|
| 196 |
+
self.dropout.build(None)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TFDebertaAttention(keras.layers.Layer):
|
| 200 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 201 |
+
super().__init__(**kwargs)
|
| 202 |
+
self.self = TFDebertaDisentangledSelfAttention(config, name="self")
|
| 203 |
+
self.dense_output = TFDebertaSelfOutput(config, name="output")
|
| 204 |
+
self.config = config
|
| 205 |
+
|
| 206 |
+
def call(
|
| 207 |
+
self,
|
| 208 |
+
input_tensor: tf.Tensor,
|
| 209 |
+
attention_mask: tf.Tensor,
|
| 210 |
+
query_states: Optional[tf.Tensor] = None,
|
| 211 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 212 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 213 |
+
output_attentions: bool = False,
|
| 214 |
+
training: bool = False,
|
| 215 |
+
) -> Tuple[tf.Tensor]:
|
| 216 |
+
self_outputs = self.self(
|
| 217 |
+
hidden_states=input_tensor,
|
| 218 |
+
attention_mask=attention_mask,
|
| 219 |
+
query_states=query_states,
|
| 220 |
+
relative_pos=relative_pos,
|
| 221 |
+
rel_embeddings=rel_embeddings,
|
| 222 |
+
output_attentions=output_attentions,
|
| 223 |
+
training=training,
|
| 224 |
+
)
|
| 225 |
+
if query_states is None:
|
| 226 |
+
query_states = input_tensor
|
| 227 |
+
attention_output = self.dense_output(
|
| 228 |
+
hidden_states=self_outputs[0], input_tensor=query_states, training=training
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
output = (attention_output,) + self_outputs[1:]
|
| 232 |
+
|
| 233 |
+
return output
|
| 234 |
+
|
| 235 |
+
def build(self, input_shape=None):
|
| 236 |
+
if self.built:
|
| 237 |
+
return
|
| 238 |
+
self.built = True
|
| 239 |
+
if getattr(self, "self", None) is not None:
|
| 240 |
+
with tf.name_scope(self.self.name):
|
| 241 |
+
self.self.build(None)
|
| 242 |
+
if getattr(self, "dense_output", None) is not None:
|
| 243 |
+
with tf.name_scope(self.dense_output.name):
|
| 244 |
+
self.dense_output.build(None)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class TFDebertaIntermediate(keras.layers.Layer):
|
| 248 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 249 |
+
super().__init__(**kwargs)
|
| 250 |
+
|
| 251 |
+
self.dense = keras.layers.Dense(
|
| 252 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if isinstance(config.hidden_act, str):
|
| 256 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 257 |
+
else:
|
| 258 |
+
self.intermediate_act_fn = config.hidden_act
|
| 259 |
+
self.config = config
|
| 260 |
+
|
| 261 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 262 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 263 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 264 |
+
|
| 265 |
+
return hidden_states
|
| 266 |
+
|
| 267 |
+
def build(self, input_shape=None):
|
| 268 |
+
if self.built:
|
| 269 |
+
return
|
| 270 |
+
self.built = True
|
| 271 |
+
if getattr(self, "dense", None) is not None:
|
| 272 |
+
with tf.name_scope(self.dense.name):
|
| 273 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class TFDebertaOutput(keras.layers.Layer):
|
| 277 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 278 |
+
super().__init__(**kwargs)
|
| 279 |
+
|
| 280 |
+
self.dense = keras.layers.Dense(
|
| 281 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 282 |
+
)
|
| 283 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 284 |
+
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
|
| 285 |
+
self.config = config
|
| 286 |
+
|
| 287 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 288 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 289 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 290 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 291 |
+
|
| 292 |
+
return hidden_states
|
| 293 |
+
|
| 294 |
+
def build(self, input_shape=None):
|
| 295 |
+
if self.built:
|
| 296 |
+
return
|
| 297 |
+
self.built = True
|
| 298 |
+
if getattr(self, "dense", None) is not None:
|
| 299 |
+
with tf.name_scope(self.dense.name):
|
| 300 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 301 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 302 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 303 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 304 |
+
if getattr(self, "dropout", None) is not None:
|
| 305 |
+
with tf.name_scope(self.dropout.name):
|
| 306 |
+
self.dropout.build(None)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class TFDebertaLayer(keras.layers.Layer):
|
| 310 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 311 |
+
super().__init__(**kwargs)
|
| 312 |
+
|
| 313 |
+
self.attention = TFDebertaAttention(config, name="attention")
|
| 314 |
+
self.intermediate = TFDebertaIntermediate(config, name="intermediate")
|
| 315 |
+
self.bert_output = TFDebertaOutput(config, name="output")
|
| 316 |
+
|
| 317 |
+
def call(
|
| 318 |
+
self,
|
| 319 |
+
hidden_states: tf.Tensor,
|
| 320 |
+
attention_mask: tf.Tensor,
|
| 321 |
+
query_states: Optional[tf.Tensor] = None,
|
| 322 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 323 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 324 |
+
output_attentions: bool = False,
|
| 325 |
+
training: bool = False,
|
| 326 |
+
) -> Tuple[tf.Tensor]:
|
| 327 |
+
attention_outputs = self.attention(
|
| 328 |
+
input_tensor=hidden_states,
|
| 329 |
+
attention_mask=attention_mask,
|
| 330 |
+
query_states=query_states,
|
| 331 |
+
relative_pos=relative_pos,
|
| 332 |
+
rel_embeddings=rel_embeddings,
|
| 333 |
+
output_attentions=output_attentions,
|
| 334 |
+
training=training,
|
| 335 |
+
)
|
| 336 |
+
attention_output = attention_outputs[0]
|
| 337 |
+
intermediate_output = self.intermediate(hidden_states=attention_output)
|
| 338 |
+
layer_output = self.bert_output(
|
| 339 |
+
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
| 340 |
+
)
|
| 341 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 342 |
+
|
| 343 |
+
return outputs
|
| 344 |
+
|
| 345 |
+
def build(self, input_shape=None):
|
| 346 |
+
if self.built:
|
| 347 |
+
return
|
| 348 |
+
self.built = True
|
| 349 |
+
if getattr(self, "attention", None) is not None:
|
| 350 |
+
with tf.name_scope(self.attention.name):
|
| 351 |
+
self.attention.build(None)
|
| 352 |
+
if getattr(self, "intermediate", None) is not None:
|
| 353 |
+
with tf.name_scope(self.intermediate.name):
|
| 354 |
+
self.intermediate.build(None)
|
| 355 |
+
if getattr(self, "bert_output", None) is not None:
|
| 356 |
+
with tf.name_scope(self.bert_output.name):
|
| 357 |
+
self.bert_output.build(None)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class TFDebertaEncoder(keras.layers.Layer):
|
| 361 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 362 |
+
super().__init__(**kwargs)
|
| 363 |
+
|
| 364 |
+
self.layer = [TFDebertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
| 365 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 366 |
+
self.config = config
|
| 367 |
+
if self.relative_attention:
|
| 368 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 369 |
+
if self.max_relative_positions < 1:
|
| 370 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 371 |
+
|
| 372 |
+
def build(self, input_shape=None):
|
| 373 |
+
if self.built:
|
| 374 |
+
return
|
| 375 |
+
self.built = True
|
| 376 |
+
if self.relative_attention:
|
| 377 |
+
self.rel_embeddings = self.add_weight(
|
| 378 |
+
name="rel_embeddings.weight",
|
| 379 |
+
shape=[self.max_relative_positions * 2, self.config.hidden_size],
|
| 380 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 381 |
+
)
|
| 382 |
+
if getattr(self, "layer", None) is not None:
|
| 383 |
+
for layer in self.layer:
|
| 384 |
+
with tf.name_scope(layer.name):
|
| 385 |
+
layer.build(None)
|
| 386 |
+
|
| 387 |
+
def get_rel_embedding(self):
|
| 388 |
+
rel_embeddings = self.rel_embeddings if self.relative_attention else None
|
| 389 |
+
return rel_embeddings
|
| 390 |
+
|
| 391 |
+
def get_attention_mask(self, attention_mask):
|
| 392 |
+
if len(shape_list(attention_mask)) <= 2:
|
| 393 |
+
extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2)
|
| 394 |
+
attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1)
|
| 395 |
+
attention_mask = tf.cast(attention_mask, tf.uint8)
|
| 396 |
+
elif len(shape_list(attention_mask)) == 3:
|
| 397 |
+
attention_mask = tf.expand_dims(attention_mask, 1)
|
| 398 |
+
|
| 399 |
+
return attention_mask
|
| 400 |
+
|
| 401 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 402 |
+
if self.relative_attention and relative_pos is None:
|
| 403 |
+
q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2]
|
| 404 |
+
relative_pos = build_relative_position(q, shape_list(hidden_states)[-2])
|
| 405 |
+
return relative_pos
|
| 406 |
+
|
| 407 |
+
def call(
|
| 408 |
+
self,
|
| 409 |
+
hidden_states: tf.Tensor,
|
| 410 |
+
attention_mask: tf.Tensor,
|
| 411 |
+
query_states: Optional[tf.Tensor] = None,
|
| 412 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 413 |
+
output_attentions: bool = False,
|
| 414 |
+
output_hidden_states: bool = False,
|
| 415 |
+
return_dict: bool = True,
|
| 416 |
+
training: bool = False,
|
| 417 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 418 |
+
all_hidden_states = () if output_hidden_states else None
|
| 419 |
+
all_attentions = () if output_attentions else None
|
| 420 |
+
|
| 421 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 422 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 423 |
+
|
| 424 |
+
if isinstance(hidden_states, Sequence):
|
| 425 |
+
next_kv = hidden_states[0]
|
| 426 |
+
else:
|
| 427 |
+
next_kv = hidden_states
|
| 428 |
+
|
| 429 |
+
rel_embeddings = self.get_rel_embedding()
|
| 430 |
+
|
| 431 |
+
for i, layer_module in enumerate(self.layer):
|
| 432 |
+
if output_hidden_states:
|
| 433 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 434 |
+
|
| 435 |
+
layer_outputs = layer_module(
|
| 436 |
+
hidden_states=next_kv,
|
| 437 |
+
attention_mask=attention_mask,
|
| 438 |
+
query_states=query_states,
|
| 439 |
+
relative_pos=relative_pos,
|
| 440 |
+
rel_embeddings=rel_embeddings,
|
| 441 |
+
output_attentions=output_attentions,
|
| 442 |
+
training=training,
|
| 443 |
+
)
|
| 444 |
+
hidden_states = layer_outputs[0]
|
| 445 |
+
|
| 446 |
+
if query_states is not None:
|
| 447 |
+
query_states = hidden_states
|
| 448 |
+
if isinstance(hidden_states, Sequence):
|
| 449 |
+
next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
|
| 450 |
+
else:
|
| 451 |
+
next_kv = hidden_states
|
| 452 |
+
|
| 453 |
+
if output_attentions:
|
| 454 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 455 |
+
|
| 456 |
+
# Add last layer
|
| 457 |
+
if output_hidden_states:
|
| 458 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 459 |
+
|
| 460 |
+
if not return_dict:
|
| 461 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 462 |
+
|
| 463 |
+
return TFBaseModelOutput(
|
| 464 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def build_relative_position(query_size, key_size):
|
| 469 |
+
"""
|
| 470 |
+
Build relative position according to the query and key
|
| 471 |
+
|
| 472 |
+
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
|
| 473 |
+
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
|
| 474 |
+
P_k\\)
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
query_size (int): the length of query
|
| 478 |
+
key_size (int): the length of key
|
| 479 |
+
|
| 480 |
+
Return:
|
| 481 |
+
`tf.Tensor`: A tensor with shape [1, query_size, key_size]
|
| 482 |
+
|
| 483 |
+
"""
|
| 484 |
+
q_ids = tf.range(query_size, dtype=tf.int32)
|
| 485 |
+
k_ids = tf.range(key_size, dtype=tf.int32)
|
| 486 |
+
rel_pos_ids = q_ids[:, None] - tf.tile(tf.reshape(k_ids, [1, -1]), [query_size, 1])
|
| 487 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 488 |
+
rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0)
|
| 489 |
+
return tf.cast(rel_pos_ids, tf.int64)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
| 493 |
+
shapes = [
|
| 494 |
+
shape_list(query_layer)[0],
|
| 495 |
+
shape_list(query_layer)[1],
|
| 496 |
+
shape_list(query_layer)[2],
|
| 497 |
+
shape_list(relative_pos)[-1],
|
| 498 |
+
]
|
| 499 |
+
return tf.broadcast_to(c2p_pos, shapes)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
| 503 |
+
shapes = [
|
| 504 |
+
shape_list(query_layer)[0],
|
| 505 |
+
shape_list(query_layer)[1],
|
| 506 |
+
shape_list(key_layer)[-2],
|
| 507 |
+
shape_list(key_layer)[-2],
|
| 508 |
+
]
|
| 509 |
+
return tf.broadcast_to(c2p_pos, shapes)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
| 513 |
+
shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]]
|
| 514 |
+
return tf.broadcast_to(pos_index, shapes)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def torch_gather(x, indices, gather_axis):
|
| 518 |
+
if gather_axis < 0:
|
| 519 |
+
gather_axis = tf.rank(x) + gather_axis
|
| 520 |
+
|
| 521 |
+
if gather_axis != tf.rank(x) - 1:
|
| 522 |
+
pre_roll = tf.rank(x) - 1 - gather_axis
|
| 523 |
+
permutation = tf.roll(tf.range(tf.rank(x)), pre_roll, axis=0)
|
| 524 |
+
x = tf.transpose(x, perm=permutation)
|
| 525 |
+
indices = tf.transpose(indices, perm=permutation)
|
| 526 |
+
else:
|
| 527 |
+
pre_roll = 0
|
| 528 |
+
|
| 529 |
+
flat_x = tf.reshape(x, (-1, tf.shape(x)[-1]))
|
| 530 |
+
flat_indices = tf.reshape(indices, (-1, tf.shape(indices)[-1]))
|
| 531 |
+
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
|
| 532 |
+
gathered = tf.reshape(gathered, tf.shape(indices))
|
| 533 |
+
|
| 534 |
+
if pre_roll != 0:
|
| 535 |
+
permutation = tf.roll(tf.range(tf.rank(x)), -pre_roll, axis=0)
|
| 536 |
+
gathered = tf.transpose(gathered, perm=permutation)
|
| 537 |
+
|
| 538 |
+
return gathered
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class TFDebertaDisentangledSelfAttention(keras.layers.Layer):
|
| 542 |
+
"""
|
| 543 |
+
Disentangled self-attention module
|
| 544 |
+
|
| 545 |
+
Parameters:
|
| 546 |
+
config (`str`):
|
| 547 |
+
A model config class instance with the configuration to build a new model. The schema is similar to
|
| 548 |
+
*BertConfig*, for more details, please refer [`DebertaConfig`]
|
| 549 |
+
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 553 |
+
super().__init__(**kwargs)
|
| 554 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 555 |
+
raise ValueError(
|
| 556 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 557 |
+
f"heads ({config.num_attention_heads})"
|
| 558 |
+
)
|
| 559 |
+
self.num_attention_heads = config.num_attention_heads
|
| 560 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 561 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 562 |
+
self.in_proj = keras.layers.Dense(
|
| 563 |
+
self.all_head_size * 3,
|
| 564 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 565 |
+
name="in_proj",
|
| 566 |
+
use_bias=False,
|
| 567 |
+
)
|
| 568 |
+
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
|
| 569 |
+
|
| 570 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 571 |
+
self.talking_head = getattr(config, "talking_head", False)
|
| 572 |
+
|
| 573 |
+
if self.talking_head:
|
| 574 |
+
self.head_logits_proj = keras.layers.Dense(
|
| 575 |
+
self.num_attention_heads,
|
| 576 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 577 |
+
name="head_logits_proj",
|
| 578 |
+
use_bias=False,
|
| 579 |
+
)
|
| 580 |
+
self.head_weights_proj = keras.layers.Dense(
|
| 581 |
+
self.num_attention_heads,
|
| 582 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 583 |
+
name="head_weights_proj",
|
| 584 |
+
use_bias=False,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
self.softmax = TFDebertaXSoftmax(axis=-1)
|
| 588 |
+
|
| 589 |
+
if self.relative_attention:
|
| 590 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 591 |
+
if self.max_relative_positions < 1:
|
| 592 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 593 |
+
self.pos_dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="pos_dropout")
|
| 594 |
+
if "c2p" in self.pos_att_type:
|
| 595 |
+
self.pos_proj = keras.layers.Dense(
|
| 596 |
+
self.all_head_size,
|
| 597 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 598 |
+
name="pos_proj",
|
| 599 |
+
use_bias=False,
|
| 600 |
+
)
|
| 601 |
+
if "p2c" in self.pos_att_type:
|
| 602 |
+
self.pos_q_proj = keras.layers.Dense(
|
| 603 |
+
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="pos_q_proj"
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
self.dropout = TFDebertaStableDropout(config.attention_probs_dropout_prob, name="dropout")
|
| 607 |
+
self.config = config
|
| 608 |
+
|
| 609 |
+
def build(self, input_shape=None):
|
| 610 |
+
if self.built:
|
| 611 |
+
return
|
| 612 |
+
self.built = True
|
| 613 |
+
self.q_bias = self.add_weight(
|
| 614 |
+
name="q_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros()
|
| 615 |
+
)
|
| 616 |
+
self.v_bias = self.add_weight(
|
| 617 |
+
name="v_bias", shape=(self.all_head_size), initializer=keras.initializers.Zeros()
|
| 618 |
+
)
|
| 619 |
+
if getattr(self, "in_proj", None) is not None:
|
| 620 |
+
with tf.name_scope(self.in_proj.name):
|
| 621 |
+
self.in_proj.build([None, None, self.config.hidden_size])
|
| 622 |
+
if getattr(self, "dropout", None) is not None:
|
| 623 |
+
with tf.name_scope(self.dropout.name):
|
| 624 |
+
self.dropout.build(None)
|
| 625 |
+
if getattr(self, "head_logits_proj", None) is not None:
|
| 626 |
+
with tf.name_scope(self.head_logits_proj.name):
|
| 627 |
+
self.head_logits_proj.build(None)
|
| 628 |
+
if getattr(self, "head_weights_proj", None) is not None:
|
| 629 |
+
with tf.name_scope(self.head_weights_proj.name):
|
| 630 |
+
self.head_weights_proj.build(None)
|
| 631 |
+
if getattr(self, "pos_dropout", None) is not None:
|
| 632 |
+
with tf.name_scope(self.pos_dropout.name):
|
| 633 |
+
self.pos_dropout.build(None)
|
| 634 |
+
if getattr(self, "pos_proj", None) is not None:
|
| 635 |
+
with tf.name_scope(self.pos_proj.name):
|
| 636 |
+
self.pos_proj.build([self.config.hidden_size])
|
| 637 |
+
if getattr(self, "pos_q_proj", None) is not None:
|
| 638 |
+
with tf.name_scope(self.pos_q_proj.name):
|
| 639 |
+
self.pos_q_proj.build([self.config.hidden_size])
|
| 640 |
+
|
| 641 |
+
def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
|
| 642 |
+
shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1]
|
| 643 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 644 |
+
tensor = tf.reshape(tensor=tensor, shape=shape)
|
| 645 |
+
|
| 646 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 647 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 648 |
+
|
| 649 |
+
def call(
|
| 650 |
+
self,
|
| 651 |
+
hidden_states: tf.Tensor,
|
| 652 |
+
attention_mask: tf.Tensor,
|
| 653 |
+
query_states: Optional[tf.Tensor] = None,
|
| 654 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 655 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 656 |
+
output_attentions: bool = False,
|
| 657 |
+
training: bool = False,
|
| 658 |
+
) -> Tuple[tf.Tensor]:
|
| 659 |
+
"""
|
| 660 |
+
Call the module
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
hidden_states (`tf.Tensor`):
|
| 664 |
+
Input states to the module usually the output from previous layer, it will be the Q,K and V in
|
| 665 |
+
*Attention(Q,K,V)*
|
| 666 |
+
|
| 667 |
+
attention_mask (`tf.Tensor`):
|
| 668 |
+
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
|
| 669 |
+
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
|
| 670 |
+
th token.
|
| 671 |
+
|
| 672 |
+
return_att (`bool`, *optional*):
|
| 673 |
+
Whether return the attention matrix.
|
| 674 |
+
|
| 675 |
+
query_states (`tf.Tensor`, *optional*):
|
| 676 |
+
The *Q* state in *Attention(Q,K,V)*.
|
| 677 |
+
|
| 678 |
+
relative_pos (`tf.Tensor`):
|
| 679 |
+
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
|
| 680 |
+
values ranging in [*-max_relative_positions*, *max_relative_positions*].
|
| 681 |
+
|
| 682 |
+
rel_embeddings (`tf.Tensor`):
|
| 683 |
+
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
|
| 684 |
+
\\text{max_relative_positions}\\), *hidden_size*].
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
"""
|
| 688 |
+
if query_states is None:
|
| 689 |
+
qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
|
| 690 |
+
query_layer, key_layer, value_layer = tf.split(
|
| 691 |
+
self.transpose_for_scores(qp), num_or_size_splits=3, axis=-1
|
| 692 |
+
)
|
| 693 |
+
else:
|
| 694 |
+
|
| 695 |
+
def linear(w, b, x):
|
| 696 |
+
out = tf.matmul(x, w, transpose_b=True)
|
| 697 |
+
if b is not None:
|
| 698 |
+
out += tf.transpose(b)
|
| 699 |
+
return out
|
| 700 |
+
|
| 701 |
+
ws = tf.split(
|
| 702 |
+
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
|
| 703 |
+
)
|
| 704 |
+
qkvw = tf.TensorArray(dtype=self.dtype, size=3)
|
| 705 |
+
for k in tf.range(3):
|
| 706 |
+
qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads)
|
| 707 |
+
for i in tf.range(self.num_attention_heads):
|
| 708 |
+
qkvw_inside = qkvw_inside.write(i, ws[i * 3 + k])
|
| 709 |
+
qkvw = qkvw.write(k, qkvw_inside.concat())
|
| 710 |
+
qkvb = [None] * 3
|
| 711 |
+
|
| 712 |
+
q = linear(qkvw[0], qkvb[0], query_states)
|
| 713 |
+
k = linear(qkvw[1], qkvb[1], hidden_states)
|
| 714 |
+
v = linear(qkvw[2], qkvb[2], hidden_states)
|
| 715 |
+
query_layer = self.transpose_for_scores(q)
|
| 716 |
+
key_layer = self.transpose_for_scores(k)
|
| 717 |
+
value_layer = self.transpose_for_scores(v)
|
| 718 |
+
|
| 719 |
+
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
|
| 720 |
+
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
|
| 721 |
+
|
| 722 |
+
rel_att = None
|
| 723 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 724 |
+
scale_factor = 1 + len(self.pos_att_type)
|
| 725 |
+
scale = math.sqrt(shape_list(query_layer)[-1] * scale_factor)
|
| 726 |
+
query_layer = query_layer / scale
|
| 727 |
+
|
| 728 |
+
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 1, 3, 2]))
|
| 729 |
+
if self.relative_attention:
|
| 730 |
+
rel_embeddings = self.pos_dropout(rel_embeddings, training=training)
|
| 731 |
+
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
|
| 732 |
+
|
| 733 |
+
if rel_att is not None:
|
| 734 |
+
attention_scores = attention_scores + rel_att
|
| 735 |
+
|
| 736 |
+
if self.talking_head:
|
| 737 |
+
attention_scores = tf.transpose(
|
| 738 |
+
self.head_logits_proj(tf.transpose(attention_scores, [0, 2, 3, 1])), [0, 3, 1, 2]
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
attention_probs = self.softmax(attention_scores, attention_mask)
|
| 742 |
+
attention_probs = self.dropout(attention_probs, training=training)
|
| 743 |
+
if self.talking_head:
|
| 744 |
+
attention_probs = tf.transpose(
|
| 745 |
+
self.head_weights_proj(tf.transpose(attention_probs, [0, 2, 3, 1])), [0, 3, 1, 2]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
| 749 |
+
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
| 750 |
+
context_layer_shape = shape_list(context_layer)
|
| 751 |
+
# Set the final dimension here explicitly.
|
| 752 |
+
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
|
| 753 |
+
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
|
| 754 |
+
# requires final input dimension to be defined
|
| 755 |
+
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
|
| 756 |
+
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
| 757 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 758 |
+
return outputs
|
| 759 |
+
|
| 760 |
+
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
| 761 |
+
if relative_pos is None:
|
| 762 |
+
q = shape_list(query_layer)[-2]
|
| 763 |
+
relative_pos = build_relative_position(q, shape_list(key_layer)[-2])
|
| 764 |
+
shape_list_pos = shape_list(relative_pos)
|
| 765 |
+
if len(shape_list_pos) == 2:
|
| 766 |
+
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
|
| 767 |
+
elif len(shape_list_pos) == 3:
|
| 768 |
+
relative_pos = tf.expand_dims(relative_pos, 1)
|
| 769 |
+
# bxhxqxk
|
| 770 |
+
elif len(shape_list_pos) != 4:
|
| 771 |
+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
|
| 772 |
+
|
| 773 |
+
att_span = tf.cast(
|
| 774 |
+
tf.minimum(
|
| 775 |
+
tf.maximum(shape_list(query_layer)[-2], shape_list(key_layer)[-2]), self.max_relative_positions
|
| 776 |
+
),
|
| 777 |
+
tf.int64,
|
| 778 |
+
)
|
| 779 |
+
rel_embeddings = tf.expand_dims(
|
| 780 |
+
rel_embeddings[self.max_relative_positions - att_span : self.max_relative_positions + att_span, :], 0
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
score = 0
|
| 784 |
+
|
| 785 |
+
# content->position
|
| 786 |
+
if "c2p" in self.pos_att_type:
|
| 787 |
+
pos_key_layer = self.pos_proj(rel_embeddings)
|
| 788 |
+
pos_key_layer = self.transpose_for_scores(pos_key_layer)
|
| 789 |
+
c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 1, 3, 2]))
|
| 790 |
+
c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
|
| 791 |
+
c2p_att = torch_gather(c2p_att, c2p_dynamic_expand(c2p_pos, query_layer, relative_pos), -1)
|
| 792 |
+
score += c2p_att
|
| 793 |
+
|
| 794 |
+
# position->content
|
| 795 |
+
if "p2c" in self.pos_att_type:
|
| 796 |
+
pos_query_layer = self.pos_q_proj(rel_embeddings)
|
| 797 |
+
pos_query_layer = self.transpose_for_scores(pos_query_layer)
|
| 798 |
+
pos_query_layer /= tf.math.sqrt(
|
| 799 |
+
tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype)
|
| 800 |
+
)
|
| 801 |
+
if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]:
|
| 802 |
+
r_pos = build_relative_position(shape_list(key_layer)[-2], shape_list(key_layer)[-2])
|
| 803 |
+
else:
|
| 804 |
+
r_pos = relative_pos
|
| 805 |
+
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
|
| 806 |
+
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 1, 3, 2]))
|
| 807 |
+
p2c_att = tf.transpose(
|
| 808 |
+
torch_gather(p2c_att, p2c_dynamic_expand(p2c_pos, query_layer, key_layer), -1), [0, 1, 3, 2]
|
| 809 |
+
)
|
| 810 |
+
if shape_list(query_layer)[-2] != shape_list(key_layer)[-2]:
|
| 811 |
+
pos_index = tf.expand_dims(relative_pos[:, :, :, 0], -1)
|
| 812 |
+
p2c_att = torch_gather(p2c_att, pos_dynamic_expand(pos_index, p2c_att, key_layer), -2)
|
| 813 |
+
score += p2c_att
|
| 814 |
+
|
| 815 |
+
return score
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
class TFDebertaEmbeddings(keras.layers.Layer):
|
| 819 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 820 |
+
|
| 821 |
+
def __init__(self, config, **kwargs):
|
| 822 |
+
super().__init__(**kwargs)
|
| 823 |
+
|
| 824 |
+
self.config = config
|
| 825 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 826 |
+
self.hidden_size = config.hidden_size
|
| 827 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 828 |
+
self.position_biased_input = getattr(config, "position_biased_input", True)
|
| 829 |
+
self.initializer_range = config.initializer_range
|
| 830 |
+
if self.embedding_size != config.hidden_size:
|
| 831 |
+
self.embed_proj = keras.layers.Dense(
|
| 832 |
+
config.hidden_size,
|
| 833 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 834 |
+
name="embed_proj",
|
| 835 |
+
use_bias=False,
|
| 836 |
+
)
|
| 837 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 838 |
+
self.dropout = TFDebertaStableDropout(config.hidden_dropout_prob, name="dropout")
|
| 839 |
+
|
| 840 |
+
def build(self, input_shape=None):
|
| 841 |
+
with tf.name_scope("word_embeddings"):
|
| 842 |
+
self.weight = self.add_weight(
|
| 843 |
+
name="weight",
|
| 844 |
+
shape=[self.config.vocab_size, self.embedding_size],
|
| 845 |
+
initializer=get_initializer(self.initializer_range),
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
with tf.name_scope("token_type_embeddings"):
|
| 849 |
+
if self.config.type_vocab_size > 0:
|
| 850 |
+
self.token_type_embeddings = self.add_weight(
|
| 851 |
+
name="embeddings",
|
| 852 |
+
shape=[self.config.type_vocab_size, self.embedding_size],
|
| 853 |
+
initializer=get_initializer(self.initializer_range),
|
| 854 |
+
)
|
| 855 |
+
else:
|
| 856 |
+
self.token_type_embeddings = None
|
| 857 |
+
|
| 858 |
+
with tf.name_scope("position_embeddings"):
|
| 859 |
+
if self.position_biased_input:
|
| 860 |
+
self.position_embeddings = self.add_weight(
|
| 861 |
+
name="embeddings",
|
| 862 |
+
shape=[self.max_position_embeddings, self.hidden_size],
|
| 863 |
+
initializer=get_initializer(self.initializer_range),
|
| 864 |
+
)
|
| 865 |
+
else:
|
| 866 |
+
self.position_embeddings = None
|
| 867 |
+
|
| 868 |
+
if self.built:
|
| 869 |
+
return
|
| 870 |
+
self.built = True
|
| 871 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 872 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 873 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 874 |
+
if getattr(self, "dropout", None) is not None:
|
| 875 |
+
with tf.name_scope(self.dropout.name):
|
| 876 |
+
self.dropout.build(None)
|
| 877 |
+
if getattr(self, "embed_proj", None) is not None:
|
| 878 |
+
with tf.name_scope(self.embed_proj.name):
|
| 879 |
+
self.embed_proj.build([None, None, self.embedding_size])
|
| 880 |
+
|
| 881 |
+
def call(
|
| 882 |
+
self,
|
| 883 |
+
input_ids: Optional[tf.Tensor] = None,
|
| 884 |
+
position_ids: Optional[tf.Tensor] = None,
|
| 885 |
+
token_type_ids: Optional[tf.Tensor] = None,
|
| 886 |
+
inputs_embeds: Optional[tf.Tensor] = None,
|
| 887 |
+
mask: Optional[tf.Tensor] = None,
|
| 888 |
+
training: bool = False,
|
| 889 |
+
) -> tf.Tensor:
|
| 890 |
+
"""
|
| 891 |
+
Applies embedding based on inputs tensor.
|
| 892 |
+
|
| 893 |
+
Returns:
|
| 894 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 895 |
+
"""
|
| 896 |
+
if input_ids is None and inputs_embeds is None:
|
| 897 |
+
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
|
| 898 |
+
|
| 899 |
+
if input_ids is not None:
|
| 900 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 901 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 902 |
+
|
| 903 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 904 |
+
|
| 905 |
+
if token_type_ids is None:
|
| 906 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 907 |
+
|
| 908 |
+
if position_ids is None:
|
| 909 |
+
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
| 910 |
+
|
| 911 |
+
final_embeddings = inputs_embeds
|
| 912 |
+
if self.position_biased_input:
|
| 913 |
+
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
| 914 |
+
final_embeddings += position_embeds
|
| 915 |
+
if self.config.type_vocab_size > 0:
|
| 916 |
+
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
| 917 |
+
final_embeddings += token_type_embeds
|
| 918 |
+
|
| 919 |
+
if self.embedding_size != self.hidden_size:
|
| 920 |
+
final_embeddings = self.embed_proj(final_embeddings)
|
| 921 |
+
|
| 922 |
+
final_embeddings = self.LayerNorm(final_embeddings)
|
| 923 |
+
|
| 924 |
+
if mask is not None:
|
| 925 |
+
if len(shape_list(mask)) != len(shape_list(final_embeddings)):
|
| 926 |
+
if len(shape_list(mask)) == 4:
|
| 927 |
+
mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1)
|
| 928 |
+
mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)
|
| 929 |
+
|
| 930 |
+
final_embeddings = final_embeddings * mask
|
| 931 |
+
|
| 932 |
+
final_embeddings = self.dropout(final_embeddings, training=training)
|
| 933 |
+
|
| 934 |
+
return final_embeddings
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
class TFDebertaPredictionHeadTransform(keras.layers.Layer):
|
| 938 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 939 |
+
super().__init__(**kwargs)
|
| 940 |
+
|
| 941 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 942 |
+
|
| 943 |
+
self.dense = keras.layers.Dense(
|
| 944 |
+
units=self.embedding_size,
|
| 945 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 946 |
+
name="dense",
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
if isinstance(config.hidden_act, str):
|
| 950 |
+
self.transform_act_fn = get_tf_activation(config.hidden_act)
|
| 951 |
+
else:
|
| 952 |
+
self.transform_act_fn = config.hidden_act
|
| 953 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 954 |
+
self.config = config
|
| 955 |
+
|
| 956 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 957 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 958 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 959 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 960 |
+
|
| 961 |
+
return hidden_states
|
| 962 |
+
|
| 963 |
+
def build(self, input_shape=None):
|
| 964 |
+
if self.built:
|
| 965 |
+
return
|
| 966 |
+
self.built = True
|
| 967 |
+
if getattr(self, "dense", None) is not None:
|
| 968 |
+
with tf.name_scope(self.dense.name):
|
| 969 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 970 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 971 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 972 |
+
self.LayerNorm.build([None, None, self.embedding_size])
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
class TFDebertaLMPredictionHead(keras.layers.Layer):
|
| 976 |
+
def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 977 |
+
super().__init__(**kwargs)
|
| 978 |
+
|
| 979 |
+
self.config = config
|
| 980 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 981 |
+
|
| 982 |
+
self.transform = TFDebertaPredictionHeadTransform(config, name="transform")
|
| 983 |
+
|
| 984 |
+
# The output weights are the same as the input embeddings, but there is
|
| 985 |
+
# an output-only bias for each token.
|
| 986 |
+
self.input_embeddings = input_embeddings
|
| 987 |
+
|
| 988 |
+
def build(self, input_shape=None):
|
| 989 |
+
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
| 990 |
+
|
| 991 |
+
if self.built:
|
| 992 |
+
return
|
| 993 |
+
self.built = True
|
| 994 |
+
if getattr(self, "transform", None) is not None:
|
| 995 |
+
with tf.name_scope(self.transform.name):
|
| 996 |
+
self.transform.build(None)
|
| 997 |
+
|
| 998 |
+
def get_output_embeddings(self) -> keras.layers.Layer:
|
| 999 |
+
return self.input_embeddings
|
| 1000 |
+
|
| 1001 |
+
def set_output_embeddings(self, value: tf.Variable):
|
| 1002 |
+
self.input_embeddings.weight = value
|
| 1003 |
+
self.input_embeddings.vocab_size = shape_list(value)[0]
|
| 1004 |
+
|
| 1005 |
+
def get_bias(self) -> Dict[str, tf.Variable]:
|
| 1006 |
+
return {"bias": self.bias}
|
| 1007 |
+
|
| 1008 |
+
def set_bias(self, value: tf.Variable):
|
| 1009 |
+
self.bias = value["bias"]
|
| 1010 |
+
self.config.vocab_size = shape_list(value["bias"])[0]
|
| 1011 |
+
|
| 1012 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 1013 |
+
hidden_states = self.transform(hidden_states=hidden_states)
|
| 1014 |
+
seq_length = shape_list(hidden_states)[1]
|
| 1015 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
| 1016 |
+
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
|
| 1017 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
|
| 1018 |
+
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
|
| 1019 |
+
|
| 1020 |
+
return hidden_states
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
class TFDebertaOnlyMLMHead(keras.layers.Layer):
|
| 1024 |
+
def __init__(self, config: DebertaConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 1025 |
+
super().__init__(**kwargs)
|
| 1026 |
+
self.predictions = TFDebertaLMPredictionHead(config, input_embeddings, name="predictions")
|
| 1027 |
+
|
| 1028 |
+
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
|
| 1029 |
+
prediction_scores = self.predictions(hidden_states=sequence_output)
|
| 1030 |
+
|
| 1031 |
+
return prediction_scores
|
| 1032 |
+
|
| 1033 |
+
def build(self, input_shape=None):
|
| 1034 |
+
if self.built:
|
| 1035 |
+
return
|
| 1036 |
+
self.built = True
|
| 1037 |
+
if getattr(self, "predictions", None) is not None:
|
| 1038 |
+
with tf.name_scope(self.predictions.name):
|
| 1039 |
+
self.predictions.build(None)
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
# @keras_serializable
|
| 1043 |
+
class TFDebertaMainLayer(keras.layers.Layer):
|
| 1044 |
+
config_class = DebertaConfig
|
| 1045 |
+
|
| 1046 |
+
def __init__(self, config: DebertaConfig, **kwargs):
|
| 1047 |
+
super().__init__(**kwargs)
|
| 1048 |
+
|
| 1049 |
+
self.config = config
|
| 1050 |
+
|
| 1051 |
+
self.embeddings = TFDebertaEmbeddings(config, name="embeddings")
|
| 1052 |
+
self.encoder = TFDebertaEncoder(config, name="encoder")
|
| 1053 |
+
|
| 1054 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 1055 |
+
return self.embeddings
|
| 1056 |
+
|
| 1057 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 1058 |
+
self.embeddings.weight = value
|
| 1059 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 1060 |
+
|
| 1061 |
+
def _prune_heads(self, heads_to_prune):
|
| 1062 |
+
"""
|
| 1063 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1064 |
+
class PreTrainedModel
|
| 1065 |
+
"""
|
| 1066 |
+
raise NotImplementedError
|
| 1067 |
+
|
| 1068 |
+
@unpack_inputs
|
| 1069 |
+
def call(
|
| 1070 |
+
self,
|
| 1071 |
+
input_ids: TFModelInputType | None = None,
|
| 1072 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1073 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1074 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1075 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1076 |
+
output_attentions: Optional[bool] = None,
|
| 1077 |
+
output_hidden_states: Optional[bool] = None,
|
| 1078 |
+
return_dict: Optional[bool] = None,
|
| 1079 |
+
training: bool = False,
|
| 1080 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 1081 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1082 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1083 |
+
elif input_ids is not None:
|
| 1084 |
+
input_shape = shape_list(input_ids)
|
| 1085 |
+
elif inputs_embeds is not None:
|
| 1086 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 1087 |
+
else:
|
| 1088 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1089 |
+
|
| 1090 |
+
if attention_mask is None:
|
| 1091 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 1092 |
+
|
| 1093 |
+
if token_type_ids is None:
|
| 1094 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 1095 |
+
|
| 1096 |
+
embedding_output = self.embeddings(
|
| 1097 |
+
input_ids=input_ids,
|
| 1098 |
+
position_ids=position_ids,
|
| 1099 |
+
token_type_ids=token_type_ids,
|
| 1100 |
+
inputs_embeds=inputs_embeds,
|
| 1101 |
+
mask=attention_mask,
|
| 1102 |
+
training=training,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
encoder_outputs = self.encoder(
|
| 1106 |
+
hidden_states=embedding_output,
|
| 1107 |
+
attention_mask=attention_mask,
|
| 1108 |
+
output_attentions=output_attentions,
|
| 1109 |
+
output_hidden_states=output_hidden_states,
|
| 1110 |
+
return_dict=return_dict,
|
| 1111 |
+
training=training,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
sequence_output = encoder_outputs[0]
|
| 1115 |
+
|
| 1116 |
+
if not return_dict:
|
| 1117 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 1118 |
+
|
| 1119 |
+
return TFBaseModelOutput(
|
| 1120 |
+
last_hidden_state=sequence_output,
|
| 1121 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1122 |
+
attentions=encoder_outputs.attentions,
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
def build(self, input_shape=None):
|
| 1126 |
+
if self.built:
|
| 1127 |
+
return
|
| 1128 |
+
self.built = True
|
| 1129 |
+
if getattr(self, "embeddings", None) is not None:
|
| 1130 |
+
with tf.name_scope(self.embeddings.name):
|
| 1131 |
+
self.embeddings.build(None)
|
| 1132 |
+
if getattr(self, "encoder", None) is not None:
|
| 1133 |
+
with tf.name_scope(self.encoder.name):
|
| 1134 |
+
self.encoder.build(None)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
class TFDebertaPreTrainedModel(TFPreTrainedModel):
|
| 1138 |
+
"""
|
| 1139 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1140 |
+
models.
|
| 1141 |
+
"""
|
| 1142 |
+
|
| 1143 |
+
config_class = DebertaConfig
|
| 1144 |
+
base_model_prefix = "deberta"
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
DEBERTA_START_DOCSTRING = r"""
|
| 1148 |
+
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
| 1149 |
+
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
|
| 1150 |
+
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
|
| 1151 |
+
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
|
| 1152 |
+
|
| 1153 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 1154 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 1155 |
+
behavior.
|
| 1156 |
+
|
| 1157 |
+
<Tip>
|
| 1158 |
+
|
| 1159 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 1160 |
+
|
| 1161 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 1162 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 1163 |
+
|
| 1164 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 1165 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 1166 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 1167 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 1168 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 1169 |
+
positional argument:
|
| 1170 |
+
|
| 1171 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 1172 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 1173 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 1174 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 1175 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 1176 |
+
|
| 1177 |
+
Note that when creating models and layers with
|
| 1178 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 1179 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 1180 |
+
|
| 1181 |
+
</Tip>
|
| 1182 |
+
|
| 1183 |
+
Parameters:
|
| 1184 |
+
config ([`DebertaConfig`]): Model configuration class with all the parameters of the model.
|
| 1185 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1186 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1187 |
+
"""
|
| 1188 |
+
|
| 1189 |
+
DEBERTA_INPUTS_DOCSTRING = r"""
|
| 1190 |
+
Args:
|
| 1191 |
+
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
|
| 1192 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1193 |
+
|
| 1194 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1195 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1196 |
+
|
| 1197 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1198 |
+
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1199 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1200 |
+
|
| 1201 |
+
- 1 for tokens that are **not masked**,
|
| 1202 |
+
- 0 for tokens that are **masked**.
|
| 1203 |
+
|
| 1204 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1205 |
+
token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1206 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 1207 |
+
1]`:
|
| 1208 |
+
|
| 1209 |
+
- 0 corresponds to a *sentence A* token,
|
| 1210 |
+
- 1 corresponds to a *sentence B* token.
|
| 1211 |
+
|
| 1212 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 1213 |
+
position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1214 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1215 |
+
config.max_position_embeddings - 1]`.
|
| 1216 |
+
|
| 1217 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1218 |
+
inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
|
| 1219 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1220 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 1221 |
+
model's internal embedding lookup matrix.
|
| 1222 |
+
output_attentions (`bool`, *optional*):
|
| 1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1224 |
+
tensors for more detail.
|
| 1225 |
+
output_hidden_states (`bool`, *optional*):
|
| 1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1227 |
+
more detail.
|
| 1228 |
+
return_dict (`bool`, *optional*):
|
| 1229 |
+
Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple.
|
| 1230 |
+
"""
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
@add_start_docstrings(
|
| 1234 |
+
"The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1235 |
+
DEBERTA_START_DOCSTRING,
|
| 1236 |
+
)
|
| 1237 |
+
class TFDebertaModel(TFDebertaPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: DebertaConfig, *inputs, **kwargs):
|
| 1239 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1240 |
+
|
| 1241 |
+
self.deberta = TFDebertaMainLayer(config, name="deberta")
|
| 1242 |
+
|
| 1243 |
+
@unpack_inputs
|
| 1244 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1245 |
+
@add_code_sample_docstrings(
|
| 1246 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1247 |
+
output_type=TFBaseModelOutput,
|
| 1248 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1249 |
+
)
|
| 1250 |
+
def call(
|
| 1251 |
+
self,
|
| 1252 |
+
input_ids: TFModelInputType | None = None,
|
| 1253 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1254 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1255 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1256 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1257 |
+
output_attentions: Optional[bool] = None,
|
| 1258 |
+
output_hidden_states: Optional[bool] = None,
|
| 1259 |
+
return_dict: Optional[bool] = None,
|
| 1260 |
+
training: Optional[bool] = False,
|
| 1261 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 1262 |
+
outputs = self.deberta(
|
| 1263 |
+
input_ids=input_ids,
|
| 1264 |
+
attention_mask=attention_mask,
|
| 1265 |
+
token_type_ids=token_type_ids,
|
| 1266 |
+
position_ids=position_ids,
|
| 1267 |
+
inputs_embeds=inputs_embeds,
|
| 1268 |
+
output_attentions=output_attentions,
|
| 1269 |
+
output_hidden_states=output_hidden_states,
|
| 1270 |
+
return_dict=return_dict,
|
| 1271 |
+
training=training,
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
return outputs
|
| 1275 |
+
|
| 1276 |
+
def build(self, input_shape=None):
|
| 1277 |
+
if self.built:
|
| 1278 |
+
return
|
| 1279 |
+
self.built = True
|
| 1280 |
+
if getattr(self, "deberta", None) is not None:
|
| 1281 |
+
with tf.name_scope(self.deberta.name):
|
| 1282 |
+
self.deberta.build(None)
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
| 1286 |
+
class TFDebertaForMaskedLM(TFDebertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
| 1287 |
+
def __init__(self, config: DebertaConfig, *inputs, **kwargs):
|
| 1288 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1289 |
+
|
| 1290 |
+
if config.is_decoder:
|
| 1291 |
+
logger.warning(
|
| 1292 |
+
"If you want to use `TFDebertaForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1293 |
+
"bi-directional self-attention."
|
| 1294 |
+
)
|
| 1295 |
+
|
| 1296 |
+
self.deberta = TFDebertaMainLayer(config, name="deberta")
|
| 1297 |
+
self.mlm = TFDebertaOnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls")
|
| 1298 |
+
|
| 1299 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 1300 |
+
return self.mlm.predictions
|
| 1301 |
+
|
| 1302 |
+
@unpack_inputs
|
| 1303 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1304 |
+
@add_code_sample_docstrings(
|
| 1305 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1306 |
+
output_type=TFMaskedLMOutput,
|
| 1307 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1308 |
+
)
|
| 1309 |
+
def call(
|
| 1310 |
+
self,
|
| 1311 |
+
input_ids: TFModelInputType | None = None,
|
| 1312 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1313 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1314 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1315 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1316 |
+
output_attentions: Optional[bool] = None,
|
| 1317 |
+
output_hidden_states: Optional[bool] = None,
|
| 1318 |
+
return_dict: Optional[bool] = None,
|
| 1319 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1320 |
+
training: Optional[bool] = False,
|
| 1321 |
+
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
|
| 1322 |
+
r"""
|
| 1323 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1324 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1325 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1326 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1327 |
+
"""
|
| 1328 |
+
outputs = self.deberta(
|
| 1329 |
+
input_ids=input_ids,
|
| 1330 |
+
attention_mask=attention_mask,
|
| 1331 |
+
token_type_ids=token_type_ids,
|
| 1332 |
+
position_ids=position_ids,
|
| 1333 |
+
inputs_embeds=inputs_embeds,
|
| 1334 |
+
output_attentions=output_attentions,
|
| 1335 |
+
output_hidden_states=output_hidden_states,
|
| 1336 |
+
return_dict=return_dict,
|
| 1337 |
+
training=training,
|
| 1338 |
+
)
|
| 1339 |
+
sequence_output = outputs[0]
|
| 1340 |
+
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
|
| 1341 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
|
| 1342 |
+
|
| 1343 |
+
if not return_dict:
|
| 1344 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1345 |
+
return ((loss,) + output) if loss is not None else output
|
| 1346 |
+
|
| 1347 |
+
return TFMaskedLMOutput(
|
| 1348 |
+
loss=loss,
|
| 1349 |
+
logits=prediction_scores,
|
| 1350 |
+
hidden_states=outputs.hidden_states,
|
| 1351 |
+
attentions=outputs.attentions,
|
| 1352 |
+
)
|
| 1353 |
+
|
| 1354 |
+
def build(self, input_shape=None):
|
| 1355 |
+
if self.built:
|
| 1356 |
+
return
|
| 1357 |
+
self.built = True
|
| 1358 |
+
if getattr(self, "deberta", None) is not None:
|
| 1359 |
+
with tf.name_scope(self.deberta.name):
|
| 1360 |
+
self.deberta.build(None)
|
| 1361 |
+
if getattr(self, "mlm", None) is not None:
|
| 1362 |
+
with tf.name_scope(self.mlm.name):
|
| 1363 |
+
self.mlm.build(None)
|
| 1364 |
+
|
| 1365 |
+
|
| 1366 |
+
@add_start_docstrings(
|
| 1367 |
+
"""
|
| 1368 |
+
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 1369 |
+
pooled output) e.g. for GLUE tasks.
|
| 1370 |
+
""",
|
| 1371 |
+
DEBERTA_START_DOCSTRING,
|
| 1372 |
+
)
|
| 1373 |
+
class TFDebertaForSequenceClassification(TFDebertaPreTrainedModel, TFSequenceClassificationLoss):
|
| 1374 |
+
def __init__(self, config: DebertaConfig, *inputs, **kwargs):
|
| 1375 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1376 |
+
|
| 1377 |
+
self.num_labels = config.num_labels
|
| 1378 |
+
|
| 1379 |
+
self.deberta = TFDebertaMainLayer(config, name="deberta")
|
| 1380 |
+
self.pooler = TFDebertaContextPooler(config, name="pooler")
|
| 1381 |
+
|
| 1382 |
+
drop_out = getattr(config, "cls_dropout", None)
|
| 1383 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
| 1384 |
+
self.dropout = TFDebertaStableDropout(drop_out, name="cls_dropout")
|
| 1385 |
+
self.classifier = keras.layers.Dense(
|
| 1386 |
+
units=config.num_labels,
|
| 1387 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1388 |
+
name="classifier",
|
| 1389 |
+
)
|
| 1390 |
+
self.output_dim = self.pooler.output_dim
|
| 1391 |
+
|
| 1392 |
+
@unpack_inputs
|
| 1393 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1394 |
+
@add_code_sample_docstrings(
|
| 1395 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1396 |
+
output_type=TFSequenceClassifierOutput,
|
| 1397 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1398 |
+
)
|
| 1399 |
+
def call(
|
| 1400 |
+
self,
|
| 1401 |
+
input_ids: TFModelInputType | None = None,
|
| 1402 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1403 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1404 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1405 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1406 |
+
output_attentions: Optional[bool] = None,
|
| 1407 |
+
output_hidden_states: Optional[bool] = None,
|
| 1408 |
+
return_dict: Optional[bool] = None,
|
| 1409 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1410 |
+
training: Optional[bool] = False,
|
| 1411 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 1412 |
+
r"""
|
| 1413 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1414 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1415 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1416 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1417 |
+
"""
|
| 1418 |
+
outputs = self.deberta(
|
| 1419 |
+
input_ids=input_ids,
|
| 1420 |
+
attention_mask=attention_mask,
|
| 1421 |
+
token_type_ids=token_type_ids,
|
| 1422 |
+
position_ids=position_ids,
|
| 1423 |
+
inputs_embeds=inputs_embeds,
|
| 1424 |
+
output_attentions=output_attentions,
|
| 1425 |
+
output_hidden_states=output_hidden_states,
|
| 1426 |
+
return_dict=return_dict,
|
| 1427 |
+
training=training,
|
| 1428 |
+
)
|
| 1429 |
+
sequence_output = outputs[0]
|
| 1430 |
+
pooled_output = self.pooler(sequence_output, training=training)
|
| 1431 |
+
pooled_output = self.dropout(pooled_output, training=training)
|
| 1432 |
+
logits = self.classifier(pooled_output)
|
| 1433 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1434 |
+
|
| 1435 |
+
if not return_dict:
|
| 1436 |
+
output = (logits,) + outputs[1:]
|
| 1437 |
+
|
| 1438 |
+
return ((loss,) + output) if loss is not None else output
|
| 1439 |
+
|
| 1440 |
+
return TFSequenceClassifierOutput(
|
| 1441 |
+
loss=loss,
|
| 1442 |
+
logits=logits,
|
| 1443 |
+
hidden_states=outputs.hidden_states,
|
| 1444 |
+
attentions=outputs.attentions,
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
def build(self, input_shape=None):
|
| 1448 |
+
if self.built:
|
| 1449 |
+
return
|
| 1450 |
+
self.built = True
|
| 1451 |
+
if getattr(self, "deberta", None) is not None:
|
| 1452 |
+
with tf.name_scope(self.deberta.name):
|
| 1453 |
+
self.deberta.build(None)
|
| 1454 |
+
if getattr(self, "pooler", None) is not None:
|
| 1455 |
+
with tf.name_scope(self.pooler.name):
|
| 1456 |
+
self.pooler.build(None)
|
| 1457 |
+
if getattr(self, "dropout", None) is not None:
|
| 1458 |
+
with tf.name_scope(self.dropout.name):
|
| 1459 |
+
self.dropout.build(None)
|
| 1460 |
+
if getattr(self, "classifier", None) is not None:
|
| 1461 |
+
with tf.name_scope(self.classifier.name):
|
| 1462 |
+
self.classifier.build([None, None, self.output_dim])
|
| 1463 |
+
|
| 1464 |
+
|
| 1465 |
+
@add_start_docstrings(
|
| 1466 |
+
"""
|
| 1467 |
+
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1468 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1469 |
+
""",
|
| 1470 |
+
DEBERTA_START_DOCSTRING,
|
| 1471 |
+
)
|
| 1472 |
+
class TFDebertaForTokenClassification(TFDebertaPreTrainedModel, TFTokenClassificationLoss):
|
| 1473 |
+
def __init__(self, config: DebertaConfig, *inputs, **kwargs):
|
| 1474 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1475 |
+
|
| 1476 |
+
self.num_labels = config.num_labels
|
| 1477 |
+
|
| 1478 |
+
self.deberta = TFDebertaMainLayer(config, name="deberta")
|
| 1479 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1480 |
+
self.classifier = keras.layers.Dense(
|
| 1481 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1482 |
+
)
|
| 1483 |
+
self.config = config
|
| 1484 |
+
|
| 1485 |
+
@unpack_inputs
|
| 1486 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1487 |
+
@add_code_sample_docstrings(
|
| 1488 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1489 |
+
output_type=TFTokenClassifierOutput,
|
| 1490 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1491 |
+
)
|
| 1492 |
+
def call(
|
| 1493 |
+
self,
|
| 1494 |
+
input_ids: TFModelInputType | None = None,
|
| 1495 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1496 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1497 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1498 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1499 |
+
output_attentions: Optional[bool] = None,
|
| 1500 |
+
output_hidden_states: Optional[bool] = None,
|
| 1501 |
+
return_dict: Optional[bool] = None,
|
| 1502 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1503 |
+
training: Optional[bool] = False,
|
| 1504 |
+
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
| 1505 |
+
r"""
|
| 1506 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1507 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1508 |
+
"""
|
| 1509 |
+
outputs = self.deberta(
|
| 1510 |
+
input_ids=input_ids,
|
| 1511 |
+
attention_mask=attention_mask,
|
| 1512 |
+
token_type_ids=token_type_ids,
|
| 1513 |
+
position_ids=position_ids,
|
| 1514 |
+
inputs_embeds=inputs_embeds,
|
| 1515 |
+
output_attentions=output_attentions,
|
| 1516 |
+
output_hidden_states=output_hidden_states,
|
| 1517 |
+
return_dict=return_dict,
|
| 1518 |
+
training=training,
|
| 1519 |
+
)
|
| 1520 |
+
sequence_output = outputs[0]
|
| 1521 |
+
sequence_output = self.dropout(sequence_output, training=training)
|
| 1522 |
+
logits = self.classifier(inputs=sequence_output)
|
| 1523 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1524 |
+
|
| 1525 |
+
if not return_dict:
|
| 1526 |
+
output = (logits,) + outputs[1:]
|
| 1527 |
+
return ((loss,) + output) if loss is not None else output
|
| 1528 |
+
|
| 1529 |
+
return TFTokenClassifierOutput(
|
| 1530 |
+
loss=loss,
|
| 1531 |
+
logits=logits,
|
| 1532 |
+
hidden_states=outputs.hidden_states,
|
| 1533 |
+
attentions=outputs.attentions,
|
| 1534 |
+
)
|
| 1535 |
+
|
| 1536 |
+
def build(self, input_shape=None):
|
| 1537 |
+
if self.built:
|
| 1538 |
+
return
|
| 1539 |
+
self.built = True
|
| 1540 |
+
if getattr(self, "deberta", None) is not None:
|
| 1541 |
+
with tf.name_scope(self.deberta.name):
|
| 1542 |
+
self.deberta.build(None)
|
| 1543 |
+
if getattr(self, "classifier", None) is not None:
|
| 1544 |
+
with tf.name_scope(self.classifier.name):
|
| 1545 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1546 |
+
|
| 1547 |
+
|
| 1548 |
+
@add_start_docstrings(
|
| 1549 |
+
"""
|
| 1550 |
+
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1551 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1552 |
+
""",
|
| 1553 |
+
DEBERTA_START_DOCSTRING,
|
| 1554 |
+
)
|
| 1555 |
+
class TFDebertaForQuestionAnswering(TFDebertaPreTrainedModel, TFQuestionAnsweringLoss):
|
| 1556 |
+
def __init__(self, config: DebertaConfig, *inputs, **kwargs):
|
| 1557 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1558 |
+
|
| 1559 |
+
self.num_labels = config.num_labels
|
| 1560 |
+
|
| 1561 |
+
self.deberta = TFDebertaMainLayer(config, name="deberta")
|
| 1562 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1563 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
| 1564 |
+
)
|
| 1565 |
+
self.config = config
|
| 1566 |
+
|
| 1567 |
+
@unpack_inputs
|
| 1568 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1569 |
+
@add_code_sample_docstrings(
|
| 1570 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1571 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1572 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1573 |
+
)
|
| 1574 |
+
def call(
|
| 1575 |
+
self,
|
| 1576 |
+
input_ids: TFModelInputType | None = None,
|
| 1577 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1578 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1579 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1580 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1581 |
+
output_attentions: Optional[bool] = None,
|
| 1582 |
+
output_hidden_states: Optional[bool] = None,
|
| 1583 |
+
return_dict: Optional[bool] = None,
|
| 1584 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1585 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1586 |
+
training: Optional[bool] = False,
|
| 1587 |
+
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
| 1588 |
+
r"""
|
| 1589 |
+
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1590 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1591 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1592 |
+
are not taken into account for computing the loss.
|
| 1593 |
+
end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1594 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1595 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1596 |
+
are not taken into account for computing the loss.
|
| 1597 |
+
"""
|
| 1598 |
+
outputs = self.deberta(
|
| 1599 |
+
input_ids=input_ids,
|
| 1600 |
+
attention_mask=attention_mask,
|
| 1601 |
+
token_type_ids=token_type_ids,
|
| 1602 |
+
position_ids=position_ids,
|
| 1603 |
+
inputs_embeds=inputs_embeds,
|
| 1604 |
+
output_attentions=output_attentions,
|
| 1605 |
+
output_hidden_states=output_hidden_states,
|
| 1606 |
+
return_dict=return_dict,
|
| 1607 |
+
training=training,
|
| 1608 |
+
)
|
| 1609 |
+
sequence_output = outputs[0]
|
| 1610 |
+
logits = self.qa_outputs(inputs=sequence_output)
|
| 1611 |
+
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
|
| 1612 |
+
start_logits = tf.squeeze(input=start_logits, axis=-1)
|
| 1613 |
+
end_logits = tf.squeeze(input=end_logits, axis=-1)
|
| 1614 |
+
loss = None
|
| 1615 |
+
|
| 1616 |
+
if start_positions is not None and end_positions is not None:
|
| 1617 |
+
labels = {"start_position": start_positions}
|
| 1618 |
+
labels["end_position"] = end_positions
|
| 1619 |
+
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
| 1620 |
+
|
| 1621 |
+
if not return_dict:
|
| 1622 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1623 |
+
return ((loss,) + output) if loss is not None else output
|
| 1624 |
+
|
| 1625 |
+
return TFQuestionAnsweringModelOutput(
|
| 1626 |
+
loss=loss,
|
| 1627 |
+
start_logits=start_logits,
|
| 1628 |
+
end_logits=end_logits,
|
| 1629 |
+
hidden_states=outputs.hidden_states,
|
| 1630 |
+
attentions=outputs.attentions,
|
| 1631 |
+
)
|
| 1632 |
+
|
| 1633 |
+
def build(self, input_shape=None):
|
| 1634 |
+
if self.built:
|
| 1635 |
+
return
|
| 1636 |
+
self.built = True
|
| 1637 |
+
if getattr(self, "deberta", None) is not None:
|
| 1638 |
+
with tf.name_scope(self.deberta.name):
|
| 1639 |
+
self.deberta.build(None)
|
| 1640 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1641 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1642 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1643 |
+
|
| 1644 |
+
|
| 1645 |
+
__all__ = [
|
| 1646 |
+
"TFDebertaForMaskedLM",
|
| 1647 |
+
"TFDebertaForQuestionAnswering",
|
| 1648 |
+
"TFDebertaForSequenceClassification",
|
| 1649 |
+
"TFDebertaForTokenClassification",
|
| 1650 |
+
"TFDebertaModel",
|
| 1651 |
+
"TFDebertaPreTrainedModel",
|
| 1652 |
+
]
|
docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization class for model DeBERTa."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import regex as re
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
| 33 |
+
def bytes_to_unicode():
|
| 34 |
+
"""
|
| 35 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
| 36 |
+
characters the bpe code barfs on.
|
| 37 |
+
|
| 38 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
| 39 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
| 40 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
| 41 |
+
tables between utf-8 bytes and unicode strings.
|
| 42 |
+
"""
|
| 43 |
+
bs = (
|
| 44 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| 45 |
+
)
|
| 46 |
+
cs = bs[:]
|
| 47 |
+
n = 0
|
| 48 |
+
for b in range(2**8):
|
| 49 |
+
if b not in bs:
|
| 50 |
+
bs.append(b)
|
| 51 |
+
cs.append(2**8 + n)
|
| 52 |
+
n += 1
|
| 53 |
+
cs = [chr(n) for n in cs]
|
| 54 |
+
return dict(zip(bs, cs))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
| 58 |
+
def get_pairs(word):
|
| 59 |
+
"""
|
| 60 |
+
Return set of symbol pairs in a word.
|
| 61 |
+
|
| 62 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 63 |
+
"""
|
| 64 |
+
pairs = set()
|
| 65 |
+
prev_char = word[0]
|
| 66 |
+
for char in word[1:]:
|
| 67 |
+
pairs.add((prev_char, char))
|
| 68 |
+
prev_char = char
|
| 69 |
+
return pairs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DebertaTokenizer(PreTrainedTokenizer):
|
| 73 |
+
"""
|
| 74 |
+
Construct a DeBERTa tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 75 |
+
|
| 76 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
| 77 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import DebertaTokenizer
|
| 81 |
+
|
| 82 |
+
>>> tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
|
| 83 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 84 |
+
[1, 31414, 232, 2]
|
| 85 |
+
|
| 86 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 87 |
+
[1, 20920, 232, 2]
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
| 91 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
| 92 |
+
|
| 93 |
+
<Tip>
|
| 94 |
+
|
| 95 |
+
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
|
| 96 |
+
|
| 97 |
+
</Tip>
|
| 98 |
+
|
| 99 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 100 |
+
this superclass for more information regarding those methods.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
vocab_file (`str`):
|
| 104 |
+
Path to the vocabulary file.
|
| 105 |
+
merges_file (`str`):
|
| 106 |
+
Path to the merges file.
|
| 107 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 108 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 109 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 110 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 111 |
+
The beginning of sequence token.
|
| 112 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 113 |
+
The end of sequence token.
|
| 114 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 115 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 116 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 117 |
+
token of a sequence built with special tokens.
|
| 118 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 119 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 120 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 121 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 122 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 123 |
+
token instead.
|
| 124 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
| 125 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 126 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 127 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 128 |
+
modeling. This is the token which the model will try to predict.
|
| 129 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 130 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 131 |
+
other word. (Deberta tokenizer detect beginning of words by the preceding space).
|
| 132 |
+
add_bos_token (`bool`, *optional*, defaults to `False`):
|
| 133 |
+
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
|
| 134 |
+
any other word.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 138 |
+
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
vocab_file,
|
| 143 |
+
merges_file,
|
| 144 |
+
errors="replace",
|
| 145 |
+
bos_token="[CLS]",
|
| 146 |
+
eos_token="[SEP]",
|
| 147 |
+
sep_token="[SEP]",
|
| 148 |
+
cls_token="[CLS]",
|
| 149 |
+
unk_token="[UNK]",
|
| 150 |
+
pad_token="[PAD]",
|
| 151 |
+
mask_token="[MASK]",
|
| 152 |
+
add_prefix_space=False,
|
| 153 |
+
add_bos_token=False,
|
| 154 |
+
**kwargs,
|
| 155 |
+
):
|
| 156 |
+
bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
|
| 157 |
+
eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
|
| 158 |
+
sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token
|
| 159 |
+
cls_token = AddedToken(cls_token, special=True) if isinstance(cls_token, str) else cls_token
|
| 160 |
+
unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
|
| 161 |
+
pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
|
| 162 |
+
|
| 163 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 164 |
+
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
| 165 |
+
self.add_bos_token = add_bos_token
|
| 166 |
+
|
| 167 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 168 |
+
self.encoder = json.load(vocab_handle)
|
| 169 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 170 |
+
self.errors = errors # how to handle errors in decoding
|
| 171 |
+
self.byte_encoder = bytes_to_unicode()
|
| 172 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 173 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 174 |
+
bpe_merges = merges_handle.read().split("\n")[1:-1]
|
| 175 |
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
| 176 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| 177 |
+
self.cache = {}
|
| 178 |
+
self.add_prefix_space = add_prefix_space
|
| 179 |
+
|
| 180 |
+
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
| 181 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
| 182 |
+
|
| 183 |
+
super().__init__(
|
| 184 |
+
errors=errors,
|
| 185 |
+
bos_token=bos_token,
|
| 186 |
+
eos_token=eos_token,
|
| 187 |
+
unk_token=unk_token,
|
| 188 |
+
sep_token=sep_token,
|
| 189 |
+
cls_token=cls_token,
|
| 190 |
+
pad_token=pad_token,
|
| 191 |
+
mask_token=mask_token,
|
| 192 |
+
add_prefix_space=add_prefix_space,
|
| 193 |
+
add_bos_token=add_bos_token,
|
| 194 |
+
**kwargs,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.vocab_size
|
| 199 |
+
def vocab_size(self):
|
| 200 |
+
return len(self.encoder)
|
| 201 |
+
|
| 202 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
|
| 203 |
+
def get_vocab(self):
|
| 204 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 205 |
+
|
| 206 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
| 207 |
+
def bpe(self, token):
|
| 208 |
+
if token in self.cache:
|
| 209 |
+
return self.cache[token]
|
| 210 |
+
word = tuple(token)
|
| 211 |
+
pairs = get_pairs(word)
|
| 212 |
+
|
| 213 |
+
if not pairs:
|
| 214 |
+
return token
|
| 215 |
+
|
| 216 |
+
while True:
|
| 217 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 218 |
+
if bigram not in self.bpe_ranks:
|
| 219 |
+
break
|
| 220 |
+
first, second = bigram
|
| 221 |
+
new_word = []
|
| 222 |
+
i = 0
|
| 223 |
+
while i < len(word):
|
| 224 |
+
try:
|
| 225 |
+
j = word.index(first, i)
|
| 226 |
+
except ValueError:
|
| 227 |
+
new_word.extend(word[i:])
|
| 228 |
+
break
|
| 229 |
+
else:
|
| 230 |
+
new_word.extend(word[i:j])
|
| 231 |
+
i = j
|
| 232 |
+
|
| 233 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 234 |
+
new_word.append(first + second)
|
| 235 |
+
i += 2
|
| 236 |
+
else:
|
| 237 |
+
new_word.append(word[i])
|
| 238 |
+
i += 1
|
| 239 |
+
new_word = tuple(new_word)
|
| 240 |
+
word = new_word
|
| 241 |
+
if len(word) == 1:
|
| 242 |
+
break
|
| 243 |
+
else:
|
| 244 |
+
pairs = get_pairs(word)
|
| 245 |
+
word = " ".join(word)
|
| 246 |
+
self.cache[token] = word
|
| 247 |
+
return word
|
| 248 |
+
|
| 249 |
+
def build_inputs_with_special_tokens(
|
| 250 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 251 |
+
) -> List[int]:
|
| 252 |
+
"""
|
| 253 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 254 |
+
adding special tokens. A DeBERTa sequence has the following format:
|
| 255 |
+
|
| 256 |
+
- single sequence: [CLS] X [SEP]
|
| 257 |
+
- pair of sequences: [CLS] A [SEP] B [SEP]
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
token_ids_0 (`List[int]`):
|
| 261 |
+
List of IDs to which the special tokens will be added.
|
| 262 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 263 |
+
Optional second list of IDs for sequence pairs.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 267 |
+
"""
|
| 268 |
+
if token_ids_1 is None:
|
| 269 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 270 |
+
cls = [self.cls_token_id]
|
| 271 |
+
sep = [self.sep_token_id]
|
| 272 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 273 |
+
|
| 274 |
+
def get_special_tokens_mask(
|
| 275 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 276 |
+
) -> List[int]:
|
| 277 |
+
"""
|
| 278 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 279 |
+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
token_ids_0 (`List[int]`):
|
| 283 |
+
List of IDs.
|
| 284 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 285 |
+
Optional second list of IDs for sequence pairs.
|
| 286 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 287 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 291 |
+
"""
|
| 292 |
+
if already_has_special_tokens:
|
| 293 |
+
return super().get_special_tokens_mask(
|
| 294 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if token_ids_1 is None:
|
| 298 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 299 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 300 |
+
|
| 301 |
+
def create_token_type_ids_from_sequences(
|
| 302 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 303 |
+
) -> List[int]:
|
| 304 |
+
"""
|
| 305 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
|
| 306 |
+
sequence pair mask has the following format:
|
| 307 |
+
|
| 308 |
+
```
|
| 309 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 310 |
+
| first sequence | second sequence |
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
token_ids_0 (`List[int]`):
|
| 317 |
+
List of IDs.
|
| 318 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 319 |
+
Optional second list of IDs for sequence pairs.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 323 |
+
"""
|
| 324 |
+
sep = [self.sep_token_id]
|
| 325 |
+
cls = [self.cls_token_id]
|
| 326 |
+
|
| 327 |
+
if token_ids_1 is None:
|
| 328 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 329 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 330 |
+
|
| 331 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
|
| 332 |
+
def _tokenize(self, text):
|
| 333 |
+
"""Tokenize a string."""
|
| 334 |
+
bpe_tokens = []
|
| 335 |
+
for token in re.findall(self.pat, text):
|
| 336 |
+
token = "".join(
|
| 337 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
| 338 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
| 339 |
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
| 340 |
+
return bpe_tokens
|
| 341 |
+
|
| 342 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
|
| 343 |
+
def _convert_token_to_id(self, token):
|
| 344 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 345 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 346 |
+
|
| 347 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
|
| 348 |
+
def _convert_id_to_token(self, index):
|
| 349 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 350 |
+
return self.decoder.get(index)
|
| 351 |
+
|
| 352 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
|
| 353 |
+
def convert_tokens_to_string(self, tokens):
|
| 354 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 355 |
+
text = "".join(tokens)
|
| 356 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
| 357 |
+
return text
|
| 358 |
+
|
| 359 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
|
| 360 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 361 |
+
if not os.path.isdir(save_directory):
|
| 362 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 363 |
+
return
|
| 364 |
+
vocab_file = os.path.join(
|
| 365 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 366 |
+
)
|
| 367 |
+
merge_file = os.path.join(
|
| 368 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 372 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 373 |
+
|
| 374 |
+
index = 0
|
| 375 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 376 |
+
writer.write("#version: 0.2\n")
|
| 377 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 378 |
+
if index != token_index:
|
| 379 |
+
logger.warning(
|
| 380 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 381 |
+
" Please check that the tokenizer is not corrupted!"
|
| 382 |
+
)
|
| 383 |
+
index = token_index
|
| 384 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 385 |
+
index += 1
|
| 386 |
+
|
| 387 |
+
return vocab_file, merge_file
|
| 388 |
+
|
| 389 |
+
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
| 390 |
+
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
| 391 |
+
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
|
| 392 |
+
text = " " + text
|
| 393 |
+
return (text, kwargs)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
__all__ = ["DebertaTokenizer"]
|
docs/transformers/build/lib/transformers/models/deberta/tokenization_deberta_fast.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Tokenization class for model DeBERTa."""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
| 20 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
from .tokenization_deberta import DebertaTokenizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DebertaTokenizerFast(PreTrainedTokenizerFast):
|
| 31 |
+
"""
|
| 32 |
+
Construct a "fast" DeBERTa tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
| 33 |
+
Byte-Pair-Encoding.
|
| 34 |
+
|
| 35 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
| 36 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
>>> from transformers import DebertaTokenizerFast
|
| 40 |
+
|
| 41 |
+
>>> tokenizer = DebertaTokenizerFast.from_pretrained("microsoft/deberta-base")
|
| 42 |
+
>>> tokenizer("Hello world")["input_ids"]
|
| 43 |
+
[1, 31414, 232, 2]
|
| 44 |
+
|
| 45 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
| 46 |
+
[1, 20920, 232, 2]
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
|
| 50 |
+
the model was not pretrained this way, it might yield a decrease in performance.
|
| 51 |
+
|
| 52 |
+
<Tip>
|
| 53 |
+
|
| 54 |
+
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
| 55 |
+
|
| 56 |
+
</Tip>
|
| 57 |
+
|
| 58 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 59 |
+
refer to this superclass for more information regarding those methods.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
vocab_file (`str`, *optional*):
|
| 63 |
+
Path to the vocabulary file.
|
| 64 |
+
merges_file (`str`, *optional*):
|
| 65 |
+
Path to the merges file.
|
| 66 |
+
tokenizer_file (`str`, *optional*):
|
| 67 |
+
The path to a tokenizer file to use instead of the vocab file.
|
| 68 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
| 69 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
| 70 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
| 71 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 72 |
+
The beginning of sequence token.
|
| 73 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 74 |
+
The end of sequence token.
|
| 75 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 76 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 77 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 78 |
+
token of a sequence built with special tokens.
|
| 79 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 80 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 81 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 82 |
+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
|
| 83 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 84 |
+
token instead.
|
| 85 |
+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
|
| 86 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 87 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 88 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 89 |
+
modeling. This is the token which the model will try to predict.
|
| 90 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 91 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
| 92 |
+
other word. (Deberta tokenizer detect beginning of words by the preceding space).
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 96 |
+
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
| 97 |
+
slow_tokenizer_class = DebertaTokenizer
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
vocab_file=None,
|
| 102 |
+
merges_file=None,
|
| 103 |
+
tokenizer_file=None,
|
| 104 |
+
errors="replace",
|
| 105 |
+
bos_token="[CLS]",
|
| 106 |
+
eos_token="[SEP]",
|
| 107 |
+
sep_token="[SEP]",
|
| 108 |
+
cls_token="[CLS]",
|
| 109 |
+
unk_token="[UNK]",
|
| 110 |
+
pad_token="[PAD]",
|
| 111 |
+
mask_token="[MASK]",
|
| 112 |
+
add_prefix_space=False,
|
| 113 |
+
**kwargs,
|
| 114 |
+
):
|
| 115 |
+
super().__init__(
|
| 116 |
+
vocab_file,
|
| 117 |
+
merges_file,
|
| 118 |
+
tokenizer_file=tokenizer_file,
|
| 119 |
+
errors=errors,
|
| 120 |
+
bos_token=bos_token,
|
| 121 |
+
eos_token=eos_token,
|
| 122 |
+
unk_token=unk_token,
|
| 123 |
+
sep_token=sep_token,
|
| 124 |
+
cls_token=cls_token,
|
| 125 |
+
pad_token=pad_token,
|
| 126 |
+
mask_token=mask_token,
|
| 127 |
+
add_prefix_space=add_prefix_space,
|
| 128 |
+
**kwargs,
|
| 129 |
+
)
|
| 130 |
+
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def mask_token(self) -> str:
|
| 134 |
+
"""
|
| 135 |
+
`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
|
| 136 |
+
having been set.
|
| 137 |
+
|
| 138 |
+
Deberta tokenizer has a special mask token to be used in the fill-mask pipeline. The mask token will greedily
|
| 139 |
+
comprise the space before the *[MASK]*.
|
| 140 |
+
"""
|
| 141 |
+
if self._mask_token is None:
|
| 142 |
+
if self.verbose:
|
| 143 |
+
logger.error("Using mask_token, but it is not set yet.")
|
| 144 |
+
return None
|
| 145 |
+
return str(self._mask_token)
|
| 146 |
+
|
| 147 |
+
@mask_token.setter
|
| 148 |
+
def mask_token(self, value):
|
| 149 |
+
"""
|
| 150 |
+
Overriding the default behavior of the mask token to have it eat the space before it.
|
| 151 |
+
"""
|
| 152 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 153 |
+
# So we set lstrip to True
|
| 154 |
+
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
|
| 155 |
+
self._mask_token = value
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 162 |
+
adding special tokens. A DeBERTa sequence has the following format:
|
| 163 |
+
|
| 164 |
+
- single sequence: [CLS] X [SEP]
|
| 165 |
+
- pair of sequences: [CLS] A [SEP] B [SEP]
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
token_ids_0 (`List[int]`):
|
| 169 |
+
List of IDs to which the special tokens will be added.
|
| 170 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 171 |
+
Optional second list of IDs for sequence pairs.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 175 |
+
"""
|
| 176 |
+
if token_ids_1 is None:
|
| 177 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 178 |
+
cls = [self.cls_token_id]
|
| 179 |
+
sep = [self.sep_token_id]
|
| 180 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 181 |
+
|
| 182 |
+
def create_token_type_ids_from_sequences(
|
| 183 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 184 |
+
) -> List[int]:
|
| 185 |
+
"""
|
| 186 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
|
| 187 |
+
sequence pair mask has the following format:
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 191 |
+
| first sequence | second sequence |
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
token_ids_0 (`List[int]`):
|
| 198 |
+
List of IDs.
|
| 199 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 200 |
+
Optional second list of IDs for sequence pairs.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 204 |
+
"""
|
| 205 |
+
sep = [self.sep_token_id]
|
| 206 |
+
cls = [self.cls_token_id]
|
| 207 |
+
|
| 208 |
+
if token_ids_1 is None:
|
| 209 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 210 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 211 |
+
|
| 212 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._batch_encode_plus
|
| 213 |
+
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
| 214 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
| 215 |
+
assert self.add_prefix_space or not is_split_into_words, (
|
| 216 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
| 217 |
+
"to use it with pretokenized inputs."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return super()._batch_encode_plus(*args, **kwargs)
|
| 221 |
+
|
| 222 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._encode_plus
|
| 223 |
+
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
| 224 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
| 225 |
+
|
| 226 |
+
assert self.add_prefix_space or not is_split_into_words, (
|
| 227 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
| 228 |
+
"to use it with pretokenized inputs."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return super()._encode_plus(*args, **kwargs)
|
| 232 |
+
|
| 233 |
+
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
|
| 234 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 235 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 236 |
+
return tuple(files)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
__all__ = ["DebertaTokenizerFast"]
|
docs/transformers/build/lib/transformers/models/deberta_v2/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_deberta_v2 import *
|
| 22 |
+
from .modeling_deberta_v2 import *
|
| 23 |
+
from .modeling_tf_deberta_v2 import *
|
| 24 |
+
from .tokenization_deberta_v2 import *
|
| 25 |
+
from .tokenization_deberta_v2_fast import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/deberta_v2/configuration_deberta_v2.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""DeBERTa-v2 model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...onnx import OnnxConfig
|
| 22 |
+
from ...utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DebertaV2Config(PretrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a
|
| 35 |
+
DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a
|
| 36 |
+
configuration with the defaults will yield a similar configuration to that of the DeBERTa
|
| 37 |
+
[microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
vocab_size (`int`, *optional*, defaults to 128100):
|
| 44 |
+
Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by
|
| 45 |
+
the `inputs_ids` passed when calling [`DebertaV2Model`].
|
| 46 |
+
hidden_size (`int`, *optional*, defaults to 1536):
|
| 47 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 48 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 49 |
+
Number of hidden layers in the Transformer encoder.
|
| 50 |
+
num_attention_heads (`int`, *optional*, defaults to 24):
|
| 51 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 52 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 53 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 54 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 56 |
+
`"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"`
|
| 57 |
+
are supported.
|
| 58 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 59 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 60 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 61 |
+
The dropout ratio for the attention probabilities.
|
| 62 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 63 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 64 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 65 |
+
type_vocab_size (`int`, *optional*, defaults to 0):
|
| 66 |
+
The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`].
|
| 67 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 69 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-7):
|
| 70 |
+
The epsilon used by the layer normalization layers.
|
| 71 |
+
relative_attention (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether use relative position encoding.
|
| 73 |
+
max_relative_positions (`int`, *optional*, defaults to -1):
|
| 74 |
+
The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value
|
| 75 |
+
as `max_position_embeddings`.
|
| 76 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 77 |
+
The value used to pad input_ids.
|
| 78 |
+
position_biased_input (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether add absolute position embedding to content embedding.
|
| 80 |
+
pos_att_type (`List[str]`, *optional*):
|
| 81 |
+
The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
|
| 82 |
+
`["p2c", "c2p"]`, `["p2c", "c2p"]`.
|
| 83 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 84 |
+
The epsilon used by the layer normalization layers.
|
| 85 |
+
legacy (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether or not the model should use the legacy `LegacyDebertaOnlyMLMHead`, which does not work properly
|
| 87 |
+
for mask infilling tasks.
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
>>> from transformers import DebertaV2Config, DebertaV2Model
|
| 93 |
+
|
| 94 |
+
>>> # Initializing a DeBERTa-v2 microsoft/deberta-v2-xlarge style configuration
|
| 95 |
+
>>> configuration = DebertaV2Config()
|
| 96 |
+
|
| 97 |
+
>>> # Initializing a model (with random weights) from the microsoft/deberta-v2-xlarge style configuration
|
| 98 |
+
>>> model = DebertaV2Model(configuration)
|
| 99 |
+
|
| 100 |
+
>>> # Accessing the model configuration
|
| 101 |
+
>>> configuration = model.config
|
| 102 |
+
```"""
|
| 103 |
+
|
| 104 |
+
model_type = "deberta-v2"
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size=128100,
|
| 109 |
+
hidden_size=1536,
|
| 110 |
+
num_hidden_layers=24,
|
| 111 |
+
num_attention_heads=24,
|
| 112 |
+
intermediate_size=6144,
|
| 113 |
+
hidden_act="gelu",
|
| 114 |
+
hidden_dropout_prob=0.1,
|
| 115 |
+
attention_probs_dropout_prob=0.1,
|
| 116 |
+
max_position_embeddings=512,
|
| 117 |
+
type_vocab_size=0,
|
| 118 |
+
initializer_range=0.02,
|
| 119 |
+
layer_norm_eps=1e-7,
|
| 120 |
+
relative_attention=False,
|
| 121 |
+
max_relative_positions=-1,
|
| 122 |
+
pad_token_id=0,
|
| 123 |
+
position_biased_input=True,
|
| 124 |
+
pos_att_type=None,
|
| 125 |
+
pooler_dropout=0,
|
| 126 |
+
pooler_hidden_act="gelu",
|
| 127 |
+
legacy=True,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
super().__init__(**kwargs)
|
| 131 |
+
|
| 132 |
+
self.hidden_size = hidden_size
|
| 133 |
+
self.num_hidden_layers = num_hidden_layers
|
| 134 |
+
self.num_attention_heads = num_attention_heads
|
| 135 |
+
self.intermediate_size = intermediate_size
|
| 136 |
+
self.hidden_act = hidden_act
|
| 137 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 138 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 139 |
+
self.max_position_embeddings = max_position_embeddings
|
| 140 |
+
self.type_vocab_size = type_vocab_size
|
| 141 |
+
self.initializer_range = initializer_range
|
| 142 |
+
self.relative_attention = relative_attention
|
| 143 |
+
self.max_relative_positions = max_relative_positions
|
| 144 |
+
self.pad_token_id = pad_token_id
|
| 145 |
+
self.position_biased_input = position_biased_input
|
| 146 |
+
|
| 147 |
+
# Backwards compatibility
|
| 148 |
+
if isinstance(pos_att_type, str):
|
| 149 |
+
pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
|
| 150 |
+
|
| 151 |
+
self.pos_att_type = pos_att_type
|
| 152 |
+
self.vocab_size = vocab_size
|
| 153 |
+
self.layer_norm_eps = layer_norm_eps
|
| 154 |
+
|
| 155 |
+
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
|
| 156 |
+
self.pooler_dropout = pooler_dropout
|
| 157 |
+
self.pooler_hidden_act = pooler_hidden_act
|
| 158 |
+
self.legacy = legacy
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class DebertaV2OnnxConfig(OnnxConfig):
|
| 162 |
+
@property
|
| 163 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 164 |
+
if self.task == "multiple-choice":
|
| 165 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 166 |
+
else:
|
| 167 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 168 |
+
if self._config.type_vocab_size > 0:
|
| 169 |
+
return OrderedDict(
|
| 170 |
+
[("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def default_onnx_opset(self) -> int:
|
| 177 |
+
return 12
|
| 178 |
+
|
| 179 |
+
def generate_dummy_inputs(
|
| 180 |
+
self,
|
| 181 |
+
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
| 182 |
+
batch_size: int = -1,
|
| 183 |
+
seq_length: int = -1,
|
| 184 |
+
num_choices: int = -1,
|
| 185 |
+
is_pair: bool = False,
|
| 186 |
+
framework: Optional["TensorType"] = None,
|
| 187 |
+
num_channels: int = 3,
|
| 188 |
+
image_width: int = 40,
|
| 189 |
+
image_height: int = 40,
|
| 190 |
+
tokenizer: "PreTrainedTokenizerBase" = None,
|
| 191 |
+
) -> Mapping[str, Any]:
|
| 192 |
+
dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
|
| 193 |
+
if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
|
| 194 |
+
del dummy_inputs["token_type_ids"]
|
| 195 |
+
return dummy_inputs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
__all__ = ["DebertaV2Config", "DebertaV2OnnxConfig"]
|
docs/transformers/build/lib/transformers/models/deberta_v2/modeling_deberta_v2.py
ADDED
|
@@ -0,0 +1,1523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Microsoft and the Hugging Face Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DeBERTa-v2 model."""
|
| 16 |
+
|
| 17 |
+
from collections.abc import Sequence
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import (
|
| 27 |
+
BaseModelOutput,
|
| 28 |
+
MaskedLMOutput,
|
| 29 |
+
MultipleChoiceModelOutput,
|
| 30 |
+
QuestionAnsweringModelOutput,
|
| 31 |
+
SequenceClassifierOutput,
|
| 32 |
+
TokenClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from ...modeling_utils import PreTrainedModel
|
| 35 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 36 |
+
from .configuration_deberta_v2 import DebertaV2Config
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
_CONFIG_FOR_DOC = "DebertaV2Config"
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
|
| 43 |
+
_QA_TARGET_START_INDEX = 2
|
| 44 |
+
_QA_TARGET_END_INDEX = 9
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
|
| 48 |
+
class DebertaV2SelfOutput(nn.Module):
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 52 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 53 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states, input_tensor):
|
| 56 |
+
hidden_states = self.dense(hidden_states)
|
| 57 |
+
hidden_states = self.dropout(hidden_states)
|
| 58 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 59 |
+
return hidden_states
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@torch.jit.script
|
| 63 |
+
def make_log_bucket_position(relative_pos, bucket_size: int, max_position: int):
|
| 64 |
+
sign = torch.sign(relative_pos)
|
| 65 |
+
mid = bucket_size // 2
|
| 66 |
+
abs_pos = torch.where(
|
| 67 |
+
(relative_pos < mid) & (relative_pos > -mid),
|
| 68 |
+
torch.tensor(mid - 1).type_as(relative_pos),
|
| 69 |
+
torch.abs(relative_pos),
|
| 70 |
+
)
|
| 71 |
+
log_pos = (
|
| 72 |
+
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
|
| 73 |
+
)
|
| 74 |
+
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
|
| 75 |
+
return bucket_pos
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_relative_position(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1):
|
| 79 |
+
"""
|
| 80 |
+
Build relative position according to the query and key
|
| 81 |
+
|
| 82 |
+
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
|
| 83 |
+
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
|
| 84 |
+
P_k\\)
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
query_size (int): the length of query
|
| 88 |
+
key_size (int): the length of key
|
| 89 |
+
bucket_size (int): the size of position bucket
|
| 90 |
+
max_position (int): the maximum allowed absolute position
|
| 91 |
+
device (`torch.device`): the device on which tensors will be created.
|
| 92 |
+
|
| 93 |
+
Return:
|
| 94 |
+
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
| 95 |
+
"""
|
| 96 |
+
query_size = query_layer.size(-2)
|
| 97 |
+
key_size = key_layer.size(-2)
|
| 98 |
+
|
| 99 |
+
q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
|
| 100 |
+
k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
|
| 101 |
+
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
|
| 102 |
+
if bucket_size > 0 and max_position > 0:
|
| 103 |
+
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
| 104 |
+
rel_pos_ids = rel_pos_ids.to(torch.long)
|
| 105 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 106 |
+
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
| 107 |
+
return rel_pos_ids
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@torch.jit.script
|
| 111 |
+
# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
|
| 112 |
+
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
| 113 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@torch.jit.script
|
| 117 |
+
# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
|
| 118 |
+
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
| 119 |
+
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@torch.jit.script
|
| 123 |
+
# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
|
| 124 |
+
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
| 125 |
+
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@torch.jit.script
|
| 129 |
+
def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
|
| 130 |
+
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@torch.jit.script
|
| 134 |
+
def build_rpos(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int):
|
| 135 |
+
if key_layer.size(-2) != query_layer.size(-2):
|
| 136 |
+
return build_relative_position(
|
| 137 |
+
key_layer,
|
| 138 |
+
key_layer,
|
| 139 |
+
bucket_size=position_buckets,
|
| 140 |
+
max_position=max_relative_positions,
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
return relative_pos
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class DisentangledSelfAttention(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
Disentangled self-attention module
|
| 149 |
+
|
| 150 |
+
Parameters:
|
| 151 |
+
config (`DebertaV2Config`):
|
| 152 |
+
A model config class instance with the configuration to build a new model. The schema is similar to
|
| 153 |
+
*BertConfig*, for more details, please refer [`DebertaV2Config`]
|
| 154 |
+
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, config):
|
| 158 |
+
super().__init__()
|
| 159 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 162 |
+
f"heads ({config.num_attention_heads})"
|
| 163 |
+
)
|
| 164 |
+
self.num_attention_heads = config.num_attention_heads
|
| 165 |
+
_attention_head_size = config.hidden_size // config.num_attention_heads
|
| 166 |
+
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
| 167 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 168 |
+
self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 169 |
+
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 170 |
+
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 171 |
+
|
| 172 |
+
self.share_att_key = getattr(config, "share_att_key", False)
|
| 173 |
+
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
|
| 174 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 175 |
+
|
| 176 |
+
if self.relative_attention:
|
| 177 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 178 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 179 |
+
if self.max_relative_positions < 1:
|
| 180 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 181 |
+
self.pos_ebd_size = self.max_relative_positions
|
| 182 |
+
if self.position_buckets > 0:
|
| 183 |
+
self.pos_ebd_size = self.position_buckets
|
| 184 |
+
|
| 185 |
+
self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 186 |
+
|
| 187 |
+
if not self.share_att_key:
|
| 188 |
+
if "c2p" in self.pos_att_type:
|
| 189 |
+
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
| 190 |
+
if "p2c" in self.pos_att_type:
|
| 191 |
+
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
| 192 |
+
|
| 193 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 194 |
+
|
| 195 |
+
def transpose_for_scores(self, x, attention_heads) -> torch.Tensor:
|
| 196 |
+
new_x_shape = x.size()[:-1] + (attention_heads, -1)
|
| 197 |
+
x = x.view(new_x_shape)
|
| 198 |
+
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
|
| 199 |
+
|
| 200 |
+
def forward(
|
| 201 |
+
self,
|
| 202 |
+
hidden_states,
|
| 203 |
+
attention_mask,
|
| 204 |
+
output_attentions=False,
|
| 205 |
+
query_states=None,
|
| 206 |
+
relative_pos=None,
|
| 207 |
+
rel_embeddings=None,
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Call the module
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
hidden_states (`torch.FloatTensor`):
|
| 214 |
+
Input states to the module usually the output from previous layer, it will be the Q,K and V in
|
| 215 |
+
*Attention(Q,K,V)*
|
| 216 |
+
|
| 217 |
+
attention_mask (`torch.BoolTensor`):
|
| 218 |
+
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
|
| 219 |
+
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
|
| 220 |
+
th token.
|
| 221 |
+
|
| 222 |
+
output_attentions (`bool`, *optional*):
|
| 223 |
+
Whether return the attention matrix.
|
| 224 |
+
|
| 225 |
+
query_states (`torch.FloatTensor`, *optional*):
|
| 226 |
+
The *Q* state in *Attention(Q,K,V)*.
|
| 227 |
+
|
| 228 |
+
relative_pos (`torch.LongTensor`):
|
| 229 |
+
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
|
| 230 |
+
values ranging in [*-max_relative_positions*, *max_relative_positions*].
|
| 231 |
+
|
| 232 |
+
rel_embeddings (`torch.FloatTensor`):
|
| 233 |
+
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
|
| 234 |
+
\\text{max_relative_positions}\\), *hidden_size*].
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
"""
|
| 238 |
+
if query_states is None:
|
| 239 |
+
query_states = hidden_states
|
| 240 |
+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
|
| 241 |
+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
|
| 242 |
+
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
|
| 243 |
+
|
| 244 |
+
rel_att = None
|
| 245 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 246 |
+
scale_factor = 1
|
| 247 |
+
if "c2p" in self.pos_att_type:
|
| 248 |
+
scale_factor += 1
|
| 249 |
+
if "p2c" in self.pos_att_type:
|
| 250 |
+
scale_factor += 1
|
| 251 |
+
scale = scaled_size_sqrt(query_layer, scale_factor)
|
| 252 |
+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
|
| 253 |
+
if self.relative_attention:
|
| 254 |
+
rel_embeddings = self.pos_dropout(rel_embeddings)
|
| 255 |
+
rel_att = self.disentangled_attention_bias(
|
| 256 |
+
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if rel_att is not None:
|
| 260 |
+
attention_scores = attention_scores + rel_att
|
| 261 |
+
attention_scores = attention_scores
|
| 262 |
+
attention_scores = attention_scores.view(
|
| 263 |
+
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
attention_mask = attention_mask.bool()
|
| 267 |
+
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
|
| 268 |
+
# bsz x height x length x dimension
|
| 269 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 270 |
+
|
| 271 |
+
attention_probs = self.dropout(attention_probs)
|
| 272 |
+
context_layer = torch.bmm(
|
| 273 |
+
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
|
| 274 |
+
)
|
| 275 |
+
context_layer = (
|
| 276 |
+
context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
|
| 277 |
+
.permute(0, 2, 1, 3)
|
| 278 |
+
.contiguous()
|
| 279 |
+
)
|
| 280 |
+
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
|
| 281 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 282 |
+
if not output_attentions:
|
| 283 |
+
return (context_layer, None)
|
| 284 |
+
return (context_layer, attention_probs)
|
| 285 |
+
|
| 286 |
+
def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
| 287 |
+
if relative_pos is None:
|
| 288 |
+
relative_pos = build_relative_position(
|
| 289 |
+
query_layer,
|
| 290 |
+
key_layer,
|
| 291 |
+
bucket_size=self.position_buckets,
|
| 292 |
+
max_position=self.max_relative_positions,
|
| 293 |
+
)
|
| 294 |
+
if relative_pos.dim() == 2:
|
| 295 |
+
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
| 296 |
+
elif relative_pos.dim() == 3:
|
| 297 |
+
relative_pos = relative_pos.unsqueeze(1)
|
| 298 |
+
# bsz x height x query x key
|
| 299 |
+
elif relative_pos.dim() != 4:
|
| 300 |
+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
|
| 301 |
+
|
| 302 |
+
att_span = self.pos_ebd_size
|
| 303 |
+
relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
|
| 304 |
+
|
| 305 |
+
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
| 306 |
+
if self.share_att_key:
|
| 307 |
+
pos_query_layer = self.transpose_for_scores(
|
| 308 |
+
self.query_proj(rel_embeddings), self.num_attention_heads
|
| 309 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
|
| 310 |
+
pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
|
| 311 |
+
query_layer.size(0) // self.num_attention_heads, 1, 1
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
if "c2p" in self.pos_att_type:
|
| 315 |
+
pos_key_layer = self.transpose_for_scores(
|
| 316 |
+
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
| 317 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
|
| 318 |
+
if "p2c" in self.pos_att_type:
|
| 319 |
+
pos_query_layer = self.transpose_for_scores(
|
| 320 |
+
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
| 321 |
+
).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
|
| 322 |
+
|
| 323 |
+
score = 0
|
| 324 |
+
# content->position
|
| 325 |
+
if "c2p" in self.pos_att_type:
|
| 326 |
+
scale = scaled_size_sqrt(pos_key_layer, scale_factor)
|
| 327 |
+
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
|
| 328 |
+
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
|
| 329 |
+
c2p_att = torch.gather(
|
| 330 |
+
c2p_att,
|
| 331 |
+
dim=-1,
|
| 332 |
+
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
| 333 |
+
)
|
| 334 |
+
score += c2p_att / scale.to(dtype=c2p_att.dtype)
|
| 335 |
+
|
| 336 |
+
# position->content
|
| 337 |
+
if "p2c" in self.pos_att_type:
|
| 338 |
+
scale = scaled_size_sqrt(pos_query_layer, scale_factor)
|
| 339 |
+
r_pos = build_rpos(
|
| 340 |
+
query_layer,
|
| 341 |
+
key_layer,
|
| 342 |
+
relative_pos,
|
| 343 |
+
self.max_relative_positions,
|
| 344 |
+
self.position_buckets,
|
| 345 |
+
)
|
| 346 |
+
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
| 347 |
+
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
|
| 348 |
+
p2c_att = torch.gather(
|
| 349 |
+
p2c_att,
|
| 350 |
+
dim=-1,
|
| 351 |
+
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
| 352 |
+
).transpose(-1, -2)
|
| 353 |
+
score += p2c_att / scale.to(dtype=p2c_att.dtype)
|
| 354 |
+
|
| 355 |
+
return score
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
|
| 359 |
+
class DebertaV2Attention(nn.Module):
|
| 360 |
+
def __init__(self, config):
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.self = DisentangledSelfAttention(config)
|
| 363 |
+
self.output = DebertaV2SelfOutput(config)
|
| 364 |
+
self.config = config
|
| 365 |
+
|
| 366 |
+
def forward(
|
| 367 |
+
self,
|
| 368 |
+
hidden_states,
|
| 369 |
+
attention_mask,
|
| 370 |
+
output_attentions: bool = False,
|
| 371 |
+
query_states=None,
|
| 372 |
+
relative_pos=None,
|
| 373 |
+
rel_embeddings=None,
|
| 374 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 375 |
+
self_output, att_matrix = self.self(
|
| 376 |
+
hidden_states,
|
| 377 |
+
attention_mask,
|
| 378 |
+
output_attentions,
|
| 379 |
+
query_states=query_states,
|
| 380 |
+
relative_pos=relative_pos,
|
| 381 |
+
rel_embeddings=rel_embeddings,
|
| 382 |
+
)
|
| 383 |
+
if query_states is None:
|
| 384 |
+
query_states = hidden_states
|
| 385 |
+
attention_output = self.output(self_output, query_states)
|
| 386 |
+
|
| 387 |
+
if output_attentions:
|
| 388 |
+
return (attention_output, att_matrix)
|
| 389 |
+
else:
|
| 390 |
+
return (attention_output, None)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
|
| 394 |
+
class DebertaV2Intermediate(nn.Module):
|
| 395 |
+
def __init__(self, config):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 398 |
+
if isinstance(config.hidden_act, str):
|
| 399 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 400 |
+
else:
|
| 401 |
+
self.intermediate_act_fn = config.hidden_act
|
| 402 |
+
|
| 403 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 404 |
+
hidden_states = self.dense(hidden_states)
|
| 405 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 406 |
+
return hidden_states
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
|
| 410 |
+
class DebertaV2Output(nn.Module):
|
| 411 |
+
def __init__(self, config):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 414 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 415 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 416 |
+
self.config = config
|
| 417 |
+
|
| 418 |
+
def forward(self, hidden_states, input_tensor):
|
| 419 |
+
hidden_states = self.dense(hidden_states)
|
| 420 |
+
hidden_states = self.dropout(hidden_states)
|
| 421 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 422 |
+
return hidden_states
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
|
| 426 |
+
class DebertaV2Layer(nn.Module):
|
| 427 |
+
def __init__(self, config):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.attention = DebertaV2Attention(config)
|
| 430 |
+
self.intermediate = DebertaV2Intermediate(config)
|
| 431 |
+
self.output = DebertaV2Output(config)
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
hidden_states,
|
| 436 |
+
attention_mask,
|
| 437 |
+
query_states=None,
|
| 438 |
+
relative_pos=None,
|
| 439 |
+
rel_embeddings=None,
|
| 440 |
+
output_attentions: bool = False,
|
| 441 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 442 |
+
attention_output, att_matrix = self.attention(
|
| 443 |
+
hidden_states,
|
| 444 |
+
attention_mask,
|
| 445 |
+
output_attentions=output_attentions,
|
| 446 |
+
query_states=query_states,
|
| 447 |
+
relative_pos=relative_pos,
|
| 448 |
+
rel_embeddings=rel_embeddings,
|
| 449 |
+
)
|
| 450 |
+
intermediate_output = self.intermediate(attention_output)
|
| 451 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 452 |
+
|
| 453 |
+
if output_attentions:
|
| 454 |
+
return (layer_output, att_matrix)
|
| 455 |
+
else:
|
| 456 |
+
return (layer_output, None)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class ConvLayer(nn.Module):
|
| 460 |
+
def __init__(self, config):
|
| 461 |
+
super().__init__()
|
| 462 |
+
kernel_size = getattr(config, "conv_kernel_size", 3)
|
| 463 |
+
groups = getattr(config, "conv_groups", 1)
|
| 464 |
+
self.conv_act = getattr(config, "conv_act", "tanh")
|
| 465 |
+
self.conv = nn.Conv1d(
|
| 466 |
+
config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
|
| 467 |
+
)
|
| 468 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 469 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 470 |
+
self.config = config
|
| 471 |
+
|
| 472 |
+
def forward(self, hidden_states, residual_states, input_mask):
|
| 473 |
+
out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
|
| 474 |
+
rmask = (1 - input_mask).bool()
|
| 475 |
+
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
|
| 476 |
+
out = ACT2FN[self.conv_act](self.dropout(out))
|
| 477 |
+
|
| 478 |
+
layer_norm_input = residual_states + out
|
| 479 |
+
output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
|
| 480 |
+
|
| 481 |
+
if input_mask is None:
|
| 482 |
+
output_states = output
|
| 483 |
+
else:
|
| 484 |
+
if input_mask.dim() != layer_norm_input.dim():
|
| 485 |
+
if input_mask.dim() == 4:
|
| 486 |
+
input_mask = input_mask.squeeze(1).squeeze(1)
|
| 487 |
+
input_mask = input_mask.unsqueeze(2)
|
| 488 |
+
|
| 489 |
+
input_mask = input_mask.to(output.dtype)
|
| 490 |
+
output_states = output * input_mask
|
| 491 |
+
|
| 492 |
+
return output_states
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2
|
| 496 |
+
class DebertaV2Embeddings(nn.Module):
|
| 497 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 498 |
+
|
| 499 |
+
def __init__(self, config):
|
| 500 |
+
super().__init__()
|
| 501 |
+
pad_token_id = getattr(config, "pad_token_id", 0)
|
| 502 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 503 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
|
| 504 |
+
|
| 505 |
+
self.position_biased_input = getattr(config, "position_biased_input", True)
|
| 506 |
+
if not self.position_biased_input:
|
| 507 |
+
self.position_embeddings = None
|
| 508 |
+
else:
|
| 509 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
|
| 510 |
+
|
| 511 |
+
if config.type_vocab_size > 0:
|
| 512 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
|
| 513 |
+
else:
|
| 514 |
+
self.token_type_embeddings = None
|
| 515 |
+
|
| 516 |
+
if self.embedding_size != config.hidden_size:
|
| 517 |
+
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
|
| 518 |
+
else:
|
| 519 |
+
self.embed_proj = None
|
| 520 |
+
|
| 521 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
| 522 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 523 |
+
self.config = config
|
| 524 |
+
|
| 525 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 526 |
+
self.register_buffer(
|
| 527 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
|
| 531 |
+
if input_ids is not None:
|
| 532 |
+
input_shape = input_ids.size()
|
| 533 |
+
else:
|
| 534 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 535 |
+
|
| 536 |
+
seq_length = input_shape[1]
|
| 537 |
+
|
| 538 |
+
if position_ids is None:
|
| 539 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 540 |
+
|
| 541 |
+
if token_type_ids is None:
|
| 542 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 543 |
+
|
| 544 |
+
if inputs_embeds is None:
|
| 545 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 546 |
+
|
| 547 |
+
if self.position_embeddings is not None:
|
| 548 |
+
position_embeddings = self.position_embeddings(position_ids.long())
|
| 549 |
+
else:
|
| 550 |
+
position_embeddings = torch.zeros_like(inputs_embeds)
|
| 551 |
+
|
| 552 |
+
embeddings = inputs_embeds
|
| 553 |
+
if self.position_biased_input:
|
| 554 |
+
embeddings = embeddings + position_embeddings
|
| 555 |
+
if self.token_type_embeddings is not None:
|
| 556 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 557 |
+
embeddings = embeddings + token_type_embeddings
|
| 558 |
+
|
| 559 |
+
if self.embed_proj is not None:
|
| 560 |
+
embeddings = self.embed_proj(embeddings)
|
| 561 |
+
|
| 562 |
+
embeddings = self.LayerNorm(embeddings)
|
| 563 |
+
|
| 564 |
+
if mask is not None:
|
| 565 |
+
if mask.dim() != embeddings.dim():
|
| 566 |
+
if mask.dim() == 4:
|
| 567 |
+
mask = mask.squeeze(1).squeeze(1)
|
| 568 |
+
mask = mask.unsqueeze(2)
|
| 569 |
+
mask = mask.to(embeddings.dtype)
|
| 570 |
+
|
| 571 |
+
embeddings = embeddings * mask
|
| 572 |
+
|
| 573 |
+
embeddings = self.dropout(embeddings)
|
| 574 |
+
return embeddings
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class DebertaV2Encoder(nn.Module):
|
| 578 |
+
"""Modified BertEncoder with relative position bias support"""
|
| 579 |
+
|
| 580 |
+
def __init__(self, config):
|
| 581 |
+
super().__init__()
|
| 582 |
+
|
| 583 |
+
self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
|
| 584 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 585 |
+
|
| 586 |
+
if self.relative_attention:
|
| 587 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 588 |
+
if self.max_relative_positions < 1:
|
| 589 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 590 |
+
|
| 591 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 592 |
+
pos_ebd_size = self.max_relative_positions * 2
|
| 593 |
+
|
| 594 |
+
if self.position_buckets > 0:
|
| 595 |
+
pos_ebd_size = self.position_buckets * 2
|
| 596 |
+
|
| 597 |
+
self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
|
| 598 |
+
|
| 599 |
+
self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
|
| 600 |
+
|
| 601 |
+
if "layer_norm" in self.norm_rel_ebd:
|
| 602 |
+
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
|
| 603 |
+
|
| 604 |
+
self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
|
| 605 |
+
self.gradient_checkpointing = False
|
| 606 |
+
|
| 607 |
+
def get_rel_embedding(self):
|
| 608 |
+
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
|
| 609 |
+
if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
|
| 610 |
+
rel_embeddings = self.LayerNorm(rel_embeddings)
|
| 611 |
+
return rel_embeddings
|
| 612 |
+
|
| 613 |
+
def get_attention_mask(self, attention_mask):
|
| 614 |
+
if attention_mask.dim() <= 2:
|
| 615 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 616 |
+
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
| 617 |
+
elif attention_mask.dim() == 3:
|
| 618 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 619 |
+
|
| 620 |
+
return attention_mask
|
| 621 |
+
|
| 622 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 623 |
+
if self.relative_attention and relative_pos is None:
|
| 624 |
+
if query_states is not None:
|
| 625 |
+
relative_pos = build_relative_position(
|
| 626 |
+
query_states,
|
| 627 |
+
hidden_states,
|
| 628 |
+
bucket_size=self.position_buckets,
|
| 629 |
+
max_position=self.max_relative_positions,
|
| 630 |
+
)
|
| 631 |
+
else:
|
| 632 |
+
relative_pos = build_relative_position(
|
| 633 |
+
hidden_states,
|
| 634 |
+
hidden_states,
|
| 635 |
+
bucket_size=self.position_buckets,
|
| 636 |
+
max_position=self.max_relative_positions,
|
| 637 |
+
)
|
| 638 |
+
return relative_pos
|
| 639 |
+
|
| 640 |
+
def forward(
|
| 641 |
+
self,
|
| 642 |
+
hidden_states,
|
| 643 |
+
attention_mask,
|
| 644 |
+
output_hidden_states=True,
|
| 645 |
+
output_attentions=False,
|
| 646 |
+
query_states=None,
|
| 647 |
+
relative_pos=None,
|
| 648 |
+
return_dict=True,
|
| 649 |
+
):
|
| 650 |
+
if attention_mask.dim() <= 2:
|
| 651 |
+
input_mask = attention_mask
|
| 652 |
+
else:
|
| 653 |
+
input_mask = attention_mask.sum(-2) > 0
|
| 654 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 655 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 656 |
+
|
| 657 |
+
all_hidden_states: Optional[Tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None
|
| 658 |
+
all_attentions = () if output_attentions else None
|
| 659 |
+
|
| 660 |
+
next_kv = hidden_states
|
| 661 |
+
rel_embeddings = self.get_rel_embedding()
|
| 662 |
+
for i, layer_module in enumerate(self.layer):
|
| 663 |
+
if self.gradient_checkpointing and self.training:
|
| 664 |
+
output_states, attn_weights = self._gradient_checkpointing_func(
|
| 665 |
+
layer_module.__call__,
|
| 666 |
+
next_kv,
|
| 667 |
+
attention_mask,
|
| 668 |
+
query_states,
|
| 669 |
+
relative_pos,
|
| 670 |
+
rel_embeddings,
|
| 671 |
+
output_attentions,
|
| 672 |
+
)
|
| 673 |
+
else:
|
| 674 |
+
output_states, attn_weights = layer_module(
|
| 675 |
+
next_kv,
|
| 676 |
+
attention_mask,
|
| 677 |
+
query_states=query_states,
|
| 678 |
+
relative_pos=relative_pos,
|
| 679 |
+
rel_embeddings=rel_embeddings,
|
| 680 |
+
output_attentions=output_attentions,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
if output_attentions:
|
| 684 |
+
all_attentions = all_attentions + (attn_weights,)
|
| 685 |
+
|
| 686 |
+
if i == 0 and self.conv is not None:
|
| 687 |
+
output_states = self.conv(hidden_states, output_states, input_mask)
|
| 688 |
+
|
| 689 |
+
if output_hidden_states:
|
| 690 |
+
all_hidden_states = all_hidden_states + (output_states,)
|
| 691 |
+
|
| 692 |
+
if query_states is not None:
|
| 693 |
+
query_states = output_states
|
| 694 |
+
if isinstance(hidden_states, Sequence):
|
| 695 |
+
next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
|
| 696 |
+
else:
|
| 697 |
+
next_kv = output_states
|
| 698 |
+
|
| 699 |
+
if not return_dict:
|
| 700 |
+
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
| 701 |
+
return BaseModelOutput(
|
| 702 |
+
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class DebertaV2PreTrainedModel(PreTrainedModel):
|
| 707 |
+
"""
|
| 708 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 709 |
+
models.
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
config_class = DebertaV2Config
|
| 713 |
+
base_model_prefix = "deberta"
|
| 714 |
+
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
|
| 715 |
+
supports_gradient_checkpointing = True
|
| 716 |
+
|
| 717 |
+
def _init_weights(self, module):
|
| 718 |
+
"""Initialize the weights."""
|
| 719 |
+
if isinstance(module, nn.Linear):
|
| 720 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 721 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 722 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 723 |
+
if module.bias is not None:
|
| 724 |
+
module.bias.data.zero_()
|
| 725 |
+
elif isinstance(module, nn.Embedding):
|
| 726 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 727 |
+
if module.padding_idx is not None:
|
| 728 |
+
module.weight.data[module.padding_idx].zero_()
|
| 729 |
+
elif isinstance(module, nn.LayerNorm):
|
| 730 |
+
module.weight.data.fill_(1.0)
|
| 731 |
+
module.bias.data.zero_()
|
| 732 |
+
elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)):
|
| 733 |
+
module.bias.data.zero_()
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
DEBERTA_START_DOCSTRING = r"""
|
| 737 |
+
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
| 738 |
+
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
|
| 739 |
+
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
|
| 740 |
+
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
|
| 741 |
+
|
| 742 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 743 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 744 |
+
and behavior.
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
Parameters:
|
| 748 |
+
config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
|
| 749 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 750 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
DEBERTA_INPUTS_DOCSTRING = r"""
|
| 754 |
+
Args:
|
| 755 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 756 |
+
Indices of input sequence tokens in the vocabulary.
|
| 757 |
+
|
| 758 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 759 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 760 |
+
|
| 761 |
+
[What are input IDs?](../glossary#input-ids)
|
| 762 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 763 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 764 |
+
|
| 765 |
+
- 1 for tokens that are **not masked**,
|
| 766 |
+
- 0 for tokens that are **masked**.
|
| 767 |
+
|
| 768 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 769 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 770 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 771 |
+
1]`:
|
| 772 |
+
|
| 773 |
+
- 0 corresponds to a *sentence A* token,
|
| 774 |
+
- 1 corresponds to a *sentence B* token.
|
| 775 |
+
|
| 776 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 777 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 778 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 779 |
+
config.max_position_embeddings - 1]`.
|
| 780 |
+
|
| 781 |
+
[What are position IDs?](../glossary#position-ids)
|
| 782 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 783 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 784 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 785 |
+
model's internal embedding lookup matrix.
|
| 786 |
+
output_attentions (`bool`, *optional*):
|
| 787 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 788 |
+
tensors for more detail.
|
| 789 |
+
output_hidden_states (`bool`, *optional*):
|
| 790 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 791 |
+
more detail.
|
| 792 |
+
return_dict (`bool`, *optional*):
|
| 793 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 794 |
+
"""
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@add_start_docstrings(
|
| 798 |
+
"The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
| 799 |
+
DEBERTA_START_DOCSTRING,
|
| 800 |
+
)
|
| 801 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
|
| 802 |
+
class DebertaV2Model(DebertaV2PreTrainedModel):
|
| 803 |
+
def __init__(self, config):
|
| 804 |
+
super().__init__(config)
|
| 805 |
+
|
| 806 |
+
self.embeddings = DebertaV2Embeddings(config)
|
| 807 |
+
self.encoder = DebertaV2Encoder(config)
|
| 808 |
+
self.z_steps = 0
|
| 809 |
+
self.config = config
|
| 810 |
+
# Initialize weights and apply final processing
|
| 811 |
+
self.post_init()
|
| 812 |
+
|
| 813 |
+
def get_input_embeddings(self):
|
| 814 |
+
return self.embeddings.word_embeddings
|
| 815 |
+
|
| 816 |
+
def set_input_embeddings(self, new_embeddings):
|
| 817 |
+
self.embeddings.word_embeddings = new_embeddings
|
| 818 |
+
|
| 819 |
+
def _prune_heads(self, heads_to_prune):
|
| 820 |
+
"""
|
| 821 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 822 |
+
class PreTrainedModel
|
| 823 |
+
"""
|
| 824 |
+
raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
|
| 825 |
+
|
| 826 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 827 |
+
@add_code_sample_docstrings(
|
| 828 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 829 |
+
output_type=BaseModelOutput,
|
| 830 |
+
config_class=_CONFIG_FOR_DOC,
|
| 831 |
+
)
|
| 832 |
+
def forward(
|
| 833 |
+
self,
|
| 834 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 835 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 836 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 837 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 838 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 839 |
+
output_attentions: Optional[bool] = None,
|
| 840 |
+
output_hidden_states: Optional[bool] = None,
|
| 841 |
+
return_dict: Optional[bool] = None,
|
| 842 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 843 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 844 |
+
output_hidden_states = (
|
| 845 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 846 |
+
)
|
| 847 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 848 |
+
|
| 849 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 850 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 851 |
+
elif input_ids is not None:
|
| 852 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 853 |
+
input_shape = input_ids.size()
|
| 854 |
+
elif inputs_embeds is not None:
|
| 855 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 856 |
+
else:
|
| 857 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 858 |
+
|
| 859 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 860 |
+
|
| 861 |
+
if attention_mask is None:
|
| 862 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 863 |
+
if token_type_ids is None:
|
| 864 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 865 |
+
|
| 866 |
+
embedding_output = self.embeddings(
|
| 867 |
+
input_ids=input_ids,
|
| 868 |
+
token_type_ids=token_type_ids,
|
| 869 |
+
position_ids=position_ids,
|
| 870 |
+
mask=attention_mask,
|
| 871 |
+
inputs_embeds=inputs_embeds,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
encoder_outputs = self.encoder(
|
| 875 |
+
embedding_output,
|
| 876 |
+
attention_mask,
|
| 877 |
+
output_hidden_states=True,
|
| 878 |
+
output_attentions=output_attentions,
|
| 879 |
+
return_dict=return_dict,
|
| 880 |
+
)
|
| 881 |
+
encoded_layers = encoder_outputs[1]
|
| 882 |
+
|
| 883 |
+
if self.z_steps > 1:
|
| 884 |
+
hidden_states = encoded_layers[-2]
|
| 885 |
+
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
|
| 886 |
+
query_states = encoded_layers[-1]
|
| 887 |
+
rel_embeddings = self.encoder.get_rel_embedding()
|
| 888 |
+
attention_mask = self.encoder.get_attention_mask(attention_mask)
|
| 889 |
+
rel_pos = self.encoder.get_rel_pos(embedding_output)
|
| 890 |
+
for layer in layers[1:]:
|
| 891 |
+
query_states = layer(
|
| 892 |
+
hidden_states,
|
| 893 |
+
attention_mask,
|
| 894 |
+
output_attentions=False,
|
| 895 |
+
query_states=query_states,
|
| 896 |
+
relative_pos=rel_pos,
|
| 897 |
+
rel_embeddings=rel_embeddings,
|
| 898 |
+
)
|
| 899 |
+
encoded_layers.append(query_states)
|
| 900 |
+
|
| 901 |
+
sequence_output = encoded_layers[-1]
|
| 902 |
+
|
| 903 |
+
if not return_dict:
|
| 904 |
+
return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
|
| 905 |
+
|
| 906 |
+
return BaseModelOutput(
|
| 907 |
+
last_hidden_state=sequence_output,
|
| 908 |
+
hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
|
| 909 |
+
attentions=encoder_outputs.attentions,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
# Copied from transformers.models.deberta.modeling_deberta.LegacyDebertaPredictionHeadTransform with Deberta->DebertaV2
|
| 914 |
+
class LegacyDebertaV2PredictionHeadTransform(nn.Module):
|
| 915 |
+
def __init__(self, config):
|
| 916 |
+
super().__init__()
|
| 917 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 918 |
+
|
| 919 |
+
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
|
| 920 |
+
if isinstance(config.hidden_act, str):
|
| 921 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 922 |
+
else:
|
| 923 |
+
self.transform_act_fn = config.hidden_act
|
| 924 |
+
self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
|
| 925 |
+
|
| 926 |
+
def forward(self, hidden_states):
|
| 927 |
+
hidden_states = self.dense(hidden_states)
|
| 928 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 929 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 930 |
+
return hidden_states
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
class LegacyDebertaV2LMPredictionHead(nn.Module):
|
| 934 |
+
def __init__(self, config):
|
| 935 |
+
super().__init__()
|
| 936 |
+
self.transform = LegacyDebertaV2PredictionHeadTransform(config)
|
| 937 |
+
|
| 938 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 939 |
+
# The output weights are the same as the input embeddings, but there is
|
| 940 |
+
# an output-only bias for each token.
|
| 941 |
+
self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
|
| 942 |
+
|
| 943 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 944 |
+
|
| 945 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 946 |
+
self.decoder.bias = self.bias
|
| 947 |
+
|
| 948 |
+
def _tie_weights(self):
|
| 949 |
+
self.decoder.bias = self.bias
|
| 950 |
+
|
| 951 |
+
def forward(self, hidden_states):
|
| 952 |
+
hidden_states = self.transform(hidden_states)
|
| 953 |
+
hidden_states = self.decoder(hidden_states)
|
| 954 |
+
return hidden_states
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
class LegacyDebertaV2OnlyMLMHead(nn.Module):
|
| 958 |
+
def __init__(self, config):
|
| 959 |
+
super().__init__()
|
| 960 |
+
self.predictions = LegacyDebertaV2LMPredictionHead(config)
|
| 961 |
+
|
| 962 |
+
def forward(self, sequence_output):
|
| 963 |
+
prediction_scores = self.predictions(sequence_output)
|
| 964 |
+
return prediction_scores
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
class DebertaV2LMPredictionHead(nn.Module):
|
| 968 |
+
"""https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
|
| 969 |
+
|
| 970 |
+
def __init__(self, config):
|
| 971 |
+
super().__init__()
|
| 972 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 973 |
+
|
| 974 |
+
if isinstance(config.hidden_act, str):
|
| 975 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 976 |
+
else:
|
| 977 |
+
self.transform_act_fn = config.hidden_act
|
| 978 |
+
|
| 979 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
|
| 980 |
+
|
| 981 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 982 |
+
|
| 983 |
+
# note that the input embeddings must be passed as an argument
|
| 984 |
+
def forward(self, hidden_states, word_embeddings):
|
| 985 |
+
hidden_states = self.dense(hidden_states)
|
| 986 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 987 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 988 |
+
hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
|
| 989 |
+
return hidden_states
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
class DebertaV2OnlyMLMHead(nn.Module):
|
| 993 |
+
def __init__(self, config):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.lm_head = DebertaV2LMPredictionHead(config)
|
| 996 |
+
|
| 997 |
+
# note that the input embeddings must be passed as an argument
|
| 998 |
+
def forward(self, sequence_output, word_embeddings):
|
| 999 |
+
prediction_scores = self.lm_head(sequence_output, word_embeddings)
|
| 1000 |
+
return prediction_scores
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
| 1004 |
+
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
| 1005 |
+
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
| 1006 |
+
_keys_to_ignore_on_load_unexpected = r"mask_predictions.*"
|
| 1007 |
+
|
| 1008 |
+
def __init__(self, config):
|
| 1009 |
+
super().__init__(config)
|
| 1010 |
+
self.legacy = config.legacy
|
| 1011 |
+
self.deberta = DebertaV2Model(config)
|
| 1012 |
+
if self.legacy:
|
| 1013 |
+
self.cls = LegacyDebertaV2OnlyMLMHead(config)
|
| 1014 |
+
else:
|
| 1015 |
+
self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"]
|
| 1016 |
+
self.lm_predictions = DebertaV2OnlyMLMHead(config)
|
| 1017 |
+
# Initialize weights and apply final processing
|
| 1018 |
+
self.post_init()
|
| 1019 |
+
|
| 1020 |
+
def get_output_embeddings(self):
|
| 1021 |
+
if self.legacy:
|
| 1022 |
+
return self.cls.predictions.decoder
|
| 1023 |
+
else:
|
| 1024 |
+
return self.lm_predictions.lm_head.dense
|
| 1025 |
+
|
| 1026 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1027 |
+
if self.legacy:
|
| 1028 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1029 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1030 |
+
else:
|
| 1031 |
+
self.lm_predictions.lm_head.dense = new_embeddings
|
| 1032 |
+
self.lm_predictions.lm_head.bias = new_embeddings.bias
|
| 1033 |
+
|
| 1034 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1035 |
+
@add_code_sample_docstrings(
|
| 1036 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1037 |
+
output_type=MaskedLMOutput,
|
| 1038 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1039 |
+
mask="[MASK]",
|
| 1040 |
+
)
|
| 1041 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2
|
| 1042 |
+
def forward(
|
| 1043 |
+
self,
|
| 1044 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1045 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1046 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1047 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1048 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1049 |
+
labels: Optional[torch.Tensor] = None,
|
| 1050 |
+
output_attentions: Optional[bool] = None,
|
| 1051 |
+
output_hidden_states: Optional[bool] = None,
|
| 1052 |
+
return_dict: Optional[bool] = None,
|
| 1053 |
+
) -> Union[Tuple, MaskedLMOutput]:
|
| 1054 |
+
r"""
|
| 1055 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1056 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1057 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1058 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1059 |
+
"""
|
| 1060 |
+
|
| 1061 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1062 |
+
|
| 1063 |
+
outputs = self.deberta(
|
| 1064 |
+
input_ids,
|
| 1065 |
+
attention_mask=attention_mask,
|
| 1066 |
+
token_type_ids=token_type_ids,
|
| 1067 |
+
position_ids=position_ids,
|
| 1068 |
+
inputs_embeds=inputs_embeds,
|
| 1069 |
+
output_attentions=output_attentions,
|
| 1070 |
+
output_hidden_states=output_hidden_states,
|
| 1071 |
+
return_dict=return_dict,
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
sequence_output = outputs[0]
|
| 1075 |
+
if self.legacy:
|
| 1076 |
+
prediction_scores = self.cls(sequence_output)
|
| 1077 |
+
else:
|
| 1078 |
+
prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
|
| 1079 |
+
|
| 1080 |
+
masked_lm_loss = None
|
| 1081 |
+
if labels is not None:
|
| 1082 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1083 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1084 |
+
|
| 1085 |
+
if not return_dict:
|
| 1086 |
+
output = (prediction_scores,) + outputs[1:]
|
| 1087 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1088 |
+
|
| 1089 |
+
return MaskedLMOutput(
|
| 1090 |
+
loss=masked_lm_loss,
|
| 1091 |
+
logits=prediction_scores,
|
| 1092 |
+
hidden_states=outputs.hidden_states,
|
| 1093 |
+
attentions=outputs.attentions,
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
|
| 1098 |
+
class ContextPooler(nn.Module):
|
| 1099 |
+
def __init__(self, config):
|
| 1100 |
+
super().__init__()
|
| 1101 |
+
self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
|
| 1102 |
+
self.dropout = nn.Dropout(config.pooler_dropout)
|
| 1103 |
+
self.config = config
|
| 1104 |
+
|
| 1105 |
+
def forward(self, hidden_states):
|
| 1106 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 1107 |
+
# to the first token.
|
| 1108 |
+
|
| 1109 |
+
context_token = hidden_states[:, 0]
|
| 1110 |
+
context_token = self.dropout(context_token)
|
| 1111 |
+
pooled_output = self.dense(context_token)
|
| 1112 |
+
pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
|
| 1113 |
+
return pooled_output
|
| 1114 |
+
|
| 1115 |
+
@property
|
| 1116 |
+
def output_dim(self):
|
| 1117 |
+
return self.config.hidden_size
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
@add_start_docstrings(
|
| 1121 |
+
"""
|
| 1122 |
+
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 1123 |
+
pooled output) e.g. for GLUE tasks.
|
| 1124 |
+
""",
|
| 1125 |
+
DEBERTA_START_DOCSTRING,
|
| 1126 |
+
)
|
| 1127 |
+
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
|
| 1128 |
+
def __init__(self, config):
|
| 1129 |
+
super().__init__(config)
|
| 1130 |
+
|
| 1131 |
+
num_labels = getattr(config, "num_labels", 2)
|
| 1132 |
+
self.num_labels = num_labels
|
| 1133 |
+
|
| 1134 |
+
self.deberta = DebertaV2Model(config)
|
| 1135 |
+
self.pooler = ContextPooler(config)
|
| 1136 |
+
output_dim = self.pooler.output_dim
|
| 1137 |
+
|
| 1138 |
+
self.classifier = nn.Linear(output_dim, num_labels)
|
| 1139 |
+
drop_out = getattr(config, "cls_dropout", None)
|
| 1140 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
| 1141 |
+
self.dropout = nn.Dropout(drop_out)
|
| 1142 |
+
|
| 1143 |
+
# Initialize weights and apply final processing
|
| 1144 |
+
self.post_init()
|
| 1145 |
+
|
| 1146 |
+
def get_input_embeddings(self):
|
| 1147 |
+
return self.deberta.get_input_embeddings()
|
| 1148 |
+
|
| 1149 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1150 |
+
self.deberta.set_input_embeddings(new_embeddings)
|
| 1151 |
+
|
| 1152 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1153 |
+
@add_code_sample_docstrings(
|
| 1154 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1155 |
+
output_type=SequenceClassifierOutput,
|
| 1156 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1157 |
+
)
|
| 1158 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2
|
| 1159 |
+
def forward(
|
| 1160 |
+
self,
|
| 1161 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1162 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1163 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1164 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1165 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1166 |
+
labels: Optional[torch.Tensor] = None,
|
| 1167 |
+
output_attentions: Optional[bool] = None,
|
| 1168 |
+
output_hidden_states: Optional[bool] = None,
|
| 1169 |
+
return_dict: Optional[bool] = None,
|
| 1170 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1171 |
+
r"""
|
| 1172 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1173 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1174 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1175 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1176 |
+
"""
|
| 1177 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1178 |
+
|
| 1179 |
+
outputs = self.deberta(
|
| 1180 |
+
input_ids,
|
| 1181 |
+
token_type_ids=token_type_ids,
|
| 1182 |
+
attention_mask=attention_mask,
|
| 1183 |
+
position_ids=position_ids,
|
| 1184 |
+
inputs_embeds=inputs_embeds,
|
| 1185 |
+
output_attentions=output_attentions,
|
| 1186 |
+
output_hidden_states=output_hidden_states,
|
| 1187 |
+
return_dict=return_dict,
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
encoder_layer = outputs[0]
|
| 1191 |
+
pooled_output = self.pooler(encoder_layer)
|
| 1192 |
+
pooled_output = self.dropout(pooled_output)
|
| 1193 |
+
logits = self.classifier(pooled_output)
|
| 1194 |
+
|
| 1195 |
+
loss = None
|
| 1196 |
+
if labels is not None:
|
| 1197 |
+
if self.config.problem_type is None:
|
| 1198 |
+
if self.num_labels == 1:
|
| 1199 |
+
# regression task
|
| 1200 |
+
loss_fn = nn.MSELoss()
|
| 1201 |
+
logits = logits.view(-1).to(labels.dtype)
|
| 1202 |
+
loss = loss_fn(logits, labels.view(-1))
|
| 1203 |
+
elif labels.dim() == 1 or labels.size(-1) == 1:
|
| 1204 |
+
label_index = (labels >= 0).nonzero()
|
| 1205 |
+
labels = labels.long()
|
| 1206 |
+
if label_index.size(0) > 0:
|
| 1207 |
+
labeled_logits = torch.gather(
|
| 1208 |
+
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
|
| 1209 |
+
)
|
| 1210 |
+
labels = torch.gather(labels, 0, label_index.view(-1))
|
| 1211 |
+
loss_fct = CrossEntropyLoss()
|
| 1212 |
+
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
|
| 1213 |
+
else:
|
| 1214 |
+
loss = torch.tensor(0).to(logits)
|
| 1215 |
+
else:
|
| 1216 |
+
log_softmax = nn.LogSoftmax(-1)
|
| 1217 |
+
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
|
| 1218 |
+
elif self.config.problem_type == "regression":
|
| 1219 |
+
loss_fct = MSELoss()
|
| 1220 |
+
if self.num_labels == 1:
|
| 1221 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1222 |
+
else:
|
| 1223 |
+
loss = loss_fct(logits, labels)
|
| 1224 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1225 |
+
loss_fct = CrossEntropyLoss()
|
| 1226 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1227 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1228 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1229 |
+
loss = loss_fct(logits, labels)
|
| 1230 |
+
if not return_dict:
|
| 1231 |
+
output = (logits,) + outputs[1:]
|
| 1232 |
+
return ((loss,) + output) if loss is not None else output
|
| 1233 |
+
|
| 1234 |
+
return SequenceClassifierOutput(
|
| 1235 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
@add_start_docstrings(
|
| 1240 |
+
"""
|
| 1241 |
+
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1242 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1243 |
+
""",
|
| 1244 |
+
DEBERTA_START_DOCSTRING,
|
| 1245 |
+
)
|
| 1246 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
|
| 1247 |
+
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
|
| 1248 |
+
def __init__(self, config):
|
| 1249 |
+
super().__init__(config)
|
| 1250 |
+
self.num_labels = config.num_labels
|
| 1251 |
+
|
| 1252 |
+
self.deberta = DebertaV2Model(config)
|
| 1253 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1254 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1255 |
+
|
| 1256 |
+
# Initialize weights and apply final processing
|
| 1257 |
+
self.post_init()
|
| 1258 |
+
|
| 1259 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1260 |
+
@add_code_sample_docstrings(
|
| 1261 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1262 |
+
output_type=TokenClassifierOutput,
|
| 1263 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1264 |
+
)
|
| 1265 |
+
def forward(
|
| 1266 |
+
self,
|
| 1267 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1268 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1269 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1270 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1271 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1272 |
+
labels: Optional[torch.Tensor] = None,
|
| 1273 |
+
output_attentions: Optional[bool] = None,
|
| 1274 |
+
output_hidden_states: Optional[bool] = None,
|
| 1275 |
+
return_dict: Optional[bool] = None,
|
| 1276 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1277 |
+
r"""
|
| 1278 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1279 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1280 |
+
"""
|
| 1281 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1282 |
+
|
| 1283 |
+
outputs = self.deberta(
|
| 1284 |
+
input_ids,
|
| 1285 |
+
attention_mask=attention_mask,
|
| 1286 |
+
token_type_ids=token_type_ids,
|
| 1287 |
+
position_ids=position_ids,
|
| 1288 |
+
inputs_embeds=inputs_embeds,
|
| 1289 |
+
output_attentions=output_attentions,
|
| 1290 |
+
output_hidden_states=output_hidden_states,
|
| 1291 |
+
return_dict=return_dict,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
sequence_output = outputs[0]
|
| 1295 |
+
|
| 1296 |
+
sequence_output = self.dropout(sequence_output)
|
| 1297 |
+
logits = self.classifier(sequence_output)
|
| 1298 |
+
|
| 1299 |
+
loss = None
|
| 1300 |
+
if labels is not None:
|
| 1301 |
+
loss_fct = CrossEntropyLoss()
|
| 1302 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1303 |
+
|
| 1304 |
+
if not return_dict:
|
| 1305 |
+
output = (logits,) + outputs[1:]
|
| 1306 |
+
return ((loss,) + output) if loss is not None else output
|
| 1307 |
+
|
| 1308 |
+
return TokenClassifierOutput(
|
| 1309 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1310 |
+
)
|
| 1311 |
+
|
| 1312 |
+
|
| 1313 |
+
@add_start_docstrings(
|
| 1314 |
+
"""
|
| 1315 |
+
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1316 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1317 |
+
""",
|
| 1318 |
+
DEBERTA_START_DOCSTRING,
|
| 1319 |
+
)
|
| 1320 |
+
class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
|
| 1321 |
+
def __init__(self, config):
|
| 1322 |
+
super().__init__(config)
|
| 1323 |
+
self.num_labels = config.num_labels
|
| 1324 |
+
|
| 1325 |
+
self.deberta = DebertaV2Model(config)
|
| 1326 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1327 |
+
|
| 1328 |
+
# Initialize weights and apply final processing
|
| 1329 |
+
self.post_init()
|
| 1330 |
+
|
| 1331 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1332 |
+
@add_code_sample_docstrings(
|
| 1333 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1334 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1335 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1336 |
+
qa_target_start_index=_QA_TARGET_START_INDEX,
|
| 1337 |
+
qa_target_end_index=_QA_TARGET_END_INDEX,
|
| 1338 |
+
)
|
| 1339 |
+
# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2
|
| 1340 |
+
def forward(
|
| 1341 |
+
self,
|
| 1342 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1343 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1344 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1345 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1346 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1347 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1348 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1349 |
+
output_attentions: Optional[bool] = None,
|
| 1350 |
+
output_hidden_states: Optional[bool] = None,
|
| 1351 |
+
return_dict: Optional[bool] = None,
|
| 1352 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1353 |
+
r"""
|
| 1354 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1355 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1356 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1357 |
+
are not taken into account for computing the loss.
|
| 1358 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1359 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1360 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1361 |
+
are not taken into account for computing the loss.
|
| 1362 |
+
"""
|
| 1363 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1364 |
+
|
| 1365 |
+
outputs = self.deberta(
|
| 1366 |
+
input_ids,
|
| 1367 |
+
attention_mask=attention_mask,
|
| 1368 |
+
token_type_ids=token_type_ids,
|
| 1369 |
+
position_ids=position_ids,
|
| 1370 |
+
inputs_embeds=inputs_embeds,
|
| 1371 |
+
output_attentions=output_attentions,
|
| 1372 |
+
output_hidden_states=output_hidden_states,
|
| 1373 |
+
return_dict=return_dict,
|
| 1374 |
+
)
|
| 1375 |
+
|
| 1376 |
+
sequence_output = outputs[0]
|
| 1377 |
+
|
| 1378 |
+
logits = self.qa_outputs(sequence_output)
|
| 1379 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1380 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1381 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1382 |
+
|
| 1383 |
+
total_loss = None
|
| 1384 |
+
if start_positions is not None and end_positions is not None:
|
| 1385 |
+
# If we are on multi-GPU, split add a dimension
|
| 1386 |
+
if len(start_positions.size()) > 1:
|
| 1387 |
+
start_positions = start_positions.squeeze(-1)
|
| 1388 |
+
if len(end_positions.size()) > 1:
|
| 1389 |
+
end_positions = end_positions.squeeze(-1)
|
| 1390 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1391 |
+
ignored_index = start_logits.size(1)
|
| 1392 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1393 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1394 |
+
|
| 1395 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1396 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1397 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1398 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1399 |
+
|
| 1400 |
+
if not return_dict:
|
| 1401 |
+
output = (start_logits, end_logits) + outputs[1:]
|
| 1402 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1403 |
+
|
| 1404 |
+
return QuestionAnsweringModelOutput(
|
| 1405 |
+
loss=total_loss,
|
| 1406 |
+
start_logits=start_logits,
|
| 1407 |
+
end_logits=end_logits,
|
| 1408 |
+
hidden_states=outputs.hidden_states,
|
| 1409 |
+
attentions=outputs.attentions,
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
@add_start_docstrings(
|
| 1414 |
+
"""
|
| 1415 |
+
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1416 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1417 |
+
""",
|
| 1418 |
+
DEBERTA_START_DOCSTRING,
|
| 1419 |
+
)
|
| 1420 |
+
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
|
| 1421 |
+
def __init__(self, config):
|
| 1422 |
+
super().__init__(config)
|
| 1423 |
+
|
| 1424 |
+
num_labels = getattr(config, "num_labels", 2)
|
| 1425 |
+
self.num_labels = num_labels
|
| 1426 |
+
|
| 1427 |
+
self.deberta = DebertaV2Model(config)
|
| 1428 |
+
self.pooler = ContextPooler(config)
|
| 1429 |
+
output_dim = self.pooler.output_dim
|
| 1430 |
+
|
| 1431 |
+
self.classifier = nn.Linear(output_dim, 1)
|
| 1432 |
+
drop_out = getattr(config, "cls_dropout", None)
|
| 1433 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
| 1434 |
+
self.dropout = nn.Dropout(drop_out)
|
| 1435 |
+
|
| 1436 |
+
self.init_weights()
|
| 1437 |
+
|
| 1438 |
+
def get_input_embeddings(self):
|
| 1439 |
+
return self.deberta.get_input_embeddings()
|
| 1440 |
+
|
| 1441 |
+
def set_input_embeddings(self, new_embeddings):
|
| 1442 |
+
self.deberta.set_input_embeddings(new_embeddings)
|
| 1443 |
+
|
| 1444 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1445 |
+
@add_code_sample_docstrings(
|
| 1446 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1447 |
+
output_type=MultipleChoiceModelOutput,
|
| 1448 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1449 |
+
)
|
| 1450 |
+
def forward(
|
| 1451 |
+
self,
|
| 1452 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1453 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1454 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1455 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1456 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1457 |
+
labels: Optional[torch.Tensor] = None,
|
| 1458 |
+
output_attentions: Optional[bool] = None,
|
| 1459 |
+
output_hidden_states: Optional[bool] = None,
|
| 1460 |
+
return_dict: Optional[bool] = None,
|
| 1461 |
+
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
| 1462 |
+
r"""
|
| 1463 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1464 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1465 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1466 |
+
`input_ids` above)
|
| 1467 |
+
"""
|
| 1468 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1469 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1470 |
+
|
| 1471 |
+
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1472 |
+
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1473 |
+
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1474 |
+
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1475 |
+
flat_inputs_embeds = (
|
| 1476 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1477 |
+
if inputs_embeds is not None
|
| 1478 |
+
else None
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
outputs = self.deberta(
|
| 1482 |
+
flat_input_ids,
|
| 1483 |
+
position_ids=flat_position_ids,
|
| 1484 |
+
token_type_ids=flat_token_type_ids,
|
| 1485 |
+
attention_mask=flat_attention_mask,
|
| 1486 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1487 |
+
output_attentions=output_attentions,
|
| 1488 |
+
output_hidden_states=output_hidden_states,
|
| 1489 |
+
return_dict=return_dict,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
encoder_layer = outputs[0]
|
| 1493 |
+
pooled_output = self.pooler(encoder_layer)
|
| 1494 |
+
pooled_output = self.dropout(pooled_output)
|
| 1495 |
+
logits = self.classifier(pooled_output)
|
| 1496 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1497 |
+
|
| 1498 |
+
loss = None
|
| 1499 |
+
if labels is not None:
|
| 1500 |
+
loss_fct = CrossEntropyLoss()
|
| 1501 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1502 |
+
|
| 1503 |
+
if not return_dict:
|
| 1504 |
+
output = (reshaped_logits,) + outputs[1:]
|
| 1505 |
+
return ((loss,) + output) if loss is not None else output
|
| 1506 |
+
|
| 1507 |
+
return MultipleChoiceModelOutput(
|
| 1508 |
+
loss=loss,
|
| 1509 |
+
logits=reshaped_logits,
|
| 1510 |
+
hidden_states=outputs.hidden_states,
|
| 1511 |
+
attentions=outputs.attentions,
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
|
| 1515 |
+
__all__ = [
|
| 1516 |
+
"DebertaV2ForMaskedLM",
|
| 1517 |
+
"DebertaV2ForMultipleChoice",
|
| 1518 |
+
"DebertaV2ForQuestionAnswering",
|
| 1519 |
+
"DebertaV2ForSequenceClassification",
|
| 1520 |
+
"DebertaV2ForTokenClassification",
|
| 1521 |
+
"DebertaV2Model",
|
| 1522 |
+
"DebertaV2PreTrainedModel",
|
| 1523 |
+
]
|
docs/transformers/build/lib/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
ADDED
|
@@ -0,0 +1,1881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 DeBERTa-v2 model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
|
| 24 |
+
from ...activations_tf import get_tf_activation
|
| 25 |
+
from ...modeling_tf_outputs import (
|
| 26 |
+
TFBaseModelOutput,
|
| 27 |
+
TFMaskedLMOutput,
|
| 28 |
+
TFMultipleChoiceModelOutput,
|
| 29 |
+
TFQuestionAnsweringModelOutput,
|
| 30 |
+
TFSequenceClassifierOutput,
|
| 31 |
+
TFTokenClassifierOutput,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_tf_utils import (
|
| 34 |
+
TFMaskedLanguageModelingLoss,
|
| 35 |
+
TFModelInputType,
|
| 36 |
+
TFMultipleChoiceLoss,
|
| 37 |
+
TFPreTrainedModel,
|
| 38 |
+
TFQuestionAnsweringLoss,
|
| 39 |
+
TFSequenceClassificationLoss,
|
| 40 |
+
TFTokenClassificationLoss,
|
| 41 |
+
get_initializer,
|
| 42 |
+
keras,
|
| 43 |
+
unpack_inputs,
|
| 44 |
+
)
|
| 45 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 46 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 47 |
+
from .configuration_deberta_v2 import DebertaV2Config
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
_CONFIG_FOR_DOC = "DebertaV2Config"
|
| 53 |
+
_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaContextPooler with Deberta->DebertaV2
|
| 57 |
+
class TFDebertaV2ContextPooler(keras.layers.Layer):
|
| 58 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 59 |
+
super().__init__(**kwargs)
|
| 60 |
+
self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense")
|
| 61 |
+
self.dropout = TFDebertaV2StableDropout(config.pooler_dropout, name="dropout")
|
| 62 |
+
self.config = config
|
| 63 |
+
|
| 64 |
+
def call(self, hidden_states, training: bool = False):
|
| 65 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 66 |
+
# to the first token.
|
| 67 |
+
context_token = hidden_states[:, 0]
|
| 68 |
+
context_token = self.dropout(context_token, training=training)
|
| 69 |
+
pooled_output = self.dense(context_token)
|
| 70 |
+
pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output)
|
| 71 |
+
return pooled_output
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def output_dim(self) -> int:
|
| 75 |
+
return self.config.hidden_size
|
| 76 |
+
|
| 77 |
+
def build(self, input_shape=None):
|
| 78 |
+
if self.built:
|
| 79 |
+
return
|
| 80 |
+
self.built = True
|
| 81 |
+
if getattr(self, "dense", None) is not None:
|
| 82 |
+
with tf.name_scope(self.dense.name):
|
| 83 |
+
self.dense.build([None, None, self.config.pooler_hidden_size])
|
| 84 |
+
if getattr(self, "dropout", None) is not None:
|
| 85 |
+
with tf.name_scope(self.dropout.name):
|
| 86 |
+
self.dropout.build(None)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2
|
| 90 |
+
class TFDebertaV2XSoftmax(keras.layers.Layer):
|
| 91 |
+
"""
|
| 92 |
+
Masked Softmax which is optimized for saving memory
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
input (`tf.Tensor`): The input tensor that will apply softmax.
|
| 96 |
+
mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
|
| 97 |
+
dim (int): The dimension that will apply softmax
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, axis=-1, **kwargs):
|
| 101 |
+
super().__init__(**kwargs)
|
| 102 |
+
self.axis = axis
|
| 103 |
+
|
| 104 |
+
def call(self, inputs: tf.Tensor, mask: tf.Tensor):
|
| 105 |
+
rmask = tf.logical_not(tf.cast(mask, tf.bool))
|
| 106 |
+
output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)
|
| 107 |
+
output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)
|
| 108 |
+
output = tf.where(rmask, 0.0, output)
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
|
| 113 |
+
class TFDebertaV2StableDropout(keras.layers.Layer):
|
| 114 |
+
"""
|
| 115 |
+
Optimized dropout module for stabilizing the training
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
drop_prob (float): the dropout probabilities
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, drop_prob, **kwargs):
|
| 122 |
+
super().__init__(**kwargs)
|
| 123 |
+
self.drop_prob = drop_prob
|
| 124 |
+
|
| 125 |
+
@tf.custom_gradient
|
| 126 |
+
def xdropout(self, inputs):
|
| 127 |
+
"""
|
| 128 |
+
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
|
| 129 |
+
"""
|
| 130 |
+
mask = tf.cast(
|
| 131 |
+
1
|
| 132 |
+
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
|
| 133 |
+
tf.bool,
|
| 134 |
+
)
|
| 135 |
+
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)
|
| 136 |
+
if self.drop_prob > 0:
|
| 137 |
+
inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale
|
| 138 |
+
|
| 139 |
+
def grad(upstream):
|
| 140 |
+
if self.drop_prob > 0:
|
| 141 |
+
return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale
|
| 142 |
+
else:
|
| 143 |
+
return upstream
|
| 144 |
+
|
| 145 |
+
return inputs, grad
|
| 146 |
+
|
| 147 |
+
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
|
| 148 |
+
if training:
|
| 149 |
+
return self.xdropout(inputs)
|
| 150 |
+
return inputs
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaSelfOutput with Deberta->DebertaV2
|
| 154 |
+
class TFDebertaV2SelfOutput(keras.layers.Layer):
|
| 155 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 156 |
+
super().__init__(**kwargs)
|
| 157 |
+
self.dense = keras.layers.Dense(config.hidden_size, name="dense")
|
| 158 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 159 |
+
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
|
| 160 |
+
self.config = config
|
| 161 |
+
|
| 162 |
+
def call(self, hidden_states, input_tensor, training: bool = False):
|
| 163 |
+
hidden_states = self.dense(hidden_states)
|
| 164 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 165 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 166 |
+
return hidden_states
|
| 167 |
+
|
| 168 |
+
def build(self, input_shape=None):
|
| 169 |
+
if self.built:
|
| 170 |
+
return
|
| 171 |
+
self.built = True
|
| 172 |
+
if getattr(self, "dense", None) is not None:
|
| 173 |
+
with tf.name_scope(self.dense.name):
|
| 174 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 175 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 176 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 177 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 178 |
+
if getattr(self, "dropout", None) is not None:
|
| 179 |
+
with tf.name_scope(self.dropout.name):
|
| 180 |
+
self.dropout.build(None)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2
|
| 184 |
+
class TFDebertaV2Attention(keras.layers.Layer):
|
| 185 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 186 |
+
super().__init__(**kwargs)
|
| 187 |
+
self.self = TFDebertaV2DisentangledSelfAttention(config, name="self")
|
| 188 |
+
self.dense_output = TFDebertaV2SelfOutput(config, name="output")
|
| 189 |
+
self.config = config
|
| 190 |
+
|
| 191 |
+
def call(
|
| 192 |
+
self,
|
| 193 |
+
input_tensor: tf.Tensor,
|
| 194 |
+
attention_mask: tf.Tensor,
|
| 195 |
+
query_states: Optional[tf.Tensor] = None,
|
| 196 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 197 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 198 |
+
output_attentions: bool = False,
|
| 199 |
+
training: bool = False,
|
| 200 |
+
) -> Tuple[tf.Tensor]:
|
| 201 |
+
self_outputs = self.self(
|
| 202 |
+
hidden_states=input_tensor,
|
| 203 |
+
attention_mask=attention_mask,
|
| 204 |
+
query_states=query_states,
|
| 205 |
+
relative_pos=relative_pos,
|
| 206 |
+
rel_embeddings=rel_embeddings,
|
| 207 |
+
output_attentions=output_attentions,
|
| 208 |
+
training=training,
|
| 209 |
+
)
|
| 210 |
+
if query_states is None:
|
| 211 |
+
query_states = input_tensor
|
| 212 |
+
attention_output = self.dense_output(
|
| 213 |
+
hidden_states=self_outputs[0], input_tensor=query_states, training=training
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
output = (attention_output,) + self_outputs[1:]
|
| 217 |
+
|
| 218 |
+
return output
|
| 219 |
+
|
| 220 |
+
def build(self, input_shape=None):
|
| 221 |
+
if self.built:
|
| 222 |
+
return
|
| 223 |
+
self.built = True
|
| 224 |
+
if getattr(self, "self", None) is not None:
|
| 225 |
+
with tf.name_scope(self.self.name):
|
| 226 |
+
self.self.build(None)
|
| 227 |
+
if getattr(self, "dense_output", None) is not None:
|
| 228 |
+
with tf.name_scope(self.dense_output.name):
|
| 229 |
+
self.dense_output.build(None)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2
|
| 233 |
+
class TFDebertaV2Intermediate(keras.layers.Layer):
|
| 234 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 235 |
+
super().__init__(**kwargs)
|
| 236 |
+
|
| 237 |
+
self.dense = keras.layers.Dense(
|
| 238 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if isinstance(config.hidden_act, str):
|
| 242 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 243 |
+
else:
|
| 244 |
+
self.intermediate_act_fn = config.hidden_act
|
| 245 |
+
self.config = config
|
| 246 |
+
|
| 247 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 248 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 249 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 250 |
+
|
| 251 |
+
return hidden_states
|
| 252 |
+
|
| 253 |
+
def build(self, input_shape=None):
|
| 254 |
+
if self.built:
|
| 255 |
+
return
|
| 256 |
+
self.built = True
|
| 257 |
+
if getattr(self, "dense", None) is not None:
|
| 258 |
+
with tf.name_scope(self.dense.name):
|
| 259 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2
|
| 263 |
+
class TFDebertaV2Output(keras.layers.Layer):
|
| 264 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 265 |
+
super().__init__(**kwargs)
|
| 266 |
+
|
| 267 |
+
self.dense = keras.layers.Dense(
|
| 268 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 269 |
+
)
|
| 270 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 271 |
+
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
|
| 272 |
+
self.config = config
|
| 273 |
+
|
| 274 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 275 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 276 |
+
hidden_states = self.dropout(hidden_states, training=training)
|
| 277 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 278 |
+
|
| 279 |
+
return hidden_states
|
| 280 |
+
|
| 281 |
+
def build(self, input_shape=None):
|
| 282 |
+
if self.built:
|
| 283 |
+
return
|
| 284 |
+
self.built = True
|
| 285 |
+
if getattr(self, "dense", None) is not None:
|
| 286 |
+
with tf.name_scope(self.dense.name):
|
| 287 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 288 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 289 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 290 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 291 |
+
if getattr(self, "dropout", None) is not None:
|
| 292 |
+
with tf.name_scope(self.dropout.name):
|
| 293 |
+
self.dropout.build(None)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2
|
| 297 |
+
class TFDebertaV2Layer(keras.layers.Layer):
|
| 298 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 299 |
+
super().__init__(**kwargs)
|
| 300 |
+
|
| 301 |
+
self.attention = TFDebertaV2Attention(config, name="attention")
|
| 302 |
+
self.intermediate = TFDebertaV2Intermediate(config, name="intermediate")
|
| 303 |
+
self.bert_output = TFDebertaV2Output(config, name="output")
|
| 304 |
+
|
| 305 |
+
def call(
|
| 306 |
+
self,
|
| 307 |
+
hidden_states: tf.Tensor,
|
| 308 |
+
attention_mask: tf.Tensor,
|
| 309 |
+
query_states: Optional[tf.Tensor] = None,
|
| 310 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 311 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 312 |
+
output_attentions: bool = False,
|
| 313 |
+
training: bool = False,
|
| 314 |
+
) -> Tuple[tf.Tensor]:
|
| 315 |
+
attention_outputs = self.attention(
|
| 316 |
+
input_tensor=hidden_states,
|
| 317 |
+
attention_mask=attention_mask,
|
| 318 |
+
query_states=query_states,
|
| 319 |
+
relative_pos=relative_pos,
|
| 320 |
+
rel_embeddings=rel_embeddings,
|
| 321 |
+
output_attentions=output_attentions,
|
| 322 |
+
training=training,
|
| 323 |
+
)
|
| 324 |
+
attention_output = attention_outputs[0]
|
| 325 |
+
intermediate_output = self.intermediate(hidden_states=attention_output)
|
| 326 |
+
layer_output = self.bert_output(
|
| 327 |
+
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
| 328 |
+
)
|
| 329 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 330 |
+
|
| 331 |
+
return outputs
|
| 332 |
+
|
| 333 |
+
def build(self, input_shape=None):
|
| 334 |
+
if self.built:
|
| 335 |
+
return
|
| 336 |
+
self.built = True
|
| 337 |
+
if getattr(self, "attention", None) is not None:
|
| 338 |
+
with tf.name_scope(self.attention.name):
|
| 339 |
+
self.attention.build(None)
|
| 340 |
+
if getattr(self, "intermediate", None) is not None:
|
| 341 |
+
with tf.name_scope(self.intermediate.name):
|
| 342 |
+
self.intermediate.build(None)
|
| 343 |
+
if getattr(self, "bert_output", None) is not None:
|
| 344 |
+
with tf.name_scope(self.bert_output.name):
|
| 345 |
+
self.bert_output.build(None)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class TFDebertaV2ConvLayer(keras.layers.Layer):
|
| 349 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 350 |
+
super().__init__(**kwargs)
|
| 351 |
+
|
| 352 |
+
self.kernel_size = getattr(config, "conv_kernel_size", 3)
|
| 353 |
+
# groups = getattr(config, "conv_groups", 1)
|
| 354 |
+
self.conv_act = get_tf_activation(getattr(config, "conv_act", "tanh"))
|
| 355 |
+
self.padding = (self.kernel_size - 1) // 2
|
| 356 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 357 |
+
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
|
| 358 |
+
self.config = config
|
| 359 |
+
|
| 360 |
+
def build(self, input_shape=None):
|
| 361 |
+
if self.built:
|
| 362 |
+
return
|
| 363 |
+
self.built = True
|
| 364 |
+
with tf.name_scope("conv"):
|
| 365 |
+
self.conv_kernel = self.add_weight(
|
| 366 |
+
name="kernel",
|
| 367 |
+
shape=[self.kernel_size, self.config.hidden_size, self.config.hidden_size],
|
| 368 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 369 |
+
)
|
| 370 |
+
self.conv_bias = self.add_weight(
|
| 371 |
+
name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer()
|
| 372 |
+
)
|
| 373 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 374 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 375 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 376 |
+
if getattr(self, "dropout", None) is not None:
|
| 377 |
+
with tf.name_scope(self.dropout.name):
|
| 378 |
+
self.dropout.build(None)
|
| 379 |
+
|
| 380 |
+
def call(
|
| 381 |
+
self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False
|
| 382 |
+
) -> tf.Tensor:
|
| 383 |
+
out = tf.nn.conv2d(
|
| 384 |
+
tf.expand_dims(hidden_states, 1),
|
| 385 |
+
tf.expand_dims(self.conv_kernel, 0),
|
| 386 |
+
strides=1,
|
| 387 |
+
padding=[[0, 0], [0, 0], [self.padding, self.padding], [0, 0]],
|
| 388 |
+
)
|
| 389 |
+
out = tf.squeeze(tf.nn.bias_add(out, self.conv_bias), 1)
|
| 390 |
+
rmask = tf.cast(1 - input_mask, tf.bool)
|
| 391 |
+
out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
|
| 392 |
+
out = self.dropout(out, training=training)
|
| 393 |
+
out = self.conv_act(out)
|
| 394 |
+
|
| 395 |
+
layer_norm_input = residual_states + out
|
| 396 |
+
output = self.LayerNorm(layer_norm_input)
|
| 397 |
+
|
| 398 |
+
if input_mask is None:
|
| 399 |
+
output_states = output
|
| 400 |
+
else:
|
| 401 |
+
if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)):
|
| 402 |
+
if len(shape_list(input_mask)) == 4:
|
| 403 |
+
input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
|
| 404 |
+
input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype)
|
| 405 |
+
|
| 406 |
+
output_states = output * input_mask
|
| 407 |
+
|
| 408 |
+
return output_states
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class TFDebertaV2Encoder(keras.layers.Layer):
|
| 412 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 413 |
+
super().__init__(**kwargs)
|
| 414 |
+
|
| 415 |
+
self.layer = [TFDebertaV2Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
| 416 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 417 |
+
self.config = config
|
| 418 |
+
if self.relative_attention:
|
| 419 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 420 |
+
if self.max_relative_positions < 1:
|
| 421 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 422 |
+
|
| 423 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 424 |
+
self.pos_ebd_size = self.max_relative_positions * 2
|
| 425 |
+
|
| 426 |
+
if self.position_buckets > 0:
|
| 427 |
+
self.pos_ebd_size = self.position_buckets * 2
|
| 428 |
+
|
| 429 |
+
self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
|
| 430 |
+
|
| 431 |
+
if "layer_norm" in self.norm_rel_ebd:
|
| 432 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 433 |
+
|
| 434 |
+
self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None
|
| 435 |
+
|
| 436 |
+
def build(self, input_shape=None):
|
| 437 |
+
if self.built:
|
| 438 |
+
return
|
| 439 |
+
self.built = True
|
| 440 |
+
if self.relative_attention:
|
| 441 |
+
self.rel_embeddings = self.add_weight(
|
| 442 |
+
name="rel_embeddings.weight",
|
| 443 |
+
shape=[self.pos_ebd_size, self.config.hidden_size],
|
| 444 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 445 |
+
)
|
| 446 |
+
if getattr(self, "conv", None) is not None:
|
| 447 |
+
with tf.name_scope(self.conv.name):
|
| 448 |
+
self.conv.build(None)
|
| 449 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 450 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 451 |
+
self.LayerNorm.build([None, self.config.hidden_size])
|
| 452 |
+
if getattr(self, "layer", None) is not None:
|
| 453 |
+
for layer in self.layer:
|
| 454 |
+
with tf.name_scope(layer.name):
|
| 455 |
+
layer.build(None)
|
| 456 |
+
|
| 457 |
+
def get_rel_embedding(self):
|
| 458 |
+
rel_embeddings = self.rel_embeddings if self.relative_attention else None
|
| 459 |
+
if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
|
| 460 |
+
rel_embeddings = self.LayerNorm(rel_embeddings)
|
| 461 |
+
return rel_embeddings
|
| 462 |
+
|
| 463 |
+
def get_attention_mask(self, attention_mask):
|
| 464 |
+
if len(shape_list(attention_mask)) <= 2:
|
| 465 |
+
extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2)
|
| 466 |
+
attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1)
|
| 467 |
+
attention_mask = tf.cast(attention_mask, tf.uint8)
|
| 468 |
+
elif len(shape_list(attention_mask)) == 3:
|
| 469 |
+
attention_mask = tf.expand_dims(attention_mask, 1)
|
| 470 |
+
|
| 471 |
+
return attention_mask
|
| 472 |
+
|
| 473 |
+
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
| 474 |
+
if self.relative_attention and relative_pos is None:
|
| 475 |
+
q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2]
|
| 476 |
+
relative_pos = build_relative_position(
|
| 477 |
+
q,
|
| 478 |
+
shape_list(hidden_states)[-2],
|
| 479 |
+
bucket_size=self.position_buckets,
|
| 480 |
+
max_position=self.max_relative_positions,
|
| 481 |
+
)
|
| 482 |
+
return relative_pos
|
| 483 |
+
|
| 484 |
+
def call(
|
| 485 |
+
self,
|
| 486 |
+
hidden_states: tf.Tensor,
|
| 487 |
+
attention_mask: tf.Tensor,
|
| 488 |
+
query_states: Optional[tf.Tensor] = None,
|
| 489 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 490 |
+
output_attentions: bool = False,
|
| 491 |
+
output_hidden_states: bool = False,
|
| 492 |
+
return_dict: bool = True,
|
| 493 |
+
training: bool = False,
|
| 494 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 495 |
+
if len(shape_list(attention_mask)) <= 2:
|
| 496 |
+
input_mask = attention_mask
|
| 497 |
+
else:
|
| 498 |
+
input_mask = tf.cast(tf.math.reduce_sum(attention_mask, axis=-2) > 0, dtype=tf.uint8)
|
| 499 |
+
|
| 500 |
+
all_hidden_states = () if output_hidden_states else None
|
| 501 |
+
all_attentions = () if output_attentions else None
|
| 502 |
+
|
| 503 |
+
attention_mask = self.get_attention_mask(attention_mask)
|
| 504 |
+
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
| 505 |
+
|
| 506 |
+
next_kv = hidden_states
|
| 507 |
+
|
| 508 |
+
rel_embeddings = self.get_rel_embedding()
|
| 509 |
+
output_states = next_kv
|
| 510 |
+
for i, layer_module in enumerate(self.layer):
|
| 511 |
+
if output_hidden_states:
|
| 512 |
+
all_hidden_states = all_hidden_states + (output_states,)
|
| 513 |
+
|
| 514 |
+
layer_outputs = layer_module(
|
| 515 |
+
hidden_states=next_kv,
|
| 516 |
+
attention_mask=attention_mask,
|
| 517 |
+
query_states=query_states,
|
| 518 |
+
relative_pos=relative_pos,
|
| 519 |
+
rel_embeddings=rel_embeddings,
|
| 520 |
+
output_attentions=output_attentions,
|
| 521 |
+
training=training,
|
| 522 |
+
)
|
| 523 |
+
output_states = layer_outputs[0]
|
| 524 |
+
|
| 525 |
+
if i == 0 and self.conv is not None:
|
| 526 |
+
output_states = self.conv(hidden_states, output_states, input_mask)
|
| 527 |
+
|
| 528 |
+
next_kv = output_states
|
| 529 |
+
|
| 530 |
+
if output_attentions:
|
| 531 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 532 |
+
|
| 533 |
+
# Add last layer
|
| 534 |
+
if output_hidden_states:
|
| 535 |
+
all_hidden_states = all_hidden_states + (output_states,)
|
| 536 |
+
|
| 537 |
+
if not return_dict:
|
| 538 |
+
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
|
| 539 |
+
|
| 540 |
+
return TFBaseModelOutput(
|
| 541 |
+
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
| 546 |
+
sign = tf.math.sign(relative_pos)
|
| 547 |
+
mid = bucket_size // 2
|
| 548 |
+
abs_pos = tf.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, tf.math.abs(relative_pos))
|
| 549 |
+
log_pos = tf.math.ceil(
|
| 550 |
+
tf.cast(tf.math.log(abs_pos / mid), tf.float32)
|
| 551 |
+
/ tf.cast(tf.math.log((max_position - 1) / mid), tf.float32)
|
| 552 |
+
* tf.cast(mid - 1, tf.float32) # in graph mode
|
| 553 |
+
) + tf.cast(mid, tf.float32)
|
| 554 |
+
bucket_pos = tf.cast(
|
| 555 |
+
tf.where(abs_pos <= mid, tf.cast(relative_pos, tf.float32), log_pos * tf.cast(sign, tf.float32)), tf.int32
|
| 556 |
+
)
|
| 557 |
+
return bucket_pos
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
|
| 561 |
+
"""
|
| 562 |
+
Build relative position according to the query and key
|
| 563 |
+
|
| 564 |
+
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
|
| 565 |
+
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
|
| 566 |
+
P_k\\)
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
query_size (int): the length of query
|
| 570 |
+
key_size (int): the length of key
|
| 571 |
+
bucket_size (int): the size of position bucket
|
| 572 |
+
max_position (int): the maximum allowed absolute position
|
| 573 |
+
|
| 574 |
+
Return:
|
| 575 |
+
`tf.Tensor`: A tensor with shape [1, query_size, key_size]
|
| 576 |
+
|
| 577 |
+
"""
|
| 578 |
+
q_ids = tf.range(query_size, dtype=tf.int32)
|
| 579 |
+
k_ids = tf.range(key_size, dtype=tf.int32)
|
| 580 |
+
rel_pos_ids = q_ids[:, None] - tf.tile(tf.expand_dims(k_ids, axis=0), [shape_list(q_ids)[0], 1])
|
| 581 |
+
if bucket_size > 0 and max_position > 0:
|
| 582 |
+
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
| 583 |
+
rel_pos_ids = rel_pos_ids[:query_size, :]
|
| 584 |
+
rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0)
|
| 585 |
+
return tf.cast(rel_pos_ids, tf.int64)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
| 589 |
+
shapes = [
|
| 590 |
+
shape_list(query_layer)[0],
|
| 591 |
+
shape_list(query_layer)[1],
|
| 592 |
+
shape_list(query_layer)[2],
|
| 593 |
+
shape_list(relative_pos)[-1],
|
| 594 |
+
]
|
| 595 |
+
return tf.broadcast_to(c2p_pos, shapes)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
| 599 |
+
shapes = [
|
| 600 |
+
shape_list(query_layer)[0],
|
| 601 |
+
shape_list(query_layer)[1],
|
| 602 |
+
shape_list(key_layer)[-2],
|
| 603 |
+
shape_list(key_layer)[-2],
|
| 604 |
+
]
|
| 605 |
+
return tf.broadcast_to(c2p_pos, shapes)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
| 609 |
+
shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]]
|
| 610 |
+
return tf.broadcast_to(pos_index, shapes)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def take_along_axis(x, indices):
|
| 614 |
+
# Only a valid port of np.take_along_axis when the gather axis is -1
|
| 615 |
+
|
| 616 |
+
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
|
| 617 |
+
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
|
| 618 |
+
# [B, S, P] -> [B, S, P, D]
|
| 619 |
+
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
|
| 620 |
+
|
| 621 |
+
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
|
| 622 |
+
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
|
| 623 |
+
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
|
| 624 |
+
|
| 625 |
+
# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
|
| 626 |
+
else:
|
| 627 |
+
gathered = tf.gather(x, indices, batch_dims=2)
|
| 628 |
+
|
| 629 |
+
return gathered
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class TFDebertaV2DisentangledSelfAttention(keras.layers.Layer):
|
| 633 |
+
"""
|
| 634 |
+
Disentangled self-attention module
|
| 635 |
+
|
| 636 |
+
Parameters:
|
| 637 |
+
config (`DebertaV2Config`):
|
| 638 |
+
A model config class instance with the configuration to build a new model. The schema is similar to
|
| 639 |
+
*BertConfig*, for more details, please refer [`DebertaV2Config`]
|
| 640 |
+
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 644 |
+
super().__init__(**kwargs)
|
| 645 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 646 |
+
raise ValueError(
|
| 647 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 648 |
+
f"heads ({config.num_attention_heads})"
|
| 649 |
+
)
|
| 650 |
+
self.num_attention_heads = config.num_attention_heads
|
| 651 |
+
_attention_head_size = config.hidden_size // config.num_attention_heads
|
| 652 |
+
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
|
| 653 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 654 |
+
self.query_proj = keras.layers.Dense(
|
| 655 |
+
self.all_head_size,
|
| 656 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 657 |
+
name="query_proj",
|
| 658 |
+
use_bias=True,
|
| 659 |
+
)
|
| 660 |
+
self.key_proj = keras.layers.Dense(
|
| 661 |
+
self.all_head_size,
|
| 662 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 663 |
+
name="key_proj",
|
| 664 |
+
use_bias=True,
|
| 665 |
+
)
|
| 666 |
+
self.value_proj = keras.layers.Dense(
|
| 667 |
+
self.all_head_size,
|
| 668 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 669 |
+
name="value_proj",
|
| 670 |
+
use_bias=True,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
self.share_att_key = getattr(config, "share_att_key", False)
|
| 674 |
+
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
|
| 675 |
+
self.relative_attention = getattr(config, "relative_attention", False)
|
| 676 |
+
|
| 677 |
+
if self.relative_attention:
|
| 678 |
+
self.position_buckets = getattr(config, "position_buckets", -1)
|
| 679 |
+
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
|
| 680 |
+
if self.max_relative_positions < 1:
|
| 681 |
+
self.max_relative_positions = config.max_position_embeddings
|
| 682 |
+
self.pos_ebd_size = self.max_relative_positions
|
| 683 |
+
if self.position_buckets > 0:
|
| 684 |
+
self.pos_ebd_size = self.position_buckets
|
| 685 |
+
|
| 686 |
+
self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout")
|
| 687 |
+
|
| 688 |
+
if not self.share_att_key:
|
| 689 |
+
if "c2p" in self.pos_att_type:
|
| 690 |
+
self.pos_key_proj = keras.layers.Dense(
|
| 691 |
+
self.all_head_size,
|
| 692 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 693 |
+
name="pos_proj",
|
| 694 |
+
use_bias=True,
|
| 695 |
+
)
|
| 696 |
+
if "p2c" in self.pos_att_type:
|
| 697 |
+
self.pos_query_proj = keras.layers.Dense(
|
| 698 |
+
self.all_head_size,
|
| 699 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 700 |
+
name="pos_q_proj",
|
| 701 |
+
)
|
| 702 |
+
self.softmax = TFDebertaV2XSoftmax(axis=-1)
|
| 703 |
+
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
|
| 704 |
+
self.config = config
|
| 705 |
+
|
| 706 |
+
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
|
| 707 |
+
tensor_shape = shape_list(tensor)
|
| 708 |
+
# In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
|
| 709 |
+
shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
|
| 710 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 711 |
+
tensor = tf.reshape(tensor=tensor, shape=shape)
|
| 712 |
+
tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 713 |
+
x_shape = shape_list(tensor)
|
| 714 |
+
tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
|
| 715 |
+
return tensor
|
| 716 |
+
|
| 717 |
+
def call(
|
| 718 |
+
self,
|
| 719 |
+
hidden_states: tf.Tensor,
|
| 720 |
+
attention_mask: tf.Tensor,
|
| 721 |
+
query_states: Optional[tf.Tensor] = None,
|
| 722 |
+
relative_pos: Optional[tf.Tensor] = None,
|
| 723 |
+
rel_embeddings: Optional[tf.Tensor] = None,
|
| 724 |
+
output_attentions: bool = False,
|
| 725 |
+
training: bool = False,
|
| 726 |
+
) -> Tuple[tf.Tensor]:
|
| 727 |
+
"""
|
| 728 |
+
Call the module
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
hidden_states (`tf.Tensor`):
|
| 732 |
+
Input states to the module usually the output from previous layer, it will be the Q,K and V in
|
| 733 |
+
*Attention(Q,K,V)*
|
| 734 |
+
|
| 735 |
+
attention_mask (`tf.Tensor`):
|
| 736 |
+
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
|
| 737 |
+
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
|
| 738 |
+
th token.
|
| 739 |
+
|
| 740 |
+
return_att (`bool`, *optional*):
|
| 741 |
+
Whether return the attention matrix.
|
| 742 |
+
|
| 743 |
+
query_states (`tf.Tensor`, *optional*):
|
| 744 |
+
The *Q* state in *Attention(Q,K,V)*.
|
| 745 |
+
|
| 746 |
+
relative_pos (`tf.Tensor`):
|
| 747 |
+
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
|
| 748 |
+
values ranging in [*-max_relative_positions*, *max_relative_positions*].
|
| 749 |
+
|
| 750 |
+
rel_embeddings (`tf.Tensor`):
|
| 751 |
+
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
|
| 752 |
+
\\text{max_relative_positions}\\), *hidden_size*].
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
"""
|
| 756 |
+
if query_states is None:
|
| 757 |
+
query_states = hidden_states
|
| 758 |
+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
|
| 759 |
+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
|
| 760 |
+
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
|
| 761 |
+
|
| 762 |
+
rel_att = None
|
| 763 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 764 |
+
scale_factor = 1
|
| 765 |
+
if "c2p" in self.pos_att_type:
|
| 766 |
+
scale_factor += 1
|
| 767 |
+
if "p2c" in self.pos_att_type:
|
| 768 |
+
scale_factor += 1
|
| 769 |
+
scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype))
|
| 770 |
+
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale)
|
| 771 |
+
if self.relative_attention:
|
| 772 |
+
rel_embeddings = self.pos_dropout(rel_embeddings)
|
| 773 |
+
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
|
| 774 |
+
|
| 775 |
+
if rel_att is not None:
|
| 776 |
+
attention_scores = attention_scores + rel_att
|
| 777 |
+
attention_scores = tf.reshape(
|
| 778 |
+
attention_scores,
|
| 779 |
+
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
# bsz x height x length x dimension
|
| 783 |
+
attention_probs = self.softmax(attention_scores, attention_mask)
|
| 784 |
+
attention_probs = self.dropout(attention_probs, training=training)
|
| 785 |
+
context_layer = tf.matmul(
|
| 786 |
+
tf.reshape(attention_probs, [-1, shape_list(attention_probs)[-2], shape_list(attention_probs)[-1]]),
|
| 787 |
+
value_layer,
|
| 788 |
+
)
|
| 789 |
+
context_layer = tf.transpose(
|
| 790 |
+
tf.reshape(
|
| 791 |
+
context_layer,
|
| 792 |
+
[-1, self.num_attention_heads, shape_list(context_layer)[-2], shape_list(context_layer)[-1]],
|
| 793 |
+
),
|
| 794 |
+
[0, 2, 1, 3],
|
| 795 |
+
)
|
| 796 |
+
# Set the final dimension here explicitly.
|
| 797 |
+
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
|
| 798 |
+
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
|
| 799 |
+
# requires final input dimension to be defined
|
| 800 |
+
context_layer_shape = shape_list(context_layer)
|
| 801 |
+
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
|
| 802 |
+
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
| 803 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 804 |
+
return outputs
|
| 805 |
+
|
| 806 |
+
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
| 807 |
+
if relative_pos is None:
|
| 808 |
+
q = shape_list(query_layer)[-2]
|
| 809 |
+
relative_pos = build_relative_position(
|
| 810 |
+
q,
|
| 811 |
+
shape_list(key_layer)[-2],
|
| 812 |
+
bucket_size=self.position_buckets,
|
| 813 |
+
max_position=self.max_relative_positions,
|
| 814 |
+
)
|
| 815 |
+
shape_list_pos = shape_list(relative_pos)
|
| 816 |
+
if len(shape_list_pos) == 2:
|
| 817 |
+
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
|
| 818 |
+
elif len(shape_list_pos) == 3:
|
| 819 |
+
relative_pos = tf.expand_dims(relative_pos, 1)
|
| 820 |
+
# bsz x height x query x key
|
| 821 |
+
elif len(shape_list_pos) != 4:
|
| 822 |
+
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
|
| 823 |
+
|
| 824 |
+
att_span = self.pos_ebd_size
|
| 825 |
+
rel_embeddings = tf.expand_dims(
|
| 826 |
+
rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :], 0
|
| 827 |
+
)
|
| 828 |
+
if self.share_att_key:
|
| 829 |
+
pos_query_layer = tf.tile(
|
| 830 |
+
self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads),
|
| 831 |
+
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
| 832 |
+
)
|
| 833 |
+
pos_key_layer = tf.tile(
|
| 834 |
+
self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads),
|
| 835 |
+
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
| 836 |
+
)
|
| 837 |
+
else:
|
| 838 |
+
if "c2p" in self.pos_att_type:
|
| 839 |
+
pos_key_layer = tf.tile(
|
| 840 |
+
self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads),
|
| 841 |
+
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
| 842 |
+
) # .split(self.all_head_size, dim=-1)
|
| 843 |
+
if "p2c" in self.pos_att_type:
|
| 844 |
+
pos_query_layer = tf.tile(
|
| 845 |
+
self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads),
|
| 846 |
+
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
| 847 |
+
) # .split(self.all_head_size, dim=-1)
|
| 848 |
+
|
| 849 |
+
score = 0
|
| 850 |
+
# content->position
|
| 851 |
+
if "c2p" in self.pos_att_type:
|
| 852 |
+
scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype))
|
| 853 |
+
c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 2, 1]))
|
| 854 |
+
c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
|
| 855 |
+
c2p_att = take_along_axis(
|
| 856 |
+
c2p_att,
|
| 857 |
+
tf.broadcast_to(
|
| 858 |
+
tf.squeeze(c2p_pos, 0),
|
| 859 |
+
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
|
| 860 |
+
),
|
| 861 |
+
)
|
| 862 |
+
score += c2p_att / scale
|
| 863 |
+
|
| 864 |
+
# position->content
|
| 865 |
+
if "p2c" in self.pos_att_type:
|
| 866 |
+
scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))
|
| 867 |
+
if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]:
|
| 868 |
+
r_pos = build_relative_position(
|
| 869 |
+
shape_list(key_layer)[-2],
|
| 870 |
+
shape_list(key_layer)[-2],
|
| 871 |
+
bucket_size=self.position_buckets,
|
| 872 |
+
max_position=self.max_relative_positions,
|
| 873 |
+
)
|
| 874 |
+
r_pos = tf.expand_dims(r_pos, 0)
|
| 875 |
+
else:
|
| 876 |
+
r_pos = relative_pos
|
| 877 |
+
|
| 878 |
+
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
|
| 879 |
+
|
| 880 |
+
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1]))
|
| 881 |
+
p2c_att = tf.transpose(
|
| 882 |
+
take_along_axis(
|
| 883 |
+
p2c_att,
|
| 884 |
+
tf.broadcast_to(
|
| 885 |
+
tf.squeeze(p2c_pos, 0),
|
| 886 |
+
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
|
| 887 |
+
),
|
| 888 |
+
),
|
| 889 |
+
[0, 2, 1],
|
| 890 |
+
)
|
| 891 |
+
score += p2c_att / scale
|
| 892 |
+
|
| 893 |
+
return score
|
| 894 |
+
|
| 895 |
+
def build(self, input_shape=None):
|
| 896 |
+
if self.built:
|
| 897 |
+
return
|
| 898 |
+
self.built = True
|
| 899 |
+
if getattr(self, "query_proj", None) is not None:
|
| 900 |
+
with tf.name_scope(self.query_proj.name):
|
| 901 |
+
self.query_proj.build([None, None, self.config.hidden_size])
|
| 902 |
+
if getattr(self, "key_proj", None) is not None:
|
| 903 |
+
with tf.name_scope(self.key_proj.name):
|
| 904 |
+
self.key_proj.build([None, None, self.config.hidden_size])
|
| 905 |
+
if getattr(self, "value_proj", None) is not None:
|
| 906 |
+
with tf.name_scope(self.value_proj.name):
|
| 907 |
+
self.value_proj.build([None, None, self.config.hidden_size])
|
| 908 |
+
if getattr(self, "dropout", None) is not None:
|
| 909 |
+
with tf.name_scope(self.dropout.name):
|
| 910 |
+
self.dropout.build(None)
|
| 911 |
+
if getattr(self, "pos_dropout", None) is not None:
|
| 912 |
+
with tf.name_scope(self.pos_dropout.name):
|
| 913 |
+
self.pos_dropout.build(None)
|
| 914 |
+
if getattr(self, "pos_key_proj", None) is not None:
|
| 915 |
+
with tf.name_scope(self.pos_key_proj.name):
|
| 916 |
+
self.pos_key_proj.build([None, None, self.config.hidden_size])
|
| 917 |
+
if getattr(self, "pos_query_proj", None) is not None:
|
| 918 |
+
with tf.name_scope(self.pos_query_proj.name):
|
| 919 |
+
self.pos_query_proj.build([None, None, self.config.hidden_size])
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2
|
| 923 |
+
class TFDebertaV2Embeddings(keras.layers.Layer):
|
| 924 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 925 |
+
|
| 926 |
+
def __init__(self, config, **kwargs):
|
| 927 |
+
super().__init__(**kwargs)
|
| 928 |
+
|
| 929 |
+
self.config = config
|
| 930 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 931 |
+
self.hidden_size = config.hidden_size
|
| 932 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 933 |
+
self.position_biased_input = getattr(config, "position_biased_input", True)
|
| 934 |
+
self.initializer_range = config.initializer_range
|
| 935 |
+
if self.embedding_size != config.hidden_size:
|
| 936 |
+
self.embed_proj = keras.layers.Dense(
|
| 937 |
+
config.hidden_size,
|
| 938 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 939 |
+
name="embed_proj",
|
| 940 |
+
use_bias=False,
|
| 941 |
+
)
|
| 942 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 943 |
+
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
|
| 944 |
+
|
| 945 |
+
def build(self, input_shape=None):
|
| 946 |
+
with tf.name_scope("word_embeddings"):
|
| 947 |
+
self.weight = self.add_weight(
|
| 948 |
+
name="weight",
|
| 949 |
+
shape=[self.config.vocab_size, self.embedding_size],
|
| 950 |
+
initializer=get_initializer(self.initializer_range),
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
with tf.name_scope("token_type_embeddings"):
|
| 954 |
+
if self.config.type_vocab_size > 0:
|
| 955 |
+
self.token_type_embeddings = self.add_weight(
|
| 956 |
+
name="embeddings",
|
| 957 |
+
shape=[self.config.type_vocab_size, self.embedding_size],
|
| 958 |
+
initializer=get_initializer(self.initializer_range),
|
| 959 |
+
)
|
| 960 |
+
else:
|
| 961 |
+
self.token_type_embeddings = None
|
| 962 |
+
|
| 963 |
+
with tf.name_scope("position_embeddings"):
|
| 964 |
+
if self.position_biased_input:
|
| 965 |
+
self.position_embeddings = self.add_weight(
|
| 966 |
+
name="embeddings",
|
| 967 |
+
shape=[self.max_position_embeddings, self.hidden_size],
|
| 968 |
+
initializer=get_initializer(self.initializer_range),
|
| 969 |
+
)
|
| 970 |
+
else:
|
| 971 |
+
self.position_embeddings = None
|
| 972 |
+
|
| 973 |
+
if self.built:
|
| 974 |
+
return
|
| 975 |
+
self.built = True
|
| 976 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 977 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 978 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 979 |
+
if getattr(self, "dropout", None) is not None:
|
| 980 |
+
with tf.name_scope(self.dropout.name):
|
| 981 |
+
self.dropout.build(None)
|
| 982 |
+
if getattr(self, "embed_proj", None) is not None:
|
| 983 |
+
with tf.name_scope(self.embed_proj.name):
|
| 984 |
+
self.embed_proj.build([None, None, self.embedding_size])
|
| 985 |
+
|
| 986 |
+
def call(
|
| 987 |
+
self,
|
| 988 |
+
input_ids: Optional[tf.Tensor] = None,
|
| 989 |
+
position_ids: Optional[tf.Tensor] = None,
|
| 990 |
+
token_type_ids: Optional[tf.Tensor] = None,
|
| 991 |
+
inputs_embeds: Optional[tf.Tensor] = None,
|
| 992 |
+
mask: Optional[tf.Tensor] = None,
|
| 993 |
+
training: bool = False,
|
| 994 |
+
) -> tf.Tensor:
|
| 995 |
+
"""
|
| 996 |
+
Applies embedding based on inputs tensor.
|
| 997 |
+
|
| 998 |
+
Returns:
|
| 999 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 1000 |
+
"""
|
| 1001 |
+
if input_ids is None and inputs_embeds is None:
|
| 1002 |
+
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
|
| 1003 |
+
|
| 1004 |
+
if input_ids is not None:
|
| 1005 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 1006 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 1007 |
+
|
| 1008 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 1009 |
+
|
| 1010 |
+
if token_type_ids is None:
|
| 1011 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 1012 |
+
|
| 1013 |
+
if position_ids is None:
|
| 1014 |
+
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
| 1015 |
+
|
| 1016 |
+
final_embeddings = inputs_embeds
|
| 1017 |
+
if self.position_biased_input:
|
| 1018 |
+
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
| 1019 |
+
final_embeddings += position_embeds
|
| 1020 |
+
if self.config.type_vocab_size > 0:
|
| 1021 |
+
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
| 1022 |
+
final_embeddings += token_type_embeds
|
| 1023 |
+
|
| 1024 |
+
if self.embedding_size != self.hidden_size:
|
| 1025 |
+
final_embeddings = self.embed_proj(final_embeddings)
|
| 1026 |
+
|
| 1027 |
+
final_embeddings = self.LayerNorm(final_embeddings)
|
| 1028 |
+
|
| 1029 |
+
if mask is not None:
|
| 1030 |
+
if len(shape_list(mask)) != len(shape_list(final_embeddings)):
|
| 1031 |
+
if len(shape_list(mask)) == 4:
|
| 1032 |
+
mask = tf.squeeze(tf.squeeze(mask, axis=1), axis=1)
|
| 1033 |
+
mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)
|
| 1034 |
+
|
| 1035 |
+
final_embeddings = final_embeddings * mask
|
| 1036 |
+
|
| 1037 |
+
final_embeddings = self.dropout(final_embeddings, training=training)
|
| 1038 |
+
|
| 1039 |
+
return final_embeddings
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPredictionHeadTransform with Deberta->DebertaV2
|
| 1043 |
+
class TFDebertaV2PredictionHeadTransform(keras.layers.Layer):
|
| 1044 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 1045 |
+
super().__init__(**kwargs)
|
| 1046 |
+
|
| 1047 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 1048 |
+
|
| 1049 |
+
self.dense = keras.layers.Dense(
|
| 1050 |
+
units=self.embedding_size,
|
| 1051 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1052 |
+
name="dense",
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
if isinstance(config.hidden_act, str):
|
| 1056 |
+
self.transform_act_fn = get_tf_activation(config.hidden_act)
|
| 1057 |
+
else:
|
| 1058 |
+
self.transform_act_fn = config.hidden_act
|
| 1059 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 1060 |
+
self.config = config
|
| 1061 |
+
|
| 1062 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 1063 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 1064 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 1065 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 1066 |
+
|
| 1067 |
+
return hidden_states
|
| 1068 |
+
|
| 1069 |
+
def build(self, input_shape=None):
|
| 1070 |
+
if self.built:
|
| 1071 |
+
return
|
| 1072 |
+
self.built = True
|
| 1073 |
+
if getattr(self, "dense", None) is not None:
|
| 1074 |
+
with tf.name_scope(self.dense.name):
|
| 1075 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 1076 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 1077 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 1078 |
+
self.LayerNorm.build([None, None, self.embedding_size])
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLMPredictionHead with Deberta->DebertaV2
|
| 1082 |
+
class TFDebertaV2LMPredictionHead(keras.layers.Layer):
|
| 1083 |
+
def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs):
|
| 1084 |
+
super().__init__(**kwargs)
|
| 1085 |
+
|
| 1086 |
+
self.config = config
|
| 1087 |
+
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
|
| 1088 |
+
|
| 1089 |
+
self.transform = TFDebertaV2PredictionHeadTransform(config, name="transform")
|
| 1090 |
+
|
| 1091 |
+
# The output weights are the same as the input embeddings, but there is
|
| 1092 |
+
# an output-only bias for each token.
|
| 1093 |
+
self.input_embeddings = input_embeddings
|
| 1094 |
+
|
| 1095 |
+
def build(self, input_shape=None):
|
| 1096 |
+
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
| 1097 |
+
|
| 1098 |
+
if self.built:
|
| 1099 |
+
return
|
| 1100 |
+
self.built = True
|
| 1101 |
+
if getattr(self, "transform", None) is not None:
|
| 1102 |
+
with tf.name_scope(self.transform.name):
|
| 1103 |
+
self.transform.build(None)
|
| 1104 |
+
|
| 1105 |
+
def get_output_embeddings(self) -> keras.layers.Layer:
|
| 1106 |
+
return self.input_embeddings
|
| 1107 |
+
|
| 1108 |
+
def set_output_embeddings(self, value: tf.Variable):
|
| 1109 |
+
self.input_embeddings.weight = value
|
| 1110 |
+
self.input_embeddings.vocab_size = shape_list(value)[0]
|
| 1111 |
+
|
| 1112 |
+
def get_bias(self) -> Dict[str, tf.Variable]:
|
| 1113 |
+
return {"bias": self.bias}
|
| 1114 |
+
|
| 1115 |
+
def set_bias(self, value: tf.Variable):
|
| 1116 |
+
self.bias = value["bias"]
|
| 1117 |
+
self.config.vocab_size = shape_list(value["bias"])[0]
|
| 1118 |
+
|
| 1119 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 1120 |
+
hidden_states = self.transform(hidden_states=hidden_states)
|
| 1121 |
+
seq_length = shape_list(hidden_states)[1]
|
| 1122 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
| 1123 |
+
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
|
| 1124 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
|
| 1125 |
+
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
|
| 1126 |
+
|
| 1127 |
+
return hidden_states
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOnlyMLMHead with Deberta->DebertaV2
|
| 1131 |
+
class TFDebertaV2OnlyMLMHead(keras.layers.Layer):
|
| 1132 |
+
def __init__(self, config: DebertaV2Config, input_embeddings: keras.layers.Layer, **kwargs):
|
| 1133 |
+
super().__init__(**kwargs)
|
| 1134 |
+
self.predictions = TFDebertaV2LMPredictionHead(config, input_embeddings, name="predictions")
|
| 1135 |
+
|
| 1136 |
+
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
|
| 1137 |
+
prediction_scores = self.predictions(hidden_states=sequence_output)
|
| 1138 |
+
|
| 1139 |
+
return prediction_scores
|
| 1140 |
+
|
| 1141 |
+
def build(self, input_shape=None):
|
| 1142 |
+
if self.built:
|
| 1143 |
+
return
|
| 1144 |
+
self.built = True
|
| 1145 |
+
if getattr(self, "predictions", None) is not None:
|
| 1146 |
+
with tf.name_scope(self.predictions.name):
|
| 1147 |
+
self.predictions.build(None)
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaMainLayer with Deberta->DebertaV2
|
| 1151 |
+
class TFDebertaV2MainLayer(keras.layers.Layer):
|
| 1152 |
+
config_class = DebertaV2Config
|
| 1153 |
+
|
| 1154 |
+
def __init__(self, config: DebertaV2Config, **kwargs):
|
| 1155 |
+
super().__init__(**kwargs)
|
| 1156 |
+
|
| 1157 |
+
self.config = config
|
| 1158 |
+
|
| 1159 |
+
self.embeddings = TFDebertaV2Embeddings(config, name="embeddings")
|
| 1160 |
+
self.encoder = TFDebertaV2Encoder(config, name="encoder")
|
| 1161 |
+
|
| 1162 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 1163 |
+
return self.embeddings
|
| 1164 |
+
|
| 1165 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 1166 |
+
self.embeddings.weight = value
|
| 1167 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 1168 |
+
|
| 1169 |
+
def _prune_heads(self, heads_to_prune):
|
| 1170 |
+
"""
|
| 1171 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1172 |
+
class PreTrainedModel
|
| 1173 |
+
"""
|
| 1174 |
+
raise NotImplementedError
|
| 1175 |
+
|
| 1176 |
+
@unpack_inputs
|
| 1177 |
+
def call(
|
| 1178 |
+
self,
|
| 1179 |
+
input_ids: TFModelInputType | None = None,
|
| 1180 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1181 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1182 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1183 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1184 |
+
output_attentions: Optional[bool] = None,
|
| 1185 |
+
output_hidden_states: Optional[bool] = None,
|
| 1186 |
+
return_dict: Optional[bool] = None,
|
| 1187 |
+
training: bool = False,
|
| 1188 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 1189 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1190 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1191 |
+
elif input_ids is not None:
|
| 1192 |
+
input_shape = shape_list(input_ids)
|
| 1193 |
+
elif inputs_embeds is not None:
|
| 1194 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 1195 |
+
else:
|
| 1196 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1197 |
+
|
| 1198 |
+
if attention_mask is None:
|
| 1199 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 1200 |
+
|
| 1201 |
+
if token_type_ids is None:
|
| 1202 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 1203 |
+
|
| 1204 |
+
embedding_output = self.embeddings(
|
| 1205 |
+
input_ids=input_ids,
|
| 1206 |
+
position_ids=position_ids,
|
| 1207 |
+
token_type_ids=token_type_ids,
|
| 1208 |
+
inputs_embeds=inputs_embeds,
|
| 1209 |
+
mask=attention_mask,
|
| 1210 |
+
training=training,
|
| 1211 |
+
)
|
| 1212 |
+
|
| 1213 |
+
encoder_outputs = self.encoder(
|
| 1214 |
+
hidden_states=embedding_output,
|
| 1215 |
+
attention_mask=attention_mask,
|
| 1216 |
+
output_attentions=output_attentions,
|
| 1217 |
+
output_hidden_states=output_hidden_states,
|
| 1218 |
+
return_dict=return_dict,
|
| 1219 |
+
training=training,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
sequence_output = encoder_outputs[0]
|
| 1223 |
+
|
| 1224 |
+
if not return_dict:
|
| 1225 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 1226 |
+
|
| 1227 |
+
return TFBaseModelOutput(
|
| 1228 |
+
last_hidden_state=sequence_output,
|
| 1229 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1230 |
+
attentions=encoder_outputs.attentions,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
def build(self, input_shape=None):
|
| 1234 |
+
if self.built:
|
| 1235 |
+
return
|
| 1236 |
+
self.built = True
|
| 1237 |
+
if getattr(self, "embeddings", None) is not None:
|
| 1238 |
+
with tf.name_scope(self.embeddings.name):
|
| 1239 |
+
self.embeddings.build(None)
|
| 1240 |
+
if getattr(self, "encoder", None) is not None:
|
| 1241 |
+
with tf.name_scope(self.encoder.name):
|
| 1242 |
+
self.encoder.build(None)
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaPreTrainedModel with Deberta->DebertaV2
|
| 1246 |
+
class TFDebertaV2PreTrainedModel(TFPreTrainedModel):
|
| 1247 |
+
"""
|
| 1248 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1249 |
+
models.
|
| 1250 |
+
"""
|
| 1251 |
+
|
| 1252 |
+
config_class = DebertaV2Config
|
| 1253 |
+
base_model_prefix = "deberta"
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
DEBERTA_START_DOCSTRING = r"""
|
| 1257 |
+
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
| 1258 |
+
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
|
| 1259 |
+
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
|
| 1260 |
+
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
|
| 1261 |
+
|
| 1262 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 1263 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 1264 |
+
behavior.
|
| 1265 |
+
|
| 1266 |
+
<Tip>
|
| 1267 |
+
|
| 1268 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 1269 |
+
|
| 1270 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 1271 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 1272 |
+
|
| 1273 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 1274 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 1275 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 1276 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 1277 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 1278 |
+
positional argument:
|
| 1279 |
+
|
| 1280 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 1281 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 1282 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 1283 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 1284 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 1285 |
+
|
| 1286 |
+
Note that when creating models and layers with
|
| 1287 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 1288 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 1289 |
+
|
| 1290 |
+
</Tip>
|
| 1291 |
+
|
| 1292 |
+
Parameters:
|
| 1293 |
+
config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
|
| 1294 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1295 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1296 |
+
"""
|
| 1297 |
+
|
| 1298 |
+
DEBERTA_INPUTS_DOCSTRING = r"""
|
| 1299 |
+
Args:
|
| 1300 |
+
input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
|
| 1301 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1302 |
+
|
| 1303 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1304 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1305 |
+
|
| 1306 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1307 |
+
attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1308 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1309 |
+
|
| 1310 |
+
- 1 for tokens that are **not masked**,
|
| 1311 |
+
- 0 for tokens that are **masked**.
|
| 1312 |
+
|
| 1313 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1314 |
+
token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1315 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 1316 |
+
1]`:
|
| 1317 |
+
|
| 1318 |
+
- 0 corresponds to a *sentence A* token,
|
| 1319 |
+
- 1 corresponds to a *sentence B* token.
|
| 1320 |
+
|
| 1321 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 1322 |
+
position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 1323 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1324 |
+
config.max_position_embeddings - 1]`.
|
| 1325 |
+
|
| 1326 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1327 |
+
inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
|
| 1328 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1329 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 1330 |
+
model's internal embedding lookup matrix.
|
| 1331 |
+
output_attentions (`bool`, *optional*):
|
| 1332 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1333 |
+
tensors for more detail.
|
| 1334 |
+
output_hidden_states (`bool`, *optional*):
|
| 1335 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1336 |
+
more detail.
|
| 1337 |
+
return_dict (`bool`, *optional*):
|
| 1338 |
+
Whether or not to return a [`~utils.ModelOutput``] instead of a plain tuple.
|
| 1339 |
+
"""
|
| 1340 |
+
|
| 1341 |
+
|
| 1342 |
+
@add_start_docstrings(
|
| 1343 |
+
"The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1344 |
+
DEBERTA_START_DOCSTRING,
|
| 1345 |
+
)
|
| 1346 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaModel with Deberta->DebertaV2
|
| 1347 |
+
class TFDebertaV2Model(TFDebertaV2PreTrainedModel):
|
| 1348 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1349 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1350 |
+
|
| 1351 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1352 |
+
|
| 1353 |
+
@unpack_inputs
|
| 1354 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1355 |
+
@add_code_sample_docstrings(
|
| 1356 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1357 |
+
output_type=TFBaseModelOutput,
|
| 1358 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1359 |
+
)
|
| 1360 |
+
def call(
|
| 1361 |
+
self,
|
| 1362 |
+
input_ids: TFModelInputType | None = None,
|
| 1363 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1364 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1365 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1366 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1367 |
+
output_attentions: Optional[bool] = None,
|
| 1368 |
+
output_hidden_states: Optional[bool] = None,
|
| 1369 |
+
return_dict: Optional[bool] = None,
|
| 1370 |
+
training: Optional[bool] = False,
|
| 1371 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 1372 |
+
outputs = self.deberta(
|
| 1373 |
+
input_ids=input_ids,
|
| 1374 |
+
attention_mask=attention_mask,
|
| 1375 |
+
token_type_ids=token_type_ids,
|
| 1376 |
+
position_ids=position_ids,
|
| 1377 |
+
inputs_embeds=inputs_embeds,
|
| 1378 |
+
output_attentions=output_attentions,
|
| 1379 |
+
output_hidden_states=output_hidden_states,
|
| 1380 |
+
return_dict=return_dict,
|
| 1381 |
+
training=training,
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
return outputs
|
| 1385 |
+
|
| 1386 |
+
def build(self, input_shape=None):
|
| 1387 |
+
if self.built:
|
| 1388 |
+
return
|
| 1389 |
+
self.built = True
|
| 1390 |
+
if getattr(self, "deberta", None) is not None:
|
| 1391 |
+
with tf.name_scope(self.deberta.name):
|
| 1392 |
+
self.deberta.build(None)
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
+
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
| 1396 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForMaskedLM with Deberta->DebertaV2
|
| 1397 |
+
class TFDebertaV2ForMaskedLM(TFDebertaV2PreTrainedModel, TFMaskedLanguageModelingLoss):
|
| 1398 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1399 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1400 |
+
|
| 1401 |
+
if config.is_decoder:
|
| 1402 |
+
logger.warning(
|
| 1403 |
+
"If you want to use `TFDebertaV2ForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1404 |
+
"bi-directional self-attention."
|
| 1405 |
+
)
|
| 1406 |
+
|
| 1407 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1408 |
+
self.mlm = TFDebertaV2OnlyMLMHead(config, input_embeddings=self.deberta.embeddings, name="cls")
|
| 1409 |
+
|
| 1410 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 1411 |
+
return self.mlm.predictions
|
| 1412 |
+
|
| 1413 |
+
@unpack_inputs
|
| 1414 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1415 |
+
@add_code_sample_docstrings(
|
| 1416 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1417 |
+
output_type=TFMaskedLMOutput,
|
| 1418 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1419 |
+
)
|
| 1420 |
+
def call(
|
| 1421 |
+
self,
|
| 1422 |
+
input_ids: TFModelInputType | None = None,
|
| 1423 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1424 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1425 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1426 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1427 |
+
output_attentions: Optional[bool] = None,
|
| 1428 |
+
output_hidden_states: Optional[bool] = None,
|
| 1429 |
+
return_dict: Optional[bool] = None,
|
| 1430 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1431 |
+
training: Optional[bool] = False,
|
| 1432 |
+
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
|
| 1433 |
+
r"""
|
| 1434 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1435 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1436 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1437 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1438 |
+
"""
|
| 1439 |
+
outputs = self.deberta(
|
| 1440 |
+
input_ids=input_ids,
|
| 1441 |
+
attention_mask=attention_mask,
|
| 1442 |
+
token_type_ids=token_type_ids,
|
| 1443 |
+
position_ids=position_ids,
|
| 1444 |
+
inputs_embeds=inputs_embeds,
|
| 1445 |
+
output_attentions=output_attentions,
|
| 1446 |
+
output_hidden_states=output_hidden_states,
|
| 1447 |
+
return_dict=return_dict,
|
| 1448 |
+
training=training,
|
| 1449 |
+
)
|
| 1450 |
+
sequence_output = outputs[0]
|
| 1451 |
+
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
|
| 1452 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
|
| 1453 |
+
|
| 1454 |
+
if not return_dict:
|
| 1455 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1456 |
+
return ((loss,) + output) if loss is not None else output
|
| 1457 |
+
|
| 1458 |
+
return TFMaskedLMOutput(
|
| 1459 |
+
loss=loss,
|
| 1460 |
+
logits=prediction_scores,
|
| 1461 |
+
hidden_states=outputs.hidden_states,
|
| 1462 |
+
attentions=outputs.attentions,
|
| 1463 |
+
)
|
| 1464 |
+
|
| 1465 |
+
def build(self, input_shape=None):
|
| 1466 |
+
if self.built:
|
| 1467 |
+
return
|
| 1468 |
+
self.built = True
|
| 1469 |
+
if getattr(self, "deberta", None) is not None:
|
| 1470 |
+
with tf.name_scope(self.deberta.name):
|
| 1471 |
+
self.deberta.build(None)
|
| 1472 |
+
if getattr(self, "mlm", None) is not None:
|
| 1473 |
+
with tf.name_scope(self.mlm.name):
|
| 1474 |
+
self.mlm.build(None)
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
@add_start_docstrings(
|
| 1478 |
+
"""
|
| 1479 |
+
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 1480 |
+
pooled output) e.g. for GLUE tasks.
|
| 1481 |
+
""",
|
| 1482 |
+
DEBERTA_START_DOCSTRING,
|
| 1483 |
+
)
|
| 1484 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForSequenceClassification with Deberta->DebertaV2
|
| 1485 |
+
class TFDebertaV2ForSequenceClassification(TFDebertaV2PreTrainedModel, TFSequenceClassificationLoss):
|
| 1486 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1487 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1488 |
+
|
| 1489 |
+
self.num_labels = config.num_labels
|
| 1490 |
+
|
| 1491 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1492 |
+
self.pooler = TFDebertaV2ContextPooler(config, name="pooler")
|
| 1493 |
+
|
| 1494 |
+
drop_out = getattr(config, "cls_dropout", None)
|
| 1495 |
+
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
| 1496 |
+
self.dropout = TFDebertaV2StableDropout(drop_out, name="cls_dropout")
|
| 1497 |
+
self.classifier = keras.layers.Dense(
|
| 1498 |
+
units=config.num_labels,
|
| 1499 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1500 |
+
name="classifier",
|
| 1501 |
+
)
|
| 1502 |
+
self.output_dim = self.pooler.output_dim
|
| 1503 |
+
|
| 1504 |
+
@unpack_inputs
|
| 1505 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1506 |
+
@add_code_sample_docstrings(
|
| 1507 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1508 |
+
output_type=TFSequenceClassifierOutput,
|
| 1509 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1510 |
+
)
|
| 1511 |
+
def call(
|
| 1512 |
+
self,
|
| 1513 |
+
input_ids: TFModelInputType | None = None,
|
| 1514 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1515 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1516 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1517 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1518 |
+
output_attentions: Optional[bool] = None,
|
| 1519 |
+
output_hidden_states: Optional[bool] = None,
|
| 1520 |
+
return_dict: Optional[bool] = None,
|
| 1521 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1522 |
+
training: Optional[bool] = False,
|
| 1523 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 1524 |
+
r"""
|
| 1525 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1526 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1527 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1528 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1529 |
+
"""
|
| 1530 |
+
outputs = self.deberta(
|
| 1531 |
+
input_ids=input_ids,
|
| 1532 |
+
attention_mask=attention_mask,
|
| 1533 |
+
token_type_ids=token_type_ids,
|
| 1534 |
+
position_ids=position_ids,
|
| 1535 |
+
inputs_embeds=inputs_embeds,
|
| 1536 |
+
output_attentions=output_attentions,
|
| 1537 |
+
output_hidden_states=output_hidden_states,
|
| 1538 |
+
return_dict=return_dict,
|
| 1539 |
+
training=training,
|
| 1540 |
+
)
|
| 1541 |
+
sequence_output = outputs[0]
|
| 1542 |
+
pooled_output = self.pooler(sequence_output, training=training)
|
| 1543 |
+
pooled_output = self.dropout(pooled_output, training=training)
|
| 1544 |
+
logits = self.classifier(pooled_output)
|
| 1545 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1546 |
+
|
| 1547 |
+
if not return_dict:
|
| 1548 |
+
output = (logits,) + outputs[1:]
|
| 1549 |
+
|
| 1550 |
+
return ((loss,) + output) if loss is not None else output
|
| 1551 |
+
|
| 1552 |
+
return TFSequenceClassifierOutput(
|
| 1553 |
+
loss=loss,
|
| 1554 |
+
logits=logits,
|
| 1555 |
+
hidden_states=outputs.hidden_states,
|
| 1556 |
+
attentions=outputs.attentions,
|
| 1557 |
+
)
|
| 1558 |
+
|
| 1559 |
+
def build(self, input_shape=None):
|
| 1560 |
+
if self.built:
|
| 1561 |
+
return
|
| 1562 |
+
self.built = True
|
| 1563 |
+
if getattr(self, "deberta", None) is not None:
|
| 1564 |
+
with tf.name_scope(self.deberta.name):
|
| 1565 |
+
self.deberta.build(None)
|
| 1566 |
+
if getattr(self, "pooler", None) is not None:
|
| 1567 |
+
with tf.name_scope(self.pooler.name):
|
| 1568 |
+
self.pooler.build(None)
|
| 1569 |
+
if getattr(self, "dropout", None) is not None:
|
| 1570 |
+
with tf.name_scope(self.dropout.name):
|
| 1571 |
+
self.dropout.build(None)
|
| 1572 |
+
if getattr(self, "classifier", None) is not None:
|
| 1573 |
+
with tf.name_scope(self.classifier.name):
|
| 1574 |
+
self.classifier.build([None, None, self.output_dim])
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
@add_start_docstrings(
|
| 1578 |
+
"""
|
| 1579 |
+
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1580 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1581 |
+
""",
|
| 1582 |
+
DEBERTA_START_DOCSTRING,
|
| 1583 |
+
)
|
| 1584 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForTokenClassification with Deberta->DebertaV2
|
| 1585 |
+
class TFDebertaV2ForTokenClassification(TFDebertaV2PreTrainedModel, TFTokenClassificationLoss):
|
| 1586 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1587 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1588 |
+
|
| 1589 |
+
self.num_labels = config.num_labels
|
| 1590 |
+
|
| 1591 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1592 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1593 |
+
self.classifier = keras.layers.Dense(
|
| 1594 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1595 |
+
)
|
| 1596 |
+
self.config = config
|
| 1597 |
+
|
| 1598 |
+
@unpack_inputs
|
| 1599 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1600 |
+
@add_code_sample_docstrings(
|
| 1601 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1602 |
+
output_type=TFTokenClassifierOutput,
|
| 1603 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1604 |
+
)
|
| 1605 |
+
def call(
|
| 1606 |
+
self,
|
| 1607 |
+
input_ids: TFModelInputType | None = None,
|
| 1608 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1609 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1610 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1611 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1612 |
+
output_attentions: Optional[bool] = None,
|
| 1613 |
+
output_hidden_states: Optional[bool] = None,
|
| 1614 |
+
return_dict: Optional[bool] = None,
|
| 1615 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1616 |
+
training: Optional[bool] = False,
|
| 1617 |
+
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
| 1618 |
+
r"""
|
| 1619 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1620 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1621 |
+
"""
|
| 1622 |
+
outputs = self.deberta(
|
| 1623 |
+
input_ids=input_ids,
|
| 1624 |
+
attention_mask=attention_mask,
|
| 1625 |
+
token_type_ids=token_type_ids,
|
| 1626 |
+
position_ids=position_ids,
|
| 1627 |
+
inputs_embeds=inputs_embeds,
|
| 1628 |
+
output_attentions=output_attentions,
|
| 1629 |
+
output_hidden_states=output_hidden_states,
|
| 1630 |
+
return_dict=return_dict,
|
| 1631 |
+
training=training,
|
| 1632 |
+
)
|
| 1633 |
+
sequence_output = outputs[0]
|
| 1634 |
+
sequence_output = self.dropout(sequence_output, training=training)
|
| 1635 |
+
logits = self.classifier(inputs=sequence_output)
|
| 1636 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1637 |
+
|
| 1638 |
+
if not return_dict:
|
| 1639 |
+
output = (logits,) + outputs[1:]
|
| 1640 |
+
return ((loss,) + output) if loss is not None else output
|
| 1641 |
+
|
| 1642 |
+
return TFTokenClassifierOutput(
|
| 1643 |
+
loss=loss,
|
| 1644 |
+
logits=logits,
|
| 1645 |
+
hidden_states=outputs.hidden_states,
|
| 1646 |
+
attentions=outputs.attentions,
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
def build(self, input_shape=None):
|
| 1650 |
+
if self.built:
|
| 1651 |
+
return
|
| 1652 |
+
self.built = True
|
| 1653 |
+
if getattr(self, "deberta", None) is not None:
|
| 1654 |
+
with tf.name_scope(self.deberta.name):
|
| 1655 |
+
self.deberta.build(None)
|
| 1656 |
+
if getattr(self, "classifier", None) is not None:
|
| 1657 |
+
with tf.name_scope(self.classifier.name):
|
| 1658 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1659 |
+
|
| 1660 |
+
|
| 1661 |
+
@add_start_docstrings(
|
| 1662 |
+
"""
|
| 1663 |
+
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1664 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1665 |
+
""",
|
| 1666 |
+
DEBERTA_START_DOCSTRING,
|
| 1667 |
+
)
|
| 1668 |
+
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaForQuestionAnswering with Deberta->DebertaV2
|
| 1669 |
+
class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsweringLoss):
|
| 1670 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1671 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1672 |
+
|
| 1673 |
+
self.num_labels = config.num_labels
|
| 1674 |
+
|
| 1675 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1676 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1677 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
| 1678 |
+
)
|
| 1679 |
+
self.config = config
|
| 1680 |
+
|
| 1681 |
+
@unpack_inputs
|
| 1682 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1683 |
+
@add_code_sample_docstrings(
|
| 1684 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1685 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1686 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1687 |
+
)
|
| 1688 |
+
def call(
|
| 1689 |
+
self,
|
| 1690 |
+
input_ids: TFModelInputType | None = None,
|
| 1691 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1692 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1693 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1694 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1695 |
+
output_attentions: Optional[bool] = None,
|
| 1696 |
+
output_hidden_states: Optional[bool] = None,
|
| 1697 |
+
return_dict: Optional[bool] = None,
|
| 1698 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1699 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1700 |
+
training: Optional[bool] = False,
|
| 1701 |
+
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
| 1702 |
+
r"""
|
| 1703 |
+
start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1704 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1705 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1706 |
+
are not taken into account for computing the loss.
|
| 1707 |
+
end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1708 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1709 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1710 |
+
are not taken into account for computing the loss.
|
| 1711 |
+
"""
|
| 1712 |
+
outputs = self.deberta(
|
| 1713 |
+
input_ids=input_ids,
|
| 1714 |
+
attention_mask=attention_mask,
|
| 1715 |
+
token_type_ids=token_type_ids,
|
| 1716 |
+
position_ids=position_ids,
|
| 1717 |
+
inputs_embeds=inputs_embeds,
|
| 1718 |
+
output_attentions=output_attentions,
|
| 1719 |
+
output_hidden_states=output_hidden_states,
|
| 1720 |
+
return_dict=return_dict,
|
| 1721 |
+
training=training,
|
| 1722 |
+
)
|
| 1723 |
+
sequence_output = outputs[0]
|
| 1724 |
+
logits = self.qa_outputs(inputs=sequence_output)
|
| 1725 |
+
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
|
| 1726 |
+
start_logits = tf.squeeze(input=start_logits, axis=-1)
|
| 1727 |
+
end_logits = tf.squeeze(input=end_logits, axis=-1)
|
| 1728 |
+
loss = None
|
| 1729 |
+
|
| 1730 |
+
if start_positions is not None and end_positions is not None:
|
| 1731 |
+
labels = {"start_position": start_positions}
|
| 1732 |
+
labels["end_position"] = end_positions
|
| 1733 |
+
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
| 1734 |
+
|
| 1735 |
+
if not return_dict:
|
| 1736 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1737 |
+
return ((loss,) + output) if loss is not None else output
|
| 1738 |
+
|
| 1739 |
+
return TFQuestionAnsweringModelOutput(
|
| 1740 |
+
loss=loss,
|
| 1741 |
+
start_logits=start_logits,
|
| 1742 |
+
end_logits=end_logits,
|
| 1743 |
+
hidden_states=outputs.hidden_states,
|
| 1744 |
+
attentions=outputs.attentions,
|
| 1745 |
+
)
|
| 1746 |
+
|
| 1747 |
+
def build(self, input_shape=None):
|
| 1748 |
+
if self.built:
|
| 1749 |
+
return
|
| 1750 |
+
self.built = True
|
| 1751 |
+
if getattr(self, "deberta", None) is not None:
|
| 1752 |
+
with tf.name_scope(self.deberta.name):
|
| 1753 |
+
self.deberta.build(None)
|
| 1754 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1755 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1756 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1757 |
+
|
| 1758 |
+
|
| 1759 |
+
@add_start_docstrings(
|
| 1760 |
+
"""
|
| 1761 |
+
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1762 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1763 |
+
""",
|
| 1764 |
+
DEBERTA_START_DOCSTRING,
|
| 1765 |
+
)
|
| 1766 |
+
class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss):
|
| 1767 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1768 |
+
# _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
|
| 1769 |
+
# _keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1770 |
+
|
| 1771 |
+
def __init__(self, config: DebertaV2Config, *inputs, **kwargs):
|
| 1772 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1773 |
+
|
| 1774 |
+
self.deberta = TFDebertaV2MainLayer(config, name="deberta")
|
| 1775 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1776 |
+
self.pooler = TFDebertaV2ContextPooler(config, name="pooler")
|
| 1777 |
+
self.classifier = keras.layers.Dense(
|
| 1778 |
+
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1779 |
+
)
|
| 1780 |
+
self.output_dim = self.pooler.output_dim
|
| 1781 |
+
|
| 1782 |
+
@unpack_inputs
|
| 1783 |
+
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 1784 |
+
@add_code_sample_docstrings(
|
| 1785 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1786 |
+
output_type=TFMultipleChoiceModelOutput,
|
| 1787 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1788 |
+
)
|
| 1789 |
+
def call(
|
| 1790 |
+
self,
|
| 1791 |
+
input_ids: TFModelInputType | None = None,
|
| 1792 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1793 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1794 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1795 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1796 |
+
output_attentions: Optional[bool] = None,
|
| 1797 |
+
output_hidden_states: Optional[bool] = None,
|
| 1798 |
+
return_dict: Optional[bool] = None,
|
| 1799 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1800 |
+
training: Optional[bool] = False,
|
| 1801 |
+
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
| 1802 |
+
r"""
|
| 1803 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 1804 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
| 1805 |
+
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
| 1806 |
+
"""
|
| 1807 |
+
if input_ids is not None:
|
| 1808 |
+
num_choices = shape_list(input_ids)[1]
|
| 1809 |
+
seq_length = shape_list(input_ids)[2]
|
| 1810 |
+
else:
|
| 1811 |
+
num_choices = shape_list(inputs_embeds)[1]
|
| 1812 |
+
seq_length = shape_list(inputs_embeds)[2]
|
| 1813 |
+
|
| 1814 |
+
flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
|
| 1815 |
+
flat_attention_mask = (
|
| 1816 |
+
tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
|
| 1817 |
+
)
|
| 1818 |
+
flat_token_type_ids = (
|
| 1819 |
+
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
|
| 1820 |
+
)
|
| 1821 |
+
flat_position_ids = (
|
| 1822 |
+
tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
|
| 1823 |
+
)
|
| 1824 |
+
flat_inputs_embeds = (
|
| 1825 |
+
tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
|
| 1826 |
+
if inputs_embeds is not None
|
| 1827 |
+
else None
|
| 1828 |
+
)
|
| 1829 |
+
outputs = self.deberta(
|
| 1830 |
+
input_ids=flat_input_ids,
|
| 1831 |
+
attention_mask=flat_attention_mask,
|
| 1832 |
+
token_type_ids=flat_token_type_ids,
|
| 1833 |
+
position_ids=flat_position_ids,
|
| 1834 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1835 |
+
output_attentions=output_attentions,
|
| 1836 |
+
output_hidden_states=output_hidden_states,
|
| 1837 |
+
return_dict=return_dict,
|
| 1838 |
+
training=training,
|
| 1839 |
+
)
|
| 1840 |
+
sequence_output = outputs[0]
|
| 1841 |
+
pooled_output = self.pooler(sequence_output, training=training)
|
| 1842 |
+
pooled_output = self.dropout(pooled_output, training=training)
|
| 1843 |
+
logits = self.classifier(pooled_output)
|
| 1844 |
+
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
| 1845 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
|
| 1846 |
+
|
| 1847 |
+
if not return_dict:
|
| 1848 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1849 |
+
return ((loss,) + output) if loss is not None else output
|
| 1850 |
+
|
| 1851 |
+
return TFMultipleChoiceModelOutput(
|
| 1852 |
+
loss=loss,
|
| 1853 |
+
logits=reshaped_logits,
|
| 1854 |
+
hidden_states=outputs.hidden_states,
|
| 1855 |
+
attentions=outputs.attentions,
|
| 1856 |
+
)
|
| 1857 |
+
|
| 1858 |
+
def build(self, input_shape=None):
|
| 1859 |
+
if self.built:
|
| 1860 |
+
return
|
| 1861 |
+
self.built = True
|
| 1862 |
+
if getattr(self, "deberta", None) is not None:
|
| 1863 |
+
with tf.name_scope(self.deberta.name):
|
| 1864 |
+
self.deberta.build(None)
|
| 1865 |
+
if getattr(self, "pooler", None) is not None:
|
| 1866 |
+
with tf.name_scope(self.pooler.name):
|
| 1867 |
+
self.pooler.build(None)
|
| 1868 |
+
if getattr(self, "classifier", None) is not None:
|
| 1869 |
+
with tf.name_scope(self.classifier.name):
|
| 1870 |
+
self.classifier.build([None, None, self.output_dim])
|
| 1871 |
+
|
| 1872 |
+
|
| 1873 |
+
__all__ = [
|
| 1874 |
+
"TFDebertaV2ForMaskedLM",
|
| 1875 |
+
"TFDebertaV2ForQuestionAnswering",
|
| 1876 |
+
"TFDebertaV2ForMultipleChoice",
|
| 1877 |
+
"TFDebertaV2ForSequenceClassification",
|
| 1878 |
+
"TFDebertaV2ForTokenClassification",
|
| 1879 |
+
"TFDebertaV2Model",
|
| 1880 |
+
"TFDebertaV2PreTrainedModel",
|
| 1881 |
+
]
|