Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- RM-EN-01-30-2026/code/main.py +394 -0
- RM-EN-01-30-2026/code/model_utils.py +177 -0
- RM-EN-01-30-2026/code/raw_datasets.py +828 -0
- RM-EN-01-30-2026/code/reward_model.py +204 -0
- RM-EN-01-30-2026/data/rm_eval.jsonl +0 -0
- RM-EN-01-30-2026/data/rm_train.jsonl +3 -0
- RM-EN-01-30-2026/model/chat_template.jinja +89 -0
- RM-EN-01-30-2026/model/config.json +73 -0
- RM-EN-01-30-2026/model/model.safetensors +3 -0
- RM-EN-01-30-2026/model/tokenizer.json +3 -0
- RM-EN-01-30-2026/model/tokenizer_config.json +30 -0
- RM-EN-01-30-2026/model/training.log +0 -0
- RM-EN-01-30-2026/scripts/run_qwen3-4b.sh +27 -0
- SFT-EN-01-29-2026/README.md +25 -0
- SFT-EN-01-29-2026/code/data_utils.py +629 -0
- SFT-EN-01-29-2026/code/main.py +866 -0
- SFT-EN-01-29-2026/code/model_utils.py +168 -0
- SFT-EN-01-29-2026/code/prompt_eval.py +146 -0
- SFT-EN-01-29-2026/code/raw_datasets.py +828 -0
- SFT-EN-01-29-2026/code/utils.py +384 -0
- SFT-EN-01-29-2026/data/dev.jsonl +0 -0
- SFT-EN-01-29-2026/data/eval.jsonl +0 -0
- SFT-EN-01-29-2026/data/train.jsonl +3 -0
- SFT-EN-01-29-2026/model/chat_template.jinja +89 -0
- SFT-EN-01-29-2026/model/config.json +72 -0
- SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 +3 -0
- SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 +3 -0
- SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 +3 -0
- SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 +3 -0
- SFT-EN-01-29-2026/model/model.safetensors +3 -0
- SFT-EN-01-29-2026/model/tokenizer.json +3 -0
- SFT-EN-01-29-2026/model/tokenizer_config.json +30 -0
- SFT-EN-01-29-2026/model/training.log +317 -0
- SFT-EN-01-29-2026/scripts/run_qwen3-4b.sh +36 -0
- sft_model_backup/chat_template.jinja +89 -0
- sft_model_backup/config.json +72 -0
- sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0 +3 -0
- sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0 +3 -0
- sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0 +3 -0
- sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0 +3 -0
- sft_model_backup/model.safetensors +3 -0
- sft_model_backup/tokenizer.json +3 -0
- sft_model_backup/tokenizer_config.json +30 -0
- sft_model_backup/training.log +317 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
RM-EN-01-30-2026/data/rm_train.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
RM-EN-01-30-2026/model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
SFT-EN-01-29-2026/data/train.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
SFT-EN-01-29-2026/model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
sft_model_backup/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
RM-EN-01-30-2026/code/main.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Microsoft Corporation.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
# DeepSpeed Team
|
| 6 |
+
#!/usr/bin/env python
|
| 7 |
+
# Copyright (c) Microsoft Corporation.
|
| 8 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 9 |
+
|
| 10 |
+
# DeepSpeed Team
|
| 11 |
+
|
| 12 |
+
#!/usr/bin/env python
|
| 13 |
+
# Copyright (c) Microsoft Corporation.
|
| 14 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 15 |
+
|
| 16 |
+
# DeepSpeed Team
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import math
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
| 24 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
+
|
| 26 |
+
from transformers import (
|
| 27 |
+
SchedulerType,
|
| 28 |
+
get_scheduler,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
import deepspeed
|
| 32 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
| 33 |
+
from deepspeed.accelerator import get_accelerator
|
| 34 |
+
|
| 35 |
+
from dschat.utils.model.model_utils import create_critic_model
|
| 36 |
+
from dschat.utils.data.data_utils import create_prompt_dataset, DataCollatorReward
|
| 37 |
+
from dschat.utils.utils import print_rank_0, to_device, save_hf_format, save_hf_format_safetensors, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
|
| 38 |
+
from dschat.utils.ds_utils import get_train_ds_config
|
| 39 |
+
from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
|
| 40 |
+
|
| 41 |
+
def parse_args():
|
| 42 |
+
parser = argparse.ArgumentParser(
|
| 43 |
+
description=
|
| 44 |
+
"Finetune a transformers model on a causal language modeling task")
|
| 45 |
+
parser.add_argument('--data_path',
|
| 46 |
+
nargs='*',
|
| 47 |
+
default=['Dahoas/rm-static'],
|
| 48 |
+
help='Path to the training dataset. Accepted format:'
|
| 49 |
+
'1) a single data path, 2) multiple datasets in the'
|
| 50 |
+
'form: dataset1-path dataset2-path ...')
|
| 51 |
+
parser.add_argument('--data_split',
|
| 52 |
+
type=str,
|
| 53 |
+
default='2,4,4',
|
| 54 |
+
help='Comma-separated list of proportions for training'
|
| 55 |
+
'phase 1, 2, and 3 data. For example the split `6,2,2`'
|
| 56 |
+
'will use 60%% of data for phase 1, 20%% for phase 2'
|
| 57 |
+
'and 20%% for phase 3.')
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'--data_output_path',
|
| 60 |
+
type=str,
|
| 61 |
+
default='/tmp/data_files/',
|
| 62 |
+
help=
|
| 63 |
+
'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--model_name_or_path",
|
| 67 |
+
type=str,
|
| 68 |
+
help=
|
| 69 |
+
"Path to pretrained model or model identifier from huggingface.co/models.",
|
| 70 |
+
required=True,
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--num_padding_at_beginning",
|
| 74 |
+
type=int,
|
| 75 |
+
default=1,
|
| 76 |
+
help=
|
| 77 |
+
"OPT model has a fixed number (1) of padding tokens at the beginning of the input. We did not see this in other models but keep it as an option for now."
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--per_device_train_batch_size",
|
| 81 |
+
type=int,
|
| 82 |
+
default=16,
|
| 83 |
+
help="Batch size (per device) for the training dataloader.",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--per_device_eval_batch_size",
|
| 87 |
+
type=int,
|
| 88 |
+
default=16,
|
| 89 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--max_seq_len",
|
| 93 |
+
type=int,
|
| 94 |
+
default=512,
|
| 95 |
+
help="The maximum sequence length.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--learning_rate",
|
| 99 |
+
type=float,
|
| 100 |
+
default=5e-5,
|
| 101 |
+
help=
|
| 102 |
+
"Initial learning rate (after the potential warmup period) to use.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument("--weight_decay",
|
| 105 |
+
type=float,
|
| 106 |
+
default=0.,
|
| 107 |
+
help="Weight decay to use.")
|
| 108 |
+
parser.add_argument("--num_train_epochs",
|
| 109 |
+
type=int,
|
| 110 |
+
default=1,
|
| 111 |
+
help="Total number of training epochs to perform.")
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--gradient_accumulation_steps",
|
| 114 |
+
type=int,
|
| 115 |
+
default=1,
|
| 116 |
+
help=
|
| 117 |
+
"Number of updates steps to accumulate before performing a backward/update pass.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--lr_scheduler_type",
|
| 121 |
+
type=SchedulerType,
|
| 122 |
+
default="cosine",
|
| 123 |
+
help="The scheduler type to use.",
|
| 124 |
+
choices=[
|
| 125 |
+
"linear", "cosine", "cosine_with_restarts", "polynomial",
|
| 126 |
+
"constant", "constant_with_warmup"
|
| 127 |
+
],
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--num_warmup_steps",
|
| 131 |
+
type=int,
|
| 132 |
+
default=0,
|
| 133 |
+
help="Number of steps for the warmup in the lr scheduler.")
|
| 134 |
+
parser.add_argument("--output_dir",
|
| 135 |
+
type=str,
|
| 136 |
+
default=None,
|
| 137 |
+
help="Where to store the model.")
|
| 138 |
+
parser.add_argument("--seed",
|
| 139 |
+
type=int,
|
| 140 |
+
default=1234,
|
| 141 |
+
help="A seed for reproducible training.")
|
| 142 |
+
parser.add_argument("--local_rank",
|
| 143 |
+
type=int,
|
| 144 |
+
default=-1,
|
| 145 |
+
help="local_rank for distributed training on gpus")
|
| 146 |
+
parser.add_argument('--gradient_checkpointing',
|
| 147 |
+
action='store_true',
|
| 148 |
+
help='Enable HF gradient checkpointing for model.')
|
| 149 |
+
parser.add_argument('--disable_dropout',
|
| 150 |
+
action='store_true',
|
| 151 |
+
help='Disable the dropout of the model.')
|
| 152 |
+
# deepspeed features
|
| 153 |
+
parser.add_argument('--offload',
|
| 154 |
+
action='store_true',
|
| 155 |
+
help='Enable ZeRO Offload techniques.')
|
| 156 |
+
parser.add_argument('--dtype',
|
| 157 |
+
type=str,
|
| 158 |
+
default='fp16',
|
| 159 |
+
choices=['fp16', 'bf16'],
|
| 160 |
+
help='Training data type')
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
'--zero_stage',
|
| 163 |
+
type=int,
|
| 164 |
+
default=0,
|
| 165 |
+
help='ZeRO optimization stage for Actor model (and clones).')
|
| 166 |
+
## LoRA for efficient training setting
|
| 167 |
+
parser.add_argument("--lora_dim",
|
| 168 |
+
type=int,
|
| 169 |
+
default=0,
|
| 170 |
+
help="If > 0, use LoRA for efficient training.")
|
| 171 |
+
parser.add_argument("--lora_module_name",
|
| 172 |
+
type=str,
|
| 173 |
+
default="decoder.layers.",
|
| 174 |
+
help="The scope of LoRA.")
|
| 175 |
+
parser.add_argument('--only_optimize_lora',
|
| 176 |
+
action='store_true',
|
| 177 |
+
help='Only optimize the LoRA parameters.')
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--lora_learning_rate",
|
| 180 |
+
type=float,
|
| 181 |
+
default=5e-4,
|
| 182 |
+
help=
|
| 183 |
+
"Initial LoRA learning rate (after the potential warmup period) to use."
|
| 184 |
+
)
|
| 185 |
+
## Tensorboard logging
|
| 186 |
+
parser.add_argument('--enable_tensorboard',
|
| 187 |
+
action='store_true',
|
| 188 |
+
help='Enable tensorboard logging')
|
| 189 |
+
parser.add_argument('--tensorboard_path',
|
| 190 |
+
type=str,
|
| 191 |
+
default="step2_tensorboard")
|
| 192 |
+
## Print loss
|
| 193 |
+
parser.add_argument('--print_loss',
|
| 194 |
+
action='store_true',
|
| 195 |
+
help='Prints loss at each step.')
|
| 196 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
return args
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main():
|
| 203 |
+
args = parse_args()
|
| 204 |
+
|
| 205 |
+
if args.local_rank == -1:
|
| 206 |
+
device = torch.device(get_accelerator().device_name())
|
| 207 |
+
else:
|
| 208 |
+
get_accelerator().set_device(args.local_rank)
|
| 209 |
+
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
| 210 |
+
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
| 211 |
+
# torch.distributed.init_process_group(backend='nccl')
|
| 212 |
+
deepspeed.init_distributed()
|
| 213 |
+
|
| 214 |
+
args.global_rank = torch.distributed.get_rank()
|
| 215 |
+
|
| 216 |
+
ds_config = get_train_ds_config(offload=args.offload,
|
| 217 |
+
dtype=args.dtype,
|
| 218 |
+
stage=args.zero_stage,
|
| 219 |
+
enable_tensorboard=args.enable_tensorboard,
|
| 220 |
+
tb_path=args.tensorboard_path,
|
| 221 |
+
tb_name="step2_model")
|
| 222 |
+
ds_config['train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
|
| 223 |
+
ds_config['train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size() * args.gradient_accumulation_steps
|
| 224 |
+
|
| 225 |
+
set_random_seed(args.seed)
|
| 226 |
+
torch.distributed.barrier()
|
| 227 |
+
|
| 228 |
+
tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True)
|
| 229 |
+
# critic_model本质上是reward_model的一个副本, 是同一个模型的参数初始化得到的
|
| 230 |
+
rm_model = create_critic_model(args.model_name_or_path,
|
| 231 |
+
tokenizer,
|
| 232 |
+
ds_config,
|
| 233 |
+
args.num_padding_at_beginning,
|
| 234 |
+
disable_dropout=args.disable_dropout)
|
| 235 |
+
|
| 236 |
+
if args.lora_dim > 0:
|
| 237 |
+
rm_model = convert_linear_layer_to_lora(rm_model,
|
| 238 |
+
args.lora_module_name,
|
| 239 |
+
args.lora_dim)
|
| 240 |
+
if args.only_optimize_lora:
|
| 241 |
+
rm_model = only_optimize_lora_parameters(rm_model)
|
| 242 |
+
|
| 243 |
+
rm_model = make_model_gradient_checkpointing_compatible(rm_model)
|
| 244 |
+
|
| 245 |
+
# 设置当前为第二阶段的训练, 即Reward Model训练阶段
|
| 246 |
+
train_phase = 2
|
| 247 |
+
train_dataset, eval_dataset = create_prompt_dataset(args.local_rank, args.data_path, args.data_split,
|
| 248 |
+
args.data_output_path, train_phase, args.seed,
|
| 249 |
+
tokenizer, args.max_seq_len)
|
| 250 |
+
# 创建DataLoader, 在代码文件utils/data/data_utils.py中有具体实现DataCollatorReward类
|
| 251 |
+
data_collator = DataCollatorReward()
|
| 252 |
+
if args.local_rank == -1:
|
| 253 |
+
train_sampler = RandomSampler(train_dataset)
|
| 254 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 255 |
+
else:
|
| 256 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 257 |
+
eval_sampler = DistributedSampler(eval_dataset)
|
| 258 |
+
# 封装训练集数据迭代器
|
| 259 |
+
train_dataloader = DataLoader(train_dataset,
|
| 260 |
+
collate_fn=data_collator,
|
| 261 |
+
sampler=train_sampler,
|
| 262 |
+
batch_size=args.per_device_train_batch_size)
|
| 263 |
+
# 封装验证集数据迭代器
|
| 264 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 265 |
+
eval_dataloader = DataLoader(eval_dataset,
|
| 266 |
+
collate_fn=data_collator,
|
| 267 |
+
sampler=eval_sampler,
|
| 268 |
+
batch_size=args.per_device_eval_batch_size)
|
| 269 |
+
|
| 270 |
+
# 在main函数内部定义了价值评估函数
|
| 271 |
+
def evaluation_reward(model, eval_dataloader):
|
| 272 |
+
# 将模型设置为评估模式
|
| 273 |
+
model.eval()
|
| 274 |
+
# 初始化若干统计值为0
|
| 275 |
+
correct_predictions = 0
|
| 276 |
+
total_predictions = 0
|
| 277 |
+
scores = 0
|
| 278 |
+
for step, batch in enumerate(eval_dataloader):
|
| 279 |
+
batch = to_device(batch, device)
|
| 280 |
+
# 数据流必须禁止梯度计算和反向传播
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
outputs = model(**batch)
|
| 283 |
+
'''
|
| 284 |
+
outputs: {
|
| 285 |
+
'loss': tensor(),
|
| 286 |
+
'chosen_mean_scores': tensor(batch_size,),
|
| 287 |
+
'rejected_mean_scores': tensor(batch_size,)
|
| 288 |
+
}
|
| 289 |
+
'''
|
| 290 |
+
# chosen.shape: (batch_size,), rejected.shape: (batch_size, )
|
| 291 |
+
chosen = outputs["chosen_mean_scores"]
|
| 292 |
+
rejected = outputs["rejected_mean_scores"]
|
| 293 |
+
# chosen分值大于rejected分值, 即为赋分正确, 本质上就是"response的排序正确"
|
| 294 |
+
correct_predictions += (chosen > rejected).sum()
|
| 295 |
+
total_predictions += chosen.shape[0]
|
| 296 |
+
# 累加每个step的平均chosen分值
|
| 297 |
+
scores += outputs["chosen_mean_scores"].mean().float()
|
| 298 |
+
if step == 99: # For faster evaluation and debugging
|
| 299 |
+
break
|
| 300 |
+
# 计算acc, 和当前step的平均分数值
|
| 301 |
+
acc = correct_predictions / total_predictions
|
| 302 |
+
scores = scores / (step + 1)
|
| 303 |
+
try:
|
| 304 |
+
# 对多进程结果进行求和平均
|
| 305 |
+
acc = get_all_reduce_mean(acc).item()
|
| 306 |
+
scores = get_all_reduce_mean(scores).item()
|
| 307 |
+
except:
|
| 308 |
+
pass
|
| 309 |
+
# 最终返回平均分数值, acc值
|
| 310 |
+
return scores, acc
|
| 311 |
+
|
| 312 |
+
# 分组参数优化, 一部分参数采用weight decay策略, 另一部分不采用.
|
| 313 |
+
optimizer_grouped_parameters = get_optimizer_grouped_parameters(rm_model,
|
| 314 |
+
args.weight_decay,
|
| 315 |
+
args.lora_learning_rate)
|
| 316 |
+
# 实例化优化器对象
|
| 317 |
+
AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
|
| 318 |
+
optimizer = AdamOptimizer(optimizer_grouped_parameters,
|
| 319 |
+
lr=args.learning_rate,
|
| 320 |
+
betas=(0.9, 0.95))
|
| 321 |
+
|
| 322 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 323 |
+
# 实例化调节器对象
|
| 324 |
+
lr_scheduler = get_scheduler(name=args.lr_scheduler_type,
|
| 325 |
+
optimizer=optimizer,
|
| 326 |
+
num_warmup_steps=args.num_warmup_steps,
|
| 327 |
+
num_training_steps=args.num_train_epochs *
|
| 328 |
+
num_update_steps_per_epoch)
|
| 329 |
+
|
| 330 |
+
# 利用deepspeed封装model, 优化器, 调节器和参数, 加速训练!
|
| 331 |
+
rm_model, optimizer, _, lr_scheduler = deepspeed.initialize(model=rm_model,
|
| 332 |
+
optimizer=optimizer,
|
| 333 |
+
args=args,
|
| 334 |
+
config=ds_config,
|
| 335 |
+
lr_scheduler=lr_scheduler,
|
| 336 |
+
dist_init_required=True)
|
| 337 |
+
|
| 338 |
+
if args.gradient_checkpointing:
|
| 339 |
+
rm_model.gradient_checkpointing_enable()
|
| 340 |
+
|
| 341 |
+
# 开始训练!!!
|
| 342 |
+
print_rank_0("***** Running training *****", args.global_rank)
|
| 343 |
+
|
| 344 |
+
print_rank_0(f"***** Evaluating reward, Epoch {0}/{args.num_train_epochs} *****", args.global_rank)
|
| 345 |
+
# 评估reward_model的表现
|
| 346 |
+
reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
|
| 347 |
+
print_rank_0(f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}", args.global_rank)
|
| 348 |
+
|
| 349 |
+
# 经典的双重for循环训练模式
|
| 350 |
+
for epoch in range(args.num_train_epochs):
|
| 351 |
+
print_rank_0(f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", args.global_rank)
|
| 352 |
+
# 设置reward model为训练模式
|
| 353 |
+
rm_model.train()
|
| 354 |
+
mean_loss = 0
|
| 355 |
+
for step, batch in enumerate(train_dataloader):
|
| 356 |
+
batch = to_device(batch, device)
|
| 357 |
+
# reward model进行前向传播计算出损失值
|
| 358 |
+
outputs = rm_model(**batch, use_cache=False)
|
| 359 |
+
'''
|
| 360 |
+
outputs: {
|
| 361 |
+
'loss': tensor(),
|
| 362 |
+
'chosen_mean_scores': tensor(batch_size,),
|
| 363 |
+
'rejected_mean_scores': tensor(batch_size,)
|
| 364 |
+
}
|
| 365 |
+
'''
|
| 366 |
+
loss = outputs["loss"]
|
| 367 |
+
# 经典"老三样", reward model进行反向传播
|
| 368 |
+
rm_model.backward(loss)
|
| 369 |
+
rm_model.step()
|
| 370 |
+
# 累加损失值, 并打印信息
|
| 371 |
+
mean_loss += loss.item()
|
| 372 |
+
print_rank_0(f"Epoch {epoch+1}/{args.num_train_epochs} with loss {mean_loss/(step+1)}", args.global_rank)
|
| 373 |
+
print_rank_0(f"***** Evaluating reward, Epoch {epoch+1}/{args.num_train_epochs} *****", args.global_rank)
|
| 374 |
+
# 在验证集上进行reward model的评估
|
| 375 |
+
reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
|
| 376 |
+
print_rank_0(f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}", args.global_rank)
|
| 377 |
+
rm_model.tput_timer.update_epoch_count()
|
| 378 |
+
|
| 379 |
+
if args.output_dir is not None:
|
| 380 |
+
print_rank_0('saving model ...', args.global_rank)
|
| 381 |
+
rm_model = convert_lora_to_linear_layer(rm_model)
|
| 382 |
+
|
| 383 |
+
if args.global_rank == 0:
|
| 384 |
+
# save_hf_format(rm_model, tokenizer, args)
|
| 385 |
+
# 因为Qwen3大模型是以safetensor格式保存的, 所以需要重写模型保存的代码
|
| 386 |
+
save_hf_format_safetensors(rm_model, tokenizer, args)
|
| 387 |
+
if args.zero_stage == 3:
|
| 388 |
+
save_zero_three_model(rm_model,
|
| 389 |
+
args.global_rank,
|
| 390 |
+
args.output_dir,
|
| 391 |
+
zero_stage=args.zero_stage)
|
| 392 |
+
|
| 393 |
+
if __name__ == '__main__':
|
| 394 |
+
main()
|
RM-EN-01-30-2026/code/model_utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# DeepSpeed Team
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoConfig,
|
| 12 |
+
AutoModel,
|
| 13 |
+
)
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from transformers.integrations import HfDeepSpeedConfig
|
| 16 |
+
|
| 17 |
+
from .reward_model import RewardModel
|
| 18 |
+
from ..utils import load_state_dict_into_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def configure_dropout(model_config, dropout):
|
| 22 |
+
if dropout is not None:
|
| 23 |
+
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
|
| 24 |
+
'activation_dropout'):
|
| 25 |
+
if hasattr(model_config, key):
|
| 26 |
+
print(f"Setting model_config.{key} to {dropout}")
|
| 27 |
+
setattr(model_config, key, dropout)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def causal_lm_model_to_fp32_loss(model):
|
| 31 |
+
""" Convert CausalLM model to calculate loss in fp32 """
|
| 32 |
+
|
| 33 |
+
def causal_lm_forward(
|
| 34 |
+
input_ids=None,
|
| 35 |
+
past_key_values=None,
|
| 36 |
+
attention_mask=None,
|
| 37 |
+
head_mask=None,
|
| 38 |
+
inputs_embeds=None,
|
| 39 |
+
labels=None,
|
| 40 |
+
use_cache=None,
|
| 41 |
+
output_attentions=None,
|
| 42 |
+
output_hidden_states=None,
|
| 43 |
+
return_dict=None,
|
| 44 |
+
**deprecated_arguments,
|
| 45 |
+
):
|
| 46 |
+
kwargs = dict() if model.config.model_type == "llama" else dict(
|
| 47 |
+
head_mask=head_mask)
|
| 48 |
+
output = model.__original_forward__(
|
| 49 |
+
input_ids=input_ids,
|
| 50 |
+
past_key_values=past_key_values,
|
| 51 |
+
attention_mask=attention_mask,
|
| 52 |
+
inputs_embeds=inputs_embeds,
|
| 53 |
+
labels=None,
|
| 54 |
+
use_cache=use_cache,
|
| 55 |
+
output_attentions=output_attentions,
|
| 56 |
+
output_hidden_states=output_hidden_states,
|
| 57 |
+
return_dict=return_dict,
|
| 58 |
+
**kwargs)
|
| 59 |
+
|
| 60 |
+
return_dict = isinstance(output, dict)
|
| 61 |
+
lm_logits = output.logits if return_dict else output[0]
|
| 62 |
+
loss = None
|
| 63 |
+
if labels is not None:
|
| 64 |
+
# move labels to correct device to enable model parallelism
|
| 65 |
+
labels = labels.to(lm_logits.device)
|
| 66 |
+
# Shift so that tokens < n predict n
|
| 67 |
+
shift_logits = lm_logits[..., :-1, :].float().contiguous()
|
| 68 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 69 |
+
batch_size, seq_length, vocab_size = shift_logits.shape
|
| 70 |
+
# Flatten the tokens
|
| 71 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 72 |
+
loss = loss_fct(
|
| 73 |
+
shift_logits.view(batch_size * seq_length, vocab_size),
|
| 74 |
+
shift_labels.view(batch_size * seq_length))
|
| 75 |
+
|
| 76 |
+
if not return_dict:
|
| 77 |
+
# re-pack output with fp32 loss
|
| 78 |
+
return ((loss, ) + output) if loss is not None else output
|
| 79 |
+
|
| 80 |
+
output.loss = loss
|
| 81 |
+
return output
|
| 82 |
+
|
| 83 |
+
model.__original_forward__ = model.forward
|
| 84 |
+
model.forward = causal_lm_forward
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_hf_model(model_class,
|
| 88 |
+
model_name_or_path,
|
| 89 |
+
tokenizer,
|
| 90 |
+
ds_config=None,
|
| 91 |
+
rlhf_training=False,
|
| 92 |
+
dropout=None):
|
| 93 |
+
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
| 94 |
+
configure_dropout(model_config, dropout)
|
| 95 |
+
|
| 96 |
+
# Note: dschf is defined in function scope to avoid global effects
|
| 97 |
+
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
|
| 98 |
+
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
|
| 99 |
+
dschf = HfDeepSpeedConfig(ds_config)
|
| 100 |
+
else:
|
| 101 |
+
dschf = None
|
| 102 |
+
if rlhf_training:
|
| 103 |
+
# the weight loading is handled by create critic model
|
| 104 |
+
with no_init_weights():
|
| 105 |
+
model = model_class.from_config(model_config)
|
| 106 |
+
else:
|
| 107 |
+
from transformers import AutoModelForCausalLM as _AutoModel
|
| 108 |
+
model = _AutoModel.from_pretrained(
|
| 109 |
+
model_name_or_path,
|
| 110 |
+
trust_remote_code=True,
|
| 111 |
+
torch_dtype="auto",
|
| 112 |
+
device_map=None)
|
| 113 |
+
|
| 114 |
+
model.config.end_token_id = tokenizer.eos_token_id
|
| 115 |
+
model.config.pad_token_id = model.config.eos_token_id
|
| 116 |
+
model.resize_token_embeddings(int(
|
| 117 |
+
8 *
|
| 118 |
+
math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
|
| 119 |
+
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
def create_critic_model(model_name_or_path,
|
| 123 |
+
tokenizer,
|
| 124 |
+
ds_config,
|
| 125 |
+
num_padding_at_beginning=0,
|
| 126 |
+
rlhf_training=False,
|
| 127 |
+
disable_dropout=False,
|
| 128 |
+
zero_stage=0):
|
| 129 |
+
start = time.time()
|
| 130 |
+
# 创建critic_model, 本质上也是调用上面的create_hf_model()函数
|
| 131 |
+
# 使用 AutoModelForCausalLM 加载,然后提取 .model(基础 transformer)
|
| 132 |
+
from transformers import AutoModelForCausalLM
|
| 133 |
+
full_model = create_hf_model(AutoModelForCausalLM, model_name_or_path, tokenizer,
|
| 134 |
+
ds_config, rlhf_training, disable_dropout)
|
| 135 |
+
# 提取基础 transformer 部分(返回 hidden_states 而非 logits)
|
| 136 |
+
if hasattr(full_model, 'model'):
|
| 137 |
+
critic_model = full_model.model # Qwen3, LLaMA 等
|
| 138 |
+
elif hasattr(full_model, 'transformer'):
|
| 139 |
+
critic_model = full_model.transformer # GPT-2 等
|
| 140 |
+
else:
|
| 141 |
+
critic_model = full_model
|
| 142 |
+
end = time.time()
|
| 143 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 可能有报错, 可以暂时先注释下面两行即可
|
| 144 |
+
if torch.distributed.get_rank() == 0:
|
| 145 |
+
print(f"> Creating model from_config took {end - start} seconds")
|
| 146 |
+
|
| 147 |
+
critic_model = RewardModel(critic_model,
|
| 148 |
+
tokenizer,
|
| 149 |
+
num_padding_at_beginning=num_padding_at_beginning)
|
| 150 |
+
|
| 151 |
+
if rlhf_training:
|
| 152 |
+
# load critic model from checkpoint
|
| 153 |
+
if not os.path.isdir(model_name_or_path):
|
| 154 |
+
model_name_or_path = snapshot_download(model_name_or_path)
|
| 155 |
+
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
|
| 156 |
+
assert os.path.exists(model_ckpt_path), f"Cannot find model checkpoint at {model_ckpt_path}"
|
| 157 |
+
|
| 158 |
+
start = time.time()
|
| 159 |
+
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
|
| 160 |
+
end = time.time()
|
| 161 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
|
| 162 |
+
if torch.distributed.get_rank() == 0:
|
| 163 |
+
print(f"> torch.load took {end - start} seconds")
|
| 164 |
+
|
| 165 |
+
# load critic model from checkpoint with zero-stage 3 compatibility
|
| 166 |
+
# this functionality may be moved to DS checkpoint load API in future
|
| 167 |
+
start = time.time()
|
| 168 |
+
load_state_dict_into_model(critic_model,
|
| 169 |
+
model_ckpt_state_dict,
|
| 170 |
+
"",
|
| 171 |
+
zero_stage=zero_stage)
|
| 172 |
+
end = time.time()
|
| 173 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
|
| 174 |
+
if torch.distributed.get_rank() == 0:
|
| 175 |
+
print(f"> Loading model state dict took {end - start} seconds")
|
| 176 |
+
|
| 177 |
+
return critic_model
|
RM-EN-01-30-2026/code/raw_datasets.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
from datasets import DatasetDict
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
# DeepSpeed Team
|
| 7 |
+
from datasets import load_dataset, load_from_disk
|
| 8 |
+
from torch.utils.data import Subset
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# The template prompt dataset class that all new dataset porting needs to
|
| 13 |
+
# follow in order to have a unified API and unified data format.
|
| 14 |
+
class PromptRawDataset(object):
|
| 15 |
+
|
| 16 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 17 |
+
self.output_path = output_path
|
| 18 |
+
self.seed = seed
|
| 19 |
+
self.local_rank = local_rank
|
| 20 |
+
#if os.path.exists(dataset_name):
|
| 21 |
+
# self.raw_datasets = load_from_disk(dataset_name)
|
| 22 |
+
if not dataset_name == 'local/jsonfile':
|
| 23 |
+
#self.raw_datasets = load_dataset(dataset_name)
|
| 24 |
+
self.raw_datasets = None
|
| 25 |
+
|
| 26 |
+
def get_train_data(self):
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
def get_eval_data(self):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
|
| 33 |
+
def get_prompt(self, sample):
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# The chosen response should be in the format of: " " + actual_response_sentence
|
| 37 |
+
def get_chosen(self, sample):
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# The rejected response should be in the format of: " " + actual_response_sentence
|
| 41 |
+
# If the dataset does not have rejected response, return None
|
| 42 |
+
def get_rejected(self, sample):
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
def get_prompt_and_chosen(self, sample):
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
def get_prompt_and_rejected(self, sample):
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# English dataset
|
| 53 |
+
class DahoasRmstaticDataset(PromptRawDataset):
|
| 54 |
+
|
| 55 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 56 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 57 |
+
self.dataset_name = "Dahoas/rm-static"
|
| 58 |
+
self.dataset_name_clean = "Dahoas_rm_static"
|
| 59 |
+
|
| 60 |
+
def get_train_data(self):
|
| 61 |
+
return self.raw_datasets["train"]
|
| 62 |
+
|
| 63 |
+
def get_eval_data(self):
|
| 64 |
+
return self.raw_datasets["test"]
|
| 65 |
+
|
| 66 |
+
def get_prompt(self, sample):
|
| 67 |
+
return sample['prompt']
|
| 68 |
+
|
| 69 |
+
def get_chosen(self, sample):
|
| 70 |
+
return sample['chosen']
|
| 71 |
+
|
| 72 |
+
def get_rejected(self, sample):
|
| 73 |
+
return sample['rejected']
|
| 74 |
+
|
| 75 |
+
def get_prompt_and_chosen(self, sample):
|
| 76 |
+
return sample['prompt'] + sample['chosen']
|
| 77 |
+
|
| 78 |
+
def get_prompt_and_rejected(self, sample):
|
| 79 |
+
return sample['prompt'] + sample['rejected']
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# English dataset
|
| 83 |
+
class DahoasFullhhrlhfDataset(PromptRawDataset):
|
| 84 |
+
|
| 85 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 86 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 87 |
+
self.dataset_name = "Dahoas/full-hh-rlhf"
|
| 88 |
+
self.dataset_name_clean = "Dahoas_full_hh_rlhf"
|
| 89 |
+
|
| 90 |
+
def get_train_data(self):
|
| 91 |
+
return self.raw_datasets["train"]
|
| 92 |
+
|
| 93 |
+
def get_eval_data(self):
|
| 94 |
+
return self.raw_datasets["test"]
|
| 95 |
+
|
| 96 |
+
def get_prompt(self, sample):
|
| 97 |
+
return sample['prompt']
|
| 98 |
+
|
| 99 |
+
def get_chosen(self, sample):
|
| 100 |
+
return sample['chosen']
|
| 101 |
+
|
| 102 |
+
def get_rejected(self, sample):
|
| 103 |
+
return sample['rejected']
|
| 104 |
+
|
| 105 |
+
def get_prompt_and_chosen(self, sample):
|
| 106 |
+
return sample['prompt'] + sample['chosen']
|
| 107 |
+
|
| 108 |
+
def get_prompt_and_rejected(self, sample):
|
| 109 |
+
return sample['prompt'] + sample['rejected']
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# English dataset
|
| 113 |
+
class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):
|
| 114 |
+
|
| 115 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 116 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 117 |
+
self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
|
| 118 |
+
self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
|
| 119 |
+
|
| 120 |
+
def get_train_data(self):
|
| 121 |
+
from .data_utils import get_raw_dataset_split_index
|
| 122 |
+
dataset = self.raw_datasets["train"]
|
| 123 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 124 |
+
self.dataset_name_clean,
|
| 125 |
+
self.seed, "train_eval", "9,1", 0,
|
| 126 |
+
len(dataset))
|
| 127 |
+
dataset = Subset(dataset, index)
|
| 128 |
+
return dataset
|
| 129 |
+
|
| 130 |
+
def get_eval_data(self):
|
| 131 |
+
from .data_utils import get_raw_dataset_split_index
|
| 132 |
+
dataset = self.raw_datasets["train"]
|
| 133 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 134 |
+
self.dataset_name_clean,
|
| 135 |
+
self.seed, "train_eval", "9,1", 1,
|
| 136 |
+
len(dataset))
|
| 137 |
+
dataset = Subset(dataset, index)
|
| 138 |
+
return dataset
|
| 139 |
+
|
| 140 |
+
def get_prompt(self, sample):
|
| 141 |
+
return " Human: " + sample['prompt'] + " Assistant:"
|
| 142 |
+
|
| 143 |
+
def get_chosen(self, sample):
|
| 144 |
+
return " " + sample['chosen']
|
| 145 |
+
|
| 146 |
+
def get_rejected(self, sample):
|
| 147 |
+
return " " + sample['rejected']
|
| 148 |
+
|
| 149 |
+
def get_prompt_and_chosen(self, sample):
|
| 150 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
|
| 151 |
+
|
| 152 |
+
def get_prompt_and_rejected(self, sample):
|
| 153 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample[
|
| 154 |
+
'rejected']
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# English dataset
|
| 158 |
+
class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):
|
| 159 |
+
|
| 160 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 161 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 162 |
+
self.dataset_name = "yitingxie/rlhf-reward-datasets"
|
| 163 |
+
self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
|
| 164 |
+
|
| 165 |
+
def get_train_data(self):
|
| 166 |
+
return self.raw_datasets["train"]
|
| 167 |
+
|
| 168 |
+
def get_eval_data(self):
|
| 169 |
+
return self.raw_datasets["test"]
|
| 170 |
+
|
| 171 |
+
def get_prompt(self, sample):
|
| 172 |
+
return sample['prompt'] + "Assistant:"
|
| 173 |
+
|
| 174 |
+
def get_chosen(self, sample):
|
| 175 |
+
return sample['chosen'].split("Assistant:")[-1]
|
| 176 |
+
|
| 177 |
+
def get_rejected(self, sample):
|
| 178 |
+
return sample['rejected'].split("Assistant:")[-1]
|
| 179 |
+
|
| 180 |
+
def get_prompt_and_chosen(self, sample):
|
| 181 |
+
return sample['prompt'] + sample['chosen']
|
| 182 |
+
|
| 183 |
+
def get_prompt_and_rejected(self, sample):
|
| 184 |
+
return sample['prompt'] + sample['rejected']
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# English dataset
|
| 188 |
+
class OpenaiWebgptcomparisonsDataset(PromptRawDataset):
|
| 189 |
+
|
| 190 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 191 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 192 |
+
self.dataset_name = "openai/webgpt_comparisons"
|
| 193 |
+
self.dataset_name_clean = "openai_webgpt_comparisons"
|
| 194 |
+
|
| 195 |
+
def get_train_data(self):
|
| 196 |
+
from .data_utils import get_raw_dataset_split_index
|
| 197 |
+
dataset = self.raw_datasets["train"]
|
| 198 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 199 |
+
self.dataset_name_clean,
|
| 200 |
+
self.seed, "train_eval", "9,1", 0,
|
| 201 |
+
len(dataset))
|
| 202 |
+
dataset = Subset(dataset, index)
|
| 203 |
+
return dataset
|
| 204 |
+
|
| 205 |
+
def get_eval_data(self):
|
| 206 |
+
from .data_utils import get_raw_dataset_split_index
|
| 207 |
+
dataset = self.raw_datasets["train"]
|
| 208 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 209 |
+
self.dataset_name_clean,
|
| 210 |
+
self.seed, "train_eval", "9,1", 1,
|
| 211 |
+
len(dataset))
|
| 212 |
+
dataset = Subset(dataset, index)
|
| 213 |
+
return dataset
|
| 214 |
+
|
| 215 |
+
def get_prompt(self, sample):
|
| 216 |
+
return " Human: " + sample['question']['full_text'] + " Assistant:"
|
| 217 |
+
|
| 218 |
+
def get_chosen(self, sample):
|
| 219 |
+
if float(sample['score_0']) >= float(sample['score_1']):
|
| 220 |
+
response = sample['answer_0']
|
| 221 |
+
else:
|
| 222 |
+
response = sample['answer_1']
|
| 223 |
+
# This data has citation square brackets and numbers (e.g., "[1]").
|
| 224 |
+
# Right now we are not doing browser-assisted finetuning, thus we
|
| 225 |
+
# remove these citations to avoid confusing the model.
|
| 226 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 227 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 228 |
+
return " " + response
|
| 229 |
+
|
| 230 |
+
def get_rejected(self, sample):
|
| 231 |
+
if float(sample['score_0']) < float(sample['score_1']):
|
| 232 |
+
response = sample['answer_0']
|
| 233 |
+
else:
|
| 234 |
+
response = sample['answer_1']
|
| 235 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 236 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 237 |
+
return " " + response
|
| 238 |
+
|
| 239 |
+
def get_prompt_and_chosen(self, sample):
|
| 240 |
+
if float(sample['score_0']) >= float(sample['score_1']):
|
| 241 |
+
response = sample['answer_0']
|
| 242 |
+
else:
|
| 243 |
+
response = sample['answer_1']
|
| 244 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 245 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 246 |
+
return " Human: " + sample['question'][
|
| 247 |
+
'full_text'] + " Assistant: " + response
|
| 248 |
+
|
| 249 |
+
def get_prompt_and_rejected(self, sample):
|
| 250 |
+
if float(sample['score_0']) < float(sample['score_1']):
|
| 251 |
+
response = sample['answer_0']
|
| 252 |
+
else:
|
| 253 |
+
response = sample['answer_1']
|
| 254 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 255 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 256 |
+
return " Human: " + sample['question'][
|
| 257 |
+
'full_text'] + " Assistant: " + response
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# English dataset
|
| 261 |
+
class StanfordnlpSHPDataset(PromptRawDataset):
|
| 262 |
+
|
| 263 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 264 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 265 |
+
self.dataset_name = "stanfordnlp/SHP"
|
| 266 |
+
self.dataset_name_clean = "stanfordnlp_SHP"
|
| 267 |
+
|
| 268 |
+
def get_train_data(self):
|
| 269 |
+
return self.raw_datasets["train"]
|
| 270 |
+
|
| 271 |
+
def get_eval_data(self):
|
| 272 |
+
return self.raw_datasets["validation"]
|
| 273 |
+
|
| 274 |
+
def get_prompt(self, sample):
|
| 275 |
+
return " Human: " + sample['history'] + " Assistant:"
|
| 276 |
+
|
| 277 |
+
def get_chosen(self, sample):
|
| 278 |
+
if int(sample["labels"]) == 1:
|
| 279 |
+
response = sample["human_ref_A"]
|
| 280 |
+
else:
|
| 281 |
+
response = sample["human_ref_B"]
|
| 282 |
+
return " " + response
|
| 283 |
+
|
| 284 |
+
def get_rejected(self, sample):
|
| 285 |
+
if int(sample["labels"]) == 1:
|
| 286 |
+
response = sample["human_ref_B"]
|
| 287 |
+
else:
|
| 288 |
+
response = sample["human_ref_A"]
|
| 289 |
+
return " " + response
|
| 290 |
+
|
| 291 |
+
def get_prompt_and_chosen(self, sample):
|
| 292 |
+
if int(sample["labels"]) == 1:
|
| 293 |
+
response = sample["human_ref_A"]
|
| 294 |
+
else:
|
| 295 |
+
response = sample["human_ref_B"]
|
| 296 |
+
return " Human: " + sample['history'] + " Assistant: " + response
|
| 297 |
+
|
| 298 |
+
def get_prompt_and_rejected(self, sample):
|
| 299 |
+
if int(sample["labels"]) == 1:
|
| 300 |
+
response = sample["human_ref_B"]
|
| 301 |
+
else:
|
| 302 |
+
response = sample["human_ref_A"]
|
| 303 |
+
return " Human: " + sample['history'] + " Assistant: " + response
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# English dataset
|
| 307 |
+
class PvduySharegptalpacaoavicunaformatDataset(PromptRawDataset):
|
| 308 |
+
|
| 309 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 310 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 311 |
+
self.dataset_name = "pvduy/sharegpt_alpaca_oa_vicuna_format"
|
| 312 |
+
self.dataset_name_clean = "pvduy_sharegpt_alpaca_oa_vicuna_format"
|
| 313 |
+
|
| 314 |
+
def get_train_data(self):
|
| 315 |
+
return self.raw_datasets["train"]
|
| 316 |
+
|
| 317 |
+
def get_eval_data(self):
|
| 318 |
+
return self.raw_datasets["test"]
|
| 319 |
+
|
| 320 |
+
def get_prompt(self, sample):
|
| 321 |
+
if sample['prompt'] is not None and len(sample['prompt']) > 0:
|
| 322 |
+
return sample['prompt'].replace("USER", "Human").replace(
|
| 323 |
+
"ASSISTANT", "Assistant")
|
| 324 |
+
return None
|
| 325 |
+
|
| 326 |
+
def get_chosen(self, sample):
|
| 327 |
+
if sample['label'] is not None and len(sample['label']) > 0:
|
| 328 |
+
return " " + sample['label']
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
def get_rejected(self, sample):
|
| 332 |
+
print(
|
| 333 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 334 |
+
)
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
def get_prompt_and_chosen(self, sample):
|
| 338 |
+
if sample['prompt'] is not None and sample['label'] is not None and len(
|
| 339 |
+
sample['prompt']) > 0 and len(sample['label']) > 0:
|
| 340 |
+
return sample['prompt'].replace("USER", "Human").replace(
|
| 341 |
+
"ASSISTANT", "Assistant") + " " + sample['label']
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
def get_prompt_and_rejected(self, sample):
|
| 345 |
+
print(
|
| 346 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 347 |
+
)
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class LocalJsonFileDataset(PromptRawDataset):
|
| 352 |
+
|
| 353 |
+
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
|
| 354 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 355 |
+
self.dataset_name = "local/jsonfile"
|
| 356 |
+
self.dataset_name_clean = "jsonfile"
|
| 357 |
+
self.raw_datasets = load_dataset('json',
|
| 358 |
+
data_files={
|
| 359 |
+
"train":
|
| 360 |
+
chat_path + '/data/train.json',
|
| 361 |
+
"eval":
|
| 362 |
+
chat_path + '/data/eval.json'
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
def get_train_data(self):
|
| 366 |
+
if self.raw_datasets['train'] is not None:
|
| 367 |
+
return self.raw_datasets['train']
|
| 368 |
+
return None
|
| 369 |
+
|
| 370 |
+
def get_eval_data(self):
|
| 371 |
+
if self.raw_datasets['eval'] is not None:
|
| 372 |
+
return self.raw_datasets['eval']
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
|
| 376 |
+
def get_prompt(self, sample):
|
| 377 |
+
if sample['prompt'] is not None:
|
| 378 |
+
return " " + sample['prompt']
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
# The chosen response should be in the format of: " " + actual_response_sentence
|
| 382 |
+
def get_chosen(self, sample):
|
| 383 |
+
if sample['chosen'] is not None:
|
| 384 |
+
return " " + sample['chosen']
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
# The rejected response should be in the format of: " " + actual_response_sentence
|
| 388 |
+
# If the dataset does not have rejected response, return None
|
| 389 |
+
def get_rejected(self, sample):
|
| 390 |
+
if sample['rejected'] is not None:
|
| 391 |
+
return " " + sample['rejected']
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
def get_prompt_and_chosen(self, sample):
|
| 395 |
+
if sample['prompt'] is not None and sample['chosen'] is not None:
|
| 396 |
+
return " " + sample['prompt'] + " " + sample['chosen']
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
def get_prompt_and_rejected(self, sample):
|
| 400 |
+
if sample['prompt'] is not None and sample['rejected'] is not None:
|
| 401 |
+
return " " + sample['prompt'] + " " + sample['rejected']
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
# Chinese dataset
|
| 406 |
+
class Wangrui6ZhihuKOLDataset(PromptRawDataset):
|
| 407 |
+
|
| 408 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 409 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 410 |
+
self.dataset_name = "wangrui6/Zhihu-KOL"
|
| 411 |
+
self.dataset_name_clean = "wangrui6_Zhihu_KOL"
|
| 412 |
+
|
| 413 |
+
def get_train_data(self):
|
| 414 |
+
from .data_utils import get_raw_dataset_split_index
|
| 415 |
+
dataset = self.raw_datasets["train"]
|
| 416 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 417 |
+
self.dataset_name_clean,
|
| 418 |
+
self.seed, "train_eval", "9,1", 0,
|
| 419 |
+
len(dataset))
|
| 420 |
+
dataset = Subset(dataset, index)
|
| 421 |
+
return dataset
|
| 422 |
+
|
| 423 |
+
def get_eval_data(self):
|
| 424 |
+
from .data_utils import get_raw_dataset_split_index
|
| 425 |
+
dataset = self.raw_datasets["train"]
|
| 426 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 427 |
+
self.dataset_name_clean,
|
| 428 |
+
self.seed, "train_eval", "9,1", 1,
|
| 429 |
+
len(dataset))
|
| 430 |
+
dataset = Subset(dataset, index)
|
| 431 |
+
return dataset
|
| 432 |
+
|
| 433 |
+
def get_prompt(self, sample):
|
| 434 |
+
if sample['INSTRUCTION'] is not None:
|
| 435 |
+
return " Human: " + sample['INSTRUCTION'] + " Assistant:"
|
| 436 |
+
return None
|
| 437 |
+
|
| 438 |
+
def get_chosen(self, sample):
|
| 439 |
+
if sample['RESPONSE'] is not None:
|
| 440 |
+
return " " + sample['RESPONSE']
|
| 441 |
+
return None
|
| 442 |
+
|
| 443 |
+
def get_rejected(self, sample):
|
| 444 |
+
print(
|
| 445 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 446 |
+
)
|
| 447 |
+
return None
|
| 448 |
+
|
| 449 |
+
def get_prompt_and_chosen(self, sample):
|
| 450 |
+
if sample['INSTRUCTION'] is not None and sample['RESPONSE'] is not None:
|
| 451 |
+
return " Human: " + sample[
|
| 452 |
+
'INSTRUCTION'] + " Assistant: " + sample['RESPONSE']
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
def get_prompt_and_rejected(self, sample):
|
| 456 |
+
print(
|
| 457 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 458 |
+
)
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Chinese dataset
|
| 463 |
+
class CohereMiraclzhqueries2212Dataset(PromptRawDataset):
|
| 464 |
+
|
| 465 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 466 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 467 |
+
self.dataset_name = "Cohere/miracl-zh-queries-22-12"
|
| 468 |
+
self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
|
| 469 |
+
|
| 470 |
+
def get_train_data(self):
|
| 471 |
+
return self.raw_datasets["train"]
|
| 472 |
+
|
| 473 |
+
def get_eval_data(self):
|
| 474 |
+
return self.raw_datasets["dev"]
|
| 475 |
+
|
| 476 |
+
def get_prompt(self, sample):
|
| 477 |
+
return " Human: " + sample['query'] + " Assistant:"
|
| 478 |
+
|
| 479 |
+
def get_chosen(self, sample):
|
| 480 |
+
return " " + sample['positive_passages'][0]['text']
|
| 481 |
+
|
| 482 |
+
def get_rejected(self, sample):
|
| 483 |
+
return " " + sample['negative_passages'][0]['text']
|
| 484 |
+
|
| 485 |
+
def get_prompt_and_chosen(self, sample):
|
| 486 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 487 |
+
'positive_passages'][0]['text']
|
| 488 |
+
|
| 489 |
+
def get_prompt_and_rejected(self, sample):
|
| 490 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 491 |
+
'negative_passages'][0]['text']
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# Chinese dataset
|
| 495 |
+
class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):
|
| 496 |
+
|
| 497 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 498 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 499 |
+
self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
|
| 500 |
+
self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
|
| 501 |
+
|
| 502 |
+
def get_train_data(self):
|
| 503 |
+
from .data_utils import get_raw_dataset_split_index
|
| 504 |
+
dataset = self.raw_datasets["train"]
|
| 505 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 506 |
+
self.dataset_name_clean,
|
| 507 |
+
self.seed, "train_eval", "9,1", 0,
|
| 508 |
+
len(dataset))
|
| 509 |
+
dataset = Subset(dataset, index)
|
| 510 |
+
return dataset
|
| 511 |
+
|
| 512 |
+
def get_eval_data(self):
|
| 513 |
+
from .data_utils import get_raw_dataset_split_index
|
| 514 |
+
dataset = self.raw_datasets["train"]
|
| 515 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 516 |
+
self.dataset_name_clean,
|
| 517 |
+
self.seed, "train_eval", "9,1", 1,
|
| 518 |
+
len(dataset))
|
| 519 |
+
dataset = Subset(dataset, index)
|
| 520 |
+
return dataset
|
| 521 |
+
|
| 522 |
+
def get_prompt(self, sample):
|
| 523 |
+
if sample['question'] is not None:
|
| 524 |
+
return " Human: " + sample['question'] + " Assistant:"
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
def get_chosen(self, sample):
|
| 528 |
+
if sample['human_answers'][0] is not None:
|
| 529 |
+
return " " + sample['human_answers'][0]
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
def get_rejected(self, sample):
|
| 533 |
+
print(
|
| 534 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 535 |
+
)
|
| 536 |
+
return None
|
| 537 |
+
|
| 538 |
+
def get_prompt_and_chosen(self, sample):
|
| 539 |
+
if sample['question'] is not None and sample['human_answers'][
|
| 540 |
+
0] is not None:
|
| 541 |
+
return " Human: " + sample['question'] + " Assistant: " + sample[
|
| 542 |
+
'human_answers'][0]
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
def get_prompt_and_rejected(self, sample):
|
| 546 |
+
print(
|
| 547 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 548 |
+
)
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# Chinese dataset
|
| 553 |
+
class MkqaChineseDataset(PromptRawDataset):
|
| 554 |
+
|
| 555 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 556 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 557 |
+
self.dataset_name = "mkqa-Chinese"
|
| 558 |
+
self.dataset_name_clean = "mkqa"
|
| 559 |
+
|
| 560 |
+
def get_train_data(self):
|
| 561 |
+
from .data_utils import get_raw_dataset_split_index
|
| 562 |
+
dataset = self.raw_datasets["train"]
|
| 563 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 564 |
+
self.dataset_name_clean,
|
| 565 |
+
self.seed, "train_eval", "9,1", 0,
|
| 566 |
+
len(dataset))
|
| 567 |
+
dataset = Subset(dataset, index)
|
| 568 |
+
return dataset
|
| 569 |
+
|
| 570 |
+
def get_eval_data(self):
|
| 571 |
+
from .data_utils import get_raw_dataset_split_index
|
| 572 |
+
dataset = self.raw_datasets["train"]
|
| 573 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 574 |
+
self.dataset_name_clean,
|
| 575 |
+
self.seed, "train_eval", "9,1", 1,
|
| 576 |
+
len(dataset))
|
| 577 |
+
dataset = Subset(dataset, index)
|
| 578 |
+
return dataset
|
| 579 |
+
|
| 580 |
+
def get_prompt(self, sample):
|
| 581 |
+
if sample['queries']['zh_cn'] is not None:
|
| 582 |
+
return " Human: " + sample['queries']['zh_cn'] + " Assistant:"
|
| 583 |
+
return None
|
| 584 |
+
|
| 585 |
+
def get_chosen(self, sample):
|
| 586 |
+
if sample['answers']['zh_cn'][0]['text'] is not None:
|
| 587 |
+
return " " + sample['answers']['zh_cn'][0]['text']
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
def get_rejected(self, sample):
|
| 591 |
+
print(
|
| 592 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 593 |
+
)
|
| 594 |
+
return None
|
| 595 |
+
|
| 596 |
+
def get_prompt_and_chosen(self, sample):
|
| 597 |
+
if sample['queries']['zh_cn'] is not None and sample['answers'][
|
| 598 |
+
'zh_cn'][0]['text'] is not None:
|
| 599 |
+
return " Human: " + sample['queries'][
|
| 600 |
+
'zh_cn'] + " Assistant: " + sample['answers']['zh_cn'][0][
|
| 601 |
+
'text']
|
| 602 |
+
return None
|
| 603 |
+
|
| 604 |
+
def get_prompt_and_rejected(self, sample):
|
| 605 |
+
print(
|
| 606 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 607 |
+
)
|
| 608 |
+
return None
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Japanese dataset
|
| 612 |
+
class MkqaJapaneseDataset(PromptRawDataset):
|
| 613 |
+
|
| 614 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 615 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 616 |
+
self.dataset_name = "mkqa-Japanese"
|
| 617 |
+
self.dataset_name_clean = "mkqa"
|
| 618 |
+
|
| 619 |
+
def get_train_data(self):
|
| 620 |
+
from .data_utils import get_raw_dataset_split_index
|
| 621 |
+
dataset = self.raw_datasets["train"]
|
| 622 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 623 |
+
self.dataset_name_clean,
|
| 624 |
+
self.seed, "train_eval", "9,1", 0,
|
| 625 |
+
len(dataset))
|
| 626 |
+
dataset = Subset(dataset, index)
|
| 627 |
+
return dataset
|
| 628 |
+
|
| 629 |
+
def get_eval_data(self):
|
| 630 |
+
from .data_utils import get_raw_dataset_split_index
|
| 631 |
+
dataset = self.raw_datasets["train"]
|
| 632 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 633 |
+
self.dataset_name_clean,
|
| 634 |
+
self.seed, "train_eval", "9,1", 1,
|
| 635 |
+
len(dataset))
|
| 636 |
+
dataset = Subset(dataset, index)
|
| 637 |
+
return dataset
|
| 638 |
+
|
| 639 |
+
def get_prompt(self, sample):
|
| 640 |
+
if sample['queries']['ja'] is not None:
|
| 641 |
+
return " Human: " + sample['queries']['ja'] + " Assistant:"
|
| 642 |
+
return None
|
| 643 |
+
|
| 644 |
+
def get_chosen(self, sample):
|
| 645 |
+
if sample['answers']['ja'][0]['text'] is not None:
|
| 646 |
+
return " " + sample['answers']['ja'][0]['text']
|
| 647 |
+
return None
|
| 648 |
+
|
| 649 |
+
def get_rejected(self, sample):
|
| 650 |
+
print(
|
| 651 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 652 |
+
)
|
| 653 |
+
return None
|
| 654 |
+
|
| 655 |
+
def get_prompt_and_chosen(self, sample):
|
| 656 |
+
if sample['queries']['ja'] is not None and sample['answers']['ja'][0][
|
| 657 |
+
'text'] is not None:
|
| 658 |
+
return " Human: " + sample['queries'][
|
| 659 |
+
'ja'] + " Assistant: " + sample['answers']['ja'][0]['text']
|
| 660 |
+
return None
|
| 661 |
+
|
| 662 |
+
def get_prompt_and_rejected(self, sample):
|
| 663 |
+
print(
|
| 664 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 665 |
+
)
|
| 666 |
+
return None
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# Japanese dataset
|
| 670 |
+
class CohereMiracljaqueries2212Dataset(PromptRawDataset):
|
| 671 |
+
|
| 672 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 673 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 674 |
+
self.dataset_name = "Cohere/miracl-ja-queries-22-12"
|
| 675 |
+
self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
|
| 676 |
+
|
| 677 |
+
def get_train_data(self):
|
| 678 |
+
return self.raw_datasets["train"]
|
| 679 |
+
|
| 680 |
+
def get_eval_data(self):
|
| 681 |
+
return self.raw_datasets["dev"]
|
| 682 |
+
|
| 683 |
+
def get_prompt(self, sample):
|
| 684 |
+
return " Human: " + sample['query'] + " Assistant:"
|
| 685 |
+
|
| 686 |
+
def get_chosen(self, sample):
|
| 687 |
+
return " " + sample['positive_passages'][0]['text']
|
| 688 |
+
|
| 689 |
+
def get_rejected(self, sample):
|
| 690 |
+
return " " + sample['negative_passages'][0]['text']
|
| 691 |
+
|
| 692 |
+
def get_prompt_and_chosen(self, sample):
|
| 693 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 694 |
+
'positive_passages'][0]['text']
|
| 695 |
+
|
| 696 |
+
def get_prompt_and_rejected(self, sample):
|
| 697 |
+
if len(sample['negative_passages']) > 0:
|
| 698 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 699 |
+
'negative_passages'][0]['text']
|
| 700 |
+
return None
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
# Japanese dataset
|
| 704 |
+
class LmqgQgjaquadDataset(PromptRawDataset):
|
| 705 |
+
|
| 706 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 707 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 708 |
+
self.dataset_name = "lmqg/qg_jaquad"
|
| 709 |
+
self.dataset_name_clean = "lmqg_qg_jaquad"
|
| 710 |
+
|
| 711 |
+
def get_train_data(self):
|
| 712 |
+
return self.raw_datasets["train"]
|
| 713 |
+
|
| 714 |
+
def get_eval_data(self):
|
| 715 |
+
return self.raw_datasets["validation"]
|
| 716 |
+
|
| 717 |
+
def get_prompt(self, sample):
|
| 718 |
+
return " Human: " + sample['question'] + " Assistant:"
|
| 719 |
+
|
| 720 |
+
def get_chosen(self, sample):
|
| 721 |
+
return " " + sample['sentence']
|
| 722 |
+
|
| 723 |
+
def get_rejected(self, sample):
|
| 724 |
+
print(
|
| 725 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 726 |
+
)
|
| 727 |
+
return None
|
| 728 |
+
|
| 729 |
+
def get_prompt_and_chosen(self, sample):
|
| 730 |
+
return " Human: " + sample['question'] + " Assistant: " + sample[
|
| 731 |
+
'sentence']
|
| 732 |
+
|
| 733 |
+
def get_prompt_and_rejected(self, sample):
|
| 734 |
+
print(
|
| 735 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 736 |
+
)
|
| 737 |
+
return None
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
# Japanese dataset
|
| 741 |
+
class LmqgQagjaquadDataset(PromptRawDataset):
|
| 742 |
+
|
| 743 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 744 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 745 |
+
self.dataset_name = "lmqg/qag_jaquad"
|
| 746 |
+
self.dataset_name_clean = "lmqg_qag_jaquad"
|
| 747 |
+
|
| 748 |
+
def get_train_data(self):
|
| 749 |
+
return self.raw_datasets["train"]
|
| 750 |
+
|
| 751 |
+
def get_eval_data(self):
|
| 752 |
+
return self.raw_datasets["validation"]
|
| 753 |
+
|
| 754 |
+
def get_prompt(self, sample):
|
| 755 |
+
return " Human: " + sample['questions'][0] + " Assistant:"
|
| 756 |
+
|
| 757 |
+
def get_chosen(self, sample):
|
| 758 |
+
return " " + sample['paragraph']
|
| 759 |
+
|
| 760 |
+
def get_rejected(self, sample):
|
| 761 |
+
print(
|
| 762 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 763 |
+
)
|
| 764 |
+
return None
|
| 765 |
+
|
| 766 |
+
def get_prompt_and_chosen(self, sample):
|
| 767 |
+
return " Human: " + sample['questions'][0] + " Assistant: " + sample[
|
| 768 |
+
'paragraph']
|
| 769 |
+
|
| 770 |
+
def get_prompt_and_rejected(self, sample):
|
| 771 |
+
print(
|
| 772 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 773 |
+
)
|
| 774 |
+
return None
|
| 775 |
+
# CustomDataset: 自定义数据集类,用于训练个性化垂直领域大模型,继承基类PromptRawDataset
|
| 776 |
+
class CustomDataset(PromptRawDataset):
|
| 777 |
+
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
|
| 778 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 779 |
+
# 个性化数据集的名字可以自定义
|
| 780 |
+
self.dataset_name = "custom"
|
| 781 |
+
self.dataset_name_clean = "custom"
|
| 782 |
+
# 设定要读取的数据集所在的绝对路径
|
| 783 |
+
train_path = chat_path + '/data/train.jsonl'
|
| 784 |
+
eval_path = chat_path + '/data/dev.jsonl'
|
| 785 |
+
# 通过DatasetDict的类封装数据, 和load_dataset()函数保持一致.
|
| 786 |
+
self.raw_datasets = DatasetDict.from_json({'train': train_path, 'eval': eval_path})
|
| 787 |
+
|
| 788 |
+
# 返回训练集数据
|
| 789 |
+
def get_train_data(self):
|
| 790 |
+
if self.raw_datasets['train'] is not None:
|
| 791 |
+
return self.raw_datasets['train']
|
| 792 |
+
return None
|
| 793 |
+
|
| 794 |
+
# 返回验证集数据
|
| 795 |
+
def get_eval_data(self):
|
| 796 |
+
if self.raw_datasets['eval'] is not None:
|
| 797 |
+
return self.raw_datasets['eval']
|
| 798 |
+
return None
|
| 799 |
+
|
| 800 |
+
# 构造prompt输入模型的格式: Human: prompt Assistant:
|
| 801 |
+
def get_prompt(self, sample):
|
| 802 |
+
if sample['prompt'] is not None:
|
| 803 |
+
return " Human: " + sample['prompt'] + " Assistant:"
|
| 804 |
+
return None
|
| 805 |
+
|
| 806 |
+
# 构造chosen输入模型的格式: chosen
|
| 807 |
+
def get_chosen(self, sample):
|
| 808 |
+
if sample['chosen'] is not None:
|
| 809 |
+
return " " + sample['chosen']
|
| 810 |
+
return None
|
| 811 |
+
|
| 812 |
+
# 构造reject输入模型的格式: reject
|
| 813 |
+
def get_rejected(self, sample):
|
| 814 |
+
if sample['rejected'] is not None:
|
| 815 |
+
return " " + sample['rejected']
|
| 816 |
+
return None
|
| 817 |
+
|
| 818 |
+
# 构造第二阶���训练Reward Model的输入模型格式: Human: prompt Assistant: chosen
|
| 819 |
+
def get_prompt_and_chosen(self, sample):
|
| 820 |
+
if sample['prompt'] is not None and sample['chosen'] is not None:
|
| 821 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
|
| 822 |
+
return None
|
| 823 |
+
|
| 824 |
+
# 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: reject
|
| 825 |
+
def get_prompt_and_rejected(self, sample):
|
| 826 |
+
if sample['prompt'] is not None and sample['rejected'] is not None:
|
| 827 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['rejected']
|
| 828 |
+
return None
|
RM-EN-01-30-2026/code/reward_model.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# DeepSpeed Team
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
# RewardModel这个类既实现了训练时用的forward() 的方法, 也实现了推理时用的forward_value()
|
| 9 |
+
class RewardModel(nn.Module):
|
| 10 |
+
def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.config = base_model.config
|
| 13 |
+
self.num_padding_at_beginning = num_padding_at_beginning
|
| 14 |
+
if hasattr(self.config, "word_embed_proj_dim"):
|
| 15 |
+
# OPT models use word_embed_proj_dim as final output
|
| 16 |
+
# 设置v_head将基于主干网络的输出特征hidden_size进行分值预测, 共输出max_seq_len个分数
|
| 17 |
+
self.v_head = nn.Linear(self.config.word_embed_proj_dim,
|
| 18 |
+
1,
|
| 19 |
+
bias=False)
|
| 20 |
+
else:
|
| 21 |
+
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
|
| 22 |
+
# 同上, 设置v_head将基于主干网络的输出特征hidden_size进行分值预测, 共输出max_seq_len个分数
|
| 23 |
+
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
|
| 24 |
+
# base_model即主干网络, 因此RM最终由1个主干网络 + 1个线性层v_head构成
|
| 25 |
+
self.rwtranrsformer = base_model
|
| 26 |
+
self.PAD_ID = tokenizer.pad_token_id
|
| 27 |
+
self.compute_fp32_loss = False
|
| 28 |
+
|
| 29 |
+
def gradient_checkpointing_enable(self):
|
| 30 |
+
self.rwtranrsformer.gradient_checkpointing_enable()
|
| 31 |
+
|
| 32 |
+
def gradient_checkpointing_disable(self):
|
| 33 |
+
self.rwtranrsformer.gradient_checkpointing_disable()
|
| 34 |
+
|
| 35 |
+
# 第2阶段调用的forward函数, 用于获取环境奖励与排序损失值 (区别于下面的forward_value()函数, 用于第3阶段推理调用)
|
| 36 |
+
def forward(self,
|
| 37 |
+
input_ids=None,
|
| 38 |
+
past_key_values=None,
|
| 39 |
+
attention_mask=None,
|
| 40 |
+
position_ids=None,
|
| 41 |
+
head_mask=None,
|
| 42 |
+
inputs_embeds=None,
|
| 43 |
+
use_cache=False):
|
| 44 |
+
loss = None
|
| 45 |
+
if self.config.model_type == "llama":
|
| 46 |
+
kwargs = dict()
|
| 47 |
+
else:
|
| 48 |
+
kwargs = dict(head_mask=head_mask)
|
| 49 |
+
# 此处的rwtransformer即为base_model基座模型, 也可以理解为主干网络
|
| 50 |
+
transformer_outputs = self.rwtranrsformer(input_ids,
|
| 51 |
+
past_key_values=past_key_values,
|
| 52 |
+
attention_mask=attention_mask,
|
| 53 |
+
inputs_embeds=inputs_embeds,
|
| 54 |
+
use_cache=use_cache,
|
| 55 |
+
**kwargs)
|
| 56 |
+
# base_model模型输出一个列表, 位置0存储最后一层的输出张量
|
| 57 |
+
# hidden_states.shape: (batch_size * 2, max_seq_len, hidden_size)
|
| 58 |
+
hidden_states = transformer_outputs[0]
|
| 59 |
+
# 通过v_head线性层映射, 将最后一个维度从hidden_size降维成1, 并直接squeeze去掉
|
| 60 |
+
# rewards.shape: (batch_size * 2, max_seq_len), 相当于为序列中每一个位置的token都预测了一个价值
|
| 61 |
+
rewards = self.v_head(hidden_states).squeeze(-1)
|
| 62 |
+
chosen_mean_scores = []
|
| 63 |
+
rejected_mean_scores = []
|
| 64 |
+
# 确认input_ids是一个二维张量
|
| 65 |
+
assert len(input_ids.shape) == 2
|
| 66 |
+
# 在data_utils.py代码中, DataCollatorReward类中, 一个batch_size的数据被组装成了两部分, 实际的batch_size大小应该是输入的一半
|
| 67 |
+
# 此处要将其一分为二, 切分成chosen部分和reject部分
|
| 68 |
+
bs = input_ids.shape[0] // 2
|
| 69 |
+
seq_len = input_ids.shape[1]
|
| 70 |
+
# 切分出前半部分的chosen, 和后半部分的rejected
|
| 71 |
+
# 4个张量的shape均为: (batch_size, max_seq_len)
|
| 72 |
+
chosen_ids = input_ids[:bs]
|
| 73 |
+
rejected_ids = input_ids[bs:]
|
| 74 |
+
chosen_rewards = rewards[:bs]
|
| 75 |
+
rejected_rewards = rewards[bs:]
|
| 76 |
+
|
| 77 |
+
# 计算Pairwise Ranking Loss
|
| 78 |
+
loss = 0
|
| 79 |
+
for i in range(bs):
|
| 80 |
+
# 取出同组chosen和rejected的token_id和分值reward
|
| 81 |
+
# chosen_id.shape: (max_seq_len, )
|
| 82 |
+
chosen_id = chosen_ids[i]
|
| 83 |
+
rejected_id = rejected_ids[i]
|
| 84 |
+
chosen_reward = chosen_rewards[i]
|
| 85 |
+
rejected_reward = rejected_rewards[i]
|
| 86 |
+
|
| 87 |
+
# 下面的代码虽然看起来复杂, 但实质上在计算一个分割点
|
| 88 |
+
# c_ind为chosen_sentence的answer后的第一个pad_token的index
|
| 89 |
+
# 例如pad_token_id = 0, sentence = [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]
|
| 90 |
+
# c_ind即为第一个pad_token的index = 6
|
| 91 |
+
c_inds = (chosen_id == self.PAD_ID).nonzero()
|
| 92 |
+
c_ind = c_inds[self.num_padding_at_beginning].item() if len(c_inds) > self.num_padding_at_beginning else seq_len
|
| 93 |
+
|
| 94 |
+
check_divergence = (chosen_id != rejected_id).nonzero()
|
| 95 |
+
# divergence_ind: 取chosen和rejected第一个不同的地方的index
|
| 96 |
+
# 可以理解为: response 中两个回答自由发挥的第1个token的index
|
| 97 |
+
if len(check_divergence) == 0:
|
| 98 |
+
end_ind = rejected_reward.size(-1)
|
| 99 |
+
divergence_ind = end_ind - 1
|
| 100 |
+
r_ind = c_ind
|
| 101 |
+
else:
|
| 102 |
+
# r_ind同理, 为reject_sentence的answer后的第一个pad_token的index
|
| 103 |
+
r_inds = (rejected_id == self.PAD_ID).nonzero()
|
| 104 |
+
r_ind = r_inds[self.num_padding_at_beginning].item() if len(r_inds) > self.num_padding_at_beginning else seq_len
|
| 105 |
+
# 较大者作为end_ind
|
| 106 |
+
end_ind = max(c_ind, r_ind)
|
| 107 |
+
divergence_ind = check_divergence[0]
|
| 108 |
+
assert divergence_ind > 0
|
| 109 |
+
|
| 110 |
+
# AI图灵君课堂 (小朱老师独家讲义)
|
| 111 |
+
# 以chosen_sentence和reject_sentence最先不同的地方为起始, 生成结束的地方为终止
|
| 112 |
+
# 取两者在这个片段的对应分值, 这部分其实就是上个代码块提及的"对齐部分"
|
| 113 |
+
'''
|
| 114 |
+
max_seq_len为10, pad_token_id为0, 有同属同个prompt的chosen_sentence和reject_sentence:
|
| 115 |
+
prompt: [1, 2, 3]
|
| 116 |
+
chosen_sentence: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0]
|
| 117 |
+
reject_sentence: [1, 2, 3, 7, 8, 0, 0, 0, 0, 0]
|
| 118 |
+
"两者answer的对齐部分", 即为"非prompt部分, 也非padding部分, 但长度要对齐":
|
| 119 |
+
chosen_truncated: [4, 5, 6]
|
| 120 |
+
reject_truncated: [7, 8, 0]
|
| 121 |
+
'''
|
| 122 |
+
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
|
| 123 |
+
r_truncated_reward = rejected_reward[divergence_ind:end_ind]
|
| 124 |
+
|
| 125 |
+
# 下面的loss计算采用了"整个对齐部分的reward"来计算成对排序损失, 但是代码中对一个对话的预测评分实际上取的是该对话文本最后一个有效token的reward,
|
| 126 |
+
# 这个DeepSpeed团队也在论文中给出了说明, 这是一个开放性的策略, 用户可以自己制定个性化的评分策略, 比如answer部分的平均reward, 序列reward再接全连接层进行聚合后的reward, 等等
|
| 127 |
+
# 取代表结束的pad token所在位置的前一个位置(最后一个有效token的位置)的分值作为参考分值
|
| 128 |
+
chosen_mean_scores.append(chosen_reward[c_ind - 1])
|
| 129 |
+
rejected_mean_scores.append(rejected_reward[r_ind - 1])
|
| 130 |
+
|
| 131 |
+
# 核心代码: 计算损失时使用了rank loss的形式, 是对chosen和rejected"对齐片段"进行计算的
|
| 132 |
+
# 计算采用了原始论文中的公式, 先计算sigmoid, 再进行log计算, 最终利用平均值作为损失值
|
| 133 |
+
# (c_truncated_reward - r_truncated_reward).shape: (truncated_seq_len,)
|
| 134 |
+
loss += -torch.nn.functional.logsigmoid(c_truncated_reward - r_truncated_reward).mean()
|
| 135 |
+
|
| 136 |
+
loss = loss / bs
|
| 137 |
+
# 将batch_size个对话的reward值进行stack堆叠, chosen_mean_scores.shape: (batch_size, )
|
| 138 |
+
chosen_mean_scores = torch.stack(chosen_mean_scores)
|
| 139 |
+
rejected_mean_scores = torch.stack(rejected_mean_scores)
|
| 140 |
+
# 模型的返回字典中包含3个字段, loss, chosen分值, rejected分值
|
| 141 |
+
return {"loss": loss,
|
| 142 |
+
"chosen_mean_scores": chosen_mean_scores,
|
| 143 |
+
"rejected_mean_scores": rejected_mean_scores}
|
| 144 |
+
|
| 145 |
+
# 第3阶段调用的推理函数-forward_value函数, 用于取到环境奖励和价值估计的方法
|
| 146 |
+
def forward_value(self,
|
| 147 |
+
input_ids=None,
|
| 148 |
+
attention_mask=None,
|
| 149 |
+
past_key_values=None,
|
| 150 |
+
position_ids=None,
|
| 151 |
+
head_mask=None,
|
| 152 |
+
inputs_embeds=None,
|
| 153 |
+
return_value_only=False,
|
| 154 |
+
prompt_length=0,
|
| 155 |
+
use_cache=False):
|
| 156 |
+
'''
|
| 157 |
+
与forward的差别在于: forward需要针对输入的chosen-rejected对计算排序损失并返回,
|
| 158 |
+
而forward_value只需要考虑一个输入, 然后返回分值.
|
| 159 |
+
说白了, forward的输入是数据对, 因为要计算数据对的排序损失,
|
| 160 |
+
而forward_value的输入是单个数据, 直接推理出其分值.
|
| 161 |
+
return_value_only: 如果设置为True, 则在计算出values(在序列中每个token的分值预测)后直接返回.
|
| 162 |
+
'''
|
| 163 |
+
if self.config.model_type == "llama":
|
| 164 |
+
kwargs = dict()
|
| 165 |
+
else:
|
| 166 |
+
kwargs = dict(head_mask=head_mask)
|
| 167 |
+
# rwtransformer即base_model, 基座模型
|
| 168 |
+
transformer_outputs = self.rwtranrsformer(input_ids,
|
| 169 |
+
past_key_values=past_key_values,
|
| 170 |
+
attention_mask=attention_mask,
|
| 171 |
+
inputs_embeds=inputs_embeds,
|
| 172 |
+
use_cache=use_cache,
|
| 173 |
+
**kwargs)
|
| 174 |
+
# [0]位置的张���即为base_model最后一层的输出张量
|
| 175 |
+
hidden_states = transformer_outputs[0]
|
| 176 |
+
# hidden_states.shape: (batch_size, max_seq_len, hidden_size)
|
| 177 |
+
# 经过线性层的映射, 在最后一个维度上, 每一个位置预测出一个分值
|
| 178 |
+
values = self.v_head(hidden_states).squeeze(-1)
|
| 179 |
+
# values.shape: (batch_size, max_seq_len)
|
| 180 |
+
|
| 181 |
+
if return_value_only:
|
| 182 |
+
return values
|
| 183 |
+
else:
|
| 184 |
+
# [0 0 0 0 prompt, answer, 0 0 0 0 ] for step 3, we have padding at the beginning
|
| 185 |
+
# [prompt, answer, 0, 0, 0, 0] this is normal
|
| 186 |
+
assert prompt_length > 1, "prompt_length must be greater than 1 to help select the end score"
|
| 187 |
+
bs = values.size(0)
|
| 188 |
+
seq_len = input_ids.shape[1]
|
| 189 |
+
# 此变量的名称和作用, 与上面forward()函数中一致
|
| 190 |
+
chosen_end_scores = []
|
| 191 |
+
for i in range(bs):
|
| 192 |
+
input_id = input_ids[i]
|
| 193 |
+
value = values[i]
|
| 194 |
+
# value.shape: (max_seq_len)
|
| 195 |
+
# c_ind即为prompt之后的序列片段中, 第一个pad_token的index
|
| 196 |
+
c_inds = (input_id[prompt_length:] == self.PAD_ID).nonzero()
|
| 197 |
+
c_ind = c_inds[0].item() + prompt_length if len(c_inds) > 0 else seq_len
|
| 198 |
+
# 取c_ind的前一个index(实际上就是answer的最终位置)作为reward_score
|
| 199 |
+
chosen_end_scores.append(value[c_ind - 1])
|
| 200 |
+
# for循环结束后, len(chosen_end_scores) = batch_size, 相当于一个batch的样本分值
|
| 201 |
+
return {
|
| 202 |
+
"values": values,
|
| 203 |
+
"chosen_end_scores": torch.stack(chosen_end_scores) # 经过stack堆叠后(batch_size,)
|
| 204 |
+
}
|
RM-EN-01-30-2026/data/rm_eval.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
RM-EN-01-30-2026/data/rm_train.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20b4085690573224ca426fee9fc34363bb784b1bf46cf034016d17bd14b58c3a
|
| 3 |
+
size 43901233
|
RM-EN-01-30-2026/model/chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
RM-EN-01-30-2026/model/config.json
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"dtype": "bfloat16",
|
| 9 |
+
"end_token_id": 151645,
|
| 10 |
+
"eos_token_id": 151645,
|
| 11 |
+
"head_dim": 128,
|
| 12 |
+
"hidden_act": "silu",
|
| 13 |
+
"hidden_size": 2560,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 9728,
|
| 16 |
+
"layer_types": [
|
| 17 |
+
"full_attention",
|
| 18 |
+
"full_attention",
|
| 19 |
+
"full_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"full_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"full_attention",
|
| 24 |
+
"full_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"full_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"full_attention",
|
| 52 |
+
"full_attention"
|
| 53 |
+
],
|
| 54 |
+
"max_position_embeddings": 40960,
|
| 55 |
+
"max_window_layers": 36,
|
| 56 |
+
"model_type": "qwen3",
|
| 57 |
+
"n_embd": 2560,
|
| 58 |
+
"num_attention_heads": 32,
|
| 59 |
+
"num_hidden_layers": 36,
|
| 60 |
+
"num_key_value_heads": 8,
|
| 61 |
+
"pad_token_id": 151645,
|
| 62 |
+
"rms_norm_eps": 1e-06,
|
| 63 |
+
"rope_parameters": {
|
| 64 |
+
"rope_theta": 1000000,
|
| 65 |
+
"rope_type": "default"
|
| 66 |
+
},
|
| 67 |
+
"sliding_window": null,
|
| 68 |
+
"tie_word_embeddings": true,
|
| 69 |
+
"transformers_version": "5.0.0",
|
| 70 |
+
"use_cache": true,
|
| 71 |
+
"use_sliding_window": false,
|
| 72 |
+
"vocab_size": 151672
|
| 73 |
+
}
|
RM-EN-01-30-2026/model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1d6ad694f70c04ed664794dbf658e4c5d5f494efa3b2f1db1f400a316f4cf4e
|
| 3 |
+
size 8043639192
|
RM-EN-01-30-2026/model/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
|
| 3 |
+
size 11422650
|
RM-EN-01-30-2026/model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"extra_special_tokens": [
|
| 9 |
+
"<|im_start|>",
|
| 10 |
+
"<|im_end|>",
|
| 11 |
+
"<|object_ref_start|>",
|
| 12 |
+
"<|object_ref_end|>",
|
| 13 |
+
"<|box_start|>",
|
| 14 |
+
"<|box_end|>",
|
| 15 |
+
"<|quad_start|>",
|
| 16 |
+
"<|quad_end|>",
|
| 17 |
+
"<|vision_start|>",
|
| 18 |
+
"<|vision_end|>",
|
| 19 |
+
"<|vision_pad|>",
|
| 20 |
+
"<|image_pad|>",
|
| 21 |
+
"<|video_pad|>"
|
| 22 |
+
],
|
| 23 |
+
"fast_tokenizer": true,
|
| 24 |
+
"is_local": true,
|
| 25 |
+
"model_max_length": 131072,
|
| 26 |
+
"pad_token": "<|im_end|>",
|
| 27 |
+
"split_special_tokens": false,
|
| 28 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 29 |
+
"unk_token": null
|
| 30 |
+
}
|
RM-EN-01-30-2026/model/training.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
RM-EN-01-30-2026/scripts/run_qwen3-4b.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
OUTPUT_DIR=./output_rm_en
|
| 3 |
+
mkdir -p $OUTPUT_DIR
|
| 4 |
+
|
| 5 |
+
deepspeed --num_gpus 1 main.py \
|
| 6 |
+
--model_name_or_path /workspace/Qwen3-4B \
|
| 7 |
+
--data_path custom \
|
| 8 |
+
--num_padding_at_beginning 0 \
|
| 9 |
+
--per_device_train_batch_size 2 \
|
| 10 |
+
--per_device_eval_batch_size 2 \
|
| 11 |
+
--max_seq_len 512 \
|
| 12 |
+
--learning_rate 1e-5 \
|
| 13 |
+
--weight_decay 0.1 \
|
| 14 |
+
--num_train_epochs 1 \
|
| 15 |
+
--gradient_accumulation_steps 8 \
|
| 16 |
+
--lr_scheduler_type cosine \
|
| 17 |
+
--num_warmup_steps 50 \
|
| 18 |
+
--seed 1234 \
|
| 19 |
+
--gradient_checkpointing \
|
| 20 |
+
--zero_stage 2 \
|
| 21 |
+
--offload \
|
| 22 |
+
--dtype bf16 \
|
| 23 |
+
--enable_tensorboard \
|
| 24 |
+
--tensorboard_path $OUTPUT_DIR/tensorboard \
|
| 25 |
+
--output_dir $OUTPUT_DIR \
|
| 26 |
+
--print_loss \
|
| 27 |
+
--deepspeed 2>&1 | tee $OUTPUT_DIR/training.log
|
SFT-EN-01-29-2026/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SFT English Medical Model - Qwen3-4B
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
- Base Model: Qwen3-4B
|
| 5 |
+
- Training: DeepSpeed-Chat SFT with LoRA
|
| 6 |
+
- Dataset: UltraMedical English (9K train, 1K eval)
|
| 7 |
+
- Date: 2026-01-29
|
| 8 |
+
|
| 9 |
+
## Training Config
|
| 10 |
+
- LoRA dim: 64
|
| 11 |
+
- Learning rate: 2e-5
|
| 12 |
+
- Batch size: 2
|
| 13 |
+
- Gradient accumulation: 4
|
| 14 |
+
- ZeRO stage: 2
|
| 15 |
+
- Dtype: bf16
|
| 16 |
+
|
| 17 |
+
## Results
|
| 18 |
+
- Final PPL: 2.498
|
| 19 |
+
- Final Loss: 0.915
|
| 20 |
+
|
| 21 |
+
## Directory
|
| 22 |
+
- model/ - SFT model weights
|
| 23 |
+
- data/ - Training data
|
| 24 |
+
- scripts/ - Training scripts
|
| 25 |
+
- code/ - Modified DeepSpeed-Chat code
|
SFT-EN-01-29-2026/code/data_utils.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# DeepSpeed Team
|
| 5 |
+
"""
|
| 6 |
+
Part of the code was adopted from https://github.com/deepspeedai/Megatron-DeepSpeed/blob/main/megatron/data/dataset_utils.py
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset, Subset, ConcatDataset
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
import numpy as np
|
| 14 |
+
import os
|
| 15 |
+
import hashlib
|
| 16 |
+
from itertools import chain
|
| 17 |
+
from dschat.utils.data import raw_datasets
|
| 18 |
+
from deepspeed.accelerator import get_accelerator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_raw_dataset(dataset_name, output_path, seed, local_rank):
|
| 22 |
+
|
| 23 |
+
if "Dahoas/rm-static" in dataset_name:
|
| 24 |
+
return raw_datasets.DahoasRmstaticDataset(output_path, seed,
|
| 25 |
+
local_rank, dataset_name)
|
| 26 |
+
elif "Dahoas/full-hh-rlhf" in dataset_name:
|
| 27 |
+
return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,
|
| 28 |
+
local_rank, dataset_name)
|
| 29 |
+
elif "Dahoas/synthetic-instruct-gptj-pairwise" in dataset_name:
|
| 30 |
+
return raw_datasets.DahoasSyntheticinstructgptjpairwiseDataset(
|
| 31 |
+
output_path, seed, local_rank, dataset_name)
|
| 32 |
+
elif "yitingxie/rlhf-reward-datasets" in dataset_name:
|
| 33 |
+
return raw_datasets.YitingxieRlhfrewarddatasetsDataset(
|
| 34 |
+
output_path, seed, local_rank, dataset_name)
|
| 35 |
+
elif "openai/webgpt_comparisons" in dataset_name:
|
| 36 |
+
return raw_datasets.OpenaiWebgptcomparisonsDataset(
|
| 37 |
+
output_path, seed, local_rank, dataset_name)
|
| 38 |
+
elif "stanfordnlp/SHP" in dataset_name:
|
| 39 |
+
return raw_datasets.StanfordnlpSHPDataset(output_path, seed,
|
| 40 |
+
local_rank, dataset_name)
|
| 41 |
+
elif "pvduy/sharegpt_alpaca_oa_vicuna_format" in dataset_name:
|
| 42 |
+
return raw_datasets.PvduySharegptalpacaoavicunaformatDataset(
|
| 43 |
+
output_path, seed, local_rank, dataset_name)
|
| 44 |
+
elif "wangrui6/Zhihu-KOL" in dataset_name:
|
| 45 |
+
return raw_datasets.Wangrui6ZhihuKOLDataset(output_path, seed,
|
| 46 |
+
local_rank, dataset_name)
|
| 47 |
+
elif "Cohere/miracl-zh-queries-22-12" in dataset_name:
|
| 48 |
+
return raw_datasets.CohereMiraclzhqueries2212Dataset(
|
| 49 |
+
output_path, seed, local_rank, dataset_name)
|
| 50 |
+
elif "Hello-SimpleAI/HC3-Chinese" in dataset_name:
|
| 51 |
+
return raw_datasets.HelloSimpleAIHC3ChineseDataset(
|
| 52 |
+
output_path, seed, local_rank, dataset_name)
|
| 53 |
+
elif "mkqa-Chinese" in dataset_name:
|
| 54 |
+
return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank,
|
| 55 |
+
"mkqa")
|
| 56 |
+
elif "mkqa-Japanese" in dataset_name:
|
| 57 |
+
return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank,
|
| 58 |
+
"mkqa")
|
| 59 |
+
elif "Cohere/miracl-ja-queries-22-12" in dataset_name:
|
| 60 |
+
return raw_datasets.CohereMiracljaqueries2212Dataset(
|
| 61 |
+
output_path, seed, local_rank, dataset_name)
|
| 62 |
+
elif "lmqg/qg_jaquad" in dataset_name:
|
| 63 |
+
return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank,
|
| 64 |
+
dataset_name)
|
| 65 |
+
elif "lmqg/qag_jaquad" in dataset_name:
|
| 66 |
+
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank,
|
| 67 |
+
dataset_name)
|
| 68 |
+
elif "local/jsonfile" in dataset_name:
|
| 69 |
+
chat_path = os.path.abspath(
|
| 70 |
+
os.path.join(os.path.dirname(__file__), os.path.pardir,
|
| 71 |
+
os.path.pardir, os.path.pardir))
|
| 72 |
+
if not (os.path.isfile(chat_path + '/data/train.json')
|
| 73 |
+
and os.path.isfile(chat_path + '/data/eval.json')):
|
| 74 |
+
raise RuntimeError(
|
| 75 |
+
f"Please check both the train.json and eval.json files in your applications/DeepSpeed-Chat/data directory."
|
| 76 |
+
)
|
| 77 |
+
return raw_datasets.LocalJsonFileDataset(output_path, seed, local_rank,
|
| 78 |
+
dataset_name, chat_path)
|
| 79 |
+
elif "custom" in dataset_name:
|
| 80 |
+
# 自動獲取當前文件所在的絕對路徑,向上跳三級到達 DeepSpeed-Chat 根目錄
|
| 81 |
+
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
| 82 |
+
chat_path = os.path.abspath(os.path.join(current_file_path, os.path.pardir, os.path.pardir, os.path.pardir))
|
| 83 |
+
return raw_datasets.CustomDataset(output_path, seed, local_rank,
|
| 84 |
+
dataset_name, chat_path)
|
| 85 |
+
else:
|
| 86 |
+
raise RuntimeError(
|
| 87 |
+
f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_shuffle_idx(seed, size):
|
| 92 |
+
np_rng = np.random.RandomState(seed=seed)
|
| 93 |
+
dtype_ = np.uint32
|
| 94 |
+
if size >= (np.iinfo(np.uint32).max - 1):
|
| 95 |
+
dtype_ = np.int64
|
| 96 |
+
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
|
| 97 |
+
np_rng.shuffle(shuffle_idx)
|
| 98 |
+
return shuffle_idx
|
| 99 |
+
|
| 100 |
+
# s=data_split, e.g., "6,2,2"
|
| 101 |
+
def get_raw_dataset_split_index(local_rank,
|
| 102 |
+
output_path,
|
| 103 |
+
dataset_name,
|
| 104 |
+
seed,
|
| 105 |
+
split_name,
|
| 106 |
+
data_split,
|
| 107 |
+
split_index,
|
| 108 |
+
data_size):
|
| 109 |
+
index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
|
| 110 |
+
# reindex each time when using local jsonfile since it's more likely to get modified
|
| 111 |
+
if (not os.path.isfile(index_file_name)) or (dataset_name
|
| 112 |
+
== 'jsonfile'):
|
| 113 |
+
splits = [float(s) for s in data_split.split(',')]
|
| 114 |
+
splits_sum = sum(splits)
|
| 115 |
+
splits = [split / splits_sum for split in splits]
|
| 116 |
+
splits_index = [0]
|
| 117 |
+
for index, split in enumerate(splits):
|
| 118 |
+
splits_index.append(splits_index[index] +
|
| 119 |
+
int(round(split * float(data_size))))
|
| 120 |
+
diff = splits_index[-1] - data_size
|
| 121 |
+
for index in range(1, len(splits_index)):
|
| 122 |
+
splits_index[index] -= diff
|
| 123 |
+
assert splits_index[-1] == data_size
|
| 124 |
+
|
| 125 |
+
shuffle_idx = get_shuffle_idx(seed, data_size)
|
| 126 |
+
for split_i in range(len(splits)):
|
| 127 |
+
shuffle_idx_split_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_i}.npy"
|
| 128 |
+
shuffle_idx_split = shuffle_idx[
|
| 129 |
+
splits_index[split_i]:splits_index[split_i + 1]]
|
| 130 |
+
np.save(shuffle_idx_split_file_name,
|
| 131 |
+
shuffle_idx_split,
|
| 132 |
+
allow_pickle=True)
|
| 133 |
+
index = np.load(index_file_name, allow_pickle=True)
|
| 134 |
+
return index.tolist()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class PromptDataset(Dataset):
|
| 138 |
+
|
| 139 |
+
def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
|
| 140 |
+
pad_token_id, train_phase) -> None:
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.prompt_dataset = prompt_dataset
|
| 143 |
+
self.chosen_dataset = chosen_dataset
|
| 144 |
+
self.reject_dataset = reject_dataset
|
| 145 |
+
self.pad_token_id = pad_token_id
|
| 146 |
+
self.train_phase = train_phase
|
| 147 |
+
|
| 148 |
+
def __len__(self):
|
| 149 |
+
length = len(self.chosen_dataset)
|
| 150 |
+
if self.train_phase == 3:
|
| 151 |
+
length = len(self.prompt_dataset)
|
| 152 |
+
return length
|
| 153 |
+
|
| 154 |
+
def __getitem__(self, idx):
|
| 155 |
+
if self.train_phase == 1:
|
| 156 |
+
return {
|
| 157 |
+
"input_ids":
|
| 158 |
+
self.chosen_dataset[idx]["input_ids"],
|
| 159 |
+
"attention_mask":
|
| 160 |
+
self.chosen_dataset[idx]["attention_mask"],
|
| 161 |
+
"labels":self.chosen_dataset[idx]["input_ids"]
|
| 162 |
+
#torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
|
| 163 |
+
# self.chosen_dataset[idx]["input_ids"], -100)
|
| 164 |
+
}
|
| 165 |
+
elif self.train_phase == 2:
|
| 166 |
+
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
|
| 167 |
+
self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
|
| 168 |
+
elif self.train_phase == 3:
|
| 169 |
+
return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \
|
| 170 |
+
self.pad_token_id
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
|
| 174 |
+
end_of_conversation_token, max_seq_len):
|
| 175 |
+
# 初始化3个空列表
|
| 176 |
+
prompt_dataset = []
|
| 177 |
+
chosen_dataset = []
|
| 178 |
+
reject_dataset = []
|
| 179 |
+
|
| 180 |
+
# 如果当前为第一阶段训练, 即SFT阶段
|
| 181 |
+
if train_phase == 1:
|
| 182 |
+
# current_dataset: 传参进来的train_dataset, 已经经过Subset(train_dataset,
|
| 183 |
+
# train_index)封装, 代表已经处理好的训练数据
|
| 184 |
+
for i, tmp_data in enumerate(current_dataset):
|
| 185 |
+
# 获取正常的(prompt, chosen)问答对, 用于第一阶段SFT训练
|
| 186 |
+
chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)
|
| 187 |
+
if chosen_sentence is not None:
|
| 188 |
+
# 对chosen_sentence尾部添加结束符
|
| 189 |
+
chosen_sentence += end_of_conversation_token
|
| 190 |
+
|
| 191 |
+
# 对中文文本数据进行tokenizer处理, 本质就是text_to_id数字化的过程
|
| 192 |
+
chosen_token = tokenizer(chosen_sentence,
|
| 193 |
+
max_length=max_seq_len,
|
| 194 |
+
padding="max_length",
|
| 195 |
+
truncation=True,
|
| 196 |
+
return_tensors="pt")
|
| 197 |
+
|
| 198 |
+
# 将input_ids和attention_mask字段取出, 并去掉batch_size=1的维度
|
| 199 |
+
chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0)
|
| 200 |
+
chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0)
|
| 201 |
+
chosen_dataset.append(chosen_token)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# 如果当前为第二阶段训练, 即Reward Model阶段
|
| 205 |
+
elif train_phase == 2:
|
| 206 |
+
for i, tmp_data in enumerate(current_dataset):
|
| 207 |
+
# 取出符合人类喜好的问答对(prompt, chosen)
|
| 208 |
+
chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)
|
| 209 |
+
# 取出不符合人类喜好的问答对(prompt, reject)
|
| 210 |
+
reject_sentence = raw_dataset.get_prompt_and_rejected(tmp_data)
|
| 211 |
+
|
| 212 |
+
if chosen_sentence is not None and reject_sentence is not None:
|
| 213 |
+
# 在问答对的后面添加结束符
|
| 214 |
+
chosen_sentence += end_of_conversation_token
|
| 215 |
+
reject_sentence += end_of_conversation_token
|
| 216 |
+
|
| 217 |
+
# 对符合人类喜好的问答对进行tokenizer处理, 并完成数字化id映射
|
| 218 |
+
chosen_token = tokenizer(chosen_sentence,
|
| 219 |
+
max_length=max_seq_len,
|
| 220 |
+
padding="max_length",
|
| 221 |
+
truncation=True,
|
| 222 |
+
return_tensors="pt")
|
| 223 |
+
|
| 224 |
+
# 对不符合人类喜好的问答对进行tokenizer处理, 并完成数字化id映射
|
| 225 |
+
reject_token = tokenizer(reject_sentence,
|
| 226 |
+
max_length=max_seq_len,
|
| 227 |
+
padding="max_length",
|
| 228 |
+
truncation=True,
|
| 229 |
+
return_tensors="pt")
|
| 230 |
+
|
| 231 |
+
# 将input_ids和attention_mask字段取出, 并添加进结果列表
|
| 232 |
+
chosen_token["input_ids"] = chosen_token["input_ids"]
|
| 233 |
+
chosen_token["attention_mask"] = chosen_token["attention_mask"]
|
| 234 |
+
chosen_dataset.append(chosen_token)
|
| 235 |
+
|
| 236 |
+
reject_token["input_ids"] = reject_token["input_ids"]
|
| 237 |
+
reject_token["attention_mask"] = reject_token["attention_mask"]
|
| 238 |
+
reject_dataset.append(reject_token)
|
| 239 |
+
|
| 240 |
+
# 如果当前为第三阶段训练, 即RLHF阶段
|
| 241 |
+
elif train_phase == 3:
|
| 242 |
+
# 不满足条件的数据, 直接过滤掉, 但需要统计被过滤掉的数据量
|
| 243 |
+
filtered = 0
|
| 244 |
+
for i, tmp_data in enumerate(current_dataset):
|
| 245 |
+
# 强化学习训练阶段, 只读取原始数据中的prompt输入
|
| 246 |
+
prompt = raw_dataset.get_prompt(tmp_data)
|
| 247 |
+
|
| 248 |
+
if prompt is not None:
|
| 249 |
+
# 对prompt进行数字化映射和tokenizer处理
|
| 250 |
+
prompt_token = tokenizer(prompt, return_tensors="pt")
|
| 251 |
+
|
| 252 |
+
# 只有数据长度满足条件的数据, 才需要被处理
|
| 253 |
+
# 如果length超过设定的最大序列长度(即max_prompt_len, 默认值256), 进行截断
|
| 254 |
+
if prompt_token["input_ids"].size()[-1] <= max_seq_len:
|
| 255 |
+
for key_word in ["input_ids", "attention_mask"]:
|
| 256 |
+
# 最后的 flip(0) 是将 token 序列进行 "翻转倒序"
|
| 257 |
+
prompt_token[key_word] = prompt_token[key_word].squeeze(0).flip(0)
|
| 258 |
+
|
| 259 |
+
# 一般来说, padding操作通常是直接在序列后面加入pad, padding后的输入序
|
| 260 |
+
# 列变成了[prompt, padding]的形式, 那么自回归大模型将接在一连串pad后面继续生成, 这显然不合理.
|
| 261 |
+
# 所以先将prompt进行flip(0)翻转倒序, 然后再padding, 达到符合条件的长
|
| 262 |
+
# 度后最后再flip(0)翻转回来, 输入序列就变成了[padding, prompt]的形式, 大模型就可以接在prompt后面
|
| 263 |
+
# 继续生成了.
|
| 264 |
+
# 举个栗子: prompt_token_ids = [11, 22, 33], max_prompt_len = 5
|
| 265 |
+
# 直接padding后, 就成了[11, 22, 33, 0, 0]
|
| 266 |
+
# 如果先进行翻转倒序, prompt_token_ids.flip(0) = [33, 22, 11]
|
| 267 |
+
# 再进行padding, prompt_token_ids.flip(0).padding() = [33, 22, 11, 0, 0]
|
| 268 |
+
# 最后再次翻转倒序, prompt_token_ids.flip(0).padding().flip(0) = [0, 0, 11, 22, 33]
|
| 269 |
+
'''
|
| 270 |
+
注意: 最后一次翻转倒序是在data_utils.py代码文件中,
|
| 271 |
+
DataCollatorRLHF()类内__call__函数内部进行的,
|
| 272 |
+
batch["prompt"] = batch["prompt"].flip(1)
|
| 273 |
+
因为此时已经是batch数据了, 所以翻转倒序是在flip(1), 即seq_len维度上进行的.
|
| 274 |
+
'''
|
| 275 |
+
prompt_dataset.append(prompt_token)
|
| 276 |
+
else:
|
| 277 |
+
filtered += 1
|
| 278 |
+
|
| 279 |
+
print(f'Creating dataset {raw_dataset.dataset_name_clean} '
|
| 280 |
+
f'for {train_phase=} size={len(prompt_dataset)} {filtered=}')
|
| 281 |
+
|
| 282 |
+
# 返回封装类对象, 相当于torch中的Dataset, 供DataLoader调用
|
| 283 |
+
return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
|
| 284 |
+
tokenizer.pad_token_id, train_phase)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class PromptDataset(Dataset):
|
| 289 |
+
def __init__(self, prompt_dataset, chosen_dataset, reject_dataset,
|
| 290 |
+
pad_token_id, train_phase) -> None:
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.prompt_dataset = prompt_dataset
|
| 293 |
+
self.chosen_dataset = chosen_dataset
|
| 294 |
+
self.reject_dataset = reject_dataset
|
| 295 |
+
self.pad_token_id = pad_token_id
|
| 296 |
+
self.train_phase = train_phase
|
| 297 |
+
|
| 298 |
+
def __len__(self):
|
| 299 |
+
length = len(self.chosen_dataset)
|
| 300 |
+
if self.train_phase == 3:
|
| 301 |
+
length = len(self.prompt_dataset)
|
| 302 |
+
return length
|
| 303 |
+
|
| 304 |
+
def __getitem__(self, idx):
|
| 305 |
+
# 第一阶段SFT训练返回数据的格式
|
| 306 |
+
if self.train_phase == 1:
|
| 307 |
+
return {
|
| 308 |
+
"input_ids": self.chosen_dataset[idx]["input_ids"],
|
| 309 |
+
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
|
| 310 |
+
"labels": self.chosen_dataset[idx]["input_ids"]
|
| 311 |
+
}
|
| 312 |
+
# 第二阶段Reward Model训练返回数据的格式
|
| 313 |
+
elif self.train_phase == 2:
|
| 314 |
+
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
|
| 315 |
+
self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"]
|
| 316 |
+
# 第三阶段RLHF训练返回数据的格式
|
| 317 |
+
elif self.train_phase == 3:
|
| 318 |
+
return self.prompt_dataset[idx]["input_ids"], self.prompt_dataset[idx]["attention_mask"], \
|
| 319 |
+
self.pad_token_id
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def create_dataset(local_rank, dataset_name, data_split, output_path,
|
| 323 |
+
train_phase, seed, tokenizer, end_of_conversation_token,
|
| 324 |
+
max_seq_len):
|
| 325 |
+
# 训练个性化私有大模型, 设置dataset_name='custom'
|
| 326 |
+
dataset_name = "custom"
|
| 327 |
+
# 因为设定了dataset_name = 'custom', 所以调用get_raw_dataset()函数时, 就自动注册了
|
| 328 |
+
# custom分支, 从本地读取数据集
|
| 329 |
+
raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
|
| 330 |
+
|
| 331 |
+
# 调用在CustomDataset类中定义的get_train_data()函数, 获取训练集数据
|
| 332 |
+
train_dataset = raw_dataset.get_train_data()
|
| 333 |
+
|
| 334 |
+
# 获取随机排列下标后的训练集index列表对象
|
| 335 |
+
train_index = get_raw_dataset_split_index(local_rank, output_path,
|
| 336 |
+
raw_dataset.dataset_name_clean,
|
| 337 |
+
seed, "train", data_split,
|
| 338 |
+
train_phase - 1,
|
| 339 |
+
len(train_dataset))
|
| 340 |
+
|
| 341 |
+
# 传参train_dataset数据集, 和随机排列后的train_index列表对象, 封装成Subset
|
| 342 |
+
# Subset功能: 取指定一个索引序列对应的子数据集
|
| 343 |
+
train_dataset = Subset(train_dataset, train_index)
|
| 344 |
+
|
| 345 |
+
# 调用核心函数create_dataset_split()进行数据切分处理
|
| 346 |
+
train_dataset = create_dataset_split(train_dataset, raw_dataset,
|
| 347 |
+
train_phase, tokenizer,
|
| 348 |
+
end_of_conversation_token,
|
| 349 |
+
max_seq_len)
|
| 350 |
+
|
| 351 |
+
# 下面验证集的数据处理流程, 同上面训练集一样
|
| 352 |
+
eval_dataset = raw_dataset.get_eval_data()
|
| 353 |
+
|
| 354 |
+
eval_index = get_raw_dataset_split_index(local_rank, output_path,
|
| 355 |
+
raw_dataset.dataset_name_clean,
|
| 356 |
+
seed, "eval",
|
| 357 |
+
data_split, train_phase - 1,
|
| 358 |
+
len(eval_dataset))
|
| 359 |
+
|
| 360 |
+
eval_dataset = Subset(eval_dataset, eval_index)
|
| 361 |
+
eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
|
| 362 |
+
tokenizer, end_of_conversation_token,
|
| 363 |
+
max_seq_len)
|
| 364 |
+
|
| 365 |
+
return train_dataset, eval_dataset
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def create_prompt_dataset(local_rank,
|
| 369 |
+
data_path,
|
| 370 |
+
data_split,
|
| 371 |
+
output_path,
|
| 372 |
+
train_phase,
|
| 373 |
+
seed,
|
| 374 |
+
tokenizer,
|
| 375 |
+
max_seq_len,
|
| 376 |
+
end_of_conversation_token="<|endoftext|>",
|
| 377 |
+
sft_only_data_path=[],
|
| 378 |
+
reload=False):
|
| 379 |
+
"""
|
| 380 |
+
Creates the prompt dataset
|
| 381 |
+
"""
|
| 382 |
+
os.makedirs(output_path, exist_ok=True)
|
| 383 |
+
fname = "_".join(data_path)
|
| 384 |
+
sft_cache_key = "_".join(sft_only_data_path)
|
| 385 |
+
tokenizer_name = tokenizer.init_kwargs["name_or_path"].replace("/", "_")
|
| 386 |
+
fname = f"{fname}_split{data_split}_phase{train_phase}_seed{seed}_tokenizer{tokenizer_name}_seqlen{max_seq_len}_sft{sft_cache_key}"
|
| 387 |
+
fname = "_".join(fname.split("/"))
|
| 388 |
+
fname = hashlib.sha256(fname.encode()).hexdigest(
|
| 389 |
+
) # hash the file name to avoid too long file name
|
| 390 |
+
train_fname = f"{output_path}/traindata_{fname}.pt"
|
| 391 |
+
eval_fname = f"{output_path}/evaldata_{fname}.pt"
|
| 392 |
+
|
| 393 |
+
cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
|
| 394 |
+
buf_create_cache = torch.ByteTensor([not cache_found]).to(
|
| 395 |
+
get_accelerator().current_device_name())
|
| 396 |
+
torch.distributed.all_reduce(buf_create_cache)
|
| 397 |
+
|
| 398 |
+
if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
|
| 399 |
+
print(f'Creating prompt dataset {data_path}, {reload=}')
|
| 400 |
+
if len(data_path) == 1: # Single dataset.
|
| 401 |
+
train_dataset, eval_dataset = create_dataset(
|
| 402 |
+
local_rank,
|
| 403 |
+
data_path[0],
|
| 404 |
+
data_split,
|
| 405 |
+
output_path,
|
| 406 |
+
train_phase,
|
| 407 |
+
seed,
|
| 408 |
+
tokenizer,
|
| 409 |
+
end_of_conversation_token,
|
| 410 |
+
max_seq_len,
|
| 411 |
+
)
|
| 412 |
+
else: # Blending datasets.
|
| 413 |
+
train_datasets = []
|
| 414 |
+
eval_datasets = []
|
| 415 |
+
train_size = 0
|
| 416 |
+
eval_size = 0
|
| 417 |
+
for d_path in data_path:
|
| 418 |
+
train_dataset, eval_dataset = create_dataset(
|
| 419 |
+
local_rank,
|
| 420 |
+
d_path,
|
| 421 |
+
data_split,
|
| 422 |
+
output_path,
|
| 423 |
+
train_phase,
|
| 424 |
+
seed,
|
| 425 |
+
tokenizer,
|
| 426 |
+
end_of_conversation_token,
|
| 427 |
+
max_seq_len,
|
| 428 |
+
)
|
| 429 |
+
train_datasets.append(train_dataset)
|
| 430 |
+
eval_datasets.append(eval_dataset)
|
| 431 |
+
train_size += len(train_dataset)
|
| 432 |
+
eval_size += len(eval_dataset)
|
| 433 |
+
train_dataset = ConcatDataset(train_datasets)
|
| 434 |
+
shuffle_idx = get_shuffle_idx(seed, train_size)
|
| 435 |
+
train_dataset = Subset(train_dataset, shuffle_idx.tolist())
|
| 436 |
+
eval_dataset = ConcatDataset(eval_datasets)
|
| 437 |
+
shuffle_idx = get_shuffle_idx(seed, eval_size)
|
| 438 |
+
eval_dataset = Subset(eval_dataset, shuffle_idx.tolist())
|
| 439 |
+
|
| 440 |
+
# Append the SFT-only dataset if it exists, and current phase is 1(SFT).
|
| 441 |
+
if train_phase == 1 and sft_only_data_path:
|
| 442 |
+
sft_train_datasets = []
|
| 443 |
+
sft_eval_datasets = []
|
| 444 |
+
sft_train_size = 0
|
| 445 |
+
sft_eval_size = 0
|
| 446 |
+
for sft_path in sft_only_data_path:
|
| 447 |
+
sft_train_dataset, sft_eval_dataset = create_dataset(
|
| 448 |
+
local_rank,
|
| 449 |
+
sft_path,
|
| 450 |
+
"10,0,0",
|
| 451 |
+
output_path,
|
| 452 |
+
train_phase,
|
| 453 |
+
seed,
|
| 454 |
+
tokenizer,
|
| 455 |
+
end_of_conversation_token,
|
| 456 |
+
max_seq_len,
|
| 457 |
+
)
|
| 458 |
+
sft_train_datasets.append(sft_train_dataset)
|
| 459 |
+
sft_eval_datasets.append(sft_eval_dataset)
|
| 460 |
+
sft_train_size += len(sft_train_dataset)
|
| 461 |
+
sft_eval_size += len(sft_eval_dataset)
|
| 462 |
+
if sft_train_datasets: # Check if sft_train_datasets is not empty
|
| 463 |
+
sft_train_dataset = ConcatDataset(sft_train_datasets)
|
| 464 |
+
train_dataset = ConcatDataset(
|
| 465 |
+
[train_dataset, sft_train_dataset])
|
| 466 |
+
shuffle_idx = get_shuffle_idx(seed, len(train_dataset))
|
| 467 |
+
train_dataset = Subset(train_dataset, shuffle_idx.tolist())
|
| 468 |
+
if sft_eval_datasets: # Check if sft_eval_datasets is not empty
|
| 469 |
+
sft_eval_dataset = ConcatDataset(sft_eval_datasets)
|
| 470 |
+
eval_dataset = ConcatDataset([eval_dataset, sft_eval_dataset])
|
| 471 |
+
shuffle_idx = get_shuffle_idx(seed, len(eval_dataset))
|
| 472 |
+
eval_dataset = Subset(eval_dataset, shuffle_idx.tolist())
|
| 473 |
+
torch.save(train_dataset, train_fname)
|
| 474 |
+
torch.save(eval_dataset, eval_fname)
|
| 475 |
+
torch.distributed.barrier()
|
| 476 |
+
return torch.load(train_fname,
|
| 477 |
+
weights_only=False), torch.load(eval_fname,
|
| 478 |
+
weights_only=False)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class DataCollatorReward:
|
| 482 |
+
|
| 483 |
+
def __call__(self, data):
|
| 484 |
+
batch = {}
|
| 485 |
+
batch["input_ids"] = torch.cat([f[0]
|
| 486 |
+
for f in data] + [f[2] for f in data],
|
| 487 |
+
dim=0)
|
| 488 |
+
batch["attention_mask"] = torch.cat([f[1] for f in data] +
|
| 489 |
+
[f[3] for f in data],
|
| 490 |
+
dim=0)
|
| 491 |
+
return batch
|
| 492 |
+
|
| 493 |
+
# 3. RLHF数据集的处理
|
| 494 |
+
class DataCollatorRLHF:
|
| 495 |
+
|
| 496 |
+
def __init__(self, max_token_len, inference_tp_size):
|
| 497 |
+
self.max_token_len = max_token_len
|
| 498 |
+
self.inference_tp_size = inference_tp_size
|
| 499 |
+
|
| 500 |
+
def __call__(self, data):
|
| 501 |
+
batch = {}
|
| 502 |
+
pad_token_id = data[-1][-1]
|
| 503 |
+
|
| 504 |
+
prompt = pad_sequence([f[0] for f in data],
|
| 505 |
+
padding_value=pad_token_id,
|
| 506 |
+
batch_first=True)
|
| 507 |
+
prompt_mask = pad_sequence([f[1] for f in data],
|
| 508 |
+
padding_value=0,
|
| 509 |
+
batch_first=True)
|
| 510 |
+
|
| 511 |
+
### make sure the final ouput is a seqence of 2**?
|
| 512 |
+
length = prompt.size()[-1]
|
| 513 |
+
pad_length = self.max_token_len - length
|
| 514 |
+
if pad_length > 0:
|
| 515 |
+
batch["prompt"] = F.pad(prompt,
|
| 516 |
+
pad=(0, pad_length),
|
| 517 |
+
mode='constant',
|
| 518 |
+
value=pad_token_id)
|
| 519 |
+
batch["prompt_att_mask"] = F.pad(prompt_mask,
|
| 520 |
+
pad=(0, pad_length),
|
| 521 |
+
mode='constant',
|
| 522 |
+
value=0)
|
| 523 |
+
else:
|
| 524 |
+
batch["prompt"] = prompt
|
| 525 |
+
batch["prompt_att_mask"] = prompt_mask
|
| 526 |
+
batch["prompt"] = batch["prompt"].flip(1)
|
| 527 |
+
batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1)
|
| 528 |
+
return batch
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def get_unsupervised_data(args, tokenizer):
|
| 532 |
+
unsupervised_raw_datasets = load_dataset(
|
| 533 |
+
args.unsupervised_dataset_name, args.unsupervised_dataset_config_name)
|
| 534 |
+
column_names = unsupervised_raw_datasets["train"].column_names
|
| 535 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
| 536 |
+
|
| 537 |
+
def tokenize_function(examples):
|
| 538 |
+
return tokenizer(examples[text_column_name])
|
| 539 |
+
|
| 540 |
+
tokenized_datasets = unsupervised_raw_datasets.map(
|
| 541 |
+
tokenize_function,
|
| 542 |
+
batched=True,
|
| 543 |
+
num_proc=args.preprocessing_num_workers,
|
| 544 |
+
remove_columns=column_names,
|
| 545 |
+
load_from_cache_file=True,
|
| 546 |
+
desc="Running tokenizer on dataset",
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
block_size = args.max_prompt_seq_len + args.max_answer_seq_len
|
| 550 |
+
|
| 551 |
+
def group_texts(examples):
|
| 552 |
+
# Concatenate all texts.
|
| 553 |
+
concatenated_examples = {
|
| 554 |
+
k: list(chain(*examples[k]))
|
| 555 |
+
for k in examples.keys()
|
| 556 |
+
}
|
| 557 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
| 558 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
| 559 |
+
# customize this part to your needs.
|
| 560 |
+
if total_length >= block_size:
|
| 561 |
+
total_length = (total_length // block_size) * block_size
|
| 562 |
+
# Split by chunks of max_len.
|
| 563 |
+
result = {
|
| 564 |
+
k:
|
| 565 |
+
[t[i:i + block_size] for i in range(0, total_length, block_size)]
|
| 566 |
+
for k, t in concatenated_examples.items()
|
| 567 |
+
}
|
| 568 |
+
result["labels"] = result["input_ids"].copy()
|
| 569 |
+
return result
|
| 570 |
+
|
| 571 |
+
lm_datasets = tokenized_datasets.map(
|
| 572 |
+
group_texts,
|
| 573 |
+
batched=True,
|
| 574 |
+
num_proc=args.preprocessing_num_workers,
|
| 575 |
+
load_from_cache_file=True,
|
| 576 |
+
desc=f"Grouping texts in chunks of {block_size}",
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
train_dataset = lm_datasets["train"]
|
| 580 |
+
|
| 581 |
+
return train_dataset
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class MiniDataset:
|
| 585 |
+
|
| 586 |
+
def __init__(self, max_size, small_batch_size):
|
| 587 |
+
self.dataset = []
|
| 588 |
+
self.max_size = max_size
|
| 589 |
+
self.small_batch_size = small_batch_size
|
| 590 |
+
|
| 591 |
+
def seperate(self):
|
| 592 |
+
small_dataset = []
|
| 593 |
+
for large_batch in self.dataset:
|
| 594 |
+
if type(large_batch) == list or type(large_batch) == tuple:
|
| 595 |
+
large_size = len(large_batch[0])
|
| 596 |
+
elif type(large_batch) == dict:
|
| 597 |
+
large_size = len(large_batch[list(large_batch.keys())[0]])
|
| 598 |
+
else:
|
| 599 |
+
large_size = len(large_batch)
|
| 600 |
+
for i in range(0, large_size, self.small_batch_size):
|
| 601 |
+
if type(large_batch) == list or type(large_batch) == tuple:
|
| 602 |
+
small_dataset.append(
|
| 603 |
+
[x[i:i + self.small_batch_size] for x in large_batch])
|
| 604 |
+
elif type(large_batch) == dict:
|
| 605 |
+
small_dataset.append({
|
| 606 |
+
k: v[i:i + self.small_batch_size]
|
| 607 |
+
for k, v in large_batch.items()
|
| 608 |
+
})
|
| 609 |
+
else:
|
| 610 |
+
small_dataset.append(large_batch[i:i +
|
| 611 |
+
self.small_batch_size])
|
| 612 |
+
self.free()
|
| 613 |
+
|
| 614 |
+
return small_dataset
|
| 615 |
+
|
| 616 |
+
def add(self, data):
|
| 617 |
+
if len(self.dataset) < self.max_size:
|
| 618 |
+
self.dataset.append(data)
|
| 619 |
+
if len(self.dataset) == self.max_size:
|
| 620 |
+
return self.seperate()
|
| 621 |
+
else:
|
| 622 |
+
return None
|
| 623 |
+
else:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
"The dataset is full but we did not stop it. There is a bug in the code."
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
def free(self):
|
| 629 |
+
self.dataset = []
|
SFT-EN-01-29-2026/code/main.py
ADDED
|
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Microsoft Corporation.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
# DeepSpeed Team
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import math
|
| 9 |
+
import sys
|
| 10 |
+
sys.path.append("/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat")
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoModelForCausalLM,
|
| 17 |
+
SchedulerType,
|
| 18 |
+
default_data_collator,
|
| 19 |
+
get_scheduler,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import deepspeed
|
| 23 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
| 24 |
+
from deepspeed import get_accelerator
|
| 25 |
+
|
| 26 |
+
from dschat.utils.data.data_utils import create_prompt_dataset
|
| 27 |
+
from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, save_hf_format_safetensors
|
| 28 |
+
from dschat.utils.ds_utils import get_train_ds_config
|
| 29 |
+
from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
|
| 30 |
+
from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
|
| 31 |
+
from dschat.utils.perf import print_throughput
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args():
|
| 35 |
+
parser = argparse.ArgumentParser(
|
| 36 |
+
description=
|
| 37 |
+
"Finetune a transformers model on a causal language modeling task")
|
| 38 |
+
# 默认的数据集我们不⽤, data_path传参进来的是垂直领域的个性化数据集
|
| 39 |
+
parser.add_argument('--data_path',
|
| 40 |
+
nargs='*',
|
| 41 |
+
default=['Dahoas/rm-static'],
|
| 42 |
+
help='Path to the training dataset. Accepted format:'
|
| 43 |
+
'1) a single data path, 2) multiple datasets in the'
|
| 44 |
+
'form: dataset1-path dataset2-path ...')
|
| 45 |
+
parser.add_argument('--data_split',
|
| 46 |
+
type=str,
|
| 47 |
+
default='6,2,2',
|
| 48 |
+
help='Comma-separated list of proportions for training'
|
| 49 |
+
'phase 1, 2, and 3 data. For example the split `6,2,2`'
|
| 50 |
+
'will use 60%% of data for phase 1, 20%% for phase 2'
|
| 51 |
+
'and 20%% for phase 3.')
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
'--sft_only_data_path',
|
| 54 |
+
nargs='*',
|
| 55 |
+
default=[],
|
| 56 |
+
help='Path to the dataset for only using in SFT phase.')
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
'--data_output_path',
|
| 59 |
+
type=str,
|
| 60 |
+
default='/tmp/data_files/',
|
| 61 |
+
help=
|
| 62 |
+
'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--model_name_or_path",
|
| 66 |
+
type=str,
|
| 67 |
+
help=
|
| 68 |
+
"Path to pretrained model or model identifier from huggingface.co/models.",
|
| 69 |
+
required=True,
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--per_device_train_batch_size",
|
| 73 |
+
type=int,
|
| 74 |
+
default=16,
|
| 75 |
+
help="Batch size (per device) for the training dataloader.",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--per_device_eval_batch_size",
|
| 79 |
+
type=int,
|
| 80 |
+
default=16,
|
| 81 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--max_seq_len",
|
| 85 |
+
type=int,
|
| 86 |
+
default=512,
|
| 87 |
+
help="The maximum sequence length.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--learning_rate",
|
| 91 |
+
type=float,
|
| 92 |
+
default=1e-3,
|
| 93 |
+
help=
|
| 94 |
+
"Initial learning rate (after the potential warmup period) to use.",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument("--weight_decay",
|
| 97 |
+
type=float,
|
| 98 |
+
default=0.,
|
| 99 |
+
help="Weight decay to use.")
|
| 100 |
+
parser.add_argument("--num_train_epochs",
|
| 101 |
+
type=int,
|
| 102 |
+
default=1,
|
| 103 |
+
help="Total number of training epochs to perform.")
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--gradient_accumulation_steps",
|
| 106 |
+
type=int,
|
| 107 |
+
default=1,
|
| 108 |
+
help=
|
| 109 |
+
"Number of updates steps to accumulate before performing a backward/update pass.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--lr_scheduler_type",
|
| 113 |
+
type=SchedulerType,
|
| 114 |
+
default="cosine",
|
| 115 |
+
help="The scheduler type to use.",
|
| 116 |
+
choices=[
|
| 117 |
+
"linear", "cosine", "cosine_with_restarts", "polynomial",
|
| 118 |
+
"constant", "constant_with_warmup"
|
| 119 |
+
],
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--num_warmup_steps",
|
| 123 |
+
type=int,
|
| 124 |
+
default=0,
|
| 125 |
+
help="Number of steps for the warmup in the lr scheduler.")
|
| 126 |
+
parser.add_argument("--output_dir",
|
| 127 |
+
type=str,
|
| 128 |
+
default=None,
|
| 129 |
+
help="Where to store the model.")
|
| 130 |
+
parser.add_argument("--seed",
|
| 131 |
+
type=int,
|
| 132 |
+
default=1234,
|
| 133 |
+
help="A seed for reproducible training.")
|
| 134 |
+
parser.add_argument("--local_rank",
|
| 135 |
+
type=int,
|
| 136 |
+
default=-1,
|
| 137 |
+
help="local_rank for distributed training on gpus")
|
| 138 |
+
parser.add_argument('--gradient_checkpointing',
|
| 139 |
+
action='store_true',
|
| 140 |
+
help='Enable HF gradient checkpointing for model.')
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--dropout",
|
| 143 |
+
type=float,
|
| 144 |
+
default=None,
|
| 145 |
+
help="If dropout configured, use it. "
|
| 146 |
+
"Otherwise, keep the default dropout configuration of the model.")
|
| 147 |
+
# deepspeed features
|
| 148 |
+
parser.add_argument('--offload',
|
| 149 |
+
action='store_true',
|
| 150 |
+
help='Enable ZeRO Offload techniques.')
|
| 151 |
+
parser.add_argument('--dtype',
|
| 152 |
+
type=str,
|
| 153 |
+
default='fp16',
|
| 154 |
+
choices=['fp16', 'bf16'],
|
| 155 |
+
help='Training data type')
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
'--zero_stage',
|
| 158 |
+
type=int,
|
| 159 |
+
default=0,
|
| 160 |
+
help='ZeRO optimization stage for Actor model (and clones).')
|
| 161 |
+
## LoRA for efficient training setting
|
| 162 |
+
parser.add_argument("--lora_dim",
|
| 163 |
+
type=int,
|
| 164 |
+
default=0,
|
| 165 |
+
help="If > 0, use LoRA for efficient training.")
|
| 166 |
+
parser.add_argument("--lora_module_name",
|
| 167 |
+
type=str,
|
| 168 |
+
default="decoder.layers.",
|
| 169 |
+
help="The scope of LoRA.")
|
| 170 |
+
parser.add_argument('--only_optimize_lora',
|
| 171 |
+
action='store_true',
|
| 172 |
+
help='Only optimize the LoRA parameters.')
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--lora_learning_rate",
|
| 175 |
+
type=float,
|
| 176 |
+
default=5e-4,
|
| 177 |
+
help=
|
| 178 |
+
"Initial LoRA learning rate (after the potential warmup period) to use."
|
| 179 |
+
)
|
| 180 |
+
## low precision
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
'--compute_fp32_loss',
|
| 183 |
+
action='store_true',
|
| 184 |
+
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
|
| 185 |
+
'If specified, loss is calculated in fp32.')
|
| 186 |
+
## Tensorboard logging
|
| 187 |
+
parser.add_argument('--enable_tensorboard',
|
| 188 |
+
action='store_true',
|
| 189 |
+
help='Enable tensorboard logging')
|
| 190 |
+
parser.add_argument('--tensorboard_path',
|
| 191 |
+
type=str,
|
| 192 |
+
default="step1_tensorboard")
|
| 193 |
+
## Tokenizer
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--add_eot_token",
|
| 196 |
+
action='store_true',
|
| 197 |
+
help="Add `eot_token` as additional special token to tokenizer")
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--eot_token",
|
| 200 |
+
type=str,
|
| 201 |
+
default="<|endoftext|>",
|
| 202 |
+
help="Specify the format of the `eot_token`",
|
| 203 |
+
)
|
| 204 |
+
## Print loss
|
| 205 |
+
parser.add_argument('--print_loss',
|
| 206 |
+
action='store_true',
|
| 207 |
+
help='Prints loss at each step.')
|
| 208 |
+
# 此处是所有超参数和训练参数的设置位置
|
| 209 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
return args
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main():
|
| 216 |
+
args = parse_args()
|
| 217 |
+
|
| 218 |
+
if args.local_rank == -1:
|
| 219 |
+
device = torch.device(get_accelerator().device_name())
|
| 220 |
+
else:
|
| 221 |
+
get_accelerator().set_device(args.local_rank)
|
| 222 |
+
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
| 223 |
+
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
| 224 |
+
# torch.distributed.init_process_group(backend='nccl')
|
| 225 |
+
deepspeed.init_distributed()
|
| 226 |
+
|
| 227 |
+
args.global_rank = torch.distributed.get_rank()
|
| 228 |
+
|
| 229 |
+
ds_config = get_train_ds_config(offload=args.offload,
|
| 230 |
+
dtype=args.dtype,
|
| 231 |
+
stage=args.zero_stage,
|
| 232 |
+
enable_tensorboard=args.enable_tensorboard,
|
| 233 |
+
tb_path=args.tensorboard_path,
|
| 234 |
+
tb_name="step1_model")
|
| 235 |
+
ds_config[
|
| 236 |
+
'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
|
| 237 |
+
ds_config[
|
| 238 |
+
'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
|
| 239 |
+
) * args.gradient_accumulation_steps
|
| 240 |
+
|
| 241 |
+
# If passed along, set the training seed now.
|
| 242 |
+
set_random_seed(args.seed)
|
| 243 |
+
|
| 244 |
+
torch.distributed.barrier()
|
| 245 |
+
|
| 246 |
+
# 实例化tokenizer和model
|
| 247 |
+
# load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
|
| 248 |
+
additional_special_tokens = args.eot_token if args.add_eot_token else None
|
| 249 |
+
tokenizer = load_hf_tokenizer(args.model_name_or_path,
|
| 250 |
+
fast_tokenizer=True,
|
| 251 |
+
add_special_tokens=additional_special_tokens)
|
| 252 |
+
|
| 253 |
+
model = create_hf_model(AutoModelForCausalLM,
|
| 254 |
+
args.model_name_or_path,
|
| 255 |
+
tokenizer,
|
| 256 |
+
ds_config,
|
| 257 |
+
dropout=args.dropout)
|
| 258 |
+
|
| 259 |
+
if args.compute_fp32_loss:
|
| 260 |
+
print_rank_0(
|
| 261 |
+
f"Using model {model.__class__.__name__} with loss in fp32",
|
| 262 |
+
args.global_rank)
|
| 263 |
+
causal_lm_model_to_fp32_loss(model)
|
| 264 |
+
|
| 265 |
+
# 设置LoRA微调
|
| 266 |
+
if args.lora_dim > 0:
|
| 267 |
+
model = convert_linear_layer_to_lora(model, args.lora_module_name,
|
| 268 |
+
args.lora_dim)
|
| 269 |
+
if args.only_optimize_lora:
|
| 270 |
+
model = only_optimize_lora_parameters(model)
|
| 271 |
+
model = make_model_gradient_checkpointing_compatible(model)
|
| 272 |
+
|
| 273 |
+
# 准备训练数据, 注意当前处于第⼀阶段 SFT
|
| 274 |
+
# Prepare the data
|
| 275 |
+
train_phase = 1
|
| 276 |
+
print('args: ', args)
|
| 277 |
+
print('data_path: ', args.data_path)
|
| 278 |
+
train_dataset, eval_dataset = create_prompt_dataset(
|
| 279 |
+
args.local_rank,
|
| 280 |
+
args.data_path,
|
| 281 |
+
args.data_split,
|
| 282 |
+
args.data_output_path,
|
| 283 |
+
train_phase,
|
| 284 |
+
args.seed,
|
| 285 |
+
tokenizer,
|
| 286 |
+
args.max_seq_len,
|
| 287 |
+
end_of_conversation_token=tokenizer.eos_token,
|
| 288 |
+
sft_only_data_path=args.sft_only_data_path)
|
| 289 |
+
# DataLoaders creation:
|
| 290 |
+
if args.local_rank == -1:
|
| 291 |
+
train_sampler = RandomSampler(train_dataset)
|
| 292 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 293 |
+
else:
|
| 294 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 295 |
+
eval_sampler = DistributedSampler(eval_dataset)
|
| 296 |
+
train_dataloader = DataLoader(train_dataset,
|
| 297 |
+
collate_fn=default_data_collator,
|
| 298 |
+
sampler=train_sampler,
|
| 299 |
+
batch_size=args.per_device_train_batch_size)
|
| 300 |
+
eval_dataloader = DataLoader(eval_dataset,
|
| 301 |
+
collate_fn=default_data_collator,
|
| 302 |
+
sampler=eval_sampler,
|
| 303 |
+
batch_size=args.per_device_eval_batch_size)
|
| 304 |
+
|
| 305 |
+
# main内部定义的评估函数
|
| 306 |
+
def evaluation(model, eval_dataloader):
|
| 307 |
+
model.eval()
|
| 308 |
+
losses = 0
|
| 309 |
+
for step, batch in enumerate(eval_dataloader):
|
| 310 |
+
batch = to_device(batch, device)
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
outputs = model(**batch)
|
| 313 |
+
|
| 314 |
+
loss = outputs.loss
|
| 315 |
+
losses += loss.float()
|
| 316 |
+
losses = losses / (step + 1)
|
| 317 |
+
try:
|
| 318 |
+
losses = get_all_reduce_mean(losses)
|
| 319 |
+
except:
|
| 320 |
+
pass
|
| 321 |
+
try:
|
| 322 |
+
perplexity = torch.exp(losses).item()
|
| 323 |
+
except OverflowError:
|
| 324 |
+
perplexity = float("inf")
|
| 325 |
+
return perplexity, losses.item()
|
| 326 |
+
|
| 327 |
+
# 采⽤分组优化参数的优化器策略
|
| 328 |
+
# Split weights in two groups, one with weight decay and the other not.
|
| 329 |
+
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
|
| 330 |
+
model, args.weight_decay, args.lora_learning_rate)
|
| 331 |
+
|
| 332 |
+
AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
|
| 333 |
+
optimizer = AdamOptimizer(optimizer_grouped_parameters,
|
| 334 |
+
lr=args.learning_rate,
|
| 335 |
+
betas=(0.9, 0.95))
|
| 336 |
+
|
| 337 |
+
num_update_steps_per_epoch = math.ceil(
|
| 338 |
+
len(train_dataloader) / args.gradient_accumulation_steps)
|
| 339 |
+
lr_scheduler = get_scheduler(
|
| 340 |
+
name=args.lr_scheduler_type,
|
| 341 |
+
optimizer=optimizer,
|
| 342 |
+
num_warmup_steps=args.num_warmup_steps,
|
| 343 |
+
num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# 采⽤deepspeed对相关组件进⾏封装, 本质上是进⾏加速优化
|
| 347 |
+
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
| 348 |
+
model=model,
|
| 349 |
+
optimizer=optimizer,
|
| 350 |
+
args=args,
|
| 351 |
+
config=ds_config,
|
| 352 |
+
lr_scheduler=lr_scheduler,
|
| 353 |
+
dist_init_required=True)
|
| 354 |
+
|
| 355 |
+
if args.gradient_checkpointing:
|
| 356 |
+
model.gradient_checkpointing_enable()
|
| 357 |
+
|
| 358 |
+
# 开始训练, 打印⼀些关键信息
|
| 359 |
+
# Train!
|
| 360 |
+
print_rank_0("***** Running training *****", args.global_rank)
|
| 361 |
+
# print_rank_0(
|
| 362 |
+
# f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****",
|
| 363 |
+
# args.global_rank)
|
| 364 |
+
# perplexity, eval_loss = evaluation(model, eval_dataloader)
|
| 365 |
+
# print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
|
| 366 |
+
|
| 367 |
+
# 经典的双重for循环训练模式
|
| 368 |
+
for epoch in range(args.num_train_epochs):
|
| 369 |
+
print_rank_0(
|
| 370 |
+
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
|
| 371 |
+
args.global_rank)
|
| 372 |
+
# 将模型设置为训练模式
|
| 373 |
+
model.train()
|
| 374 |
+
import time
|
| 375 |
+
for step, batch in enumerate(train_dataloader):
|
| 376 |
+
start = time.time()
|
| 377 |
+
batch = to_device(batch, device)
|
| 378 |
+
# 模型的前向传播计算, 并取到损失值loss
|
| 379 |
+
outputs = model(**batch, use_cache=False)
|
| 380 |
+
loss = outputs.loss
|
| 381 |
+
if args.print_loss:
|
| 382 |
+
print(
|
| 383 |
+
f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
|
| 384 |
+
)
|
| 385 |
+
# 反向传播, "⽼三样"
|
| 386 |
+
model.backward(loss)
|
| 387 |
+
model.step()
|
| 388 |
+
end = time.time()
|
| 389 |
+
if torch.distributed.get_rank() == 0:
|
| 390 |
+
print_throughput(model.model, args, end - start,
|
| 391 |
+
args.global_rank)
|
| 392 |
+
|
| 393 |
+
# 在验证集上进⾏评估, 获取困惑度
|
| 394 |
+
# Evaluate perplexity on the validation set.
|
| 395 |
+
print_rank_0(
|
| 396 |
+
f"***** Evaluating perplexity, Epoch {epoch+1}/{args.num_train_epochs} *****",
|
| 397 |
+
args.global_rank)
|
| 398 |
+
perplexity, eval_loss = evaluation(model, eval_dataloader)
|
| 399 |
+
print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
|
| 400 |
+
model.tput_timer.update_epoch_count()
|
| 401 |
+
|
| 402 |
+
if args.output_dir is not None:
|
| 403 |
+
print_rank_0('saving the final model ...', args.global_rank)
|
| 404 |
+
model = convert_lora_to_linear_layer(model)
|
| 405 |
+
|
| 406 |
+
if args.global_rank == 0:
|
| 407 |
+
# save_hf_format(model, tokenizer, args)
|
| 408 |
+
# 因为我们项⽬中需要训练的是Qwen3⼤模型, 需要保存成safetensor的格式
|
| 409 |
+
save_hf_format_safetensors(model, tokenizer, args)
|
| 410 |
+
|
| 411 |
+
if args.zero_stage == 3:
|
| 412 |
+
# 在zero_stage==3时, 每⼀个GPU只包含model的⼀部分, 因此需要⼀个特殊的函数来进⾏模型的保存
|
| 413 |
+
# For zero stage 3, each gpu only has a part of the model, so we need a special save function
|
| 414 |
+
#save_zero_three_model(model,
|
| 415 |
+
# args.global_rank,
|
| 416 |
+
# args.output_dir,
|
| 417 |
+
# zero_stage=args.zero_stage)
|
| 418 |
+
|
| 419 |
+
save_zero_three_model_safetensors(model,
|
| 420 |
+
torch.distributed.get_rank(),
|
| 421 |
+
"./output/final_model",
|
| 422 |
+
zero_stage=args.zero_stage,
|
| 423 |
+
lora_alpha=args.lora_dim,
|
| 424 |
+
merge_lora=True)
|
| 425 |
+
|
| 426 |
+
save_model_config_and_tokenizer(
|
| 427 |
+
model.module if hasattr(model, 'module') else model,
|
| 428 |
+
torch.distributed.get_rank(),
|
| 429 |
+
"./output/final_model",
|
| 430 |
+
base_model_path="workspace/Qwen3-4B"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
main()
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
'''
|
| 438 |
+
import argparse
|
| 439 |
+
import math
|
| 440 |
+
import sys
|
| 441 |
+
sys.path.append("/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat")
|
| 442 |
+
import torch
|
| 443 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
| 444 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 445 |
+
|
| 446 |
+
from transformers import (
|
| 447 |
+
AutoModelForCausalLM,
|
| 448 |
+
SchedulerType,
|
| 449 |
+
default_data_collator,
|
| 450 |
+
get_scheduler,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
import deepspeed
|
| 454 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
| 455 |
+
from deepspeed import get_accelerator
|
| 456 |
+
|
| 457 |
+
from dschat.utils.data.data_utils import create_prompt_dataset
|
| 458 |
+
from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, save_hf_format_safetensors
|
| 459 |
+
from dschat.utils.ds_utils import get_train_ds_config
|
| 460 |
+
from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
|
| 461 |
+
from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
|
| 462 |
+
from dschat.utils.perf import print_throughput
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def parse_args():
|
| 466 |
+
parser = argparse.ArgumentParser(
|
| 467 |
+
description=
|
| 468 |
+
"Finetune a transformers model on a causal language modeling task")
|
| 469 |
+
# 默认的数据集我们不⽤, data_path传参进来的是垂直领域的个性化数据集
|
| 470 |
+
parser.add_argument('--data_path',
|
| 471 |
+
nargs='*',
|
| 472 |
+
default=['Dahoas/rm-static'],
|
| 473 |
+
help='Path to the training dataset. Accepted format:'
|
| 474 |
+
'1) a single data path, 2) multiple datasets in the'
|
| 475 |
+
'form: dataset1-path dataset2-path ...')
|
| 476 |
+
parser.add_argument('--data_split',
|
| 477 |
+
type=str,
|
| 478 |
+
default='6,2,2',
|
| 479 |
+
help='Comma-separated list of proportions for training'
|
| 480 |
+
'phase 1, 2, and 3 data. For example the split `6,2,2`'
|
| 481 |
+
'will use 60%% of data for phase 1, 20%% for phase 2'
|
| 482 |
+
'and 20%% for phase 3.')
|
| 483 |
+
parser.add_argument(
|
| 484 |
+
'--sft_only_data_path',
|
| 485 |
+
nargs='*',
|
| 486 |
+
default=[],
|
| 487 |
+
help='Path to the dataset for only using in SFT phase.')
|
| 488 |
+
parser.add_argument(
|
| 489 |
+
'--data_output_path',
|
| 490 |
+
type=str,
|
| 491 |
+
default='/tmp/data_files/',
|
| 492 |
+
help=
|
| 493 |
+
'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
|
| 494 |
+
)
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
"--model_name_or_path",
|
| 497 |
+
type=str,
|
| 498 |
+
help=
|
| 499 |
+
"Path to pretrained model or model identifier from huggingface.co/models.",
|
| 500 |
+
required=True,
|
| 501 |
+
)
|
| 502 |
+
parser.add_argument(
|
| 503 |
+
"--per_device_train_batch_size",
|
| 504 |
+
type=int,
|
| 505 |
+
default=16,
|
| 506 |
+
help="Batch size (per device) for the training dataloader.",
|
| 507 |
+
)
|
| 508 |
+
parser.add_argument(
|
| 509 |
+
"--per_device_eval_batch_size",
|
| 510 |
+
type=int,
|
| 511 |
+
default=16,
|
| 512 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
| 513 |
+
)
|
| 514 |
+
parser.add_argument(
|
| 515 |
+
"--max_seq_len",
|
| 516 |
+
type=int,
|
| 517 |
+
default=512,
|
| 518 |
+
help="The maximum sequence length.",
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--learning_rate",
|
| 522 |
+
type=float,
|
| 523 |
+
default=1e-3,
|
| 524 |
+
help=
|
| 525 |
+
"Initial learning rate (after the potential warmup period) to use.",
|
| 526 |
+
)
|
| 527 |
+
parser.add_argument("--weight_decay",
|
| 528 |
+
type=float,
|
| 529 |
+
default=0.,
|
| 530 |
+
help="Weight decay to use.")
|
| 531 |
+
parser.add_argument("--num_train_epochs",
|
| 532 |
+
type=int,
|
| 533 |
+
default=1,
|
| 534 |
+
help="Total number of training epochs to perform.")
|
| 535 |
+
parser.add_argument(
|
| 536 |
+
"--gradient_accumulation_steps",
|
| 537 |
+
type=int,
|
| 538 |
+
default=1,
|
| 539 |
+
help=
|
| 540 |
+
"Number of updates steps to accumulate before performing a backward/update pass.",
|
| 541 |
+
)
|
| 542 |
+
parser.add_argument(
|
| 543 |
+
"--lr_scheduler_type",
|
| 544 |
+
type=SchedulerType,
|
| 545 |
+
default="cosine",
|
| 546 |
+
help="The scheduler type to use.",
|
| 547 |
+
choices=[
|
| 548 |
+
"linear", "cosine", "cosine_with_restarts", "polynomial",
|
| 549 |
+
"constant", "constant_with_warmup"
|
| 550 |
+
],
|
| 551 |
+
)
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--num_warmup_steps",
|
| 554 |
+
type=int,
|
| 555 |
+
default=0,
|
| 556 |
+
help="Number of steps for the warmup in the lr scheduler.")
|
| 557 |
+
parser.add_argument("--output_dir",
|
| 558 |
+
type=str,
|
| 559 |
+
default=None,
|
| 560 |
+
help="Where to store the model.")
|
| 561 |
+
parser.add_argument("--seed",
|
| 562 |
+
type=int,
|
| 563 |
+
default=1234,
|
| 564 |
+
help="A seed for reproducible training.")
|
| 565 |
+
parser.add_argument("--local_rank",
|
| 566 |
+
type=int,
|
| 567 |
+
default=-1,
|
| 568 |
+
help="local_rank for distributed training on gpus")
|
| 569 |
+
parser.add_argument('--gradient_checkpointing',
|
| 570 |
+
action='store_true',
|
| 571 |
+
help='Enable HF gradient checkpointing for model.')
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--dropout",
|
| 574 |
+
type=float,
|
| 575 |
+
default=None,
|
| 576 |
+
help="If dropout configured, use it. "
|
| 577 |
+
"Otherwise, keep the default dropout configuration of the model.")
|
| 578 |
+
# deepspeed features
|
| 579 |
+
parser.add_argument('--offload',
|
| 580 |
+
action='store_true',
|
| 581 |
+
help='Enable ZeRO Offload techniques.')
|
| 582 |
+
parser.add_argument('--dtype',
|
| 583 |
+
type=str,
|
| 584 |
+
default='fp16',
|
| 585 |
+
choices=['fp16', 'bf16'],
|
| 586 |
+
help='Training data type')
|
| 587 |
+
parser.add_argument(
|
| 588 |
+
'--zero_stage',
|
| 589 |
+
type=int,
|
| 590 |
+
default=0,
|
| 591 |
+
help='ZeRO optimization stage for Actor model (and clones).')
|
| 592 |
+
## LoRA for efficient training setting
|
| 593 |
+
parser.add_argument("--lora_dim",
|
| 594 |
+
type=int,
|
| 595 |
+
default=0,
|
| 596 |
+
help="If > 0, use LoRA for efficient training.")
|
| 597 |
+
parser.add_argument("--lora_module_name",
|
| 598 |
+
type=str,
|
| 599 |
+
default="decoder.layers.",
|
| 600 |
+
help="The scope of LoRA.")
|
| 601 |
+
parser.add_argument('--only_optimize_lora',
|
| 602 |
+
action='store_true',
|
| 603 |
+
help='Only optimize the LoRA parameters.')
|
| 604 |
+
parser.add_argument(
|
| 605 |
+
"--lora_learning_rate",
|
| 606 |
+
type=float,
|
| 607 |
+
default=5e-4,
|
| 608 |
+
help=
|
| 609 |
+
"Initial LoRA learning rate (after the potential warmup period) to use."
|
| 610 |
+
)
|
| 611 |
+
## low precision
|
| 612 |
+
parser.add_argument(
|
| 613 |
+
'--compute_fp32_loss',
|
| 614 |
+
action='store_true',
|
| 615 |
+
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
|
| 616 |
+
'If specified, loss is calculated in fp32.')
|
| 617 |
+
## Tensorboard logging
|
| 618 |
+
parser.add_argument('--enable_tensorboard',
|
| 619 |
+
action='store_true',
|
| 620 |
+
help='Enable tensorboard logging')
|
| 621 |
+
parser.add_argument('--tensorboard_path',
|
| 622 |
+
type=str,
|
| 623 |
+
default="step1_tensorboard")
|
| 624 |
+
## Tokenizer
|
| 625 |
+
parser.add_argument(
|
| 626 |
+
"--add_eot_token",
|
| 627 |
+
action='store_true',
|
| 628 |
+
help="Add `eot_token` as additional special token to tokenizer")
|
| 629 |
+
parser.add_argument(
|
| 630 |
+
"--eot_token",
|
| 631 |
+
type=str,
|
| 632 |
+
default="<|endoftext|>",
|
| 633 |
+
help="Specify the format of the `eot_token`",
|
| 634 |
+
)
|
| 635 |
+
## Print loss
|
| 636 |
+
parser.add_argument('--print_loss',
|
| 637 |
+
action='store_true',
|
| 638 |
+
help='Prints loss at each step.')
|
| 639 |
+
# 此处是所有超参数和训练参数的设置位置
|
| 640 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 641 |
+
args = parser.parse_args()
|
| 642 |
+
|
| 643 |
+
return args
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def main():
|
| 647 |
+
args = parse_args()
|
| 648 |
+
|
| 649 |
+
if args.local_rank == -1:
|
| 650 |
+
device = torch.device(get_accelerator().device_name())
|
| 651 |
+
else:
|
| 652 |
+
get_accelerator().set_device(args.local_rank)
|
| 653 |
+
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
| 654 |
+
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
| 655 |
+
# torch.distributed.init_process_group(backend='nccl')
|
| 656 |
+
deepspeed.init_distributed()
|
| 657 |
+
|
| 658 |
+
args.global_rank = torch.distributed.get_rank()
|
| 659 |
+
|
| 660 |
+
ds_config = get_train_ds_config(offload=args.offload,
|
| 661 |
+
dtype=args.dtype,
|
| 662 |
+
stage=args.zero_stage,
|
| 663 |
+
enable_tensorboard=args.enable_tensorboard,
|
| 664 |
+
tb_path=args.tensorboard_path,
|
| 665 |
+
tb_name="step1_model")
|
| 666 |
+
ds_config[
|
| 667 |
+
'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
|
| 668 |
+
ds_config[
|
| 669 |
+
'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
|
| 670 |
+
) * args.gradient_accumulation_steps
|
| 671 |
+
|
| 672 |
+
# If passed along, set the training seed now.
|
| 673 |
+
set_random_seed(args.seed)
|
| 674 |
+
|
| 675 |
+
torch.distributed.barrier()
|
| 676 |
+
|
| 677 |
+
# 实例化tokenizer和model
|
| 678 |
+
# load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
|
| 679 |
+
additional_special_tokens = args.eot_token if args.add_eot_token else None
|
| 680 |
+
tokenizer = load_hf_tokenizer(args.model_name_or_path,
|
| 681 |
+
fast_tokenizer=True,
|
| 682 |
+
add_special_tokens=additional_special_tokens)
|
| 683 |
+
|
| 684 |
+
model = create_hf_model(AutoModelForCausalLM,
|
| 685 |
+
args.model_name_or_path,
|
| 686 |
+
tokenizer,
|
| 687 |
+
ds_config,
|
| 688 |
+
dropout=args.dropout)
|
| 689 |
+
|
| 690 |
+
if args.compute_fp32_loss:
|
| 691 |
+
print_rank_0(
|
| 692 |
+
f"Using model {model.__class__.__name__} with loss in fp32",
|
| 693 |
+
args.global_rank)
|
| 694 |
+
causal_lm_model_to_fp32_loss(model)
|
| 695 |
+
|
| 696 |
+
# 设置LoRA微调
|
| 697 |
+
if args.lora_dim > 0:
|
| 698 |
+
model = convert_linear_layer_to_lora(model, args.lora_module_name,
|
| 699 |
+
args.lora_dim)
|
| 700 |
+
if args.only_optimize_lora:
|
| 701 |
+
model = only_optimize_lora_parameters(model)
|
| 702 |
+
model = make_model_gradient_checkpointing_compatible(model)
|
| 703 |
+
|
| 704 |
+
# 准备训练数据, 注意当前处于第⼀阶段 SFT
|
| 705 |
+
# Prepare the data
|
| 706 |
+
train_phase = 1
|
| 707 |
+
print('args: ', args)
|
| 708 |
+
print('data_path: ', args.data_path)
|
| 709 |
+
train_dataset, eval_dataset = create_prompt_dataset(
|
| 710 |
+
args.local_rank,
|
| 711 |
+
args.data_path,
|
| 712 |
+
args.data_split,
|
| 713 |
+
args.data_output_path,
|
| 714 |
+
train_phase,
|
| 715 |
+
args.seed,
|
| 716 |
+
tokenizer,
|
| 717 |
+
args.max_seq_len,
|
| 718 |
+
end_of_conversation_token=tokenizer.eos_token,
|
| 719 |
+
sft_only_data_path=args.sft_only_data_path)
|
| 720 |
+
# DataLoaders creation:
|
| 721 |
+
if args.local_rank == -1:
|
| 722 |
+
train_sampler = RandomSampler(train_dataset)
|
| 723 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 724 |
+
else:
|
| 725 |
+
train_sampler = DistributedSampler(train_dataset)
|
| 726 |
+
eval_sampler = DistributedSampler(eval_dataset)
|
| 727 |
+
train_dataloader = DataLoader(train_dataset,
|
| 728 |
+
collate_fn=default_data_collator,
|
| 729 |
+
sampler=train_sampler,
|
| 730 |
+
batch_size=args.per_device_train_batch_size)
|
| 731 |
+
eval_dataloader = DataLoader(eval_dataset,
|
| 732 |
+
collate_fn=default_data_collator,
|
| 733 |
+
sampler=eval_sampler,
|
| 734 |
+
batch_size=args.per_device_eval_batch_size)
|
| 735 |
+
|
| 736 |
+
# main内部定义的评估函数
|
| 737 |
+
def evaluation(model, eval_dataloader):
|
| 738 |
+
model.eval()
|
| 739 |
+
losses = 0
|
| 740 |
+
for step, batch in enumerate(eval_dataloader):
|
| 741 |
+
batch = to_device(batch, device)
|
| 742 |
+
with torch.no_grad():
|
| 743 |
+
outputs = model(**batch)
|
| 744 |
+
|
| 745 |
+
loss = outputs.loss
|
| 746 |
+
losses += loss.float()
|
| 747 |
+
losses = losses / (step + 1)
|
| 748 |
+
try:
|
| 749 |
+
losses = get_all_reduce_mean(losses)
|
| 750 |
+
except:
|
| 751 |
+
pass
|
| 752 |
+
try:
|
| 753 |
+
perplexity = torch.exp(losses).item()
|
| 754 |
+
except OverflowError:
|
| 755 |
+
perplexity = float("inf")
|
| 756 |
+
return perplexity, losses.item()
|
| 757 |
+
|
| 758 |
+
# 采⽤分组优化参数的优化器策略
|
| 759 |
+
# Split weights in two groups, one with weight decay and the other not.
|
| 760 |
+
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
|
| 761 |
+
model, args.weight_decay, args.lora_learning_rate)
|
| 762 |
+
|
| 763 |
+
AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
|
| 764 |
+
optimizer = AdamOptimizer(optimizer_grouped_parameters,
|
| 765 |
+
lr=args.learning_rate,
|
| 766 |
+
betas=(0.9, 0.95))
|
| 767 |
+
|
| 768 |
+
num_update_steps_per_epoch = math.ceil(
|
| 769 |
+
len(train_dataloader) / args.gradient_accumulation_steps)
|
| 770 |
+
lr_scheduler = get_scheduler(
|
| 771 |
+
name=args.lr_scheduler_type,
|
| 772 |
+
optimizer=optimizer,
|
| 773 |
+
num_warmup_steps=args.num_warmup_steps,
|
| 774 |
+
num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# 采⽤deepspeed对相关组件进⾏封装, 本质上是进⾏加速优化
|
| 778 |
+
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
| 779 |
+
model=model,
|
| 780 |
+
optimizer=optimizer,
|
| 781 |
+
args=args,
|
| 782 |
+
config=ds_config,
|
| 783 |
+
lr_scheduler=lr_scheduler,
|
| 784 |
+
dist_init_required=True)
|
| 785 |
+
|
| 786 |
+
if args.gradient_checkpointing:
|
| 787 |
+
model.gradient_checkpointing_enable()
|
| 788 |
+
|
| 789 |
+
# 开始训练, 打印⼀些关键信息
|
| 790 |
+
# Train!
|
| 791 |
+
print_rank_0("***** Running training *****", args.global_rank)
|
| 792 |
+
print_rank_0(
|
| 793 |
+
f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****",
|
| 794 |
+
args.global_rank)
|
| 795 |
+
perplexity, eval_loss = evaluation(model, eval_dataloader)
|
| 796 |
+
print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
|
| 797 |
+
|
| 798 |
+
# 经典的双重for循环训练模式
|
| 799 |
+
for epoch in range(args.num_train_epochs):
|
| 800 |
+
print_rank_0(
|
| 801 |
+
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
|
| 802 |
+
args.global_rank)
|
| 803 |
+
# 将模型设置为训练模式
|
| 804 |
+
model.train()
|
| 805 |
+
import time
|
| 806 |
+
for step, batch in enumerate(train_dataloader):
|
| 807 |
+
start = time.time()
|
| 808 |
+
batch = to_device(batch, device)
|
| 809 |
+
# 模型的前向传播计算, 并取到损失值loss
|
| 810 |
+
outputs = model(**batch, use_cache=False)
|
| 811 |
+
loss = outputs.loss
|
| 812 |
+
if args.print_loss:
|
| 813 |
+
print(
|
| 814 |
+
f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
|
| 815 |
+
)
|
| 816 |
+
# 反向传播, "⽼三样"
|
| 817 |
+
model.backward(loss)
|
| 818 |
+
model.step()
|
| 819 |
+
end = time.time()
|
| 820 |
+
if torch.distributed.get_rank() == 0:
|
| 821 |
+
print_throughput(model.model, args, end - start,
|
| 822 |
+
args.global_rank)
|
| 823 |
+
|
| 824 |
+
# 在验证集上进⾏评估, 获取困惑度
|
| 825 |
+
# Evaluate perplexity on the validation set.
|
| 826 |
+
print_rank_0(
|
| 827 |
+
f"***** Evaluating perplexity, Epoch {epoch+1}/{args.num_train_epochs} *****",
|
| 828 |
+
args.global_rank)
|
| 829 |
+
perplexity, eval_loss = evaluation(model, eval_dataloader)
|
| 830 |
+
print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank)
|
| 831 |
+
model.tput_timer.update_epoch_count()
|
| 832 |
+
|
| 833 |
+
if args.output_dir is not None:
|
| 834 |
+
print_rank_0('saving the final model ...', args.global_rank)
|
| 835 |
+
model = convert_lora_to_linear_layer(model)
|
| 836 |
+
|
| 837 |
+
if args.global_rank == 0:
|
| 838 |
+
# save_hf_format(model, tokenizer, args)
|
| 839 |
+
# 因为我们项⽬中需要训练的是Qwen3⼤模型, 需要保存成safetensor的格式
|
| 840 |
+
save_hf_format_safetensors(model, tokenizer, args)
|
| 841 |
+
|
| 842 |
+
if args.zero_stage == 3:
|
| 843 |
+
# 在zero_stage==3时, 每⼀个GPU只包含model的⼀部分, 因此需要⼀个特殊的函数来进⾏模型的保存
|
| 844 |
+
# For zero stage 3, each gpu only has a part of the model, so we need a special save function
|
| 845 |
+
#save_zero_three_model(model,
|
| 846 |
+
# args.global_rank,
|
| 847 |
+
# args.output_dir,
|
| 848 |
+
# zero_stage=args.zero_stage)
|
| 849 |
+
|
| 850 |
+
save_zero_three_model_safetensors(model,
|
| 851 |
+
torch.distributed.get_rank(),
|
| 852 |
+
"./output/final_model",
|
| 853 |
+
zero_stage=args.zero_stage,
|
| 854 |
+
lora_alpha=args.lora_dim,
|
| 855 |
+
merge_lora=True)
|
| 856 |
+
|
| 857 |
+
save_model_config_and_tokenizer(
|
| 858 |
+
model.module if hasattr(model, 'module') else model,
|
| 859 |
+
torch.distributed.get_rank(),
|
| 860 |
+
"./output/final_model",
|
| 861 |
+
base_model_path="workspace/Qwen3-4B"
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
if __name__ == "__main__":
|
| 865 |
+
main()
|
| 866 |
+
'''
|
SFT-EN-01-29-2026/code/model_utils.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# DeepSpeed Team
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoConfig,
|
| 12 |
+
AutoModel,
|
| 13 |
+
)
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from transformers.integrations import HfDeepSpeedConfig
|
| 16 |
+
|
| 17 |
+
from .reward_model import RewardModel
|
| 18 |
+
from ..utils import load_state_dict_into_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def configure_dropout(model_config, dropout):
|
| 22 |
+
if dropout is not None:
|
| 23 |
+
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
|
| 24 |
+
'activation_dropout'):
|
| 25 |
+
if hasattr(model_config, key):
|
| 26 |
+
print(f"Setting model_config.{key} to {dropout}")
|
| 27 |
+
setattr(model_config, key, dropout)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def causal_lm_model_to_fp32_loss(model):
|
| 31 |
+
""" Convert CausalLM model to calculate loss in fp32 """
|
| 32 |
+
|
| 33 |
+
def causal_lm_forward(
|
| 34 |
+
input_ids=None,
|
| 35 |
+
past_key_values=None,
|
| 36 |
+
attention_mask=None,
|
| 37 |
+
head_mask=None,
|
| 38 |
+
inputs_embeds=None,
|
| 39 |
+
labels=None,
|
| 40 |
+
use_cache=None,
|
| 41 |
+
output_attentions=None,
|
| 42 |
+
output_hidden_states=None,
|
| 43 |
+
return_dict=None,
|
| 44 |
+
**deprecated_arguments,
|
| 45 |
+
):
|
| 46 |
+
kwargs = dict() if model.config.model_type == "llama" else dict(
|
| 47 |
+
head_mask=head_mask)
|
| 48 |
+
output = model.__original_forward__(
|
| 49 |
+
input_ids=input_ids,
|
| 50 |
+
past_key_values=past_key_values,
|
| 51 |
+
attention_mask=attention_mask,
|
| 52 |
+
inputs_embeds=inputs_embeds,
|
| 53 |
+
labels=None,
|
| 54 |
+
use_cache=use_cache,
|
| 55 |
+
output_attentions=output_attentions,
|
| 56 |
+
output_hidden_states=output_hidden_states,
|
| 57 |
+
return_dict=return_dict,
|
| 58 |
+
**kwargs)
|
| 59 |
+
|
| 60 |
+
return_dict = isinstance(output, dict)
|
| 61 |
+
lm_logits = output.logits if return_dict else output[0]
|
| 62 |
+
loss = None
|
| 63 |
+
if labels is not None:
|
| 64 |
+
# move labels to correct device to enable model parallelism
|
| 65 |
+
labels = labels.to(lm_logits.device)
|
| 66 |
+
# Shift so that tokens < n predict n
|
| 67 |
+
shift_logits = lm_logits[..., :-1, :].float().contiguous()
|
| 68 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 69 |
+
batch_size, seq_length, vocab_size = shift_logits.shape
|
| 70 |
+
# Flatten the tokens
|
| 71 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 72 |
+
loss = loss_fct(
|
| 73 |
+
shift_logits.view(batch_size * seq_length, vocab_size),
|
| 74 |
+
shift_labels.view(batch_size * seq_length))
|
| 75 |
+
|
| 76 |
+
if not return_dict:
|
| 77 |
+
# re-pack output with fp32 loss
|
| 78 |
+
return ((loss, ) + output) if loss is not None else output
|
| 79 |
+
|
| 80 |
+
output.loss = loss
|
| 81 |
+
return output
|
| 82 |
+
|
| 83 |
+
model.__original_forward__ = model.forward
|
| 84 |
+
model.forward = causal_lm_forward
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_hf_model(model_class,
|
| 88 |
+
model_name_or_path,
|
| 89 |
+
tokenizer,
|
| 90 |
+
ds_config=None,
|
| 91 |
+
rlhf_training=False,
|
| 92 |
+
dropout=None):
|
| 93 |
+
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
| 94 |
+
configure_dropout(model_config, dropout)
|
| 95 |
+
|
| 96 |
+
# Note: dschf is defined in function scope to avoid global effects
|
| 97 |
+
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
|
| 98 |
+
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
|
| 99 |
+
dschf = HfDeepSpeedConfig(ds_config)
|
| 100 |
+
else:
|
| 101 |
+
dschf = None
|
| 102 |
+
if rlhf_training:
|
| 103 |
+
# the weight loading is handled by create critic model
|
| 104 |
+
with no_init_weights():
|
| 105 |
+
model = model_class.from_config(model_config)
|
| 106 |
+
else:
|
| 107 |
+
from transformers import AutoModelForCausalLM as _AutoModel
|
| 108 |
+
model = _AutoModel.from_pretrained(
|
| 109 |
+
model_name_or_path,
|
| 110 |
+
trust_remote_code=True,
|
| 111 |
+
torch_dtype="auto",
|
| 112 |
+
device_map=None)
|
| 113 |
+
|
| 114 |
+
model.config.end_token_id = tokenizer.eos_token_id
|
| 115 |
+
model.config.pad_token_id = model.config.eos_token_id
|
| 116 |
+
model.resize_token_embeddings(int(
|
| 117 |
+
8 *
|
| 118 |
+
math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
|
| 119 |
+
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
def create_critic_model(model_name_or_path,
|
| 123 |
+
tokenizer,
|
| 124 |
+
ds_config,
|
| 125 |
+
num_padding_at_beginning=0,
|
| 126 |
+
rlhf_training=False,
|
| 127 |
+
disable_dropout=False,
|
| 128 |
+
zero_stage=0):
|
| 129 |
+
start = time.time()
|
| 130 |
+
# 创建critic_model, 本质上也是调用上面的create_hf_model()函数
|
| 131 |
+
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
|
| 132 |
+
ds_config, rlhf_training, disable_dropout)
|
| 133 |
+
end = time.time()
|
| 134 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 可能有报错, 可以暂时先注释下面两行即可
|
| 135 |
+
if torch.distributed.get_rank() == 0:
|
| 136 |
+
print(f"> Creating model from_config took {end - start} seconds")
|
| 137 |
+
|
| 138 |
+
critic_model = RewardModel(critic_model,
|
| 139 |
+
tokenizer,
|
| 140 |
+
num_padding_at_beginning=num_padding_at_beginning)
|
| 141 |
+
|
| 142 |
+
if rlhf_training:
|
| 143 |
+
# load critic model from checkpoint
|
| 144 |
+
if not os.path.isdir(model_name_or_path):
|
| 145 |
+
model_name_or_path = snapshot_download(model_name_or_path)
|
| 146 |
+
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
|
| 147 |
+
assert os.path.exists(model_ckpt_path), f"Cannot find model checkpoint at {model_ckpt_path}"
|
| 148 |
+
|
| 149 |
+
start = time.time()
|
| 150 |
+
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
|
| 151 |
+
end = time.time()
|
| 152 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
|
| 153 |
+
if torch.distributed.get_rank() == 0:
|
| 154 |
+
print(f"> torch.load took {end - start} seconds")
|
| 155 |
+
|
| 156 |
+
# load critic model from checkpoint with zero-stage 3 compatibility
|
| 157 |
+
# this functionality may be moved to DS checkpoint load API in future
|
| 158 |
+
start = time.time()
|
| 159 |
+
load_state_dict_into_model(critic_model,
|
| 160 |
+
model_ckpt_state_dict,
|
| 161 |
+
"",
|
| 162 |
+
zero_stage=zero_stage)
|
| 163 |
+
end = time.time()
|
| 164 |
+
# 单独运行第二阶段训练Reward Model的评估代码run_eval.sh时, 有报错, 可以暂时先注释下面两行即可
|
| 165 |
+
if torch.distributed.get_rank() == 0:
|
| 166 |
+
print(f"> Loading model state dict took {end - start} seconds")
|
| 167 |
+
|
| 168 |
+
return critic_model
|
SFT-EN-01-29-2026/code/prompt_eval.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoModelForCausalLM,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
StoppingCriteria,
|
| 9 |
+
StoppingCriteriaList,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
p = argparse.ArgumentParser(description="Eval baseline vs finetuned SFT model (clean compare)")
|
| 17 |
+
p.add_argument("--model_name_or_path_baseline", type=str, required=True)
|
| 18 |
+
p.add_argument("--model_name_or_path_finetune", type=str, required=True)
|
| 19 |
+
p.add_argument("--max_new_tokens", type=int, default=200)
|
| 20 |
+
p.add_argument("--language", type=str, default="English", choices=["English", "Chinese"])
|
| 21 |
+
p.add_argument("--device", type=str, default=None, help="cuda / cpu. default: auto")
|
| 22 |
+
return p.parse_args()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_tokenizer(path: str):
|
| 26 |
+
tok = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
| 27 |
+
if tok.pad_token is None:
|
| 28 |
+
tok.pad_token = tok.eos_token
|
| 29 |
+
return tok
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StopOnSubsequence(StoppingCriteria):
|
| 33 |
+
def __init__(self, stop_token_seqs):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.stop_token_seqs = stop_token_seqs # List[List[int]]
|
| 36 |
+
|
| 37 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 38 |
+
seq = input_ids[0].tolist()
|
| 39 |
+
for stop_seq in self.stop_token_seqs:
|
| 40 |
+
if len(stop_seq) == 0:
|
| 41 |
+
continue
|
| 42 |
+
if len(seq) >= len(stop_seq) and seq[-len(stop_seq):] == stop_seq:
|
| 43 |
+
return True
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_stopping_criteria(tokenizer):
|
| 48 |
+
stop_strings = ["\nHuman:", "\nAssistant:", "Human:", "Assistant:"]
|
| 49 |
+
stop_token_seqs = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
|
| 50 |
+
return StoppingCriteriaList([StopOnSubsequence(stop_token_seqs)])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def post_trim(text: str):
|
| 54 |
+
markers = ["\nHuman:", "\nAssistant:", "Human:", "Assistant:"]
|
| 55 |
+
cut = None
|
| 56 |
+
for m in markers:
|
| 57 |
+
idx = text.find(m)
|
| 58 |
+
if idx != -1:
|
| 59 |
+
cut = idx if cut is None else min(cut, idx)
|
| 60 |
+
if cut is not None:
|
| 61 |
+
text = text[:cut]
|
| 62 |
+
return text.strip()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def generate_greedy(model, tokenizer, prompt, device, max_new_tokens=200):
|
| 66 |
+
enc = tokenizer(prompt, return_tensors="pt", padding=False, truncation=True, return_attention_mask=True)
|
| 67 |
+
input_ids = enc["input_ids"].to(device)
|
| 68 |
+
attention_mask = enc["attention_mask"].to(device)
|
| 69 |
+
|
| 70 |
+
stopping_criteria = build_stopping_criteria(tokenizer)
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
gen = model.generate(
|
| 74 |
+
input_ids=input_ids,
|
| 75 |
+
attention_mask=attention_mask,
|
| 76 |
+
max_new_tokens=max_new_tokens,
|
| 77 |
+
do_sample=False,
|
| 78 |
+
temperature=None,
|
| 79 |
+
top_p=None,
|
| 80 |
+
top_k=None,
|
| 81 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 82 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 83 |
+
stopping_criteria=stopping_criteria,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
new_tokens = gen[0][input_ids.shape[-1]:]
|
| 87 |
+
out = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 88 |
+
return post_trim(out)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_model(path: str, device: torch.device):
|
| 92 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 93 |
+
path,
|
| 94 |
+
trust_remote_code=True,
|
| 95 |
+
dtype=torch.bfloat16,
|
| 96 |
+
)
|
| 97 |
+
return model.to(device).eval()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_prompts(language: str):
|
| 101 |
+
if language == "English":
|
| 102 |
+
return [
|
| 103 |
+
"Human: My father was just diagnosed with diabetes a few days ago at the hospital. He is 60 years old with a blood sugar level of 10. What can he eat to improve his condition? Assistant:",
|
| 104 |
+
"Human: What is hemorrhoid prolapse? What should I do about it? What are the dangers? How should it be treated? Assistant:",
|
| 105 |
+
"Human: My grandmother is around 70 years old and has had high blood pressure for many years. Recently she has nosebleeds almost every day. A few days ago her submandibular lymph nodes were painful, and the hospital said there might be something wrong with her blood. Could this be leukemia? Assistant:",
|
| 106 |
+
"Human: My wisdom tooth is inflamed. Yesterday the dentist packed it with medicine but now it is swollen again. What should I do? Assistant:",
|
| 107 |
+
"Human: These past two days my child's nose seems to be blocked. When lying flat, breathing is difficult, but it gets better when picked up. Sometimes there's also coughing. What should I do? Assistant:",
|
| 108 |
+
"Human: Four days after intercourse, I tested positive for pregnancy. Ultrasound showed a gestational sac of 15.8mm. Is this pregnancy from this recent intercourse, or was I already pregnant before? Assistant:",
|
| 109 |
+
]
|
| 110 |
+
else:
|
| 111 |
+
return [
|
| 112 |
+
"Human: 爸爸前几天检查出糖尿病,60岁血糖10,吃什么能好转? Assistant:",
|
| 113 |
+
"Human: 什么是痔疮脱出?怎么治疗? Assistant:",
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
args = parse_args()
|
| 119 |
+
device = torch.device("cuda" if (args.device is None and torch.cuda.is_available()) else (args.device or "cpu"))
|
| 120 |
+
|
| 121 |
+
tok_base = load_tokenizer(args.model_name_or_path_baseline)
|
| 122 |
+
tok_ft = load_tokenizer(args.model_name_or_path_finetune)
|
| 123 |
+
|
| 124 |
+
print("Loading baseline model...")
|
| 125 |
+
model_base = load_model(args.model_name_or_path_baseline, device)
|
| 126 |
+
|
| 127 |
+
print("Loading finetuned model...")
|
| 128 |
+
model_ft = load_model(args.model_name_or_path_finetune, device)
|
| 129 |
+
|
| 130 |
+
prompts = get_prompts(args.language)
|
| 131 |
+
|
| 132 |
+
for i, prompt in enumerate(prompts):
|
| 133 |
+
print("\n" + "=" * 60)
|
| 134 |
+
print(f"Prompt {i+1}: {prompt[:80]}...")
|
| 135 |
+
|
| 136 |
+
print("\n=== Baseline ===")
|
| 137 |
+
out_base = generate_greedy(model_base, tok_base, prompt, device, args.max_new_tokens)
|
| 138 |
+
print(out_base if out_base else "(empty)")
|
| 139 |
+
|
| 140 |
+
print("\n=== Finetuned ===")
|
| 141 |
+
out_ft = generate_greedy(model_ft, tok_ft, prompt, device, args.max_new_tokens)
|
| 142 |
+
print(out_ft if out_ft else "(empty)")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
SFT-EN-01-29-2026/code/raw_datasets.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
from datasets import DatasetDict
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
# DeepSpeed Team
|
| 7 |
+
from datasets import load_dataset, load_from_disk
|
| 8 |
+
from torch.utils.data import Subset
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# The template prompt dataset class that all new dataset porting needs to
|
| 13 |
+
# follow in order to have a unified API and unified data format.
|
| 14 |
+
class PromptRawDataset(object):
|
| 15 |
+
|
| 16 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 17 |
+
self.output_path = output_path
|
| 18 |
+
self.seed = seed
|
| 19 |
+
self.local_rank = local_rank
|
| 20 |
+
#if os.path.exists(dataset_name):
|
| 21 |
+
# self.raw_datasets = load_from_disk(dataset_name)
|
| 22 |
+
if not dataset_name == 'local/jsonfile':
|
| 23 |
+
#self.raw_datasets = load_dataset(dataset_name)
|
| 24 |
+
self.raw_datasets = None
|
| 25 |
+
|
| 26 |
+
def get_train_data(self):
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
def get_eval_data(self):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
|
| 33 |
+
def get_prompt(self, sample):
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# The chosen response should be in the format of: " " + actual_response_sentence
|
| 37 |
+
def get_chosen(self, sample):
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# The rejected response should be in the format of: " " + actual_response_sentence
|
| 41 |
+
# If the dataset does not have rejected response, return None
|
| 42 |
+
def get_rejected(self, sample):
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
def get_prompt_and_chosen(self, sample):
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
def get_prompt_and_rejected(self, sample):
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# English dataset
|
| 53 |
+
class DahoasRmstaticDataset(PromptRawDataset):
|
| 54 |
+
|
| 55 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 56 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 57 |
+
self.dataset_name = "Dahoas/rm-static"
|
| 58 |
+
self.dataset_name_clean = "Dahoas_rm_static"
|
| 59 |
+
|
| 60 |
+
def get_train_data(self):
|
| 61 |
+
return self.raw_datasets["train"]
|
| 62 |
+
|
| 63 |
+
def get_eval_data(self):
|
| 64 |
+
return self.raw_datasets["test"]
|
| 65 |
+
|
| 66 |
+
def get_prompt(self, sample):
|
| 67 |
+
return sample['prompt']
|
| 68 |
+
|
| 69 |
+
def get_chosen(self, sample):
|
| 70 |
+
return sample['chosen']
|
| 71 |
+
|
| 72 |
+
def get_rejected(self, sample):
|
| 73 |
+
return sample['rejected']
|
| 74 |
+
|
| 75 |
+
def get_prompt_and_chosen(self, sample):
|
| 76 |
+
return sample['prompt'] + sample['chosen']
|
| 77 |
+
|
| 78 |
+
def get_prompt_and_rejected(self, sample):
|
| 79 |
+
return sample['prompt'] + sample['rejected']
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# English dataset
|
| 83 |
+
class DahoasFullhhrlhfDataset(PromptRawDataset):
|
| 84 |
+
|
| 85 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 86 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 87 |
+
self.dataset_name = "Dahoas/full-hh-rlhf"
|
| 88 |
+
self.dataset_name_clean = "Dahoas_full_hh_rlhf"
|
| 89 |
+
|
| 90 |
+
def get_train_data(self):
|
| 91 |
+
return self.raw_datasets["train"]
|
| 92 |
+
|
| 93 |
+
def get_eval_data(self):
|
| 94 |
+
return self.raw_datasets["test"]
|
| 95 |
+
|
| 96 |
+
def get_prompt(self, sample):
|
| 97 |
+
return sample['prompt']
|
| 98 |
+
|
| 99 |
+
def get_chosen(self, sample):
|
| 100 |
+
return sample['chosen']
|
| 101 |
+
|
| 102 |
+
def get_rejected(self, sample):
|
| 103 |
+
return sample['rejected']
|
| 104 |
+
|
| 105 |
+
def get_prompt_and_chosen(self, sample):
|
| 106 |
+
return sample['prompt'] + sample['chosen']
|
| 107 |
+
|
| 108 |
+
def get_prompt_and_rejected(self, sample):
|
| 109 |
+
return sample['prompt'] + sample['rejected']
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# English dataset
|
| 113 |
+
class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset):
|
| 114 |
+
|
| 115 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 116 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 117 |
+
self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise"
|
| 118 |
+
self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise"
|
| 119 |
+
|
| 120 |
+
def get_train_data(self):
|
| 121 |
+
from .data_utils import get_raw_dataset_split_index
|
| 122 |
+
dataset = self.raw_datasets["train"]
|
| 123 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 124 |
+
self.dataset_name_clean,
|
| 125 |
+
self.seed, "train_eval", "9,1", 0,
|
| 126 |
+
len(dataset))
|
| 127 |
+
dataset = Subset(dataset, index)
|
| 128 |
+
return dataset
|
| 129 |
+
|
| 130 |
+
def get_eval_data(self):
|
| 131 |
+
from .data_utils import get_raw_dataset_split_index
|
| 132 |
+
dataset = self.raw_datasets["train"]
|
| 133 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 134 |
+
self.dataset_name_clean,
|
| 135 |
+
self.seed, "train_eval", "9,1", 1,
|
| 136 |
+
len(dataset))
|
| 137 |
+
dataset = Subset(dataset, index)
|
| 138 |
+
return dataset
|
| 139 |
+
|
| 140 |
+
def get_prompt(self, sample):
|
| 141 |
+
return " Human: " + sample['prompt'] + " Assistant:"
|
| 142 |
+
|
| 143 |
+
def get_chosen(self, sample):
|
| 144 |
+
return " " + sample['chosen']
|
| 145 |
+
|
| 146 |
+
def get_rejected(self, sample):
|
| 147 |
+
return " " + sample['rejected']
|
| 148 |
+
|
| 149 |
+
def get_prompt_and_chosen(self, sample):
|
| 150 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
|
| 151 |
+
|
| 152 |
+
def get_prompt_and_rejected(self, sample):
|
| 153 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample[
|
| 154 |
+
'rejected']
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# English dataset
|
| 158 |
+
class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset):
|
| 159 |
+
|
| 160 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 161 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 162 |
+
self.dataset_name = "yitingxie/rlhf-reward-datasets"
|
| 163 |
+
self.dataset_name_clean = "yitingxie_rlhf_reward_datasets"
|
| 164 |
+
|
| 165 |
+
def get_train_data(self):
|
| 166 |
+
return self.raw_datasets["train"]
|
| 167 |
+
|
| 168 |
+
def get_eval_data(self):
|
| 169 |
+
return self.raw_datasets["test"]
|
| 170 |
+
|
| 171 |
+
def get_prompt(self, sample):
|
| 172 |
+
return sample['prompt'] + "Assistant:"
|
| 173 |
+
|
| 174 |
+
def get_chosen(self, sample):
|
| 175 |
+
return sample['chosen'].split("Assistant:")[-1]
|
| 176 |
+
|
| 177 |
+
def get_rejected(self, sample):
|
| 178 |
+
return sample['rejected'].split("Assistant:")[-1]
|
| 179 |
+
|
| 180 |
+
def get_prompt_and_chosen(self, sample):
|
| 181 |
+
return sample['prompt'] + sample['chosen']
|
| 182 |
+
|
| 183 |
+
def get_prompt_and_rejected(self, sample):
|
| 184 |
+
return sample['prompt'] + sample['rejected']
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# English dataset
|
| 188 |
+
class OpenaiWebgptcomparisonsDataset(PromptRawDataset):
|
| 189 |
+
|
| 190 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 191 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 192 |
+
self.dataset_name = "openai/webgpt_comparisons"
|
| 193 |
+
self.dataset_name_clean = "openai_webgpt_comparisons"
|
| 194 |
+
|
| 195 |
+
def get_train_data(self):
|
| 196 |
+
from .data_utils import get_raw_dataset_split_index
|
| 197 |
+
dataset = self.raw_datasets["train"]
|
| 198 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 199 |
+
self.dataset_name_clean,
|
| 200 |
+
self.seed, "train_eval", "9,1", 0,
|
| 201 |
+
len(dataset))
|
| 202 |
+
dataset = Subset(dataset, index)
|
| 203 |
+
return dataset
|
| 204 |
+
|
| 205 |
+
def get_eval_data(self):
|
| 206 |
+
from .data_utils import get_raw_dataset_split_index
|
| 207 |
+
dataset = self.raw_datasets["train"]
|
| 208 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 209 |
+
self.dataset_name_clean,
|
| 210 |
+
self.seed, "train_eval", "9,1", 1,
|
| 211 |
+
len(dataset))
|
| 212 |
+
dataset = Subset(dataset, index)
|
| 213 |
+
return dataset
|
| 214 |
+
|
| 215 |
+
def get_prompt(self, sample):
|
| 216 |
+
return " Human: " + sample['question']['full_text'] + " Assistant:"
|
| 217 |
+
|
| 218 |
+
def get_chosen(self, sample):
|
| 219 |
+
if float(sample['score_0']) >= float(sample['score_1']):
|
| 220 |
+
response = sample['answer_0']
|
| 221 |
+
else:
|
| 222 |
+
response = sample['answer_1']
|
| 223 |
+
# This data has citation square brackets and numbers (e.g., "[1]").
|
| 224 |
+
# Right now we are not doing browser-assisted finetuning, thus we
|
| 225 |
+
# remove these citations to avoid confusing the model.
|
| 226 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 227 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 228 |
+
return " " + response
|
| 229 |
+
|
| 230 |
+
def get_rejected(self, sample):
|
| 231 |
+
if float(sample['score_0']) < float(sample['score_1']):
|
| 232 |
+
response = sample['answer_0']
|
| 233 |
+
else:
|
| 234 |
+
response = sample['answer_1']
|
| 235 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 236 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 237 |
+
return " " + response
|
| 238 |
+
|
| 239 |
+
def get_prompt_and_chosen(self, sample):
|
| 240 |
+
if float(sample['score_0']) >= float(sample['score_1']):
|
| 241 |
+
response = sample['answer_0']
|
| 242 |
+
else:
|
| 243 |
+
response = sample['answer_1']
|
| 244 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 245 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 246 |
+
return " Human: " + sample['question'][
|
| 247 |
+
'full_text'] + " Assistant: " + response
|
| 248 |
+
|
| 249 |
+
def get_prompt_and_rejected(self, sample):
|
| 250 |
+
if float(sample['score_0']) < float(sample['score_1']):
|
| 251 |
+
response = sample['answer_0']
|
| 252 |
+
else:
|
| 253 |
+
response = sample['answer_1']
|
| 254 |
+
response = re.sub(r" [\(\[].*?[\)\]]", "", response)
|
| 255 |
+
response = re.sub(r"[\(\[].*?[\)\]]", "", response)
|
| 256 |
+
return " Human: " + sample['question'][
|
| 257 |
+
'full_text'] + " Assistant: " + response
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# English dataset
|
| 261 |
+
class StanfordnlpSHPDataset(PromptRawDataset):
|
| 262 |
+
|
| 263 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 264 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 265 |
+
self.dataset_name = "stanfordnlp/SHP"
|
| 266 |
+
self.dataset_name_clean = "stanfordnlp_SHP"
|
| 267 |
+
|
| 268 |
+
def get_train_data(self):
|
| 269 |
+
return self.raw_datasets["train"]
|
| 270 |
+
|
| 271 |
+
def get_eval_data(self):
|
| 272 |
+
return self.raw_datasets["validation"]
|
| 273 |
+
|
| 274 |
+
def get_prompt(self, sample):
|
| 275 |
+
return " Human: " + sample['history'] + " Assistant:"
|
| 276 |
+
|
| 277 |
+
def get_chosen(self, sample):
|
| 278 |
+
if int(sample["labels"]) == 1:
|
| 279 |
+
response = sample["human_ref_A"]
|
| 280 |
+
else:
|
| 281 |
+
response = sample["human_ref_B"]
|
| 282 |
+
return " " + response
|
| 283 |
+
|
| 284 |
+
def get_rejected(self, sample):
|
| 285 |
+
if int(sample["labels"]) == 1:
|
| 286 |
+
response = sample["human_ref_B"]
|
| 287 |
+
else:
|
| 288 |
+
response = sample["human_ref_A"]
|
| 289 |
+
return " " + response
|
| 290 |
+
|
| 291 |
+
def get_prompt_and_chosen(self, sample):
|
| 292 |
+
if int(sample["labels"]) == 1:
|
| 293 |
+
response = sample["human_ref_A"]
|
| 294 |
+
else:
|
| 295 |
+
response = sample["human_ref_B"]
|
| 296 |
+
return " Human: " + sample['history'] + " Assistant: " + response
|
| 297 |
+
|
| 298 |
+
def get_prompt_and_rejected(self, sample):
|
| 299 |
+
if int(sample["labels"]) == 1:
|
| 300 |
+
response = sample["human_ref_B"]
|
| 301 |
+
else:
|
| 302 |
+
response = sample["human_ref_A"]
|
| 303 |
+
return " Human: " + sample['history'] + " Assistant: " + response
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# English dataset
|
| 307 |
+
class PvduySharegptalpacaoavicunaformatDataset(PromptRawDataset):
|
| 308 |
+
|
| 309 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 310 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 311 |
+
self.dataset_name = "pvduy/sharegpt_alpaca_oa_vicuna_format"
|
| 312 |
+
self.dataset_name_clean = "pvduy_sharegpt_alpaca_oa_vicuna_format"
|
| 313 |
+
|
| 314 |
+
def get_train_data(self):
|
| 315 |
+
return self.raw_datasets["train"]
|
| 316 |
+
|
| 317 |
+
def get_eval_data(self):
|
| 318 |
+
return self.raw_datasets["test"]
|
| 319 |
+
|
| 320 |
+
def get_prompt(self, sample):
|
| 321 |
+
if sample['prompt'] is not None and len(sample['prompt']) > 0:
|
| 322 |
+
return sample['prompt'].replace("USER", "Human").replace(
|
| 323 |
+
"ASSISTANT", "Assistant")
|
| 324 |
+
return None
|
| 325 |
+
|
| 326 |
+
def get_chosen(self, sample):
|
| 327 |
+
if sample['label'] is not None and len(sample['label']) > 0:
|
| 328 |
+
return " " + sample['label']
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
def get_rejected(self, sample):
|
| 332 |
+
print(
|
| 333 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 334 |
+
)
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
def get_prompt_and_chosen(self, sample):
|
| 338 |
+
if sample['prompt'] is not None and sample['label'] is not None and len(
|
| 339 |
+
sample['prompt']) > 0 and len(sample['label']) > 0:
|
| 340 |
+
return sample['prompt'].replace("USER", "Human").replace(
|
| 341 |
+
"ASSISTANT", "Assistant") + " " + sample['label']
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
def get_prompt_and_rejected(self, sample):
|
| 345 |
+
print(
|
| 346 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 347 |
+
)
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class LocalJsonFileDataset(PromptRawDataset):
|
| 352 |
+
|
| 353 |
+
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
|
| 354 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 355 |
+
self.dataset_name = "local/jsonfile"
|
| 356 |
+
self.dataset_name_clean = "jsonfile"
|
| 357 |
+
self.raw_datasets = load_dataset('json',
|
| 358 |
+
data_files={
|
| 359 |
+
"train":
|
| 360 |
+
chat_path + '/data/train.json',
|
| 361 |
+
"eval":
|
| 362 |
+
chat_path + '/data/eval.json'
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
def get_train_data(self):
|
| 366 |
+
if self.raw_datasets['train'] is not None:
|
| 367 |
+
return self.raw_datasets['train']
|
| 368 |
+
return None
|
| 369 |
+
|
| 370 |
+
def get_eval_data(self):
|
| 371 |
+
if self.raw_datasets['eval'] is not None:
|
| 372 |
+
return self.raw_datasets['eval']
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
|
| 376 |
+
def get_prompt(self, sample):
|
| 377 |
+
if sample['prompt'] is not None:
|
| 378 |
+
return " " + sample['prompt']
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
# The chosen response should be in the format of: " " + actual_response_sentence
|
| 382 |
+
def get_chosen(self, sample):
|
| 383 |
+
if sample['chosen'] is not None:
|
| 384 |
+
return " " + sample['chosen']
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
# The rejected response should be in the format of: " " + actual_response_sentence
|
| 388 |
+
# If the dataset does not have rejected response, return None
|
| 389 |
+
def get_rejected(self, sample):
|
| 390 |
+
if sample['rejected'] is not None:
|
| 391 |
+
return " " + sample['rejected']
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
def get_prompt_and_chosen(self, sample):
|
| 395 |
+
if sample['prompt'] is not None and sample['chosen'] is not None:
|
| 396 |
+
return " " + sample['prompt'] + " " + sample['chosen']
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
def get_prompt_and_rejected(self, sample):
|
| 400 |
+
if sample['prompt'] is not None and sample['rejected'] is not None:
|
| 401 |
+
return " " + sample['prompt'] + " " + sample['rejected']
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
# Chinese dataset
|
| 406 |
+
class Wangrui6ZhihuKOLDataset(PromptRawDataset):
|
| 407 |
+
|
| 408 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 409 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 410 |
+
self.dataset_name = "wangrui6/Zhihu-KOL"
|
| 411 |
+
self.dataset_name_clean = "wangrui6_Zhihu_KOL"
|
| 412 |
+
|
| 413 |
+
def get_train_data(self):
|
| 414 |
+
from .data_utils import get_raw_dataset_split_index
|
| 415 |
+
dataset = self.raw_datasets["train"]
|
| 416 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 417 |
+
self.dataset_name_clean,
|
| 418 |
+
self.seed, "train_eval", "9,1", 0,
|
| 419 |
+
len(dataset))
|
| 420 |
+
dataset = Subset(dataset, index)
|
| 421 |
+
return dataset
|
| 422 |
+
|
| 423 |
+
def get_eval_data(self):
|
| 424 |
+
from .data_utils import get_raw_dataset_split_index
|
| 425 |
+
dataset = self.raw_datasets["train"]
|
| 426 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 427 |
+
self.dataset_name_clean,
|
| 428 |
+
self.seed, "train_eval", "9,1", 1,
|
| 429 |
+
len(dataset))
|
| 430 |
+
dataset = Subset(dataset, index)
|
| 431 |
+
return dataset
|
| 432 |
+
|
| 433 |
+
def get_prompt(self, sample):
|
| 434 |
+
if sample['INSTRUCTION'] is not None:
|
| 435 |
+
return " Human: " + sample['INSTRUCTION'] + " Assistant:"
|
| 436 |
+
return None
|
| 437 |
+
|
| 438 |
+
def get_chosen(self, sample):
|
| 439 |
+
if sample['RESPONSE'] is not None:
|
| 440 |
+
return " " + sample['RESPONSE']
|
| 441 |
+
return None
|
| 442 |
+
|
| 443 |
+
def get_rejected(self, sample):
|
| 444 |
+
print(
|
| 445 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 446 |
+
)
|
| 447 |
+
return None
|
| 448 |
+
|
| 449 |
+
def get_prompt_and_chosen(self, sample):
|
| 450 |
+
if sample['INSTRUCTION'] is not None and sample['RESPONSE'] is not None:
|
| 451 |
+
return " Human: " + sample[
|
| 452 |
+
'INSTRUCTION'] + " Assistant: " + sample['RESPONSE']
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
def get_prompt_and_rejected(self, sample):
|
| 456 |
+
print(
|
| 457 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 458 |
+
)
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Chinese dataset
|
| 463 |
+
class CohereMiraclzhqueries2212Dataset(PromptRawDataset):
|
| 464 |
+
|
| 465 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 466 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 467 |
+
self.dataset_name = "Cohere/miracl-zh-queries-22-12"
|
| 468 |
+
self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12"
|
| 469 |
+
|
| 470 |
+
def get_train_data(self):
|
| 471 |
+
return self.raw_datasets["train"]
|
| 472 |
+
|
| 473 |
+
def get_eval_data(self):
|
| 474 |
+
return self.raw_datasets["dev"]
|
| 475 |
+
|
| 476 |
+
def get_prompt(self, sample):
|
| 477 |
+
return " Human: " + sample['query'] + " Assistant:"
|
| 478 |
+
|
| 479 |
+
def get_chosen(self, sample):
|
| 480 |
+
return " " + sample['positive_passages'][0]['text']
|
| 481 |
+
|
| 482 |
+
def get_rejected(self, sample):
|
| 483 |
+
return " " + sample['negative_passages'][0]['text']
|
| 484 |
+
|
| 485 |
+
def get_prompt_and_chosen(self, sample):
|
| 486 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 487 |
+
'positive_passages'][0]['text']
|
| 488 |
+
|
| 489 |
+
def get_prompt_and_rejected(self, sample):
|
| 490 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 491 |
+
'negative_passages'][0]['text']
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# Chinese dataset
|
| 495 |
+
class HelloSimpleAIHC3ChineseDataset(PromptRawDataset):
|
| 496 |
+
|
| 497 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 498 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 499 |
+
self.dataset_name = "Hello-SimpleAI/HC3-Chinese"
|
| 500 |
+
self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese"
|
| 501 |
+
|
| 502 |
+
def get_train_data(self):
|
| 503 |
+
from .data_utils import get_raw_dataset_split_index
|
| 504 |
+
dataset = self.raw_datasets["train"]
|
| 505 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 506 |
+
self.dataset_name_clean,
|
| 507 |
+
self.seed, "train_eval", "9,1", 0,
|
| 508 |
+
len(dataset))
|
| 509 |
+
dataset = Subset(dataset, index)
|
| 510 |
+
return dataset
|
| 511 |
+
|
| 512 |
+
def get_eval_data(self):
|
| 513 |
+
from .data_utils import get_raw_dataset_split_index
|
| 514 |
+
dataset = self.raw_datasets["train"]
|
| 515 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 516 |
+
self.dataset_name_clean,
|
| 517 |
+
self.seed, "train_eval", "9,1", 1,
|
| 518 |
+
len(dataset))
|
| 519 |
+
dataset = Subset(dataset, index)
|
| 520 |
+
return dataset
|
| 521 |
+
|
| 522 |
+
def get_prompt(self, sample):
|
| 523 |
+
if sample['question'] is not None:
|
| 524 |
+
return " Human: " + sample['question'] + " Assistant:"
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
def get_chosen(self, sample):
|
| 528 |
+
if sample['human_answers'][0] is not None:
|
| 529 |
+
return " " + sample['human_answers'][0]
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
def get_rejected(self, sample):
|
| 533 |
+
print(
|
| 534 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 535 |
+
)
|
| 536 |
+
return None
|
| 537 |
+
|
| 538 |
+
def get_prompt_and_chosen(self, sample):
|
| 539 |
+
if sample['question'] is not None and sample['human_answers'][
|
| 540 |
+
0] is not None:
|
| 541 |
+
return " Human: " + sample['question'] + " Assistant: " + sample[
|
| 542 |
+
'human_answers'][0]
|
| 543 |
+
return None
|
| 544 |
+
|
| 545 |
+
def get_prompt_and_rejected(self, sample):
|
| 546 |
+
print(
|
| 547 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 548 |
+
)
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# Chinese dataset
|
| 553 |
+
class MkqaChineseDataset(PromptRawDataset):
|
| 554 |
+
|
| 555 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 556 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 557 |
+
self.dataset_name = "mkqa-Chinese"
|
| 558 |
+
self.dataset_name_clean = "mkqa"
|
| 559 |
+
|
| 560 |
+
def get_train_data(self):
|
| 561 |
+
from .data_utils import get_raw_dataset_split_index
|
| 562 |
+
dataset = self.raw_datasets["train"]
|
| 563 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 564 |
+
self.dataset_name_clean,
|
| 565 |
+
self.seed, "train_eval", "9,1", 0,
|
| 566 |
+
len(dataset))
|
| 567 |
+
dataset = Subset(dataset, index)
|
| 568 |
+
return dataset
|
| 569 |
+
|
| 570 |
+
def get_eval_data(self):
|
| 571 |
+
from .data_utils import get_raw_dataset_split_index
|
| 572 |
+
dataset = self.raw_datasets["train"]
|
| 573 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 574 |
+
self.dataset_name_clean,
|
| 575 |
+
self.seed, "train_eval", "9,1", 1,
|
| 576 |
+
len(dataset))
|
| 577 |
+
dataset = Subset(dataset, index)
|
| 578 |
+
return dataset
|
| 579 |
+
|
| 580 |
+
def get_prompt(self, sample):
|
| 581 |
+
if sample['queries']['zh_cn'] is not None:
|
| 582 |
+
return " Human: " + sample['queries']['zh_cn'] + " Assistant:"
|
| 583 |
+
return None
|
| 584 |
+
|
| 585 |
+
def get_chosen(self, sample):
|
| 586 |
+
if sample['answers']['zh_cn'][0]['text'] is not None:
|
| 587 |
+
return " " + sample['answers']['zh_cn'][0]['text']
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
def get_rejected(self, sample):
|
| 591 |
+
print(
|
| 592 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 593 |
+
)
|
| 594 |
+
return None
|
| 595 |
+
|
| 596 |
+
def get_prompt_and_chosen(self, sample):
|
| 597 |
+
if sample['queries']['zh_cn'] is not None and sample['answers'][
|
| 598 |
+
'zh_cn'][0]['text'] is not None:
|
| 599 |
+
return " Human: " + sample['queries'][
|
| 600 |
+
'zh_cn'] + " Assistant: " + sample['answers']['zh_cn'][0][
|
| 601 |
+
'text']
|
| 602 |
+
return None
|
| 603 |
+
|
| 604 |
+
def get_prompt_and_rejected(self, sample):
|
| 605 |
+
print(
|
| 606 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 607 |
+
)
|
| 608 |
+
return None
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Japanese dataset
|
| 612 |
+
class MkqaJapaneseDataset(PromptRawDataset):
|
| 613 |
+
|
| 614 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 615 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 616 |
+
self.dataset_name = "mkqa-Japanese"
|
| 617 |
+
self.dataset_name_clean = "mkqa"
|
| 618 |
+
|
| 619 |
+
def get_train_data(self):
|
| 620 |
+
from .data_utils import get_raw_dataset_split_index
|
| 621 |
+
dataset = self.raw_datasets["train"]
|
| 622 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 623 |
+
self.dataset_name_clean,
|
| 624 |
+
self.seed, "train_eval", "9,1", 0,
|
| 625 |
+
len(dataset))
|
| 626 |
+
dataset = Subset(dataset, index)
|
| 627 |
+
return dataset
|
| 628 |
+
|
| 629 |
+
def get_eval_data(self):
|
| 630 |
+
from .data_utils import get_raw_dataset_split_index
|
| 631 |
+
dataset = self.raw_datasets["train"]
|
| 632 |
+
index = get_raw_dataset_split_index(self.local_rank, self.output_path,
|
| 633 |
+
self.dataset_name_clean,
|
| 634 |
+
self.seed, "train_eval", "9,1", 1,
|
| 635 |
+
len(dataset))
|
| 636 |
+
dataset = Subset(dataset, index)
|
| 637 |
+
return dataset
|
| 638 |
+
|
| 639 |
+
def get_prompt(self, sample):
|
| 640 |
+
if sample['queries']['ja'] is not None:
|
| 641 |
+
return " Human: " + sample['queries']['ja'] + " Assistant:"
|
| 642 |
+
return None
|
| 643 |
+
|
| 644 |
+
def get_chosen(self, sample):
|
| 645 |
+
if sample['answers']['ja'][0]['text'] is not None:
|
| 646 |
+
return " " + sample['answers']['ja'][0]['text']
|
| 647 |
+
return None
|
| 648 |
+
|
| 649 |
+
def get_rejected(self, sample):
|
| 650 |
+
print(
|
| 651 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 652 |
+
)
|
| 653 |
+
return None
|
| 654 |
+
|
| 655 |
+
def get_prompt_and_chosen(self, sample):
|
| 656 |
+
if sample['queries']['ja'] is not None and sample['answers']['ja'][0][
|
| 657 |
+
'text'] is not None:
|
| 658 |
+
return " Human: " + sample['queries'][
|
| 659 |
+
'ja'] + " Assistant: " + sample['answers']['ja'][0]['text']
|
| 660 |
+
return None
|
| 661 |
+
|
| 662 |
+
def get_prompt_and_rejected(self, sample):
|
| 663 |
+
print(
|
| 664 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 665 |
+
)
|
| 666 |
+
return None
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# Japanese dataset
|
| 670 |
+
class CohereMiracljaqueries2212Dataset(PromptRawDataset):
|
| 671 |
+
|
| 672 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 673 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 674 |
+
self.dataset_name = "Cohere/miracl-ja-queries-22-12"
|
| 675 |
+
self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12"
|
| 676 |
+
|
| 677 |
+
def get_train_data(self):
|
| 678 |
+
return self.raw_datasets["train"]
|
| 679 |
+
|
| 680 |
+
def get_eval_data(self):
|
| 681 |
+
return self.raw_datasets["dev"]
|
| 682 |
+
|
| 683 |
+
def get_prompt(self, sample):
|
| 684 |
+
return " Human: " + sample['query'] + " Assistant:"
|
| 685 |
+
|
| 686 |
+
def get_chosen(self, sample):
|
| 687 |
+
return " " + sample['positive_passages'][0]['text']
|
| 688 |
+
|
| 689 |
+
def get_rejected(self, sample):
|
| 690 |
+
return " " + sample['negative_passages'][0]['text']
|
| 691 |
+
|
| 692 |
+
def get_prompt_and_chosen(self, sample):
|
| 693 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 694 |
+
'positive_passages'][0]['text']
|
| 695 |
+
|
| 696 |
+
def get_prompt_and_rejected(self, sample):
|
| 697 |
+
if len(sample['negative_passages']) > 0:
|
| 698 |
+
return " Human: " + sample['query'] + " Assistant: " + sample[
|
| 699 |
+
'negative_passages'][0]['text']
|
| 700 |
+
return None
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
# Japanese dataset
|
| 704 |
+
class LmqgQgjaquadDataset(PromptRawDataset):
|
| 705 |
+
|
| 706 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 707 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 708 |
+
self.dataset_name = "lmqg/qg_jaquad"
|
| 709 |
+
self.dataset_name_clean = "lmqg_qg_jaquad"
|
| 710 |
+
|
| 711 |
+
def get_train_data(self):
|
| 712 |
+
return self.raw_datasets["train"]
|
| 713 |
+
|
| 714 |
+
def get_eval_data(self):
|
| 715 |
+
return self.raw_datasets["validation"]
|
| 716 |
+
|
| 717 |
+
def get_prompt(self, sample):
|
| 718 |
+
return " Human: " + sample['question'] + " Assistant:"
|
| 719 |
+
|
| 720 |
+
def get_chosen(self, sample):
|
| 721 |
+
return " " + sample['sentence']
|
| 722 |
+
|
| 723 |
+
def get_rejected(self, sample):
|
| 724 |
+
print(
|
| 725 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 726 |
+
)
|
| 727 |
+
return None
|
| 728 |
+
|
| 729 |
+
def get_prompt_and_chosen(self, sample):
|
| 730 |
+
return " Human: " + sample['question'] + " Assistant: " + sample[
|
| 731 |
+
'sentence']
|
| 732 |
+
|
| 733 |
+
def get_prompt_and_rejected(self, sample):
|
| 734 |
+
print(
|
| 735 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 736 |
+
)
|
| 737 |
+
return None
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
# Japanese dataset
|
| 741 |
+
class LmqgQagjaquadDataset(PromptRawDataset):
|
| 742 |
+
|
| 743 |
+
def __init__(self, output_path, seed, local_rank, dataset_name):
|
| 744 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 745 |
+
self.dataset_name = "lmqg/qag_jaquad"
|
| 746 |
+
self.dataset_name_clean = "lmqg_qag_jaquad"
|
| 747 |
+
|
| 748 |
+
def get_train_data(self):
|
| 749 |
+
return self.raw_datasets["train"]
|
| 750 |
+
|
| 751 |
+
def get_eval_data(self):
|
| 752 |
+
return self.raw_datasets["validation"]
|
| 753 |
+
|
| 754 |
+
def get_prompt(self, sample):
|
| 755 |
+
return " Human: " + sample['questions'][0] + " Assistant:"
|
| 756 |
+
|
| 757 |
+
def get_chosen(self, sample):
|
| 758 |
+
return " " + sample['paragraph']
|
| 759 |
+
|
| 760 |
+
def get_rejected(self, sample):
|
| 761 |
+
print(
|
| 762 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 763 |
+
)
|
| 764 |
+
return None
|
| 765 |
+
|
| 766 |
+
def get_prompt_and_chosen(self, sample):
|
| 767 |
+
return " Human: " + sample['questions'][0] + " Assistant: " + sample[
|
| 768 |
+
'paragraph']
|
| 769 |
+
|
| 770 |
+
def get_prompt_and_rejected(self, sample):
|
| 771 |
+
print(
|
| 772 |
+
f"Warning: dataset {self.dataset_name} does not include rejected response."
|
| 773 |
+
)
|
| 774 |
+
return None
|
| 775 |
+
# CustomDataset: 自定义数据集类,用于训练个性化垂直领域大模型,继承基类PromptRawDataset
|
| 776 |
+
class CustomDataset(PromptRawDataset):
|
| 777 |
+
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
|
| 778 |
+
super().__init__(output_path, seed, local_rank, dataset_name)
|
| 779 |
+
# 个性化数据集的名字可以自定义
|
| 780 |
+
self.dataset_name = "custom"
|
| 781 |
+
self.dataset_name_clean = "custom"
|
| 782 |
+
# 设定要读取的数据集所在的绝对路径
|
| 783 |
+
train_path = chat_path + '/data/train.jsonl'
|
| 784 |
+
eval_path = chat_path + '/data/dev.jsonl'
|
| 785 |
+
# 通过DatasetDict的类封装数据, 和load_dataset()函数保持一致.
|
| 786 |
+
self.raw_datasets = DatasetDict.from_json({'train': train_path, 'eval': eval_path})
|
| 787 |
+
|
| 788 |
+
# 返回训练集数据
|
| 789 |
+
def get_train_data(self):
|
| 790 |
+
if self.raw_datasets['train'] is not None:
|
| 791 |
+
return self.raw_datasets['train']
|
| 792 |
+
return None
|
| 793 |
+
|
| 794 |
+
# 返回验证集数据
|
| 795 |
+
def get_eval_data(self):
|
| 796 |
+
if self.raw_datasets['eval'] is not None:
|
| 797 |
+
return self.raw_datasets['eval']
|
| 798 |
+
return None
|
| 799 |
+
|
| 800 |
+
# 构造prompt输入模型的格式: Human: prompt Assistant:
|
| 801 |
+
def get_prompt(self, sample):
|
| 802 |
+
if sample['prompt'] is not None:
|
| 803 |
+
return " Human: " + sample['prompt'] + " Assistant:"
|
| 804 |
+
return None
|
| 805 |
+
|
| 806 |
+
# 构造chosen输入模型的格式: chosen
|
| 807 |
+
def get_chosen(self, sample):
|
| 808 |
+
if sample['chosen'] is not None:
|
| 809 |
+
return " " + sample['chosen']
|
| 810 |
+
return None
|
| 811 |
+
|
| 812 |
+
# 构造reject输入模型的格式: reject
|
| 813 |
+
def get_rejected(self, sample):
|
| 814 |
+
if sample['rejected'] is not None:
|
| 815 |
+
return " " + sample['reject']
|
| 816 |
+
return None
|
| 817 |
+
|
| 818 |
+
# 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: chosen
|
| 819 |
+
def get_prompt_and_chosen(self, sample):
|
| 820 |
+
if sample['prompt'] is not None and sample['chosen'] is not None:
|
| 821 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen']
|
| 822 |
+
return None
|
| 823 |
+
|
| 824 |
+
# 构造第二阶段训练Reward Model的输入模型格式: Human: prompt Assistant: reject
|
| 825 |
+
def get_prompt_and_rejected(self, sample):
|
| 826 |
+
if sample['prompt'] is not None and sample['reject'] is not None:
|
| 827 |
+
return " Human: " + sample['prompt'] + " Assistant: " + sample['reject']
|
| 828 |
+
return None
|
SFT-EN-01-29-2026/code/utils.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
from safetensors.torch import save_file
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
# DeepSpeed Team
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
from transformers import set_seed, AutoTokenizer
|
| 11 |
+
import json
|
| 12 |
+
import deepspeed
|
| 13 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 14 |
+
from deepspeed.accelerator import get_accelerator
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def print_rank_0(msg, rank=None):
|
| 19 |
+
if rank is not None and rank <= 0:
|
| 20 |
+
print(msg)
|
| 21 |
+
elif is_rank_0():
|
| 22 |
+
print(msg)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_rank_0():
|
| 26 |
+
"""Check whether it is rank 0."""
|
| 27 |
+
if torch.distributed.is_initialized():
|
| 28 |
+
if torch.distributed.get_rank() == 0:
|
| 29 |
+
return True
|
| 30 |
+
else:
|
| 31 |
+
return False
|
| 32 |
+
else:
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def to_device(batch, device):
|
| 37 |
+
output = {}
|
| 38 |
+
for k, v in batch.items():
|
| 39 |
+
try:
|
| 40 |
+
output[k] = v.to(device)
|
| 41 |
+
except:
|
| 42 |
+
output[k] = v
|
| 43 |
+
return output
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MovingAverage:
|
| 47 |
+
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.count = 0
|
| 50 |
+
self.total = 0
|
| 51 |
+
self.mean = 0
|
| 52 |
+
|
| 53 |
+
def update(self, num):
|
| 54 |
+
self.total += num
|
| 55 |
+
self.count += 1
|
| 56 |
+
self.mean = self.total / self.count
|
| 57 |
+
|
| 58 |
+
return self.mean
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ExponentialMovingAverage:
|
| 62 |
+
|
| 63 |
+
def __init__(self, alpha=0.9):
|
| 64 |
+
self.alpha = alpha
|
| 65 |
+
self.ema = None
|
| 66 |
+
|
| 67 |
+
def update(self, num):
|
| 68 |
+
prev_ema = num if self.ema is None else self.ema
|
| 69 |
+
self.ema = self.alpha * prev_ema + (1.0 - self.alpha) * num
|
| 70 |
+
return self.ema
|
| 71 |
+
|
| 72 |
+
def get(self):
|
| 73 |
+
return self.ema if self.ema is not None else 0.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_tokenizer(model_name_or_path, fast_tokenizer=True):
|
| 77 |
+
if "llama" in model_name_or_path:
|
| 78 |
+
from transformers.models.llama import LlamaTokenizer
|
| 79 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
| 80 |
+
model_name_or_path, fast_tokenizer=fast_tokenizer)
|
| 81 |
+
if tokenizer.pad_token is None:
|
| 82 |
+
# assert tokenizer.eos_token is not None
|
| 83 |
+
# tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
|
| 84 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 85 |
+
tokenizer.padding_side = 'right'
|
| 86 |
+
else:
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 88 |
+
model_name_or_path, fast_tokenizer=fast_tokenizer)
|
| 89 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 90 |
+
# make sure tokenizer is right pad in our logic
|
| 91 |
+
tokenizer.padding_side = 'right'
|
| 92 |
+
return tokenizer
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_hf_tokenizer(model_name_or_path,
|
| 96 |
+
fast_tokenizer=True,
|
| 97 |
+
add_special_tokens=None):
|
| 98 |
+
if os.path.exists(model_name_or_path):
|
| 99 |
+
# Locally tokenizer loading has some issue, so we need to force download
|
| 100 |
+
model_json = os.path.join(model_name_or_path, "config.json")
|
| 101 |
+
if os.path.exists(model_json):
|
| 102 |
+
model_json_file = json.load(open(model_json))
|
| 103 |
+
model_name = model_json_file.get("_name_or_path",
|
| 104 |
+
model_name_or_path)
|
| 105 |
+
tokenizer = get_tokenizer(model_name,
|
| 106 |
+
fast_tokenizer=fast_tokenizer)
|
| 107 |
+
else:
|
| 108 |
+
tokenizer = get_tokenizer(model_name_or_path,
|
| 109 |
+
fast_tokenizer=fast_tokenizer)
|
| 110 |
+
|
| 111 |
+
if add_special_tokens is not None:
|
| 112 |
+
add_special_tokens = [add_special_tokens] if isinstance(add_special_tokens, str) \
|
| 113 |
+
else add_special_tokens
|
| 114 |
+
tokenizer.add_special_tokens(
|
| 115 |
+
{'additional_special_tokens': add_special_tokens})
|
| 116 |
+
|
| 117 |
+
return tokenizer
|
| 118 |
+
|
| 119 |
+
def save_hf_format_safetensors(model, tokenizer, args, sub_folder=""):
|
| 120 |
+
"""
|
| 121 |
+
将模型和分词器保存为 Hugging Face 格式, 并使用 safetensors 保存模型权重.
|
| 122 |
+
此版本能正确处理共享内存的张量 (如Qwen3的 lm_head 和 embed_tokens).
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
model: 要保存的模型
|
| 126 |
+
tokenizer: 分词器
|
| 127 |
+
args: 包含 output_dir 等参数的对象
|
| 128 |
+
sub_folder (str, optional): 在输出目录下的子文件夹名
|
| 129 |
+
"""
|
| 130 |
+
# 1: 提取原始模型 (移除 DeepSpeed 或 DataParallel 的包装)
|
| 131 |
+
model_to_save = model.module if hasattr(model, 'module') else model
|
| 132 |
+
|
| 133 |
+
# 2: 定义输出路径
|
| 134 |
+
output_dir = os.path.join(args.output_dir, sub_folder)
|
| 135 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
# 3: 获取模型状态字典
|
| 138 |
+
state_dict = model_to_save.state_dict()
|
| 139 |
+
|
| 140 |
+
# 4: 处理共享内存的张量, 创建一个新的字典, 其中共享内存的张量将被克隆
|
| 141 |
+
new_state_dict = {}
|
| 142 |
+
# 用于追踪已处理过的内存地址, 避免重复克隆同一内存块
|
| 143 |
+
seen_data_ptrs = {}
|
| 144 |
+
|
| 145 |
+
for key, tensor in state_dict.items():
|
| 146 |
+
# 检查张量的底层数据指针
|
| 147 |
+
data_ptr = tensor.data_ptr()
|
| 148 |
+
|
| 149 |
+
if data_ptr in seen_data_ptrs:
|
| 150 |
+
# 如果这个内存地址已经出现过, 说明是共享内存张量, 需要克隆一份
|
| 151 |
+
print(f"检测到共享��存张量 '{key}' 与 '{seen_data_ptrs[data_ptr]}' 共享内存, 正在克隆...")
|
| 152 |
+
# 使用 .clone() 创建一份独立的副本
|
| 153 |
+
new_state_dict[key] = tensor.clone()
|
| 154 |
+
else:
|
| 155 |
+
# 首次遇到的内存地址, 直接存入新字典并记录
|
| 156 |
+
new_state_dict[key] = tensor
|
| 157 |
+
seen_data_ptrs[data_ptr] = key
|
| 158 |
+
|
| 159 |
+
# 5: 移除 LoRA 权重 (如果使用了LoRA微调)
|
| 160 |
+
if hasattr(model_to_save, 'peft_config') or any("lora" in k for k in new_state_dict.keys()):
|
| 161 |
+
print("检测到LoRA权重, 正在移除...")
|
| 162 |
+
keys_to_remove = [key for key in new_state_dict.keys() if "lora" in key]
|
| 163 |
+
for key in keys_to_remove:
|
| 164 |
+
del new_state_dict[key]
|
| 165 |
+
print(f" 已移除: {key}")
|
| 166 |
+
|
| 167 |
+
# 6: 使用 safetensors 保存处理后的权重
|
| 168 |
+
output_safetensors_file = os.path.join(output_dir, "model.safetensors")
|
| 169 |
+
# 注意: 这里保存的是 new_state_dict, 而不是原始的 state_dict
|
| 170 |
+
save_file(new_state_dict, output_safetensors_file, metadata={"format": "pt"})
|
| 171 |
+
print(f"✅ 模型权重已保存至: {output_safetensors_file}")
|
| 172 |
+
|
| 173 |
+
# 7: 保存模型配置
|
| 174 |
+
output_config_file = os.path.join(output_dir, "config.json")
|
| 175 |
+
model_to_save.config.to_json_file(output_config_file)
|
| 176 |
+
print(f"✅ 模型配置已保存至: {output_config_file}")
|
| 177 |
+
|
| 178 |
+
# 8: 保存分词器 (推荐的标准方式)
|
| 179 |
+
tokenizer.save_pretrained(output_dir)
|
| 180 |
+
print(f"✅ 分词器文件已保存至: {output_dir}")
|
| 181 |
+
|
| 182 |
+
# 9: 可选: 验证保存的权重可以正确加载
|
| 183 |
+
print("正在进行快速加载验证...")
|
| 184 |
+
try:
|
| 185 |
+
# 从保存的文件加载权重, 检查完整性
|
| 186 |
+
from safetensors.torch import load_file
|
| 187 |
+
loaded_tensors = load_file(output_safetensors_file)
|
| 188 |
+
print(f"✅ 验证通过! 成功加载了 {len(loaded_tensors)} 个张量.")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"⚠ 加载验证时出现警告(可能不影响后续使用): {e}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def save_hf_format(model, tokenizer, args, sub_folder=""):
|
| 194 |
+
# used to save huggingface format, so we can use it for hf.from_pretrained
|
| 195 |
+
model_to_save = model.module if hasattr(model, 'module') else model
|
| 196 |
+
CONFIG_NAME = "config.json"
|
| 197 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 198 |
+
output_dir = os.path.join(args.output_dir, sub_folder)
|
| 199 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 200 |
+
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
|
| 201 |
+
output_config_file = os.path.join(output_dir, CONFIG_NAME)
|
| 202 |
+
save_dict = model_to_save.state_dict()
|
| 203 |
+
for key in list(save_dict.keys()):
|
| 204 |
+
if "lora" in key:
|
| 205 |
+
del save_dict[key]
|
| 206 |
+
torch.save(save_dict, output_model_file)
|
| 207 |
+
model_to_save.config.to_json_file(output_config_file)
|
| 208 |
+
tokenizer.save_vocabulary(output_dir)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def set_random_seed(seed):
|
| 212 |
+
if seed is not None:
|
| 213 |
+
set_seed(seed)
|
| 214 |
+
random.seed(seed)
|
| 215 |
+
np.random.seed(seed)
|
| 216 |
+
torch.manual_seed(seed)
|
| 217 |
+
get_accelerator().manual_seed_all(seed)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_all_reduce_mean(tensor):
|
| 221 |
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
| 222 |
+
tensor = tensor / torch.distributed.get_world_size()
|
| 223 |
+
return tensor
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# This function is a modified version of code available in the from_pretrained API of HuggingFace Transformers
|
| 227 |
+
# The code is copied and modified from: https://github.com/huggingface/transformers/blob/5ee9693a1c77c617ebc43ef20194b6d3b674318e/src/transformers/modeling_utils.py#L498
|
| 228 |
+
# This function helps load a HF format checkpoint into a DeepSpeed wrapped model that has been sharded using ZeRO Stage 3
|
| 229 |
+
def load_state_dict_into_model(model_to_load=None,
|
| 230 |
+
state_dict=None,
|
| 231 |
+
start_prefix="",
|
| 232 |
+
zero_stage=0):
|
| 233 |
+
|
| 234 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 235 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 236 |
+
state_dict = state_dict.copy()
|
| 237 |
+
if metadata is not None:
|
| 238 |
+
state_dict._metadata = metadata
|
| 239 |
+
|
| 240 |
+
error_msgs = []
|
| 241 |
+
|
| 242 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
| 243 |
+
# so we need to apply the function recursively.
|
| 244 |
+
def load(module: nn.Module, state_dict, prefix=""):
|
| 245 |
+
local_metadata = {} if metadata is None else metadata.get(
|
| 246 |
+
prefix[:-1], {})
|
| 247 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
| 248 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
| 249 |
+
# state_dict
|
| 250 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
| 251 |
+
if zero_stage == 3:
|
| 252 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
| 253 |
+
# parameters that are in the current state_dict.
|
| 254 |
+
named_parameters = dict(
|
| 255 |
+
module.named_parameters(prefix=prefix[:-1], recurse=False))
|
| 256 |
+
params_to_gather = [
|
| 257 |
+
named_parameters[k] for k in state_dict.keys()
|
| 258 |
+
if k in named_parameters
|
| 259 |
+
]
|
| 260 |
+
if len(params_to_gather) > 0:
|
| 261 |
+
# because zero3 puts placeholders in model params, this context
|
| 262 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
| 263 |
+
# the state dict and then re-partitions them again
|
| 264 |
+
with deepspeed.zero.GatheredParameters(params_to_gather,
|
| 265 |
+
modifier_rank=0):
|
| 266 |
+
if torch.distributed.get_rank() == 0:
|
| 267 |
+
module._load_from_state_dict(*args)
|
| 268 |
+
else:
|
| 269 |
+
module._load_from_state_dict(*args)
|
| 270 |
+
|
| 271 |
+
for name, child in module._modules.items():
|
| 272 |
+
if child is not None:
|
| 273 |
+
load(child, state_dict, prefix + name + ".")
|
| 274 |
+
|
| 275 |
+
load(model_to_load, state_dict, prefix=start_prefix)
|
| 276 |
+
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
| 277 |
+
# it's safe to delete it.
|
| 278 |
+
del state_dict
|
| 279 |
+
|
| 280 |
+
return error_msgs
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_optimizer_grouped_parameters(
|
| 284 |
+
model,
|
| 285 |
+
weight_decay,
|
| 286 |
+
lora_lr=5e-4,
|
| 287 |
+
no_decay_name_list=[
|
| 288 |
+
"bias", "layer_norm.weight", "layernorm.weight", "norm.weight",
|
| 289 |
+
"ln_f.weight"
|
| 290 |
+
],
|
| 291 |
+
lora_name_list=["lora_right_weight", "lora_left_weight"],
|
| 292 |
+
):
|
| 293 |
+
optimizer_grouped_parameters = [
|
| 294 |
+
{
|
| 295 |
+
"params": [
|
| 296 |
+
p for n, p in model.named_parameters()
|
| 297 |
+
if (not any(nd in n.lower() for nd in no_decay_name_list)
|
| 298 |
+
and p.requires_grad and not any(nd in n.lower()
|
| 299 |
+
for nd in lora_name_list))
|
| 300 |
+
],
|
| 301 |
+
"weight_decay":
|
| 302 |
+
weight_decay,
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"params": [
|
| 306 |
+
p for n, p in model.named_parameters()
|
| 307 |
+
if (not any(nd in n.lower() for nd in no_decay_name_list)
|
| 308 |
+
and p.requires_grad and any(nd in n.lower()
|
| 309 |
+
for nd in lora_name_list))
|
| 310 |
+
],
|
| 311 |
+
"weight_decay":
|
| 312 |
+
weight_decay,
|
| 313 |
+
"lr":
|
| 314 |
+
lora_lr
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"params": [
|
| 318 |
+
p for n, p in model.named_parameters()
|
| 319 |
+
if (any(nd in n.lower()
|
| 320 |
+
for nd in no_decay_name_list) and p.requires_grad)
|
| 321 |
+
],
|
| 322 |
+
"weight_decay":
|
| 323 |
+
0.0,
|
| 324 |
+
},
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
non_empty_groups = []
|
| 328 |
+
for group in optimizer_grouped_parameters:
|
| 329 |
+
if group["params"]:
|
| 330 |
+
non_empty_groups.append(group)
|
| 331 |
+
return non_empty_groups
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _z3_params_to_fetch(param_list):
|
| 335 |
+
return [
|
| 336 |
+
p for p in param_list
|
| 337 |
+
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
| 338 |
+
]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
|
| 342 |
+
zero_stage_3 = (zero_stage == 3)
|
| 343 |
+
with torch.no_grad():
|
| 344 |
+
for param, param_ema in zip(model.parameters(),
|
| 345 |
+
model_ema.parameters()):
|
| 346 |
+
# TODO: use prefiltering for efficiency
|
| 347 |
+
params_to_fetch = _z3_params_to_fetch([param, param_ema
|
| 348 |
+
]) if zero_stage_3 else []
|
| 349 |
+
should_gather_param = len(params_to_fetch) > 0
|
| 350 |
+
with deepspeed.zero.GatheredParameters(
|
| 351 |
+
params_to_fetch, enabled=should_gather_param):
|
| 352 |
+
data = param.data
|
| 353 |
+
if device is not None:
|
| 354 |
+
data = data.to(device)
|
| 355 |
+
param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
|
| 359 |
+
zero_stage_3 = (zero_stage == 3)
|
| 360 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 361 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 362 |
+
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
|
| 363 |
+
|
| 364 |
+
model_to_save = model_ema.module if hasattr(model_ema,
|
| 365 |
+
'module') else model_ema
|
| 366 |
+
if not zero_stage_3:
|
| 367 |
+
if global_rank == 0:
|
| 368 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 369 |
+
else:
|
| 370 |
+
output_state_dict = {}
|
| 371 |
+
for k, v in model_to_save.named_parameters():
|
| 372 |
+
|
| 373 |
+
if hasattr(v, 'ds_id'):
|
| 374 |
+
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v
|
| 375 |
+
]),
|
| 376 |
+
enabled=zero_stage_3):
|
| 377 |
+
v_p = v.data.cpu()
|
| 378 |
+
else:
|
| 379 |
+
v_p = v.cpu()
|
| 380 |
+
if global_rank == 0 and "lora" not in k:
|
| 381 |
+
output_state_dict[k] = v_p
|
| 382 |
+
if global_rank == 0:
|
| 383 |
+
torch.save(output_state_dict, output_model_file)
|
| 384 |
+
del output_state_dict
|
SFT-EN-01-29-2026/data/dev.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
SFT-EN-01-29-2026/data/eval.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
SFT-EN-01-29-2026/data/train.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de56cd90e05715d0521515aa4a90d718d3e0da27d49970ff0a83136652066906
|
| 3 |
+
size 25584972
|
SFT-EN-01-29-2026/model/chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
SFT-EN-01-29-2026/model/config.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"dtype": "bfloat16",
|
| 9 |
+
"end_token_id": 151645,
|
| 10 |
+
"eos_token_id": 151645,
|
| 11 |
+
"head_dim": 128,
|
| 12 |
+
"hidden_act": "silu",
|
| 13 |
+
"hidden_size": 2560,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 9728,
|
| 16 |
+
"layer_types": [
|
| 17 |
+
"full_attention",
|
| 18 |
+
"full_attention",
|
| 19 |
+
"full_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"full_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"full_attention",
|
| 24 |
+
"full_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"full_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"full_attention",
|
| 52 |
+
"full_attention"
|
| 53 |
+
],
|
| 54 |
+
"max_position_embeddings": 40960,
|
| 55 |
+
"max_window_layers": 36,
|
| 56 |
+
"model_type": "qwen3",
|
| 57 |
+
"num_attention_heads": 32,
|
| 58 |
+
"num_hidden_layers": 36,
|
| 59 |
+
"num_key_value_heads": 8,
|
| 60 |
+
"pad_token_id": 151645,
|
| 61 |
+
"rms_norm_eps": 1e-06,
|
| 62 |
+
"rope_parameters": {
|
| 63 |
+
"rope_theta": 1000000,
|
| 64 |
+
"rope_type": "default"
|
| 65 |
+
},
|
| 66 |
+
"sliding_window": null,
|
| 67 |
+
"tie_word_embeddings": false,
|
| 68 |
+
"transformers_version": "5.0.0",
|
| 69 |
+
"use_cache": true,
|
| 70 |
+
"use_sliding_window": false,
|
| 71 |
+
"vocab_size": 151672
|
| 72 |
+
}
|
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bec6d2eb21f2a317beea31714d84e4c3fe34ba0c62365be1e2dc9ea98806cd55
|
| 3 |
+
size 204
|
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:543d7327217546fd37d3fee5322ec0f7ccbb530b17e65a16f14250a52daa3a4a
|
| 3 |
+
size 1448
|
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d996a4ea809b4cd08593bf3563f579d16402dfa2a1f23aa03d3110beda00a0e
|
| 3 |
+
size 37198
|
SFT-EN-01-29-2026/model/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17e264bbafe4a0f20c4f2b0df948cee4cc696f181bd12f2530404e8c71c06444
|
| 3 |
+
size 37198
|
SFT-EN-01-29-2026/model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9956493572a7ae7ff86699c23789cba8d31a38d0a2d6333177d846b9d9cade23
|
| 3 |
+
size 8820191160
|
SFT-EN-01-29-2026/model/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
|
| 3 |
+
size 11422650
|
SFT-EN-01-29-2026/model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"extra_special_tokens": [
|
| 9 |
+
"<|im_start|>",
|
| 10 |
+
"<|im_end|>",
|
| 11 |
+
"<|object_ref_start|>",
|
| 12 |
+
"<|object_ref_end|>",
|
| 13 |
+
"<|box_start|>",
|
| 14 |
+
"<|box_end|>",
|
| 15 |
+
"<|quad_start|>",
|
| 16 |
+
"<|quad_end|>",
|
| 17 |
+
"<|vision_start|>",
|
| 18 |
+
"<|vision_end|>",
|
| 19 |
+
"<|vision_pad|>",
|
| 20 |
+
"<|image_pad|>",
|
| 21 |
+
"<|video_pad|>"
|
| 22 |
+
],
|
| 23 |
+
"fast_tokenizer": true,
|
| 24 |
+
"is_local": true,
|
| 25 |
+
"model_max_length": 131072,
|
| 26 |
+
"pad_token": "<|im_end|>",
|
| 27 |
+
"split_special_tokens": false,
|
| 28 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 29 |
+
"unk_token": null
|
| 30 |
+
}
|
SFT-EN-01-29-2026/model/training.log
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 2 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 3 |
+
[2026-01-29 22:24:38,868] [WARNING] [runner.py:232:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
|
| 4 |
+
[2026-01-29 22:24:38,868] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None --log_level=info main.py --model_name_or_path /workspace/Qwen3-4B --data_path /home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 8 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --zero_stage 3 --offload --dtype bf16 --enable_tensorboard --tensorboard_path ./output_sft_en --deepspeed --output_dir ./output_sft_en
|
| 5 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 6 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 7 |
+
[2026-01-29 22:24:45,395] [INFO] [launch.py:162:main] WORLD INFO DICT: {'localhost': [0]}
|
| 8 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:168:main] nnodes=1, num_local_procs=1, node_rank=0
|
| 9 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:179:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
|
| 10 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:180:main] dist_world_size=1
|
| 11 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:184:main] Setting CUDA_VISIBLE_DEVICES=0
|
| 12 |
+
[2026-01-29 22:24:45,398] [INFO] [launch.py:272:main] process 31271 spawned with command: ['/usr/bin/python3', '-u', 'main.py', '--local_rank=0', '--model_name_or_path', '/workspace/Qwen3-4B', '--data_path', '/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl', '--weight_decay', '0.1', '--dropout', '0.0', '--gradient_accumulation_steps', '8', '--per_device_train_batch_size', '1', '--per_device_eval_batch_size', '1', '--zero_stage', '3', '--offload', '--dtype', 'bf16', '--enable_tensorboard', '--tensorboard_path', './output_sft_en', '--deepspeed', '--output_dir', './output_sft_en']
|
| 13 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 14 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 15 |
+
[rank0]:[W129 22:24:52.444107661 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
|
| 16 |
+
Setting model_config.attention_dropout to 0.0
|
| 17 |
+
args: Namespace(data_path=['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl'], data_split='6,2,2', sft_only_data_path=[], data_output_path='/tmp/data_files/', model_name_or_path='/workspace/Qwen3-4B', per_device_train_batch_size=1, per_device_eval_batch_size=1, max_seq_len=512, learning_rate=0.001, weight_decay=0.1, num_train_epochs=1, gradient_accumulation_steps=8, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, num_warmup_steps=0, output_dir='./output_sft_en', seed=1234, local_rank=0, gradient_checkpointing=False, dropout=0.0, offload=True, dtype='bf16', zero_stage=3, lora_dim=0, lora_module_name='decoder.layers.', only_optimize_lora=False, lora_learning_rate=0.0005, compute_fp32_loss=False, enable_tensorboard=True, tensorboard_path='./output_sft_en', add_eot_token=False, eot_token='<|endoftext|>', print_loss=False, deepspeed=True, deepspeed_config=None, deepscale=False, deepscale_config=None, global_rank=0)
|
| 18 |
+
data_path: ['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl']
|
| 19 |
+
/usr/lib/python3/dist-packages/torch/utils/cpp_extension.py:2376: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
|
| 20 |
+
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
|
| 21 |
+
warnings.warn(
|
| 22 |
+
2026-01-29 22:25:34.798274: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
| 23 |
+
2026-01-29 22:25:34.808869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
| 24 |
+
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
|
| 25 |
+
E0000 00:00:1769725534.821805 31271 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
| 26 |
+
E0000 00:00:1769725534.825823 31271 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
| 27 |
+
W0000 00:00:1769725534.835606 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 28 |
+
W0000 00:00:1769725534.835626 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 29 |
+
W0000 00:00:1769725534.835656 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 30 |
+
W0000 00:00:1769725534.835658 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 31 |
+
2026-01-29 22:25:34.838493: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
| 32 |
+
To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
| 33 |
+
Stage 3 initialize beginning
|
| 34 |
+
MA 0.72 GB Max_MA 2.9 GB CA 2.9 GB Max_CA 3 GB
|
| 35 |
+
CPU Virtual Memory: used = 16.26 GB, percent = 7.4%
|
| 36 |
+
DeepSpeedZeRoOffload initialize [begin]
|
| 37 |
+
MA 0.72 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
|
| 38 |
+
CPU Virtual Memory: used = 16.25 GB, percent = 7.3%
|
| 39 |
+
Parameter Offload - Persistent parameters statistics: param_count = 145, numel = 196096
|
| 40 |
+
DeepSpeedZeRoOffload initialize [end]
|
| 41 |
+
MA 0.0 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
|
| 42 |
+
CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
|
| 43 |
+
Before creating fp16 partitions
|
| 44 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 45 |
+
CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
|
| 46 |
+
After creating fp16 partitions: 5
|
| 47 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 48 |
+
CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
|
| 49 |
+
Before creating fp32 partitions
|
| 50 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 51 |
+
CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
|
| 52 |
+
After creating fp32 partitions
|
| 53 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 54 |
+
CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
|
| 55 |
+
Before initializing optimizer states
|
| 56 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 57 |
+
CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
|
| 58 |
+
After initializing optimizer states
|
| 59 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 60 |
+
CPU Virtual Memory: used = 49.09 GB, percent = 22.2%
|
| 61 |
+
After initializing ZeRO optimizer
|
| 62 |
+
MA 0.93 GB Max_MA 2.38 GB CA 3.83 GB Max_CA 4 GB
|
| 63 |
+
CPU Virtual Memory: used = 56.32 GB, percent = 25.5%
|
| 64 |
+
***** Running training *****
|
| 65 |
+
Beginning of Epoch 1/1, Total Micro Batches 5400
|
| 66 |
+
Model Parameters: 4.022 B, Latency: 2.91s, TFLOPs: 3.40, Samples/sec: 0.34, Time/seq 2.91s, Batch Size: 1, Sequence Length: 512
|
| 67 |
+
Model Parameters: 4.022 B, Latency: 3.07s, TFLOPs: 3.22, Samples/sec: 0.33, Time/seq 3.07s, Batch Size: 1, Sequence Length: 512
|
| 68 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 69 |
+
Model Parameters: 4.022 B, Latency: 2.35s, TFLOPs: 4.20, Samples/sec: 0.43, Time/seq 2.35s, Batch Size: 1, Sequence Length: 512
|
| 70 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 71 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 72 |
+
Model Parameters: 4.022 B, Latency: 2.33s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.33s, Batch Size: 1, Sequence Length: 512
|
| 73 |
+
Model Parameters: 4.022 B, Latency: 6.18s, TFLOPs: 1.60, Samples/sec: 0.16, Time/seq 6.18s, Batch Size: 1, Sequence Length: 512
|
| 74 |
+
Model Parameters: 4.022 B, Latency: 2.11s, TFLOPs: 4.69, Samples/sec: 0.47, Time/seq 2.11s, Batch Size: 1, Sequence Length: 512
|
| 75 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 76 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 77 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 78 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 79 |
+
Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
|
| 80 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 81 |
+
Model Parameters: 4.022 B, Latency: 4.19s, TFLOPs: 2.36, Samples/sec: 0.24, Time/seq 4.19s, Batch Size: 1, Sequence Length: 512
|
| 82 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 83 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 84 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 85 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 86 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 87 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 88 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 89 |
+
Model Parameters: 4.022 B, Latency: 4.21s, TFLOPs: 2.35, Samples/sec: 0.24, Time/seq 4.21s, Batch Size: 1, Sequence Length: 512
|
| 90 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 91 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 92 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 93 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 94 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 95 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 96 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.01, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 97 |
+
Model Parameters: 4.022 B, Latency: 4.24s, TFLOPs: 2.33, Samples/sec: 0.24, Time/seq 4.24s, Batch Size: 1, Sequence Length: 512
|
| 98 |
+
Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.14, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
|
| 99 |
+
Model Parameters: 4.022 B, Latency: 2.36s, TFLOPs: 4.19, Samples/sec: 0.42, Time/seq 2.36s, Batch Size: 1, Sequence Length: 512
|
| 100 |
+
Model Parameters: 4.022 B, Latency: 2.31s, TFLOPs: 4.27, Samples/sec: 0.43, Time/seq 2.31s, Batch Size: 1, Sequence Length: 512
|
| 101 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 102 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 103 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.96, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 104 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 105 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 106 |
+
Model Parameters: 4.022 B, Latency: 1.95s, TFLOPs: 5.06, Samples/sec: 0.51, Time/seq 1.95s, Batch Size: 1, Sequence Length: 512
|
| 107 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 108 |
+
Model Parameters: 4.022 B, Latency: 1.93s, TFLOPs: 5.12, Samples/sec: 0.52, Time/seq 1.93s, Batch Size: 1, Sequence Length: 512
|
| 109 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 110 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 111 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 112 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 113 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 114 |
+
Model Parameters: 4.022 B, Latency: 2.13s, TFLOPs: 4.64, Samples/sec: 0.47, Time/seq 2.13s, Batch Size: 1, Sequence Length: 512
|
| 115 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 116 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 117 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 118 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 119 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 120 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 121 |
+
Model Parameters: 4.022 B, Latency: 4.22s, TFLOPs: 2.34, Samples/sec: 0.24, Time/seq 4.22s, Batch Size: 1, Sequence Length: 512
|
| 122 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 123 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 124 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 125 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 126 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 127 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 128 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 129 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 130 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 131 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 132 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 133 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 134 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 135 |
+
Model Parameters: 4.022 B, Latency: 2.25s, TFLOPs: 4.40, Samples/sec: 0.45, Time/seq 2.25s, Batch Size: 1, Sequence Length: 512
|
| 136 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 137 |
+
Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
|
| 138 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 139 |
+
Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.13, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
|
| 140 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 141 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.18, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 142 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 143 |
+
Model Parameters: 4.022 B, Latency: 2.28s, TFLOPs: 4.33, Samples/sec: 0.44, Time/seq 2.28s, Batch Size: 1, Sequence Length: 512
|
| 144 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 145 |
+
Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
|
| 146 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 147 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 148 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 149 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 150 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 151 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 152 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 153 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 154 |
+
Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.67, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
|
| 155 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 156 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 157 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 158 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 159 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 160 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 161 |
+
Model Parameters: 4.022 B, Latency: 4.37s, TFLOPs: 2.26, Samples/sec: 0.23, Time/seq 4.37s, Batch Size: 1, Sequence Length: 512
|
| 162 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 163 |
+
Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
|
| 164 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.08, Samples/sec: 0.51, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 165 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 166 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 167 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 168 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 169 |
+
Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
|
| 170 |
+
Model Parameters: 4.022 B, Latency: 2.17s, TFLOPs: 4.55, Samples/sec: 0.46, Time/seq 2.17s, Batch Size: 1, Sequence Length: 512
|
| 171 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 172 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 173 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 174 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 175 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 176 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 177 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 178 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 179 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 180 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 181 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 182 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 183 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 184 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 185 |
+
Model Parameters: 4.022 B, Latency: 4.54s, TFLOPs: 2.17, Samples/sec: 0.22, Time/seq 4.54s, Batch Size: 1, Sequence Length: 512
|
| 186 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 187 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 188 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 189 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 190 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 191 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 192 |
+
Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.66, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
|
| 193 |
+
Model Parameters: 4.022 B, Latency: 4.17s, TFLOPs: 2.37, Samples/sec: 0.24, Time/seq 4.17s, Batch Size: 1, Sequence Length: 512
|
| 194 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 195 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 196 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 197 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 198 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 199 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 200 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 201 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 202 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 203 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 204 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 205 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 206 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 207 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 208 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 209 |
+
Model Parameters: 4.022 B, Latency: 4.36s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.36s, Batch Size: 1, Sequence Length: 512
|
| 210 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 211 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 212 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 213 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 214 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 215 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 216 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 217 |
+
Model Parameters: 4.022 B, Latency: 4.26s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.26s, Batch Size: 1, Sequence Length: 512
|
| 218 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 219 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 220 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 221 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 222 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 223 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 224 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 225 |
+
Model Parameters: 4.022 B, Latency: 4.31s, TFLOPs: 2.29, Samples/sec: 0.23, Time/seq 4.31s, Batch Size: 1, Sequence Length: 512
|
| 226 |
+
Model Parameters: 4.022 B, Latency: 2.15s, TFLOPs: 4.60, Samples/sec: 0.46, Time/seq 2.15s, Batch Size: 1, Sequence Length: 512
|
| 227 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 228 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 229 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 230 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 231 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 232 |
+
Model Parameters: 4.022 B, Latency: 2.52s, TFLOPs: 3.92, Samples/sec: 0.40, Time/seq 2.52s, Batch Size: 1, Sequence Length: 512
|
| 233 |
+
Model Parameters: 4.022 B, Latency: 4.34s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.34s, Batch Size: 1, Sequence Length: 512
|
| 234 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 235 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 236 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 237 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 238 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 239 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 240 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 241 |
+
Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
|
| 242 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.48, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 243 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 244 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 245 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 246 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 247 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 248 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 249 |
+
Model Parameters: 4.022 B, Latency: 4.33s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.33s, Batch Size: 1, Sequence Length: 512
|
| 250 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 251 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 252 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 253 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 254 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 255 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 256 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 257 |
+
Model Parameters: 4.022 B, Latency: 4.25s, TFLOPs: 2.32, Samples/sec: 0.24, Time/seq 4.25s, Batch Size: 1, Sequence Length: 512
|
| 258 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 259 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 260 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 261 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 262 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 263 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 264 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 265 |
+
Model Parameters: 4.022 B, Latency: 4.35s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.35s, Batch Size: 1, Sequence Length: 512
|
| 266 |
+
Model Parameters: 4.022 B, Latency: 2.14s, TFLOPs: 4.61, Samples/sec: 0.47, Time/seq 2.14s, Batch Size: 1, Sequence Length: 512
|
| 267 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 268 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 269 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 270 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 271 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 272 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 273 |
+
[2026-01-29 22:33:59,925] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 274 |
+
[rank0]: Traceback (most recent call last):
|
| 275 |
+
[rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 434, in <module>
|
| 276 |
+
[rank0]: main()
|
| 277 |
+
[rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 387, in main
|
| 278 |
+
[rank0]: model.step()
|
| 279 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2690, in step
|
| 280 |
+
[rank0]: self._take_model_step(lr_kwargs)
|
| 281 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2585, in _take_model_step
|
| 282 |
+
[rank0]: self.optimizer.step()
|
| 283 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
|
| 284 |
+
[rank0]: ret_val = func(*args, **kwargs)
|
| 285 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2220, in step
|
| 286 |
+
[rank0]: self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
|
| 287 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
|
| 288 |
+
[rank0]: ret_val = func(*args, **kwargs)
|
| 289 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2168, in _reassign_or_swap_out_partitioned_parameters
|
| 290 |
+
[rank0]: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
| 291 |
+
[rank0]: KeyboardInterrupt
|
| 292 |
+
Traceback (most recent call last):
|
| 293 |
+
File "/home/ubuntu/.local/bin/deepspeed", line 6, in <module>
|
| 294 |
+
main()
|
| 295 |
+
File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/launcher/runner.py", line 646, in main
|
| 296 |
+
result.wait()
|
| 297 |
+
File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
|
| 298 |
+
return self._wait(timeout=timeout)
|
| 299 |
+
File "/usr/lib/python3.10/subprocess.py", line 1959, in _wait
|
| 300 |
+
(pid, sts) = self._try_wait(0)
|
| 301 |
+
File "/usr/lib/python3.10/subprocess.py", line 1917, in _try_wait
|
| 302 |
+
(pid, sts) = os.waitpid(self.pid, wait_flags)
|
| 303 |
+
KeyboardInterrupt
|
| 304 |
+
[2026-01-29 22:34:00,546] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 305 |
+
Exception ignored in atexit callback: <function shutdown_compile_workers at 0x7d22457a00d0>
|
| 306 |
+
Traceback (most recent call last):
|
| 307 |
+
File "/usr/lib/python3/dist-packages/torch/_inductor/async_compile.py", line 113, in shutdown_compile_workers
|
| 308 |
+
pool.shutdown()
|
| 309 |
+
File "/usr/lib/python3/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 239, in shutdown
|
| 310 |
+
self.process.wait(300)
|
| 311 |
+
File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
|
| 312 |
+
return self._wait(timeout=timeout)
|
| 313 |
+
File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
|
| 314 |
+
time.sleep(delay)
|
| 315 |
+
KeyboardInterrupt:
|
| 316 |
+
[2026-01-29 22:34:00,990] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 317 |
+
[2026-01-29 22:34:04,967] [INFO] [launch.py:344:sigkill_handler] Main process received SIGINT, exiting
|
SFT-EN-01-29-2026/scripts/run_qwen3-4b.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Step 1: SFT Training for English Medical Data (UltraMedical)
|
| 3 |
+
# Qwen3-4B with LoRA on H100
|
| 4 |
+
|
| 5 |
+
MODEL_PATH=/workspace/Qwen3-4B
|
| 6 |
+
DATA_PATH=/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl
|
| 7 |
+
OUTPUT_DIR=./output_sft_en
|
| 8 |
+
|
| 9 |
+
mkdir -p $OUTPUT_DIR
|
| 10 |
+
|
| 11 |
+
deepspeed --num_gpus 1 main.py \
|
| 12 |
+
--model_name_or_path $MODEL_PATH \
|
| 13 |
+
--data_path $DATA_PATH \
|
| 14 |
+
--per_device_train_batch_size 2 \
|
| 15 |
+
--per_device_eval_batch_size 2 \
|
| 16 |
+
--max_seq_len 512 \
|
| 17 |
+
--learning_rate 2e-5 \
|
| 18 |
+
--weight_decay 0.1 \
|
| 19 |
+
--num_train_epochs 1 \
|
| 20 |
+
--num_warmup_steps 100 \
|
| 21 |
+
--gradient_accumulation_steps 4 \
|
| 22 |
+
--lr_scheduler_type cosine \
|
| 23 |
+
--gradient_checkpointing \
|
| 24 |
+
--dropout 0.0 \
|
| 25 |
+
--zero_stage 2 \
|
| 26 |
+
--dtype bf16 \
|
| 27 |
+
--lora_dim 64 \
|
| 28 |
+
--lora_module_name "layers." \
|
| 29 |
+
--only_optimize_lora \
|
| 30 |
+
--lora_learning_rate 5e-4 \
|
| 31 |
+
--compute_fp32_loss \
|
| 32 |
+
--print_loss \
|
| 33 |
+
--enable_tensorboard \
|
| 34 |
+
--tensorboard_path $OUTPUT_DIR \
|
| 35 |
+
--deepspeed \
|
| 36 |
+
--output_dir $OUTPUT_DIR
|
sft_model_backup/chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
sft_model_backup/config.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"dtype": "bfloat16",
|
| 9 |
+
"end_token_id": 151645,
|
| 10 |
+
"eos_token_id": 151645,
|
| 11 |
+
"head_dim": 128,
|
| 12 |
+
"hidden_act": "silu",
|
| 13 |
+
"hidden_size": 2560,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 9728,
|
| 16 |
+
"layer_types": [
|
| 17 |
+
"full_attention",
|
| 18 |
+
"full_attention",
|
| 19 |
+
"full_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"full_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"full_attention",
|
| 24 |
+
"full_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention",
|
| 43 |
+
"full_attention",
|
| 44 |
+
"full_attention",
|
| 45 |
+
"full_attention",
|
| 46 |
+
"full_attention",
|
| 47 |
+
"full_attention",
|
| 48 |
+
"full_attention",
|
| 49 |
+
"full_attention",
|
| 50 |
+
"full_attention",
|
| 51 |
+
"full_attention",
|
| 52 |
+
"full_attention"
|
| 53 |
+
],
|
| 54 |
+
"max_position_embeddings": 40960,
|
| 55 |
+
"max_window_layers": 36,
|
| 56 |
+
"model_type": "qwen3",
|
| 57 |
+
"num_attention_heads": 32,
|
| 58 |
+
"num_hidden_layers": 36,
|
| 59 |
+
"num_key_value_heads": 8,
|
| 60 |
+
"pad_token_id": 151645,
|
| 61 |
+
"rms_norm_eps": 1e-06,
|
| 62 |
+
"rope_parameters": {
|
| 63 |
+
"rope_theta": 1000000,
|
| 64 |
+
"rope_type": "default"
|
| 65 |
+
},
|
| 66 |
+
"sliding_window": null,
|
| 67 |
+
"tie_word_embeddings": false,
|
| 68 |
+
"transformers_version": "5.0.0",
|
| 69 |
+
"use_cache": true,
|
| 70 |
+
"use_sliding_window": false,
|
| 71 |
+
"vocab_size": 151672
|
| 72 |
+
}
|
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725308.209-20-158-64.30075.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bec6d2eb21f2a317beea31714d84e4c3fe34ba0c62365be1e2dc9ea98806cd55
|
| 3 |
+
size 204
|
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769725536.209-20-158-64.31271.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:543d7327217546fd37d3fee5322ec0f7ccbb530b17e65a16f14250a52daa3a4a
|
| 3 |
+
size 1448
|
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769726189.209-20-158-64.32221.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d996a4ea809b4cd08593bf3563f579d16402dfa2a1f23aa03d3110beda00a0e
|
| 3 |
+
size 37198
|
sft_model_backup/ds_tensorboard_logs/step1_model_tensorboard/events.out.tfevents.1769727296.209-20-158-64.32989.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17e264bbafe4a0f20c4f2b0df948cee4cc696f181bd12f2530404e8c71c06444
|
| 3 |
+
size 37198
|
sft_model_backup/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9956493572a7ae7ff86699c23789cba8d31a38d0a2d6333177d846b9d9cade23
|
| 3 |
+
size 8820191160
|
sft_model_backup/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
|
| 3 |
+
size 11422650
|
sft_model_backup/tokenizer_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"extra_special_tokens": [
|
| 9 |
+
"<|im_start|>",
|
| 10 |
+
"<|im_end|>",
|
| 11 |
+
"<|object_ref_start|>",
|
| 12 |
+
"<|object_ref_end|>",
|
| 13 |
+
"<|box_start|>",
|
| 14 |
+
"<|box_end|>",
|
| 15 |
+
"<|quad_start|>",
|
| 16 |
+
"<|quad_end|>",
|
| 17 |
+
"<|vision_start|>",
|
| 18 |
+
"<|vision_end|>",
|
| 19 |
+
"<|vision_pad|>",
|
| 20 |
+
"<|image_pad|>",
|
| 21 |
+
"<|video_pad|>"
|
| 22 |
+
],
|
| 23 |
+
"fast_tokenizer": true,
|
| 24 |
+
"is_local": true,
|
| 25 |
+
"model_max_length": 131072,
|
| 26 |
+
"pad_token": "<|im_end|>",
|
| 27 |
+
"split_special_tokens": false,
|
| 28 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 29 |
+
"unk_token": null
|
| 30 |
+
}
|
sft_model_backup/training.log
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 2 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 3 |
+
[2026-01-29 22:24:38,868] [WARNING] [runner.py:232:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
|
| 4 |
+
[2026-01-29 22:24:38,868] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None --log_level=info main.py --model_name_or_path /workspace/Qwen3-4B --data_path /home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 8 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --zero_stage 3 --offload --dtype bf16 --enable_tensorboard --tensorboard_path ./output_sft_en --deepspeed --output_dir ./output_sft_en
|
| 5 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 6 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 7 |
+
[2026-01-29 22:24:45,395] [INFO] [launch.py:162:main] WORLD INFO DICT: {'localhost': [0]}
|
| 8 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:168:main] nnodes=1, num_local_procs=1, node_rank=0
|
| 9 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:179:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
|
| 10 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:180:main] dist_world_size=1
|
| 11 |
+
[2026-01-29 22:24:45,396] [INFO] [launch.py:184:main] Setting CUDA_VISIBLE_DEVICES=0
|
| 12 |
+
[2026-01-29 22:24:45,398] [INFO] [launch.py:272:main] process 31271 spawned with command: ['/usr/bin/python3', '-u', 'main.py', '--local_rank=0', '--model_name_or_path', '/workspace/Qwen3-4B', '--data_path', '/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl', '--weight_decay', '0.1', '--dropout', '0.0', '--gradient_accumulation_steps', '8', '--per_device_train_batch_size', '1', '--per_device_eval_batch_size', '1', '--zero_stage', '3', '--offload', '--dtype', 'bf16', '--enable_tensorboard', '--tensorboard_path', './output_sft_en', '--deepspeed', '--output_dir', './output_sft_en']
|
| 13 |
+
/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4
|
| 14 |
+
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
| 15 |
+
[rank0]:[W129 22:24:52.444107661 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
|
| 16 |
+
Setting model_config.attention_dropout to 0.0
|
| 17 |
+
args: Namespace(data_path=['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl'], data_split='6,2,2', sft_only_data_path=[], data_output_path='/tmp/data_files/', model_name_or_path='/workspace/Qwen3-4B', per_device_train_batch_size=1, per_device_eval_batch_size=1, max_seq_len=512, learning_rate=0.001, weight_decay=0.1, num_train_epochs=1, gradient_accumulation_steps=8, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, num_warmup_steps=0, output_dir='./output_sft_en', seed=1234, local_rank=0, gradient_checkpointing=False, dropout=0.0, offload=True, dtype='bf16', zero_stage=3, lora_dim=0, lora_module_name='decoder.layers.', only_optimize_lora=False, lora_learning_rate=0.0005, compute_fp32_loss=False, enable_tensorboard=True, tensorboard_path='./output_sft_en', add_eot_token=False, eot_token='<|endoftext|>', print_loss=False, deepspeed=True, deepspeed_config=None, deepscale=False, deepscale_config=None, global_rank=0)
|
| 18 |
+
data_path: ['/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/data/train.jsonl']
|
| 19 |
+
/usr/lib/python3/dist-packages/torch/utils/cpp_extension.py:2376: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
|
| 20 |
+
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
|
| 21 |
+
warnings.warn(
|
| 22 |
+
2026-01-29 22:25:34.798274: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
| 23 |
+
2026-01-29 22:25:34.808869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
| 24 |
+
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
|
| 25 |
+
E0000 00:00:1769725534.821805 31271 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
| 26 |
+
E0000 00:00:1769725534.825823 31271 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
| 27 |
+
W0000 00:00:1769725534.835606 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 28 |
+
W0000 00:00:1769725534.835626 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 29 |
+
W0000 00:00:1769725534.835656 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 30 |
+
W0000 00:00:1769725534.835658 31271 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
|
| 31 |
+
2026-01-29 22:25:34.838493: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
| 32 |
+
To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
| 33 |
+
Stage 3 initialize beginning
|
| 34 |
+
MA 0.72 GB Max_MA 2.9 GB CA 2.9 GB Max_CA 3 GB
|
| 35 |
+
CPU Virtual Memory: used = 16.26 GB, percent = 7.4%
|
| 36 |
+
DeepSpeedZeRoOffload initialize [begin]
|
| 37 |
+
MA 0.72 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
|
| 38 |
+
CPU Virtual Memory: used = 16.25 GB, percent = 7.3%
|
| 39 |
+
Parameter Offload - Persistent parameters statistics: param_count = 145, numel = 196096
|
| 40 |
+
DeepSpeedZeRoOffload initialize [end]
|
| 41 |
+
MA 0.0 GB Max_MA 0.72 GB CA 2.9 GB Max_CA 3 GB
|
| 42 |
+
CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
|
| 43 |
+
Before creating fp16 partitions
|
| 44 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 45 |
+
CPU Virtual Memory: used = 16.7 GB, percent = 7.6%
|
| 46 |
+
After creating fp16 partitions: 5
|
| 47 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 48 |
+
CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
|
| 49 |
+
Before creating fp32 partitions
|
| 50 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 51 |
+
CPU Virtual Memory: used = 19.89 GB, percent = 9.0%
|
| 52 |
+
After creating fp32 partitions
|
| 53 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 54 |
+
CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
|
| 55 |
+
Before initializing optimizer states
|
| 56 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 57 |
+
CPU Virtual Memory: used = 34.0 GB, percent = 15.4%
|
| 58 |
+
After initializing optimizer states
|
| 59 |
+
MA 0.0 GB Max_MA 0.0 GB CA 2.9 GB Max_CA 3 GB
|
| 60 |
+
CPU Virtual Memory: used = 49.09 GB, percent = 22.2%
|
| 61 |
+
After initializing ZeRO optimizer
|
| 62 |
+
MA 0.93 GB Max_MA 2.38 GB CA 3.83 GB Max_CA 4 GB
|
| 63 |
+
CPU Virtual Memory: used = 56.32 GB, percent = 25.5%
|
| 64 |
+
***** Running training *****
|
| 65 |
+
Beginning of Epoch 1/1, Total Micro Batches 5400
|
| 66 |
+
Model Parameters: 4.022 B, Latency: 2.91s, TFLOPs: 3.40, Samples/sec: 0.34, Time/seq 2.91s, Batch Size: 1, Sequence Length: 512
|
| 67 |
+
Model Parameters: 4.022 B, Latency: 3.07s, TFLOPs: 3.22, Samples/sec: 0.33, Time/seq 3.07s, Batch Size: 1, Sequence Length: 512
|
| 68 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 69 |
+
Model Parameters: 4.022 B, Latency: 2.35s, TFLOPs: 4.20, Samples/sec: 0.43, Time/seq 2.35s, Batch Size: 1, Sequence Length: 512
|
| 70 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 71 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 72 |
+
Model Parameters: 4.022 B, Latency: 2.33s, TFLOPs: 4.23, Samples/sec: 0.43, Time/seq 2.33s, Batch Size: 1, Sequence Length: 512
|
| 73 |
+
Model Parameters: 4.022 B, Latency: 6.18s, TFLOPs: 1.60, Samples/sec: 0.16, Time/seq 6.18s, Batch Size: 1, Sequence Length: 512
|
| 74 |
+
Model Parameters: 4.022 B, Latency: 2.11s, TFLOPs: 4.69, Samples/sec: 0.47, Time/seq 2.11s, Batch Size: 1, Sequence Length: 512
|
| 75 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 76 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 77 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 78 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 79 |
+
Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
|
| 80 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 81 |
+
Model Parameters: 4.022 B, Latency: 4.19s, TFLOPs: 2.36, Samples/sec: 0.24, Time/seq 4.19s, Batch Size: 1, Sequence Length: 512
|
| 82 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 83 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 84 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 85 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 86 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 87 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 88 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 89 |
+
Model Parameters: 4.022 B, Latency: 4.21s, TFLOPs: 2.35, Samples/sec: 0.24, Time/seq 4.21s, Batch Size: 1, Sequence Length: 512
|
| 90 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 91 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 92 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 93 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.02, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 94 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 5.00, Samples/sec: 0.51, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 95 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 96 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.01, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 97 |
+
Model Parameters: 4.022 B, Latency: 4.24s, TFLOPs: 2.33, Samples/sec: 0.24, Time/seq 4.24s, Batch Size: 1, Sequence Length: 512
|
| 98 |
+
Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.14, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
|
| 99 |
+
Model Parameters: 4.022 B, Latency: 2.36s, TFLOPs: 4.19, Samples/sec: 0.42, Time/seq 2.36s, Batch Size: 1, Sequence Length: 512
|
| 100 |
+
Model Parameters: 4.022 B, Latency: 2.31s, TFLOPs: 4.27, Samples/sec: 0.43, Time/seq 2.31s, Batch Size: 1, Sequence Length: 512
|
| 101 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 102 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 103 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.96, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 104 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 105 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 106 |
+
Model Parameters: 4.022 B, Latency: 1.95s, TFLOPs: 5.06, Samples/sec: 0.51, Time/seq 1.95s, Batch Size: 1, Sequence Length: 512
|
| 107 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 108 |
+
Model Parameters: 4.022 B, Latency: 1.93s, TFLOPs: 5.12, Samples/sec: 0.52, Time/seq 1.93s, Batch Size: 1, Sequence Length: 512
|
| 109 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 110 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 111 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 112 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.97, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 113 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 114 |
+
Model Parameters: 4.022 B, Latency: 2.13s, TFLOPs: 4.64, Samples/sec: 0.47, Time/seq 2.13s, Batch Size: 1, Sequence Length: 512
|
| 115 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 116 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 117 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 118 |
+
Model Parameters: 4.022 B, Latency: 1.99s, TFLOPs: 4.98, Samples/sec: 0.50, Time/seq 1.99s, Batch Size: 1, Sequence Length: 512
|
| 119 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 120 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 121 |
+
Model Parameters: 4.022 B, Latency: 4.22s, TFLOPs: 2.34, Samples/sec: 0.24, Time/seq 4.22s, Batch Size: 1, Sequence Length: 512
|
| 122 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 123 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 124 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 125 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 126 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 127 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 128 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.75, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 129 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 130 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 131 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 132 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 133 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 134 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 135 |
+
Model Parameters: 4.022 B, Latency: 2.25s, TFLOPs: 4.40, Samples/sec: 0.45, Time/seq 2.25s, Batch Size: 1, Sequence Length: 512
|
| 136 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 137 |
+
Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
|
| 138 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 139 |
+
Model Parameters: 4.022 B, Latency: 2.39s, TFLOPs: 4.13, Samples/sec: 0.42, Time/seq 2.39s, Batch Size: 1, Sequence Length: 512
|
| 140 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 141 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.18, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 142 |
+
Model Parameters: 4.022 B, Latency: 2.37s, TFLOPs: 4.17, Samples/sec: 0.42, Time/seq 2.37s, Batch Size: 1, Sequence Length: 512
|
| 143 |
+
Model Parameters: 4.022 B, Latency: 2.28s, TFLOPs: 4.33, Samples/sec: 0.44, Time/seq 2.28s, Batch Size: 1, Sequence Length: 512
|
| 144 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 145 |
+
Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
|
| 146 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 147 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 148 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 149 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 150 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 151 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 152 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 153 |
+
Model Parameters: 4.022 B, Latency: 4.27s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.27s, Batch Size: 1, Sequence Length: 512
|
| 154 |
+
Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.67, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
|
| 155 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 156 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 157 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 158 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 159 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 160 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 161 |
+
Model Parameters: 4.022 B, Latency: 4.37s, TFLOPs: 2.26, Samples/sec: 0.23, Time/seq 4.37s, Batch Size: 1, Sequence Length: 512
|
| 162 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.95, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 163 |
+
Model Parameters: 4.022 B, Latency: 1.96s, TFLOPs: 5.04, Samples/sec: 0.51, Time/seq 1.96s, Batch Size: 1, Sequence Length: 512
|
| 164 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.08, Samples/sec: 0.51, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 165 |
+
Model Parameters: 4.022 B, Latency: 1.94s, TFLOPs: 5.09, Samples/sec: 0.52, Time/seq 1.94s, Batch Size: 1, Sequence Length: 512
|
| 166 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 167 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 168 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 169 |
+
Model Parameters: 4.022 B, Latency: 4.29s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.29s, Batch Size: 1, Sequence Length: 512
|
| 170 |
+
Model Parameters: 4.022 B, Latency: 2.17s, TFLOPs: 4.55, Samples/sec: 0.46, Time/seq 2.17s, Batch Size: 1, Sequence Length: 512
|
| 171 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 172 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 173 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.92, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 174 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 175 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 176 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 177 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 178 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 179 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 180 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 181 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 182 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 183 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 184 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 185 |
+
Model Parameters: 4.022 B, Latency: 4.54s, TFLOPs: 2.17, Samples/sec: 0.22, Time/seq 4.54s, Batch Size: 1, Sequence Length: 512
|
| 186 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 187 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 188 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 189 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 190 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 191 |
+
Model Parameters: 4.022 B, Latency: 2.34s, TFLOPs: 4.22, Samples/sec: 0.43, Time/seq 2.34s, Batch Size: 1, Sequence Length: 512
|
| 192 |
+
Model Parameters: 4.022 B, Latency: 2.12s, TFLOPs: 4.66, Samples/sec: 0.47, Time/seq 2.12s, Batch Size: 1, Sequence Length: 512
|
| 193 |
+
Model Parameters: 4.022 B, Latency: 4.17s, TFLOPs: 2.37, Samples/sec: 0.24, Time/seq 4.17s, Batch Size: 1, Sequence Length: 512
|
| 194 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 195 |
+
Model Parameters: 4.022 B, Latency: 1.98s, TFLOPs: 4.99, Samples/sec: 0.50, Time/seq 1.98s, Batch Size: 1, Sequence Length: 512
|
| 196 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 197 |
+
Model Parameters: 4.022 B, Latency: 1.97s, TFLOPs: 5.03, Samples/sec: 0.51, Time/seq 1.97s, Batch Size: 1, Sequence Length: 512
|
| 198 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 199 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 200 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 201 |
+
Model Parameters: 4.022 B, Latency: 4.30s, TFLOPs: 2.30, Samples/sec: 0.23, Time/seq 4.30s, Batch Size: 1, Sequence Length: 512
|
| 202 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.74, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 203 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 204 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 205 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 206 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 207 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 208 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 209 |
+
Model Parameters: 4.022 B, Latency: 4.36s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.36s, Batch Size: 1, Sequence Length: 512
|
| 210 |
+
Model Parameters: 4.022 B, Latency: 2.09s, TFLOPs: 4.73, Samples/sec: 0.48, Time/seq 2.09s, Batch Size: 1, Sequence Length: 512
|
| 211 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 212 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 213 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 214 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.88, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 215 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 216 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 217 |
+
Model Parameters: 4.022 B, Latency: 4.26s, TFLOPs: 2.32, Samples/sec: 0.23, Time/seq 4.26s, Batch Size: 1, Sequence Length: 512
|
| 218 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 219 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 220 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 221 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 222 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 223 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 224 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.91, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 225 |
+
Model Parameters: 4.022 B, Latency: 4.31s, TFLOPs: 2.29, Samples/sec: 0.23, Time/seq 4.31s, Batch Size: 1, Sequence Length: 512
|
| 226 |
+
Model Parameters: 4.022 B, Latency: 2.15s, TFLOPs: 4.60, Samples/sec: 0.46, Time/seq 2.15s, Batch Size: 1, Sequence Length: 512
|
| 227 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 228 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 229 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 230 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 231 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 232 |
+
Model Parameters: 4.022 B, Latency: 2.52s, TFLOPs: 3.92, Samples/sec: 0.40, Time/seq 2.52s, Batch Size: 1, Sequence Length: 512
|
| 233 |
+
Model Parameters: 4.022 B, Latency: 4.34s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.34s, Batch Size: 1, Sequence Length: 512
|
| 234 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 235 |
+
Model Parameters: 4.022 B, Latency: 2.00s, TFLOPs: 4.94, Samples/sec: 0.50, Time/seq 2.00s, Batch Size: 1, Sequence Length: 512
|
| 236 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 237 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 238 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 239 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 240 |
+
Model Parameters: 4.022 B, Latency: 2.01s, TFLOPs: 4.93, Samples/sec: 0.50, Time/seq 2.01s, Batch Size: 1, Sequence Length: 512
|
| 241 |
+
Model Parameters: 4.022 B, Latency: 4.28s, TFLOPs: 2.31, Samples/sec: 0.23, Time/seq 4.28s, Batch Size: 1, Sequence Length: 512
|
| 242 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.48, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 243 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.89, Samples/sec: 0.49, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 244 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.78, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 245 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 246 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 247 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 248 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 249 |
+
Model Parameters: 4.022 B, Latency: 4.33s, TFLOPs: 2.28, Samples/sec: 0.23, Time/seq 4.33s, Batch Size: 1, Sequence Length: 512
|
| 250 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 251 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 252 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 253 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 254 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 255 |
+
Model Parameters: 4.022 B, Latency: 2.02s, TFLOPs: 4.90, Samples/sec: 0.50, Time/seq 2.02s, Batch Size: 1, Sequence Length: 512
|
| 256 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.82, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 257 |
+
Model Parameters: 4.022 B, Latency: 4.25s, TFLOPs: 2.32, Samples/sec: 0.24, Time/seq 4.25s, Batch Size: 1, Sequence Length: 512
|
| 258 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 259 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.83, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 260 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.79, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 261 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.84, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 262 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.87, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 263 |
+
Model Parameters: 4.022 B, Latency: 2.10s, TFLOPs: 4.71, Samples/sec: 0.48, Time/seq 2.10s, Batch Size: 1, Sequence Length: 512
|
| 264 |
+
Model Parameters: 4.022 B, Latency: 2.03s, TFLOPs: 4.86, Samples/sec: 0.49, Time/seq 2.03s, Batch Size: 1, Sequence Length: 512
|
| 265 |
+
Model Parameters: 4.022 B, Latency: 4.35s, TFLOPs: 2.27, Samples/sec: 0.23, Time/seq 4.35s, Batch Size: 1, Sequence Length: 512
|
| 266 |
+
Model Parameters: 4.022 B, Latency: 2.14s, TFLOPs: 4.61, Samples/sec: 0.47, Time/seq 2.14s, Batch Size: 1, Sequence Length: 512
|
| 267 |
+
Model Parameters: 4.022 B, Latency: 2.08s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.08s, Batch Size: 1, Sequence Length: 512
|
| 268 |
+
Model Parameters: 4.022 B, Latency: 2.06s, TFLOPs: 4.80, Samples/sec: 0.49, Time/seq 2.06s, Batch Size: 1, Sequence Length: 512
|
| 269 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.76, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 270 |
+
Model Parameters: 4.022 B, Latency: 2.04s, TFLOPs: 4.85, Samples/sec: 0.49, Time/seq 2.04s, Batch Size: 1, Sequence Length: 512
|
| 271 |
+
Model Parameters: 4.022 B, Latency: 2.07s, TFLOPs: 4.77, Samples/sec: 0.48, Time/seq 2.07s, Batch Size: 1, Sequence Length: 512
|
| 272 |
+
Model Parameters: 4.022 B, Latency: 2.05s, TFLOPs: 4.81, Samples/sec: 0.49, Time/seq 2.05s, Batch Size: 1, Sequence Length: 512
|
| 273 |
+
[2026-01-29 22:33:59,925] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 274 |
+
[rank0]: Traceback (most recent call last):
|
| 275 |
+
[rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 434, in <module>
|
| 276 |
+
[rank0]: main()
|
| 277 |
+
[rank0]: File "/home/ubuntu/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 387, in main
|
| 278 |
+
[rank0]: model.step()
|
| 279 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2690, in step
|
| 280 |
+
[rank0]: self._take_model_step(lr_kwargs)
|
| 281 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2585, in _take_model_step
|
| 282 |
+
[rank0]: self.optimizer.step()
|
| 283 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
|
| 284 |
+
[rank0]: ret_val = func(*args, **kwargs)
|
| 285 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2220, in step
|
| 286 |
+
[rank0]: self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
|
| 287 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
|
| 288 |
+
[rank0]: ret_val = func(*args, **kwargs)
|
| 289 |
+
[rank0]: File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2168, in _reassign_or_swap_out_partitioned_parameters
|
| 290 |
+
[rank0]: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
| 291 |
+
[rank0]: KeyboardInterrupt
|
| 292 |
+
Traceback (most recent call last):
|
| 293 |
+
File "/home/ubuntu/.local/bin/deepspeed", line 6, in <module>
|
| 294 |
+
main()
|
| 295 |
+
File "/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/launcher/runner.py", line 646, in main
|
| 296 |
+
result.wait()
|
| 297 |
+
File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
|
| 298 |
+
return self._wait(timeout=timeout)
|
| 299 |
+
File "/usr/lib/python3.10/subprocess.py", line 1959, in _wait
|
| 300 |
+
(pid, sts) = self._try_wait(0)
|
| 301 |
+
File "/usr/lib/python3.10/subprocess.py", line 1917, in _try_wait
|
| 302 |
+
(pid, sts) = os.waitpid(self.pid, wait_flags)
|
| 303 |
+
KeyboardInterrupt
|
| 304 |
+
[2026-01-29 22:34:00,546] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 305 |
+
Exception ignored in atexit callback: <function shutdown_compile_workers at 0x7d22457a00d0>
|
| 306 |
+
Traceback (most recent call last):
|
| 307 |
+
File "/usr/lib/python3/dist-packages/torch/_inductor/async_compile.py", line 113, in shutdown_compile_workers
|
| 308 |
+
pool.shutdown()
|
| 309 |
+
File "/usr/lib/python3/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 239, in shutdown
|
| 310 |
+
self.process.wait(300)
|
| 311 |
+
File "/usr/lib/python3.10/subprocess.py", line 1209, in wait
|
| 312 |
+
return self._wait(timeout=timeout)
|
| 313 |
+
File "/usr/lib/python3.10/subprocess.py", line 1953, in _wait
|
| 314 |
+
time.sleep(delay)
|
| 315 |
+
KeyboardInterrupt:
|
| 316 |
+
[2026-01-29 22:34:00,990] [INFO] [launch.py:335:sigkill_handler] Killing subprocess 31271
|
| 317 |
+
[2026-01-29 22:34:04,967] [INFO] [launch.py:344:sigkill_handler] Main process received SIGINT, exiting
|