Leon299 commited on
Commit
fd1afc8
·
verified ·
1 Parent(s): 60f02df

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/inference_full.cpython-312.pyc +0 -0
  2. __pycache__/runtime_utils.cpython-312.pyc +0 -0
  3. batch_infer.sh +148 -0
  4. output_qwen3_plain_ar/checkpoint-17233/rng_state_18.pth +3 -0
  5. output_qwen3_plain_ar/checkpoint-17233/rng_state_19.pth +3 -0
  6. output_qwen3_plain_ar/checkpoint-17233/rng_state_2.pth +3 -0
  7. output_qwen3_plain_ar/checkpoint-17233/rng_state_20.pth +3 -0
  8. output_qwen3_plain_ar/checkpoint-17233/rng_state_21.pth +3 -0
  9. output_qwen3_plain_ar/checkpoint-17233/rng_state_22.pth +3 -0
  10. output_qwen3_plain_ar/checkpoint-17233/rng_state_23.pth +3 -0
  11. output_qwen3_plain_ar/checkpoint-17233/rng_state_24.pth +3 -0
  12. output_qwen3_plain_ar/checkpoint-17233/rng_state_25.pth +3 -0
  13. output_qwen3_plain_ar/checkpoint-17233/rng_state_26.pth +3 -0
  14. output_qwen3_plain_ar/checkpoint-17233/rng_state_27.pth +3 -0
  15. output_qwen3_plain_ar/checkpoint-18140/trainer_state.json +0 -0
  16. output_qwen3_plain_ar/checkpoint-18140/zero_to_fp32.py +760 -0
  17. output_qwen3_plain_ar/checkpoint-2721/config.json +66 -0
  18. output_qwen3_plain_ar/checkpoint-2721/generation_config.json +13 -0
  19. output_qwen3_plain_ar/checkpoint-2721/latest +1 -0
  20. output_qwen3_plain_ar/checkpoint-2721/trainer_state.json +1938 -0
  21. output_qwen3_plain_ar/checkpoint-2721/zero_to_fp32.py +760 -0
  22. output_qwen3_plain_ar/checkpoint-3628/config.json +66 -0
  23. output_qwen3_plain_ar/checkpoint-3628/generation_config.json +13 -0
  24. output_qwen3_plain_ar/checkpoint-3628/latest +1 -0
  25. output_qwen3_plain_ar/checkpoint-3628/trainer_state.json +2568 -0
  26. output_qwen3_plain_ar/checkpoint-3628/zero_to_fp32.py +760 -0
  27. output_qwen3_plain_ar/checkpoint-4535/config.json +66 -0
  28. output_qwen3_plain_ar/checkpoint-4535/generation_config.json +13 -0
  29. output_qwen3_plain_ar/checkpoint-4535/latest +1 -0
  30. output_qwen3_plain_ar/checkpoint-4535/trainer_state.json +3205 -0
  31. output_qwen3_plain_ar/checkpoint-4535/zero_to_fp32.py +760 -0
  32. output_qwen3_plain_ar/checkpoint-5442/config.json +66 -0
  33. output_qwen3_plain_ar/checkpoint-5442/generation_config.json +13 -0
  34. output_qwen3_plain_ar/checkpoint-5442/latest +1 -0
  35. output_qwen3_plain_ar/checkpoint-5442/trainer_state.json +0 -0
  36. output_qwen3_plain_ar/checkpoint-5442/zero_to_fp32.py +760 -0
  37. output_qwen3_plain_ar/checkpoint-6349/config.json +66 -0
  38. output_qwen3_plain_ar/checkpoint-6349/generation_config.json +13 -0
  39. output_qwen3_plain_ar/checkpoint-6349/latest +1 -0
  40. output_qwen3_plain_ar/checkpoint-6349/trainer_state.json +0 -0
  41. output_qwen3_plain_ar/checkpoint-6349/zero_to_fp32.py +760 -0
  42. output_qwen3_plain_ar/checkpoint-7256/config.json +66 -0
  43. output_qwen3_plain_ar/checkpoint-7256/generation_config.json +13 -0
  44. output_qwen3_plain_ar/checkpoint-7256/latest +1 -0
  45. output_qwen3_plain_ar/checkpoint-7256/trainer_state.json +0 -0
  46. output_qwen3_plain_ar/checkpoint-7256/zero_to_fp32.py +760 -0
  47. output_qwen3_plain_ar/checkpoint-8163/config.json +66 -0
  48. output_qwen3_plain_ar/checkpoint-8163/generation_config.json +13 -0
  49. output_qwen3_plain_ar/checkpoint-8163/latest +1 -0
  50. output_qwen3_plain_ar/checkpoint-8163/trainer_state.json +0 -0
__pycache__/inference_full.cpython-312.pyc CHANGED
Binary files a/__pycache__/inference_full.cpython-312.pyc and b/__pycache__/inference_full.cpython-312.pyc differ
 
__pycache__/runtime_utils.cpython-312.pyc CHANGED
Binary files a/__pycache__/runtime_utils.cpython-312.pyc and b/__pycache__/runtime_utils.cpython-312.pyc differ
 
batch_infer.sh ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -uo pipefail
4
+
5
+ ########################################
6
+ # 配置区(你只需要改这里)
7
+ ########################################
8
+
9
+ SCRIPT_PATH="qwen3_plain_ar.py"
10
+
11
+ DATASET_PATH="muse_mucodec_chord.ds"
12
+
13
+ # tokenizer(必须是带 chat_template 的)
14
+ TOKENIZER_PATH="/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/final"
15
+
16
+ # checkpoint 列表
17
+ CHECKPOINTS=(
18
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-907"
19
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-1814"
20
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-2721"
21
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-3628"
22
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-4535"
23
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-5442"
24
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-6349"
25
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-7256"
26
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-8163"
27
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9070"
28
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9977"
29
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-10884"
30
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-11791"
31
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-12698"
32
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-13605"
33
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-14512"
34
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-15419"
35
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-16326"
36
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-17233"
37
+ "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-18140"
38
+ )
39
+
40
+ # 输出根目录
41
+ OUTPUT_ROOT="/root/batch_preditions_ablation"
42
+
43
+ # 每个 checkpoint 推理多少条
44
+ NUM_SAMPLES=20
45
+
46
+ ########################################
47
+ # 推理参数(可以调)
48
+ ########################################
49
+
50
+ DEVICE="cuda:0"
51
+ DTYPE="bfloat16"
52
+ ATTN_IMPLEMENTATION="sdpa"
53
+
54
+ TEMPERATURE=1.0
55
+ TOP_K=50
56
+ TOP_P=0.9
57
+
58
+ MAX_NEW_TOKENS=4096
59
+
60
+ # 是否跳过音频解码(调试建议先开)
61
+ SKIP_DECODE=false
62
+
63
+ ########################################
64
+ # 日志文件
65
+ ########################################
66
+
67
+ FAILED_LOG="${OUTPUT_ROOT}/failed_cases.log"
68
+ SUCCESS_LOG="${OUTPUT_ROOT}/success_cases.log"
69
+
70
+ ########################################
71
+ # 开始执行
72
+ ########################################
73
+
74
+ mkdir -p "${OUTPUT_ROOT}"
75
+ touch "${FAILED_LOG}"
76
+ touch "${SUCCESS_LOG}"
77
+
78
+ echo "======================================" | tee -a "${SUCCESS_LOG}"
79
+ echo "Batch inference started at $(date)" | tee -a "${SUCCESS_LOG}"
80
+ echo "Output root: ${OUTPUT_ROOT}" | tee -a "${SUCCESS_LOG}"
81
+ echo "======================================" | tee -a "${SUCCESS_LOG}"
82
+
83
+ for CKPT in "${CHECKPOINTS[@]}"; do
84
+ CKPT_NAME=$(basename "${CKPT}")
85
+ OUT_DIR="${OUTPUT_ROOT}/${CKPT_NAME}"
86
+ CKPT_LOG="${OUT_DIR}/run.log"
87
+
88
+ echo "======================================"
89
+ echo "Running checkpoint: ${CKPT_NAME}"
90
+ echo "Output dir: ${OUT_DIR}"
91
+ echo "======================================"
92
+
93
+ if [ ! -d "${CKPT}" ]; then
94
+ echo "[ERROR] checkpoint directory not found: ${CKPT}" | tee -a "${FAILED_LOG}"
95
+ continue
96
+ fi
97
+
98
+ mkdir -p "${OUT_DIR}"
99
+ touch "${CKPT_LOG}"
100
+
101
+ for ((i=0; i<NUM_SAMPLES; i++)); do
102
+ echo "[INFO] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${CKPT_LOG}"
103
+
104
+ CMD=(
105
+ python "${SCRIPT_PATH}" infer
106
+ --model_path "${CKPT}"
107
+ --tokenizer_path "${TOKENIZER_PATH}"
108
+ --dataset_path "${DATASET_PATH}"
109
+ --split validation
110
+ --sample_idx "${i}"
111
+ --device "${DEVICE}"
112
+ --dtype "${DTYPE}"
113
+ --attn_implementation "${ATTN_IMPLEMENTATION}"
114
+ --temperature "${TEMPERATURE}"
115
+ --top_k "${TOP_K}"
116
+ --top_p "${TOP_P}"
117
+ --max_new_tokens_per_section "${MAX_NEW_TOKENS}"
118
+ --output_dir "${OUT_DIR}"
119
+ --output_prefix "sample_${i}"
120
+ )
121
+
122
+ if [ "${SKIP_DECODE}" = true ]; then
123
+ CMD+=(--skip_decode)
124
+ fi
125
+
126
+ {
127
+ echo "[CMD] ${CMD[*]}"
128
+ "${CMD[@]}"
129
+ } >> "${CKPT_LOG}" 2>&1
130
+
131
+ EXIT_CODE=$?
132
+
133
+ if [ ${EXIT_CODE} -ne 0 ]; then
134
+ echo "[ERROR] checkpoint=${CKPT_NAME} sample_idx=${i} exit_code=${EXIT_CODE}" | tee -a "${FAILED_LOG}"
135
+ continue
136
+ else
137
+ echo "[OK] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${SUCCESS_LOG}"
138
+ fi
139
+ done
140
+
141
+ echo "[DONE] checkpoint=${CKPT_NAME}" | tee -a "${SUCCESS_LOG}"
142
+ done
143
+
144
+ echo "======================================" | tee -a "${SUCCESS_LOG}"
145
+ echo "Batch inference finished at $(date)" | tee -a "${SUCCESS_LOG}"
146
+ echo "Success log: ${SUCCESS_LOG}" | tee -a "${SUCCESS_LOG}"
147
+ echo "Failed log: ${FAILED_LOG}" | tee -a "${SUCCESS_LOG}"
148
+ echo "All done."
output_qwen3_plain_ar/checkpoint-17233/rng_state_18.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd17b3bc8809e659c44ad5767c09471365bf1aaf99af587ed6ceb8212a83647f
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_19.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48baee25b7a07901dc04101212281542a2c94ba2e39212de0833d10d73bcff15
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab8c0ea40d6071acfd801325ef5fce06795f95b25fd7cf033726000b4174406a
3
+ size 16325
output_qwen3_plain_ar/checkpoint-17233/rng_state_20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91b4d0f639063a020849c066d4bdf38ac11a72c2fb705b5d36090aaec746bc0c
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_21.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1f3338bc26fbbe65cb13f269b236abdee71e69af31f680d19dd714839c3cb60
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_22.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9190d2dd67803c301fed8e4bb7793ee439602ceff38c5fd9b3ddfa7f8b3ebee1
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_23.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcc224b33040d41584dd3cbdfb1bca14ca82f95e91c4ce8da7d1fa5af39fb996
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b75d0da8e588491425dfede07d4fcf9a4236407463ca5cdf5cdf29b1a8fca5d1
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f509a161c9a3c4078f25919a5284e639b26b17196b37c50b6359eb10892a6477
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_26.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa1a5e34bfbf6b2ef864a91a119b55c4972c987b01ac5ce6edbd95a8dbfaf56b
3
+ size 16340
output_qwen3_plain_ar/checkpoint-17233/rng_state_27.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9105492960e0409542231deb8d2dc0148b7dbd7179c17107bf77eab0f18f32f
3
+ size 16340
output_qwen3_plain_ar/checkpoint-18140/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
output_qwen3_plain_ar/checkpoint-18140/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-2721/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-2721/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-2721/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step2721
output_qwen3_plain_ar/checkpoint-2721/trainer_state.json ADDED
@@ -0,0 +1,1938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 3.0,
6
+ "eval_steps": 500,
7
+ "global_step": 2721,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.011028398125172319,
14
+ "grad_norm": 435.2422180175781,
15
+ "learning_rate": 9e-07,
16
+ "loss": 20.84569549560547,
17
+ "step": 10
18
+ },
19
+ {
20
+ "epoch": 0.022056796250344637,
21
+ "grad_norm": 141.7341766357422,
22
+ "learning_rate": 1.9e-06,
23
+ "loss": 18.69615936279297,
24
+ "step": 20
25
+ },
26
+ {
27
+ "epoch": 0.033085194375516956,
28
+ "grad_norm": 74.42520904541016,
29
+ "learning_rate": 2.9e-06,
30
+ "loss": 16.079673767089844,
31
+ "step": 30
32
+ },
33
+ {
34
+ "epoch": 0.044113592500689275,
35
+ "grad_norm": 24.73248863220215,
36
+ "learning_rate": 3.9e-06,
37
+ "loss": 13.684315490722657,
38
+ "step": 40
39
+ },
40
+ {
41
+ "epoch": 0.055141990625861594,
42
+ "grad_norm": 7.049101829528809,
43
+ "learning_rate": 4.9000000000000005e-06,
44
+ "loss": 12.474874877929688,
45
+ "step": 50
46
+ },
47
+ {
48
+ "epoch": 0.06617038875103391,
49
+ "grad_norm": 2.3411474227905273,
50
+ "learning_rate": 5.9e-06,
51
+ "loss": 12.072142028808594,
52
+ "step": 60
53
+ },
54
+ {
55
+ "epoch": 0.07719878687620624,
56
+ "grad_norm": 1.126215934753418,
57
+ "learning_rate": 6.900000000000001e-06,
58
+ "loss": 11.938906860351562,
59
+ "step": 70
60
+ },
61
+ {
62
+ "epoch": 0.08822718500137855,
63
+ "grad_norm": 1.2050226926803589,
64
+ "learning_rate": 7.9e-06,
65
+ "loss": 11.81988296508789,
66
+ "step": 80
67
+ },
68
+ {
69
+ "epoch": 0.09925558312655088,
70
+ "grad_norm": 1.444793462753296,
71
+ "learning_rate": 8.9e-06,
72
+ "loss": 11.602033996582032,
73
+ "step": 90
74
+ },
75
+ {
76
+ "epoch": 0.11028398125172319,
77
+ "grad_norm": 5.791665077209473,
78
+ "learning_rate": 9.900000000000002e-06,
79
+ "loss": 11.201815032958985,
80
+ "step": 100
81
+ },
82
+ {
83
+ "epoch": 0.12131237937689551,
84
+ "grad_norm": 9.492277145385742,
85
+ "learning_rate": 1.09e-05,
86
+ "loss": 10.535708618164062,
87
+ "step": 110
88
+ },
89
+ {
90
+ "epoch": 0.13234077750206782,
91
+ "grad_norm": 2.7546133995056152,
92
+ "learning_rate": 1.19e-05,
93
+ "loss": 9.847169494628906,
94
+ "step": 120
95
+ },
96
+ {
97
+ "epoch": 0.14336917562724014,
98
+ "grad_norm": 1.0953313112258911,
99
+ "learning_rate": 1.29e-05,
100
+ "loss": 9.429026031494141,
101
+ "step": 130
102
+ },
103
+ {
104
+ "epoch": 0.15439757375241248,
105
+ "grad_norm": 0.7153559327125549,
106
+ "learning_rate": 1.3900000000000002e-05,
107
+ "loss": 9.266969299316406,
108
+ "step": 140
109
+ },
110
+ {
111
+ "epoch": 0.1654259718775848,
112
+ "grad_norm": 0.5888933539390564,
113
+ "learning_rate": 1.49e-05,
114
+ "loss": 9.1935546875,
115
+ "step": 150
116
+ },
117
+ {
118
+ "epoch": 0.1764543700027571,
119
+ "grad_norm": 0.4850365221500397,
120
+ "learning_rate": 1.59e-05,
121
+ "loss": 9.19604034423828,
122
+ "step": 160
123
+ },
124
+ {
125
+ "epoch": 0.1874827681279294,
126
+ "grad_norm": 0.5772538185119629,
127
+ "learning_rate": 1.69e-05,
128
+ "loss": 9.17010726928711,
129
+ "step": 170
130
+ },
131
+ {
132
+ "epoch": 0.19851116625310175,
133
+ "grad_norm": 0.4283920228481293,
134
+ "learning_rate": 1.79e-05,
135
+ "loss": 9.172830200195312,
136
+ "step": 180
137
+ },
138
+ {
139
+ "epoch": 0.20953956437827406,
140
+ "grad_norm": 0.8650698065757751,
141
+ "learning_rate": 1.8900000000000002e-05,
142
+ "loss": 9.154988098144532,
143
+ "step": 190
144
+ },
145
+ {
146
+ "epoch": 0.22056796250344637,
147
+ "grad_norm": 0.42017608880996704,
148
+ "learning_rate": 1.9900000000000003e-05,
149
+ "loss": 9.146849060058594,
150
+ "step": 200
151
+ },
152
+ {
153
+ "epoch": 0.23159636062861869,
154
+ "grad_norm": 0.9125994443893433,
155
+ "learning_rate": 2.09e-05,
156
+ "loss": 9.164442443847657,
157
+ "step": 210
158
+ },
159
+ {
160
+ "epoch": 0.24262475875379103,
161
+ "grad_norm": 0.6468876004219055,
162
+ "learning_rate": 2.19e-05,
163
+ "loss": 9.159596252441407,
164
+ "step": 220
165
+ },
166
+ {
167
+ "epoch": 0.25365315687896334,
168
+ "grad_norm": 0.4124819338321686,
169
+ "learning_rate": 2.29e-05,
170
+ "loss": 9.13860626220703,
171
+ "step": 230
172
+ },
173
+ {
174
+ "epoch": 0.26468155500413565,
175
+ "grad_norm": 1.990302562713623,
176
+ "learning_rate": 2.39e-05,
177
+ "loss": 9.145040893554688,
178
+ "step": 240
179
+ },
180
+ {
181
+ "epoch": 0.27570995312930796,
182
+ "grad_norm": 0.7875277400016785,
183
+ "learning_rate": 2.4900000000000002e-05,
184
+ "loss": 9.152925109863281,
185
+ "step": 250
186
+ },
187
+ {
188
+ "epoch": 0.2867383512544803,
189
+ "grad_norm": 0.8343706130981445,
190
+ "learning_rate": 2.5900000000000003e-05,
191
+ "loss": 9.132975769042968,
192
+ "step": 260
193
+ },
194
+ {
195
+ "epoch": 0.2977667493796526,
196
+ "grad_norm": 3.00996470451355,
197
+ "learning_rate": 2.6900000000000003e-05,
198
+ "loss": 9.097848510742187,
199
+ "step": 270
200
+ },
201
+ {
202
+ "epoch": 0.30879514750482495,
203
+ "grad_norm": 2.4282069206237793,
204
+ "learning_rate": 2.7900000000000004e-05,
205
+ "loss": 9.042235565185546,
206
+ "step": 280
207
+ },
208
+ {
209
+ "epoch": 0.31982354562999726,
210
+ "grad_norm": 4.171019554138184,
211
+ "learning_rate": 2.8899999999999998e-05,
212
+ "loss": 8.927298736572265,
213
+ "step": 290
214
+ },
215
+ {
216
+ "epoch": 0.3308519437551696,
217
+ "grad_norm": 2.197887659072876,
218
+ "learning_rate": 2.9900000000000002e-05,
219
+ "loss": 8.805252075195312,
220
+ "step": 300
221
+ },
222
+ {
223
+ "epoch": 0.3418803418803419,
224
+ "grad_norm": 10.306541442871094,
225
+ "learning_rate": 3.09e-05,
226
+ "loss": 8.673678588867187,
227
+ "step": 310
228
+ },
229
+ {
230
+ "epoch": 0.3529087400055142,
231
+ "grad_norm": 8.463860511779785,
232
+ "learning_rate": 3.19e-05,
233
+ "loss": 8.570347595214844,
234
+ "step": 320
235
+ },
236
+ {
237
+ "epoch": 0.3639371381306865,
238
+ "grad_norm": 3.999753475189209,
239
+ "learning_rate": 3.29e-05,
240
+ "loss": 8.429109191894531,
241
+ "step": 330
242
+ },
243
+ {
244
+ "epoch": 0.3749655362558588,
245
+ "grad_norm": 5.259007930755615,
246
+ "learning_rate": 3.3900000000000004e-05,
247
+ "loss": 8.334149169921876,
248
+ "step": 340
249
+ },
250
+ {
251
+ "epoch": 0.38599393438103113,
252
+ "grad_norm": 8.362598419189453,
253
+ "learning_rate": 3.49e-05,
254
+ "loss": 8.196139526367187,
255
+ "step": 350
256
+ },
257
+ {
258
+ "epoch": 0.3970223325062035,
259
+ "grad_norm": 10.273512840270996,
260
+ "learning_rate": 3.59e-05,
261
+ "loss": 8.040153503417969,
262
+ "step": 360
263
+ },
264
+ {
265
+ "epoch": 0.4080507306313758,
266
+ "grad_norm": 5.111108303070068,
267
+ "learning_rate": 3.69e-05,
268
+ "loss": 7.866473388671875,
269
+ "step": 370
270
+ },
271
+ {
272
+ "epoch": 0.4190791287565481,
273
+ "grad_norm": 9.192107200622559,
274
+ "learning_rate": 3.79e-05,
275
+ "loss": 7.695774841308594,
276
+ "step": 380
277
+ },
278
+ {
279
+ "epoch": 0.43010752688172044,
280
+ "grad_norm": 5.393336772918701,
281
+ "learning_rate": 3.8900000000000004e-05,
282
+ "loss": 7.498152160644532,
283
+ "step": 390
284
+ },
285
+ {
286
+ "epoch": 0.44113592500689275,
287
+ "grad_norm": 10.53490161895752,
288
+ "learning_rate": 3.99e-05,
289
+ "loss": 7.270246887207032,
290
+ "step": 400
291
+ },
292
+ {
293
+ "epoch": 0.45216432313206506,
294
+ "grad_norm": 6.174643516540527,
295
+ "learning_rate": 4.09e-05,
296
+ "loss": 7.127191162109375,
297
+ "step": 410
298
+ },
299
+ {
300
+ "epoch": 0.46319272125723737,
301
+ "grad_norm": 4.522936820983887,
302
+ "learning_rate": 4.19e-05,
303
+ "loss": 6.871500396728516,
304
+ "step": 420
305
+ },
306
+ {
307
+ "epoch": 0.4742211193824097,
308
+ "grad_norm": 4.3594207763671875,
309
+ "learning_rate": 4.29e-05,
310
+ "loss": 6.702586364746094,
311
+ "step": 430
312
+ },
313
+ {
314
+ "epoch": 0.48524951750758205,
315
+ "grad_norm": 5.950730323791504,
316
+ "learning_rate": 4.39e-05,
317
+ "loss": 6.493560791015625,
318
+ "step": 440
319
+ },
320
+ {
321
+ "epoch": 0.49627791563275436,
322
+ "grad_norm": 6.233413219451904,
323
+ "learning_rate": 4.49e-05,
324
+ "loss": 6.293489074707031,
325
+ "step": 450
326
+ },
327
+ {
328
+ "epoch": 0.5073063137579267,
329
+ "grad_norm": 7.656834125518799,
330
+ "learning_rate": 4.5900000000000004e-05,
331
+ "loss": 6.102347946166992,
332
+ "step": 460
333
+ },
334
+ {
335
+ "epoch": 0.518334711883099,
336
+ "grad_norm": 4.319094657897949,
337
+ "learning_rate": 4.69e-05,
338
+ "loss": 5.928083419799805,
339
+ "step": 470
340
+ },
341
+ {
342
+ "epoch": 0.5293631100082713,
343
+ "grad_norm": 5.585537433624268,
344
+ "learning_rate": 4.79e-05,
345
+ "loss": 5.77436637878418,
346
+ "step": 480
347
+ },
348
+ {
349
+ "epoch": 0.5403915081334436,
350
+ "grad_norm": 5.104014873504639,
351
+ "learning_rate": 4.89e-05,
352
+ "loss": 5.636859130859375,
353
+ "step": 490
354
+ },
355
+ {
356
+ "epoch": 0.5514199062586159,
357
+ "grad_norm": 5.453028202056885,
358
+ "learning_rate": 4.99e-05,
359
+ "loss": 5.507636260986328,
360
+ "step": 500
361
+ },
362
+ {
363
+ "epoch": 0.5624483043837882,
364
+ "grad_norm": 7.728854179382324,
365
+ "learning_rate": 5.0900000000000004e-05,
366
+ "loss": 5.411964416503906,
367
+ "step": 510
368
+ },
369
+ {
370
+ "epoch": 0.5734767025089605,
371
+ "grad_norm": 4.50288724899292,
372
+ "learning_rate": 5.19e-05,
373
+ "loss": 5.295291900634766,
374
+ "step": 520
375
+ },
376
+ {
377
+ "epoch": 0.5845051006341329,
378
+ "grad_norm": 4.245919704437256,
379
+ "learning_rate": 5.2900000000000005e-05,
380
+ "loss": 5.194162750244141,
381
+ "step": 530
382
+ },
383
+ {
384
+ "epoch": 0.5955334987593052,
385
+ "grad_norm": 6.278975963592529,
386
+ "learning_rate": 5.390000000000001e-05,
387
+ "loss": 5.113618087768555,
388
+ "step": 540
389
+ },
390
+ {
391
+ "epoch": 0.6065618968844775,
392
+ "grad_norm": 4.214662075042725,
393
+ "learning_rate": 5.4900000000000006e-05,
394
+ "loss": 5.038372039794922,
395
+ "step": 550
396
+ },
397
+ {
398
+ "epoch": 0.6175902950096499,
399
+ "grad_norm": 3.5404605865478516,
400
+ "learning_rate": 5.590000000000001e-05,
401
+ "loss": 4.935391235351562,
402
+ "step": 560
403
+ },
404
+ {
405
+ "epoch": 0.6286186931348222,
406
+ "grad_norm": 3.6460280418395996,
407
+ "learning_rate": 5.69e-05,
408
+ "loss": 4.896538543701172,
409
+ "step": 570
410
+ },
411
+ {
412
+ "epoch": 0.6396470912599945,
413
+ "grad_norm": 5.254800796508789,
414
+ "learning_rate": 5.79e-05,
415
+ "loss": 4.829419708251953,
416
+ "step": 580
417
+ },
418
+ {
419
+ "epoch": 0.6506754893851668,
420
+ "grad_norm": 5.132180690765381,
421
+ "learning_rate": 5.89e-05,
422
+ "loss": 4.793368148803711,
423
+ "step": 590
424
+ },
425
+ {
426
+ "epoch": 0.6617038875103392,
427
+ "grad_norm": 4.222960948944092,
428
+ "learning_rate": 5.99e-05,
429
+ "loss": 4.746239852905274,
430
+ "step": 600
431
+ },
432
+ {
433
+ "epoch": 0.6727322856355115,
434
+ "grad_norm": 4.070414066314697,
435
+ "learning_rate": 6.09e-05,
436
+ "loss": 4.688523864746093,
437
+ "step": 610
438
+ },
439
+ {
440
+ "epoch": 0.6837606837606838,
441
+ "grad_norm": 3.4652583599090576,
442
+ "learning_rate": 6.19e-05,
443
+ "loss": 4.692922973632813,
444
+ "step": 620
445
+ },
446
+ {
447
+ "epoch": 0.6947890818858561,
448
+ "grad_norm": 4.559128284454346,
449
+ "learning_rate": 6.29e-05,
450
+ "loss": 4.639920043945312,
451
+ "step": 630
452
+ },
453
+ {
454
+ "epoch": 0.7058174800110284,
455
+ "grad_norm": 3.197758436203003,
456
+ "learning_rate": 6.390000000000001e-05,
457
+ "loss": 4.601907348632812,
458
+ "step": 640
459
+ },
460
+ {
461
+ "epoch": 0.7168458781362007,
462
+ "grad_norm": 4.209578514099121,
463
+ "learning_rate": 6.49e-05,
464
+ "loss": 4.56639404296875,
465
+ "step": 650
466
+ },
467
+ {
468
+ "epoch": 0.727874276261373,
469
+ "grad_norm": 3.701484203338623,
470
+ "learning_rate": 6.59e-05,
471
+ "loss": 4.545608901977539,
472
+ "step": 660
473
+ },
474
+ {
475
+ "epoch": 0.7389026743865453,
476
+ "grad_norm": 3.951927900314331,
477
+ "learning_rate": 6.690000000000001e-05,
478
+ "loss": 4.493326187133789,
479
+ "step": 670
480
+ },
481
+ {
482
+ "epoch": 0.7499310725117176,
483
+ "grad_norm": 4.219130039215088,
484
+ "learning_rate": 6.790000000000001e-05,
485
+ "loss": 4.482691955566406,
486
+ "step": 680
487
+ },
488
+ {
489
+ "epoch": 0.76095947063689,
490
+ "grad_norm": 6.267204284667969,
491
+ "learning_rate": 6.89e-05,
492
+ "loss": 4.4599052429199215,
493
+ "step": 690
494
+ },
495
+ {
496
+ "epoch": 0.7719878687620623,
497
+ "grad_norm": 3.367382764816284,
498
+ "learning_rate": 6.99e-05,
499
+ "loss": 4.429808807373047,
500
+ "step": 700
501
+ },
502
+ {
503
+ "epoch": 0.7830162668872346,
504
+ "grad_norm": 3.8906455039978027,
505
+ "learning_rate": 7.09e-05,
506
+ "loss": 4.4144752502441404,
507
+ "step": 710
508
+ },
509
+ {
510
+ "epoch": 0.794044665012407,
511
+ "grad_norm": 6.759398460388184,
512
+ "learning_rate": 7.19e-05,
513
+ "loss": 4.385488891601563,
514
+ "step": 720
515
+ },
516
+ {
517
+ "epoch": 0.8050730631375793,
518
+ "grad_norm": 3.520167350769043,
519
+ "learning_rate": 7.29e-05,
520
+ "loss": 4.397706985473633,
521
+ "step": 730
522
+ },
523
+ {
524
+ "epoch": 0.8161014612627516,
525
+ "grad_norm": 2.7510974407196045,
526
+ "learning_rate": 7.390000000000001e-05,
527
+ "loss": 4.374617385864258,
528
+ "step": 740
529
+ },
530
+ {
531
+ "epoch": 0.8271298593879239,
532
+ "grad_norm": 4.395699977874756,
533
+ "learning_rate": 7.49e-05,
534
+ "loss": 4.3302146911621096,
535
+ "step": 750
536
+ },
537
+ {
538
+ "epoch": 0.8381582575130962,
539
+ "grad_norm": 3.277766704559326,
540
+ "learning_rate": 7.59e-05,
541
+ "loss": 4.313335418701172,
542
+ "step": 760
543
+ },
544
+ {
545
+ "epoch": 0.8491866556382686,
546
+ "grad_norm": 2.466207981109619,
547
+ "learning_rate": 7.69e-05,
548
+ "loss": 4.3226570129394535,
549
+ "step": 770
550
+ },
551
+ {
552
+ "epoch": 0.8602150537634409,
553
+ "grad_norm": 3.637355327606201,
554
+ "learning_rate": 7.790000000000001e-05,
555
+ "loss": 4.295929718017578,
556
+ "step": 780
557
+ },
558
+ {
559
+ "epoch": 0.8712434518886132,
560
+ "grad_norm": 3.155527353286743,
561
+ "learning_rate": 7.890000000000001e-05,
562
+ "loss": 4.287591552734375,
563
+ "step": 790
564
+ },
565
+ {
566
+ "epoch": 0.8822718500137855,
567
+ "grad_norm": 3.593884229660034,
568
+ "learning_rate": 7.99e-05,
569
+ "loss": 4.267314147949219,
570
+ "step": 800
571
+ },
572
+ {
573
+ "epoch": 0.8933002481389578,
574
+ "grad_norm": 2.361081123352051,
575
+ "learning_rate": 8.090000000000001e-05,
576
+ "loss": 4.265741348266602,
577
+ "step": 810
578
+ },
579
+ {
580
+ "epoch": 0.9043286462641301,
581
+ "grad_norm": 2.7084105014801025,
582
+ "learning_rate": 8.19e-05,
583
+ "loss": 4.261878204345703,
584
+ "step": 820
585
+ },
586
+ {
587
+ "epoch": 0.9153570443893024,
588
+ "grad_norm": 3.6093873977661133,
589
+ "learning_rate": 8.29e-05,
590
+ "loss": 4.211677551269531,
591
+ "step": 830
592
+ },
593
+ {
594
+ "epoch": 0.9263854425144747,
595
+ "grad_norm": 3.9739396572113037,
596
+ "learning_rate": 8.39e-05,
597
+ "loss": 4.224007034301758,
598
+ "step": 840
599
+ },
600
+ {
601
+ "epoch": 0.9374138406396471,
602
+ "grad_norm": 2.174050807952881,
603
+ "learning_rate": 8.49e-05,
604
+ "loss": 4.211782836914063,
605
+ "step": 850
606
+ },
607
+ {
608
+ "epoch": 0.9484422387648194,
609
+ "grad_norm": 2.7151405811309814,
610
+ "learning_rate": 8.59e-05,
611
+ "loss": 4.204391098022461,
612
+ "step": 860
613
+ },
614
+ {
615
+ "epoch": 0.9594706368899917,
616
+ "grad_norm": 3.7480661869049072,
617
+ "learning_rate": 8.69e-05,
618
+ "loss": 4.175582504272461,
619
+ "step": 870
620
+ },
621
+ {
622
+ "epoch": 0.9704990350151641,
623
+ "grad_norm": 3.1127700805664062,
624
+ "learning_rate": 8.790000000000001e-05,
625
+ "loss": 4.183733749389648,
626
+ "step": 880
627
+ },
628
+ {
629
+ "epoch": 0.9815274331403364,
630
+ "grad_norm": 2.750716209411621,
631
+ "learning_rate": 8.89e-05,
632
+ "loss": 4.167971801757813,
633
+ "step": 890
634
+ },
635
+ {
636
+ "epoch": 0.9925558312655087,
637
+ "grad_norm": 4.02509880065918,
638
+ "learning_rate": 8.99e-05,
639
+ "loss": 4.170472717285156,
640
+ "step": 900
641
+ },
642
+ {
643
+ "epoch": 1.0033085194375517,
644
+ "grad_norm": 3.0058505535125732,
645
+ "learning_rate": 9.090000000000001e-05,
646
+ "loss": 4.1449127197265625,
647
+ "step": 910
648
+ },
649
+ {
650
+ "epoch": 1.014336917562724,
651
+ "grad_norm": 2.553403377532959,
652
+ "learning_rate": 9.190000000000001e-05,
653
+ "loss": 4.1404258728027346,
654
+ "step": 920
655
+ },
656
+ {
657
+ "epoch": 1.0253653156878964,
658
+ "grad_norm": 2.8066084384918213,
659
+ "learning_rate": 9.290000000000001e-05,
660
+ "loss": 4.110780334472656,
661
+ "step": 930
662
+ },
663
+ {
664
+ "epoch": 1.0363937138130686,
665
+ "grad_norm": 3.904608726501465,
666
+ "learning_rate": 9.39e-05,
667
+ "loss": 4.134862899780273,
668
+ "step": 940
669
+ },
670
+ {
671
+ "epoch": 1.047422111938241,
672
+ "grad_norm": 2.217729330062866,
673
+ "learning_rate": 9.49e-05,
674
+ "loss": 4.112079620361328,
675
+ "step": 950
676
+ },
677
+ {
678
+ "epoch": 1.0584505100634134,
679
+ "grad_norm": 2.498760938644409,
680
+ "learning_rate": 9.59e-05,
681
+ "loss": 4.097566986083985,
682
+ "step": 960
683
+ },
684
+ {
685
+ "epoch": 1.0694789081885856,
686
+ "grad_norm": 3.577143907546997,
687
+ "learning_rate": 9.69e-05,
688
+ "loss": 4.081307220458984,
689
+ "step": 970
690
+ },
691
+ {
692
+ "epoch": 1.080507306313758,
693
+ "grad_norm": 3.283250570297241,
694
+ "learning_rate": 9.790000000000001e-05,
695
+ "loss": 4.103987503051758,
696
+ "step": 980
697
+ },
698
+ {
699
+ "epoch": 1.0915357044389302,
700
+ "grad_norm": 2.1897776126861572,
701
+ "learning_rate": 9.89e-05,
702
+ "loss": 4.084938812255859,
703
+ "step": 990
704
+ },
705
+ {
706
+ "epoch": 1.1025641025641026,
707
+ "grad_norm": 2.6925997734069824,
708
+ "learning_rate": 9.99e-05,
709
+ "loss": 4.058921051025391,
710
+ "step": 1000
711
+ },
712
+ {
713
+ "epoch": 1.1135925006892748,
714
+ "grad_norm": 3.4118456840515137,
715
+ "learning_rate": 9.994749124854142e-05,
716
+ "loss": 4.061585235595703,
717
+ "step": 1010
718
+ },
719
+ {
720
+ "epoch": 1.1246208988144473,
721
+ "grad_norm": 2.6139297485351562,
722
+ "learning_rate": 9.988914819136523e-05,
723
+ "loss": 4.070050048828125,
724
+ "step": 1020
725
+ },
726
+ {
727
+ "epoch": 1.1356492969396195,
728
+ "grad_norm": 1.8616399765014648,
729
+ "learning_rate": 9.983080513418903e-05,
730
+ "loss": 4.0413330078125,
731
+ "step": 1030
732
+ },
733
+ {
734
+ "epoch": 1.146677695064792,
735
+ "grad_norm": 2.361706018447876,
736
+ "learning_rate": 9.977246207701284e-05,
737
+ "loss": 4.023075866699219,
738
+ "step": 1040
739
+ },
740
+ {
741
+ "epoch": 1.157706093189964,
742
+ "grad_norm": 3.815014123916626,
743
+ "learning_rate": 9.971411901983664e-05,
744
+ "loss": 4.036756134033203,
745
+ "step": 1050
746
+ },
747
+ {
748
+ "epoch": 1.1687344913151365,
749
+ "grad_norm": 2.4410274028778076,
750
+ "learning_rate": 9.965577596266045e-05,
751
+ "loss": 4.020483779907226,
752
+ "step": 1060
753
+ },
754
+ {
755
+ "epoch": 1.1797628894403087,
756
+ "grad_norm": 2.768084764480591,
757
+ "learning_rate": 9.959743290548426e-05,
758
+ "loss": 4.021839141845703,
759
+ "step": 1070
760
+ },
761
+ {
762
+ "epoch": 1.1907912875654811,
763
+ "grad_norm": 1.9342570304870605,
764
+ "learning_rate": 9.953908984830806e-05,
765
+ "loss": 4.026360321044922,
766
+ "step": 1080
767
+ },
768
+ {
769
+ "epoch": 1.2018196856906533,
770
+ "grad_norm": 2.8184762001037598,
771
+ "learning_rate": 9.948074679113187e-05,
772
+ "loss": 4.007581329345703,
773
+ "step": 1090
774
+ },
775
+ {
776
+ "epoch": 1.2128480838158258,
777
+ "grad_norm": 3.2656188011169434,
778
+ "learning_rate": 9.942240373395566e-05,
779
+ "loss": 3.9965087890625,
780
+ "step": 1100
781
+ },
782
+ {
783
+ "epoch": 1.223876481940998,
784
+ "grad_norm": 2.4359538555145264,
785
+ "learning_rate": 9.936406067677947e-05,
786
+ "loss": 3.9959388732910157,
787
+ "step": 1110
788
+ },
789
+ {
790
+ "epoch": 1.2349048800661704,
791
+ "grad_norm": 1.9357632398605347,
792
+ "learning_rate": 9.930571761960327e-05,
793
+ "loss": 3.9851417541503906,
794
+ "step": 1120
795
+ },
796
+ {
797
+ "epoch": 1.2459332781913428,
798
+ "grad_norm": 2.1269352436065674,
799
+ "learning_rate": 9.924737456242708e-05,
800
+ "loss": 3.9773223876953123,
801
+ "step": 1130
802
+ },
803
+ {
804
+ "epoch": 1.256961676316515,
805
+ "grad_norm": 3.3491597175598145,
806
+ "learning_rate": 9.918903150525088e-05,
807
+ "loss": 3.9877471923828125,
808
+ "step": 1140
809
+ },
810
+ {
811
+ "epoch": 1.2679900744416872,
812
+ "grad_norm": 1.8646328449249268,
813
+ "learning_rate": 9.913068844807468e-05,
814
+ "loss": 3.9694965362548826,
815
+ "step": 1150
816
+ },
817
+ {
818
+ "epoch": 1.2790184725668596,
819
+ "grad_norm": 2.6204631328582764,
820
+ "learning_rate": 9.907234539089849e-05,
821
+ "loss": 3.9611881256103514,
822
+ "step": 1160
823
+ },
824
+ {
825
+ "epoch": 1.290046870692032,
826
+ "grad_norm": 1.872028112411499,
827
+ "learning_rate": 9.901400233372228e-05,
828
+ "loss": 3.964163970947266,
829
+ "step": 1170
830
+ },
831
+ {
832
+ "epoch": 1.3010752688172043,
833
+ "grad_norm": 3.490435838699341,
834
+ "learning_rate": 9.895565927654609e-05,
835
+ "loss": 3.959897994995117,
836
+ "step": 1180
837
+ },
838
+ {
839
+ "epoch": 1.3121036669423767,
840
+ "grad_norm": 2.862489700317383,
841
+ "learning_rate": 9.88973162193699e-05,
842
+ "loss": 3.9567939758300783,
843
+ "step": 1190
844
+ },
845
+ {
846
+ "epoch": 1.3231320650675489,
847
+ "grad_norm": 3.0570664405822754,
848
+ "learning_rate": 9.883897316219371e-05,
849
+ "loss": 3.9470645904541017,
850
+ "step": 1200
851
+ },
852
+ {
853
+ "epoch": 1.3341604631927213,
854
+ "grad_norm": 1.9254627227783203,
855
+ "learning_rate": 9.878063010501752e-05,
856
+ "loss": 3.9442317962646483,
857
+ "step": 1210
858
+ },
859
+ {
860
+ "epoch": 1.3451888613178935,
861
+ "grad_norm": 3.606224298477173,
862
+ "learning_rate": 9.872228704784131e-05,
863
+ "loss": 3.9380733489990236,
864
+ "step": 1220
865
+ },
866
+ {
867
+ "epoch": 1.356217259443066,
868
+ "grad_norm": 2.1184027194976807,
869
+ "learning_rate": 9.866394399066512e-05,
870
+ "loss": 3.9452835083007813,
871
+ "step": 1230
872
+ },
873
+ {
874
+ "epoch": 1.3672456575682381,
875
+ "grad_norm": 1.8997142314910889,
876
+ "learning_rate": 9.860560093348892e-05,
877
+ "loss": 3.9270603179931642,
878
+ "step": 1240
879
+ },
880
+ {
881
+ "epoch": 1.3782740556934105,
882
+ "grad_norm": 2.9672305583953857,
883
+ "learning_rate": 9.854725787631273e-05,
884
+ "loss": 3.9120155334472657,
885
+ "step": 1250
886
+ },
887
+ {
888
+ "epoch": 1.389302453818583,
889
+ "grad_norm": 1.9220951795578003,
890
+ "learning_rate": 9.848891481913652e-05,
891
+ "loss": 3.900279235839844,
892
+ "step": 1260
893
+ },
894
+ {
895
+ "epoch": 1.4003308519437552,
896
+ "grad_norm": 2.013521194458008,
897
+ "learning_rate": 9.843057176196033e-05,
898
+ "loss": 3.9147193908691404,
899
+ "step": 1270
900
+ },
901
+ {
902
+ "epoch": 1.4113592500689274,
903
+ "grad_norm": 1.451686143875122,
904
+ "learning_rate": 9.837222870478413e-05,
905
+ "loss": 3.906220245361328,
906
+ "step": 1280
907
+ },
908
+ {
909
+ "epoch": 1.4223876481940998,
910
+ "grad_norm": 4.606860637664795,
911
+ "learning_rate": 9.831388564760794e-05,
912
+ "loss": 3.905352020263672,
913
+ "step": 1290
914
+ },
915
+ {
916
+ "epoch": 1.4334160463192722,
917
+ "grad_norm": 1.779123306274414,
918
+ "learning_rate": 9.825554259043175e-05,
919
+ "loss": 3.9137496948242188,
920
+ "step": 1300
921
+ },
922
+ {
923
+ "epoch": 1.4444444444444444,
924
+ "grad_norm": 2.086585521697998,
925
+ "learning_rate": 9.819719953325554e-05,
926
+ "loss": 3.89554443359375,
927
+ "step": 1310
928
+ },
929
+ {
930
+ "epoch": 1.4554728425696168,
931
+ "grad_norm": 3.3514609336853027,
932
+ "learning_rate": 9.813885647607935e-05,
933
+ "loss": 3.8901123046875,
934
+ "step": 1320
935
+ },
936
+ {
937
+ "epoch": 1.466501240694789,
938
+ "grad_norm": 2.1145269870758057,
939
+ "learning_rate": 9.808051341890316e-05,
940
+ "loss": 3.8892486572265623,
941
+ "step": 1330
942
+ },
943
+ {
944
+ "epoch": 1.4775296388199615,
945
+ "grad_norm": 1.5503329038619995,
946
+ "learning_rate": 9.802217036172697e-05,
947
+ "loss": 3.8922355651855467,
948
+ "step": 1340
949
+ },
950
+ {
951
+ "epoch": 1.4885580369451337,
952
+ "grad_norm": 2.3014304637908936,
953
+ "learning_rate": 9.796382730455076e-05,
954
+ "loss": 3.8860099792480467,
955
+ "step": 1350
956
+ },
957
+ {
958
+ "epoch": 1.499586435070306,
959
+ "grad_norm": 1.9633557796478271,
960
+ "learning_rate": 9.790548424737457e-05,
961
+ "loss": 3.875183868408203,
962
+ "step": 1360
963
+ },
964
+ {
965
+ "epoch": 1.5106148331954783,
966
+ "grad_norm": 2.228351593017578,
967
+ "learning_rate": 9.784714119019837e-05,
968
+ "loss": 3.8726768493652344,
969
+ "step": 1370
970
+ },
971
+ {
972
+ "epoch": 1.5216432313206507,
973
+ "grad_norm": 3.0888657569885254,
974
+ "learning_rate": 9.778879813302218e-05,
975
+ "loss": 3.872690963745117,
976
+ "step": 1380
977
+ },
978
+ {
979
+ "epoch": 1.5326716294458231,
980
+ "grad_norm": 2.0078868865966797,
981
+ "learning_rate": 9.773045507584599e-05,
982
+ "loss": 3.8612388610839843,
983
+ "step": 1390
984
+ },
985
+ {
986
+ "epoch": 1.5437000275709953,
987
+ "grad_norm": 2.1966569423675537,
988
+ "learning_rate": 9.767211201866978e-05,
989
+ "loss": 3.8649852752685545,
990
+ "step": 1400
991
+ },
992
+ {
993
+ "epoch": 1.5547284256961675,
994
+ "grad_norm": 2.1047487258911133,
995
+ "learning_rate": 9.761376896149359e-05,
996
+ "loss": 3.8632328033447267,
997
+ "step": 1410
998
+ },
999
+ {
1000
+ "epoch": 1.56575682382134,
1001
+ "grad_norm": 1.9347233772277832,
1002
+ "learning_rate": 9.755542590431739e-05,
1003
+ "loss": 3.8362571716308596,
1004
+ "step": 1420
1005
+ },
1006
+ {
1007
+ "epoch": 1.5767852219465124,
1008
+ "grad_norm": 1.7961437702178955,
1009
+ "learning_rate": 9.74970828471412e-05,
1010
+ "loss": 3.8461585998535157,
1011
+ "step": 1430
1012
+ },
1013
+ {
1014
+ "epoch": 1.5878136200716846,
1015
+ "grad_norm": 2.4657342433929443,
1016
+ "learning_rate": 9.743873978996499e-05,
1017
+ "loss": 3.842551040649414,
1018
+ "step": 1440
1019
+ },
1020
+ {
1021
+ "epoch": 1.5988420181968568,
1022
+ "grad_norm": 2.043138027191162,
1023
+ "learning_rate": 9.73803967327888e-05,
1024
+ "loss": 3.8387855529785155,
1025
+ "step": 1450
1026
+ },
1027
+ {
1028
+ "epoch": 1.6098704163220292,
1029
+ "grad_norm": 3.732532262802124,
1030
+ "learning_rate": 9.732205367561261e-05,
1031
+ "loss": 3.8399681091308593,
1032
+ "step": 1460
1033
+ },
1034
+ {
1035
+ "epoch": 1.6208988144472016,
1036
+ "grad_norm": 2.43684720993042,
1037
+ "learning_rate": 9.726371061843642e-05,
1038
+ "loss": 3.8324966430664062,
1039
+ "step": 1470
1040
+ },
1041
+ {
1042
+ "epoch": 1.6319272125723738,
1043
+ "grad_norm": 2.4433460235595703,
1044
+ "learning_rate": 9.720536756126023e-05,
1045
+ "loss": 3.817783737182617,
1046
+ "step": 1480
1047
+ },
1048
+ {
1049
+ "epoch": 1.642955610697546,
1050
+ "grad_norm": 2.1049606800079346,
1051
+ "learning_rate": 9.714702450408402e-05,
1052
+ "loss": 3.804280090332031,
1053
+ "step": 1490
1054
+ },
1055
+ {
1056
+ "epoch": 1.6539840088227185,
1057
+ "grad_norm": 3.529686450958252,
1058
+ "learning_rate": 9.708868144690783e-05,
1059
+ "loss": 3.805449295043945,
1060
+ "step": 1500
1061
+ },
1062
+ {
1063
+ "epoch": 1.6650124069478909,
1064
+ "grad_norm": 2.0984089374542236,
1065
+ "learning_rate": 9.703033838973162e-05,
1066
+ "loss": 3.788246917724609,
1067
+ "step": 1510
1068
+ },
1069
+ {
1070
+ "epoch": 1.6760408050730633,
1071
+ "grad_norm": 1.9434291124343872,
1072
+ "learning_rate": 9.697199533255543e-05,
1073
+ "loss": 3.7875442504882812,
1074
+ "step": 1520
1075
+ },
1076
+ {
1077
+ "epoch": 1.6870692031982355,
1078
+ "grad_norm": 1.99173903465271,
1079
+ "learning_rate": 9.691365227537923e-05,
1080
+ "loss": 3.7807193756103517,
1081
+ "step": 1530
1082
+ },
1083
+ {
1084
+ "epoch": 1.6980976013234077,
1085
+ "grad_norm": 2.5006911754608154,
1086
+ "learning_rate": 9.685530921820304e-05,
1087
+ "loss": 3.744763946533203,
1088
+ "step": 1540
1089
+ },
1090
+ {
1091
+ "epoch": 1.7091259994485801,
1092
+ "grad_norm": 2.1816165447235107,
1093
+ "learning_rate": 9.679696616102685e-05,
1094
+ "loss": 3.760245513916016,
1095
+ "step": 1550
1096
+ },
1097
+ {
1098
+ "epoch": 1.7201543975737525,
1099
+ "grad_norm": 2.123291492462158,
1100
+ "learning_rate": 9.673862310385064e-05,
1101
+ "loss": 3.738916778564453,
1102
+ "step": 1560
1103
+ },
1104
+ {
1105
+ "epoch": 1.7311827956989247,
1106
+ "grad_norm": 2.378187894821167,
1107
+ "learning_rate": 9.668028004667445e-05,
1108
+ "loss": 3.734139251708984,
1109
+ "step": 1570
1110
+ },
1111
+ {
1112
+ "epoch": 1.742211193824097,
1113
+ "grad_norm": 2.54819393157959,
1114
+ "learning_rate": 9.662193698949825e-05,
1115
+ "loss": 3.715302276611328,
1116
+ "step": 1580
1117
+ },
1118
+ {
1119
+ "epoch": 1.7532395919492694,
1120
+ "grad_norm": 4.285822868347168,
1121
+ "learning_rate": 9.656359393232206e-05,
1122
+ "loss": 3.72213134765625,
1123
+ "step": 1590
1124
+ },
1125
+ {
1126
+ "epoch": 1.7642679900744418,
1127
+ "grad_norm": 1.8676700592041016,
1128
+ "learning_rate": 9.650525087514586e-05,
1129
+ "loss": 3.7252479553222657,
1130
+ "step": 1600
1131
+ },
1132
+ {
1133
+ "epoch": 1.775296388199614,
1134
+ "grad_norm": 1.6977792978286743,
1135
+ "learning_rate": 9.644690781796967e-05,
1136
+ "loss": 3.704994964599609,
1137
+ "step": 1610
1138
+ },
1139
+ {
1140
+ "epoch": 1.7863247863247862,
1141
+ "grad_norm": 1.8334232568740845,
1142
+ "learning_rate": 9.638856476079347e-05,
1143
+ "loss": 3.6980815887451173,
1144
+ "step": 1620
1145
+ },
1146
+ {
1147
+ "epoch": 1.7973531844499586,
1148
+ "grad_norm": 2.6574559211730957,
1149
+ "learning_rate": 9.633022170361728e-05,
1150
+ "loss": 3.683759307861328,
1151
+ "step": 1630
1152
+ },
1153
+ {
1154
+ "epoch": 1.808381582575131,
1155
+ "grad_norm": 2.085084915161133,
1156
+ "learning_rate": 9.627187864644109e-05,
1157
+ "loss": 3.67755126953125,
1158
+ "step": 1640
1159
+ },
1160
+ {
1161
+ "epoch": 1.8194099807003032,
1162
+ "grad_norm": 1.685441017150879,
1163
+ "learning_rate": 9.621353558926488e-05,
1164
+ "loss": 3.656099319458008,
1165
+ "step": 1650
1166
+ },
1167
+ {
1168
+ "epoch": 1.8304383788254754,
1169
+ "grad_norm": 2.4462475776672363,
1170
+ "learning_rate": 9.615519253208869e-05,
1171
+ "loss": 3.668656921386719,
1172
+ "step": 1660
1173
+ },
1174
+ {
1175
+ "epoch": 1.8414667769506479,
1176
+ "grad_norm": 1.54155433177948,
1177
+ "learning_rate": 9.609684947491249e-05,
1178
+ "loss": 3.66968994140625,
1179
+ "step": 1670
1180
+ },
1181
+ {
1182
+ "epoch": 1.8524951750758203,
1183
+ "grad_norm": 3.862130880355835,
1184
+ "learning_rate": 9.60385064177363e-05,
1185
+ "loss": 3.6412506103515625,
1186
+ "step": 1680
1187
+ },
1188
+ {
1189
+ "epoch": 1.8635235732009927,
1190
+ "grad_norm": 1.7317070960998535,
1191
+ "learning_rate": 9.598016336056009e-05,
1192
+ "loss": 3.639806365966797,
1193
+ "step": 1690
1194
+ },
1195
+ {
1196
+ "epoch": 1.874551971326165,
1197
+ "grad_norm": 2.2640931606292725,
1198
+ "learning_rate": 9.59218203033839e-05,
1199
+ "loss": 3.6341064453125,
1200
+ "step": 1700
1201
+ },
1202
+ {
1203
+ "epoch": 1.8855803694513371,
1204
+ "grad_norm": 3.653146743774414,
1205
+ "learning_rate": 9.586347724620771e-05,
1206
+ "loss": 3.6380882263183594,
1207
+ "step": 1710
1208
+ },
1209
+ {
1210
+ "epoch": 1.8966087675765095,
1211
+ "grad_norm": 1.8987306356430054,
1212
+ "learning_rate": 9.58051341890315e-05,
1213
+ "loss": 3.6405975341796877,
1214
+ "step": 1720
1215
+ },
1216
+ {
1217
+ "epoch": 1.907637165701682,
1218
+ "grad_norm": 2.202659845352173,
1219
+ "learning_rate": 9.574679113185531e-05,
1220
+ "loss": 3.6375991821289064,
1221
+ "step": 1730
1222
+ },
1223
+ {
1224
+ "epoch": 1.9186655638268542,
1225
+ "grad_norm": 1.5091872215270996,
1226
+ "learning_rate": 9.568844807467912e-05,
1227
+ "loss": 3.6208465576171873,
1228
+ "step": 1740
1229
+ },
1230
+ {
1231
+ "epoch": 1.9296939619520264,
1232
+ "grad_norm": 1.9811325073242188,
1233
+ "learning_rate": 9.563010501750293e-05,
1234
+ "loss": 3.600755310058594,
1235
+ "step": 1750
1236
+ },
1237
+ {
1238
+ "epoch": 1.9407223600771988,
1239
+ "grad_norm": 3.184499979019165,
1240
+ "learning_rate": 9.557176196032673e-05,
1241
+ "loss": 3.6109405517578126,
1242
+ "step": 1760
1243
+ },
1244
+ {
1245
+ "epoch": 1.9517507582023712,
1246
+ "grad_norm": 2.340125322341919,
1247
+ "learning_rate": 9.551341890315054e-05,
1248
+ "loss": 3.6129817962646484,
1249
+ "step": 1770
1250
+ },
1251
+ {
1252
+ "epoch": 1.9627791563275434,
1253
+ "grad_norm": 1.7258495092391968,
1254
+ "learning_rate": 9.545507584597433e-05,
1255
+ "loss": 3.590809631347656,
1256
+ "step": 1780
1257
+ },
1258
+ {
1259
+ "epoch": 1.9738075544527156,
1260
+ "grad_norm": 1.6129754781723022,
1261
+ "learning_rate": 9.539673278879814e-05,
1262
+ "loss": 3.5866302490234374,
1263
+ "step": 1790
1264
+ },
1265
+ {
1266
+ "epoch": 1.984835952577888,
1267
+ "grad_norm": 2.7458667755126953,
1268
+ "learning_rate": 9.533838973162195e-05,
1269
+ "loss": 3.596644973754883,
1270
+ "step": 1800
1271
+ },
1272
+ {
1273
+ "epoch": 1.9958643507030605,
1274
+ "grad_norm": 2.258280038833618,
1275
+ "learning_rate": 9.528004667444574e-05,
1276
+ "loss": 3.5881332397460937,
1277
+ "step": 1810
1278
+ },
1279
+ {
1280
+ "epoch": 2.0066170388751035,
1281
+ "grad_norm": 2.1228580474853516,
1282
+ "learning_rate": 9.522170361726955e-05,
1283
+ "loss": 3.5709766387939452,
1284
+ "step": 1820
1285
+ },
1286
+ {
1287
+ "epoch": 2.017645437000276,
1288
+ "grad_norm": 1.588876485824585,
1289
+ "learning_rate": 9.516336056009335e-05,
1290
+ "loss": 3.5627593994140625,
1291
+ "step": 1830
1292
+ },
1293
+ {
1294
+ "epoch": 2.028673835125448,
1295
+ "grad_norm": 2.451474189758301,
1296
+ "learning_rate": 9.510501750291716e-05,
1297
+ "loss": 3.5535301208496093,
1298
+ "step": 1840
1299
+ },
1300
+ {
1301
+ "epoch": 2.0397022332506203,
1302
+ "grad_norm": 2.0007503032684326,
1303
+ "learning_rate": 9.504667444574095e-05,
1304
+ "loss": 3.553875732421875,
1305
+ "step": 1850
1306
+ },
1307
+ {
1308
+ "epoch": 2.0507306313757927,
1309
+ "grad_norm": 1.4410080909729004,
1310
+ "learning_rate": 9.498833138856476e-05,
1311
+ "loss": 3.550189971923828,
1312
+ "step": 1860
1313
+ },
1314
+ {
1315
+ "epoch": 2.061759029500965,
1316
+ "grad_norm": 2.062835216522217,
1317
+ "learning_rate": 9.492998833138857e-05,
1318
+ "loss": 3.5456893920898436,
1319
+ "step": 1870
1320
+ },
1321
+ {
1322
+ "epoch": 2.072787427626137,
1323
+ "grad_norm": 2.4534783363342285,
1324
+ "learning_rate": 9.487164527421238e-05,
1325
+ "loss": 3.536829376220703,
1326
+ "step": 1880
1327
+ },
1328
+ {
1329
+ "epoch": 2.0838158257513095,
1330
+ "grad_norm": 2.2788970470428467,
1331
+ "learning_rate": 9.481330221703619e-05,
1332
+ "loss": 3.5525283813476562,
1333
+ "step": 1890
1334
+ },
1335
+ {
1336
+ "epoch": 2.094844223876482,
1337
+ "grad_norm": 1.4259227514266968,
1338
+ "learning_rate": 9.475495915985998e-05,
1339
+ "loss": 3.5479995727539064,
1340
+ "step": 1900
1341
+ },
1342
+ {
1343
+ "epoch": 2.1058726220016544,
1344
+ "grad_norm": 2.672534465789795,
1345
+ "learning_rate": 9.469661610268379e-05,
1346
+ "loss": 3.5359420776367188,
1347
+ "step": 1910
1348
+ },
1349
+ {
1350
+ "epoch": 2.116901020126827,
1351
+ "grad_norm": 2.0648045539855957,
1352
+ "learning_rate": 9.463827304550759e-05,
1353
+ "loss": 3.5452896118164063,
1354
+ "step": 1920
1355
+ },
1356
+ {
1357
+ "epoch": 2.1279294182519988,
1358
+ "grad_norm": 1.6846543550491333,
1359
+ "learning_rate": 9.45799299883314e-05,
1360
+ "loss": 3.5434345245361327,
1361
+ "step": 1930
1362
+ },
1363
+ {
1364
+ "epoch": 2.138957816377171,
1365
+ "grad_norm": 1.9105942249298096,
1366
+ "learning_rate": 9.452158693115519e-05,
1367
+ "loss": 3.5351535797119142,
1368
+ "step": 1940
1369
+ },
1370
+ {
1371
+ "epoch": 2.1499862145023436,
1372
+ "grad_norm": 1.8230890035629272,
1373
+ "learning_rate": 9.4463243873979e-05,
1374
+ "loss": 3.5190963745117188,
1375
+ "step": 1950
1376
+ },
1377
+ {
1378
+ "epoch": 2.161014612627516,
1379
+ "grad_norm": 1.6383274793624878,
1380
+ "learning_rate": 9.440490081680281e-05,
1381
+ "loss": 3.5228431701660154,
1382
+ "step": 1960
1383
+ },
1384
+ {
1385
+ "epoch": 2.172043010752688,
1386
+ "grad_norm": 1.7378439903259277,
1387
+ "learning_rate": 9.43465577596266e-05,
1388
+ "loss": 3.520981216430664,
1389
+ "step": 1970
1390
+ },
1391
+ {
1392
+ "epoch": 2.1830714088778604,
1393
+ "grad_norm": 1.941454529762268,
1394
+ "learning_rate": 9.428821470245041e-05,
1395
+ "loss": 3.519342803955078,
1396
+ "step": 1980
1397
+ },
1398
+ {
1399
+ "epoch": 2.194099807003033,
1400
+ "grad_norm": 1.8295516967773438,
1401
+ "learning_rate": 9.422987164527421e-05,
1402
+ "loss": 3.5412979125976562,
1403
+ "step": 1990
1404
+ },
1405
+ {
1406
+ "epoch": 2.2051282051282053,
1407
+ "grad_norm": 1.8052620887756348,
1408
+ "learning_rate": 9.417152858809802e-05,
1409
+ "loss": 3.5153289794921876,
1410
+ "step": 2000
1411
+ },
1412
+ {
1413
+ "epoch": 2.2161566032533773,
1414
+ "grad_norm": 2.1949570178985596,
1415
+ "learning_rate": 9.411318553092183e-05,
1416
+ "loss": 3.521608352661133,
1417
+ "step": 2010
1418
+ },
1419
+ {
1420
+ "epoch": 2.2271850013785497,
1421
+ "grad_norm": 1.746172308921814,
1422
+ "learning_rate": 9.405484247374564e-05,
1423
+ "loss": 3.5008296966552734,
1424
+ "step": 2020
1425
+ },
1426
+ {
1427
+ "epoch": 2.238213399503722,
1428
+ "grad_norm": 2.5374276638031006,
1429
+ "learning_rate": 9.399649941656943e-05,
1430
+ "loss": 3.5140228271484375,
1431
+ "step": 2030
1432
+ },
1433
+ {
1434
+ "epoch": 2.2492417976288945,
1435
+ "grad_norm": 1.7763218879699707,
1436
+ "learning_rate": 9.393815635939324e-05,
1437
+ "loss": 3.510652542114258,
1438
+ "step": 2040
1439
+ },
1440
+ {
1441
+ "epoch": 2.2602701957540665,
1442
+ "grad_norm": 1.6599587202072144,
1443
+ "learning_rate": 9.387981330221705e-05,
1444
+ "loss": 3.5122325897216795,
1445
+ "step": 2050
1446
+ },
1447
+ {
1448
+ "epoch": 2.271298593879239,
1449
+ "grad_norm": 2.1496078968048096,
1450
+ "learning_rate": 9.382147024504085e-05,
1451
+ "loss": 3.5139747619628907,
1452
+ "step": 2060
1453
+ },
1454
+ {
1455
+ "epoch": 2.2823269920044114,
1456
+ "grad_norm": 1.64266836643219,
1457
+ "learning_rate": 9.376312718786465e-05,
1458
+ "loss": 3.507743072509766,
1459
+ "step": 2070
1460
+ },
1461
+ {
1462
+ "epoch": 2.293355390129584,
1463
+ "grad_norm": 2.1241567134857178,
1464
+ "learning_rate": 9.370478413068845e-05,
1465
+ "loss": 3.5162708282470705,
1466
+ "step": 2080
1467
+ },
1468
+ {
1469
+ "epoch": 2.304383788254756,
1470
+ "grad_norm": 1.8391071557998657,
1471
+ "learning_rate": 9.364644107351226e-05,
1472
+ "loss": 3.4955375671386717,
1473
+ "step": 2090
1474
+ },
1475
+ {
1476
+ "epoch": 2.315412186379928,
1477
+ "grad_norm": 2.7478973865509033,
1478
+ "learning_rate": 9.358809801633605e-05,
1479
+ "loss": 3.497519302368164,
1480
+ "step": 2100
1481
+ },
1482
+ {
1483
+ "epoch": 2.3264405845051006,
1484
+ "grad_norm": 1.938588261604309,
1485
+ "learning_rate": 9.352975495915986e-05,
1486
+ "loss": 3.490141677856445,
1487
+ "step": 2110
1488
+ },
1489
+ {
1490
+ "epoch": 2.337468982630273,
1491
+ "grad_norm": 1.5637104511260986,
1492
+ "learning_rate": 9.347141190198366e-05,
1493
+ "loss": 3.499908447265625,
1494
+ "step": 2120
1495
+ },
1496
+ {
1497
+ "epoch": 2.3484973807554455,
1498
+ "grad_norm": 1.882504940032959,
1499
+ "learning_rate": 9.341306884480747e-05,
1500
+ "loss": 3.491979217529297,
1501
+ "step": 2130
1502
+ },
1503
+ {
1504
+ "epoch": 2.3595257788806174,
1505
+ "grad_norm": 1.8528521060943604,
1506
+ "learning_rate": 9.335472578763128e-05,
1507
+ "loss": 3.4961143493652345,
1508
+ "step": 2140
1509
+ },
1510
+ {
1511
+ "epoch": 2.37055417700579,
1512
+ "grad_norm": 1.8050177097320557,
1513
+ "learning_rate": 9.329638273045509e-05,
1514
+ "loss": 3.4948150634765627,
1515
+ "step": 2150
1516
+ },
1517
+ {
1518
+ "epoch": 2.3815825751309623,
1519
+ "grad_norm": 1.816784381866455,
1520
+ "learning_rate": 9.32380396732789e-05,
1521
+ "loss": 3.4910873413085937,
1522
+ "step": 2160
1523
+ },
1524
+ {
1525
+ "epoch": 2.3926109732561347,
1526
+ "grad_norm": 1.9779244661331177,
1527
+ "learning_rate": 9.317969661610269e-05,
1528
+ "loss": 3.492570495605469,
1529
+ "step": 2170
1530
+ },
1531
+ {
1532
+ "epoch": 2.4036393713813067,
1533
+ "grad_norm": 1.8939772844314575,
1534
+ "learning_rate": 9.31213535589265e-05,
1535
+ "loss": 3.473868560791016,
1536
+ "step": 2180
1537
+ },
1538
+ {
1539
+ "epoch": 2.414667769506479,
1540
+ "grad_norm": 2.1493656635284424,
1541
+ "learning_rate": 9.30630105017503e-05,
1542
+ "loss": 3.494515228271484,
1543
+ "step": 2190
1544
+ },
1545
+ {
1546
+ "epoch": 2.4256961676316515,
1547
+ "grad_norm": 1.8989397287368774,
1548
+ "learning_rate": 9.30046674445741e-05,
1549
+ "loss": 3.487537384033203,
1550
+ "step": 2200
1551
+ },
1552
+ {
1553
+ "epoch": 2.436724565756824,
1554
+ "grad_norm": 1.881856918334961,
1555
+ "learning_rate": 9.294632438739791e-05,
1556
+ "loss": 3.475904083251953,
1557
+ "step": 2210
1558
+ },
1559
+ {
1560
+ "epoch": 2.447752963881996,
1561
+ "grad_norm": 1.9463883638381958,
1562
+ "learning_rate": 9.288798133022171e-05,
1563
+ "loss": 3.4829254150390625,
1564
+ "step": 2220
1565
+ },
1566
+ {
1567
+ "epoch": 2.4587813620071683,
1568
+ "grad_norm": 2.01379656791687,
1569
+ "learning_rate": 9.282963827304552e-05,
1570
+ "loss": 3.472850036621094,
1571
+ "step": 2230
1572
+ },
1573
+ {
1574
+ "epoch": 2.4698097601323408,
1575
+ "grad_norm": 2.442741632461548,
1576
+ "learning_rate": 9.277129521586931e-05,
1577
+ "loss": 3.47030029296875,
1578
+ "step": 2240
1579
+ },
1580
+ {
1581
+ "epoch": 2.480838158257513,
1582
+ "grad_norm": 1.5051734447479248,
1583
+ "learning_rate": 9.271295215869312e-05,
1584
+ "loss": 3.489413833618164,
1585
+ "step": 2250
1586
+ },
1587
+ {
1588
+ "epoch": 2.4918665563826856,
1589
+ "grad_norm": 1.9489309787750244,
1590
+ "learning_rate": 9.265460910151692e-05,
1591
+ "loss": 3.464769744873047,
1592
+ "step": 2260
1593
+ },
1594
+ {
1595
+ "epoch": 2.5028949545078576,
1596
+ "grad_norm": 2.319654941558838,
1597
+ "learning_rate": 9.259626604434072e-05,
1598
+ "loss": 3.469140625,
1599
+ "step": 2270
1600
+ },
1601
+ {
1602
+ "epoch": 2.51392335263303,
1603
+ "grad_norm": 1.7984129190444946,
1604
+ "learning_rate": 9.253792298716453e-05,
1605
+ "loss": 3.466594696044922,
1606
+ "step": 2280
1607
+ },
1608
+ {
1609
+ "epoch": 2.5249517507582024,
1610
+ "grad_norm": 1.640869379043579,
1611
+ "learning_rate": 9.247957992998833e-05,
1612
+ "loss": 3.463022994995117,
1613
+ "step": 2290
1614
+ },
1615
+ {
1616
+ "epoch": 2.5359801488833744,
1617
+ "grad_norm": 1.6698195934295654,
1618
+ "learning_rate": 9.242123687281214e-05,
1619
+ "loss": 3.4695220947265626,
1620
+ "step": 2300
1621
+ },
1622
+ {
1623
+ "epoch": 2.547008547008547,
1624
+ "grad_norm": 2.2945683002471924,
1625
+ "learning_rate": 9.236289381563595e-05,
1626
+ "loss": 3.469150924682617,
1627
+ "step": 2310
1628
+ },
1629
+ {
1630
+ "epoch": 2.5580369451337193,
1631
+ "grad_norm": 1.7678370475769043,
1632
+ "learning_rate": 9.230455075845976e-05,
1633
+ "loss": 3.470307159423828,
1634
+ "step": 2320
1635
+ },
1636
+ {
1637
+ "epoch": 2.5690653432588917,
1638
+ "grad_norm": 1.8386255502700806,
1639
+ "learning_rate": 9.224620770128355e-05,
1640
+ "loss": 3.4638832092285154,
1641
+ "step": 2330
1642
+ },
1643
+ {
1644
+ "epoch": 2.580093741384064,
1645
+ "grad_norm": 2.0348527431488037,
1646
+ "learning_rate": 9.218786464410736e-05,
1647
+ "loss": 3.460480880737305,
1648
+ "step": 2340
1649
+ },
1650
+ {
1651
+ "epoch": 2.5911221395092365,
1652
+ "grad_norm": 1.845974326133728,
1653
+ "learning_rate": 9.212952158693116e-05,
1654
+ "loss": 3.4529083251953123,
1655
+ "step": 2350
1656
+ },
1657
+ {
1658
+ "epoch": 2.6021505376344085,
1659
+ "grad_norm": 2.0843095779418945,
1660
+ "learning_rate": 9.207117852975496e-05,
1661
+ "loss": 3.4576786041259764,
1662
+ "step": 2360
1663
+ },
1664
+ {
1665
+ "epoch": 2.613178935759581,
1666
+ "grad_norm": 1.7627031803131104,
1667
+ "learning_rate": 9.201283547257876e-05,
1668
+ "loss": 3.4450752258300783,
1669
+ "step": 2370
1670
+ },
1671
+ {
1672
+ "epoch": 2.6242073338847534,
1673
+ "grad_norm": 1.371972918510437,
1674
+ "learning_rate": 9.195449241540257e-05,
1675
+ "loss": 3.464734649658203,
1676
+ "step": 2380
1677
+ },
1678
+ {
1679
+ "epoch": 2.6352357320099253,
1680
+ "grad_norm": 1.6781940460205078,
1681
+ "learning_rate": 9.189614935822638e-05,
1682
+ "loss": 3.444991683959961,
1683
+ "step": 2390
1684
+ },
1685
+ {
1686
+ "epoch": 2.6462641301350978,
1687
+ "grad_norm": 1.8782585859298706,
1688
+ "learning_rate": 9.183780630105017e-05,
1689
+ "loss": 3.4558509826660155,
1690
+ "step": 2400
1691
+ },
1692
+ {
1693
+ "epoch": 2.65729252826027,
1694
+ "grad_norm": 1.942812204360962,
1695
+ "learning_rate": 9.177946324387398e-05,
1696
+ "loss": 3.4555503845214846,
1697
+ "step": 2410
1698
+ },
1699
+ {
1700
+ "epoch": 2.6683209263854426,
1701
+ "grad_norm": 1.404680609703064,
1702
+ "learning_rate": 9.172112018669778e-05,
1703
+ "loss": 3.438182830810547,
1704
+ "step": 2420
1705
+ },
1706
+ {
1707
+ "epoch": 2.679349324510615,
1708
+ "grad_norm": 1.7656677961349487,
1709
+ "learning_rate": 9.166277712952159e-05,
1710
+ "loss": 3.4622947692871096,
1711
+ "step": 2430
1712
+ },
1713
+ {
1714
+ "epoch": 2.690377722635787,
1715
+ "grad_norm": 1.8348901271820068,
1716
+ "learning_rate": 9.16044340723454e-05,
1717
+ "loss": 3.438182830810547,
1718
+ "step": 2440
1719
+ },
1720
+ {
1721
+ "epoch": 2.7014061207609594,
1722
+ "grad_norm": 2.0641167163848877,
1723
+ "learning_rate": 9.15460910151692e-05,
1724
+ "loss": 3.441473388671875,
1725
+ "step": 2450
1726
+ },
1727
+ {
1728
+ "epoch": 2.712434518886132,
1729
+ "grad_norm": 1.726035475730896,
1730
+ "learning_rate": 9.148774795799301e-05,
1731
+ "loss": 3.441991424560547,
1732
+ "step": 2460
1733
+ },
1734
+ {
1735
+ "epoch": 2.7234629170113043,
1736
+ "grad_norm": 1.854658603668213,
1737
+ "learning_rate": 9.142940490081681e-05,
1738
+ "loss": 3.4441551208496093,
1739
+ "step": 2470
1740
+ },
1741
+ {
1742
+ "epoch": 2.7344913151364763,
1743
+ "grad_norm": 1.8229296207427979,
1744
+ "learning_rate": 9.137106184364062e-05,
1745
+ "loss": 3.441034698486328,
1746
+ "step": 2480
1747
+ },
1748
+ {
1749
+ "epoch": 2.7455197132616487,
1750
+ "grad_norm": 1.6627975702285767,
1751
+ "learning_rate": 9.131271878646441e-05,
1752
+ "loss": 3.4399124145507813,
1753
+ "step": 2490
1754
+ },
1755
+ {
1756
+ "epoch": 2.756548111386821,
1757
+ "grad_norm": 1.4111251831054688,
1758
+ "learning_rate": 9.125437572928822e-05,
1759
+ "loss": 3.4374462127685548,
1760
+ "step": 2500
1761
+ },
1762
+ {
1763
+ "epoch": 2.7675765095119935,
1764
+ "grad_norm": 2.015869379043579,
1765
+ "learning_rate": 9.119603267211202e-05,
1766
+ "loss": 3.4262016296386717,
1767
+ "step": 2510
1768
+ },
1769
+ {
1770
+ "epoch": 2.778604907637166,
1771
+ "grad_norm": 2.2818591594696045,
1772
+ "learning_rate": 9.113768961493583e-05,
1773
+ "loss": 3.446285629272461,
1774
+ "step": 2520
1775
+ },
1776
+ {
1777
+ "epoch": 2.789633305762338,
1778
+ "grad_norm": 1.8643262386322021,
1779
+ "learning_rate": 9.107934655775962e-05,
1780
+ "loss": 3.4362293243408204,
1781
+ "step": 2530
1782
+ },
1783
+ {
1784
+ "epoch": 2.8006617038875103,
1785
+ "grad_norm": 1.248988151550293,
1786
+ "learning_rate": 9.102100350058343e-05,
1787
+ "loss": 3.441702651977539,
1788
+ "step": 2540
1789
+ },
1790
+ {
1791
+ "epoch": 2.8116901020126828,
1792
+ "grad_norm": 1.5247464179992676,
1793
+ "learning_rate": 9.096266044340724e-05,
1794
+ "loss": 3.4388256072998047,
1795
+ "step": 2550
1796
+ },
1797
+ {
1798
+ "epoch": 2.8227185001378547,
1799
+ "grad_norm": 1.9120620489120483,
1800
+ "learning_rate": 9.090431738623103e-05,
1801
+ "loss": 3.4206756591796874,
1802
+ "step": 2560
1803
+ },
1804
+ {
1805
+ "epoch": 2.833746898263027,
1806
+ "grad_norm": 1.4591054916381836,
1807
+ "learning_rate": 9.084597432905484e-05,
1808
+ "loss": 3.4229709625244142,
1809
+ "step": 2570
1810
+ },
1811
+ {
1812
+ "epoch": 2.8447752963881996,
1813
+ "grad_norm": 2.24849796295166,
1814
+ "learning_rate": 9.078763127187865e-05,
1815
+ "loss": 3.426911163330078,
1816
+ "step": 2580
1817
+ },
1818
+ {
1819
+ "epoch": 2.855803694513372,
1820
+ "grad_norm": 1.5658804178237915,
1821
+ "learning_rate": 9.072928821470246e-05,
1822
+ "loss": 3.445120620727539,
1823
+ "step": 2590
1824
+ },
1825
+ {
1826
+ "epoch": 2.8668320926385444,
1827
+ "grad_norm": 1.483583688735962,
1828
+ "learning_rate": 9.067094515752626e-05,
1829
+ "loss": 3.430312728881836,
1830
+ "step": 2600
1831
+ },
1832
+ {
1833
+ "epoch": 2.8778604907637164,
1834
+ "grad_norm": 1.5759658813476562,
1835
+ "learning_rate": 9.061260210035007e-05,
1836
+ "loss": 3.4178386688232423,
1837
+ "step": 2610
1838
+ },
1839
+ {
1840
+ "epoch": 2.888888888888889,
1841
+ "grad_norm": 1.9259848594665527,
1842
+ "learning_rate": 9.055425904317386e-05,
1843
+ "loss": 3.430949401855469,
1844
+ "step": 2620
1845
+ },
1846
+ {
1847
+ "epoch": 2.8999172870140613,
1848
+ "grad_norm": 1.470717191696167,
1849
+ "learning_rate": 9.049591598599767e-05,
1850
+ "loss": 3.439757537841797,
1851
+ "step": 2630
1852
+ },
1853
+ {
1854
+ "epoch": 2.9109456851392337,
1855
+ "grad_norm": 1.8934212923049927,
1856
+ "learning_rate": 9.043757292882148e-05,
1857
+ "loss": 3.430719757080078,
1858
+ "step": 2640
1859
+ },
1860
+ {
1861
+ "epoch": 2.9219740832644057,
1862
+ "grad_norm": 1.6267489194869995,
1863
+ "learning_rate": 9.037922987164527e-05,
1864
+ "loss": 3.4224998474121096,
1865
+ "step": 2650
1866
+ },
1867
+ {
1868
+ "epoch": 2.933002481389578,
1869
+ "grad_norm": 1.6213353872299194,
1870
+ "learning_rate": 9.032088681446908e-05,
1871
+ "loss": 3.4213233947753907,
1872
+ "step": 2660
1873
+ },
1874
+ {
1875
+ "epoch": 2.9440308795147505,
1876
+ "grad_norm": 1.961879849433899,
1877
+ "learning_rate": 9.026254375729288e-05,
1878
+ "loss": 3.4108352661132812,
1879
+ "step": 2670
1880
+ },
1881
+ {
1882
+ "epoch": 2.955059277639923,
1883
+ "grad_norm": 1.7363910675048828,
1884
+ "learning_rate": 9.020420070011669e-05,
1885
+ "loss": 3.423554229736328,
1886
+ "step": 2680
1887
+ },
1888
+ {
1889
+ "epoch": 2.9660876757650954,
1890
+ "grad_norm": 1.6161952018737793,
1891
+ "learning_rate": 9.014585764294048e-05,
1892
+ "loss": 3.418962860107422,
1893
+ "step": 2690
1894
+ },
1895
+ {
1896
+ "epoch": 2.9771160738902673,
1897
+ "grad_norm": 1.8065682649612427,
1898
+ "learning_rate": 9.008751458576429e-05,
1899
+ "loss": 3.4218765258789063,
1900
+ "step": 2700
1901
+ },
1902
+ {
1903
+ "epoch": 2.9881444720154398,
1904
+ "grad_norm": 1.4285337924957275,
1905
+ "learning_rate": 9.00291715285881e-05,
1906
+ "loss": 3.413957214355469,
1907
+ "step": 2710
1908
+ },
1909
+ {
1910
+ "epoch": 2.999172870140612,
1911
+ "grad_norm": 1.30274498462677,
1912
+ "learning_rate": 8.997082847141191e-05,
1913
+ "loss": 3.4176124572753905,
1914
+ "step": 2720
1915
+ }
1916
+ ],
1917
+ "logging_steps": 10,
1918
+ "max_steps": 18140,
1919
+ "num_input_tokens_seen": 0,
1920
+ "num_train_epochs": 20,
1921
+ "save_steps": 500,
1922
+ "stateful_callbacks": {
1923
+ "TrainerControl": {
1924
+ "args": {
1925
+ "should_epoch_stop": false,
1926
+ "should_evaluate": false,
1927
+ "should_log": false,
1928
+ "should_save": true,
1929
+ "should_training_stop": false
1930
+ },
1931
+ "attributes": {}
1932
+ }
1933
+ },
1934
+ "total_flos": 1083298732048384.0,
1935
+ "train_batch_size": 1,
1936
+ "trial_name": null,
1937
+ "trial_params": null
1938
+ }
output_qwen3_plain_ar/checkpoint-2721/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-3628/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-3628/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-3628/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step3628
output_qwen3_plain_ar/checkpoint-3628/trainer_state.json ADDED
@@ -0,0 +1,2568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 4.0,
6
+ "eval_steps": 500,
7
+ "global_step": 3628,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.011028398125172319,
14
+ "grad_norm": 435.2422180175781,
15
+ "learning_rate": 9e-07,
16
+ "loss": 20.84569549560547,
17
+ "step": 10
18
+ },
19
+ {
20
+ "epoch": 0.022056796250344637,
21
+ "grad_norm": 141.7341766357422,
22
+ "learning_rate": 1.9e-06,
23
+ "loss": 18.69615936279297,
24
+ "step": 20
25
+ },
26
+ {
27
+ "epoch": 0.033085194375516956,
28
+ "grad_norm": 74.42520904541016,
29
+ "learning_rate": 2.9e-06,
30
+ "loss": 16.079673767089844,
31
+ "step": 30
32
+ },
33
+ {
34
+ "epoch": 0.044113592500689275,
35
+ "grad_norm": 24.73248863220215,
36
+ "learning_rate": 3.9e-06,
37
+ "loss": 13.684315490722657,
38
+ "step": 40
39
+ },
40
+ {
41
+ "epoch": 0.055141990625861594,
42
+ "grad_norm": 7.049101829528809,
43
+ "learning_rate": 4.9000000000000005e-06,
44
+ "loss": 12.474874877929688,
45
+ "step": 50
46
+ },
47
+ {
48
+ "epoch": 0.06617038875103391,
49
+ "grad_norm": 2.3411474227905273,
50
+ "learning_rate": 5.9e-06,
51
+ "loss": 12.072142028808594,
52
+ "step": 60
53
+ },
54
+ {
55
+ "epoch": 0.07719878687620624,
56
+ "grad_norm": 1.126215934753418,
57
+ "learning_rate": 6.900000000000001e-06,
58
+ "loss": 11.938906860351562,
59
+ "step": 70
60
+ },
61
+ {
62
+ "epoch": 0.08822718500137855,
63
+ "grad_norm": 1.2050226926803589,
64
+ "learning_rate": 7.9e-06,
65
+ "loss": 11.81988296508789,
66
+ "step": 80
67
+ },
68
+ {
69
+ "epoch": 0.09925558312655088,
70
+ "grad_norm": 1.444793462753296,
71
+ "learning_rate": 8.9e-06,
72
+ "loss": 11.602033996582032,
73
+ "step": 90
74
+ },
75
+ {
76
+ "epoch": 0.11028398125172319,
77
+ "grad_norm": 5.791665077209473,
78
+ "learning_rate": 9.900000000000002e-06,
79
+ "loss": 11.201815032958985,
80
+ "step": 100
81
+ },
82
+ {
83
+ "epoch": 0.12131237937689551,
84
+ "grad_norm": 9.492277145385742,
85
+ "learning_rate": 1.09e-05,
86
+ "loss": 10.535708618164062,
87
+ "step": 110
88
+ },
89
+ {
90
+ "epoch": 0.13234077750206782,
91
+ "grad_norm": 2.7546133995056152,
92
+ "learning_rate": 1.19e-05,
93
+ "loss": 9.847169494628906,
94
+ "step": 120
95
+ },
96
+ {
97
+ "epoch": 0.14336917562724014,
98
+ "grad_norm": 1.0953313112258911,
99
+ "learning_rate": 1.29e-05,
100
+ "loss": 9.429026031494141,
101
+ "step": 130
102
+ },
103
+ {
104
+ "epoch": 0.15439757375241248,
105
+ "grad_norm": 0.7153559327125549,
106
+ "learning_rate": 1.3900000000000002e-05,
107
+ "loss": 9.266969299316406,
108
+ "step": 140
109
+ },
110
+ {
111
+ "epoch": 0.1654259718775848,
112
+ "grad_norm": 0.5888933539390564,
113
+ "learning_rate": 1.49e-05,
114
+ "loss": 9.1935546875,
115
+ "step": 150
116
+ },
117
+ {
118
+ "epoch": 0.1764543700027571,
119
+ "grad_norm": 0.4850365221500397,
120
+ "learning_rate": 1.59e-05,
121
+ "loss": 9.19604034423828,
122
+ "step": 160
123
+ },
124
+ {
125
+ "epoch": 0.1874827681279294,
126
+ "grad_norm": 0.5772538185119629,
127
+ "learning_rate": 1.69e-05,
128
+ "loss": 9.17010726928711,
129
+ "step": 170
130
+ },
131
+ {
132
+ "epoch": 0.19851116625310175,
133
+ "grad_norm": 0.4283920228481293,
134
+ "learning_rate": 1.79e-05,
135
+ "loss": 9.172830200195312,
136
+ "step": 180
137
+ },
138
+ {
139
+ "epoch": 0.20953956437827406,
140
+ "grad_norm": 0.8650698065757751,
141
+ "learning_rate": 1.8900000000000002e-05,
142
+ "loss": 9.154988098144532,
143
+ "step": 190
144
+ },
145
+ {
146
+ "epoch": 0.22056796250344637,
147
+ "grad_norm": 0.42017608880996704,
148
+ "learning_rate": 1.9900000000000003e-05,
149
+ "loss": 9.146849060058594,
150
+ "step": 200
151
+ },
152
+ {
153
+ "epoch": 0.23159636062861869,
154
+ "grad_norm": 0.9125994443893433,
155
+ "learning_rate": 2.09e-05,
156
+ "loss": 9.164442443847657,
157
+ "step": 210
158
+ },
159
+ {
160
+ "epoch": 0.24262475875379103,
161
+ "grad_norm": 0.6468876004219055,
162
+ "learning_rate": 2.19e-05,
163
+ "loss": 9.159596252441407,
164
+ "step": 220
165
+ },
166
+ {
167
+ "epoch": 0.25365315687896334,
168
+ "grad_norm": 0.4124819338321686,
169
+ "learning_rate": 2.29e-05,
170
+ "loss": 9.13860626220703,
171
+ "step": 230
172
+ },
173
+ {
174
+ "epoch": 0.26468155500413565,
175
+ "grad_norm": 1.990302562713623,
176
+ "learning_rate": 2.39e-05,
177
+ "loss": 9.145040893554688,
178
+ "step": 240
179
+ },
180
+ {
181
+ "epoch": 0.27570995312930796,
182
+ "grad_norm": 0.7875277400016785,
183
+ "learning_rate": 2.4900000000000002e-05,
184
+ "loss": 9.152925109863281,
185
+ "step": 250
186
+ },
187
+ {
188
+ "epoch": 0.2867383512544803,
189
+ "grad_norm": 0.8343706130981445,
190
+ "learning_rate": 2.5900000000000003e-05,
191
+ "loss": 9.132975769042968,
192
+ "step": 260
193
+ },
194
+ {
195
+ "epoch": 0.2977667493796526,
196
+ "grad_norm": 3.00996470451355,
197
+ "learning_rate": 2.6900000000000003e-05,
198
+ "loss": 9.097848510742187,
199
+ "step": 270
200
+ },
201
+ {
202
+ "epoch": 0.30879514750482495,
203
+ "grad_norm": 2.4282069206237793,
204
+ "learning_rate": 2.7900000000000004e-05,
205
+ "loss": 9.042235565185546,
206
+ "step": 280
207
+ },
208
+ {
209
+ "epoch": 0.31982354562999726,
210
+ "grad_norm": 4.171019554138184,
211
+ "learning_rate": 2.8899999999999998e-05,
212
+ "loss": 8.927298736572265,
213
+ "step": 290
214
+ },
215
+ {
216
+ "epoch": 0.3308519437551696,
217
+ "grad_norm": 2.197887659072876,
218
+ "learning_rate": 2.9900000000000002e-05,
219
+ "loss": 8.805252075195312,
220
+ "step": 300
221
+ },
222
+ {
223
+ "epoch": 0.3418803418803419,
224
+ "grad_norm": 10.306541442871094,
225
+ "learning_rate": 3.09e-05,
226
+ "loss": 8.673678588867187,
227
+ "step": 310
228
+ },
229
+ {
230
+ "epoch": 0.3529087400055142,
231
+ "grad_norm": 8.463860511779785,
232
+ "learning_rate": 3.19e-05,
233
+ "loss": 8.570347595214844,
234
+ "step": 320
235
+ },
236
+ {
237
+ "epoch": 0.3639371381306865,
238
+ "grad_norm": 3.999753475189209,
239
+ "learning_rate": 3.29e-05,
240
+ "loss": 8.429109191894531,
241
+ "step": 330
242
+ },
243
+ {
244
+ "epoch": 0.3749655362558588,
245
+ "grad_norm": 5.259007930755615,
246
+ "learning_rate": 3.3900000000000004e-05,
247
+ "loss": 8.334149169921876,
248
+ "step": 340
249
+ },
250
+ {
251
+ "epoch": 0.38599393438103113,
252
+ "grad_norm": 8.362598419189453,
253
+ "learning_rate": 3.49e-05,
254
+ "loss": 8.196139526367187,
255
+ "step": 350
256
+ },
257
+ {
258
+ "epoch": 0.3970223325062035,
259
+ "grad_norm": 10.273512840270996,
260
+ "learning_rate": 3.59e-05,
261
+ "loss": 8.040153503417969,
262
+ "step": 360
263
+ },
264
+ {
265
+ "epoch": 0.4080507306313758,
266
+ "grad_norm": 5.111108303070068,
267
+ "learning_rate": 3.69e-05,
268
+ "loss": 7.866473388671875,
269
+ "step": 370
270
+ },
271
+ {
272
+ "epoch": 0.4190791287565481,
273
+ "grad_norm": 9.192107200622559,
274
+ "learning_rate": 3.79e-05,
275
+ "loss": 7.695774841308594,
276
+ "step": 380
277
+ },
278
+ {
279
+ "epoch": 0.43010752688172044,
280
+ "grad_norm": 5.393336772918701,
281
+ "learning_rate": 3.8900000000000004e-05,
282
+ "loss": 7.498152160644532,
283
+ "step": 390
284
+ },
285
+ {
286
+ "epoch": 0.44113592500689275,
287
+ "grad_norm": 10.53490161895752,
288
+ "learning_rate": 3.99e-05,
289
+ "loss": 7.270246887207032,
290
+ "step": 400
291
+ },
292
+ {
293
+ "epoch": 0.45216432313206506,
294
+ "grad_norm": 6.174643516540527,
295
+ "learning_rate": 4.09e-05,
296
+ "loss": 7.127191162109375,
297
+ "step": 410
298
+ },
299
+ {
300
+ "epoch": 0.46319272125723737,
301
+ "grad_norm": 4.522936820983887,
302
+ "learning_rate": 4.19e-05,
303
+ "loss": 6.871500396728516,
304
+ "step": 420
305
+ },
306
+ {
307
+ "epoch": 0.4742211193824097,
308
+ "grad_norm": 4.3594207763671875,
309
+ "learning_rate": 4.29e-05,
310
+ "loss": 6.702586364746094,
311
+ "step": 430
312
+ },
313
+ {
314
+ "epoch": 0.48524951750758205,
315
+ "grad_norm": 5.950730323791504,
316
+ "learning_rate": 4.39e-05,
317
+ "loss": 6.493560791015625,
318
+ "step": 440
319
+ },
320
+ {
321
+ "epoch": 0.49627791563275436,
322
+ "grad_norm": 6.233413219451904,
323
+ "learning_rate": 4.49e-05,
324
+ "loss": 6.293489074707031,
325
+ "step": 450
326
+ },
327
+ {
328
+ "epoch": 0.5073063137579267,
329
+ "grad_norm": 7.656834125518799,
330
+ "learning_rate": 4.5900000000000004e-05,
331
+ "loss": 6.102347946166992,
332
+ "step": 460
333
+ },
334
+ {
335
+ "epoch": 0.518334711883099,
336
+ "grad_norm": 4.319094657897949,
337
+ "learning_rate": 4.69e-05,
338
+ "loss": 5.928083419799805,
339
+ "step": 470
340
+ },
341
+ {
342
+ "epoch": 0.5293631100082713,
343
+ "grad_norm": 5.585537433624268,
344
+ "learning_rate": 4.79e-05,
345
+ "loss": 5.77436637878418,
346
+ "step": 480
347
+ },
348
+ {
349
+ "epoch": 0.5403915081334436,
350
+ "grad_norm": 5.104014873504639,
351
+ "learning_rate": 4.89e-05,
352
+ "loss": 5.636859130859375,
353
+ "step": 490
354
+ },
355
+ {
356
+ "epoch": 0.5514199062586159,
357
+ "grad_norm": 5.453028202056885,
358
+ "learning_rate": 4.99e-05,
359
+ "loss": 5.507636260986328,
360
+ "step": 500
361
+ },
362
+ {
363
+ "epoch": 0.5624483043837882,
364
+ "grad_norm": 7.728854179382324,
365
+ "learning_rate": 5.0900000000000004e-05,
366
+ "loss": 5.411964416503906,
367
+ "step": 510
368
+ },
369
+ {
370
+ "epoch": 0.5734767025089605,
371
+ "grad_norm": 4.50288724899292,
372
+ "learning_rate": 5.19e-05,
373
+ "loss": 5.295291900634766,
374
+ "step": 520
375
+ },
376
+ {
377
+ "epoch": 0.5845051006341329,
378
+ "grad_norm": 4.245919704437256,
379
+ "learning_rate": 5.2900000000000005e-05,
380
+ "loss": 5.194162750244141,
381
+ "step": 530
382
+ },
383
+ {
384
+ "epoch": 0.5955334987593052,
385
+ "grad_norm": 6.278975963592529,
386
+ "learning_rate": 5.390000000000001e-05,
387
+ "loss": 5.113618087768555,
388
+ "step": 540
389
+ },
390
+ {
391
+ "epoch": 0.6065618968844775,
392
+ "grad_norm": 4.214662075042725,
393
+ "learning_rate": 5.4900000000000006e-05,
394
+ "loss": 5.038372039794922,
395
+ "step": 550
396
+ },
397
+ {
398
+ "epoch": 0.6175902950096499,
399
+ "grad_norm": 3.5404605865478516,
400
+ "learning_rate": 5.590000000000001e-05,
401
+ "loss": 4.935391235351562,
402
+ "step": 560
403
+ },
404
+ {
405
+ "epoch": 0.6286186931348222,
406
+ "grad_norm": 3.6460280418395996,
407
+ "learning_rate": 5.69e-05,
408
+ "loss": 4.896538543701172,
409
+ "step": 570
410
+ },
411
+ {
412
+ "epoch": 0.6396470912599945,
413
+ "grad_norm": 5.254800796508789,
414
+ "learning_rate": 5.79e-05,
415
+ "loss": 4.829419708251953,
416
+ "step": 580
417
+ },
418
+ {
419
+ "epoch": 0.6506754893851668,
420
+ "grad_norm": 5.132180690765381,
421
+ "learning_rate": 5.89e-05,
422
+ "loss": 4.793368148803711,
423
+ "step": 590
424
+ },
425
+ {
426
+ "epoch": 0.6617038875103392,
427
+ "grad_norm": 4.222960948944092,
428
+ "learning_rate": 5.99e-05,
429
+ "loss": 4.746239852905274,
430
+ "step": 600
431
+ },
432
+ {
433
+ "epoch": 0.6727322856355115,
434
+ "grad_norm": 4.070414066314697,
435
+ "learning_rate": 6.09e-05,
436
+ "loss": 4.688523864746093,
437
+ "step": 610
438
+ },
439
+ {
440
+ "epoch": 0.6837606837606838,
441
+ "grad_norm": 3.4652583599090576,
442
+ "learning_rate": 6.19e-05,
443
+ "loss": 4.692922973632813,
444
+ "step": 620
445
+ },
446
+ {
447
+ "epoch": 0.6947890818858561,
448
+ "grad_norm": 4.559128284454346,
449
+ "learning_rate": 6.29e-05,
450
+ "loss": 4.639920043945312,
451
+ "step": 630
452
+ },
453
+ {
454
+ "epoch": 0.7058174800110284,
455
+ "grad_norm": 3.197758436203003,
456
+ "learning_rate": 6.390000000000001e-05,
457
+ "loss": 4.601907348632812,
458
+ "step": 640
459
+ },
460
+ {
461
+ "epoch": 0.7168458781362007,
462
+ "grad_norm": 4.209578514099121,
463
+ "learning_rate": 6.49e-05,
464
+ "loss": 4.56639404296875,
465
+ "step": 650
466
+ },
467
+ {
468
+ "epoch": 0.727874276261373,
469
+ "grad_norm": 3.701484203338623,
470
+ "learning_rate": 6.59e-05,
471
+ "loss": 4.545608901977539,
472
+ "step": 660
473
+ },
474
+ {
475
+ "epoch": 0.7389026743865453,
476
+ "grad_norm": 3.951927900314331,
477
+ "learning_rate": 6.690000000000001e-05,
478
+ "loss": 4.493326187133789,
479
+ "step": 670
480
+ },
481
+ {
482
+ "epoch": 0.7499310725117176,
483
+ "grad_norm": 4.219130039215088,
484
+ "learning_rate": 6.790000000000001e-05,
485
+ "loss": 4.482691955566406,
486
+ "step": 680
487
+ },
488
+ {
489
+ "epoch": 0.76095947063689,
490
+ "grad_norm": 6.267204284667969,
491
+ "learning_rate": 6.89e-05,
492
+ "loss": 4.4599052429199215,
493
+ "step": 690
494
+ },
495
+ {
496
+ "epoch": 0.7719878687620623,
497
+ "grad_norm": 3.367382764816284,
498
+ "learning_rate": 6.99e-05,
499
+ "loss": 4.429808807373047,
500
+ "step": 700
501
+ },
502
+ {
503
+ "epoch": 0.7830162668872346,
504
+ "grad_norm": 3.8906455039978027,
505
+ "learning_rate": 7.09e-05,
506
+ "loss": 4.4144752502441404,
507
+ "step": 710
508
+ },
509
+ {
510
+ "epoch": 0.794044665012407,
511
+ "grad_norm": 6.759398460388184,
512
+ "learning_rate": 7.19e-05,
513
+ "loss": 4.385488891601563,
514
+ "step": 720
515
+ },
516
+ {
517
+ "epoch": 0.8050730631375793,
518
+ "grad_norm": 3.520167350769043,
519
+ "learning_rate": 7.29e-05,
520
+ "loss": 4.397706985473633,
521
+ "step": 730
522
+ },
523
+ {
524
+ "epoch": 0.8161014612627516,
525
+ "grad_norm": 2.7510974407196045,
526
+ "learning_rate": 7.390000000000001e-05,
527
+ "loss": 4.374617385864258,
528
+ "step": 740
529
+ },
530
+ {
531
+ "epoch": 0.8271298593879239,
532
+ "grad_norm": 4.395699977874756,
533
+ "learning_rate": 7.49e-05,
534
+ "loss": 4.3302146911621096,
535
+ "step": 750
536
+ },
537
+ {
538
+ "epoch": 0.8381582575130962,
539
+ "grad_norm": 3.277766704559326,
540
+ "learning_rate": 7.59e-05,
541
+ "loss": 4.313335418701172,
542
+ "step": 760
543
+ },
544
+ {
545
+ "epoch": 0.8491866556382686,
546
+ "grad_norm": 2.466207981109619,
547
+ "learning_rate": 7.69e-05,
548
+ "loss": 4.3226570129394535,
549
+ "step": 770
550
+ },
551
+ {
552
+ "epoch": 0.8602150537634409,
553
+ "grad_norm": 3.637355327606201,
554
+ "learning_rate": 7.790000000000001e-05,
555
+ "loss": 4.295929718017578,
556
+ "step": 780
557
+ },
558
+ {
559
+ "epoch": 0.8712434518886132,
560
+ "grad_norm": 3.155527353286743,
561
+ "learning_rate": 7.890000000000001e-05,
562
+ "loss": 4.287591552734375,
563
+ "step": 790
564
+ },
565
+ {
566
+ "epoch": 0.8822718500137855,
567
+ "grad_norm": 3.593884229660034,
568
+ "learning_rate": 7.99e-05,
569
+ "loss": 4.267314147949219,
570
+ "step": 800
571
+ },
572
+ {
573
+ "epoch": 0.8933002481389578,
574
+ "grad_norm": 2.361081123352051,
575
+ "learning_rate": 8.090000000000001e-05,
576
+ "loss": 4.265741348266602,
577
+ "step": 810
578
+ },
579
+ {
580
+ "epoch": 0.9043286462641301,
581
+ "grad_norm": 2.7084105014801025,
582
+ "learning_rate": 8.19e-05,
583
+ "loss": 4.261878204345703,
584
+ "step": 820
585
+ },
586
+ {
587
+ "epoch": 0.9153570443893024,
588
+ "grad_norm": 3.6093873977661133,
589
+ "learning_rate": 8.29e-05,
590
+ "loss": 4.211677551269531,
591
+ "step": 830
592
+ },
593
+ {
594
+ "epoch": 0.9263854425144747,
595
+ "grad_norm": 3.9739396572113037,
596
+ "learning_rate": 8.39e-05,
597
+ "loss": 4.224007034301758,
598
+ "step": 840
599
+ },
600
+ {
601
+ "epoch": 0.9374138406396471,
602
+ "grad_norm": 2.174050807952881,
603
+ "learning_rate": 8.49e-05,
604
+ "loss": 4.211782836914063,
605
+ "step": 850
606
+ },
607
+ {
608
+ "epoch": 0.9484422387648194,
609
+ "grad_norm": 2.7151405811309814,
610
+ "learning_rate": 8.59e-05,
611
+ "loss": 4.204391098022461,
612
+ "step": 860
613
+ },
614
+ {
615
+ "epoch": 0.9594706368899917,
616
+ "grad_norm": 3.7480661869049072,
617
+ "learning_rate": 8.69e-05,
618
+ "loss": 4.175582504272461,
619
+ "step": 870
620
+ },
621
+ {
622
+ "epoch": 0.9704990350151641,
623
+ "grad_norm": 3.1127700805664062,
624
+ "learning_rate": 8.790000000000001e-05,
625
+ "loss": 4.183733749389648,
626
+ "step": 880
627
+ },
628
+ {
629
+ "epoch": 0.9815274331403364,
630
+ "grad_norm": 2.750716209411621,
631
+ "learning_rate": 8.89e-05,
632
+ "loss": 4.167971801757813,
633
+ "step": 890
634
+ },
635
+ {
636
+ "epoch": 0.9925558312655087,
637
+ "grad_norm": 4.02509880065918,
638
+ "learning_rate": 8.99e-05,
639
+ "loss": 4.170472717285156,
640
+ "step": 900
641
+ },
642
+ {
643
+ "epoch": 1.0033085194375517,
644
+ "grad_norm": 3.0058505535125732,
645
+ "learning_rate": 9.090000000000001e-05,
646
+ "loss": 4.1449127197265625,
647
+ "step": 910
648
+ },
649
+ {
650
+ "epoch": 1.014336917562724,
651
+ "grad_norm": 2.553403377532959,
652
+ "learning_rate": 9.190000000000001e-05,
653
+ "loss": 4.1404258728027346,
654
+ "step": 920
655
+ },
656
+ {
657
+ "epoch": 1.0253653156878964,
658
+ "grad_norm": 2.8066084384918213,
659
+ "learning_rate": 9.290000000000001e-05,
660
+ "loss": 4.110780334472656,
661
+ "step": 930
662
+ },
663
+ {
664
+ "epoch": 1.0363937138130686,
665
+ "grad_norm": 3.904608726501465,
666
+ "learning_rate": 9.39e-05,
667
+ "loss": 4.134862899780273,
668
+ "step": 940
669
+ },
670
+ {
671
+ "epoch": 1.047422111938241,
672
+ "grad_norm": 2.217729330062866,
673
+ "learning_rate": 9.49e-05,
674
+ "loss": 4.112079620361328,
675
+ "step": 950
676
+ },
677
+ {
678
+ "epoch": 1.0584505100634134,
679
+ "grad_norm": 2.498760938644409,
680
+ "learning_rate": 9.59e-05,
681
+ "loss": 4.097566986083985,
682
+ "step": 960
683
+ },
684
+ {
685
+ "epoch": 1.0694789081885856,
686
+ "grad_norm": 3.577143907546997,
687
+ "learning_rate": 9.69e-05,
688
+ "loss": 4.081307220458984,
689
+ "step": 970
690
+ },
691
+ {
692
+ "epoch": 1.080507306313758,
693
+ "grad_norm": 3.283250570297241,
694
+ "learning_rate": 9.790000000000001e-05,
695
+ "loss": 4.103987503051758,
696
+ "step": 980
697
+ },
698
+ {
699
+ "epoch": 1.0915357044389302,
700
+ "grad_norm": 2.1897776126861572,
701
+ "learning_rate": 9.89e-05,
702
+ "loss": 4.084938812255859,
703
+ "step": 990
704
+ },
705
+ {
706
+ "epoch": 1.1025641025641026,
707
+ "grad_norm": 2.6925997734069824,
708
+ "learning_rate": 9.99e-05,
709
+ "loss": 4.058921051025391,
710
+ "step": 1000
711
+ },
712
+ {
713
+ "epoch": 1.1135925006892748,
714
+ "grad_norm": 3.4118456840515137,
715
+ "learning_rate": 9.994749124854142e-05,
716
+ "loss": 4.061585235595703,
717
+ "step": 1010
718
+ },
719
+ {
720
+ "epoch": 1.1246208988144473,
721
+ "grad_norm": 2.6139297485351562,
722
+ "learning_rate": 9.988914819136523e-05,
723
+ "loss": 4.070050048828125,
724
+ "step": 1020
725
+ },
726
+ {
727
+ "epoch": 1.1356492969396195,
728
+ "grad_norm": 1.8616399765014648,
729
+ "learning_rate": 9.983080513418903e-05,
730
+ "loss": 4.0413330078125,
731
+ "step": 1030
732
+ },
733
+ {
734
+ "epoch": 1.146677695064792,
735
+ "grad_norm": 2.361706018447876,
736
+ "learning_rate": 9.977246207701284e-05,
737
+ "loss": 4.023075866699219,
738
+ "step": 1040
739
+ },
740
+ {
741
+ "epoch": 1.157706093189964,
742
+ "grad_norm": 3.815014123916626,
743
+ "learning_rate": 9.971411901983664e-05,
744
+ "loss": 4.036756134033203,
745
+ "step": 1050
746
+ },
747
+ {
748
+ "epoch": 1.1687344913151365,
749
+ "grad_norm": 2.4410274028778076,
750
+ "learning_rate": 9.965577596266045e-05,
751
+ "loss": 4.020483779907226,
752
+ "step": 1060
753
+ },
754
+ {
755
+ "epoch": 1.1797628894403087,
756
+ "grad_norm": 2.768084764480591,
757
+ "learning_rate": 9.959743290548426e-05,
758
+ "loss": 4.021839141845703,
759
+ "step": 1070
760
+ },
761
+ {
762
+ "epoch": 1.1907912875654811,
763
+ "grad_norm": 1.9342570304870605,
764
+ "learning_rate": 9.953908984830806e-05,
765
+ "loss": 4.026360321044922,
766
+ "step": 1080
767
+ },
768
+ {
769
+ "epoch": 1.2018196856906533,
770
+ "grad_norm": 2.8184762001037598,
771
+ "learning_rate": 9.948074679113187e-05,
772
+ "loss": 4.007581329345703,
773
+ "step": 1090
774
+ },
775
+ {
776
+ "epoch": 1.2128480838158258,
777
+ "grad_norm": 3.2656188011169434,
778
+ "learning_rate": 9.942240373395566e-05,
779
+ "loss": 3.9965087890625,
780
+ "step": 1100
781
+ },
782
+ {
783
+ "epoch": 1.223876481940998,
784
+ "grad_norm": 2.4359538555145264,
785
+ "learning_rate": 9.936406067677947e-05,
786
+ "loss": 3.9959388732910157,
787
+ "step": 1110
788
+ },
789
+ {
790
+ "epoch": 1.2349048800661704,
791
+ "grad_norm": 1.9357632398605347,
792
+ "learning_rate": 9.930571761960327e-05,
793
+ "loss": 3.9851417541503906,
794
+ "step": 1120
795
+ },
796
+ {
797
+ "epoch": 1.2459332781913428,
798
+ "grad_norm": 2.1269352436065674,
799
+ "learning_rate": 9.924737456242708e-05,
800
+ "loss": 3.9773223876953123,
801
+ "step": 1130
802
+ },
803
+ {
804
+ "epoch": 1.256961676316515,
805
+ "grad_norm": 3.3491597175598145,
806
+ "learning_rate": 9.918903150525088e-05,
807
+ "loss": 3.9877471923828125,
808
+ "step": 1140
809
+ },
810
+ {
811
+ "epoch": 1.2679900744416872,
812
+ "grad_norm": 1.8646328449249268,
813
+ "learning_rate": 9.913068844807468e-05,
814
+ "loss": 3.9694965362548826,
815
+ "step": 1150
816
+ },
817
+ {
818
+ "epoch": 1.2790184725668596,
819
+ "grad_norm": 2.6204631328582764,
820
+ "learning_rate": 9.907234539089849e-05,
821
+ "loss": 3.9611881256103514,
822
+ "step": 1160
823
+ },
824
+ {
825
+ "epoch": 1.290046870692032,
826
+ "grad_norm": 1.872028112411499,
827
+ "learning_rate": 9.901400233372228e-05,
828
+ "loss": 3.964163970947266,
829
+ "step": 1170
830
+ },
831
+ {
832
+ "epoch": 1.3010752688172043,
833
+ "grad_norm": 3.490435838699341,
834
+ "learning_rate": 9.895565927654609e-05,
835
+ "loss": 3.959897994995117,
836
+ "step": 1180
837
+ },
838
+ {
839
+ "epoch": 1.3121036669423767,
840
+ "grad_norm": 2.862489700317383,
841
+ "learning_rate": 9.88973162193699e-05,
842
+ "loss": 3.9567939758300783,
843
+ "step": 1190
844
+ },
845
+ {
846
+ "epoch": 1.3231320650675489,
847
+ "grad_norm": 3.0570664405822754,
848
+ "learning_rate": 9.883897316219371e-05,
849
+ "loss": 3.9470645904541017,
850
+ "step": 1200
851
+ },
852
+ {
853
+ "epoch": 1.3341604631927213,
854
+ "grad_norm": 1.9254627227783203,
855
+ "learning_rate": 9.878063010501752e-05,
856
+ "loss": 3.9442317962646483,
857
+ "step": 1210
858
+ },
859
+ {
860
+ "epoch": 1.3451888613178935,
861
+ "grad_norm": 3.606224298477173,
862
+ "learning_rate": 9.872228704784131e-05,
863
+ "loss": 3.9380733489990236,
864
+ "step": 1220
865
+ },
866
+ {
867
+ "epoch": 1.356217259443066,
868
+ "grad_norm": 2.1184027194976807,
869
+ "learning_rate": 9.866394399066512e-05,
870
+ "loss": 3.9452835083007813,
871
+ "step": 1230
872
+ },
873
+ {
874
+ "epoch": 1.3672456575682381,
875
+ "grad_norm": 1.8997142314910889,
876
+ "learning_rate": 9.860560093348892e-05,
877
+ "loss": 3.9270603179931642,
878
+ "step": 1240
879
+ },
880
+ {
881
+ "epoch": 1.3782740556934105,
882
+ "grad_norm": 2.9672305583953857,
883
+ "learning_rate": 9.854725787631273e-05,
884
+ "loss": 3.9120155334472657,
885
+ "step": 1250
886
+ },
887
+ {
888
+ "epoch": 1.389302453818583,
889
+ "grad_norm": 1.9220951795578003,
890
+ "learning_rate": 9.848891481913652e-05,
891
+ "loss": 3.900279235839844,
892
+ "step": 1260
893
+ },
894
+ {
895
+ "epoch": 1.4003308519437552,
896
+ "grad_norm": 2.013521194458008,
897
+ "learning_rate": 9.843057176196033e-05,
898
+ "loss": 3.9147193908691404,
899
+ "step": 1270
900
+ },
901
+ {
902
+ "epoch": 1.4113592500689274,
903
+ "grad_norm": 1.451686143875122,
904
+ "learning_rate": 9.837222870478413e-05,
905
+ "loss": 3.906220245361328,
906
+ "step": 1280
907
+ },
908
+ {
909
+ "epoch": 1.4223876481940998,
910
+ "grad_norm": 4.606860637664795,
911
+ "learning_rate": 9.831388564760794e-05,
912
+ "loss": 3.905352020263672,
913
+ "step": 1290
914
+ },
915
+ {
916
+ "epoch": 1.4334160463192722,
917
+ "grad_norm": 1.779123306274414,
918
+ "learning_rate": 9.825554259043175e-05,
919
+ "loss": 3.9137496948242188,
920
+ "step": 1300
921
+ },
922
+ {
923
+ "epoch": 1.4444444444444444,
924
+ "grad_norm": 2.086585521697998,
925
+ "learning_rate": 9.819719953325554e-05,
926
+ "loss": 3.89554443359375,
927
+ "step": 1310
928
+ },
929
+ {
930
+ "epoch": 1.4554728425696168,
931
+ "grad_norm": 3.3514609336853027,
932
+ "learning_rate": 9.813885647607935e-05,
933
+ "loss": 3.8901123046875,
934
+ "step": 1320
935
+ },
936
+ {
937
+ "epoch": 1.466501240694789,
938
+ "grad_norm": 2.1145269870758057,
939
+ "learning_rate": 9.808051341890316e-05,
940
+ "loss": 3.8892486572265623,
941
+ "step": 1330
942
+ },
943
+ {
944
+ "epoch": 1.4775296388199615,
945
+ "grad_norm": 1.5503329038619995,
946
+ "learning_rate": 9.802217036172697e-05,
947
+ "loss": 3.8922355651855467,
948
+ "step": 1340
949
+ },
950
+ {
951
+ "epoch": 1.4885580369451337,
952
+ "grad_norm": 2.3014304637908936,
953
+ "learning_rate": 9.796382730455076e-05,
954
+ "loss": 3.8860099792480467,
955
+ "step": 1350
956
+ },
957
+ {
958
+ "epoch": 1.499586435070306,
959
+ "grad_norm": 1.9633557796478271,
960
+ "learning_rate": 9.790548424737457e-05,
961
+ "loss": 3.875183868408203,
962
+ "step": 1360
963
+ },
964
+ {
965
+ "epoch": 1.5106148331954783,
966
+ "grad_norm": 2.228351593017578,
967
+ "learning_rate": 9.784714119019837e-05,
968
+ "loss": 3.8726768493652344,
969
+ "step": 1370
970
+ },
971
+ {
972
+ "epoch": 1.5216432313206507,
973
+ "grad_norm": 3.0888657569885254,
974
+ "learning_rate": 9.778879813302218e-05,
975
+ "loss": 3.872690963745117,
976
+ "step": 1380
977
+ },
978
+ {
979
+ "epoch": 1.5326716294458231,
980
+ "grad_norm": 2.0078868865966797,
981
+ "learning_rate": 9.773045507584599e-05,
982
+ "loss": 3.8612388610839843,
983
+ "step": 1390
984
+ },
985
+ {
986
+ "epoch": 1.5437000275709953,
987
+ "grad_norm": 2.1966569423675537,
988
+ "learning_rate": 9.767211201866978e-05,
989
+ "loss": 3.8649852752685545,
990
+ "step": 1400
991
+ },
992
+ {
993
+ "epoch": 1.5547284256961675,
994
+ "grad_norm": 2.1047487258911133,
995
+ "learning_rate": 9.761376896149359e-05,
996
+ "loss": 3.8632328033447267,
997
+ "step": 1410
998
+ },
999
+ {
1000
+ "epoch": 1.56575682382134,
1001
+ "grad_norm": 1.9347233772277832,
1002
+ "learning_rate": 9.755542590431739e-05,
1003
+ "loss": 3.8362571716308596,
1004
+ "step": 1420
1005
+ },
1006
+ {
1007
+ "epoch": 1.5767852219465124,
1008
+ "grad_norm": 1.7961437702178955,
1009
+ "learning_rate": 9.74970828471412e-05,
1010
+ "loss": 3.8461585998535157,
1011
+ "step": 1430
1012
+ },
1013
+ {
1014
+ "epoch": 1.5878136200716846,
1015
+ "grad_norm": 2.4657342433929443,
1016
+ "learning_rate": 9.743873978996499e-05,
1017
+ "loss": 3.842551040649414,
1018
+ "step": 1440
1019
+ },
1020
+ {
1021
+ "epoch": 1.5988420181968568,
1022
+ "grad_norm": 2.043138027191162,
1023
+ "learning_rate": 9.73803967327888e-05,
1024
+ "loss": 3.8387855529785155,
1025
+ "step": 1450
1026
+ },
1027
+ {
1028
+ "epoch": 1.6098704163220292,
1029
+ "grad_norm": 3.732532262802124,
1030
+ "learning_rate": 9.732205367561261e-05,
1031
+ "loss": 3.8399681091308593,
1032
+ "step": 1460
1033
+ },
1034
+ {
1035
+ "epoch": 1.6208988144472016,
1036
+ "grad_norm": 2.43684720993042,
1037
+ "learning_rate": 9.726371061843642e-05,
1038
+ "loss": 3.8324966430664062,
1039
+ "step": 1470
1040
+ },
1041
+ {
1042
+ "epoch": 1.6319272125723738,
1043
+ "grad_norm": 2.4433460235595703,
1044
+ "learning_rate": 9.720536756126023e-05,
1045
+ "loss": 3.817783737182617,
1046
+ "step": 1480
1047
+ },
1048
+ {
1049
+ "epoch": 1.642955610697546,
1050
+ "grad_norm": 2.1049606800079346,
1051
+ "learning_rate": 9.714702450408402e-05,
1052
+ "loss": 3.804280090332031,
1053
+ "step": 1490
1054
+ },
1055
+ {
1056
+ "epoch": 1.6539840088227185,
1057
+ "grad_norm": 3.529686450958252,
1058
+ "learning_rate": 9.708868144690783e-05,
1059
+ "loss": 3.805449295043945,
1060
+ "step": 1500
1061
+ },
1062
+ {
1063
+ "epoch": 1.6650124069478909,
1064
+ "grad_norm": 2.0984089374542236,
1065
+ "learning_rate": 9.703033838973162e-05,
1066
+ "loss": 3.788246917724609,
1067
+ "step": 1510
1068
+ },
1069
+ {
1070
+ "epoch": 1.6760408050730633,
1071
+ "grad_norm": 1.9434291124343872,
1072
+ "learning_rate": 9.697199533255543e-05,
1073
+ "loss": 3.7875442504882812,
1074
+ "step": 1520
1075
+ },
1076
+ {
1077
+ "epoch": 1.6870692031982355,
1078
+ "grad_norm": 1.99173903465271,
1079
+ "learning_rate": 9.691365227537923e-05,
1080
+ "loss": 3.7807193756103517,
1081
+ "step": 1530
1082
+ },
1083
+ {
1084
+ "epoch": 1.6980976013234077,
1085
+ "grad_norm": 2.5006911754608154,
1086
+ "learning_rate": 9.685530921820304e-05,
1087
+ "loss": 3.744763946533203,
1088
+ "step": 1540
1089
+ },
1090
+ {
1091
+ "epoch": 1.7091259994485801,
1092
+ "grad_norm": 2.1816165447235107,
1093
+ "learning_rate": 9.679696616102685e-05,
1094
+ "loss": 3.760245513916016,
1095
+ "step": 1550
1096
+ },
1097
+ {
1098
+ "epoch": 1.7201543975737525,
1099
+ "grad_norm": 2.123291492462158,
1100
+ "learning_rate": 9.673862310385064e-05,
1101
+ "loss": 3.738916778564453,
1102
+ "step": 1560
1103
+ },
1104
+ {
1105
+ "epoch": 1.7311827956989247,
1106
+ "grad_norm": 2.378187894821167,
1107
+ "learning_rate": 9.668028004667445e-05,
1108
+ "loss": 3.734139251708984,
1109
+ "step": 1570
1110
+ },
1111
+ {
1112
+ "epoch": 1.742211193824097,
1113
+ "grad_norm": 2.54819393157959,
1114
+ "learning_rate": 9.662193698949825e-05,
1115
+ "loss": 3.715302276611328,
1116
+ "step": 1580
1117
+ },
1118
+ {
1119
+ "epoch": 1.7532395919492694,
1120
+ "grad_norm": 4.285822868347168,
1121
+ "learning_rate": 9.656359393232206e-05,
1122
+ "loss": 3.72213134765625,
1123
+ "step": 1590
1124
+ },
1125
+ {
1126
+ "epoch": 1.7642679900744418,
1127
+ "grad_norm": 1.8676700592041016,
1128
+ "learning_rate": 9.650525087514586e-05,
1129
+ "loss": 3.7252479553222657,
1130
+ "step": 1600
1131
+ },
1132
+ {
1133
+ "epoch": 1.775296388199614,
1134
+ "grad_norm": 1.6977792978286743,
1135
+ "learning_rate": 9.644690781796967e-05,
1136
+ "loss": 3.704994964599609,
1137
+ "step": 1610
1138
+ },
1139
+ {
1140
+ "epoch": 1.7863247863247862,
1141
+ "grad_norm": 1.8334232568740845,
1142
+ "learning_rate": 9.638856476079347e-05,
1143
+ "loss": 3.6980815887451173,
1144
+ "step": 1620
1145
+ },
1146
+ {
1147
+ "epoch": 1.7973531844499586,
1148
+ "grad_norm": 2.6574559211730957,
1149
+ "learning_rate": 9.633022170361728e-05,
1150
+ "loss": 3.683759307861328,
1151
+ "step": 1630
1152
+ },
1153
+ {
1154
+ "epoch": 1.808381582575131,
1155
+ "grad_norm": 2.085084915161133,
1156
+ "learning_rate": 9.627187864644109e-05,
1157
+ "loss": 3.67755126953125,
1158
+ "step": 1640
1159
+ },
1160
+ {
1161
+ "epoch": 1.8194099807003032,
1162
+ "grad_norm": 1.685441017150879,
1163
+ "learning_rate": 9.621353558926488e-05,
1164
+ "loss": 3.656099319458008,
1165
+ "step": 1650
1166
+ },
1167
+ {
1168
+ "epoch": 1.8304383788254754,
1169
+ "grad_norm": 2.4462475776672363,
1170
+ "learning_rate": 9.615519253208869e-05,
1171
+ "loss": 3.668656921386719,
1172
+ "step": 1660
1173
+ },
1174
+ {
1175
+ "epoch": 1.8414667769506479,
1176
+ "grad_norm": 1.54155433177948,
1177
+ "learning_rate": 9.609684947491249e-05,
1178
+ "loss": 3.66968994140625,
1179
+ "step": 1670
1180
+ },
1181
+ {
1182
+ "epoch": 1.8524951750758203,
1183
+ "grad_norm": 3.862130880355835,
1184
+ "learning_rate": 9.60385064177363e-05,
1185
+ "loss": 3.6412506103515625,
1186
+ "step": 1680
1187
+ },
1188
+ {
1189
+ "epoch": 1.8635235732009927,
1190
+ "grad_norm": 1.7317070960998535,
1191
+ "learning_rate": 9.598016336056009e-05,
1192
+ "loss": 3.639806365966797,
1193
+ "step": 1690
1194
+ },
1195
+ {
1196
+ "epoch": 1.874551971326165,
1197
+ "grad_norm": 2.2640931606292725,
1198
+ "learning_rate": 9.59218203033839e-05,
1199
+ "loss": 3.6341064453125,
1200
+ "step": 1700
1201
+ },
1202
+ {
1203
+ "epoch": 1.8855803694513371,
1204
+ "grad_norm": 3.653146743774414,
1205
+ "learning_rate": 9.586347724620771e-05,
1206
+ "loss": 3.6380882263183594,
1207
+ "step": 1710
1208
+ },
1209
+ {
1210
+ "epoch": 1.8966087675765095,
1211
+ "grad_norm": 1.8987306356430054,
1212
+ "learning_rate": 9.58051341890315e-05,
1213
+ "loss": 3.6405975341796877,
1214
+ "step": 1720
1215
+ },
1216
+ {
1217
+ "epoch": 1.907637165701682,
1218
+ "grad_norm": 2.202659845352173,
1219
+ "learning_rate": 9.574679113185531e-05,
1220
+ "loss": 3.6375991821289064,
1221
+ "step": 1730
1222
+ },
1223
+ {
1224
+ "epoch": 1.9186655638268542,
1225
+ "grad_norm": 1.5091872215270996,
1226
+ "learning_rate": 9.568844807467912e-05,
1227
+ "loss": 3.6208465576171873,
1228
+ "step": 1740
1229
+ },
1230
+ {
1231
+ "epoch": 1.9296939619520264,
1232
+ "grad_norm": 1.9811325073242188,
1233
+ "learning_rate": 9.563010501750293e-05,
1234
+ "loss": 3.600755310058594,
1235
+ "step": 1750
1236
+ },
1237
+ {
1238
+ "epoch": 1.9407223600771988,
1239
+ "grad_norm": 3.184499979019165,
1240
+ "learning_rate": 9.557176196032673e-05,
1241
+ "loss": 3.6109405517578126,
1242
+ "step": 1760
1243
+ },
1244
+ {
1245
+ "epoch": 1.9517507582023712,
1246
+ "grad_norm": 2.340125322341919,
1247
+ "learning_rate": 9.551341890315054e-05,
1248
+ "loss": 3.6129817962646484,
1249
+ "step": 1770
1250
+ },
1251
+ {
1252
+ "epoch": 1.9627791563275434,
1253
+ "grad_norm": 1.7258495092391968,
1254
+ "learning_rate": 9.545507584597433e-05,
1255
+ "loss": 3.590809631347656,
1256
+ "step": 1780
1257
+ },
1258
+ {
1259
+ "epoch": 1.9738075544527156,
1260
+ "grad_norm": 1.6129754781723022,
1261
+ "learning_rate": 9.539673278879814e-05,
1262
+ "loss": 3.5866302490234374,
1263
+ "step": 1790
1264
+ },
1265
+ {
1266
+ "epoch": 1.984835952577888,
1267
+ "grad_norm": 2.7458667755126953,
1268
+ "learning_rate": 9.533838973162195e-05,
1269
+ "loss": 3.596644973754883,
1270
+ "step": 1800
1271
+ },
1272
+ {
1273
+ "epoch": 1.9958643507030605,
1274
+ "grad_norm": 2.258280038833618,
1275
+ "learning_rate": 9.528004667444574e-05,
1276
+ "loss": 3.5881332397460937,
1277
+ "step": 1810
1278
+ },
1279
+ {
1280
+ "epoch": 2.0066170388751035,
1281
+ "grad_norm": 2.1228580474853516,
1282
+ "learning_rate": 9.522170361726955e-05,
1283
+ "loss": 3.5709766387939452,
1284
+ "step": 1820
1285
+ },
1286
+ {
1287
+ "epoch": 2.017645437000276,
1288
+ "grad_norm": 1.588876485824585,
1289
+ "learning_rate": 9.516336056009335e-05,
1290
+ "loss": 3.5627593994140625,
1291
+ "step": 1830
1292
+ },
1293
+ {
1294
+ "epoch": 2.028673835125448,
1295
+ "grad_norm": 2.451474189758301,
1296
+ "learning_rate": 9.510501750291716e-05,
1297
+ "loss": 3.5535301208496093,
1298
+ "step": 1840
1299
+ },
1300
+ {
1301
+ "epoch": 2.0397022332506203,
1302
+ "grad_norm": 2.0007503032684326,
1303
+ "learning_rate": 9.504667444574095e-05,
1304
+ "loss": 3.553875732421875,
1305
+ "step": 1850
1306
+ },
1307
+ {
1308
+ "epoch": 2.0507306313757927,
1309
+ "grad_norm": 1.4410080909729004,
1310
+ "learning_rate": 9.498833138856476e-05,
1311
+ "loss": 3.550189971923828,
1312
+ "step": 1860
1313
+ },
1314
+ {
1315
+ "epoch": 2.061759029500965,
1316
+ "grad_norm": 2.062835216522217,
1317
+ "learning_rate": 9.492998833138857e-05,
1318
+ "loss": 3.5456893920898436,
1319
+ "step": 1870
1320
+ },
1321
+ {
1322
+ "epoch": 2.072787427626137,
1323
+ "grad_norm": 2.4534783363342285,
1324
+ "learning_rate": 9.487164527421238e-05,
1325
+ "loss": 3.536829376220703,
1326
+ "step": 1880
1327
+ },
1328
+ {
1329
+ "epoch": 2.0838158257513095,
1330
+ "grad_norm": 2.2788970470428467,
1331
+ "learning_rate": 9.481330221703619e-05,
1332
+ "loss": 3.5525283813476562,
1333
+ "step": 1890
1334
+ },
1335
+ {
1336
+ "epoch": 2.094844223876482,
1337
+ "grad_norm": 1.4259227514266968,
1338
+ "learning_rate": 9.475495915985998e-05,
1339
+ "loss": 3.5479995727539064,
1340
+ "step": 1900
1341
+ },
1342
+ {
1343
+ "epoch": 2.1058726220016544,
1344
+ "grad_norm": 2.672534465789795,
1345
+ "learning_rate": 9.469661610268379e-05,
1346
+ "loss": 3.5359420776367188,
1347
+ "step": 1910
1348
+ },
1349
+ {
1350
+ "epoch": 2.116901020126827,
1351
+ "grad_norm": 2.0648045539855957,
1352
+ "learning_rate": 9.463827304550759e-05,
1353
+ "loss": 3.5452896118164063,
1354
+ "step": 1920
1355
+ },
1356
+ {
1357
+ "epoch": 2.1279294182519988,
1358
+ "grad_norm": 1.6846543550491333,
1359
+ "learning_rate": 9.45799299883314e-05,
1360
+ "loss": 3.5434345245361327,
1361
+ "step": 1930
1362
+ },
1363
+ {
1364
+ "epoch": 2.138957816377171,
1365
+ "grad_norm": 1.9105942249298096,
1366
+ "learning_rate": 9.452158693115519e-05,
1367
+ "loss": 3.5351535797119142,
1368
+ "step": 1940
1369
+ },
1370
+ {
1371
+ "epoch": 2.1499862145023436,
1372
+ "grad_norm": 1.8230890035629272,
1373
+ "learning_rate": 9.4463243873979e-05,
1374
+ "loss": 3.5190963745117188,
1375
+ "step": 1950
1376
+ },
1377
+ {
1378
+ "epoch": 2.161014612627516,
1379
+ "grad_norm": 1.6383274793624878,
1380
+ "learning_rate": 9.440490081680281e-05,
1381
+ "loss": 3.5228431701660154,
1382
+ "step": 1960
1383
+ },
1384
+ {
1385
+ "epoch": 2.172043010752688,
1386
+ "grad_norm": 1.7378439903259277,
1387
+ "learning_rate": 9.43465577596266e-05,
1388
+ "loss": 3.520981216430664,
1389
+ "step": 1970
1390
+ },
1391
+ {
1392
+ "epoch": 2.1830714088778604,
1393
+ "grad_norm": 1.941454529762268,
1394
+ "learning_rate": 9.428821470245041e-05,
1395
+ "loss": 3.519342803955078,
1396
+ "step": 1980
1397
+ },
1398
+ {
1399
+ "epoch": 2.194099807003033,
1400
+ "grad_norm": 1.8295516967773438,
1401
+ "learning_rate": 9.422987164527421e-05,
1402
+ "loss": 3.5412979125976562,
1403
+ "step": 1990
1404
+ },
1405
+ {
1406
+ "epoch": 2.2051282051282053,
1407
+ "grad_norm": 1.8052620887756348,
1408
+ "learning_rate": 9.417152858809802e-05,
1409
+ "loss": 3.5153289794921876,
1410
+ "step": 2000
1411
+ },
1412
+ {
1413
+ "epoch": 2.2161566032533773,
1414
+ "grad_norm": 2.1949570178985596,
1415
+ "learning_rate": 9.411318553092183e-05,
1416
+ "loss": 3.521608352661133,
1417
+ "step": 2010
1418
+ },
1419
+ {
1420
+ "epoch": 2.2271850013785497,
1421
+ "grad_norm": 1.746172308921814,
1422
+ "learning_rate": 9.405484247374564e-05,
1423
+ "loss": 3.5008296966552734,
1424
+ "step": 2020
1425
+ },
1426
+ {
1427
+ "epoch": 2.238213399503722,
1428
+ "grad_norm": 2.5374276638031006,
1429
+ "learning_rate": 9.399649941656943e-05,
1430
+ "loss": 3.5140228271484375,
1431
+ "step": 2030
1432
+ },
1433
+ {
1434
+ "epoch": 2.2492417976288945,
1435
+ "grad_norm": 1.7763218879699707,
1436
+ "learning_rate": 9.393815635939324e-05,
1437
+ "loss": 3.510652542114258,
1438
+ "step": 2040
1439
+ },
1440
+ {
1441
+ "epoch": 2.2602701957540665,
1442
+ "grad_norm": 1.6599587202072144,
1443
+ "learning_rate": 9.387981330221705e-05,
1444
+ "loss": 3.5122325897216795,
1445
+ "step": 2050
1446
+ },
1447
+ {
1448
+ "epoch": 2.271298593879239,
1449
+ "grad_norm": 2.1496078968048096,
1450
+ "learning_rate": 9.382147024504085e-05,
1451
+ "loss": 3.5139747619628907,
1452
+ "step": 2060
1453
+ },
1454
+ {
1455
+ "epoch": 2.2823269920044114,
1456
+ "grad_norm": 1.64266836643219,
1457
+ "learning_rate": 9.376312718786465e-05,
1458
+ "loss": 3.507743072509766,
1459
+ "step": 2070
1460
+ },
1461
+ {
1462
+ "epoch": 2.293355390129584,
1463
+ "grad_norm": 2.1241567134857178,
1464
+ "learning_rate": 9.370478413068845e-05,
1465
+ "loss": 3.5162708282470705,
1466
+ "step": 2080
1467
+ },
1468
+ {
1469
+ "epoch": 2.304383788254756,
1470
+ "grad_norm": 1.8391071557998657,
1471
+ "learning_rate": 9.364644107351226e-05,
1472
+ "loss": 3.4955375671386717,
1473
+ "step": 2090
1474
+ },
1475
+ {
1476
+ "epoch": 2.315412186379928,
1477
+ "grad_norm": 2.7478973865509033,
1478
+ "learning_rate": 9.358809801633605e-05,
1479
+ "loss": 3.497519302368164,
1480
+ "step": 2100
1481
+ },
1482
+ {
1483
+ "epoch": 2.3264405845051006,
1484
+ "grad_norm": 1.938588261604309,
1485
+ "learning_rate": 9.352975495915986e-05,
1486
+ "loss": 3.490141677856445,
1487
+ "step": 2110
1488
+ },
1489
+ {
1490
+ "epoch": 2.337468982630273,
1491
+ "grad_norm": 1.5637104511260986,
1492
+ "learning_rate": 9.347141190198366e-05,
1493
+ "loss": 3.499908447265625,
1494
+ "step": 2120
1495
+ },
1496
+ {
1497
+ "epoch": 2.3484973807554455,
1498
+ "grad_norm": 1.882504940032959,
1499
+ "learning_rate": 9.341306884480747e-05,
1500
+ "loss": 3.491979217529297,
1501
+ "step": 2130
1502
+ },
1503
+ {
1504
+ "epoch": 2.3595257788806174,
1505
+ "grad_norm": 1.8528521060943604,
1506
+ "learning_rate": 9.335472578763128e-05,
1507
+ "loss": 3.4961143493652345,
1508
+ "step": 2140
1509
+ },
1510
+ {
1511
+ "epoch": 2.37055417700579,
1512
+ "grad_norm": 1.8050177097320557,
1513
+ "learning_rate": 9.329638273045509e-05,
1514
+ "loss": 3.4948150634765627,
1515
+ "step": 2150
1516
+ },
1517
+ {
1518
+ "epoch": 2.3815825751309623,
1519
+ "grad_norm": 1.816784381866455,
1520
+ "learning_rate": 9.32380396732789e-05,
1521
+ "loss": 3.4910873413085937,
1522
+ "step": 2160
1523
+ },
1524
+ {
1525
+ "epoch": 2.3926109732561347,
1526
+ "grad_norm": 1.9779244661331177,
1527
+ "learning_rate": 9.317969661610269e-05,
1528
+ "loss": 3.492570495605469,
1529
+ "step": 2170
1530
+ },
1531
+ {
1532
+ "epoch": 2.4036393713813067,
1533
+ "grad_norm": 1.8939772844314575,
1534
+ "learning_rate": 9.31213535589265e-05,
1535
+ "loss": 3.473868560791016,
1536
+ "step": 2180
1537
+ },
1538
+ {
1539
+ "epoch": 2.414667769506479,
1540
+ "grad_norm": 2.1493656635284424,
1541
+ "learning_rate": 9.30630105017503e-05,
1542
+ "loss": 3.494515228271484,
1543
+ "step": 2190
1544
+ },
1545
+ {
1546
+ "epoch": 2.4256961676316515,
1547
+ "grad_norm": 1.8989397287368774,
1548
+ "learning_rate": 9.30046674445741e-05,
1549
+ "loss": 3.487537384033203,
1550
+ "step": 2200
1551
+ },
1552
+ {
1553
+ "epoch": 2.436724565756824,
1554
+ "grad_norm": 1.881856918334961,
1555
+ "learning_rate": 9.294632438739791e-05,
1556
+ "loss": 3.475904083251953,
1557
+ "step": 2210
1558
+ },
1559
+ {
1560
+ "epoch": 2.447752963881996,
1561
+ "grad_norm": 1.9463883638381958,
1562
+ "learning_rate": 9.288798133022171e-05,
1563
+ "loss": 3.4829254150390625,
1564
+ "step": 2220
1565
+ },
1566
+ {
1567
+ "epoch": 2.4587813620071683,
1568
+ "grad_norm": 2.01379656791687,
1569
+ "learning_rate": 9.282963827304552e-05,
1570
+ "loss": 3.472850036621094,
1571
+ "step": 2230
1572
+ },
1573
+ {
1574
+ "epoch": 2.4698097601323408,
1575
+ "grad_norm": 2.442741632461548,
1576
+ "learning_rate": 9.277129521586931e-05,
1577
+ "loss": 3.47030029296875,
1578
+ "step": 2240
1579
+ },
1580
+ {
1581
+ "epoch": 2.480838158257513,
1582
+ "grad_norm": 1.5051734447479248,
1583
+ "learning_rate": 9.271295215869312e-05,
1584
+ "loss": 3.489413833618164,
1585
+ "step": 2250
1586
+ },
1587
+ {
1588
+ "epoch": 2.4918665563826856,
1589
+ "grad_norm": 1.9489309787750244,
1590
+ "learning_rate": 9.265460910151692e-05,
1591
+ "loss": 3.464769744873047,
1592
+ "step": 2260
1593
+ },
1594
+ {
1595
+ "epoch": 2.5028949545078576,
1596
+ "grad_norm": 2.319654941558838,
1597
+ "learning_rate": 9.259626604434072e-05,
1598
+ "loss": 3.469140625,
1599
+ "step": 2270
1600
+ },
1601
+ {
1602
+ "epoch": 2.51392335263303,
1603
+ "grad_norm": 1.7984129190444946,
1604
+ "learning_rate": 9.253792298716453e-05,
1605
+ "loss": 3.466594696044922,
1606
+ "step": 2280
1607
+ },
1608
+ {
1609
+ "epoch": 2.5249517507582024,
1610
+ "grad_norm": 1.640869379043579,
1611
+ "learning_rate": 9.247957992998833e-05,
1612
+ "loss": 3.463022994995117,
1613
+ "step": 2290
1614
+ },
1615
+ {
1616
+ "epoch": 2.5359801488833744,
1617
+ "grad_norm": 1.6698195934295654,
1618
+ "learning_rate": 9.242123687281214e-05,
1619
+ "loss": 3.4695220947265626,
1620
+ "step": 2300
1621
+ },
1622
+ {
1623
+ "epoch": 2.547008547008547,
1624
+ "grad_norm": 2.2945683002471924,
1625
+ "learning_rate": 9.236289381563595e-05,
1626
+ "loss": 3.469150924682617,
1627
+ "step": 2310
1628
+ },
1629
+ {
1630
+ "epoch": 2.5580369451337193,
1631
+ "grad_norm": 1.7678370475769043,
1632
+ "learning_rate": 9.230455075845976e-05,
1633
+ "loss": 3.470307159423828,
1634
+ "step": 2320
1635
+ },
1636
+ {
1637
+ "epoch": 2.5690653432588917,
1638
+ "grad_norm": 1.8386255502700806,
1639
+ "learning_rate": 9.224620770128355e-05,
1640
+ "loss": 3.4638832092285154,
1641
+ "step": 2330
1642
+ },
1643
+ {
1644
+ "epoch": 2.580093741384064,
1645
+ "grad_norm": 2.0348527431488037,
1646
+ "learning_rate": 9.218786464410736e-05,
1647
+ "loss": 3.460480880737305,
1648
+ "step": 2340
1649
+ },
1650
+ {
1651
+ "epoch": 2.5911221395092365,
1652
+ "grad_norm": 1.845974326133728,
1653
+ "learning_rate": 9.212952158693116e-05,
1654
+ "loss": 3.4529083251953123,
1655
+ "step": 2350
1656
+ },
1657
+ {
1658
+ "epoch": 2.6021505376344085,
1659
+ "grad_norm": 2.0843095779418945,
1660
+ "learning_rate": 9.207117852975496e-05,
1661
+ "loss": 3.4576786041259764,
1662
+ "step": 2360
1663
+ },
1664
+ {
1665
+ "epoch": 2.613178935759581,
1666
+ "grad_norm": 1.7627031803131104,
1667
+ "learning_rate": 9.201283547257876e-05,
1668
+ "loss": 3.4450752258300783,
1669
+ "step": 2370
1670
+ },
1671
+ {
1672
+ "epoch": 2.6242073338847534,
1673
+ "grad_norm": 1.371972918510437,
1674
+ "learning_rate": 9.195449241540257e-05,
1675
+ "loss": 3.464734649658203,
1676
+ "step": 2380
1677
+ },
1678
+ {
1679
+ "epoch": 2.6352357320099253,
1680
+ "grad_norm": 1.6781940460205078,
1681
+ "learning_rate": 9.189614935822638e-05,
1682
+ "loss": 3.444991683959961,
1683
+ "step": 2390
1684
+ },
1685
+ {
1686
+ "epoch": 2.6462641301350978,
1687
+ "grad_norm": 1.8782585859298706,
1688
+ "learning_rate": 9.183780630105017e-05,
1689
+ "loss": 3.4558509826660155,
1690
+ "step": 2400
1691
+ },
1692
+ {
1693
+ "epoch": 2.65729252826027,
1694
+ "grad_norm": 1.942812204360962,
1695
+ "learning_rate": 9.177946324387398e-05,
1696
+ "loss": 3.4555503845214846,
1697
+ "step": 2410
1698
+ },
1699
+ {
1700
+ "epoch": 2.6683209263854426,
1701
+ "grad_norm": 1.404680609703064,
1702
+ "learning_rate": 9.172112018669778e-05,
1703
+ "loss": 3.438182830810547,
1704
+ "step": 2420
1705
+ },
1706
+ {
1707
+ "epoch": 2.679349324510615,
1708
+ "grad_norm": 1.7656677961349487,
1709
+ "learning_rate": 9.166277712952159e-05,
1710
+ "loss": 3.4622947692871096,
1711
+ "step": 2430
1712
+ },
1713
+ {
1714
+ "epoch": 2.690377722635787,
1715
+ "grad_norm": 1.8348901271820068,
1716
+ "learning_rate": 9.16044340723454e-05,
1717
+ "loss": 3.438182830810547,
1718
+ "step": 2440
1719
+ },
1720
+ {
1721
+ "epoch": 2.7014061207609594,
1722
+ "grad_norm": 2.0641167163848877,
1723
+ "learning_rate": 9.15460910151692e-05,
1724
+ "loss": 3.441473388671875,
1725
+ "step": 2450
1726
+ },
1727
+ {
1728
+ "epoch": 2.712434518886132,
1729
+ "grad_norm": 1.726035475730896,
1730
+ "learning_rate": 9.148774795799301e-05,
1731
+ "loss": 3.441991424560547,
1732
+ "step": 2460
1733
+ },
1734
+ {
1735
+ "epoch": 2.7234629170113043,
1736
+ "grad_norm": 1.854658603668213,
1737
+ "learning_rate": 9.142940490081681e-05,
1738
+ "loss": 3.4441551208496093,
1739
+ "step": 2470
1740
+ },
1741
+ {
1742
+ "epoch": 2.7344913151364763,
1743
+ "grad_norm": 1.8229296207427979,
1744
+ "learning_rate": 9.137106184364062e-05,
1745
+ "loss": 3.441034698486328,
1746
+ "step": 2480
1747
+ },
1748
+ {
1749
+ "epoch": 2.7455197132616487,
1750
+ "grad_norm": 1.6627975702285767,
1751
+ "learning_rate": 9.131271878646441e-05,
1752
+ "loss": 3.4399124145507813,
1753
+ "step": 2490
1754
+ },
1755
+ {
1756
+ "epoch": 2.756548111386821,
1757
+ "grad_norm": 1.4111251831054688,
1758
+ "learning_rate": 9.125437572928822e-05,
1759
+ "loss": 3.4374462127685548,
1760
+ "step": 2500
1761
+ },
1762
+ {
1763
+ "epoch": 2.7675765095119935,
1764
+ "grad_norm": 2.015869379043579,
1765
+ "learning_rate": 9.119603267211202e-05,
1766
+ "loss": 3.4262016296386717,
1767
+ "step": 2510
1768
+ },
1769
+ {
1770
+ "epoch": 2.778604907637166,
1771
+ "grad_norm": 2.2818591594696045,
1772
+ "learning_rate": 9.113768961493583e-05,
1773
+ "loss": 3.446285629272461,
1774
+ "step": 2520
1775
+ },
1776
+ {
1777
+ "epoch": 2.789633305762338,
1778
+ "grad_norm": 1.8643262386322021,
1779
+ "learning_rate": 9.107934655775962e-05,
1780
+ "loss": 3.4362293243408204,
1781
+ "step": 2530
1782
+ },
1783
+ {
1784
+ "epoch": 2.8006617038875103,
1785
+ "grad_norm": 1.248988151550293,
1786
+ "learning_rate": 9.102100350058343e-05,
1787
+ "loss": 3.441702651977539,
1788
+ "step": 2540
1789
+ },
1790
+ {
1791
+ "epoch": 2.8116901020126828,
1792
+ "grad_norm": 1.5247464179992676,
1793
+ "learning_rate": 9.096266044340724e-05,
1794
+ "loss": 3.4388256072998047,
1795
+ "step": 2550
1796
+ },
1797
+ {
1798
+ "epoch": 2.8227185001378547,
1799
+ "grad_norm": 1.9120620489120483,
1800
+ "learning_rate": 9.090431738623103e-05,
1801
+ "loss": 3.4206756591796874,
1802
+ "step": 2560
1803
+ },
1804
+ {
1805
+ "epoch": 2.833746898263027,
1806
+ "grad_norm": 1.4591054916381836,
1807
+ "learning_rate": 9.084597432905484e-05,
1808
+ "loss": 3.4229709625244142,
1809
+ "step": 2570
1810
+ },
1811
+ {
1812
+ "epoch": 2.8447752963881996,
1813
+ "grad_norm": 2.24849796295166,
1814
+ "learning_rate": 9.078763127187865e-05,
1815
+ "loss": 3.426911163330078,
1816
+ "step": 2580
1817
+ },
1818
+ {
1819
+ "epoch": 2.855803694513372,
1820
+ "grad_norm": 1.5658804178237915,
1821
+ "learning_rate": 9.072928821470246e-05,
1822
+ "loss": 3.445120620727539,
1823
+ "step": 2590
1824
+ },
1825
+ {
1826
+ "epoch": 2.8668320926385444,
1827
+ "grad_norm": 1.483583688735962,
1828
+ "learning_rate": 9.067094515752626e-05,
1829
+ "loss": 3.430312728881836,
1830
+ "step": 2600
1831
+ },
1832
+ {
1833
+ "epoch": 2.8778604907637164,
1834
+ "grad_norm": 1.5759658813476562,
1835
+ "learning_rate": 9.061260210035007e-05,
1836
+ "loss": 3.4178386688232423,
1837
+ "step": 2610
1838
+ },
1839
+ {
1840
+ "epoch": 2.888888888888889,
1841
+ "grad_norm": 1.9259848594665527,
1842
+ "learning_rate": 9.055425904317386e-05,
1843
+ "loss": 3.430949401855469,
1844
+ "step": 2620
1845
+ },
1846
+ {
1847
+ "epoch": 2.8999172870140613,
1848
+ "grad_norm": 1.470717191696167,
1849
+ "learning_rate": 9.049591598599767e-05,
1850
+ "loss": 3.439757537841797,
1851
+ "step": 2630
1852
+ },
1853
+ {
1854
+ "epoch": 2.9109456851392337,
1855
+ "grad_norm": 1.8934212923049927,
1856
+ "learning_rate": 9.043757292882148e-05,
1857
+ "loss": 3.430719757080078,
1858
+ "step": 2640
1859
+ },
1860
+ {
1861
+ "epoch": 2.9219740832644057,
1862
+ "grad_norm": 1.6267489194869995,
1863
+ "learning_rate": 9.037922987164527e-05,
1864
+ "loss": 3.4224998474121096,
1865
+ "step": 2650
1866
+ },
1867
+ {
1868
+ "epoch": 2.933002481389578,
1869
+ "grad_norm": 1.6213353872299194,
1870
+ "learning_rate": 9.032088681446908e-05,
1871
+ "loss": 3.4213233947753907,
1872
+ "step": 2660
1873
+ },
1874
+ {
1875
+ "epoch": 2.9440308795147505,
1876
+ "grad_norm": 1.961879849433899,
1877
+ "learning_rate": 9.026254375729288e-05,
1878
+ "loss": 3.4108352661132812,
1879
+ "step": 2670
1880
+ },
1881
+ {
1882
+ "epoch": 2.955059277639923,
1883
+ "grad_norm": 1.7363910675048828,
1884
+ "learning_rate": 9.020420070011669e-05,
1885
+ "loss": 3.423554229736328,
1886
+ "step": 2680
1887
+ },
1888
+ {
1889
+ "epoch": 2.9660876757650954,
1890
+ "grad_norm": 1.6161952018737793,
1891
+ "learning_rate": 9.014585764294048e-05,
1892
+ "loss": 3.418962860107422,
1893
+ "step": 2690
1894
+ },
1895
+ {
1896
+ "epoch": 2.9771160738902673,
1897
+ "grad_norm": 1.8065682649612427,
1898
+ "learning_rate": 9.008751458576429e-05,
1899
+ "loss": 3.4218765258789063,
1900
+ "step": 2700
1901
+ },
1902
+ {
1903
+ "epoch": 2.9881444720154398,
1904
+ "grad_norm": 1.4285337924957275,
1905
+ "learning_rate": 9.00291715285881e-05,
1906
+ "loss": 3.413957214355469,
1907
+ "step": 2710
1908
+ },
1909
+ {
1910
+ "epoch": 2.999172870140612,
1911
+ "grad_norm": 1.30274498462677,
1912
+ "learning_rate": 8.997082847141191e-05,
1913
+ "loss": 3.4176124572753905,
1914
+ "step": 2720
1915
+ },
1916
+ {
1917
+ "epoch": 3.009925558312655,
1918
+ "grad_norm": 1.5460416078567505,
1919
+ "learning_rate": 8.991248541423572e-05,
1920
+ "loss": 3.388013458251953,
1921
+ "step": 2730
1922
+ },
1923
+ {
1924
+ "epoch": 3.0209539564378276,
1925
+ "grad_norm": 1.5832446813583374,
1926
+ "learning_rate": 8.985414235705951e-05,
1927
+ "loss": 3.3929378509521486,
1928
+ "step": 2740
1929
+ },
1930
+ {
1931
+ "epoch": 3.0319823545629996,
1932
+ "grad_norm": 1.6086630821228027,
1933
+ "learning_rate": 8.979579929988332e-05,
1934
+ "loss": 3.3940502166748048,
1935
+ "step": 2750
1936
+ },
1937
+ {
1938
+ "epoch": 3.043010752688172,
1939
+ "grad_norm": 1.6624842882156372,
1940
+ "learning_rate": 8.973745624270712e-05,
1941
+ "loss": 3.388884353637695,
1942
+ "step": 2760
1943
+ },
1944
+ {
1945
+ "epoch": 3.0540391508133444,
1946
+ "grad_norm": 1.7352933883666992,
1947
+ "learning_rate": 8.967911318553093e-05,
1948
+ "loss": 3.409127426147461,
1949
+ "step": 2770
1950
+ },
1951
+ {
1952
+ "epoch": 3.065067548938517,
1953
+ "grad_norm": 1.45657217502594,
1954
+ "learning_rate": 8.962077012835472e-05,
1955
+ "loss": 3.389351654052734,
1956
+ "step": 2780
1957
+ },
1958
+ {
1959
+ "epoch": 3.076095947063689,
1960
+ "grad_norm": 1.4969090223312378,
1961
+ "learning_rate": 8.956242707117853e-05,
1962
+ "loss": 3.3988433837890626,
1963
+ "step": 2790
1964
+ },
1965
+ {
1966
+ "epoch": 3.0871243451888613,
1967
+ "grad_norm": 1.710800051689148,
1968
+ "learning_rate": 8.950408401400234e-05,
1969
+ "loss": 3.395826721191406,
1970
+ "step": 2800
1971
+ },
1972
+ {
1973
+ "epoch": 3.0981527433140337,
1974
+ "grad_norm": 1.6347870826721191,
1975
+ "learning_rate": 8.944574095682614e-05,
1976
+ "loss": 3.391011047363281,
1977
+ "step": 2810
1978
+ },
1979
+ {
1980
+ "epoch": 3.109181141439206,
1981
+ "grad_norm": 1.4630122184753418,
1982
+ "learning_rate": 8.938739789964995e-05,
1983
+ "loss": 3.401841735839844,
1984
+ "step": 2820
1985
+ },
1986
+ {
1987
+ "epoch": 3.120209539564378,
1988
+ "grad_norm": 1.547430157661438,
1989
+ "learning_rate": 8.932905484247374e-05,
1990
+ "loss": 3.3979782104492187,
1991
+ "step": 2830
1992
+ },
1993
+ {
1994
+ "epoch": 3.1312379376895505,
1995
+ "grad_norm": 1.5614186525344849,
1996
+ "learning_rate": 8.927071178529755e-05,
1997
+ "loss": 3.3884544372558594,
1998
+ "step": 2840
1999
+ },
2000
+ {
2001
+ "epoch": 3.142266335814723,
2002
+ "grad_norm": 1.4073251485824585,
2003
+ "learning_rate": 8.921236872812136e-05,
2004
+ "loss": 3.3886154174804686,
2005
+ "step": 2850
2006
+ },
2007
+ {
2008
+ "epoch": 3.1532947339398953,
2009
+ "grad_norm": 1.3639475107192993,
2010
+ "learning_rate": 8.915402567094517e-05,
2011
+ "loss": 3.383074951171875,
2012
+ "step": 2860
2013
+ },
2014
+ {
2015
+ "epoch": 3.1643231320650678,
2016
+ "grad_norm": 2.3929882049560547,
2017
+ "learning_rate": 8.909568261376896e-05,
2018
+ "loss": 3.3788246154785155,
2019
+ "step": 2870
2020
+ },
2021
+ {
2022
+ "epoch": 3.1753515301902397,
2023
+ "grad_norm": 1.7196829319000244,
2024
+ "learning_rate": 8.903733955659277e-05,
2025
+ "loss": 3.3822708129882812,
2026
+ "step": 2880
2027
+ },
2028
+ {
2029
+ "epoch": 3.186379928315412,
2030
+ "grad_norm": 1.526293396949768,
2031
+ "learning_rate": 8.897899649941658e-05,
2032
+ "loss": 3.381543731689453,
2033
+ "step": 2890
2034
+ },
2035
+ {
2036
+ "epoch": 3.1974083264405846,
2037
+ "grad_norm": 1.2336128950119019,
2038
+ "learning_rate": 8.892065344224038e-05,
2039
+ "loss": 3.3975807189941407,
2040
+ "step": 2900
2041
+ },
2042
+ {
2043
+ "epoch": 3.208436724565757,
2044
+ "grad_norm": 1.4868130683898926,
2045
+ "learning_rate": 8.886231038506419e-05,
2046
+ "loss": 3.3970687866210936,
2047
+ "step": 2910
2048
+ },
2049
+ {
2050
+ "epoch": 3.219465122690929,
2051
+ "grad_norm": 1.5349540710449219,
2052
+ "learning_rate": 8.880396732788798e-05,
2053
+ "loss": 3.385994720458984,
2054
+ "step": 2920
2055
+ },
2056
+ {
2057
+ "epoch": 3.2304935208161014,
2058
+ "grad_norm": 1.5333718061447144,
2059
+ "learning_rate": 8.874562427071179e-05,
2060
+ "loss": 3.362841796875,
2061
+ "step": 2930
2062
+ },
2063
+ {
2064
+ "epoch": 3.241521918941274,
2065
+ "grad_norm": 1.514235258102417,
2066
+ "learning_rate": 8.868728121353558e-05,
2067
+ "loss": 3.3816680908203125,
2068
+ "step": 2940
2069
+ },
2070
+ {
2071
+ "epoch": 3.2525503170664463,
2072
+ "grad_norm": 1.5870161056518555,
2073
+ "learning_rate": 8.86289381563594e-05,
2074
+ "loss": 3.3818199157714846,
2075
+ "step": 2950
2076
+ },
2077
+ {
2078
+ "epoch": 3.2635787151916182,
2079
+ "grad_norm": 1.6295320987701416,
2080
+ "learning_rate": 8.85705950991832e-05,
2081
+ "loss": 3.379594421386719,
2082
+ "step": 2960
2083
+ },
2084
+ {
2085
+ "epoch": 3.2746071133167907,
2086
+ "grad_norm": 1.533991813659668,
2087
+ "learning_rate": 8.8512252042007e-05,
2088
+ "loss": 3.387801742553711,
2089
+ "step": 2970
2090
+ },
2091
+ {
2092
+ "epoch": 3.285635511441963,
2093
+ "grad_norm": 2.2125084400177,
2094
+ "learning_rate": 8.845390898483081e-05,
2095
+ "loss": 3.3856468200683594,
2096
+ "step": 2980
2097
+ },
2098
+ {
2099
+ "epoch": 3.2966639095671355,
2100
+ "grad_norm": 1.800207495689392,
2101
+ "learning_rate": 8.839556592765462e-05,
2102
+ "loss": 3.3843597412109374,
2103
+ "step": 2990
2104
+ },
2105
+ {
2106
+ "epoch": 3.3076923076923075,
2107
+ "grad_norm": 1.3071027994155884,
2108
+ "learning_rate": 8.833722287047842e-05,
2109
+ "loss": 3.3861888885498046,
2110
+ "step": 3000
2111
+ },
2112
+ {
2113
+ "epoch": 3.31872070581748,
2114
+ "grad_norm": 1.7724641561508179,
2115
+ "learning_rate": 8.827887981330222e-05,
2116
+ "loss": 3.3929458618164063,
2117
+ "step": 3010
2118
+ },
2119
+ {
2120
+ "epoch": 3.3297491039426523,
2121
+ "grad_norm": 1.3397877216339111,
2122
+ "learning_rate": 8.822053675612603e-05,
2123
+ "loss": 3.3785301208496095,
2124
+ "step": 3020
2125
+ },
2126
+ {
2127
+ "epoch": 3.3407775020678248,
2128
+ "grad_norm": 1.352630376815796,
2129
+ "learning_rate": 8.816219369894982e-05,
2130
+ "loss": 3.3796306610107423,
2131
+ "step": 3030
2132
+ },
2133
+ {
2134
+ "epoch": 3.351805900192997,
2135
+ "grad_norm": 1.5996475219726562,
2136
+ "learning_rate": 8.810385064177363e-05,
2137
+ "loss": 3.362406921386719,
2138
+ "step": 3040
2139
+ },
2140
+ {
2141
+ "epoch": 3.362834298318169,
2142
+ "grad_norm": 1.6010814905166626,
2143
+ "learning_rate": 8.804550758459744e-05,
2144
+ "loss": 3.3811767578125,
2145
+ "step": 3050
2146
+ },
2147
+ {
2148
+ "epoch": 3.3738626964433416,
2149
+ "grad_norm": 1.3276373147964478,
2150
+ "learning_rate": 8.798716452742124e-05,
2151
+ "loss": 3.3732643127441406,
2152
+ "step": 3060
2153
+ },
2154
+ {
2155
+ "epoch": 3.384891094568514,
2156
+ "grad_norm": 1.7741515636444092,
2157
+ "learning_rate": 8.792882147024505e-05,
2158
+ "loss": 3.381968688964844,
2159
+ "step": 3070
2160
+ },
2161
+ {
2162
+ "epoch": 3.3959194926936864,
2163
+ "grad_norm": 1.7820576429367065,
2164
+ "learning_rate": 8.787047841306884e-05,
2165
+ "loss": 3.358811950683594,
2166
+ "step": 3080
2167
+ },
2168
+ {
2169
+ "epoch": 3.4069478908188584,
2170
+ "grad_norm": 1.389573574066162,
2171
+ "learning_rate": 8.781213535589265e-05,
2172
+ "loss": 3.36102180480957,
2173
+ "step": 3090
2174
+ },
2175
+ {
2176
+ "epoch": 3.417976288944031,
2177
+ "grad_norm": 1.1910648345947266,
2178
+ "learning_rate": 8.775379229871645e-05,
2179
+ "loss": 3.3652645111083985,
2180
+ "step": 3100
2181
+ },
2182
+ {
2183
+ "epoch": 3.4290046870692032,
2184
+ "grad_norm": 1.965219497680664,
2185
+ "learning_rate": 8.769544924154026e-05,
2186
+ "loss": 3.3735313415527344,
2187
+ "step": 3110
2188
+ },
2189
+ {
2190
+ "epoch": 3.4400330851943757,
2191
+ "grad_norm": 1.5992330312728882,
2192
+ "learning_rate": 8.763710618436406e-05,
2193
+ "loss": 3.362974166870117,
2194
+ "step": 3120
2195
+ },
2196
+ {
2197
+ "epoch": 3.4510614833195477,
2198
+ "grad_norm": 2.2293193340301514,
2199
+ "learning_rate": 8.757876312718787e-05,
2200
+ "loss": 3.3681709289550783,
2201
+ "step": 3130
2202
+ },
2203
+ {
2204
+ "epoch": 3.46208988144472,
2205
+ "grad_norm": 1.2978801727294922,
2206
+ "learning_rate": 8.752042007001168e-05,
2207
+ "loss": 3.3776336669921876,
2208
+ "step": 3140
2209
+ },
2210
+ {
2211
+ "epoch": 3.4731182795698925,
2212
+ "grad_norm": 1.227036714553833,
2213
+ "learning_rate": 8.746207701283548e-05,
2214
+ "loss": 3.3590301513671874,
2215
+ "step": 3150
2216
+ },
2217
+ {
2218
+ "epoch": 3.484146677695065,
2219
+ "grad_norm": 1.8023360967636108,
2220
+ "learning_rate": 8.740373395565929e-05,
2221
+ "loss": 3.35421142578125,
2222
+ "step": 3160
2223
+ },
2224
+ {
2225
+ "epoch": 3.495175075820237,
2226
+ "grad_norm": 1.6423453092575073,
2227
+ "learning_rate": 8.734539089848308e-05,
2228
+ "loss": 3.3748985290527345,
2229
+ "step": 3170
2230
+ },
2231
+ {
2232
+ "epoch": 3.5062034739454093,
2233
+ "grad_norm": 1.3261916637420654,
2234
+ "learning_rate": 8.728704784130689e-05,
2235
+ "loss": 3.36380615234375,
2236
+ "step": 3180
2237
+ },
2238
+ {
2239
+ "epoch": 3.5172318720705817,
2240
+ "grad_norm": 1.290014624595642,
2241
+ "learning_rate": 8.722870478413069e-05,
2242
+ "loss": 3.3596282958984376,
2243
+ "step": 3190
2244
+ },
2245
+ {
2246
+ "epoch": 3.528260270195754,
2247
+ "grad_norm": 2.0481576919555664,
2248
+ "learning_rate": 8.71703617269545e-05,
2249
+ "loss": 3.358118438720703,
2250
+ "step": 3200
2251
+ },
2252
+ {
2253
+ "epoch": 3.5392886683209266,
2254
+ "grad_norm": 1.4758331775665283,
2255
+ "learning_rate": 8.71120186697783e-05,
2256
+ "loss": 3.3536834716796875,
2257
+ "step": 3210
2258
+ },
2259
+ {
2260
+ "epoch": 3.5503170664460986,
2261
+ "grad_norm": 1.4340440034866333,
2262
+ "learning_rate": 8.70536756126021e-05,
2263
+ "loss": 3.358259582519531,
2264
+ "step": 3220
2265
+ },
2266
+ {
2267
+ "epoch": 3.561345464571271,
2268
+ "grad_norm": 1.6952699422836304,
2269
+ "learning_rate": 8.699533255542591e-05,
2270
+ "loss": 3.3730777740478515,
2271
+ "step": 3230
2272
+ },
2273
+ {
2274
+ "epoch": 3.5723738626964434,
2275
+ "grad_norm": 1.9069234132766724,
2276
+ "learning_rate": 8.69369894982497e-05,
2277
+ "loss": 3.3552001953125,
2278
+ "step": 3240
2279
+ },
2280
+ {
2281
+ "epoch": 3.5834022608216154,
2282
+ "grad_norm": 1.6194590330123901,
2283
+ "learning_rate": 8.687864644107351e-05,
2284
+ "loss": 3.3562744140625,
2285
+ "step": 3250
2286
+ },
2287
+ {
2288
+ "epoch": 3.594430658946788,
2289
+ "grad_norm": 1.33975350856781,
2290
+ "learning_rate": 8.682030338389732e-05,
2291
+ "loss": 3.3622581481933596,
2292
+ "step": 3260
2293
+ },
2294
+ {
2295
+ "epoch": 3.6054590570719602,
2296
+ "grad_norm": 1.3948160409927368,
2297
+ "learning_rate": 8.676196032672113e-05,
2298
+ "loss": 3.3645614624023437,
2299
+ "step": 3270
2300
+ },
2301
+ {
2302
+ "epoch": 3.6164874551971327,
2303
+ "grad_norm": 1.4972363710403442,
2304
+ "learning_rate": 8.670361726954493e-05,
2305
+ "loss": 3.3713829040527346,
2306
+ "step": 3280
2307
+ },
2308
+ {
2309
+ "epoch": 3.627515853322305,
2310
+ "grad_norm": 1.9456968307495117,
2311
+ "learning_rate": 8.664527421236874e-05,
2312
+ "loss": 3.3617935180664062,
2313
+ "step": 3290
2314
+ },
2315
+ {
2316
+ "epoch": 3.6385442514474775,
2317
+ "grad_norm": 1.8050702810287476,
2318
+ "learning_rate": 8.658693115519254e-05,
2319
+ "loss": 3.359496307373047,
2320
+ "step": 3300
2321
+ },
2322
+ {
2323
+ "epoch": 3.6495726495726495,
2324
+ "grad_norm": 1.294492244720459,
2325
+ "learning_rate": 8.652858809801634e-05,
2326
+ "loss": 3.361173629760742,
2327
+ "step": 3310
2328
+ },
2329
+ {
2330
+ "epoch": 3.660601047697822,
2331
+ "grad_norm": 1.7897614240646362,
2332
+ "learning_rate": 8.647024504084015e-05,
2333
+ "loss": 3.3475852966308595,
2334
+ "step": 3320
2335
+ },
2336
+ {
2337
+ "epoch": 3.6716294458229943,
2338
+ "grad_norm": 1.5647767782211304,
2339
+ "learning_rate": 8.641190198366394e-05,
2340
+ "loss": 3.3594207763671875,
2341
+ "step": 3330
2342
+ },
2343
+ {
2344
+ "epoch": 3.6826578439481663,
2345
+ "grad_norm": 1.3839472532272339,
2346
+ "learning_rate": 8.635355892648775e-05,
2347
+ "loss": 3.361709976196289,
2348
+ "step": 3340
2349
+ },
2350
+ {
2351
+ "epoch": 3.6936862420733387,
2352
+ "grad_norm": 1.543115258216858,
2353
+ "learning_rate": 8.629521586931155e-05,
2354
+ "loss": 3.349272918701172,
2355
+ "step": 3350
2356
+ },
2357
+ {
2358
+ "epoch": 3.704714640198511,
2359
+ "grad_norm": 1.2722103595733643,
2360
+ "learning_rate": 8.623687281213536e-05,
2361
+ "loss": 3.3600040435791017,
2362
+ "step": 3360
2363
+ },
2364
+ {
2365
+ "epoch": 3.7157430383236836,
2366
+ "grad_norm": 2.396493434906006,
2367
+ "learning_rate": 8.617852975495915e-05,
2368
+ "loss": 3.359762954711914,
2369
+ "step": 3370
2370
+ },
2371
+ {
2372
+ "epoch": 3.726771436448856,
2373
+ "grad_norm": 1.3756037950515747,
2374
+ "learning_rate": 8.612018669778296e-05,
2375
+ "loss": 3.3409027099609374,
2376
+ "step": 3380
2377
+ },
2378
+ {
2379
+ "epoch": 3.737799834574028,
2380
+ "grad_norm": 1.5124824047088623,
2381
+ "learning_rate": 8.606184364060677e-05,
2382
+ "loss": 3.346342849731445,
2383
+ "step": 3390
2384
+ },
2385
+ {
2386
+ "epoch": 3.7488282326992004,
2387
+ "grad_norm": 1.3679585456848145,
2388
+ "learning_rate": 8.600350058343058e-05,
2389
+ "loss": 3.3478328704833986,
2390
+ "step": 3400
2391
+ },
2392
+ {
2393
+ "epoch": 3.759856630824373,
2394
+ "grad_norm": 1.3470197916030884,
2395
+ "learning_rate": 8.594515752625439e-05,
2396
+ "loss": 3.352674865722656,
2397
+ "step": 3410
2398
+ },
2399
+ {
2400
+ "epoch": 3.770885028949545,
2401
+ "grad_norm": 1.4775781631469727,
2402
+ "learning_rate": 8.588681446907818e-05,
2403
+ "loss": 3.3504791259765625,
2404
+ "step": 3420
2405
+ },
2406
+ {
2407
+ "epoch": 3.7819134270747172,
2408
+ "grad_norm": 1.1987943649291992,
2409
+ "learning_rate": 8.582847141190199e-05,
2410
+ "loss": 3.3457687377929686,
2411
+ "step": 3430
2412
+ },
2413
+ {
2414
+ "epoch": 3.7929418251998896,
2415
+ "grad_norm": 1.8007314205169678,
2416
+ "learning_rate": 8.577012835472579e-05,
2417
+ "loss": 3.3557716369628907,
2418
+ "step": 3440
2419
+ },
2420
+ {
2421
+ "epoch": 3.803970223325062,
2422
+ "grad_norm": 1.4193800687789917,
2423
+ "learning_rate": 8.57117852975496e-05,
2424
+ "loss": 3.346666717529297,
2425
+ "step": 3450
2426
+ },
2427
+ {
2428
+ "epoch": 3.8149986214502345,
2429
+ "grad_norm": 1.600216031074524,
2430
+ "learning_rate": 8.56534422403734e-05,
2431
+ "loss": 3.354322814941406,
2432
+ "step": 3460
2433
+ },
2434
+ {
2435
+ "epoch": 3.826027019575407,
2436
+ "grad_norm": 1.6823015213012695,
2437
+ "learning_rate": 8.55950991831972e-05,
2438
+ "loss": 3.3344764709472656,
2439
+ "step": 3470
2440
+ },
2441
+ {
2442
+ "epoch": 3.837055417700579,
2443
+ "grad_norm": 1.8002822399139404,
2444
+ "learning_rate": 8.553675612602101e-05,
2445
+ "loss": 3.338224411010742,
2446
+ "step": 3480
2447
+ },
2448
+ {
2449
+ "epoch": 3.8480838158257513,
2450
+ "grad_norm": 1.019519567489624,
2451
+ "learning_rate": 8.54784130688448e-05,
2452
+ "loss": 3.342393493652344,
2453
+ "step": 3490
2454
+ },
2455
+ {
2456
+ "epoch": 3.8591122139509237,
2457
+ "grad_norm": 1.4397176504135132,
2458
+ "learning_rate": 8.542007001166861e-05,
2459
+ "loss": 3.3416332244873046,
2460
+ "step": 3500
2461
+ },
2462
+ {
2463
+ "epoch": 3.8701406120760957,
2464
+ "grad_norm": 1.398215889930725,
2465
+ "learning_rate": 8.536172695449241e-05,
2466
+ "loss": 3.3455711364746095,
2467
+ "step": 3510
2468
+ },
2469
+ {
2470
+ "epoch": 3.881169010201268,
2471
+ "grad_norm": 1.431221604347229,
2472
+ "learning_rate": 8.530338389731622e-05,
2473
+ "loss": 3.3510116577148437,
2474
+ "step": 3520
2475
+ },
2476
+ {
2477
+ "epoch": 3.8921974083264406,
2478
+ "grad_norm": 1.2339868545532227,
2479
+ "learning_rate": 8.524504084014003e-05,
2480
+ "loss": 3.333365631103516,
2481
+ "step": 3530
2482
+ },
2483
+ {
2484
+ "epoch": 3.903225806451613,
2485
+ "grad_norm": 1.2564575672149658,
2486
+ "learning_rate": 8.518669778296384e-05,
2487
+ "loss": 3.355131912231445,
2488
+ "step": 3540
2489
+ },
2490
+ {
2491
+ "epoch": 3.9142542045767854,
2492
+ "grad_norm": 1.44709050655365,
2493
+ "learning_rate": 8.512835472578765e-05,
2494
+ "loss": 3.352345275878906,
2495
+ "step": 3550
2496
+ },
2497
+ {
2498
+ "epoch": 3.9252826027019574,
2499
+ "grad_norm": 1.0984286069869995,
2500
+ "learning_rate": 8.507001166861144e-05,
2501
+ "loss": 3.3399391174316406,
2502
+ "step": 3560
2503
+ },
2504
+ {
2505
+ "epoch": 3.93631100082713,
2506
+ "grad_norm": 1.521567702293396,
2507
+ "learning_rate": 8.501166861143525e-05,
2508
+ "loss": 3.3333946228027345,
2509
+ "step": 3570
2510
+ },
2511
+ {
2512
+ "epoch": 3.9473393989523022,
2513
+ "grad_norm": 1.3443926572799683,
2514
+ "learning_rate": 8.495332555425905e-05,
2515
+ "loss": 3.3321746826171874,
2516
+ "step": 3580
2517
+ },
2518
+ {
2519
+ "epoch": 3.9583677970774747,
2520
+ "grad_norm": 1.539640188217163,
2521
+ "learning_rate": 8.489498249708285e-05,
2522
+ "loss": 3.335438537597656,
2523
+ "step": 3590
2524
+ },
2525
+ {
2526
+ "epoch": 3.9693961952026466,
2527
+ "grad_norm": 1.123307466506958,
2528
+ "learning_rate": 8.483663943990665e-05,
2529
+ "loss": 3.3397190093994142,
2530
+ "step": 3600
2531
+ },
2532
+ {
2533
+ "epoch": 3.980424593327819,
2534
+ "grad_norm": 1.6037691831588745,
2535
+ "learning_rate": 8.477829638273046e-05,
2536
+ "loss": 3.3357570648193358,
2537
+ "step": 3610
2538
+ },
2539
+ {
2540
+ "epoch": 3.9914529914529915,
2541
+ "grad_norm": 1.6570971012115479,
2542
+ "learning_rate": 8.471995332555425e-05,
2543
+ "loss": 3.341298294067383,
2544
+ "step": 3620
2545
+ }
2546
+ ],
2547
+ "logging_steps": 10,
2548
+ "max_steps": 18140,
2549
+ "num_input_tokens_seen": 0,
2550
+ "num_train_epochs": 20,
2551
+ "save_steps": 500,
2552
+ "stateful_callbacks": {
2553
+ "TrainerControl": {
2554
+ "args": {
2555
+ "should_epoch_stop": false,
2556
+ "should_evaluate": false,
2557
+ "should_log": false,
2558
+ "should_save": true,
2559
+ "should_training_stop": false
2560
+ },
2561
+ "attributes": {}
2562
+ }
2563
+ },
2564
+ "total_flos": 1444428795346944.0,
2565
+ "train_batch_size": 1,
2566
+ "trial_name": null,
2567
+ "trial_params": null
2568
+ }
output_qwen3_plain_ar/checkpoint-3628/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-4535/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-4535/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-4535/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step4535
output_qwen3_plain_ar/checkpoint-4535/trainer_state.json ADDED
@@ -0,0 +1,3205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 5.0,
6
+ "eval_steps": 500,
7
+ "global_step": 4535,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.011028398125172319,
14
+ "grad_norm": 435.2422180175781,
15
+ "learning_rate": 9e-07,
16
+ "loss": 20.84569549560547,
17
+ "step": 10
18
+ },
19
+ {
20
+ "epoch": 0.022056796250344637,
21
+ "grad_norm": 141.7341766357422,
22
+ "learning_rate": 1.9e-06,
23
+ "loss": 18.69615936279297,
24
+ "step": 20
25
+ },
26
+ {
27
+ "epoch": 0.033085194375516956,
28
+ "grad_norm": 74.42520904541016,
29
+ "learning_rate": 2.9e-06,
30
+ "loss": 16.079673767089844,
31
+ "step": 30
32
+ },
33
+ {
34
+ "epoch": 0.044113592500689275,
35
+ "grad_norm": 24.73248863220215,
36
+ "learning_rate": 3.9e-06,
37
+ "loss": 13.684315490722657,
38
+ "step": 40
39
+ },
40
+ {
41
+ "epoch": 0.055141990625861594,
42
+ "grad_norm": 7.049101829528809,
43
+ "learning_rate": 4.9000000000000005e-06,
44
+ "loss": 12.474874877929688,
45
+ "step": 50
46
+ },
47
+ {
48
+ "epoch": 0.06617038875103391,
49
+ "grad_norm": 2.3411474227905273,
50
+ "learning_rate": 5.9e-06,
51
+ "loss": 12.072142028808594,
52
+ "step": 60
53
+ },
54
+ {
55
+ "epoch": 0.07719878687620624,
56
+ "grad_norm": 1.126215934753418,
57
+ "learning_rate": 6.900000000000001e-06,
58
+ "loss": 11.938906860351562,
59
+ "step": 70
60
+ },
61
+ {
62
+ "epoch": 0.08822718500137855,
63
+ "grad_norm": 1.2050226926803589,
64
+ "learning_rate": 7.9e-06,
65
+ "loss": 11.81988296508789,
66
+ "step": 80
67
+ },
68
+ {
69
+ "epoch": 0.09925558312655088,
70
+ "grad_norm": 1.444793462753296,
71
+ "learning_rate": 8.9e-06,
72
+ "loss": 11.602033996582032,
73
+ "step": 90
74
+ },
75
+ {
76
+ "epoch": 0.11028398125172319,
77
+ "grad_norm": 5.791665077209473,
78
+ "learning_rate": 9.900000000000002e-06,
79
+ "loss": 11.201815032958985,
80
+ "step": 100
81
+ },
82
+ {
83
+ "epoch": 0.12131237937689551,
84
+ "grad_norm": 9.492277145385742,
85
+ "learning_rate": 1.09e-05,
86
+ "loss": 10.535708618164062,
87
+ "step": 110
88
+ },
89
+ {
90
+ "epoch": 0.13234077750206782,
91
+ "grad_norm": 2.7546133995056152,
92
+ "learning_rate": 1.19e-05,
93
+ "loss": 9.847169494628906,
94
+ "step": 120
95
+ },
96
+ {
97
+ "epoch": 0.14336917562724014,
98
+ "grad_norm": 1.0953313112258911,
99
+ "learning_rate": 1.29e-05,
100
+ "loss": 9.429026031494141,
101
+ "step": 130
102
+ },
103
+ {
104
+ "epoch": 0.15439757375241248,
105
+ "grad_norm": 0.7153559327125549,
106
+ "learning_rate": 1.3900000000000002e-05,
107
+ "loss": 9.266969299316406,
108
+ "step": 140
109
+ },
110
+ {
111
+ "epoch": 0.1654259718775848,
112
+ "grad_norm": 0.5888933539390564,
113
+ "learning_rate": 1.49e-05,
114
+ "loss": 9.1935546875,
115
+ "step": 150
116
+ },
117
+ {
118
+ "epoch": 0.1764543700027571,
119
+ "grad_norm": 0.4850365221500397,
120
+ "learning_rate": 1.59e-05,
121
+ "loss": 9.19604034423828,
122
+ "step": 160
123
+ },
124
+ {
125
+ "epoch": 0.1874827681279294,
126
+ "grad_norm": 0.5772538185119629,
127
+ "learning_rate": 1.69e-05,
128
+ "loss": 9.17010726928711,
129
+ "step": 170
130
+ },
131
+ {
132
+ "epoch": 0.19851116625310175,
133
+ "grad_norm": 0.4283920228481293,
134
+ "learning_rate": 1.79e-05,
135
+ "loss": 9.172830200195312,
136
+ "step": 180
137
+ },
138
+ {
139
+ "epoch": 0.20953956437827406,
140
+ "grad_norm": 0.8650698065757751,
141
+ "learning_rate": 1.8900000000000002e-05,
142
+ "loss": 9.154988098144532,
143
+ "step": 190
144
+ },
145
+ {
146
+ "epoch": 0.22056796250344637,
147
+ "grad_norm": 0.42017608880996704,
148
+ "learning_rate": 1.9900000000000003e-05,
149
+ "loss": 9.146849060058594,
150
+ "step": 200
151
+ },
152
+ {
153
+ "epoch": 0.23159636062861869,
154
+ "grad_norm": 0.9125994443893433,
155
+ "learning_rate": 2.09e-05,
156
+ "loss": 9.164442443847657,
157
+ "step": 210
158
+ },
159
+ {
160
+ "epoch": 0.24262475875379103,
161
+ "grad_norm": 0.6468876004219055,
162
+ "learning_rate": 2.19e-05,
163
+ "loss": 9.159596252441407,
164
+ "step": 220
165
+ },
166
+ {
167
+ "epoch": 0.25365315687896334,
168
+ "grad_norm": 0.4124819338321686,
169
+ "learning_rate": 2.29e-05,
170
+ "loss": 9.13860626220703,
171
+ "step": 230
172
+ },
173
+ {
174
+ "epoch": 0.26468155500413565,
175
+ "grad_norm": 1.990302562713623,
176
+ "learning_rate": 2.39e-05,
177
+ "loss": 9.145040893554688,
178
+ "step": 240
179
+ },
180
+ {
181
+ "epoch": 0.27570995312930796,
182
+ "grad_norm": 0.7875277400016785,
183
+ "learning_rate": 2.4900000000000002e-05,
184
+ "loss": 9.152925109863281,
185
+ "step": 250
186
+ },
187
+ {
188
+ "epoch": 0.2867383512544803,
189
+ "grad_norm": 0.8343706130981445,
190
+ "learning_rate": 2.5900000000000003e-05,
191
+ "loss": 9.132975769042968,
192
+ "step": 260
193
+ },
194
+ {
195
+ "epoch": 0.2977667493796526,
196
+ "grad_norm": 3.00996470451355,
197
+ "learning_rate": 2.6900000000000003e-05,
198
+ "loss": 9.097848510742187,
199
+ "step": 270
200
+ },
201
+ {
202
+ "epoch": 0.30879514750482495,
203
+ "grad_norm": 2.4282069206237793,
204
+ "learning_rate": 2.7900000000000004e-05,
205
+ "loss": 9.042235565185546,
206
+ "step": 280
207
+ },
208
+ {
209
+ "epoch": 0.31982354562999726,
210
+ "grad_norm": 4.171019554138184,
211
+ "learning_rate": 2.8899999999999998e-05,
212
+ "loss": 8.927298736572265,
213
+ "step": 290
214
+ },
215
+ {
216
+ "epoch": 0.3308519437551696,
217
+ "grad_norm": 2.197887659072876,
218
+ "learning_rate": 2.9900000000000002e-05,
219
+ "loss": 8.805252075195312,
220
+ "step": 300
221
+ },
222
+ {
223
+ "epoch": 0.3418803418803419,
224
+ "grad_norm": 10.306541442871094,
225
+ "learning_rate": 3.09e-05,
226
+ "loss": 8.673678588867187,
227
+ "step": 310
228
+ },
229
+ {
230
+ "epoch": 0.3529087400055142,
231
+ "grad_norm": 8.463860511779785,
232
+ "learning_rate": 3.19e-05,
233
+ "loss": 8.570347595214844,
234
+ "step": 320
235
+ },
236
+ {
237
+ "epoch": 0.3639371381306865,
238
+ "grad_norm": 3.999753475189209,
239
+ "learning_rate": 3.29e-05,
240
+ "loss": 8.429109191894531,
241
+ "step": 330
242
+ },
243
+ {
244
+ "epoch": 0.3749655362558588,
245
+ "grad_norm": 5.259007930755615,
246
+ "learning_rate": 3.3900000000000004e-05,
247
+ "loss": 8.334149169921876,
248
+ "step": 340
249
+ },
250
+ {
251
+ "epoch": 0.38599393438103113,
252
+ "grad_norm": 8.362598419189453,
253
+ "learning_rate": 3.49e-05,
254
+ "loss": 8.196139526367187,
255
+ "step": 350
256
+ },
257
+ {
258
+ "epoch": 0.3970223325062035,
259
+ "grad_norm": 10.273512840270996,
260
+ "learning_rate": 3.59e-05,
261
+ "loss": 8.040153503417969,
262
+ "step": 360
263
+ },
264
+ {
265
+ "epoch": 0.4080507306313758,
266
+ "grad_norm": 5.111108303070068,
267
+ "learning_rate": 3.69e-05,
268
+ "loss": 7.866473388671875,
269
+ "step": 370
270
+ },
271
+ {
272
+ "epoch": 0.4190791287565481,
273
+ "grad_norm": 9.192107200622559,
274
+ "learning_rate": 3.79e-05,
275
+ "loss": 7.695774841308594,
276
+ "step": 380
277
+ },
278
+ {
279
+ "epoch": 0.43010752688172044,
280
+ "grad_norm": 5.393336772918701,
281
+ "learning_rate": 3.8900000000000004e-05,
282
+ "loss": 7.498152160644532,
283
+ "step": 390
284
+ },
285
+ {
286
+ "epoch": 0.44113592500689275,
287
+ "grad_norm": 10.53490161895752,
288
+ "learning_rate": 3.99e-05,
289
+ "loss": 7.270246887207032,
290
+ "step": 400
291
+ },
292
+ {
293
+ "epoch": 0.45216432313206506,
294
+ "grad_norm": 6.174643516540527,
295
+ "learning_rate": 4.09e-05,
296
+ "loss": 7.127191162109375,
297
+ "step": 410
298
+ },
299
+ {
300
+ "epoch": 0.46319272125723737,
301
+ "grad_norm": 4.522936820983887,
302
+ "learning_rate": 4.19e-05,
303
+ "loss": 6.871500396728516,
304
+ "step": 420
305
+ },
306
+ {
307
+ "epoch": 0.4742211193824097,
308
+ "grad_norm": 4.3594207763671875,
309
+ "learning_rate": 4.29e-05,
310
+ "loss": 6.702586364746094,
311
+ "step": 430
312
+ },
313
+ {
314
+ "epoch": 0.48524951750758205,
315
+ "grad_norm": 5.950730323791504,
316
+ "learning_rate": 4.39e-05,
317
+ "loss": 6.493560791015625,
318
+ "step": 440
319
+ },
320
+ {
321
+ "epoch": 0.49627791563275436,
322
+ "grad_norm": 6.233413219451904,
323
+ "learning_rate": 4.49e-05,
324
+ "loss": 6.293489074707031,
325
+ "step": 450
326
+ },
327
+ {
328
+ "epoch": 0.5073063137579267,
329
+ "grad_norm": 7.656834125518799,
330
+ "learning_rate": 4.5900000000000004e-05,
331
+ "loss": 6.102347946166992,
332
+ "step": 460
333
+ },
334
+ {
335
+ "epoch": 0.518334711883099,
336
+ "grad_norm": 4.319094657897949,
337
+ "learning_rate": 4.69e-05,
338
+ "loss": 5.928083419799805,
339
+ "step": 470
340
+ },
341
+ {
342
+ "epoch": 0.5293631100082713,
343
+ "grad_norm": 5.585537433624268,
344
+ "learning_rate": 4.79e-05,
345
+ "loss": 5.77436637878418,
346
+ "step": 480
347
+ },
348
+ {
349
+ "epoch": 0.5403915081334436,
350
+ "grad_norm": 5.104014873504639,
351
+ "learning_rate": 4.89e-05,
352
+ "loss": 5.636859130859375,
353
+ "step": 490
354
+ },
355
+ {
356
+ "epoch": 0.5514199062586159,
357
+ "grad_norm": 5.453028202056885,
358
+ "learning_rate": 4.99e-05,
359
+ "loss": 5.507636260986328,
360
+ "step": 500
361
+ },
362
+ {
363
+ "epoch": 0.5624483043837882,
364
+ "grad_norm": 7.728854179382324,
365
+ "learning_rate": 5.0900000000000004e-05,
366
+ "loss": 5.411964416503906,
367
+ "step": 510
368
+ },
369
+ {
370
+ "epoch": 0.5734767025089605,
371
+ "grad_norm": 4.50288724899292,
372
+ "learning_rate": 5.19e-05,
373
+ "loss": 5.295291900634766,
374
+ "step": 520
375
+ },
376
+ {
377
+ "epoch": 0.5845051006341329,
378
+ "grad_norm": 4.245919704437256,
379
+ "learning_rate": 5.2900000000000005e-05,
380
+ "loss": 5.194162750244141,
381
+ "step": 530
382
+ },
383
+ {
384
+ "epoch": 0.5955334987593052,
385
+ "grad_norm": 6.278975963592529,
386
+ "learning_rate": 5.390000000000001e-05,
387
+ "loss": 5.113618087768555,
388
+ "step": 540
389
+ },
390
+ {
391
+ "epoch": 0.6065618968844775,
392
+ "grad_norm": 4.214662075042725,
393
+ "learning_rate": 5.4900000000000006e-05,
394
+ "loss": 5.038372039794922,
395
+ "step": 550
396
+ },
397
+ {
398
+ "epoch": 0.6175902950096499,
399
+ "grad_norm": 3.5404605865478516,
400
+ "learning_rate": 5.590000000000001e-05,
401
+ "loss": 4.935391235351562,
402
+ "step": 560
403
+ },
404
+ {
405
+ "epoch": 0.6286186931348222,
406
+ "grad_norm": 3.6460280418395996,
407
+ "learning_rate": 5.69e-05,
408
+ "loss": 4.896538543701172,
409
+ "step": 570
410
+ },
411
+ {
412
+ "epoch": 0.6396470912599945,
413
+ "grad_norm": 5.254800796508789,
414
+ "learning_rate": 5.79e-05,
415
+ "loss": 4.829419708251953,
416
+ "step": 580
417
+ },
418
+ {
419
+ "epoch": 0.6506754893851668,
420
+ "grad_norm": 5.132180690765381,
421
+ "learning_rate": 5.89e-05,
422
+ "loss": 4.793368148803711,
423
+ "step": 590
424
+ },
425
+ {
426
+ "epoch": 0.6617038875103392,
427
+ "grad_norm": 4.222960948944092,
428
+ "learning_rate": 5.99e-05,
429
+ "loss": 4.746239852905274,
430
+ "step": 600
431
+ },
432
+ {
433
+ "epoch": 0.6727322856355115,
434
+ "grad_norm": 4.070414066314697,
435
+ "learning_rate": 6.09e-05,
436
+ "loss": 4.688523864746093,
437
+ "step": 610
438
+ },
439
+ {
440
+ "epoch": 0.6837606837606838,
441
+ "grad_norm": 3.4652583599090576,
442
+ "learning_rate": 6.19e-05,
443
+ "loss": 4.692922973632813,
444
+ "step": 620
445
+ },
446
+ {
447
+ "epoch": 0.6947890818858561,
448
+ "grad_norm": 4.559128284454346,
449
+ "learning_rate": 6.29e-05,
450
+ "loss": 4.639920043945312,
451
+ "step": 630
452
+ },
453
+ {
454
+ "epoch": 0.7058174800110284,
455
+ "grad_norm": 3.197758436203003,
456
+ "learning_rate": 6.390000000000001e-05,
457
+ "loss": 4.601907348632812,
458
+ "step": 640
459
+ },
460
+ {
461
+ "epoch": 0.7168458781362007,
462
+ "grad_norm": 4.209578514099121,
463
+ "learning_rate": 6.49e-05,
464
+ "loss": 4.56639404296875,
465
+ "step": 650
466
+ },
467
+ {
468
+ "epoch": 0.727874276261373,
469
+ "grad_norm": 3.701484203338623,
470
+ "learning_rate": 6.59e-05,
471
+ "loss": 4.545608901977539,
472
+ "step": 660
473
+ },
474
+ {
475
+ "epoch": 0.7389026743865453,
476
+ "grad_norm": 3.951927900314331,
477
+ "learning_rate": 6.690000000000001e-05,
478
+ "loss": 4.493326187133789,
479
+ "step": 670
480
+ },
481
+ {
482
+ "epoch": 0.7499310725117176,
483
+ "grad_norm": 4.219130039215088,
484
+ "learning_rate": 6.790000000000001e-05,
485
+ "loss": 4.482691955566406,
486
+ "step": 680
487
+ },
488
+ {
489
+ "epoch": 0.76095947063689,
490
+ "grad_norm": 6.267204284667969,
491
+ "learning_rate": 6.89e-05,
492
+ "loss": 4.4599052429199215,
493
+ "step": 690
494
+ },
495
+ {
496
+ "epoch": 0.7719878687620623,
497
+ "grad_norm": 3.367382764816284,
498
+ "learning_rate": 6.99e-05,
499
+ "loss": 4.429808807373047,
500
+ "step": 700
501
+ },
502
+ {
503
+ "epoch": 0.7830162668872346,
504
+ "grad_norm": 3.8906455039978027,
505
+ "learning_rate": 7.09e-05,
506
+ "loss": 4.4144752502441404,
507
+ "step": 710
508
+ },
509
+ {
510
+ "epoch": 0.794044665012407,
511
+ "grad_norm": 6.759398460388184,
512
+ "learning_rate": 7.19e-05,
513
+ "loss": 4.385488891601563,
514
+ "step": 720
515
+ },
516
+ {
517
+ "epoch": 0.8050730631375793,
518
+ "grad_norm": 3.520167350769043,
519
+ "learning_rate": 7.29e-05,
520
+ "loss": 4.397706985473633,
521
+ "step": 730
522
+ },
523
+ {
524
+ "epoch": 0.8161014612627516,
525
+ "grad_norm": 2.7510974407196045,
526
+ "learning_rate": 7.390000000000001e-05,
527
+ "loss": 4.374617385864258,
528
+ "step": 740
529
+ },
530
+ {
531
+ "epoch": 0.8271298593879239,
532
+ "grad_norm": 4.395699977874756,
533
+ "learning_rate": 7.49e-05,
534
+ "loss": 4.3302146911621096,
535
+ "step": 750
536
+ },
537
+ {
538
+ "epoch": 0.8381582575130962,
539
+ "grad_norm": 3.277766704559326,
540
+ "learning_rate": 7.59e-05,
541
+ "loss": 4.313335418701172,
542
+ "step": 760
543
+ },
544
+ {
545
+ "epoch": 0.8491866556382686,
546
+ "grad_norm": 2.466207981109619,
547
+ "learning_rate": 7.69e-05,
548
+ "loss": 4.3226570129394535,
549
+ "step": 770
550
+ },
551
+ {
552
+ "epoch": 0.8602150537634409,
553
+ "grad_norm": 3.637355327606201,
554
+ "learning_rate": 7.790000000000001e-05,
555
+ "loss": 4.295929718017578,
556
+ "step": 780
557
+ },
558
+ {
559
+ "epoch": 0.8712434518886132,
560
+ "grad_norm": 3.155527353286743,
561
+ "learning_rate": 7.890000000000001e-05,
562
+ "loss": 4.287591552734375,
563
+ "step": 790
564
+ },
565
+ {
566
+ "epoch": 0.8822718500137855,
567
+ "grad_norm": 3.593884229660034,
568
+ "learning_rate": 7.99e-05,
569
+ "loss": 4.267314147949219,
570
+ "step": 800
571
+ },
572
+ {
573
+ "epoch": 0.8933002481389578,
574
+ "grad_norm": 2.361081123352051,
575
+ "learning_rate": 8.090000000000001e-05,
576
+ "loss": 4.265741348266602,
577
+ "step": 810
578
+ },
579
+ {
580
+ "epoch": 0.9043286462641301,
581
+ "grad_norm": 2.7084105014801025,
582
+ "learning_rate": 8.19e-05,
583
+ "loss": 4.261878204345703,
584
+ "step": 820
585
+ },
586
+ {
587
+ "epoch": 0.9153570443893024,
588
+ "grad_norm": 3.6093873977661133,
589
+ "learning_rate": 8.29e-05,
590
+ "loss": 4.211677551269531,
591
+ "step": 830
592
+ },
593
+ {
594
+ "epoch": 0.9263854425144747,
595
+ "grad_norm": 3.9739396572113037,
596
+ "learning_rate": 8.39e-05,
597
+ "loss": 4.224007034301758,
598
+ "step": 840
599
+ },
600
+ {
601
+ "epoch": 0.9374138406396471,
602
+ "grad_norm": 2.174050807952881,
603
+ "learning_rate": 8.49e-05,
604
+ "loss": 4.211782836914063,
605
+ "step": 850
606
+ },
607
+ {
608
+ "epoch": 0.9484422387648194,
609
+ "grad_norm": 2.7151405811309814,
610
+ "learning_rate": 8.59e-05,
611
+ "loss": 4.204391098022461,
612
+ "step": 860
613
+ },
614
+ {
615
+ "epoch": 0.9594706368899917,
616
+ "grad_norm": 3.7480661869049072,
617
+ "learning_rate": 8.69e-05,
618
+ "loss": 4.175582504272461,
619
+ "step": 870
620
+ },
621
+ {
622
+ "epoch": 0.9704990350151641,
623
+ "grad_norm": 3.1127700805664062,
624
+ "learning_rate": 8.790000000000001e-05,
625
+ "loss": 4.183733749389648,
626
+ "step": 880
627
+ },
628
+ {
629
+ "epoch": 0.9815274331403364,
630
+ "grad_norm": 2.750716209411621,
631
+ "learning_rate": 8.89e-05,
632
+ "loss": 4.167971801757813,
633
+ "step": 890
634
+ },
635
+ {
636
+ "epoch": 0.9925558312655087,
637
+ "grad_norm": 4.02509880065918,
638
+ "learning_rate": 8.99e-05,
639
+ "loss": 4.170472717285156,
640
+ "step": 900
641
+ },
642
+ {
643
+ "epoch": 1.0033085194375517,
644
+ "grad_norm": 3.0058505535125732,
645
+ "learning_rate": 9.090000000000001e-05,
646
+ "loss": 4.1449127197265625,
647
+ "step": 910
648
+ },
649
+ {
650
+ "epoch": 1.014336917562724,
651
+ "grad_norm": 2.553403377532959,
652
+ "learning_rate": 9.190000000000001e-05,
653
+ "loss": 4.1404258728027346,
654
+ "step": 920
655
+ },
656
+ {
657
+ "epoch": 1.0253653156878964,
658
+ "grad_norm": 2.8066084384918213,
659
+ "learning_rate": 9.290000000000001e-05,
660
+ "loss": 4.110780334472656,
661
+ "step": 930
662
+ },
663
+ {
664
+ "epoch": 1.0363937138130686,
665
+ "grad_norm": 3.904608726501465,
666
+ "learning_rate": 9.39e-05,
667
+ "loss": 4.134862899780273,
668
+ "step": 940
669
+ },
670
+ {
671
+ "epoch": 1.047422111938241,
672
+ "grad_norm": 2.217729330062866,
673
+ "learning_rate": 9.49e-05,
674
+ "loss": 4.112079620361328,
675
+ "step": 950
676
+ },
677
+ {
678
+ "epoch": 1.0584505100634134,
679
+ "grad_norm": 2.498760938644409,
680
+ "learning_rate": 9.59e-05,
681
+ "loss": 4.097566986083985,
682
+ "step": 960
683
+ },
684
+ {
685
+ "epoch": 1.0694789081885856,
686
+ "grad_norm": 3.577143907546997,
687
+ "learning_rate": 9.69e-05,
688
+ "loss": 4.081307220458984,
689
+ "step": 970
690
+ },
691
+ {
692
+ "epoch": 1.080507306313758,
693
+ "grad_norm": 3.283250570297241,
694
+ "learning_rate": 9.790000000000001e-05,
695
+ "loss": 4.103987503051758,
696
+ "step": 980
697
+ },
698
+ {
699
+ "epoch": 1.0915357044389302,
700
+ "grad_norm": 2.1897776126861572,
701
+ "learning_rate": 9.89e-05,
702
+ "loss": 4.084938812255859,
703
+ "step": 990
704
+ },
705
+ {
706
+ "epoch": 1.1025641025641026,
707
+ "grad_norm": 2.6925997734069824,
708
+ "learning_rate": 9.99e-05,
709
+ "loss": 4.058921051025391,
710
+ "step": 1000
711
+ },
712
+ {
713
+ "epoch": 1.1135925006892748,
714
+ "grad_norm": 3.4118456840515137,
715
+ "learning_rate": 9.994749124854142e-05,
716
+ "loss": 4.061585235595703,
717
+ "step": 1010
718
+ },
719
+ {
720
+ "epoch": 1.1246208988144473,
721
+ "grad_norm": 2.6139297485351562,
722
+ "learning_rate": 9.988914819136523e-05,
723
+ "loss": 4.070050048828125,
724
+ "step": 1020
725
+ },
726
+ {
727
+ "epoch": 1.1356492969396195,
728
+ "grad_norm": 1.8616399765014648,
729
+ "learning_rate": 9.983080513418903e-05,
730
+ "loss": 4.0413330078125,
731
+ "step": 1030
732
+ },
733
+ {
734
+ "epoch": 1.146677695064792,
735
+ "grad_norm": 2.361706018447876,
736
+ "learning_rate": 9.977246207701284e-05,
737
+ "loss": 4.023075866699219,
738
+ "step": 1040
739
+ },
740
+ {
741
+ "epoch": 1.157706093189964,
742
+ "grad_norm": 3.815014123916626,
743
+ "learning_rate": 9.971411901983664e-05,
744
+ "loss": 4.036756134033203,
745
+ "step": 1050
746
+ },
747
+ {
748
+ "epoch": 1.1687344913151365,
749
+ "grad_norm": 2.4410274028778076,
750
+ "learning_rate": 9.965577596266045e-05,
751
+ "loss": 4.020483779907226,
752
+ "step": 1060
753
+ },
754
+ {
755
+ "epoch": 1.1797628894403087,
756
+ "grad_norm": 2.768084764480591,
757
+ "learning_rate": 9.959743290548426e-05,
758
+ "loss": 4.021839141845703,
759
+ "step": 1070
760
+ },
761
+ {
762
+ "epoch": 1.1907912875654811,
763
+ "grad_norm": 1.9342570304870605,
764
+ "learning_rate": 9.953908984830806e-05,
765
+ "loss": 4.026360321044922,
766
+ "step": 1080
767
+ },
768
+ {
769
+ "epoch": 1.2018196856906533,
770
+ "grad_norm": 2.8184762001037598,
771
+ "learning_rate": 9.948074679113187e-05,
772
+ "loss": 4.007581329345703,
773
+ "step": 1090
774
+ },
775
+ {
776
+ "epoch": 1.2128480838158258,
777
+ "grad_norm": 3.2656188011169434,
778
+ "learning_rate": 9.942240373395566e-05,
779
+ "loss": 3.9965087890625,
780
+ "step": 1100
781
+ },
782
+ {
783
+ "epoch": 1.223876481940998,
784
+ "grad_norm": 2.4359538555145264,
785
+ "learning_rate": 9.936406067677947e-05,
786
+ "loss": 3.9959388732910157,
787
+ "step": 1110
788
+ },
789
+ {
790
+ "epoch": 1.2349048800661704,
791
+ "grad_norm": 1.9357632398605347,
792
+ "learning_rate": 9.930571761960327e-05,
793
+ "loss": 3.9851417541503906,
794
+ "step": 1120
795
+ },
796
+ {
797
+ "epoch": 1.2459332781913428,
798
+ "grad_norm": 2.1269352436065674,
799
+ "learning_rate": 9.924737456242708e-05,
800
+ "loss": 3.9773223876953123,
801
+ "step": 1130
802
+ },
803
+ {
804
+ "epoch": 1.256961676316515,
805
+ "grad_norm": 3.3491597175598145,
806
+ "learning_rate": 9.918903150525088e-05,
807
+ "loss": 3.9877471923828125,
808
+ "step": 1140
809
+ },
810
+ {
811
+ "epoch": 1.2679900744416872,
812
+ "grad_norm": 1.8646328449249268,
813
+ "learning_rate": 9.913068844807468e-05,
814
+ "loss": 3.9694965362548826,
815
+ "step": 1150
816
+ },
817
+ {
818
+ "epoch": 1.2790184725668596,
819
+ "grad_norm": 2.6204631328582764,
820
+ "learning_rate": 9.907234539089849e-05,
821
+ "loss": 3.9611881256103514,
822
+ "step": 1160
823
+ },
824
+ {
825
+ "epoch": 1.290046870692032,
826
+ "grad_norm": 1.872028112411499,
827
+ "learning_rate": 9.901400233372228e-05,
828
+ "loss": 3.964163970947266,
829
+ "step": 1170
830
+ },
831
+ {
832
+ "epoch": 1.3010752688172043,
833
+ "grad_norm": 3.490435838699341,
834
+ "learning_rate": 9.895565927654609e-05,
835
+ "loss": 3.959897994995117,
836
+ "step": 1180
837
+ },
838
+ {
839
+ "epoch": 1.3121036669423767,
840
+ "grad_norm": 2.862489700317383,
841
+ "learning_rate": 9.88973162193699e-05,
842
+ "loss": 3.9567939758300783,
843
+ "step": 1190
844
+ },
845
+ {
846
+ "epoch": 1.3231320650675489,
847
+ "grad_norm": 3.0570664405822754,
848
+ "learning_rate": 9.883897316219371e-05,
849
+ "loss": 3.9470645904541017,
850
+ "step": 1200
851
+ },
852
+ {
853
+ "epoch": 1.3341604631927213,
854
+ "grad_norm": 1.9254627227783203,
855
+ "learning_rate": 9.878063010501752e-05,
856
+ "loss": 3.9442317962646483,
857
+ "step": 1210
858
+ },
859
+ {
860
+ "epoch": 1.3451888613178935,
861
+ "grad_norm": 3.606224298477173,
862
+ "learning_rate": 9.872228704784131e-05,
863
+ "loss": 3.9380733489990236,
864
+ "step": 1220
865
+ },
866
+ {
867
+ "epoch": 1.356217259443066,
868
+ "grad_norm": 2.1184027194976807,
869
+ "learning_rate": 9.866394399066512e-05,
870
+ "loss": 3.9452835083007813,
871
+ "step": 1230
872
+ },
873
+ {
874
+ "epoch": 1.3672456575682381,
875
+ "grad_norm": 1.8997142314910889,
876
+ "learning_rate": 9.860560093348892e-05,
877
+ "loss": 3.9270603179931642,
878
+ "step": 1240
879
+ },
880
+ {
881
+ "epoch": 1.3782740556934105,
882
+ "grad_norm": 2.9672305583953857,
883
+ "learning_rate": 9.854725787631273e-05,
884
+ "loss": 3.9120155334472657,
885
+ "step": 1250
886
+ },
887
+ {
888
+ "epoch": 1.389302453818583,
889
+ "grad_norm": 1.9220951795578003,
890
+ "learning_rate": 9.848891481913652e-05,
891
+ "loss": 3.900279235839844,
892
+ "step": 1260
893
+ },
894
+ {
895
+ "epoch": 1.4003308519437552,
896
+ "grad_norm": 2.013521194458008,
897
+ "learning_rate": 9.843057176196033e-05,
898
+ "loss": 3.9147193908691404,
899
+ "step": 1270
900
+ },
901
+ {
902
+ "epoch": 1.4113592500689274,
903
+ "grad_norm": 1.451686143875122,
904
+ "learning_rate": 9.837222870478413e-05,
905
+ "loss": 3.906220245361328,
906
+ "step": 1280
907
+ },
908
+ {
909
+ "epoch": 1.4223876481940998,
910
+ "grad_norm": 4.606860637664795,
911
+ "learning_rate": 9.831388564760794e-05,
912
+ "loss": 3.905352020263672,
913
+ "step": 1290
914
+ },
915
+ {
916
+ "epoch": 1.4334160463192722,
917
+ "grad_norm": 1.779123306274414,
918
+ "learning_rate": 9.825554259043175e-05,
919
+ "loss": 3.9137496948242188,
920
+ "step": 1300
921
+ },
922
+ {
923
+ "epoch": 1.4444444444444444,
924
+ "grad_norm": 2.086585521697998,
925
+ "learning_rate": 9.819719953325554e-05,
926
+ "loss": 3.89554443359375,
927
+ "step": 1310
928
+ },
929
+ {
930
+ "epoch": 1.4554728425696168,
931
+ "grad_norm": 3.3514609336853027,
932
+ "learning_rate": 9.813885647607935e-05,
933
+ "loss": 3.8901123046875,
934
+ "step": 1320
935
+ },
936
+ {
937
+ "epoch": 1.466501240694789,
938
+ "grad_norm": 2.1145269870758057,
939
+ "learning_rate": 9.808051341890316e-05,
940
+ "loss": 3.8892486572265623,
941
+ "step": 1330
942
+ },
943
+ {
944
+ "epoch": 1.4775296388199615,
945
+ "grad_norm": 1.5503329038619995,
946
+ "learning_rate": 9.802217036172697e-05,
947
+ "loss": 3.8922355651855467,
948
+ "step": 1340
949
+ },
950
+ {
951
+ "epoch": 1.4885580369451337,
952
+ "grad_norm": 2.3014304637908936,
953
+ "learning_rate": 9.796382730455076e-05,
954
+ "loss": 3.8860099792480467,
955
+ "step": 1350
956
+ },
957
+ {
958
+ "epoch": 1.499586435070306,
959
+ "grad_norm": 1.9633557796478271,
960
+ "learning_rate": 9.790548424737457e-05,
961
+ "loss": 3.875183868408203,
962
+ "step": 1360
963
+ },
964
+ {
965
+ "epoch": 1.5106148331954783,
966
+ "grad_norm": 2.228351593017578,
967
+ "learning_rate": 9.784714119019837e-05,
968
+ "loss": 3.8726768493652344,
969
+ "step": 1370
970
+ },
971
+ {
972
+ "epoch": 1.5216432313206507,
973
+ "grad_norm": 3.0888657569885254,
974
+ "learning_rate": 9.778879813302218e-05,
975
+ "loss": 3.872690963745117,
976
+ "step": 1380
977
+ },
978
+ {
979
+ "epoch": 1.5326716294458231,
980
+ "grad_norm": 2.0078868865966797,
981
+ "learning_rate": 9.773045507584599e-05,
982
+ "loss": 3.8612388610839843,
983
+ "step": 1390
984
+ },
985
+ {
986
+ "epoch": 1.5437000275709953,
987
+ "grad_norm": 2.1966569423675537,
988
+ "learning_rate": 9.767211201866978e-05,
989
+ "loss": 3.8649852752685545,
990
+ "step": 1400
991
+ },
992
+ {
993
+ "epoch": 1.5547284256961675,
994
+ "grad_norm": 2.1047487258911133,
995
+ "learning_rate": 9.761376896149359e-05,
996
+ "loss": 3.8632328033447267,
997
+ "step": 1410
998
+ },
999
+ {
1000
+ "epoch": 1.56575682382134,
1001
+ "grad_norm": 1.9347233772277832,
1002
+ "learning_rate": 9.755542590431739e-05,
1003
+ "loss": 3.8362571716308596,
1004
+ "step": 1420
1005
+ },
1006
+ {
1007
+ "epoch": 1.5767852219465124,
1008
+ "grad_norm": 1.7961437702178955,
1009
+ "learning_rate": 9.74970828471412e-05,
1010
+ "loss": 3.8461585998535157,
1011
+ "step": 1430
1012
+ },
1013
+ {
1014
+ "epoch": 1.5878136200716846,
1015
+ "grad_norm": 2.4657342433929443,
1016
+ "learning_rate": 9.743873978996499e-05,
1017
+ "loss": 3.842551040649414,
1018
+ "step": 1440
1019
+ },
1020
+ {
1021
+ "epoch": 1.5988420181968568,
1022
+ "grad_norm": 2.043138027191162,
1023
+ "learning_rate": 9.73803967327888e-05,
1024
+ "loss": 3.8387855529785155,
1025
+ "step": 1450
1026
+ },
1027
+ {
1028
+ "epoch": 1.6098704163220292,
1029
+ "grad_norm": 3.732532262802124,
1030
+ "learning_rate": 9.732205367561261e-05,
1031
+ "loss": 3.8399681091308593,
1032
+ "step": 1460
1033
+ },
1034
+ {
1035
+ "epoch": 1.6208988144472016,
1036
+ "grad_norm": 2.43684720993042,
1037
+ "learning_rate": 9.726371061843642e-05,
1038
+ "loss": 3.8324966430664062,
1039
+ "step": 1470
1040
+ },
1041
+ {
1042
+ "epoch": 1.6319272125723738,
1043
+ "grad_norm": 2.4433460235595703,
1044
+ "learning_rate": 9.720536756126023e-05,
1045
+ "loss": 3.817783737182617,
1046
+ "step": 1480
1047
+ },
1048
+ {
1049
+ "epoch": 1.642955610697546,
1050
+ "grad_norm": 2.1049606800079346,
1051
+ "learning_rate": 9.714702450408402e-05,
1052
+ "loss": 3.804280090332031,
1053
+ "step": 1490
1054
+ },
1055
+ {
1056
+ "epoch": 1.6539840088227185,
1057
+ "grad_norm": 3.529686450958252,
1058
+ "learning_rate": 9.708868144690783e-05,
1059
+ "loss": 3.805449295043945,
1060
+ "step": 1500
1061
+ },
1062
+ {
1063
+ "epoch": 1.6650124069478909,
1064
+ "grad_norm": 2.0984089374542236,
1065
+ "learning_rate": 9.703033838973162e-05,
1066
+ "loss": 3.788246917724609,
1067
+ "step": 1510
1068
+ },
1069
+ {
1070
+ "epoch": 1.6760408050730633,
1071
+ "grad_norm": 1.9434291124343872,
1072
+ "learning_rate": 9.697199533255543e-05,
1073
+ "loss": 3.7875442504882812,
1074
+ "step": 1520
1075
+ },
1076
+ {
1077
+ "epoch": 1.6870692031982355,
1078
+ "grad_norm": 1.99173903465271,
1079
+ "learning_rate": 9.691365227537923e-05,
1080
+ "loss": 3.7807193756103517,
1081
+ "step": 1530
1082
+ },
1083
+ {
1084
+ "epoch": 1.6980976013234077,
1085
+ "grad_norm": 2.5006911754608154,
1086
+ "learning_rate": 9.685530921820304e-05,
1087
+ "loss": 3.744763946533203,
1088
+ "step": 1540
1089
+ },
1090
+ {
1091
+ "epoch": 1.7091259994485801,
1092
+ "grad_norm": 2.1816165447235107,
1093
+ "learning_rate": 9.679696616102685e-05,
1094
+ "loss": 3.760245513916016,
1095
+ "step": 1550
1096
+ },
1097
+ {
1098
+ "epoch": 1.7201543975737525,
1099
+ "grad_norm": 2.123291492462158,
1100
+ "learning_rate": 9.673862310385064e-05,
1101
+ "loss": 3.738916778564453,
1102
+ "step": 1560
1103
+ },
1104
+ {
1105
+ "epoch": 1.7311827956989247,
1106
+ "grad_norm": 2.378187894821167,
1107
+ "learning_rate": 9.668028004667445e-05,
1108
+ "loss": 3.734139251708984,
1109
+ "step": 1570
1110
+ },
1111
+ {
1112
+ "epoch": 1.742211193824097,
1113
+ "grad_norm": 2.54819393157959,
1114
+ "learning_rate": 9.662193698949825e-05,
1115
+ "loss": 3.715302276611328,
1116
+ "step": 1580
1117
+ },
1118
+ {
1119
+ "epoch": 1.7532395919492694,
1120
+ "grad_norm": 4.285822868347168,
1121
+ "learning_rate": 9.656359393232206e-05,
1122
+ "loss": 3.72213134765625,
1123
+ "step": 1590
1124
+ },
1125
+ {
1126
+ "epoch": 1.7642679900744418,
1127
+ "grad_norm": 1.8676700592041016,
1128
+ "learning_rate": 9.650525087514586e-05,
1129
+ "loss": 3.7252479553222657,
1130
+ "step": 1600
1131
+ },
1132
+ {
1133
+ "epoch": 1.775296388199614,
1134
+ "grad_norm": 1.6977792978286743,
1135
+ "learning_rate": 9.644690781796967e-05,
1136
+ "loss": 3.704994964599609,
1137
+ "step": 1610
1138
+ },
1139
+ {
1140
+ "epoch": 1.7863247863247862,
1141
+ "grad_norm": 1.8334232568740845,
1142
+ "learning_rate": 9.638856476079347e-05,
1143
+ "loss": 3.6980815887451173,
1144
+ "step": 1620
1145
+ },
1146
+ {
1147
+ "epoch": 1.7973531844499586,
1148
+ "grad_norm": 2.6574559211730957,
1149
+ "learning_rate": 9.633022170361728e-05,
1150
+ "loss": 3.683759307861328,
1151
+ "step": 1630
1152
+ },
1153
+ {
1154
+ "epoch": 1.808381582575131,
1155
+ "grad_norm": 2.085084915161133,
1156
+ "learning_rate": 9.627187864644109e-05,
1157
+ "loss": 3.67755126953125,
1158
+ "step": 1640
1159
+ },
1160
+ {
1161
+ "epoch": 1.8194099807003032,
1162
+ "grad_norm": 1.685441017150879,
1163
+ "learning_rate": 9.621353558926488e-05,
1164
+ "loss": 3.656099319458008,
1165
+ "step": 1650
1166
+ },
1167
+ {
1168
+ "epoch": 1.8304383788254754,
1169
+ "grad_norm": 2.4462475776672363,
1170
+ "learning_rate": 9.615519253208869e-05,
1171
+ "loss": 3.668656921386719,
1172
+ "step": 1660
1173
+ },
1174
+ {
1175
+ "epoch": 1.8414667769506479,
1176
+ "grad_norm": 1.54155433177948,
1177
+ "learning_rate": 9.609684947491249e-05,
1178
+ "loss": 3.66968994140625,
1179
+ "step": 1670
1180
+ },
1181
+ {
1182
+ "epoch": 1.8524951750758203,
1183
+ "grad_norm": 3.862130880355835,
1184
+ "learning_rate": 9.60385064177363e-05,
1185
+ "loss": 3.6412506103515625,
1186
+ "step": 1680
1187
+ },
1188
+ {
1189
+ "epoch": 1.8635235732009927,
1190
+ "grad_norm": 1.7317070960998535,
1191
+ "learning_rate": 9.598016336056009e-05,
1192
+ "loss": 3.639806365966797,
1193
+ "step": 1690
1194
+ },
1195
+ {
1196
+ "epoch": 1.874551971326165,
1197
+ "grad_norm": 2.2640931606292725,
1198
+ "learning_rate": 9.59218203033839e-05,
1199
+ "loss": 3.6341064453125,
1200
+ "step": 1700
1201
+ },
1202
+ {
1203
+ "epoch": 1.8855803694513371,
1204
+ "grad_norm": 3.653146743774414,
1205
+ "learning_rate": 9.586347724620771e-05,
1206
+ "loss": 3.6380882263183594,
1207
+ "step": 1710
1208
+ },
1209
+ {
1210
+ "epoch": 1.8966087675765095,
1211
+ "grad_norm": 1.8987306356430054,
1212
+ "learning_rate": 9.58051341890315e-05,
1213
+ "loss": 3.6405975341796877,
1214
+ "step": 1720
1215
+ },
1216
+ {
1217
+ "epoch": 1.907637165701682,
1218
+ "grad_norm": 2.202659845352173,
1219
+ "learning_rate": 9.574679113185531e-05,
1220
+ "loss": 3.6375991821289064,
1221
+ "step": 1730
1222
+ },
1223
+ {
1224
+ "epoch": 1.9186655638268542,
1225
+ "grad_norm": 1.5091872215270996,
1226
+ "learning_rate": 9.568844807467912e-05,
1227
+ "loss": 3.6208465576171873,
1228
+ "step": 1740
1229
+ },
1230
+ {
1231
+ "epoch": 1.9296939619520264,
1232
+ "grad_norm": 1.9811325073242188,
1233
+ "learning_rate": 9.563010501750293e-05,
1234
+ "loss": 3.600755310058594,
1235
+ "step": 1750
1236
+ },
1237
+ {
1238
+ "epoch": 1.9407223600771988,
1239
+ "grad_norm": 3.184499979019165,
1240
+ "learning_rate": 9.557176196032673e-05,
1241
+ "loss": 3.6109405517578126,
1242
+ "step": 1760
1243
+ },
1244
+ {
1245
+ "epoch": 1.9517507582023712,
1246
+ "grad_norm": 2.340125322341919,
1247
+ "learning_rate": 9.551341890315054e-05,
1248
+ "loss": 3.6129817962646484,
1249
+ "step": 1770
1250
+ },
1251
+ {
1252
+ "epoch": 1.9627791563275434,
1253
+ "grad_norm": 1.7258495092391968,
1254
+ "learning_rate": 9.545507584597433e-05,
1255
+ "loss": 3.590809631347656,
1256
+ "step": 1780
1257
+ },
1258
+ {
1259
+ "epoch": 1.9738075544527156,
1260
+ "grad_norm": 1.6129754781723022,
1261
+ "learning_rate": 9.539673278879814e-05,
1262
+ "loss": 3.5866302490234374,
1263
+ "step": 1790
1264
+ },
1265
+ {
1266
+ "epoch": 1.984835952577888,
1267
+ "grad_norm": 2.7458667755126953,
1268
+ "learning_rate": 9.533838973162195e-05,
1269
+ "loss": 3.596644973754883,
1270
+ "step": 1800
1271
+ },
1272
+ {
1273
+ "epoch": 1.9958643507030605,
1274
+ "grad_norm": 2.258280038833618,
1275
+ "learning_rate": 9.528004667444574e-05,
1276
+ "loss": 3.5881332397460937,
1277
+ "step": 1810
1278
+ },
1279
+ {
1280
+ "epoch": 2.0066170388751035,
1281
+ "grad_norm": 2.1228580474853516,
1282
+ "learning_rate": 9.522170361726955e-05,
1283
+ "loss": 3.5709766387939452,
1284
+ "step": 1820
1285
+ },
1286
+ {
1287
+ "epoch": 2.017645437000276,
1288
+ "grad_norm": 1.588876485824585,
1289
+ "learning_rate": 9.516336056009335e-05,
1290
+ "loss": 3.5627593994140625,
1291
+ "step": 1830
1292
+ },
1293
+ {
1294
+ "epoch": 2.028673835125448,
1295
+ "grad_norm": 2.451474189758301,
1296
+ "learning_rate": 9.510501750291716e-05,
1297
+ "loss": 3.5535301208496093,
1298
+ "step": 1840
1299
+ },
1300
+ {
1301
+ "epoch": 2.0397022332506203,
1302
+ "grad_norm": 2.0007503032684326,
1303
+ "learning_rate": 9.504667444574095e-05,
1304
+ "loss": 3.553875732421875,
1305
+ "step": 1850
1306
+ },
1307
+ {
1308
+ "epoch": 2.0507306313757927,
1309
+ "grad_norm": 1.4410080909729004,
1310
+ "learning_rate": 9.498833138856476e-05,
1311
+ "loss": 3.550189971923828,
1312
+ "step": 1860
1313
+ },
1314
+ {
1315
+ "epoch": 2.061759029500965,
1316
+ "grad_norm": 2.062835216522217,
1317
+ "learning_rate": 9.492998833138857e-05,
1318
+ "loss": 3.5456893920898436,
1319
+ "step": 1870
1320
+ },
1321
+ {
1322
+ "epoch": 2.072787427626137,
1323
+ "grad_norm": 2.4534783363342285,
1324
+ "learning_rate": 9.487164527421238e-05,
1325
+ "loss": 3.536829376220703,
1326
+ "step": 1880
1327
+ },
1328
+ {
1329
+ "epoch": 2.0838158257513095,
1330
+ "grad_norm": 2.2788970470428467,
1331
+ "learning_rate": 9.481330221703619e-05,
1332
+ "loss": 3.5525283813476562,
1333
+ "step": 1890
1334
+ },
1335
+ {
1336
+ "epoch": 2.094844223876482,
1337
+ "grad_norm": 1.4259227514266968,
1338
+ "learning_rate": 9.475495915985998e-05,
1339
+ "loss": 3.5479995727539064,
1340
+ "step": 1900
1341
+ },
1342
+ {
1343
+ "epoch": 2.1058726220016544,
1344
+ "grad_norm": 2.672534465789795,
1345
+ "learning_rate": 9.469661610268379e-05,
1346
+ "loss": 3.5359420776367188,
1347
+ "step": 1910
1348
+ },
1349
+ {
1350
+ "epoch": 2.116901020126827,
1351
+ "grad_norm": 2.0648045539855957,
1352
+ "learning_rate": 9.463827304550759e-05,
1353
+ "loss": 3.5452896118164063,
1354
+ "step": 1920
1355
+ },
1356
+ {
1357
+ "epoch": 2.1279294182519988,
1358
+ "grad_norm": 1.6846543550491333,
1359
+ "learning_rate": 9.45799299883314e-05,
1360
+ "loss": 3.5434345245361327,
1361
+ "step": 1930
1362
+ },
1363
+ {
1364
+ "epoch": 2.138957816377171,
1365
+ "grad_norm": 1.9105942249298096,
1366
+ "learning_rate": 9.452158693115519e-05,
1367
+ "loss": 3.5351535797119142,
1368
+ "step": 1940
1369
+ },
1370
+ {
1371
+ "epoch": 2.1499862145023436,
1372
+ "grad_norm": 1.8230890035629272,
1373
+ "learning_rate": 9.4463243873979e-05,
1374
+ "loss": 3.5190963745117188,
1375
+ "step": 1950
1376
+ },
1377
+ {
1378
+ "epoch": 2.161014612627516,
1379
+ "grad_norm": 1.6383274793624878,
1380
+ "learning_rate": 9.440490081680281e-05,
1381
+ "loss": 3.5228431701660154,
1382
+ "step": 1960
1383
+ },
1384
+ {
1385
+ "epoch": 2.172043010752688,
1386
+ "grad_norm": 1.7378439903259277,
1387
+ "learning_rate": 9.43465577596266e-05,
1388
+ "loss": 3.520981216430664,
1389
+ "step": 1970
1390
+ },
1391
+ {
1392
+ "epoch": 2.1830714088778604,
1393
+ "grad_norm": 1.941454529762268,
1394
+ "learning_rate": 9.428821470245041e-05,
1395
+ "loss": 3.519342803955078,
1396
+ "step": 1980
1397
+ },
1398
+ {
1399
+ "epoch": 2.194099807003033,
1400
+ "grad_norm": 1.8295516967773438,
1401
+ "learning_rate": 9.422987164527421e-05,
1402
+ "loss": 3.5412979125976562,
1403
+ "step": 1990
1404
+ },
1405
+ {
1406
+ "epoch": 2.2051282051282053,
1407
+ "grad_norm": 1.8052620887756348,
1408
+ "learning_rate": 9.417152858809802e-05,
1409
+ "loss": 3.5153289794921876,
1410
+ "step": 2000
1411
+ },
1412
+ {
1413
+ "epoch": 2.2161566032533773,
1414
+ "grad_norm": 2.1949570178985596,
1415
+ "learning_rate": 9.411318553092183e-05,
1416
+ "loss": 3.521608352661133,
1417
+ "step": 2010
1418
+ },
1419
+ {
1420
+ "epoch": 2.2271850013785497,
1421
+ "grad_norm": 1.746172308921814,
1422
+ "learning_rate": 9.405484247374564e-05,
1423
+ "loss": 3.5008296966552734,
1424
+ "step": 2020
1425
+ },
1426
+ {
1427
+ "epoch": 2.238213399503722,
1428
+ "grad_norm": 2.5374276638031006,
1429
+ "learning_rate": 9.399649941656943e-05,
1430
+ "loss": 3.5140228271484375,
1431
+ "step": 2030
1432
+ },
1433
+ {
1434
+ "epoch": 2.2492417976288945,
1435
+ "grad_norm": 1.7763218879699707,
1436
+ "learning_rate": 9.393815635939324e-05,
1437
+ "loss": 3.510652542114258,
1438
+ "step": 2040
1439
+ },
1440
+ {
1441
+ "epoch": 2.2602701957540665,
1442
+ "grad_norm": 1.6599587202072144,
1443
+ "learning_rate": 9.387981330221705e-05,
1444
+ "loss": 3.5122325897216795,
1445
+ "step": 2050
1446
+ },
1447
+ {
1448
+ "epoch": 2.271298593879239,
1449
+ "grad_norm": 2.1496078968048096,
1450
+ "learning_rate": 9.382147024504085e-05,
1451
+ "loss": 3.5139747619628907,
1452
+ "step": 2060
1453
+ },
1454
+ {
1455
+ "epoch": 2.2823269920044114,
1456
+ "grad_norm": 1.64266836643219,
1457
+ "learning_rate": 9.376312718786465e-05,
1458
+ "loss": 3.507743072509766,
1459
+ "step": 2070
1460
+ },
1461
+ {
1462
+ "epoch": 2.293355390129584,
1463
+ "grad_norm": 2.1241567134857178,
1464
+ "learning_rate": 9.370478413068845e-05,
1465
+ "loss": 3.5162708282470705,
1466
+ "step": 2080
1467
+ },
1468
+ {
1469
+ "epoch": 2.304383788254756,
1470
+ "grad_norm": 1.8391071557998657,
1471
+ "learning_rate": 9.364644107351226e-05,
1472
+ "loss": 3.4955375671386717,
1473
+ "step": 2090
1474
+ },
1475
+ {
1476
+ "epoch": 2.315412186379928,
1477
+ "grad_norm": 2.7478973865509033,
1478
+ "learning_rate": 9.358809801633605e-05,
1479
+ "loss": 3.497519302368164,
1480
+ "step": 2100
1481
+ },
1482
+ {
1483
+ "epoch": 2.3264405845051006,
1484
+ "grad_norm": 1.938588261604309,
1485
+ "learning_rate": 9.352975495915986e-05,
1486
+ "loss": 3.490141677856445,
1487
+ "step": 2110
1488
+ },
1489
+ {
1490
+ "epoch": 2.337468982630273,
1491
+ "grad_norm": 1.5637104511260986,
1492
+ "learning_rate": 9.347141190198366e-05,
1493
+ "loss": 3.499908447265625,
1494
+ "step": 2120
1495
+ },
1496
+ {
1497
+ "epoch": 2.3484973807554455,
1498
+ "grad_norm": 1.882504940032959,
1499
+ "learning_rate": 9.341306884480747e-05,
1500
+ "loss": 3.491979217529297,
1501
+ "step": 2130
1502
+ },
1503
+ {
1504
+ "epoch": 2.3595257788806174,
1505
+ "grad_norm": 1.8528521060943604,
1506
+ "learning_rate": 9.335472578763128e-05,
1507
+ "loss": 3.4961143493652345,
1508
+ "step": 2140
1509
+ },
1510
+ {
1511
+ "epoch": 2.37055417700579,
1512
+ "grad_norm": 1.8050177097320557,
1513
+ "learning_rate": 9.329638273045509e-05,
1514
+ "loss": 3.4948150634765627,
1515
+ "step": 2150
1516
+ },
1517
+ {
1518
+ "epoch": 2.3815825751309623,
1519
+ "grad_norm": 1.816784381866455,
1520
+ "learning_rate": 9.32380396732789e-05,
1521
+ "loss": 3.4910873413085937,
1522
+ "step": 2160
1523
+ },
1524
+ {
1525
+ "epoch": 2.3926109732561347,
1526
+ "grad_norm": 1.9779244661331177,
1527
+ "learning_rate": 9.317969661610269e-05,
1528
+ "loss": 3.492570495605469,
1529
+ "step": 2170
1530
+ },
1531
+ {
1532
+ "epoch": 2.4036393713813067,
1533
+ "grad_norm": 1.8939772844314575,
1534
+ "learning_rate": 9.31213535589265e-05,
1535
+ "loss": 3.473868560791016,
1536
+ "step": 2180
1537
+ },
1538
+ {
1539
+ "epoch": 2.414667769506479,
1540
+ "grad_norm": 2.1493656635284424,
1541
+ "learning_rate": 9.30630105017503e-05,
1542
+ "loss": 3.494515228271484,
1543
+ "step": 2190
1544
+ },
1545
+ {
1546
+ "epoch": 2.4256961676316515,
1547
+ "grad_norm": 1.8989397287368774,
1548
+ "learning_rate": 9.30046674445741e-05,
1549
+ "loss": 3.487537384033203,
1550
+ "step": 2200
1551
+ },
1552
+ {
1553
+ "epoch": 2.436724565756824,
1554
+ "grad_norm": 1.881856918334961,
1555
+ "learning_rate": 9.294632438739791e-05,
1556
+ "loss": 3.475904083251953,
1557
+ "step": 2210
1558
+ },
1559
+ {
1560
+ "epoch": 2.447752963881996,
1561
+ "grad_norm": 1.9463883638381958,
1562
+ "learning_rate": 9.288798133022171e-05,
1563
+ "loss": 3.4829254150390625,
1564
+ "step": 2220
1565
+ },
1566
+ {
1567
+ "epoch": 2.4587813620071683,
1568
+ "grad_norm": 2.01379656791687,
1569
+ "learning_rate": 9.282963827304552e-05,
1570
+ "loss": 3.472850036621094,
1571
+ "step": 2230
1572
+ },
1573
+ {
1574
+ "epoch": 2.4698097601323408,
1575
+ "grad_norm": 2.442741632461548,
1576
+ "learning_rate": 9.277129521586931e-05,
1577
+ "loss": 3.47030029296875,
1578
+ "step": 2240
1579
+ },
1580
+ {
1581
+ "epoch": 2.480838158257513,
1582
+ "grad_norm": 1.5051734447479248,
1583
+ "learning_rate": 9.271295215869312e-05,
1584
+ "loss": 3.489413833618164,
1585
+ "step": 2250
1586
+ },
1587
+ {
1588
+ "epoch": 2.4918665563826856,
1589
+ "grad_norm": 1.9489309787750244,
1590
+ "learning_rate": 9.265460910151692e-05,
1591
+ "loss": 3.464769744873047,
1592
+ "step": 2260
1593
+ },
1594
+ {
1595
+ "epoch": 2.5028949545078576,
1596
+ "grad_norm": 2.319654941558838,
1597
+ "learning_rate": 9.259626604434072e-05,
1598
+ "loss": 3.469140625,
1599
+ "step": 2270
1600
+ },
1601
+ {
1602
+ "epoch": 2.51392335263303,
1603
+ "grad_norm": 1.7984129190444946,
1604
+ "learning_rate": 9.253792298716453e-05,
1605
+ "loss": 3.466594696044922,
1606
+ "step": 2280
1607
+ },
1608
+ {
1609
+ "epoch": 2.5249517507582024,
1610
+ "grad_norm": 1.640869379043579,
1611
+ "learning_rate": 9.247957992998833e-05,
1612
+ "loss": 3.463022994995117,
1613
+ "step": 2290
1614
+ },
1615
+ {
1616
+ "epoch": 2.5359801488833744,
1617
+ "grad_norm": 1.6698195934295654,
1618
+ "learning_rate": 9.242123687281214e-05,
1619
+ "loss": 3.4695220947265626,
1620
+ "step": 2300
1621
+ },
1622
+ {
1623
+ "epoch": 2.547008547008547,
1624
+ "grad_norm": 2.2945683002471924,
1625
+ "learning_rate": 9.236289381563595e-05,
1626
+ "loss": 3.469150924682617,
1627
+ "step": 2310
1628
+ },
1629
+ {
1630
+ "epoch": 2.5580369451337193,
1631
+ "grad_norm": 1.7678370475769043,
1632
+ "learning_rate": 9.230455075845976e-05,
1633
+ "loss": 3.470307159423828,
1634
+ "step": 2320
1635
+ },
1636
+ {
1637
+ "epoch": 2.5690653432588917,
1638
+ "grad_norm": 1.8386255502700806,
1639
+ "learning_rate": 9.224620770128355e-05,
1640
+ "loss": 3.4638832092285154,
1641
+ "step": 2330
1642
+ },
1643
+ {
1644
+ "epoch": 2.580093741384064,
1645
+ "grad_norm": 2.0348527431488037,
1646
+ "learning_rate": 9.218786464410736e-05,
1647
+ "loss": 3.460480880737305,
1648
+ "step": 2340
1649
+ },
1650
+ {
1651
+ "epoch": 2.5911221395092365,
1652
+ "grad_norm": 1.845974326133728,
1653
+ "learning_rate": 9.212952158693116e-05,
1654
+ "loss": 3.4529083251953123,
1655
+ "step": 2350
1656
+ },
1657
+ {
1658
+ "epoch": 2.6021505376344085,
1659
+ "grad_norm": 2.0843095779418945,
1660
+ "learning_rate": 9.207117852975496e-05,
1661
+ "loss": 3.4576786041259764,
1662
+ "step": 2360
1663
+ },
1664
+ {
1665
+ "epoch": 2.613178935759581,
1666
+ "grad_norm": 1.7627031803131104,
1667
+ "learning_rate": 9.201283547257876e-05,
1668
+ "loss": 3.4450752258300783,
1669
+ "step": 2370
1670
+ },
1671
+ {
1672
+ "epoch": 2.6242073338847534,
1673
+ "grad_norm": 1.371972918510437,
1674
+ "learning_rate": 9.195449241540257e-05,
1675
+ "loss": 3.464734649658203,
1676
+ "step": 2380
1677
+ },
1678
+ {
1679
+ "epoch": 2.6352357320099253,
1680
+ "grad_norm": 1.6781940460205078,
1681
+ "learning_rate": 9.189614935822638e-05,
1682
+ "loss": 3.444991683959961,
1683
+ "step": 2390
1684
+ },
1685
+ {
1686
+ "epoch": 2.6462641301350978,
1687
+ "grad_norm": 1.8782585859298706,
1688
+ "learning_rate": 9.183780630105017e-05,
1689
+ "loss": 3.4558509826660155,
1690
+ "step": 2400
1691
+ },
1692
+ {
1693
+ "epoch": 2.65729252826027,
1694
+ "grad_norm": 1.942812204360962,
1695
+ "learning_rate": 9.177946324387398e-05,
1696
+ "loss": 3.4555503845214846,
1697
+ "step": 2410
1698
+ },
1699
+ {
1700
+ "epoch": 2.6683209263854426,
1701
+ "grad_norm": 1.404680609703064,
1702
+ "learning_rate": 9.172112018669778e-05,
1703
+ "loss": 3.438182830810547,
1704
+ "step": 2420
1705
+ },
1706
+ {
1707
+ "epoch": 2.679349324510615,
1708
+ "grad_norm": 1.7656677961349487,
1709
+ "learning_rate": 9.166277712952159e-05,
1710
+ "loss": 3.4622947692871096,
1711
+ "step": 2430
1712
+ },
1713
+ {
1714
+ "epoch": 2.690377722635787,
1715
+ "grad_norm": 1.8348901271820068,
1716
+ "learning_rate": 9.16044340723454e-05,
1717
+ "loss": 3.438182830810547,
1718
+ "step": 2440
1719
+ },
1720
+ {
1721
+ "epoch": 2.7014061207609594,
1722
+ "grad_norm": 2.0641167163848877,
1723
+ "learning_rate": 9.15460910151692e-05,
1724
+ "loss": 3.441473388671875,
1725
+ "step": 2450
1726
+ },
1727
+ {
1728
+ "epoch": 2.712434518886132,
1729
+ "grad_norm": 1.726035475730896,
1730
+ "learning_rate": 9.148774795799301e-05,
1731
+ "loss": 3.441991424560547,
1732
+ "step": 2460
1733
+ },
1734
+ {
1735
+ "epoch": 2.7234629170113043,
1736
+ "grad_norm": 1.854658603668213,
1737
+ "learning_rate": 9.142940490081681e-05,
1738
+ "loss": 3.4441551208496093,
1739
+ "step": 2470
1740
+ },
1741
+ {
1742
+ "epoch": 2.7344913151364763,
1743
+ "grad_norm": 1.8229296207427979,
1744
+ "learning_rate": 9.137106184364062e-05,
1745
+ "loss": 3.441034698486328,
1746
+ "step": 2480
1747
+ },
1748
+ {
1749
+ "epoch": 2.7455197132616487,
1750
+ "grad_norm": 1.6627975702285767,
1751
+ "learning_rate": 9.131271878646441e-05,
1752
+ "loss": 3.4399124145507813,
1753
+ "step": 2490
1754
+ },
1755
+ {
1756
+ "epoch": 2.756548111386821,
1757
+ "grad_norm": 1.4111251831054688,
1758
+ "learning_rate": 9.125437572928822e-05,
1759
+ "loss": 3.4374462127685548,
1760
+ "step": 2500
1761
+ },
1762
+ {
1763
+ "epoch": 2.7675765095119935,
1764
+ "grad_norm": 2.015869379043579,
1765
+ "learning_rate": 9.119603267211202e-05,
1766
+ "loss": 3.4262016296386717,
1767
+ "step": 2510
1768
+ },
1769
+ {
1770
+ "epoch": 2.778604907637166,
1771
+ "grad_norm": 2.2818591594696045,
1772
+ "learning_rate": 9.113768961493583e-05,
1773
+ "loss": 3.446285629272461,
1774
+ "step": 2520
1775
+ },
1776
+ {
1777
+ "epoch": 2.789633305762338,
1778
+ "grad_norm": 1.8643262386322021,
1779
+ "learning_rate": 9.107934655775962e-05,
1780
+ "loss": 3.4362293243408204,
1781
+ "step": 2530
1782
+ },
1783
+ {
1784
+ "epoch": 2.8006617038875103,
1785
+ "grad_norm": 1.248988151550293,
1786
+ "learning_rate": 9.102100350058343e-05,
1787
+ "loss": 3.441702651977539,
1788
+ "step": 2540
1789
+ },
1790
+ {
1791
+ "epoch": 2.8116901020126828,
1792
+ "grad_norm": 1.5247464179992676,
1793
+ "learning_rate": 9.096266044340724e-05,
1794
+ "loss": 3.4388256072998047,
1795
+ "step": 2550
1796
+ },
1797
+ {
1798
+ "epoch": 2.8227185001378547,
1799
+ "grad_norm": 1.9120620489120483,
1800
+ "learning_rate": 9.090431738623103e-05,
1801
+ "loss": 3.4206756591796874,
1802
+ "step": 2560
1803
+ },
1804
+ {
1805
+ "epoch": 2.833746898263027,
1806
+ "grad_norm": 1.4591054916381836,
1807
+ "learning_rate": 9.084597432905484e-05,
1808
+ "loss": 3.4229709625244142,
1809
+ "step": 2570
1810
+ },
1811
+ {
1812
+ "epoch": 2.8447752963881996,
1813
+ "grad_norm": 2.24849796295166,
1814
+ "learning_rate": 9.078763127187865e-05,
1815
+ "loss": 3.426911163330078,
1816
+ "step": 2580
1817
+ },
1818
+ {
1819
+ "epoch": 2.855803694513372,
1820
+ "grad_norm": 1.5658804178237915,
1821
+ "learning_rate": 9.072928821470246e-05,
1822
+ "loss": 3.445120620727539,
1823
+ "step": 2590
1824
+ },
1825
+ {
1826
+ "epoch": 2.8668320926385444,
1827
+ "grad_norm": 1.483583688735962,
1828
+ "learning_rate": 9.067094515752626e-05,
1829
+ "loss": 3.430312728881836,
1830
+ "step": 2600
1831
+ },
1832
+ {
1833
+ "epoch": 2.8778604907637164,
1834
+ "grad_norm": 1.5759658813476562,
1835
+ "learning_rate": 9.061260210035007e-05,
1836
+ "loss": 3.4178386688232423,
1837
+ "step": 2610
1838
+ },
1839
+ {
1840
+ "epoch": 2.888888888888889,
1841
+ "grad_norm": 1.9259848594665527,
1842
+ "learning_rate": 9.055425904317386e-05,
1843
+ "loss": 3.430949401855469,
1844
+ "step": 2620
1845
+ },
1846
+ {
1847
+ "epoch": 2.8999172870140613,
1848
+ "grad_norm": 1.470717191696167,
1849
+ "learning_rate": 9.049591598599767e-05,
1850
+ "loss": 3.439757537841797,
1851
+ "step": 2630
1852
+ },
1853
+ {
1854
+ "epoch": 2.9109456851392337,
1855
+ "grad_norm": 1.8934212923049927,
1856
+ "learning_rate": 9.043757292882148e-05,
1857
+ "loss": 3.430719757080078,
1858
+ "step": 2640
1859
+ },
1860
+ {
1861
+ "epoch": 2.9219740832644057,
1862
+ "grad_norm": 1.6267489194869995,
1863
+ "learning_rate": 9.037922987164527e-05,
1864
+ "loss": 3.4224998474121096,
1865
+ "step": 2650
1866
+ },
1867
+ {
1868
+ "epoch": 2.933002481389578,
1869
+ "grad_norm": 1.6213353872299194,
1870
+ "learning_rate": 9.032088681446908e-05,
1871
+ "loss": 3.4213233947753907,
1872
+ "step": 2660
1873
+ },
1874
+ {
1875
+ "epoch": 2.9440308795147505,
1876
+ "grad_norm": 1.961879849433899,
1877
+ "learning_rate": 9.026254375729288e-05,
1878
+ "loss": 3.4108352661132812,
1879
+ "step": 2670
1880
+ },
1881
+ {
1882
+ "epoch": 2.955059277639923,
1883
+ "grad_norm": 1.7363910675048828,
1884
+ "learning_rate": 9.020420070011669e-05,
1885
+ "loss": 3.423554229736328,
1886
+ "step": 2680
1887
+ },
1888
+ {
1889
+ "epoch": 2.9660876757650954,
1890
+ "grad_norm": 1.6161952018737793,
1891
+ "learning_rate": 9.014585764294048e-05,
1892
+ "loss": 3.418962860107422,
1893
+ "step": 2690
1894
+ },
1895
+ {
1896
+ "epoch": 2.9771160738902673,
1897
+ "grad_norm": 1.8065682649612427,
1898
+ "learning_rate": 9.008751458576429e-05,
1899
+ "loss": 3.4218765258789063,
1900
+ "step": 2700
1901
+ },
1902
+ {
1903
+ "epoch": 2.9881444720154398,
1904
+ "grad_norm": 1.4285337924957275,
1905
+ "learning_rate": 9.00291715285881e-05,
1906
+ "loss": 3.413957214355469,
1907
+ "step": 2710
1908
+ },
1909
+ {
1910
+ "epoch": 2.999172870140612,
1911
+ "grad_norm": 1.30274498462677,
1912
+ "learning_rate": 8.997082847141191e-05,
1913
+ "loss": 3.4176124572753905,
1914
+ "step": 2720
1915
+ },
1916
+ {
1917
+ "epoch": 3.009925558312655,
1918
+ "grad_norm": 1.5460416078567505,
1919
+ "learning_rate": 8.991248541423572e-05,
1920
+ "loss": 3.388013458251953,
1921
+ "step": 2730
1922
+ },
1923
+ {
1924
+ "epoch": 3.0209539564378276,
1925
+ "grad_norm": 1.5832446813583374,
1926
+ "learning_rate": 8.985414235705951e-05,
1927
+ "loss": 3.3929378509521486,
1928
+ "step": 2740
1929
+ },
1930
+ {
1931
+ "epoch": 3.0319823545629996,
1932
+ "grad_norm": 1.6086630821228027,
1933
+ "learning_rate": 8.979579929988332e-05,
1934
+ "loss": 3.3940502166748048,
1935
+ "step": 2750
1936
+ },
1937
+ {
1938
+ "epoch": 3.043010752688172,
1939
+ "grad_norm": 1.6624842882156372,
1940
+ "learning_rate": 8.973745624270712e-05,
1941
+ "loss": 3.388884353637695,
1942
+ "step": 2760
1943
+ },
1944
+ {
1945
+ "epoch": 3.0540391508133444,
1946
+ "grad_norm": 1.7352933883666992,
1947
+ "learning_rate": 8.967911318553093e-05,
1948
+ "loss": 3.409127426147461,
1949
+ "step": 2770
1950
+ },
1951
+ {
1952
+ "epoch": 3.065067548938517,
1953
+ "grad_norm": 1.45657217502594,
1954
+ "learning_rate": 8.962077012835472e-05,
1955
+ "loss": 3.389351654052734,
1956
+ "step": 2780
1957
+ },
1958
+ {
1959
+ "epoch": 3.076095947063689,
1960
+ "grad_norm": 1.4969090223312378,
1961
+ "learning_rate": 8.956242707117853e-05,
1962
+ "loss": 3.3988433837890626,
1963
+ "step": 2790
1964
+ },
1965
+ {
1966
+ "epoch": 3.0871243451888613,
1967
+ "grad_norm": 1.710800051689148,
1968
+ "learning_rate": 8.950408401400234e-05,
1969
+ "loss": 3.395826721191406,
1970
+ "step": 2800
1971
+ },
1972
+ {
1973
+ "epoch": 3.0981527433140337,
1974
+ "grad_norm": 1.6347870826721191,
1975
+ "learning_rate": 8.944574095682614e-05,
1976
+ "loss": 3.391011047363281,
1977
+ "step": 2810
1978
+ },
1979
+ {
1980
+ "epoch": 3.109181141439206,
1981
+ "grad_norm": 1.4630122184753418,
1982
+ "learning_rate": 8.938739789964995e-05,
1983
+ "loss": 3.401841735839844,
1984
+ "step": 2820
1985
+ },
1986
+ {
1987
+ "epoch": 3.120209539564378,
1988
+ "grad_norm": 1.547430157661438,
1989
+ "learning_rate": 8.932905484247374e-05,
1990
+ "loss": 3.3979782104492187,
1991
+ "step": 2830
1992
+ },
1993
+ {
1994
+ "epoch": 3.1312379376895505,
1995
+ "grad_norm": 1.5614186525344849,
1996
+ "learning_rate": 8.927071178529755e-05,
1997
+ "loss": 3.3884544372558594,
1998
+ "step": 2840
1999
+ },
2000
+ {
2001
+ "epoch": 3.142266335814723,
2002
+ "grad_norm": 1.4073251485824585,
2003
+ "learning_rate": 8.921236872812136e-05,
2004
+ "loss": 3.3886154174804686,
2005
+ "step": 2850
2006
+ },
2007
+ {
2008
+ "epoch": 3.1532947339398953,
2009
+ "grad_norm": 1.3639475107192993,
2010
+ "learning_rate": 8.915402567094517e-05,
2011
+ "loss": 3.383074951171875,
2012
+ "step": 2860
2013
+ },
2014
+ {
2015
+ "epoch": 3.1643231320650678,
2016
+ "grad_norm": 2.3929882049560547,
2017
+ "learning_rate": 8.909568261376896e-05,
2018
+ "loss": 3.3788246154785155,
2019
+ "step": 2870
2020
+ },
2021
+ {
2022
+ "epoch": 3.1753515301902397,
2023
+ "grad_norm": 1.7196829319000244,
2024
+ "learning_rate": 8.903733955659277e-05,
2025
+ "loss": 3.3822708129882812,
2026
+ "step": 2880
2027
+ },
2028
+ {
2029
+ "epoch": 3.186379928315412,
2030
+ "grad_norm": 1.526293396949768,
2031
+ "learning_rate": 8.897899649941658e-05,
2032
+ "loss": 3.381543731689453,
2033
+ "step": 2890
2034
+ },
2035
+ {
2036
+ "epoch": 3.1974083264405846,
2037
+ "grad_norm": 1.2336128950119019,
2038
+ "learning_rate": 8.892065344224038e-05,
2039
+ "loss": 3.3975807189941407,
2040
+ "step": 2900
2041
+ },
2042
+ {
2043
+ "epoch": 3.208436724565757,
2044
+ "grad_norm": 1.4868130683898926,
2045
+ "learning_rate": 8.886231038506419e-05,
2046
+ "loss": 3.3970687866210936,
2047
+ "step": 2910
2048
+ },
2049
+ {
2050
+ "epoch": 3.219465122690929,
2051
+ "grad_norm": 1.5349540710449219,
2052
+ "learning_rate": 8.880396732788798e-05,
2053
+ "loss": 3.385994720458984,
2054
+ "step": 2920
2055
+ },
2056
+ {
2057
+ "epoch": 3.2304935208161014,
2058
+ "grad_norm": 1.5333718061447144,
2059
+ "learning_rate": 8.874562427071179e-05,
2060
+ "loss": 3.362841796875,
2061
+ "step": 2930
2062
+ },
2063
+ {
2064
+ "epoch": 3.241521918941274,
2065
+ "grad_norm": 1.514235258102417,
2066
+ "learning_rate": 8.868728121353558e-05,
2067
+ "loss": 3.3816680908203125,
2068
+ "step": 2940
2069
+ },
2070
+ {
2071
+ "epoch": 3.2525503170664463,
2072
+ "grad_norm": 1.5870161056518555,
2073
+ "learning_rate": 8.86289381563594e-05,
2074
+ "loss": 3.3818199157714846,
2075
+ "step": 2950
2076
+ },
2077
+ {
2078
+ "epoch": 3.2635787151916182,
2079
+ "grad_norm": 1.6295320987701416,
2080
+ "learning_rate": 8.85705950991832e-05,
2081
+ "loss": 3.379594421386719,
2082
+ "step": 2960
2083
+ },
2084
+ {
2085
+ "epoch": 3.2746071133167907,
2086
+ "grad_norm": 1.533991813659668,
2087
+ "learning_rate": 8.8512252042007e-05,
2088
+ "loss": 3.387801742553711,
2089
+ "step": 2970
2090
+ },
2091
+ {
2092
+ "epoch": 3.285635511441963,
2093
+ "grad_norm": 2.2125084400177,
2094
+ "learning_rate": 8.845390898483081e-05,
2095
+ "loss": 3.3856468200683594,
2096
+ "step": 2980
2097
+ },
2098
+ {
2099
+ "epoch": 3.2966639095671355,
2100
+ "grad_norm": 1.800207495689392,
2101
+ "learning_rate": 8.839556592765462e-05,
2102
+ "loss": 3.3843597412109374,
2103
+ "step": 2990
2104
+ },
2105
+ {
2106
+ "epoch": 3.3076923076923075,
2107
+ "grad_norm": 1.3071027994155884,
2108
+ "learning_rate": 8.833722287047842e-05,
2109
+ "loss": 3.3861888885498046,
2110
+ "step": 3000
2111
+ },
2112
+ {
2113
+ "epoch": 3.31872070581748,
2114
+ "grad_norm": 1.7724641561508179,
2115
+ "learning_rate": 8.827887981330222e-05,
2116
+ "loss": 3.3929458618164063,
2117
+ "step": 3010
2118
+ },
2119
+ {
2120
+ "epoch": 3.3297491039426523,
2121
+ "grad_norm": 1.3397877216339111,
2122
+ "learning_rate": 8.822053675612603e-05,
2123
+ "loss": 3.3785301208496095,
2124
+ "step": 3020
2125
+ },
2126
+ {
2127
+ "epoch": 3.3407775020678248,
2128
+ "grad_norm": 1.352630376815796,
2129
+ "learning_rate": 8.816219369894982e-05,
2130
+ "loss": 3.3796306610107423,
2131
+ "step": 3030
2132
+ },
2133
+ {
2134
+ "epoch": 3.351805900192997,
2135
+ "grad_norm": 1.5996475219726562,
2136
+ "learning_rate": 8.810385064177363e-05,
2137
+ "loss": 3.362406921386719,
2138
+ "step": 3040
2139
+ },
2140
+ {
2141
+ "epoch": 3.362834298318169,
2142
+ "grad_norm": 1.6010814905166626,
2143
+ "learning_rate": 8.804550758459744e-05,
2144
+ "loss": 3.3811767578125,
2145
+ "step": 3050
2146
+ },
2147
+ {
2148
+ "epoch": 3.3738626964433416,
2149
+ "grad_norm": 1.3276373147964478,
2150
+ "learning_rate": 8.798716452742124e-05,
2151
+ "loss": 3.3732643127441406,
2152
+ "step": 3060
2153
+ },
2154
+ {
2155
+ "epoch": 3.384891094568514,
2156
+ "grad_norm": 1.7741515636444092,
2157
+ "learning_rate": 8.792882147024505e-05,
2158
+ "loss": 3.381968688964844,
2159
+ "step": 3070
2160
+ },
2161
+ {
2162
+ "epoch": 3.3959194926936864,
2163
+ "grad_norm": 1.7820576429367065,
2164
+ "learning_rate": 8.787047841306884e-05,
2165
+ "loss": 3.358811950683594,
2166
+ "step": 3080
2167
+ },
2168
+ {
2169
+ "epoch": 3.4069478908188584,
2170
+ "grad_norm": 1.389573574066162,
2171
+ "learning_rate": 8.781213535589265e-05,
2172
+ "loss": 3.36102180480957,
2173
+ "step": 3090
2174
+ },
2175
+ {
2176
+ "epoch": 3.417976288944031,
2177
+ "grad_norm": 1.1910648345947266,
2178
+ "learning_rate": 8.775379229871645e-05,
2179
+ "loss": 3.3652645111083985,
2180
+ "step": 3100
2181
+ },
2182
+ {
2183
+ "epoch": 3.4290046870692032,
2184
+ "grad_norm": 1.965219497680664,
2185
+ "learning_rate": 8.769544924154026e-05,
2186
+ "loss": 3.3735313415527344,
2187
+ "step": 3110
2188
+ },
2189
+ {
2190
+ "epoch": 3.4400330851943757,
2191
+ "grad_norm": 1.5992330312728882,
2192
+ "learning_rate": 8.763710618436406e-05,
2193
+ "loss": 3.362974166870117,
2194
+ "step": 3120
2195
+ },
2196
+ {
2197
+ "epoch": 3.4510614833195477,
2198
+ "grad_norm": 2.2293193340301514,
2199
+ "learning_rate": 8.757876312718787e-05,
2200
+ "loss": 3.3681709289550783,
2201
+ "step": 3130
2202
+ },
2203
+ {
2204
+ "epoch": 3.46208988144472,
2205
+ "grad_norm": 1.2978801727294922,
2206
+ "learning_rate": 8.752042007001168e-05,
2207
+ "loss": 3.3776336669921876,
2208
+ "step": 3140
2209
+ },
2210
+ {
2211
+ "epoch": 3.4731182795698925,
2212
+ "grad_norm": 1.227036714553833,
2213
+ "learning_rate": 8.746207701283548e-05,
2214
+ "loss": 3.3590301513671874,
2215
+ "step": 3150
2216
+ },
2217
+ {
2218
+ "epoch": 3.484146677695065,
2219
+ "grad_norm": 1.8023360967636108,
2220
+ "learning_rate": 8.740373395565929e-05,
2221
+ "loss": 3.35421142578125,
2222
+ "step": 3160
2223
+ },
2224
+ {
2225
+ "epoch": 3.495175075820237,
2226
+ "grad_norm": 1.6423453092575073,
2227
+ "learning_rate": 8.734539089848308e-05,
2228
+ "loss": 3.3748985290527345,
2229
+ "step": 3170
2230
+ },
2231
+ {
2232
+ "epoch": 3.5062034739454093,
2233
+ "grad_norm": 1.3261916637420654,
2234
+ "learning_rate": 8.728704784130689e-05,
2235
+ "loss": 3.36380615234375,
2236
+ "step": 3180
2237
+ },
2238
+ {
2239
+ "epoch": 3.5172318720705817,
2240
+ "grad_norm": 1.290014624595642,
2241
+ "learning_rate": 8.722870478413069e-05,
2242
+ "loss": 3.3596282958984376,
2243
+ "step": 3190
2244
+ },
2245
+ {
2246
+ "epoch": 3.528260270195754,
2247
+ "grad_norm": 2.0481576919555664,
2248
+ "learning_rate": 8.71703617269545e-05,
2249
+ "loss": 3.358118438720703,
2250
+ "step": 3200
2251
+ },
2252
+ {
2253
+ "epoch": 3.5392886683209266,
2254
+ "grad_norm": 1.4758331775665283,
2255
+ "learning_rate": 8.71120186697783e-05,
2256
+ "loss": 3.3536834716796875,
2257
+ "step": 3210
2258
+ },
2259
+ {
2260
+ "epoch": 3.5503170664460986,
2261
+ "grad_norm": 1.4340440034866333,
2262
+ "learning_rate": 8.70536756126021e-05,
2263
+ "loss": 3.358259582519531,
2264
+ "step": 3220
2265
+ },
2266
+ {
2267
+ "epoch": 3.561345464571271,
2268
+ "grad_norm": 1.6952699422836304,
2269
+ "learning_rate": 8.699533255542591e-05,
2270
+ "loss": 3.3730777740478515,
2271
+ "step": 3230
2272
+ },
2273
+ {
2274
+ "epoch": 3.5723738626964434,
2275
+ "grad_norm": 1.9069234132766724,
2276
+ "learning_rate": 8.69369894982497e-05,
2277
+ "loss": 3.3552001953125,
2278
+ "step": 3240
2279
+ },
2280
+ {
2281
+ "epoch": 3.5834022608216154,
2282
+ "grad_norm": 1.6194590330123901,
2283
+ "learning_rate": 8.687864644107351e-05,
2284
+ "loss": 3.3562744140625,
2285
+ "step": 3250
2286
+ },
2287
+ {
2288
+ "epoch": 3.594430658946788,
2289
+ "grad_norm": 1.33975350856781,
2290
+ "learning_rate": 8.682030338389732e-05,
2291
+ "loss": 3.3622581481933596,
2292
+ "step": 3260
2293
+ },
2294
+ {
2295
+ "epoch": 3.6054590570719602,
2296
+ "grad_norm": 1.3948160409927368,
2297
+ "learning_rate": 8.676196032672113e-05,
2298
+ "loss": 3.3645614624023437,
2299
+ "step": 3270
2300
+ },
2301
+ {
2302
+ "epoch": 3.6164874551971327,
2303
+ "grad_norm": 1.4972363710403442,
2304
+ "learning_rate": 8.670361726954493e-05,
2305
+ "loss": 3.3713829040527346,
2306
+ "step": 3280
2307
+ },
2308
+ {
2309
+ "epoch": 3.627515853322305,
2310
+ "grad_norm": 1.9456968307495117,
2311
+ "learning_rate": 8.664527421236874e-05,
2312
+ "loss": 3.3617935180664062,
2313
+ "step": 3290
2314
+ },
2315
+ {
2316
+ "epoch": 3.6385442514474775,
2317
+ "grad_norm": 1.8050702810287476,
2318
+ "learning_rate": 8.658693115519254e-05,
2319
+ "loss": 3.359496307373047,
2320
+ "step": 3300
2321
+ },
2322
+ {
2323
+ "epoch": 3.6495726495726495,
2324
+ "grad_norm": 1.294492244720459,
2325
+ "learning_rate": 8.652858809801634e-05,
2326
+ "loss": 3.361173629760742,
2327
+ "step": 3310
2328
+ },
2329
+ {
2330
+ "epoch": 3.660601047697822,
2331
+ "grad_norm": 1.7897614240646362,
2332
+ "learning_rate": 8.647024504084015e-05,
2333
+ "loss": 3.3475852966308595,
2334
+ "step": 3320
2335
+ },
2336
+ {
2337
+ "epoch": 3.6716294458229943,
2338
+ "grad_norm": 1.5647767782211304,
2339
+ "learning_rate": 8.641190198366394e-05,
2340
+ "loss": 3.3594207763671875,
2341
+ "step": 3330
2342
+ },
2343
+ {
2344
+ "epoch": 3.6826578439481663,
2345
+ "grad_norm": 1.3839472532272339,
2346
+ "learning_rate": 8.635355892648775e-05,
2347
+ "loss": 3.361709976196289,
2348
+ "step": 3340
2349
+ },
2350
+ {
2351
+ "epoch": 3.6936862420733387,
2352
+ "grad_norm": 1.543115258216858,
2353
+ "learning_rate": 8.629521586931155e-05,
2354
+ "loss": 3.349272918701172,
2355
+ "step": 3350
2356
+ },
2357
+ {
2358
+ "epoch": 3.704714640198511,
2359
+ "grad_norm": 1.2722103595733643,
2360
+ "learning_rate": 8.623687281213536e-05,
2361
+ "loss": 3.3600040435791017,
2362
+ "step": 3360
2363
+ },
2364
+ {
2365
+ "epoch": 3.7157430383236836,
2366
+ "grad_norm": 2.396493434906006,
2367
+ "learning_rate": 8.617852975495915e-05,
2368
+ "loss": 3.359762954711914,
2369
+ "step": 3370
2370
+ },
2371
+ {
2372
+ "epoch": 3.726771436448856,
2373
+ "grad_norm": 1.3756037950515747,
2374
+ "learning_rate": 8.612018669778296e-05,
2375
+ "loss": 3.3409027099609374,
2376
+ "step": 3380
2377
+ },
2378
+ {
2379
+ "epoch": 3.737799834574028,
2380
+ "grad_norm": 1.5124824047088623,
2381
+ "learning_rate": 8.606184364060677e-05,
2382
+ "loss": 3.346342849731445,
2383
+ "step": 3390
2384
+ },
2385
+ {
2386
+ "epoch": 3.7488282326992004,
2387
+ "grad_norm": 1.3679585456848145,
2388
+ "learning_rate": 8.600350058343058e-05,
2389
+ "loss": 3.3478328704833986,
2390
+ "step": 3400
2391
+ },
2392
+ {
2393
+ "epoch": 3.759856630824373,
2394
+ "grad_norm": 1.3470197916030884,
2395
+ "learning_rate": 8.594515752625439e-05,
2396
+ "loss": 3.352674865722656,
2397
+ "step": 3410
2398
+ },
2399
+ {
2400
+ "epoch": 3.770885028949545,
2401
+ "grad_norm": 1.4775781631469727,
2402
+ "learning_rate": 8.588681446907818e-05,
2403
+ "loss": 3.3504791259765625,
2404
+ "step": 3420
2405
+ },
2406
+ {
2407
+ "epoch": 3.7819134270747172,
2408
+ "grad_norm": 1.1987943649291992,
2409
+ "learning_rate": 8.582847141190199e-05,
2410
+ "loss": 3.3457687377929686,
2411
+ "step": 3430
2412
+ },
2413
+ {
2414
+ "epoch": 3.7929418251998896,
2415
+ "grad_norm": 1.8007314205169678,
2416
+ "learning_rate": 8.577012835472579e-05,
2417
+ "loss": 3.3557716369628907,
2418
+ "step": 3440
2419
+ },
2420
+ {
2421
+ "epoch": 3.803970223325062,
2422
+ "grad_norm": 1.4193800687789917,
2423
+ "learning_rate": 8.57117852975496e-05,
2424
+ "loss": 3.346666717529297,
2425
+ "step": 3450
2426
+ },
2427
+ {
2428
+ "epoch": 3.8149986214502345,
2429
+ "grad_norm": 1.600216031074524,
2430
+ "learning_rate": 8.56534422403734e-05,
2431
+ "loss": 3.354322814941406,
2432
+ "step": 3460
2433
+ },
2434
+ {
2435
+ "epoch": 3.826027019575407,
2436
+ "grad_norm": 1.6823015213012695,
2437
+ "learning_rate": 8.55950991831972e-05,
2438
+ "loss": 3.3344764709472656,
2439
+ "step": 3470
2440
+ },
2441
+ {
2442
+ "epoch": 3.837055417700579,
2443
+ "grad_norm": 1.8002822399139404,
2444
+ "learning_rate": 8.553675612602101e-05,
2445
+ "loss": 3.338224411010742,
2446
+ "step": 3480
2447
+ },
2448
+ {
2449
+ "epoch": 3.8480838158257513,
2450
+ "grad_norm": 1.019519567489624,
2451
+ "learning_rate": 8.54784130688448e-05,
2452
+ "loss": 3.342393493652344,
2453
+ "step": 3490
2454
+ },
2455
+ {
2456
+ "epoch": 3.8591122139509237,
2457
+ "grad_norm": 1.4397176504135132,
2458
+ "learning_rate": 8.542007001166861e-05,
2459
+ "loss": 3.3416332244873046,
2460
+ "step": 3500
2461
+ },
2462
+ {
2463
+ "epoch": 3.8701406120760957,
2464
+ "grad_norm": 1.398215889930725,
2465
+ "learning_rate": 8.536172695449241e-05,
2466
+ "loss": 3.3455711364746095,
2467
+ "step": 3510
2468
+ },
2469
+ {
2470
+ "epoch": 3.881169010201268,
2471
+ "grad_norm": 1.431221604347229,
2472
+ "learning_rate": 8.530338389731622e-05,
2473
+ "loss": 3.3510116577148437,
2474
+ "step": 3520
2475
+ },
2476
+ {
2477
+ "epoch": 3.8921974083264406,
2478
+ "grad_norm": 1.2339868545532227,
2479
+ "learning_rate": 8.524504084014003e-05,
2480
+ "loss": 3.333365631103516,
2481
+ "step": 3530
2482
+ },
2483
+ {
2484
+ "epoch": 3.903225806451613,
2485
+ "grad_norm": 1.2564575672149658,
2486
+ "learning_rate": 8.518669778296384e-05,
2487
+ "loss": 3.355131912231445,
2488
+ "step": 3540
2489
+ },
2490
+ {
2491
+ "epoch": 3.9142542045767854,
2492
+ "grad_norm": 1.44709050655365,
2493
+ "learning_rate": 8.512835472578765e-05,
2494
+ "loss": 3.352345275878906,
2495
+ "step": 3550
2496
+ },
2497
+ {
2498
+ "epoch": 3.9252826027019574,
2499
+ "grad_norm": 1.0984286069869995,
2500
+ "learning_rate": 8.507001166861144e-05,
2501
+ "loss": 3.3399391174316406,
2502
+ "step": 3560
2503
+ },
2504
+ {
2505
+ "epoch": 3.93631100082713,
2506
+ "grad_norm": 1.521567702293396,
2507
+ "learning_rate": 8.501166861143525e-05,
2508
+ "loss": 3.3333946228027345,
2509
+ "step": 3570
2510
+ },
2511
+ {
2512
+ "epoch": 3.9473393989523022,
2513
+ "grad_norm": 1.3443926572799683,
2514
+ "learning_rate": 8.495332555425905e-05,
2515
+ "loss": 3.3321746826171874,
2516
+ "step": 3580
2517
+ },
2518
+ {
2519
+ "epoch": 3.9583677970774747,
2520
+ "grad_norm": 1.539640188217163,
2521
+ "learning_rate": 8.489498249708285e-05,
2522
+ "loss": 3.335438537597656,
2523
+ "step": 3590
2524
+ },
2525
+ {
2526
+ "epoch": 3.9693961952026466,
2527
+ "grad_norm": 1.123307466506958,
2528
+ "learning_rate": 8.483663943990665e-05,
2529
+ "loss": 3.3397190093994142,
2530
+ "step": 3600
2531
+ },
2532
+ {
2533
+ "epoch": 3.980424593327819,
2534
+ "grad_norm": 1.6037691831588745,
2535
+ "learning_rate": 8.477829638273046e-05,
2536
+ "loss": 3.3357570648193358,
2537
+ "step": 3610
2538
+ },
2539
+ {
2540
+ "epoch": 3.9914529914529915,
2541
+ "grad_norm": 1.6570971012115479,
2542
+ "learning_rate": 8.471995332555425e-05,
2543
+ "loss": 3.341298294067383,
2544
+ "step": 3620
2545
+ },
2546
+ {
2547
+ "epoch": 4.0022056796250345,
2548
+ "grad_norm": 1.4301789999008179,
2549
+ "learning_rate": 8.466161026837806e-05,
2550
+ "loss": 3.3353721618652346,
2551
+ "step": 3630
2552
+ },
2553
+ {
2554
+ "epoch": 4.013234077750207,
2555
+ "grad_norm": 1.539963722229004,
2556
+ "learning_rate": 8.460326721120187e-05,
2557
+ "loss": 3.3291671752929686,
2558
+ "step": 3640
2559
+ },
2560
+ {
2561
+ "epoch": 4.024262475875379,
2562
+ "grad_norm": 1.5195462703704834,
2563
+ "learning_rate": 8.454492415402567e-05,
2564
+ "loss": 3.3193031311035157,
2565
+ "step": 3650
2566
+ },
2567
+ {
2568
+ "epoch": 4.035290874000552,
2569
+ "grad_norm": 1.423514485359192,
2570
+ "learning_rate": 8.448658109684948e-05,
2571
+ "loss": 3.316299057006836,
2572
+ "step": 3660
2573
+ },
2574
+ {
2575
+ "epoch": 4.046319272125724,
2576
+ "grad_norm": 1.4557220935821533,
2577
+ "learning_rate": 8.442823803967328e-05,
2578
+ "loss": 3.310700607299805,
2579
+ "step": 3670
2580
+ },
2581
+ {
2582
+ "epoch": 4.057347670250896,
2583
+ "grad_norm": 1.6277695894241333,
2584
+ "learning_rate": 8.43698949824971e-05,
2585
+ "loss": 3.3296432495117188,
2586
+ "step": 3680
2587
+ },
2588
+ {
2589
+ "epoch": 4.068376068376068,
2590
+ "grad_norm": 1.4026418924331665,
2591
+ "learning_rate": 8.431155192532089e-05,
2592
+ "loss": 3.316411590576172,
2593
+ "step": 3690
2594
+ },
2595
+ {
2596
+ "epoch": 4.079404466501241,
2597
+ "grad_norm": 1.3620136976242065,
2598
+ "learning_rate": 8.42532088681447e-05,
2599
+ "loss": 3.3130718231201173,
2600
+ "step": 3700
2601
+ },
2602
+ {
2603
+ "epoch": 4.090432864626413,
2604
+ "grad_norm": 1.4140877723693848,
2605
+ "learning_rate": 8.419486581096851e-05,
2606
+ "loss": 3.321166229248047,
2607
+ "step": 3710
2608
+ },
2609
+ {
2610
+ "epoch": 4.101461262751585,
2611
+ "grad_norm": 1.3145273923873901,
2612
+ "learning_rate": 8.41365227537923e-05,
2613
+ "loss": 3.3267845153808593,
2614
+ "step": 3720
2615
+ },
2616
+ {
2617
+ "epoch": 4.112489660876758,
2618
+ "grad_norm": 1.1830849647521973,
2619
+ "learning_rate": 8.407817969661611e-05,
2620
+ "loss": 3.315142059326172,
2621
+ "step": 3730
2622
+ },
2623
+ {
2624
+ "epoch": 4.12351805900193,
2625
+ "grad_norm": 1.4326401948928833,
2626
+ "learning_rate": 8.401983663943991e-05,
2627
+ "loss": 3.313446807861328,
2628
+ "step": 3740
2629
+ },
2630
+ {
2631
+ "epoch": 4.134546457127103,
2632
+ "grad_norm": 1.2179306745529175,
2633
+ "learning_rate": 8.396149358226372e-05,
2634
+ "loss": 3.3100254058837892,
2635
+ "step": 3750
2636
+ },
2637
+ {
2638
+ "epoch": 4.145574855252274,
2639
+ "grad_norm": 1.3347259759902954,
2640
+ "learning_rate": 8.390315052508751e-05,
2641
+ "loss": 3.3180007934570312,
2642
+ "step": 3760
2643
+ },
2644
+ {
2645
+ "epoch": 4.156603253377447,
2646
+ "grad_norm": 1.4468998908996582,
2647
+ "learning_rate": 8.384480746791132e-05,
2648
+ "loss": 3.307207489013672,
2649
+ "step": 3770
2650
+ },
2651
+ {
2652
+ "epoch": 4.167631651502619,
2653
+ "grad_norm": 1.5258162021636963,
2654
+ "learning_rate": 8.378646441073512e-05,
2655
+ "loss": 3.3205909729003906,
2656
+ "step": 3780
2657
+ },
2658
+ {
2659
+ "epoch": 4.1786600496277915,
2660
+ "grad_norm": 1.4104669094085693,
2661
+ "learning_rate": 8.372812135355892e-05,
2662
+ "loss": 3.309407424926758,
2663
+ "step": 3790
2664
+ },
2665
+ {
2666
+ "epoch": 4.189688447752964,
2667
+ "grad_norm": 1.4369711875915527,
2668
+ "learning_rate": 8.366977829638273e-05,
2669
+ "loss": 3.3081008911132814,
2670
+ "step": 3800
2671
+ },
2672
+ {
2673
+ "epoch": 4.200716845878136,
2674
+ "grad_norm": 1.2004350423812866,
2675
+ "learning_rate": 8.361143523920654e-05,
2676
+ "loss": 3.3130638122558596,
2677
+ "step": 3810
2678
+ },
2679
+ {
2680
+ "epoch": 4.211745244003309,
2681
+ "grad_norm": 1.2577087879180908,
2682
+ "learning_rate": 8.355309218203035e-05,
2683
+ "loss": 3.312338638305664,
2684
+ "step": 3820
2685
+ },
2686
+ {
2687
+ "epoch": 4.222773642128481,
2688
+ "grad_norm": 1.3649225234985352,
2689
+ "learning_rate": 8.349474912485415e-05,
2690
+ "loss": 3.323046875,
2691
+ "step": 3830
2692
+ },
2693
+ {
2694
+ "epoch": 4.233802040253654,
2695
+ "grad_norm": 1.3110648393630981,
2696
+ "learning_rate": 8.343640606767796e-05,
2697
+ "loss": 3.3168025970458985,
2698
+ "step": 3840
2699
+ },
2700
+ {
2701
+ "epoch": 4.244830438378825,
2702
+ "grad_norm": 1.493674635887146,
2703
+ "learning_rate": 8.337806301050175e-05,
2704
+ "loss": 3.320719909667969,
2705
+ "step": 3850
2706
+ },
2707
+ {
2708
+ "epoch": 4.2558588365039975,
2709
+ "grad_norm": 1.283460259437561,
2710
+ "learning_rate": 8.331971995332556e-05,
2711
+ "loss": 3.3127769470214843,
2712
+ "step": 3860
2713
+ },
2714
+ {
2715
+ "epoch": 4.26688723462917,
2716
+ "grad_norm": 1.4842219352722168,
2717
+ "learning_rate": 8.326137689614936e-05,
2718
+ "loss": 3.3042266845703123,
2719
+ "step": 3870
2720
+ },
2721
+ {
2722
+ "epoch": 4.277915632754342,
2723
+ "grad_norm": 1.1820423603057861,
2724
+ "learning_rate": 8.320303383897316e-05,
2725
+ "loss": 3.3116954803466796,
2726
+ "step": 3880
2727
+ },
2728
+ {
2729
+ "epoch": 4.288944030879515,
2730
+ "grad_norm": 1.5040090084075928,
2731
+ "learning_rate": 8.314469078179697e-05,
2732
+ "loss": 3.310440444946289,
2733
+ "step": 3890
2734
+ },
2735
+ {
2736
+ "epoch": 4.299972429004687,
2737
+ "grad_norm": 1.1614471673965454,
2738
+ "learning_rate": 8.308634772462077e-05,
2739
+ "loss": 3.3075687408447267,
2740
+ "step": 3900
2741
+ },
2742
+ {
2743
+ "epoch": 4.31100082712986,
2744
+ "grad_norm": 1.5577434301376343,
2745
+ "learning_rate": 8.302800466744458e-05,
2746
+ "loss": 3.312149429321289,
2747
+ "step": 3910
2748
+ },
2749
+ {
2750
+ "epoch": 4.322029225255032,
2751
+ "grad_norm": 1.6462024450302124,
2752
+ "learning_rate": 8.296966161026837e-05,
2753
+ "loss": 3.321173095703125,
2754
+ "step": 3920
2755
+ },
2756
+ {
2757
+ "epoch": 4.333057623380204,
2758
+ "grad_norm": 1.302138090133667,
2759
+ "learning_rate": 8.291131855309218e-05,
2760
+ "loss": 3.3210208892822264,
2761
+ "step": 3930
2762
+ },
2763
+ {
2764
+ "epoch": 4.344086021505376,
2765
+ "grad_norm": 1.6717387437820435,
2766
+ "learning_rate": 8.285297549591599e-05,
2767
+ "loss": 3.3135406494140627,
2768
+ "step": 3940
2769
+ },
2770
+ {
2771
+ "epoch": 4.3551144196305485,
2772
+ "grad_norm": 1.5899906158447266,
2773
+ "learning_rate": 8.27946324387398e-05,
2774
+ "loss": 3.31378059387207,
2775
+ "step": 3950
2776
+ },
2777
+ {
2778
+ "epoch": 4.366142817755721,
2779
+ "grad_norm": 1.2071844339370728,
2780
+ "learning_rate": 8.273628938156361e-05,
2781
+ "loss": 3.3018829345703127,
2782
+ "step": 3960
2783
+ },
2784
+ {
2785
+ "epoch": 4.377171215880893,
2786
+ "grad_norm": 1.8953418731689453,
2787
+ "learning_rate": 8.26779463243874e-05,
2788
+ "loss": 3.3119953155517576,
2789
+ "step": 3970
2790
+ },
2791
+ {
2792
+ "epoch": 4.388199614006066,
2793
+ "grad_norm": 1.7741807699203491,
2794
+ "learning_rate": 8.261960326721121e-05,
2795
+ "loss": 3.3027114868164062,
2796
+ "step": 3980
2797
+ },
2798
+ {
2799
+ "epoch": 4.399228012131238,
2800
+ "grad_norm": 1.3921217918395996,
2801
+ "learning_rate": 8.256126021003501e-05,
2802
+ "loss": 3.317920684814453,
2803
+ "step": 3990
2804
+ },
2805
+ {
2806
+ "epoch": 4.410256410256411,
2807
+ "grad_norm": 1.1690531969070435,
2808
+ "learning_rate": 8.250291715285882e-05,
2809
+ "loss": 3.29705810546875,
2810
+ "step": 4000
2811
+ },
2812
+ {
2813
+ "epoch": 4.421284808381582,
2814
+ "grad_norm": 1.3882209062576294,
2815
+ "learning_rate": 8.244457409568261e-05,
2816
+ "loss": 3.304886245727539,
2817
+ "step": 4010
2818
+ },
2819
+ {
2820
+ "epoch": 4.4323132065067545,
2821
+ "grad_norm": 2.1946423053741455,
2822
+ "learning_rate": 8.238623103850642e-05,
2823
+ "loss": 3.3152816772460936,
2824
+ "step": 4020
2825
+ },
2826
+ {
2827
+ "epoch": 4.443341604631927,
2828
+ "grad_norm": 1.517082929611206,
2829
+ "learning_rate": 8.232788798133022e-05,
2830
+ "loss": 3.3114837646484374,
2831
+ "step": 4030
2832
+ },
2833
+ {
2834
+ "epoch": 4.454370002757099,
2835
+ "grad_norm": 1.2431399822235107,
2836
+ "learning_rate": 8.226954492415403e-05,
2837
+ "loss": 3.306407165527344,
2838
+ "step": 4040
2839
+ },
2840
+ {
2841
+ "epoch": 4.465398400882272,
2842
+ "grad_norm": 1.5142467021942139,
2843
+ "learning_rate": 8.221120186697783e-05,
2844
+ "loss": 3.3055789947509764,
2845
+ "step": 4050
2846
+ },
2847
+ {
2848
+ "epoch": 4.476426799007444,
2849
+ "grad_norm": 1.1361483335494995,
2850
+ "learning_rate": 8.215285880980163e-05,
2851
+ "loss": 3.3048805236816405,
2852
+ "step": 4060
2853
+ },
2854
+ {
2855
+ "epoch": 4.487455197132617,
2856
+ "grad_norm": 1.1522105932235718,
2857
+ "learning_rate": 8.209451575262544e-05,
2858
+ "loss": 3.2948539733886717,
2859
+ "step": 4070
2860
+ },
2861
+ {
2862
+ "epoch": 4.498483595257789,
2863
+ "grad_norm": 1.1002084016799927,
2864
+ "learning_rate": 8.203617269544925e-05,
2865
+ "loss": 3.306837463378906,
2866
+ "step": 4080
2867
+ },
2868
+ {
2869
+ "epoch": 4.5095119933829615,
2870
+ "grad_norm": 1.4114456176757812,
2871
+ "learning_rate": 8.197782963827306e-05,
2872
+ "loss": 3.3008705139160157,
2873
+ "step": 4090
2874
+ },
2875
+ {
2876
+ "epoch": 4.520540391508133,
2877
+ "grad_norm": 1.3177834749221802,
2878
+ "learning_rate": 8.191948658109685e-05,
2879
+ "loss": 3.3015769958496093,
2880
+ "step": 4100
2881
+ },
2882
+ {
2883
+ "epoch": 4.5315687896333054,
2884
+ "grad_norm": 1.2859690189361572,
2885
+ "learning_rate": 8.186114352392066e-05,
2886
+ "loss": 3.3012962341308594,
2887
+ "step": 4110
2888
+ },
2889
+ {
2890
+ "epoch": 4.542597187758478,
2891
+ "grad_norm": 1.149977445602417,
2892
+ "learning_rate": 8.180280046674446e-05,
2893
+ "loss": 3.2883323669433593,
2894
+ "step": 4120
2895
+ },
2896
+ {
2897
+ "epoch": 4.55362558588365,
2898
+ "grad_norm": 1.1980609893798828,
2899
+ "learning_rate": 8.174445740956827e-05,
2900
+ "loss": 3.305525207519531,
2901
+ "step": 4130
2902
+ },
2903
+ {
2904
+ "epoch": 4.564653984008823,
2905
+ "grad_norm": 1.2316346168518066,
2906
+ "learning_rate": 8.168611435239207e-05,
2907
+ "loss": 3.296540069580078,
2908
+ "step": 4140
2909
+ },
2910
+ {
2911
+ "epoch": 4.575682382133995,
2912
+ "grad_norm": 1.456752896308899,
2913
+ "learning_rate": 8.162777129521587e-05,
2914
+ "loss": 3.301634979248047,
2915
+ "step": 4150
2916
+ },
2917
+ {
2918
+ "epoch": 4.586710780259168,
2919
+ "grad_norm": 1.5025802850723267,
2920
+ "learning_rate": 8.156942823803968e-05,
2921
+ "loss": 3.303561782836914,
2922
+ "step": 4160
2923
+ },
2924
+ {
2925
+ "epoch": 4.59773917838434,
2926
+ "grad_norm": 1.3021212816238403,
2927
+ "learning_rate": 8.151108518086347e-05,
2928
+ "loss": 3.302195739746094,
2929
+ "step": 4170
2930
+ },
2931
+ {
2932
+ "epoch": 4.608767576509512,
2933
+ "grad_norm": 1.758484125137329,
2934
+ "learning_rate": 8.145274212368728e-05,
2935
+ "loss": 3.298839569091797,
2936
+ "step": 4180
2937
+ },
2938
+ {
2939
+ "epoch": 4.619795974634684,
2940
+ "grad_norm": 1.034860372543335,
2941
+ "learning_rate": 8.139439906651108e-05,
2942
+ "loss": 3.308152770996094,
2943
+ "step": 4190
2944
+ },
2945
+ {
2946
+ "epoch": 4.630824372759856,
2947
+ "grad_norm": 1.233070969581604,
2948
+ "learning_rate": 8.133605600933489e-05,
2949
+ "loss": 3.297984313964844,
2950
+ "step": 4200
2951
+ },
2952
+ {
2953
+ "epoch": 4.641852770885029,
2954
+ "grad_norm": 1.7277765274047852,
2955
+ "learning_rate": 8.12777129521587e-05,
2956
+ "loss": 3.3045902252197266,
2957
+ "step": 4210
2958
+ },
2959
+ {
2960
+ "epoch": 4.652881169010201,
2961
+ "grad_norm": 1.2869057655334473,
2962
+ "learning_rate": 8.12193698949825e-05,
2963
+ "loss": 3.3063819885253904,
2964
+ "step": 4220
2965
+ },
2966
+ {
2967
+ "epoch": 4.663909567135374,
2968
+ "grad_norm": 1.1411103010177612,
2969
+ "learning_rate": 8.116102683780631e-05,
2970
+ "loss": 3.2905479431152345,
2971
+ "step": 4230
2972
+ },
2973
+ {
2974
+ "epoch": 4.674937965260546,
2975
+ "grad_norm": 1.342445969581604,
2976
+ "learning_rate": 8.110268378063011e-05,
2977
+ "loss": 3.2918365478515623,
2978
+ "step": 4240
2979
+ },
2980
+ {
2981
+ "epoch": 4.6859663633857185,
2982
+ "grad_norm": 1.206933617591858,
2983
+ "learning_rate": 8.104434072345392e-05,
2984
+ "loss": 3.303221893310547,
2985
+ "step": 4250
2986
+ },
2987
+ {
2988
+ "epoch": 4.696994761510891,
2989
+ "grad_norm": 1.3959113359451294,
2990
+ "learning_rate": 8.098599766627771e-05,
2991
+ "loss": 3.3067909240722657,
2992
+ "step": 4260
2993
+ },
2994
+ {
2995
+ "epoch": 4.708023159636063,
2996
+ "grad_norm": 1.9725914001464844,
2997
+ "learning_rate": 8.092765460910152e-05,
2998
+ "loss": 3.300902557373047,
2999
+ "step": 4270
3000
+ },
3001
+ {
3002
+ "epoch": 4.719051557761235,
3003
+ "grad_norm": 1.3540401458740234,
3004
+ "learning_rate": 8.086931155192532e-05,
3005
+ "loss": 3.3051612854003904,
3006
+ "step": 4280
3007
+ },
3008
+ {
3009
+ "epoch": 4.730079955886407,
3010
+ "grad_norm": 1.2321784496307373,
3011
+ "learning_rate": 8.081096849474913e-05,
3012
+ "loss": 3.2939830780029298,
3013
+ "step": 4290
3014
+ },
3015
+ {
3016
+ "epoch": 4.74110835401158,
3017
+ "grad_norm": 1.2586874961853027,
3018
+ "learning_rate": 8.075262543757294e-05,
3019
+ "loss": 3.301250457763672,
3020
+ "step": 4300
3021
+ },
3022
+ {
3023
+ "epoch": 4.752136752136752,
3024
+ "grad_norm": 1.1622635126113892,
3025
+ "learning_rate": 8.069428238039673e-05,
3026
+ "loss": 3.29620361328125,
3027
+ "step": 4310
3028
+ },
3029
+ {
3030
+ "epoch": 4.7631651502619246,
3031
+ "grad_norm": 1.204060673713684,
3032
+ "learning_rate": 8.063593932322054e-05,
3033
+ "loss": 3.301993179321289,
3034
+ "step": 4320
3035
+ },
3036
+ {
3037
+ "epoch": 4.774193548387097,
3038
+ "grad_norm": 1.2462209463119507,
3039
+ "learning_rate": 8.057759626604434e-05,
3040
+ "loss": 3.294866180419922,
3041
+ "step": 4330
3042
+ },
3043
+ {
3044
+ "epoch": 4.785221946512269,
3045
+ "grad_norm": 1.086969256401062,
3046
+ "learning_rate": 8.051925320886814e-05,
3047
+ "loss": 3.290216827392578,
3048
+ "step": 4340
3049
+ },
3050
+ {
3051
+ "epoch": 4.796250344637441,
3052
+ "grad_norm": 1.6685938835144043,
3053
+ "learning_rate": 8.046091015169195e-05,
3054
+ "loss": 3.291016387939453,
3055
+ "step": 4350
3056
+ },
3057
+ {
3058
+ "epoch": 4.807278742762613,
3059
+ "grad_norm": 1.279870867729187,
3060
+ "learning_rate": 8.040256709451576e-05,
3061
+ "loss": 3.2892833709716798,
3062
+ "step": 4360
3063
+ },
3064
+ {
3065
+ "epoch": 4.818307140887786,
3066
+ "grad_norm": 1.083748459815979,
3067
+ "learning_rate": 8.034422403733956e-05,
3068
+ "loss": 3.298921585083008,
3069
+ "step": 4370
3070
+ },
3071
+ {
3072
+ "epoch": 4.829335539012958,
3073
+ "grad_norm": 1.215922474861145,
3074
+ "learning_rate": 8.028588098016337e-05,
3075
+ "loss": 3.290033721923828,
3076
+ "step": 4380
3077
+ },
3078
+ {
3079
+ "epoch": 4.840363937138131,
3080
+ "grad_norm": 1.4302809238433838,
3081
+ "learning_rate": 8.022753792298718e-05,
3082
+ "loss": 3.289847564697266,
3083
+ "step": 4390
3084
+ },
3085
+ {
3086
+ "epoch": 4.851392335263303,
3087
+ "grad_norm": 1.2112072706222534,
3088
+ "learning_rate": 8.016919486581097e-05,
3089
+ "loss": 3.289959716796875,
3090
+ "step": 4400
3091
+ },
3092
+ {
3093
+ "epoch": 4.8624207333884755,
3094
+ "grad_norm": 1.5671532154083252,
3095
+ "learning_rate": 8.011085180863478e-05,
3096
+ "loss": 3.2878067016601564,
3097
+ "step": 4410
3098
+ },
3099
+ {
3100
+ "epoch": 4.873449131513648,
3101
+ "grad_norm": 1.6009515523910522,
3102
+ "learning_rate": 8.005250875145858e-05,
3103
+ "loss": 3.286448669433594,
3104
+ "step": 4420
3105
+ },
3106
+ {
3107
+ "epoch": 4.88447752963882,
3108
+ "grad_norm": 1.324246883392334,
3109
+ "learning_rate": 7.999416569428238e-05,
3110
+ "loss": 3.2920913696289062,
3111
+ "step": 4430
3112
+ },
3113
+ {
3114
+ "epoch": 4.895505927763992,
3115
+ "grad_norm": 1.2959766387939453,
3116
+ "learning_rate": 7.993582263710618e-05,
3117
+ "loss": 3.3023696899414063,
3118
+ "step": 4440
3119
+ },
3120
+ {
3121
+ "epoch": 4.906534325889164,
3122
+ "grad_norm": 1.0577853918075562,
3123
+ "learning_rate": 7.987747957992999e-05,
3124
+ "loss": 3.2800533294677736,
3125
+ "step": 4450
3126
+ },
3127
+ {
3128
+ "epoch": 4.917562724014337,
3129
+ "grad_norm": 1.5235346555709839,
3130
+ "learning_rate": 7.98191365227538e-05,
3131
+ "loss": 3.2923828125,
3132
+ "step": 4460
3133
+ },
3134
+ {
3135
+ "epoch": 4.928591122139509,
3136
+ "grad_norm": 1.4431898593902588,
3137
+ "learning_rate": 7.97607934655776e-05,
3138
+ "loss": 3.2987926483154295,
3139
+ "step": 4470
3140
+ },
3141
+ {
3142
+ "epoch": 4.9396195202646815,
3143
+ "grad_norm": 1.6988770961761475,
3144
+ "learning_rate": 7.97024504084014e-05,
3145
+ "loss": 3.2937850952148438,
3146
+ "step": 4480
3147
+ },
3148
+ {
3149
+ "epoch": 4.950647918389854,
3150
+ "grad_norm": 1.3248101472854614,
3151
+ "learning_rate": 7.964410735122521e-05,
3152
+ "loss": 3.284267044067383,
3153
+ "step": 4490
3154
+ },
3155
+ {
3156
+ "epoch": 4.961676316515026,
3157
+ "grad_norm": 1.3350111246109009,
3158
+ "learning_rate": 7.958576429404902e-05,
3159
+ "loss": 3.2882057189941407,
3160
+ "step": 4500
3161
+ },
3162
+ {
3163
+ "epoch": 4.972704714640199,
3164
+ "grad_norm": 1.434801459312439,
3165
+ "learning_rate": 7.952742123687282e-05,
3166
+ "loss": 3.28388671875,
3167
+ "step": 4510
3168
+ },
3169
+ {
3170
+ "epoch": 4.983733112765371,
3171
+ "grad_norm": 1.0145658254623413,
3172
+ "learning_rate": 7.946907817969662e-05,
3173
+ "loss": 3.288136291503906,
3174
+ "step": 4520
3175
+ },
3176
+ {
3177
+ "epoch": 4.994761510890543,
3178
+ "grad_norm": 1.1575376987457275,
3179
+ "learning_rate": 7.941073512252042e-05,
3180
+ "loss": 3.2970806121826173,
3181
+ "step": 4530
3182
+ }
3183
+ ],
3184
+ "logging_steps": 10,
3185
+ "max_steps": 18140,
3186
+ "num_input_tokens_seen": 0,
3187
+ "num_train_epochs": 20,
3188
+ "save_steps": 500,
3189
+ "stateful_callbacks": {
3190
+ "TrainerControl": {
3191
+ "args": {
3192
+ "should_epoch_stop": false,
3193
+ "should_evaluate": false,
3194
+ "should_log": false,
3195
+ "should_save": true,
3196
+ "should_training_stop": false
3197
+ },
3198
+ "attributes": {}
3199
+ }
3200
+ },
3201
+ "total_flos": 1805561643270144.0,
3202
+ "train_batch_size": 1,
3203
+ "trial_name": null,
3204
+ "trial_params": null
3205
+ }
output_qwen3_plain_ar/checkpoint-4535/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-5442/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-5442/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-5442/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step5442
output_qwen3_plain_ar/checkpoint-5442/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
output_qwen3_plain_ar/checkpoint-5442/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-6349/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-6349/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-6349/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step6349
output_qwen3_plain_ar/checkpoint-6349/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
output_qwen3_plain_ar/checkpoint-6349/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-7256/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-7256/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-7256/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step7256
output_qwen3_plain_ar/checkpoint-7256/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
output_qwen3_plain_ar/checkpoint-7256/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
output_qwen3_plain_ar/checkpoint-8163/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "dtype": "bfloat16",
9
+ "eos_token_id": 151645,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention"
44
+ ],
45
+ "magel_chord_dropout_trigger_prob": 0.6,
46
+ "magel_num_audio_token": 16384,
47
+ "magel_structure_dropout_trigger_prob": 0.6,
48
+ "max_position_embeddings": 40960,
49
+ "max_window_layers": 28,
50
+ "model_type": "qwen3",
51
+ "num_attention_heads": 16,
52
+ "num_hidden_layers": 28,
53
+ "num_key_value_heads": 8,
54
+ "pad_token_id": null,
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_parameters": {
57
+ "rope_theta": 1000000,
58
+ "rope_type": "default"
59
+ },
60
+ "sliding_window": null,
61
+ "tie_word_embeddings": true,
62
+ "transformers_version": "5.4.0",
63
+ "use_cache": false,
64
+ "use_sliding_window": false,
65
+ "vocab_size": 168056
66
+ }
output_qwen3_plain_ar/checkpoint-8163/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "5.4.0"
13
+ }
output_qwen3_plain_ar/checkpoint-8163/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step8163
output_qwen3_plain_ar/checkpoint-8163/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff